eksemyashkina commited on
Commit
f096e52
·
1 Parent(s): 8d55205

Added files

Browse files
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import gradio as gr
3
+ import json
4
+ import PIL.Image, PIL.ImageOps
5
+ import torch
6
+ import torchvision.transforms.functional as F
7
+
8
+ from src.models.resnet50 import ResNet
9
+ from src.models.mobilenet_v2 import MobileNetV2
10
+
11
+
12
+ num_classes = 30
13
+ model1 = ResNet(weights_path="weights/checkpoint-best-resnet.pth", num_classes=num_classes)
14
+ model1.eval()
15
+ model2 = MobileNetV2(weights_path="weights/checkpoint-best-mobilenet.pth", num_classes=num_classes)
16
+ model2.eval()
17
+
18
+
19
+ with open("labels.json", "r") as f:
20
+ class_labels = json.load(f)
21
+ label_mapping = {v: k for k, v in class_labels.items()}
22
+
23
+
24
+ def predict(img, model_choice) -> Dict[str, float]:
25
+ model = model1 if model_choice == "ResNet" else model2
26
+ width, height = img.size
27
+ max_dim = max(width, height)
28
+ padding = (max_dim - width, max_dim - height)
29
+ img = PIL.ImageOps.expand(img, padding, (255, 255, 255))
30
+ img = img.resize((224, 224))
31
+ img = F.to_tensor(img)
32
+ img = F.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
+ img = img.unsqueeze(0)
34
+
35
+ with torch.inference_mode():
36
+ logits = model.forward(img)
37
+ probs = torch.nn.functional.softmax(logits, dim=1)
38
+ top_probs, top_indices = probs[0].topk(3)
39
+
40
+ top_classes = {label_mapping[idx.item()]: prob.item() for idx, prob in zip(top_indices, top_probs)}
41
+ return top_classes
42
+
43
+
44
+ examples = [
45
+ ["assets/banana.jpg"],
46
+ ["assets/pineapple.jpg"],
47
+ ["assets/mango.jpg"],
48
+ ["assets/melon.jpg"],
49
+ ["assets/orange.jpg"],
50
+ ["assets/eggplant.jpg"],
51
+ ["assets/black.jpg"],
52
+ ["assets/white.jpg"]
53
+ ]
54
+
55
+
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("## Plant Classification")
58
+ with gr.Row():
59
+ with gr.Column():
60
+ pic = gr.Image(label="Upload Plant Image", type="pil", height=300, width=300)
61
+ model_choice = gr.Dropdown(choices=["ResNet", "MobileNetV2"], label="Select Model", value="ResNet")
62
+ with gr.Row():
63
+ with gr.Column(scale=1):
64
+ predict_btn = gr.Button("Predict")
65
+ with gr.Column(scale=1):
66
+ clear_btn = gr.Button("Clear")
67
+
68
+ with gr.Column():
69
+ output = gr.Label(label="Top 3 Predicted Classes")
70
+
71
+ predict_btn.click(fn=predict, inputs=[pic, model_choice], outputs=output, api_name="predict")
72
+ clear_btn.click(lambda: (None, None), outputs=[pic, output])
73
+ gr.Examples(examples=examples, inputs=[pic])
74
+
75
+ demo.launch()
assets/banana.jpg ADDED
assets/black.jpg ADDED
assets/eggplant.jpg ADDED
assets/mango.jpg ADDED
assets/melon.jpg ADDED
assets/orange.jpg ADDED
assets/pineapple.jpg ADDED
assets/white.jpg ADDED
labels.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "aloevera": 0,
3
+ "banana": 1,
4
+ "bilimbi": 2,
5
+ "cantaloupe": 3,
6
+ "cassava": 4,
7
+ "coconut": 5,
8
+ "corn": 6,
9
+ "cucumber": 7,
10
+ "curcuma": 8,
11
+ "eggplant": 9,
12
+ "galangal": 10,
13
+ "ginger": 11,
14
+ "guava": 12,
15
+ "kale": 13,
16
+ "longbeans": 14,
17
+ "mango": 15,
18
+ "melon": 16,
19
+ "orange": 17,
20
+ "paddy": 18,
21
+ "papaya": 19,
22
+ "peperchili": 20,
23
+ "pineapple": 21,
24
+ "pomelo": 22,
25
+ "shallot": 23,
26
+ "soybeans": 24,
27
+ "spinach": 25,
28
+ "sweetpotatoes": 26,
29
+ "tobacco": 27,
30
+ "waterapple": 28,
31
+ "watermelon": 29
32
+ }
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchvision==0.19.1
3
+ kaggle==1.6.17
4
+ wandb==0.18.5
5
+ gradio==5.4.0
src/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (713 Bytes). View file
 
