grayphite commited on
Commit
13b7b20
Β·
verified Β·
1 Parent(s): c918735

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +419 -73
app.py CHANGED
@@ -1,69 +1,187 @@
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
4
- from transformers import AutoProcessor, LlavaForConditionalGeneration
5
- from io import BytesIO
6
  import requests
 
7
  import json
8
  import time
 
 
 
 
 
9
 
10
- # Load processor and model
11
- processor = AutoProcessor.from_pretrained("liuhaotian/llava-v1.5-7b")
12
 
13
- model = LlavaForConditionalGeneration.from_pretrained(
14
- "liuhaotian/llava-v1.5-7b",
15
- torch_dtype=torch.float16,
16
- device_map="auto"
17
- )
18
 
19
- # Core inference function
20
- def generate_response(user_message, system_prompt=None, image=None, max_tokens=1024, temperature=0.7):
21
- if system_prompt:
22
- prompt = f"<image>\n{system_prompt}\n{user_message}"
23
- else:
24
- prompt = f"<image>\n{user_message}"
25
 
26
- inputs = processor(prompt, image, return_tensors="pt").to(model.device)
27
-
28
- with torch.inference_mode():
29
- output = model.generate(
30
- **inputs,
31
- max_new_tokens=max_tokens,
32
- do_sample=True,
33
- temperature=temperature,
 
 
 
 
 
 
 
 
 
34
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- response_text = processor.decode(output[0], skip_special_tokens=True)
37
- return response_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # API-style function for programmatic access
40
- def api_endpoint(request: gr.Request):
 
 
 
 
 
41
  try:
42
- data = request.json
43
- user_message = data.get("user_message", "")
44
- system_prompt = data.get("system_prompt", None)
45
- image_url = data.get("image_url", None)
46
- max_tokens = data.get("max_tokens", 1024)
47
- temperature = data.get("temperature", 0.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- image_data = None
 
 
 
 
 
 
 
 
 
 
 
 
50
  if image_url:
51
- image_response = requests.get(image_url)
52
- image_data = Image.open(BytesIO(image_response.content)).convert("RGB")
53
-
 
 
 
 
 
54
  response_text = generate_response(
55
- user_message=user_message,
 
56
  system_prompt=system_prompt,
57
- image=image_data,
58
  max_tokens=max_tokens,
59
  temperature=temperature
60
  )
61
-
62
- return gr.Response(json.dumps({
 
63
  "id": f"chatcmpl-{int(time.time())}",
64
- "object": "chat.completion",
65
  "created": int(time.time()),
66
- "model": "llava-1.5-7b",
67
  "choices": [{
68
  "message": {
69
  "role": "assistant",
@@ -71,40 +189,268 @@ def api_endpoint(request: gr.Request):
71
  },
72
  "index": 0,
73
  "finish_reason": "stop"
74
- }]
75
- }), media_type="application/json")
76
-
 
 
 
 
 
77
  except Exception as e:
78
- return gr.Response(json.dumps({"error": str(e)}), media_type="application/json")
79
-
80
 
81
- # Gradio UI
82
- with gr.Blocks() as demo:
83
- gr.Markdown("# πŸ” LLaVA API Demo")
84
 
85
- with gr.Tab("Test UI"):
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  with gr.Row():
87
- with gr.Column():
88
- user_message = gr.Textbox(label="User Message", lines=3)
89
- system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=2)
90
- image_input = gr.Image(label="Image (Optional)", type="pil")
91
- max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=2048, value=1024, step=1)
92
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1)
93
- submit_btn = gr.Button("Generate Response")
94
- with gr.Column():
95
- output = gr.Textbox(label="Response", lines=10)
96
-
97
- def on_submit(message, system, image, tokens, temp):
98
- return generate_response(message, system, image, tokens, temp)
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  submit_btn.click(
101
- fn=on_submit,
102
- inputs=[user_message, system_prompt, image_input, max_tokens, temperature],
103
- outputs=output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  )
105
 
106
- # API endpoint
107
- demo.api("/api")(api_endpoint)
108
-
109
- # Launch
110
- demo.launch()
 
 
 
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
 
 
4
  import requests
5
+ from io import BytesIO
6
  import json
7
  import time
8
+ import os
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
10
+ from transformers import CLIPVisionModel, CLIPImageProcessor
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
 
14
+ print("πŸš€ Starting LLaVA deployment...")
 
15
 
16
+ # Check GPU availability
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"πŸ’» Using device: {device}")
 
 
19
 
20
+ # Global variables for model components
21
+ tokenizer = None
22
+ model = None
23
+ image_processor = None
24
+ vision_tower = None
 
25
 
