Spaces:
Sleeping
Sleeping
Commit
·
f096e52
1
Parent(s):
8d55205
Added files
Browse files- app.py +75 -0
- assets/banana.jpg +0 -0
- assets/black.jpg +0 -0
- assets/eggplant.jpg +0 -0
- assets/mango.jpg +0 -0
- assets/melon.jpg +0 -0
- assets/orange.jpg +0 -0
- assets/pineapple.jpg +0 -0
- assets/white.jpg +0 -0
- labels.json +32 -0
- requirements.txt +5 -0
- src/__pycache__/dataset.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/dataset.py +48 -0
- src/models/__pycache__/mobilenet_v2.cpython-310.pyc +0 -0
- src/models/__pycache__/mobilenet_v2.cpython-312.pyc +0 -0
- src/models/__pycache__/resnet50.cpython-310.pyc +0 -0
- src/models/__pycache__/resnet50.cpython-312.pyc +0 -0
- src/models/mobilenet_v2.py +151 -0
- src/models/resnet50.py +147 -0
- src/train.py +152 -0
- src/utils.py +32 -0
- weights/checkpoint-best-mobilenet.pth +3 -0
- weights/checkpoint-best-resnet.pth +3 -0
- weights/download_checkpoints.sh +2 -0
- weights/download_pretrained.py +0 -0
- weights/mobilenet_v2-b0353104.pth +3 -0
- weights/resnet50-0676ba61.pth +3 -0
app.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import gradio as gr
|
3 |
+
import json
|
4 |
+
import PIL.Image, PIL.ImageOps
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
|
8 |
+
from src.models.resnet50 import ResNet
|
9 |
+
from src.models.mobilenet_v2 import MobileNetV2
|
10 |
+
|
11 |
+
|
12 |
+
num_classes = 30
|
13 |
+
model1 = ResNet(weights_path="weights/checkpoint-best-resnet.pth", num_classes=num_classes)
|
14 |
+
model1.eval()
|
15 |
+
model2 = MobileNetV2(weights_path="weights/checkpoint-best-mobilenet.pth", num_classes=num_classes)
|
16 |
+
model2.eval()
|
17 |
+
|
18 |
+
|
19 |
+
with open("labels.json", "r") as f:
|
20 |
+
class_labels = json.load(f)
|
21 |
+
label_mapping = {v: k for k, v in class_labels.items()}
|
22 |
+
|
23 |
+
|
24 |
+
def predict(img, model_choice) -> Dict[str, float]:
|
25 |
+
model = model1 if model_choice == "ResNet" else model2
|
26 |
+
width, height = img.size
|
27 |
+
max_dim = max(width, height)
|
28 |
+
padding = (max_dim - width, max_dim - height)
|
29 |
+
img = PIL.ImageOps.expand(img, padding, (255, 255, 255))
|
30 |
+
img = img.resize((224, 224))
|
31 |
+
img = F.to_tensor(img)
|
32 |
+
img = F.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
33 |
+
img = img.unsqueeze(0)
|
34 |
+
|
35 |
+
with torch.inference_mode():
|
36 |
+
logits = model.forward(img)
|
37 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
38 |
+
top_probs, top_indices = probs[0].topk(3)
|
39 |
+
|
40 |
+
top_classes = {label_mapping[idx.item()]: prob.item() for idx, prob in zip(top_indices, top_probs)}
|
41 |
+
return top_classes
|
42 |
+
|
43 |
+
|
44 |
+
examples = [
|
45 |
+
["assets/banana.jpg"],
|
46 |
+
["assets/pineapple.jpg"],
|
47 |
+
["assets/mango.jpg"],
|
48 |
+
["assets/melon.jpg"],
|
49 |
+
["assets/orange.jpg"],
|
50 |
+
["assets/eggplant.jpg"],
|
51 |
+
["assets/black.jpg"],
|
52 |
+
["assets/white.jpg"]
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
with gr.Blocks() as demo:
|
57 |
+
gr.Markdown("## Plant Classification")
|
58 |
+
with gr.Row():
|
59 |
+
with gr.Column():
|
60 |
+
pic = gr.Image(label="Upload Plant Image", type="pil", height=300, width=300)
|
61 |
+
model_choice = gr.Dropdown(choices=["ResNet", "MobileNetV2"], label="Select Model", value="ResNet")
|
62 |
+
with gr.Row():
|
63 |
+
with gr.Column(scale=1):
|
64 |
+
predict_btn = gr.Button("Predict")
|
65 |
+
with gr.Column(scale=1):
|
66 |
+
clear_btn = gr.Button("Clear")
|
67 |
+
|
68 |
+
with gr.Column():
|
69 |
+
output = gr.Label(label="Top 3 Predicted Classes")
|
70 |
+
|
71 |
+
predict_btn.click(fn=predict, inputs=[pic, model_choice], outputs=output, api_name="predict")
|
72 |
+
clear_btn.click(lambda: (None, None), outputs=[pic, output])
|
73 |
+
gr.Examples(examples=examples, inputs=[pic])
|
74 |
+
|
75 |
+
demo.launch()
|
assets/banana.jpg
ADDED
![]() |
assets/black.jpg
ADDED
![]() |
assets/eggplant.jpg
ADDED
![]() |
assets/mango.jpg
ADDED
![]() |
assets/melon.jpg
ADDED
![]() |
assets/orange.jpg
ADDED
![]() |
assets/pineapple.jpg
ADDED
![]() |
assets/white.jpg
ADDED
![]() |
labels.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"aloevera": 0,
|
3 |
+
"banana": 1,
|
4 |
+
"bilimbi": 2,
|
5 |
+
"cantaloupe": 3,
|
6 |
+
"cassava": 4,
|
7 |
+
"coconut": 5,
|
8 |
+
"corn": 6,
|
9 |
+
"cucumber": 7,
|
10 |
+
"curcuma": 8,
|
11 |
+
"eggplant": 9,
|
12 |
+
"galangal": 10,
|
13 |
+
"ginger": 11,
|
14 |
+
"guava": 12,
|
15 |
+
"kale": 13,
|
16 |
+
"longbeans": 14,
|
17 |
+
"mango": 15,
|
18 |
+
"melon": 16,
|
19 |
+
"orange": 17,
|
20 |
+
"paddy": 18,
|
21 |
+
"papaya": 19,
|
22 |
+
"peperchili": 20,
|
23 |
+
"pineapple": 21,
|
24 |
+
"pomelo": 22,
|
25 |
+
"shallot": 23,
|
26 |
+
"soybeans": 24,
|
27 |
+
"spinach": 25,
|
28 |
+
"sweetpotatoes": 26,
|
29 |
+
"tobacco": 27,
|
30 |
+
"waterapple": 28,
|
31 |
+
"watermelon": 29
|
32 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.1
|
2 |
+
torchvision==0.19.1
|
3 |
+
kaggle==1.6.17
|
4 |
+
wandb==0.18.5
|
5 |
+
gradio==5.4.0
|
src/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (2.26 kB). View file
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (713 Bytes). View file
|
|
src/dataset.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Tuple, Callable
|
2 |
+
from pathlib import Path
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
import PIL.Image
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class PlantsDataset(Dataset):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
root: str,
|
12 |
+
labels: Dict[int, str],
|
13 |
+
transform: Callable,
|
14 |
+
load_to_ram: bool = True,
|
15 |
+
) -> None:
|
16 |
+
super().__init__()
|
17 |
+
self.root = root
|
18 |
+
self.labels = labels
|
19 |
+
self.transform = transform
|
20 |
+
self.load_to_ram = load_to_ram
|
21 |
+
|
22 |
+
self.data = [
|
23 |
+
{
|
24 |
+
"path": x.as_posix(),
|
25 |
+
"label": self.labels[x.parent.name],
|
26 |
+
"image": PIL.Image.open(x).convert("RGB") if self.load_to_ram else None,
|
27 |
+
}
|
28 |
+
for x in sorted(Path(self.root).glob("**/*.jpg"))
|
29 |
+
]
|
30 |
+
|
31 |
+
def __len__(self) -> int:
|
32 |
+
return len(self.data)
|
33 |
+
|
34 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
35 |
+
item = self.data[idx]
|
36 |
+
if self.load_to_ram:
|
37 |
+
image = item["image"]
|
38 |
+
else:
|
39 |
+
image = PIL.Image.open(item["path"]).convert("RGB")
|
40 |
+
image = self.transform(image)
|
41 |
+
label = torch.tensor(item["label"], dtype=torch.long)
|
42 |
+
return (image, label)
|
43 |
+
|
44 |
+
|
45 |
+
def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
46 |
+
images = torch.cat([item[0].unsqueeze(0) for item in items])
|
47 |
+
labels = torch.cat([item[1].unsqueeze(0) for item in items])
|
48 |
+
return (images, labels)
|
src/models/__pycache__/mobilenet_v2.cpython-310.pyc
ADDED
Binary file (5.32 kB). View file
|
|
src/models/__pycache__/mobilenet_v2.cpython-312.pyc
ADDED
Binary file (9.01 kB). View file
|
|
src/models/__pycache__/resnet50.cpython-310.pyc
ADDED
Binary file (5.16 kB). View file
|
|
src/models/__pycache__/resnet50.cpython-312.pyc
ADDED
Binary file (9.77 kB). View file
|
|
src/models/mobilenet_v2.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Union
|
2 |
+
from pathlib import Path
|
3 |
+
import PIL.Image
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Conv2dNormActivation(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
|
12 |
+
padding: int | None = None, groups: int = 1, norm_layer: Callable[..., torch.nn.Module] = nn.BatchNorm2d,
|
13 |
+
activation_layer: Callable[..., torch.nn.Module] = nn.ReLU, bias: bool | None = False,
|
14 |
+
) -> None:
|
15 |
+
super().__init__()
|
16 |
+
if padding is None:
|
17 |
+
padding = (kernel_size - 1) // 2
|
18 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, groups=groups)
|
19 |
+
self.norm = norm_layer(out_channels)
|
20 |
+
self.activation = activation_layer()
|
21 |
+
|
22 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
23 |
+
x = self.conv(x)
|
24 |
+
x = self.norm(x)
|
25 |
+
x = self.activation(x)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
class InvertedResidual(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self, inp: int, oup: int, stride: int, expand_ratio: int,
|
32 |
+
) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.stride = stride
|
35 |
+
|
36 |
+
hidden_dim = int(round(inp * expand_ratio))
|
37 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
38 |
+
|
39 |
+
layers = []
|
40 |
+
if expand_ratio != 1:
|
41 |
+
layers.append(Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU6))
|
42 |
+
layers.extend(
|
43 |
+
[
|
44 |
+
Conv2dNormActivation(
|
45 |
+
hidden_dim,
|
46 |
+
hidden_dim,
|
47 |
+
stride=stride,
|
48 |
+
groups=hidden_dim,
|
49 |
+
norm_layer=nn.BatchNorm2d,
|
50 |
+
activation_layer=nn.ReLU6,
|
51 |
+
),
|
52 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
53 |
+
nn.BatchNorm2d(oup),
|
54 |
+
]
|
55 |
+
)
|
56 |
+
self.conv = nn.Sequential(*layers)
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
59 |
+
if self.use_res_connect:
|
60 |
+
return x + self.conv(x)
|
61 |
+
else:
|
62 |
+
return self.conv(x)
|
63 |
+
|
64 |
+
|
65 |
+
class MobileNetV2(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
num_classes: int = 1000,
|
69 |
+
weights_path: str | None = None,
|
70 |
+
) -> None:
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
if weights_path is not None and not Path(weights_path).exists():
|
74 |
+
raise FileNotFoundError(weights_path)
|
75 |
+
|
76 |
+
input_channel = 32
|
77 |
+
last_channel = 1280
|
78 |
+
inverted_residual_setting = [
|
79 |
+
# t, c, n, s
|
80 |
+
[1, 16, 1, 1],
|
81 |
+
[6, 24, 2, 2],
|
82 |
+
[6, 32, 3, 2],
|
83 |
+
[6, 64, 4, 2],
|
84 |
+
[6, 96, 3, 1],
|
85 |
+
[6, 160, 3, 2],
|
86 |
+
[6, 320, 1, 1],
|
87 |
+
]
|
88 |
+
|
89 |
+
features = [Conv2dNormActivation(3, input_channel, stride=2, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU6)]
|
90 |
+
for t, c, n, s in inverted_residual_setting:
|
91 |
+
output_channel = c
|
92 |
+
for i in range(n):
|
93 |
+
stride = s if i == 0 else 1
|
94 |
+
features.append(InvertedResidual(input_channel, output_channel, stride, expand_ratio=t))
|
95 |
+
input_channel = output_channel
|
96 |
+
features.append(
|
97 |
+
Conv2dNormActivation(
|
98 |
+
input_channel, last_channel, kernel_size=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU6
|
99 |
+
)
|
100 |
+
)
|
101 |
+
self.features = nn.Sequential(*features)
|
102 |
+
|
103 |
+
self.classifier = nn.Sequential(
|
104 |
+
nn.Dropout(p=0.2),
|
105 |
+
nn.Linear(last_channel, num_classes),
|
106 |
+
)
|
107 |
+
|
108 |
+
if weights_path:
|
109 |
+
self.load_pretrained_weights(weights_path)
|
110 |
+
|
111 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
112 |
+
x = self.features(x)
|
113 |
+
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
114 |
+
x = torch.flatten(x, 1)
|
115 |
+
x = self.classifier(x)
|
116 |
+
return x
|
117 |
+
|
118 |
+
def load_pretrained_weights(self, weights_path: str) -> None:
|
119 |
+
state_dict = torch.load(weights_path, map_location="cpu")
|
120 |
+
model_state_dict = self.state_dict()
|
121 |
+
new_state_dict = {}
|
122 |
+
for key1, key2 in zip(model_state_dict.keys(), state_dict.keys()):
|
123 |
+
new_state_dict[key1] = state_dict[key2]
|
124 |
+
self.load_state_dict(new_state_dict)
|
125 |
+
|
126 |
+
@torch.inference_mode()
|
127 |
+
def predict(self, x: torch.Tensor, top_k: int | None) -> Union[List[int], List[List[int]]]:
|
128 |
+
output = self.forward(x)
|
129 |
+
probs = torch.nn.functional.softmax(output, dim=1)
|
130 |
+
if top_k is not None:
|
131 |
+
preds = torch.topk(probs, dim=1, k=top_k).indices
|
132 |
+
return preds.tolist()
|
133 |
+
else:
|
134 |
+
pred = torch.argmax(probs, dim=1)
|
135 |
+
return pred.tolist()
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
model = MobileNetV2(weights_path="weights\mobilenet_v2-b0353104.pth")
|
140 |
+
num_params = sum([p.numel() for p in model.parameters()])
|
141 |
+
print(f"params: {num_params/1e6:.2f} M")
|
142 |
+
|
143 |
+
model.eval()
|
144 |
+
image = PIL.Image.open("assets\cat.jpg").convert("RGB")
|
145 |
+
image = F.resize(image, (224, 224))
|
146 |
+
image = F.to_tensor(image)
|
147 |
+
image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
148 |
+
image = image.unsqueeze(0)
|
149 |
+
predicted_class = model.predict(image, top_k=10)
|
150 |
+
print(f"predicted class: {predicted_class}")
|
151 |
+
# https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/
|
src/models/resnet50.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, List
|
2 |
+
from pathlib import Path
|
3 |
+
import PIL.Image
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Bottleneck(nn.Module):
|
10 |
+
expansion: int = 4
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
inplanes: int,
|
15 |
+
planes: int,
|
16 |
+
stride: int = 1,
|
17 |
+
downsample: nn.Module | None = None,
|
18 |
+
groups: int = 1,
|
19 |
+
dilation: int = 1,
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)
|
23 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
24 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=groups, dilation=dilation, bias=False)
|
25 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
26 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1, bias=False)
|
27 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
28 |
+
self.relu = nn.ReLU(inplace=True)
|
29 |
+
self.downsample = downsample
|
30 |
+
|
31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
32 |
+
identity = x
|
33 |
+
out = self.conv1(x)
|
34 |
+
out = self.bn1(out)
|
35 |
+
out = self.relu(out)
|
36 |
+
out = self.conv2(out)
|
37 |
+
out = self.bn2(out)
|
38 |
+
out = self.relu(out)
|
39 |
+
out = self.conv3(out)
|
40 |
+
out = self.bn3(out)
|
41 |
+
if self.downsample is not None:
|
42 |
+
identity = self.downsample(x)
|
43 |
+
out += identity
|
44 |
+
out = self.relu(out)
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
class ResNet(nn.Module):
|
49 |
+
@property
|
50 |
+
def expansion(self):
|
51 |
+
return 4
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
num_classes: int = 1000,
|
55 |
+
weights_path: str | None = None,
|
56 |
+
) -> None:
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
if weights_path is not None and not Path(weights_path).exists():
|
60 |
+
raise FileNotFoundError(weights_path)
|
61 |
+
|
62 |
+
self.inplanes = 64
|
63 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
64 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
65 |
+
self.relu = nn.ReLU(inplace=True)
|
66 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
67 |
+
self.layer1 = self._make_layer(Bottleneck, 64, 3)
|
68 |
+
self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
|
69 |
+
self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
|
70 |
+
self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)
|
71 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
72 |
+
self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)
|
73 |
+
|
74 |
+
for m in self.modules():
|
75 |
+
if isinstance(m, nn.Conv2d):
|
76 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
77 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
78 |
+
nn.init.constant_(m.weight, 1)
|
79 |
+
nn.init.constant_(m.bias, 0)
|
80 |
+
|
81 |
+
if weights_path:
|
82 |
+
self.load_pretrained_weights(weights_path)
|
83 |
+
|
84 |
+
def _make_layer(
|
85 |
+
self,
|
86 |
+
block: Bottleneck,
|
87 |
+
planes: int,
|
88 |
+
blocks: int,
|
89 |
+
stride: int = 1,
|
90 |
+
) -> nn.Sequential:
|
91 |
+
downsample = None
|
92 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
93 |
+
downsample = nn.Sequential(
|
94 |
+
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
95 |
+
nn.BatchNorm2d(planes * block.expansion),
|
96 |
+
)
|
97 |
+
layers = []
|
98 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
99 |
+
self.inplanes = planes * block.expansion
|
100 |
+
for _ in range(1, blocks):
|
101 |
+
layers.append(block(self.inplanes, planes))
|
102 |
+
return nn.Sequential(*layers)
|
103 |
+
|
104 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
105 |
+
x = self.conv1(x)
|
106 |
+
x = self.bn1(x)
|
107 |
+
x = self.relu(x)
|
108 |
+
x = self.maxpool(x)
|
109 |
+
x = self.layer1(x)
|
110 |
+
x = self.layer2(x)
|
111 |
+
x = self.layer3(x)
|
112 |
+
x = self.layer4(x)
|
113 |
+
x = self.avgpool(x)
|
114 |
+
x = torch.flatten(x, 1)
|
115 |
+
x = self.fc(x)
|
116 |
+
return x
|
117 |
+
|
118 |
+
def load_pretrained_weights(self, weights_path: str) -> None:
|
119 |
+
state_dict = torch.load(weights_path, map_location="cpu")
|
120 |
+
self.load_state_dict(state_dict)
|
121 |
+
|
122 |
+
@torch.inference_mode()
|
123 |
+
def predict(self, x: torch.Tensor, top_k: int | None) -> Union[List[int], List[List[int]]]:
|
124 |
+
output = self.forward(x)
|
125 |
+
probs = torch.nn.functional.softmax(output, dim=1)
|
126 |
+
if top_k is not None:
|
127 |
+
preds = torch.topk(probs, dim=1, k=top_k).indices
|
128 |
+
return preds.tolist()
|
129 |
+
else:
|
130 |
+
pred = torch.argmax(probs, dim=1)
|
131 |
+
return pred.tolist()
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
model = ResNet(weights_path="weights/resnet50-0676ba61.pth")
|
136 |
+
num_params = sum([p.numel() for p in model.parameters()])
|
137 |
+
print(f"params: {num_params/1e6:.2f} M")
|
138 |
+
|
139 |
+
model.eval()
|
140 |
+
image = PIL.Image.open("assets\cat.jpg").convert("RGB")
|
141 |
+
image = F.resize(image, (224, 224))
|
142 |
+
image = F.to_tensor(image)
|
143 |
+
image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
144 |
+
image = image.unsqueeze(0)
|
145 |
+
predicted_class = model.predict(image, top_k=10)
|
146 |
+
print(f"predicted class: {predicted_class}")
|
147 |
+
# https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/
|
src/train.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from tqdm import tqdm
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
import wandb
|
7 |
+
import pickle
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
|
12 |
+
import models
|
13 |
+
from models.resnet50 import ResNet
|
14 |
+
from models.mobilenet_v2 import MobileNetV2
|
15 |
+
from dataset import PlantsDataset
|
16 |
+
from utils import train_transform, test_transform, EMA
|
17 |
+
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
parser = argparse.ArgumentParser(description="Train a model on plant dataset")
|
21 |
+
parser.add_argument("--train-root", type=str, default="data/plants/train", help="Path to the training data")
|
22 |
+
parser.add_argument("--test-root", type=str, default="data/plants/test", help="Path to the testing data")
|
23 |
+
parser.add_argument("--load-to-ram", type=bool, default=False, help="Load dataset to RAM")
|
24 |
+
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing")
|
25 |
+
parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader")
|
26 |
+
parser.add_argument("--num-workers", type=int, default=1, help="Number of workers for DataLoader")
|
27 |
+
parser.add_argument("--num-epochs", type=int, default=10, help="Number of training epochs")
|
28 |
+
parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer")
|
29 |
+
parser.add_argument("--weights-path", type=str, default="weights/mobilenet_v2-b0353104.pth", choices=["weights/resnet50-0676ba61.pth", "weights/mobilenet_v2-b0353104.pth"], help="Path to the pre-trained weights")
|
30 |
+
parser.add_argument("--project-name", type=str, default="plants_classifier", help="WandB project name")
|
31 |
+
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
|
32 |
+
parser.add_argument("--criterion", type=str, default="CrossEntropyLoss", help="Loss function type")
|
33 |
+
parser.add_argument("--labels-path", type=str, default="labels.json", help="Path to the labels json file")
|
34 |
+
parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
|
35 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run the training on")
|
36 |
+
parser.add_argument("--model", type=str, default="mobilenet", choices=["resnet", "mobilenet"], help="Model class name")
|
37 |
+
parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
|
38 |
+
parser.add_argument("--logs-dir", type=str, default="resnet-logs", choices=["resnet-logs", "mobilenet-logs"], help="???")
|
39 |
+
return parser.parse_args()
|
40 |
+
|
41 |
+
|
42 |
+
def main() -> None:
|
43 |
+
args = parse_args()
|
44 |
+
|
45 |
+
with open(args.labels_path, "r") as fp:
|
46 |
+
labels = json.load(fp)
|
47 |
+
num_classes = len(labels)
|
48 |
+
|
49 |
+
logs_dir = Path(args.logs_dir)
|
50 |
+
logs_dir.mkdir(exist_ok=True)
|
51 |
+
|
52 |
+
wandb.init(project=args.project_name, dir=logs_dir)
|
53 |
+
|
54 |
+
train_dataset = PlantsDataset(root=args.train_root, load_to_ram=args.load_to_ram, transform=train_transform, labels=labels)
|
55 |
+
test_dataset = PlantsDataset(root=args.test_root, load_to_ram=args.load_to_ram, transform=test_transform, labels=labels)
|
56 |
+
|
57 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory, num_workers=args.num_workers)
|
58 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory, num_workers=args.num_workers)
|
59 |
+
|
60 |
+
device = torch.device(args.device)
|
61 |
+
|
62 |
+
if args.model == "resnet":
|
63 |
+
model = ResNet(weights_path=args.weights_path)
|
64 |
+
model.fc = nn.Linear(512 * model.expansion, num_classes)
|
65 |
+
nn.init.xavier_uniform_(model.fc.weight)
|
66 |
+
for name, param in model.named_parameters():
|
67 |
+
if "layer4" in name or "fc" in name:
|
68 |
+
param.requires_grad = True
|
69 |
+
else:
|
70 |
+
param.requires_grad = False
|
71 |
+
elif args.model == "mobilenet":
|
72 |
+
model = MobileNetV2(weights_path=args.weights_path)
|
73 |
+
num_ftrs = model.classifier[1].in_features
|
74 |
+
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
|
75 |
+
nn.init.xavier_uniform_(model.classifier[1].weight)
|
76 |
+
for name, param in model.named_parameters():
|
77 |
+
if "classifier" or "features.18" or "features.17" in name:
|
78 |
+
param.requires_grad = True
|
79 |
+
else:
|
80 |
+
param.requires_grad = False
|
81 |
+
model = model.to(device)
|
82 |
+
|
83 |
+
optimizer_class = getattr(torch.optim, args.optimizer)
|
84 |
+
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
|
85 |
+
criterion_class = getattr(nn, args.criterion)
|
86 |
+
criterion = criterion_class()
|
87 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs)
|
88 |
+
|
89 |
+
best_accuracy = 0
|
90 |
+
|
91 |
+
train_loss_ema, train_accuracy_ema, grad_norm_ema = EMA(), EMA(), EMA()
|
92 |
+
for epoch in range(1, args.num_epochs + 1):
|
93 |
+
model.train()
|
94 |
+
pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}")
|
95 |
+
for images, labels in pbar:
|
96 |
+
images = images.to(device)
|
97 |
+
labels = labels.to(device)
|
98 |
+
optimizer.zero_grad()
|
99 |
+
logits = model(images)
|
100 |
+
loss = criterion(logits, labels)
|
101 |
+
loss.backward()
|
102 |
+
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_norm).item()
|
103 |
+
optimizer.step()
|
104 |
+
train_loss = loss.item()
|
105 |
+
train_accuracy = (logits.argmax(dim=1) == labels).sum().item() / logits.shape[0]
|
106 |
+
pbar.set_postfix({"loss": train_loss_ema(train_loss), "accuracy": train_accuracy_ema(train_accuracy), "grad_norm": grad_norm_ema(grad_norm)})
|
107 |
+
wandb.log(
|
108 |
+
{
|
109 |
+
"train/epoch": epoch,
|
110 |
+
"train/loss": train_loss,
|
111 |
+
"train/accuracy": train_accuracy,
|
112 |
+
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
113 |
+
"train/grad_norm": grad_norm,
|
114 |
+
}
|
115 |
+
)
|
116 |
+
|
117 |
+
model.eval()
|
118 |
+
test_loss, test_accuracy = 0.0, 0.0
|
119 |
+
with torch.no_grad():
|
120 |
+
pbar = tqdm(test_loader, desc=f"Val epoch {epoch}/{args.num_epochs}")
|
121 |
+
for images, labels in pbar:
|
122 |
+
images = images.to(device)
|
123 |
+
labels = labels.to(device)
|
124 |
+
logits = model(images)
|
125 |
+
loss = criterion(logits, labels)
|
126 |
+
test_loss += loss.item()
|
127 |
+
test_accuracy += (logits.argmax(dim=1) == labels).sum().item()
|
128 |
+
test_loss /= len(test_loader)
|
129 |
+
test_accuracy /= len(test_loader.dataset)
|
130 |
+
print(f"loss: {test_loss:.3f}, accuracy: {test_accuracy:.3f}")
|
131 |
+
|
132 |
+
wandb.log(
|
133 |
+
{
|
134 |
+
"val/epoch": epoch,
|
135 |
+
"val/test_loss": test_loss,
|
136 |
+
"val/test_accuracy": test_accuracy,
|
137 |
+
}
|
138 |
+
)
|
139 |
+
|
140 |
+
scheduler.step()
|
141 |
+
|
142 |
+
if test_accuracy > best_accuracy:
|
143 |
+
best_accuracy = test_accuracy
|
144 |
+
torch.save(model.state_dict(), logs_dir / f"checkpoint-best-{epoch:09}.pth")
|
145 |
+
elif epoch % args.save_frequency == 0:
|
146 |
+
torch.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
|
147 |
+
|
148 |
+
wandb.finish()
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
main()
|
src/utils.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms as T
|
2 |
+
|
3 |
+
mean = [0.485, 0.456, 0.406]
|
4 |
+
std = [0.229, 0.224, 0.225]
|
5 |
+
|
6 |
+
train_transform = T.Compose([
|
7 |
+
T.RandomRotation(degrees=15),
|
8 |
+
T.RandomResizedCrop(224, scale=(0.5, 1.0)),
|
9 |
+
T.RandomHorizontalFlip(),
|
10 |
+
T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
|
11 |
+
T.ToTensor(),
|
12 |
+
T.Normalize(mean=mean, std=std),
|
13 |
+
])
|
14 |
+
|
15 |
+
test_transform = T.Compose([
|
16 |
+
T.Resize(256),
|
17 |
+
T.CenterCrop(224),
|
18 |
+
T.ToTensor(),
|
19 |
+
T.Normalize(mean=mean, std=std),
|
20 |
+
])
|
21 |
+
|
22 |
+
class EMA:
|
23 |
+
def __init__(self, alpha: float = 0.9) -> None:
|
24 |
+
self.value = None
|
25 |
+
self.alpha = alpha
|
26 |
+
|
27 |
+
def __call__(self, value: float) -> float:
|
28 |
+
if self.value is None:
|
29 |
+
self.value = value
|
30 |
+
else:
|
31 |
+
self.value = self.alpha * self.value + (1 - self.alpha) * value
|
32 |
+
return self.value
|
weights/checkpoint-best-mobilenet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:126f0aab718f57f32e7b5c05898bc65f767c393d2ecb1dcc4a50d220d33a9b80
|
3 |
+
size 9300442
|
weights/checkpoint-best-resnet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3660126925b0296a4b2a64d0248eeea8d36ab3f5bb9e596c89a76d789c11470e
|
3 |
+
size 94601530
|
weights/download_checkpoints.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
wget https://huggingface.co/eksemyashkina/plants-classification/resolve/main/checkpoint-best-mobilenet.pth
|
2 |
+
wget https://huggingface.co/eksemyashkina/plants-classification/resolve/main/checkpoint-best-resnet.pth
|
weights/download_pretrained.py
ADDED
File without changes
|
weights/mobilenet_v2-b0353104.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b03531047ffacf1e2488318dcd2aba1126cde36e3bfe1aa5cb07700aeeee9889
|
3 |
+
size 14212972
|
weights/resnet50-0676ba61.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0676ba61b6795bbe1773cffd859882e5e297624d384b6993f7c9e683e722fb8a
|
3 |
+
size 102530333
|