HERIUN commited on
Commit
3bca6fe
·
verified ·
1 Parent(s): 41a5265

fix dtype missmatchint input and model's weight

Browse files
Files changed (1) hide show
  1. modeling_hyperclovax.py +2 -1
modeling_hyperclovax.py CHANGED
@@ -1135,7 +1135,8 @@ class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin):
1135
  inputs_embeds = (
1136
  inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds
1137
  )
1138
-
 
1139
  # pred : torch.int64 : [batchsize, generated token_length]
1140
  pred = self.language_model.generate(
1141
  inputs_embeds=inputs_embeds,
 
1135
  inputs_embeds = (
1136
  inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds
1137
  )
1138
+
1139
+ inputs_embeds = inputs_embeds.to(dtype=self.base_model.dtype)
1140
  # pred : torch.int64 : [batchsize, generated token_length]
1141
  pred = self.language_model.generate(
1142
  inputs_embeds=inputs_embeds,