26
+ def load_model():
27
+ """Load LLaVA model components"""
28
+ global tokenizer, model, image_processor, vision_tower
29
+
30
+ try:
31
+ print("πŸ“¦ Loading tokenizer...")
32
+ # Use the smaller 7B model for free tier
33
+ model_path = "liuhaotian/llava-v1.5-7b"
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
36
+
37
+ print("🧠 Loading language model...")
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_path,
40
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
41
+ low_cpu_mem_usage=True,
42
+ device_map="auto" if device == "cuda" else None
43
  )
44
+
45
+ print("πŸ‘οΈ Loading vision components...")
46
+ # Load vision tower
47
+ vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
48
+ image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
49
+
50
+ if device == "cuda":
51
+ vision_tower = vision_tower.to(device)
52
+
53
+ print("βœ… Model loaded successfully!")
54
+ return True
55
+
56
+ except Exception as e:
57
+ print(f"❌ Error loading model: {str(e)}")
58
+ return False
59
 
60
+ def process_image(image):
61
+ """Process image for the model"""
62
+ if image is None:
63
+ return None
64
+
65
+ try:
66
+ # Convert to RGB if needed
67
+ if image.mode != 'RGB':
68
+ image = image.convert('RGB')
69
+
70
+ # Process image
71
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
72
+
73
+ if device == "cuda":
74
+ image_tensor = image_tensor.to(device)
75
+
76
+ # Get image features
77
+ with torch.no_grad():
78
+ image_features = vision_tower(image_tensor).last_hidden_state
79
+
80
+ return image_features
81
+
82
+ except Exception as e:
83
+ print(f"Error processing image: {str(e)}")
84
+ return None
85
 
86
+ def generate_response(message, image=None, system_prompt="", max_tokens=1024, temperature=0.7):
87
+ """Generate response using LLaVA"""
88
+ global tokenizer, model, image_processor, vision_tower
89
+
90
+ if model is None:
91
+ return "❌ Model not loaded. Please wait for initialization."
92
+
93
  try:
94
+ # Process image if provided
95
+ image_features = None
96
+ if image is not None:
97
+ image_features = process_image(image)
98
+ if image_features is None:
99
+ return "❌ Error processing image."
100
+
101
+ # Prepare prompt
102
+ if system_prompt:
103
+ full_prompt = f"System: {system_prompt}\n\nUser: {message}\n\nAssistant:"
104
+ else:
105
+ if image is not None:
106
+ full_prompt = f"USER: <image>\n{message}\nASSISTANT:"
107
+ else:
108
+ full_prompt = f"USER: {message}\nASSISTANT:"
109
+
110
+ # Tokenize
111
+ inputs = tokenizer(full_prompt, return_tensors="pt")
112
+
113
+ if device == "cuda":
114
+ inputs = {k: v.to(device) for k, v in inputs.items()}
115
+
116
+ # Generate
117
+ with torch.no_grad():
118
+ if image_features is not None:
119
+ # For multimodal input, we need to handle image features
120
+ # This is a simplified version - real LLaVA has more complex integration
121
+ outputs = model.generate(
122
+ **inputs,
123
+ max_new_tokens=max_tokens,
124
+ temperature=temperature,
125
+ do_sample=True,
126
+ pad_token_id=tokenizer.eos_token_id
127
+ )
128
+ else:
129
+ # Text-only generation
130
+ outputs = model.generate(
131
+ **inputs,
132
+ max_new_tokens=max_tokens,
133
+ temperature=temperature,
134
+ do_sample=True,
135
+ pad_token_id=tokenizer.eos_token_id
136
+ )
137
+
138
+ # Decode response
139
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
140
+
141
+ # Clean up response (remove the input prompt)
142
+ response = response[len(full_prompt):].strip()
143
+
144
+ return response
145
+
146
+ except Exception as e:
147
+ return f"❌ Error generating response: {str(e)}"
148
 
149
+ def api_endpoint(request_json):
150
+ """API endpoint for programmatic access"""
151
+ try:
152
+ data = json.loads(request_json)
153
+
154
+ message = data.get("message", "")
155
+ system_prompt = data.get("system_prompt", "")
156
+ image_url = data.get("image_url", None)
157
+ max_tokens = int(data.get("max_tokens", 1024))
158
+ temperature = float(data.get("temperature", 0.7))
159
+
160
+ # Process image if URL provided
161
+ image = None
162
  if image_url:
163
+ try:
164
+ response = requests.get(image_url, timeout=10)
165
+ if response.status_code == 200:
166
+ image = Image.open(BytesIO(response.content))
167
+ except Exception as e:
168
+ return json.dumps({"error": f"Failed to load image: {str(e)}"})
169
+
170
+ # Generate response
171
  response_text = generate_response(
172
+ message=message,
173
+ image=image,
174
  system_prompt=system_prompt,
 
175
  max_tokens=max_tokens,
176
  temperature=temperature
177
  )
