Update modeling_hyperclovax.py
Browse files- modeling_hyperclovax.py +1 -3
modeling_hyperclovax.py
CHANGED
@@ -1132,9 +1132,7 @@ class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin):
|
|
1132 |
first_last_frames_slows=first_last_frames_slows,
|
1133 |
is_videos=is_videos,
|
1134 |
)
|
1135 |
-
inputs_embeds = (
|
1136 |
-
inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds
|
1137 |
-
)
|
1138 |
|
1139 |
# pred : torch.int64 : [batchsize, generated token_length]
|
1140 |
pred = self.language_model.generate(
|
|
|
1132 |
first_last_frames_slows=first_last_frames_slows,
|
1133 |
is_videos=is_videos,
|
1134 |
)
|
1135 |
+
inputs_embeds = inputs_embeds.to(device=self.language_model.device, dtype=self.language_model.dtype)
|
|
|
|
|
1136 |
|
1137 |
# pred : torch.int64 : [batchsize, generated token_length]
|
1138 |
pred = self.language_model.generate(
|