jlov7 commited on
Commit
1b5bd3c
Β·
1 Parent(s): 5410dc5

fix: add timeout protection and optimize inference for HF Spaces

Browse files
Files changed (1) hide show
  1. test_constrained_model.py +117 -51
test_constrained_model.py CHANGED
@@ -14,70 +14,136 @@ from typing import Dict, List
14
  import time
15
 
16
  def load_trained_model():
17
- """Load our intensively trained model."""
18
- print("πŸ”„ Loading SmolLM3-3B (base model for demo)...")
19
 
20
  # Load base model
21
  base_model_name = "HuggingFaceTB/SmolLM3-3B"
22
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
23
- if tokenizer.pad_token is None:
24
- tokenizer.pad_token = tokenizer.eos_token
25
 
26
- model = AutoModelForCausalLM.from_pretrained(
27
- base_model_name,
28
- torch_dtype=torch.float32,
29
- device_map="mps" if torch.backends.mps.is_available() else "auto"
30
- )
31
-
32
- # Note: Using base model for demo (LoRA adapter not included to keep repo size small)
33
- print("πŸ”§ Using base model (LoRA adapter excluded for size constraints)...")
34
- # For production deployment, upload LoRA adapter to HF Hub and load from there
35
-
36
- print("βœ… Trained model loaded successfully")
37
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def constrained_json_generate(model, tokenizer, prompt: str, schema: Dict, max_attempts: int = 3):
40
  """Generate JSON with multiple attempts and validation."""
41
  device = next(model.parameters()).device
42
 
43
  for attempt in range(max_attempts):
44
- # Generate with different temperatures for diversity
45
- temperature = 0.1 + (attempt * 0.1)
46
-
47
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
48
-
49
- with torch.no_grad():
50
- outputs = model.generate(
51
- **inputs,
52
- max_new_tokens=200,
53
- temperature=temperature,
54
- do_sample=True,
55
- top_p=0.9,
56
- pad_token_id=tokenizer.eos_token_id,
57
- eos_token_id=tokenizer.eos_token_id
58
- )
59
-
60
- # Decode response
61
- response = tokenizer.decode(
62
- outputs[0][inputs['input_ids'].shape[1]:],
63
- skip_special_tokens=True
64
- ).strip()
65
-
66
- # Try to parse as JSON
67
  try:
68
- parsed = json.loads(response)
69
- # Validate against schema if provided
70
- if schema:
71
- jsonschema.validate(parsed, schema)
72
- return response, True, None
73
- except json.JSONDecodeError as e:
74
- if attempt == max_attempts - 1:
75
- return response, False, str(e)
76
- except jsonschema.ValidationError as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  if attempt == max_attempts - 1:
78
- return response, False, f"Schema validation: {str(e)}"
 
79
 
80
- return response, False, "Max attempts exceeded"
81
 
82
  def create_test_schemas():
83
  """Create the test schemas we're evaluating against."""
 
14
  import time
15
 
16
  def load_trained_model():
17
+ """Load our model - tries fine-tuned first, falls back to base model."""
18
+ print("πŸ”„ Loading SmolLM3-3B Function-Calling Agent...")
19
 
20
  # Load base model
21
  base_model_name = "HuggingFaceTB/SmolLM3-3B"
 
 
 
22
 
