|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision import transforms, datasets, models
|
|
|
|
|
|
class RetinaNet(nn.Module):
|
|
def __init__(self, num_classes=2):
|
|
super(RetinaNet, self).__init__()
|
|
self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
|
|
|
|
|
for param in self.backbone.parameters():
|
|
param.requires_grad = False
|
|
|
|
|
|
for param in self.backbone.layer3.parameters():
|
|
param.requires_grad = True
|
|
for param in self.backbone.layer4.parameters():
|
|
param.requires_grad = False
|
|
|
|
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(2048, 512),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(512, num_classes)
|
|
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.backbone.conv1(x)
|
|
x = self.backbone.bn1(x)
|
|
x = self.backbone.relu(x)
|
|
x = self.backbone.maxpool(x)
|
|
|
|
x = self.backbone.layer1(x)
|
|
x = self.backbone.layer2(x)
|
|
x = self.backbone.layer3(x)
|
|
x = self.backbone.layer4(x)
|
|
|
|
x = self.backbone.avgpool(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.classifier(x)
|
|
return x |