src/dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple, Callable
2
+ from pathlib import Path
3
+ from torch.utils.data import Dataset
4
+ import PIL.Image
5
+ import torch
6
+
7
+
8
+ class PlantsDataset(Dataset):
9
+ def __init__(
10
+ self,
11
+ root: str,
12
+ labels: Dict[int, str],
13
+ transform: Callable,
14
+ load_to_ram: bool = True,
15
+ ) -> None:
16
+ super().__init__()
17
+ self.root = root
18
+ self.labels = labels
19
+ self.transform = transform
20
+ self.load_to_ram = load_to_ram
21
+
22
+ self.data = [
23
+ {
24
+ "path": x.as_posix(),
25
+ "label": self.labels[x.parent.name],
26
+ "image": PIL.Image.open(x).convert("RGB") if self.load_to_ram else None,
27
+ }
28
+ for x in sorted(Path(self.root).glob("**/*.jpg"))
29
+ ]
30
+
31
+ def __len__(self) -> int:
32
+ return len(self.data)
33
+
34
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ item = self.data[idx]
36
+ if self.load_to_ram:
37
+ image = item["image"]
38
+ else:
39
+ image = PIL.Image.open(item["path"]).convert("RGB")
40
+ image = self.transform(image)
41
+ label = torch.tensor(item["label"], dtype=torch.long)
42
+ return (image, label)
43
+
44
+
45
+ def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
46
+ images = torch.cat([item[0].unsqueeze(0) for item in items])
47
+ labels = torch.cat([item[1].unsqueeze(0) for item in items])
48
+ return (images, labels)
src/models/__pycache__/mobilenet_v2.cpython-310.pyc ADDED
Binary file (5.32 kB). View file
 
src/models/__pycache__/mobilenet_v2.cpython-312.pyc ADDED
Binary file (9.01 kB). View file
 
src/models/__pycache__/resnet50.cpython-310.pyc ADDED
Binary file (5.16 kB). View file
 
src/models/__pycache__/resnet50.cpython-312.pyc ADDED
Binary file (9.77 kB). View file
 
