royyy74 commited on
Commit
656168e
·
verified ·
1 Parent(s): 8a4f7ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -19
app.py CHANGED
@@ -119,6 +119,10 @@ class ImageAdapter(nn.Module):
119
 
120
 
121
 
 
 
 
 
122
  # Load CLIP
123
  print("Loading CLIP")
124
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
@@ -134,7 +138,9 @@ del checkpoint
134
 
135
  clip_model.eval()
136
  clip_model.requires_grad_(False)
137
- clip_model.to("cuda")
 
 
138
 
139
 
140
  # Tokenizer
@@ -145,21 +151,25 @@ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTr
145
  # LLM
146
  print("Loading LLM")
147
  print("Loading VLM's custom text model")
148
- text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
 
149
  text_model.eval()
150
 
151
  # Image Adapter
152
  print("Loading image adapter")
153
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
154
- image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
155
  image_adapter.eval()
156
- image_adapter.to("cuda")
 
 
157
 
158
 
159
- @spaces.GPU()
160
  @torch.no_grad()
161
  def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str | int, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
162
- torch.cuda.empty_cache()
 
163
 
164
  # 'any' means no length specified
165
  length = None if caption_length == "any" else caption_length
@@ -201,14 +211,32 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
201
  image = input_image.resize((384, 384), Image.LANCZOS)
202
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
203
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
204
- pixel_values = pixel_values.to('cuda')
 
 
 
205
 
206
  # Embed image
207
  # This results in Batch x Image Tokens x Features
208
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
 
 
 
 
 
 
 
 
 
 
 
 
209
  vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
210
  embedded_images = image_adapter(vision_outputs.hidden_states)
211
- embedded_images = embedded_images.to('cuda')
 
 
 
212
 
213
  # Build the conversation
214
  convo = [
@@ -228,34 +256,38 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
228
 
229
  # Tokenize the conversation
230
  # prompt_str is tokenized separately so we can do the calculations below
231
- convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
232
- prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
233
  assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
234
  convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
235
  prompt_tokens = prompt_tokens.squeeze(0)
236
 
237
  # Calculate where to inject the image
238
- eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
 
239
  assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
240
 
241
  preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
242
 
243
  # Embed the tokens
244
- convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to('cuda'))
245
 
246
  # Construct the input
 
247
  input_embeds = torch.cat([
248
  convo_embeds[:, :preamble_len], # Part before the prompt
249
- embedded_images.to(dtype=convo_embeds.dtype), # Image
250
  convo_embeds[:, preamble_len:], # The prompt and anything after it
251
- ], dim=1).to('cuda')
 
252
 
253
  input_ids = torch.cat([
254
  convo_tokens[:preamble_len].unsqueeze(0),
255
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
256
  convo_tokens[preamble_len:].unsqueeze(0),
257
- ], dim=1).to('cuda')
258
- attention_mask = torch.ones_like(input_ids)
 
259
 
260
  # Debugging
261
  print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
@@ -277,6 +309,9 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
277
  with gr.Blocks() as demo:
278
  gr.HTML(TITLE)
279
 
 
 
 
280
  with gr.Row():
281
  with gr.Column():
282
  input_image = gr.Image(type="pil", label="Input Image")
@@ -333,4 +368,4 @@ with gr.Blocks() as demo:
333
 
334
 
335
  if __name__ == "__main__":
336
- demo.launch()
 
119
 
120
 
121
 
122
+ # Determine device
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ print(f"Using device: {device}")
125
+
126
  # Load CLIP
127
  print("Loading CLIP")
128
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
 
138
 
139
  clip_model.eval()
140
  clip_model.requires_grad_(False)
141
+ if device.type == 'cuda':
142
+ clip_model = clip_model.to(dtype=torch.bfloat16) # Convert to bfloat16 if on CUDA
143
+ clip_model.to(device)
144
 
145
 
146
  # Tokenizer
 
151
  # LLM
152
  print("Loading LLM")
153
  print("Loading VLM's custom text model")
154
+ # Use device_map="auto" to allow accelerate to handle model placement, including CPU
155
+ text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map="auto", torch_dtype=torch.bfloat16)
156
  text_model.eval()
157
 
158
  # Image Adapter
159
  print("Loading image adapter")
160
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
161
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu")) # Load to CPU first
162
  image_adapter.eval()
