import gradio as gr import tensorflow as tf from tensorflow import keras from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as preprocess_vgg from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input as preprocess_efficientnet from tensorflow.keras.preprocessing import image import numpy as np # Define input image size (must match training settings) image_size = (224, 224) # Function to build the VGG16 model def build_vgg16_model(): base_model = VGG16(weights=None, include_top=False, input_shape=image_size + (3,)) inputs = keras.layers.Input(shape=image_size + (3,)) x1 = preprocess_vgg(inputs) x1 = base_model(x1, training=False) x1 = keras.layers.Flatten()(x1) # Flatten instead of GlobalAveragePooling2D x1 = keras.layers.Dropout(rate=0.5)(x1) x1 = keras.layers.Dense(units=256, activation="relu")(x1) x1 = keras.layers.Dropout(rate=0.5)(x1) outputs = keras.layers.Dense(units=1, activation="sigmoid")(x1) model = keras.models.Model(inputs=[inputs], outputs=[outputs]) # Compile the model model.compile( loss="binary_crossentropy", optimizer=keras.optimizers.Adam(learning_rate=0.001), metrics=["accuracy"], ) return model # Function to build the EfficientNetB0 model def build_efficientnet_model(): base_model = EfficientNetB0(input_shape=image_size + (3,), include_top=False, weights="imagenet") inputs = keras.layers.Input(shape=image_size + (3,)) x = preprocess_efficientnet(inputs) # EfficientNet includes its own preprocessing x = base_model(x, training=False) x = keras.layers.GlobalAveragePooling2D()(x) x = keras.layers.Dropout(rate=0.5)(x) x = keras.layers.Dense(units=256, activation="relu")(x) x = keras.layers.Dropout(rate=0.5)(x) outputs = keras.layers.Dense(units=1, activation="sigmoid")(x) model = keras.models.Model(inputs=[inputs], outputs=[outputs]) # Compile the model model.compile( loss="binary_crossentropy", optimizer=keras.optimizers.Adam(learning_rate=0.001), metrics=["accuracy"], ) return model # Dictionary to store models models = { "VGG16": build_vgg16_model(), "EfficientNetB0": build_efficientnet_model(), } # Load pre-trained weights models["VGG16"].load_weights("VGG16_best_finetune_checkpoint.weights.h5") models["EfficientNetB0"].load_weights("EfficientNetB0_best_finetune_checkpoint.weights.h5") print("Models and weights loaded successfully!") # Set the default model current_model = models["VGG16"] # Function to update the current model based on selection def load_selected_model(model_name): global current_model current_model = models[model_name] return f"Loaded {model_name} model." # Preprocessing function def preprocess_img(img, model_name): img = img.resize(image_size) # Resize to match training img_array = image.img_to_array(img) # Apply the correct preprocessing function based on model selection if model_name == "VGG16": img_array = preprocess_vgg(img_array) else: img_array = preprocess_efficientnet(img_array) img_array = np.expand_dims(img_array, axis=0) # Add batch dimension return img_array # Prediction function def predict(img, model_name): img_array = preprocess_img(img, model_name) predictions = current_model.predict(img_array) confidence = float(predictions[0][0]) # Convert to float # Class labels class_labels = {0: "Defective", 1: "Good"} predicted_class = 1 if confidence > 0.5 else 0 result_text = ( f"Predicted Class: {class_labels[predicted_class]}\n" f"Confidence (Good): {confidence:.8f}\n" f"Confidence (Defective): {1 - confidence:.8f}" ) return result_text # Function to clear input and output def clear(): return None, "" # Gradio Interface with Model Selection with gr.Blocks() as interface: gr.Markdown( "## Fine-tuned Defect Tyre Classification\n\n" "This Gradio-based application allows users to classify tyre images as **'Good'** or **'Defective'** using a " "fine-tuned deep learning model. Users can select between two models (**VGG16** and **EfficientNetB0**) " "via a dropdown menu. The selected model is dynamically loaded and applied to the uploaded image, " "with predictions and confidence scores displayed as output.\n\n" "Upload a tyre image and select a model (**VGG16** or **EfficientNetB0**) to classify defects." ) # Dropdown to select model model_dropdown = gr.Dropdown( choices=list(models.keys()), value="VGG16", label="Select Model" ) model_status = gr.Textbox(label="Model Status", interactive=False) with gr.Row(): input_image = gr.Image(type="pil", label="Upload Image") output_text = gr.Textbox(label="Classification Output", interactive=False) with gr.Row(): classify_button = gr.Button("Classify") clear_button = gr.Button("Clear") # Link dropdown to model selection function model_dropdown.change(fn=load_selected_model, inputs=model_dropdown, outputs=model_status) # Link buttons to functions classify_button.click(fn=predict, inputs=[input_image, model_dropdown], outputs=output_text) clear_button.click(fn=clear, inputs=[], outputs=[input_image, output_text]) # Run the app if __name__ == "__main__": interface.launch()