Fix code to comport with newer Transformers library

#41
Files changed (1) hide show
  1. modeling_GOT.py +39 -54
modeling_GOT.py CHANGED
@@ -393,59 +393,46 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
393
  def prepare_inputs_for_generation(
394
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
  ):
396
- # Omit tokens covered by past_key_values
 
 
397
  if past_key_values is not None:
398
  if isinstance(past_key_values, Cache):
399
  cache_length = past_key_values.get_seq_length()
400
- past_length = past_key_values.seen_tokens
401
- max_cache_length = past_key_values.get_max_length()
 
402
  else:
403
- cache_length = past_length = past_key_values[0][0].shape[2]
 
404
  max_cache_length = None
405
 
406
- # Keep only the unprocessed tokens:
407
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
408
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
409
- # input)
410
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
411
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
412
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
413
- # input_ids based on the past_length.
414
- elif past_length < input_ids.shape[1]:
415
- input_ids = input_ids[:, past_length:]
416
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
417
-
418
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
419
- if (
420
- max_cache_length is not None
421
- and attention_mask is not None
422
- and cache_length + input_ids.shape[1] > max_cache_length
423
- ):
424
- attention_mask = attention_mask[:, -max_cache_length:]
425
 
426
  position_ids = kwargs.get("position_ids", None)
427
  if attention_mask is not None and position_ids is None:
428
- # create position_ids on the fly for batch generation
429
  position_ids = attention_mask.long().cumsum(-1) - 1
430
  position_ids.masked_fill_(attention_mask == 0, 1)
431
  if past_key_values:
432
- position_ids = position_ids[:, -input_ids.shape[1] :]
 
 
 
 
 
 
 
 
 
 
433
 
434
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
435
- if inputs_embeds is not None and past_key_values is None:
436
- model_inputs = {"inputs_embeds": inputs_embeds}
437
- else:
438
- model_inputs = {"input_ids": input_ids}
439
-
440
- model_inputs.update(
441
- {
442
- "position_ids": position_ids,
443
- "past_key_values": past_key_values,
444
- "use_cache": kwargs.get("use_cache"),
445
- "attention_mask": attention_mask,
446
- "images": kwargs.get("images", None),
447
- }
448
- )
449
  return model_inputs
450
 
