import torch import numpy as np import gradio as gr import cv2 import h5py from test_develop_code.architecture import model_generator import PIL.Image device = torch.device("cpu") model = model_generator("mst_plus_plus", "mst_plus_plus.pth").to(device) model.eval() wavelengths = np.linspace(400, 700, 31) def wavelength_to_rgb(wl: float) -> tuple[float]: if 380 <= wl <= 440: R = -(wl - 440) / (440 - 380) G = 0.0 B = 1.0 elif 440 < wl <= 490: R = 0.0 G = (wl - 440) / (490 - 440) B = 1.0 elif 490 < wl <= 510: R = 0.0 G = 1.0 B = -(wl - 510) / (510 - 490) elif 510 < wl <= 580: R = (wl - 510) / (580 - 510) G = 1.0 B = 0.0 elif 580 < wl <= 645: R = 1.0 G = -(wl - 645) / (645 - 580) B = 0.0 elif 645 < wl <= 700: R = 1.0 G = 0.0 B = 0.0 else: R = G = B = 0.0 return (max(R, 0.0), max(G, 0.0), max(B, 0.0)) def predict(img: np.ndarray) -> np.ndarray: # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.float32) img = (img - img.min()) / (img.max() - img.min() + 1e-8) img = np.transpose(img, (2, 0, 1)) img_tensor = torch.from_numpy(img).unsqueeze(0).to(device) with torch.no_grad(): pred = model(img_tensor) pred = pred.squeeze(0).cpu().numpy() pred = np.clip(pred, 0, 1) return pred def visualize_channel(cube: np.ndarray, index: int) -> PIL.Image.Image: if cube is None: return None band = cube[index] band = (band - band.min()) / (band.max() - band.min() + 1e-8) color = wavelength_to_rgb(wavelengths[index]) rgb = np.stack([band * c for c in color], axis=-1) rgb = (rgb * 255).astype(np.uint8) return PIL.Image.fromarray(rgb) def load_mat(mat_file: gr.File) -> np.ndarray: with h5py.File(mat_file.name, "r") as f: cube = np.array(f["cube"]) cube = np.transpose(cube, (0, 2, 1)) cube = np.clip(cube, 0, 1) return cube with gr.Blocks() as demo: gr.Markdown("## Spectral Reconstruction") with gr.Row(): with gr.Column(): rgb_input = gr.Image(type="numpy", label="Upload RGB Image") pred_state = gr.State() with gr.Column(): pred_output = gr.Image(label="Prediction Visualization") pred_slider = gr.Slider(minimum=0, maximum=30, step=1, label="Channel (Prediction)", value=0) with gr.Row(): with gr.Column(): mat_input = gr.File(label="Upload .mat file (Ground Truth)") gt_state = gr.State() with gr.Column(): gt_output = gr.Image(label="Ground Truth Visualization") gt_slider = gr.Slider(minimum=0, maximum=30, step=1, label="Channel (Ground Truth)", value=0) rgb_input.change(fn=predict, inputs=rgb_input, outputs=pred_state) pred_slider.change(fn=visualize_channel, inputs=[pred_state, pred_slider], outputs=pred_output) mat_input.change(fn=load_mat, inputs=mat_input, outputs=gt_state) gt_slider.change(fn=visualize_channel, inputs=[gt_state, gt_slider], outputs=gt_output) gr.Examples( examples=[ ["assets/ARAD_1K_0001.jpg", 0, "assets/ARAD_1K_0001.mat", 0], ["assets/ARAD_1K_0002.jpg", 0, "assets/ARAD_1K_0002.mat", 0], ["assets/ARAD_1K_0003.jpg", 0, "assets/ARAD_1K_0003.mat", 0], ["assets/ARAD_1K_0004.jpg", 0, "assets/ARAD_1K_0004.mat", 0], ["assets/ARAD_1K_0005.jpg", 0, "assets/ARAD_1K_0005.mat", 0], ], inputs=[rgb_input, pred_slider, mat_input, gt_slider], outputs=[pred_output, gt_output], label="Try Examples" ) if __name__ == "__main__": demo.launch()