ITI110-2024S2 / app.py
EdwinLH's picture
Update app.py
f40f30c verified
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as preprocess_vgg
from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input as preprocess_efficientnet
from tensorflow.keras.preprocessing import image
import numpy as np
# Define input image size (must match training settings)
image_size = (224, 224)
# Function to build the VGG16 model
def build_vgg16_model():
base_model = VGG16(weights=None, include_top=False, input_shape=image_size + (3,))
inputs = keras.layers.Input(shape=image_size + (3,))
x1 = preprocess_vgg(inputs)
x1 = base_model(x1, training=False)
x1 = keras.layers.Flatten()(x1) # Flatten instead of GlobalAveragePooling2D
x1 = keras.layers.Dropout(rate=0.5)(x1)
x1 = keras.layers.Dense(units=256, activation="relu")(x1)
x1 = keras.layers.Dropout(rate=0.5)(x1)
outputs = keras.layers.Dense(units=1, activation="sigmoid")(x1)
model = keras.models.Model(inputs=[inputs], outputs=[outputs])
# Compile the model
model.compile(
loss="binary_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=["accuracy"],
)
return model
# Function to build the EfficientNetB0 model
def build_efficientnet_model():
base_model = EfficientNetB0(input_shape=image_size + (3,), include_top=False, weights="imagenet")
inputs = keras.layers.Input(shape=image_size + (3,))
x = preprocess_efficientnet(inputs) # EfficientNet includes its own preprocessing
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(rate=0.5)(x)
x = keras.layers.Dense(units=256, activation="relu")(x)
x = keras.layers.Dropout(rate=0.5)(x)
outputs = keras.layers.Dense(units=1, activation="sigmoid")(x)
model = keras.models.Model(inputs=[inputs], outputs=[outputs])
# Compile the model
model.compile(
loss="binary_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=["accuracy"],
)
return model
# Dictionary to store models
models = {
"VGG16": build_vgg16_model(),
"EfficientNetB0": build_efficientnet_model(),
}
# Load pre-trained weights
models["VGG16"].load_weights("VGG16_best_finetune_checkpoint.weights.h5")
models["EfficientNetB0"].load_weights("EfficientNetB0_best_finetune_checkpoint.weights.h5")
print("Models and weights loaded successfully!")
# Set the default model
current_model = models["VGG16"]
# Function to update the current model based on selection
def load_selected_model(model_name):
global current_model
current_model = models[model_name]
return f"Loaded {model_name} model."
# Preprocessing function
def preprocess_img(img, model_name):
img = img.resize(image_size) # Resize to match training
img_array = image.img_to_array(img)
# Apply the correct preprocessing function based on model selection
if model_name == "VGG16":
img_array = preprocess_vgg(img_array)
else:
img_array = preprocess_efficientnet(img_array)
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
return img_array
# Prediction function
def predict(img, model_name):
img_array = preprocess_img(img, model_name)
predictions = current_model.predict(img_array)
confidence = float(predictions[0][0]) # Convert to float
# Class labels
class_labels = {0: "Defective", 1: "Good"}
predicted_class = 1 if confidence > 0.5 else 0
result_text = (
f"Predicted Class: {class_labels[predicted_class]}\n"
f"Confidence (Good): {confidence:.8f}\n"
f"Confidence (Defective): {1 - confidence:.8f}"
)
return result_text
# Function to clear input and output
def clear():
return None, ""
# Gradio Interface with Model Selection
with gr.Blocks() as interface:
gr.Markdown(
"## Fine-tuned Defect Tyre Classification\n\n"
"This Gradio-based application allows users to classify tyre images as **'Good'** or **'Defective'** using a "
"fine-tuned deep learning model. Users can select between two models (**VGG16** and **EfficientNetB0**) "
"via a dropdown menu. The selected model is dynamically loaded and applied to the uploaded image, "
"with predictions and confidence scores displayed as output.\n\n"
"Upload a tyre image and select a model (**VGG16** or **EfficientNetB0**) to classify defects."
)
# Dropdown to select model
model_dropdown = gr.Dropdown(
choices=list(models.keys()), value="VGG16", label="Select Model"
)
model_status = gr.Textbox(label="Model Status", interactive=False)
with gr.Row():
input_image = gr.Image(type="pil", label="Upload Image")
output_text = gr.Textbox(label="Classification Output", interactive=False)
with gr.Row():
classify_button = gr.Button("Classify")
clear_button = gr.Button("Clear")
# Link dropdown to model selection function
model_dropdown.change(fn=load_selected_model, inputs=model_dropdown, outputs=model_status)
# Link buttons to functions
classify_button.click(fn=predict, inputs=[input_image, model_dropdown], outputs=output_text)
clear_button.click(fn=clear, inputs=[], outputs=[input_image, output_text])
# Run the app
if __name__ == "__main__":
interface.launch()