|
import logging |
|
import os |
|
from typing import Tuple, List, Optional |
|
from pathlib import Path |
|
import shutil |
|
import tempfile |
|
import numpy as np |
|
import cv2 |
|
import gradio as gr |
|
from PIL import Image |
|
from transformers import pipeline |
|
from transformers.image_utils import load_image |
|
import tqdm |
|
|
|
|
|
CHECKPOINTS = [ |
|
"ustc-community/dfine_m_obj365", |
|
"ustc-community/dfine_n_coco", |
|
"ustc-community/dfine_s_coco", |
|
"ustc-community/dfine_m_coco", |
|
"ustc-community/dfine_l_coco", |
|
"ustc-community/dfine_x_coco", |
|
"ustc-community/dfine_s_obj365", |
|
"ustc-community/dfine_l_obj365", |
|
"ustc-community/dfine_x_obj365", |
|
"ustc-community/dfine_s_obj2coco", |
|
"ustc-community/dfine_m_obj2coco", |
|
"ustc-community/dfine_l_obj2coco_e25", |
|
"ustc-community/dfine_x_obj2coco", |
|
] |
|
MAX_NUM_FRAMES = 300 |
|
DEFAULT_CHECKPOINT = CHECKPOINTS[0] |
|
DEFAULT_CONFIDENCE_THRESHOLD = 0.3 |
|
IMAGE_EXAMPLES = [ |
|
{"path": "./image.jpg", "use_url": False, "url": "", "label": "Local Image"}, |
|
{ |
|
"path": None, |
|
"use_url": True, |
|
"url": "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", |
|
"label": "Flickr Image", |
|
}, |
|
] |
|
VIDEO_EXAMPLES = [ |
|
{"path": "./video.mp4", "label": "Local Video"}, |
|
] |
|
ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
VIDEO_OUTPUT_DIR = Path("static/videos") |
|
VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def detect_objects( |
|
image: Optional[Image.Image], |
|
checkpoint: str, |
|
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, |
|
use_url: bool = False, |
|
url: str = "", |
|
) -> Tuple[ |
|
Optional[Tuple[Image.Image, List[Tuple[Tuple[int, int, int, int], str]]]], |
|
gr.Markdown, |
|
]: |
|
if use_url and url: |
|
try: |
|
input_image = load_image(url) |
|
except Exception as e: |
|
logger.error(f"Failed to load image from URL {url}: {str(e)}") |
|
return None, gr.Markdown( |
|
f"**Error**: Failed to load image from URL: {str(e)}", visible=True |
|
) |
|
elif image is not None: |
|
if not isinstance(image, Image.Image): |
|
logger.error("Input image is not a PIL Image") |
|
return None, gr.Markdown("**Error**: Invalid image format.", visible=True) |
|
input_image = image |
|
else: |
|
return None, gr.Markdown( |
|
"**Error**: Please provide an image or URL.", visible=True |
|
) |
|
|
|
try: |
|
pipe = pipeline( |
|
"object-detection", |
|
model=checkpoint, |
|
image_processor=checkpoint, |
|
device="cpu", |
|
) |
|
except Exception as e: |
|
logger.error(f"Failed to initialize model pipeline for {checkpoint}: {str(e)}") |
|
return None, gr.Markdown( |
|
f"**Error**: Failed to load model: {str(e)}", visible=True |
|
) |
|
|
|
results = pipe(input_image, threshold=confidence_threshold) |
|
img_width, img_height = input_image.size |
|
|
|
annotations = [] |
|
for result in results: |
|
score = result["score"] |
|
if score < confidence_threshold: |
|
continue |
|
label = f"{result['label']} ({score:.2f})" |
|
box = result["box"] |
|
|
|
bbox_xmin = max(0, int(box["xmin"])) |
|
bbox_ymin = max(0, int(box["ymin"])) |
|
bbox_xmax = min(img_width, int(box["xmax"])) |
|
bbox_ymax = min(img_height, int(box["ymax"])) |
|
if bbox_xmax <= bbox_xmin or bbox_ymax <= bbox_ymin: |
|
continue |
|
bounding_box = (bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax) |
|
annotations.append((bounding_box, label)) |
|
|
|
if not annotations: |
|
return (input_image, []), gr.Markdown( |
|
"**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.", |
|
visible=True, |
|
) |
|
|
|
return (input_image, annotations), gr.Markdown(visible=False) |
|
|
|
|
|
def annotate_frame( |
|
image: Image.Image, annotations: List[Tuple[Tuple[int, int, int, int], str]] |
|
) -> np.ndarray: |
|
image_np = np.array(image) |
|
image_bgr = image_np[:, :, ::-1].copy() |
|
|
|
for (xmin, ymin, xmax, ymax), label in annotations: |
|
cv2.rectangle(image_bgr, (xmin, ymin), (xmax, ymax), (255, 255, 255), 2) |
|
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] |
|
cv2.rectangle( |
|
image_bgr, |
|
(xmin, ymin - text_size[1] - 4), |
|
(xmin + text_size[0], ymin), |
|
(255, 255, 255), |
|
-1, |
|
) |
|
cv2.putText( |
|
image_bgr, |
|
label, |
|
(xmin, ymin - 4), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
0.5, |
|
(0, 0, 0), |
|
1, |
|
) |
|
|
|
return image_bgr |
|
|
|
|
|
def process_video( |
|
video_path: str, |
|
checkpoint: str, |
|
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, |
|
progress: gr.Progress = gr.Progress(track_tqdm=True), |
|
) -> Tuple[Optional[str], gr.Markdown]: |
|
if not video_path or not os.path.isfile(video_path): |
|
logger.error(f"Invalid video path: {video_path}") |
|
return None, gr.Markdown( |
|
"**Error**: Please provide a valid video file.", visible=True |
|
) |
|
|
|
ext = os.path.splitext(video_path)[1].lower() |
|
if ext not in ALLOWED_VIDEO_EXTENSIONS: |
|
logger.error(f"Unsupported video format: {ext}") |
|
return None, gr.Markdown( |
|
f"**Error**: Unsupported video format. Use MP4, AVI, or MOV.", visible=True |
|
) |
|
|
|
try: |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
logger.error(f"Failed to open video: {video_path}") |
|
return None, gr.Markdown( |
|
"**Error**: Failed to open video file.", visible=True |
|
) |
|
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) |
|
writer = cv2.VideoWriter(temp_file.name, fourcc, fps, (width, height)) |
|
if not writer.isOpened(): |
|
logger.error("Failed to initialize video writer") |
|
cap.release() |
|
temp_file.close() |
|
os.unlink(temp_file.name) |
|
return None, gr.Markdown( |
|
"**Error**: Failed to initialize video writer.", visible=True |
|
) |
|
|
|
frame_count = 0 |
|
for _ in tqdm.tqdm( |
|
range(min(MAX_NUM_FRAMES, num_frames)), desc="Processing video" |
|
): |
|
ok, frame = cap.read() |
|
if not ok: |
|
break |
|
rgb_frame = frame[:, :, ::-1] |
|
pil_image = Image.fromarray(rgb_frame) |
|
(annotated_image, annotations), _ = detect_objects( |
|
pil_image, checkpoint, confidence_threshold, use_url=False, url="" |
|
) |
|
if annotated_image is None: |
|
continue |
|
annotated_frame = annotate_frame(annotated_image, annotations) |
|
writer.write(annotated_frame) |
|
frame_count += 1 |
|
|
|
writer.release() |
|
cap.release() |
|
|
|
if frame_count == 0: |
|
logger.warning("No valid frames processed in video") |
|
temp_file.close() |
|
os.unlink(temp_file.name) |
|
return None, gr.Markdown( |
|
"**Warning**: No valid frames processed. Try a different video or threshold.", |
|
visible=True, |
|
) |
|
|
|
temp_file.close() |
|
|
|
|
|
output_filename = f"output_{os.path.basename(temp_file.name)}" |
|
output_path = VIDEO_OUTPUT_DIR / output_filename |
|
shutil.copy(temp_file.name, output_path) |
|
os.unlink(temp_file.name) |
|
logger.info(f"Video saved to {output_path}") |
|
|
|
return str(output_path), gr.Markdown(visible=False) |
|
|
|
except Exception as e: |
|
logger.error(f"Video processing failed: {str(e)}") |
|
if "temp_file" in locals(): |
|
temp_file.close() |
|
if os.path.exists(temp_file.name): |
|
os.unlink(temp_file.name) |
|
return None, gr.Markdown( |
|
f"**Error**: Video processing failed: {str(e)}", visible=True |
|
) |
|
|
|
|
|
def create_image_inputs() -> List[gr.components.Component]: |
|
return [ |
|
gr.Image( |
|
label="Upload Image", |
|
type="pil", |
|
sources=["upload", "webcam"], |
|
interactive=True, |
|
elem_classes="input-component", |
|
), |
|
gr.Checkbox(label="Use Image URL Instead", value=False), |
|
gr.Textbox( |
|
label="Image URL", |
|
placeholder="https://example.com/image.jpg", |
|
visible=False, |
|
elem_classes="input-component", |
|
), |
|
gr.Dropdown( |
|
choices=CHECKPOINTS, |
|
label="Select Model Checkpoint", |
|
value=DEFAULT_CHECKPOINT, |
|
elem_classes="input-component", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=DEFAULT_CONFIDENCE_THRESHOLD, |
|
step=0.1, |
|
label="Confidence Threshold", |
|
elem_classes="input-component", |
|
), |
|
] |
|
|
|
|
|
def create_video_inputs() -> List[gr.components.Component]: |
|
return [ |
|
gr.Video( |
|
label="Upload Video", |
|
sources=["upload"], |
|
interactive=True, |
|
format="mp4", |
|
elem_classes="input-component", |
|
), |
|
gr.Dropdown( |
|
choices=CHECKPOINTS, |
|
label="Select Model Checkpoint", |
|
value=DEFAULT_CHECKPOINT, |
|
elem_classes="input-component", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=DEFAULT_CONFIDENCE_THRESHOLD, |
|
step=0.1, |
|
label="Confidence Threshold", |
|
elem_classes="input-component", |
|
), |
|
] |
|
|
|
|
|
def create_button_row(is_image: bool) -> List[gr.Button]: |
|
prefix = "Image" if is_image else "Video" |
|
return [ |
|
gr.Button( |
|
f"{prefix} Detect Objects", variant="primary", elem_classes="action-button" |
|
), |
|
gr.Button(f"{prefix} Clear", variant="secondary", elem_classes="action-button"), |
|
] |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# Real-Time Object Detection Demo |
|
Experience state-of-the-art object detection with USTC's Dfine models. Upload an image or video, |
|
provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time! |
|
""", |
|
elem_classes="header-text", |
|
) |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Image"): |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=300): |
|
with gr.Group(): |
|
( |
|
image_input, |
|
use_url, |
|
url_input, |
|
image_checkpoint, |
|
image_confidence_threshold, |
|
) = create_image_inputs() |
|
image_detect_button, image_clear_button = create_button_row( |
|
is_image=True |
|
) |
|
with gr.Column(scale=2): |
|
image_output = gr.AnnotatedImage( |
|
label="Detection Results", |
|
show_label=True, |
|
color_map=None, |
|
elem_classes="output-component", |
|
) |
|
image_error_message = gr.Markdown( |
|
visible=False, elem_classes="error-text" |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
example["path"], |
|
example["use_url"], |
|
example["url"], |
|
DEFAULT_CHECKPOINT, |
|
DEFAULT_CONFIDENCE_THRESHOLD, |
|
] |
|
for example in IMAGE_EXAMPLES |
|
], |
|
inputs=[ |
|
image_input, |
|
use_url, |
|
url_input, |
|
image_checkpoint, |
|
image_confidence_threshold, |
|
], |
|
outputs=[image_output, image_error_message], |
|
fn=detect_objects, |
|
cache_examples=False, |
|
label="Select an image example to populate inputs", |
|
) |
|
|
|
with gr.Tab("Video"): |
|
gr.Markdown( |
|
f"The input video will be truncated to {MAX_NUM_FRAMES} frames." |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=300): |
|
with gr.Group(): |
|
video_input, video_checkpoint, video_confidence_threshold = ( |
|
create_video_inputs() |
|
) |
|
video_detect_button, video_clear_button = create_button_row( |
|
is_image=False |
|
) |
|
with gr.Column(scale=2): |
|
video_output = gr.Video( |
|
label="Detection Results", |
|
format="mp4", |
|
elem_classes="output-component", |
|
) |
|
video_error_message = gr.Markdown( |
|
visible=False, elem_classes="error-text" |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[example["path"], DEFAULT_CHECKPOINT, DEFAULT_CONFIDENCE_THRESHOLD] |
|
for example in VIDEO_EXAMPLES |
|
], |
|
inputs=[video_input, video_checkpoint, video_confidence_threshold], |
|
outputs=[video_output, video_error_message], |
|
fn=process_video, |
|
cache_examples=False, |
|
label="Select a video example to populate inputs", |
|
) |
|
|
|
|
|
use_url.change( |
|
fn=lambda x: gr.update(visible=x), |
|
inputs=use_url, |
|
outputs=url_input, |
|
) |
|
|
|
|
|
image_clear_button.click( |
|
fn=lambda: ( |
|
None, |
|
False, |
|
"", |
|
DEFAULT_CHECKPOINT, |
|
DEFAULT_CONFIDENCE_THRESHOLD, |
|
None, |
|
gr.Markdown(visible=False), |
|
), |
|
outputs=[ |
|
image_input, |
|
use_url, |
|
url_input, |
|
image_checkpoint, |
|
image_confidence_threshold, |
|
image_output, |
|
image_error_message, |
|
], |
|
) |
|
|
|
|
|
video_clear_button.click( |
|
fn=lambda: ( |
|
None, |
|
DEFAULT_CHECKPOINT, |
|
DEFAULT_CONFIDENCE_THRESHOLD, |
|
None, |
|
gr.Markdown(visible=False), |
|
), |
|
outputs=[ |
|
video_input, |
|
video_checkpoint, |
|
video_confidence_threshold, |
|
video_output, |
|
video_error_message, |
|
], |
|
) |
|
|
|
|
|
image_detect_button.click( |
|
fn=detect_objects, |
|
inputs=[ |
|
image_input, |
|
image_checkpoint, |
|
image_confidence_threshold, |
|
use_url, |
|
url_input, |
|
], |
|
outputs=[image_output, image_error_message], |
|
) |
|
|
|
|
|
video_detect_button.click( |
|
fn=process_video, |
|
inputs=[video_input, video_checkpoint, video_confidence_threshold], |
|
outputs=[video_output, video_error_message], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=20).launch() |
|
|