fixes the asserion error when num_beams > 1

#42
by freewym - opened
Files changed (1) hide show
  1. modeling_phi4mm.py +1 -1
modeling_phi4mm.py CHANGED
@@ -2096,7 +2096,7 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
2096
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2097
 
2098
  if isinstance(input_mode, torch.Tensor):
2099
- assert len(input_mode) == 1
2100
  input_mode = input_mode[0].item()
2101
  input_mode = InputMode(input_mode)
2102
 
 
2096
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2097
 
2098
  if isinstance(input_mode, torch.Tensor):
2099
+ # len(input_mode) == num_beams in beam search, and all elements of input_mode should have the same value
2100
  input_mode = input_mode[0].item()
2101
  input_mode = InputMode(input_mode)
2102