royyy74 commited on
Commit
011b368
·
verified ·
1 Parent(s): 613218d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -3
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import torchvision.transforms.functional as TVF
9
  import io
10
  import json # For parsing extra_options_json
 
11
 
12
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
13
  from pydantic import BaseModel
@@ -164,8 +165,46 @@ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTr
164
  # LLM
165
  print("Loading LLM")
166
  print("Loading VLM's custom text model")
167
- # Use device_map="auto" to allow accelerate to handle model placement, including CPU
168
- text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map="auto", torch_dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  text_model.eval()
170
 
171
  # Image Adapter
@@ -365,4 +404,4 @@ async def caption_image_endpoint(
365
 
366
  if __name__ == "__main__":
367
  import uvicorn
368
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  import torchvision.transforms.functional as TVF
9
  import io
10
  import json # For parsing extra_options_json
11
+ from tempfile import TemporaryDirectory # For offload_folder
12
 
13
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
14
  from pydantic import BaseModel
 
165
  # LLM
166
  print("Loading LLM")
167
  print("Loading VLM's custom text model")
168
+ # Use device_map="auto" and load_in_8bit for quantization to reduce memory footprint
169
+ try:
170
+ text_model = AutoModelForCausalLM.from_pretrained(
171
+ CHECKPOINT_PATH / "text_model",
172
+ device_map="auto",
173
+ load_in_8bit=True # Enable 8-bit quantization
174
+ # torch_dtype is generally not specified with load_in_8bit,
175
+ # as bitsandbytes handles the underlying types.
176
+ )
177
+ except Exception as e:
178
+ print(f"Failed to load model with 8-bit quantization: {e}")
179
+ print("Attempting to load without 8-bit quantization (this may fail due to memory or require offloading)...")
180
+ # Fallback or alternative loading strategy can be placed here if needed
181
+ # For now, let it re-raise or try a different approach if the primary fails.
182
+ # As a simple fallback for now, try original loading which might hit the offload error
183
+ try:
184
+ text_model = AutoModelForCausalLM.from_pretrained(
185
+ CHECKPOINT_PATH / "text_model",
186
+ device_map="auto",
187
+ torch_dtype=torch.bfloat16 # Try with bfloat16 first
188
+ )
189
+ except ValueError as ve:
190
+ if "offload_dir" in str(ve): # Check if the error is about needing offload_dir
191
+ print(f"Original loading failed with ValueError (likely needing offload_dir): {ve}")
192
+ print("Attempting to load model with disk offloading...")
193
+ model_offload_dir = TemporaryDirectory().name
194
+ text_model = AutoModelForCausalLM.from_pretrained(
195
+ CHECKPOINT_PATH / "text_model",
196
+ device_map="auto",
197
+ torch_dtype=torch.bfloat16, # Keep bfloat16 if possible
198
+ offload_folder=model_offload_dir,
199
+ offload_state_dict=True # Recommended when offloading
200
+ )
201
+ print(f"Model loaded with offloading to {model_offload_dir}. WARNING: This will be very slow.")
202
+ else:
203
+ raise # Re-raise other ValueErrors
204
+ except Exception as final_e: # Catch any other exceptions during the last fallback attempt
205
+ print(f"All model loading attempts failed. Last error: {final_e}")
206
+ raise
207
+
208
  text_model.eval()
209
 
210
  # Image Adapter
 
404
 
405
  if __name__ == "__main__":
406
  import uvicorn
407
+ uvicorn.run(app, host="0.0.0.0", port=8000)