qiqiyuan commited on
Commit
3701ecb
·
verified ·
1 Parent(s): 178ef7b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ from PIL import Image
3
+ import gradio as gr
4
+ from huggingface_hub import snapshot_download
5
+ import os
6
+ import cv2
7
+
8
+ def load_model(repo_id, model_filename="best_int8_openvino_model"):
9
+ """
10
+ Loads a YOLO model from Hugging Face Hub.
11
+
12
+ Args:
13
+ repo_id: The ID of the Hugging Face Hub repository.
14
+ model_filename: The filename of the YOLO model within the repository.
15
+
16
+ Returns:
17
+ The loaded YOLO model.
18
+ """
19
+ download_dir = snapshot_download(repo_id)
20
+ model_path = os.path.join(download_dir, model_filename)
21
+ detection_model = YOLO(model_path, task='detect')
22
+ return detection_model
23
+
24
+ def predict_image(pilimg, conf_thresh, iou_thresh):
25
+ """
26
+ Performs object detection on the input image.
27
+
28
+ Args:
29
+ pilimg: The input image as a PIL Image object.
30
+ conf_thresh: The confidence threshold for object detection.
31
+ iou_thresh: The IoU threshold for non-maximum suppression.
32
+
33
+ Returns:
34
+ The processed image with detected objects highlighted.
35
+ """
36
+ source = pilimg
37
+ result = detection_model(source, conf=conf_thresh, iou=iou_thresh)
38
+ img_bgr = result[0].plot()
39
+ out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert BGR to RGB
40
+ return out_pilimg
41
+
42
+ def predict_video(video_path, conf_thresh, iou_thresh):
43
+ """
44
+ Performs object detection on a video.
45
+
46
+ Args:
47
+ video_path: Path to the video file.
48
+ conf_thresh: The confidence threshold for object detection.
49
+ iou_thresh: The IoU threshold for non-maximum suppression.
50
+
51
+ Returns:
52
+ A list of processed frames with detected objects highlighted.
53
+ """
54
+ cap = cv2.VideoCapture(video_path)
55
+ frame_list = []
56
+
57
+ while True:
58
+ ret, frame = cap.read()
59
+ if not ret:
60
+ break
61
+ pilimg = Image.fromarray(frame)
62
+ result = detection_model(pilimg, conf=conf_thresh, iou=iou_thresh)
63
+ img_bgr = result[0].plot()
64
+ frame_list.append(img_bgr)
65
+
66
+ cap.release()
67
+ return frame_list
68
+
69
+ REPO_ID = "qiqiyuan/glasses_and_mouth"
70
+ detection_model = load_model(REPO_ID)
71
+
72
+ iface = gr.Interface(
73
+ fn=[predict_image, predict_video],
74
+ inputs=[
75
+ gr.Image(type="pil", label="Image"),
76
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold"),
77
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold"),
78
+ gr.File(label="Video (optional)")
79
+ ],
80
+ outputs=[
81
+ gr.Image(type="pil", label="Image Output"),
82
+ gr.Video(label="Video Output")
83
+ ],
84
+ title="Object Detection with YOLO",
85
+ description="Upload an image or video to detect glasses and mouth.",
86
+ examples=[["examples/image1.jpg"], ["examples/video1.mp4"]] # Add example images/videos
87
+ )
88
+
89
+ iface.launch(share=True)