23
+ try:
24
+ print("πŸ”„ Loading tokenizer...")
25
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
26
+ if tokenizer.pad_token is None:
27
+ tokenizer.pad_token = tokenizer.eos_token
28
+
29
+ print("πŸ”„ Loading base model...")
30
+ # Use smaller data type for Hugging Face Spaces
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ base_model_name,
33
+ torch_dtype=torch.float16, # Use float16 for better memory usage
34
+ device_map="auto",
35
+ low_cpu_mem_usage=True # Reduce memory usage during loading
36
+ )
37
+
38
+ # Try to load fine-tuned adapter from Hugging Face Hub
39
+ try:
40
+ print("πŸ”„ Attempting to load fine-tuned adapter...")
41
+ # from peft import PeftModel # Uncomment when adapter is available
42
+ # model = PeftModel.from_pretrained(model, "jlov7/SmolLM3-Function-Calling-LoRA")
43
+ # model = model.merge_and_unload()
44
+ # print("βœ… Fine-tuned model loaded successfully!")
45
+ print("πŸ”§ Fine-tuned adapter not yet available - using base model with optimized prompting")
46
+ except Exception as e:
47
+ print(f"⚠️ Could not load fine-tuned adapter: {e}")
48
+ print("πŸ”§ Using base model with optimized prompting")
49
+
50
+ print("βœ… Model loaded successfully")
51
+ return model, tokenizer
52
+
53
+ except Exception as e:
54
+ print(f"❌ Error loading model: {e}")
55
+ raise
56
 
57
  def constrained_json_generate(model, tokenizer, prompt: str, schema: Dict, max_attempts: int = 3):
58
  """Generate JSON with multiple attempts and validation."""
59
  device = next(model.parameters()).device
60
 
61
  for attempt in range(max_attempts):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
+ # Generate with different temperatures for diversity
64
+ temperature = 0.1 + (attempt * 0.1)
65
+
66
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
67
+
68
+ # Simple timeout protection using threading (cross-platform)
69
+ import threading
70
+
71
+ result = [None]
72
+ error = [None]
73
+
74
+ def generate_with_timeout():
75
+ try:
76
+ with torch.no_grad():
77
+ outputs = model.generate(
78
+ **inputs,
79
+ max_new_tokens=100, # Reduced for faster generation
80
+ temperature=temperature,
81
+ do_sample=True,
82
+ pad_token_id=tokenizer.eos_token_id,
83
+ eos_token_id=tokenizer.eos_token_id,
84
+ num_return_sequences=1,
85
+ use_cache=True
86
+ )
87
+
88
+ # Extract generated text
89
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
90
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
91
+
92
+ # Try to extract JSON from response
93
+ if "{" in response and "}" in response:
94
+ # Find the first complete JSON object
95
+ start = response.find("{")
96
+ bracket_count = 0
97
+ end = start
98
+
99
+ for i, char in enumerate(response[start:], start):
100
+ if char == "{":
101
+ bracket_count += 1
102
+ elif char == "}":
103
+ bracket_count -= 1
104
+ if bracket_count == 0:
105
+ end = i + 1
106
+ break
107
+
108
+ json_str = response[start:end]
109
+ result[0] = json_str
110
+ else:
111
+ result[0] = response
112
+
113
+ except Exception as e:
114
+ error[0] = str(e)
115
+
116
+ # Start generation in a separate thread with timeout
117
+ thread = threading.Thread(target=generate_with_timeout)
118
+ thread.daemon = True
119
+ thread.start()
120
+ thread.join(timeout=20) # 20-second timeout
121
+
122
+ if thread.is_alive():
123
+ return "", False, f"Generation timed out (attempt {attempt + 1})"
124
+
125
+ if error[0]:
126
+ if attempt == max_attempts - 1:
127
+ return "", False, f"Generation error: {error[0]}"
128
+ continue
129
+
130
+ if result[0]:
131
+ # Validate JSON and schema
132
+ try:
133
+ parsed = json.loads(result[0])
134
+ jsonschema.validate(parsed, schema)
135
+ return result[0], True, None
136
+ except (json.JSONDecodeError, jsonschema.ValidationError) as e:
137
+ if attempt == max_attempts - 1:
138
+ return result[0], False, f"JSON validation failed: {str(e)}"
139
+ continue
140
+
141
+ except Exception as e:
142
  if attempt == max_attempts - 1:
143
+ return "", False, f"Generation error: {str(e)}"
144
+ continue
145
 
146
+ return "", False, "All generation attempts failed"
147
 
148
  def create_test_schemas():
149
  """Create the test schemas we're evaluating against."""