wookimchye commited on
Commit
1b803a3
·
verified ·
1 Parent(s): 6ee5646

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[8]:
5
+
6
+
7
+ import tensorflow as tf
8
+ from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
9
+ from tensorflow.keras.preprocessing import image
10
+ from ultralytics import YOLO
11
+ import numpy as np
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ import gradio as gr
14
+ from huggingface_hub import snapshot_download
15
+ import os
16
+ from torchvision import transforms
17
+
18
+ # Define the class labels
19
+ classes = {0: "Defective", 1: "Good"}
20
+
21
+ model_path = "ResNet50_Classification.h5" # Trained RestNet50 model
22
+
23
+ best_yolo_model = "best.pt" # Trained YOLOv8 detection model
24
+
25
+ classification_model = tf.keras.models.load_model('ResNet50_Classification.h5')
26
+
27
+ detection_model = YOLO(best_yolo_model, task='detect')
28
+
29
+
30
+ # Define the image preprocessing function
31
+ def preprocess_imageXX(image):
32
+ image = image.resize((224, 224)) # Resize to the input size of ResNet50
33
+ image = np.array(image) # Convert to numpy array
34
+ image = preprocess_input(image) # Preprocess for ResNet50
35
+ image = np.expand_dims(image, axis=0) # Add batch dimension
36
+ return image
37
+
38
+ def preprocess_image(pilimg):
39
+ img = pilimg.resize((224, 224)) # Resize to the input size of ResNet50
40
+ img_array = image.img_to_array(img)
41
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
42
+ return img_array
43
+
44
+ def classify_image(pilimg):
45
+ img_array = preprocess_image(pilimg) # Preprocess the input image
46
+ classify_result = classification_model.predict(img_array)[0][0] # Get prediction probability
47
+ print(">>> Result : ", classify_result)
48
+
49
+ predicted_class = "Good" if classify_result >= 0.5 else "Defective"
50
+ print(">>> predicted_class : ", predicted_class)
51
+
52
+ return predicted_class
53
+
54
+ def detect_defect(img):
55
+ detection_result = detection_model.predict(img, conf=0.4, iou=0.5)
56
+
57
+ return detection_result
58
+
59
+
60
+ def process_image(pilimg):
61
+ # Perform classification first, then perform detection if Defective
62
+ classification = classify_image(pilimg)
63
+
64
+ if classification == "Good":
65
+ out_pilimg = pilimg.convert("RGB")
66
+ draw = ImageDraw.Draw(out_pilimg)
67
+ font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font
68
+
69
+ draw.text((250, 10), "Good", fill="green", font=font)
70
+ gr.Info("No defect detected,GOOD!", duration=3)
71
+
72
+ else: # Defective
73
+ detection_result = detect_defect(pilimg)
74
+ img_bgr = detection_result[0].plot()
75
+ out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image
76
+
77
+ draw = ImageDraw.Draw(out_pilimg)
78
+ font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font
79
+
80
+ draw.text((250, 10), "Defective", fill="red", font=font)
81
+
82
+ detections = detection_result[0].boxes.data # Get detections
83
+ if len(detections) > 0:
84
+ gr.Warning("Defect detected, BAD!", duration=3)
85
+ else:
86
+ gr.Warning("Classified as Defective but defect cannot be detected, ERROR!")
87
+
88
+ return out_pilimg
89
+
90
+ title = "Detect the status of the cap, DEFECTIVE or GOOD"
91
+ interface = gr.Interface(
92
+ fn=process_image,
93
+ inputs=gr.Image(type="pil", label="Input Image"),
94
+ outputs=gr.Image(type="pil", label="Classification result"),
95
+ title=title,
96
+ )
97
+
98
+ # Launch the interface
99
+ interface.launch(share=True)
100
+
101
+
102
+ # In[ ]:
103
+
104
+
105
+
106
+
107
+
108
+ # In[ ]:
109
+
110
+
111
+
112
+
113
+
114
+ # In[ ]:
115
+
116
+
117
+
118
+
119
+
120
+ # In[37]:
121
+
122
+
123
+ from ultralytics import YOLO
124
+ from PIL import Image, ImageDraw, ImageFont
125
+ import gradio as gr
126
+ from huggingface_hub import snapshot_download
127
+ import os
128
+ from torchvision import transforms
129
+
130
+ classes = {0: "Defective", 1: "Good"}
131
+
132
+ model_path = "best_int8_openvino_model"
133
+
134
+ def load_model_local():
135
+ detection_model = YOLO(model_path, task='classify') # Load the model
136
+ return detection_model
137
+
138
+ def load_model(repo_id):
139
+ download_dir = snapshot_download(repo_id)
140
+ print(download_dir)
141
+ path = os.path.join(download_dir, "best_int8_openvino_model")
142
+ print(path)
143
+ detection_model = YOLO(path, task='classify')
144
+ return detection_model
145
+
146
+ def predict(pilimg):
147
+ source = pilimg
148
+
149
+ # Call the model to transform image size
150
+ transform = transforms.Compose([
151
+ transforms.Resize((224, 224)),
152
+ transforms.ToTensor(),
153
+ ])
154
+
155
+ source = transform(source) # Update the source image size to 224x224, 1 of 2 sizes accepted by Yolo classification model
156
+
157
+ #result = detection_model.predict(source, conf=0.55, iou=0.6)
158
+ result = detection_model.predict(source) # Make prediction
159
+ # Get the top prediction
160
+ label = result[0].probs.top1
161
+
162
+ class_names = detection_model.names # Retrieves the class names mapping (dict-like)
163
+ classified_type = class_names[label] # Map numeric label to class name
164
+ print (">>> Class : ", classified_type)
165
+
166
+ confidence = result[0].probs.top1conf # Get the top class confidence
167
+ print(">>> Confidence : ", confidence)
168
+
169
+ annotated_image = pilimg.convert("RGB")
170
+ draw = ImageDraw.Draw(annotated_image)
171
+ font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font
172
+
173
+ if classified_type == classes[0]:
174
+ draw.text((300, 10), classified_type, fill="red", font=font)
175
+ gr.Warning("Defect detected, BAD!.")
176
+ else:
177
+ draw.text((300, 10), classified_type, fill="green", font=font)
178
+ gr.Info("No defect detected,GOOD!")
179
+
180
+ #draw.text((300, 10), classified_type, fill="red", font=font)
181
+
182
+ return annotated_image
183
+
184
+ detection_model = load_model_local()
185
+
186
+ title = "Detect the status of the cap, DEFECTIVE or GOOD"
187
+ interface = gr.Interface(
188
+ fn=predict,
189
+ inputs=gr.Image(type="pil", label="Input Image"),
190
+ outputs=gr.Image(type="pil", label="Classification result"),
191
+ title=title,
192
+ )
193
+
194
+ # Launch the interface
195
+ interface.launch(share=True)
196
+
197
+
198
+ # In[ ]:
199
+
200
+
201
+
202
+
203
+
204
+ # In[ ]:
205
+
206
+
207
+
208
+
209
+
210
+ # In[ ]:
211
+
212
+
213
+
214
+