Ubuntu commited on
Commit
c6827c1
·
1 Parent(s): 4a3d4e5

add application file

Browse files
.gitignore CHANGED
@@ -1,6 +1,2 @@
1
  __pycache__
2
- model/dataset
3
- model/img
4
- model/loss
5
- model/PIE
6
  */__pycache__
 
1
  __pycache__
 
 
 
 
2
  */__pycache__
img/sample//343/203/206/343/202/271/343/203/210/347/224/273/345/203/217/357/274/221.JPG ADDED
img/sample//343/203/206/343/202/271/343/203/210/347/224/273/345/203/217/357/274/222.JPG ADDED
img/sample//343/203/206/343/202/271/343/203/210/347/224/273/345/203/217/357/274/223.JPG ADDED
model/___init__.py ADDED
File without changes
model/data.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchvision import transforms as T
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ from torchvision.ops.boxes import box_convert
7
+ import glob
8
+ import json
9
+
10
+ # dataset
11
+ # |_ train
12
+ # | |_ .BMP files
13
+ # | |_ annotation
14
+ # | |_ classes.txt (one class per line)
15
+ # | |_ .txt anno files (class x_center y_center width height)
16
+ # |_ test
17
+ # |_ val
18
+
19
+ class Therin(Dataset): # Therin: Intruder thermal dataset
20
+ def __init__(self, dir, set):
21
+ self._dir = dir + '/' + set
22
+ # self._imglist = glob.glob(self._dir + '/*.BMP')
23
+ self._json_path = dir + '/' + set + '.json'
24
+ with open(self._json_path) as anno_file:
25
+ self._anno = json.load(anno_file)["annotations"]
26
+ with open(self._json_path) as anno_file:
27
+ self._imglist = json.load(anno_file)["images"]
28
+ self._transform = T.Compose([T.ToTensor()])
29
+
30
+ def __len__(self):
31
+ return len(self._imglist)
32
+
33
+ def __getitem__(self, index):
34
+ image = Image.open(self._dir + "/" + self._imglist[index]["file_name"])
35
+
36
+ boxes = np.zeros((1, 4), dtype=np.float32)
37
+ boxes[0] = self._anno[index]['bbox']
38
+ boxes = torch.as_tensor(boxes, dtype=torch.float32)
39
+ boxes = box_convert(boxes, in_fmt='xywh', out_fmt='xyxy')
40
+
41
+ gt_classes = np.zeros((1), dtype=np.int32)
42
+ gt_classes[0] = self._anno[index]['category_id']-1
43
+ gt_classes = torch.as_tensor(gt_classes, dtype=torch.int64)
44
+
45
+ image_id = self._anno[index]['image_id']
46
+ image_id = torch.as_tensor(image_id, dtype=torch.int64)
47
+
48
+ area = np.zeros((1), dtype=np.int32)
49
+ area[0] = self._anno[index]['area']
50
+ area = torch.as_tensor(area, dtype=torch.int64)
51
+
52
+ target = {"labels": gt_classes, "boxes": boxes, "image_id": image_id, "area": area, "iscrowd": torch.tensor([0])}
53
+
54
+ image = self._transform(image)
55
+ return image, target
56
+
57
+ def collate_fn(self, batch):
58
+ return tuple(zip(*batch))
model/model/densenet-model-9-mAp--1.0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66644fd209b9317352a578b24af5e6d1cdcf7e0d34d6093588e876135732dd4e
3
+ size 225829605
model/train.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from detector import *
5
+ from backbone import *
6
+ from loss import *
7
+ from data import Therin
8
+ import datetime
9
+ from detector.fasterRCNN import FasterRCNN
10
+ from backbone.densenet import DenseNet
11
+ from utils.engine import *
12
+ from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor
13
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights
14
+
15
+ parser = argparse.ArgumentParser("Intruder_Thermal_Dataset")
16
+
17
+ # Model Settings
18
+ parser.add_argument('--detector', type=str, default='fasterRCNN', help='detector name')
19
+ parser.add_argument('--backbone', type=str, default='densenet', help='backbone name')
20
+ parser.add_argument('--loss', type=str, default='focalloss', help='loss name')
21
+ parser.add_argument('--modelscale', type=float, default=1.0, help='model scale')
22
+
23
+ # Training Settings
24
+ parser.add_argument('--batch', type=int, default=4, help='batch size')
25
+ parser.add_argument('--epoch', type=int, default=10, help='epochs number')
26
+ parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate')
27
+ parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
28
+ parser.add_argument('--decay', type=float, default=3e-4, help='weight decay')
29
+
30
+ # Dataset Settings
31
+ parser.add_argument('--data_dir', type=str, default='./dataset', help='dataset dir')
32
+
33
+
34
+ args = parser.parse_args()
35
+
36
+ def main():
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ print("Using {} device training.".format(device.type))
39
+
40
+ timestr = datetime.datetime.now().strftime("%Y%m%d-%H%M%S%f")
41
+ print(timestr)
42
+ model_save_dir = timestr
43
+ if not os.path.exists(model_save_dir):
44
+ os.makedirs(model_save_dir)
45
+ num_classes = 5
46
+ # Load data
47
+ train_dataset = Therin(args.data_dir, 'train')
48
+ train_dataloader = torch.utils.data.DataLoader(train_dataset,
49
+ batch_size=args.batch,
50
+ shuffle=True,
51
+ num_workers=0,
52
+ collate_fn=train_dataset.collate_fn)
53
+ test_dataset = Therin(args.data_dir, 'test')
54
+ test_dataloader = torch.utils.data.DataLoader(test_dataset,
55
+ batch_size=args.batch,
56
+ shuffle=True,
57
+ num_workers=0,
58
+ collate_fn=test_dataset.collate_fn)
59
+
60
+
61
+ # Create model
62
+ backbone = resnet_fpn_backbone('resnet18', False)
63
+ model = FasterRCNN(backbone, num_classes)
64
+ model.to(device)
65
+
66
+ # Define optimizer
67
+ params = [p for p in model.parameters() if p.requires_grad]
68
+ optimizer = torch.optim.SGD(params, lr=args.lr,
69
+ momentum=args.momentum, weight_decay=args.decay)
70
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
71
+ step_size=3,
72
+ gamma=0.1)
73
+
74
+ #Training
75
+ for epoch in range(args.epoch):
76
+ # train for one epoch
77
+ loss_dict, total_loss = train_one_epoch(model, optimizer, train_dataloader, device, epoch, print_freq=1)
78
+ # update the learning rate
79
+ lr_scheduler.step()
80
+ # evaluate on the test dataset
81
+ _, mAP = evaluate(model, test_dataloader, device=device)
82
+ print('validation mAp is {}'.format(mAP))
83
+ # save weights
84
+ save_files = {
85
+ 'model': model.state_dict(),
86
+ 'optimizer': optimizer.state_dict(),
87
+ 'lr_scheduler': lr_scheduler.state_dict(),
88
+ 'epoch': epoch,
89
+ 'loss_dict': loss_dict,
90
+ 'total_loss': total_loss}
91
+ torch.save(save_files,
92
+ os.path.join(model_save_dir, "{}-model-{}-mAp-{}.pth".format(args.backbone, epoch, mAP)))
93
+
94
+ if __name__ == '__main__':
95
+ main()
model/utils/coco_eval.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import io
3
+ from contextlib import redirect_stdout
4
+
5
+ import numpy as np
6
+ import pycocotools.mask as mask_util
7
+ import torch
8
+ from .utils import *
9
+ from pycocotools.coco import COCO
10
+ from pycocotools.cocoeval import COCOeval
11
+
12
+
13
+ class CocoEvaluator:
14
+ def __init__(self, coco_gt, iou_types):
15
+ if not isinstance(iou_types, (list, tuple)):
16
+ raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
17
+ coco_gt = copy.deepcopy(coco_gt)
18
+ self.coco_gt = coco_gt
19
+
20
+ self.iou_types = iou_types
21
+ self.coco_eval = {}
22
+ for iou_type in iou_types:
23
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
24
+
25
+ self.img_ids = []
26
+ self.eval_imgs = {k: [] for k in iou_types}
27
+
28
+ def update(self, predictions):
29
+ img_ids = list(np.unique(list(predictions.keys())))
30
+ self.img_ids.extend(img_ids)
31
+
32
+ for iou_type in self.iou_types:
33
+ results = self.prepare(predictions, iou_type)
34
+ with redirect_stdout(io.StringIO()):
35
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
36
+ coco_eval = self.coco_eval[iou_type]
37
+
38
+ coco_eval.cocoDt = coco_dt
39
+ coco_eval.params.imgIds = list(img_ids)
40
+ img_ids, eval_imgs = evaluate(coco_eval)
41
+
42
+ self.eval_imgs[iou_type].append(eval_imgs)
43
+
44
+ def synchronize_between_processes(self):
45
+ for iou_type in self.iou_types:
46
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
47
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
48
+
49
+ def accumulate(self):
50
+ for coco_eval in self.coco_eval.values():
51
+ coco_eval.accumulate()
52
+
53
+ def summarize(self):
54
+ for iou_type, coco_eval in self.coco_eval.items():
55
+ print(f"IoU metric: {iou_type}")
56
+ coco_eval.summarize()
57
+
58
+ def prepare(self, predictions, iou_type):
59
+ if iou_type == "bbox":
60
+ return self.prepare_for_coco_detection(predictions)
61
+ if iou_type == "segm":
62
+ return self.prepare_for_coco_segmentation(predictions)
63
+ if iou_type == "keypoints":
64
+ return self.prepare_for_coco_keypoint(predictions)
65
+ raise ValueError(f"Unknown iou type {iou_type}")
66
+
67
+ def prepare_for_coco_detection(self, predictions):
68
+ coco_results = []
69
+ for original_id, prediction in predictions.items():
70
+ if len(prediction) == 0:
71
+ continue
72
+
73
+ boxes = prediction["boxes"]
74
+ boxes = convert_to_xywh(boxes).tolist()
75
+ scores = prediction["scores"].tolist()
76
+ labels = prediction["labels"].tolist()
77
+
78
+ coco_results.extend(
79
+ [
80
+ {
81
+ "image_id": original_id,
82
+ "category_id": labels[k],
83
+ "bbox": box,
84
+ "score": scores[k],
85
+ }
86
+ for k, box in enumerate(boxes)
87
+ ]
88
+ )
89
+ return coco_results
90
+
91
+ def prepare_for_coco_segmentation(self, predictions):
92
+ coco_results = []
93
+ for original_id, prediction in predictions.items():
94
+ if len(prediction) == 0:
95
+ continue
96
+
97
+ scores = prediction["scores"]
98
+ labels = prediction["labels"]
99
+ masks = prediction["masks"]
100
+
101
+ masks = masks > 0.5
102
+
103
+ scores = prediction["scores"].tolist()
104
+ labels = prediction["labels"].tolist()
105
+
106
+ rles = [
107
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
108
+ ]
109
+ for rle in rles:
110
+ rle["counts"] = rle["counts"].decode("utf-8")
111
+
112
+ coco_results.extend(
113
+ [
114
+ {
115
+ "image_id": original_id,
116
+ "category_id": labels[k],
117
+ "segmentation": rle,
118
+ "score": scores[k],
119
+ }
120
+ for k, rle in enumerate(rles)
121
+ ]
122
+ )
123
+ return coco_results
124
+
125
+ def prepare_for_coco_keypoint(self, predictions):
126
+ coco_results = []
127
+ for original_id, prediction in predictions.items():
128
+ if len(prediction) == 0:
129
+ continue
130
+
131
+ boxes = prediction["boxes"]
132
+ boxes = convert_to_xywh(boxes).tolist()
133
+ scores = prediction["scores"].tolist()
134
+ labels = prediction["labels"].tolist()
135
+ keypoints = prediction["keypoints"]
136
+ keypoints = keypoints.flatten(start_dim=1).tolist()
137
+
138
+ coco_results.extend(
139
+ [
140
+ {
141
+ "image_id": original_id,
142
+ "category_id": labels[k],
143
+ "keypoints": keypoint,
144
+ "score": scores[k],
145
+ }
146
+ for k, keypoint in enumerate(keypoints)
147
+ ]
148
+ )
149
+ return coco_results
150
+
151
+
152
+ def convert_to_xywh(boxes):
153
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
154
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
155
+
156
+
157
+ def merge(img_ids, eval_imgs):
158
+ all_img_ids = all_gather(img_ids)
159
+ all_eval_imgs = all_gather(eval_imgs)
160
+
161
+ merged_img_ids = []
162
+ for p in all_img_ids:
163
+ merged_img_ids.extend(p)
164
+
165
+ merged_eval_imgs = []
166
+ for p in all_eval_imgs:
167
+ merged_eval_imgs.append(p)
168
+
169
+ merged_img_ids = np.array(merged_img_ids)
170
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
171
+
172
+ # keep only unique (and in sorted order) images
173
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
174
+ merged_eval_imgs = merged_eval_imgs[..., idx]
175
+
176
+ return merged_img_ids, merged_eval_imgs
177
+
178
+
179
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
180
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
181
+ img_ids = list(img_ids)
182
+ eval_imgs = list(eval_imgs.flatten())
183
+
184
+ coco_eval.evalImgs = eval_imgs
185
+ coco_eval.params.imgIds = img_ids
186
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
187
+
188
+
189
+ def evaluate(imgs):
190
+ with redirect_stdout(io.StringIO()):
191
+ imgs.evaluate()
192
+ return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
model/utils/coco_utils.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import torch
5
+ import torch.utils.data
6
+ import torchvision
7
+ from torchvision import transforms as T
8
+ from pycocotools import mask as coco_mask
9
+ from pycocotools.coco import COCO
10
+
11
+
12
+ class FilterAndRemapCocoCategories:
13
+ def __init__(self, categories, remap=True):
14
+ self.categories = categories
15
+ self.remap = remap
16
+
17
+ def __call__(self, image, target):
18
+ anno = target["annotations"]
19
+ anno = [obj for obj in anno if obj["category_id"] in self.categories]
20
+ if not self.remap:
21
+ target["annotations"] = anno
22
+ return image, target
23
+ anno = copy.deepcopy(anno)
24
+ for obj in anno:
25
+ obj["category_id"] = self.categories.index(obj["category_id"])
26
+ target["annotations"] = anno
27
+ return image, target
28
+
29
+
30
+ def convert_coco_poly_to_mask(segmentations, height, width):
31
+ masks = []
32
+ for polygons in segmentations:
33
+ rles = coco_mask.frPyObjects(polygons, height, width)
34
+ mask = coco_mask.decode(rles)
35
+ if len(mask.shape) < 3:
36
+ mask = mask[..., None]
37
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
38
+ mask = mask.any(dim=2)
39
+ masks.append(mask)
40
+ if masks:
41
+ masks = torch.stack(masks, dim=0)
42
+ else:
43
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
44
+ return masks
45
+
46
+
47
+ class ConvertCocoPolysToMask:
48
+ def __call__(self, image, target):
49
+ w, h = image.size
50
+
51
+ image_id = target["image_id"]
52
+ image_id = torch.tensor([image_id])
53
+
54
+ anno = target["annotations"]
55
+
56
+ anno = [obj for obj in anno if obj["iscrowd"] == 0]
57
+
58
+ boxes = [obj["bbox"] for obj in anno]
59
+ # guard against no boxes via resizing
60
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
61
+ boxes[:, 2:] += boxes[:, :2]
62
+ boxes[:, 0::2].clamp_(min=0, max=w)
63
+ boxes[:, 1::2].clamp_(min=0, max=h)
64
+
65
+ classes = [obj["category_id"] for obj in anno]
66
+ classes = torch.tensor(classes, dtype=torch.int64)
67
+
68
+ segmentations = [obj["segmentation"] for obj in anno]
69
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
70
+
71
+ keypoints = None
72
+ if anno and "keypoints" in anno[0]:
73
+ keypoints = [obj["keypoints"] for obj in anno]
74
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
75
+ num_keypoints = keypoints.shape[0]
76
+ if num_keypoints:
77
+ keypoints = keypoints.view(num_keypoints, -1, 3)
78
+
79
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
80
+ boxes = boxes[keep]
81
+ classes = classes[keep]
82
+ masks = masks[keep]
83
+ if keypoints is not None:
84
+ keypoints = keypoints[keep]
85
+
86
+ target = {}
87
+ target["boxes"] = boxes
88
+ target["labels"] = classes
89
+ target["masks"] = masks
90
+ target["image_id"] = image_id
91
+ if keypoints is not None:
92
+ target["keypoints"] = keypoints
93
+
94
+ # for conversion to coco api
95
+ area = torch.tensor([obj["area"] for obj in anno])
96
+ iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
97
+ target["area"] = area
98
+ target["iscrowd"] = iscrowd
99
+
100
+ return image, target
101
+
102
+
103
+ def _coco_remove_images_without_annotations(dataset, cat_list=None):
104
+ def _has_only_empty_bbox(anno):
105
+ return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
106
+
107
+ def _count_visible_keypoints(anno):
108
+ return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
109
+
110
+ min_keypoints_per_image = 10
111
+
112
+ def _has_valid_annotation(anno):
113
+ # if it's empty, there is no annotation
114
+ if len(anno) == 0:
115
+ return False
116
+ # if all boxes have close to zero area, there is no annotation
117
+ if _has_only_empty_bbox(anno):
118
+ return False
119
+ # keypoints task have a slight different critera for considering
120
+ # if an annotation is valid
121
+ if "keypoints" not in anno[0]:
122
+ return True
123
+ # for keypoint detection tasks, only consider valid images those
124
+ # containing at least min_keypoints_per_image
125
+ if _count_visible_keypoints(anno) >= min_keypoints_per_image:
126
+ return True
127
+ return False
128
+
129
+ if not isinstance(dataset, torchvision.datasets.CocoDetection):
130
+ raise TypeError(
131
+ f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
132
+ )
133
+ ids = []
134
+ for ds_idx, img_id in enumerate(dataset.ids):
135
+ ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
136
+ anno = dataset.coco.loadAnns(ann_ids)
137
+ if cat_list:
138
+ anno = [obj for obj in anno if obj["category_id"] in cat_list]
139
+ if _has_valid_annotation(anno):
140
+ ids.append(ds_idx)
141
+
142
+ dataset = torch.utils.data.Subset(dataset, ids)
143
+ return dataset
144
+
145
+
146
+ def convert_to_coco_api(ds):
147
+ coco_ds = COCO()
148
+ # annotation IDs need to start at 1, not 0, see torchvision issue #1530
149
+ ann_id = 1
150
+ dataset = {"images": [], "categories": [], "annotations": []}
151
+ categories = set()
152
+ for img_idx in range(len(ds)):
153
+ # find better way to get target
154
+ # targets = ds.get_annotations(img_idx)
155
+ img, targets = ds[img_idx]
156
+ image_id = targets["image_id"].item()
157
+ img_dict = {}
158
+ img_dict["id"] = image_id
159
+ img_dict["height"] = img.shape[-2]
160
+ img_dict["width"] = img.shape[-1]
161
+ dataset["images"].append(img_dict)
162
+ bboxes = targets["boxes"].clone()
163
+ bboxes[:, 2:] -= bboxes[:, :2]
164
+ bboxes = bboxes.tolist()
165
+ labels = targets["labels"].tolist()
166
+ areas = targets["area"].tolist()
167
+ iscrowd = targets["iscrowd"].tolist()
168
+ if "masks" in targets:
169
+ masks = targets["masks"]
170
+ # make masks Fortran contiguous for coco_mask
171
+ masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
172
+ if "keypoints" in targets:
173
+ keypoints = targets["keypoints"]
174
+ keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
175
+ num_objs = len(bboxes)
176
+ for i in range(num_objs-1):
177
+ print(i)
178
+ ann = {}
179
+ ann["image_id"] = image_id
180
+ ann["bbox"] = bboxes[i]
181
+ ann["category_id"] = labels[i]
182
+ categories.add(labels[i])
183
+ ann["area"] = areas[i]
184
+ ann["iscrowd"] = iscrowd[i]
185
+ ann["id"] = ann_id
186
+ if "masks" in targets:
187
+ ann["segmentation"] = coco_mask.encode(masks[i].numpy())
188
+ if "keypoints" in targets:
189
+ ann["keypoints"] = keypoints[i]
190
+ ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
191
+ dataset["annotations"].append(ann)
192
+ ann_id += 1
193
+ dataset["categories"] = [{"id": i} for i in sorted(categories)]
194
+ coco_ds.dataset = dataset
195
+ coco_ds.createIndex()
196
+ return coco_ds
197
+
198
+
199
+ def get_coco_api_from_dataset(dataset):
200
+ for _ in range(10):
201
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
202
+ break
203
+ if isinstance(dataset, torch.utils.data.Subset):
204
+ dataset = dataset.dataset
205
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
206
+ return dataset.coco
207
+ return convert_to_coco_api(dataset)
208
+
209
+
210
+ class CocoDetection(torchvision.datasets.CocoDetection):
211
+ def __init__(self, img_folder, ann_file, transforms):
212
+ super().__init__(img_folder, ann_file)
213
+ self._transforms = transforms
214
+
215
+ def __getitem__(self, idx):
216
+ img, target = super().__getitem__(idx)
217
+ image_id = self.ids[idx]
218
+ target = dict(image_id=image_id, annotations=target)
219
+ if self._transforms is not None:
220
+ img, target = self._transforms(img, target)
221
+ return img, target
222
+
223
+
224
+ def get_coco(root, image_set, transforms, mode="instances"):
225
+ anno_file_template = "{}_{}2017.json"
226
+ PATHS = {
227
+ "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
228
+ "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
229
+ # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
230
+ }
231
+
232
+ t = [ConvertCocoPolysToMask()]
233
+
234
+ if transforms is not None:
235
+ t.append(transforms)
236
+ transforms = T.Compose(t)
237
+
238
+ img_folder, ann_file = PATHS[image_set]
239
+ img_folder = os.path.join(root, img_folder)
240
+ ann_file = os.path.join(root, ann_file)
241
+
242
+ dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
243
+
244
+ if image_set == "train":
245
+ dataset = _coco_remove_images_without_annotations(dataset)
246
+
247
+ # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
248
+
249
+ return dataset
250
+
251
+
252
+ def get_coco_kp(root, image_set, transforms):
253
+ return get_coco(root, image_set, transforms, mode="person_keypoints")
model/utils/engine.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ import time
4
+
5
+ import torch
6
+ import torchvision.models.detection.mask_rcnn
7
+ from .utils import *
8
+ from .coco_eval import CocoEvaluator
9
+ from .coco_utils import get_coco_api_from_dataset
10
+
11
+
12
+ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
13
+ model.train()
14
+ metric_logger = MetricLogger(delimiter=" ")
15
+ metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
16
+ header = f"Epoch: [{epoch}]"
17
+
18
+ lr_scheduler = None
19
+ if epoch == 0:
20
+ warmup_factor = 1.0 / 1000
21
+ warmup_iters = min(1000, len(data_loader) - 1)
22
+
23
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
24
+ optimizer, start_factor=warmup_factor, total_iters=warmup_iters
25
+ )
26
+
27
+ for images, targets in metric_logger.log_every(data_loader, print_freq, header):
28
+ images = list(image.to(device) for image in images)
29
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
30
+
31
+
32
+ with torch.cuda.amp.autocast(enabled=scaler is not None):
33
+ loss_dict = model(images, targets)
34
+ losses = sum(loss for loss in loss_dict.values())
35
+
36
+ # reduce losses over all GPUs for logging purposes
37
+ loss_dict_reduced = reduce_dict(loss_dict)
38
+ losses_reduced = sum(loss for loss in loss_dict_reduced.values())
39
+
40
+ loss_value = losses_reduced.item()
41
+
42
+ if not math.isfinite(loss_value):
43
+ print(f"Loss is {loss_value}, stopping training")
44
+ print(loss_dict_reduced)
45
+ sys.exit(1)
46
+
47
+ optimizer.zero_grad()
48
+ if scaler is not None:
49
+ scaler.scale(losses).backward()
50
+ scaler.step(optimizer)
51
+ scaler.update()
52
+ else:
53
+ losses.backward()
54
+ optimizer.step()
55
+
56
+ if lr_scheduler is not None:
57
+ lr_scheduler.step()
58
+
59
+ metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
60
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
61
+
62
+ return loss_dict, losses
63
+
64
+
65
+ def _get_iou_types(model):
66
+ model_without_ddp = model
67
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
68
+ model_without_ddp = model.module
69
+ iou_types = ["bbox"]
70
+ if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
71
+ iou_types.append("segm")
72
+ if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
73
+ iou_types.append("keypoints")
74
+ return iou_types
75
+
76
+
77
+ @torch.inference_mode()
78
+ def evaluate(model, data_loader, device):
79
+ n_threads = torch.get_num_threads()
80
+ # FIXME remove this and make paste_masks_in_image run on the GPU
81
+ torch.set_num_threads(1)
82
+ cpu_device = torch.device("cpu")
83
+ model.eval()
84
+ metric_logger = MetricLogger(delimiter=" ")
85
+ header = "Test:"
86
+
87
+ coco = get_coco_api_from_dataset(data_loader.dataset)
88
+ iou_types = _get_iou_types(model)
89
+ coco_evaluator = CocoEvaluator(coco, iou_types)
90
+
91
+ for images, targets in metric_logger.log_every(data_loader, 100, header):
92
+ images = list(img.to(device) for img in images)
93
+
94
+ if torch.cuda.is_available():
95
+ torch.cuda.synchronize()
96
+ model_time = time.time()
97
+ outputs = model(images)
98
+
99
+ outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
100
+ model_time = time.time() - model_time
101
+
102
+ res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
103
+ evaluator_time = time.time()
104
+ coco_evaluator.update(res)
105
+ evaluator_time = time.time() - evaluator_time
106
+ metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
107
+
108
+ # gather the stats from all processes
109
+ metric_logger.synchronize_between_processes()
110
+ print("Averaged stats:", metric_logger)
111
+ coco_evaluator.synchronize_between_processes()
112
+
113
+ # accumulate predictions from all images
114
+ coco_evaluator.accumulate()
115
+ coco_evaluator.summarize()
116
+ torch.set_num_threads(n_threads)
117
+
118
+ mAP = coco_evaluator.coco_eval[iou_types[0]].stats[0]
119
+
120
+ return coco_evaluator, mAP
model/utils/transforms.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torchvision
5
+ from torch import nn, Tensor
6
+ from torchvision import ops
7
+ from torchvision.transforms import functional as F, InterpolationMode, transforms as T
8
+
9
+
10
+ def _flip_coco_person_keypoints(kps, width):
11
+ flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
12
+ flipped_data = kps[:, flip_inds]
13
+ flipped_data[..., 0] = width - flipped_data[..., 0]
14
+ # Maintain COCO convention that if visibility == 0, then x, y = 0
15
+ inds = flipped_data[..., 2] == 0
16
+ flipped_data[inds] = 0
17
+ return flipped_data
18
+
19
+
20
+ class Compose:
21
+ def __init__(self, transforms):
22
+ self.transforms = transforms
23
+
24
+ def __call__(self, image, target):
25
+ for t in self.transforms:
26
+ image, target = t(image, target)
27
+ return image, target
28
+
29
+ class ToTensor(object):
30
+ def __call__(self, image, target):
31
+ image = F.to_tensor(image)
32
+ return image, target
33
+
34
+ class RandomHorizontalFlip(T.RandomHorizontalFlip):
35
+ def forward(
36
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
37
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
38
+ if torch.rand(1) < self.p:
39
+ image = F.hflip(image)
40
+ if target is not None:
41
+ _, _, width = F.get_dimensions(image)
42
+ target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
43
+ if "masks" in target:
44
+ target["masks"] = target["masks"].flip(-1)
45
+ if "keypoints" in target:
46
+ keypoints = target["keypoints"]
47
+ keypoints = _flip_coco_person_keypoints(keypoints, width)
48
+ target["keypoints"] = keypoints
49
+ return image, target
50
+
51
+
52
+ class PILToTensor(nn.Module):
53
+ def forward(
54
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
55
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
56
+ image = F.pil_to_tensor(image)
57
+ return image, target
58
+
59
+
60
+ class ConvertImageDtype(nn.Module):
61
+ def __init__(self, dtype: torch.dtype) -> None:
62
+ super().__init__()
63
+ self.dtype = dtype
64
+
65
+ def forward(
66
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
67
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
68
+ image = F.convert_image_dtype(image, self.dtype)
69
+ return image, target
70
+
71
+
72
+ class RandomIoUCrop(nn.Module):
73
+ def __init__(
74
+ self,
75
+ min_scale: float = 0.3,
76
+ max_scale: float = 1.0,
77
+ min_aspect_ratio: float = 0.5,
78
+ max_aspect_ratio: float = 2.0,
79
+ sampler_options: Optional[List[float]] = None,
80
+ trials: int = 40,
81
+ ):
82
+ super().__init__()
83
+ # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
84
+ self.min_scale = min_scale
85
+ self.max_scale = max_scale
86
+ self.min_aspect_ratio = min_aspect_ratio
87
+ self.max_aspect_ratio = max_aspect_ratio
88
+ if sampler_options is None:
89
+ sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
90
+ self.options = sampler_options
91
+ self.trials = trials
92
+
93
+ def forward(
94
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
95
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
96
+ if target is None:
97
+ raise ValueError("The targets can't be None for this transform.")
98
+
99
+ if isinstance(image, torch.Tensor):
100
+ if image.ndimension() not in {2, 3}:
101
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
102
+ elif image.ndimension() == 2:
103
+ image = image.unsqueeze(0)
104
+
105
+ _, orig_h, orig_w = F.get_dimensions(image)
106
+
107
+ while True:
108
+ # sample an option
109
+ idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
110
+ min_jaccard_overlap = self.options[idx]
111
+ if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
112
+ return image, target
113
+
114
+ for _ in range(self.trials):
115
+ # check the aspect ratio limitations
116
+ r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
117
+ new_w = int(orig_w * r[0])
118
+ new_h = int(orig_h * r[1])
119
+ aspect_ratio = new_w / new_h
120
+ if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
121
+ continue
122
+
123
+ # check for 0 area crops
124
+ r = torch.rand(2)
125
+ left = int((orig_w - new_w) * r[0])
126
+ top = int((orig_h - new_h) * r[1])
127
+ right = left + new_w
128
+ bottom = top + new_h
129
+ if left == right or top == bottom:
130
+ continue
131
+
132
+ # check for any valid boxes with centers within the crop area
133
+ cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
134
+ cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
135
+ is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
136
+ if not is_within_crop_area.any():
137
+ continue
138
+
139
+ # check at least 1 box with jaccard limitations
140
+ boxes = target["boxes"][is_within_crop_area]
141
+ ious = torchvision.ops.boxes.box_iou(
142
+ boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
143
+ )
144
+ if ious.max() < min_jaccard_overlap:
145
+ continue
146
+
147
+ # keep only valid boxes and perform cropping
148
+ target["boxes"] = boxes
149
+ target["labels"] = target["labels"][is_within_crop_area]
150
+ target["boxes"][:, 0::2] -= left
151
+ target["boxes"][:, 1::2] -= top
152
+ target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
153
+ target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
154
+ image = F.crop(image, top, left, new_h, new_w)
155
+
156
+ return image, target
157
+
158
+
159
+ class RandomZoomOut(nn.Module):
160
+ def __init__(
161
+ self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
162
+ ):
163
+ super().__init__()
164
+ if fill is None:
165
+ fill = [0.0, 0.0, 0.0]
166
+ self.fill = fill
167
+ self.side_range = side_range
168
+ if side_range[0] < 1.0 or side_range[0] > side_range[1]:
169
+ raise ValueError(f"Invalid canvas side range provided {side_range}.")
170
+ self.p = p
171
+
172
+ @torch.jit.unused
173
+ def _get_fill_value(self, is_pil):
174
+ # type: (bool) -> int
175
+ # We fake the type to make it work on JIT
176
+ return tuple(int(x) for x in self.fill) if is_pil else 0
177
+
178
+ def forward(
179
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
180
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
181
+ if isinstance(image, torch.Tensor):
182
+ if image.ndimension() not in {2, 3}:
183
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
184
+ elif image.ndimension() == 2:
185
+ image = image.unsqueeze(0)
186
+
187
+ if torch.rand(1) >= self.p:
188
+ return image, target
189
+
190
+ _, orig_h, orig_w = F.get_dimensions(image)
191
+
192
+ r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
193
+ canvas_width = int(orig_w * r)
194
+ canvas_height = int(orig_h * r)
195
+
196
+ r = torch.rand(2)
197
+ left = int((canvas_width - orig_w) * r[0])
198
+ top = int((canvas_height - orig_h) * r[1])
199
+ right = canvas_width - (left + orig_w)
200
+ bottom = canvas_height - (top + orig_h)
201
+
202
+ if torch.jit.is_scripting():
203
+ fill = 0
204
+ else:
205
+ fill = self._get_fill_value(F._is_pil_image(image))
206
+
207
+ image = F.pad(image, [left, top, right, bottom], fill=fill)
208
+ if isinstance(image, torch.Tensor):
209
+ # PyTorch's pad supports only integers on fill. So we need to overwrite the colour
210
+ v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
211
+ image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
212
+ ..., :, (left + orig_w) :
213
+ ] = v
214
+
215
+ if target is not None:
216
+ target["boxes"][:, 0::2] += left
217
+ target["boxes"][:, 1::2] += top
218
+
219
+ return image, target
220
+
221
+
222
+ class RandomPhotometricDistort(nn.Module):
223
+ def __init__(
224
+ self,
225
+ contrast: Tuple[float, float] = (0.5, 1.5),
226
+ saturation: Tuple[float, float] = (0.5, 1.5),
227
+ hue: Tuple[float, float] = (-0.05, 0.05),
228
+ brightness: Tuple[float, float] = (0.875, 1.125),
229
+ p: float = 0.5,
230
+ ):
231
+ super().__init__()
232
+ self._brightness = T.ColorJitter(brightness=brightness)
233
+ self._contrast = T.ColorJitter(contrast=contrast)
234
+ self._hue = T.ColorJitter(hue=hue)
235
+ self._saturation = T.ColorJitter(saturation=saturation)
236
+ self.p = p
237
+
238
+ def forward(
239
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
240
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
241
+ if isinstance(image, torch.Tensor):
242
+ if image.ndimension() not in {2, 3}:
243
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
244
+ elif image.ndimension() == 2:
245
+ image = image.unsqueeze(0)
246
+
247
+ r = torch.rand(7)
248
+
249
+ if r[0] < self.p:
250
+ image = self._brightness(image)
251
+
252
+ contrast_before = r[1] < 0.5
253
+ if contrast_before:
254
+ if r[2] < self.p:
255
+ image = self._contrast(image)
256
+
257
+ if r[3] < self.p:
258
+ image = self._saturation(image)
259
+
260
+ if r[4] < self.p:
261
+ image = self._hue(image)
262
+
263
+ if not contrast_before:
264
+ if r[5] < self.p:
265
+ image = self._contrast(image)
266
+
267
+ if r[6] < self.p:
268
+ channels, _, _ = F.get_dimensions(image)
269
+ permutation = torch.randperm(channels)
270
+
271
+ is_pil = F._is_pil_image(image)
272
+ if is_pil:
273
+ image = F.pil_to_tensor(image)
274
+ image = F.convert_image_dtype(image)
275
+ image = image[..., permutation, :, :]
276
+ if is_pil:
277
+ image = F.to_pil_image(image)
278
+
279
+ return image, target
280
+
281
+
282
+ class ScaleJitter(nn.Module):
283
+ """Randomly resizes the image and its bounding boxes within the specified scale range.
284
+ The class implements the Scale Jitter augmentation as described in the paper
285
+ `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
286
+
287
+ Args:
288
+ target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
289
+ scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
290
+ range a <= scale <= b.
291
+ interpolation (InterpolationMode): Desired interpolation enum defined by
292
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ target_size: Tuple[int, int],
298
+ scale_range: Tuple[float, float] = (0.1, 2.0),
299
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
300
+ ):
301
+ super().__init__()
302
+ self.target_size = target_size
303
+ self.scale_range = scale_range
304
+ self.interpolation = interpolation
305
+
306
+ def forward(
307
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
308
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
309
+ if isinstance(image, torch.Tensor):
310
+ if image.ndimension() not in {2, 3}:
311
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
312
+ elif image.ndimension() == 2:
313
+ image = image.unsqueeze(0)
314
+
315
+ _, orig_height, orig_width = F.get_dimensions(image)
316
+
317
+ scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
318
+ r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
319
+ new_width = int(orig_width * r)
320
+ new_height = int(orig_height * r)
321
+
322
+ image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
323
+
324
+ if target is not None:
325
+ target["boxes"][:, 0::2] *= new_width / orig_width
326
+ target["boxes"][:, 1::2] *= new_height / orig_height
327
+ if "masks" in target:
328
+ target["masks"] = F.resize(
329
+ target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
330
+ )
331
+
332
+ return image, target
333
+
334
+
335
+ class FixedSizeCrop(nn.Module):
336
+ def __init__(self, size, fill=0, padding_mode="constant"):
337
+ super().__init__()
338
+ size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
339
+ self.crop_height = size[0]
340
+ self.crop_width = size[1]
341
+ self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
342
+ self.padding_mode = padding_mode
343
+
344
+ def _pad(self, img, target, padding):
345
+ # Taken from the functional_tensor.py pad
346
+ if isinstance(padding, int):
347
+ pad_left = pad_right = pad_top = pad_bottom = padding
348
+ elif len(padding) == 1:
349
+ pad_left = pad_right = pad_top = pad_bottom = padding[0]
350
+ elif len(padding) == 2:
351
+ pad_left = pad_right = padding[0]
352
+ pad_top = pad_bottom = padding[1]
353
+ else:
354
+ pad_left = padding[0]
355
+ pad_top = padding[1]
356
+ pad_right = padding[2]
357
+ pad_bottom = padding[3]
358
+
359
+ padding = [pad_left, pad_top, pad_right, pad_bottom]
360
+ img = F.pad(img, padding, self.fill, self.padding_mode)
361
+ if target is not None:
362
+ target["boxes"][:, 0::2] += pad_left
363
+ target["boxes"][:, 1::2] += pad_top
364
+ if "masks" in target:
365
+ target["masks"] = F.pad(target["masks"], padding, 0, "constant")
366
+
367
+ return img, target
368
+
369
+ def _crop(self, img, target, top, left, height, width):
370
+ img = F.crop(img, top, left, height, width)
371
+ if target is not None:
372
+ boxes = target["boxes"]
373
+ boxes[:, 0::2] -= left
374
+ boxes[:, 1::2] -= top
375
+ boxes[:, 0::2].clamp_(min=0, max=width)
376
+ boxes[:, 1::2].clamp_(min=0, max=height)
377
+
378
+ is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
379
+
380
+ target["boxes"] = boxes[is_valid]
381
+ target["labels"] = target["labels"][is_valid]
382
+ if "masks" in target:
383
+ target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
384
+
385
+ return img, target
386
+
387
+ def forward(self, img, target=None):
388
+ _, height, width = F.get_dimensions(img)
389
+ new_height = min(height, self.crop_height)
390
+ new_width = min(width, self.crop_width)
391
+
392
+ if new_height != height or new_width != width:
393
+ offset_height = max(height - self.crop_height, 0)
394
+ offset_width = max(width - self.crop_width, 0)
395
+
396
+ r = torch.rand(1)
397
+ top = int(offset_height * r)
398
+ left = int(offset_width * r)
399
+
400
+ img, target = self._crop(img, target, top, left, new_height, new_width)
401
+
402
+ pad_bottom = max(self.crop_height - new_height, 0)
403
+ pad_right = max(self.crop_width - new_width, 0)
404
+ if pad_bottom != 0 or pad_right != 0:
405
+ img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
406
+
407
+ return img, target
408
+
409
+
410
+ class RandomShortestSize(nn.Module):
411
+ def __init__(
412
+ self,
413
+ min_size: Union[List[int], Tuple[int], int],
414
+ max_size: int,
415
+ interpolation: InterpolationMode = InterpolationMode.BILINEAR,
416
+ ):
417
+ super().__init__()
418
+ self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
419
+ self.max_size = max_size
420
+ self.interpolation = interpolation
421
+
422
+ def forward(
423
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
424
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
425
+ _, orig_height, orig_width = F.get_dimensions(image)
426
+
427
+ min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
428
+ r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
429
+
430
+ new_width = int(orig_width * r)
431
+ new_height = int(orig_height * r)
432
+
433
+ image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
434
+
435
+ if target is not None:
436
+ target["boxes"][:, 0::2] *= new_width / orig_width
437
+ target["boxes"][:, 1::2] *= new_height / orig_height
438
+ if "masks" in target:
439
+ target["masks"] = F.resize(
440
+ target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
441
+ )
442
+
443
+ return image, target
444
+
445
+
446
+ def _copy_paste(
447
+ image: torch.Tensor,
448
+ target: Dict[str, Tensor],
449
+ paste_image: torch.Tensor,
450
+ paste_target: Dict[str, Tensor],
451
+ blending: bool = True,
452
+ resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
453
+ ) -> Tuple[torch.Tensor, Dict[str, Tensor]]:
454
+
455
+ # Random paste targets selection:
456
+ num_masks = len(paste_target["masks"])
457
+
458
+ if num_masks < 1:
459
+ # Such degerante case with num_masks=0 can happen with LSJ
460
+ # Let's just return (image, target)
461
+ return image, target
462
+
463
+ # We have to please torch script by explicitly specifying dtype as torch.long
464
+ random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
465
+ random_selection = torch.unique(random_selection).to(torch.long)
466
+
467
+ paste_masks = paste_target["masks"][random_selection]
468
+ paste_boxes = paste_target["boxes"][random_selection]
469
+ paste_labels = paste_target["labels"][random_selection]
470
+
471
+ masks = target["masks"]
472
+
473
+ # We resize source and paste data if they have different sizes
474
+ # This is something we introduced here as originally the algorithm works
475
+ # on equal-sized data (for example, coming from LSJ data augmentations)
476
+ size1 = image.shape[-2:]
477
+ size2 = paste_image.shape[-2:]
478
+ if size1 != size2:
479
+ paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation)
480
+ paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST)
481
+ # resize bboxes:
482
+ ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device)
483
+ paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape)
484
+
485
+ paste_alpha_mask = paste_masks.sum(dim=0) > 0
486
+
487
+ if blending:
488
+ paste_alpha_mask = F.gaussian_blur(
489
+ paste_alpha_mask.unsqueeze(0),
490
+ kernel_size=(5, 5),
491
+ sigma=[
492
+ 2.0,
493
+ ],
494
+ )
495
+
496
+ # Copy-paste images:
497
+ image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
498
+
499
+ # Copy-paste masks:
500
+ masks = masks * (~paste_alpha_mask)
501
+ non_all_zero_masks = masks.sum((-1, -2)) > 0
502
+ masks = masks[non_all_zero_masks]
503
+
504
+ # Do a shallow copy of the target dict
505
+ out_target = {k: v for k, v in target.items()}
506
+
507
+ out_target["masks"] = torch.cat([masks, paste_masks])
508
+
509
+ # Copy-paste boxes and labels
510
+ boxes = ops.masks_to_boxes(masks)
511
+ out_target["boxes"] = torch.cat([boxes, paste_boxes])
512
+
513
+ labels = target["labels"][non_all_zero_masks]
514
+ out_target["labels"] = torch.cat([labels, paste_labels])
515
+
516
+ # Update additional optional keys: area and iscrowd if exist
517
+ if "area" in target:
518
+ out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32)
519
+
520
+ if "iscrowd" in target and "iscrowd" in paste_target:
521
+ # target['iscrowd'] size can be differ from mask size (non_all_zero_masks)
522
+ # For example, if previous transforms geometrically modifies masks/boxes/labels but
523
+ # does not update "iscrowd"
524
+ if len(target["iscrowd"]) == len(non_all_zero_masks):
525
+ iscrowd = target["iscrowd"][non_all_zero_masks]
526
+ paste_iscrowd = paste_target["iscrowd"][random_selection]
527
+ out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd])
528
+
529
+ # Check for degenerated boxes and remove them
530
+ boxes = out_target["boxes"]
531
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
532
+ if degenerate_boxes.any():
533
+ valid_targets = ~degenerate_boxes.any(dim=1)
534
+
535
+ out_target["boxes"] = boxes[valid_targets]
536
+ out_target["masks"] = out_target["masks"][valid_targets]
537
+ out_target["labels"] = out_target["labels"][valid_targets]
538
+
539
+ if "area" in out_target:
540
+ out_target["area"] = out_target["area"][valid_targets]
541
+ if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets):
542
+ out_target["iscrowd"] = out_target["iscrowd"][valid_targets]
543
+
544
+ return image, out_target
545
+
546
+
547
+ class SimpleCopyPaste(torch.nn.Module):
548
+ def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR):
549
+ super().__init__()
550
+ self.resize_interpolation = resize_interpolation
551
+ self.blending = blending
552
+
553
+ def forward(
554
+ self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]]
555
+ ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]:
556
+ torch._assert(
557
+ isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]),
558
+ "images should be a list of tensors",
559
+ )
560
+ torch._assert(
561
+ isinstance(targets, (list, tuple)) and len(images) == len(targets),
562
+ "targets should be a list of the same size as images",
563
+ )
564
+ for target in targets:
565
+ # Can not check for instance type dict with inside torch.jit.script
566
+ # torch._assert(isinstance(target, dict), "targets item should be a dict")
567
+ for k in ["masks", "boxes", "labels"]:
568
+ torch._assert(k in target, f"Key {k} should be present in targets")
569
+ torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor")
570
+
571
+ # images = [t1, t2, ..., tN]
572
+ # Let's define paste_images as shifted list of input images
573
+ # paste_images = [t2, t3, ..., tN, t1]
574
+ # FYI: in TF they mix data on the dataset level
575
+ images_rolled = images[-1:] + images[:-1]
576
+ targets_rolled = targets[-1:] + targets[:-1]
577
+
578
+ output_images: List[torch.Tensor] = []
579
+ output_targets: List[Dict[str, Tensor]] = []
580
+
581
+ for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
582
+ output_image, output_data = _copy_paste(
583
+ image,
584
+ target,
585
+ paste_image,
586
+ paste_target,
587
+ blending=self.blending,
588
+ resize_interpolation=self.resize_interpolation,
589
+ )
590
+ output_images.append(output_image)
591
+ output_targets.append(output_data)
592
+
593
+ return output_images, output_targets
594
+
595
+ def __repr__(self) -> str:
596
+ s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
597
+ return s
model/utils/utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import errno
3
+ import os
4
+ import time
5
+ from collections import defaultdict, deque
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+
10
+
11
+ class SmoothedValue:
12
+ """Track a series of values and provide access to smoothed values over a
13
+ window or the global series average.
14
+ """
15
+
16
+ def __init__(self, window_size=20, fmt=None):
17
+ if fmt is None:
18
+ fmt = "{median:.4f} ({global_avg:.4f})"
19
+ self.deque = deque(maxlen=window_size)
20
+ self.total = 0.0
21
+ self.count = 0
22
+ self.fmt = fmt
23
+
24
+ def update(self, value, n=1):
25
+ self.deque.append(value)
26
+ self.count += n
27
+ self.total += value * n
28
+
29
+ def synchronize_between_processes(self):
30
+ """
31
+ Warning: does not synchronize the deque!
32
+ """
33
+ if not is_dist_avail_and_initialized():
34
+ return
35
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
36
+ dist.barrier()
37
+ dist.all_reduce(t)
38
+ t = t.tolist()
39
+ self.count = int(t[0])
40
+ self.total = t[1]
41
+
42
+ @property
43
+ def median(self):
44
+ d = torch.tensor(list(self.deque))
45
+ return d.median().item()
46
+
47
+ @property
48
+ def avg(self):
49
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
50
+ return d.mean().item()
51
+
52
+ @property
53
+ def global_avg(self):
54
+ return self.total / self.count
55
+
56
+ @property
57
+ def max(self):
58
+ return max(self.deque)
59
+
60
+ @property
61
+ def value(self):
62
+ return self.deque[-1]
63
+
64
+ def __str__(self):
65
+ return self.fmt.format(
66
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
67
+ )
68
+
69
+
70
+ def all_gather(data):
71
+ """
72
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
73
+ Args:
74
+ data: any picklable object
75
+ Returns:
76
+ list[data]: list of data gathered from each rank
77
+ """
78
+ world_size = get_world_size()
79
+ if world_size == 1:
80
+ return [data]
81
+ data_list = [None] * world_size
82
+ dist.all_gather_object(data_list, data)
83
+ return data_list
84
+
85
+
86
+ def reduce_dict(input_dict, average=True):
87
+ """
88
+ Args:
89
+ input_dict (dict): all the values will be reduced
90
+ average (bool): whether to do average or sum
91
+ Reduce the values in the dictionary from all processes so that all processes
92
+ have the averaged results. Returns a dict with the same fields as
93
+ input_dict, after reduction.
94
+ """
95
+ world_size = get_world_size()
96
+ if world_size < 2:
97
+ return input_dict
98
+ with torch.inference_mode():
99
+ names = []
100
+ values = []
101
+ # sort the keys so that they are consistent across processes
102
+ for k in sorted(input_dict.keys()):
103
+ names.append(k)
104
+ values.append(input_dict[k])
105
+ values = torch.stack(values, dim=0)
106
+ dist.all_reduce(values)
107
+ if average:
108
+ values /= world_size
109
+ reduced_dict = {k: v for k, v in zip(names, values)}
110
+ return reduced_dict
111
+
112
+
113
+ class MetricLogger:
114
+ def __init__(self, delimiter="\t"):
115
+ self.meters = defaultdict(SmoothedValue)
116
+ self.delimiter = delimiter
117
+
118
+ def update(self, **kwargs):
119
+ for k, v in kwargs.items():
120
+ if isinstance(v, torch.Tensor):
121
+ v = v.item()
122
+ assert isinstance(v, (float, int))
123
+ self.meters[k].update(v)
124
+
125
+ def __getattr__(self, attr):
126
+ if attr in self.meters:
127
+ return self.meters[attr]
128
+ if attr in self.__dict__:
129
+ return self.__dict__[attr]
130
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
131
+
132
+ def __str__(self):
133
+ loss_str = []
134
+ for name, meter in self.meters.items():
135
+ loss_str.append(f"{name}: {str(meter)}")
136
+ return self.delimiter.join(loss_str)
137
+
138
+ def synchronize_between_processes(self):
139
+ for meter in self.meters.values():
140
+ meter.synchronize_between_processes()
141
+
142
+ def add_meter(self, name, meter):
143
+ self.meters[name] = meter
144
+
145
+ def log_every(self, iterable, print_freq, header=None):
146
+ i = 0
147
+ if not header:
148
+ header = ""
149
+ start_time = time.time()
150
+ end = time.time()
151
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
152
+ data_time = SmoothedValue(fmt="{avg:.4f}")
153
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
154
+ if torch.cuda.is_available():
155
+ log_msg = self.delimiter.join(
156
+ [
157
+ header,
158
+ "[{0" + space_fmt + "}/{1}]",
159
+ "eta: {eta}",
160
+ "{meters}",
161
+ "time: {time}",
162
+ "data: {data}",
163
+ "max mem: {memory:.0f}",
164
+ ]
165
+ )
166
+ else:
167
+ log_msg = self.delimiter.join(
168
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
169
+ )
170
+ MB = 1024.0 * 1024.0
171
+ for obj in iterable:
172
+ data_time.update(time.time() - end)
173
+ yield obj
174
+ iter_time.update(time.time() - end)
175
+ if i % print_freq == 0 or i == len(iterable) - 1:
176
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
177
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
178
+ if torch.cuda.is_available():
179
+ print(
180
+ log_msg.format(
181
+ i,
182
+ len(iterable),
183
+ eta=eta_string,
184
+ meters=str(self),
185
+ time=str(iter_time),
186
+ data=str(data_time),
187
+ memory=torch.cuda.max_memory_allocated() / MB,
188
+ )
189
+ )
190
+ else:
191
+ print(
192
+ log_msg.format(
193
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
194
+ )
195
+ )
196
+ i += 1
197
+ end = time.time()
198
+ total_time = time.time() - start_time
199
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
200
+ print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
201
+
202
+
203
+ def collate_fn(batch):
204
+ return tuple(zip(*batch))
205
+
206
+
207
+ def mkdir(path):
208
+ try:
209
+ os.makedirs(path)
210
+ except OSError as e:
211
+ if e.errno != errno.EEXIST:
212
+ raise
213
+
214
+
215
+ def setup_for_distributed(is_master):
216
+ """
217
+ This function disables printing when not in master process
218
+ """
219
+ import builtins as __builtin__
220
+
221
+ builtin_print = __builtin__.print
222
+
223
+ def print(*args, **kwargs):
224
+ force = kwargs.pop("force", False)
225
+ if is_master or force:
226
+ builtin_print(*args, **kwargs)
227
+
228
+ __builtin__.print = print
229
+
230
+
231
+ def is_dist_avail_and_initialized():
232
+ if not dist.is_available():
233
+ return False
234
+ if not dist.is_initialized():
235
+ return False
236
+ return True
237
+
238
+
239
+ def get_world_size():
240
+ if not is_dist_avail_and_initialized():
241
+ return 1
242
+ return dist.get_world_size()
243
+
244
+
245
+ def get_rank():
246
+ if not is_dist_avail_and_initialized():
247
+ return 0
248
+ return dist.get_rank()
249
+
250
+
251
+ def is_main_process():
252
+ return get_rank() == 0
253
+
254
+
255
+ def save_on_master(*args, **kwargs):
256
+ if is_main_process():
257
+ torch.save(*args, **kwargs)
258
+
259
+
260
+ def init_distributed_mode(args):
261
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
262
+ args.rank = int(os.environ["RANK"])
263
+ args.world_size = int(os.environ["WORLD_SIZE"])
264
+ args.gpu = int(os.environ["LOCAL_RANK"])
265
+ elif "SLURM_PROCID" in os.environ:
266
+ args.rank = int(os.environ["SLURM_PROCID"])
267
+ args.gpu = args.rank % torch.cuda.device_count()
268
+ else:
269
+ print("Not using distributed mode")
270
+ args.distributed = False
271
+ return
272
+
273
+ args.distributed = True
274
+
275
+ torch.cuda.set_device(args.gpu)
276
+ args.dist_backend = "nccl"
277
+ print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
278
+ torch.distributed.init_process_group(
279
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
280
+ )
281
+ torch.distributed.barrier()
282
+ setup_for_distributed(args.rank == 0)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.23.4
2
+ Pillow==9.3.0
3
+ pycocotools==2.0.6
4
+ torch==1.13.0
5
+ torchvision==0.14.0
6
+ streamlit