2303113Q / app.py
Xphosis's picture
Update app.py
2335ea2 verified
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import cv2
import tempfile
from tqdm import tqdm
def load_model(repo_id):
download_dir = snapshot_download(repo_id)
print(download_dir)
path = os.path.join(download_dir, 'best_int8_openvino_model')
print(path)
detection_model = YOLO(path, task='detect')
return detection_model
def detect_image(pilimg):
source = pilimg
# x = np.asarray(pilimg)
# print(x.shape)
result = detection_model.predict(source, conf=0.5, iou=0.7)
img_bgr = result[0].plot()
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image
return out_pilimg
def detect_video(video_file):
video_reader = cv2.VideoCapture(video_file)
nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = video_reader.get(cv2.CAP_PROP_FPS)
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
temp_video_path = temp_video.name
video_writer = cv2.VideoWriter(temp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps/2, (frame_width, frame_height))
for i in tqdm(range(nb_frames)):
success, frame = video_reader.read()
if success:
results = detection_model.predict(frame, conf=0.5, iou=0.7)
annotated_frame = results[0].plot()
video_writer.write(annotated_frame)
video_reader.release()
video_writer.release()
return temp_video_path
REPO_ID = 'Xphosis/PotatoOrSweetPotato'
detection_model = load_model(REPO_ID)
with gr.Blocks() as app:
with gr.Row():
with gr.Column(scale=1, min_width=300):
# gr.Markdown('Upload image to detect Potato or Sweet Potato')
img_input = gr.Image(type='pil', label='Upload Image to detect Potato or Sweet Potato',height=300,min_width=300)
img_output = gr.Image(type='pil', label='Download Image',height=300,min_width=300)
img_button = gr.Button('Analyze Image')
img_button.click(detect_image, inputs=img_input, outputs=img_output,)
clear_image_button = gr.Button('Clear Image')
clear_image_button.click(fn=lambda: None, inputs=None, outputs=img_input)
with gr.Column(scale=1, min_width=300):
# gr.Markdown('Upload video to detect Potato or Sweet Potato')
video_input = gr.Video(label='Upload Video to detect Potato or Sweet Potato', height=300,min_width=300)
video_output = gr.Video(label='Download Video', autoplay=True, loop=True,show_share_button=True, show_download_button=True,height=300, min_width=300)
video_button = gr.Button('Analyse Video')
video_button.click(detect_video, inputs=video_input, outputs=video_output)
clear_video_button = gr.Button('Clear Video')
clear_video_button.click(fn=lambda: None, inputs=None, outputs=video_input)
app.launch()