jwlarocque commited on
Commit
ab7d699
·
1 Parent(s): fdaae10

Create DIS-SAM space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png +0 -0
  3. IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png +0 -0
  4. IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png +0 -0
  5. IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png +0 -0
  6. IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png +0 -0
  7. IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.jpg +3 -0
  8. IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.jpg +3 -0
  9. IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.jpg +3 -0
  10. IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.jpg +3 -0
  11. IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.jpg +3 -0
  12. IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png +0 -0
  13. IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png +0 -0
  14. IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png +0 -0
  15. IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png +0 -0
  16. IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png +0 -0
  17. IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png +0 -0
  18. IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#3292738108_c51336a8be_o.png +0 -0
  19. IS_Net/DIS5K/DIS5K-test/gt/4#Architecture#10#Pavilion#5795028920_08884db993_o.png +0 -0
  20. IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.jpg +3 -0
  21. IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#3292738108_c51336a8be_o.jpg +3 -0
  22. IS_Net/DIS5K/DIS5K-test/im/4#Architecture#10#Pavilion#5795028920_08884db993_o.jpg +3 -0
  23. IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png +0 -0
  24. IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#3292738108_c51336a8be_o.png +0 -0
  25. IS_Net/DIS5K/DIS5K-test/mask/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png +0 -0
  26. IS_Net/__pycache__/data_loader.cpython-311.pyc +0 -0
  27. IS_Net/basics.py +125 -0
  28. IS_Net/data_loader.py +542 -0
  29. IS_Net/datalist.py +62 -0
  30. IS_Net/models/__pycache__/isnet.cpython-311.pyc +0 -0
  31. IS_Net/models/isnet.py +640 -0
  32. IS_Net/saliency_toolbox.py +552 -0
  33. IS_Net/swd_optim/__init__.py +10 -0
  34. IS_Net/swd_optim/adai.py +116 -0
  35. IS_Net/swd_optim/adais.py +120 -0
  36. IS_Net/swd_optim/adams.py +137 -0
  37. IS_Net/swd_optim/sgds.py +82 -0
  38. IS_Net/train_valid_inference_main.py +729 -0
  39. MultiScaleDeformableAttention-1.0-py3-none-any.whl +3 -0
  40. README.md +4 -2
  41. SAM/segment_anything/__init__.py +15 -0
  42. SAM/segment_anything/__pycache__/__init__.cpython-311.pyc +0 -0
  43. SAM/segment_anything/__pycache__/automatic_mask_generator.cpython-311.pyc +0 -0
  44. SAM/segment_anything/__pycache__/build_sam.cpython-311.pyc +0 -0
  45. SAM/segment_anything/__pycache__/predictor.cpython-311.pyc +0 -0
  46. SAM/segment_anything/automatic_mask_generator.py +372 -0
  47. SAM/segment_anything/build_sam.py +111 -0
  48. SAM/segment_anything/modeling/__init__.py +11 -0
  49. SAM/segment_anything/modeling/__pycache__/__init__.cpython-311.pyc +0 -0
  50. SAM/segment_anything/modeling/__pycache__/common.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ MultiScaleDeformableAttention-1.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.jpg ADDED

Git LFS Details

  • SHA256: 4cb7d3c28db6f3bc4d2227d7551b1ba85abc9d335a1c6a625777de351bb9d469
  • Pointer size: 131 Bytes
  • Size of remote file: 779 kB
IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.jpg ADDED

Git LFS Details

  • SHA256: 2c4229b3b7978308ba3f28903e5a65e2bce7bfa7bd684f53c5bb23f3067dd6c4
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.jpg ADDED

Git LFS Details

  • SHA256: 71fd4c8bd0e10b57142b9781bc4654f368cd91ceb8e0b4e22ff9e54ce0b2fe06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.jpg ADDED

Git LFS Details

  • SHA256: 71fd4c8bd0e10b57142b9781bc4654f368cd91ceb8e0b4e22ff9e54ce0b2fe06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.jpg ADDED

Git LFS Details

  • SHA256: 71fd4c8bd0e10b57142b9781bc4654f368cd91ceb8e0b4e22ff9e54ce0b2fe06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png ADDED
IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png ADDED
IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png ADDED
IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#3292738108_c51336a8be_o.png ADDED
IS_Net/DIS5K/DIS5K-test/gt/4#Architecture#10#Pavilion#5795028920_08884db993_o.png ADDED
IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.jpg ADDED

Git LFS Details

  • SHA256: 4cb7d3c28db6f3bc4d2227d7551b1ba85abc9d335a1c6a625777de351bb9d469
  • Pointer size: 131 Bytes
  • Size of remote file: 779 kB
IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#3292738108_c51336a8be_o.jpg ADDED

Git LFS Details

  • SHA256: 2c4229b3b7978308ba3f28903e5a65e2bce7bfa7bd684f53c5bb23f3067dd6c4
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
IS_Net/DIS5K/DIS5K-test/im/4#Architecture#10#Pavilion#5795028920_08884db993_o.jpg ADDED

Git LFS Details

  • SHA256: 71fd4c8bd0e10b57142b9781bc4654f368cd91ceb8e0b4e22ff9e54ce0b2fe06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png ADDED
IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#3292738108_c51336a8be_o.png ADDED
IS_Net/DIS5K/DIS5K-test/mask/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png ADDED
IS_Net/__pycache__/data_loader.cpython-311.pyc ADDED
Binary file (34.3 kB). View file
 
