File size: 5,570 Bytes
4ad8062
10eac88
14fcecf
acbbd8b
 
 
2f9a445
b1b2c74
6f6962c
acbbd8b
b1b2c74
acbbd8b
b1b2c74
acbbd8b
189d865
da77435
eb9bd6c
 
da77435
 
 
 
5948fc5
 
 
 
b1b2c74
eb9bd6c
da77435
5948fc5
9557a09
b1b2c74
5948fc5
 
 
 
 
 
 
 
 
 
eb9bd6c
5948fc5
 
b1b2c74
5948fc5
189d865
065086c
6f6962c
 
5948fc5
 
 
b1b2c74
 
5948fc5
 
b1b2c74
eb9bd6c
 
 
 
5948fc5
 
eb9bd6c
5948fc5
9557a09
b1b2c74
eb9bd6c
 
 
5948fc5
 
 
 
 
 
 
 
 
b1b2c74
eb9bd6c
84f2317
 
c0bf37a
84f2317
 
5948fc5
 
84f2317
 
50d49d8
84f2317
 
b1b2c74
eb9bd6c
 
 
 
 
 
 
 
b1b2c74
6f6962c
eb9bd6c
189d865
6f6962c
 
189d865
6f6962c
189d865
6f6962c
 
 
c0bf37a
50d49d8
84f2317
eb9bd6c
 
 
 
 
 
 
 
 
 
 
b1b2c74
5948fc5
 
 
b1b2c74
5948fc5
c0bf37a
50d49d8
c0bf37a
50d49d8
189d865
eb9bd6c
5948fc5
 
c0bf37a
50d49d8
2f9a445
84f2317
 
eb9bd6c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# app.py
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import numpy as np
import os

MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"

if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"YOLO model not found at '{MODEL_PATH}'.")
model = YOLO(MODEL_PATH)
print("YOLO model loaded.")

def detect_plastic_pellets(input_image, threshold=0.5):
    """
    Perform plastic pellet detection using our customized YOLO model.
    Returns the processed image and the number of detections.
    """
    if input_image is None:
        error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
        draw = ImageDraw.Draw(error_image)
        try:
            font = ImageFont.truetype("arial.ttf", size=15)
        except IOError:
            font = ImageFont.load_default()
        draw.text((10, 40), "Please upload a valid image.", fill=(255, 255, 255), font=font)
        return error_image, 0  # Returning 0 detections

    try:
        print("Starting detection with threshold:", threshold)
        input_image.thumbnail((1024, 1024), Image.LANCZOS)
        img = np.array(input_image.convert("RGB"))

        results = model(img)
        draw = ImageDraw.Draw(input_image)
        try:
            font = ImageFont.truetype("arial.ttf", size=15)
        except IOError:
            font = ImageFont.load_default()

        detection_made = False
        detection_count = 0  # Initialize detection count

        for result in results:
            for box in result.boxes:
                confidence = box.conf[0].item()
                if confidence < threshold:
                    continue  

                x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
                cls = int(box.cls[0].item())
                name = model.names[cls] if model.names else "Object"

                color = (255, 0, 0)
                draw.rectangle(((x1, y1), (x2, y2)), outline=color, width=2)

                label = f"{name} {confidence:.2f}"
                text_width, text_height = font.getbbox(label)[2:]
                # Ensure text does not go above the image
                text_y = max(y1 - text_height, 0)
                draw.rectangle(((x1, text_y), (x1 + text_width, y1)), fill=color)
                draw.text((x1, text_y), label, fill=(255, 255, 255), font=font)

                detection_made = True
                detection_count += 1  # Increment detection count

        if not detection_made:
            draw.text((10, 10), "No plastic pellets detected.", fill=(255, 0, 0), font=font)
        
        print("Detection completed. Total detections:", detection_count)
        return input_image, detection_count

    except Exception as e:
        print(f"Detection error: {str(e)}")
        error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
        draw = ImageDraw.Draw(error_image)
        try:
            font = ImageFont.truetype("arial.ttf", size=15)
        except IOError:
            font = ImageFont.load_default()
        draw.text((10, 40), f"Error: {str(e)}", fill=(255, 255, 255), font=font)
        return error_image, 0  # Returning 0 detections on error

def main():
    with gr.Blocks(css=".gradio-container {max-width: 800px}") as demo:
        gr.Markdown(
            """
            <h1 align="center">🌊 Beach Plastic Pellet Detection Challenge</h1>
            <p align="center">Help us clean up beaches from plastic pellets! Upload your beach photos or choose from our samples, and contribute to data collection for a cleaner environment.</p>
            """
        )

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="pil", label="🌊 Upload or Select Beach Image", interactive=True)
                examples = [
                    'images/image1.bmp',
                    'images/image2.bmp',
                    'images/image3.bmp',
                    'images/image4.bmp',
                    'images/image5.bmp',
                    'images/image6.bmp'
                ]
                gr.Examples(examples=examples, inputs=input_image, label="Or choose one of these images")
                
                # Slider for confidence threshold
                confidence_threshold = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.5,
                    step=0.05,
                    label="Confidence Threshold",
                    info="Adjust the confidence threshold for displaying detections."
                )
                
                submit_button = gr.Button("πŸ” Detect Plastic Pellets")

            with gr.Column():
                output_image = gr.Image(
                    type="pil",
                    label="βœ… Detection Result",
                    interactive=False,
                    show_download_button=True
                )
                detection_count = gr.Text(
                    value="Detections: 0",
                    label="πŸ”’ Number of Detections",
                    interactive=False
                )

        gr.Markdown(
            """
            ---
            <p align="center">Β© 2024 Beach Clean-Up Initiative.</p>
            """
        )

        submit_button.click(
            fn=detect_plastic_pellets,
            inputs=[input_image, confidence_threshold],
            outputs=[output_image, detection_count],  
            api_name="detect",
            show_progress=True  
        )

    demo.launch()

if __name__ == "__main__":
    main()