163
+ if device.type == 'cuda':
164
+ image_adapter = image_adapter.to(dtype=torch.bfloat16) # Convert to bfloat16 if on CUDA
165
+ image_adapter.to(device)
166
 
167
 
168
+ @spaces.GPU() # We keep this decorator for now, assuming GPU is preferred if available
169
  @torch.no_grad()
170
  def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str | int, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
171
+ if device.type == "cuda":
172
+ torch.cuda.empty_cache()
173
 
174
  # 'any' means no length specified
175
  length = None if caption_length == "any" else caption_length
 
211
  image = input_image.resize((384, 384), Image.LANCZOS)
212
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
213
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
214
+ if device.type == 'cuda':
215
+ pixel_values = pixel_values.to(device, dtype=torch.bfloat16)
216
+ else:
217
+ pixel_values = pixel_values.to(device) # CPU will use float32
218
 
219
  # Embed image
220
  # This results in Batch x Image Tokens x Features
221
+ # For CPU, autocast can use bfloat16 if available and beneficial, or can be disabled.
222
+ # For simplicity here, we'll enable it for CPU with bfloat16 if PyTorch supports it, else float32.
223
+ # Note: True CPU mixed precision benefits depend on CPU architecture and PyTorch version.
224
+ autocast_device_type = device.type
225
+ autocast_kwargs = {'enabled': True}
226
+ if autocast_device_type == 'cpu':
227
+ # Check if bfloat16 is supported for CPU autocast, otherwise default to float32 by not specifying dtype
228
+ # This check might be more involved depending on PyTorch version; for now, let's assume it handles it.
229
+ # Or, explicitly set dtype if needed: autocast_kwargs['dtype'] = torch.bfloat16
230
+ pass
231
+
232
+
233
+ with torch.amp.autocast_mode.autocast(autocast_device_type, **autocast_kwargs):
234
  vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
235
  embedded_images = image_adapter(vision_outputs.hidden_states)
236
+ # embedded_images are already on the correct device due to image_adapter.to(device)
237
+ # and operations within adapter should respect input tensor's device.
238
+ # Explicitly moving again to be safe, though may be redundant.
239
+ embedded_images = embedded_images.to(device)
240
 
241
  # Build the conversation
242
  convo = [
 
256
 
257
  # Tokenize the conversation
258
  # prompt_str is tokenized separately so we can do the calculations below
259
+ convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
260
+ prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False).to(device)
261
  assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
262
  convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
263
  prompt_tokens = prompt_tokens.squeeze(0)
264
 
265
  # Calculate where to inject the image
266
+ # Ensure convo_tokens is on the CPU for this kind of operation if it involves list conversion or complex indexing not ideal for GPU
267
+ eot_id_indices = (convo_tokens.cpu() == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
268
  assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
269
 
270
  preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
271
 
272
  # Embed the tokens
273
+ convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(text_model.device)) # Ensure tokens are on same device as text_model
274
 
275
  # Construct the input
276
+ # Ensure all parts are on the same device before concatenation
277
  input_embeds = torch.cat([
278
  convo_embeds[:, :preamble_len], # Part before the prompt
279
+ embedded_images.to(dtype=convo_embeds.dtype, device=convo_embeds.device), # Image, ensure same dtype and device
280
  convo_embeds[:, preamble_len:], # The prompt and anything after it
281
+ ], dim=1)
282
+ # input_embeds will be on the device of convo_embeds (i.e. text_model.device)
283
 
284
  input_ids = torch.cat([
285
  convo_tokens[:preamble_len].unsqueeze(0),
286
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long, device=convo_tokens.device), # Dummy tokens for the image
287
  convo_tokens[preamble_len:].unsqueeze(0),
288
+ ], dim=1)
289
+ # input_ids will be on the device of convo_tokens
290
+ attention_mask = torch.ones_like(input_ids) # Will be on the same device as input_ids
291
 
292
  # Debugging
293
  print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
 
309
  with gr.Blocks() as demo:
310
  gr.HTML(TITLE)
311
 
312
+ if device.type == 'cpu':
313
+ gr.Markdown("**Warning: Running on CPU.** Captions may take a very long time to generate (potentially several minutes). For faster performance, please use a Space with GPU hardware.")
314
+
315
  with gr.Row():
316
  with gr.Column():
317
  input_image = gr.Image(type="pil", label="Input Image")
 
368
 
369
 
370
  if __name__ == "__main__":
371
+ demo.launch()