File size: 5,408 Bytes
6639f75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
test_model.py - Test our trained dynamic function-calling agent

This script loads the trained LoRA adapter and tests it on various schemas
to demonstrate zero-shot function calling capability.
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import json

def load_trained_model():
    """Load the base model and trained adapter."""
    print("πŸ”„ Loading trained model...")
    
    # Load base model and tokenizer
    base_model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None,
        trust_remote_code=True
    )
    
    # Load the trained adapter
    model = PeftModel.from_pretrained(base_model, "./smollm_tool_adapter/checkpoint-6")
    
    print("βœ… Model loaded successfully!")
    return model, tokenizer

def test_function_call(model, tokenizer, schema, question):
    """Test the model on a specific schema and question."""
    
    prompt = f"""<|im_start|>system
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|>

<schema>
{json.dumps(schema, indent=2)}
</schema>

<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
"""
    
    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # Decode response
    response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
    
    # Try to parse as JSON to validate
    try:
        json_response = json.loads(response.strip())
        is_valid_json = True
    except:
        is_valid_json = False
        json_response = None
    
    return response.strip(), is_valid_json, json_response

def main():
    print("πŸ§ͺ Testing Dynamic Function-Calling Agent")
    print("=" * 50)
    
    # Load the trained model
    model, tokenizer = load_trained_model()
    
    # Test cases - mix of training and new schemas
    test_cases = [
        {
            "name": "Trained Schema: Stock Price",
            "schema": {
                "name": "get_stock_price",
                "description": "Return the latest price for a given ticker symbol.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "ticker": {"type": "string"}
                    },
                    "required": ["ticker"]
                }
            },
            "question": "What's Microsoft trading at?"
        },
        {
            "name": "NEW Schema: Database Query", 
            "schema": {
                "name": "query_database",
                "description": "Execute a SQL query on the database.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "query": {"type": "string"},
                        "timeout": {"type": "number"}
                    },
                    "required": ["query"]
                }
            },
            "question": "Find all users who signed up last week"
        },
        {
            "name": "NEW Schema: File Operations",
            "schema": {
                "name": "create_file",
                "description": "Create a new file with content.",
                "parameters": {
                    "type": "object", 
                    "properties": {
                        "filename": {"type": "string"},
                        "content": {"type": "string"},
                        "overwrite": {"type": "boolean"}
                    },
                    "required": ["filename", "content"]
                }
            },
            "question": "Create a file called report.txt with the content 'Meeting notes'"
        }
    ]
    
    # Run tests
    valid_count = 0
    total_count = len(test_cases)
    
    for i, test_case in enumerate(test_cases, 1):
        print(f"\nπŸ“‹ Test {i}: {test_case['name']}")
        print(f"❓ Question: {test_case['question']}")
        
        response, is_valid, json_obj = test_function_call(
            model, tokenizer, test_case['schema'], test_case['question']
        )
        
        print(f"πŸ€– Model response: {response}")
        
        if is_valid:
            print(f"βœ… Valid JSON: {json_obj}")
            valid_count += 1
        else:
            print(f"❌ Invalid JSON")
        
        print("-" * 40)
    
    # Summary
    print(f"\nπŸ“Š Results Summary:")
    print(f"βœ… Valid JSON responses: {valid_count}/{total_count} ({valid_count/total_count*100:.1f}%)")
    print(f"🎯 Success criteria: β‰₯80% valid calls")
    print(f"πŸ† Result: {'PASS' if valid_count/total_count >= 0.8 else 'NEEDS IMPROVEMENT'}")

if __name__ == "__main__":
    main()