Conn-Cerberus's picture
Update app.py
3b6b171 verified
raw
history blame contribute delete
3.07 kB
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2
# Load model
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("skin_cancer_resnet18_version1.pt", map_location="cpu"))
model.eval()
classes = ['benign', 'malignant']
# Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# Target layer for Grad-CAM
target_layer = model.layer3[1].conv2
# Store activations & gradients
activations = None
gradients = None
def forward_hook(module, input, output):
global activations
activations = output.detach()
def backward_hook(module, grad_input, grad_output):
global gradients
gradients = grad_output[0].detach()
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)
# Grad-CAM function
def generate_gradcam(input_tensor, class_idx):
model.zero_grad()
output = model(input_tensor)
class_score = output[0, class_idx]
class_score.backward()
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) # [C]
weighted_activations = activations[0] * pooled_gradients[:, None, None] # [C, H, W]
cam = torch.sum(weighted_activations, dim=0).cpu().numpy()
# Normalize and resize
cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (224, 224))
cam -= cam.min()
cam /= cam.max()
return cam
# Full pipeline
def predict(img):
global activations, gradients
activations = None
gradients = None
img = img.convert("RGB")
input_tensor = transform(img).unsqueeze(0)
output = model(input_tensor)
probs = F.softmax(output[0], dim=0)
pred_class = torch.argmax(probs).item()
cam = generate_gradcam(input_tensor, pred_class)
# Convert to heatmap
heatmap = np.uint8(255 * cam)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Overlay
img_np = np.array(img.resize((224, 224)))
overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)
overlay_img = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
return {classes[i]: float(probs[i]) for i in range(2)}, overlay_img
# Gradio interface
title = "🧠 Soma: Skin Cancer Classifier + Grad-CAM"
description = """
🔐 Privacy Disclaimer
🛡️ Your data is private.
This tool does not store your images, personal information, or results. All image analysis is performed temporarily in memory, and no files are saved or shared.
Your uploaded image is used only for the current prediction, then immediately discarded.
⚠️ This is not a diagnostic tool. For medical concerns, always consult a healthcare professional.
"""
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Lesion Image"),
outputs=[
gr.Label(num_top_classes=2, label="Prediction"),
gr.Image(type="pil", label="Grad-CAM Visualisation")
],
title=title,
description=description
)
demo.launch()