|
""" |
|
test_model.py - Test our trained dynamic function-calling agent |
|
|
|
This script loads the trained LoRA adapter and tests it on various schemas |
|
to demonstrate zero-shot function calling capability. |
|
""" |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import PeftModel |
|
import json |
|
|
|
def load_trained_model(): |
|
"""Load the base model and trained adapter.""" |
|
print("π Loading trained model...") |
|
|
|
|
|
base_model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, "./smollm_tool_adapter/checkpoint-6") |
|
|
|
print("β
Model loaded successfully!") |
|
return model, tokenizer |
|
|
|
def test_function_call(model, tokenizer, schema, question): |
|
"""Test the model on a specific schema and question.""" |
|
|
|
prompt = f"""<|im_start|>system |
|
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
|
|
|
<schema> |
|
{json.dumps(schema, indent=2)} |
|
</schema> |
|
|
|
<|im_start|>user |
|
{question}<|im_end|> |
|
<|im_start|>assistant |
|
""" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=100, |
|
temperature=0.1, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) |
|
|
|
|
|
try: |
|
json_response = json.loads(response.strip()) |
|
is_valid_json = True |
|
except: |
|
is_valid_json = False |
|
json_response = None |
|
|
|
return response.strip(), is_valid_json, json_response |
|
|
|
def main(): |
|
print("π§ͺ Testing Dynamic Function-Calling Agent") |
|
print("=" * 50) |
|
|
|
|
|
model, tokenizer = load_trained_model() |
|
|
|
|
|
test_cases = [ |
|
{ |
|
"name": "Trained Schema: Stock Price", |
|
"schema": { |
|
"name": "get_stock_price", |
|
"description": "Return the latest price for a given ticker symbol.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"ticker": {"type": "string"} |
|
}, |
|
"required": ["ticker"] |
|
} |
|
}, |
|
"question": "What's Microsoft trading at?" |
|
}, |
|
{ |
|
"name": "NEW Schema: Database Query", |
|
"schema": { |
|
"name": "query_database", |
|
"description": "Execute a SQL query on the database.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query": {"type": "string"}, |
|
"timeout": {"type": "number"} |
|
}, |
|
"required": ["query"] |
|
} |
|
}, |
|
"question": "Find all users who signed up last week" |
|
}, |
|
{ |
|
"name": "NEW Schema: File Operations", |
|
"schema": { |
|
"name": "create_file", |
|
"description": "Create a new file with content.", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"filename": {"type": "string"}, |
|
"content": {"type": "string"}, |
|
"overwrite": {"type": "boolean"} |
|
}, |
|
"required": ["filename", "content"] |
|
} |
|
}, |
|
"question": "Create a file called report.txt with the content 'Meeting notes'" |
|
} |
|
] |
|
|
|
|
|
valid_count = 0 |
|
total_count = len(test_cases) |
|
|
|
for i, test_case in enumerate(test_cases, 1): |
|
print(f"\nπ Test {i}: {test_case['name']}") |
|
print(f"β Question: {test_case['question']}") |
|
|
|
response, is_valid, json_obj = test_function_call( |
|
model, tokenizer, test_case['schema'], test_case['question'] |
|
) |
|
|
|
print(f"π€ Model response: {response}") |
|
|
|
if is_valid: |
|
print(f"β
Valid JSON: {json_obj}") |
|
valid_count += 1 |
|
else: |
|
print(f"β Invalid JSON") |
|
|
|
print("-" * 40) |
|
|
|
|
|
print(f"\nπ Results Summary:") |
|
print(f"β
Valid JSON responses: {valid_count}/{total_count} ({valid_count/total_count*100:.1f}%)") |
|
print(f"π― Success criteria: β₯80% valid calls") |
|
print(f"π Result: {'PASS' if valid_count/total_count >= 0.8 else 'NEEDS IMPROVEMENT'}") |
|
|
|
if __name__ == "__main__": |
|
main() |