developer0hye's picture
Update app.py
795585e verified
raw
history blame contribute delete
6 kB
import gradio as gr
import spaces
import argparse
import cv2
from PIL import Image
import numpy as np
import warnings
import torch
warnings.filterwarnings("ignore")
# Replace custom imports with Transformers
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
# Add supervision for better visualization
import supervision as sv
# Model ID for Hugging Face
model_id = "IDEA-Research/grounding-dino-base"
# Load model and processor using Transformers
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
@spaces.GPU
def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
# Convert numpy array to PIL Image if needed
if isinstance(input_image, np.ndarray):
if input_image.ndim == 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
input_image = Image.fromarray(input_image)
init_image = input_image.convert("RGB")
# Process input using transformers
inputs = processor(images=init_image, text=grounding_caption, return_tensors="pt").to(device)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process results
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[init_image.size[::-1]]
)
result = results[0]
# Convert image for supervision visualization
image_np = np.array(init_image)
# Create detections for supervision
boxes = []
labels = []
confidences = []
class_ids = []
for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
# Convert box to xyxy format
xyxy = box.tolist()
boxes.append(xyxy)
labels.append(label)
confidences.append(float(score))
class_ids.append(i) # Use index as class_id (integer)
# Create Detections object for supervision
if boxes:
detections = sv.Detections(
xyxy=np.array(boxes),
confidence=np.array(confidences),
class_id=np.array(class_ids, dtype=np.int32), # Ensure it's an integer array
)
text_scale = sv.calculate_optimal_text_scale(resolution_wh=init_image.size)
line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=init_image.size)
# Create annotators
box_annotator = sv.BoxAnnotator(
thickness=2,
color=sv.ColorPalette.DEFAULT,
)
label_annotator = sv.LabelAnnotator(
color=sv.ColorPalette.DEFAULT,
text_color=sv.Color.WHITE,
text_scale=text_scale,
text_thickness=line_thickness,
text_padding=3
)
# Create formatted labels for each detection
formatted_labels = [
f"{label}: {conf:.2f}"
for label, conf in zip(labels, confidences)
]
# Apply annotations to the image
annotated_image = box_annotator.annotate(scene=image_np, detections=detections)
annotated_image = label_annotator.annotate(
scene=annotated_image,
detections=detections,
labels=formatted_labels
)
else:
annotated_image = image_np
# Convert back to PIL Image
image_with_box = Image.fromarray(annotated_image)
return image_with_box
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
args = parser.parse_args()
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("<h1><center>Grounding DINO Base<h1><center>")
gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a><h3><center>")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
grounding_caption = gr.Textbox(label="Detection Prompt(VERY important: text queries need to be lowercased + end with a dot, example: a cat. a remote control.)", value="a person. a car.")
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=False):
box_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.3, step=0.001,
label="Box Threshold"
)
text_threshold = gr.Slider(
minimum=0.0, maximum=1.0, value=0.25, step=0.001,
label="Text Threshold"
)
with gr.Column():
gallery = gr.Image(
label="Detection Result",
type="pil"
)
run_button.click(
fn=run_grounding,
inputs=[input_image, grounding_caption, box_threshold, text_threshold],
outputs=[gallery]
)
gr.Examples(
examples=[
["000000039769.jpg", "a cat. a remote control.", 0.3, 0.25],
["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", 0.3, 0.25]
],
inputs=[input_image, grounding_caption, box_threshold, text_threshold],
outputs=[gallery],
fn=run_grounding,
cache_examples=True,
)
demo.launch(share=args.share, debug=args.debug, show_error=True)