Text Generation
Transformers
PyTorch
chatts
feature-extraction
conversational
custom_code
xiezhe24 commited on
Commit
1dd465d
·
verified ·
1 Parent(s): db0db5e

Update modeling_qwen2.py (#7)

Browse files

- Update modeling_qwen2.py (1f8297368338fdac2bf4306e7f6dd98d26d20d46)

Files changed (1) hide show
  1. modeling_qwen2.py +9 -2
modeling_qwen2.py CHANGED
@@ -1450,6 +1450,9 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1450
  attention_mask=attention_mask
1451
  )
1452
 
 
 
 
1453
  def _update_model_kwargs_for_generation(
1454
  self,
1455
  outputs: ModelOutput,
@@ -1505,8 +1508,12 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1505
  if past_key_values is not None:
1506
  if isinstance(past_key_values, Cache):
1507
  cache_length = past_key_values.get_seq_length()
1508
- past_length = past_key_values.seen_tokens
1509
- max_cache_length = past_key_values.get_max_length()
 
 
 
 
1510
  else:
1511
  cache_length = past_length = past_key_values[0][0].shape[2]
1512
  max_cache_length = None
 
1450
  attention_mask=attention_mask
1451
  )
1452
 
1453
+ def _extract_past_from_model_output(self, outputs: ModelOutput):
1454
+ return "past_key_values", outputs.past_key_values
1455
+
1456
  def _update_model_kwargs_for_generation(
1457
  self,
1458
  outputs: ModelOutput,
 
1508
  if past_key_values is not None:
1509
  if isinstance(past_key_values, Cache):
1510
  cache_length = past_key_values.get_seq_length()
1511
+ past_length = past_key_values.seen_tokens
1512
+ max_cache_length = (
1513
+ past_key_values.get_max_length()
1514
+ if hasattr(past_key_values, "get_max_length")
1515
+ else past_key_values.get_max_cache_shape()
1516
+ )
1517
  else:
1518
  cache_length = past_length = past_key_values[0][0].shape[2]
1519
  max_cache_length = None