Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import os
|
|
8 |
import torchvision.transforms.functional as TVF
|
9 |
import io
|
10 |
import json # For parsing extra_options_json
|
|
|
11 |
|
12 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
13 |
from pydantic import BaseModel
|
@@ -164,8 +165,46 @@ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTr
|
|
164 |
# LLM
|
165 |
print("Loading LLM")
|
166 |
print("Loading VLM's custom text model")
|
167 |
-
# Use device_map="auto"
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
text_model.eval()
|
170 |
|
171 |
# Image Adapter
|
@@ -365,4 +404,4 @@ async def caption_image_endpoint(
|
|
365 |
|
366 |
if __name__ == "__main__":
|
367 |
import uvicorn
|
368 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
8 |
import torchvision.transforms.functional as TVF
|
9 |
import io
|
10 |
import json # For parsing extra_options_json
|
11 |
+
from tempfile import TemporaryDirectory # For offload_folder
|
12 |
|
13 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
14 |
from pydantic import BaseModel
|
|
|
165 |
# LLM
|
166 |
print("Loading LLM")
|
167 |
print("Loading VLM's custom text model")
|
168 |
+
# Use device_map="auto" and load_in_8bit for quantization to reduce memory footprint
|
169 |
+
try:
|
170 |
+
text_model = AutoModelForCausalLM.from_pretrained(
|
171 |
+
CHECKPOINT_PATH / "text_model",
|
172 |
+
device_map="auto",
|
173 |
+
load_in_8bit=True # Enable 8-bit quantization
|
174 |
+
# torch_dtype is generally not specified with load_in_8bit,
|
175 |
+
# as bitsandbytes handles the underlying types.
|
176 |
+
)
|
177 |
+
except Exception as e:
|
178 |
+
print(f"Failed to load model with 8-bit quantization: {e}")
|
179 |
+
print("Attempting to load without 8-bit quantization (this may fail due to memory or require offloading)...")
|
180 |
+
# Fallback or alternative loading strategy can be placed here if needed
|
181 |
+
# For now, let it re-raise or try a different approach if the primary fails.
|
182 |
+
# As a simple fallback for now, try original loading which might hit the offload error
|
183 |
+
try:
|
184 |
+
text_model = AutoModelForCausalLM.from_pretrained(
|
185 |
+
CHECKPOINT_PATH / "text_model",
|
186 |
+
device_map="auto",
|
187 |
+
torch_dtype=torch.bfloat16 # Try with bfloat16 first
|
188 |
+
)
|
189 |
+
except ValueError as ve:
|
190 |
+
if "offload_dir" in str(ve): # Check if the error is about needing offload_dir
|
191 |
+
print(f"Original loading failed with ValueError (likely needing offload_dir): {ve}")
|
192 |
+
print("Attempting to load model with disk offloading...")
|
193 |
+
model_offload_dir = TemporaryDirectory().name
|
194 |
+
text_model = AutoModelForCausalLM.from_pretrained(
|
195 |
+
CHECKPOINT_PATH / "text_model",
|
196 |
+
device_map="auto",
|
197 |
+
torch_dtype=torch.bfloat16, # Keep bfloat16 if possible
|
198 |
+
offload_folder=model_offload_dir,
|
199 |
+
offload_state_dict=True # Recommended when offloading
|
200 |
+
)
|
201 |
+
print(f"Model loaded with offloading to {model_offload_dir}. WARNING: This will be very slow.")
|
202 |
+
else:
|
203 |
+
raise # Re-raise other ValueErrors
|
204 |
+
except Exception as final_e: # Catch any other exceptions during the last fallback attempt
|
205 |
+
print(f"All model loading attempts failed. Last error: {final_e}")
|
206 |
+
raise
|
207 |
+
|
208 |
text_model.eval()
|
209 |
|
210 |
# Image Adapter
|
|
|
404 |
|
405 |
if __name__ == "__main__":
|
406 |
import uvicorn
|
407 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|