#!/usr/bin/env python # coding: utf-8 from ultralytics import YOLO from PIL import Image, ImageDraw, ImageFont import gradio as gr from huggingface_hub import snapshot_download import os from torchvision import transforms classes = {0: "Defective", 1: "Good"} #model_path = "best_int8_openvino_model" model_path = "./best.pt" def load_model_local(): detection_model = YOLO(model_path, task='classify') # Load the model return detection_model def load_model(repo_id): download_dir = snapshot_download(repo_id) print(download_dir) path = os.path.join(download_dir, "best_int8_openvino_model") print(path) detection_model = YOLO(path, task='classify') return detection_model def predict(pilimg): source = pilimg # Call the model to transform image size transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) source = transform(source) # Update the source image size to 224x224, 1 of 2 sizes accepted by Yolo classification model #result = detection_model.predict(source, conf=0.5, iou=0.6) result = detection_model.predict(source) # Make prediction # Get the top prediction label = result[0].probs.top1 class_names = detection_model.names # Retrieves the class names mapping (dict-like) classified_type = class_names[label] # Map numeric label to class name print (">>> Class : ", classified_type) confidence = result[0].probs.top1conf # Get the top class confidence print(">>> Confidence : ", confidence) annotated_image = pilimg.convert("RGB") draw = ImageDraw.Draw(annotated_image) font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" font = ImageFont.truetype(font_path, 30) #font = ImageFont.truetype("font/arialbd.ttf", 30) # Use arial.ttf for bold font if classified_type == classes[0]: draw.text((300, 10), classified_type, fill="red", font=font) gr.Warning("Defect detected, BAD!.") else: draw.text((300, 10), classified_type, fill="green", font=font) gr.Info("No defect detected,GOOD!") #draw.text((300, 10), classified_type, fill="red", font=font) return annotated_image detection_model = load_model_local() title = "Detect the status of the cap, DEFECTIVE or GOOD" interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Input Image"), outputs=gr.Image(type="pil", label="Classification result"), title=title, ) # Launch the interface interface.launch(share=True)