|
""" |
|
constrained_generator.py - JSON Schema Constrained Generation |
|
|
|
This implements constrained decoding to force valid JSON output: |
|
1. Token-by-token validation against JSON schema |
|
2. Backtracking on invalid JSON syntax |
|
3. Beam search with JSON constraints |
|
4. Schema-aware generation |
|
""" |
|
|
|
import torch |
|
import json |
|
import jsonschema |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from typing import List, Dict, Any, Optional |
|
import re |
|
|
|
class ConstrainedJSONGenerator: |
|
def __init__(self, model, tokenizer, device="mps"): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.device = device |
|
self.model.eval() |
|
|
|
def is_valid_json_prefix(self, text: str) -> bool: |
|
"""Check if text could be the start of valid JSON.""" |
|
text = text.strip() |
|
if not text: |
|
return True |
|
|
|
|
|
if not text.startswith('{'): |
|
return False |
|
|
|
|
|
try: |
|
json.loads(text) |
|
return True |
|
except json.JSONDecodeError as e: |
|
|
|
if "Expecting" in str(e) and "delimiter" in str(e): |
|
|
|
return True |
|
return False |
|
|
|
def get_valid_next_tokens(self, current_text: str, schema: Dict) -> List[int]: |
|
"""Get tokens that would keep JSON valid.""" |
|
valid_tokens = [] |
|
|
|
|
|
vocab_size = len(self.tokenizer.vocab) |
|
|
|
for token_id in range(vocab_size): |
|
if token_id == self.tokenizer.pad_token_id: |
|
continue |
|
|
|
token_text = self.tokenizer.decode([token_id]) |
|
new_text = current_text + token_text |
|
|
|
if self.is_valid_json_prefix(new_text): |
|
valid_tokens.append(token_id) |
|
|
|
|
|
if len(valid_tokens) > 50: |
|
break |
|
|
|
return valid_tokens |
|
|
|
def generate_constrained(self, prompt: str, schema: Dict, max_length: int = 200) -> str: |
|
"""Generate text with JSON constraints.""" |
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
generated_text = "" |
|
current_input_ids = inputs['input_ids'].clone() |
|
|
|
for step in range(max_length): |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(current_input_ids) |
|
logits = outputs.logits[0, -1, :] |
|
|
|
|
|
valid_tokens = self.get_valid_next_tokens(generated_text, schema) |
|
|
|
if not valid_tokens: |
|
|
|
if not generated_text.strip().endswith('}'): |
|
|
|
next_token_id = self.tokenizer.encode('}')[0] |
|
else: |
|
break |
|
else: |
|
|
|
masked_logits = logits.clone() |
|
mask = torch.full_like(logits, float('-inf')) |
|
mask[valid_tokens] = 0 |
|
masked_logits = masked_logits + mask |
|
|
|
|
|
probs = torch.softmax(masked_logits, dim=-1) |
|
next_token_id = torch.multinomial(probs, 1).item() |
|
|
|
|
|
current_input_ids = torch.cat([ |
|
current_input_ids, |
|
torch.tensor([[next_token_id]], device=self.device) |
|
], dim=1) |
|
|
|
|
|
new_token = self.tokenizer.decode([next_token_id]) |
|
generated_text += new_token |
|
|
|
|
|
try: |
|
parsed = json.loads(generated_text.strip()) |
|
if self.validate_against_schema(parsed, schema): |
|
break |
|
except: |
|
continue |
|
|
|
return generated_text.strip() |
|
|
|
def validate_against_schema(self, data: Dict, schema: Dict) -> bool: |
|
"""Validate JSON data against schema.""" |
|
try: |
|
jsonschema.validate(data, schema) |
|
return True |
|
except jsonschema.ValidationError: |
|
return False |
|
|
|
def generate_with_beam_search(self, prompt: str, schema: Dict, num_beams: int = 3) -> str: |
|
"""Generate with beam search and JSON constraints.""" |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=150, |
|
num_beams=num_beams, |
|
early_stopping=True, |
|
temperature=0.1, |
|
do_sample=False, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
num_return_sequences=num_beams |
|
) |
|
|
|
|
|
candidates = [] |
|
for output in outputs: |
|
generated_text = self.tokenizer.decode( |
|
output[inputs['input_ids'].shape[1]:], |
|
skip_special_tokens=True |
|
) |
|
candidates.append(generated_text.strip()) |
|
|
|
|
|
for candidate in candidates: |
|
try: |
|
parsed = json.loads(candidate) |
|
if self.validate_against_schema(parsed, schema): |
|
return candidate |
|
except json.JSONDecodeError: |
|
continue |
|
|
|
|
|
return candidates[0] if candidates else "" |
|
|
|
def create_json_schema_from_function(function_def: Dict) -> Dict: |
|
"""Create a JSON schema for validating function calls.""" |
|
return { |
|
"type": "object", |
|
"properties": { |
|
"name": { |
|
"type": "string", |
|
"const": function_def["name"] |
|
}, |
|
"arguments": function_def["parameters"] |
|
}, |
|
"required": ["name", "arguments"], |
|
"additionalProperties": False |
|
} |
|
|
|
def test_constrained_generation(): |
|
"""Test the constrained generator.""" |
|
print("π§ͺ Testing Constrained JSON Generation...") |
|
|
|
|
|
model_name = "HuggingFaceTB/SmolLM3-3B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float32, |
|
device_map="mps" if torch.backends.mps.is_available() else "auto" |
|
) |
|
|
|
generator = ConstrainedJSONGenerator(model, tokenizer) |
|
|
|
|
|
function_def = { |
|
"name": "get_weather", |
|
"description": "Get weather forecast", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"location": {"type": "string"}, |
|
"days": {"type": "integer"} |
|
}, |
|
"required": ["location", "days"] |
|
} |
|
} |
|
|
|
schema = create_json_schema_from_function(function_def) |
|
|
|
prompt = f"""<|im_start|>system |
|
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
|
|
|
<schema> |
|
{json.dumps(function_def, indent=2)} |
|
</schema> |
|
|
|
<|im_start|>user |
|
Get 3-day weather for New York<|im_end|> |
|
<|im_start|>assistant |
|
""" |
|
|
|
|
|
print("π― Testing constrained generation...") |
|
result = generator.generate_constrained(prompt, schema) |
|
print(f"π€ Constrained result: {result}") |
|
|
|
|
|
try: |
|
parsed = json.loads(result) |
|
generator.validate_against_schema(parsed, schema) |
|
print("β
Valid JSON with correct schema!") |
|
except Exception as e: |
|
print(f"β Validation failed: {e}") |
|
|
|
|
|
print("π― Testing beam search...") |
|
beam_result = generator.generate_with_beam_search(prompt, schema) |
|
print(f"π€ Beam result: {beam_result}") |
|
|
|
try: |
|
parsed = json.loads(beam_result) |
|
generator.validate_against_schema(parsed, schema) |
|
print("β
Beam search produced valid JSON!") |
|
except Exception as e: |
|
print(f"β Beam validation failed: {e}") |
|
|
|
if __name__ == "__main__": |
|
test_constrained_generation() |