""" test_constrained_model.py - Test Constrained Generation with Trained Model This tests our intensively trained model using constrained JSON generation to force valid outputs and solve the "Expecting ',' delimiter" issues. """ import torch import json import jsonschema from transformers import AutoTokenizer, AutoModelForCausalLM # from peft import PeftModel # Not needed for base model demo from typing import Dict, List import time def load_trained_model(): """Load our model - tries fine-tuned first, falls back to base model.""" print("๐Ÿ”„ Loading SmolLM3-3B Function-Calling Agent...") # Load base model base_model_name = "HuggingFaceTB/SmolLM3-3B" try: print("๐Ÿ”„ Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(base_model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("๐Ÿ”„ Loading base model...") # Use smaller data type for Hugging Face Spaces model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16, # Use float16 for better memory usage device_map="auto", low_cpu_mem_usage=True # Reduce memory usage during loading ) # Try to load fine-tuned adapter - local first, then Hub try: print("๐Ÿ”„ Attempting to load fine-tuned adapter locally...") from peft import PeftModel model = PeftModel.from_pretrained(model, "./smollm3_robust") model = model.merge_and_unload() print("โœ… Fine-tuned model loaded successfully from local files!") except Exception as e: try: print(f"โš ๏ธ Local adapter failed: {e}") print("๐Ÿ”„ Trying Hugging Face Hub...") model = PeftModel.from_pretrained(model, "jlov7/SmolLM3-Function-Calling-LoRA") model = model.merge_and_unload() print("โœ… Fine-tuned model loaded successfully from Hub!") except Exception as e2: print(f"โš ๏ธ Could not load fine-tuned adapter: {e2}") print("๐Ÿ”ง Using base model with optimized prompting") print("โœ… Model loaded successfully") return model, tokenizer except Exception as e: print(f"โŒ Error loading model: {e}") raise def constrained_json_generate(model, tokenizer, prompt: str, schema: Dict, max_attempts: int = 3): """Generate JSON with multiple attempts and validation.""" device = next(model.parameters()).device for attempt in range(max_attempts): try: # Generate with different temperatures for diversity temperature = 0.1 + (attempt * 0.1) inputs = tokenizer(prompt, return_tensors="pt").to(device) # Simple timeout protection using threading (cross-platform) import threading result = [None] error = [None] def generate_with_timeout(): try: with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=100, # Reduced for faster generation temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, num_return_sequences=1, use_cache=True ) # Extract generated text generated_ids = outputs[0][inputs['input_ids'].shape[1]:] response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() # Try to extract JSON from response if "{" in response and "}" in response: # Find the first complete JSON object start = response.find("{") bracket_count = 0 end = start for i, char in enumerate(response[start:], start): if char == "{": bracket_count += 1 elif char == "}": bracket_count -= 1 if bracket_count == 0: end = i + 1 break json_str = response[start:end] result[0] = json_str else: result[0] = response except Exception as e: error[0] = str(e) # Start generation in a separate thread with timeout thread = threading.Thread(target=generate_with_timeout) thread.daemon = True thread.start() thread.join(timeout=20) # 20-second timeout if thread.is_alive(): return "", False, f"Generation timed out (attempt {attempt + 1})" if error[0]: if attempt == max_attempts - 1: return "", False, f"Generation error: {error[0]}" continue if result[0]: # Validate JSON and schema try: parsed = json.loads(result[0]) jsonschema.validate(parsed, schema) return result[0], True, None except (json.JSONDecodeError, jsonschema.ValidationError) as e: if attempt == max_attempts - 1: return result[0], False, f"JSON validation failed: {str(e)}" continue except Exception as e: if attempt == max_attempts - 1: return "", False, f"Generation error: {str(e)}" continue return "", False, "All generation attempts failed" def create_test_schemas(): """Create the test schemas we're evaluating against.""" return { "weather_forecast": { "name": "get_weather_forecast", "description": "Get weather forecast", "parameters": { "type": "object", "properties": { "location": {"type": "string"}, "days": {"type": "integer"}, "units": {"type": "string"}, "include_hourly": {"type": "boolean"} }, "required": ["location", "days"] } }, "sentiment_analysis": { "name": "analyze_sentiment", "description": "Analyze text sentiment", "parameters": { "type": "object", "properties": { "text": {"type": "string"}, "language": {"type": "string"}, "include_emotions": {"type": "boolean"}, "confidence_threshold": {"type": "number"} }, "required": ["text"] } }, "currency_converter": { "name": "convert_currency", "description": "Convert currency amounts", "parameters": { "type": "object", "properties": { "amount": {"type": "number"}, "from_currency": {"type": "string"}, "to_currency": {"type": "string"}, "include_fees": {"type": "boolean"}, "precision": {"type": "integer"} }, "required": ["amount", "from_currency", "to_currency"] } } } def create_json_schema(function_def: Dict) -> Dict: """Create JSON schema for validation.""" return { "type": "object", "properties": { "name": { "type": "string", "const": function_def["name"] }, "arguments": function_def["parameters"] }, "required": ["name", "arguments"], "additionalProperties": False } def test_constrained_generation(): """Test constrained generation on our problem schemas.""" print("๐Ÿงช Testing Constrained Generation with Trained Model") print("=" * 60) # Load trained model model, tokenizer = load_trained_model() # Get test schemas schemas = create_test_schemas() test_cases = [ ("weather_forecast", "Get 3-day weather for San Francisco in metric units"), ("sentiment_analysis", "Analyze sentiment: The product was excellent and delivery was fast"), ("currency_converter", "Convert 500 USD to EUR with fees included"), ("weather_forecast", "Give me tomorrow's weather for London with hourly details"), ("sentiment_analysis", "Check sentiment for I am frustrated with this service"), ("currency_converter", "Convert 250 EUR to CAD using rates from 2023-12-01") ] results = {"passed": 0, "total": len(test_cases), "details": []} for schema_name, query in test_cases: print(f"\n๐ŸŽฏ Testing: {schema_name}") print(f"๐Ÿ“ Query: {query}") # Create prompt function_def = schemas[schema_name] schema = create_json_schema(function_def) prompt = f"""<|im_start|>system You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> {json.dumps(function_def, indent=2)} <|im_start|>user {query}<|im_end|> <|im_start|>assistant """ # Test constrained generation response, success, error = constrained_json_generate(model, tokenizer, prompt, schema) print(f"๐Ÿค– Response: {response}") if success: print("โœ… PASS - Valid JSON with correct schema!") results["passed"] += 1 else: print(f"โŒ FAIL - {error}") results["details"].append({ "schema": schema_name, "query": query, "response": response, "success": success, "error": error }) # Calculate success rate success_rate = (results["passed"] / results["total"]) * 100 print(f"\n๐Ÿ† CONSTRAINED GENERATION RESULTS") print("=" * 60) print(f"โœ… Passed: {results['passed']}/{results['total']} ({success_rate:.1f}%)") print(f"๐ŸŽฏ Target: โ‰ฅ80%") if success_rate >= 80: print("๐ŸŽ‰ SUCCESS! Reached 80%+ target with constrained generation!") else: print(f"๐Ÿ“ˆ Improvement needed: +{80 - success_rate:.1f}% to reach target") # Save results with open("constrained_results.json", "w") as f: json.dump({ "success_rate": success_rate, "passed": results["passed"], "total": results["total"], "details": results["details"], "timestamp": time.time() }, f, indent=2) print(f"๐Ÿ’พ Results saved to constrained_results.json") return success_rate if __name__ == "__main__": success_rate = test_constrained_generation()