plastic-pellet / app.py
allutrifork's picture
counting added
eb9bd6c
# app.py
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import numpy as np
import os
MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"YOLO model not found at '{MODEL_PATH}'.")
model = YOLO(MODEL_PATH)
print("YOLO model loaded.")
def detect_plastic_pellets(input_image, threshold=0.5):
"""
Perform plastic pellet detection using our customized YOLO model.
Returns the processed image and the number of detections.
"""
if input_image is None:
error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
draw = ImageDraw.Draw(error_image)
try:
font = ImageFont.truetype("arial.ttf", size=15)
except IOError:
font = ImageFont.load_default()
draw.text((10, 40), "Please upload a valid image.", fill=(255, 255, 255), font=font)
return error_image, 0 # Returning 0 detections
try:
print("Starting detection with threshold:", threshold)
input_image.thumbnail((1024, 1024), Image.LANCZOS)
img = np.array(input_image.convert("RGB"))
results = model(img)
draw = ImageDraw.Draw(input_image)
try:
font = ImageFont.truetype("arial.ttf", size=15)
except IOError:
font = ImageFont.load_default()
detection_made = False
detection_count = 0 # Initialize detection count
for result in results:
for box in result.boxes:
confidence = box.conf[0].item()
if confidence < threshold:
continue
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
cls = int(box.cls[0].item())
name = model.names[cls] if model.names else "Object"
color = (255, 0, 0)
draw.rectangle(((x1, y1), (x2, y2)), outline=color, width=2)
label = f"{name} {confidence:.2f}"
text_width, text_height = font.getbbox(label)[2:]
# Ensure text does not go above the image
text_y = max(y1 - text_height, 0)
draw.rectangle(((x1, text_y), (x1 + text_width, y1)), fill=color)
draw.text((x1, text_y), label, fill=(255, 255, 255), font=font)
detection_made = True
detection_count += 1 # Increment detection count
if not detection_made:
draw.text((10, 10), "No plastic pellets detected.", fill=(255, 0, 0), font=font)
print("Detection completed. Total detections:", detection_count)
return input_image, detection_count
except Exception as e:
print(f"Detection error: {str(e)}")
error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
draw = ImageDraw.Draw(error_image)
try:
font = ImageFont.truetype("arial.ttf", size=15)
except IOError:
font = ImageFont.load_default()
draw.text((10, 40), f"Error: {str(e)}", fill=(255, 255, 255), font=font)
return error_image, 0 # Returning 0 detections on error
def main():
with gr.Blocks(css=".gradio-container {max-width: 800px}") as demo:
gr.Markdown(
"""
<h1 align="center">🌊 Beach Plastic Pellet Detection Challenge</h1>
<p align="center">Help us clean up beaches from plastic pellets! Upload your beach photos or choose from our samples, and contribute to data collection for a cleaner environment.</p>
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="🌊 Upload or Select Beach Image", interactive=True)
examples = [
'images/image1.bmp',
'images/image2.bmp',
'images/image3.bmp',
'images/image4.bmp',
'images/image5.bmp',
'images/image6.bmp'
]
gr.Examples(examples=examples, inputs=input_image, label="Or choose one of these images")
# Slider for confidence threshold
confidence_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.05,
label="Confidence Threshold",
info="Adjust the confidence threshold for displaying detections."
)
submit_button = gr.Button("πŸ” Detect Plastic Pellets")
with gr.Column():
output_image = gr.Image(
type="pil",
label="βœ… Detection Result",
interactive=False,
show_download_button=True
)
detection_count = gr.Text(
value="Detections: 0",
label="πŸ”’ Number of Detections",
interactive=False
)
gr.Markdown(
"""
---
<p align="center">Β© 2024 Beach Clean-Up Initiative.</p>
"""
)
submit_button.click(
fn=detect_plastic_pellets,
inputs=[input_image, confidence_threshold],
outputs=[output_image, detection_count],
api_name="detect",
show_progress=True
)
demo.launch()
if __name__ == "__main__":
main()