|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, Pipeline |
|
|
|
|
|
class TextGenerationPipeline(Pipeline): |
|
def __init__(self, model, **kwargs): |
|
super().__init__(model=model, **kwargs) |
|
|
|
model_name = "InstaDeepAI/ChatNT" |
|
self.english_tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, subfolder="english_tokenizer" |
|
) |
|
self.bio_tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, subfolder="bio_tokenizer" |
|
) |
|
|
|
def _sanitize_parameters(self, **kwargs: dict) -> tuple[dict, dict, dict]: |
|
preprocess_kwargs = {} |
|
forward_kwargs = {} |
|
postprocess_kwargs = {} |
|
|
|
if "max_num_tokens_to_decode" in kwargs: |
|
forward_kwargs["max_num_tokens_to_decode"] = kwargs[ |
|
"max_num_tokens_to_decode" |
|
] |
|
if "english_tokens_max_length" in kwargs: |
|
preprocess_kwargs["english_tokens_max_length"] = kwargs[ |
|
"english_tokens_max_length" |
|
] |
|
if "bio_tokens_max_length" in kwargs: |
|
preprocess_kwargs["bio_tokens_max_length"] = kwargs["bio_tokens_max_length"] |
|
|
|
return preprocess_kwargs, forward_kwargs, postprocess_kwargs |
|
|
|
def preprocess( |
|
self, |
|
inputs: dict, |
|
english_tokens_max_length: int = 512, |
|
bio_tokens_max_length: int = 512, |
|
) -> dict: |
|
english_sequence = inputs["english_sequence"] |
|
dna_sequences = inputs["dna_sequences"] |
|
|
|
context = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " |
|
space = " " |
|
if english_sequence[-1] == " ": |
|
space = "" |
|
english_sequence = context + english_sequence + space + "ASSISTANT:" |
|
|
|
english_tokens = self.english_tokenizer( |
|
english_sequence, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=english_tokens_max_length, |
|
).input_ids |
|
bio_tokens = self.bio_tokenizer( |
|
dna_sequences, |
|
return_tensors="pt", |
|
padding="max_length", |
|
max_length=bio_tokens_max_length, |
|
truncation=True, |
|
).input_ids.unsqueeze(0) |
|
|
|
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens} |
|
|
|
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict: |
|
english_tokens = model_inputs["english_tokens"].clone() |
|
bio_tokens = model_inputs["bio_tokens"].clone() |
|
projected_bio_embeddings = None |
|
|
|
actual_num_steps = 0 |
|
with torch.no_grad(): |
|
for _ in range(max_num_tokens_to_decode): |
|
|
|
if ( |
|
self.english_tokenizer.pad_token_id |
|
not in english_tokens[0].cpu().numpy() |
|
): |
|
break |
|
|
|
|
|
outs = self.model( |
|
multi_omics_tokens_ids=(english_tokens, bio_tokens), |
|
projection_english_tokens_ids=english_tokens, |
|
projected_bio_embeddings=projected_bio_embeddings, |
|
) |
|
projected_bio_embeddings = outs["projected_bio_embeddings"] |
|
logits = outs["logits"].detach().cpu().numpy() |
|
|
|
|
|
first_idx_pad_token = np.where( |
|
english_tokens[0].cpu() == self.english_tokenizer.pad_token_id |
|
)[0][0] |
|
predicted_token = np.argmax(logits[0, first_idx_pad_token - 1]) |
|
|
|
|
|
if predicted_token == self.english_tokenizer.eos_token_id: |
|
break |
|
else: |
|
english_tokens[0, first_idx_pad_token] = predicted_token |
|
actual_num_steps += 1 |
|
|
|
|
|
idx_begin_generation = np.where( |
|
model_inputs["english_tokens"][0].cpu() |
|
== self.english_tokenizer.pad_token_id |
|
)[0][0] |
|
|
|
|
|
generated_tokens = english_tokens[ |
|
0, idx_begin_generation : idx_begin_generation + actual_num_steps |
|
] |
|
|
|
return { |
|
"generated_tokens": generated_tokens, |
|
} |
|
|
|
def postprocess(self, model_outputs: dict) -> str: |
|
generated_tokens = model_outputs["generated_tokens"] |
|
generated_sequence: str = self.english_tokenizer.decode(generated_tokens) |
|
return generated_sequence |
|
|