oil-spill-api / app.py
TheArchitect416's picture
Update app.py
0ba9aa3 verified
import torch
import gradio as gr
from torchvision import transforms
from huggingface_hub import hf_hub_download
import segmentation_models_pytorch as smp
import numpy as np
# Set the number of output classes (from your label_colors.txt, you have 4 classes)
NUM_CLASSES = 4
# Define a mapping from class indices to RGB colors
# For example: background: black, oil: (255, 0, 124), others: (255, 204, 51), water: (51, 221, 255)
COLOR_MAPPING = {
0: [0, 0, 0],
1: [255, 0, 124],
2: [255, 204, 51],
3: [51, 221, 255]
}
def colorize_mask(mask):
"""
Convert a 2D mask (with class indices) into a color image.
Args:
mask (np.ndarray): 2D numpy array with class indices.
Returns:
np.ndarray: Color image (H x W x 3) with each class colored according to COLOR_MAPPING.
"""
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for cls, color in COLOR_MAPPING.items():
color_mask[mask == cls] = color
return color_mask
# Download the model state dictionary from your Hugging Face repository
model_path = hf_hub_download(repo_id="TheArchitect416/oil-spill-segmentation-model", filename="model.pth")
# Create the model using segmentation_models_pytorch.
# This must match the architecture used during training.
model = smp.Unet(
encoder_name="resnet34", # For example, resnet34 was used in training.
encoder_weights="imagenet", # Use pretrained weights from ImageNet.
in_channels=3, # RGB images.
classes=NUM_CLASSES # Number of segmentation classes.
)
# Load the state dict (mapping the keys appropriately)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Define preprocessing transforms (should match what was used during training)
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet means
std=(0.229, 0.224, 0.225))
])
# Define the inference function
def predict(image):
"""
Accepts a PIL image, preprocesses it, runs the model,
and returns the predicted colored segmentation mask.
"""
# Preprocess the image
input_tensor = preprocess(image).unsqueeze(0) # shape: [1, 3, 256, 256]
with torch.no_grad():
output = model(input_tensor)
# Get the predicted class for each pixel
pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
# Convert the 2D class-index mask to a color mask
colored_mask = colorize_mask(pred_mask)
return colored_mask
# Create a Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="numpy"),
title="Oil Spill Segmentation",
description="Segment oil spills in aerial images."
)
print("Gradio version:", gr.__version__)
# Launch the interface
if __name__ == "__main__":
iface.queue()
iface.launch(server_name="0.0.0.0", server_port=7860)