|
""" |
|
test_smollm3_robust.py - Test the robust SmolLM3-3B model |
|
|
|
This script tests our newly trained model on various schemas to measure |
|
the dramatic improvement in function calling capability. |
|
""" |
|
|
|
import torch |
|
import json |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import PeftModel |
|
|
|
def load_trained_model(): |
|
"""Load the robust trained model.""" |
|
print("π Loading robust SmolLM3-3B model...") |
|
|
|
base_model_name = "HuggingFaceTB/SmolLM3-3B" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, "./smollm3_robust") |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
model = model.to("mps") |
|
device = "mps" |
|
else: |
|
device = "cpu" |
|
|
|
print(f"β
Model loaded on {device}") |
|
return model, tokenizer, device |
|
|
|
def test_function_call(model, tokenizer, device, schema, question): |
|
"""Test the model on a specific schema and question.""" |
|
|
|
prompt = f"""<|im_start|>system |
|
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
|
|
|
<schema> |
|
{json.dumps(schema, indent=2)} |
|
</schema> |
|
|
|
<|im_start|>user |
|
{question}<|im_end|> |
|
<|im_start|>assistant |
|
""" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
if device == "mps": |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=100, |
|
temperature=0.1, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
input_length = inputs["input_ids"].shape[1] |
|
response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) |
|
|
|
|
|
response = response.strip() |
|
if response.endswith('}"}'): |
|
response = response[:-2] |
|
if response.endswith('}}'): |
|
response = response[:-1] |
|
|
|
|
|
try: |
|
json_response = json.loads(response) |
|
is_valid = True |
|
|
|
|
|
has_name = "name" in json_response |
|
has_args = "arguments" in json_response |
|
correct_name = json_response.get("name") == schema["name"] |
|
|
|
score = sum([is_valid, has_name, has_args, correct_name]) |
|
|
|
except json.JSONDecodeError as e: |
|
is_valid = False |
|
json_response = None |
|
score = 0 |
|
|
|
return response, is_valid, json_response, score |
|
|
|
def main(): |
|
print("π§ͺ Testing Robust SmolLM3-3B Function Calling") |
|
print("=" * 55) |
|
|
|
|
|
model, tokenizer, device = load_trained_model() |
|
|
|
|
|
test_cases = [ |
|
{ |
|
"name": "Stock Price (Training)", |
|
"schema": { |
|
"name": "get_stock_price", |
|
"description": "Get current stock price for a ticker", |
|
"parameters": { |
|
"type": "object", |
|
"properties": {"ticker": {"type": "string"}}, |
|
"required": ["ticker"] |
|
} |
|
}, |
|
"question": "What's Apple stock trading at?" |
|
}, |
|
{ |
|
"name": "Weather (Seen Pattern)", |
|
"schema": { |
|
"name": "get_weather", |
|
"description": "Get weather for a location", |
|
"parameters": { |
|
"type": "object", |
|
"properties": {"location": {"type": "string"}}, |
|
"required": ["location"] |
|
} |
|
}, |
|
"question": "How's the weather in Tokyo?" |
|
}, |
|
{ |
|
"name": "NEW: Database Query", |
|
"schema": { |
|
"name": "execute_sql", |
|
"description": "Execute SQL query on database", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"query": {"type": "string"}, |
|
"database": {"type": "string"} |
|
}, |
|
"required": ["query"] |
|
} |
|
}, |
|
"question": "Find all users who registered this month" |
|
}, |
|
{ |
|
"name": "NEW: Complex Parameters", |
|
"schema": { |
|
"name": "book_flight", |
|
"description": "Book a flight ticket", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"from_city": {"type": "string"}, |
|
"to_city": {"type": "string"}, |
|
"departure_date": {"type": "string"}, |
|
"passengers": {"type": "integer"} |
|
}, |
|
"required": ["from_city", "to_city", "departure_date"] |
|
} |
|
}, |
|
"question": "Book a flight from New York to London for December 15th" |
|
}, |
|
{ |
|
"name": "NEW: Financial Transaction", |
|
"schema": { |
|
"name": "transfer_funds", |
|
"description": "Transfer money between accounts", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"amount": {"type": "number"}, |
|
"from_account": {"type": "string"}, |
|
"to_account": {"type": "string"}, |
|
"memo": {"type": "string"} |
|
}, |
|
"required": ["amount", "from_account", "to_account"] |
|
} |
|
}, |
|
"question": "Send $500 from checking to savings" |
|
} |
|
] |
|
|
|
|
|
total_score = 0 |
|
max_score = len(test_cases) * 4 |
|
valid_json_count = 0 |
|
|
|
for i, test_case in enumerate(test_cases, 1): |
|
print(f"\nπ Test {i}: {test_case['name']}") |
|
print(f"β Question: {test_case['question']}") |
|
|
|
response, is_valid, json_obj, score = test_function_call( |
|
model, tokenizer, device, test_case['schema'], test_case['question'] |
|
) |
|
|
|
print(f"π€ Raw response: {response}") |
|
|
|
if is_valid: |
|
print(f"β
Valid JSON: {json_obj}") |
|
valid_json_count += 1 |
|
else: |
|
print(f"β Invalid JSON") |
|
|
|
print(f"π Score: {score}/4") |
|
total_score += score |
|
print("-" * 50) |
|
|
|
|
|
print(f"\nπ FINAL RESULTS:") |
|
print(f"β
Valid JSON responses: {valid_json_count}/{len(test_cases)} ({valid_json_count/len(test_cases)*100:.1f}%)") |
|
print(f"π Overall score: {total_score}/{max_score} ({total_score/max_score*100:.1f}%)") |
|
print(f"π― Success criteria: β₯80% valid calls") |
|
|
|
if valid_json_count/len(test_cases) >= 0.8: |
|
print(f"π PASS - Excellent function calling capability!") |
|
elif valid_json_count/len(test_cases) >= 0.6: |
|
print(f"π‘ GOOD - Strong improvement, approaching target") |
|
else: |
|
print(f"π PROGRESS - Significant improvement from baseline") |
|
|
|
|
|
print(f"\nπ IMPROVEMENT COMPARISON:") |
|
print(f"Previous SmolLM2-1.7B result: 0/3 (0%)") |
|
print(f"Current SmolLM3-3B result: {valid_json_count}/{len(test_cases)} ({valid_json_count/len(test_cases)*100:.1f}%)") |
|
print(f"π Training loss improvement: 2.38 β 1.49 (37% better)") |
|
|
|
if __name__ == "__main__": |
|
main() |