Fix code to comport with newer Transformers library
#41
by
ctranslate2-4you
- opened
- modeling_GOT.py +39 -54
modeling_GOT.py
CHANGED
@@ -393,59 +393,46 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
393 |
def prepare_inputs_for_generation(
|
394 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
395 |
):
|
396 |
-
|
|
|
|
|
397 |
if past_key_values is not None:
|
398 |
if isinstance(past_key_values, Cache):
|
399 |
cache_length = past_key_values.get_seq_length()
|
400 |
-
|
401 |
-
|
|
|
402 |
else:
|
403 |
-
cache_length =
|
|
|
404 |
max_cache_length = None
|
405 |
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
elif past_length < input_ids.shape[1]:
|
415 |
-
input_ids = input_ids[:, past_length:]
|
416 |
-
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
417 |
-
|
418 |
-
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
419 |
-
if (
|
420 |
-
max_cache_length is not None
|
421 |
-
and attention_mask is not None
|
422 |
-
and cache_length + input_ids.shape[1] > max_cache_length
|
423 |
-
):
|
424 |
-
attention_mask = attention_mask[:, -max_cache_length:]
|
425 |
|
426 |
position_ids = kwargs.get("position_ids", None)
|
427 |
if attention_mask is not None and position_ids is None:
|
428 |
-
# create position_ids on the fly for batch generation
|
429 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
430 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
431 |
if past_key_values:
|
432 |
-
position_ids = position_ids[:, -input_ids.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
|
434 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
435 |
-
if inputs_embeds is not None and past_key_values is None:
|
436 |
-
model_inputs = {"inputs_embeds": inputs_embeds}
|
437 |
-
else:
|
438 |
-
model_inputs = {"input_ids": input_ids}
|
439 |
-
|
440 |
-
model_inputs.update(
|
441 |
-
{
|
442 |
-
"position_ids": position_ids,
|
443 |
-
"past_key_values": past_key_values,
|
444 |
-
"use_cache": kwargs.get("use_cache"),
|
445 |
-
"attention_mask": attention_mask,
|
446 |
-
"images": kwargs.get("images", None),
|
447 |
-
}
|
448 |
-
)
|
449 |
return model_inputs
|
450 |
|
451 |
def initialize_vision_tokenizer(
|
@@ -536,7 +523,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
536 |
|
537 |
conv_mpt = Conversation(
|
538 |
system="""<|im_start|>system
|
539 |
-
|
540 |
# system = None,
|
541 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
542 |
version="mpt",
|
@@ -728,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
728 |
return processed_images
|
729 |
|
730 |
|
731 |
-
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag
|
732 |
# Model
|
733 |
self.disable_torch_init()
|
734 |
multi_page=False
|
@@ -778,21 +765,18 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
778 |
image_tensor_1 = image_processor_high(image)
|
779 |
image_list.append(image_tensor_1)
|
780 |
|
781 |
-
|
782 |
image_list = torch.stack(image_list)
|
783 |
|
784 |
-
print('====new images batch size======: \n',image_list.shape)
|
785 |
-
|
786 |
|
787 |
if use_im_start_end:
|
788 |
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
|
789 |
else:
|
790 |
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
791 |
|
792 |
-
|
793 |
conv_mpt = Conversation(
|
794 |
system="""<|im_start|>system
|
795 |
-
|
796 |
# system = None,
|
797 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
798 |
version="mpt",
|
@@ -811,8 +795,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
811 |
print(prompt)
|
812 |
|
813 |
inputs = tokenizer([prompt])
|
814 |
-
|
815 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
|
|
816 |
|
817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
818 |
keywords = [stop_str]
|
@@ -824,25 +808,26 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
824 |
output_ids = self.generate(
|
825 |
input_ids,
|
826 |
images=[image_list.half().cuda()],
|
|
|
827 |
do_sample=False,
|
828 |
-
num_beams = 1,
|
829 |
-
# no_repeat_ngram_size = 20,
|
830 |
streamer=streamer,
|
|
|
831 |
max_new_tokens=4096,
|
832 |
stopping_criteria=[stopping_criteria]
|
833 |
-
|
|
|
834 |
else:
|
835 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
836 |
output_ids = self.generate(
|
837 |
input_ids,
|
838 |
images=[image_list.half().cuda()],
|
|
|
839 |
do_sample=False,
|
840 |
-
num_beams = 1,
|
841 |
-
# no_repeat_ngram_size = 20,
|
842 |
# streamer=streamer,
|
|
|
843 |
max_new_tokens=4096,
|
844 |
stopping_criteria=[stopping_criteria]
|
845 |
-
|
846 |
|
847 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
848 |
|
|
|
393 |
def prepare_inputs_for_generation(
|
394 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
395 |
):
|
396 |
+
if attention_mask is None:
|
397 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
|
398 |
+
|
399 |
if past_key_values is not None:
|
400 |
if isinstance(past_key_values, Cache):
|
401 |
cache_length = past_key_values.get_seq_length()
|
402 |
+
current_length = cache_length
|
403 |
+
max_cache_shape = past_key_values.get_max_cache_shape()
|
404 |
+
max_cache_length = max_cache_shape[1] if max_cache_shape else None
|
405 |
else:
|
406 |
+
cache_length = past_key_values[0][0].shape[2]
|
407 |
+
current_length = cache_length
|
408 |
max_cache_length = None
|
409 |
|
410 |
+
if attention_mask.shape[1] > input_ids.shape[1]:
|
411 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - cache_length):]
|
412 |
+
elif cache_length < input_ids.shape[1]:
|
413 |
+
input_ids = input_ids[:, cache_length:]
|
414 |
+
|
415 |
+
if max_cache_length is not None and attention_mask is not None:
|
416 |
+
if cache_length + input_ids.shape[1] > max_cache_length:
|
417 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
|
419 |
position_ids = kwargs.get("position_ids", None)
|
420 |
if attention_mask is not None and position_ids is None:
|
|
|
421 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
422 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
423 |
if past_key_values:
|
424 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
425 |
+
|
426 |
+
model_inputs = {
|
427 |
+
"input_ids": input_ids if inputs_embeds is None or past_key_values is not None else None,
|
428 |
+
"inputs_embeds": inputs_embeds if past_key_values is None else None,
|
429 |
+
"past_key_values": past_key_values,
|
430 |
+
"position_ids": position_ids,
|
431 |
+
"attention_mask": attention_mask,
|
432 |
+
"images": kwargs.get("images", None),
|
433 |
+
"use_cache": kwargs.get("use_cache", True)
|
434 |
+
}
|
435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
return model_inputs
|
437 |
|
438 |
def initialize_vision_tokenizer(
|
|
|
523 |
|
524 |
conv_mpt = Conversation(
|
525 |
system="""<|im_start|>system
|
526 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
527 |
# system = None,
|
528 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
529 |
version="mpt",
|
|
|
715 |
return processed_images
|
716 |
|
717 |
|
718 |
+
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag=False):
|
719 |
# Model
|
720 |
self.disable_torch_init()
|
721 |
multi_page=False
|
|
|
765 |
image_tensor_1 = image_processor_high(image)
|
766 |
image_list.append(image_tensor_1)
|
767 |
|
|
|
768 |
image_list = torch.stack(image_list)
|
769 |
|
770 |
+
# print('====new images batch size======: \n',image_list.shape)
|
|
|
771 |
|
772 |
if use_im_start_end:
|
773 |
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
|
774 |
else:
|
775 |
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
776 |
|
|
|
777 |
conv_mpt = Conversation(
|
778 |
system="""<|im_start|>system
|
779 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
780 |
# system = None,
|
781 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
782 |
version="mpt",
|
|
|
795 |
print(prompt)
|
796 |
|
797 |
inputs = tokenizer([prompt])
|
|
|
798 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
799 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
|
800 |
|
801 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
802 |
keywords = [stop_str]
|
|
|
808 |
output_ids = self.generate(
|
809 |
input_ids,
|
810 |
images=[image_list.half().cuda()],
|
811 |
+
attention_mask=attention_mask,
|
812 |
do_sample=False,
|
|
|
|
|
813 |
streamer=streamer,
|
814 |
+
num_beams=1,
|
815 |
max_new_tokens=4096,
|
816 |
stopping_criteria=[stopping_criteria]
|
817 |
+
)
|
818 |
+
|
819 |
else:
|
820 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
821 |
output_ids = self.generate(
|
822 |
input_ids,
|
823 |
images=[image_list.half().cuda()],
|
824 |
+
attention_mask=attention_mask,
|
825 |
do_sample=False,
|
|
|
|
|
826 |
# streamer=streamer,
|
827 |
+
num_beams=1,
|
828 |
max_new_tokens=4096,
|
829 |
stopping_criteria=[stopping_criteria]
|
830 |
+
)
|
831 |
|
832 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
833 |
|