178
+
179
+ # Return API response
180
+ return json.dumps({
181
  "id": f"chatcmpl-{int(time.time())}",
182
+ "object": "chat.completion",
183
  "created": int(time.time()),
184
+ "model": "llava-v1.5-7b",
185
  "choices": [{
186
  "message": {
187
  "role": "assistant",
 
189
  },
190
  "index": 0,
191
  "finish_reason": "stop"
192
+ }],
193
+ "usage": {
194
+ "prompt_tokens": 0, # Simplified
195
+ "completion_tokens": 0, # Simplified
196
+ "total_tokens": 0 # Simplified
197
+ }
198
+ })
199
+
200
  except Exception as e:
201
+ return json.dumps({"error": str(e)})
 
202
 
203
+ # Initialize model on startup
204
+ print("πŸ”„ Initializing model...")
205
+ model_loaded = load_model()
206
 
207
+ # Create Gradio interface
208
+ with gr.Blocks(title="LLaVA - Large Language and Vision Assistant", theme=gr.themes.Soft()) as demo:
209
+ gr.Markdown("""
210
+ # πŸ¦™ LLaVA - Large Language and Vision Assistant
211
+
212
+ An open-source chatbot trained by fine-tuning LLaMA/Vicuna on GPT-generated multimodal instruction-following data.
213
+
214
+ **Features:**
215
+ - πŸ’¬ Text-based conversation
216
+ - πŸ–ΌοΈ Image understanding and description
217
+ - πŸ”§ API endpoint for integration
218
+ """)
219
+
220
+ with gr.Tab("πŸ’¬ Chat Interface"):
221
  with gr.Row():
