File size: 5,482 Bytes
d0a2e2a
 
 
91211ef
 
d0a2e2a
 
 
a80a6f0
d0a2e2a
f40f30c
d0a2e2a
91211ef
 
d0a2e2a
 
91211ef
d0a2e2a
59e0382
d0a2e2a
 
 
 
 
 
 
91211ef
d0a2e2a
 
 
 
 
 
 
 
91211ef
 
 
 
 
 
 
 
0e8dbc5
91211ef
0e8dbc5
91211ef
 
 
 
 
 
 
0e8dbc5
91211ef
 
 
 
 
 
 
 
 
 
 
c05711e
 
91211ef
 
 
 
 
 
 
 
 
 
 
d0a2e2a
 
91211ef
d0a2e2a
 
91211ef
 
 
 
 
 
 
d0a2e2a
 
 
 
91211ef
 
 
 
d0a2e2a
 
 
 
 
 
 
 
fa76f86
 
d0a2e2a
 
 
 
d4456ba
 
 
 
91211ef
d0a2e2a
34be39e
505fb10
 
34be39e
 
 
505fb10
a80a6f0
91211ef
 
 
 
 
 
d0a2e2a
86b7628
 
 
 
d0a2e2a
094ca8b
 
d0a2e2a
91211ef
 
 
 
 
094ca8b
d0a2e2a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as preprocess_vgg
from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input as preprocess_efficientnet
from tensorflow.keras.preprocessing import image
import numpy as np


# Define input image size (must match training settings)
image_size = (224, 224)

# Function to build the VGG16 model
def build_vgg16_model():
    base_model = VGG16(weights=None, include_top=False, input_shape=image_size + (3,))
    inputs = keras.layers.Input(shape=image_size + (3,))
    x1 = preprocess_vgg(inputs)
    x1 = base_model(x1, training=False)
    x1 = keras.layers.Flatten()(x1)  # Flatten instead of GlobalAveragePooling2D
    x1 = keras.layers.Dropout(rate=0.5)(x1)
    x1 = keras.layers.Dense(units=256, activation="relu")(x1)
    x1 = keras.layers.Dropout(rate=0.5)(x1)
    outputs = keras.layers.Dense(units=1, activation="sigmoid")(x1)

    model = keras.models.Model(inputs=[inputs], outputs=[outputs])
    
    # Compile the model
    model.compile(
        loss="binary_crossentropy",
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        metrics=["accuracy"],
    )
    
    return model

# Function to build the EfficientNetB0 model
def build_efficientnet_model():
    base_model = EfficientNetB0(input_shape=image_size + (3,), include_top=False, weights="imagenet")
    inputs = keras.layers.Input(shape=image_size + (3,))
    x = preprocess_efficientnet(inputs)  # EfficientNet includes its own preprocessing
    x = base_model(x, training=False)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dropout(rate=0.5)(x)
    x = keras.layers.Dense(units=256, activation="relu")(x)
    x = keras.layers.Dropout(rate=0.5)(x)
    outputs = keras.layers.Dense(units=1, activation="sigmoid")(x)

    model = keras.models.Model(inputs=[inputs], outputs=[outputs])

    # Compile the model
    model.compile(
        loss="binary_crossentropy",
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        metrics=["accuracy"],
    )

    return model

# Dictionary to store models
models = {
    "VGG16": build_vgg16_model(),
    "EfficientNetB0": build_efficientnet_model(),
}

# Load pre-trained weights
models["VGG16"].load_weights("VGG16_best_finetune_checkpoint.weights.h5")
models["EfficientNetB0"].load_weights("EfficientNetB0_best_finetune_checkpoint.weights.h5")

print("Models and weights loaded successfully!")

# Set the default model
current_model = models["VGG16"]

# Function to update the current model based on selection
def load_selected_model(model_name):
    global current_model
    current_model = models[model_name]
    return f"Loaded {model_name} model."

# Preprocessing function
def preprocess_img(img, model_name):
    img = img.resize(image_size)  # Resize to match training
    img_array = image.img_to_array(img)

    # Apply the correct preprocessing function based on model selection
    if model_name == "VGG16":
        img_array = preprocess_vgg(img_array)
    else:
        img_array = preprocess_efficientnet(img_array)

    img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
    return img_array

# Prediction function
def predict(img, model_name):
    img_array = preprocess_img(img, model_name)
    predictions = current_model.predict(img_array)
    
    confidence = float(predictions[0][0])  # Convert to float
    
    # Class labels
    class_labels = {0: "Defective", 1: "Good"}
    predicted_class = 1 if confidence > 0.5 else 0

    result_text = (
        f"Predicted Class: {class_labels[predicted_class]}\n"
        f"Confidence (Good): {confidence:.8f}\n"
        f"Confidence (Defective): {1 - confidence:.8f}"
    )

    return result_text

# Function to clear input and output
def clear():
    return None, ""

# Gradio Interface with Model Selection
with gr.Blocks() as interface:
    gr.Markdown(
        "## Fine-tuned Defect Tyre Classification\n\n"
        "This Gradio-based application allows users to classify tyre images as **'Good'** or **'Defective'** using a "
        "fine-tuned deep learning model. Users can select between two models (**VGG16** and **EfficientNetB0**) "
        "via a dropdown menu. The selected model is dynamically loaded and applied to the uploaded image, "
        "with predictions and confidence scores displayed as output.\n\n"
        "Upload a tyre image and select a model (**VGG16** or **EfficientNetB0**) to classify defects."
    )

    # Dropdown to select model
    model_dropdown = gr.Dropdown(
        choices=list(models.keys()), value="VGG16", label="Select Model"
    )
    model_status = gr.Textbox(label="Model Status", interactive=False)

    with gr.Row():
        input_image = gr.Image(type="pil", label="Upload Image")
        output_text = gr.Textbox(label="Classification Output", interactive=False)

    with gr.Row():
        classify_button = gr.Button("Classify")
        clear_button = gr.Button("Clear")

    # Link dropdown to model selection function
    model_dropdown.change(fn=load_selected_model, inputs=model_dropdown, outputs=model_status)
    
    # Link buttons to functions
    classify_button.click(fn=predict, inputs=[input_image, model_dropdown], outputs=output_text)
    clear_button.click(fn=clear, inputs=[], outputs=[input_image, output_text])

# Run the app
if __name__ == "__main__":
    interface.launch()