Spaces:
Sleeping
Sleeping
Zell
commited on
Commit
·
1991a7d
1
Parent(s):
a77a1a7
Beautify the UI with tabs
Browse files
app.py
CHANGED
@@ -3,34 +3,77 @@ from PIL import Image
|
|
3 |
import gradio as gr
|
4 |
from huggingface_hub import snapshot_download
|
5 |
import os
|
|
|
|
|
6 |
|
7 |
|
8 |
def load_model(repo_id):
|
9 |
download_dir = snapshot_download(repo_id)
|
10 |
print(download_dir)
|
11 |
-
path
|
12 |
print(path)
|
13 |
detection_model = YOLO(path, task='detect')
|
14 |
return detection_model
|
15 |
|
16 |
|
17 |
-
def
|
18 |
-
|
19 |
-
source = pilimg
|
20 |
-
# x = np.asarray(pilimg)
|
21 |
-
# print(x.shape)
|
22 |
-
result = detection_model.predict(source, conf=0.5, iou=0.6)
|
23 |
img_bgr = result[0].plot()
|
24 |
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image
|
25 |
-
|
26 |
return out_pilimg
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
REPO_ID = "zell-dev/fire-smoke-detection"
|
30 |
detection_model = load_model(REPO_ID)
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
|
|
|
3 |
import gradio as gr
|
4 |
from huggingface_hub import snapshot_download
|
5 |
import os
|
6 |
+
import cv2
|
7 |
+
import tempfile
|
8 |
|
9 |
|
10 |
def load_model(repo_id):
|
11 |
download_dir = snapshot_download(repo_id)
|
12 |
print(download_dir)
|
13 |
+
path = os.path.join(download_dir, "best_int8_openvino_model")
|
14 |
print(path)
|
15 |
detection_model = YOLO(path, task='detect')
|
16 |
return detection_model
|
17 |
|
18 |
|
19 |
+
def predict_image(pilimg):
|
20 |
+
result = detection_model.predict(pilimg, conf=0.5, iou=0.6)
|
|
|
|
|
|
|
|
|
21 |
img_bgr = result[0].plot()
|
22 |
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image
|
|
|
23 |
return out_pilimg
|
24 |
|
25 |
|
26 |
+
def process_video(video_file):
|
27 |
+
cap = cv2.VideoCapture(video_file)
|
28 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
29 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
30 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
31 |
+
|
32 |
+
# Use a temporary file to store the annotated video
|
33 |
+
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
34 |
+
temp_video_path = temp_video.name
|
35 |
+
writer = cv2.VideoWriter(temp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
|
36 |
+
|
37 |
+
while cap.isOpened():
|
38 |
+
ret, frame = cap.read()
|
39 |
+
if not ret:
|
40 |
+
break
|
41 |
+
# Perform object detection
|
42 |
+
results = detection_model.predict(frame, conf=0.5, iou=0.6)
|
43 |
+
annotated_frame = results[0].plot()
|
44 |
+
writer.write(annotated_frame)
|
45 |
+
|
46 |
+
cap.release()
|
47 |
+
writer.release()
|
48 |
+
|
49 |
+
return temp_video_path
|
50 |
+
|
51 |
REPO_ID = "zell-dev/fire-smoke-detection"
|
52 |
detection_model = load_model(REPO_ID)
|
53 |
|
54 |
+
# Improved UI with image and video upload
|
55 |
+
with gr.Blocks() as app:
|
56 |
+
gr.Markdown("# 🔥 Fire and Smoke Detection App - 2415336E")
|
57 |
+
gr.Markdown("Upload an image or a video to detect fire or smoke using YOLO model.")
|
58 |
+
|
59 |
+
with gr.Tabs():
|
60 |
+
with gr.Tab("Image Detection"):
|
61 |
+
gr.Markdown("### Upload an Image")
|
62 |
+
img_input = gr.Image(type="pil", label="Input Image")
|
63 |
+
img_output = gr.Image(type="pil", label="Detection Output")
|
64 |
+
img_button = gr.Button("Detect Fire/Smoke")
|
65 |
+
img_button.click(predict_image, inputs=img_input, outputs=img_output)
|
66 |
+
|
67 |
+
with gr.Tab("Video Detection"):
|
68 |
+
gr.Markdown("### Upload a video to detect fire and smoke. You can download the processed video after detection.")
|
69 |
+
video_input = gr.Video(label="Upload Video")
|
70 |
+
video_output = gr.File(label="Download Processed Video") # File download component
|
71 |
+
process_button = gr.Button("Process Video")
|
72 |
+
|
73 |
+
process_button.click(process_video, inputs=video_input, outputs=video_output)
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
gr.Markdown("Developed with 💻 by Zell (Feng Long)")
|
78 |
|
79 |
+
app.launch(share=True)
|