|
""" |
|
test_constrained_model_spaces.py - SPACES-OPTIMIZED Constrained Generation |
|
|
|
Ultra-aggressive optimization for Hugging Face Spaces environment |
|
""" |
|
|
|
import torch |
|
import json |
|
import jsonschema |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from typing import Dict |
|
import time |
|
import threading |
|
|
|
class TimeoutException(Exception): |
|
pass |
|
|
|
def load_trained_model(): |
|
"""Load our model - SPACES OPTIMIZED""" |
|
print("π Loading SmolLM3-3B Function-Calling Agent...") |
|
|
|
base_model_name = "HuggingFaceTB/SmolLM3-3B" |
|
|
|
try: |
|
print("π Loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
print("π Loading base model...") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
adapter_paths = [ |
|
"jlov7/SmolLM3-Function-Calling-LoRA", |
|
"./model_files", |
|
"./smollm3_robust", |
|
"./hub_upload", |
|
] |
|
|
|
model_loaded = False |
|
for i, adapter_path in enumerate(adapter_paths): |
|
try: |
|
if i == 0: |
|
print("π Loading fine-tuned adapter from Hugging Face Hub...") |
|
else: |
|
print(f"π Trying local path: {adapter_path}") |
|
|
|
from peft import PeftModel |
|
model = PeftModel.from_pretrained(model, adapter_path) |
|
model = model.merge_and_unload() |
|
|
|
if i == 0: |
|
print("β
Fine-tuned model loaded successfully from Hub!") |
|
else: |
|
print(f"β
Fine-tuned model loaded successfully from {adapter_path}!") |
|
model_loaded = True |
|
break |
|
|
|
except Exception as e: |
|
if i == 0: |
|
print(f"β οΈ Hub adapter not found: {e}") |
|
else: |
|
print(f"β οΈ Path {adapter_path} failed: {e}") |
|
continue |
|
|
|
if not model_loaded: |
|
print("π§ Using base model with optimized prompting") |
|
|
|
print("β
Model loaded successfully") |
|
return model, tokenizer |
|
|
|
except Exception as e: |
|
print(f"β Error loading model: {e}") |
|
raise |
|
|
|
def constrained_json_generate(model, tokenizer, prompt: str, schema: Dict, max_attempts: int = 2): |
|
"""SPACES-OPTIMIZED generation with aggressive timeouts""" |
|
device = next(model.parameters()).device |
|
|
|
for attempt in range(max_attempts): |
|
try: |
|
|
|
temperature = 0.1 + (attempt * 0.2) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
result = [None] |
|
error = [None] |
|
|
|
def generate_with_timeout(): |
|
try: |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=25, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
num_return_sequences=1, |
|
use_cache=True, |
|
repetition_penalty=1.2 |
|
) |
|
result[0] = outputs |
|
except Exception as e: |
|
error[0] = str(e) |
|
|
|
|
|
thread = threading.Thread(target=generate_with_timeout) |
|
thread.daemon = True |
|
thread.start() |
|
thread.join(timeout=4) |
|
|
|
if thread.is_alive(): |
|
return "", False, f"Generation timed out (attempt {attempt + 1})" |
|
|
|
if error[0]: |
|
return "", False, f"Generation error: {error[0]}" |
|
|
|
if result[0] is None: |
|
return "", False, f"Generation failed (attempt {attempt + 1})" |
|
|
|
outputs = result[0] |
|
|
|
|
|
generated_ids = outputs[0][inputs['input_ids'].shape[1]:] |
|
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
if "{" in response and "}" in response: |
|
start = response.find("{") |
|
bracket_count = 0 |
|
end = start |
|
|
|
for i, char in enumerate(response[start:], start): |
|
if char == "{": |
|
bracket_count += 1 |
|
elif char == "}": |
|
bracket_count -= 1 |
|
if bracket_count == 0: |
|
end = i + 1 |
|
break |
|
|
|
json_str = response[start:end] |
|
else: |
|
json_str = response |
|
|
|
|
|
try: |
|
parsed = json.loads(json_str) |
|
jsonschema.validate(parsed, schema) |
|
return json_str, True, None |
|
except (json.JSONDecodeError, jsonschema.ValidationError) as e: |
|
if attempt == max_attempts - 1: |
|
return json_str, False, f"JSON validation failed: {str(e)}" |
|
continue |
|
|
|
except Exception as e: |
|
if attempt == max_attempts - 1: |
|
return "", False, f"Generation error: {str(e)}" |
|
continue |
|
|
|
return "", False, "All generation attempts failed" |
|
|
|
def create_json_schema(function_def: Dict) -> Dict: |
|
"""Create JSON schema for function definition""" |
|
return { |
|
"type": "object", |
|
"properties": { |
|
"name": { |
|
"type": "string", |
|
"enum": [function_def["name"]] |
|
}, |
|
"arguments": function_def["parameters"] |
|
}, |
|
"required": ["name", "arguments"] |
|
} |
|
|
|
def create_test_schemas(): |
|
"""Create simplified test schemas""" |
|
return { |
|
"weather_forecast": { |
|
"name": "get_weather_forecast", |
|
"description": "Get weather forecast", |
|
"parameters": { |
|
"type": "object", |
|
"properties": { |
|
"location": {"type": "string"}, |
|
"days": {"type": "integer"} |
|
}, |
|
"required": ["location", "days"] |
|
} |
|
} |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
print("π§ͺ Testing SPACES-optimized model...") |
|
try: |
|
model, tokenizer = load_trained_model() |
|
|
|
test_schema = create_test_schemas()["weather_forecast"] |
|
schema = create_json_schema(test_schema) |
|
|
|
prompt = """<|im_start|>system |
|
You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
|
|
|
<schema> |
|
{"name": "get_weather_forecast", "description": "Get weather forecast", "parameters": {"type": "object", "properties": {"location": {"type": "string"}, "days": {"type": "integer"}}, "required": ["location", "days"]}} |
|
</schema> |
|
|
|
<|im_start|>user |
|
Get weather for Tokyo for 5 days<|im_end|> |
|
<|im_start|>assistant |
|
""" |
|
|
|
result, success, error = constrained_json_generate(model, tokenizer, prompt, schema) |
|
print(f"β
Result: {result}") |
|
print(f"β
Success: {success}") |
|
if error: |
|
print(f"β οΈ Error: {error}") |
|
|
|
except Exception as e: |
|
print(f"β Test failed: {e}") |