Spaces:
Sleeping
Sleeping
File size: 7,120 Bytes
f97a499 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import gradio as gr
import yaml
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from models.modules.networks import PromptAttentionUNet, HighResEnhancer
from models.modules.biomedclip import BiomedCLIPTextEncoder
from monai.inferers import sliding_window_inference
from markers import breast_markers, prostatic_markers, pancreatic_markers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the convertion model
# load the cfg file for convertion
cfg_file = 'configs/{}.cfg'.format('convertion')
with open(cfg_file, 'r') as f:
cfg = yaml.safe_load(f)
print("successfully loaded config file: ", cfg)
# convertion models
convertion_ckpt = './checkpoint/stage_ii.pkl'
convertion_net = PromptAttentionUNet(in_channels=cfg['MODEL']['IMC_IN'], out_channels=cfg['MODEL']['IMC_OUT'], channels=(128, 256, 512, 1024, 2048))
prompt_model = BiomedCLIPTextEncoder(device=device)
# load state_dict
state_dict = torch.load(convertion_ckpt, map_location='cpu')['generator']
# remove all the 'module.' prefix
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
convertion_net.load_state_dict(state_dict)
# load the translation model
cfg_file = 'configs/{}.cfg'.format('translation')
with open(cfg_file, 'r') as f:
cfg = yaml.safe_load(f)
print("successfully loaded config file: ", cfg)
translation_ckpt = './checkpoint/stage_i.pkl'
imc_net = HighResEnhancer(model_name=cfg['MODEL']['TIMM_MODEL'],
in_channels=cfg['MODEL']['IMC_IN'],
out_channels=cfg['MODEL']['IMC_OUT'],
norm=cfg['MODEL']['NORM'],
use_dilated_bottleneck=True)
# load state_dict for IMC
state_dict = torch.load(translation_ckpt, map_location='cpu')['imc_G']
# remove all the 'module.0.' prefix
state_dict = {k.replace('module.0.', ''): v for k, v in state_dict.items()}
# remove the key "sobel.filter" in the state_dict
state_dict.pop('sobel.filter.weight')
imc_net.load_state_dict(state_dict, strict=False)
convertion_net.eval().to(device)
imc_net.eval().to(device)
# load the metadata for demo data
df = pd.read_csv('./test_data/test_metadata.csv')
breast_df = df[df['source'] == 'BreastCancer_V2']
prostatic_df = df[df['source'] == 'ProstaticCancer_V2']
pancreatic_df = df[df['source'] == 'PancreaticCancer_V2']
def load_image(pair_index):
# select the item from the dataframe and convert to Series using `squeeze()`
item = df[df['name'] == pair_index].squeeze()
data = np.load(item['path'])['arr_0']
x1 = data[:, :, 0]
x2 = data[:, :, 1]
return gr.Image(value=x1), gr.Image(value=x2)
def generate_imc(x1, x2, marker_name):
# stage I
inputs = np.concatenate([x1, x2[:, :, 2:3]], axis=2)
# normalize to [0, 1]
inputs = inputs / 255.0
# to tensor
inputs = torch.from_numpy(inputs.transpose(2, 0, 1)).unsqueeze(0).float()
# rescale to [-1, 1]
inputs = 2 * inputs - 1
output = sliding_window_inference(inputs.to(device), roi_size=(320, 320), sw_batch_size=2, predictor=imc_net, overlap=0.5)
output = F.tanh(output)
# to numpy
pred_nuclei = output[0].detach().cpu().numpy().transpose(1, 2, 0)
pred_nuclei = (pred_nuclei + 1) / 2 # normalize to [0, 1]
# stage II
nuclei_inputs = torch.from_numpy(pred_nuclei).permute(2, 0, 1).unsqueeze(0).float()
# rescale to [-1, 1]
nuclei_inputs = 2 * nuclei_inputs - 1
prompt_in = torch.as_tensor(prompt_model([marker_name])).to(device)
output = F.tanh(convertion_net(nuclei_inputs.to(device), prompt_in))
marker = output[0].detach().cpu().numpy().transpose(1, 2, 0)
marker = (marker + 1) / 2 # normalize to [0, 1]
# visualization
vis = np.concatenate([marker, np.zeros_like(pred_nuclei, dtype=np.float32), pred_nuclei], axis=2)
# normalize to [0, 255] and convert to uint8
vis = (vis * 255).astype(np.uint8)
return gr.Image(value=vis)
# Function to update the second dropdown based on the first dropdown's selection
def update_dropdown_by_tissue(selected_category):
if selected_category == "Breast":
image_selector = gr.Dropdown(choices=breast_df['name'].values.tolist(), value=breast_df['name'].values[0], interactive=True)
marker_selector = gr.Dropdown(choices=breast_markers, value=breast_markers[0], interactive=True)
elif selected_category == "Pancreatic":
image_selector = gr.Dropdown(choices=pancreatic_df['name'].values.tolist(), value=pancreatic_df['name'].values[0], interactive=True)
marker_selector = gr.Dropdown(choices=pancreatic_markers, value=pancreatic_markers[0], interactive=True)
elif selected_category == "Prostatic":
image_selector = gr.Dropdown(choices=prostatic_df['name'].values.tolist(), value=prostatic_df['name'].values[0], interactive=True)
marker_selector = gr.Dropdown(choices=prostatic_markers, value=prostatic_markers[0], interactive=True)
return [image_selector, marker_selector]
# Create the Gradio interface
def create_gradio_ui():
with gr.Blocks() as demo:
with gr.Tab("Mbi2Spi"):
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
# image visualizer
brightfield = gr.Image(label="Brightfield Image", type="numpy", interactive=False)
aux = gr.Image(type="numpy", visible=False, interactive=False)
with gr.Row():
with gr.Column():
# tissue selector (Breast, Pancreatic, Prostatic)
tissue_selector = gr.Dropdown(choices=["Breast", "Pancreatic", "Prostatic"], label="Select Tissue Type")
# marker selector
marker_selector = gr.Dropdown(label="Marker Selector", interactive=False)
with gr.Column():
# image selector
image_selector = gr.Dropdown(label="Brightfield Selector", interactive=False)
# update the image selector based on the tissue type
tissue_selector.change(update_dropdown_by_tissue, inputs=tissue_selector, outputs=[image_selector, marker_selector])
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image", type="numpy")
button1 = gr.Button("Predict IMC")
# Load the selected image and update the input image and infrared image
image_selector.change(load_image, inputs=image_selector, outputs=[brightfield, aux])
# Event handler for button click
button1.click(generate_imc, inputs=[brightfield, aux, marker_selector], outputs=output_image)
return demo
# Launch the demo
if __name__ == '__main__':
demo = create_gradio_ui()
demo.launch(show_error=True)
|