eksemyashkina commited on
Commit
e6d79e8
·
verified ·
1 Parent(s): b5d81c7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import cv2
5
+ import h5py
6
+ from test_develop_code.architecture import model_generator
7
+ import PIL.Image
8
+
9
+
10
+ device = torch.device("cpu")
11
+ model = model_generator("mst_plus_plus", "mst_plus_plus.pth").to(device)
12
+ model.eval()
13
+ wavelengths = np.linspace(400, 700, 31)
14
+
15
+
16
+ def wavelength_to_rgb(wl: float) -> tuple[float]:
17
+ if 380 <= wl <= 440:
18
+ R = -(wl - 440) / (440 - 380)
19
+ G = 0.0
20
+ B = 1.0
21
+ elif 440 < wl <= 490:
22
+ R = 0.0
23
+ G = (wl - 440) / (490 - 440)
24
+ B = 1.0
25
+ elif 490 < wl <= 510:
26
+ R = 0.0
27
+ G = 1.0
28
+ B = -(wl - 510) / (510 - 490)
29
+ elif 510 < wl <= 580:
30
+ R = (wl - 510) / (580 - 510)
31
+ G = 1.0
32
+ B = 0.0
33
+ elif 580 < wl <= 645:
34
+ R = 1.0
35
+ G = -(wl - 645) / (645 - 580)
36
+ B = 0.0
37
+ elif 645 < wl <= 700:
38
+ R = 1.0
39
+ G = 0.0
40
+ B = 0.0
41
+ else:
42
+ R = G = B = 0.0
43
+ return (max(R, 0.0), max(G, 0.0), max(B, 0.0))
44
+
45
+
46
+ def predict(img: np.ndarray) -> np.ndarray:
47
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
48
+ img = img.astype(np.float32)
49
+ img = (img - img.min()) / (img.max() - img.min() + 1e-8)
50
+ img = np.transpose(img, (2, 0, 1))
51
+ img_tensor = torch.from_numpy(img).unsqueeze(0).to(device)
52
+ with torch.no_grad():
53
+ pred = model(img_tensor)
54
+ pred = pred.squeeze(0).cpu().numpy()
55
+ pred = np.clip(pred, 0, 1)
56
+ return pred
57
+
58
+
59
+ def visualize_channel(cube: np.ndarray, index: int) -> PIL.Image.Image:
60
+ if cube is None:
61
+ return None
62
+ band = cube[index]
63
+ band = (band - band.min()) / (band.max() - band.min() + 1e-8)
64
+ color = wavelength_to_rgb(wavelengths[index])
65
+ rgb = np.stack([band * c for c in color], axis=-1)
66
+ rgb = (rgb * 255).astype(np.uint8)
67
+ return PIL.Image.fromarray(rgb)
68
+
69
+
70
+ def load_mat(mat_file: gr.File) -> np.ndarray:
71
+ with h5py.File(mat_file.name, "r") as f:
72
+ cube = np.array(f["cube"])
73
+ cube = np.transpose(cube, (0, 2, 1))
74
+ cube = np.clip(cube, 0, 1)
75
+ return cube
76
+
77
+
78
+ with gr.Blocks() as demo:
79
+ gr.Markdown("## Spectral Reconstruction")
80
+ with gr.Row():
81
+ with gr.Column():
82
+ rgb_input = gr.Image(type="numpy", label="Upload RGB Image")
83
+ pred_state = gr.State()
84
+ with gr.Column():
85
+ pred_output = gr.Image(label="Prediction Visualization")
86
+ pred_slider = gr.Slider(minimum=0, maximum=30, step=1, label="Channel (Prediction)", value=0)
87
+ with gr.Row():
88
+ with gr.Column():
89
+ mat_input = gr.File(label="Upload .mat file (Ground Truth)")
90
+ gt_state = gr.State()
91
+ with gr.Column():
92
+ gt_output = gr.Image(label="Ground Truth Visualization")
93
+ gt_slider = gr.Slider(minimum=0, maximum=30, step=1, label="Channel (Ground Truth)", value=0)
94
+ rgb_input.change(fn=predict, inputs=rgb_input, outputs=pred_state)
95
+ pred_slider.change(fn=visualize_channel, inputs=[pred_state, pred_slider], outputs=pred_output)
96
+ mat_input.change(fn=load_mat, inputs=mat_input, outputs=gt_state)
97
+ gt_slider.change(fn=visualize_channel, inputs=[gt_state, gt_slider], outputs=gt_output)
98
+ gr.Examples(
99
+ examples=[
100
+ ["assets/ARAD_1K_0001.jpg", 0, "assets/ARAD_1K_0001.mat", 0],
101
+ ["assets/ARAD_1K_0002.jpg", 0, "assets/ARAD_1K_0002.mat", 0],
102
+ ["assets/ARAD_1K_0003.jpg", 0, "assets/ARAD_1K_0003.mat", 0],
103
+ ["assets/ARAD_1K_0004.jpg", 0, "assets/ARAD_1K_0004.mat", 0],
104
+ ["assets/ARAD_1K_0005.jpg", 0, "assets/ARAD_1K_0005.mat", 0],
105
+ ],
106
+ inputs=[rgb_input, pred_slider, mat_input, gt_slider],
107
+ outputs=[pred_output, gt_output],
108
+ label="Try Examples"
109
+ )
110
+
111
+
112
+ if __name__ == "__main__":
113
+ demo.launch()