allutrifork commited on
Commit
e6c072f
·
1 Parent(s): 6f6962c

pre resnet model added

Browse files
.gitattributes CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.bmp filter=lfs diff=lfs merge=lfs -text
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.bmp filter=lfs diff=lfs merge=lfs -text
37
+ *.t7 filter=lfs diff=lfs merge=lfs -text
38
+ *.pth.tar filter=lfs diff=lfs merge=lfs -text
39
+ *.txt filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -5,25 +5,85 @@ import torch
5
  from ultralytics import YOLO
6
  import numpy as np
7
  import os
 
 
 
 
8
  from PIL import __version__ as PIL_VERSION
9
  print(f"Pillow version: {PIL_VERSION}")
10
 
11
  MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
 
 
12
 
13
  # Define the confidence threshold (used if not using the slider)
14
  # CONF_THRESHOLD = 0.5 # Optional: Remove if using the slider
15
 
16
- # Verify the model path
17
  if not os.path.exists(MODEL_PATH):
18
  raise FileNotFoundError(f"YOLO model not found at '{MODEL_PATH}'.")
 
 
 
 
19
 
20
  # Load the YOLO model
21
  model = YOLO(MODEL_PATH)
22
  print("YOLO model loaded.")
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def detect_plastic_pellets(input_image, threshold=0.5):
25
  """
26
- Perform plastic pellet detection using our customized model.
27
  """
28
  if input_image is None:
29
  error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
@@ -36,7 +96,21 @@ def detect_plastic_pellets(input_image, threshold=0.5):
36
  return error_image
37
 
38
  try:
39
- print("Starting detection with threshold:", threshold)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  input_image.thumbnail((1024, 1024), Image.LANCZOS)
41
  img = np.array(input_image.convert("RGB"))
42
 
@@ -135,4 +209,4 @@ def main():
135
  demo.launch()
136
 
137
  if __name__ == "__main__":
138
- main()
 
5
  from ultralytics import YOLO
6
  import numpy as np
7
  import os
8
+ from torchvision import models, transforms
9
+ import json
10
+
11
+ # Load Pillow version
12
  from PIL import __version__ as PIL_VERSION
13
  print(f"Pillow version: {PIL_VERSION}")
14
 
15
  MODEL_PATH = "model/231220_detect_lr_0001_640_brightness.pt"
16
+ SCENE_MODEL_PATH = "model/resnet50_places365.pth.tar" # Updated path
17
+ SCENE_LABELS_PATH = "model/categories_places365.txt" # Updated path
18
 
19
  # Define the confidence threshold (used if not using the slider)
20
  # CONF_THRESHOLD = 0.5 # Optional: Remove if using the slider
21
 
22
+ # Verify the model paths
23
  if not os.path.exists(MODEL_PATH):
24
  raise FileNotFoundError(f"YOLO model not found at '{MODEL_PATH}'.")
25
+ if not os.path.exists(SCENE_MODEL_PATH):
26
+ raise FileNotFoundError(f"Scene classification model not found at '{SCENE_MODEL_PATH}'.")
27
+ if not os.path.exists(SCENE_LABELS_PATH):
28
+ raise FileNotFoundError(f"Scene classification labels not found at '{SCENE_LABELS_PATH}'.")
29
 
30
  # Load the YOLO model
31
  model = YOLO(MODEL_PATH)
32
  print("YOLO model loaded.")
33
 
34
+ # Load the scene classification model
35
+ def load_scene_classification_model():
36
+ # Load pre-trained ResNet50 model
37
+ model = models.resnet50(num_classes=365)
38
+ checkpoint = torch.load(SCENE_MODEL_PATH, map_location=torch.device('cpu'))
39
+ state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items()}
40
+ model.load_state_dict(state_dict)
41
+ model.eval()
42
+ return model
43
+
44
+ scene_model = load_scene_classification_model()
45
+ print("Scene classification model loaded.")
46
+
47
+ # Load class labels
48
+ with open(SCENE_LABELS_PATH) as class_file:
49
+ classes = class_file.read().splitlines()
50
+ class_labels = [line.split(' ')[0][3:] for line in classes] # Adjust parsing based on the file format
51
+
52
+ # Define image transformations for scene classification
53
+ scene_transform = transforms.Compose([
54
+ transforms.Resize((224, 224)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet means
57
+ std=[0.229, 0.224, 0.225]) # ImageNet stds
58
+ ])
59
+
60
+ def is_beach_scene(input_image, model, class_labels, transform, threshold=0.5):
61
+ """
62
+ Classify the scene of the input image and check if it's a beach.
63
+
64
+ Args:
65
+ input_image (PIL.Image): The uploaded image.
66
+ model (torch.nn.Module): The pre-trained scene classification model.
67
+ class_labels (list): List of class labels.
68
+ transform (torchvision.transforms): Image transformations.
69
+ threshold (float): Confidence threshold for beach classification.
70
+
71
+ Returns:
72
+ bool: True if the image is classified as beach with confidence >= threshold, else False.
73
+ float: Confidence score for the beach classification.
74
+ """
75
+ image = transform(input_image).unsqueeze(0) # Add batch dimension
76
+ with torch.no_grad():
77
+ outputs = model(image)
78
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
79
+ confidence, predicted = torch.max(probabilities, 1)
80
+ predicted_class = class_labels[predicted.item()]
81
+ is_beach = predicted_class.lower() in ['beach', 'seashore', 'shore', 'oceanfront'] and confidence.item() >= threshold
82
+ return is_beach, confidence.item()
83
+
84
  def detect_plastic_pellets(input_image, threshold=0.5):
85
  """
86
+ Perform plastic pellet detection using our customized model after verifying the scene.
87
  """
88
  if input_image is None:
89
  error_image = Image.new('RGB', (500, 100), color=(255, 0, 0))
 
96
  return error_image
97
 
98
  try:
99
+ print("Starting scene classification...")
100
+ is_beach, scene_confidence = is_beach_scene(input_image, scene_model, class_labels, scene_transform, threshold=0.5)
101
+
102
+ if not is_beach:
103
+ error_image = Image.new('RGB', (500, 100), color=(255, 165, 0)) # Orange color
104
+ draw = ImageDraw.Draw(error_image)
105
+ try:
106
+ font = ImageFont.truetype("arial.ttf", size=15)
107
+ except IOError:
108
+ font = ImageFont.load_default()
109
+ message = f"Image is not recognized as a beach (Confidence: {scene_confidence:.2f}). Please upload a beach image."
110
+ draw.text((10, 40), message, fill=(0, 0, 0), font=font)
111
+ return error_image
112
+
113
+ print("Scene classification passed. Starting detection...")
114
  input_image.thumbnail((1024, 1024), Image.LANCZOS)
115
  img = np.array(input_image.convert("RGB"))
116
 
 
209
  demo.launch()
210
 
211
  if __name__ == "__main__":
212
+ main()
model/categories_places365.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2affba635eb657e7ca95f4e6cc69bd9fac29ef4c32aeb83cafdfcd06ec6a1ea6
3
+ size 6833
model/resnet50_places365.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46529c86902bd0cfb0ea562a30b2850c28d2620d96282b3db9c318e1d774f6c5
3
+ size 97270159