blip2zh-chatglm-6b / modeling_blip2chatglm.py
shwu
feat: new generation code
d6e13c5
raw
history blame contribute delete
16.3 kB
import copy
import os
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import pad_sequence
import warnings
from torch import Tensor, nn
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
Blip2VisionModel,
Blip2QFormerModel,
Blip2Model,
Blip2PreTrainedModel,
Blip2ForConditionalGeneration,
GenerationConfig,
)
from transformers.models.blip_2.modeling_blip_2 import (
Blip2ForConditionalGenerationModelOutput,
)
from transformers.utils import logging
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from .modeling_chatglm import (
ChatGLMForConditionalGeneration,
InvalidScoreLogitsProcessor,
)
from .configuration_blip2chatglm import Blip2ChatGLMConfig
logger = logging.get_logger(__name__)
class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
config_class = Blip2ChatGLMConfig
def __init__(self, config: Blip2ChatGLMConfig):
Blip2PreTrainedModel.__init__(self, config)
# NOTE: we only initialize Blip2PreTrainedModel
# directly call super().__init__() will cause error since ChatGLM cannot be found by AutoModel
self.vision_model = Blip2VisionModel(config.vision_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
)
self.qformer = Blip2QFormerModel(config.qformer_config)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size, config.text_config.hidden_size
)
self.language_model = ChatGLMForConditionalGeneration(config.text_config)
# Initialize weights and apply final processing
# self.post_init()
def setup_dtype(self, vision_encoder_dtype: str = "fp32", lm_dtype: str = "fp16"):
if vision_encoder_dtype == "fp32":
self.vision_model = self.vision_model.float()
elif vision_encoder_dtype == "fp16":
self.vision_model = self.vision_model.half()
else:
raise NotImplementedError(
f"Unsupported vision_encoder_dtype: {vision_encoder_dtype}"
)
if lm_dtype == "fp32":
self.language_model = self.language_model.float()
elif lm_dtype == "fp16":
self.language_model = self.language_model.half()
elif lm_dtype == "int4":
self.language_model = self.language_model.half().quantize(4)
elif lm_dtype == "int8":
self.language_model = self.language_model.half().quantize(8)
else:
raise NotImplementedError(f"Unsupported lm_dtype: {lm_dtype}")
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.FloatTensor,
image_slot_offset: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
"""_summary_
Args:
pixel_values (torch.FloatTensor): _description_
input_ids (torch.FloatTensor): input_ids[:, :num_query_tokens] should be filled with tokenizer.unk_token_id
image_slot_offset (Optional[torch.LongTensor], optional): if not set, all vtokens are placed as prefix (image_slot_offset = torch.zeros(bsz)). Defaults to None.
attention_mask (Optional[torch.LongTensor], optional): _description_. Defaults to None.
output_attentions (Optional[bool], optional): _description_. Defaults to None.
output_hidden_states (Optional[bool], optional): _description_. Defaults to None.
labels (Optional[torch.LongTensor], optional): _description_. Defaults to None.
return_dict (Optional[bool], optional): _description_. Defaults to None.
Returns:
Union[Tuple, Blip2ForConditionalGenerationModelOutput]: _description_
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device
)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_slot_offset is None:
# image as prefix
# update data to avoid inplace operation of leaf Variable
inputs_embeds.data[
:, : self.config.num_query_tokens, :
] = language_model_inputs
else:
for i, offset in enumerate(image_slot_offset):
inputs_embeds.data[
i, offset : offset + self.config.num_query_tokens, :
] = language_model_inputs[i]
outputs = self.language_model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(logits.device)
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="mean")
loss = loss_fct(
shift_logits.view(-1, self.config.text_config.vocab_size),
shift_labels.view(-1),
)
if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
return ((loss,) + output) if loss is not None else output
return Blip2ForConditionalGenerationModelOutput(
loss=loss,
logits=logits,
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
def prepare_inputs_for_chat(
self,
tokenizer: PreTrainedTokenizer,
batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
max_length: int,
user_role: str = "问",
bot_role: str = "答",
):
device = self.device
nvtokens = self.config.num_query_tokens
# 1. Prepare token ids
all_images = []
all_image_slots = []
all_input_ids = []
for messages in batch_messages:
images = []
image_slots = []
input_ids = []
round_roles = [set()]
for role, qtext, qimgs in messages:
if role in round_roles[-1]:
# a new round (not the first round)
input_ids += tokenizer(
f"\n[Round {len(round_roles)}]\n{role}:",
add_special_tokens=False,
).input_ids
round_roles.append({role})
else:
round_roles[-1].add(role)
input_ids += tokenizer(
# For first role, no new line
f"\n{role}:" if len(input_ids) != 0 else f"{role}:", add_special_tokens=False
).input_ids
cur_index = 0
for qimg, img_idx in qimgs:
if img_idx > cur_index:
input_ids += tokenizer(
qtext[cur_index:img_idx], add_special_tokens=False
).input_ids
cur_index = img_idx
# image slot, embedding will be replaced by image embeddings
image_slots.append(len(input_ids))
input_ids += [tokenizer.unk_token_id] * nvtokens
images.append(qimg)
input_ids += tokenizer(
qtext[cur_index:], add_special_tokens=False
).input_ids
if len(round_roles) == 1:
# only 1 round
if len(round_roles[0]) == 1 and user_role in round_roles[0]:
# only user role
input_ids += tokenizer("").input_ids
else:
input_ids += tokenizer(f"\n{bot_role}:").input_ids
else:
# add tag for round 0
input_ids = (
tokenizer(f"[Round 0]\n", add_special_tokens=False).input_ids
+ input_ids
)
input_ids += tokenizer(f"\n{bot_role}:").input_ids
if len(input_ids) >= max_length:
image_slots_after_truncate = []
images_after_truncate = []
truncate_index = len(input_ids) - max_length
for image_slot, image in zip(image_slots, images):
# truncate from left
if len(input_ids) - image_slot < max_length:
image_slots_after_truncate.append(image_slot)
images_after_truncate.append(image)
elif len(input_ids) - (image_slot + nvtokens) < max_length:
# in-contact image slot is not allowed
truncate_index = max(truncate_index, image_slot + nvtokens)
for i, image_slot in enumerate(image_slots_after_truncate):
image_slots_after_truncate[i] = image_slot - truncate_index
input_ids = input_ids[truncate_index:]
image_slots = image_slots_after_truncate
images = images_after_truncate
# print(tokenizer.convert_ids_to_tokens(input_ids))
all_images.extend(images)
all_image_slots.append(image_slots)
all_input_ids.append(input_ids)
# 2. Prepare image embeddings
if len(all_images) != 0:
vision_outputs = self.vision_model.forward(torch.cat(all_images, dim=0))
all_image_embeds = vision_outputs[0]
indices_or_sections = [len(chunk) for chunk in all_image_slots]
indices_or_sections = np.cumsum(indices_or_sections)
all_vtokens = []
# TODO: qformer not batched
for image_embeds in torch.tensor_split(
all_image_embeds, tuple(indices_or_sections)
):
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
device
)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer.forward(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
)
query_output = query_outputs[0]
all_vtokens.append(self.language_projection(query_output))
else:
all_vtokens = None
# 3. Place image embeddings into slots
input_ids = (
torch.ones(
(len(all_input_ids), max(len(ids) for ids in all_input_ids)),
dtype=torch.long,
)
* tokenizer.pad_token_id
)
for i, ids in enumerate(all_input_ids):
# pad left
input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long)
input_ids = input_ids.to(device)
inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
if all_vtokens is not None:
for i, (image_slots, vtokens) in enumerate(
zip(all_image_slots, all_vtokens)
):
for slot, vimg in zip(image_slots, vtokens):
inputs_embeds[i][slot : slot + nvtokens, :] = vimg
return input_ids, inputs_embeds
@torch.no_grad()
def batch_chat(
self,
tokenizer: PreTrainedTokenizer,
batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
max_length: int = 2048,
num_beams=1,
do_sample=True,
top_p=0.7,
temperature=0.95,
user_role: str = "问",
bot_role: str = "答",
**kwargs,
):
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
tokenizer=tokenizer,
batch_messages=batch_messages,
max_length=max_length,
user_role=user_role,
bot_role=bot_role,
)
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
outputs = self.language_model.generate(
input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
)
responses = []
for i, output in enumerate(outputs.tolist()):
output = output[len(input_ids[i]) :]
response = tokenizer.decode(output)
responses.append(self.language_model.process_response(response))
return responses
@torch.no_grad()
def stream_chat(
self,
tokenizer: PreTrainedTokenizer,
messages: List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]],
num_beams=5,
max_length=512,
top_p=0.9,
do_sample=True,
temperature=1,
user_role: str = "问",
bot_role: str = "答",
**kwargs,
):
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
tokenizer=tokenizer,
batch_messages=[messages],
max_length=max_length,
user_role=user_role,
bot_role=bot_role,
)
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
for outputs in self.language_model.stream_generate(
input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
):
outputs = outputs.tolist()[0][len(input_ids[0]) :]
response = tokenizer.decode(outputs)
response = self.language_model.process_response(response)
yield response