wookimchye's picture
Rename app.py to app.py.Bak
6ee5646 verified
#!/usr/bin/env python
# coding: utf-8
from ultralytics import YOLO
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from huggingface_hub import snapshot_download
import os
from torchvision import transforms
classes = {0: "Defective", 1: "Good"}
#model_path = "best_int8_openvino_model"
model_path = "./best.pt"
def load_model_local():
detection_model = YOLO(model_path, task='classify') # Load the model
return detection_model
def load_model(repo_id):
download_dir = snapshot_download(repo_id)
print(download_dir)
path = os.path.join(download_dir, "best_int8_openvino_model")
print(path)
detection_model = YOLO(path, task='classify')
return detection_model
def predict(pilimg):
source = pilimg
# Call the model to transform image size
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
source = transform(source) # Update the source image size to 224x224, 1 of 2 sizes accepted by Yolo classification model
#result = detection_model.predict(source, conf=0.5, iou=0.6)
result = detection_model.predict(source) # Make prediction
# Get the top prediction
label = result[0].probs.top1
class_names = detection_model.names # Retrieves the class names mapping (dict-like)
classified_type = class_names[label] # Map numeric label to class name
print (">>> Class : ", classified_type)
confidence = result[0].probs.top1conf # Get the top class confidence
print(">>> Confidence : ", confidence)
annotated_image = pilimg.convert("RGB")
draw = ImageDraw.Draw(annotated_image)
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
font = ImageFont.truetype(font_path, 30)
#font = ImageFont.truetype("font/arialbd.ttf", 30) # Use arial.ttf for bold font
if classified_type == classes[0]:
draw.text((300, 10), classified_type, fill="red", font=font)
gr.Warning("Defect detected, BAD!.")
else:
draw.text((300, 10), classified_type, fill="green", font=font)
gr.Info("No defect detected,GOOD!")
#draw.text((300, 10), classified_type, fill="red", font=font)
return annotated_image
detection_model = load_model_local()
title = "Detect the status of the cap, DEFECTIVE or GOOD"
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Classification result"),
title=title,
)
# Launch the interface
interface.launch(share=True)