Amazingldl's picture
Update app.py
5195576
raw
history blame contribute delete
1.99 kB
import gradio as gr
import numpy as np
import torch
from typing import List
from PIL import Image, ImageDraw
from transformers import OwlViTProcessor, OwlViTForObjectDetection
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
def pro_process(labelstring):
labels = labelstring.split(",")
labels = [i.strip() for i in labels]
return labels
def inference(img: Image.Image, labels: List[str]) -> Image.Image:
labels = pro_process(labels)
print(labels)
inputs = processor(text=labels, images=img, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.Tensor([img.size[::-1]])
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)
i = 0
boxes, scores, labels_index = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
draw = ImageDraw.Draw(img)
for box, score, label_index in zip(boxes, scores, labels_index):
box = [round(i, 2) for i in box.tolist()]
xmin, ymin, xmax, ymax = box
draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
draw.text((xmin, ymin), f"{labels[label_index]}: {round(float(score),2)}", fill="white")
return img
with gr.Blocks(title="Zero-shot object detection", theme="freddyaboulton/dracula_revamped") as demo:
gr.Markdown(""
"## Zero-shot object detection"
"")
with gr.Row():
with gr.Column():
in_img = gr.Image(label="Input Image", type="pil")
in_labels = gr.Textbox(label="Input labels, comma apart")
inference_btn = gr.Button("Inference", variant="primary")
with gr.Column():
out_img = gr.Image(label="Result", interactive=False)
inference_btn.click(inference, inputs=[in_img, in_labels], outputs=[out_img])
if __name__ == "__main__":
demo.queue().launch()