Mbi2Spi / app.py
hsiangyualex's picture
Upload 64 files
f97a499 verified
raw
history blame
7.12 kB
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import gradio as gr
import yaml
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from models.modules.networks import PromptAttentionUNet, HighResEnhancer
from models.modules.biomedclip import BiomedCLIPTextEncoder
from monai.inferers import sliding_window_inference
from markers import breast_markers, prostatic_markers, pancreatic_markers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the convertion model
# load the cfg file for convertion
cfg_file = 'configs/{}.cfg'.format('convertion')
with open(cfg_file, 'r') as f:
cfg = yaml.safe_load(f)
print("successfully loaded config file: ", cfg)
# convertion models
convertion_ckpt = './checkpoint/stage_ii.pkl'
convertion_net = PromptAttentionUNet(in_channels=cfg['MODEL']['IMC_IN'], out_channels=cfg['MODEL']['IMC_OUT'], channels=(128, 256, 512, 1024, 2048))
prompt_model = BiomedCLIPTextEncoder(device=device)
# load state_dict
state_dict = torch.load(convertion_ckpt, map_location='cpu')['generator']
# remove all the 'module.' prefix
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
convertion_net.load_state_dict(state_dict)
# load the translation model
cfg_file = 'configs/{}.cfg'.format('translation')
with open(cfg_file, 'r') as f:
cfg = yaml.safe_load(f)
print("successfully loaded config file: ", cfg)
translation_ckpt = './checkpoint/stage_i.pkl'
imc_net = HighResEnhancer(model_name=cfg['MODEL']['TIMM_MODEL'],
in_channels=cfg['MODEL']['IMC_IN'],
out_channels=cfg['MODEL']['IMC_OUT'],
norm=cfg['MODEL']['NORM'],
use_dilated_bottleneck=True)
# load state_dict for IMC
state_dict = torch.load(translation_ckpt, map_location='cpu')['imc_G']
# remove all the 'module.0.' prefix
state_dict = {k.replace('module.0.', ''): v for k, v in state_dict.items()}
# remove the key "sobel.filter" in the state_dict
state_dict.pop('sobel.filter.weight')
imc_net.load_state_dict(state_dict, strict=False)
convertion_net.eval().to(device)
imc_net.eval().to(device)
# load the metadata for demo data
df = pd.read_csv('./test_data/test_metadata.csv')
breast_df = df[df['source'] == 'BreastCancer_V2']
prostatic_df = df[df['source'] == 'ProstaticCancer_V2']
pancreatic_df = df[df['source'] == 'PancreaticCancer_V2']
def load_image(pair_index):
# select the item from the dataframe and convert to Series using `squeeze()`
item = df[df['name'] == pair_index].squeeze()
data = np.load(item['path'])['arr_0']
x1 = data[:, :, 0]
x2 = data[:, :, 1]
return gr.Image(value=x1), gr.Image(value=x2)
def generate_imc(x1, x2, marker_name):
# stage I
inputs = np.concatenate([x1, x2[:, :, 2:3]], axis=2)
# normalize to [0, 1]
inputs = inputs / 255.0
# to tensor
inputs = torch.from_numpy(inputs.transpose(2, 0, 1)).unsqueeze(0).float()
# rescale to [-1, 1]
inputs = 2 * inputs - 1
output = sliding_window_inference(inputs.to(device), roi_size=(320, 320), sw_batch_size=2, predictor=imc_net, overlap=0.5)
output = F.tanh(output)
# to numpy
pred_nuclei = output[0].detach().cpu().numpy().transpose(1, 2, 0)
pred_nuclei = (pred_nuclei + 1) / 2 # normalize to [0, 1]
# stage II
nuclei_inputs = torch.from_numpy(pred_nuclei).permute(2, 0, 1).unsqueeze(0).float()
# rescale to [-1, 1]
nuclei_inputs = 2 * nuclei_inputs - 1
prompt_in = torch.as_tensor(prompt_model([marker_name])).to(device)
output = F.tanh(convertion_net(nuclei_inputs.to(device), prompt_in))
marker = output[0].detach().cpu().numpy().transpose(1, 2, 0)
marker = (marker + 1) / 2 # normalize to [0, 1]
# visualization
vis = np.concatenate([marker, np.zeros_like(pred_nuclei, dtype=np.float32), pred_nuclei], axis=2)
# normalize to [0, 255] and convert to uint8
vis = (vis * 255).astype(np.uint8)
return gr.Image(value=vis)
# Function to update the second dropdown based on the first dropdown's selection
def update_dropdown_by_tissue(selected_category):
if selected_category == "Breast":
image_selector = gr.Dropdown(choices=breast_df['name'].values.tolist(), value=breast_df['name'].values[0], interactive=True)
marker_selector = gr.Dropdown(choices=breast_markers, value=breast_markers[0], interactive=True)
elif selected_category == "Pancreatic":
image_selector = gr.Dropdown(choices=pancreatic_df['name'].values.tolist(), value=pancreatic_df['name'].values[0], interactive=True)
marker_selector = gr.Dropdown(choices=pancreatic_markers, value=pancreatic_markers[0], interactive=True)
elif selected_category == "Prostatic":
image_selector = gr.Dropdown(choices=prostatic_df['name'].values.tolist(), value=prostatic_df['name'].values[0], interactive=True)
marker_selector = gr.Dropdown(choices=prostatic_markers, value=prostatic_markers[0], interactive=True)
return [image_selector, marker_selector]
# Create the Gradio interface
def create_gradio_ui():
with gr.Blocks() as demo:
with gr.Tab("Mbi2Spi"):
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
# image visualizer
brightfield = gr.Image(label="Brightfield Image", type="numpy", interactive=False)
aux = gr.Image(type="numpy", visible=False, interactive=False)
with gr.Row():
with gr.Column():
# tissue selector (Breast, Pancreatic, Prostatic)
tissue_selector = gr.Dropdown(choices=["Breast", "Pancreatic", "Prostatic"], label="Select Tissue Type")
# marker selector
marker_selector = gr.Dropdown(label="Marker Selector", interactive=False)
with gr.Column():
# image selector
image_selector = gr.Dropdown(label="Brightfield Selector", interactive=False)
# update the image selector based on the tissue type
tissue_selector.change(update_dropdown_by_tissue, inputs=tissue_selector, outputs=[image_selector, marker_selector])
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image", type="numpy")
button1 = gr.Button("Predict IMC")
# Load the selected image and update the input image and infrared image
image_selector.change(load_image, inputs=image_selector, outputs=[brightfield, aux])
# Event handler for button click
button1.click(generate_imc, inputs=[brightfield, aux, marker_selector], outputs=output_image)
return demo
# Launch the demo
if __name__ == '__main__':
demo = create_gradio_ui()
demo.launch(show_error=True)