import ast import json import re from collections.abc import Sequence from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager, ) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid logger = init_logger(__name__) @ToolParserManager.register_module("llama_nemotron_xml") class LlamaNemotronXMLToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 # Potentially for streaming self.streamed_args_for_tool: list[str] = [] # Potentially for streaming self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" # Regex to find full ... blocks and capture their content self.tool_call_block_regex = re.compile(r"(.*?)", re.DOTALL) # Regex to find ... within a tool_call block content self.name_regex = re.compile(r"(.*?)", re.DOTALL) # Regex to find value pairs within the tool_call block content (excluding tags) self.param_regex = re.compile(r"<([^/>\s]+)>(.*?)", re.DOTALL) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: tool_call_start_index = model_output.find(self.tool_call_start_token) if tool_call_start_index == -1: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) content = model_output[:tool_call_start_index].strip() tool_calls_str_content = model_output[tool_call_start_index:] parsed_tool_calls = [] try: # Find all occurrences of ... xml_tool_call_contents = self.tool_call_block_regex.findall(tool_calls_str_content) for tool_content_str in xml_tool_call_contents: name_match = self.name_regex.search(tool_content_str) if not name_match: logger.warning(f"Could not find tool name in XML block: {tool_content_str}") continue tool_name = name_match.group(1).strip() parsed_arguments = {} # Find all parameter tags in the tool_call content, excluding the tag param_matches = self.param_regex.finditer(tool_content_str) for match in param_matches: param_name = match.group(1).strip() param_value_str = match.group(2).strip() # Skip the tag since it's not a parameter if param_name == "tool": continue target_type = None # Try to get type from request.tools schema if request.tools: for tool_def in request.tools: if tool_def.function.name == tool_name: if tool_def.function.parameters and \ isinstance(tool_def.function.parameters, dict) and \ "properties" in tool_def.function.parameters and \ isinstance(tool_def.function.parameters["properties"], dict) and \ param_name in tool_def.function.parameters["properties"] and \ isinstance(tool_def.function.parameters["properties"][param_name], dict): target_type = tool_def.function.parameters["properties"][param_name].get("type") break typed_param_value = param_value_str # Default to string if target_type: try: if target_type == "string": typed_param_value = param_value_str elif target_type == "integer": typed_param_value = int(param_value_str) elif target_type == "number": typed_param_value = float(param_value_str) elif target_type == "boolean": typed_param_value = param_value_str.lower() == 'true' elif target_type in ["object", "array"]: try: typed_param_value = json.loads(param_value_str) except json.JSONDecodeError: # Fallback for non-strict JSON like Python dict/list string typed_param_value = ast.literal_eval(param_value_str) else: # Unknown type, keep as string typed_param_value = param_value_str except (ValueError, SyntaxError, json.JSONDecodeError) as e: logger.warning( f"Could not convert param '{param_name}' with value '{param_value_str}' " f"to type '{target_type}'. Error: {e}. Using string value." ) typed_param_value = param_value_str else: # No schema type, try ast.literal_eval try: # For values like "true", "123", "['a', 'b']" # ast.literal_eval('some_string_without_quotes') will raise SyntaxError if (param_value_str.startswith("'") and param_value_str.endswith("'")) or \ (param_value_str.startswith('"') and param_value_str.endswith('"')) or \ (param_value_str.startswith('[') and param_value_str.endswith(']')) or \ (param_value_str.startswith('{') and param_value_str.endswith('}')) or \ param_value_str.lower() in ['true', 'false', 'none'] or \ param_value_str.replace('.', '', 1).isdigit() or \ (param_value_str.startswith('-') and param_value_str[1:].replace('.', '', 1).isdigit()): typed_param_value = ast.literal_eval(param_value_str) else: # It's likely a plain string not meant for ast.literal_eval typed_param_value = param_value_str except (ValueError, SyntaxError): typed_param_value = param_value_str # Keep as string if ast.literal_eval fails parsed_arguments[param_name] = typed_param_value parsed_tool_calls.append(ToolCall( id=f"call_{random_uuid()}", type="function", function=FunctionCall( name=tool_name, arguments=json.dumps(parsed_arguments, ensure_ascii=False), ), )) return ExtractedToolCallInformation( tools_called=len(parsed_tool_calls) > 0, tool_calls=parsed_tool_calls, content=content if content else None, ) except Exception: logger.exception(f"Error in extracting XML tool call from response. Response: {model_output}") # Fallback to original model output if parsing fails catastrophically return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: raise NotImplementedError("Tool calling is not supported in streaming mode!") @ToolParserManager.register_module("llama_nemotron_json") class LlamaNemotronJSONToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.tool_call_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) else: try: str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip() if not str_tool_calls.startswith("["): str_tool_calls = "[" + str_tool_calls if not str_tool_calls.endswith("]"): str_tool_calls = "]" + str_tool_calls json_tool_calls = json.loads(str_tool_calls) tool_calls = [] for tool_call in json_tool_calls: try: tool_calls.append(ToolCall( type="function", function=FunctionCall( name=tool_call["name"], arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \ if isinstance(tool_call["arguments"], dict) else tool_call["arguments"], ), )) except: continue content = model_output[:model_output.rfind(self.tool_call_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None, ) except Exception: logger.exception(f"Error in extracting tool call from response. Response: {model_output}") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: raise NotImplementedError("Tool calling is not supported in streaming mode!") @ToolParserManager.register_module("llama_nemotron_pythonic") class LlamaNemotronPythonicToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) # Regex to parse pythonic function calls: function_name(arg1="value1", arg2=123, arg3=True) self.function_call_regex = re.compile(r"(\w+)\((.*?)\)$", re.DOTALL) def parse_function_arguments(self, args_str: str) -> dict: """Parse pythonic function arguments string into a dictionary""" if not args_str.strip(): return {} # Use ast.parse to safely parse the function call arguments # We'll construct a temporary function call and parse it try: # Create a dummy function call to parse arguments dummy_code = f"dummy_func({args_str})" parsed = ast.parse(dummy_code, mode='eval') # Extract arguments from the AST call_node = parsed.body if not isinstance(call_node, ast.Call): return {} arguments = {} # Handle keyword arguments for keyword in call_node.keywords: if keyword.arg is None: # **kwargs continue # Convert AST value to Python value try: value = ast.literal_eval(keyword.value) arguments[keyword.arg] = value except (ValueError, TypeError): # If literal_eval fails, try to get the raw value if isinstance(keyword.value, ast.Name): arguments[keyword.arg] = keyword.value.id elif isinstance(keyword.value, ast.Constant): arguments[keyword.arg] = keyword.value.value else: # Fallback: convert to string arguments[keyword.arg] = ast.unparse(keyword.value) # Handle positional arguments (less common in tool calls but supported) for i, arg in enumerate(call_node.args): try: value = ast.literal_eval(arg) arguments[f"arg_{i}"] = value except (ValueError, TypeError): if isinstance(arg, ast.Name): arguments[f"arg_{i}"] = arg.id elif isinstance(arg, ast.Constant): arguments[f"arg_{i}"] = arg.value else: arguments[f"arg_{i}"] = ast.unparse(arg) return arguments except (SyntaxError, ValueError) as e: logger.warning(f"Failed to parse function arguments '{args_str}': {e}") return {} def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.tool_call_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) tool_call_start_index = model_output.find(self.tool_call_start_token) content = model_output[:tool_call_start_index].strip() try: # Extract content between tags tool_call_matches = self.tool_call_regex.findall(model_output) if not tool_call_matches: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) tool_calls_content = tool_call_matches[0].strip() # Split by lines to get individual function calls function_lines = [line.strip() for line in tool_calls_content.split('\n') if line.strip()] parsed_tool_calls = [] for func_line in function_lines: # Parse each function call match = self.function_call_regex.match(func_line) if not match: logger.warning(f"Could not parse function call: {func_line}") continue function_name = match.group(1) args_str = match.group(2) # Parse arguments parsed_arguments = self.parse_function_arguments(args_str) # Apply type conversion based on schema if available if request.tools: for tool_def in request.tools: if tool_def.function.name == function_name: schema_properties = {} if (tool_def.function.parameters and isinstance(tool_def.function.parameters, dict) and "properties" in tool_def.function.parameters and isinstance(tool_def.function.parameters["properties"], dict)): schema_properties = tool_def.function.parameters["properties"] # Convert arguments based on schema types for arg_name, arg_value in parsed_arguments.items(): if arg_name in schema_properties: param_info = schema_properties[arg_name] target_type = param_info.get("type") try: if target_type == "string" and not isinstance(arg_value, str): parsed_arguments[arg_name] = str(arg_value) elif target_type == "integer" and not isinstance(arg_value, int): parsed_arguments[arg_name] = int(arg_value) elif target_type == "number" and not isinstance(arg_value, (int, float)): parsed_arguments[arg_name] = float(arg_value) elif target_type == "boolean" and not isinstance(arg_value, bool): if isinstance(arg_value, str): parsed_arguments[arg_name] = arg_value.lower() in ['true', '1', 'yes'] else: parsed_arguments[arg_name] = bool(arg_value) elif target_type in ["object", "array"]: if isinstance(arg_value, str): try: parsed_arguments[arg_name] = json.loads(arg_value) except json.JSONDecodeError: # Keep as string if JSON parsing fails pass except (ValueError, TypeError) as e: logger.warning(f"Type conversion failed for {arg_name}: {e}") # Keep original value if conversion fails break parsed_tool_calls.append(ToolCall( id=f"call_{random_uuid()}", type="function", function=FunctionCall( name=function_name, arguments=json.dumps(parsed_arguments, ensure_ascii=False), ), )) return ExtractedToolCallInformation( tools_called=len(parsed_tool_calls) > 0, tool_calls=parsed_tool_calls, content=content if content else None, ) except Exception: logger.exception(f"Error in extracting pythonic tool call from response. Response: {model_output}") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output, ) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: raise NotImplementedError("Tool calling is not supported in streaming mode!")