shilinxu commited on
Commit
6843f07
·
verified ·
1 Parent(s): a7c0ad8

Update modeling_qwen2_vl.py

Browse files
Files changed (1) hide show
  1. modeling_qwen2_vl.py +3 -3
modeling_qwen2_vl.py CHANGED
@@ -395,13 +395,13 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
395
  return rotary_pos_emb
396
 
397
  @auto_docstring
398
- def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
399
  r"""
400
  grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
401
  The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values.
402
  """
403
- hidden_states = self.patch_embed(hidden_states)
404
- rotary_pos_emb = self.rot_pos_emb(grid_thw)
405
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
406
  position_embeddings = (emb.cos(), emb.sin())
407
 
 
395
  return rotary_pos_emb
396
 
397
  @auto_docstring
398
+ def forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor:
399
  r"""
400
  grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
401
  The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values.
402
  """
403
+ hidden_states = self.patch_embed(pixel_values)
404
+ rotary_pos_emb = self.rot_pos_emb(image_grid_thw)
405
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
406
  position_embeddings = (emb.cos(), emb.sin())
407