Ganrong commited on
Commit
4a8f9a6
·
verified ·
1 Parent(s): 78b28ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -5,7 +5,8 @@ from huggingface_hub import snapshot_download
5
  import os
6
  import cv2
7
  import numpy as np
8
- from tqdm import tqdm # For progress bar
 
9
 
10
  # Function to load the model
11
  def load_model(repo_id):
@@ -15,7 +16,7 @@ def load_model(repo_id):
15
  detection_model = YOLO(path, task="detect")
16
  return detection_model
17
 
18
- # Function to predict an image
19
  def predict_image(pilimg, conf_threshold, iou_threshold):
20
  """Process an image with user-defined thresholds."""
21
  try:
@@ -26,9 +27,9 @@ def predict_image(pilimg, conf_threshold, iou_threshold):
26
  except Exception as e:
27
  return f"Error processing image: {e}"
28
 
29
- # Function to predict a video with progress tracking
30
  def predict_video(video_file, conf_threshold, iou_threshold, start_time, end_time):
31
- """Process a video with user-defined thresholds and time range."""
32
  cap = cv2.VideoCapture(video_file)
33
 
34
  if not cap.isOpened():
@@ -39,9 +40,11 @@ def predict_video(video_file, conf_threshold, iou_threshold, start_time, end_tim
39
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
40
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
41
 
42
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
43
- output_path = "output_video.mp4"
 
44
 
 
45
  start_frame = int(start_time * fps) if start_time else 0
46
  end_frame = int(end_time * fps) if end_time else total_frames
47
 
@@ -78,6 +81,7 @@ detection_model = load_model(REPO_ID)
78
  with gr.Blocks() as demo:
79
  gr.Markdown("## Pangolin and Axolotl Detection")
80
 
 
81
  with gr.Tab("Image Input"):
82
  img_input = gr.Image(type="pil", label="Upload an Image")
83
  conf_slider_img = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
@@ -91,13 +95,14 @@ with gr.Blocks() as demo:
91
  outputs=img_output
92
  )
93
 
 
94
  with gr.Tab("Video Input"):
95
  video_input = gr.Video(label="Upload a Video")
96
  conf_slider_video = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
97
  iou_slider_video = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold")
98
  start_time = gr.Number(value=0, label="Start Time (seconds)")
99
  end_time = gr.Number(value=0, label="End Time (seconds, 0 for full video)")
100
- video_output = gr.File(label="Download Processed Video")
101
  video_submit = gr.Button("Process Video")
102
 
103
  video_submit.click(
 
5
  import os
6
  import cv2
7
  import numpy as np
8
+ from tqdm import tqdm
9
+ import tempfile
10
 
11
  # Function to load the model
12
  def load_model(repo_id):
 
16
  detection_model = YOLO(path, task="detect")
17
  return detection_model
18
 
19
+ # Function to process an image
20
  def predict_image(pilimg, conf_threshold, iou_threshold):
21
  """Process an image with user-defined thresholds."""
22
  try:
 
27
  except Exception as e:
28
  return f"Error processing image: {e}"
29
 
30
+ # Function to process a video
31
  def predict_video(video_file, conf_threshold, iou_threshold, start_time, end_time):
32
+ """Process a video and return the path for displaying."""
33
  cap = cv2.VideoCapture(video_file)
34
 
35
  if not cap.isOpened():
 
40
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
41
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
42
 
43
+ # Use a temporary file to store the processed video
44
+ temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
45
+ output_path = temp_video_file.name
46
 
47
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
48
  start_frame = int(start_time * fps) if start_time else 0
49
  end_frame = int(end_time * fps) if end_time else total_frames
50
 
 
81
  with gr.Blocks() as demo:
82
  gr.Markdown("## Pangolin and Axolotl Detection")
83
 
84
+ # Image Processing Tab
85
  with gr.Tab("Image Input"):
86
  img_input = gr.Image(type="pil", label="Upload an Image")
87
  conf_slider_img = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
 
95
  outputs=img_output
96
  )
97
 
98
+ # Video Processing Tab
99
  with gr.Tab("Video Input"):
100
  video_input = gr.Video(label="Upload a Video")
101
  conf_slider_video = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
102
  iou_slider_video = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold")
103
  start_time = gr.Number(value=0, label="Start Time (seconds)")
104
  end_time = gr.Number(value=0, label="End Time (seconds, 0 for full video)")
105
+ video_output = gr.Video(label="Processed Video")
106
  video_submit = gr.Button("Process Video")
107
 
108
  video_submit.click(