""" constrained_generator.py - JSON Schema Constrained Generation This implements constrained decoding to force valid JSON output: 1. Token-by-token validation against JSON schema 2. Backtracking on invalid JSON syntax 3. Beam search with JSON constraints 4. Schema-aware generation """ import torch import json import jsonschema from transformers import AutoTokenizer, AutoModelForCausalLM from typing import List, Dict, Any, Optional import re class ConstrainedJSONGenerator: def __init__(self, model, tokenizer, device="mps"): self.model = model self.tokenizer = tokenizer self.device = device self.model.eval() def is_valid_json_prefix(self, text: str) -> bool: """Check if text could be the start of valid JSON.""" text = text.strip() if not text: return True # Must start with { if not text.startswith('{'): return False # Try to parse - if it fails, check if it's a valid prefix try: json.loads(text) return True except json.JSONDecodeError as e: # Check if it's a valid JSON prefix if "Expecting" in str(e) and "delimiter" in str(e): # This is likely a valid prefix that's just incomplete return True return False def get_valid_next_tokens(self, current_text: str, schema: Dict) -> List[int]: """Get tokens that would keep JSON valid.""" valid_tokens = [] # Get all possible next tokens vocab_size = len(self.tokenizer.vocab) for token_id in range(vocab_size): if token_id == self.tokenizer.pad_token_id: continue token_text = self.tokenizer.decode([token_id]) new_text = current_text + token_text if self.is_valid_json_prefix(new_text): valid_tokens.append(token_id) # Early termination if we have enough valid tokens if len(valid_tokens) > 50: break return valid_tokens def generate_constrained(self, prompt: str, schema: Dict, max_length: int = 200) -> str: """Generate text with JSON constraints.""" # Encode prompt inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) generated_text = "" current_input_ids = inputs['input_ids'].clone() for step in range(max_length): # Get model predictions with torch.no_grad(): outputs = self.model(current_input_ids) logits = outputs.logits[0, -1, :] # Last token logits # Get valid next tokens for JSON valid_tokens = self.get_valid_next_tokens(generated_text, schema) if not valid_tokens: # If no valid tokens, try to complete JSON if not generated_text.strip().endswith('}'): # Add closing brace next_token_id = self.tokenizer.encode('}')[0] else: break else: # Mask invalid tokens masked_logits = logits.clone() mask = torch.full_like(logits, float('-inf')) mask[valid_tokens] = 0 masked_logits = masked_logits + mask # Sample from valid tokens probs = torch.softmax(masked_logits, dim=-1) next_token_id = torch.multinomial(probs, 1).item() # Add token to sequence current_input_ids = torch.cat([ current_input_ids, torch.tensor([[next_token_id]], device=self.device) ], dim=1) # Decode the new token new_token = self.tokenizer.decode([next_token_id]) generated_text += new_token # Check if we have complete JSON try: parsed = json.loads(generated_text.strip()) if self.validate_against_schema(parsed, schema): break except: continue return generated_text.strip() def validate_against_schema(self, data: Dict, schema: Dict) -> bool: """Validate JSON data against schema.""" try: jsonschema.validate(data, schema) return True except jsonschema.ValidationError: return False def generate_with_beam_search(self, prompt: str, schema: Dict, num_beams: int = 3) -> str: """Generate with beam search and JSON constraints.""" inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) # Use constrained beam search with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=150, num_beams=num_beams, early_stopping=True, temperature=0.1, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, num_return_sequences=num_beams ) # Decode all candidates candidates = [] for output in outputs: generated_text = self.tokenizer.decode( output[inputs['input_ids'].shape[1]:], skip_special_tokens=True ) candidates.append(generated_text.strip()) # Find the best valid JSON for candidate in candidates: try: parsed = json.loads(candidate) if self.validate_against_schema(parsed, schema): return candidate except json.JSONDecodeError: continue # If no valid JSON found, return the first candidate return candidates[0] if candidates else "" def create_json_schema_from_function(function_def: Dict) -> Dict: """Create a JSON schema for validating function calls.""" return { "type": "object", "properties": { "name": { "type": "string", "const": function_def["name"] }, "arguments": function_def["parameters"] }, "required": ["name", "arguments"], "additionalProperties": False } def test_constrained_generation(): """Test the constrained generator.""" print("๐Ÿงช Testing Constrained JSON Generation...") # Load model model_name = "HuggingFaceTB/SmolLM3-3B" tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, device_map="mps" if torch.backends.mps.is_available() else "auto" ) generator = ConstrainedJSONGenerator(model, tokenizer) # Test schema function_def = { "name": "get_weather", "description": "Get weather forecast", "parameters": { "type": "object", "properties": { "location": {"type": "string"}, "days": {"type": "integer"} }, "required": ["location", "days"] } } schema = create_json_schema_from_function(function_def) prompt = f"""<|im_start|>system You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> {json.dumps(function_def, indent=2)} <|im_start|>user Get 3-day weather for New York<|im_end|> <|im_start|>assistant """ # Test constrained generation print("๐ŸŽฏ Testing constrained generation...") result = generator.generate_constrained(prompt, schema) print(f"๐Ÿค– Constrained result: {result}") # Validate result try: parsed = json.loads(result) generator.validate_against_schema(parsed, schema) print("โœ… Valid JSON with correct schema!") except Exception as e: print(f"โŒ Validation failed: {e}") # Test beam search print("๐ŸŽฏ Testing beam search...") beam_result = generator.generate_with_beam_search(prompt, schema) print(f"๐Ÿค– Beam result: {beam_result}") try: parsed = json.loads(beam_result) generator.validate_against_schema(parsed, schema) print("โœ… Beam search produced valid JSON!") except Exception as e: print(f"โŒ Beam validation failed: {e}") if __name__ == "__main__": test_constrained_generation()