ga89tiy
last update
f5ae994
raw
history blame contribute delete
4.7 kB
from pathlib import Path
import io
import requests
import torch
from PIL import Image
import numpy as np
from huggingface_hub import snapshot_download, hf_hub_download
from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
from LLAVA_Biovil.llava.model.builder import load_pretrained_model
from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1
from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
from utils import create_chest_xray_transform_for_inference, init_chexpert_predictor
def load_model_from_huggingface(repo_id):
# Download model files
model_path = snapshot_download(repo_id=repo_id, revision="main")
model_path = Path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
return tokenizer, model, image_processor, context_len
if __name__ == '__main__':
sample_img_path = "https://openi.nlm.nih.gov/imgs/512/294/3502/CXR3502_IM-1707-1001.png?keywords=Surgical%20Instruments,Cardiomegaly,Pulmonary%20Congestion,Diaphragm"
response = requests.get(sample_img_path)
image = Image.open(io.BytesIO(response.content))
image = remap_to_uint8(np.array(image))
image = Image.fromarray(image).convert("L")
tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="ChantalPellegrini/RaDialog-interactive-radiology-report-generation")
cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
model.config.tokenizer_padding_side = "left"
cp_image = cp_transforms(image)
logits = cp_model(cp_image[None].half().cuda())
preds_probs = torch.sigmoid(logits)
preds = preds_probs > 0.5
pred = preds[0].cpu().numpy()
findings = cp_class_names[pred].tolist()
findings = ', '.join(findings).lower().strip()
conv = conv_vicuna_v1.copy()
REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
print("USER: ", REPORT_GEN_PROMPT)
conv.append_message("USER", REPORT_GEN_PROMPT)
conv.append_message("ASSISTANT", None)
text_input = conv.get_prompt()
# get the image
vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
image_tensor = vis_transforms_biovil(image).unsqueeze(0)
image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
# generate a report
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=False,
use_cache=True,
max_new_tokens=300,
stopping_criteria=[stopping_criteria],
pad_token_id=tokenizer.pad_token_id
)
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
print("ASSISTANT: ", pred)
# add prediction to conversation
conv.messages.pop()
conv.append_message("ASSISTANT", pred)
conv.append_message("USER", "Translate this report to easy language for a patient to understand.")
conv.append_message("ASSISTANT", None)
text_input = conv.get_prompt()
print("USER: ", "Translate this report to easy language for a patient to understand.")
# generate easy language report
input_ids = tokenizer_image_token(text_input, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=False,
use_cache=True,
max_new_tokens=300,
stopping_criteria=[stopping_criteria],
pad_token_id=tokenizer.pad_token_id
)
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
print("ASSISTANT: ", pred)