fix dtype missmatchint input and model's weight
Browse files- modeling_hyperclovax.py +2 -1
modeling_hyperclovax.py
CHANGED
@@ -1135,7 +1135,8 @@ class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin):
|
|
1135 |
inputs_embeds = (
|
1136 |
inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds
|
1137 |
)
|
1138 |
-
|
|
|
1139 |
# pred : torch.int64 : [batchsize, generated token_length]
|
1140 |
pred = self.language_model.generate(
|
1141 |
inputs_embeds=inputs_embeds,
|
|
|
1135 |
inputs_embeds = (
|
1136 |
inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds
|
1137 |
)
|
1138 |
+
|
1139 |
+
inputs_embeds = inputs_embeds.to(dtype=self.base_model.dtype)
|
1140 |
# pred : torch.int64 : [batchsize, generated token_length]
|
1141 |
pred = self.language_model.generate(
|
1142 |
inputs_embeds=inputs_embeds,
|