Zell commited on
Commit
1991a7d
·
1 Parent(s): a77a1a7

Beautify the UI with tabs

Browse files
Files changed (1) hide show
  1. app.py +55 -12
app.py CHANGED
@@ -3,34 +3,77 @@ from PIL import Image
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
5
  import os
 
 
6
 
7
 
8
  def load_model(repo_id):
9
  download_dir = snapshot_download(repo_id)
10
  print(download_dir)
11
- path = os.path.join(download_dir, "best_int8_openvino_model")
12
  print(path)
13
  detection_model = YOLO(path, task='detect')
14
  return detection_model
15
 
16
 
17
- def predict(pilimg):
18
-
19
- source = pilimg
20
- # x = np.asarray(pilimg)
21
- # print(x.shape)
22
- result = detection_model.predict(source, conf=0.5, iou=0.6)
23
  img_bgr = result[0].plot()
24
  out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image
25
-
26
  return out_pilimg
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  REPO_ID = "zell-dev/fire-smoke-detection"
30
  detection_model = load_model(REPO_ID)
31
 
32
- gr.Interface(fn=predict,
33
- inputs=gr.Image(type="pil"),
34
- outputs=gr.Image(type="pil")
35
- ).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
5
  import os
6
+ import cv2
7
+ import tempfile
8
 
9
 
10
  def load_model(repo_id):
11
  download_dir = snapshot_download(repo_id)
12
  print(download_dir)
13
+ path = os.path.join(download_dir, "best_int8_openvino_model")
14
  print(path)
15
  detection_model = YOLO(path, task='detect')
16
  return detection_model
17
 
18
 
19
+ def predict_image(pilimg):
20
+ result = detection_model.predict(pilimg, conf=0.5, iou=0.6)
 
 
 
 
21
  img_bgr = result[0].plot()
22
  out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image
 
23
  return out_pilimg
24
 
25
 
26
+ def process_video(video_file):
27
+ cap = cv2.VideoCapture(video_file)
28
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
29
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
30
+ fps = cap.get(cv2.CAP_PROP_FPS)
31
+
32
+ # Use a temporary file to store the annotated video
33
+ temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
34
+ temp_video_path = temp_video.name
35
+ writer = cv2.VideoWriter(temp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
36
+
37
+ while cap.isOpened():
38
+ ret, frame = cap.read()
39
+ if not ret:
40
+ break
41
+ # Perform object detection
42
+ results = detection_model.predict(frame, conf=0.5, iou=0.6)
43
+ annotated_frame = results[0].plot()
44
+ writer.write(annotated_frame)
45
+
46
+ cap.release()
47
+ writer.release()
48
+
49
+ return temp_video_path
50
+
51
  REPO_ID = "zell-dev/fire-smoke-detection"
52
  detection_model = load_model(REPO_ID)
53
 
54
+ # Improved UI with image and video upload
55
+ with gr.Blocks() as app:
56
+ gr.Markdown("# 🔥 Fire and Smoke Detection App - 2415336E")
57
+ gr.Markdown("Upload an image or a video to detect fire or smoke using YOLO model.")
58
+
59
+ with gr.Tabs():
60
+ with gr.Tab("Image Detection"):
61
+ gr.Markdown("### Upload an Image")
62
+ img_input = gr.Image(type="pil", label="Input Image")
63
+ img_output = gr.Image(type="pil", label="Detection Output")
64
+ img_button = gr.Button("Detect Fire/Smoke")
65
+ img_button.click(predict_image, inputs=img_input, outputs=img_output)
66
+
67
+ with gr.Tab("Video Detection"):
68
+ gr.Markdown("### Upload a video to detect fire and smoke. You can download the processed video after detection.")
69
+ video_input = gr.Video(label="Upload Video")
70
+ video_output = gr.File(label="Download Processed Video") # File download component
71
+ process_button = gr.Button("Process Video")
72
+
73
+ process_button.click(process_video, inputs=video_input, outputs=video_output)
74
+
75
+
76
+
77
+ gr.Markdown("Developed with 💻 by Zell (Feng Long)")
78
 
79
+ app.launch(share=True)