Yongdong Wang commited on
Commit
792bd1c
Β·
1 Parent(s): 83496e3

Increase multi-model support.

Browse files
Files changed (1) hide show
  1. app.py +78 -31
app.py CHANGED
@@ -19,45 +19,74 @@ import warnings
19
  import os
20
  warnings.filterwarnings("ignore")
21
 
22
- # Model configuration
23
- MODEL_NAME = "meta-llama/Llama-3.1-8B"
24
- LORA_MODEL = "YongdongWang/llama3.1-8b-lora-qlora-dart-llm"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Global variables to store model and tokenizer
27
  model = None
28
  tokenizer = None
 
29
  model_loaded = False
30
 
31
- def load_model_and_tokenizer():
32
  """Load tokenizer - executed on CPU"""
33
- global tokenizer, model_loaded
34
 
35
- if model_loaded:
36
  return
37
 
38
- print("πŸ”„ Loading tokenizer...")
39
 
40
  # Load tokenizer (on CPU)
 
41
  tokenizer = AutoTokenizer.from_pretrained(
42
- MODEL_NAME,
43
  use_fast=False,
44
  trust_remote_code=True
45
  )
46
  if tokenizer.pad_token is None:
47
  tokenizer.pad_token = tokenizer.eos_token
48
 
 
49
  model_loaded = True
50
  print("βœ… Tokenizer loaded successfully!")
51
 
52
  @spaces.GPU(duration=60) # Request GPU for loading model at startup
53
- def load_model_on_gpu():
54
  """Load model on GPU"""
55
  global model
56
 
57
- if model is not None:
 
58
  return model
59
 
60
- print("πŸ”„ Loading model on GPU...")
 
 
 
 
 
 
 
 
61
 
62
  try:
63
  # 4-bit quantization configuration
@@ -70,7 +99,7 @@ def load_model_on_gpu():
70
 
71
  # Load base model
