import numpy as np import torch import torchvision.transforms as transforms import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torchvision import transforms, datasets, models # Define model class RetinaNet(nn.Module): def __init__(self, num_classes=2): super(RetinaNet, self).__init__() self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) # Freeze backbone parameters for param in self.backbone.parameters(): param.requires_grad = False # Unfreeze later layers for param in self.backbone.layer3.parameters(): param.requires_grad = True for param in self.backbone.layer4.parameters(): param.requires_grad = False # Modified classifier head self.classifier = nn.Sequential( nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes) # nn.Sigmoid() ) def forward(self, x): x = self.backbone.conv1(x) x = self.backbone.bn1(x) x = self.backbone.relu(x) x = self.backbone.maxpool(x) x = self.backbone.layer1(x) x = self.backbone.layer2(x) x = self.backbone.layer3(x) x = self.backbone.layer4(x) x = self.backbone.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x