451
  def initialize_vision_tokenizer(
@@ -536,7 +523,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
536
 
537
  conv_mpt = Conversation(
538
  system="""<|im_start|>system
539
- You should follow the instructions carefully and explain your answers in detail.""",
540
  # system = None,
541
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
542
  version="mpt",
@@ -728,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
728
  return processed_images
729
 
730
 
731
- def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
732
  # Model
733
  self.disable_torch_init()
734
  multi_page=False
@@ -778,21 +765,18 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
778
  image_tensor_1 = image_processor_high(image)
779
  image_list.append(image_tensor_1)
780
 
781
-
782
  image_list = torch.stack(image_list)
783
 
784
- print('====new images batch size======: \n',image_list.shape)
785
-
786
 
787
  if use_im_start_end:
788
  qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
789
  else:
790
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
791
 
792
-
793
  conv_mpt = Conversation(
794
  system="""<|im_start|>system
795
- You should follow the instructions carefully and explain your answers in detail.""",
796
  # system = None,
797
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
798
  version="mpt",
@@ -811,8 +795,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
811
  print(prompt)
812
 
813
  inputs = tokenizer([prompt])
814
-
815
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
 
816
 
817
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
  keywords = [stop_str]
@@ -824,25 +808,26 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
824
  output_ids = self.generate(
825
  input_ids,
826
  images=[image_list.half().cuda()],
 
827
  do_sample=False,
828
- num_beams = 1,
829
- # no_repeat_ngram_size = 20,
830
  streamer=streamer,
 
831
  max_new_tokens=4096,
832
  stopping_criteria=[stopping_criteria]
833
- )
 
834
  else:
835
  with torch.autocast("cuda", dtype=torch.bfloat16):
836
  output_ids = self.generate(
837
  input_ids,
838
  images=[image_list.half().cuda()],
 
839
  do_sample=False,
840
- num_beams = 1,
841
- # no_repeat_ngram_size = 20,
842
  # streamer=streamer,
 
843
  max_new_tokens=4096,
844
  stopping_criteria=[stopping_criteria]
845
- )
846
 
847
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
848
 
 
393
  def prepare_inputs_for_generation(
394
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
  ):
396
+ if attention_mask is None:
397
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
398
+
399
  if past_key_values is not None:
400
  if isinstance(past_key_values, Cache):
401
  cache_length = past_key_values.get_seq_length()
402
+ current_length = cache_length
403
+ max_cache_shape = past_key_values.get_max_cache_shape()
404
+ max_cache_length = max_cache_shape[1] if max_cache_shape else None
405
  else:
406
+ cache_length = past_key_values[0][0].shape[2]
407
+ current_length = cache_length
408
  max_cache_length = None
409
 
410
+ if attention_mask.shape[1] > input_ids.shape[1]:
411
+ input_ids = input_ids[:, -(attention_mask.shape[1] - cache_length):]
412
+ elif cache_length < input_ids.shape[1]:
413
+ input_ids = input_ids[:, cache_length:]
414
+
415
+ if max_cache_length is not None and attention_mask is not None:
416
+ if cache_length + input_ids.shape[1] > max_cache_length:
417
+ attention_mask = attention_mask[:, -max_cache_length:]
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  position_ids = kwargs.get("position_ids", None)
420
  if attention_mask is not None and position_ids is None:
 
421
  position_ids = attention_mask.long().cumsum(-1) - 1
422
  position_ids.masked_fill_(attention_mask == 0, 1)
423
  if past_key_values:
424
+ position_ids = position_ids[:, -input_ids.shape[1]:]
425
+
426
+ model_inputs = {
427
+ "input_ids": input_ids if inputs_embeds is None or past_key_values is not None else None,
428
+ "inputs_embeds": inputs_embeds if past_key_values is None else None,
429
+ "past_key_values": past_key_values,
430
+ "position_ids": position_ids,
431
+ "attention_mask": attention_mask,
432
+ "images": kwargs.get("images", None),
433
+ "use_cache": kwargs.get("use_cache", True)
434
+ }
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  return model_inputs
437
 
438
  def initialize_vision_tokenizer(
 
523
 
524
  conv_mpt = Conversation(
525
  system="""<|im_start|>system
526
+ You should follow the instructions carefully and explain your answers in detail.""",
527
  # system = None,
528
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
529
  version="mpt",
 
715
  return processed_images
716
 
717
 
718
+ def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag=False):
719
  # Model
720
  self.disable_torch_init()
721
  multi_page=False
 
765
  image_tensor_1 = image_processor_high(image)
766
  image_list.append(image_tensor_1)
767
 
 
768
  image_list = torch.stack(image_list)
769
 
770
+ # print('====new images batch size======: \n',image_list.shape)
 
771
 
772
  if use_im_start_end:
773
  qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
774
  else:
775
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
776
 
 
777
  conv_mpt = Conversation(
778
  system="""<|im_start|>system
779
+ You should follow the instructions carefully and explain your answers in detail.""",
780
  # system = None,
781
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
782
  version="mpt",
 
795
  print(prompt)
796
 
797
  inputs = tokenizer([prompt])
 
798
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
799
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
800
 
801
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
802
  keywords = [stop_str]
 
808
  output_ids = self.generate(
809
  input_ids,
810
  images=[image_list.half().cuda()],
811
+ attention_mask=attention_mask,
812
  do_sample=False,
 
 
813
  streamer=streamer,
814
+ num_beams=1,
815
  max_new_tokens=4096,
816
  stopping_criteria=[stopping_criteria]
817
+ )
818
+
819
  else:
820
  with torch.autocast("cuda", dtype=torch.bfloat16):
821
  output_ids = self.generate(
822
  input_ids,
823
  images=[image_list.half().cuda()],
824
+ attention_mask=attention_mask,
825
  do_sample=False,
 
 
826
  # streamer=streamer,
827
+ num_beams=1,
828
  max_new_tokens=4096,
829
  stopping_criteria=[stopping_criteria]
830
+ )
831
 
832
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
833