KingNish commited on
Commit
c40e1ba
·
verified ·
1 Parent(s): 050df43

Update modeling/bagel/bagel.py

Browse files
Files changed (1) hide show
  1. modeling/bagel/bagel.py +9 -22
modeling/bagel/bagel.py CHANGED
@@ -897,13 +897,9 @@ class Bagel(PreTrainedModel):
897
  the behavior of the original batch generation function, including the handling
898
  of start tokens and the end-of-sequence token.
899
  """
 
900
  curr_tokens = packed_start_tokens
901
-
902
- for _ in range(max_length):
903
- # The original function would append `curr_tokens` to a list at this point.
904
- # Instead, we yield it to the caller, enabling streaming.
905
- yield curr_tokens
906
-
907
  packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
908
  query_lens = torch.ones_like(curr_tokens)
909
  packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
@@ -912,9 +908,6 @@ class Bagel(PreTrainedModel):
912
  dtype=key_values_lens.dtype
913
  )
914
 
915
- # This block modifies packed_key_value_indexes before the forward pass,
916
- # preserving the specific logic for NaViT-style packed inputs.
917
- # The typo 'uppacked' is kept to match the original source code.
918
  uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
919
  for i in range(len(uppacked)):
920
  uppacked[i] += i
@@ -940,20 +933,12 @@ class Bagel(PreTrainedModel):
940
  packed_query_sequence = output.packed_query_sequence
941
  pred_logits = self.language_model.lm_head(packed_query_sequence)
942
 
943
- # Sample the next token
944
  if do_sample:
945
  probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
946
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
947
  else:
948
- next_tokens = torch.argmax(pred_logits, dim=-1)
949
-
950
- # The stop condition is checked on the newly generated token. If it's the
951
- # end token, we break the loop. This token will not be yielded.
952
- if end_token_id is not None and next_tokens[0] == end_token_id: # only support batch=1
953
- break
954
 
955
- # This block updates the state variables for the next iteration. It reads
956
- # the already-modified `packed_key_value_indexes` and updates it further.
957
  uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
958
  for i in range(len(uppacked)):
959
  uppacked[i] = torch.cat(
@@ -962,10 +947,12 @@ class Bagel(PreTrainedModel):
962
  packed_key_value_indexes = torch.cat(uppacked, dim=0)
963
  key_values_lens = key_values_lens + 1
964
  packed_query_position_ids = packed_query_position_ids + 1
965
-
966
- # The newly generated token becomes the input for the next loop iteration.
967
- curr_tokens = next_tokens
968
 
 
 
969
  # for evaluation
970
  @torch.no_grad()
971
  def chat(
 
897
  the behavior of the original batch generation function, including the handling
898
  of start tokens and the end-of-sequence token.
899
  """
900
+ step = 0
901
  curr_tokens = packed_start_tokens
902
+ while step < max_length:
 
 
 
 
 
903
  packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
904
  query_lens = torch.ones_like(curr_tokens)
905
  packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
 
908
  dtype=key_values_lens.dtype
909
  )
910
 
 
 
 
911
  uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
912
  for i in range(len(uppacked)):
913
  uppacked[i] += i
 
933
  packed_query_sequence = output.packed_query_sequence
934
  pred_logits = self.language_model.lm_head(packed_query_sequence)
935
 
 
936
  if do_sample:
937
  probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
938
+ curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
939
  else:
940
+ curr_tokens = torch.argmax(pred_logits, dim=-1)
 
 
 
 
 
941
 
 
 
942
  uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
943
  for i in range(len(uppacked)):
944
  uppacked[i] = torch.cat(
 
947
  packed_key_value_indexes = torch.cat(uppacked, dim=0)
948
  key_values_lens = key_values_lens + 1
949
  packed_query_position_ids = packed_query_position_ids + 1
950
+ step += 1
951
+
952
+ yield curr_tokens # Yield each token as it's generated
953
 
954
+ if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
955
+ break
956
  # for evaluation
957
  @torch.no_grad()
958
  def chat(