fix dtype missmatchint input and model's weight
#15
by
HERIUN
- opened
- modeling_hyperclovax.py +2 -1
modeling_hyperclovax.py
CHANGED
|
@@ -1135,7 +1135,8 @@ class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 1135 |
inputs_embeds = (
|
| 1136 |
inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds
|
| 1137 |
)
|
| 1138 |
-
|
|
|
|
| 1139 |
# pred : torch.int64 : [batchsize, generated token_length]
|
| 1140 |
pred = self.language_model.generate(
|
| 1141 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 1135 |
inputs_embeds = (
|
| 1136 |
inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds
|
| 1137 |
)
|
| 1138 |
+
|
| 1139 |
+
inputs_embeds = inputs_embeds.to(dtype=self.base_model.dtype)
|
| 1140 |
# pred : torch.int64 : [batchsize, generated token_length]
|
| 1141 |
pred = self.language_model.generate(
|
| 1142 |
inputs_embeds=inputs_embeds,
|