72
  base_model = AutoModelForCausalLM.from_pretrained(
73
- MODEL_NAME,
74
  quantization_config=bnb_config,
75
  device_map="auto",
76
  torch_dtype=torch.float16,
@@ -82,13 +111,13 @@ def load_model_on_gpu():
82
  # Load LoRA adapter
83
  model = PeftModel.from_pretrained(
84
  base_model,
85
- LORA_MODEL,
86
  torch_dtype=torch.float16,
87
  use_safetensors=True
88
  )
89
  model.eval()
90
 
91
- print("βœ… Model loaded on GPU successfully!")
92
  return model
93
 
94
  except Exception as load_error:
@@ -122,17 +151,17 @@ def process_json_in_response(response):
122
  return response
123
 
124
  @spaces.GPU(duration=60) # GPU inference
125
- def generate_response_gpu(prompt, max_tokens=512):
126
  """Generate response - executed on GPU"""
127
  global model
128
 
129
  # Ensure tokenizer is loaded
130
- if tokenizer is None:
131
- load_model_and_tokenizer()
132
 
133
  # Ensure model is loaded on GPU
134
- if model is None:
135
- model = load_model_on_gpu()
136
 
137
  if model is None:
138
  return "❌ Model failed to load. Please check the Space logs."
@@ -184,18 +213,18 @@ def generate_response_gpu(prompt, max_tokens=512):
184
  except Exception as generation_error:
185
  return f"❌ Generation Error: {str(generation_error)}"
186
 
187
- def chat_interface(message, history, max_tokens):
188
  """Chat interface - runs on CPU, calls GPU functions"""
189
  if not message.strip():
190
  return history, ""
191
 
192
  # Initialize tokenizer (if needed)
193
- if tokenizer is None:
194
- load_model_and_tokenizer()
195
 
196
  try:
197
  # Call GPU function to generate response
198
- response = generate_response_gpu(message, max_tokens)
199
  history.append((message, response))
200
  return history, ""
201
  except Exception as chat_error:
@@ -203,12 +232,12 @@ def chat_interface(message, history, max_tokens):
203
  history.append((message, error_msg))
204
  return history, ""
205
 
206
- # Load tokenizer at startup
207
- load_model_and_tokenizer()
208
 
209
  # Create Gradio application
210
  with gr.Blocks(
211
- title="Robot Task Planning - Llama 3.1 8B",
212
  theme=gr.themes.Soft(),
213
  css="""
214
  .gradio-container {
@@ -218,13 +247,20 @@ with gr.Blocks(
218
  """
219
  ) as app:
220
  gr.Markdown("""
221
- # πŸ€– Llama 3.1 8B - Robot Task Planning
 
 
222
 
223
- This is a fine-tuned version of Meta's Llama 3.1 8B model specialized for **robot task planning** using QLoRA technique.
 
 
224
 
225
  **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots.
226
 
227
- **Model**: [YongdongWang/llama3.1-8b-lora-qlora-dart-llm](https://huggingface.co/YongdongWang/llama3.1-8b-lora-qlora-dart-llm)
 
 
 
228
 
229
  ⚑ **Using ZeroGPU**: This Space uses dynamic GPU allocation (Nvidia H200). First generation might take a bit longer.
230
  """)
@@ -256,6 +292,14 @@ with gr.Blocks(
256
  with gr.Column(scale=1):
257
  gr.Markdown("### βš™οΈ Generation Settings")
258
 
 
 
 
 
 
 
 
 
259
  max_tokens = gr.Slider(
260
  minimum=50,
261
  maximum=5000,
@@ -270,6 +314,9 @@ with gr.Blocks(
270
  - **Hardware**: ZeroGPU (Dynamic Nvidia H200)
271
  - **Status**: Ready
272
  - **Note**: First generation allocates GPU resources
 
 
 
273
  """)
274
 
275
  # Example conversations
@@ -289,13 +336,13 @@ with gr.Blocks(
289
  # Event handling
290
  msg.submit(
291
  chat_interface,
292
- inputs=[msg, chatbot, max_tokens],
293
  outputs=[chatbot, msg]
294
  )
295
 
296
  send_btn.click(
297
  chat_interface,
298
- inputs=[msg, chatbot, max_tokens],
299
  outputs=[chatbot, msg]
300
  )
301
 
 
19
  import os
20
  warnings.filterwarnings("ignore")
21
 
22
+ # Model configurations
23
+ MODEL_CONFIGS = {
24
+ "1B": {
25
+ "name": "Dart-llm-model-1B",
26
+ "base_model": "meta-llama/Llama-3.2-1B",
27
+ "lora_model": "YongdongWang/llama-3.2-1b-lora-qlora-dart-llm"
28
+ },
29
+ "3B": {
30
+ "name": "Dart-llm-model-3B",
31
+ "base_model": "meta-llama/Llama-3.2-3B",
32
+ "lora_model": "YongdongWang/llama-3.2-3b-lora-qlora-dart-llm"
33
+ },
34
+ "8B": {
35
+ "name": "Dart-llm-model-8B",
36
+ "base_model": "meta-llama/Llama-3.1-8B",
37
+ "lora_model": "YongdongWang/llama-3.1-8b-lora-qlora-dart-llm"
38
+ }
39
+ }
40
+
41
+ DEFAULT_MODEL = "1B" # Set 1B as default
42
 
43
  # Global variables to store model and tokenizer
44
  model = None
45
  tokenizer = None
46
+ current_model_config = None
47
  model_loaded = False
48
 
49
+ def load_model_and_tokenizer(selected_model=DEFAULT_MODEL):
50
  """Load tokenizer - executed on CPU"""
51
+ global tokenizer, model_loaded, current_model_config
52
 
53
+ if model_loaded and current_model_config == selected_model:
54
  return
55
 
56
+ print(f"πŸ”„ Loading tokenizer for {MODEL_CONFIGS[selected_model]['name']}...")
57
 
58
  # Load tokenizer (on CPU)
59
+ base_model = MODEL_CONFIGS[selected_model]["base_model"]
60
  tokenizer = AutoTokenizer.from_pretrained(
61
+ base_model,
62
  use_fast=False,
63
  trust_remote_code=True
64
  )
65
  if tokenizer.pad_token is None:
66
  tokenizer.pad_token = tokenizer.eos_token
67
 
68
+ current_model_config = selected_model
69
  model_loaded = True
70
  print("βœ… Tokenizer loaded successfully!")
71
 
72
  @spaces.GPU(duration=60) # Request GPU for loading model at startup
73
+ def load_model_on_gpu(selected_model=DEFAULT_MODEL):
74
  """Load model on GPU"""
75
  global model
76
 
77
+ # If model is already loaded and it's the same model, return it
78
+ if model is not None and current_model_config == selected_model:
79
  return model
80
 
81
+ # Clear existing model if switching
82
+ if model is not None:
83
+ print("πŸ—‘οΈ Clearing existing model from GPU...")
84
+ del model
85
+ torch.cuda.empty_cache()
86
+ model = None
87
+
88
+ model_config = MODEL_CONFIGS[selected_model]
89
+ print(f"πŸ”„ Loading {model_config['name']} on GPU...")
90
 
91
  try:
92
  # 4-bit quantization configuration
 
99
 
100
  # Load base model
101
  base_model = AutoModelForCausalLM.from_pretrained(
102
+ model_config["base_model"],
103
  quantization_config=bnb_config,
104
  device_map="auto",
105
  torch_dtype=torch.float16,
 
111
  # Load LoRA adapter
112
  model = PeftModel.from_pretrained(
113
  base_model,
114
+ model_config["lora_model"],
115
  torch_dtype=torch.float16,
116
  use_safetensors=True
117
  )
118
  model.eval()
119
 
120
+ print(f"βœ… {model_config['name']} loaded on GPU successfully!")
121
  return model
122
 
123
  except Exception as load_error:
 
151
  return response
152
 
153
  @spaces.GPU(duration=60) # GPU inference
154
+ def generate_response_gpu(prompt, max_tokens=512, selected_model=DEFAULT_MODEL):
155
  """Generate response - executed on GPU"""
156
  global model
157
 
158
  # Ensure tokenizer is loaded
159
+ if tokenizer is None or current_model_config != selected_model:
160
+ load_model_and_tokenizer(selected_model)
161
 
162
  # Ensure model is loaded on GPU
163
+ if model is None or current_model_config != selected_model:
164
+ model = load_model_on_gpu(selected_model)
165
 
166
  if model is None:
167
  return "❌ Model failed to load. Please check the Space logs."
 
213
  except Exception as generation_error:
214
  return f"❌ Generation Error: {str(generation_error)}"
215
 
216
+ def chat_interface(message, history, max_tokens, selected_model):
217
  """Chat interface - runs on CPU, calls GPU functions"""
218
  if not message.strip():
219
  return history, ""
220
 
221
  # Initialize tokenizer (if needed)
222
+ if tokenizer is None or current_model_config != selected_model:
223
+ load_model_and_tokenizer(selected_model)
224
 
225
  try:
226
  # Call GPU function to generate response
227
+ response = generate_response_gpu(message, max_tokens, selected_model)
228
  history.append((message, response))
229
  return history, ""
230
  except Exception as chat_error:
 
232
  history.append((message, error_msg))
233
  return history, ""
234
 
235
+ # Load tokenizer at startup with default model
236
+ load_model_and_tokenizer(DEFAULT_MODEL)
237
 
238
  # Create Gradio application
239
  with gr.Blocks(
240
+ title="Robot Task Planning - DART-LLM Multi-Model",
241
  theme=gr.themes.Soft(),
242
  css="""
243
  .gradio-container {
 
247
  """
248
  ) as app:
249
  gr.Markdown("""
250
+ # πŸ€– DART-LLM Multi-Model - Robot Task Planning
251
+
252
+ Choose from **three fine-tuned models** specialized for **robot task planning** using QLoRA technique:
253
 
254
+ - **πŸš€ Dart-llm-model-1B** (Default): Fastest inference, optimized for speed
255
+ - **βš–οΈ Dart-llm-model-3B**: Balanced performance and quality
256
+ - **🎯 Dart-llm-model-8B**: Best quality output, higher latency
257
 
258
  **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots.
259
 
260
+ **Models**:
261
+ - [YongdongWang/llama-3.2-1b-lora-qlora-dart-llm](https://huggingface.co/YongdongWang/llama-3.2-1b-lora-qlora-dart-llm) (Default)
262
+ - [YongdongWang/llama-3.2-3b-lora-qlora-dart-llm](https://huggingface.co/YongdongWang/llama-3.2-3b-lora-qlora-dart-llm)
263
+ - [YongdongWang/llama-3.1-8b-lora-qlora-dart-llm](https://huggingface.co/YongdongWang/llama-3.1-8b-lora-qlora-dart-llm)
264
 
265
  ⚑ **Using ZeroGPU**: This Space uses dynamic GPU allocation (Nvidia H200). First generation might take a bit longer.
266
  """)
 
292
  with gr.Column(scale=1):
293
  gr.Markdown("### βš™οΈ Generation Settings")
294
 
295
+ model_selector = gr.Dropdown(
296
+ choices=[(config["name"], key) for key, config in MODEL_CONFIGS.items()],
297
+ value=DEFAULT_MODEL,
298
+ label="Model Size",
299
+ info="Select model size (1B = fastest, 8B = best quality)",
300
+ interactive=True
301
+ )
302
+
303
  max_tokens = gr.Slider(
304
  minimum=50,
305
  maximum=5000,
 
314
  - **Hardware**: ZeroGPU (Dynamic Nvidia H200)
315
  - **Status**: Ready
316
  - **Note**: First generation allocates GPU resources
317
+ - **Dart-llm-model-1B**: Fastest inference (Default)
318
+ - **Dart-llm-model-3B**: Balanced speed/quality
319
+ - **Dart-llm-model-8B**: Best quality, slower
320
  """)
321
 
322
  # Example conversations
 
336
  # Event handling
337
  msg.submit(
338
  chat_interface,
339
+ inputs=[msg, chatbot, max_tokens, model_selector],
340
  outputs=[chatbot, msg]
341
  )
342
 
343
  send_btn.click(
344
  chat_interface,
345
+ inputs=[msg, chatbot, max_tokens, model_selector],
346
  outputs=[chatbot, msg]
347
  )
348