src/models/mobilenet_v2.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Union
2
+ from pathlib import Path
3
+ import PIL.Image
4
+ import torch
5
+ from torch import nn
6
+ import torchvision.transforms.functional as F
7
+
8
+
9
+ class Conv2dNormActivation(nn.Module):
10
+ def __init__(
11
+ self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
12
+ padding: int | None = None, groups: int = 1, norm_layer: Callable[..., torch.nn.Module] = nn.BatchNorm2d,
13
+ activation_layer: Callable[..., torch.nn.Module] = nn.ReLU, bias: bool | None = False,
14
+ ) -> None:
15
+ super().__init__()
16
+ if padding is None:
17
+ padding = (kernel_size - 1) // 2
18
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, groups=groups)
19
+ self.norm = norm_layer(out_channels)
20
+ self.activation = activation_layer()
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ x = self.conv(x)
24
+ x = self.norm(x)
25
+ x = self.activation(x)
26
+ return x
27
+
28
+
29
+ class InvertedResidual(nn.Module):
30
+ def __init__(
31
+ self, inp: int, oup: int, stride: int, expand_ratio: int,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.stride = stride
35
+
36
+ hidden_dim = int(round(inp * expand_ratio))
37
+ self.use_res_connect = self.stride == 1 and inp == oup
38
+
39
+ layers = []
40
+ if expand_ratio != 1:
41
+ layers.append(Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU6))
42
+ layers.extend(
43
+ [
44
+ Conv2dNormActivation(
45
+ hidden_dim,
46
+ hidden_dim,
47
+ stride=stride,
48
+ groups=hidden_dim,
49
+ norm_layer=nn.BatchNorm2d,
50
+ activation_layer=nn.ReLU6,
51
+ ),
52
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
53
+ nn.BatchNorm2d(oup),
54
+ ]
55
+ )
56
+ self.conv = nn.Sequential(*layers)
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ if self.use_res_connect:
60
+ return x + self.conv(x)
61
+ else:
62
+ return self.conv(x)
63
+
64
+
65
+ class MobileNetV2(nn.Module):
66
+ def __init__(
67
+ self,
68
+ num_classes: int = 1000,
69
+ weights_path: str | None = None,
70
+ ) -> None:
71
+ super().__init__()
72
+
73
+ if weights_path is not None and not Path(weights_path).exists():
74
+ raise FileNotFoundError(weights_path)
75
+
76
+ input_channel = 32
77
+ last_channel = 1280
78
+ inverted_residual_setting = [
79
+ # t, c, n, s
80
+ [1, 16, 1, 1],
81
+ [6, 24, 2, 2],
82
+ [6, 32, 3, 2],
83
+ [6, 64, 4, 2],
84
+ [6, 96, 3, 1],
85
+ [6, 160, 3, 2],
86
+ [6, 320, 1, 1],
87
+ ]
88
+
89
+ features = [Conv2dNormActivation(3, input_channel, stride=2, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU6)]
90
+ for t, c, n, s in inverted_residual_setting:
91
+ output_channel = c
92
+ for i in range(n):
93
+ stride = s if i == 0 else 1
94
+ features.append(InvertedResidual(input_channel, output_channel, stride, expand_ratio=t))
95
+ input_channel = output_channel
96
+ features.append(
97
+ Conv2dNormActivation(
98
+ input_channel, last_channel, kernel_size=1, norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU6
99
+ )
100
+ )
101
+ self.features = nn.Sequential(*features)
102
+
103
+ self.classifier = nn.Sequential(
104
+ nn.Dropout(p=0.2),
105
+ nn.Linear(last_channel, num_classes),
106
+ )
107
+
108
+ if weights_path:
109
+ self.load_pretrained_weights(weights_path)
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ x = self.features(x)
113
+ x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
114
+ x = torch.flatten(x, 1)
115
+ x = self.classifier(x)
116
+ return x
117
+
118
+ def load_pretrained_weights(self, weights_path: str) -> None:
119
+ state_dict = torch.load(weights_path, map_location="cpu")
120
+ model_state_dict = self.state_dict()
121
+ new_state_dict = {}
122
+ for key1, key2 in zip(model_state_dict.keys(), state_dict.keys()):
123
+ new_state_dict[key1] = state_dict[key2]
124
+ self.load_state_dict(new_state_dict)
125
+
126
+ @torch.inference_mode()
127
+ def predict(self, x: torch.Tensor, top_k: int | None) -> Union[List[int], List[List[int]]]:
128
+ output = self.forward(x)
129
+ probs = torch.nn.functional.softmax(output, dim=1)
130
+ if top_k is not None:
131
+ preds = torch.topk(probs, dim=1, k=top_k).indices
132
+ return preds.tolist()
133
+ else:
134
+ pred = torch.argmax(probs, dim=1)
135
+ return pred.tolist()
136
+
137
+
138
+ if __name__ == "__main__":
139
+ model = MobileNetV2(weights_path="weights\mobilenet_v2-b0353104.pth")
140
+ num_params = sum([p.numel() for p in model.parameters()])
141
+ print(f"params: {num_params/1e6:.2f} M")
142
+
143
+ model.eval()
144
+ image = PIL.Image.open("assets\cat.jpg").convert("RGB")
145
+ image = F.resize(image, (224, 224))
146
+ image = F.to_tensor(image)
147
+ image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
148
+ image = image.unsqueeze(0)
149
+ predicted_class = model.predict(image, top_k=10)
150
+ print(f"predicted class: {predicted_class}")
151
+ # https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/
src/models/resnet50.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+ from pathlib import Path
3
+ import PIL.Image
4
+ import torch
5
+ from torch import nn
6
+ import torchvision.transforms.functional as F
7
+
8
+
9
+ class Bottleneck(nn.Module):
10
+ expansion: int = 4
11
+
12
+ def __init__(
13
+ self,
14
+ inplanes: int,
15
+ planes: int,
16
+ stride: int = 1,
17
+ downsample: nn.Module | None = None,
18
+ groups: int = 1,
19
+ dilation: int = 1,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=groups, dilation=dilation, bias=False)
25
+ self.bn2 = nn.BatchNorm2d(planes)
26
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1, bias=False)
27
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = downsample
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ identity = x
33
+ out = self.conv1(x)
34
+ out = self.bn1(out)
35
+ out = self.relu(out)
36
+ out = self.conv2(out)
37
+ out = self.bn2(out)
38
+ out = self.relu(out)
39
+ out = self.conv3(out)
40
+ out = self.bn3(out)
41
+ if self.downsample is not None:
42
+ identity = self.downsample(x)
43
+ out += identity
44
+ out = self.relu(out)
45
+ return out
46
+
47
+
48
+ class ResNet(nn.Module):
49
+ @property
50
+ def expansion(self):
51
+ return 4
52
+ def __init__(
53
+ self,
54
+ num_classes: int = 1000,
55
+ weights_path: str | None = None,
56
+ ) -> None:
57
+ super().__init__()
58
+
59
+ if weights_path is not None and not Path(weights_path).exists():
60
+ raise FileNotFoundError(weights_path)
61
+
62
+ self.inplanes = 64
63
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
64
+ self.bn1 = nn.BatchNorm2d(self.inplanes)
65
+ self.relu = nn.ReLU(inplace=True)
66
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
67
+ self.layer1 = self._make_layer(Bottleneck, 64, 3)
68
+ self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
69
+ self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
70
+ self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)
71
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
72
+ self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)
73
+
74
+ for m in self.modules():
75
+ if isinstance(m, nn.Conv2d):
76
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
77
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
78
+ nn.init.constant_(m.weight, 1)
79
+ nn.init.constant_(m.bias, 0)
80
+
81
+ if weights_path:
82
+ self.load_pretrained_weights(weights_path)
83
+
84
+ def _make_layer(
85
+ self,
86
+ block: Bottleneck,
87
+ planes: int,
88
+ blocks: int,
89
+ stride: int = 1,
90
+ ) -> nn.Sequential:
91
+ downsample = None
92
+ if stride != 1 or self.inplanes != planes * block.expansion:
93
+ downsample = nn.Sequential(
94
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
95
+ nn.BatchNorm2d(planes * block.expansion),
96
+ )
97
+ layers = []
98
+ layers.append(block(self.inplanes, planes, stride, downsample))
99
+ self.inplanes = planes * block.expansion
100
+ for _ in range(1, blocks):
101
+ layers.append(block(self.inplanes, planes))
102
+ return nn.Sequential(*layers)
103
+
104
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
105
+ x = self.conv1(x)
106
+ x = self.bn1(x)
107
+ x = self.relu(x)
108
+ x = self.maxpool(x)
109
+ x = self.layer1(x)
110
+ x = self.layer2(x)
111
+ x = self.layer3(x)
112
+ x = self.layer4(x)
113
+ x = self.avgpool(x)
114
+ x = torch.flatten(x, 1)
115
+ x = self.fc(x)
116
+ return x
117
+
118
+ def load_pretrained_weights(self, weights_path: str) -> None:
119
+ state_dict = torch.load(weights_path, map_location="cpu")
120
+ self.load_state_dict(state_dict)
121
+
122
+ @torch.inference_mode()
123
+ def predict(self, x: torch.Tensor, top_k: int | None) -> Union[List[int], List[List[int]]]:
124
+ output = self.forward(x)
125
+ probs = torch.nn.functional.softmax(output, dim=1)
126
+ if top_k is not None:
127
+ preds = torch.topk(probs, dim=1, k=top_k).indices
128
+ return preds.tolist()
129
+ else:
130
+ pred = torch.argmax(probs, dim=1)
131
+ return pred.tolist()
132
+
133
+
134
+ if __name__ == "__main__":
135
+ model = ResNet(weights_path="weights/resnet50-0676ba61.pth")
136
+ num_params = sum([p.numel() for p in model.parameters()])
137
+ print(f"params: {num_params/1e6:.2f} M")
138
+
139
+ model.eval()
140
+ image = PIL.Image.open("assets\cat.jpg").convert("RGB")
141
+ image = F.resize(image, (224, 224))
142
+ image = F.to_tensor(image)
143
+ image = F.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
144
+ image = image.unsqueeze(0)
145
+ predicted_class = model.predict(image, top_k=10)
146
+ print(f"predicted class: {predicted_class}")
147
+ # https://deeplearning.cms.waikato.ac.nz/user-guide/class-maps/IMAGENET/
src/train.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from tqdm import tqdm
3
+ import numpy as np
4
+ import argparse
5
+ import json
6
+ import wandb
7
+ import pickle
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader
11
+
12
+ import models
13
+ from models.resnet50 import ResNet
14
+ from models.mobilenet_v2 import MobileNetV2
15
+ from dataset import PlantsDataset
16
+ from utils import train_transform, test_transform, EMA
17
+
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser(description="Train a model on plant dataset")
21
+ parser.add_argument("--train-root", type=str, default="data/plants/train", help="Path to the training data")
22
+ parser.add_argument("--test-root", type=str, default="data/plants/test", help="Path to the testing data")
23
+ parser.add_argument("--load-to-ram", type=bool, default=False, help="Load dataset to RAM")
24
+ parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing")
25
+ parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader")
26
+ parser.add_argument("--num-workers", type=int, default=1, help="Number of workers for DataLoader")
27
+ parser.add_argument("--num-epochs", type=int, default=10, help="Number of training epochs")
28
+ parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer")
29
+ parser.add_argument("--weights-path", type=str, default="weights/mobilenet_v2-b0353104.pth", choices=["weights/resnet50-0676ba61.pth", "weights/mobilenet_v2-b0353104.pth"], help="Path to the pre-trained weights")
30
+ parser.add_argument("--project-name", type=str, default="plants_classifier", help="WandB project name")
31
+ parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
32
+ parser.add_argument("--criterion", type=str, default="CrossEntropyLoss", help="Loss function type")
33
+ parser.add_argument("--labels-path", type=str, default="labels.json", help="Path to the labels json file")
34
+ parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
35
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run the training on")
36
+ parser.add_argument("--model", type=str, default="mobilenet", choices=["resnet", "mobilenet"], help="Model class name")
37
+ parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
38
+ parser.add_argument("--logs-dir", type=str, default="resnet-logs", choices=["resnet-logs", "mobilenet-logs"], help="???")
39
+ return parser.parse_args()
40
+
41
+
42
+ def main() -> None:
43
+ args = parse_args()
44
+
45
+ with open(args.labels_path, "r") as fp:
46
+ labels = json.load(fp)
47
+ num_classes = len(labels)
48
+
49
+ logs_dir = Path(args.logs_dir)
50
+ logs_dir.mkdir(exist_ok=True)
51
+
52
+ wandb.init(project=args.project_name, dir=logs_dir)
53
+
54
+ train_dataset = PlantsDataset(root=args.train_root, load_to_ram=args.load_to_ram, transform=train_transform, labels=labels)
55
+ test_dataset = PlantsDataset(root=args.test_root, load_to_ram=args.load_to_ram, transform=test_transform, labels=labels)
56
+
57
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory, num_workers=args.num_workers)
58
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory, num_workers=args.num_workers)
59
+
60
+ device = torch.device(args.device)
61
+
62
+ if args.model == "resnet":
63
+ model = ResNet(weights_path=args.weights_path)
64
+ model.fc = nn.Linear(512 * model.expansion, num_classes)
65
+ nn.init.xavier_uniform_(model.fc.weight)
66
+ for name, param in model.named_parameters():
67
+ if "layer4" in name or "fc" in name:
68
+ param.requires_grad = True
69
+ else:
70
+ param.requires_grad = False
71
+ elif args.model == "mobilenet":
72
+ model = MobileNetV2(weights_path=args.weights_path)
73
+ num_ftrs = model.classifier[1].in_features
74
+ model.classifier[1] = nn.Linear(num_ftrs, num_classes)
75
+ nn.init.xavier_uniform_(model.classifier[1].weight)
76
+ for name, param in model.named_parameters():
77
+ if "classifier" or "features.18" or "features.17" in name:
78
+ param.requires_grad = True
79
+ else:
80
+ param.requires_grad = False
81
+ model = model.to(device)
82
+
83
+ optimizer_class = getattr(torch.optim, args.optimizer)
84
+ optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
85
+ criterion_class = getattr(nn, args.criterion)
86
+ criterion = criterion_class()
87
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs)
88
+
89
+ best_accuracy = 0
90
+
91
+ train_loss_ema, train_accuracy_ema, grad_norm_ema = EMA(), EMA(), EMA()
92
+ for epoch in range(1, args.num_epochs + 1):
93
+ model.train()
94
+ pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}")
95
+ for images, labels in pbar:
96
+ images = images.to(device)
97
+ labels = labels.to(device)
98
+ optimizer.zero_grad()
99
+ logits = model(images)
100
+ loss = criterion(logits, labels)
101
+ loss.backward()
102
+ grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_norm).item()
103
+ optimizer.step()
104
+ train_loss = loss.item()
105
+ train_accuracy = (logits.argmax(dim=1) == labels).sum().item() / logits.shape[0]
106
+ pbar.set_postfix({"loss": train_loss_ema(train_loss), "accuracy": train_accuracy_ema(train_accuracy), "grad_norm": grad_norm_ema(grad_norm)})
107
+ wandb.log(
108
+ {
109
+ "train/epoch": epoch,
110
+ "train/loss": train_loss,
111
+ "train/accuracy": train_accuracy,
112
+ "train/learning_rate": optimizer.param_groups[0]["lr"],
113
+ "train/grad_norm": grad_norm,
114
+ }
115
+ )
116
+
117
+ model.eval()
118
+ test_loss, test_accuracy = 0.0, 0.0
119
+ with torch.no_grad():
120
+ pbar = tqdm(test_loader, desc=f"Val epoch {epoch}/{args.num_epochs}")
121
+ for images, labels in pbar:
122
+ images = images.to(device)
123
+ labels = labels.to(device)
124
+ logits = model(images)
125
+ loss = criterion(logits, labels)
126
+ test_loss += loss.item()
127
+ test_accuracy += (logits.argmax(dim=1) == labels).sum().item()
128
+ test_loss /= len(test_loader)
129
+ test_accuracy /= len(test_loader.dataset)
130
+ print(f"loss: {test_loss:.3f}, accuracy: {test_accuracy:.3f}")
131
+
132
+ wandb.log(
133
+ {
134
+ "val/epoch": epoch,
135
+ "val/test_loss": test_loss,
136
+ "val/test_accuracy": test_accuracy,
137
+ }
138
+ )
139
+
140
+ scheduler.step()
141
+
142
+ if test_accuracy > best_accuracy:
143
+ best_accuracy = test_accuracy
144
+ torch.save(model.state_dict(), logs_dir / f"checkpoint-best-{epoch:09}.pth")
145
+ elif epoch % args.save_frequency == 0:
146
+ torch.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
147
+
148
+ wandb.finish()
149
+
150
+
151
+ if __name__ == "__main__":
152
+ main()
src/utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as T
2
+
3
+ mean = [0.485, 0.456, 0.406]
4
+ std = [0.229, 0.224, 0.225]
5
+
6
+ train_transform = T.Compose([
7
+ T.RandomRotation(degrees=15),
8
+ T.RandomResizedCrop(224, scale=(0.5, 1.0)),
9
+ T.RandomHorizontalFlip(),
10
+ T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
11
+ T.ToTensor(),
12
+ T.Normalize(mean=mean, std=std),
13
+ ])
14
+
15
+ test_transform = T.Compose([
16
+ T.Resize(256),
17
+ T.CenterCrop(224),
18
+ T.ToTensor(),
19
+ T.Normalize(mean=mean, std=std),
20
+ ])
21
+
22
+ class EMA:
23
+ def __init__(self, alpha: float = 0.9) -> None:
24
+ self.value = None
25
+ self.alpha = alpha
26
+
27
+ def __call__(self, value: float) -> float:
28
+ if self.value is None:
29
+ self.value = value
30
+ else:
31
+ self.value = self.alpha * self.value + (1 - self.alpha) * value
32
+ return self.value
weights/checkpoint-best-mobilenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:126f0aab718f57f32e7b5c05898bc65f767c393d2ecb1dcc4a50d220d33a9b80
3
+ size 9300442
weights/checkpoint-best-resnet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3660126925b0296a4b2a64d0248eeea8d36ab3f5bb9e596c89a76d789c11470e
3
+ size 94601530
weights/download_checkpoints.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wget https://huggingface.co/eksemyashkina/plants-classification/resolve/main/checkpoint-best-mobilenet.pth
2
+ wget https://huggingface.co/eksemyashkina/plants-classification/resolve/main/checkpoint-best-resnet.pth
weights/download_pretrained.py ADDED
File without changes
weights/mobilenet_v2-b0353104.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b03531047ffacf1e2488318dcd2aba1126cde36e3bfe1aa5cb07700aeeee9889
3
+ size 14212972
weights/resnet50-0676ba61.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0676ba61b6795bbe1773cffd859882e5e297624d384b6993f7c9e683e722fb8a
3
+ size 102530333