Dynamic-Function-Calling-Agent / constrained_generator.py
jlov7's picture
feat: Multi-tool selection and robustness testing
6639f75
"""
constrained_generator.py - JSON Schema Constrained Generation
This implements constrained decoding to force valid JSON output:
1. Token-by-token validation against JSON schema
2. Backtracking on invalid JSON syntax
3. Beam search with JSON constraints
4. Schema-aware generation
"""
import torch
import json
import jsonschema
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Any, Optional
import re
class ConstrainedJSONGenerator:
def __init__(self, model, tokenizer, device="mps"):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.eval()
def is_valid_json_prefix(self, text: str) -> bool:
"""Check if text could be the start of valid JSON."""
text = text.strip()
if not text:
return True
# Must start with {
if not text.startswith('{'):
return False
# Try to parse - if it fails, check if it's a valid prefix
try:
json.loads(text)
return True
except json.JSONDecodeError as e:
# Check if it's a valid JSON prefix
if "Expecting" in str(e) and "delimiter" in str(e):
# This is likely a valid prefix that's just incomplete
return True
return False
def get_valid_next_tokens(self, current_text: str, schema: Dict) -> List[int]:
"""Get tokens that would keep JSON valid."""
valid_tokens = []
# Get all possible next tokens
vocab_size = len(self.tokenizer.vocab)
for token_id in range(vocab_size):
if token_id == self.tokenizer.pad_token_id:
continue
token_text = self.tokenizer.decode([token_id])
new_text = current_text + token_text
if self.is_valid_json_prefix(new_text):
valid_tokens.append(token_id)
# Early termination if we have enough valid tokens
if len(valid_tokens) > 50:
break
return valid_tokens
def generate_constrained(self, prompt: str, schema: Dict, max_length: int = 200) -> str:
"""Generate text with JSON constraints."""
# Encode prompt
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
generated_text = ""
current_input_ids = inputs['input_ids'].clone()
for step in range(max_length):
# Get model predictions
with torch.no_grad():
outputs = self.model(current_input_ids)
logits = outputs.logits[0, -1, :] # Last token logits
# Get valid next tokens for JSON
valid_tokens = self.get_valid_next_tokens(generated_text, schema)
if not valid_tokens:
# If no valid tokens, try to complete JSON
if not generated_text.strip().endswith('}'):
# Add closing brace
next_token_id = self.tokenizer.encode('}')[0]
else:
break
else:
# Mask invalid tokens
masked_logits = logits.clone()
mask = torch.full_like(logits, float('-inf'))
mask[valid_tokens] = 0
masked_logits = masked_logits + mask
# Sample from valid tokens
probs = torch.softmax(masked_logits, dim=-1)
next_token_id = torch.multinomial(probs, 1).item()
# Add token to sequence
current_input_ids = torch.cat([
current_input_ids,
torch.tensor([[next_token_id]], device=self.device)
], dim=1)
# Decode the new token
new_token = self.tokenizer.decode([next_token_id])
generated_text += new_token
# Check if we have complete JSON
try:
parsed = json.loads(generated_text.strip())
if self.validate_against_schema(parsed, schema):
break
except:
continue
return generated_text.strip()
def validate_against_schema(self, data: Dict, schema: Dict) -> bool:
"""Validate JSON data against schema."""
try:
jsonschema.validate(data, schema)
return True
except jsonschema.ValidationError:
return False
def generate_with_beam_search(self, prompt: str, schema: Dict, num_beams: int = 3) -> str:
"""Generate with beam search and JSON constraints."""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Use constrained beam search
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
num_beams=num_beams,
early_stopping=True,
temperature=0.1,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
num_return_sequences=num_beams
)
# Decode all candidates
candidates = []
for output in outputs:
generated_text = self.tokenizer.decode(
output[inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
candidates.append(generated_text.strip())
# Find the best valid JSON
for candidate in candidates:
try:
parsed = json.loads(candidate)
if self.validate_against_schema(parsed, schema):
return candidate
except json.JSONDecodeError:
continue
# If no valid JSON found, return the first candidate
return candidates[0] if candidates else ""
def create_json_schema_from_function(function_def: Dict) -> Dict:
"""Create a JSON schema for validating function calls."""
return {
"type": "object",
"properties": {
"name": {
"type": "string",
"const": function_def["name"]
},
"arguments": function_def["parameters"]
},
"required": ["name", "arguments"],
"additionalProperties": False
}
def test_constrained_generation():
"""Test the constrained generator."""
print("πŸ§ͺ Testing Constrained JSON Generation...")
# Load model
model_name = "HuggingFaceTB/SmolLM3-3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="mps" if torch.backends.mps.is_available() else "auto"
)
generator = ConstrainedJSONGenerator(model, tokenizer)
# Test schema
function_def = {
"name": "get_weather",
"description": "Get weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
"days": {"type": "integer"}
},
"required": ["location", "days"]
}
}
schema = create_json_schema_from_function(function_def)
prompt = f"""<|im_start|>system
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|>
<schema>
{json.dumps(function_def, indent=2)}
</schema>
<|im_start|>user
Get 3-day weather for New York<|im_end|>
<|im_start|>assistant
"""
# Test constrained generation
print("🎯 Testing constrained generation...")
result = generator.generate_constrained(prompt, schema)
print(f"πŸ€– Constrained result: {result}")
# Validate result
try:
parsed = json.loads(result)
generator.validate_against_schema(parsed, schema)
print("βœ… Valid JSON with correct schema!")
except Exception as e:
print(f"❌ Validation failed: {e}")
# Test beam search
print("🎯 Testing beam search...")
beam_result = generator.generate_with_beam_search(prompt, schema)
print(f"πŸ€– Beam result: {beam_result}")
try:
parsed = json.loads(beam_result)
generator.validate_against_schema(parsed, schema)
print("βœ… Beam search produced valid JSON!")
except Exception as e:
print(f"❌ Beam validation failed: {e}")
if __name__ == "__main__":
test_constrained_generation()