File size: 8,882 Bytes
6639f75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
"""
constrained_generator.py - JSON Schema Constrained Generation
This implements constrained decoding to force valid JSON output:
1. Token-by-token validation against JSON schema
2. Backtracking on invalid JSON syntax
3. Beam search with JSON constraints
4. Schema-aware generation
"""
import torch
import json
import jsonschema
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Any, Optional
import re
class ConstrainedJSONGenerator:
def __init__(self, model, tokenizer, device="mps"):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.eval()
def is_valid_json_prefix(self, text: str) -> bool:
"""Check if text could be the start of valid JSON."""
text = text.strip()
if not text:
return True
# Must start with {
if not text.startswith('{'):
return False
# Try to parse - if it fails, check if it's a valid prefix
try:
json.loads(text)
return True
except json.JSONDecodeError as e:
# Check if it's a valid JSON prefix
if "Expecting" in str(e) and "delimiter" in str(e):
# This is likely a valid prefix that's just incomplete
return True
return False
def get_valid_next_tokens(self, current_text: str, schema: Dict) -> List[int]:
"""Get tokens that would keep JSON valid."""
valid_tokens = []
# Get all possible next tokens
vocab_size = len(self.tokenizer.vocab)
for token_id in range(vocab_size):
if token_id == self.tokenizer.pad_token_id:
continue
token_text = self.tokenizer.decode([token_id])
new_text = current_text + token_text
if self.is_valid_json_prefix(new_text):
valid_tokens.append(token_id)
# Early termination if we have enough valid tokens
if len(valid_tokens) > 50:
break
return valid_tokens
def generate_constrained(self, prompt: str, schema: Dict, max_length: int = 200) -> str:
"""Generate text with JSON constraints."""
# Encode prompt
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
generated_text = ""
current_input_ids = inputs['input_ids'].clone()
for step in range(max_length):
# Get model predictions
with torch.no_grad():
outputs = self.model(current_input_ids)
logits = outputs.logits[0, -1, :] # Last token logits
# Get valid next tokens for JSON
valid_tokens = self.get_valid_next_tokens(generated_text, schema)
if not valid_tokens:
# If no valid tokens, try to complete JSON
if not generated_text.strip().endswith('}'):
# Add closing brace
next_token_id = self.tokenizer.encode('}')[0]
else:
break
else:
# Mask invalid tokens
masked_logits = logits.clone()
mask = torch.full_like(logits, float('-inf'))
mask[valid_tokens] = 0
masked_logits = masked_logits + mask
# Sample from valid tokens
probs = torch.softmax(masked_logits, dim=-1)
next_token_id = torch.multinomial(probs, 1).item()
# Add token to sequence
current_input_ids = torch.cat([
current_input_ids,
torch.tensor([[next_token_id]], device=self.device)
], dim=1)
# Decode the new token
new_token = self.tokenizer.decode([next_token_id])
generated_text += new_token
# Check if we have complete JSON
try:
parsed = json.loads(generated_text.strip())
if self.validate_against_schema(parsed, schema):
break
except:
continue
return generated_text.strip()
def validate_against_schema(self, data: Dict, schema: Dict) -> bool:
"""Validate JSON data against schema."""
try:
jsonschema.validate(data, schema)
return True
except jsonschema.ValidationError:
return False
def generate_with_beam_search(self, prompt: str, schema: Dict, num_beams: int = 3) -> str:
"""Generate with beam search and JSON constraints."""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Use constrained beam search
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=150,
num_beams=num_beams,
early_stopping=True,
temperature=0.1,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
num_return_sequences=num_beams
)
# Decode all candidates
candidates = []
for output in outputs:
generated_text = self.tokenizer.decode(
output[inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
candidates.append(generated_text.strip())
# Find the best valid JSON
for candidate in candidates:
try:
parsed = json.loads(candidate)
if self.validate_against_schema(parsed, schema):
return candidate
except json.JSONDecodeError:
continue
# If no valid JSON found, return the first candidate
return candidates[0] if candidates else ""
def create_json_schema_from_function(function_def: Dict) -> Dict:
"""Create a JSON schema for validating function calls."""
return {
"type": "object",
"properties": {
"name": {
"type": "string",
"const": function_def["name"]
},
"arguments": function_def["parameters"]
},
"required": ["name", "arguments"],
"additionalProperties": False
}
def test_constrained_generation():
"""Test the constrained generator."""
print("π§ͺ Testing Constrained JSON Generation...")
# Load model
model_name = "HuggingFaceTB/SmolLM3-3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="mps" if torch.backends.mps.is_available() else "auto"
)
generator = ConstrainedJSONGenerator(model, tokenizer)
# Test schema
function_def = {
"name": "get_weather",
"description": "Get weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
"days": {"type": "integer"}
},
"required": ["location", "days"]
}
}
schema = create_json_schema_from_function(function_def)
prompt = f"""<|im_start|>system
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|>
<schema>
{json.dumps(function_def, indent=2)}
</schema>
<|im_start|>user
Get 3-day weather for New York<|im_end|>
<|im_start|>assistant
"""
# Test constrained generation
print("π― Testing constrained generation...")
result = generator.generate_constrained(prompt, schema)
print(f"π€ Constrained result: {result}")
# Validate result
try:
parsed = json.loads(result)
generator.validate_against_schema(parsed, schema)
print("β
Valid JSON with correct schema!")
except Exception as e:
print(f"β Validation failed: {e}")
# Test beam search
print("π― Testing beam search...")
beam_result = generator.generate_with_beam_search(prompt, schema)
print(f"π€ Beam result: {beam_result}")
try:
parsed = json.loads(beam_result)
generator.validate_against_schema(parsed, schema)
print("β
Beam search produced valid JSON!")
except Exception as e:
print(f"β Beam validation failed: {e}")
if __name__ == "__main__":
test_constrained_generation() |