import numpy as np import tensorflow as tf import gradio as gr from tensorflow.keras.preprocessing import image from huggingface_hub import snapshot_download import os # Load the model from Hugging Face Hub def load_model(repo_id): download_dir = snapshot_download(repo_id) model_path = os.path.join(download_dir, "full_model.weights.h5") model = tf.keras.models.load_model(model_path) return model # Function to preprocess the uploaded image def preprocess_image(img, target_size=(224, 224)): img = img.resize(target_size) # Resize to match model input img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) # Add batch dimension img_array = tf.keras.applications.efficientnet.preprocess_input(img_array) return img_array # Perform inference def predict(image_input): class_names = ["Defective Tyre", "Good Tyre"] # Preprocess image img_array = preprocess_image(image_input) # Get prediction prediction = model.predict(img_array)[0][0] # Scalar sigmoid output predicted_class_idx = int(prediction >= 0.5) # 0 if <0.5, 1 if >=0.5 predicted_class = class_names[predicted_class_idx] # Get class name return f"Predicted Class: {predicted_class} (Confidence: {prediction:.5f})" # Hugging Face Model Repository ID REPO_ID = "skngew/9053220B" # my actual repo ID # Load the model model = load_model(REPO_ID) # Student ID student_id = "Student ID: 9053220B" # Markdown description to show classification threshold threshold_info = """ ### EfficientNetB0 (Feature Extraction) ### Classification Threshold: - A tyre is classified as **Good** if the confidence score is **≥ 0.5**. - A tyre is classified as **Defective** if the confidence score is **< 0.5**. """ # Create the Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload an Image"), outputs=gr.Textbox(label="Prediction"), title="Binary Classification: Good vs. Defective Tyre", description=student_id, allow_flagging="never", examples=[], #Can add examples here ) # Add the threshold information markdown with gr.Blocks() as app: gr.Markdown(threshold_info) # Display threshold info interface.render() # Launch the Gradio app app.launch(share=True)