Update modeling_qwen2_vl.py
Browse files- modeling_qwen2_vl.py +3 -3
modeling_qwen2_vl.py
CHANGED
|
@@ -395,13 +395,13 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|
| 395 |
return rotary_pos_emb
|
| 396 |
|
| 397 |
@auto_docstring
|
| 398 |
-
def forward(self,
|
| 399 |
r"""
|
| 400 |
grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
|
| 401 |
The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values.
|
| 402 |
"""
|
| 403 |
-
hidden_states = self.patch_embed(
|
| 404 |
-
rotary_pos_emb = self.rot_pos_emb(
|
| 405 |
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 406 |
position_embeddings = (emb.cos(), emb.sin())
|
| 407 |
|
|
|
|
| 395 |
return rotary_pos_emb
|
| 396 |
|
| 397 |
@auto_docstring
|
| 398 |
+
def forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor:
|
| 399 |
r"""
|
| 400 |
grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
|
| 401 |
The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values.
|
| 402 |
"""
|
| 403 |
+
hidden_states = self.patch_embed(pixel_values)
|
| 404 |
+
rotary_pos_emb = self.rot_pos_emb(image_grid_thw)
|
| 405 |
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 406 |
position_embeddings = (emb.cos(), emb.sin())
|
| 407 |
|