IS_Net/basics.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '2'
3
+ from skimage import io, transform
4
+ import torch
5
+ import torchvision
6
+ from torch.autograd import Variable
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms, utils
11
+ import torch.optim as optim
12
+ from skimage.metrics import structural_similarity as ssim
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from PIL import Image
16
+ import glob
17
+ import cv2
18
+ from scipy.stats import pearsonr
19
+
20
+ def mae_torch(pred,gt):
21
+
22
+ h,w = gt.shape[0:2]
23
+ sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
24
+ maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4)
25
+
26
+ return maeError
27
+
28
+ import torch
29
+
30
+ def maximal_f_measure_torch(pd, gt):
31
+ gtNum = torch.sum((gt > 128).float() * 1) # 计算真实标签中像素值大于128的数量
32
+
33
+ # 从预测张量中提取正例和负例
34
+ pp = pd[gt > 128]
35
+ nn = pd[gt <= 128]
36
+
37
+ # 计算正例和负例的直方图
38
+ pp_hist = torch.histc(pp, bins=255, min=0, max=255)
39
+ nn_hist = torch.histc(nn, bins=255, min=0, max=255)
40
+
41
+ # 反转直方图并计算累积和
42
+ pp_hist_flip = torch.flipud(pp_hist)
43
+ nn_hist_flip = torch.flipud(nn_hist)
44
+
45
+ pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
46
+ nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
47
+
48
+ # 计算Precision、Recall 和 F-measure
49
+ precision = (pp_hist_flip_cum) / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)
50
+ recall = (pp_hist_flip_cum) / (gtNum + 1e-4)
51
+ f_measure = (2 * precision * recall) / (precision + recall + 1e-4)
52
+
53
+ # 找到最大F-measure及其对应的阈值
54
+ max_f_measure, threshold = torch.max(f_measure, dim=0)
55
+
56
+ return max_f_measure.item(), threshold.item()
57
+
58
+ def calculate_meam(image1, image2):
59
+ # 直方图均衡化
60
+ image1_equalized = cv2.equalizeHist(image1)
61
+ image2_equalized = cv2.equalizeHist(image2)
62
+
63
+ # 计算Pearson相关系数
64
+ correlation_coefficient, _ = pearsonr(image1_equalized.flatten(), image2_equalized.flatten())
65
+
66
+ # 计算MEAM值
67
+ meam_value = correlation_coefficient * np.mean(np.minimum(image1_equalized, image2_equalized))
68
+
69
+ return meam_value
70
+
71
+ def f1score_torch(pd,gt):
72
+
73
+ # print(gt.shape)
74
+ gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels
75
+
76
+ pp = pd[gt>128]
77
+ nn = pd[gt<=128]
78
+
79
+ pp_hist =torch.histc(pp,bins=255,min=0,max=255)
80
+ nn_hist = torch.histc(nn,bins=255,min=0,max=255)
81
+
82
+
83
+ pp_hist_flip = torch.flipud(pp_hist)
84
+ nn_hist_flip = torch.flipud(nn_hist)
85
+
86
+ pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
87
+ nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
88
+
89
+ precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
90
+ recall = (pp_hist_flip_cum)/(gtNum + 1e-4)
91
+ f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4)
92
+
93
+ return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0]))
94
+
95
+
96
+ def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
97
+
98
+ import time
99
+ tic = time.time()
100
+
101
+ if(len(gt.shape)>2):
102
+ gt = gt[:,:,0]
103
+ # if pred.shape != gt.shape:
104
+ # plt.imshow(pred.cpu().detach().numpy())
105
+ # plt.show()
106
+ # plt.imshow(gt.cpu().detach().numpy())
107
+ # plt.show()
108
+ # pred = pred.transpose(1,0)
109
+ # print(pred.shape,gt.shape)
110
+ # print(valid_dataset.dataset["im_name"][idx]+".png")
111
+ pre, rec, f1 = f1score_torch(pred,gt)
112
+ mae = mae_torch(pred,gt)
113
+
114
+ # hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ###
115
+ if(hypar["valid_out_dir"]!=""):
116
+ if(not os.path.exists(hypar["valid_out_dir"])):
117
+ os.mkdir(hypar["valid_out_dir"])
118
+ dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx])
119
+ if(not os.path.exists(dataset_folder)):
120
+ os.mkdir(dataset_folder)
121
+ io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8))
122
+ # print(valid_dataset.dataset["im_name"][idx]+".png")
123
+ # print("time for evaluation : ", time.time()-tic)
124
+
125
+ return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy()
IS_Net/data_loader.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## data loader
2
+ ## Ackownledgement:
3
+ ## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
4
+ ## for his helps in implementing cache machanism of our DIS dataloader.
5
+ from __future__ import print_function, division
6
+
7
+ import numpy as np
8
+ import random
9
+ from copy import deepcopy
10
+ import json
11
+ from tqdm import tqdm
12
+ from skimage import io
13
+ import os
14
+ from glob import glob
15
+ import matplotlib.pyplot as plt
16
+ from PIL import Image, ImageOps
17
+ import torch
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torchvision import transforms, utils
20
+ from torchvision.transforms.functional import normalize
21
+ import torch.nn.functional as F
22
+ import cv2
23
+ from scipy.ndimage import label
24
+
25
+ def show_gray_images(images, m=4):
26
+ """
27
+ 展示一组灰度图像
28
+
29
+ 参数:
30
+ images: 一个形状为(n, h, w)的数组,其中n是图像的数量,h和w分别是图像的高度和宽度。
31
+ m: 每行展示的图像数量,默认为4。
32
+
33
+ 返回值:
34
+
35
+ """
36
+ n, h, w = images.shape # 获取输入图像的数量、高度和宽度
37
+ num_rows = (n + m - 1) // m # 计算需要的行数
38
+ fig, axes = plt.subplots(num_rows, m, figsize=(m*2, num_rows*2)) # 创建画布和子图
39
+ plt.subplots_adjust(wspace=0.05, hspace=0.05) # 调整子图间的间距
40
+ for i in range(num_rows):
41
+ for j in range(m):
42
+ idx = i*m + j # 计算当前图像的索引
43
+ if idx < n:
44
+ axes[i, j].imshow(images[idx], cmap='gray') # 展示图像
45
+ axes[i, j].axis('off') # 关闭坐标轴显示
46
+ plt.show() # 显示图像
47
+ #### --------------------- DIS dataloader cache ---------------------####
48
+
49
+ def segment_connected_components(mask):
50
+ # 将mask转换为PyTorch张量
51
+ mask_tensor = torch.tensor(mask)
52
+
53
+ # 使用Scipy的label函数找到连通组件
54
+ labeled_array, num_features = label(mask_tensor.numpy())
55
+
56
+ # 创建一个字典来存储每个连通组件的像素值
57
+ components = {}
58
+ for label_idx in range(1, num_features + 1):
59
+ component_mask = (labeled_array == label_idx)
60
+ components[label_idx] = component_mask.astype(int)
61
+
62
+ return components
63
+
64
+ def FillHole(im_in):
65
+ img = np.array(im_in,dtype=np.uint8)[0]
66
+ mask = np.zeros_like(img)
67
+ contours, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
+ for contour in contours:
69
+ cv2.drawContours(mask, [contour], -1, 255, thickness=cv2.FILLED)
70
+ im_out = torch.from_numpy(mask)[None,...].float()
71
+ return im_out
72
+
73
+ def get_im_gt_name_dict(datasets, flag='valid'):
74
+ print("------------------------------", flag, "--------------------------------")
75
+ name_im_gt_mid_list = []
76
+ for i in range(len(datasets)):
77
+ print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
78
+ tmp_im_list, tmp_gt_list, tmp_mid_list = [], [], []
79
+ tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
80
+
81
+ # img_name_dict[im_dirs[i][0]] = tmp_im_list
82
+ # print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
83
+
84
+ if(datasets[i]["gt_dir"]==""):
85
+ print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
86
+ tmp_gt_list = []
87
+ else:
88
+ tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
89
+
90
+ # lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
91
+ # print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
92
+
93
+ if(datasets[i]["mid_dir"]==""):
94
+ print('-mid-', datasets[i]["name"], datasets[i]["mid_dir"], ': ', 'No mid Found')
95
+ tmp_mid_list = []
96
+ else:
97
+ tmp_mid_list = [datasets[i]["mid_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["mid_ext"] for x in tmp_im_list]
98
+
99
+ # lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
100
+ # print('-mid-', datasets[i]["name"],datasets[i]["mid_dir"], ': ',len(tmp_gt_list))
101
+
102
+
103
+
104
+ if flag=="train": ## combine multiple training sets into one dataset
105
+ if len(name_im_gt_mid_list)==0:
106
+ name_im_gt_mid_list.append({"dataset_name":datasets[i]["name"],
107
+ "im_path":tmp_im_list,
108
+ "gt_path":tmp_gt_list,
109
+ "mid_path":tmp_mid_list,
110
+ "im_ext":datasets[i]["im_ext"],
111
+ "gt_ext":datasets[i]["gt_ext"],
112
+ "mid_ext":datasets[i]["mid_ext"],
113
+ "cache_dir":datasets[i]["cache_dir"]})
114
+ else:
115
+ name_im_gt_mid_list[0]["dataset_name"] = name_im_gt_mid_list[0]["dataset_name"] + "_" + datasets[i]["name"]
116
+ name_im_gt_mid_list[0]["im_path"] = name_im_gt_mid_list[0]["im_path"] + tmp_im_list
117
+ name_im_gt_mid_list[0]["gt_path"] = name_im_gt_mid_list[0]["gt_path"] + tmp_gt_list
118
+ name_im_gt_mid_list[0]["mid_path"] = name_im_gt_mid_list[0]["mid_path"] + tmp_mid_list
119
+ if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
120
+ print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
121
+ exit()
122
+ name_im_gt_mid_list[0]["im_ext"] = ".jpg"
123
+ name_im_gt_mid_list[0]["gt_ext"] = ".png"
124
+ name_im_gt_mid_list[0]["mid_ext"] = ".png"
125
+ name_im_gt_mid_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_mid_list[0]["dataset_name"]
126
+ else: ## keep different validation or inference datasets as separate ones
127
+ name_im_gt_mid_list.append({"dataset_name":datasets[i]["name"],
128
+ "im_path":tmp_im_list,
129
+ "gt_path":tmp_gt_list,
130
+ "mid_path":tmp_mid_list,
131
+ "im_ext":datasets[i]["im_ext"],
132
+ "gt_ext":datasets[i]["gt_ext"],
133
+ "mid_ext":datasets[i]["mid_ext"],
134
+ "cache_dir":datasets[i]["cache_dir"]})
135
+
136
+ return name_im_gt_mid_list
137
+
138
+ def create_dataloaders(name_im_gt_mid_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False,is_train=True):
139
+ ## model="train": return one dataloader for training
140
+ ## model="valid": return a list of dataloaders for validation or testing
141
+
142
+ gos_dataloaders = []
143
+ gos_datasets = []
144
+
145
+ if(len(name_im_gt_mid_list)==0):
146
+ return gos_dataloaders, gos_datasets
147
+
148
+ num_workers_ = 0
149
+ # if(batch_size>1):
150
+ # num_workers_ = 2
151
+ # if(batch_size>4):
152
+ # num_workers_ = 4
153
+ # if(batch_size>8):
154
+ # num_workers_ = 8
155
+
156
+ for i in range(0,len(name_im_gt_mid_list)):
157
+ gos_dataset = GOSDatasetCache([name_im_gt_mid_list[i]],
158
+ cache_size = cache_size,
159
+ cache_path = name_im_gt_mid_list[i]["cache_dir"],
160
+ cache_boost = cache_boost,
161
+ transform = transforms.Compose(my_transforms),
162
+ is_train=is_train)
163
+ gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
164
+ gos_datasets.append(gos_dataset)
165
+
166
+ return gos_dataloaders, gos_datasets
167
+
168
+ def im_reader(im_path):
169
+ image = Image.open(im_path).convert('RGB')
170
+ corrected_image = ImageOps.exif_transpose(image)
171
+ # return plt.imread(im_path)
172
+ return np.array(corrected_image)
173
+
174
+ def im_preprocess(im,size):
175
+ if len(im.shape) > 3:
176
+ im = im[:,:,:3]
177
+ if len(im.shape) < 3:
178
+ im = im[:, :, np.newaxis]
179
+ if im.shape[2] == 1:
180
+ im = np.repeat(im, 3, axis=2)
181
+ im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
182
+ im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
183
+ if(len(size)<2):
184
+ return im_tensor, im.shape[0:2]
185
+ else:
186
+ im_tensor = torch.unsqueeze(im_tensor,0)
187
+ im_tensor = F.upsample(im_tensor, size, mode="bilinear")
188
+ im_tensor = torch.squeeze(im_tensor,0)
189
+
190
+ return im_tensor.type(torch.uint8), im.shape[0:2]
191
+
192
+ def gt_preprocess(gt,size):
193
+ if len(gt.shape) > 2:
194
+ gt = gt[:, :, 0]
195
+
196
+ gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
197
+
198
+ if(len(size)<2):
199
+ return gt_tensor.type(torch.uint8), gt.shape[0:2]
200
+ else:
201
+ gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
202
+ gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
203
+ gt_tensor = torch.squeeze(gt_tensor,0)
204
+
205
+ return gt_tensor.type(torch.uint8), gt.shape[0:2]
206
+ # return gt_tensor, gt.shape[0:2]
207
+
208
+ class GOSRandomHFlip(object):
209
+ def __init__(self,prob=0.25):
210
+ self.prob = prob
211
+ def __call__(self,sample):
212
+ imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
213
+
214
+ # random horizontal flip
215
+ randomnum = random.random()
216
+ if randomnum <= self.prob:
217
+ image = torch.flip(image,dims=[2])
218
+ label = torch.flip(label,dims=[2])
219
+ box = torch.flip(box,dims=[2])
220
+ mask = torch.flip(mask,dims=[2])
221
+ elif randomnum <= self.prob*2:
222
+ image = torch.flip(image,dims=[1])
223
+ label = torch.flip(label,dims=[1])
224
+ box = torch.flip(box,dims=[1])
225
+ mask = torch.flip(mask,dims=[1])
226
+ elif randomnum <= self.prob*3:
227
+ image = torch.flip(image,dims=[2])
228
+ label = torch.flip(label,dims=[2])
229
+ box = torch.flip(box,dims=[2])
230
+ mask = torch.flip(mask,dims=[2])
231
+ image = torch.flip(image,dims=[1])
232
+ label = torch.flip(label,dims=[1])
233
+ box = torch.flip(box,dims=[1])
234
+ mask = torch.flip(mask,dims=[1])
235
+
236
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
237
+
238
+ class GOSResize(object):
239
+ def __init__(self,size=[320,320]):
240
+ self.size = size
241
+ def __call__(self,sample):
242
+ imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
243
+
244
+ # import time
245
+ # start = time.time()
246
+
247
+ image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
248
+ label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
249
+
250
+ # print("time for resize: ", time.time()-start)
251
+
252
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
253
+
254
+ class GOSRandomCrop(object):
255
+ def __init__(self,size=[288,288]):
256
+ self.size = size
257
+ def __call__(self,sample):
258
+ imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
259
+
260
+ h, w = image.shape[1:]
261
+ new_h, new_w = self.size
262
+
263
+ top = np.random.randint(0, h - new_h)
264
+ left = np.random.randint(0, w - new_w)
265
+
266
+ image = image[:,top:top+new_h,left:left+new_w]
267
+ label = label[:,top:top+new_h,left:left+new_w]
268
+
269
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
270
+
271
+
272
+ class GOSNormalize(object):
273
+ def __init__(self, mean=[0.485,0.456,0.406,0], std=[0.229,0.224,0.225,1.0]):
274
+ self.mean = mean
275
+ self.std = std
276
+
277
+ def __call__(self,sample):
278
+
279
+ imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
280
+ # print(image.shape)
281
+ image = normalize(image,self.mean,self.std)
282
+ mask = normalize(mask,0,1)
283
+ box = normalize(box,0,1)
284
+
285
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
286
+
287
+ class GOSRandomthorw(object):
288
+ def __init__(self,ratio=0.25):
289
+ self.ratio = ratio
290
+ def __call__(self,sample):
291
+ imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
292
+ randomnum = random.random()
293
+ if randomnum < self.ratio:
294
+ mask = torch.zeros_like(mask)
295
+ elif randomnum < self.ratio*2:
296
+ box = torch.zeros_like(box)
297
+ elif randomnum < self.ratio*3:
298
+ mask = torch.zeros_like(mask)
299
+ box = torch.zeros_like(box)
300
+
301
+ return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
302
+
303
+ class GOSDatasetCache(Dataset):
304
+
305
+ def __init__(self, name_im_gt_mid_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None, is_train=True):
306
+
307
+ self.is_train = is_train
308
+ self.cache_size = cache_size
309
+ self.cache_path = cache_path
310
+ self.cache_file_name = cache_file_name
311
+ self.cache_boost_name = ""
312
+
313
+ self.cache_boost = cache_boost
314
+ # self.ims_npy = None
315
+ # self.gts_npy = None
316
+
317
+ ## cache all the images and ground truth into a single pytorch tensor
318
+ self.ims_pt = None
319
+ self.gts_pt = None
320
+ self.mid_pt = None
321
+
322
+ ## we will cache the npy as well regardless of the cache_boost
323
+ # if(self.cache_boost):
324
+ self.cache_boost_name = cache_file_name.split('.json')[0]
325
+
326
+ self.transform = transform
327
+
328
+ self.dataset = {}
329
+
330
+ ## combine different datasets into one
331
+ dataset_names = []
332
+ dt_name_list = [] # dataset name per image
333
+ im_name_list = [] # image name
334
+ im_path_list = [] # im path
335
+ gt_path_list = [] # gt path
336
+ mid_path_list = []
337
+ im_ext_list = [] # im ext
338
+ gt_ext_list = [] # gt ext
339
+ mid_ext_list = []
340
+ for i in range(0,len(name_im_gt_mid_list)):
341
+ dataset_names.append(name_im_gt_mid_list[i]["dataset_name"])
342
+ # dataset name repeated based on the number of images in this dataset
343
+ dt_name_list.extend([name_im_gt_mid_list[i]["dataset_name"] for x in name_im_gt_mid_list[i]["im_path"]])
344
+ im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_mid_list[i]["im_ext"])[0] for x in name_im_gt_mid_list[i]["im_path"]])
345
+ im_path_list.extend(name_im_gt_mid_list[i]["im_path"])
346
+ gt_path_list.extend(name_im_gt_mid_list[i]["gt_path"])
347
+ mid_path_list.extend(name_im_gt_mid_list[i]["mid_path"])
348
+ im_ext_list.extend([name_im_gt_mid_list[i]["im_ext"] for x in name_im_gt_mid_list[i]["im_path"]])
349
+ gt_ext_list.extend([name_im_gt_mid_list[i]["gt_ext"] for x in name_im_gt_mid_list[i]["gt_path"]])
350
+ mid_ext_list.extend([name_im_gt_mid_list[i]["mid_ext"] for x in name_im_gt_mid_list[i]["mid_path"]])
351
+
352
+
353
+ self.dataset["data_name"] = dt_name_list
354
+ self.dataset["im_name"] = im_name_list
355
+ self.dataset["im_path"] = im_path_list
356
+ self.dataset["ori_im_path"] = deepcopy(im_path_list)
357
+ self.dataset["gt_path"] = gt_path_list
358
+ self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
359
+ self.dataset["mid_path"] = mid_path_list
360
+ self.dataset["ori_mid_path"] = deepcopy(mid_path_list)
361
+ self.dataset["im_shp"] = []
362
+ self.dataset["gt_shp"] = []
363
+ self.dataset["mid_shp"] = []
364
+ self.dataset["im_ext"] = im_ext_list
365
+ self.dataset["gt_ext"] = gt_ext_list
366
+ self.dataset["mid_ext"] = mid_ext_list
367
+
368
+
369
+ self.dataset["ims_pt_dir"] = ""
370
+ self.dataset["gts_pt_dir"] = ""
371
+ self.dataset["mid_pt_dir"] = ""
372
+
373
+ self.dataset = self.manage_cache(dataset_names)
374
+
375
+ def manage_cache(self,dataset_names):
376
+ if not os.path.exists(self.cache_path): # create the folder for cache
377
+ os.makedirs(self.cache_path)
378
+ cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
379
+ # if cache_folder.__len__() > 100: cache_folder = cache_folder[:100]
380
+ if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
381
+ return self.cache(cache_folder)
382
+ return self.load_cache(cache_folder)
383
+
384
+ def cache(self,cache_folder):
385
+ os.mkdir(cache_folder)
386
+ cached_dataset = deepcopy(self.dataset)
387
+
388
+ # ims_list = []
389
+ # gts_list = []
390
+ ims_pt_list = []
391
+ gts_pt_list = []
392
+ mid_pt_list = []
393
+ for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
394
+
395
+ im_id = cached_dataset["im_name"][i]
396
+ # print("im_path: ", im_path)
397
+ im = im_reader(im_path)
398
+ im, im_shp = im_preprocess(im,self.cache_size)
399
+ im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
400
+ torch.save(im,im_cache_file)
401
+
402
+ cached_dataset["im_path"][i] = im_cache_file
403
+ if(self.cache_boost):
404
+ ims_pt_list.append(torch.unsqueeze(im,0))
405
+ # ims_list.append(im.cpu().data.numpy().astype(np.uint8))
406
+
407
+ gt = np.zeros(im.shape[0:2])
408
+ if len(self.dataset["gt_path"])!=0:
409
+ gt = im_reader(self.dataset["gt_path"][i])
410
+ gt, gt_shp = gt_preprocess(gt,self.cache_size)
411
+ gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
412
+ torch.save(gt,gt_cache_file)
413
+ if len(self.dataset["gt_path"])>0:
414
+ cached_dataset["gt_path"][i] = gt_cache_file
415
+ else:
416
+ cached_dataset["gt_path"].append(gt_cache_file)
417
+ if(self.cache_boost):
418
+ gts_pt_list.append(torch.unsqueeze(gt,0))
419
+
420
+ mid = np.zeros(im.shape[0:2])
421
+ if len(self.dataset["mid_path"])!=0:
422
+ mid = im_reader(self.dataset["mid_path"][i])
423
+ mid, mid_shp = gt_preprocess(mid,self.cache_size)
424
+ mid_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_mid.pt")
425
+ torch.save(mid,mid_cache_file)
426
+ if len(self.dataset["mid_path"])>0:
427
+ cached_dataset["mid_path"][i] = mid_cache_file
428
+ else:
429
+ cached_dataset["mid_path"].append(mid_cache_file)
430
+ if(self.cache_boost):
431
+ mid_pt_list.append(torch.unsqueeze(mid,0))
432
+
433
+ # gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
434
+
435
+ # im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
436
+ # torch.save(gt_shp, shp_cache_file)
437
+ cached_dataset["im_shp"].append(im_shp)
438
+ # self.dataset["im_shp"].append(im_shp)
439
+
440
+ # shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
441
+ # torch.save(gt_shp, shp_cache_file)
442
+ cached_dataset["gt_shp"].append(gt_shp)
443
+ # self.dataset["gt_shp"].append(gt_shp)
444
+
445
+ cached_dataset["mid_shp"].append(mid_shp)
446
+
447
+ if(self.cache_boost):
448
+ cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
449
+ cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
450
+ cached_dataset["mid_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_mids.pt')
451
+ self.ims_pt = torch.cat(ims_pt_list,dim=0)
452
+ self.gts_pt = torch.cat(gts_pt_list,dim=0)
453
+ self.mid_pt = torch.cat(mid_pt_list,dim=0)
454
+ torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
455
+ torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
456
+ torch.save(torch.cat(mid_pt_list,dim=0),cached_dataset["mid_pt_dir"])
457
+
458
+ try:
459
+ json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
460
+ json.dump(cached_dataset, json_file)
461
+ json_file.close()
462
+ except Exception:
463
+ raise FileNotFoundError("Cannot create JSON")
464
+ return cached_dataset
465
+
466
+ def load_cache(self, cache_folder):
467
+ print(os.path.join(cache_folder,self.cache_file_name))
468
+ json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
469
+ dataset = json.load(json_file)
470
+ json_file.close()
471
+ ## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
472
+ ## otherwise the pytorch tensor will be loaded
473
+ if(self.cache_boost):
474
+ # self.ims_npy = np.load(dataset["ims_npy_dir"])
475
+ # self.gts_npy = np.load(dataset["gts_npy_dir"])
476
+ self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
477
+ self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
478
+ self.mid_pt = torch.load(dataset["mid_pt_dir"], map_location='cpu')
479
+ return dataset
480
+
481
+ def __len__(self):
482
+ return len(self.dataset["im_path"])
483
+
484
+ def __getitem__(self, idx):
485
+
486
+ im = None
487
+ gt = None
488
+ mid = None
489
+ if(self.cache_boost and self.ims_pt is not None):
490
+
491
+ # start = time.time()
492
+ im = self.ims_pt[idx]#.type(torch.float32)
493
+ gt = self.gts_pt[idx]#.type(torch.float32)
494
+ mid = self.mid_pt[idx]#.type(torch.float32)
495
+ # print(idx, 'time for pt loading: ', time.time()-start)
496
+
497
+ else:
498
+ # import time
499
+ # start = time.time()
500
+ # print("tensor***")
501
+ im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
502
+ im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
503
+ gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
504
+ gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
505
+ mid_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["mid_path"][idx].split(os.sep)[-2:]))
506
+ mid = torch.load(mid_pt_path)#(self.dataset["gt_path"][idx])
507
+ # print(idx,'time for tensor loading: ', time.time()-start)
508
+
509
+
510
+ im_shp = self.dataset["im_shp"][idx]
511
+ # print("time for loading im and gt: ", time.time()-start)
512
+
513
+ box = torch.zeros_like(gt[0])+gt[0]
514
+ rows, cols = torch.where(box>0)
515
+ left = torch.min(cols)
516
+ top = torch.min(rows)
517
+ right = torch.max(cols)
518
+ bottom = torch.max(rows)
519
+ box[top:bottom,left:right] = 255
520
+ box[box!=255] = 0
521
+ box = box[None,...]
522
+ gim = torch.cat([im,mid,box],dim=0)
523
+
524
+ # start_time = time.time()
525
+ im = torch.divide(gim,255.0)
526
+ gt = torch.divide(gt,255.0)
527
+ mask = torch.divide(mid,255.0)
528
+ box = torch.divide(box,255.0)
529
+
530
+
531
+ sample = {
532
+ "imidx": torch.from_numpy(np.array(idx)),
533
+ "image": im,
534
+ "label": gt,
535
+ "mask": mask,
536
+ 'box': box,
537
+ "shape": torch.from_numpy(np.array(im_shp)),
538
+ }
539
+
540
+ if self.transform:
541
+ sample = self.transform(sample)
542
+ return sample
IS_Net/datalist.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_test = {"name": "DIS5K-test",
2
+ "im_dir": r"DIS5K/DIS5K-test/im",
3
+ "gt_dir": r"DIS5K/DIS5K-test/gt",
4
+ "mid_dir":r"DIS5K/DIS5K-test/mask",
5
+ "im_ext": ".jpg",
6
+ "gt_ext": ".png",
7
+ "mid_ext": ".png",
8
+ "cache_dir":r"DIS5K-Cache/DIS-test"}
9
+
10
+ dataset_tr = {"name": "DIS5K-TR-m",
11
+ "im_dir": r"DIS5K/DIS-TR/im",
12
+ "gt_dir": r"DIS5K/DIS-TR/gt",
13
+ "mid_dir":r"DIS5K-TR/mask",
14
+ "im_ext": ".jpg",
15
+ "gt_ext": ".png",
16
+ "mid_ext": ".png",
17
+ "cache_dir":r"DIS5K-Cache/DIS-TR-m"}
18
+
19
+ dataset_vd = {"name": "DIS5K-VD-m",
20
+ "im_dir": r"DIS5K/DIS-VD/im",
21
+ "gt_dir": r"DIS5K/DIS-VD/gt",
22
+ "mid_dir":r"DIS5K/DIS5K-VD/mask",
23
+ "im_ext": ".jpg",
24
+ "gt_ext": ".png",
25
+ "mid_ext": ".png",
26
+ "cache_dir":r"DIS5K-Cache/DIS-VD-m"}
27
+
28
+ dataset_te1 = {"name": "DIS5K-TE1-m",
29
+ "im_dir": r"DIS5K/DIS-TE1/im",
30
+ "gt_dir": r"DIS5K/DIS-TE1/gt",
31
+ "mid_dir":r"DIS5K/DIS5K-TE1/mask",
32
+ "im_ext": ".jpg",
33
+ "gt_ext": ".png",
34
+ "mid_ext": ".png",
35
+ "cache_dir":r"DIS5K-Cache/DIS-TE1-m"}
36
+
37
+ dataset_te2 = {"name": "DIS5K-TE2-m",
38
+ "im_dir": r"DIS5K/DIS-TE2/im",
39
+ "gt_dir": r"DIS5K/DIS-TE2/gt",
40
+ "mid_dir":r"DIS5K/DIS5K-TE2/mask",
41
+ "im_ext": ".jpg",
42
+ "gt_ext": ".png",
43
+ "mid_ext": ".png",
44
+ "cache_dir":r"DIS5K-Cache/DIS-TE2-m"}
45
+
46
+ dataset_te3 = {"name": "DIS5K-TE3-m",
47
+ "im_dir": r"DIS5K/DIS-TE3/im",
48
+ "gt_dir": r"DIS5K/DIS-TE3/gt",
49
+ "mid_dir":r"DIS5K/DIS5K-TE3/mask",
50
+ "im_ext": ".jpg",
51
+ "gt_ext": ".png",
52
+ "mid_ext": ".png",
53
+ "cache_dir":r"DIS5K-Cache/DIS-TE3-m"}
54
+
55
+ dataset_te4 = {"name": "DIS5K-TE4-m",
56
+ "im_dir": r"DIS5K/DIS-TE4/im",
57
+ "gt_dir": r"DIS5K/DIS-TE4/gt",
58
+ "mid_dir":r"DIS5K/DIS5K-TE4/mask",
59
+ "im_ext": ".jpg",
60
+ "gt_ext": ".png",
61
+ "mid_ext": ".png",
62
+ "cache_dir":r"DIS5K-Cache/DIS-TE4-m"}
IS_Net/models/__pycache__/isnet.cpython-311.pyc ADDED
Binary file (33.1 kB). View file
 
