#!/usr/bin/env python # coding: utf-8 import tensorflow as tf from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions from tensorflow.keras.preprocessing import image from ultralytics import YOLO import numpy as np from PIL import Image, ImageDraw, ImageFont import gradio as gr from huggingface_hub import snapshot_download import os from torchvision import transforms # Define the class labels classes = {0: "Defective", 1: "Good"} model_path = "ResNet50_Classification.h5" # Trained RestNet50 model best_yolo_model = "best.pt" # Trained YOLOv8 detection model classification_model = tf.keras.models.load_model('ResNet50_Classification.h5') detection_model = YOLO(best_yolo_model, task='detect') # Define the image preprocessing function def preprocess_image(pilimg): img = pilimg.resize((224, 224)) # Resize to the input size of ResNet50 img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) # Add batch dimension return img_array def classify_image(pilimg): img_array = preprocess_image(pilimg) # Preprocess the input image classify_result = classification_model.predict(img_array)[0][0] # Get prediction probability print(">>> Result : ", classify_result) predicted_class = "Good" if classify_result >= 0.5 else "Defective" print(">>> predicted_class : ", predicted_class) return predicted_class def detect_defect(img): detection_result = detection_model.predict(img, conf=0.4, iou=0.5) return detection_result def process_image(pilimg): summary_str = "" # summary variable # Perform classification first, then perform detection if Defective classification = classify_image(pilimg) if classification == "Good": out_pilimg = pilimg.convert("RGB") draw = ImageDraw.Draw(out_pilimg) font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" font = ImageFont.truetype(font_path, 30) #font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font draw.text((250, 10), "Good", fill="green", font=font) #summary_str = "No defect is detected, the cap is GOOD!" summary_str = f"No defect is detected, the cap is GOOD!" else: # Defective detection_result = detect_defect(pilimg) img_bgr = detection_result[0].plot() out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image draw = ImageDraw.Draw(out_pilimg) font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" font = ImageFont.truetype(font_path, 30) #font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font draw.text((300, 10), "Defective", fill="red", font=font) detections = detection_result[0].boxes.data # Get detections if len(detections) > 0: #summary_str = "Defect is detected, the cap is BAD" summary_str = f"Defect is detected, the cap is BAD!" else: #summary_str = "The cap is classifed as Defective but the defect cannot be detected!" summary_str = f"The cap is classifed as Defective but the defect cannot be detected!" #return out_pilimg, f"**{summary_str}**" return out_pilimg, summary_str title = "Detect the status of the cap: DEFECTIVE or GOOD" interface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil", label="Input Image"), outputs=[ gr.Image(type="pil", label="Classification/Detection result"), gr.Markdown(label="Classification/Detection Summary"), ], title=title, ) # Launch the interface interface.launch(share=True)