ZunYin commited on
Commit
7f0b8e1
·
verified ·
1 Parent(s): 1c92ef5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -4
app.py CHANGED
@@ -1,7 +1,94 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import cv2
3
+ import requests
4
+ import os
5
+ from ultralytics import YOLO
6
 
7
+ # Define file URLs for images and videos
8
+ file_urls = [
9
+ 'https://drive.google.com/file/d/1rvuphnn3BV4NdILrQE72jU7fxA79SiYn/view?usp=sharing', # Image
10
+ 'https://drive.google.com/file/d/16gu9cLamGFrM5DRd1WJyk_6Xt9v0S7go/view?usp=sharing', # Image
11
+ 'https://drive.google.com/file/d/1UgZi54js65f5qGhNF3nGLZwIN5nrUek6/view?usp=sharing', # Video
12
+ ]
13
 
14
+ # Helper function to download files
15
+ def download_file(url, save_name):
16
+ if not os.path.exists(save_name):
17
+ file = requests.get(url)
18
+ open(save_name, 'wb').write(file.content)
19
+
20
+ # Download example files
21
+ for i, url in enumerate(file_urls):
22
+ if url.endswith(".mp4"):
23
+ download_file(url, f"video.mp4")
24
+ else:
25
+ download_file(url, f"image_{i}.jpg")
26
+
27
+ # Load the YOLO model
28
+ model = YOLO('best.pt')
29
+
30
+ # Define example paths for Gradio
31
+ image_examples = [["image_0.jpg"], ["image_1.jpg"]]
32
+ video_examples = [["video.mp4"]]
33
+
34
+ # Function for processing images
35
+ def show_preds_image(image_path):
36
+ image = cv2.imread(image_path)
37
+ results = model.predict(source=image_path)
38
+ annotated_image = results[0].plot() # YOLO provides a built-in plot function
39
+ return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
40
+
41
+ # Function for processing videos
42
+ def show_preds_video(video_path):
43
+ cap = cv2.VideoCapture(video_path)
44
+ out_frames = [] # List to store annotated frames
45
+
46
+ while cap.isOpened():
47
+ ret, frame = cap.read()
48
+ if not ret:
49
+ break
50
+
51
+ results = model.predict(source=frame)
52
+ annotated_frame = results[0].plot()
53
+ out_frames.append(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
54
+
55
+ cap.release()
56
+
57
+ # Save the annotated video
58
+ output_path = "annotated_video.mp4"
59
+ height, width, _ = out_frames[0].shape
60
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
61
+ writer = cv2.VideoWriter(output_path, fourcc, 20, (width, height))
62
+
63
+ for frame in out_frames:
64
+ writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
65
+
66
+ writer.release()
67
+ return output_path
68
+
69
+ # Gradio interfaces
70
+ inputs_image = gr.Image(type="filepath", label="Input Image")
71
+ outputs_image = gr.Image(type="numpy", label="Output Image")
72
+ interface_image = gr.Interface(
73
+ fn=show_preds_image,
74
+ inputs=inputs_image,
75
+ outputs=outputs_image,
76
+ title="Tiger & Ibex Detector - Image",
77
+ examples=image_examples,
78
+ )
79
+
80
+ inputs_video = gr.Video(label="Input Video") # Removed type argument
81
+ outputs_video = gr.Video(label="Annotated Output")
82
+ interface_video = gr.Interface(
83
+ fn=show_preds_video,
84
+ inputs=inputs_video,
85
+ outputs=outputs_video,
86
+ title="Tiger & Ibex Detector - Video",
87
+ examples=video_examples,
88
+ )
89
+
90
+ # Combine into a tabbed interface
91
+ gr.TabbedInterface(
92
+ [interface_image, interface_video],
93
+ tab_names=['Image Inference', 'Video Inference']
94
+ ).launch(share=True)