Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from torchvision import transforms | |
from huggingface_hub import hf_hub_download | |
import segmentation_models_pytorch as smp | |
import numpy as np | |
# Set the number of output classes (from your label_colors.txt, you have 4 classes) | |
NUM_CLASSES = 4 | |
# Define a mapping from class indices to RGB colors | |
# For example: background: black, oil: (255, 0, 124), others: (255, 204, 51), water: (51, 221, 255) | |
COLOR_MAPPING = { | |
0: [0, 0, 0], | |
1: [255, 0, 124], | |
2: [255, 204, 51], | |
3: [51, 221, 255] | |
} | |
def colorize_mask(mask): | |
""" | |
Convert a 2D mask (with class indices) into a color image. | |
Args: | |
mask (np.ndarray): 2D numpy array with class indices. | |
Returns: | |
np.ndarray: Color image (H x W x 3) with each class colored according to COLOR_MAPPING. | |
""" | |
h, w = mask.shape | |
color_mask = np.zeros((h, w, 3), dtype=np.uint8) | |
for cls, color in COLOR_MAPPING.items(): | |
color_mask[mask == cls] = color | |
return color_mask | |
# Download the model state dictionary from your Hugging Face repository | |
model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth") | |
# Create the model using segmentation_models_pytorch. | |
# This must match the architecture used during training. | |
model = smp.Unet( | |
encoder_name="resnet34", # For example, resnet34 was used in training. | |
encoder_weights="imagenet", # Use pretrained weights from ImageNet. | |
in_channels=3, # RGB images. | |
classes=NUM_CLASSES # Number of segmentation classes. | |
) | |
# Load the state dict (mapping the keys appropriately) | |
model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
model.eval() | |
# Define preprocessing transforms (should match what was used during training) | |
preprocess = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means | |
std=(0.229, 0.224, 0.225)) | |
]) | |
# Define the inference function | |
def predict(image): | |
""" | |
Accepts a PIL image, preprocesses it, runs the model, | |
and returns the predicted colored segmentation mask. | |
""" | |
# Preprocess the image | |
input_tensor = preprocess(image).unsqueeze(0) # shape: [1, 3, 256, 256] | |
with torch.no_grad(): | |
output = model(input_tensor) | |
# Get the predicted class for each pixel | |
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8) | |
# Convert the 2D class-index mask to a color mask | |
colored_mask = colorize_mask(pred_mask) | |
return colored_mask | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Image(type="numpy"), | |
title="Oil Spill Segmentation", | |
description="Segment oil spills in aerial images." | |
) | |
print("Gradio version:", gr.__version__) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.queue() | |
iface.launch(server_name="0.0.0.0", server_port=7860) | |