jlov7's picture
feat: Multi-tool selection and robustness testing
6639f75
raw
history blame
5.41 kB
"""
test_model.py - Test our trained dynamic function-calling agent
This script loads the trained LoRA adapter and tests it on various schemas
to demonstrate zero-shot function calling capability.
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import json
def load_trained_model():
"""Load the base model and trained adapter."""
print("πŸ”„ Loading trained model...")
# Load base model and tokenizer
base_model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Load the trained adapter
model = PeftModel.from_pretrained(base_model, "./smollm_tool_adapter/checkpoint-6")
print("βœ… Model loaded successfully!")
return model, tokenizer
def test_function_call(model, tokenizer, schema, question):
"""Test the model on a specific schema and question."""
prompt = f"""<|im_start|>system
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|>
<schema>
{json.dumps(schema, indent=2)}
</schema>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
"""
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode response
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
# Try to parse as JSON to validate
try:
json_response = json.loads(response.strip())
is_valid_json = True
except:
is_valid_json = False
json_response = None
return response.strip(), is_valid_json, json_response
def main():
print("πŸ§ͺ Testing Dynamic Function-Calling Agent")
print("=" * 50)
# Load the trained model
model, tokenizer = load_trained_model()
# Test cases - mix of training and new schemas
test_cases = [
{
"name": "Trained Schema: Stock Price",
"schema": {
"name": "get_stock_price",
"description": "Return the latest price for a given ticker symbol.",
"parameters": {
"type": "object",
"properties": {
"ticker": {"type": "string"}
},
"required": ["ticker"]
}
},
"question": "What's Microsoft trading at?"
},
{
"name": "NEW Schema: Database Query",
"schema": {
"name": "query_database",
"description": "Execute a SQL query on the database.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"timeout": {"type": "number"}
},
"required": ["query"]
}
},
"question": "Find all users who signed up last week"
},
{
"name": "NEW Schema: File Operations",
"schema": {
"name": "create_file",
"description": "Create a new file with content.",
"parameters": {
"type": "object",
"properties": {
"filename": {"type": "string"},
"content": {"type": "string"},
"overwrite": {"type": "boolean"}
},
"required": ["filename", "content"]
}
},
"question": "Create a file called report.txt with the content 'Meeting notes'"
}
]
# Run tests
valid_count = 0
total_count = len(test_cases)
for i, test_case in enumerate(test_cases, 1):
print(f"\nπŸ“‹ Test {i}: {test_case['name']}")
print(f"❓ Question: {test_case['question']}")
response, is_valid, json_obj = test_function_call(
model, tokenizer, test_case['schema'], test_case['question']
)
print(f"πŸ€– Model response: {response}")
if is_valid:
print(f"βœ… Valid JSON: {json_obj}")
valid_count += 1
else:
print(f"❌ Invalid JSON")
print("-" * 40)
# Summary
print(f"\nπŸ“Š Results Summary:")
print(f"βœ… Valid JSON responses: {valid_count}/{total_count} ({valid_count/total_count*100:.1f}%)")
print(f"🎯 Success criteria: β‰₯80% valid calls")
print(f"πŸ† Result: {'PASS' if valid_count/total_count >= 0.8 else 'NEEDS IMPROVEMENT'}")
if __name__ == "__main__":
main()