Ubuntu
add application file
c6827c1
import argparse
import os
import torch
from detector import *
from backbone import *
from loss import *
from data import Therin
import datetime
from detector.fasterRCNN import FasterRCNN
from backbone.densenet import DenseNet
from utils.engine import *
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights
parser = argparse.ArgumentParser("Intruder_Thermal_Dataset")
# Model Settings
parser.add_argument('--detector', type=str, default='fasterRCNN', help='detector name')
parser.add_argument('--backbone', type=str, default='densenet', help='backbone name')
parser.add_argument('--loss', type=str, default='focalloss', help='loss name')
parser.add_argument('--modelscale', type=float, default=1.0, help='model scale')
# Training Settings
parser.add_argument('--batch', type=int, default=4, help='batch size')
parser.add_argument('--epoch', type=int, default=10, help='epochs number')
parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--decay', type=float, default=3e-4, help='weight decay')
# Dataset Settings
parser.add_argument('--data_dir', type=str, default='./dataset', help='dataset dir')
args = parser.parse_args()
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device training.".format(device.type))
timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S%f")
print(timestr)
model_save_dir = timestr
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
num_classes = 5
# Load data
train_dataset = Therin(args.data_dir, 'train')
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.batch,
shuffle=True,
num_workers=0,
collate_fn=train_dataset.collate_fn)
test_dataset = Therin(args.data_dir, 'test')
test_dataloader = torch.utils.data.DataLoader(test_dataset,
batch_size=args.batch,
shuffle=True,
num_workers=0,
collate_fn=test_dataset.collate_fn)
# Create model
backbone = resnet_fpn_backbone('resnet18', False)
model = FasterRCNN(backbone, num_classes)
model.to(device)
# Define optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=args.lr,
momentum=args.momentum, weight_decay=args.decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
#Training
for epoch in range(args.epoch):
# train for one epoch
loss_dict, total_loss = train_one_epoch(model, optimizer, train_dataloader, device, epoch, print_freq=1)
# update the learning rate
lr_scheduler.step()
# evaluate on the test dataset
_, mAP = evaluate(model, test_dataloader, device=device)
print('validation mAp is {}'.format(mAP))
# save weights
save_files = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'loss_dict': loss_dict,
'total_loss': total_loss}
torch.save(save_files,
os.path.join(model_save_dir, "{}-model-{}-mAp-{}.pth".format(args.backbone, epoch, mAP)))
if __name__ == '__main__':
main()