allutrifork commited on
Commit
eb9bd6c
Β·
1 Parent(s): 065086c

counting added

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -1,12 +1,9 @@
1
  # app.py
2
  import gradio as gr
3
  from PIL import Image, ImageDraw, ImageFont
4
- import torch
5
  from ultralytics import YOLO
6
  import numpy as np
7
  import os
8
- from PIL import __version__ as PIL_VERSION
9
- print(f"Pillow version: {PIL_VERSION}")
10
 
11
  MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
12
 
@@ -17,7 +14,8 @@ print("YOLO model loaded.")
17
 
18
  def detect_plastic_pellets(input_image, threshold=0.5):
19
  """
20
- Perform plastic pellet detection using our customized model.
 
21
  """
22
  if input_image is None:
23
  error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
@@ -27,7 +25,7 @@ def detect_plastic_pellets(input_image, threshold=0.5):
27
  except IOError:
28
  font = ImageFont.load_default()
29
  draw.text((10, 40), "Please upload a valid image.", fill=(255, 255, 255), font=font)
30
- return error_image
31
 
32
  try:
33
  print("Starting detection with threshold:", threshold)
@@ -42,6 +40,7 @@ def detect_plastic_pellets(input_image, threshold=0.5):
42
  font = ImageFont.load_default()
43
 
44
  detection_made = False
 
45
 
46
  for result in results:
47
  for box in result.boxes:
@@ -58,17 +57,19 @@ def detect_plastic_pellets(input_image, threshold=0.5):
58
 
59
  label = f"{name} {confidence:.2f}"
60
  text_width, text_height = font.getbbox(label)[2:]
61
- draw.rectangle(((x1, y1 - text_height), (x1 + text_width, y1)), fill=color)
62
- draw.text((x1, y1 - text_height), label, fill=(255, 255, 255), font=font)
 
 
63
 
64
  detection_made = True
 
65
 
66
  if not detection_made:
67
  draw.text((10, 10), "No plastic pellets detected.", fill=(255, 0, 0), font=font)
68
- return input_image
69
-
70
- print("Detection completed.")
71
- return input_image
72
 
73
  except Exception as e:
74
  print(f"Detection error: {str(e)}")
@@ -79,7 +80,7 @@ def detect_plastic_pellets(input_image, threshold=0.5):
79
  except IOError:
80
  font = ImageFont.load_default()
81
  draw.text((10, 40), f"Error: {str(e)}", fill=(255, 255, 255), font=font)
82
- return error_image
83
 
84
  def main():
85
  with gr.Blocks(css=".gradio-container {max-width: 800px}") as demo:
@@ -93,10 +94,17 @@ def main():
93
  with gr.Row():
94
  with gr.Column():
95
  input_image = gr.Image(type="pil", label="🌊 Upload or Select Beach Image", interactive=True)
96
- examples = ['images/image1.bmp', 'images/image2.bmp', 'images/image3.bmp', 'images/image4.bmp', 'images/image5.bmp', 'images/image6.bmp']
 
 
 
 
 
 
 
97
  gr.Examples(examples=examples, inputs=input_image, label="Or choose one of these images")
98
 
