|
#!/usr/bin/env python |
|
# coding: utf-8 |
|
|
|
|
|
from ultralytics import YOLO |
|
from PIL import Image, ImageDraw, ImageFont |
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
import os |
|
from torchvision import transforms |
|
|
|
classes = {0: "Defective", 1: "Good"} |
|
|
|
#model_path = "best_int8_openvino_model" |
|
model_path = "./best.pt" |
|
|
|
def load_model_local(): |
|
detection_model = YOLO(model_path, task='classify') # Load the model |
|
return detection_model |
|
|
|
def load_model(repo_id): |
|
download_dir = snapshot_download(repo_id) |
|
print(download_dir) |
|
path = os.path.join(download_dir, "best_int8_openvino_model") |
|
print(path) |
|
detection_model = YOLO(path, task='classify') |
|
return detection_model |
|
|
|
def predict(pilimg): |
|
source = pilimg |
|
|
|
# Call the model to transform image size |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
source = transform(source) # Update the source image size to 224x224, 1 of 2 sizes accepted by Yolo classification model |
|
|
|
#result = detection_model.predict(source, conf=0.5, iou=0.6) |
|
result = detection_model.predict(source) # Make prediction |
|
# Get the top prediction |
|
label = result[0].probs.top1 |
|
|
|
class_names = detection_model.names # Retrieves the class names mapping (dict-like) |
|
classified_type = class_names[label] # Map numeric label to class name |
|
print (">>> Class : ", classified_type) |
|
|
|
confidence = result[0].probs.top1conf # Get the top class confidence |
|
print(">>> Confidence : ", confidence) |
|
|
|
annotated_image = pilimg.convert("RGB") |
|
draw = ImageDraw.Draw(annotated_image) |
|
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" |
|
font = ImageFont.truetype(font_path, 30) |
|
#font = ImageFont.truetype("font/arialbd.ttf", 30) # Use arial.ttf for bold font |
|
|
|
if classified_type == classes[0]: |
|
draw.text((300, 10), classified_type, fill="red", font=font) |
|
gr.Warning("Defect detected, BAD!.") |
|
else: |
|
draw.text((300, 10), classified_type, fill="green", font=font) |
|
gr.Info("No defect detected,GOOD!") |
|
|
|
#draw.text((300, 10), classified_type, fill="red", font=font) |
|
|
|
return annotated_image |
|
|
|
detection_model = load_model_local() |
|
|
|
title = "Detect the status of the cap, DEFECTIVE or GOOD" |
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Input Image"), |
|
outputs=gr.Image(type="pil", label="Classification result"), |
|
title=title, |
|
) |
|
|
|
# Launch the interface |
|
interface.launch(share=True) |
|
|
|
|
|
|