segmentation / gradio_app.py
Alex
update endpoint path
f8b9c38
raw
history blame contribute delete
1.84 kB
import gradio as gr
import requests
from PIL import Image
import io
import base64
import logging
from app import ModelManager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def process_image(url: str):
try:
# Initialize model manager (will load models if not already loaded)
model_manager = ModelManager()
# Download image from URL
response = requests.get(url, stream=True)
if response.status_code != 200:
raise ValueError("Could not download image from URL")
# Process image
image = Image.open(response.raw).convert("RGB")
result = model_manager.process_clothes_image(image)
# Convert base64 mask back to image
mask_data = result["mask"].split(",")[1]
mask_bytes = base64.b64decode(mask_data)
mask_image = Image.open(io.BytesIO(mask_bytes))
return image, mask_image, f"Processed image size: {result['size']}"
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
return None, None, f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Textbox(label="Image URL", placeholder="Enter the URL of the image"),
outputs=[
gr.Image(label="Original Image"),
gr.Image(label="Segmentation Mask"),
gr.Textbox(label="Processing Info")
],
title="Clothes Segmentation",
description="Enter an image URL to generate a segmentation mask for clothing items.",
examples=[
["https://example.com/path/to/clothing/image.jpg"],
["https://another-example.com/fashion/photo.jpg"]
],
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch(server_port=7861) # Using different port than FastAPI