99
- # Add a slider for confidence threshold
100
  confidence_threshold = gr.Slider(
101
  minimum=0.0,
102
  maximum=1.0,
@@ -109,7 +117,17 @@ def main():
109
  submit_button = gr.Button("πŸ” Detect Plastic Pellets")
110
 
111
  with gr.Column():
112
- output_image = gr.Image(type="pil", label="βœ… Detection Result", interactive=False, show_download_button=True)
 
 
 
 
 
 
 
 
 
 
113
 
114
  gr.Markdown(
115
  """
@@ -121,7 +139,7 @@ def main():
121
  submit_button.click(
122
  fn=detect_plastic_pellets,
123
  inputs=[input_image, confidence_threshold],
124
- outputs=output_image,
125
  api_name="detect",
126
  show_progress=True
127
  )
@@ -129,4 +147,4 @@ def main():
129
  demo.launch()
130
 
131
  if __name__ == "__main__":
132
- main()
 
1
  # app.py
2
  import gradio as gr
3
  from PIL import Image, ImageDraw, ImageFont
 
4
  from ultralytics import YOLO
5
  import numpy as np
6
  import os
 
 
7
 
8
  MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
9
 
 
14
 
15
  def detect_plastic_pellets(input_image, threshold=0.5):
16
  """
17
+ Perform plastic pellet detection using our customized YOLO model.
18
+ Returns the processed image and the number of detections.
19
  """
20
  if input_image is None:
21
  error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
 
25
  except IOError:
26
  font = ImageFont.load_default()
27
  draw.text((10, 40), "Please upload a valid image.", fill=(255, 255, 255), font=font)
28
+ return error_image, 0 # Returning 0 detections
29
 
30
  try:
31
  print("Starting detection with threshold:", threshold)
 
40
  font = ImageFont.load_default()
41
 
42
  detection_made = False
43
+ detection_count = 0 # Initialize detection count
44
 
45
  for result in results:
46
  for box in result.boxes:
 
57
 
58
  label = f"{name} {confidence:.2f}"
59
  text_width, text_height = font.getbbox(label)[2:]
60
+ # Ensure text does not go above the image
61
+ text_y = max(y1 - text_height, 0)
62
+ draw.rectangle(((x1, text_y), (x1 + text_width, y1)), fill=color)
63
+ draw.text((x1, text_y), label, fill=(255, 255, 255), font=font)
64
 
65
  detection_made = True
66
+ detection_count += 1 # Increment detection count
67
 
68
  if not detection_made:
69
  draw.text((10, 10), "No plastic pellets detected.", fill=(255, 0, 0), font=font)
70
+
71
+ print("Detection completed. Total detections:", detection_count)
72
+ return input_image, detection_count
 
73
 
74
  except Exception as e:
75
  print(f"Detection error: {str(e)}")
 
80
  except IOError:
81
  font = ImageFont.load_default()
82
  draw.text((10, 40), f"Error: {str(e)}", fill=(255, 255, 255), font=font)
83
+ return error_image, 0 # Returning 0 detections on error
84
 
85
  def main():
86
  with gr.Blocks(css=".gradio-container {max-width: 800px}") as demo:
 
94
  with gr.Row():
95
  with gr.Column():
96
  input_image = gr.Image(type="pil", label="🌊 Upload or Select Beach Image", interactive=True)
97
+ examples = [
98
+ 'images/image1.bmp',
99
+ 'images/image2.bmp',
100
+ 'images/image3.bmp',
101
+ 'images/image4.bmp',
102
+ 'images/image5.bmp',
103
+ 'images/image6.bmp'
104
+ ]
105
  gr.Examples(examples=examples, inputs=input_image, label="Or choose one of these images")
106
 
107
+ # Slider for confidence threshold
108
  confidence_threshold = gr.Slider(
109
  minimum=0.0,
110
  maximum=1.0,
 
117
  submit_button = gr.Button("πŸ” Detect Plastic Pellets")
118
 
119
  with gr.Column():
120
+ output_image = gr.Image(
121
+ type="pil",
122
+ label="βœ… Detection Result",
123
+ interactive=False,
124
+ show_download_button=True
125
+ )
126
+ detection_count = gr.Text(
127
+ value="Detections: 0",
128
+ label="πŸ”’ Number of Detections",
129
+ interactive=False
130
+ )
131
 
132
  gr.Markdown(
133
  """
 
139
  submit_button.click(
140
  fn=detect_plastic_pellets,
141
  inputs=[input_image, confidence_threshold],
142
+ outputs=[output_image, detection_count],
143
  api_name="detect",
144
  show_progress=True
145
  )
 
147
  demo.launch()
148
 
149
  if __name__ == "__main__":
150
+ main()