import os os.environ['CUDA_VISIBLE_DEVICES'] = '3' import gradio as gr import yaml import torch import torch.nn.functional as F import numpy as np import pandas as pd from models.modules.networks import PromptAttentionUNet, HighResEnhancer from models.modules.biomedclip import BiomedCLIPTextEncoder from monai.inferers import sliding_window_inference from markers import breast_markers, prostatic_markers, pancreatic_markers device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # load the convertion model # load the cfg file for convertion cfg_file = 'configs/{}.cfg'.format('convertion') with open(cfg_file, 'r') as f: cfg = yaml.safe_load(f) print("successfully loaded config file: ", cfg) # convertion models convertion_ckpt = './checkpoint/stage_ii.pkl' convertion_net = PromptAttentionUNet(in_channels=cfg['MODEL']['IMC_IN'], out_channels=cfg['MODEL']['IMC_OUT'], channels=(128, 256, 512, 1024, 2048)) prompt_model = BiomedCLIPTextEncoder(device=device) # load state_dict state_dict = torch.load(convertion_ckpt, map_location='cpu')['generator'] # remove all the 'module.' prefix state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} convertion_net.load_state_dict(state_dict) # load the translation model cfg_file = 'configs/{}.cfg'.format('translation') with open(cfg_file, 'r') as f: cfg = yaml.safe_load(f) print("successfully loaded config file: ", cfg) translation_ckpt = './checkpoint/stage_i.pkl' imc_net = HighResEnhancer(model_name=cfg['MODEL']['TIMM_MODEL'], in_channels=cfg['MODEL']['IMC_IN'], out_channels=cfg['MODEL']['IMC_OUT'], norm=cfg['MODEL']['NORM'], use_dilated_bottleneck=True) # load state_dict for IMC state_dict = torch.load(translation_ckpt, map_location='cpu')['imc_G'] # remove all the 'module.0.' prefix state_dict = {k.replace('module.0.', ''): v for k, v in state_dict.items()} # remove the key "sobel.filter" in the state_dict state_dict.pop('sobel.filter.weight') imc_net.load_state_dict(state_dict, strict=False) convertion_net.eval().to(device) imc_net.eval().to(device) # load the metadata for demo data df = pd.read_csv('./test_data/test_metadata.csv') breast_df = df[df['source'] == 'BreastCancer_V2'] prostatic_df = df[df['source'] == 'ProstaticCancer_V2'] pancreatic_df = df[df['source'] == 'PancreaticCancer_V2'] def load_image(pair_index): # select the item from the dataframe and convert to Series using `squeeze()` item = df[df['name'] == pair_index].squeeze() data = np.load(item['path'])['arr_0'] x1 = data[:, :, 0] x2 = data[:, :, 1] return gr.Image(value=x1), gr.Image(value=x2) def generate_imc(x1, x2, marker_name): # stage I inputs = np.concatenate([x1, x2[:, :, 2:3]], axis=2) # normalize to [0, 1] inputs = inputs / 255.0 # to tensor inputs = torch.from_numpy(inputs.transpose(2, 0, 1)).unsqueeze(0).float() # rescale to [-1, 1] inputs = 2 * inputs - 1 output = sliding_window_inference(inputs.to(device), roi_size=(320, 320), sw_batch_size=2, predictor=imc_net, overlap=0.5) output = F.tanh(output) # to numpy pred_nuclei = output[0].detach().cpu().numpy().transpose(1, 2, 0) pred_nuclei = (pred_nuclei + 1) / 2 # normalize to [0, 1] # stage II nuclei_inputs = torch.from_numpy(pred_nuclei).permute(2, 0, 1).unsqueeze(0).float() # rescale to [-1, 1] nuclei_inputs = 2 * nuclei_inputs - 1 prompt_in = torch.as_tensor(prompt_model([marker_name])).to(device) output = F.tanh(convertion_net(nuclei_inputs.to(device), prompt_in)) marker = output[0].detach().cpu().numpy().transpose(1, 2, 0) marker = (marker + 1) / 2 # normalize to [0, 1] # visualization vis = np.concatenate([marker, np.zeros_like(pred_nuclei, dtype=np.float32), pred_nuclei], axis=2) # normalize to [0, 255] and convert to uint8 vis = (vis * 255).astype(np.uint8) return gr.Image(value=vis) # Function to update the second dropdown based on the first dropdown's selection def update_dropdown_by_tissue(selected_category): if selected_category == "Breast": image_selector = gr.Dropdown(choices=breast_df['name'].values.tolist(), value=breast_df['name'].values[0], interactive=True) marker_selector = gr.Dropdown(choices=breast_markers, value=breast_markers[0], interactive=True) elif selected_category == "Pancreatic": image_selector = gr.Dropdown(choices=pancreatic_df['name'].values.tolist(), value=pancreatic_df['name'].values[0], interactive=True) marker_selector = gr.Dropdown(choices=pancreatic_markers, value=pancreatic_markers[0], interactive=True) elif selected_category == "Prostatic": image_selector = gr.Dropdown(choices=prostatic_df['name'].values.tolist(), value=prostatic_df['name'].values[0], interactive=True) marker_selector = gr.Dropdown(choices=prostatic_markers, value=prostatic_markers[0], interactive=True) return [image_selector, marker_selector] # Create the Gradio interface def create_gradio_ui(): with gr.Blocks() as demo: with gr.Tab("Mbi2Spi"): with gr.Row(): with gr.Column(scale=1): with gr.Row(): # image visualizer brightfield = gr.Image(label="Brightfield Image", type="numpy", interactive=False) aux = gr.Image(type="numpy", visible=False, interactive=False) with gr.Row(): with gr.Column(): # tissue selector (Breast, Pancreatic, Prostatic) tissue_selector = gr.Dropdown(choices=["Breast", "Pancreatic", "Prostatic"], label="Select Tissue Type") # marker selector marker_selector = gr.Dropdown(label="Marker Selector", interactive=False) with gr.Column(): # image selector image_selector = gr.Dropdown(label="Brightfield Selector", interactive=False) # update the image selector based on the tissue type tissue_selector.change(update_dropdown_by_tissue, inputs=tissue_selector, outputs=[image_selector, marker_selector]) with gr.Column(scale=1): output_image = gr.Image(label="Generated Image", type="numpy") button1 = gr.Button("Predict IMC") # Load the selected image and update the input image and infrared image image_selector.change(load_image, inputs=image_selector, outputs=[brightfield, aux]) # Event handler for button click button1.click(generate_imc, inputs=[brightfield, aux, marker_selector], outputs=output_image) return demo # Launch the demo if __name__ == '__main__': demo = create_gradio_ui() demo.launch(show_error=True)