IS_Net/models/isnet.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import torch.nn.functional as F
5
+ from timm.models.layers import trunc_normal_, DropPath
6
+ import matplotlib.pyplot as plt
7
+ import monai
8
+
9
+ def iou_loss(pred, mask):
10
+ inter = (pred * mask).sum(dim=(2, 3)) #交集
11
+ union = (pred + mask).sum(dim=(2, 3)) - inter #并集-交集
12
+ iou = 1 - (inter + 1) / (union + 1)
13
+ return iou.mean()
14
+
15
+
16
+ bce_loss = nn.BCELoss(reduction='mean')
17
+
18
+ def muti_loss_fusion(preds, target):
19
+ loss0 = 0.0
20
+ loss = 0.0
21
+
22
+ for i in range(0,len(preds)):
23
+ # print("i: ", i, preds[i].shape)
24
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
25
+ # tmp_target = _upsample_like(target,preds[i])
26
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
27
+ loss = loss + 20*bce_loss(preds[i],tmp_target) + 0.5*iou_loss(preds[i],tmp_target)
28
+ # loss = loss + bce_loss(preds[i],tmp_target)+ iou_loss(preds[i],tmp_target)
29
+ # loss = loss + bce_loss(preds[i],tmp_target)
30
+ else:
31
+ loss = loss + 20*bce_loss(preds[i],target) + 0.5*iou_loss(preds[i],target)
32
+ # loss = loss + bce_loss(preds[i],target) + iou_loss(preds[i],target)
33
+ # loss = loss + bce_loss(preds[i],target)
34
+ if(i==0):
35
+ loss0 = loss
36
+ return loss0, loss
37
+
38
+ MSE_loss = nn.MSELoss(reduction='mean')
39
+ kl_loss = nn.KLDivLoss(reduction='mean')
40
+ l1_loss = nn.L1Loss(reduction='mean')
41
+ smooth_l1_loss = nn.SmoothL1Loss(reduction='mean')
42
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
43
+ loss0 = 0.0
44
+ loss = 0.0
45
+
46
+ for i in range(0,len(preds)):
47
+ # print("i: ", i, preds[i].shape)
48
+ if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
49
+ # tmp_target = _upsample_like(target,preds[i])
50
+ tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
51
+ loss = loss + 20*bce_loss(preds[i],tmp_target) + 0.5*iou_loss(preds[i],tmp_target)
52
+ # loss = loss + bce_loss(preds[i],tmp_target) + iou_loss(preds[i],tmp_target)
53
+ # loss = loss + bce_loss(preds[i],tmp_target)
54
+ else:
55
+ loss = loss + 20*bce_loss(preds[i],target) + 0.5*iou_loss(preds[i],target)
56
+ # loss = loss + bce_loss(preds[i],target) + iou_loss(preds[i],target)
57
+ # loss = loss + bce_loss(preds[i],target)
58
+ if(i==0):
59
+ loss0 = loss
60
+
61
+ for i in range(0,len(dfs)):
62
+ if(mode=='MSE'):
63
+ loss = loss + MSE_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
64
+ # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
65
+ elif(mode=='KL'):
66
+ loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
67
+ # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
68
+ elif(mode=='MAE'):
69
+ loss = loss + l1_loss(dfs[i],fs[i])
70
+ # print("ls_loss: ", l1_loss(dfs[i],fs[i]))
71
+ elif(mode=='SmoothL1'):
72
+ loss = loss + smooth_l1_loss(dfs[i],fs[i])
73
+ # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
74
+
75
+ return loss0, loss
76
+
77
+ class REBNCONV(nn.Module):
78
+ def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
79
+ super(REBNCONV,self).__init__()
80
+
81
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
82
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
83
+ self.relu_s1 = nn.ReLU(inplace=True)
84
+
85
+ def forward(self,x):
86
+
87
+ hx = x
88
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
89
+
90
+ return xout
91
+
92
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
93
+ def _upsample_like(src,tar):
94
+
95
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
96
+
97
+ return src
98
+
99
+
100
+ ### RSU-7 ###
101
+ class RSU7(nn.Module):
102
+
103
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
104
+ super(RSU7,self).__init__()
105
+
106
+ self.in_ch = in_ch
107
+ self.mid_ch = mid_ch
108
+ self.out_ch = out_ch
109
+
110
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
111
+
112
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
113
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
114
+
115
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
116
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
117
+
118
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
119
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
120
+
121
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
122
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
123
+
124
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
125
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
128
+
129
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
130
+
131
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
134
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
135
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
136
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
137
+
138
+ def forward(self,x):
139
+ b, c, h, w = x.shape
140
+
141
+ hx = x
142
+ hxin = self.rebnconvin(hx)
143
+
144
+ hx1 = self.rebnconv1(hxin)
145
+ hx = self.pool1(hx1)
146
+
147
+ hx2 = self.rebnconv2(hx)
148
+ hx = self.pool2(hx2)
149
+
150
+ hx3 = self.rebnconv3(hx)
151
+ hx = self.pool3(hx3)
152
+
153
+ hx4 = self.rebnconv4(hx)
154
+ hx = self.pool4(hx4)
155
+
156
+ hx5 = self.rebnconv5(hx)
157
+ hx = self.pool5(hx5)
158
+
159
+ hx6 = self.rebnconv6(hx)
160
+
161
+ hx7 = self.rebnconv7(hx6)
162
+
163
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
164
+ hx6dup = _upsample_like(hx6d,hx5)
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
167
+ hx5dup = _upsample_like(hx5d,hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
170
+ hx4dup = _upsample_like(hx4d,hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
173
+ hx3dup = _upsample_like(hx3d,hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
176
+ hx2dup = _upsample_like(hx2d,hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
179
+
180
+ return hx1d + hxin
181
+
182
+
183
+ ### RSU-6 ###
184
+ class RSU6(nn.Module):
185
+
186
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
187
+ super(RSU6,self).__init__()
188
+
189
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
190
+
191
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
192
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
193
+
194
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
195
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
196
+
197
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
198
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
199
+
200
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
201
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
202
+
203
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
204
+
205
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
206
+
207
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
208
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
209
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
210
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
211
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
212
+
213
+ def forward(self,x):
214
+
215
+ hx = x
216
+
217
+ hxin = self.rebnconvin(hx)
218
+
219
+ hx1 = self.rebnconv1(hxin)
220
+ hx = self.pool1(hx1)
221
+
222
+ hx2 = self.rebnconv2(hx)
223
+ hx = self.pool2(hx2)
224
+
225
+ hx3 = self.rebnconv3(hx)
226
+ hx = self.pool3(hx3)
227
+
228
+ hx4 = self.rebnconv4(hx)
229
+ hx = self.pool4(hx4)
230
+
231
+ hx5 = self.rebnconv5(hx)
232
+
233
+ hx6 = self.rebnconv6(hx5)
234
+
235
+
236
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
237
+ hx5dup = _upsample_like(hx5d,hx4)
238
+
239
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
240
+ hx4dup = _upsample_like(hx4d,hx3)
241
+
242
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
243
+ hx3dup = _upsample_like(hx3d,hx2)
244
+
245
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
246
+ hx2dup = _upsample_like(hx2d,hx1)
247
+
248
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
249
+
250
+ return hx1d + hxin
251
+
252
+ ### RSU-5 ###
253
+ class RSU5(nn.Module):
254
+
255
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
256
+ super(RSU5,self).__init__()
257
+
258
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
259
+
260
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
261
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
262
+
263
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
264
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
265
+
266
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
267
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
268
+
269
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
270
+
271
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
272
+
273
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
274
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
275
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
276
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
277
+
278
+ def forward(self,x):
279
+
280
+ hx = x
281
+
282
+ hxin = self.rebnconvin(hx)
283
+
284
+ hx1 = self.rebnconv1(hxin)
285
+ hx = self.pool1(hx1)
286
+
287
+ hx2 = self.rebnconv2(hx)
288
+ hx = self.pool2(hx2)
289
+
290
+ hx3 = self.rebnconv3(hx)
291
+ hx = self.pool3(hx3)
292
+
293
+ hx4 = self.rebnconv4(hx)
294
+
295
+ hx5 = self.rebnconv5(hx4)
296
+
297
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
298
+ hx4dup = _upsample_like(hx4d,hx3)
299
+
300
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
301
+ hx3dup = _upsample_like(hx3d,hx2)
302
+
303
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
304
+ hx2dup = _upsample_like(hx2d,hx1)
305
+
306
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
307
+
308
+ return hx1d + hxin
309
+
310
+ ### RSU-4 ###
311
+ class RSU4(nn.Module):
312
+
313
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
314
+ super(RSU4,self).__init__()
315
+
316
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
317
+
318
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
319
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
320
+
321
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
322
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
323
+
324
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
325
+
326
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
327
+
328
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
329
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
330
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
331
+
332
+ def forward(self,x):
333
+
334
+ hx = x
335
+
336
+ hxin = self.rebnconvin(hx)
337
+
338
+ hx1 = self.rebnconv1(hxin)
339
+ hx = self.pool1(hx1)
340
+
341
+ hx2 = self.rebnconv2(hx)
342
+ hx = self.pool2(hx2)
343
+
344
+ hx3 = self.rebnconv3(hx)
345
+
346
+ hx4 = self.rebnconv4(hx3)
347
+
348
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
349
+ hx3dup = _upsample_like(hx3d,hx2)
350
+
351
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
352
+ hx2dup = _upsample_like(hx2d,hx1)
353
+
354
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
355
+
356
+ return hx1d + hxin
357
+
358
+ ### RSU-4F ###
359
+ class RSU4F(nn.Module):
360
+
361
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
362
+ super(RSU4F,self).__init__()
363
+
364
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
365
+
366
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
367
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
368
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
369
+
370
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
371
+
372
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
373
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
374
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
375
+
376
+ def forward(self,x):
377
+
378
+ hx = x
379
+
380
+ hxin = self.rebnconvin(hx)
381
+
382
+ hx1 = self.rebnconv1(hxin)
383
+ hx2 = self.rebnconv2(hx1)
384
+ hx3 = self.rebnconv3(hx2)
385
+
386
+ hx4 = self.rebnconv4(hx3)
387
+
388
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
389
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
390
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
391
+
392
+ return hx1d + hxin
393
+
394
+
395
+ class myrebnconv(nn.Module):
396
+ def __init__(self, in_ch=3,
397
+ out_ch=1,
398
+ kernel_size=3,
399
+ stride=1,
400
+ padding=1,
401
+ dilation=1,
402
+ groups=1):
403
+ super(myrebnconv,self).__init__()
404
+
405
+ self.conv = nn.Conv2d(in_ch,
406
+ out_ch,
407
+ kernel_size=kernel_size,
408
+ stride=stride,
409
+ padding=padding,
410
+ dilation=dilation,
411
+ groups=groups)
412
+ self.bn = nn.BatchNorm2d(out_ch)
413
+ self.rl = nn.ReLU(inplace=True)
414
+
415
+ def forward(self,x):
416
+ return self.rl(self.bn(self.conv(x)))
417
+
418
+
419
+ class ISNetGTEncoder(nn.Module):
420
+
421
+ def __init__(self,in_ch=1,out_ch=1):
422
+ super(ISNetGTEncoder,self).__init__()
423
+
424
+ self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
425
+
426
+ self.stage1 = RSU7(16,16,64)
427
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
428
+
429
+ self.stage2 = RSU6(64,16,64)
430
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
431
+
432
+ self.stage3 = RSU5(64,32,128)
433
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
434
+
435
+ self.stage4 = RSU4(128,32,256)
436
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
437
+
438
+ self.stage5 = RSU4F(256,64,512)
439
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
440
+
441
+ self.stage6 = RSU4F(512,64,512)
442
+
443
+
444
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
445
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
446
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
447
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
448
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
449
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
450
+
451
+ def compute_loss(self, preds, targets):
452
+
453
+ return muti_loss_fusion(preds,targets)
454
+
455
+ def forward(self,x):
456
+
457
+ hx = x
458
+
459
+ hxin = self.conv_in(hx)
460
+ # hx = self.pool_in(hxin)
461
+
462
+ #stage 1
463
+ hx1 = self.stage1(hxin)
464
+ hx = self.pool12(hx1)
465
+
466
+ #stage 2
467
+ hx2 = self.stage2(hx)
468
+ hx = self.pool23(hx2)
469
+
470
+ #stage 3
471
+ hx3 = self.stage3(hx)
472
+ hx = self.pool34(hx3)
473
+
474
+ #stage 4
475
+ hx4 = self.stage4(hx)
476
+ hx = self.pool45(hx4)
477
+
478
+ #stage 5
479
+ hx5 = self.stage5(hx)
480
+ hx = self.pool56(hx5)
481
+
482
+ #stage 6
483
+ hx6 = self.stage6(hx)
484
+
485
+
486
+ #side output
487
+ d1 = self.side1(hx1)
488
+ d1 = _upsample_like(d1,x)
489
+
490
+ d2 = self.side2(hx2)
491
+ d2 = _upsample_like(d2,x)
492
+
493
+ d3 = self.side3(hx3)
494
+ d3 = _upsample_like(d3,x)
495
+
496
+ d4 = self.side4(hx4)
497
+ d4 = _upsample_like(d4,x)
498
+
499
+ d5 = self.side5(hx5)
500
+ d5 = _upsample_like(d5,x)
501
+
502
+ d6 = self.side6(hx6)
503
+ d6 = _upsample_like(d6,x)
504
+
505
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
506
+
507
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
508
+
509
+ class ISNetDIS(nn.Module):
510
+
511
+ def __init__(self,in_ch=3,out_ch=1):
512
+ super(ISNetDIS,self).__init__()
513
+
514
+ self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
515
+ self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
516
+
517
+
518
+ self.stage1 = RSU7(64,32,64)
519
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
520
+
521
+ self.stage2 = RSU6(64,32,128)
522
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
523
+
524
+ self.stage3 = RSU5(128,64,256)
525
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
526
+
527
+ self.stage4 = RSU4(256,128,512)
528
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
529
+
530
+ self.stage5 = RSU4F(512,256,512)
531
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
532
+
533
+ self.stage6 = RSU4F(512,256,512)
534
+
535
+ # decoder
536
+ self.stage5d = RSU4F(1024,256,512)
537
+ self.stage4d = RSU4(1024,128,256)
538
+ self.stage3d = RSU5(512,64,128)
539
+ self.stage2d = RSU6(256,32,64)
540
+ self.stage1d = RSU7(128,16,64)
541
+
542
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
543
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
544
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
545
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
546
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
547
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
548
+
549
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
550
+
551
+ def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
552
+
553
+ # return muti_loss_fusion(preds,targets)
554
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
555
+
556
+ def compute_loss(self, preds, targets):
557
+
558
+ # return muti_loss_fusion(preds,targets)
559
+ return muti_loss_fusion(preds, targets)
560
+
561
+ def forward(self,x):
562
+
563
+ hx = x
564
+
565
+ hxin = self.conv_in(hx)
566
+
567
+ #stage 1
568
+ hx1 = self.stage1(hxin)
569
+ hx = self.pool12(hx1)
570
+
571
+ #stage 2
572
+ hx2 = self.stage2(hx)
573
+ hx = self.pool23(hx2)
574
+
575
+ #stage 3
576
+ hx3 = self.stage3(hx)
577
+ hx = self.pool34(hx3)
578
+
579
+ #stage 4
580
+ hx4 = self.stage4(hx)
581
+ hx = self.pool45(hx4)
582
+
583
+ #stage 5
584
+ hx5 = self.stage5(hx)
585
+ hx = self.pool56(hx5)
586
+
587
+ #stage 6
588
+ hx6 = self.stage6(hx)
589
+
590
+ hx6up = _upsample_like(hx6,hx5)
591
+
592
+ #-------------------- decoder --------------------
593
+ hx5d = self.stage5d(torch.cat([hx6up,hx5],1))
594
+ hx5dup = _upsample_like(hx5d,hx4)
595
+
596
+ hx4d = self.stage4d(torch.cat([hx5dup,hx4],1))
597
+ hx4dup = _upsample_like(hx4d,hx3)
598
+
599
+ hx3d = self.stage3d(torch.cat([hx4dup,hx3],1))
600
+ hx3dup = _upsample_like(hx3d,hx2)
601
+
602
+ hx2d = self.stage2d(torch.cat([hx3dup,hx2],1))
603
+ hx2dup = _upsample_like(hx2d,hx1)
604
+
605
+ hx1d = self.stage1d(torch.cat([hx2dup,hx1],1))
606
+
607
+
608
+ #side output
609
+ d1 = self.side1(hx1d)
610
+ d1 = _upsample_like(d1,x)
611
+
612
+ d2 = self.side2(hx2d)
613
+ d2 = _upsample_like(d2,x)
614
+
615
+ d3 = self.side3(hx3d)
616
+ d3 = _upsample_like(d3,x)
617
+
618
+ d4 = self.side4(hx4d)
619
+ d4 = _upsample_like(d4,x)
620
+
621
+ d5 = self.side5(hx5d)
622
+ d5 = _upsample_like(d5,x)
623
+
624
+ d6 = self.side6(hx6)
625
+ d6 = _upsample_like(d6,x)
626
+
627
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
628
+ # plt.imshow(hx1d[0][0].cpu().detach().numpy(),cmap='gray')
629
+ # plt.show()
630
+ # plt.imshow(hx2d[0][0].cpu().detach().numpy(),cmap='gray')
631
+ # plt.show()
632
+ # plt.imshow(hx3d[0][0].cpu().detach().numpy(),cmap='gray')
633
+ # plt.show()
634
+ # plt.imshow(hx4d[0][0].cpu().detach().numpy(),cmap='gray')
635
+ # plt.show()
636
+ # plt.imshow(hx5d[0][0].cpu().detach().numpy(),cmap='gray')
637
+ # plt.show()
638
+ # plt.imshow(hx6[0][0].cpu().detach().numpy(),cmap='gray')
639
+ # plt.show()
640
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
IS_Net/saliency_toolbox.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import numpy as np
5
+ from glob import glob
6
+ from tqdm import tqdm
7
+ from scipy.ndimage import correlate
8
+ from scipy.ndimage.morphology import distance_transform_edt
9
+ from joblib import Parallel, delayed
10
+
11
+ eps = sys.float_info.epsilon
12
+
13
+ def calcualte_once(gt_name,sm_dir,gt_threshold,beta,measures):
14
+ values = dict()
15
+ for idx in measures:
16
+ values[idx] = list()
17
+ if idx == 'Max-F':
18
+ values['Precision'] = list()
19
+ values['Recall'] = list()
20
+ _, name = os.path.split(gt_name)
21
+ sm_name = os.path.join(sm_dir, name)
22
+
23
+ if os.path.exists(sm_name):
24
+
25
+ gt, sm = read_and_normalize(gt_name, sm_name, gt_threshold)
26
+
27
+ if 'MAE' in measures:
28
+ values['MAE'].append(mean_square_error(gt, sm))
29
+ if 'E-measure' in measures:
30
+ values['E-measure'].append(e_measure(gt, sm))
31
+ if 'S-measure' in measures:
32
+ values['S-measure'].append(s_measure(gt, sm))
33
+ if 'Adp-F' in measures:
34
+ values['Adp-F'].append(adaptive_fmeasure(gt, sm, beta))
35
+ if 'Wgt-F' in measures:
36
+ values['Wgt-F'].append(weighted_fmeasure(gt, sm))
37
+ if 'Max-F' in measures:
38
+ prec, recall = prec_recall(gt, sm, 256) # 256 thresholds between 0 and 1
39
+ values['Precision'].append(prec)
40
+ values['Recall'].append(recall)
41
+ else:
42
+ print("\n{} not found!".format(os.path.basename(sm_name)))
43
+ print('---' * 10)
44
+ return values
45
+
46
+ def calculate_measures(gt_dir, sm_dir, measures, save=False, beta=np.sqrt(0.3), gt_threshold=0.5, n_thread=1):
47
+ """
48
+ function that calculates Saliency measures for given directories
49
+
50
+ arameters
51
+ ----------
52
+ gt_dir : str
53
+ The path to the ground truth directory
54
+ sm_dir : str
55
+ The path to the predicted saliency map directory
56
+ measures : list
57
+ list of measure names which need to be calculated
58
+ supported measures: 'MAE' => Mean Squared Error
59
+ 'E-measure' => Enhanced-alignment measure
60
+ 'S-measure' => Structure-measure
61
+ 'Max-F' => Maximum F-measure
62
+ 'Adp-F' => Adaptive F-measure
63
+ 'Wgt-F' => Weighted F-measure
64
+ save : str
65
+ If spesified, the results will be saved in 'save' directory
66
+ beta : float
67
+ beta parameter that is used in F-measure formula. default is sqrt(0.3)
68
+ gt_threshold : float
69
+ The threshold that is used to binrize ground truth maps.
70
+
71
+ Returns
72
+ -------
73
+ values : dictionary
74
+ a dict containing the results
75
+ """
76
+
77
+ values = dict()
78
+ for idx in measures:
79
+ values[idx] = list()
80
+ if idx == 'Max-F':
81
+ values['Precision'] = list()
82
+ values['Recall'] = list()
83
+
84
+ results = Parallel(n_jobs=n_thread)(delayed(calcualte_once)(gt_name,sm_dir,gt_threshold,beta,measures) for gt_name in tqdm(glob(os.path.join(gt_dir, '*')), total=len(glob(os.path.join(gt_dir, '*')))))
85
+ for i in results:
86
+ if 'MAE' in measures:
87
+ values['MAE'].append(i["MAE"])
88
+ if 'E-measure' in measures:
89
+ values['E-measure'].append(i["E-measure"])
90
+ if 'S-measure' in measures:
91
+ values['S-measure'].append(i["S-measure"])
92
+ if 'Adp-F' in measures:
93
+ values['Adp-F'].append(i["Adp-F"])
94
+ if 'Wgt-F' in measures:
95
+ values['Wgt-F'].append(i["Wgt-F"])
96
+ if 'Max-F' in measures: # 256 thresholds between 0 and 1
97
+ values['Precision'].append(i["Precision"])
98
+ values['Recall'].append(i["Recall"])
99
+
100
+ if 'MAE' in measures:
101
+ values['MAE'] = np.mean(values['MAE'])
102
+
103
+ if 'E-measure' in measures:
104
+ values['E-measure'] = np.mean(values['E-measure'])
105
+
106
+ if 'S-measure' in measures:
107
+ values['S-measure'] = np.mean(values['S-measure'])
108
+
109
+ if 'Adp-F' in measures:
110
+ values['Adp-F'] = np.mean(values['Adp-F'])
111
+
112
+ if 'Wgt-F' in measures:
113
+ values['Wgt-F'] = np.mean(values['Wgt-F'])
114
+
115
+ if 'Max-F' in measures:
116
+ values['Precision'] = np.mean(np.hstack(values['Precision'][:]), 1)
117
+ values['Recall'] = np.mean(np.hstack(values['Recall'][:]), 1)
118
+ f_measures = (1 + beta ** 2) * values['Precision'] * values['Recall'] / (
119
+ beta ** 2 * values['Precision'] + values['Recall'])
120
+ values['Fmeasure_all_thresholds'] = f_measures
121
+ values['Max-F'] = np.max(f_measures)
122
+
123
+ if save:
124
+ if not os.path.isdir(save):
125
+ os.mkdir(save)
126
+ for key in values.keys():
127
+ np.save(os.path.join(save, key + ".npy"), values[key])
128
+
129
+ return values
130
+
131
+
132
+ def read_and_normalize(gt_path, sm_path, gt_threshold=0.5):
133
+ """
134
+ function that reads, normalizes and crops a ground truth and a saliency map
135
+
136
+ parameters
137
+ ----------
138
+ gt_path : str
139
+ The path to a ground truth map
140
+ sm_path : str
141
+ The path to a predicted saliency map
142
+ gt_threshold : float
143
+ The threshold that is used to binrize ground truth maps.
144
+
145
+ Returns
146
+ -------
147
+ gt_img, sm_img : numpy.ndarray
148
+ The prepared arrays
149
+ """
150
+ gt_img = norm_img(cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE))
151
+ gt_img = (gt_img >= gt_threshold).astype(np.float32)
152
+ sm_img = norm_img(cv2.imread(sm_path, cv2.IMREAD_GRAYSCALE))
153
+ if sm_img.shape[0] != gt_img.shape[0] or sm_img.shape[1] != gt_img.shape[1]:
154
+ sm_img = cv2.resize(sm_img, (gt_img.shape[1], gt_img.shape[0]))
155
+
156
+ return gt_img, sm_img
157
+
158
+
159
+ def norm_img(im):
160
+ return cv2.normalize(im.astype('float'),
161
+ None,
162
+ 0.0, 1.0,
163
+ cv2.NORM_MINMAX)
164
+
165
+
166
+ # MAE
167
+ def mean_square_error(gt, sm):
168
+ return np.mean(np.abs(sm - gt))
169
+
170
+
171
+ # E-measure
172
+ # article: https://arxiv.org/abs/1805.10421
173
+ # original code [Matlab]: https://github.com/DengPingFan/E-measure
174
+ def e_measure(gt, sm):
175
+ """
176
+ This fucntion computes the Enhanced-alignment Measure (E-Measure) between the saliency map and the ground truth
177
+ article: https://arxiv.org/abs/1805.10421
178
+ original code [Matlab]: https://github.com/DengPingFan/E-measure
179
+
180
+ parameters
181
+ ----------
182
+ gt : numpy.ndarray
183
+ The path to the ground truth directory
184
+ sm : numpy.ndarray
185
+ The path to the predicted saliency map directory
186
+
187
+ Returns
188
+ -------
189
+ value : float
190
+ The calculated E-masure
191
+ """
192
+ sm = adptive_binary(sm)
193
+
194
+ gt = gt.astype(np.bool_)
195
+ sm = sm.astype(np.bool_)
196
+
197
+ dgt = gt.astype(np.float32)
198
+ dsm = sm.astype(np.float32)
199
+
200
+ if np.sum(dgt) == 0: # if the gt is completely black
201
+ enhanced_matrix = 1.0 - dsm # only calculate the black area of intersection
202
+ elif np.mean(dgt) == 1: # if the gt is completely white
203
+ enhanced_matrix = dsm # only calcualte the white area of intersection
204
+ else:
205
+ # Normal case:
206
+ # 1.compute alignment matrix
207
+ align_matrix = alignment_term(dsm, dgt)
208
+ # 2.compute enhanced alignment matrix
209
+ enhanced_matrix = enhanced_alignment_term(align_matrix)
210
+
211
+ height, width = gt.shape
212
+ value = np.sum(enhanced_matrix) / (height * width - 1 + eps)
213
+ return value
214
+
215
+
216
+ def alignment_term(dgt, dsm):
217
+ # compute global mean
218
+ mu_fm = np.mean(dsm)
219
+ mu_gt = np.mean(dgt)
220
+
221
+ # compute the bias matrix
222
+ align_fm = dsm - mu_fm
223
+ align_gt = dgt - mu_gt
224
+
225
+ # compute alignment matrix
226
+ align_Matrix = 2 * (align_gt * align_fm) / (align_gt * align_gt + align_fm * align_fm + eps)
227
+ return align_Matrix
228
+
229
+
230
+ def enhanced_alignment_term(align_matrix):
231
+ enhanced = ((align_matrix + 1) ** 2) / 4
232
+ return enhanced
233
+
234
+
235
+ def adptive_binary(sm):
236
+ adaptive_threshold = 2 * np.mean(sm)
237
+
238
+ if adaptive_threshold > 1:
239
+ adaptive_threshold = 1
240
+
241
+ binary_sm = (sm >= adaptive_threshold).astype(np.float32)
242
+
243
+ return binary_sm
244
+
245
+
246
+ # S-Measure
247
+ # article: https://www.crcv.ucf.edu/papers/iccv17/1164.pdf
248
+ # Matlab code: https://github.com/DengPingFan/S-measure
249
+ def s_measure(gt, sm):
250
+ """
251
+ This fucntion computes the structural similarity (S-Measure) between the saliency map and the ground truth
252
+ article: https://www.crcv.ucf.edu/papers/iccv17/1164.pdf
253
+ original code [Matlab]: https://github.com/DengPingFan/S-measure
254
+
255
+ parameters
256
+ ----------
257
+ gt : numpy.ndarray
258
+ The path to the ground truth directory
259
+ sm : numpy.ndarray
260
+ The path to the predicted saliency map directory
261
+
262
+ Returns
263
+ -------
264
+ value : float
265
+ The calculated S-masure
266
+ """
267
+ gt_mean = np.mean(gt)
268
+
269
+ if gt_mean == 0: # if the GT is completely black
270
+ sm_mean = np.mean(sm)
271
+ measure = 1.0 - sm_mean # only calculate the area of intersection
272
+ elif gt_mean == 1: # if the GT is completely white
273
+ sm_mean = np.mean(sm)
274
+ measure = sm_mean.copy() # only calcualte the area of intersection
275
+ else:
276
+ alpha = 0.5
277
+ measure = alpha * s_object(sm, gt) + (1 - alpha) * s_region(sm, gt)
278
+ if measure < 0:
279
+ measure = 0
280
+
281
+ return measure
282
+
283
+
284
+ def ssim(gt, sm):
285
+ gt = gt.astype(np.float32)
286
+
287
+ height, width = sm.shape
288
+ num_pixels = width * height
289
+
290
+ # Compute the mean of SM,GT
291
+ sm_mean = np.mean(sm)
292
+ gt_mean = np.mean(gt)
293
+
294
+ # Compute the variance of SM,GT
295
+ sigma_x2 = np.sum(np.sum((sm - sm_mean) ** 2)) / (num_pixels - 1 + eps)
296
+ sigma_y2 = np.sum(np.sum((gt - gt_mean) ** 2)) / (num_pixels - 1 + eps)
297
+
298
+ # Compute the covariance
299
+ sigma_xy = np.sum(np.sum((sm - sm_mean) * (gt - gt_mean))) / (num_pixels - 1 + eps)
300
+
301
+ alpha = 4 * sm_mean * gt_mean * sigma_xy
302
+ beta = (sm_mean ** 2 + gt_mean ** 2) * (sigma_x2 + sigma_y2)
303
+
304
+ if alpha != 0:
305
+ ssim_value = alpha / (beta + eps)
306
+ elif alpha == 0 and beta == 0:
307
+ ssim_value = 1.0
308
+ else:
309
+ ssim_value = 0
310
+
311
+ return ssim_value
312
+
313
+
314
+ def divide_sm(sm, x, y):
315
+ # copy the 4 regions
316
+ lt = sm[:y, :x]
317
+ rt = sm[:y, x:]
318
+ lb = sm[y:, :x]
319
+ rb = sm[y:, x:]
320
+
321
+ return lt, rt, lb, rb
322
+
323
+
324
+ def divide_gt(gt, x, y):
325
+ height, width = gt.shape
326
+ area = width * height
327
+
328
+ # copy the 4 regions
329
+ lt = gt[:y, :x]
330
+ rt = gt[:y, x:]
331
+ lb = gt[y:, :x]
332
+ rb = gt[y:, x:]
333
+
334
+ # The different weight (each block proportional to the GT foreground region).
335
+ w1 = (x * y) / area
336
+ w2 = ((width - x) * y) / area
337
+ w3 = (x * (height - y)) / area
338
+ w4 = 1.0 - w1 - w2 - w3
339
+
340
+ return lt, rt, lb, rb, w1, w2, w3, w4
341
+
342
+
343
+ def centroid(gt):
344
+ # col
345
+ rows, cols = gt.shape
346
+
347
+ if np.sum(gt) == 0:
348
+ x = np.round(cols / 2)
349
+ y = np.round(rows / 2)
350
+ else:
351
+ total = np.sum(gt)
352
+ i = np.arange(cols).reshape(1, cols) + 1
353
+ j = np.arange(rows).reshape(rows, 1) + 1
354
+
355
+ x = int(np.round(np.sum(np.sum(gt, 0, keepdims=True) * i) / total))
356
+ y = int(np.round(np.sum(np.sum(gt, 1, keepdims=True) * j) / total))
357
+
358
+ return x, y
359
+
360
+
361
+ def s_region(gt, sm):
362
+ x, y = centroid(gt)
363
+ gt_1, gt_2, gt_3, gt_4, w1, w2, w3, w4 = divide_gt(gt, x, y)
364
+
365
+ sm_1, sm_2, sm_3, sm_4 = divide_sm(sm, x, y)
366
+
367
+ q1 = ssim(sm_1, gt_1)
368
+ q2 = ssim(sm_2, gt_2)
369
+ q3 = ssim(sm_3, gt_3)
370
+ q4 = ssim(sm_4, gt_4)
371
+
372
+ region_value = w1 * q1 + w2 * q2 + w3 * q3 + w4 * q4
373
+
374
+ return region_value
375
+
376
+
377
+ def object(gt, sm):
378
+ x = np.mean(sm[gt == 1])
379
+ # compute the standard deviations of the foreground or background in sm
380
+ sigma_x = np.std(sm[gt == 1])
381
+ score = 2.0 * x / (x ** 2 + 1.0 + sigma_x + eps)
382
+ return score
383
+
384
+
385
+ def s_object(gt, sm):
386
+ # compute the similarity of the foreground in the object level
387
+
388
+ sm_fg = sm.copy()
389
+ sm_fg[gt == 0] = 0
390
+ o_fg = object(sm_fg, gt)
391
+
392
+ # compute the similarity of the background
393
+ sm_bg = 1.0 - sm.copy()
394
+ sm_bg[gt == 1] = 0
395
+ o_bg = object(sm_bg, gt == 0)
396
+
397
+ u = np.mean(gt)
398
+ object_value = u * o_fg + (1 - u) * o_bg
399
+ return object_value
400
+
401
+
402
+
403
+ # Weighted F-Measure
404
+ # article: https://ieeexplore.ieee.org/document/6909433
405
+ # Matlab code: https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/FGEval/
406
+ def weighted_fmeasure(gt, sm, beta2=1):
407
+ """
408
+ This fucntion computes Weighted F-Measure between the saliency map and the ground truth
409
+ article: https://ieeexplore.ieee.org/document/6909433
410
+ original code [Matlab]: https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/FGEval/
411
+
412
+ parameters
413
+ ----------
414
+ gt : numpy.ndarray
415
+ The path to the ground truth directory
416
+ sm : numpy.ndarray
417
+ The path to the predicted saliency map directory
418
+
419
+ Returns
420
+ -------
421
+ value : float
422
+ The calculated Weighted F-Measure
423
+ """
424
+ dst, idx = distance_transform_edt(1 - gt, return_indices=True)
425
+
426
+ raw_idx = idx[0][gt == 0]
427
+ col_idx = idx[1][gt == 0]
428
+
429
+ e = np.abs(sm - gt).astype(np.float32)
430
+ et = np.abs(sm - gt).astype(np.float32)
431
+
432
+ et[gt == 0] = et[raw_idx, col_idx]
433
+
434
+ k = matlab_style_gauss2d(shape=(7, 7), sigma=5)
435
+
436
+ ea = correlate(et.astype(np.float32), k, mode='constant')
437
+ min_e_ea = np.abs(sm - gt).astype(np.float32)
438
+
439
+ min_e_ea[gt * (ea < e) == 1] = ea[gt * (ea < e) == 1]
440
+
441
+ b = np.ones_like(gt).astype(np.float32)
442
+ b[gt == 0] = 2 - 1 * np.exp(np.log(1 - 0.5) / 5. * dst[gt == 0])
443
+
444
+ ew = min_e_ea * b
445
+ tpw = np.sum(gt) - np.sum(ew[gt == 1])
446
+ fpw = np.sum(ew[gt == 0])
447
+
448
+ rec = 1 - np.mean(ew[gt == 1]) # Weighed Recall
449
+ prec = tpw / (eps + tpw + fpw) # Weighted Precision
450
+
451
+ value = (1 + beta2) * (rec * prec) / (eps + (beta2 * rec) + prec)
452
+ return value
453
+
454
+ def matlab_style_gauss2d(shape=(3, 3), sigma=0.5):
455
+ """
456
+ 2D gaussian mask - should give the same result as MATLAB's
457
+ fspecial('gaussian',[shape],[sigma])
458
+ """
459
+ m, n = [(ss - 1.) / 2. for ss in shape]
460
+ y, x = np.ogrid[-m:m + 1, -n:n + 1]
461
+ h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
462
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
463
+ sumh = h.sum()
464
+ if sumh != 0:
465
+ h /= sumh
466
+ return h
467
+
468
+
469
+
470
+ # Adaptive F-measure
471
+
472
+ def adaptive_fmeasure(gt, sm, beta):
473
+ """
474
+ This fucntion computes Adaptive F-measure between the saliency map and the ground truth using
475
+ the binary method proposed in:
476
+ https://ieeexplore.ieee.org/document/5206596
477
+
478
+ parameters
479
+ ----------
480
+ gt : numpy.ndarray
481
+ The path to the ground truth directory
482
+ sm : numpy.ndarray
483
+ The path to the predicted saliency map directory
484
+
485
+ Returns
486
+ -------
487
+ value : float
488
+ The calculated Adaptive F-measure
489
+ """
490
+ gt_idx = np.where(gt > 0)
491
+ gt_cnt = np.sum(gt)
492
+
493
+ if gt_cnt == 0:
494
+ prec = []
495
+ recall = []
496
+ else:
497
+ adaptive_threshold = 2 * np.mean(sm)
498
+ if adaptive_threshold > 1:
499
+ adaptive_threshold = 1
500
+ sm_binary = (sm >= adaptive_threshold).astype(np.float32)
501
+ hit_cnt = np.sum(sm_binary[gt_idx])
502
+ alg_cnt = np.sum(sm_binary)
503
+
504
+ if hit_cnt == 0:
505
+ prec = 0
506
+ recall = 0
507
+ else:
508
+ prec = hit_cnt / (alg_cnt + eps)
509
+ recall = hit_cnt / gt_cnt
510
+ value = (1 + beta ** 2) * prec * recall / ((beta ** 2 * prec + recall) + eps)
511
+ return value
512
+
513
+
514
+
515
+ def prec_recall(gt, sm, num_th):
516
+ """
517
+ This fucntion computes Adaptive F-measure between the saliency map and the ground truth using
518
+ the binary method proposed in:
519
+ https://ieeexplore.ieee.org/document/5206596
520
+ The results of this dunction will be used to calculate Max-F measure and plot PR and F-Threshold Curves
521
+ parameters
522
+ ----------
523
+ gt : numpy.ndarray
524
+ The path to the ground truth directory
525
+ sm : numpy.ndarray
526
+ The path to the predicted saliency map directory
527
+ num_th : interger
528
+ The total number of thresholds between 0 and 1
529
+ Returns
530
+ -------
531
+ prec, recall: numpy.ndarray
532
+ The calculated Precision and Recall (shape: (num_th,1))
533
+ """
534
+ gt_idx = np.where(gt > 0)
535
+ gt_cnt = np.sum(gt)
536
+
537
+ if gt_cnt == 0:
538
+ prec = []
539
+ recall = []
540
+ else:
541
+ hit_cnt = np.zeros((num_th, 1), np.float32)
542
+ alg_cnt = np.zeros((num_th, 1), np.float32)
543
+ thresholds = np.linspace(0, 1, num_th)
544
+ for k, curTh in enumerate(thresholds):
545
+ sm_binary = (sm >= curTh).astype(np.float32)
546
+ hit_cnt[k] = np.sum(sm_binary[gt_idx])
547
+ alg_cnt[k] = np.sum(sm_binary)
548
+
549
+ prec = hit_cnt / (alg_cnt + eps)
550
+ recall = hit_cnt / gt_cnt
551
+
552
+ return prec, recall
IS_Net/swd_optim/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .adai import Adai
3
+ from .adais import AdaiS
4
+ from .adams import AdamS
5
+ from .sgds import SGDS
6
+
7
+ del adai
8
+ del adais
9
+ del adams
10
+ del sgds
IS_Net/swd_optim/adai.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim.optimizer import Optimizer, required
3
+
4
+ class Adai(Optimizer):
5
+ r"""Implements Adaptive Inertia Estimation (Adai) algorithm.
6
+ It has be proposed in
7
+ `Adai: Separating the Effects of Adaptive Learning Rate and Momentum Inertia`__.
8
+
9
+ Arguments:
10
+ params (iterable): iterable of parameters to optimize or dicts defining
11
+ parameter groups
12
+ lr (float): learning rate
13
+ betas (Tuple[float, float], optional): beta0 and beta2 (default: (0.1, 0.99))
14
+ eps (float, optional): the inertia bound (default: 1e-03)
15
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
16
+
17
+ """
18
+
19
+ def __init__(self, params, lr=required, betas=(0.1, 0.99), eps=1e-03,
20
+ weight_decay=0):
21
+ if lr is not required and lr < 0.0:
22
+ raise ValueError("Invalid learning rate: {}".format(lr))
23
+ if not 0.0 <= eps:
24
+ raise ValueError("Invalid epsilon value: {}".format(eps))
25
+ if not 0.0 <= betas[0]:
26
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
27
+ if not 0.0 <= betas[1] < 1.0:
28
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
29
+ if not 0.0 <= weight_decay:
30
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
31
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
32
+ super(Adai, self).__init__(params, defaults)
33
+
34
+
35
+ def __setstate__(self, state):
36
+ super(Adai, self).__setstate__(state)
37
+
38
+ @torch.no_grad()
39
+ def step(self, closure=None):
40
+ """Performs a single optimization step.
41
+
42
+ Arguments:
43
+ closure (callable, optional): A closure that reevaluates the model
44
+ and returns the loss.
45
+ """
46
+ loss = None
47
+ if closure is not None:
48
+ loss = closure()
49
+
50
+ param_size = 0
51
+ exp_avg_sq_hat_sum = 0.
52
+
53
+ for group in self.param_groups:
54
+ for p in group['params']:
55
+ if p.grad is None:
56
+ continue
57
+ param_size += p.numel()
58
+ grad = p.grad.data
59
+
60
+ state = self.state[p]
61
+
62
+ # State initialization
63
+ if len(state) == 0:
64
+ state['step'] = 0
65
+ # Exponential moving average of gradient values
66
+ state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
67
+ # Exponential moving average of squared gradient values
68
+ state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
69
+ # Cumulative products of beta1
70
+ state['beta1_prod'] = torch.ones_like(p.data, memory_format=torch.preserve_format)
71
+
72
+ state['step'] += 1
73
+
74
+ exp_avg_sq = state['exp_avg_sq']
75
+ beta0, beta2 = group['betas']
76
+
77
+ bias_correction2 = 1 - beta2 ** state['step']
78
+
79
+ if group['weight_decay'] != 0:
80
+ grad.add_(group['weight_decay'], p.data)
81
+
82
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
83
+
84
+ exp_avg_sq_hat_sum += exp_avg_sq.sum() / bias_correction2
85
+
86
+ # Calculate the mean of all elements in exp_avg_sq_hat
87
+ exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size
88
+
89
+ for group in self.param_groups:
90
+ for p in group['params']:
91
+ if p.grad is None:
92
+ continue
93
+ grad = p.grad.data
94
+
95
+ state = self.state[p]
96
+
97
+ exp_avg = state['exp_avg']
98
+ exp_avg_sq = state['exp_avg_sq']
99
+ beta1_prod = state['beta1_prod']
100
+ beta0, beta2 = group['betas']
101
+
102
+ bias_correction2 = 1 - beta2 ** state['step']
103
+
104
+ exp_avg_sq_hat = exp_avg_sq / bias_correction2
105
+ beta1 = (1. - (exp_avg_sq_hat / exp_avg_sq_hat_mean).mul(beta0)).clamp(0., 1 - group['eps'])
106
+
107
+ beta1_prod.mul_(beta1)
108
+ bias_correction1 = 1 - beta1_prod
109
+
110
+ exp_avg.mul_(beta1).addcmul_(1 - beta1, grad)
111
+ exp_avg_hat = exp_avg / bias_correction1
112
+
113
+ step_size = group['lr']
114
+ p.data.add_(-step_size, exp_avg_hat)
115
+
116
+ return loss
IS_Net/swd_optim/adais.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.optim.optimizer import Optimizer, required
3
+
4
+
5
+ class AdaiS(Optimizer):
6
+ r"""Implements Adai with stable/decoupled weight decay (AdaiS/AdaiW).
7
+ It is based on
8
+ `Adai: Separating the Effects of Adaptive Learning Rate and Momentum Inertia`
9
+ and
10
+ `Stable Weight Decay Regularization`__.
11
+
12
+ Arguments:
13
+ params (iterable): iterable of parameters to optimize or dicts defining
14
+ parameter groups
15
+ lr (float, optional): learning rate
16
+ betas (Tuple[float, float], optional): beta0 and beta2 (default: (0.1, 0.99))
17
+ eps (float, optional): the inertia bound (default: 1e-03)
18
+ weight_decay (float, optional): weight decay (default: 0)
19
+
20
+ """
21
+
22
+ def __init__(self, params, lr=required, betas=(0.1, 0.99), eps=1e-03,
23
+ weight_decay=0):
24
+ if lr is not required and lr < 0.0:
25
+ raise ValueError("Invalid learning rate: {}".format(lr))
26
+ if not 0.0 <= eps:
27
+ raise ValueError("Invalid epsilon value: {}".format(eps))
28
+ if not 0.0 <= betas[0]:
29
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
30
+ if not 0.0 <= betas[1] < 1.0:
31
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
32
+ if not 0.0 <= weight_decay:
33
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
34
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
35
+ super(AdaiS, self).__init__(params, defaults)
36
+
37
+
38
+ def __setstate__(self, state):
39
+ super(AdaiS, self).__setstate__(state)
40
+
41
+ @torch.no_grad()
42
+ def step(self, closure=None):
43
+ """Performs a single optimization step.
44
+
45
+ Arguments:
46
+ closure (callable, optional): A closure that reevaluates the model
47
+ and returns the loss.
48
+ """
49
+ loss = None
50
+ if closure is not None:
51
+ loss = closure()
52
+
53
+ param_size = 0
54
+ exp_avg_sq_hat_sum = 0.
55
+ for group in self.param_groups:
56
+ for p in group['params']:
57
+ if p.grad is None:
58
+ continue
59
+ param_size += p.numel()
60
+ grad = p.grad.data
61
+
62
+ state = self.state[p]
63
+
64
+ # State initialization
65
+ if len(state) == 0:
66
+ state['step'] = 0
67
+ # Exponential moving average of gradient values
68
+ state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
69
+ # Exponential moving average of squared gradient values
70
+ state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
71
+ # Cumulative products of beta1
72
+ state['beta1_prod'] = torch.ones_like(p.data, memory_format=torch.preserve_format)
73
+
74
+ exp_avg_sq = state['exp_avg_sq']
75
+ beta0, beta2 = group['betas']
76
+
77
+ state['step'] += 1
78
+ bias_correction2 = 1 - beta2 ** state['step']
79
+
80
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
81
+
82
+ exp_avg_sq_hat = exp_avg_sq / bias_correction2
83
+
84
+ exp_avg_sq_hat_sum += exp_avg_sq_hat.sum()
85
+
86
+ # Calculate the mean of all elements in exp_avg_sq_hat
87
+ exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size
88
+
89
+ for group in self.param_groups:
90
+ for p in group['params']:
91
+ if p.grad is None:
92
+ continue
93
+ grad = p.grad.data
94
+
95
+ # Perform stable/decoupled weight decay
96
+ if group['weight_decay'] !=0:
97
+ p.data.mul_(1 - group['lr'] * group['weight_decay'])
98
+
99
+ state = self.state[p]
100
+
101
+ exp_avg = state['exp_avg']
102
+ exp_avg_sq = state['exp_avg_sq']
103
+ beta0, beta2 = group['betas']
104
+ beta1_prod = state['beta1_prod']
105
+ bias_correction2 = 1 - beta2 ** state['step']
106
+
107
+ exp_avg_sq_hat = exp_avg_sq / bias_correction2
108
+
109
+ beta1 = (1. - (exp_avg_sq_hat / exp_avg_sq_hat_mean).mul(beta0)).clamp(0., 1 - group['eps'])
110
+
111
+ beta1_prod.mul_(beta1)
112
+ bias_correction1 = 1 - beta1_prod
113
+
114
+ exp_avg.mul_(beta1).addcmul_(1 - beta1, grad)
115
+ exp_avg_hat = exp_avg.div(bias_correction1)
116
+
117
+ step_size = group['lr']
118
+ p.data.add_(-step_size, exp_avg_hat)
119
+
120
+ return loss
IS_Net/swd_optim/adams.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.optim.optimizer import Optimizer
4
+
5
+
6
+ class AdamS(Optimizer):
7
+ r"""Implements Adam with stable weight decay (AdamS) algorithm.
8
+ It has be proposed in
9
+ `Stable Weight Decay Regularization`__.
10
+
11
+ Arguments:
12
+ params (iterable): iterable of parameters to optimize or dicts defining
13
+ parameter groups
14
+ lr (float, optional): learning rate (default: 1e-3)
15
+ betas (Tuple[float, float], optional): coefficients used for computing
16
+ running averages of gradient and its square (default: (0.9, 0.999))
17
+ eps (float, optional): term added to the denominator to improve
18
+ numerical stability (default: 1e-8)
19
+ weight_decay (float, optional): weight decay coefficient (default: 1e-4)
20
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
21
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
22
+ (default: False)
23
+ """
24
+
25
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
26
+ weight_decay=1e-4, amsgrad=False):
27
+ if not 0.0 <= lr:
28
+ raise ValueError("Invalid learning rate: {}".format(lr))
29
+ if not 0.0 <= eps:
30
+ raise ValueError("Invalid epsilon value: {}".format(eps))
31
+ if not 0.0 <= betas[0] < 1.0:
32
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
33
+ if not 0.0 <= betas[1] < 1.0:
34
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
35
+ if not 0.0 <= weight_decay:
36
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
37
+ defaults = dict(lr=lr, betas=betas, eps=eps,
38
+ weight_decay=weight_decay, amsgrad=amsgrad)
39
+ super(AdamS, self).__init__(params, defaults)
40
+
41
+ def __setstate__(self, state):
42
+ super(AdamS, self).__setstate__(state)
43
+ for group in self.param_groups:
44
+ group.setdefault('amsgrad', False)
45
+
46
+ @torch.no_grad()
47
+ def step(self, closure=None):
48
+ """Performs a single optimization step.
49
+
50
+ Arguments:
51
+ closure (callable, optional): A closure that reevaluates the model
52
+ and returns the loss.
53
+ """
54
+ loss = None
55
+ if closure is not None:
56
+ with torch.enable_grad():
57
+ loss = closure()
58
+
59
+ param_size = 0
60
+ exp_avg_sq_hat_sum = 0.
61
+
62
+ for group in self.param_groups:
63
+ for p in group['params']:
64
+ if p.grad is None:
65
+ continue
66
+ param_size += p.numel()
67
+
68
+ # Perform optimization step
69
+ grad = p.grad
70
+ if grad.is_sparse:
71
+ raise RuntimeError('AdamS does not support sparse gradients')
72
+ amsgrad = group['amsgrad']
73
+
74
+ state = self.state[p]
75
+
76
+ # State initialization
77
+ if len(state) == 0:
78
+ state['step'] = 0
79
+ # Exponential moving average of gradient values
80
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
81
+ # Exponential moving average of squared gradient values
82
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
83
+ if amsgrad:
84
+ # Maintains max of all exp. moving avg. of sq. grad. values
85
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
86
+
87
+ beta1, beta2 = group['betas']
88
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
89
+
90
+ state['step'] += 1
91
+ bias_correction2 = 1 - beta2 ** state['step']
92
+
93
+ # Decay the first and second moment running average coefficient
94
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
95
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
96
+ if amsgrad:
97
+ max_exp_avg_sq = state['max_exp_avg_sq']
98
+ # Maintains the maximum of all 2nd moment running avg. till now
99
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
100
+ # Use the max. for normalizing running avg. of gradient
101
+ exp_avg_sq_hat = max_exp_avg_sq / bias_correction2
102
+ else:
103
+ exp_avg_sq_hat = exp_avg_sq / bias_correction2
104
+
105
+ exp_avg_sq_hat_sum += exp_avg_sq_hat.sum()
106
+
107
+ # Calculate the sqrt of the mean of all elements in exp_avg_sq_hat
108
+ exp_avg_mean_sqrt = math.sqrt(exp_avg_sq_hat_sum / param_size)
109
+
110
+ for group in self.param_groups:
111
+ for p in group['params']:
112
+ if p.grad is None:
113
+ continue
114
+
115
+ state = self.state[p]
116
+
117
+ #Perform stable weight decay
118
+ if group['weight_decay'] !=0:
119
+ p.data.mul_(1 - group['weight_decay'] * group['lr'] / exp_avg_mean_sqrt)
120
+
121
+ beta1, beta2 = group['betas']
122
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
123
+ bias_correction1 = 1 - beta1 ** state['step']
124
+ bias_correction2 = 1 - beta2 ** state['step']
125
+
126
+ if amsgrad:
127
+ max_exp_avg_sq = state['max_exp_avg_sq']
128
+ exp_avg_sq_hat = max_exp_avg_sq / bias_correction2
129
+ else:
130
+ exp_avg_sq_hat = exp_avg_sq / bias_correction2
131
+
132
+ denom = exp_avg_sq_hat.sqrt().add(group['eps'])
133
+
134
+ step_size = group['lr'] / bias_correction1
135
+ p.addcdiv_(exp_avg, denom, value= - step_size)
136
+
137
+ return loss
IS_Net/swd_optim/sgds.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch.optim.optimizer import Optimizer, required
4
+
5
+
6
+ class SGDS(Optimizer):
7
+ r"""Implements stochastic gradient descent with stable weight decay (SGDS).
8
+ It has be proposed in
9
+ `Stable Weight Decay Regularization`__.
10
+
11
+ Args:
12
+ params (iterable): iterable of parameters to optimize or dicts defining
13
+ parameter groups
14
+ lr (float): learning rate
15
+ momentum (float, optional): momentum factor (default: 0)
16
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
17
+ dampening (float, optional): dampening for momentum (default: 0)
18
+ nesterov (bool, optional): enables Nesterov momentum (default: False)
19
+ """
20
+
21
+ def __init__(self, params, lr=required, momentum=0, dampening=0,
22
+ weight_decay=0, nesterov=False):
23
+ if lr is not required and lr < 0.0:
24
+ raise ValueError("Invalid learning rate: {}".format(lr))
25
+ if momentum < 0.0:
26
+ raise ValueError("Invalid momentum value: {}".format(momentum))
27
+ if weight_decay < 0.0:
28
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
29
+
30
+ defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
31
+ weight_decay=weight_decay, nesterov=nesterov)
32
+ if nesterov and (momentum <= 0 or dampening != 0):
33
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
34
+ super(SGDS, self).__init__(params, defaults)
35
+
36
+ def __setstate__(self, state):
37
+ super(SGDS, self).__setstate__(state)
38
+ for group in self.param_groups:
39
+ group.setdefault('nesterov', False)
40
+
41
+ @torch.no_grad()
42
+ def step(self, closure=None):
43
+ """Performs a single optimization step.
44
+ Arguments:
45
+ closure (callable, optional): A closure that reevaluates the model
46
+ and returns the loss.
47
+ """
48
+ loss = None
49
+ if closure is not None:
50
+ with torch.enable_grad():
51
+ loss = closure()
52
+
53
+ for group in self.param_groups:
54
+ momentum = group['momentum']
55
+ dampening = group['dampening']
56
+ nesterov = group['nesterov']
57
+
58
+ for p in group['params']:
59
+ if p.grad is None:
60
+ continue
61
+ d_p = p.grad
62
+
63
+ # Perform stable weight decay
64
+ if group['weight_decay'] !=0:
65
+ bias_correction = (1 - dampening) / (1 - momentum)
66
+ p.data.mul_(1 - bias_correction * group['lr'] * group['weight_decay'])
67
+
68
+ if momentum != 0:
69
+ param_state = self.state[p]
70
+ if 'momentum_buffer' not in param_state:
71
+ buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
72
+ else:
73
+ buf = param_state['momentum_buffer']
74
+ buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
75
+ if nesterov:
76
+ d_p = d_p.add(buf, alpha=momentum)
77
+ else:
78
+ d_p = buf
79
+
80
+ p.add_(d_p, alpha=-group['lr'])
81
+
82
+ return loss
IS_Net/train_valid_inference_main.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ from skimage import io
5
+ import time
6
+ import matplotlib.pyplot as plt
7
+ import torch, gc
8
+ import torch.nn as nn
9
+ from torch.autograd import Variable
10
+ import torch.optim as optim
11
+ import torch.nn.functional as F
12
+ from data_loader import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
13
+ # from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
14
+ from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
15
+ from models.isnet import ISNetGTEncoder, ISNetDIS
16
+ from torch.cuda.amp import autocast, GradScaler
17
+ from datalist import *
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
+ def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
21
+
22
+ torch.manual_seed(hypar["seed"])
23
+ if torch.cuda.is_available():
24
+ torch.cuda.manual_seed(hypar["seed"])
25
+
26
+ print("define gt encoder ...")
27
+ net = ISNetGTEncoder() #UNETGTENCODERCombine()
28
+ # if(hypar["model_digit"]=="half"):
29
+ # net.half()
30
+ ## load the existing model gt encoder
31
+ if(hypar["gt_encoder_model"]!=""):
32
+ model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"]
33
+ if torch.cuda.is_available():
34
+ net.load_state_dict(torch.load(model_path))
35
+ net.cuda()
36
+ else:
37
+ net.load_state_dict(torch.load(model_path,map_location="cpu"))
38
+ print("gt encoder restored from the saved weights ...")
39
+ return net ############
40
+
41
+ if torch.cuda.is_available():
42
+ net.cuda()
43
+
44
+ print("--- define optimizer for GT Encoder---")
45
+ # optimizer = lion.Lion(net.parameters(), lr=1e-4, betas=(0.9, 0.99))
46
+ optimizer = optim.AdamW(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
47
+ # optimizer = optim.SGD(net.parameters(), lr=1e-4)
48
+
49
+ model_path = hypar["model_path"]
50
+ model_save_fre = hypar["model_save_fre"]
51
+ max_ite = hypar["max_ite"]
52
+ batch_size_train = hypar["batch_size_train"]
53
+ batch_size_valid = hypar["batch_size_valid"]
54
+
55
+ if(not os.path.exists(model_path)):
56
+ os.mkdir(model_path)
57
+
58
+ ite_num = hypar["start_ite"] # count the total iteration number
59
+ ite_num4val = 0 #
60
+ running_loss = 0.0 # count the toal loss
61
+ running_tar_loss = 0.0 # count the target output loss
62
+ last_f1 = [0 for x in range(len(valid_dataloaders))]
63
+
64
+ train_num = train_datasets[0].__len__()
65
+
66
+ net.train()
67
+
68
+ start_last = time.time()
69
+ gos_dataloader = train_dataloaders[0]
70
+ epoch_num = hypar["max_epoch_num"]
71
+ notgood_cnt = 0
72
+ for epoch in range(epoch_num): ## set the epoch num as 100000
73
+
74
+ for i, data in enumerate(gos_dataloader):
75
+
76
+ if(ite_num >= max_ite):
77
+ print("Training Reached the Maximal Iteration Number ", max_ite)
78
+ exit()
79
+
80
+ # start_read = time.time()
81
+ ite_num = ite_num + 1
82
+ ite_num4val = ite_num4val + 1
83
+
84
+ # get the inputs
85
+ labels = data['label']
86
+
87
+ if(hypar["model_digit"]=="full"):
88
+ labels = labels.type(torch.FloatTensor)
89
+ else:
90
+ labels = labels.type(torch.HalfTensor)
91
+
92
+ # wrap them in Variable
93
+ if torch.cuda.is_available():
94
+ labels_v = Variable(labels.cuda(), requires_grad=False)
95
+ else:
96
+ labels_v = Variable(labels, requires_grad=False)
97
+
98
+ # print("time lapse for data preparation: ", time.time()-start_read, ' s')
99
+
100
+ # y zero the parameter gradients
101
+ start_inf_loss_back = time.time()
102
+ optimizer.zero_grad()
103
+
104
+ # plt.imshow(labels_v[0][0].cpu(),cmap='gray')
105
+ # plt.show()
106
+ # with autocast():
107
+ ds, fs = net(labels_v)#net(inputs_v)
108
+ loss2, loss = net.compute_loss(ds, labels_v)
109
+ # scaler.scale(loss).backward()
110
+ # loss.backward()
111
+ # scaler.step(optimizer)
112
+ # scaler.update()
113
+ #ORTHO Loss
114
+ reg = 1e-8
115
+ orth_loss = torch.zeros(1).to(device)
116
+ for name, param in net.named_parameters():
117
+ if 'bias' not in name:
118
+ param_flat = param.view(param.shape[0], -1)
119
+ sym = torch.mm(param_flat, torch.t(param_flat))
120
+ sym -= torch.eye(param_flat.shape[0]).to(param.device)
121
+ orth_loss = orth_loss + (reg * sym.abs().sum())
122
+ loss = loss + orth_loss
123
+ loss.backward()
124
+ optimizer.step()
125
+
126
+ running_loss += loss.item()
127
+ running_tar_loss += loss2.item()
128
+
129
+ # del outputs, loss
130
+ del ds, loss2, loss
131
+ end_inf_loss_back = time.time()-start_inf_loss_back
132
+
133
+ print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
134
+ epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
135
+ start_last = time.time()
136
+
137
+ if ite_num % model_save_fre == 0: # validate every 2000 iterations
138
+ notgood_cnt += 1
139
+ net.eval()
140
+ tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
141
+ # tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch)
142
+
143
+ net.train() # resume train
144
+
145
+ tmp_out = 0
146
+ print("last_f1:",last_f1,np.mean(last_f1))
147
+ print("tmp_f1:",tmp_f1,np.mean(tmp_f1))
148
+ # for fi in range(len(last_f1)):
149
+ if(np.mean(tmp_f1)>np.mean(last_f1)):
150
+ tmp_out = 1
151
+ print("tmp_out:",tmp_out)
152
+ if(tmp_out):
153
+ notgood_cnt = 0
154
+ last_f1 = tmp_f1
155
+ tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
156
+ tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
157
+ maxf1 = '_'.join(tmp_f1_str)
158
+ meanM = '_'.join(tmp_mae_str)
159
+ # .cpu().detach().numpy()
160
+ model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\
161
+ "_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
162
+ "_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
163
+ "_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
164
+ "_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
165
+ "_maxF1_" + maxf1 + \
166
+ "_mae_" + meanM + \
167
+ "_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
168
+ torch.save(net.state_dict(), model_path + model_name)
169
+
170
+ running_loss = 0.0
171
+ running_tar_loss = 0.0
172
+ ite_num4val = 0
173
+
174
+ if(np.mean(tmp_f1)>0.99):
175
+ print("GT encoder is well-trained and obtained...")
176
+ return net
177
+
178
+ if(notgood_cnt >= hypar["early_stop"]):
179
+ print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
180
+ exit()
181
+ print("Training Reaches The Maximum Epoch Number")
182
+ return net
183
+
184
+ def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
185
+ net.eval()
186
+ print("Validating...")
187
+ epoch_num = hypar["max_epoch_num"]
188
+
189
+ val_loss = 0.0
190
+ tar_loss = 0.0
191
+
192
+
193
+ tmp_f1 = []
194
+ tmp_mae = []
195
+ tmp_time = []
196
+
197
+ start_valid = time.time()
198
+ for k in range(len(valid_dataloaders)):
199
+
200
+ valid_dataloader = valid_dataloaders[k]
201
+ valid_dataset = valid_datasets[k]
202
+
203
+ val_num = valid_dataset.__len__()
204
+ mybins = np.arange(0,256)
205
+ PRE = np.zeros((val_num,len(mybins)-1))
206
+ REC = np.zeros((val_num,len(mybins)-1))
207
+ F1 = np.zeros((val_num,len(mybins)-1))
208
+ MAE = np.zeros((val_num))
209
+
210
+ val_cnt = 0.0
211
+ i_val = None
212
+
213
+ for i_val, data_val in enumerate(valid_dataloader):
214
+
215
+ # imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
216
+ imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape']
217
+ if(hypar["model_digit"]=="full"):
218
+ labels_val = labels_val.type(torch.FloatTensor)
219
+ else:
220
+ labels_val = labels_val.type(torch.HalfTensor)
221
+
222
+ # wrap them in Variable
223
+ if torch.cuda.is_available():
224
+ labels_val_v = Variable(labels_val.cuda(), requires_grad=False)
225
+ else:
226
+ labels_val_v = Variable(labels_val,requires_grad=False)
227
+ # with autocast():
228
+ t_start = time.time()
229
+ ds_val = net(labels_val_v)[0]
230
+ t_end = time.time()-t_start
231
+ tmp_time.append(t_end)
232
+
233
+ # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
234
+ loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
235
+
236
+ # compute F measure
237
+ for t in range(hypar["batch_size_valid"]):
238
+ val_cnt = val_cnt + 1.0
239
+ print("num of val: ", val_cnt)
240
+ i_test = imidx_val[t].data.numpy()
241
+
242
+ pred_val = ds_val[0][t,:,:,:].float() # B x 1 x H x W
243
+
244
+ ## recover the prediction spatial size to the orignal image size
245
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
246
+
247
+ ma = torch.max(pred_val)
248
+ mi = torch.min(pred_val)
249
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
250
+ # pred_val = normPRED(pred_val)
251
+
252
+ gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
253
+ if gt.max()==1:
254
+ gt=gt*255
255
+ with torch.no_grad():
256
+ gt = torch.tensor(gt).to(device)
257
+
258
+ pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
259
+
260
+ PRE[i_test,:]=pre
261
+ REC[i_test,:] = rec
262
+ F1[i_test,:] = f1
263
+ MAE[i_test] = mae
264
+
265
+ del ds_val, gt
266
+ gc.collect()
267
+ torch.cuda.empty_cache()
268
+
269
+ # if(loss_val.data[0]>1):
270
+ val_loss += loss_val.item()#data[0]
271
+ tar_loss += loss2_val.item()#data[0]
272
+
273
+ print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
274
+
275
+ del loss2_val, loss_val
276
+
277
+ print('============================')
278
+ PRE_m = np.mean(PRE,0)
279
+ REC_m = np.mean(REC,0)
280
+ f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
281
+ # print('--------------:', np.mean(f1_m))
282
+ tmp_f1.append(np.amax(f1_m))
283
+ tmp_mae.append(np.mean(MAE))
284
+ print("The max F1 Score: %f"%(np.max(f1_m)))
285
+ print("MAE: ", np.mean(MAE))
286
+
287
+ # print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
288
+
289
+ return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
290
+
291
+ def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
292
+
293
+ if hypar["interm_sup"]:
294
+ print("Get the gt encoder ...")
295
+ featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val)
296
+ ## freeze the weights of gt encoder
297
+ for param in featurenet.parameters():
298
+ param.requires_grad=False
299
+
300
+ # scaler = GradScaler()
301
+ model_path = hypar["model_path"]
302
+ model_save_fre = hypar["model_save_fre"]
303
+ max_ite = hypar["max_ite"]
304
+ batch_size_train = hypar["batch_size_train"]
305
+ batch_size_valid = hypar["batch_size_valid"]
306
+
307
+ if(not os.path.exists(model_path)):
308
+ os.mkdir(model_path)
309
+
310
+ ite_num = hypar["start_ite"] # count the toal iteration number
311
+ ite_num4val = 0 #
312
+ running_loss = 0.0 # count the toal loss
313
+ running_tar_loss = 0.0 # count the target output loss
314
+ last_mae = [1 for x in range(len(valid_dataloaders))]
315
+ last_f1 = [0 for x in range(len(valid_dataloaders))]
316
+
317
+ train_num = train_datasets[0].__len__()
318
+
319
+ net.train()
320
+
321
+ start_last = time.time()
322
+ gos_dataloader = train_dataloaders[0]
323
+ epoch_num = hypar["max_epoch_num"]
324
+ notgood_cnt = 0
325
+ for epoch in range(epoch_num): ## set the epoch num as 100000
326
+
327
+ for i, data in enumerate(gos_dataloader):
328
+
329
+ if(ite_num >= max_ite):
330
+ print("Training Reached the Maximal Iteration Number ", max_ite)
331
+ exit()
332
+
333
+ # start_read = time.time()
334
+ ite_num = ite_num + 1
335
+ ite_num4val = ite_num4val + 1
336
+
337
+ # get the inputs
338
+ inputs, labels = data['image'], data['label']
339
+ locations = data['location_blocks']
340
+ if(hypar["model_digit"]=="full"):
341
+ inputs = inputs.type(torch.FloatTensor)
342
+ labels = labels.type(torch.FloatTensor)
343
+ locations = locations.type(torch.FloatTensor)
344
+ else:
345
+ inputs = inputs.type(torch.HalfTensor)
346
+ labels = labels.type(torch.HalfTensor)
347
+ locations = locations.type(torch.HalfTensor)
348
+
349
+ # wrap them in Variable
350
+ if torch.cuda.is_available():
351
+ inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
352
+ locations_v = Variable(locations.cuda(), requires_grad=False)
353
+ else:
354
+ inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
355
+ locations_v = Variable(locations, requires_grad=False)
356
+
357
+ # print("time lapse for data preparation: ", time.time()-start_read, ' s')
358
+
359
+ # y zero the parameter gradients
360
+ start_inf_loss_back = time.time()
361
+ optimizer.zero_grad()
362
+ if hypar["interm_sup"]:
363
+ # with autocast():
364
+ # forward + backward + optimize
365
+ _,fs = featurenet(labels_v)
366
+ ds,dfs = net(inputs_v)
367
+ ## extract the gt encodings
368
+ loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE')
369
+ # loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='cosin')
370
+ # print(next(featurenet.parameters()).dtype,next(net.parameters()).dtype,labels_v.dtype,fs[0][0].dtype)
371
+ # print(ds[0][0].dtype,dfs[0][0].dtype)
372
+ # print(loss2.dtype,loss.dtype)
373
+ else:
374
+ # with autocast():
375
+ # forward + backward + optimize
376
+ ds,_ = net(inputs_v)
377
+ loss2, loss = net.compute_loss(ds, labels_v)
378
+ # loss.backward()
379
+ # with torch.autograd.detect_anomaly():
380
+ # scaler.scale(loss).backward()
381
+ #ORTHO Loss
382
+ reg = 1e-8
383
+ orth_loss = torch.zeros(1).to(device)
384
+ for name, param in net.named_parameters():
385
+ if 'bias' not in name:
386
+ param_flat = param.view(param.shape[0], -1)
387
+ sym = torch.mm(param_flat, torch.t(param_flat))
388
+ sym -= torch.eye(param_flat.shape[0]).to(device)
389
+ orth_loss = orth_loss + (reg * sym.abs().sum())
390
+ loss = loss + orth_loss
391
+ loss.backward()
392
+ # scaler.step(optimizer)
393
+ # scaler.update()
394
+ optimizer.step()
395
+ # torch.cuda.empty_cache()
396
+
397
+ # # print statistics
398
+ running_loss += loss.item()
399
+ running_tar_loss += loss2.item()
400
+
401
+ # del outputs, loss
402
+ del ds, loss2, loss
403
+ end_inf_loss_back = time.time()-start_inf_loss_back
404
+
405
+ print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
406
+ epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
407
+ start_last = time.time()
408
+
409
+ if ite_num % model_save_fre == 0: # validate every 2000 iterations
410
+ notgood_cnt += 1
411
+ net.eval()
412
+ tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch)
413
+ torch.cuda.empty_cache()
414
+ net.train() # resume train
415
+
416
+ tmp_out = 0
417
+ print("last_f1:",last_f1,np.mean(last_f1))
418
+ print("tmp_f1:",tmp_f1,np.mean(tmp_f1))
419
+ if np.mean(tmp_mae)<np.mean(last_mae):
420
+ last_mae = tmp_mae
421
+ tmp_out = 1
422
+ if np.mean(tmp_f1)>np.mean(last_f1):
423
+ last_f1 = tmp_f1
424
+ tmp_out = 1
425
+ print("tmp_out:",tmp_out)
426
+ if(tmp_out):
427
+ notgood_cnt = 0
428
+ # last_f1 = tmp_f1
429
+ tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
430
+ tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
431
+ maxf1 = '_'.join(tmp_f1_str)
432
+ meanM = '_'.join(tmp_mae_str)
433
+ # .cpu().detach().numpy()
434
+ model_name = "/gpu_itr_"+str(ite_num)+\
435
+ "_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
436
+ "_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
437
+ "_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
438
+ "_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
439
+ "_maxF1_" + maxf1 + \
440
+ "_mae_" + meanM + \
441
+ "_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
442
+ torch.save(net.state_dict(), model_path + model_name)
443
+
444
+ running_loss = 0.0
445
+ running_tar_loss = 0.0
446
+ ite_num4val = 0
447
+
448
+ if(notgood_cnt >= hypar["early_stop"]):
449
+ print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
450
+ exit()
451
+
452
+ print("Training Reaches The Maximum Epoch Number")
453
+
454
+ def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
455
+ net.eval()
456
+ print("Validating...")
457
+ epoch_num = hypar["max_epoch_num"]
458
+
459
+ val_loss = 0.0
460
+ tar_loss = 0.0
461
+ val_cnt = 0.0
462
+
463
+ tmp_f1 = []
464
+ tmp_mae = []
465
+ tmp_time = []
466
+
467
+ start_valid = time.time()
468
+
469
+ for k in range(len(valid_dataloaders)):
470
+
471
+ valid_dataloader = valid_dataloaders[k]
472
+ valid_dataset = valid_datasets[k]
473
+
474
+ val_num = valid_dataset.__len__()
475
+ mybins = np.arange(0,256)
476
+ PRE = np.zeros((val_num,len(mybins)-1))
477
+ REC = np.zeros((val_num,len(mybins)-1))
478
+ F1 = np.zeros((val_num,len(mybins)-1))
479
+ MAE = np.zeros((val_num))
480
+
481
+ for i_val, data_val in enumerate(valid_dataloader):
482
+ val_cnt = val_cnt + 1.0
483
+ imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
484
+
485
+ if(hypar["model_digit"]=="full"):
486
+ inputs_val = inputs_val.type(torch.FloatTensor)
487
+ labels_val = labels_val.type(torch.FloatTensor)
488
+ else:
489
+ inputs_val = inputs_val.type(torch.HalfTensor)
490
+ labels_val = labels_val.type(torch.HalfTensor)
491
+
492
+ # wrap them in Variable
493
+ if torch.cuda.is_available():
494
+ inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False)
495
+ else:
496
+ inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False)
497
+ # with autocast():
498
+ t_start = time.time()
499
+ ds_val = net(inputs_val_v)[0]
500
+ # plt.imshow(inputs_val_v[0][0].cpu().detach())
501
+ # plt.show()
502
+ # print(inputs_val_v.cpu().detach().shape)
503
+ t_end = time.time()-t_start
504
+ tmp_time.append(t_end)
505
+
506
+ # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
507
+ loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
508
+
509
+ # compute F measure
510
+ for t in range(hypar["batch_size_valid"]):
511
+ i_test = imidx_val[t].data.numpy()
512
+
513
+ pred_val = ds_val[0][t,:,:,:].float() # B x 1 x H x W
514
+
515
+ ## recover the prediction spatial size to the orignal image size
516
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
517
+
518
+ # pred_val = normPRED(pred_val)
519
+ ma = torch.max(pred_val)
520
+ mi = torch.min(pred_val)
521
+ pred_val = (pred_val-mi)/(ma-mi) # max = 1
522
+
523
+ gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
524
+ if gt.max()==1:
525
+ gt=gt*255
526
+
527
+ with torch.no_grad():
528
+ gt = torch.tensor(gt).to(device)
529
+
530
+ pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
531
+
532
+
533
+ PRE[i_test,:]=pre
534
+ REC[i_test,:] = rec
535
+ F1[i_test,:] = f1
536
+ MAE[i_test] = mae
537
+
538
+ del ds_val, gt
539
+ gc.collect()
540
+ torch.cuda.empty_cache()
541
+
542
+ # if(loss_val.data[0]>1):
543
+ val_loss += loss_val.item()#data[0]
544
+ tar_loss += loss2_val.item()#data[0]
545
+
546
+ print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
547
+
548
+ del loss2_val, loss_val
549
+
550
+ print('============================')
551
+ PRE_m = np.mean(PRE,0)
552
+ REC_m = np.mean(REC,0)
553
+ f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
554
+
555
+ tmp_f1.append(np.amax(f1_m))
556
+ tmp_mae.append(np.mean(MAE))
557
+
558
+ return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
559
+
560
+ def main(train_datasets,
561
+ valid_datasets,
562
+ hypar): # model: "train", "test"
563
+
564
+ ### --- Step 1: Build datasets and dataloaders ---
565
+ dataloaders_train = []
566
+ dataloaders_valid = []
567
+
568
+ if(hypar["mode"]=="train"):
569
+ print("--- create training dataloader ---")
570
+ ## collect training dataset
571
+ train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
572
+ ## build dataloader for training datasets
573
+ train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
574
+ cache_size = hypar["cache_size"],
575
+ cache_boost = hypar["cache_boost_train"],
576
+ my_transforms = [
577
+ GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
578
+ # GOSResize(hypar["input_size"]),
579
+ # GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
580
+ # GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]),
581
+ GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]),
582
+ # GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
583
+ # GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375])
584
+ ],
585
+ batch_size = hypar["batch_size_train"],
586
+ shuffle = True,
587
+ is_train=True)
588
+ train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list,
589
+ cache_size = hypar["cache_size"],
590
+ cache_boost = hypar["cache_boost_train"],
591
+ my_transforms = [
592
+ # GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]),
593
+ GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]),
594
+ # GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
595
+ # GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375])
596
+ ],
597
+ batch_size = hypar["batch_size_valid"],
598
+ shuffle = False,
599
+ is_train=False)
600
+ print(len(train_dataloaders), " train dataloaders created")
601
+
602
+ print("--- create valid dataloader ---")
603
+ ## build dataloader for validation or testing
604
+ valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
605
+ ## build dataloader for training datasets
606
+ valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list,
607
+ cache_size = hypar["cache_size"],
608
+ cache_boost = hypar["cache_boost_valid"],
609
+ my_transforms = [
610
+ # GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]),
611
+ GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]),
612
+ # GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
613
+ # GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375])
614
+ # GOSResize(hypar["input_size"])
615
+ ],
616
+ batch_size=hypar["batch_size_valid"],
617
+ shuffle=False,
618
+ is_train=False)
619
+ print(len(valid_dataloaders), " valid dataloaders created")
620
+ # print(valid_datasets[0]["data_name"])
621
+
622
+ ### --- Step 2: Build Model and Optimizer ---
623
+ print("--- build model ---")
624
+ net = hypar["model"]#GOSNETINC(3,1)
625
+
626
+ # convert to half precision
627
+ # if(hypar["model_digit"]=="half"):
628
+ # net.half()
629
+
630
+ if torch.cuda.is_available():
631
+ net.cuda()
632
+
633
+ if(hypar["restore_model"]!=""):
634
+ print("restore model from:")
635
+ print(hypar["model_path"]+"/"+hypar["restore_model"])
636
+ if torch.cuda.is_available():
637
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]),strict=False)
638
+ else:
639
+ net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"),strict=False)
640
+
641
+ print("--- define optimizer ---")
642
+ # optimizer = optim.AdamW(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
643
+ optimizer = optim.AdamW(net.parameters(), lr=4e-5, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
644
+ ### --- Step 3: Train or Valid Model ---
645
+ if(hypar["mode"]=="train"):
646
+ train(net,
647
+ optimizer,
648
+ train_dataloaders,
649
+ train_datasets,
650
+ valid_dataloaders,
651
+ valid_datasets,
652
+ hypar,
653
+ train_dataloaders_val, train_datasets_val)
654
+ else:
655
+ valid(net,
656
+ valid_dataloaders,
657
+ valid_datasets,
658
+ hypar)
659
+
660
+
661
+ if __name__ == "__main__":
662
+
663
+ ### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
664
+ ## configure the train, valid and inference datasets
665
+ train_datasets, valid_datasets = [], []
666
+
667
+ valid_datasets = [dataset_test] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
668
+ train_datasets = [dataset_test] ## users can create mutiple dictionary for setting a list of datasets as training set
669
+
670
+
671
+ ### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
672
+ hypar = {}
673
+
674
+ ## -- 2.1. configure the model saving or restoring path --
675
+ hypar["mode"] = "train"
676
+ ## "train": for training,
677
+ ## "valid": for validation and inferening,
678
+ ## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
679
+ ## otherwise only accuracy will bee calculated and no predictions will be saved
680
+ hypar["interm_sup"] = True ## in-dicate if activate intermediate feature supervision
681
+
682
+ if hypar["mode"] == "train":
683
+ hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
684
+ hypar["model_path"] ="./saved_models" ## model weights saving (or restoring) path
685
+ hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
686
+ hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
687
+ hypar["gt_encoder_model"] = ""
688
+ else: ## configure the segmentation output path and the to-be-used model weights path
689
+ hypar["valid_out_dir"] = "./your-results/"##".D:/Code/Design_for_graduation/DIS-main/IS-Net/DIS5K-Results-test" ## output inferenced segmentation maps into this fold
690
+ hypar["model_path"] = "./saved_models" ## load trained weights from this path
691
+ hypar["restore_model"] = "gpu_itr_102000_traLoss_2.5701_traTarLoss_0.0248_valLoss_2.3643_valTarLoss_0.3743_maxF1_0.8063_mae_0.0825_time_0.015695.pth"##"isnet.pth" ## name of the to-be-loaded weights
692
+
693
+ # if hypar["restore_model"]!="":
694
+ # hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
695
+
696
+ ## -- 2.2. choose floating point accuracy --
697
+ hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
698
+ hypar["seed"] = 0
699
+
700
+ ## -- 2.3. cache data spatial size --
701
+ ## To handle large size input images, which take a lot of time for loading in training,
702
+ # we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
703
+ hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
704
+ hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
705
+ hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
706
+
707
+ ## --- 2.4. data augmentation parameters ---
708
+ hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
709
+ hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
710
+ hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the datader and it is not in use
711
+ hypar["random_flip_v"] = 1 ## vertical flip , currently not in use
712
+
713
+ ## --- 2.5. define model ---
714
+ print("building model...")
715
+ hypar["model"] = ISNetDIS(in_ch=5) #U2NETFASTFEATURESUP()
716
+ hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
717
+ hypar["model_save_fre"] = 3000 ## valid and save model weights every 2000 iterations
718
+
719
+ hypar["batch_size_train"] = 6 ## batch size for training
720
+ hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing
721
+ print("batch size: ", hypar["batch_size_train"])
722
+
723
+ hypar["max_ite"] = 50000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
724
+ hypar["max_epoch_num"] = 500000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
725
+
726
+ main(train_datasets,
727
+ valid_datasets,
728
+ hypar=hypar)
729
+
MultiScaleDeformableAttention-1.0-py3-none-any.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:152caec7860d1f39f644ac5eed946b5a4eecfad40764396345b3d0e516921b17
3
+ size 2048806
README.md CHANGED
@@ -9,6 +9,8 @@ app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: SAM-prompted dichotomous segmentation. No affiliation.
 
 
 
 
12
  ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  license: mit
11
  short_description: SAM-prompted dichotomous segmentation. No affiliation.
12
+ python_version: 3.11
13
+ preload_from_hub:
14
+ - jwlarocque/DIS-SAM DIS-SAM-checkpoint.pth
15
+ - andzhang01/segment_anything sam_vit_l_0b3195.pth
16
  ---
 
 
SAM/segment_anything/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_sam import (
8
+ build_sam,
9
+ build_sam_vit_h,
10
+ build_sam_vit_l,
11
+ build_sam_vit_b,
12
+ sam_model_registry,
13
+ )
14
+ from .predictor import SamPredictor
15
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
SAM/segment_anything/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (549 Bytes). View file
 
SAM/segment_anything/__pycache__/automatic_mask_generator.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
SAM/segment_anything/__pycache__/build_sam.cpython-311.pyc ADDED
Binary file (3.22 kB). View file
 
SAM/segment_anything/__pycache__/predictor.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
SAM/segment_anything/automatic_mask_generator.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10
+
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ from .modeling import Sam
14
+ from .predictor import SamPredictor
15
+ from .utils.amg import (
16
+ MaskData,
17
+ area_from_rle,
18
+ batch_iterator,
19
+ batched_mask_to_box,
20
+ box_xyxy_to_xywh,
21
+ build_all_layer_point_grids,
22
+ calculate_stability_score,
23
+ coco_encode_rle,
24
+ generate_crop_boxes,
25
+ is_box_near_crop_edge,
26
+ mask_to_rle_pytorch,
27
+ remove_small_regions,
28
+ rle_to_mask,
29
+ uncrop_boxes_xyxy,
30
+ uncrop_masks,
31
+ uncrop_points,
32
+ )
33
+
34
+
35
+ class SamAutomaticMaskGenerator:
36
+ def __init__(
37
+ self,
38
+ model: Sam,
39
+ points_per_side: Optional[int] = 32,
40
+ points_per_batch: int = 64,
41
+ pred_iou_thresh: float = 0.88,
42
+ stability_score_thresh: float = 0.95,
43
+ stability_score_offset: float = 1.0,
44
+ box_nms_thresh: float = 0.7,
45
+ crop_n_layers: int = 0,
46
+ crop_nms_thresh: float = 0.7,
47
+ crop_overlap_ratio: float = 512 / 1500,
48
+ crop_n_points_downscale_factor: int = 1,
49
+ point_grids: Optional[List[np.ndarray]] = None,
50
+ min_mask_region_area: int = 0,
51
+ output_mode: str = "binary_mask",
52
+ ) -> None:
53
+ """
54
+ Using a SAM model, generates masks for the entire image.
55
+ Generates a grid of point prompts over the image, then filters
56
+ low quality and duplicate masks. The default settings are chosen
57
+ for SAM with a ViT-H backbone.
58
+
59
+ Arguments:
60
+ model (Sam): The SAM model to use for mask prediction.
61
+ points_per_side (int or None): The number of points to be sampled
62
+ along one side of the image. The total number of points is
63
+ points_per_side**2. If None, 'point_grids' must provide explicit
64
+ point sampling.
65
+ points_per_batch (int): Sets the number of points run simultaneously
66
+ by the model. Higher numbers may be faster but use more GPU memory.
67
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
68
+ model's predicted mask quality.
69
+ stability_score_thresh (float): A filtering threshold in [0,1], using
70
+ the stability of the mask under changes to the cutoff used to binarize
71
+ the model's mask predictions.
72
+ stability_score_offset (float): The amount to shift the cutoff when
73
+ calculated the stability score.
74
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
75
+ suppression to filter duplicate masks.
76
+ crop_n_layers (int): If >0, mask prediction will be run again on
77
+ crops of the image. Sets the number of layers to run, where each
78
+ layer has 2**i_layer number of image crops.
79
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
80
+ suppression to filter duplicate masks between different crops.
81
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
82
+ In the first crop layer, crops will overlap by this fraction of
83
+ the image length. Later layers with more crops scale down this overlap.
84
+ crop_n_points_downscale_factor (int): The number of points-per-side
85
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86
+ point_grids (list(np.ndarray) or None): A list over explicit grids
87
+ of points used for sampling, normalized to [0,1]. The nth grid in the
88
+ list is used in the nth crop layer. Exclusive with points_per_side.
89
+ min_mask_region_area (int): If >0, postprocessing will be applied
90
+ to remove disconnected regions and holes in masks with area smaller
91
+ than min_mask_region_area. Requires opencv.
92
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
93
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94
+ For large resolutions, 'binary_mask' may consume large amounts of
95
+ memory.
96
+ """
97
+
98
+ assert (points_per_side is None) != (
99
+ point_grids is None
100
+ ), "Exactly one of points_per_side or point_grid must be provided."
101
+ if points_per_side is not None:
102
+ self.point_grids = build_all_layer_point_grids(
103
+ points_per_side,
104
+ crop_n_layers,
105
+ crop_n_points_downscale_factor,
106
+ )
107
+ elif point_grids is not None:
108
+ self.point_grids = point_grids
109
+ else:
110
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
111
+
112
+ assert output_mode in [
113
+ "binary_mask",
114
+ "uncompressed_rle",
115
+ "coco_rle",
116
+ ], f"Unknown output_mode {output_mode}."
117
+ if output_mode == "coco_rle":
118
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119
+
120
+ if min_mask_region_area > 0:
121
+ import cv2 # type: ignore # noqa: F401
122
+
123
+ self.predictor = SamPredictor(model)
124
+ self.points_per_batch = points_per_batch
125
+ self.pred_iou_thresh = pred_iou_thresh
126
+ self.stability_score_thresh = stability_score_thresh
127
+ self.stability_score_offset = stability_score_offset
128
+ self.box_nms_thresh = box_nms_thresh
129
+ self.crop_n_layers = crop_n_layers
130
+ self.crop_nms_thresh = crop_nms_thresh
131
+ self.crop_overlap_ratio = crop_overlap_ratio
132
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133
+ self.min_mask_region_area = min_mask_region_area
134
+ self.output_mode = output_mode
135
+
136
+ @torch.no_grad()
137
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138
+ """
139
+ Generates masks for the given image.
140
+
141
+ Arguments:
142
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143
+
144
+ Returns:
145
+ list(dict(str, any)): A list over records for masks. Each record is
146
+ a dict containing the following keys:
147
+ segmentation (dict(str, any) or np.ndarray): The mask. If
148
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
149
+ is a dictionary containing the RLE.
150
+ bbox (list(float)): The box around the mask, in XYWH format.
151
+ area (int): The area in pixels of the mask.
152
+ predicted_iou (float): The model's own prediction of the mask's
153
+ quality. This is filtered by the pred_iou_thresh parameter.
154
+ point_coords (list(list(float))): The point coordinates input
155
+ to the model to generate this mask.
156
+ stability_score (float): A measure of the mask's quality. This
157
+ is filtered on using the stability_score_thresh parameter.
158
+ crop_box (list(float)): The crop of the image used to generate
159
+ the mask, given in XYWH format.
160
+ """
161
+
162
+ # Generate masks
163
+ mask_data = self._generate_masks(image)
164
+
165
+ # Filter small disconnected regions and holes in masks
166
+ if self.min_mask_region_area > 0:
167
+ mask_data = self.postprocess_small_regions(
168
+ mask_data,
169
+ self.min_mask_region_area,
170
+ max(self.box_nms_thresh, self.crop_nms_thresh),
171
+ )
172
+
173
+ # Encode masks
174
+ if self.output_mode == "coco_rle":
175
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176
+ elif self.output_mode == "binary_mask":
177
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178
+ else:
179
+ mask_data["segmentations"] = mask_data["rles"]
180
+
181
+ # Write mask records
182
+ curr_anns = []
183
+ for idx in range(len(mask_data["segmentations"])):
184
+ ann = {
185
+ "segmentation": mask_data["segmentations"][idx],
186
+ "area": area_from_rle(mask_data["rles"][idx]),
187
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
189
+ "point_coords": [mask_data["points"][idx].tolist()],
190
+ "stability_score": mask_data["stability_score"][idx].item(),
191
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192
+ }
193
+ curr_anns.append(ann)
194
+
195
+ return curr_anns
196
+
197
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
198
+ orig_size = image.shape[:2]
199
+ crop_boxes, layer_idxs = generate_crop_boxes(
200
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
201
+ )
202
+
203
+ # Iterate over image crops
204
+ data = MaskData()
205
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207
+ data.cat(crop_data)
208
+
209
+ # Remove duplicate masks between crops
210
+ if len(crop_boxes) > 1:
211
+ # Prefer masks from smaller crops
212
+ scores = 1 / box_area(data["crop_boxes"])
213
+ scores = scores.to(data["boxes"].device)
214
+ keep_by_nms = batched_nms(
215
+ data["boxes"].float(),
216
+ scores,
217
+ torch.zeros_like(data["boxes"][:, 0]), # categories
218
+ iou_threshold=self.crop_nms_thresh,
219
+ )
220
+ data.filter(keep_by_nms)
221
+
222
+ data.to_numpy()
223
+ return data
224
+
225
+ def _process_crop(
226
+ self,
227
+ image: np.ndarray,
228
+ crop_box: List[int],
229
+ crop_layer_idx: int,
230
+ orig_size: Tuple[int, ...],
231
+ ) -> MaskData:
232
+ # Crop the image and calculate embeddings
233
+ x0, y0, x1, y1 = crop_box
234
+ cropped_im = image[y0:y1, x0:x1, :]
235
+ cropped_im_size = cropped_im.shape[:2]
236
+ self.predictor.set_image(cropped_im)
237
+
238
+ # Get points for this crop
239
+ points_scale = np.array(cropped_im_size)[None, ::-1]
240
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
241
+
242
+ # Generate masks for this crop in batches
243
+ data = MaskData()
244
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246
+ data.cat(batch_data)
247
+ del batch_data
248
+ self.predictor.reset_image()
249
+
250
+ # Remove duplicates within this crop.
251
+ keep_by_nms = batched_nms(
252
+ data["boxes"].float(),
253
+ data["iou_preds"],
254
+ torch.zeros_like(data["boxes"][:, 0]), # categories
255
+ iou_threshold=self.box_nms_thresh,
256
+ )
257
+ data.filter(keep_by_nms)
258
+
259
+ # Return to the original image frame
260
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261
+ data["points"] = uncrop_points(data["points"], crop_box)
262
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263
+
264
+ return data
265
+
266
+ def _process_batch(
267
+ self,
268
+ points: np.ndarray,
269
+ im_size: Tuple[int, ...],
270
+ crop_box: List[int],
271
+ orig_size: Tuple[int, ...],
272
+ ) -> MaskData:
273
+ orig_h, orig_w = orig_size
274
+
275
+ # Run model on this batch
276
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
277
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279
+ masks, iou_preds, _ = self.predictor.predict_torch(
280
+ in_points[:, None, :],
281
+ in_labels[:, None],
282
+ multimask_output=True,
283
+ return_logits=True,
284
+ )
285
+
286
+ # Serialize predictions and store in MaskData
287
+ data = MaskData(
288
+ masks=masks.flatten(0, 1),
289
+ iou_preds=iou_preds.flatten(0, 1),
290
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291
+ )
292
+ del masks
293
+
294
+ # Filter by predicted IoU
295
+ if self.pred_iou_thresh > 0.0:
296
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
297
+ data.filter(keep_mask)
298
+
299
+ # Calculate stability score
300
+ data["stability_score"] = calculate_stability_score(
301
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302
+ )
303
+ if self.stability_score_thresh > 0.0:
304
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
305
+ data.filter(keep_mask)
306
+
307
+ # Threshold masks and calculate boxes
308
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309
+ data["boxes"] = batched_mask_to_box(data["masks"])
310
+
311
+ # Filter boxes that touch crop boundaries
312
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313
+ if not torch.all(keep_mask):
314
+ data.filter(keep_mask)
315
+
316
+ # Compress to RLE
317
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
319
+ del data["masks"]
320
+
321
+ return data
322
+
323
+ @staticmethod
324
+ def postprocess_small_regions(
325
+ mask_data: MaskData, min_area: int, nms_thresh: float
326
+ ) -> MaskData:
327
+ """
328
+ Removes small disconnected regions and holes in masks, then reruns
329
+ box NMS to remove any new duplicates.
330
+
331
+ Edits mask_data in place.
332
+
333
+ Requires open-cv as a dependency.
334
+ """
335
+ if len(mask_data["rles"]) == 0:
336
+ return mask_data
337
+
338
+ # Filter small disconnected regions and holes
339
+ new_masks = []
340
+ scores = []
341
+ for rle in mask_data["rles"]:
342
+ mask = rle_to_mask(rle)
343
+
344
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
345
+ unchanged = not changed
346
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
347
+ unchanged = unchanged and not changed
348
+
349
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350
+ # Give score=0 to changed masks and score=1 to unchanged masks
351
+ # so NMS will prefer ones that didn't need postprocessing
352
+ scores.append(float(unchanged))
353
+
354
+ # Recalculate boxes and remove any new duplicates
355
+ masks = torch.cat(new_masks, dim=0)
356
+ boxes = batched_mask_to_box(masks)
357
+ keep_by_nms = batched_nms(
358
+ boxes.float(),
359
+ torch.as_tensor(scores),
360
+ torch.zeros_like(boxes[:, 0]), # categories
361
+ iou_threshold=nms_thresh,
362
+ )
363
+
364
+ # Only recalculate RLEs for masks that have changed
365
+ for i_mask in keep_by_nms:
366
+ if scores[i_mask] == 0.0:
367
+ mask_torch = masks[i_mask].unsqueeze(0)
368
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370
+ mask_data.filter(keep_by_nms)
371
+
372
+ return mask_data
SAM/segment_anything/build_sam.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from functools import partial
10
+
11
+ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12
+
13
+
14
+ def build_sam_vit_h(checkpoint=None, device="cpu"):
15
+ return _build_sam(
16
+ encoder_embed_dim=1280,
17
+ encoder_depth=32,
18
+ encoder_num_heads=16,
19
+ encoder_global_attn_indexes=[7, 15, 23, 31],
20
+ checkpoint=checkpoint,
21
+ device=device,
22
+ )
23
+
24
+
25
+ build_sam = build_sam_vit_h
26
+
27
+
28
+ def build_sam_vit_l(checkpoint=None, device="cpu"):
29
+ return _build_sam(
30
+ encoder_embed_dim=1024,
31
+ encoder_depth=24,
32
+ encoder_num_heads=16,
33
+ encoder_global_attn_indexes=[5, 11, 17, 23],
34
+ checkpoint=checkpoint,
35
+ device=device,
36
+ )
37
+
38
+
39
+ def build_sam_vit_b(checkpoint=None, device="cpu"):
40
+ return _build_sam(
41
+ encoder_embed_dim=768,
42
+ encoder_depth=12,
43
+ encoder_num_heads=12,
44
+ encoder_global_attn_indexes=[2, 5, 8, 11],
45
+ checkpoint=checkpoint,
46
+ device=device,
47
+ )
48
+
49
+
50
+ sam_model_registry = {
51
+ "default": build_sam_vit_h,
52
+ "vit_h": build_sam_vit_h,
53
+ "vit_l": build_sam_vit_l,
54
+ "vit_b": build_sam_vit_b,
55
+ }
56
+
57
+
58
+ def _build_sam(
59
+ encoder_embed_dim,
60
+ encoder_depth,
61
+ encoder_num_heads,
62
+ encoder_global_attn_indexes,
63
+ checkpoint=None,
64
+ device="cpu"
65
+ ):
66
+ prompt_embed_dim = 256
67
+ image_size = 1024
68
+ vit_patch_size = 16
69
+ image_embedding_size = image_size // vit_patch_size
70
+ sam = Sam(
71
+ image_encoder=ImageEncoderViT(
72
+ depth=encoder_depth,
73
+ embed_dim=encoder_embed_dim,
74
+ img_size=image_size,
75
+ mlp_ratio=4,
76
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
77
+ num_heads=encoder_num_heads,
78
+ patch_size=vit_patch_size,
79
+ qkv_bias=True,
80
+ use_rel_pos=True,
81
+ global_attn_indexes=encoder_global_attn_indexes,
82
+ window_size=14,
83
+ out_chans=prompt_embed_dim,
84
+ ),
85
+ prompt_encoder=PromptEncoder(
86
+ embed_dim=prompt_embed_dim,
87
+ image_embedding_size=(image_embedding_size, image_embedding_size),
88
+ input_image_size=(image_size, image_size),
89
+ mask_in_chans=16,
90
+ ),
91
+ mask_decoder=MaskDecoder(
92
+ num_multimask_outputs=3,
93
+ transformer=TwoWayTransformer(
94
+ depth=2,
95
+ embedding_dim=prompt_embed_dim,
96
+ mlp_dim=2048,
97
+ num_heads=8,
98
+ ),
99
+ transformer_dim=prompt_embed_dim,
100
+ iou_head_depth=3,
101
+ iou_head_hidden_dim=256,
102
+ ),
103
+ pixel_mean=[123.675, 116.28, 103.53],
104
+ pixel_std=[58.395, 57.12, 57.375],
105
+ )
106
+ sam.eval()
107
+ if checkpoint is not None:
108
+ with open(checkpoint, "rb") as f:
109
+ state_dict = torch.load(f, map_location=device)
110
+ sam.load_state_dict(state_dict)
111
+ return sam
SAM/segment_anything/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .sam import Sam
8
+ from .image_encoder import ImageEncoderViT
9
+ from .mask_decoder import MaskDecoder
10
+ from .prompt_encoder import PromptEncoder
11
+ from .transformer import TwoWayTransformer
SAM/segment_anything/modeling/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (516 Bytes). View file
 
SAM/segment_anything/modeling/__pycache__/common.cpython-311.pyc ADDED
Binary file (3.24 kB). View file