222
+ with gr.Column(scale=1):
223
+ image_input = gr.Image(
224
+ type="pil",
225
+ label="πŸ“Έ Upload Image (Optional)",
226
+ height=300
227
+ )
228
+ system_prompt = gr.Textbox(
229
+ label="🎯 System Prompt (Optional)",
230
+ placeholder="You are a helpful assistant that can analyze images...",
231
+ lines=2
232
+ )
233
+
234
+ with gr.Column(scale=2):
235
+ chatbot = gr.Chatbot(
236
+ label="πŸ’­ Conversation",
237
+ height=400
238
+ )
239
+
240
+ msg = gr.Textbox(
241
+ label="✍️ Your Message",
242
+ placeholder="Type your message here... You can ask about the uploaded image!",
243
+ lines=2
244
+ )
245
+
246
+ with gr.Row():
247
+ submit_btn = gr.Button("πŸš€ Send", variant="primary")
248
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
249
+
250
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
251
+ max_tokens = gr.Slider(
252
+ minimum=1,
253
+ maximum=2048,
254
+ value=1024,
255
+ step=1,
256
+ label="πŸ“ Max Tokens"
257
+ )
258
+ temperature = gr.Slider(
259
+ minimum=0.1,
260
+ maximum=2.0,
261
+ value=0.7,
262
+ step=0.1,
263
+ label="🌑️ Temperature"
264
+ )
265
+
266
+ with gr.Tab("πŸ”Œ API Documentation"):
267
+ gr.Markdown("""
268
+ ## API Endpoint Usage
269
+
270
+ **Endpoint**: `https://your-space-name.hf.space/api/predict`
271
+
272
+ **Method**: POST
273
+
274
+ ### Request Format:
275
+ ```json
276
+ {
277
+ "data": [
278
+ "{
279
+ \"message\": \"Describe this image in detail\",
280
+ \"system_prompt\": \"You are a helpful assistant\",
281
+ \"image_url\": \"https://example.com/image.jpg\",
282
+ \"max_tokens\": 1024,
283
+ \"temperature\": 0.7
284
+ }"
285
+ ]
286
+ }
287
+ ```
288
+
289
+ ### Response Format:
290
+ ```json
291
+ {
292
+ "data": [
293
+ "{
294
+ \"id\": \"chatcmpl-123456789\",
295
+ \"object\": \"chat.completion\",
296
+ \"created\": 1683123456,
297
+ \"model\": \"llava-v1.5-7b\",
298
+ \"choices\": [
299
+ {
300
+ \"message\": {
301
+ \"role\": \"assistant\",
302
+ \"content\": \"This image shows...\"
303
+ },
304
+ \"index\": 0,
305
+ \"finish_reason\": \"stop\"
306
+ }
307
+ ]
308
+ }"
309
+ ]
310
+ }
311
+ ```
312
+
313
+ ### Python Client Example:
314
+ ```python
315
+ import requests
316
+ import json
317
+
318
+ def query_llava(message, image_url=None, system_prompt=""):
319
+ payload = {
320
+ "data": [json.dumps({
321
+ "message": message,
322
+ "image_url": image_url,
323
+ "system_prompt": system_prompt,
324
+ "max_tokens": 1024,
325
+ "temperature": 0.7
326
+ })]
327
+ }
328
+
329
+ response = requests.post(
330
+ "https://your-space-name.hf.space/api/predict",
331
+ json=payload
332
+ )
333
+
334
+ if response.status_code == 200:
335
+ result = response.json()
336
+ api_response = json.loads(result["data"][0])
337
+ return api_response["choices"][0]["message"]["content"]
338
+ else:
339
+ return f"Error: {response.status_code}"
340
+
341
+ # Example usage
342
+ result = query_llava(
343
+ "What do you see in this image?",
344
+ image_url="https://example.com/image.jpg"
345
+ )
346
+ print(result)
347
+ ```
348
+ """)
349
+
350
+ # API testing interface
351
+ gr.Markdown("### πŸ§ͺ Test API")
352
+ api_input = gr.Textbox(
353
+ label="πŸ“ API Request (JSON)",
354
+ placeholder='{"message": "Hello!", "max_tokens": 1024}',
355
+ lines=4
356
+ )
357
+ api_output = gr.Textbox(
358
+ label="πŸ“€ API Response",
359
+ lines=8
360
+ )
361
+ api_test_btn = gr.Button("πŸ§ͺ Test API", variant="primary")
362
+
363
+ with gr.Tab("ℹ️ About"):
364
+ gr.Markdown("""
365
+ ## About LLaVA
366
+
367
+ **LLaVA (Large Language and Vision Assistant)** is an open-source multimodal AI assistant that combines:
368
+
369
+ - 🧠 **Language Understanding**: Based on Vicuna/LLaMA architecture
370
+ - πŸ‘οΈ **Vision Capabilities**: Uses CLIP vision encoder
371
+ - πŸ”— **Multimodal Integration**: Connects vision and language seamlessly
372
+
373
+ ### Key Features:
374
+ - **Visual Question Answering**: Ask questions about images
375
+ - **Image Description**: Get detailed descriptions of uploaded images
376
+ - **General Conversation**: Chat about any topic
377
+ - **API Integration**: Easy integration with your applications
378
+
379
+ ### Model Information:
380
+ - **Base Model**: LLaVA-v1.5-7B
381
+ - **Vision Encoder**: CLIP ViT-L/14@336px
382
+ - **Language Model**: Vicuna-7B
383
+ - **Training Data**: LLaVA-Instruct-150K
384
+
385
+ ### Citation:
386
+ ```
387
+ @misc{liu2023llava,
388
+ title={Visual Instruction Tuning},
389
+ author={Haotian Liu and Chunyuan Li and Qingyang Wu and Yong Jae Lee},
390
+ year={2023},
391
+ eprint={2304.08485},
392
+ archivePrefix={arXiv},
393
+ primaryClass={cs.CV}
394
+ }
395
+ ```
396
+
397
+ **GitHub**: [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA)
398
+ """)
399
+
400
+ # Event handlers
401
+ def respond(message, chat_history, image, system_prompt, max_tokens, temperature):
402
+ if not message.strip():
403
+ return "", chat_history
404
+
405
+ # Add user message to chat
406
+ chat_history.append([message, None])
407
+
408
+ # Generate response
409
+ response = generate_response(
410
+ message=message,
411
+ image=image,
412
+ system_prompt=system_prompt if system_prompt.strip() else "",
413
+ max_tokens=int(max_tokens),
414
+ temperature=temperature
415
+ )
416
+
417
+ # Add assistant response to chat
418
+ chat_history[-1][1] = response
419
+
420
+ return "", chat_history
421
+
422
+ def clear_chat():
423
+ return None, []
424
+
425
+ # Connect event handlers
426
  submit_btn.click(
427
+ respond,
428
+ [msg, chatbot, image_input, system_prompt, max_tokens, temperature],
429
+ [msg, chatbot]
430
+ )
431
+
432
+ msg.submit(
433
+ respond,
434
+ [msg, chatbot, image_input, system_prompt, max_tokens, temperature],
435
+ [msg, chatbot]
436
+ )
437
+
438
+ clear_btn.click(clear_chat, outputs=[chatbot, msg])
439
+
440
+ api_test_btn.click(api_endpoint, inputs=api_input, outputs=api_output)
441
+
442
+ # Add API endpoint
443
+ api_interface = gr.Interface(
444
+ fn=api_endpoint,
445
+ inputs=gr.Textbox(),
446
+ outputs=gr.Textbox(),
447
+ api_name="predict"
448
  )
449
 
450
+ # Launch the app
451
+ if __name__ == "__main__":
452
+ demo.launch(
453
+ server_name="0.0.0.0",
454
+ server_port=7860,
455
+ share=False
456
+ )