File size: 7,120 Bytes
f97a499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import gradio as gr
import yaml
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from models.modules.networks import PromptAttentionUNet, HighResEnhancer
from models.modules.biomedclip import BiomedCLIPTextEncoder
from monai.inferers import sliding_window_inference
from markers import breast_markers, prostatic_markers, pancreatic_markers

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load the convertion model
# load the cfg file for convertion
cfg_file = 'configs/{}.cfg'.format('convertion')
with open(cfg_file, 'r') as f:
    cfg = yaml.safe_load(f)
    print("successfully loaded config file: ", cfg)

# convertion models
convertion_ckpt = './checkpoint/stage_ii.pkl'
convertion_net = PromptAttentionUNet(in_channels=cfg['MODEL']['IMC_IN'], out_channels=cfg['MODEL']['IMC_OUT'], channels=(128, 256, 512, 1024, 2048))
prompt_model = BiomedCLIPTextEncoder(device=device)

# load state_dict
state_dict = torch.load(convertion_ckpt, map_location='cpu')['generator']
# remove all the 'module.' prefix
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
convertion_net.load_state_dict(state_dict)

# load the translation model
cfg_file = 'configs/{}.cfg'.format('translation')
with open(cfg_file, 'r') as f:
    cfg = yaml.safe_load(f)
    print("successfully loaded config file: ", cfg)
translation_ckpt = './checkpoint/stage_i.pkl'

imc_net = HighResEnhancer(model_name=cfg['MODEL']['TIMM_MODEL'], 
                               in_channels=cfg['MODEL']['IMC_IN'], 
                               out_channels=cfg['MODEL']['IMC_OUT'], 
                               norm=cfg['MODEL']['NORM'],
                               use_dilated_bottleneck=True)

# load state_dict for IMC
state_dict = torch.load(translation_ckpt, map_location='cpu')['imc_G']
# remove all the 'module.0.' prefix
state_dict = {k.replace('module.0.', ''): v for k, v in state_dict.items()}
# remove the key "sobel.filter" in the state_dict
state_dict.pop('sobel.filter.weight')
imc_net.load_state_dict(state_dict, strict=False)

convertion_net.eval().to(device)
imc_net.eval().to(device)

# load the metadata for demo data
df = pd.read_csv('./test_data/test_metadata.csv')
breast_df = df[df['source'] == 'BreastCancer_V2']
prostatic_df = df[df['source'] == 'ProstaticCancer_V2']
pancreatic_df = df[df['source'] == 'PancreaticCancer_V2']


def load_image(pair_index):
    # select the item from the dataframe and convert to Series using `squeeze()`
    item = df[df['name'] == pair_index].squeeze()
    data = np.load(item['path'])['arr_0']
    x1 = data[:, :, 0]
    x2 = data[:, :, 1]
    return gr.Image(value=x1), gr.Image(value=x2)


def generate_imc(x1, x2, marker_name):
    # stage I
    inputs = np.concatenate([x1, x2[:, :, 2:3]], axis=2)
    # normalize to [0, 1]
    inputs = inputs / 255.0
    # to tensor
    inputs = torch.from_numpy(inputs.transpose(2, 0, 1)).unsqueeze(0).float()
    # rescale to [-1, 1]
    inputs = 2 * inputs - 1
    output = sliding_window_inference(inputs.to(device), roi_size=(320, 320), sw_batch_size=2, predictor=imc_net, overlap=0.5)
    output = F.tanh(output)
    # to numpy
    pred_nuclei = output[0].detach().cpu().numpy().transpose(1, 2, 0)
    pred_nuclei = (pred_nuclei + 1) / 2  # normalize to [0, 1]
    # stage II
    nuclei_inputs = torch.from_numpy(pred_nuclei).permute(2, 0, 1).unsqueeze(0).float()
    # rescale to [-1, 1]
    nuclei_inputs = 2 * nuclei_inputs - 1
    prompt_in = torch.as_tensor(prompt_model([marker_name])).to(device)
    output = F.tanh(convertion_net(nuclei_inputs.to(device), prompt_in))
    marker = output[0].detach().cpu().numpy().transpose(1, 2, 0)
    marker = (marker + 1) / 2 # normalize to [0, 1]
    # visualization
    vis = np.concatenate([marker, np.zeros_like(pred_nuclei, dtype=np.float32), pred_nuclei], axis=2)
    # normalize to [0, 255] and convert to uint8
    vis = (vis * 255).astype(np.uint8)
    return gr.Image(value=vis)


# Function to update the second dropdown based on the first dropdown's selection
def update_dropdown_by_tissue(selected_category):
    if selected_category == "Breast":
        image_selector = gr.Dropdown(choices=breast_df['name'].values.tolist(), value=breast_df['name'].values[0], interactive=True)
        marker_selector = gr.Dropdown(choices=breast_markers, value=breast_markers[0], interactive=True)
    elif selected_category == "Pancreatic":
        image_selector =  gr.Dropdown(choices=pancreatic_df['name'].values.tolist(), value=pancreatic_df['name'].values[0], interactive=True)
        marker_selector = gr.Dropdown(choices=pancreatic_markers, value=pancreatic_markers[0], interactive=True)
    elif selected_category == "Prostatic":
        image_selector =  gr.Dropdown(choices=prostatic_df['name'].values.tolist(), value=prostatic_df['name'].values[0], interactive=True)
        marker_selector = gr.Dropdown(choices=prostatic_markers, value=prostatic_markers[0], interactive=True)
    return [image_selector, marker_selector]


# Create the Gradio interface
def create_gradio_ui():
    with gr.Blocks() as demo:
        with gr.Tab("Mbi2Spi"):
            with gr.Row():
                with gr.Column(scale=1):
                    with gr.Row():
                        # image visualizer
                        brightfield = gr.Image(label="Brightfield Image", type="numpy", interactive=False)
                        aux = gr.Image(type="numpy", visible=False, interactive=False)
                    
                    with gr.Row():
                        with gr.Column():
                            # tissue selector (Breast, Pancreatic, Prostatic)
                            tissue_selector = gr.Dropdown(choices=["Breast", "Pancreatic", "Prostatic"], label="Select Tissue Type")
                            # marker selector
                            marker_selector = gr.Dropdown(label="Marker Selector", interactive=False)
                            
                        with gr.Column():
                            # image selector
                            image_selector = gr.Dropdown(label="Brightfield Selector", interactive=False)
                            # update the image selector based on the tissue type
                            tissue_selector.change(update_dropdown_by_tissue, inputs=tissue_selector, outputs=[image_selector, marker_selector])

                with gr.Column(scale=1):
                    output_image = gr.Image(label="Generated Image", type="numpy")
                    button1 = gr.Button("Predict IMC")

                    # Load the selected image and update the input image and infrared image
                    image_selector.change(load_image, inputs=image_selector, outputs=[brightfield, aux])

                    # Event handler for button click
                    button1.click(generate_imc, inputs=[brightfield, aux, marker_selector], outputs=output_image)

    return demo

# Launch the demo
if __name__ == '__main__':
    demo = create_gradio_ui()
    demo.launch(show_error=True)