Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import argparse | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
import warnings | |
import torch | |
warnings.filterwarnings("ignore") | |
# Replace custom imports with Transformers | |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
# Add supervision for better visualization | |
import supervision as sv | |
# Model ID for Hugging Face | |
model_id = "IDEA-Research/grounding-dino-base" | |
# Load model and processor using Transformers | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) | |
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold): | |
# Convert numpy array to PIL Image if needed | |
if isinstance(input_image, np.ndarray): | |
if input_image.ndim == 3: | |
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) | |
input_image = Image.fromarray(input_image) | |
init_image = input_image.convert("RGB") | |
# Process input using transformers | |
inputs = processor(images=init_image, text=grounding_caption, return_tensors="pt").to(device) | |
# Run inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Post-process results | |
results = processor.post_process_grounded_object_detection( | |
outputs, | |
inputs.input_ids, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold, | |
target_sizes=[init_image.size[::-1]] | |
) | |
result = results[0] | |
# Convert image for supervision visualization | |
image_np = np.array(init_image) | |
# Create detections for supervision | |
boxes = [] | |
labels = [] | |
confidences = [] | |
class_ids = [] | |
for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])): | |
# Convert box to xyxy format | |
xyxy = box.tolist() | |
boxes.append(xyxy) | |
labels.append(label) | |
confidences.append(float(score)) | |
class_ids.append(i) # Use index as class_id (integer) | |
# Create Detections object for supervision | |
if boxes: | |
detections = sv.Detections( | |
xyxy=np.array(boxes), | |
confidence=np.array(confidences), | |
class_id=np.array(class_ids, dtype=np.int32), # Ensure it's an integer array | |
) | |
text_scale = sv.calculate_optimal_text_scale(resolution_wh=init_image.size) | |
line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=init_image.size) | |
# Create annotators | |
box_annotator = sv.BoxAnnotator( | |
thickness=2, | |
color=sv.ColorPalette.DEFAULT, | |
) | |
label_annotator = sv.LabelAnnotator( | |
color=sv.ColorPalette.DEFAULT, | |
text_color=sv.Color.WHITE, | |
text_scale=text_scale, | |
text_thickness=line_thickness, | |
text_padding=3 | |
) | |
# Create formatted labels for each detection | |
formatted_labels = [ | |
f"{label}: {conf:.2f}" | |
for label, conf in zip(labels, confidences) | |
] | |
# Apply annotations to the image | |
annotated_image = box_annotator.annotate(scene=image_np, detections=detections) | |
annotated_image = label_annotator.annotate( | |
scene=annotated_image, | |
detections=detections, | |
labels=formatted_labels | |
) | |
else: | |
annotated_image = image_np | |
# Convert back to PIL Image | |
image_with_box = Image.fromarray(annotated_image) | |
return image_with_box | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) | |
parser.add_argument("--debug", action="store_true", help="using debug mode") | |
parser.add_argument("--share", action="store_true", help="share the app") | |
args = parser.parse_args() | |
css = """ | |
#mkd { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("<h1><center>Grounding DINO Base<h1><center>") | |
gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a><h3><center>") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
grounding_caption = gr.Textbox(label="Detection Prompt(VERY important: text queries need to be lowercased + end with a dot, example: a cat. a remote control.)", value="a person. a car.") | |
run_button = gr.Button("Run") | |
with gr.Accordion("Advanced options", open=False): | |
box_threshold = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.3, step=0.001, | |
label="Box Threshold" | |
) | |
text_threshold = gr.Slider( | |
minimum=0.0, maximum=1.0, value=0.25, step=0.001, | |
label="Text Threshold" | |
) | |
with gr.Column(): | |
gallery = gr.Image( | |
label="Detection Result", | |
type="pil" | |
) | |
run_button.click( | |
fn=run_grounding, | |
inputs=[input_image, grounding_caption, box_threshold, text_threshold], | |
outputs=[gallery] | |
) | |
gr.Examples( | |
examples=[ | |
["000000039769.jpg", "a cat. a remote control.", 0.3, 0.25], | |
["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", 0.3, 0.25] | |
], | |
inputs=[input_image, grounding_caption, box_threshold, text_threshold], | |
outputs=[gallery], | |
fn=run_grounding, | |
cache_examples=True, | |
) | |
demo.launch(share=args.share, debug=args.debug, show_error=True) |