allutrifork commited on
Commit
9557a09
·
1 Parent(s): 189d865

resnet deleted

Browse files
app.py CHANGED
@@ -5,113 +5,27 @@ import torch
5
  from ultralytics import YOLO
6
  import numpy as np
7
  import os
8
- from torchvision import models, transforms
9
- import re
10
- import logging
11
-
12
- # Configure logging
13
- logging.basicConfig(filename='app.log', level=logging.INFO,
14
- format='%(asctime)s:%(levelname)s:%(message)s')
15
-
16
- # Load Pillow version
17
  from PIL import __version__ as PIL_VERSION
18
  print(f"Pillow version: {PIL_VERSION}")
19
 
20
- # Paths to models and labels
21
  MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
22
- SCENE_MODEL_PATH = "model/resnet50_places365.pth.tar" # Updated path
23
- SCENE_LABELS_PATH = "model/categories_places365.txt" # Updated path
24
 
25
- # Verify the model paths
 
 
 
26
  if not os.path.exists(MODEL_PATH):
27
  raise FileNotFoundError(f"YOLO model not found at '{MODEL_PATH}'.")
28
- if not os.path.exists(SCENE_MODEL_PATH):
29
- raise FileNotFoundError(f"Scene classification model not found at '{SCENE_MODEL_PATH}'.")
30
- if not os.path.exists(SCENE_LABELS_PATH):
31
- raise FileNotFoundError(f"Scene classification labels not found at '{SCENE_LABELS_PATH}'.")
32
 
33
  # Load the YOLO model
34
  model = YOLO(MODEL_PATH)
35
  print("YOLO model loaded.")
36
 
37
- # Load the scene classification model
38
- def load_scene_classification_model():
39
- # Load pre-trained ResNet50 model
40
- scene_model = models.resnet50(num_classes=365)
41
- checkpoint = torch.load(SCENE_MODEL_PATH, map_location=torch.device('cpu'))
42
- # Remove 'module.' prefix if present
43
- state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
44
- scene_model.load_state_dict(state_dict)
45
- scene_model.eval()
46
- return scene_model
47
-
48
- scene_model = load_scene_classification_model()
49
- print("Scene classification model loaded.")
50
-
51
- # Load class labels
52
- with open(SCENE_LABELS_PATH) as class_file:
53
- classes = class_file.read().splitlines()
54
-
55
- # Correct parsing of class labels
56
- # Each line is in the format '/a/beach 48', so we extract 'beach'
57
- class_labels = [line.split(' ')[0][3:].lower() for line in classes]
58
-
59
- # Debug: Print some class labels to verify parsing
60
- print("Sample Class Labels:")
61
- for idx in range(10):
62
- print(f"{idx}: {class_labels[idx]}")
63
-
64
- # Define image transformations for scene classification
65
- scene_transform = transforms.Compose([
66
- transforms.Resize((224, 224)),
67
- transforms.ToTensor(),
68
- transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet means
69
- std=[0.229, 0.224, 0.225]) # ImageNet stds
70
- ])
71
-
72
- def is_beach_scene(input_image, model, class_labels, transform, threshold=0.2):
73
- """
74
- Classify the scene of the input image and check if it's a beach.
75
-
76
- Args:
77
- input_image (PIL.Image): The uploaded image.
78
- model (torch.nn.Module): The pre-trained scene classification model.
79
- class_labels (list): List of class labels.
80
- transform (torchvision.transforms): Image transformations.
81
- threshold (float): Confidence threshold for beach classification.
82
-
83
- Returns:
84
- bool: True if the image is classified as beach with confidence >= threshold, else False.
85
- float: Confidence score for the beach classification.
86
- """
87
- image = transform(input_image).unsqueeze(0) # Add batch dimension
88
- with torch.no_grad():
89
- outputs = model(image)
90
- probabilities = torch.nn.functional.softmax(outputs, dim=1)
91
- confidence, predicted = torch.max(probabilities, 1)
92
- predicted_class = class_labels[predicted.item()]
93
- predicted_class_lower = predicted_class.lower()
94
-
95
- # Check if 'beach' or 'sand' is in the predicted class and exclude 'desert'
96
- is_beach = (('beach' in predicted_class_lower or 'sand' in predicted_class_lower) and
97
- ('desert' not in predicted_class_lower) and
98
- confidence.item() >= threshold)
99
-
100
- # Log the classification result
101
- logging.info(f"Predicted Class: {predicted_class}, Confidence: {confidence.item():.4f}, Is Beach: {is_beach}")
102
-
103
- # Debug: Print predicted class and confidence
104
- print(f"Predicted Class: {predicted_class}, Confidence: {confidence.item():.4f}")
105
- print(f"Is Beach: {is_beach}")
106
-
107
- return is_beach, confidence.item()
108
-
109
  def detect_plastic_pellets(input_image, threshold=0.5):
110
  """
111
- Perform plastic pellet detection using our customized model after verifying the scene.
112
  """
