abhishekchohan commited on
Commit
d15b68f
·
verified ·
1 Parent(s): 56eadf5

Update tools/gemma_tool_parser.py

Browse files
Files changed (1) hide show
  1. tools/gemma_tool_parser.py +285 -291
tools/gemma_tool_parser.py CHANGED
@@ -1,291 +1,285 @@
1
- # SPDX-License-Identifier: Apache-2.0
2
-
3
- import json
4
- import re
5
- from collections.abc import Sequence
6
- from json import JSONDecoder
7
- from typing import Union
8
-
9
- import partial_json_parser
10
- from partial_json_parser.core.options import Allow
11
- from transformers import PreTrainedTokenizerBase
12
-
13
- from vllm.entrypoints.openai.protocol import (
14
- ChatCompletionRequest,
15
- DeltaFunctionCall,
16
- DeltaMessage,
17
- DeltaToolCall,
18
- ExtractedToolCallInformation,
19
- FunctionCall,
20
- ToolCall,
21
- )
22
- from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
23
- ToolParser,
24
- ToolParserManager,
25
- )
26
- from vllm.entrypoints.openai.tool_parsers.utils import (
27
- find_common_prefix,
28
- is_complete_json,
29
- partial_json_loads,
30
- )
31
- from vllm.logger import init_logger
32
- from vllm.utils import random_uuid
33
-
34
- logger = init_logger(__name__)
35
-
36
-
37
- @ToolParserManager.register_module("gemma_json")
38
- class GemmaJsonToolParser(ToolParser):
39
- """
40
- Tool call parser for Gemma 3 models intended for use with the
41
- appropriate Gemma chat template.
42
-
43
- Used when --enable-auto-tool-choice --tool-call-parser gemma_json
44
- are all set
45
- """
46
-
47
- def __init__(self, tokenizer: PreTrainedTokenizerBase):
48
- super().__init__(tokenizer)
49
-
50
- # initialize properties used for state when parsing tool calls in
51
- # streaming mode
52
- self.prev_tool_call_arr: list[dict] = []
53
- self.current_tool_id: int = -1
54
- self.current_tool_name_sent: bool = False
55
- self.streamed_args_for_tool: list[str] = []
56
-
57
- # Gemma specific tokens
58
- self.bos_token = "<bos>"
59
- self.model_token = "<start_of_turn>model"
60
- self.user_token = "<start_of_turn>user"
61
- self.end_turn_token = "<end_of_turn>"
62
-
63
- # For JSON detection
64
- self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
65
-
66
- def extract_tool_calls(
67
- self, model_output: str, request: ChatCompletionRequest
68
- ) -> ExtractedToolCallInformation:
69
- """
70
- Extract the tool calls from a complete model response.
71
- """
72
- # case -- if the response doesn't contain JSON, return a text response
73
- if not model_output.startswith("{"):
74
- return ExtractedToolCallInformation(
75
- tools_called=False, tool_calls=[], content=model_output
76
- )
77
-
78
- try:
79
- # load the JSON, and then use it to build the Function and
80
- # Tool Call
81
- dec = JSONDecoder()
82
- function_call_arr = []
83
-
84
- start_idx = 0
85
- while start_idx < len(model_output):
86
- try:
87
- (obj, end_idx) = dec.raw_decode(model_output[start_idx:])
88
- start_idx += end_idx
89
- # Skip any separators like semicolons or commas
90
- while start_idx < len(model_output) and model_output[start_idx] in [
91
- ";",
92
- ",",
93
- " ",
94
- ]:
95
- start_idx += 1
96
- function_call_arr.append(obj)
97
- except json.JSONDecodeError:
98
- break
99
-
100
- tool_calls: list[ToolCall] = [
101
- ToolCall(
102
- type="function",
103
- function=FunctionCall(
104
- name=raw_function_call["name"],
105
- # function call args are JSON but as a string
106
- arguments=json.dumps(
107
- raw_function_call["arguments"]
108
- if "arguments" in raw_function_call
109
- else raw_function_call["parameters"]
110
- ),
111
- ),
112
- )
113
- for raw_function_call in function_call_arr
114
- ]
115
-
116
- return ExtractedToolCallInformation(
117
- tools_called=True, tool_calls=tool_calls, content=None
118
- )
119
-
120
- except Exception:
121
- logger.exception("Error in extracting tool call from response.")
122
- # return information to just treat the tool call as regular JSON
123
- return ExtractedToolCallInformation(
124
- tools_called=False, tool_calls=[], content=model_output
125
- )
126
-
127
- def extract_tool_calls_streaming(
128
- self,
129
- previous_text: str,
130
- current_text: str,
131
- delta_text: str,
132
- previous_token_ids: Sequence[int],
133
- current_token_ids: Sequence[int],
134
- delta_token_ids: Sequence[int],
135
- request: ChatCompletionRequest,
136
- ) -> Union[DeltaMessage, None]:
137
-
138
- # Skip if not JSON format
139
- if not current_text.startswith("{"):
140
- return DeltaMessage(content=delta_text)
141
-
142
- # bit mask flags for partial JSON parsing
143
- flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
144
- try:
145
- tool_call_arr = []
146
- is_complete = []
147
- try:
148
- start_idx = 0
149
- while start_idx < len(current_text):
150
- (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
151
- is_complete.append(
152
- is_complete_json(current_text[start_idx : start_idx + end_idx])
153
- )
154
- start_idx += end_idx
155
- # Skip any separators like semicolons or commas
156
- while start_idx < len(current_text) and current_text[start_idx] in [
157
- ";",
158
- ",",
159
- " ",
160
- ]:
161
- start_idx += 1
162
-
163
- # Handle parameters field as arguments if needed
164
- if "parameters" in obj:
165
- assert (
166
- "arguments" not in obj
167
- ), "model generated both parameters and arguments"
168
- obj["arguments"] = obj["parameters"]
169
- tool_call_arr.append(obj)
170
- except partial_json_parser.core.exceptions.MalformedJSON:
171
- logger.debug("not enough tokens to parse into JSON yet")
172
- return None
173
-
174
- # select as the current tool call the one we're on the state at
175
- current_tool_call: dict = (
176
- tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
177
- )
178
-
179
- # case -- if no tokens have been streamed for the tool, e.g.
180
- # only the array brackets, stream nothing
181
- if len(tool_call_arr) == 0:
182
- return None
183
-
184
- # case: we are starting a new tool in the array
185
- # -> array has > 0 length AND length has moved past cursor
186
- elif (
187
- len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
188
- ):
189
- # if we're moving on to a new call, first make sure we
190
- # haven't missed anything in the previous one that was
191
- # auto-generated due to JSON completions, but wasn't
192
- # streamed to the client yet.
193
- if self.current_tool_id >= 0:
194
- cur_arguments = current_tool_call.get("arguments")
195
- if cur_arguments:
196
- cur_args_json = json.dumps(cur_arguments)
197
- sent = len(self.streamed_args_for_tool[self.current_tool_id])
198
- argument_diff = cur_args_json[sent:]
199
-
200
- logger.debug("got arguments diff: %s", argument_diff)
201
- delta = DeltaMessage(
202
- tool_calls=[
203
- DeltaToolCall(
204
- index=self.current_tool_id,
205
- function=DeltaFunctionCall(
206
- arguments=argument_diff
207
- ).model_dump(exclude_none=True),
208
- )
209
- ]
210
- )
211
- self.streamed_args_for_tool[
212
- self.current_tool_id
213
- ] += argument_diff
214
- else:
215
- delta = None
216
- else:
217
- delta = None
218
- # re-set stuff pertaining to progress in the current tool
219
- self.current_tool_id = len(tool_call_arr) - 1
220
- self.current_tool_name_sent = False
221
- self.streamed_args_for_tool.append("")
222
- logger.debug("starting on new tool %d", self.current_tool_id)
223
- return delta
224
-
225
- # if the current tool name hasn't been sent, send if available
226
- # - otherwise send nothing
227
- elif not self.current_tool_name_sent:
228
- function_name = current_tool_call.get("name")
229
- if function_name:
230
- delta = DeltaMessage(
231
- tool_calls=[
232
- DeltaToolCall(
233
- index=self.current_tool_id,
234
- type="function",
235
- id=f"chatcmpl-tool-{random_uuid()}",
236
- function=DeltaFunctionCall(
237
- name=function_name
238
- ).model_dump(exclude_none=True),
239
- )
240
- ]
241
- )
242
- self.current_tool_name_sent = True
243
- else:
244
- delta = None
245
-
246
- # now we know we're on the same tool call and we're streaming
247
- # arguments
248
- else:
249
- cur_arguments = current_tool_call.get("arguments")
250
- delta = None
251
-
252
- if cur_arguments:
253
- sent = len(self.streamed_args_for_tool[self.current_tool_id])
254
- cur_args_json = json.dumps(cur_arguments)
255
- prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
256
- "arguments"
257
- )
258
-
259
- argument_diff = None
260
- if is_complete[self.current_tool_id]:
261
- argument_diff = cur_args_json[sent:]
262
- elif prev_arguments:
263
- prev_args_json = json.dumps(prev_arguments)
264
- if cur_args_json != prev_args_json:
265
- prefix = find_common_prefix(prev_args_json, cur_args_json)
266
- argument_diff = prefix[sent:]
267
-
268
- if argument_diff is not None:
269
- delta = DeltaMessage(
270
- tool_calls=[
271
- DeltaToolCall(
272
- index=self.current_tool_id,
273
- function=DeltaFunctionCall(
274
- arguments=argument_diff
275
- ).model_dump(exclude_none=True),
276
- )
277
- ]
278
- )
279
- self.streamed_args_for_tool[
280
- self.current_tool_id
281
- ] += argument_diff
282
-
283
- self.prev_tool_call_arr = tool_call_arr
284
- return delta
285
-
286
- except Exception:
287
- logger.exception("Error trying to handle streaming tool call.")
288
- logger.debug(
289
- "Skipping chunk as a result of tool streaming extraction error"
290
- )
291
- return None
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import re
5
+ from collections.abc import Sequence
6
+ from json import JSONDecoder
7
+ from typing import Union
8
+
9
+ import partial_json_parser
10
+ from partial_json_parser.core.options import Allow
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from vllm.entrypoints.openai.protocol import (
14
+ ChatCompletionRequest,
15
+ DeltaFunctionCall,
16
+ DeltaMessage,
17
+ DeltaToolCall,
18
+ ExtractedToolCallInformation,
19
+ FunctionCall,
20
+ ToolCall,
21
+ )
22
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
23
+ ToolParser,
24
+ ToolParserManager,
25
+ )
26
+ from vllm.entrypoints.openai.tool_parsers.utils import (
27
+ find_common_prefix,
28
+ is_complete_json,
29
+ partial_json_loads,
30
+ )
31
+ from vllm.logger import init_logger
32
+ from vllm.utils import random_uuid
33
+
34
+ logger = init_logger(__name__)
35
+
36
+
37
+ @ToolParserManager.register_module("gemma")
38
+ class GemmaJsonToolParser(ToolParser):
39
+ """
40
+ Tool call parser for Gemma 3 models intended for use with the
41
+ appropriate Gemma chat template.
42
+
43
+ Used when --enable-auto-tool-choice --tool-call-parser gemma_json
44
+ are all set
45
+ """
46
+
47
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
48
+ super().__init__(tokenizer)
49
+
50
+ # initialize properties used for state when parsing tool calls in
51
+ # streaming mode
52
+ self.prev_tool_call_arr: list[dict] = []
53
+ self.current_tool_id: int = -1
54
+ self.current_tool_name_sent: bool = False
55
+ self.streamed_args_for_tool: list[str] = []
56
+
57
+ # Gemma specific tokens
58
+ self.bos_token = "<bos>"
59
+ self.model_token = "<start_of_turn>model"
60
+ self.user_token = "<start_of_turn>user"
61
+ self.end_turn_token = "<end_of_turn>"
62
+
63
+ # For JSON detection
64
+ self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
65
+
66
+ def extract_tool_calls(
67
+ self, model_output: str, request: ChatCompletionRequest
68
+ ) -> ExtractedToolCallInformation:
69
+ """
70
+ Extract the tool calls from a complete model response.
71
+ """
72
+ # case -- if the response doesn't contain JSON, return a text response
73
+ if not model_output.startswith("{"):
74
+ return ExtractedToolCallInformation(
75
+ tools_called=False, tool_calls=[], content=model_output
76
+ )
77
+
78
+ try:
79
+ # load the JSON, and then use it to build the Function and
80
+ # Tool Call
81
+ dec = JSONDecoder()
82
+ function_call_arr = []
83
+
84
+ start_idx = 0
85
+ while start_idx < len(model_output):
86
+ try:
87
+ (obj, end_idx) = dec.raw_decode(model_output[start_idx:])
88
+ start_idx += end_idx
89
+ # Skip any separators like semicolons or commas
90
+ while start_idx < len(model_output) and model_output[start_idx] in [
91
+ ";",
92
+ ",",
93
+ " ",
94
+ ]:
95
+ start_idx += 1
96
+ function_call_arr.append(obj)
97
+ except json.JSONDecodeError:
98
+ break
99
+
100
+ tool_calls: list[ToolCall] = [
101
+ ToolCall(
102
+ type="function",
103
+ function=FunctionCall(
104
+ name=raw_function_call["name"],
105
+ # function call args are JSON but as a string
106
+ arguments=json.dumps(
107
+ raw_function_call["arguments"]
108
+ if "arguments" in raw_function_call
109
+ else raw_function_call["parameters"]
110
+ ),
111
+ ),
112
+ )
113
+ for raw_function_call in function_call_arr
114
+ ]
115
+
116
+ return ExtractedToolCallInformation(
117
+ tools_called=True, tool_calls=tool_calls, content=None
118
+ )
119
+
120
+ except Exception:
121
+ logger.exception("Error in extracting tool call from response.")
122
+ # return information to just treat the tool call as regular JSON
123
+ return ExtractedToolCallInformation(
124
+ tools_called=False, tool_calls=[], content=model_output
125
+ )
126
+
127
+ def extract_tool_calls_streaming(
128
+ self,
129
+ previous_text: str,
130
+ current_text: str,
131
+ delta_text: str,
132
+ previous_token_ids: Sequence[int],
133
+ current_token_ids: Sequence[int],
134
+ delta_token_ids: Sequence[int],
135
+ request: ChatCompletionRequest,
136
+ ) -> Union[DeltaMessage, None]:
137
+
138
+ # Skip if not JSON format
139
+ if not current_text.startswith("{"):
140
+ return DeltaMessage(content=delta_text)
141
+
142
+ # bit mask flags for partial JSON parsing
143
+ flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
144
+ try:
145
+ tool_call_arr = []
146
+ is_complete = []
147
+ try:
148
+ start_idx = 0
149
+ while start_idx < len(current_text):
150
+ (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
151
+ is_complete.append(
152
+ is_complete_json(current_text[start_idx : start_idx + end_idx])
153
+ )
154
+ start_idx += end_idx
155
+ # Skip any separators like semicolons or commas
156
+ while start_idx < len(current_text) and current_text[start_idx] in [
157
+ ";",
158
+ ",",
159
+ " ",
160
+ ]:
161
+ start_idx += 1
162
+
163
+ # Handle parameters field as arguments if needed
164
+ if "parameters" in obj:
165
+ assert (
166
+ "arguments" not in obj
167
+ ), "model generated both parameters and arguments"
168
+ obj["arguments"] = obj["parameters"]
169
+ tool_call_arr.append(obj)
170
+ except partial_json_parser.core.exceptions.MalformedJSON:
171
+ logger.debug("not enough tokens to parse into JSON yet")
172
+ return None
173
+
174
+ # select as the current tool call the one we're on the state at
175
+ current_tool_call: dict = (
176
+ tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
177
+ )
178
+
179
+ # case -- if no tokens have been streamed for the tool, e.g.
180
+ # only the array brackets, stream nothing
181
+ if len(tool_call_arr) == 0:
182
+ return None
183
+
184
+ # case: we are starting a new tool in the array
185
+ # -> array has > 0 length AND length has moved past cursor
186
+ elif (
187
+ len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
188
+ ):
189
+ if self.current_tool_id >= 0:
190
+ cur_arguments = current_tool_call.get("arguments")
191
+ if cur_arguments:
192
+ cur_args_json = json.dumps(cur_arguments)
193
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
194
+ argument_diff = cur_args_json[sent:]
195
+
196
+ logger.debug("got arguments diff: %s", argument_diff)
197
+ delta = DeltaMessage(
198
+ tool_calls=[
199
+ DeltaToolCall(
200
+ index=self.current_tool_id,
201
+ function=DeltaFunctionCall(
202
+ arguments=argument_diff
203
+ ).model_dump(exclude_none=True),
204
+ )
205
+ ]
206
+ )
207
+ self.streamed_args_for_tool[
208
+ self.current_tool_id
209
+ ] += argument_diff
210
+ else:
211
+ delta = None
212
+ else:
213
+ delta = None
214
+ # re-set stuff pertaining to progress in the current tool
215
+ self.current_tool_id = len(tool_call_arr) - 1
216
+ self.current_tool_name_sent = False
217
+ self.streamed_args_for_tool.append("")
218
+ logger.debug("starting on new tool %d", self.current_tool_id)
219
+ return delta
220
+
221
+ # if the current tool name hasn't been sent, send if available
222
+ # - otherwise send nothing
223
+ elif not self.current_tool_name_sent:
224
+ function_name = current_tool_call.get("name")
225
+ if function_name:
226
+ delta = DeltaMessage(
227
+ tool_calls=[
228
+ DeltaToolCall(
229
+ index=self.current_tool_id,
230
+ type="function",
231
+ id=f"chatcmpl-tool-{random_uuid()}",
232
+ function=DeltaFunctionCall(
233
+ name=function_name
234
+ ).model_dump(exclude_none=True),
235
+ )
236
+ ]
237
+ )
238
+ self.current_tool_name_sent = True
239
+ else:
240
+ delta = None
241
+
242
+ else:
243
+ cur_arguments = current_tool_call.get("arguments")
244
+ delta = None
245
+
246
+ if cur_arguments:
247
+ sent = len(self.streamed_args_for_tool[self.current_tool_id])
248
+ cur_args_json = json.dumps(cur_arguments)
249
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
250
+ "arguments"
251
+ )
252
+
253
+ argument_diff = None
254
+ if is_complete[self.current_tool_id]:
255
+ argument_diff = cur_args_json[sent:]
256
+ elif prev_arguments:
257
+ prev_args_json = json.dumps(prev_arguments)
258
+ if cur_args_json != prev_args_json:
259
+ prefix = find_common_prefix(prev_args_json, cur_args_json)
260
+ argument_diff = prefix[sent:]
261
+
262
+ if argument_diff is not None:
263
+ delta = DeltaMessage(
264
+ tool_calls=[
265
+ DeltaToolCall(
266
+ index=self.current_tool_id,
267
+ function=DeltaFunctionCall(
268
+ arguments=argument_diff
269
+ ).model_dump(exclude_none=True),
270
+ )
271
+ ]
272
+ )
273
+ self.streamed_args_for_tool[
274
+ self.current_tool_id
275
+ ] += argument_diff
276
+
277
+ self.prev_tool_call_arr = tool_call_arr
278
+ return delta
279
+
280
+ except Exception:
281
+ logger.exception("Error trying to handle streaming tool call.")
282
+ logger.debug(
283
+ "Skipping chunk as a result of tool streaming extraction error"
284
+ )
285
+ return None