ChatNT / text_generation.py
Yanisadel's picture
Update text_generation.py
4c6ceaf verified
raw
history blame contribute delete
4.83 kB
import numpy as np
import torch
from transformers import AutoTokenizer, Pipeline
class TextGenerationPipeline(Pipeline):
def __init__(self, model, **kwargs): # type: ignore
super().__init__(model=model, **kwargs)
# Load tokenizers
model_name = "InstaDeepAI/ChatNT"
self.english_tokenizer = AutoTokenizer.from_pretrained(
model_name, subfolder="english_tokenizer"
)
self.bio_tokenizer = AutoTokenizer.from_pretrained(
model_name, subfolder="bio_tokenizer"
)
def _sanitize_parameters(self, **kwargs: dict) -> tuple[dict, dict, dict]:
preprocess_kwargs = {}
forward_kwargs = {}
postprocess_kwargs = {} # type: ignore
if "max_num_tokens_to_decode" in kwargs:
forward_kwargs["max_num_tokens_to_decode"] = kwargs[
"max_num_tokens_to_decode"
]
if "english_tokens_max_length" in kwargs:
preprocess_kwargs["english_tokens_max_length"] = kwargs[
"english_tokens_max_length"
]
if "bio_tokens_max_length" in kwargs:
preprocess_kwargs["bio_tokens_max_length"] = kwargs["bio_tokens_max_length"]
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(
self,
inputs: dict,
english_tokens_max_length: int = 512,
bio_tokens_max_length: int = 512,
) -> dict:
english_sequence = inputs["english_sequence"]
dna_sequences = inputs["dna_sequences"]
context = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " # noqa
space = " "
if english_sequence[-1] == " ":
space = ""
english_sequence = context + english_sequence + space + "ASSISTANT:"
english_tokens = self.english_tokenizer(
english_sequence,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=english_tokens_max_length,
).input_ids
bio_tokens = self.bio_tokenizer(
dna_sequences,
return_tensors="pt",
padding="max_length",
max_length=bio_tokens_max_length,
truncation=True,
).input_ids.unsqueeze(0)
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
english_tokens = model_inputs["english_tokens"].clone()
bio_tokens = model_inputs["bio_tokens"].clone()
projected_bio_embeddings = None
actual_num_steps = 0
with torch.no_grad():
for _ in range(max_num_tokens_to_decode):
# Check if no more pad token id
if (
self.english_tokenizer.pad_token_id
not in english_tokens[0].cpu().numpy()
):
break
# Predictions
outs = self.model(
multi_omics_tokens_ids=(english_tokens, bio_tokens),
projection_english_tokens_ids=english_tokens,
projected_bio_embeddings=projected_bio_embeddings,
)
projected_bio_embeddings = outs["projected_bio_embeddings"]
logits = outs["logits"].detach().cpu().numpy()
# Get predicted token
first_idx_pad_token = np.where(
english_tokens[0].cpu() == self.english_tokenizer.pad_token_id
)[0][0]
predicted_token = np.argmax(logits[0, first_idx_pad_token - 1])
# If it's <eos> then stop, else add the predicted token
if predicted_token == self.english_tokenizer.eos_token_id:
break
else:
english_tokens[0, first_idx_pad_token] = predicted_token
actual_num_steps += 1
# Get the position where generation started
idx_begin_generation = np.where(
model_inputs["english_tokens"][0].cpu()
== self.english_tokenizer.pad_token_id
)[0][0]
# Get generated tokens
generated_tokens = english_tokens[
0, idx_begin_generation : idx_begin_generation + actual_num_steps
]
return {
"generated_tokens": generated_tokens,
}
def postprocess(self, model_outputs: dict) -> str:
generated_tokens = model_outputs["generated_tokens"]
generated_sequence: str = self.english_tokenizer.decode(generated_tokens)
return generated_sequence