113
  if input_image is None:
114
- logging.warning("No image uploaded.")
115
  error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
116
  draw = ImageDraw.Draw(error_image)
117
  try:
@@ -122,24 +36,7 @@ def detect_plastic_pellets(input_image, threshold=0.5):
122
  return error_image
123
 
124
  try:
125
- print("Starting scene classification...")
126
- logging.info("Starting scene classification...")
127
- is_beach, scene_confidence = is_beach_scene(input_image, scene_model, class_labels, scene_transform, threshold=0.2)
128
-
129
- if not is_beach:
130
- logging.warning("Image not recognized as a beach.")
131
- error_image = Image.new('RGB', (500, 150), color=(255, 165, 0)) # Increased height for more text
132
- draw = ImageDraw.Draw(error_image)
133
- try:
134
- font = ImageFont.truetype("arial.ttf", size=15)
135
- except IOError:
136
- font = ImageFont.load_default()
137
- message = f"Image is not recognized as a beach.\nConfidence: {scene_confidence:.2f}"
138
- draw.text((10, 40), message, fill=(0, 0, 0), font=font)
139
- return error_image
140
-
141
- print("Scene classification passed. Starting detection...")
142
- logging.info("Scene classification passed. Starting detection...")
143
  input_image.thumbnail((1024, 1024), Image.LANCZOS)
144
  img = np.array(input_image.convert("RGB"))
145
 
@@ -172,20 +69,14 @@ def detect_plastic_pellets(input_image, threshold=0.5):
172
 
173
  detection_made = True
174
 
175
- if detection_made:
176
- logging.info("Plastic pellets detected.")
177
- print("Plastic pellets detected.")
178
- else:
179
- logging.info("No plastic pellets detected.")
180
  draw.text((10, 10), "No plastic pellets detected.", fill=(255, 0, 0), font=font)
181
  return input_image
182
 
183
  print("Detection completed.")
184
- logging.info("Detection completed.")
185
  return input_image
186
 
187
  except Exception as e:
188
- logging.error(f"Detection error: {str(e)}")
189
  print(f"Detection error: {str(e)}")
190
  error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
191
  draw = ImageDraw.Draw(error_image)
@@ -244,4 +135,4 @@ def main():
244
  demo.launch()
245
 
246
  if __name__ == "__main__":
247
- main()
 
5
  from ultralytics import YOLO
6
  import numpy as np
7
  import os
 
 
 
 
 
 
 
 
 
8
  from PIL import __version__ as PIL_VERSION
9
  print(f"Pillow version: {PIL_VERSION}")
10
 
 
11
  MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
 
 
12
 
13
+ # Define the confidence threshold (used if not using the slider)
14
+ # CONF_THRESHOLD = 0.5 # Optional: Remove if using the slider
15
+
16
+ # Verify the model path
17
  if not os.path.exists(MODEL_PATH):
18
  raise FileNotFoundError(f"YOLO model not found at '{MODEL_PATH}'.")
 
 
 
 
19
 
20
  # Load the YOLO model
21
  model = YOLO(MODEL_PATH)
22
  print("YOLO model loaded.")
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def detect_plastic_pellets(input_image, threshold=0.5):
25
  """
26
+ Perform plastic pellet detection using our customized model.
27
  """
28
  if input_image is None:
 
29
  error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
30
  draw = ImageDraw.Draw(error_image)
31
  try:
 
36
  return error_image
37
 
38
  try:
39
+ print("Starting detection with threshold:", threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  input_image.thumbnail((1024, 1024), Image.LANCZOS)
41
  img = np.array(input_image.convert("RGB"))
42
 
 
69
 
70
  detection_made = True
71
 
72
+ if not detection_made:
 
 
 
 
73
  draw.text((10, 10), "No plastic pellets detected.", fill=(255, 0, 0), font=font)
74
  return input_image
75
 
76
  print("Detection completed.")
 
77
  return input_image
78
 
79
  except Exception as e:
 
80
  print(f"Detection error: {str(e)}")
81
  error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
82
  draw = ImageDraw.Draw(error_image)
 
135
  demo.launch()
136
 
137
  if __name__ == "__main__":
138
+ main()
model/categories_places365.txt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2affba635eb657e7ca95f4e6cc69bd9fac29ef4c32aeb83cafdfcd06ec6a1ea6
3
- size 6833
 
 
 
 
model/resnet50_places365.pth.tar DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:46529c86902bd0cfb0ea562a30b2850c28d2620d96282b3db9c318e1d774f6c5
3
- size 97270159
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,3 @@
1
- gradio>=3.38.0
2
- torch>=2.0.0
3
- torchvision>=0.15.1
4
- ultralytics>=8.0.0
5
- pillow>=10.0.0
6
- numpy>=1.23.0
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0689b23c5d7d1c089c59d97ac59bee19bec098c7857c300e9df9815cc1840d63
3
+ size 96