kimpepe commited on
Commit
61a5fbb
·
verified ·
1 Parent(s): 3fe1f84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -10
app.py CHANGED
@@ -1,26 +1,70 @@
1
  from ultralytics import YOLO
2
  from PIL import Image
3
  import gradio as gr
 
 
4
 
5
  # Load YOLOv8 model
6
- model = YOLO("best.pt") # Ensure best.pt is in the same directory
7
 
8
- # Preprocess and run inference
9
- def predict(image):
10
  # Perform prediction
11
  results = model.predict(source=image, conf=0.5)
12
-
13
  # Annotate the image with bounding boxes
14
  annotated_image = results[0].plot()
15
-
16
  # Convert to PIL Image
17
  return Image.fromarray(annotated_image)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Gradio interface
20
  gr.Interface(
21
- fn=predict,
22
- inputs=gr.Image(type="pil"),
23
- outputs="image",
 
 
 
24
  title="Hippo or Rhino Detection",
25
- description="Upload an image for object detection with YOLOv8."
26
- ).launch()
 
1
  from ultralytics import YOLO
2
  from PIL import Image
3
  import gradio as gr
4
+ import cv2
5
+ import tempfile
6
 
7
  # Load YOLOv8 model
8
+ model = YOLO("best.pt") # Ensure best.pt is in the same directory or provide the correct path
9
 
10
+ # Preprocess and run inference for images
11
+ def predict_image(image):
12
  # Perform prediction
13
  results = model.predict(source=image, conf=0.5)
14
+
15
  # Annotate the image with bounding boxes
16
  annotated_image = results[0].plot()
17
+
18
  # Convert to PIL Image
19
  return Image.fromarray(annotated_image)
20
 
21
+ # Preprocess and run inference for videos
22
+ def predict_video(video):
23
+ # Save video to a temporary file
24
+ temp_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
25
+ with open(temp_video_path, "wb") as f:
26
+ f.write(video.read())
27
+
28
+ # Open the video file
29
+ cap = cv2.VideoCapture(temp_video_path)
30
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for .mp4
31
+ output_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
32
+
33
+ # Get video properties
34
+ fps = cap.get(cv2.CAP_PROP_FPS)
35
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
36
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
37
+
38
+ # Create video writer for output
39
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
40
+
41
+ while cap.isOpened():
42
+ ret, frame = cap.read()
43
+ if not ret:
44
+ break # Exit when video ends
45
+
46
+ # Perform predictions on the frame
47
+ results = model.predict(source=frame, conf=0.5)
48
+ annotated_frame = results[0].plot() # Annotate frame
49
+
50
+ # Write the frame to the output video
51
+ out.write(annotated_frame)
52
+
53
+ # Release resources
54
+ cap.release()
55
+ out.release()
56
+
57
+ # Return the annotated video path
58
+ return output_path
59
+
60
  # Gradio interface
61
  gr.Interface(
62
+ fn={"Image Detection": predict_image, "Video Detection": predict_video},
63
+ inputs=[
64
+ gr.Image(type="pil", label="Upload an Image"),
65
+ gr.Video(label="Upload a Video")
66
+ ],
67
+ outputs=["image", "video"],
68
  title="Hippo or Rhino Detection",
69
+ description="Upload an image or video for object detection using YOLOv8."
70
+ ).launch()