Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as preprocess_vgg | |
from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input as preprocess_efficientnet | |
from tensorflow.keras.preprocessing import image | |
import numpy as np | |
# Define input image size (must match training settings) | |
image_size = (224, 224) | |
# Function to build the VGG16 model | |
def build_vgg16_model(): | |
base_model = VGG16(weights=None, include_top=False, input_shape=image_size + (3,)) | |
inputs = keras.layers.Input(shape=image_size + (3,)) | |
x1 = preprocess_vgg(inputs) | |
x1 = base_model(x1, training=False) | |
x1 = keras.layers.Flatten()(x1) # Flatten instead of GlobalAveragePooling2D | |
x1 = keras.layers.Dropout(rate=0.5)(x1) | |
x1 = keras.layers.Dense(units=256, activation="relu")(x1) | |
x1 = keras.layers.Dropout(rate=0.5)(x1) | |
outputs = keras.layers.Dense(units=1, activation="sigmoid")(x1) | |
model = keras.models.Model(inputs=[inputs], outputs=[outputs]) | |
# Compile the model | |
model.compile( | |
loss="binary_crossentropy", | |
optimizer=keras.optimizers.Adam(learning_rate=0.001), | |
metrics=["accuracy"], | |
) | |
return model | |
# Function to build the EfficientNetB0 model | |
def build_efficientnet_model(): | |
base_model = EfficientNetB0(input_shape=image_size + (3,), include_top=False, weights="imagenet") | |
inputs = keras.layers.Input(shape=image_size + (3,)) | |
x = preprocess_efficientnet(inputs) # EfficientNet includes its own preprocessing | |
x = base_model(x, training=False) | |
x = keras.layers.GlobalAveragePooling2D()(x) | |
x = keras.layers.Dropout(rate=0.5)(x) | |
x = keras.layers.Dense(units=256, activation="relu")(x) | |
x = keras.layers.Dropout(rate=0.5)(x) | |
outputs = keras.layers.Dense(units=1, activation="sigmoid")(x) | |
model = keras.models.Model(inputs=[inputs], outputs=[outputs]) | |
# Compile the model | |
model.compile( | |
loss="binary_crossentropy", | |
optimizer=keras.optimizers.Adam(learning_rate=0.001), | |
metrics=["accuracy"], | |
) | |
return model | |
# Dictionary to store models | |
models = { | |
"VGG16": build_vgg16_model(), | |
"EfficientNetB0": build_efficientnet_model(), | |
} | |
# Load pre-trained weights | |
models["VGG16"].load_weights("VGG16_best_finetune_checkpoint.weights.h5") | |
models["EfficientNetB0"].load_weights("EfficientNetB0_best_finetune_checkpoint.weights.h5") | |
print("Models and weights loaded successfully!") | |
# Set the default model | |
current_model = models["VGG16"] | |
# Function to update the current model based on selection | |
def load_selected_model(model_name): | |
global current_model | |
current_model = models[model_name] | |
return f"Loaded {model_name} model." | |
# Preprocessing function | |
def preprocess_img(img, model_name): | |
img = img.resize(image_size) # Resize to match training | |
img_array = image.img_to_array(img) | |
# Apply the correct preprocessing function based on model selection | |
if model_name == "VGG16": | |
img_array = preprocess_vgg(img_array) | |
else: | |
img_array = preprocess_efficientnet(img_array) | |
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension | |
return img_array | |
# Prediction function | |
def predict(img, model_name): | |
img_array = preprocess_img(img, model_name) | |
predictions = current_model.predict(img_array) | |
confidence = float(predictions[0][0]) # Convert to float | |
# Class labels | |
class_labels = {0: "Defective", 1: "Good"} | |
predicted_class = 1 if confidence > 0.5 else 0 | |
result_text = ( | |
f"Predicted Class: {class_labels[predicted_class]}\n" | |
f"Confidence (Good): {confidence:.8f}\n" | |
f"Confidence (Defective): {1 - confidence:.8f}" | |
) | |
return result_text | |
# Function to clear input and output | |
def clear(): | |
return None, "" | |
# Gradio Interface with Model Selection | |
with gr.Blocks() as interface: | |
gr.Markdown( | |
"## Fine-tuned Defect Tyre Classification\n\n" | |
"This Gradio-based application allows users to classify tyre images as **'Good'** or **'Defective'** using a " | |
"fine-tuned deep learning model. Users can select between two models (**VGG16** and **EfficientNetB0**) " | |
"via a dropdown menu. The selected model is dynamically loaded and applied to the uploaded image, " | |
"with predictions and confidence scores displayed as output.\n\n" | |
"Upload a tyre image and select a model (**VGG16** or **EfficientNetB0**) to classify defects." | |
) | |
# Dropdown to select model | |
model_dropdown = gr.Dropdown( | |
choices=list(models.keys()), value="VGG16", label="Select Model" | |
) | |
model_status = gr.Textbox(label="Model Status", interactive=False) | |
with gr.Row(): | |
input_image = gr.Image(type="pil", label="Upload Image") | |
output_text = gr.Textbox(label="Classification Output", interactive=False) | |
with gr.Row(): | |
classify_button = gr.Button("Classify") | |
clear_button = gr.Button("Clear") | |
# Link dropdown to model selection function | |
model_dropdown.change(fn=load_selected_model, inputs=model_dropdown, outputs=model_status) | |
# Link buttons to functions | |
classify_button.click(fn=predict, inputs=[input_image, model_dropdown], outputs=output_text) | |
clear_button.click(fn=clear, inputs=[], outputs=[input_image, output_text]) | |
# Run the app | |
if __name__ == "__main__": | |
interface.launch() | |