Dynamic-Function-Calling-Agent / test_constrained_model.py
jlov7's picture
feat: Multi-tool selection and robustness testing
6639f75
raw
history blame
8.15 kB
"""
test_constrained_model.py - Test Constrained Generation with Trained Model
This tests our intensively trained model using constrained JSON generation
to force valid outputs and solve the "Expecting ',' delimiter" issues.
"""
import torch
import json
import jsonschema
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from typing import Dict, List
import time
def load_trained_model():
"""Load our intensively trained model."""
print("πŸ”„ Loading intensively trained SmolLM3-3B...")
# Load base model
base_model_name = "HuggingFaceTB/SmolLM3-3B"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float32,
device_map="mps" if torch.backends.mps.is_available() else "auto"
)
# Load LoRA weights
print("πŸ”§ Loading LoRA adapter...")
model = PeftModel.from_pretrained(model, "./smollm3_robust")
model = model.merge_and_unload() # Merge for faster inference
print("βœ… Trained model loaded successfully")
return model, tokenizer
def constrained_json_generate(model, tokenizer, prompt: str, schema: Dict, max_attempts: int = 3):
"""Generate JSON with multiple attempts and validation."""
device = next(model.parameters()).device
for attempt in range(max_attempts):
# Generate with different temperatures for diversity
temperature = 0.1 + (attempt * 0.1)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode response
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
).strip()
# Try to parse as JSON
try:
parsed = json.loads(response)
# Validate against schema if provided
if schema:
jsonschema.validate(parsed, schema)
return response, True, None
except json.JSONDecodeError as e:
if attempt == max_attempts - 1:
return response, False, str(e)
except jsonschema.ValidationError as e:
if attempt == max_attempts - 1:
return response, False, f"Schema validation: {str(e)}"
return response, False, "Max attempts exceeded"
def create_test_schemas():
"""Create the test schemas we're evaluating against."""
return {
"weather_forecast": {
"name": "get_weather_forecast",
"description": "Get weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
"days": {"type": "integer"},
"units": {"type": "string"},
"include_hourly": {"type": "boolean"}
},
"required": ["location", "days"]
}
},
"sentiment_analysis": {
"name": "analyze_sentiment",
"description": "Analyze text sentiment",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string"},
"language": {"type": "string"},
"include_emotions": {"type": "boolean"},
"confidence_threshold": {"type": "number"}
},
"required": ["text"]
}
},
"currency_converter": {
"name": "convert_currency",
"description": "Convert currency amounts",
"parameters": {
"type": "object",
"properties": {
"amount": {"type": "number"},
"from_currency": {"type": "string"},
"to_currency": {"type": "string"},
"include_fees": {"type": "boolean"},
"precision": {"type": "integer"}
},
"required": ["amount", "from_currency", "to_currency"]
}
}
}
def create_json_schema(function_def: Dict) -> Dict:
"""Create JSON schema for validation."""
return {
"type": "object",
"properties": {
"name": {
"type": "string",
"const": function_def["name"]
},
"arguments": function_def["parameters"]
},
"required": ["name", "arguments"],
"additionalProperties": False
}
def test_constrained_generation():
"""Test constrained generation on our problem schemas."""
print("πŸ§ͺ Testing Constrained Generation with Trained Model")
print("=" * 60)
# Load trained model
model, tokenizer = load_trained_model()
# Get test schemas
schemas = create_test_schemas()
test_cases = [
("weather_forecast", "Get 3-day weather for San Francisco in metric units"),
("sentiment_analysis", "Analyze sentiment: The product was excellent and delivery was fast"),
("currency_converter", "Convert 500 USD to EUR with fees included"),
("weather_forecast", "Give me tomorrow's weather for London with hourly details"),
("sentiment_analysis", "Check sentiment for I am frustrated with this service"),
("currency_converter", "Convert 250 EUR to CAD using rates from 2023-12-01")
]
results = {"passed": 0, "total": len(test_cases), "details": []}
for schema_name, query in test_cases:
print(f"\n🎯 Testing: {schema_name}")
print(f"πŸ“ Query: {query}")
# Create prompt
function_def = schemas[schema_name]
schema = create_json_schema(function_def)
prompt = f"""<|im_start|>system
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|>
<schema>
{json.dumps(function_def, indent=2)}
</schema>
<|im_start|>user
{query}<|im_end|>
<|im_start|>assistant
"""
# Test constrained generation
response, success, error = constrained_json_generate(model, tokenizer, prompt, schema)
print(f"πŸ€– Response: {response}")
if success:
print("βœ… PASS - Valid JSON with correct schema!")
results["passed"] += 1
else:
print(f"❌ FAIL - {error}")
results["details"].append({
"schema": schema_name,
"query": query,
"response": response,
"success": success,
"error": error
})
# Calculate success rate
success_rate = (results["passed"] / results["total"]) * 100
print(f"\nπŸ† CONSTRAINED GENERATION RESULTS")
print("=" * 60)
print(f"βœ… Passed: {results['passed']}/{results['total']} ({success_rate:.1f}%)")
print(f"🎯 Target: β‰₯80%")
if success_rate >= 80:
print("πŸŽ‰ SUCCESS! Reached 80%+ target with constrained generation!")
else:
print(f"πŸ“ˆ Improvement needed: +{80 - success_rate:.1f}% to reach target")
# Save results
with open("constrained_results.json", "w") as f:
json.dump({
"success_rate": success_rate,
"passed": results["passed"],
"total": results["total"],
"details": results["details"],
"timestamp": time.time()
}, f, indent=2)
print(f"πŸ’Ύ Results saved to constrained_results.json")
return success_rate
if __name__ == "__main__":
success_rate = test_constrained_generation()