Update app.py
Browse files
app.py
CHANGED
@@ -119,6 +119,10 @@ class ImageAdapter(nn.Module):
|
|
119 |
|
120 |
|
121 |
|
|
|
|
|
|
|
|
|
122 |
# Load CLIP
|
123 |
print("Loading CLIP")
|
124 |
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
@@ -134,7 +138,9 @@ del checkpoint
|
|
134 |
|
135 |
clip_model.eval()
|
136 |
clip_model.requires_grad_(False)
|
137 |
-
|
|
|
|
|
138 |
|
139 |
|
140 |
# Tokenizer
|
@@ -145,21 +151,25 @@ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTr
|
|
145 |
# LLM
|
146 |
print("Loading LLM")
|
147 |
print("Loading VLM's custom text model")
|
148 |
-
|
|
|
149 |
text_model.eval()
|
150 |
|
151 |
# Image Adapter
|
152 |
print("Loading image adapter")
|
153 |
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
|
154 |
-
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
|
155 |
image_adapter.eval()
|
156 |
-
|
|
|
|
|
157 |
|
158 |
|
159 |
-
@spaces.GPU()
|
160 |
@torch.no_grad()
|
161 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str | int, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
|
162 |
-
|
|
|
163 |
|
164 |
# 'any' means no length specified
|
165 |
length = None if caption_length == "any" else caption_length
|
@@ -201,14 +211,32 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
|
|
201 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
202 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
203 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
204 |
-
|
|
|
|
|
|
|
205 |
|
206 |
# Embed image
|
207 |
# This results in Batch x Image Tokens x Features
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
210 |
embedded_images = image_adapter(vision_outputs.hidden_states)
|
211 |
-
embedded_images
|
|
|
|
|
|
|
212 |
|
213 |
# Build the conversation
|
214 |
convo = [
|
@@ -228,34 +256,38 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
|
|
228 |
|
229 |
# Tokenize the conversation
|
230 |
# prompt_str is tokenized separately so we can do the calculations below
|
231 |
-
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
|
232 |
-
prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
|
233 |
assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
|
234 |
convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
|
235 |
prompt_tokens = prompt_tokens.squeeze(0)
|
236 |
|
237 |
# Calculate where to inject the image
|
238 |
-
|
|
|
239 |
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
|
240 |
|
241 |
preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
|
242 |
|
243 |
# Embed the tokens
|
244 |
-
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(
|
245 |
|
246 |
# Construct the input
|
|
|
247 |
input_embeds = torch.cat([
|
248 |
convo_embeds[:, :preamble_len], # Part before the prompt
|
249 |
-
embedded_images.to(dtype=convo_embeds.dtype), # Image
|
250 |
convo_embeds[:, preamble_len:], # The prompt and anything after it
|
251 |
-
], dim=1)
|
|
|
252 |
|
253 |
input_ids = torch.cat([
|
254 |
convo_tokens[:preamble_len].unsqueeze(0),
|
255 |
-
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image
|
256 |
convo_tokens[preamble_len:].unsqueeze(0),
|
257 |
-
], dim=1)
|
258 |
-
|
|
|
259 |
|
260 |
# Debugging
|
261 |
print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
|
@@ -277,6 +309,9 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
|
|
277 |
with gr.Blocks() as demo:
|
278 |
gr.HTML(TITLE)
|
279 |
|
|
|
|
|
|
|
280 |
with gr.Row():
|
281 |
with gr.Column():
|
282 |
input_image = gr.Image(type="pil", label="Input Image")
|
@@ -333,4 +368,4 @@ with gr.Blocks() as demo:
|
|
333 |
|
334 |
|
335 |
if __name__ == "__main__":
|
336 |
-
demo.launch()
|
|
|
119 |
|
120 |
|
121 |
|
122 |
+
# Determine device
|
123 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
124 |
+
print(f"Using device: {device}")
|
125 |
+
|
126 |
# Load CLIP
|
127 |
print("Loading CLIP")
|
128 |
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
|
|
138 |
|
139 |
clip_model.eval()
|
140 |
clip_model.requires_grad_(False)
|
141 |
+
if device.type == 'cuda':
|
142 |
+
clip_model = clip_model.to(dtype=torch.bfloat16) # Convert to bfloat16 if on CUDA
|
143 |
+
clip_model.to(device)
|
144 |
|
145 |
|
146 |
# Tokenizer
|
|
|
151 |
# LLM
|
152 |
print("Loading LLM")
|
153 |
print("Loading VLM's custom text model")
|
154 |
+
# Use device_map="auto" to allow accelerate to handle model placement, including CPU
|
155 |
+
text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map="auto", torch_dtype=torch.bfloat16)
|
156 |
text_model.eval()
|
157 |
|
158 |
# Image Adapter
|
159 |
print("Loading image adapter")
|
160 |
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
|
161 |
+
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")) # Load to CPU first
|
162 |
image_adapter.eval()
|
163 |
+
if device.type == 'cuda':
|
164 |
+
image_adapter = image_adapter.to(dtype=torch.bfloat16) # Convert to bfloat16 if on CUDA
|
165 |
+
image_adapter.to(device)
|
166 |
|
167 |
|
168 |
+
@spaces.GPU() # We keep this decorator for now, assuming GPU is preferred if available
|
169 |
@torch.no_grad()
|
170 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str | int, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
|
171 |
+
if device.type == "cuda":
|
172 |
+
torch.cuda.empty_cache()
|
173 |
|
174 |
# 'any' means no length specified
|
175 |
length = None if caption_length == "any" else caption_length
|
|
|
211 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
212 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
213 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
214 |
+
if device.type == 'cuda':
|
215 |
+
pixel_values = pixel_values.to(device, dtype=torch.bfloat16)
|
216 |
+
else:
|
217 |
+
pixel_values = pixel_values.to(device) # CPU will use float32
|
218 |
|
219 |
# Embed image
|
220 |
# This results in Batch x Image Tokens x Features
|
221 |
+
# For CPU, autocast can use bfloat16 if available and beneficial, or can be disabled.
|
222 |
+
# For simplicity here, we'll enable it for CPU with bfloat16 if PyTorch supports it, else float32.
|
223 |
+
# Note: True CPU mixed precision benefits depend on CPU architecture and PyTorch version.
|
224 |
+
autocast_device_type = device.type
|
225 |
+
autocast_kwargs = {'enabled': True}
|
226 |
+
if autocast_device_type == 'cpu':
|
227 |
+
# Check if bfloat16 is supported for CPU autocast, otherwise default to float32 by not specifying dtype
|
228 |
+
# This check might be more involved depending on PyTorch version; for now, let's assume it handles it.
|
229 |
+
# Or, explicitly set dtype if needed: autocast_kwargs['dtype'] = torch.bfloat16
|
230 |
+
pass
|
231 |
+
|
232 |
+
|
233 |
+
with torch.amp.autocast_mode.autocast(autocast_device_type, **autocast_kwargs):
|
234 |
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
235 |
embedded_images = image_adapter(vision_outputs.hidden_states)
|
236 |
+
# embedded_images are already on the correct device due to image_adapter.to(device)
|
237 |
+
# and operations within adapter should respect input tensor's device.
|
238 |
+
# Explicitly moving again to be safe, though may be redundant.
|
239 |
+
embedded_images = embedded_images.to(device)
|
240 |
|
241 |
# Build the conversation
|
242 |
convo = [
|
|
|
256 |
|
257 |
# Tokenize the conversation
|
258 |
# prompt_str is tokenized separately so we can do the calculations below
|
259 |
+
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
|
260 |
+
prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
|
261 |
assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
|
262 |
convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
|
263 |
prompt_tokens = prompt_tokens.squeeze(0)
|
264 |
|
265 |
# Calculate where to inject the image
|
266 |
+
# Ensure convo_tokens is on the CPU for this kind of operation if it involves list conversion or complex indexing not ideal for GPU
|
267 |
+
eot_id_indices = (convo_tokens.cpu() == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
|
268 |
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
|
269 |
|
270 |
preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
|
271 |
|
272 |
# Embed the tokens
|
273 |
+
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(text_model.device)) # Ensure tokens are on same device as text_model
|
274 |
|
275 |
# Construct the input
|
276 |
+
# Ensure all parts are on the same device before concatenation
|
277 |
input_embeds = torch.cat([
|
278 |
convo_embeds[:, :preamble_len], # Part before the prompt
|
279 |
+
embedded_images.to(dtype=convo_embeds.dtype, device=convo_embeds.device), # Image, ensure same dtype and device
|
280 |
convo_embeds[:, preamble_len:], # The prompt and anything after it
|
281 |
+
], dim=1)
|
282 |
+
# input_embeds will be on the device of convo_embeds (i.e. text_model.device)
|
283 |
|
284 |
input_ids = torch.cat([
|
285 |
convo_tokens[:preamble_len].unsqueeze(0),
|
286 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long, device=convo_tokens.device), # Dummy tokens for the image
|
287 |
convo_tokens[preamble_len:].unsqueeze(0),
|
288 |
+
], dim=1)
|
289 |
+
# input_ids will be on the device of convo_tokens
|
290 |
+
attention_mask = torch.ones_like(input_ids) # Will be on the same device as input_ids
|
291 |
|
292 |
# Debugging
|
293 |
print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
|
|
|
309 |
with gr.Blocks() as demo:
|
310 |
gr.HTML(TITLE)
|
311 |
|
312 |
+
if device.type == 'cpu':
|
313 |
+
gr.Markdown("**Warning: Running on CPU.** Captions may take a very long time to generate (potentially several minutes). For faster performance, please use a Space with GPU hardware.")
|
314 |
+
|
315 |
with gr.Row():
|
316 |
with gr.Column():
|
317 |
input_image = gr.Image(type="pil", label="Input Image")
|
|
|
368 |
|
369 |
|
370 |
if __name__ == "__main__":
|
371 |
+
demo.launch()
|