Spaces:
Sleeping
Sleeping
import os | |
os.environ['CUDA_VISIBLE_DEVICES'] = '3' | |
import gradio as gr | |
import yaml | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import pandas as pd | |
from models.modules.networks import PromptAttentionUNet, HighResEnhancer | |
from models.modules.biomedclip import BiomedCLIPTextEncoder | |
from monai.inferers import sliding_window_inference | |
from markers import breast_markers, prostatic_markers, pancreatic_markers | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# load the convertion model | |
# load the cfg file for convertion | |
cfg_file = 'configs/{}.cfg'.format('convertion') | |
with open(cfg_file, 'r') as f: | |
cfg = yaml.safe_load(f) | |
print("successfully loaded config file: ", cfg) | |
# convertion models | |
convertion_ckpt = './checkpoint/stage_ii.pkl' | |
convertion_net = PromptAttentionUNet(in_channels=cfg['MODEL']['IMC_IN'], out_channels=cfg['MODEL']['IMC_OUT'], channels=(128, 256, 512, 1024, 2048)) | |
prompt_model = BiomedCLIPTextEncoder(device=device) | |
# load state_dict | |
state_dict = torch.load(convertion_ckpt, map_location='cpu')['generator'] | |
# remove all the 'module.' prefix | |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
convertion_net.load_state_dict(state_dict) | |
# load the translation model | |
cfg_file = 'configs/{}.cfg'.format('translation') | |
with open(cfg_file, 'r') as f: | |
cfg = yaml.safe_load(f) | |
print("successfully loaded config file: ", cfg) | |
translation_ckpt = './checkpoint/stage_i.pkl' | |
imc_net = HighResEnhancer(model_name=cfg['MODEL']['TIMM_MODEL'], | |
in_channels=cfg['MODEL']['IMC_IN'], | |
out_channels=cfg['MODEL']['IMC_OUT'], | |
norm=cfg['MODEL']['NORM'], | |
use_dilated_bottleneck=True) | |
# load state_dict for IMC | |
state_dict = torch.load(translation_ckpt, map_location='cpu')['imc_G'] | |
# remove all the 'module.0.' prefix | |
state_dict = {k.replace('module.0.', ''): v for k, v in state_dict.items()} | |
# remove the key "sobel.filter" in the state_dict | |
state_dict.pop('sobel.filter.weight') | |
imc_net.load_state_dict(state_dict, strict=False) | |
convertion_net.eval().to(device) | |
imc_net.eval().to(device) | |
# load the metadata for demo data | |
df = pd.read_csv('./test_data/test_metadata.csv') | |
breast_df = df[df['source'] == 'BreastCancer_V2'] | |
prostatic_df = df[df['source'] == 'ProstaticCancer_V2'] | |
pancreatic_df = df[df['source'] == 'PancreaticCancer_V2'] | |
def load_image(pair_index): | |
# select the item from the dataframe and convert to Series using `squeeze()` | |
item = df[df['name'] == pair_index].squeeze() | |
data = np.load(item['path'])['arr_0'] | |
x1 = data[:, :, 0] | |
x2 = data[:, :, 1] | |
return gr.Image(value=x1), gr.Image(value=x2) | |
def generate_imc(x1, x2, marker_name): | |
# stage I | |
inputs = np.concatenate([x1, x2[:, :, 2:3]], axis=2) | |
# normalize to [0, 1] | |
inputs = inputs / 255.0 | |
# to tensor | |
inputs = torch.from_numpy(inputs.transpose(2, 0, 1)).unsqueeze(0).float() | |
# rescale to [-1, 1] | |
inputs = 2 * inputs - 1 | |
output = sliding_window_inference(inputs.to(device), roi_size=(320, 320), sw_batch_size=2, predictor=imc_net, overlap=0.5) | |
output = F.tanh(output) | |
# to numpy | |
pred_nuclei = output[0].detach().cpu().numpy().transpose(1, 2, 0) | |
pred_nuclei = (pred_nuclei + 1) / 2 # normalize to [0, 1] | |
# stage II | |
nuclei_inputs = torch.from_numpy(pred_nuclei).permute(2, 0, 1).unsqueeze(0).float() | |
# rescale to [-1, 1] | |
nuclei_inputs = 2 * nuclei_inputs - 1 | |
prompt_in = torch.as_tensor(prompt_model([marker_name])).to(device) | |
output = F.tanh(convertion_net(nuclei_inputs.to(device), prompt_in)) | |
marker = output[0].detach().cpu().numpy().transpose(1, 2, 0) | |
marker = (marker + 1) / 2 # normalize to [0, 1] | |
# visualization | |
vis = np.concatenate([marker, np.zeros_like(pred_nuclei, dtype=np.float32), pred_nuclei], axis=2) | |
# normalize to [0, 255] and convert to uint8 | |
vis = (vis * 255).astype(np.uint8) | |
return gr.Image(value=vis) | |
# Function to update the second dropdown based on the first dropdown's selection | |
def update_dropdown_by_tissue(selected_category): | |
if selected_category == "Breast": | |
image_selector = gr.Dropdown(choices=breast_df['name'].values.tolist(), value=breast_df['name'].values[0], interactive=True) | |
marker_selector = gr.Dropdown(choices=breast_markers, value=breast_markers[0], interactive=True) | |
elif selected_category == "Pancreatic": | |
image_selector = gr.Dropdown(choices=pancreatic_df['name'].values.tolist(), value=pancreatic_df['name'].values[0], interactive=True) | |
marker_selector = gr.Dropdown(choices=pancreatic_markers, value=pancreatic_markers[0], interactive=True) | |
elif selected_category == "Prostatic": | |
image_selector = gr.Dropdown(choices=prostatic_df['name'].values.tolist(), value=prostatic_df['name'].values[0], interactive=True) | |
marker_selector = gr.Dropdown(choices=prostatic_markers, value=prostatic_markers[0], interactive=True) | |
return [image_selector, marker_selector] | |
# Create the Gradio interface | |
def create_gradio_ui(): | |
with gr.Blocks() as demo: | |
with gr.Tab("Mbi2Spi"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
# image visualizer | |
brightfield = gr.Image(label="Brightfield Image", type="numpy", interactive=False) | |
aux = gr.Image(type="numpy", visible=False, interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
# tissue selector (Breast, Pancreatic, Prostatic) | |
tissue_selector = gr.Dropdown(choices=["Breast", "Pancreatic", "Prostatic"], label="Select Tissue Type") | |
# marker selector | |
marker_selector = gr.Dropdown(label="Marker Selector", interactive=False) | |
with gr.Column(): | |
# image selector | |
image_selector = gr.Dropdown(label="Brightfield Selector", interactive=False) | |
# update the image selector based on the tissue type | |
tissue_selector.change(update_dropdown_by_tissue, inputs=tissue_selector, outputs=[image_selector, marker_selector]) | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="Generated Image", type="numpy") | |
button1 = gr.Button("Predict IMC") | |
# Load the selected image and update the input image and infrared image | |
image_selector.change(load_image, inputs=image_selector, outputs=[brightfield, aux]) | |
# Event handler for button click | |
button1.click(generate_imc, inputs=[brightfield, aux, marker_selector], outputs=output_image) | |
return demo | |
# Launch the demo | |
if __name__ == '__main__': | |
demo = create_gradio_ui() | |
demo.launch(show_error=True) | |