Spaces:
Running
Running
Commit
·
ab7d699
1
Parent(s):
fdaae10
Create DIS-SAM space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.jpg +3 -0
- IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.jpg +3 -0
- IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.jpg +3 -0
- IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.jpg +3 -0
- IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.jpg +3 -0
- IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png +0 -0
- IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png +0 -0
- IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png +0 -0
- IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#3292738108_c51336a8be_o.png +0 -0
- IS_Net/DIS5K/DIS5K-test/gt/4#Architecture#10#Pavilion#5795028920_08884db993_o.png +0 -0
- IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.jpg +3 -0
- IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#3292738108_c51336a8be_o.jpg +3 -0
- IS_Net/DIS5K/DIS5K-test/im/4#Architecture#10#Pavilion#5795028920_08884db993_o.jpg +3 -0
- IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png +0 -0
- IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#3292738108_c51336a8be_o.png +0 -0
- IS_Net/DIS5K/DIS5K-test/mask/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png +0 -0
- IS_Net/__pycache__/data_loader.cpython-311.pyc +0 -0
- IS_Net/basics.py +125 -0
- IS_Net/data_loader.py +542 -0
- IS_Net/datalist.py +62 -0
- IS_Net/models/__pycache__/isnet.cpython-311.pyc +0 -0
- IS_Net/models/isnet.py +640 -0
- IS_Net/saliency_toolbox.py +552 -0
- IS_Net/swd_optim/__init__.py +10 -0
- IS_Net/swd_optim/adai.py +116 -0
- IS_Net/swd_optim/adais.py +120 -0
- IS_Net/swd_optim/adams.py +137 -0
- IS_Net/swd_optim/sgds.py +82 -0
- IS_Net/train_valid_inference_main.py +729 -0
- MultiScaleDeformableAttention-1.0-py3-none-any.whl +3 -0
- README.md +4 -2
- SAM/segment_anything/__init__.py +15 -0
- SAM/segment_anything/__pycache__/__init__.cpython-311.pyc +0 -0
- SAM/segment_anything/__pycache__/automatic_mask_generator.cpython-311.pyc +0 -0
- SAM/segment_anything/__pycache__/build_sam.cpython-311.pyc +0 -0
- SAM/segment_anything/__pycache__/predictor.cpython-311.pyc +0 -0
- SAM/segment_anything/automatic_mask_generator.py +372 -0
- SAM/segment_anything/build_sam.py +111 -0
- SAM/segment_anything/modeling/__init__.py +11 -0
- SAM/segment_anything/modeling/__pycache__/__init__.cpython-311.pyc +0 -0
- SAM/segment_anything/modeling/__pycache__/common.cpython-311.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
MultiScaleDeformableAttention-1.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.jpg
ADDED
![]() |
Git LFS Details
|
IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.jpg
ADDED
![]() |
Git LFS Details
|
IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.jpg
ADDED
![]() |
Git LFS Details
|
IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.jpg
ADDED
![]() |
Git LFS Details
|
IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.jpg
ADDED
![]() |
Git LFS Details
|
IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#3292738108_c51336a8be_o.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/gt/4#Architecture#10#Pavilion#5795028920_08884db993_o.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.jpg
ADDED
![]() |
Git LFS Details
|
IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#3292738108_c51336a8be_o.jpg
ADDED
![]() |
Git LFS Details
|
IS_Net/DIS5K/DIS5K-test/im/4#Architecture#10#Pavilion#5795028920_08884db993_o.jpg
ADDED
![]() |
Git LFS Details
|
IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#3292738108_c51336a8be_o.png
ADDED
![]() |
IS_Net/DIS5K/DIS5K-test/mask/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png
ADDED
![]() |
IS_Net/__pycache__/data_loader.cpython-311.pyc
ADDED
Binary file (34.3 kB). View file
|
|
IS_Net/basics.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
3 |
+
from skimage import io, transform
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from torch.autograd import Variable
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.utils.data import Dataset, DataLoader
|
10 |
+
from torchvision import transforms, utils
|
11 |
+
import torch.optim as optim
|
12 |
+
from skimage.metrics import structural_similarity as ssim
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
import glob
|
17 |
+
import cv2
|
18 |
+
from scipy.stats import pearsonr
|
19 |
+
|
20 |
+
def mae_torch(pred,gt):
|
21 |
+
|
22 |
+
h,w = gt.shape[0:2]
|
23 |
+
sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
|
24 |
+
maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4)
|
25 |
+
|
26 |
+
return maeError
|
27 |
+
|
28 |
+
import torch
|
29 |
+
|
30 |
+
def maximal_f_measure_torch(pd, gt):
|
31 |
+
gtNum = torch.sum((gt > 128).float() * 1) # 计算真实标签中像素值大于128的数量
|
32 |
+
|
33 |
+
# 从预测张量中提取正例和负例
|
34 |
+
pp = pd[gt > 128]
|
35 |
+
nn = pd[gt <= 128]
|
36 |
+
|
37 |
+
# 计算正例和负例的直方图
|
38 |
+
pp_hist = torch.histc(pp, bins=255, min=0, max=255)
|
39 |
+
nn_hist = torch.histc(nn, bins=255, min=0, max=255)
|
40 |
+
|
41 |
+
# 反转直方图并计算累积和
|
42 |
+
pp_hist_flip = torch.flipud(pp_hist)
|
43 |
+
nn_hist_flip = torch.flipud(nn_hist)
|
44 |
+
|
45 |
+
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
|
46 |
+
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
|
47 |
+
|
48 |
+
# 计算Precision、Recall 和 F-measure
|
49 |
+
precision = (pp_hist_flip_cum) / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)
|
50 |
+
recall = (pp_hist_flip_cum) / (gtNum + 1e-4)
|
51 |
+
f_measure = (2 * precision * recall) / (precision + recall + 1e-4)
|
52 |
+
|
53 |
+
# 找到最大F-measure及其对应的阈值
|
54 |
+
max_f_measure, threshold = torch.max(f_measure, dim=0)
|
55 |
+
|
56 |
+
return max_f_measure.item(), threshold.item()
|
57 |
+
|
58 |
+
def calculate_meam(image1, image2):
|
59 |
+
# 直方图均衡化
|
60 |
+
image1_equalized = cv2.equalizeHist(image1)
|
61 |
+
image2_equalized = cv2.equalizeHist(image2)
|
62 |
+
|
63 |
+
# 计算Pearson相关系数
|
64 |
+
correlation_coefficient, _ = pearsonr(image1_equalized.flatten(), image2_equalized.flatten())
|
65 |
+
|
66 |
+
# 计算MEAM值
|
67 |
+
meam_value = correlation_coefficient * np.mean(np.minimum(image1_equalized, image2_equalized))
|
68 |
+
|
69 |
+
return meam_value
|
70 |
+
|
71 |
+
def f1score_torch(pd,gt):
|
72 |
+
|
73 |
+
# print(gt.shape)
|
74 |
+
gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels
|
75 |
+
|
76 |
+
pp = pd[gt>128]
|
77 |
+
nn = pd[gt<=128]
|
78 |
+
|
79 |
+
pp_hist =torch.histc(pp,bins=255,min=0,max=255)
|
80 |
+
nn_hist = torch.histc(nn,bins=255,min=0,max=255)
|
81 |
+
|
82 |
+
|
83 |
+
pp_hist_flip = torch.flipud(pp_hist)
|
84 |
+
nn_hist_flip = torch.flipud(nn_hist)
|
85 |
+
|
86 |
+
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
|
87 |
+
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
|
88 |
+
|
89 |
+
precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
|
90 |
+
recall = (pp_hist_flip_cum)/(gtNum + 1e-4)
|
91 |
+
f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4)
|
92 |
+
|
93 |
+
return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0]))
|
94 |
+
|
95 |
+
|
96 |
+
def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
|
97 |
+
|
98 |
+
import time
|
99 |
+
tic = time.time()
|
100 |
+
|
101 |
+
if(len(gt.shape)>2):
|
102 |
+
gt = gt[:,:,0]
|
103 |
+
# if pred.shape != gt.shape:
|
104 |
+
# plt.imshow(pred.cpu().detach().numpy())
|
105 |
+
# plt.show()
|
106 |
+
# plt.imshow(gt.cpu().detach().numpy())
|
107 |
+
# plt.show()
|
108 |
+
# pred = pred.transpose(1,0)
|
109 |
+
# print(pred.shape,gt.shape)
|
110 |
+
# print(valid_dataset.dataset["im_name"][idx]+".png")
|
111 |
+
pre, rec, f1 = f1score_torch(pred,gt)
|
112 |
+
mae = mae_torch(pred,gt)
|
113 |
+
|
114 |
+
# hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ###
|
115 |
+
if(hypar["valid_out_dir"]!=""):
|
116 |
+
if(not os.path.exists(hypar["valid_out_dir"])):
|
117 |
+
os.mkdir(hypar["valid_out_dir"])
|
118 |
+
dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx])
|
119 |
+
if(not os.path.exists(dataset_folder)):
|
120 |
+
os.mkdir(dataset_folder)
|
121 |
+
io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8))
|
122 |
+
# print(valid_dataset.dataset["im_name"][idx]+".png")
|
123 |
+
# print("time for evaluation : ", time.time()-tic)
|
124 |
+
|
125 |
+
return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy()
|
IS_Net/data_loader.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## data loader
|
2 |
+
## Ackownledgement:
|
3 |
+
## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en)
|
4 |
+
## for his helps in implementing cache machanism of our DIS dataloader.
|
5 |
+
from __future__ import print_function, division
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
from copy import deepcopy
|
10 |
+
import json
|
11 |
+
from tqdm import tqdm
|
12 |
+
from skimage import io
|
13 |
+
import os
|
14 |
+
from glob import glob
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
from PIL import Image, ImageOps
|
17 |
+
import torch
|
18 |
+
from torch.utils.data import Dataset, DataLoader
|
19 |
+
from torchvision import transforms, utils
|
20 |
+
from torchvision.transforms.functional import normalize
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import cv2
|
23 |
+
from scipy.ndimage import label
|
24 |
+
|
25 |
+
def show_gray_images(images, m=4):
|
26 |
+
"""
|
27 |
+
展示一组灰度图像
|
28 |
+
|
29 |
+
参数:
|
30 |
+
images: 一个形状为(n, h, w)的数组,其中n是图像的数量,h和w分别是图像的高度和宽度。
|
31 |
+
m: 每行展示的图像数量,默认为4。
|
32 |
+
|
33 |
+
返回值:
|
34 |
+
无
|
35 |
+
"""
|
36 |
+
n, h, w = images.shape # 获取输入图像的数量、高度和宽度
|
37 |
+
num_rows = (n + m - 1) // m # 计算需要的行数
|
38 |
+
fig, axes = plt.subplots(num_rows, m, figsize=(m*2, num_rows*2)) # 创建画布和子图
|
39 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05) # 调整子图间的间距
|
40 |
+
for i in range(num_rows):
|
41 |
+
for j in range(m):
|
42 |
+
idx = i*m + j # 计算当前图像的索引
|
43 |
+
if idx < n:
|
44 |
+
axes[i, j].imshow(images[idx], cmap='gray') # 展示图像
|
45 |
+
axes[i, j].axis('off') # 关闭坐标轴显示
|
46 |
+
plt.show() # 显示图像
|
47 |
+
#### --------------------- DIS dataloader cache ---------------------####
|
48 |
+
|
49 |
+
def segment_connected_components(mask):
|
50 |
+
# 将mask转换为PyTorch张量
|
51 |
+
mask_tensor = torch.tensor(mask)
|
52 |
+
|
53 |
+
# 使用Scipy的label函数找到连通组件
|
54 |
+
labeled_array, num_features = label(mask_tensor.numpy())
|
55 |
+
|
56 |
+
# 创建一个字典来存储每个连通组件的像素值
|
57 |
+
components = {}
|
58 |
+
for label_idx in range(1, num_features + 1):
|
59 |
+
component_mask = (labeled_array == label_idx)
|
60 |
+
components[label_idx] = component_mask.astype(int)
|
61 |
+
|
62 |
+
return components
|
63 |
+
|
64 |
+
def FillHole(im_in):
|
65 |
+
img = np.array(im_in,dtype=np.uint8)[0]
|
66 |
+
mask = np.zeros_like(img)
|
67 |
+
contours, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
68 |
+
for contour in contours:
|
69 |
+
cv2.drawContours(mask, [contour], -1, 255, thickness=cv2.FILLED)
|
70 |
+
im_out = torch.from_numpy(mask)[None,...].float()
|
71 |
+
return im_out
|
72 |
+
|
73 |
+
def get_im_gt_name_dict(datasets, flag='valid'):
|
74 |
+
print("------------------------------", flag, "--------------------------------")
|
75 |
+
name_im_gt_mid_list = []
|
76 |
+
for i in range(len(datasets)):
|
77 |
+
print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---")
|
78 |
+
tmp_im_list, tmp_gt_list, tmp_mid_list = [], [], []
|
79 |
+
tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"])
|
80 |
+
|
81 |
+
# img_name_dict[im_dirs[i][0]] = tmp_im_list
|
82 |
+
# print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list))
|
83 |
+
|
84 |
+
if(datasets[i]["gt_dir"]==""):
|
85 |
+
print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found')
|
86 |
+
tmp_gt_list = []
|
87 |
+
else:
|
88 |
+
tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list]
|
89 |
+
|
90 |
+
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
|
91 |
+
# print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list))
|
92 |
+
|
93 |
+
if(datasets[i]["mid_dir"]==""):
|
94 |
+
print('-mid-', datasets[i]["name"], datasets[i]["mid_dir"], ': ', 'No mid Found')
|
95 |
+
tmp_mid_list = []
|
96 |
+
else:
|
97 |
+
tmp_mid_list = [datasets[i]["mid_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["mid_ext"] for x in tmp_im_list]
|
98 |
+
|
99 |
+
# lbl_name_dict[im_dirs[i][0]] = tmp_gt_list
|
100 |
+
# print('-mid-', datasets[i]["name"],datasets[i]["mid_dir"], ': ',len(tmp_gt_list))
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
if flag=="train": ## combine multiple training sets into one dataset
|
105 |
+
if len(name_im_gt_mid_list)==0:
|
106 |
+
name_im_gt_mid_list.append({"dataset_name":datasets[i]["name"],
|
107 |
+
"im_path":tmp_im_list,
|
108 |
+
"gt_path":tmp_gt_list,
|
109 |
+
"mid_path":tmp_mid_list,
|
110 |
+
"im_ext":datasets[i]["im_ext"],
|
111 |
+
"gt_ext":datasets[i]["gt_ext"],
|
112 |
+
"mid_ext":datasets[i]["mid_ext"],
|
113 |
+
"cache_dir":datasets[i]["cache_dir"]})
|
114 |
+
else:
|
115 |
+
name_im_gt_mid_list[0]["dataset_name"] = name_im_gt_mid_list[0]["dataset_name"] + "_" + datasets[i]["name"]
|
116 |
+
name_im_gt_mid_list[0]["im_path"] = name_im_gt_mid_list[0]["im_path"] + tmp_im_list
|
117 |
+
name_im_gt_mid_list[0]["gt_path"] = name_im_gt_mid_list[0]["gt_path"] + tmp_gt_list
|
118 |
+
name_im_gt_mid_list[0]["mid_path"] = name_im_gt_mid_list[0]["mid_path"] + tmp_mid_list
|
119 |
+
if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png":
|
120 |
+
print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!")
|
121 |
+
exit()
|
122 |
+
name_im_gt_mid_list[0]["im_ext"] = ".jpg"
|
123 |
+
name_im_gt_mid_list[0]["gt_ext"] = ".png"
|
124 |
+
name_im_gt_mid_list[0]["mid_ext"] = ".png"
|
125 |
+
name_im_gt_mid_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_mid_list[0]["dataset_name"]
|
126 |
+
else: ## keep different validation or inference datasets as separate ones
|
127 |
+
name_im_gt_mid_list.append({"dataset_name":datasets[i]["name"],
|
128 |
+
"im_path":tmp_im_list,
|
129 |
+
"gt_path":tmp_gt_list,
|
130 |
+
"mid_path":tmp_mid_list,
|
131 |
+
"im_ext":datasets[i]["im_ext"],
|
132 |
+
"gt_ext":datasets[i]["gt_ext"],
|
133 |
+
"mid_ext":datasets[i]["mid_ext"],
|
134 |
+
"cache_dir":datasets[i]["cache_dir"]})
|
135 |
+
|
136 |
+
return name_im_gt_mid_list
|
137 |
+
|
138 |
+
def create_dataloaders(name_im_gt_mid_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False,is_train=True):
|
139 |
+
## model="train": return one dataloader for training
|
140 |
+
## model="valid": return a list of dataloaders for validation or testing
|
141 |
+
|
142 |
+
gos_dataloaders = []
|
143 |
+
gos_datasets = []
|
144 |
+
|
145 |
+
if(len(name_im_gt_mid_list)==0):
|
146 |
+
return gos_dataloaders, gos_datasets
|
147 |
+
|
148 |
+
num_workers_ = 0
|
149 |
+
# if(batch_size>1):
|
150 |
+
# num_workers_ = 2
|
151 |
+
# if(batch_size>4):
|
152 |
+
# num_workers_ = 4
|
153 |
+
# if(batch_size>8):
|
154 |
+
# num_workers_ = 8
|
155 |
+
|
156 |
+
for i in range(0,len(name_im_gt_mid_list)):
|
157 |
+
gos_dataset = GOSDatasetCache([name_im_gt_mid_list[i]],
|
158 |
+
cache_size = cache_size,
|
159 |
+
cache_path = name_im_gt_mid_list[i]["cache_dir"],
|
160 |
+
cache_boost = cache_boost,
|
161 |
+
transform = transforms.Compose(my_transforms),
|
162 |
+
is_train=is_train)
|
163 |
+
gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_))
|
164 |
+
gos_datasets.append(gos_dataset)
|
165 |
+
|
166 |
+
return gos_dataloaders, gos_datasets
|
167 |
+
|
168 |
+
def im_reader(im_path):
|
169 |
+
image = Image.open(im_path).convert('RGB')
|
170 |
+
corrected_image = ImageOps.exif_transpose(image)
|
171 |
+
# return plt.imread(im_path)
|
172 |
+
return np.array(corrected_image)
|
173 |
+
|
174 |
+
def im_preprocess(im,size):
|
175 |
+
if len(im.shape) > 3:
|
176 |
+
im = im[:,:,:3]
|
177 |
+
if len(im.shape) < 3:
|
178 |
+
im = im[:, :, np.newaxis]
|
179 |
+
if im.shape[2] == 1:
|
180 |
+
im = np.repeat(im, 3, axis=2)
|
181 |
+
im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
|
182 |
+
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
183 |
+
if(len(size)<2):
|
184 |
+
return im_tensor, im.shape[0:2]
|
185 |
+
else:
|
186 |
+
im_tensor = torch.unsqueeze(im_tensor,0)
|
187 |
+
im_tensor = F.upsample(im_tensor, size, mode="bilinear")
|
188 |
+
im_tensor = torch.squeeze(im_tensor,0)
|
189 |
+
|
190 |
+
return im_tensor.type(torch.uint8), im.shape[0:2]
|
191 |
+
|
192 |
+
def gt_preprocess(gt,size):
|
193 |
+
if len(gt.shape) > 2:
|
194 |
+
gt = gt[:, :, 0]
|
195 |
+
|
196 |
+
gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0)
|
197 |
+
|
198 |
+
if(len(size)<2):
|
199 |
+
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
200 |
+
else:
|
201 |
+
gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0)
|
202 |
+
gt_tensor = F.upsample(gt_tensor, size, mode="bilinear")
|
203 |
+
gt_tensor = torch.squeeze(gt_tensor,0)
|
204 |
+
|
205 |
+
return gt_tensor.type(torch.uint8), gt.shape[0:2]
|
206 |
+
# return gt_tensor, gt.shape[0:2]
|
207 |
+
|
208 |
+
class GOSRandomHFlip(object):
|
209 |
+
def __init__(self,prob=0.25):
|
210 |
+
self.prob = prob
|
211 |
+
def __call__(self,sample):
|
212 |
+
imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
|
213 |
+
|
214 |
+
# random horizontal flip
|
215 |
+
randomnum = random.random()
|
216 |
+
if randomnum <= self.prob:
|
217 |
+
image = torch.flip(image,dims=[2])
|
218 |
+
label = torch.flip(label,dims=[2])
|
219 |
+
box = torch.flip(box,dims=[2])
|
220 |
+
mask = torch.flip(mask,dims=[2])
|
221 |
+
elif randomnum <= self.prob*2:
|
222 |
+
image = torch.flip(image,dims=[1])
|
223 |
+
label = torch.flip(label,dims=[1])
|
224 |
+
box = torch.flip(box,dims=[1])
|
225 |
+
mask = torch.flip(mask,dims=[1])
|
226 |
+
elif randomnum <= self.prob*3:
|
227 |
+
image = torch.flip(image,dims=[2])
|
228 |
+
label = torch.flip(label,dims=[2])
|
229 |
+
box = torch.flip(box,dims=[2])
|
230 |
+
mask = torch.flip(mask,dims=[2])
|
231 |
+
image = torch.flip(image,dims=[1])
|
232 |
+
label = torch.flip(label,dims=[1])
|
233 |
+
box = torch.flip(box,dims=[1])
|
234 |
+
mask = torch.flip(mask,dims=[1])
|
235 |
+
|
236 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
|
237 |
+
|
238 |
+
class GOSResize(object):
|
239 |
+
def __init__(self,size=[320,320]):
|
240 |
+
self.size = size
|
241 |
+
def __call__(self,sample):
|
242 |
+
imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
|
243 |
+
|
244 |
+
# import time
|
245 |
+
# start = time.time()
|
246 |
+
|
247 |
+
image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0)
|
248 |
+
label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0)
|
249 |
+
|
250 |
+
# print("time for resize: ", time.time()-start)
|
251 |
+
|
252 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
|
253 |
+
|
254 |
+
class GOSRandomCrop(object):
|
255 |
+
def __init__(self,size=[288,288]):
|
256 |
+
self.size = size
|
257 |
+
def __call__(self,sample):
|
258 |
+
imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
|
259 |
+
|
260 |
+
h, w = image.shape[1:]
|
261 |
+
new_h, new_w = self.size
|
262 |
+
|
263 |
+
top = np.random.randint(0, h - new_h)
|
264 |
+
left = np.random.randint(0, w - new_w)
|
265 |
+
|
266 |
+
image = image[:,top:top+new_h,left:left+new_w]
|
267 |
+
label = label[:,top:top+new_h,left:left+new_w]
|
268 |
+
|
269 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
|
270 |
+
|
271 |
+
|
272 |
+
class GOSNormalize(object):
|
273 |
+
def __init__(self, mean=[0.485,0.456,0.406,0], std=[0.229,0.224,0.225,1.0]):
|
274 |
+
self.mean = mean
|
275 |
+
self.std = std
|
276 |
+
|
277 |
+
def __call__(self,sample):
|
278 |
+
|
279 |
+
imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
|
280 |
+
# print(image.shape)
|
281 |
+
image = normalize(image,self.mean,self.std)
|
282 |
+
mask = normalize(mask,0,1)
|
283 |
+
box = normalize(box,0,1)
|
284 |
+
|
285 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
|
286 |
+
|
287 |
+
class GOSRandomthorw(object):
|
288 |
+
def __init__(self,ratio=0.25):
|
289 |
+
self.ratio = ratio
|
290 |
+
def __call__(self,sample):
|
291 |
+
imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask']
|
292 |
+
randomnum = random.random()
|
293 |
+
if randomnum < self.ratio:
|
294 |
+
mask = torch.zeros_like(mask)
|
295 |
+
elif randomnum < self.ratio*2:
|
296 |
+
box = torch.zeros_like(box)
|
297 |
+
elif randomnum < self.ratio*3:
|
298 |
+
mask = torch.zeros_like(mask)
|
299 |
+
box = torch.zeros_like(box)
|
300 |
+
|
301 |
+
return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box}
|
302 |
+
|
303 |
+
class GOSDatasetCache(Dataset):
|
304 |
+
|
305 |
+
def __init__(self, name_im_gt_mid_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None, is_train=True):
|
306 |
+
|
307 |
+
self.is_train = is_train
|
308 |
+
self.cache_size = cache_size
|
309 |
+
self.cache_path = cache_path
|
310 |
+
self.cache_file_name = cache_file_name
|
311 |
+
self.cache_boost_name = ""
|
312 |
+
|
313 |
+
self.cache_boost = cache_boost
|
314 |
+
# self.ims_npy = None
|
315 |
+
# self.gts_npy = None
|
316 |
+
|
317 |
+
## cache all the images and ground truth into a single pytorch tensor
|
318 |
+
self.ims_pt = None
|
319 |
+
self.gts_pt = None
|
320 |
+
self.mid_pt = None
|
321 |
+
|
322 |
+
## we will cache the npy as well regardless of the cache_boost
|
323 |
+
# if(self.cache_boost):
|
324 |
+
self.cache_boost_name = cache_file_name.split('.json')[0]
|
325 |
+
|
326 |
+
self.transform = transform
|
327 |
+
|
328 |
+
self.dataset = {}
|
329 |
+
|
330 |
+
## combine different datasets into one
|
331 |
+
dataset_names = []
|
332 |
+
dt_name_list = [] # dataset name per image
|
333 |
+
im_name_list = [] # image name
|
334 |
+
im_path_list = [] # im path
|
335 |
+
gt_path_list = [] # gt path
|
336 |
+
mid_path_list = []
|
337 |
+
im_ext_list = [] # im ext
|
338 |
+
gt_ext_list = [] # gt ext
|
339 |
+
mid_ext_list = []
|
340 |
+
for i in range(0,len(name_im_gt_mid_list)):
|
341 |
+
dataset_names.append(name_im_gt_mid_list[i]["dataset_name"])
|
342 |
+
# dataset name repeated based on the number of images in this dataset
|
343 |
+
dt_name_list.extend([name_im_gt_mid_list[i]["dataset_name"] for x in name_im_gt_mid_list[i]["im_path"]])
|
344 |
+
im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_mid_list[i]["im_ext"])[0] for x in name_im_gt_mid_list[i]["im_path"]])
|
345 |
+
im_path_list.extend(name_im_gt_mid_list[i]["im_path"])
|
346 |
+
gt_path_list.extend(name_im_gt_mid_list[i]["gt_path"])
|
347 |
+
mid_path_list.extend(name_im_gt_mid_list[i]["mid_path"])
|
348 |
+
im_ext_list.extend([name_im_gt_mid_list[i]["im_ext"] for x in name_im_gt_mid_list[i]["im_path"]])
|
349 |
+
gt_ext_list.extend([name_im_gt_mid_list[i]["gt_ext"] for x in name_im_gt_mid_list[i]["gt_path"]])
|
350 |
+
mid_ext_list.extend([name_im_gt_mid_list[i]["mid_ext"] for x in name_im_gt_mid_list[i]["mid_path"]])
|
351 |
+
|
352 |
+
|
353 |
+
self.dataset["data_name"] = dt_name_list
|
354 |
+
self.dataset["im_name"] = im_name_list
|
355 |
+
self.dataset["im_path"] = im_path_list
|
356 |
+
self.dataset["ori_im_path"] = deepcopy(im_path_list)
|
357 |
+
self.dataset["gt_path"] = gt_path_list
|
358 |
+
self.dataset["ori_gt_path"] = deepcopy(gt_path_list)
|
359 |
+
self.dataset["mid_path"] = mid_path_list
|
360 |
+
self.dataset["ori_mid_path"] = deepcopy(mid_path_list)
|
361 |
+
self.dataset["im_shp"] = []
|
362 |
+
self.dataset["gt_shp"] = []
|
363 |
+
self.dataset["mid_shp"] = []
|
364 |
+
self.dataset["im_ext"] = im_ext_list
|
365 |
+
self.dataset["gt_ext"] = gt_ext_list
|
366 |
+
self.dataset["mid_ext"] = mid_ext_list
|
367 |
+
|
368 |
+
|
369 |
+
self.dataset["ims_pt_dir"] = ""
|
370 |
+
self.dataset["gts_pt_dir"] = ""
|
371 |
+
self.dataset["mid_pt_dir"] = ""
|
372 |
+
|
373 |
+
self.dataset = self.manage_cache(dataset_names)
|
374 |
+
|
375 |
+
def manage_cache(self,dataset_names):
|
376 |
+
if not os.path.exists(self.cache_path): # create the folder for cache
|
377 |
+
os.makedirs(self.cache_path)
|
378 |
+
cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size]))
|
379 |
+
# if cache_folder.__len__() > 100: cache_folder = cache_folder[:100]
|
380 |
+
if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache
|
381 |
+
return self.cache(cache_folder)
|
382 |
+
return self.load_cache(cache_folder)
|
383 |
+
|
384 |
+
def cache(self,cache_folder):
|
385 |
+
os.mkdir(cache_folder)
|
386 |
+
cached_dataset = deepcopy(self.dataset)
|
387 |
+
|
388 |
+
# ims_list = []
|
389 |
+
# gts_list = []
|
390 |
+
ims_pt_list = []
|
391 |
+
gts_pt_list = []
|
392 |
+
mid_pt_list = []
|
393 |
+
for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])):
|
394 |
+
|
395 |
+
im_id = cached_dataset["im_name"][i]
|
396 |
+
# print("im_path: ", im_path)
|
397 |
+
im = im_reader(im_path)
|
398 |
+
im, im_shp = im_preprocess(im,self.cache_size)
|
399 |
+
im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt")
|
400 |
+
torch.save(im,im_cache_file)
|
401 |
+
|
402 |
+
cached_dataset["im_path"][i] = im_cache_file
|
403 |
+
if(self.cache_boost):
|
404 |
+
ims_pt_list.append(torch.unsqueeze(im,0))
|
405 |
+
# ims_list.append(im.cpu().data.numpy().astype(np.uint8))
|
406 |
+
|
407 |
+
gt = np.zeros(im.shape[0:2])
|
408 |
+
if len(self.dataset["gt_path"])!=0:
|
409 |
+
gt = im_reader(self.dataset["gt_path"][i])
|
410 |
+
gt, gt_shp = gt_preprocess(gt,self.cache_size)
|
411 |
+
gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt")
|
412 |
+
torch.save(gt,gt_cache_file)
|
413 |
+
if len(self.dataset["gt_path"])>0:
|
414 |
+
cached_dataset["gt_path"][i] = gt_cache_file
|
415 |
+
else:
|
416 |
+
cached_dataset["gt_path"].append(gt_cache_file)
|
417 |
+
if(self.cache_boost):
|
418 |
+
gts_pt_list.append(torch.unsqueeze(gt,0))
|
419 |
+
|
420 |
+
mid = np.zeros(im.shape[0:2])
|
421 |
+
if len(self.dataset["mid_path"])!=0:
|
422 |
+
mid = im_reader(self.dataset["mid_path"][i])
|
423 |
+
mid, mid_shp = gt_preprocess(mid,self.cache_size)
|
424 |
+
mid_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_mid.pt")
|
425 |
+
torch.save(mid,mid_cache_file)
|
426 |
+
if len(self.dataset["mid_path"])>0:
|
427 |
+
cached_dataset["mid_path"][i] = mid_cache_file
|
428 |
+
else:
|
429 |
+
cached_dataset["mid_path"].append(mid_cache_file)
|
430 |
+
if(self.cache_boost):
|
431 |
+
mid_pt_list.append(torch.unsqueeze(mid,0))
|
432 |
+
|
433 |
+
# gts_list.append(gt.cpu().data.numpy().astype(np.uint8))
|
434 |
+
|
435 |
+
# im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt")
|
436 |
+
# torch.save(gt_shp, shp_cache_file)
|
437 |
+
cached_dataset["im_shp"].append(im_shp)
|
438 |
+
# self.dataset["im_shp"].append(im_shp)
|
439 |
+
|
440 |
+
# shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt")
|
441 |
+
# torch.save(gt_shp, shp_cache_file)
|
442 |
+
cached_dataset["gt_shp"].append(gt_shp)
|
443 |
+
# self.dataset["gt_shp"].append(gt_shp)
|
444 |
+
|
445 |
+
cached_dataset["mid_shp"].append(mid_shp)
|
446 |
+
|
447 |
+
if(self.cache_boost):
|
448 |
+
cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt')
|
449 |
+
cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt')
|
450 |
+
cached_dataset["mid_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_mids.pt')
|
451 |
+
self.ims_pt = torch.cat(ims_pt_list,dim=0)
|
452 |
+
self.gts_pt = torch.cat(gts_pt_list,dim=0)
|
453 |
+
self.mid_pt = torch.cat(mid_pt_list,dim=0)
|
454 |
+
torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"])
|
455 |
+
torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"])
|
456 |
+
torch.save(torch.cat(mid_pt_list,dim=0),cached_dataset["mid_pt_dir"])
|
457 |
+
|
458 |
+
try:
|
459 |
+
json_file = open(os.path.join(cache_folder, self.cache_file_name),"w")
|
460 |
+
json.dump(cached_dataset, json_file)
|
461 |
+
json_file.close()
|
462 |
+
except Exception:
|
463 |
+
raise FileNotFoundError("Cannot create JSON")
|
464 |
+
return cached_dataset
|
465 |
+
|
466 |
+
def load_cache(self, cache_folder):
|
467 |
+
print(os.path.join(cache_folder,self.cache_file_name))
|
468 |
+
json_file = open(os.path.join(cache_folder,self.cache_file_name),"r")
|
469 |
+
dataset = json.load(json_file)
|
470 |
+
json_file.close()
|
471 |
+
## if cache_boost is true, we will load the image npy and ground truth npy into the RAM
|
472 |
+
## otherwise the pytorch tensor will be loaded
|
473 |
+
if(self.cache_boost):
|
474 |
+
# self.ims_npy = np.load(dataset["ims_npy_dir"])
|
475 |
+
# self.gts_npy = np.load(dataset["gts_npy_dir"])
|
476 |
+
self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu')
|
477 |
+
self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu')
|
478 |
+
self.mid_pt = torch.load(dataset["mid_pt_dir"], map_location='cpu')
|
479 |
+
return dataset
|
480 |
+
|
481 |
+
def __len__(self):
|
482 |
+
return len(self.dataset["im_path"])
|
483 |
+
|
484 |
+
def __getitem__(self, idx):
|
485 |
+
|
486 |
+
im = None
|
487 |
+
gt = None
|
488 |
+
mid = None
|
489 |
+
if(self.cache_boost and self.ims_pt is not None):
|
490 |
+
|
491 |
+
# start = time.time()
|
492 |
+
im = self.ims_pt[idx]#.type(torch.float32)
|
493 |
+
gt = self.gts_pt[idx]#.type(torch.float32)
|
494 |
+
mid = self.mid_pt[idx]#.type(torch.float32)
|
495 |
+
# print(idx, 'time for pt loading: ', time.time()-start)
|
496 |
+
|
497 |
+
else:
|
498 |
+
# import time
|
499 |
+
# start = time.time()
|
500 |
+
# print("tensor***")
|
501 |
+
im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:]))
|
502 |
+
im = torch.load(im_pt_path)#(self.dataset["im_path"][idx])
|
503 |
+
gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:]))
|
504 |
+
gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx])
|
505 |
+
mid_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["mid_path"][idx].split(os.sep)[-2:]))
|
506 |
+
mid = torch.load(mid_pt_path)#(self.dataset["gt_path"][idx])
|
507 |
+
# print(idx,'time for tensor loading: ', time.time()-start)
|
508 |
+
|
509 |
+
|
510 |
+
im_shp = self.dataset["im_shp"][idx]
|
511 |
+
# print("time for loading im and gt: ", time.time()-start)
|
512 |
+
|
513 |
+
box = torch.zeros_like(gt[0])+gt[0]
|
514 |
+
rows, cols = torch.where(box>0)
|
515 |
+
left = torch.min(cols)
|
516 |
+
top = torch.min(rows)
|
517 |
+
right = torch.max(cols)
|
518 |
+
bottom = torch.max(rows)
|
519 |
+
box[top:bottom,left:right] = 255
|
520 |
+
box[box!=255] = 0
|
521 |
+
box = box[None,...]
|
522 |
+
gim = torch.cat([im,mid,box],dim=0)
|
523 |
+
|
524 |
+
# start_time = time.time()
|
525 |
+
im = torch.divide(gim,255.0)
|
526 |
+
gt = torch.divide(gt,255.0)
|
527 |
+
mask = torch.divide(mid,255.0)
|
528 |
+
box = torch.divide(box,255.0)
|
529 |
+
|
530 |
+
|
531 |
+
sample = {
|
532 |
+
"imidx": torch.from_numpy(np.array(idx)),
|
533 |
+
"image": im,
|
534 |
+
"label": gt,
|
535 |
+
"mask": mask,
|
536 |
+
'box': box,
|
537 |
+
"shape": torch.from_numpy(np.array(im_shp)),
|
538 |
+
}
|
539 |
+
|
540 |
+
if self.transform:
|
541 |
+
sample = self.transform(sample)
|
542 |
+
return sample
|
IS_Net/datalist.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_test = {"name": "DIS5K-test",
|
2 |
+
"im_dir": r"DIS5K/DIS5K-test/im",
|
3 |
+
"gt_dir": r"DIS5K/DIS5K-test/gt",
|
4 |
+
"mid_dir":r"DIS5K/DIS5K-test/mask",
|
5 |
+
"im_ext": ".jpg",
|
6 |
+
"gt_ext": ".png",
|
7 |
+
"mid_ext": ".png",
|
8 |
+
"cache_dir":r"DIS5K-Cache/DIS-test"}
|
9 |
+
|
10 |
+
dataset_tr = {"name": "DIS5K-TR-m",
|
11 |
+
"im_dir": r"DIS5K/DIS-TR/im",
|
12 |
+
"gt_dir": r"DIS5K/DIS-TR/gt",
|
13 |
+
"mid_dir":r"DIS5K-TR/mask",
|
14 |
+
"im_ext": ".jpg",
|
15 |
+
"gt_ext": ".png",
|
16 |
+
"mid_ext": ".png",
|
17 |
+
"cache_dir":r"DIS5K-Cache/DIS-TR-m"}
|
18 |
+
|
19 |
+
dataset_vd = {"name": "DIS5K-VD-m",
|
20 |
+
"im_dir": r"DIS5K/DIS-VD/im",
|
21 |
+
"gt_dir": r"DIS5K/DIS-VD/gt",
|
22 |
+
"mid_dir":r"DIS5K/DIS5K-VD/mask",
|
23 |
+
"im_ext": ".jpg",
|
24 |
+
"gt_ext": ".png",
|
25 |
+
"mid_ext": ".png",
|
26 |
+
"cache_dir":r"DIS5K-Cache/DIS-VD-m"}
|
27 |
+
|
28 |
+
dataset_te1 = {"name": "DIS5K-TE1-m",
|
29 |
+
"im_dir": r"DIS5K/DIS-TE1/im",
|
30 |
+
"gt_dir": r"DIS5K/DIS-TE1/gt",
|
31 |
+
"mid_dir":r"DIS5K/DIS5K-TE1/mask",
|
32 |
+
"im_ext": ".jpg",
|
33 |
+
"gt_ext": ".png",
|
34 |
+
"mid_ext": ".png",
|
35 |
+
"cache_dir":r"DIS5K-Cache/DIS-TE1-m"}
|
36 |
+
|
37 |
+
dataset_te2 = {"name": "DIS5K-TE2-m",
|
38 |
+
"im_dir": r"DIS5K/DIS-TE2/im",
|
39 |
+
"gt_dir": r"DIS5K/DIS-TE2/gt",
|
40 |
+
"mid_dir":r"DIS5K/DIS5K-TE2/mask",
|
41 |
+
"im_ext": ".jpg",
|
42 |
+
"gt_ext": ".png",
|
43 |
+
"mid_ext": ".png",
|
44 |
+
"cache_dir":r"DIS5K-Cache/DIS-TE2-m"}
|
45 |
+
|
46 |
+
dataset_te3 = {"name": "DIS5K-TE3-m",
|
47 |
+
"im_dir": r"DIS5K/DIS-TE3/im",
|
48 |
+
"gt_dir": r"DIS5K/DIS-TE3/gt",
|
49 |
+
"mid_dir":r"DIS5K/DIS5K-TE3/mask",
|
50 |
+
"im_ext": ".jpg",
|
51 |
+
"gt_ext": ".png",
|
52 |
+
"mid_ext": ".png",
|
53 |
+
"cache_dir":r"DIS5K-Cache/DIS-TE3-m"}
|
54 |
+
|
55 |
+
dataset_te4 = {"name": "DIS5K-TE4-m",
|
56 |
+
"im_dir": r"DIS5K/DIS-TE4/im",
|
57 |
+
"gt_dir": r"DIS5K/DIS-TE4/gt",
|
58 |
+
"mid_dir":r"DIS5K/DIS5K-TE4/mask",
|
59 |
+
"im_ext": ".jpg",
|
60 |
+
"gt_ext": ".png",
|
61 |
+
"mid_ext": ".png",
|
62 |
+
"cache_dir":r"DIS5K-Cache/DIS-TE4-m"}
|
IS_Net/models/__pycache__/isnet.cpython-311.pyc
ADDED
Binary file (33.1 kB). View file
|
|
IS_Net/models/isnet.py
ADDED
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import models
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from timm.models.layers import trunc_normal_, DropPath
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import monai
|
8 |
+
|
9 |
+
def iou_loss(pred, mask):
|
10 |
+
inter = (pred * mask).sum(dim=(2, 3)) #交集
|
11 |
+
union = (pred + mask).sum(dim=(2, 3)) - inter #并集-交集
|
12 |
+
iou = 1 - (inter + 1) / (union + 1)
|
13 |
+
return iou.mean()
|
14 |
+
|
15 |
+
|
16 |
+
bce_loss = nn.BCELoss(reduction='mean')
|
17 |
+
|
18 |
+
def muti_loss_fusion(preds, target):
|
19 |
+
loss0 = 0.0
|
20 |
+
loss = 0.0
|
21 |
+
|
22 |
+
for i in range(0,len(preds)):
|
23 |
+
# print("i: ", i, preds[i].shape)
|
24 |
+
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
|
25 |
+
# tmp_target = _upsample_like(target,preds[i])
|
26 |
+
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
|
27 |
+
loss = loss + 20*bce_loss(preds[i],tmp_target) + 0.5*iou_loss(preds[i],tmp_target)
|
28 |
+
# loss = loss + bce_loss(preds[i],tmp_target)+ iou_loss(preds[i],tmp_target)
|
29 |
+
# loss = loss + bce_loss(preds[i],tmp_target)
|
30 |
+
else:
|
31 |
+
loss = loss + 20*bce_loss(preds[i],target) + 0.5*iou_loss(preds[i],target)
|
32 |
+
# loss = loss + bce_loss(preds[i],target) + iou_loss(preds[i],target)
|
33 |
+
# loss = loss + bce_loss(preds[i],target)
|
34 |
+
if(i==0):
|
35 |
+
loss0 = loss
|
36 |
+
return loss0, loss
|
37 |
+
|
38 |
+
MSE_loss = nn.MSELoss(reduction='mean')
|
39 |
+
kl_loss = nn.KLDivLoss(reduction='mean')
|
40 |
+
l1_loss = nn.L1Loss(reduction='mean')
|
41 |
+
smooth_l1_loss = nn.SmoothL1Loss(reduction='mean')
|
42 |
+
def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
|
43 |
+
loss0 = 0.0
|
44 |
+
loss = 0.0
|
45 |
+
|
46 |
+
for i in range(0,len(preds)):
|
47 |
+
# print("i: ", i, preds[i].shape)
|
48 |
+
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
|
49 |
+
# tmp_target = _upsample_like(target,preds[i])
|
50 |
+
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
|
51 |
+
loss = loss + 20*bce_loss(preds[i],tmp_target) + 0.5*iou_loss(preds[i],tmp_target)
|
52 |
+
# loss = loss + bce_loss(preds[i],tmp_target) + iou_loss(preds[i],tmp_target)
|
53 |
+
# loss = loss + bce_loss(preds[i],tmp_target)
|
54 |
+
else:
|
55 |
+
loss = loss + 20*bce_loss(preds[i],target) + 0.5*iou_loss(preds[i],target)
|
56 |
+
# loss = loss + bce_loss(preds[i],target) + iou_loss(preds[i],target)
|
57 |
+
# loss = loss + bce_loss(preds[i],target)
|
58 |
+
if(i==0):
|
59 |
+
loss0 = loss
|
60 |
+
|
61 |
+
for i in range(0,len(dfs)):
|
62 |
+
if(mode=='MSE'):
|
63 |
+
loss = loss + MSE_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
|
64 |
+
# print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
|
65 |
+
elif(mode=='KL'):
|
66 |
+
loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
|
67 |
+
# print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
|
68 |
+
elif(mode=='MAE'):
|
69 |
+
loss = loss + l1_loss(dfs[i],fs[i])
|
70 |
+
# print("ls_loss: ", l1_loss(dfs[i],fs[i]))
|
71 |
+
elif(mode=='SmoothL1'):
|
72 |
+
loss = loss + smooth_l1_loss(dfs[i],fs[i])
|
73 |
+
# print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
|
74 |
+
|
75 |
+
return loss0, loss
|
76 |
+
|
77 |
+
class REBNCONV(nn.Module):
|
78 |
+
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
79 |
+
super(REBNCONV,self).__init__()
|
80 |
+
|
81 |
+
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
|
82 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
83 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
84 |
+
|
85 |
+
def forward(self,x):
|
86 |
+
|
87 |
+
hx = x
|
88 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
89 |
+
|
90 |
+
return xout
|
91 |
+
|
92 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
93 |
+
def _upsample_like(src,tar):
|
94 |
+
|
95 |
+
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
|
96 |
+
|
97 |
+
return src
|
98 |
+
|
99 |
+
|
100 |
+
### RSU-7 ###
|
101 |
+
class RSU7(nn.Module):
|
102 |
+
|
103 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
104 |
+
super(RSU7,self).__init__()
|
105 |
+
|
106 |
+
self.in_ch = in_ch
|
107 |
+
self.mid_ch = mid_ch
|
108 |
+
self.out_ch = out_ch
|
109 |
+
|
110 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
|
111 |
+
|
112 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
113 |
+
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
114 |
+
|
115 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
116 |
+
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
117 |
+
|
118 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
119 |
+
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
120 |
+
|
121 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
122 |
+
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
123 |
+
|
124 |
+
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
125 |
+
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
126 |
+
|
127 |
+
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
128 |
+
|
129 |
+
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
130 |
+
|
131 |
+
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
132 |
+
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
133 |
+
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
134 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
135 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
136 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
137 |
+
|
138 |
+
def forward(self,x):
|
139 |
+
b, c, h, w = x.shape
|
140 |
+
|
141 |
+
hx = x
|
142 |
+
hxin = self.rebnconvin(hx)
|
143 |
+
|
144 |
+
hx1 = self.rebnconv1(hxin)
|
145 |
+
hx = self.pool1(hx1)
|
146 |
+
|
147 |
+
hx2 = self.rebnconv2(hx)
|
148 |
+
hx = self.pool2(hx2)
|
149 |
+
|
150 |
+
hx3 = self.rebnconv3(hx)
|
151 |
+
hx = self.pool3(hx3)
|
152 |
+
|
153 |
+
hx4 = self.rebnconv4(hx)
|
154 |
+
hx = self.pool4(hx4)
|
155 |
+
|
156 |
+
hx5 = self.rebnconv5(hx)
|
157 |
+
hx = self.pool5(hx5)
|
158 |
+
|
159 |
+
hx6 = self.rebnconv6(hx)
|
160 |
+
|
161 |
+
hx7 = self.rebnconv7(hx6)
|
162 |
+
|
163 |
+
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
|
164 |
+
hx6dup = _upsample_like(hx6d,hx5)
|
165 |
+
|
166 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
|
167 |
+
hx5dup = _upsample_like(hx5d,hx4)
|
168 |
+
|
169 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
170 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
171 |
+
|
172 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
173 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
174 |
+
|
175 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
176 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
177 |
+
|
178 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
179 |
+
|
180 |
+
return hx1d + hxin
|
181 |
+
|
182 |
+
|
183 |
+
### RSU-6 ###
|
184 |
+
class RSU6(nn.Module):
|
185 |
+
|
186 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
187 |
+
super(RSU6,self).__init__()
|
188 |
+
|
189 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
190 |
+
|
191 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
192 |
+
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
193 |
+
|
194 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
195 |
+
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
196 |
+
|
197 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
198 |
+
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
199 |
+
|
200 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
201 |
+
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
202 |
+
|
203 |
+
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
204 |
+
|
205 |
+
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
206 |
+
|
207 |
+
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
208 |
+
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
209 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
210 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
211 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
212 |
+
|
213 |
+
def forward(self,x):
|
214 |
+
|
215 |
+
hx = x
|
216 |
+
|
217 |
+
hxin = self.rebnconvin(hx)
|
218 |
+
|
219 |
+
hx1 = self.rebnconv1(hxin)
|
220 |
+
hx = self.pool1(hx1)
|
221 |
+
|
222 |
+
hx2 = self.rebnconv2(hx)
|
223 |
+
hx = self.pool2(hx2)
|
224 |
+
|
225 |
+
hx3 = self.rebnconv3(hx)
|
226 |
+
hx = self.pool3(hx3)
|
227 |
+
|
228 |
+
hx4 = self.rebnconv4(hx)
|
229 |
+
hx = self.pool4(hx4)
|
230 |
+
|
231 |
+
hx5 = self.rebnconv5(hx)
|
232 |
+
|
233 |
+
hx6 = self.rebnconv6(hx5)
|
234 |
+
|
235 |
+
|
236 |
+
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
|
237 |
+
hx5dup = _upsample_like(hx5d,hx4)
|
238 |
+
|
239 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
240 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
241 |
+
|
242 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
243 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
244 |
+
|
245 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
246 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
247 |
+
|
248 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
249 |
+
|
250 |
+
return hx1d + hxin
|
251 |
+
|
252 |
+
### RSU-5 ###
|
253 |
+
class RSU5(nn.Module):
|
254 |
+
|
255 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
256 |
+
super(RSU5,self).__init__()
|
257 |
+
|
258 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
259 |
+
|
260 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
261 |
+
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
262 |
+
|
263 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
264 |
+
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
265 |
+
|
266 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
267 |
+
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
268 |
+
|
269 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
270 |
+
|
271 |
+
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
272 |
+
|
273 |
+
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
274 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
275 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
276 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
277 |
+
|
278 |
+
def forward(self,x):
|
279 |
+
|
280 |
+
hx = x
|
281 |
+
|
282 |
+
hxin = self.rebnconvin(hx)
|
283 |
+
|
284 |
+
hx1 = self.rebnconv1(hxin)
|
285 |
+
hx = self.pool1(hx1)
|
286 |
+
|
287 |
+
hx2 = self.rebnconv2(hx)
|
288 |
+
hx = self.pool2(hx2)
|
289 |
+
|
290 |
+
hx3 = self.rebnconv3(hx)
|
291 |
+
hx = self.pool3(hx3)
|
292 |
+
|
293 |
+
hx4 = self.rebnconv4(hx)
|
294 |
+
|
295 |
+
hx5 = self.rebnconv5(hx4)
|
296 |
+
|
297 |
+
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
|
298 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
299 |
+
|
300 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
301 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
302 |
+
|
303 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
304 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
305 |
+
|
306 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
307 |
+
|
308 |
+
return hx1d + hxin
|
309 |
+
|
310 |
+
### RSU-4 ###
|
311 |
+
class RSU4(nn.Module):
|
312 |
+
|
313 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
314 |
+
super(RSU4,self).__init__()
|
315 |
+
|
316 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
317 |
+
|
318 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
319 |
+
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
320 |
+
|
321 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
322 |
+
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
323 |
+
|
324 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
325 |
+
|
326 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
327 |
+
|
328 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
329 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
330 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
331 |
+
|
332 |
+
def forward(self,x):
|
333 |
+
|
334 |
+
hx = x
|
335 |
+
|
336 |
+
hxin = self.rebnconvin(hx)
|
337 |
+
|
338 |
+
hx1 = self.rebnconv1(hxin)
|
339 |
+
hx = self.pool1(hx1)
|
340 |
+
|
341 |
+
hx2 = self.rebnconv2(hx)
|
342 |
+
hx = self.pool2(hx2)
|
343 |
+
|
344 |
+
hx3 = self.rebnconv3(hx)
|
345 |
+
|
346 |
+
hx4 = self.rebnconv4(hx3)
|
347 |
+
|
348 |
+
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
349 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
350 |
+
|
351 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
352 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
353 |
+
|
354 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
355 |
+
|
356 |
+
return hx1d + hxin
|
357 |
+
|
358 |
+
### RSU-4F ###
|
359 |
+
class RSU4F(nn.Module):
|
360 |
+
|
361 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
362 |
+
super(RSU4F,self).__init__()
|
363 |
+
|
364 |
+
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
365 |
+
|
366 |
+
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
367 |
+
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
368 |
+
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
|
369 |
+
|
370 |
+
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
|
371 |
+
|
372 |
+
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
|
373 |
+
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
|
374 |
+
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
375 |
+
|
376 |
+
def forward(self,x):
|
377 |
+
|
378 |
+
hx = x
|
379 |
+
|
380 |
+
hxin = self.rebnconvin(hx)
|
381 |
+
|
382 |
+
hx1 = self.rebnconv1(hxin)
|
383 |
+
hx2 = self.rebnconv2(hx1)
|
384 |
+
hx3 = self.rebnconv3(hx2)
|
385 |
+
|
386 |
+
hx4 = self.rebnconv4(hx3)
|
387 |
+
|
388 |
+
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
389 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
|
390 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
|
391 |
+
|
392 |
+
return hx1d + hxin
|
393 |
+
|
394 |
+
|
395 |
+
class myrebnconv(nn.Module):
|
396 |
+
def __init__(self, in_ch=3,
|
397 |
+
out_ch=1,
|
398 |
+
kernel_size=3,
|
399 |
+
stride=1,
|
400 |
+
padding=1,
|
401 |
+
dilation=1,
|
402 |
+
groups=1):
|
403 |
+
super(myrebnconv,self).__init__()
|
404 |
+
|
405 |
+
self.conv = nn.Conv2d(in_ch,
|
406 |
+
out_ch,
|
407 |
+
kernel_size=kernel_size,
|
408 |
+
stride=stride,
|
409 |
+
padding=padding,
|
410 |
+
dilation=dilation,
|
411 |
+
groups=groups)
|
412 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
413 |
+
self.rl = nn.ReLU(inplace=True)
|
414 |
+
|
415 |
+
def forward(self,x):
|
416 |
+
return self.rl(self.bn(self.conv(x)))
|
417 |
+
|
418 |
+
|
419 |
+
class ISNetGTEncoder(nn.Module):
|
420 |
+
|
421 |
+
def __init__(self,in_ch=1,out_ch=1):
|
422 |
+
super(ISNetGTEncoder,self).__init__()
|
423 |
+
|
424 |
+
self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
425 |
+
|
426 |
+
self.stage1 = RSU7(16,16,64)
|
427 |
+
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
428 |
+
|
429 |
+
self.stage2 = RSU6(64,16,64)
|
430 |
+
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
431 |
+
|
432 |
+
self.stage3 = RSU5(64,32,128)
|
433 |
+
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
434 |
+
|
435 |
+
self.stage4 = RSU4(128,32,256)
|
436 |
+
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
437 |
+
|
438 |
+
self.stage5 = RSU4F(256,64,512)
|
439 |
+
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
440 |
+
|
441 |
+
self.stage6 = RSU4F(512,64,512)
|
442 |
+
|
443 |
+
|
444 |
+
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
445 |
+
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
446 |
+
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
447 |
+
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
448 |
+
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
449 |
+
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
450 |
+
|
451 |
+
def compute_loss(self, preds, targets):
|
452 |
+
|
453 |
+
return muti_loss_fusion(preds,targets)
|
454 |
+
|
455 |
+
def forward(self,x):
|
456 |
+
|
457 |
+
hx = x
|
458 |
+
|
459 |
+
hxin = self.conv_in(hx)
|
460 |
+
# hx = self.pool_in(hxin)
|
461 |
+
|
462 |
+
#stage 1
|
463 |
+
hx1 = self.stage1(hxin)
|
464 |
+
hx = self.pool12(hx1)
|
465 |
+
|
466 |
+
#stage 2
|
467 |
+
hx2 = self.stage2(hx)
|
468 |
+
hx = self.pool23(hx2)
|
469 |
+
|
470 |
+
#stage 3
|
471 |
+
hx3 = self.stage3(hx)
|
472 |
+
hx = self.pool34(hx3)
|
473 |
+
|
474 |
+
#stage 4
|
475 |
+
hx4 = self.stage4(hx)
|
476 |
+
hx = self.pool45(hx4)
|
477 |
+
|
478 |
+
#stage 5
|
479 |
+
hx5 = self.stage5(hx)
|
480 |
+
hx = self.pool56(hx5)
|
481 |
+
|
482 |
+
#stage 6
|
483 |
+
hx6 = self.stage6(hx)
|
484 |
+
|
485 |
+
|
486 |
+
#side output
|
487 |
+
d1 = self.side1(hx1)
|
488 |
+
d1 = _upsample_like(d1,x)
|
489 |
+
|
490 |
+
d2 = self.side2(hx2)
|
491 |
+
d2 = _upsample_like(d2,x)
|
492 |
+
|
493 |
+
d3 = self.side3(hx3)
|
494 |
+
d3 = _upsample_like(d3,x)
|
495 |
+
|
496 |
+
d4 = self.side4(hx4)
|
497 |
+
d4 = _upsample_like(d4,x)
|
498 |
+
|
499 |
+
d5 = self.side5(hx5)
|
500 |
+
d5 = _upsample_like(d5,x)
|
501 |
+
|
502 |
+
d6 = self.side6(hx6)
|
503 |
+
d6 = _upsample_like(d6,x)
|
504 |
+
|
505 |
+
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
506 |
+
|
507 |
+
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
|
508 |
+
|
509 |
+
class ISNetDIS(nn.Module):
|
510 |
+
|
511 |
+
def __init__(self,in_ch=3,out_ch=1):
|
512 |
+
super(ISNetDIS,self).__init__()
|
513 |
+
|
514 |
+
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
515 |
+
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
516 |
+
|
517 |
+
|
518 |
+
self.stage1 = RSU7(64,32,64)
|
519 |
+
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
520 |
+
|
521 |
+
self.stage2 = RSU6(64,32,128)
|
522 |
+
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
523 |
+
|
524 |
+
self.stage3 = RSU5(128,64,256)
|
525 |
+
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
526 |
+
|
527 |
+
self.stage4 = RSU4(256,128,512)
|
528 |
+
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
529 |
+
|
530 |
+
self.stage5 = RSU4F(512,256,512)
|
531 |
+
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
532 |
+
|
533 |
+
self.stage6 = RSU4F(512,256,512)
|
534 |
+
|
535 |
+
# decoder
|
536 |
+
self.stage5d = RSU4F(1024,256,512)
|
537 |
+
self.stage4d = RSU4(1024,128,256)
|
538 |
+
self.stage3d = RSU5(512,64,128)
|
539 |
+
self.stage2d = RSU6(256,32,64)
|
540 |
+
self.stage1d = RSU7(128,16,64)
|
541 |
+
|
542 |
+
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
543 |
+
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
544 |
+
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
545 |
+
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
546 |
+
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
547 |
+
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
548 |
+
|
549 |
+
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
550 |
+
|
551 |
+
def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
|
552 |
+
|
553 |
+
# return muti_loss_fusion(preds,targets)
|
554 |
+
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
555 |
+
|
556 |
+
def compute_loss(self, preds, targets):
|
557 |
+
|
558 |
+
# return muti_loss_fusion(preds,targets)
|
559 |
+
return muti_loss_fusion(preds, targets)
|
560 |
+
|
561 |
+
def forward(self,x):
|
562 |
+
|
563 |
+
hx = x
|
564 |
+
|
565 |
+
hxin = self.conv_in(hx)
|
566 |
+
|
567 |
+
#stage 1
|
568 |
+
hx1 = self.stage1(hxin)
|
569 |
+
hx = self.pool12(hx1)
|
570 |
+
|
571 |
+
#stage 2
|
572 |
+
hx2 = self.stage2(hx)
|
573 |
+
hx = self.pool23(hx2)
|
574 |
+
|
575 |
+
#stage 3
|
576 |
+
hx3 = self.stage3(hx)
|
577 |
+
hx = self.pool34(hx3)
|
578 |
+
|
579 |
+
#stage 4
|
580 |
+
hx4 = self.stage4(hx)
|
581 |
+
hx = self.pool45(hx4)
|
582 |
+
|
583 |
+
#stage 5
|
584 |
+
hx5 = self.stage5(hx)
|
585 |
+
hx = self.pool56(hx5)
|
586 |
+
|
587 |
+
#stage 6
|
588 |
+
hx6 = self.stage6(hx)
|
589 |
+
|
590 |
+
hx6up = _upsample_like(hx6,hx5)
|
591 |
+
|
592 |
+
#-------------------- decoder --------------------
|
593 |
+
hx5d = self.stage5d(torch.cat([hx6up,hx5],1))
|
594 |
+
hx5dup = _upsample_like(hx5d,hx4)
|
595 |
+
|
596 |
+
hx4d = self.stage4d(torch.cat([hx5dup,hx4],1))
|
597 |
+
hx4dup = _upsample_like(hx4d,hx3)
|
598 |
+
|
599 |
+
hx3d = self.stage3d(torch.cat([hx4dup,hx3],1))
|
600 |
+
hx3dup = _upsample_like(hx3d,hx2)
|
601 |
+
|
602 |
+
hx2d = self.stage2d(torch.cat([hx3dup,hx2],1))
|
603 |
+
hx2dup = _upsample_like(hx2d,hx1)
|
604 |
+
|
605 |
+
hx1d = self.stage1d(torch.cat([hx2dup,hx1],1))
|
606 |
+
|
607 |
+
|
608 |
+
#side output
|
609 |
+
d1 = self.side1(hx1d)
|
610 |
+
d1 = _upsample_like(d1,x)
|
611 |
+
|
612 |
+
d2 = self.side2(hx2d)
|
613 |
+
d2 = _upsample_like(d2,x)
|
614 |
+
|
615 |
+
d3 = self.side3(hx3d)
|
616 |
+
d3 = _upsample_like(d3,x)
|
617 |
+
|
618 |
+
d4 = self.side4(hx4d)
|
619 |
+
d4 = _upsample_like(d4,x)
|
620 |
+
|
621 |
+
d5 = self.side5(hx5d)
|
622 |
+
d5 = _upsample_like(d5,x)
|
623 |
+
|
624 |
+
d6 = self.side6(hx6)
|
625 |
+
d6 = _upsample_like(d6,x)
|
626 |
+
|
627 |
+
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
628 |
+
# plt.imshow(hx1d[0][0].cpu().detach().numpy(),cmap='gray')
|
629 |
+
# plt.show()
|
630 |
+
# plt.imshow(hx2d[0][0].cpu().detach().numpy(),cmap='gray')
|
631 |
+
# plt.show()
|
632 |
+
# plt.imshow(hx3d[0][0].cpu().detach().numpy(),cmap='gray')
|
633 |
+
# plt.show()
|
634 |
+
# plt.imshow(hx4d[0][0].cpu().detach().numpy(),cmap='gray')
|
635 |
+
# plt.show()
|
636 |
+
# plt.imshow(hx5d[0][0].cpu().detach().numpy(),cmap='gray')
|
637 |
+
# plt.show()
|
638 |
+
# plt.imshow(hx6[0][0].cpu().detach().numpy(),cmap='gray')
|
639 |
+
# plt.show()
|
640 |
+
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
|
IS_Net/saliency_toolbox.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import numpy as np
|
5 |
+
from glob import glob
|
6 |
+
from tqdm import tqdm
|
7 |
+
from scipy.ndimage import correlate
|
8 |
+
from scipy.ndimage.morphology import distance_transform_edt
|
9 |
+
from joblib import Parallel, delayed
|
10 |
+
|
11 |
+
eps = sys.float_info.epsilon
|
12 |
+
|
13 |
+
def calcualte_once(gt_name,sm_dir,gt_threshold,beta,measures):
|
14 |
+
values = dict()
|
15 |
+
for idx in measures:
|
16 |
+
values[idx] = list()
|
17 |
+
if idx == 'Max-F':
|
18 |
+
values['Precision'] = list()
|
19 |
+
values['Recall'] = list()
|
20 |
+
_, name = os.path.split(gt_name)
|
21 |
+
sm_name = os.path.join(sm_dir, name)
|
22 |
+
|
23 |
+
if os.path.exists(sm_name):
|
24 |
+
|
25 |
+
gt, sm = read_and_normalize(gt_name, sm_name, gt_threshold)
|
26 |
+
|
27 |
+
if 'MAE' in measures:
|
28 |
+
values['MAE'].append(mean_square_error(gt, sm))
|
29 |
+
if 'E-measure' in measures:
|
30 |
+
values['E-measure'].append(e_measure(gt, sm))
|
31 |
+
if 'S-measure' in measures:
|
32 |
+
values['S-measure'].append(s_measure(gt, sm))
|
33 |
+
if 'Adp-F' in measures:
|
34 |
+
values['Adp-F'].append(adaptive_fmeasure(gt, sm, beta))
|
35 |
+
if 'Wgt-F' in measures:
|
36 |
+
values['Wgt-F'].append(weighted_fmeasure(gt, sm))
|
37 |
+
if 'Max-F' in measures:
|
38 |
+
prec, recall = prec_recall(gt, sm, 256) # 256 thresholds between 0 and 1
|
39 |
+
values['Precision'].append(prec)
|
40 |
+
values['Recall'].append(recall)
|
41 |
+
else:
|
42 |
+
print("\n{} not found!".format(os.path.basename(sm_name)))
|
43 |
+
print('---' * 10)
|
44 |
+
return values
|
45 |
+
|
46 |
+
def calculate_measures(gt_dir, sm_dir, measures, save=False, beta=np.sqrt(0.3), gt_threshold=0.5, n_thread=1):
|
47 |
+
"""
|
48 |
+
function that calculates Saliency measures for given directories
|
49 |
+
|
50 |
+
arameters
|
51 |
+
----------
|
52 |
+
gt_dir : str
|
53 |
+
The path to the ground truth directory
|
54 |
+
sm_dir : str
|
55 |
+
The path to the predicted saliency map directory
|
56 |
+
measures : list
|
57 |
+
list of measure names which need to be calculated
|
58 |
+
supported measures: 'MAE' => Mean Squared Error
|
59 |
+
'E-measure' => Enhanced-alignment measure
|
60 |
+
'S-measure' => Structure-measure
|
61 |
+
'Max-F' => Maximum F-measure
|
62 |
+
'Adp-F' => Adaptive F-measure
|
63 |
+
'Wgt-F' => Weighted F-measure
|
64 |
+
save : str
|
65 |
+
If spesified, the results will be saved in 'save' directory
|
66 |
+
beta : float
|
67 |
+
beta parameter that is used in F-measure formula. default is sqrt(0.3)
|
68 |
+
gt_threshold : float
|
69 |
+
The threshold that is used to binrize ground truth maps.
|
70 |
+
|
71 |
+
Returns
|
72 |
+
-------
|
73 |
+
values : dictionary
|
74 |
+
a dict containing the results
|
75 |
+
"""
|
76 |
+
|
77 |
+
values = dict()
|
78 |
+
for idx in measures:
|
79 |
+
values[idx] = list()
|
80 |
+
if idx == 'Max-F':
|
81 |
+
values['Precision'] = list()
|
82 |
+
values['Recall'] = list()
|
83 |
+
|
84 |
+
results = Parallel(n_jobs=n_thread)(delayed(calcualte_once)(gt_name,sm_dir,gt_threshold,beta,measures) for gt_name in tqdm(glob(os.path.join(gt_dir, '*')), total=len(glob(os.path.join(gt_dir, '*')))))
|
85 |
+
for i in results:
|
86 |
+
if 'MAE' in measures:
|
87 |
+
values['MAE'].append(i["MAE"])
|
88 |
+
if 'E-measure' in measures:
|
89 |
+
values['E-measure'].append(i["E-measure"])
|
90 |
+
if 'S-measure' in measures:
|
91 |
+
values['S-measure'].append(i["S-measure"])
|
92 |
+
if 'Adp-F' in measures:
|
93 |
+
values['Adp-F'].append(i["Adp-F"])
|
94 |
+
if 'Wgt-F' in measures:
|
95 |
+
values['Wgt-F'].append(i["Wgt-F"])
|
96 |
+
if 'Max-F' in measures: # 256 thresholds between 0 and 1
|
97 |
+
values['Precision'].append(i["Precision"])
|
98 |
+
values['Recall'].append(i["Recall"])
|
99 |
+
|
100 |
+
if 'MAE' in measures:
|
101 |
+
values['MAE'] = np.mean(values['MAE'])
|
102 |
+
|
103 |
+
if 'E-measure' in measures:
|
104 |
+
values['E-measure'] = np.mean(values['E-measure'])
|
105 |
+
|
106 |
+
if 'S-measure' in measures:
|
107 |
+
values['S-measure'] = np.mean(values['S-measure'])
|
108 |
+
|
109 |
+
if 'Adp-F' in measures:
|
110 |
+
values['Adp-F'] = np.mean(values['Adp-F'])
|
111 |
+
|
112 |
+
if 'Wgt-F' in measures:
|
113 |
+
values['Wgt-F'] = np.mean(values['Wgt-F'])
|
114 |
+
|
115 |
+
if 'Max-F' in measures:
|
116 |
+
values['Precision'] = np.mean(np.hstack(values['Precision'][:]), 1)
|
117 |
+
values['Recall'] = np.mean(np.hstack(values['Recall'][:]), 1)
|
118 |
+
f_measures = (1 + beta ** 2) * values['Precision'] * values['Recall'] / (
|
119 |
+
beta ** 2 * values['Precision'] + values['Recall'])
|
120 |
+
values['Fmeasure_all_thresholds'] = f_measures
|
121 |
+
values['Max-F'] = np.max(f_measures)
|
122 |
+
|
123 |
+
if save:
|
124 |
+
if not os.path.isdir(save):
|
125 |
+
os.mkdir(save)
|
126 |
+
for key in values.keys():
|
127 |
+
np.save(os.path.join(save, key + ".npy"), values[key])
|
128 |
+
|
129 |
+
return values
|
130 |
+
|
131 |
+
|
132 |
+
def read_and_normalize(gt_path, sm_path, gt_threshold=0.5):
|
133 |
+
"""
|
134 |
+
function that reads, normalizes and crops a ground truth and a saliency map
|
135 |
+
|
136 |
+
parameters
|
137 |
+
----------
|
138 |
+
gt_path : str
|
139 |
+
The path to a ground truth map
|
140 |
+
sm_path : str
|
141 |
+
The path to a predicted saliency map
|
142 |
+
gt_threshold : float
|
143 |
+
The threshold that is used to binrize ground truth maps.
|
144 |
+
|
145 |
+
Returns
|
146 |
+
-------
|
147 |
+
gt_img, sm_img : numpy.ndarray
|
148 |
+
The prepared arrays
|
149 |
+
"""
|
150 |
+
gt_img = norm_img(cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE))
|
151 |
+
gt_img = (gt_img >= gt_threshold).astype(np.float32)
|
152 |
+
sm_img = norm_img(cv2.imread(sm_path, cv2.IMREAD_GRAYSCALE))
|
153 |
+
if sm_img.shape[0] != gt_img.shape[0] or sm_img.shape[1] != gt_img.shape[1]:
|
154 |
+
sm_img = cv2.resize(sm_img, (gt_img.shape[1], gt_img.shape[0]))
|
155 |
+
|
156 |
+
return gt_img, sm_img
|
157 |
+
|
158 |
+
|
159 |
+
def norm_img(im):
|
160 |
+
return cv2.normalize(im.astype('float'),
|
161 |
+
None,
|
162 |
+
0.0, 1.0,
|
163 |
+
cv2.NORM_MINMAX)
|
164 |
+
|
165 |
+
|
166 |
+
# MAE
|
167 |
+
def mean_square_error(gt, sm):
|
168 |
+
return np.mean(np.abs(sm - gt))
|
169 |
+
|
170 |
+
|
171 |
+
# E-measure
|
172 |
+
# article: https://arxiv.org/abs/1805.10421
|
173 |
+
# original code [Matlab]: https://github.com/DengPingFan/E-measure
|
174 |
+
def e_measure(gt, sm):
|
175 |
+
"""
|
176 |
+
This fucntion computes the Enhanced-alignment Measure (E-Measure) between the saliency map and the ground truth
|
177 |
+
article: https://arxiv.org/abs/1805.10421
|
178 |
+
original code [Matlab]: https://github.com/DengPingFan/E-measure
|
179 |
+
|
180 |
+
parameters
|
181 |
+
----------
|
182 |
+
gt : numpy.ndarray
|
183 |
+
The path to the ground truth directory
|
184 |
+
sm : numpy.ndarray
|
185 |
+
The path to the predicted saliency map directory
|
186 |
+
|
187 |
+
Returns
|
188 |
+
-------
|
189 |
+
value : float
|
190 |
+
The calculated E-masure
|
191 |
+
"""
|
192 |
+
sm = adptive_binary(sm)
|
193 |
+
|
194 |
+
gt = gt.astype(np.bool_)
|
195 |
+
sm = sm.astype(np.bool_)
|
196 |
+
|
197 |
+
dgt = gt.astype(np.float32)
|
198 |
+
dsm = sm.astype(np.float32)
|
199 |
+
|
200 |
+
if np.sum(dgt) == 0: # if the gt is completely black
|
201 |
+
enhanced_matrix = 1.0 - dsm # only calculate the black area of intersection
|
202 |
+
elif np.mean(dgt) == 1: # if the gt is completely white
|
203 |
+
enhanced_matrix = dsm # only calcualte the white area of intersection
|
204 |
+
else:
|
205 |
+
# Normal case:
|
206 |
+
# 1.compute alignment matrix
|
207 |
+
align_matrix = alignment_term(dsm, dgt)
|
208 |
+
# 2.compute enhanced alignment matrix
|
209 |
+
enhanced_matrix = enhanced_alignment_term(align_matrix)
|
210 |
+
|
211 |
+
height, width = gt.shape
|
212 |
+
value = np.sum(enhanced_matrix) / (height * width - 1 + eps)
|
213 |
+
return value
|
214 |
+
|
215 |
+
|
216 |
+
def alignment_term(dgt, dsm):
|
217 |
+
# compute global mean
|
218 |
+
mu_fm = np.mean(dsm)
|
219 |
+
mu_gt = np.mean(dgt)
|
220 |
+
|
221 |
+
# compute the bias matrix
|
222 |
+
align_fm = dsm - mu_fm
|
223 |
+
align_gt = dgt - mu_gt
|
224 |
+
|
225 |
+
# compute alignment matrix
|
226 |
+
align_Matrix = 2 * (align_gt * align_fm) / (align_gt * align_gt + align_fm * align_fm + eps)
|
227 |
+
return align_Matrix
|
228 |
+
|
229 |
+
|
230 |
+
def enhanced_alignment_term(align_matrix):
|
231 |
+
enhanced = ((align_matrix + 1) ** 2) / 4
|
232 |
+
return enhanced
|
233 |
+
|
234 |
+
|
235 |
+
def adptive_binary(sm):
|
236 |
+
adaptive_threshold = 2 * np.mean(sm)
|
237 |
+
|
238 |
+
if adaptive_threshold > 1:
|
239 |
+
adaptive_threshold = 1
|
240 |
+
|
241 |
+
binary_sm = (sm >= adaptive_threshold).astype(np.float32)
|
242 |
+
|
243 |
+
return binary_sm
|
244 |
+
|
245 |
+
|
246 |
+
# S-Measure
|
247 |
+
# article: https://www.crcv.ucf.edu/papers/iccv17/1164.pdf
|
248 |
+
# Matlab code: https://github.com/DengPingFan/S-measure
|
249 |
+
def s_measure(gt, sm):
|
250 |
+
"""
|
251 |
+
This fucntion computes the structural similarity (S-Measure) between the saliency map and the ground truth
|
252 |
+
article: https://www.crcv.ucf.edu/papers/iccv17/1164.pdf
|
253 |
+
original code [Matlab]: https://github.com/DengPingFan/S-measure
|
254 |
+
|
255 |
+
parameters
|
256 |
+
----------
|
257 |
+
gt : numpy.ndarray
|
258 |
+
The path to the ground truth directory
|
259 |
+
sm : numpy.ndarray
|
260 |
+
The path to the predicted saliency map directory
|
261 |
+
|
262 |
+
Returns
|
263 |
+
-------
|
264 |
+
value : float
|
265 |
+
The calculated S-masure
|
266 |
+
"""
|
267 |
+
gt_mean = np.mean(gt)
|
268 |
+
|
269 |
+
if gt_mean == 0: # if the GT is completely black
|
270 |
+
sm_mean = np.mean(sm)
|
271 |
+
measure = 1.0 - sm_mean # only calculate the area of intersection
|
272 |
+
elif gt_mean == 1: # if the GT is completely white
|
273 |
+
sm_mean = np.mean(sm)
|
274 |
+
measure = sm_mean.copy() # only calcualte the area of intersection
|
275 |
+
else:
|
276 |
+
alpha = 0.5
|
277 |
+
measure = alpha * s_object(sm, gt) + (1 - alpha) * s_region(sm, gt)
|
278 |
+
if measure < 0:
|
279 |
+
measure = 0
|
280 |
+
|
281 |
+
return measure
|
282 |
+
|
283 |
+
|
284 |
+
def ssim(gt, sm):
|
285 |
+
gt = gt.astype(np.float32)
|
286 |
+
|
287 |
+
height, width = sm.shape
|
288 |
+
num_pixels = width * height
|
289 |
+
|
290 |
+
# Compute the mean of SM,GT
|
291 |
+
sm_mean = np.mean(sm)
|
292 |
+
gt_mean = np.mean(gt)
|
293 |
+
|
294 |
+
# Compute the variance of SM,GT
|
295 |
+
sigma_x2 = np.sum(np.sum((sm - sm_mean) ** 2)) / (num_pixels - 1 + eps)
|
296 |
+
sigma_y2 = np.sum(np.sum((gt - gt_mean) ** 2)) / (num_pixels - 1 + eps)
|
297 |
+
|
298 |
+
# Compute the covariance
|
299 |
+
sigma_xy = np.sum(np.sum((sm - sm_mean) * (gt - gt_mean))) / (num_pixels - 1 + eps)
|
300 |
+
|
301 |
+
alpha = 4 * sm_mean * gt_mean * sigma_xy
|
302 |
+
beta = (sm_mean ** 2 + gt_mean ** 2) * (sigma_x2 + sigma_y2)
|
303 |
+
|
304 |
+
if alpha != 0:
|
305 |
+
ssim_value = alpha / (beta + eps)
|
306 |
+
elif alpha == 0 and beta == 0:
|
307 |
+
ssim_value = 1.0
|
308 |
+
else:
|
309 |
+
ssim_value = 0
|
310 |
+
|
311 |
+
return ssim_value
|
312 |
+
|
313 |
+
|
314 |
+
def divide_sm(sm, x, y):
|
315 |
+
# copy the 4 regions
|
316 |
+
lt = sm[:y, :x]
|
317 |
+
rt = sm[:y, x:]
|
318 |
+
lb = sm[y:, :x]
|
319 |
+
rb = sm[y:, x:]
|
320 |
+
|
321 |
+
return lt, rt, lb, rb
|
322 |
+
|
323 |
+
|
324 |
+
def divide_gt(gt, x, y):
|
325 |
+
height, width = gt.shape
|
326 |
+
area = width * height
|
327 |
+
|
328 |
+
# copy the 4 regions
|
329 |
+
lt = gt[:y, :x]
|
330 |
+
rt = gt[:y, x:]
|
331 |
+
lb = gt[y:, :x]
|
332 |
+
rb = gt[y:, x:]
|
333 |
+
|
334 |
+
# The different weight (each block proportional to the GT foreground region).
|
335 |
+
w1 = (x * y) / area
|
336 |
+
w2 = ((width - x) * y) / area
|
337 |
+
w3 = (x * (height - y)) / area
|
338 |
+
w4 = 1.0 - w1 - w2 - w3
|
339 |
+
|
340 |
+
return lt, rt, lb, rb, w1, w2, w3, w4
|
341 |
+
|
342 |
+
|
343 |
+
def centroid(gt):
|
344 |
+
# col
|
345 |
+
rows, cols = gt.shape
|
346 |
+
|
347 |
+
if np.sum(gt) == 0:
|
348 |
+
x = np.round(cols / 2)
|
349 |
+
y = np.round(rows / 2)
|
350 |
+
else:
|
351 |
+
total = np.sum(gt)
|
352 |
+
i = np.arange(cols).reshape(1, cols) + 1
|
353 |
+
j = np.arange(rows).reshape(rows, 1) + 1
|
354 |
+
|
355 |
+
x = int(np.round(np.sum(np.sum(gt, 0, keepdims=True) * i) / total))
|
356 |
+
y = int(np.round(np.sum(np.sum(gt, 1, keepdims=True) * j) / total))
|
357 |
+
|
358 |
+
return x, y
|
359 |
+
|
360 |
+
|
361 |
+
def s_region(gt, sm):
|
362 |
+
x, y = centroid(gt)
|
363 |
+
gt_1, gt_2, gt_3, gt_4, w1, w2, w3, w4 = divide_gt(gt, x, y)
|
364 |
+
|
365 |
+
sm_1, sm_2, sm_3, sm_4 = divide_sm(sm, x, y)
|
366 |
+
|
367 |
+
q1 = ssim(sm_1, gt_1)
|
368 |
+
q2 = ssim(sm_2, gt_2)
|
369 |
+
q3 = ssim(sm_3, gt_3)
|
370 |
+
q4 = ssim(sm_4, gt_4)
|
371 |
+
|
372 |
+
region_value = w1 * q1 + w2 * q2 + w3 * q3 + w4 * q4
|
373 |
+
|
374 |
+
return region_value
|
375 |
+
|
376 |
+
|
377 |
+
def object(gt, sm):
|
378 |
+
x = np.mean(sm[gt == 1])
|
379 |
+
# compute the standard deviations of the foreground or background in sm
|
380 |
+
sigma_x = np.std(sm[gt == 1])
|
381 |
+
score = 2.0 * x / (x ** 2 + 1.0 + sigma_x + eps)
|
382 |
+
return score
|
383 |
+
|
384 |
+
|
385 |
+
def s_object(gt, sm):
|
386 |
+
# compute the similarity of the foreground in the object level
|
387 |
+
|
388 |
+
sm_fg = sm.copy()
|
389 |
+
sm_fg[gt == 0] = 0
|
390 |
+
o_fg = object(sm_fg, gt)
|
391 |
+
|
392 |
+
# compute the similarity of the background
|
393 |
+
sm_bg = 1.0 - sm.copy()
|
394 |
+
sm_bg[gt == 1] = 0
|
395 |
+
o_bg = object(sm_bg, gt == 0)
|
396 |
+
|
397 |
+
u = np.mean(gt)
|
398 |
+
object_value = u * o_fg + (1 - u) * o_bg
|
399 |
+
return object_value
|
400 |
+
|
401 |
+
|
402 |
+
|
403 |
+
# Weighted F-Measure
|
404 |
+
# article: https://ieeexplore.ieee.org/document/6909433
|
405 |
+
# Matlab code: https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/FGEval/
|
406 |
+
def weighted_fmeasure(gt, sm, beta2=1):
|
407 |
+
"""
|
408 |
+
This fucntion computes Weighted F-Measure between the saliency map and the ground truth
|
409 |
+
article: https://ieeexplore.ieee.org/document/6909433
|
410 |
+
original code [Matlab]: https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/FGEval/
|
411 |
+
|
412 |
+
parameters
|
413 |
+
----------
|
414 |
+
gt : numpy.ndarray
|
415 |
+
The path to the ground truth directory
|
416 |
+
sm : numpy.ndarray
|
417 |
+
The path to the predicted saliency map directory
|
418 |
+
|
419 |
+
Returns
|
420 |
+
-------
|
421 |
+
value : float
|
422 |
+
The calculated Weighted F-Measure
|
423 |
+
"""
|
424 |
+
dst, idx = distance_transform_edt(1 - gt, return_indices=True)
|
425 |
+
|
426 |
+
raw_idx = idx[0][gt == 0]
|
427 |
+
col_idx = idx[1][gt == 0]
|
428 |
+
|
429 |
+
e = np.abs(sm - gt).astype(np.float32)
|
430 |
+
et = np.abs(sm - gt).astype(np.float32)
|
431 |
+
|
432 |
+
et[gt == 0] = et[raw_idx, col_idx]
|
433 |
+
|
434 |
+
k = matlab_style_gauss2d(shape=(7, 7), sigma=5)
|
435 |
+
|
436 |
+
ea = correlate(et.astype(np.float32), k, mode='constant')
|
437 |
+
min_e_ea = np.abs(sm - gt).astype(np.float32)
|
438 |
+
|
439 |
+
min_e_ea[gt * (ea < e) == 1] = ea[gt * (ea < e) == 1]
|
440 |
+
|
441 |
+
b = np.ones_like(gt).astype(np.float32)
|
442 |
+
b[gt == 0] = 2 - 1 * np.exp(np.log(1 - 0.5) / 5. * dst[gt == 0])
|
443 |
+
|
444 |
+
ew = min_e_ea * b
|
445 |
+
tpw = np.sum(gt) - np.sum(ew[gt == 1])
|
446 |
+
fpw = np.sum(ew[gt == 0])
|
447 |
+
|
448 |
+
rec = 1 - np.mean(ew[gt == 1]) # Weighed Recall
|
449 |
+
prec = tpw / (eps + tpw + fpw) # Weighted Precision
|
450 |
+
|
451 |
+
value = (1 + beta2) * (rec * prec) / (eps + (beta2 * rec) + prec)
|
452 |
+
return value
|
453 |
+
|
454 |
+
def matlab_style_gauss2d(shape=(3, 3), sigma=0.5):
|
455 |
+
"""
|
456 |
+
2D gaussian mask - should give the same result as MATLAB's
|
457 |
+
fspecial('gaussian',[shape],[sigma])
|
458 |
+
"""
|
459 |
+
m, n = [(ss - 1.) / 2. for ss in shape]
|
460 |
+
y, x = np.ogrid[-m:m + 1, -n:n + 1]
|
461 |
+
h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
|
462 |
+
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
463 |
+
sumh = h.sum()
|
464 |
+
if sumh != 0:
|
465 |
+
h /= sumh
|
466 |
+
return h
|
467 |
+
|
468 |
+
|
469 |
+
|
470 |
+
# Adaptive F-measure
|
471 |
+
|
472 |
+
def adaptive_fmeasure(gt, sm, beta):
|
473 |
+
"""
|
474 |
+
This fucntion computes Adaptive F-measure between the saliency map and the ground truth using
|
475 |
+
the binary method proposed in:
|
476 |
+
https://ieeexplore.ieee.org/document/5206596
|
477 |
+
|
478 |
+
parameters
|
479 |
+
----------
|
480 |
+
gt : numpy.ndarray
|
481 |
+
The path to the ground truth directory
|
482 |
+
sm : numpy.ndarray
|
483 |
+
The path to the predicted saliency map directory
|
484 |
+
|
485 |
+
Returns
|
486 |
+
-------
|
487 |
+
value : float
|
488 |
+
The calculated Adaptive F-measure
|
489 |
+
"""
|
490 |
+
gt_idx = np.where(gt > 0)
|
491 |
+
gt_cnt = np.sum(gt)
|
492 |
+
|
493 |
+
if gt_cnt == 0:
|
494 |
+
prec = []
|
495 |
+
recall = []
|
496 |
+
else:
|
497 |
+
adaptive_threshold = 2 * np.mean(sm)
|
498 |
+
if adaptive_threshold > 1:
|
499 |
+
adaptive_threshold = 1
|
500 |
+
sm_binary = (sm >= adaptive_threshold).astype(np.float32)
|
501 |
+
hit_cnt = np.sum(sm_binary[gt_idx])
|
502 |
+
alg_cnt = np.sum(sm_binary)
|
503 |
+
|
504 |
+
if hit_cnt == 0:
|
505 |
+
prec = 0
|
506 |
+
recall = 0
|
507 |
+
else:
|
508 |
+
prec = hit_cnt / (alg_cnt + eps)
|
509 |
+
recall = hit_cnt / gt_cnt
|
510 |
+
value = (1 + beta ** 2) * prec * recall / ((beta ** 2 * prec + recall) + eps)
|
511 |
+
return value
|
512 |
+
|
513 |
+
|
514 |
+
|
515 |
+
def prec_recall(gt, sm, num_th):
|
516 |
+
"""
|
517 |
+
This fucntion computes Adaptive F-measure between the saliency map and the ground truth using
|
518 |
+
the binary method proposed in:
|
519 |
+
https://ieeexplore.ieee.org/document/5206596
|
520 |
+
The results of this dunction will be used to calculate Max-F measure and plot PR and F-Threshold Curves
|
521 |
+
parameters
|
522 |
+
----------
|
523 |
+
gt : numpy.ndarray
|
524 |
+
The path to the ground truth directory
|
525 |
+
sm : numpy.ndarray
|
526 |
+
The path to the predicted saliency map directory
|
527 |
+
num_th : interger
|
528 |
+
The total number of thresholds between 0 and 1
|
529 |
+
Returns
|
530 |
+
-------
|
531 |
+
prec, recall: numpy.ndarray
|
532 |
+
The calculated Precision and Recall (shape: (num_th,1))
|
533 |
+
"""
|
534 |
+
gt_idx = np.where(gt > 0)
|
535 |
+
gt_cnt = np.sum(gt)
|
536 |
+
|
537 |
+
if gt_cnt == 0:
|
538 |
+
prec = []
|
539 |
+
recall = []
|
540 |
+
else:
|
541 |
+
hit_cnt = np.zeros((num_th, 1), np.float32)
|
542 |
+
alg_cnt = np.zeros((num_th, 1), np.float32)
|
543 |
+
thresholds = np.linspace(0, 1, num_th)
|
544 |
+
for k, curTh in enumerate(thresholds):
|
545 |
+
sm_binary = (sm >= curTh).astype(np.float32)
|
546 |
+
hit_cnt[k] = np.sum(sm_binary[gt_idx])
|
547 |
+
alg_cnt[k] = np.sum(sm_binary)
|
548 |
+
|
549 |
+
prec = hit_cnt / (alg_cnt + eps)
|
550 |
+
recall = hit_cnt / gt_cnt
|
551 |
+
|
552 |
+
return prec, recall
|
IS_Net/swd_optim/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .adai import Adai
|
3 |
+
from .adais import AdaiS
|
4 |
+
from .adams import AdamS
|
5 |
+
from .sgds import SGDS
|
6 |
+
|
7 |
+
del adai
|
8 |
+
del adais
|
9 |
+
del adams
|
10 |
+
del sgds
|
IS_Net/swd_optim/adai.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim.optimizer import Optimizer, required
|
3 |
+
|
4 |
+
class Adai(Optimizer):
|
5 |
+
r"""Implements Adaptive Inertia Estimation (Adai) algorithm.
|
6 |
+
It has be proposed in
|
7 |
+
`Adai: Separating the Effects of Adaptive Learning Rate and Momentum Inertia`__.
|
8 |
+
|
9 |
+
Arguments:
|
10 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
11 |
+
parameter groups
|
12 |
+
lr (float): learning rate
|
13 |
+
betas (Tuple[float, float], optional): beta0 and beta2 (default: (0.1, 0.99))
|
14 |
+
eps (float, optional): the inertia bound (default: 1e-03)
|
15 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, params, lr=required, betas=(0.1, 0.99), eps=1e-03,
|
20 |
+
weight_decay=0):
|
21 |
+
if lr is not required and lr < 0.0:
|
22 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
23 |
+
if not 0.0 <= eps:
|
24 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
25 |
+
if not 0.0 <= betas[0]:
|
26 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
27 |
+
if not 0.0 <= betas[1] < 1.0:
|
28 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
29 |
+
if not 0.0 <= weight_decay:
|
30 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
31 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
32 |
+
super(Adai, self).__init__(params, defaults)
|
33 |
+
|
34 |
+
|
35 |
+
def __setstate__(self, state):
|
36 |
+
super(Adai, self).__setstate__(state)
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
def step(self, closure=None):
|
40 |
+
"""Performs a single optimization step.
|
41 |
+
|
42 |
+
Arguments:
|
43 |
+
closure (callable, optional): A closure that reevaluates the model
|
44 |
+
and returns the loss.
|
45 |
+
"""
|
46 |
+
loss = None
|
47 |
+
if closure is not None:
|
48 |
+
loss = closure()
|
49 |
+
|
50 |
+
param_size = 0
|
51 |
+
exp_avg_sq_hat_sum = 0.
|
52 |
+
|
53 |
+
for group in self.param_groups:
|
54 |
+
for p in group['params']:
|
55 |
+
if p.grad is None:
|
56 |
+
continue
|
57 |
+
param_size += p.numel()
|
58 |
+
grad = p.grad.data
|
59 |
+
|
60 |
+
state = self.state[p]
|
61 |
+
|
62 |
+
# State initialization
|
63 |
+
if len(state) == 0:
|
64 |
+
state['step'] = 0
|
65 |
+
# Exponential moving average of gradient values
|
66 |
+
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
67 |
+
# Exponential moving average of squared gradient values
|
68 |
+
state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
69 |
+
# Cumulative products of beta1
|
70 |
+
state['beta1_prod'] = torch.ones_like(p.data, memory_format=torch.preserve_format)
|
71 |
+
|
72 |
+
state['step'] += 1
|
73 |
+
|
74 |
+
exp_avg_sq = state['exp_avg_sq']
|
75 |
+
beta0, beta2 = group['betas']
|
76 |
+
|
77 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
78 |
+
|
79 |
+
if group['weight_decay'] != 0:
|
80 |
+
grad.add_(group['weight_decay'], p.data)
|
81 |
+
|
82 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
83 |
+
|
84 |
+
exp_avg_sq_hat_sum += exp_avg_sq.sum() / bias_correction2
|
85 |
+
|
86 |
+
# Calculate the mean of all elements in exp_avg_sq_hat
|
87 |
+
exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size
|
88 |
+
|
89 |
+
for group in self.param_groups:
|
90 |
+
for p in group['params']:
|
91 |
+
if p.grad is None:
|
92 |
+
continue
|
93 |
+
grad = p.grad.data
|
94 |
+
|
95 |
+
state = self.state[p]
|
96 |
+
|
97 |
+
exp_avg = state['exp_avg']
|
98 |
+
exp_avg_sq = state['exp_avg_sq']
|
99 |
+
beta1_prod = state['beta1_prod']
|
100 |
+
beta0, beta2 = group['betas']
|
101 |
+
|
102 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
103 |
+
|
104 |
+
exp_avg_sq_hat = exp_avg_sq / bias_correction2
|
105 |
+
beta1 = (1. - (exp_avg_sq_hat / exp_avg_sq_hat_mean).mul(beta0)).clamp(0., 1 - group['eps'])
|
106 |
+
|
107 |
+
beta1_prod.mul_(beta1)
|
108 |
+
bias_correction1 = 1 - beta1_prod
|
109 |
+
|
110 |
+
exp_avg.mul_(beta1).addcmul_(1 - beta1, grad)
|
111 |
+
exp_avg_hat = exp_avg / bias_correction1
|
112 |
+
|
113 |
+
step_size = group['lr']
|
114 |
+
p.data.add_(-step_size, exp_avg_hat)
|
115 |
+
|
116 |
+
return loss
|
IS_Net/swd_optim/adais.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.optim.optimizer import Optimizer, required
|
3 |
+
|
4 |
+
|
5 |
+
class AdaiS(Optimizer):
|
6 |
+
r"""Implements Adai with stable/decoupled weight decay (AdaiS/AdaiW).
|
7 |
+
It is based on
|
8 |
+
`Adai: Separating the Effects of Adaptive Learning Rate and Momentum Inertia`
|
9 |
+
and
|
10 |
+
`Stable Weight Decay Regularization`__.
|
11 |
+
|
12 |
+
Arguments:
|
13 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
14 |
+
parameter groups
|
15 |
+
lr (float, optional): learning rate
|
16 |
+
betas (Tuple[float, float], optional): beta0 and beta2 (default: (0.1, 0.99))
|
17 |
+
eps (float, optional): the inertia bound (default: 1e-03)
|
18 |
+
weight_decay (float, optional): weight decay (default: 0)
|
19 |
+
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, params, lr=required, betas=(0.1, 0.99), eps=1e-03,
|
23 |
+
weight_decay=0):
|
24 |
+
if lr is not required and lr < 0.0:
|
25 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
26 |
+
if not 0.0 <= eps:
|
27 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
28 |
+
if not 0.0 <= betas[0]:
|
29 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
30 |
+
if not 0.0 <= betas[1] < 1.0:
|
31 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
32 |
+
if not 0.0 <= weight_decay:
|
33 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
34 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
35 |
+
super(AdaiS, self).__init__(params, defaults)
|
36 |
+
|
37 |
+
|
38 |
+
def __setstate__(self, state):
|
39 |
+
super(AdaiS, self).__setstate__(state)
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def step(self, closure=None):
|
43 |
+
"""Performs a single optimization step.
|
44 |
+
|
45 |
+
Arguments:
|
46 |
+
closure (callable, optional): A closure that reevaluates the model
|
47 |
+
and returns the loss.
|
48 |
+
"""
|
49 |
+
loss = None
|
50 |
+
if closure is not None:
|
51 |
+
loss = closure()
|
52 |
+
|
53 |
+
param_size = 0
|
54 |
+
exp_avg_sq_hat_sum = 0.
|
55 |
+
for group in self.param_groups:
|
56 |
+
for p in group['params']:
|
57 |
+
if p.grad is None:
|
58 |
+
continue
|
59 |
+
param_size += p.numel()
|
60 |
+
grad = p.grad.data
|
61 |
+
|
62 |
+
state = self.state[p]
|
63 |
+
|
64 |
+
# State initialization
|
65 |
+
if len(state) == 0:
|
66 |
+
state['step'] = 0
|
67 |
+
# Exponential moving average of gradient values
|
68 |
+
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
69 |
+
# Exponential moving average of squared gradient values
|
70 |
+
state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
71 |
+
# Cumulative products of beta1
|
72 |
+
state['beta1_prod'] = torch.ones_like(p.data, memory_format=torch.preserve_format)
|
73 |
+
|
74 |
+
exp_avg_sq = state['exp_avg_sq']
|
75 |
+
beta0, beta2 = group['betas']
|
76 |
+
|
77 |
+
state['step'] += 1
|
78 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
79 |
+
|
80 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
81 |
+
|
82 |
+
exp_avg_sq_hat = exp_avg_sq / bias_correction2
|
83 |
+
|
84 |
+
exp_avg_sq_hat_sum += exp_avg_sq_hat.sum()
|
85 |
+
|
86 |
+
# Calculate the mean of all elements in exp_avg_sq_hat
|
87 |
+
exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size
|
88 |
+
|
89 |
+
for group in self.param_groups:
|
90 |
+
for p in group['params']:
|
91 |
+
if p.grad is None:
|
92 |
+
continue
|
93 |
+
grad = p.grad.data
|
94 |
+
|
95 |
+
# Perform stable/decoupled weight decay
|
96 |
+
if group['weight_decay'] !=0:
|
97 |
+
p.data.mul_(1 - group['lr'] * group['weight_decay'])
|
98 |
+
|
99 |
+
state = self.state[p]
|
100 |
+
|
101 |
+
exp_avg = state['exp_avg']
|
102 |
+
exp_avg_sq = state['exp_avg_sq']
|
103 |
+
beta0, beta2 = group['betas']
|
104 |
+
beta1_prod = state['beta1_prod']
|
105 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
106 |
+
|
107 |
+
exp_avg_sq_hat = exp_avg_sq / bias_correction2
|
108 |
+
|
109 |
+
beta1 = (1. - (exp_avg_sq_hat / exp_avg_sq_hat_mean).mul(beta0)).clamp(0., 1 - group['eps'])
|
110 |
+
|
111 |
+
beta1_prod.mul_(beta1)
|
112 |
+
bias_correction1 = 1 - beta1_prod
|
113 |
+
|
114 |
+
exp_avg.mul_(beta1).addcmul_(1 - beta1, grad)
|
115 |
+
exp_avg_hat = exp_avg.div(bias_correction1)
|
116 |
+
|
117 |
+
step_size = group['lr']
|
118 |
+
p.data.add_(-step_size, exp_avg_hat)
|
119 |
+
|
120 |
+
return loss
|
IS_Net/swd_optim/adams.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.optim.optimizer import Optimizer
|
4 |
+
|
5 |
+
|
6 |
+
class AdamS(Optimizer):
|
7 |
+
r"""Implements Adam with stable weight decay (AdamS) algorithm.
|
8 |
+
It has be proposed in
|
9 |
+
`Stable Weight Decay Regularization`__.
|
10 |
+
|
11 |
+
Arguments:
|
12 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
13 |
+
parameter groups
|
14 |
+
lr (float, optional): learning rate (default: 1e-3)
|
15 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
16 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
17 |
+
eps (float, optional): term added to the denominator to improve
|
18 |
+
numerical stability (default: 1e-8)
|
19 |
+
weight_decay (float, optional): weight decay coefficient (default: 1e-4)
|
20 |
+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
21 |
+
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
22 |
+
(default: False)
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
26 |
+
weight_decay=1e-4, amsgrad=False):
|
27 |
+
if not 0.0 <= lr:
|
28 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
29 |
+
if not 0.0 <= eps:
|
30 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
31 |
+
if not 0.0 <= betas[0] < 1.0:
|
32 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
33 |
+
if not 0.0 <= betas[1] < 1.0:
|
34 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
35 |
+
if not 0.0 <= weight_decay:
|
36 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
37 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
38 |
+
weight_decay=weight_decay, amsgrad=amsgrad)
|
39 |
+
super(AdamS, self).__init__(params, defaults)
|
40 |
+
|
41 |
+
def __setstate__(self, state):
|
42 |
+
super(AdamS, self).__setstate__(state)
|
43 |
+
for group in self.param_groups:
|
44 |
+
group.setdefault('amsgrad', False)
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def step(self, closure=None):
|
48 |
+
"""Performs a single optimization step.
|
49 |
+
|
50 |
+
Arguments:
|
51 |
+
closure (callable, optional): A closure that reevaluates the model
|
52 |
+
and returns the loss.
|
53 |
+
"""
|
54 |
+
loss = None
|
55 |
+
if closure is not None:
|
56 |
+
with torch.enable_grad():
|
57 |
+
loss = closure()
|
58 |
+
|
59 |
+
param_size = 0
|
60 |
+
exp_avg_sq_hat_sum = 0.
|
61 |
+
|
62 |
+
for group in self.param_groups:
|
63 |
+
for p in group['params']:
|
64 |
+
if p.grad is None:
|
65 |
+
continue
|
66 |
+
param_size += p.numel()
|
67 |
+
|
68 |
+
# Perform optimization step
|
69 |
+
grad = p.grad
|
70 |
+
if grad.is_sparse:
|
71 |
+
raise RuntimeError('AdamS does not support sparse gradients')
|
72 |
+
amsgrad = group['amsgrad']
|
73 |
+
|
74 |
+
state = self.state[p]
|
75 |
+
|
76 |
+
# State initialization
|
77 |
+
if len(state) == 0:
|
78 |
+
state['step'] = 0
|
79 |
+
# Exponential moving average of gradient values
|
80 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
81 |
+
# Exponential moving average of squared gradient values
|
82 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
83 |
+
if amsgrad:
|
84 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
85 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
86 |
+
|
87 |
+
beta1, beta2 = group['betas']
|
88 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
89 |
+
|
90 |
+
state['step'] += 1
|
91 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
92 |
+
|
93 |
+
# Decay the first and second moment running average coefficient
|
94 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
95 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
96 |
+
if amsgrad:
|
97 |
+
max_exp_avg_sq = state['max_exp_avg_sq']
|
98 |
+
# Maintains the maximum of all 2nd moment running avg. till now
|
99 |
+
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
100 |
+
# Use the max. for normalizing running avg. of gradient
|
101 |
+
exp_avg_sq_hat = max_exp_avg_sq / bias_correction2
|
102 |
+
else:
|
103 |
+
exp_avg_sq_hat = exp_avg_sq / bias_correction2
|
104 |
+
|
105 |
+
exp_avg_sq_hat_sum += exp_avg_sq_hat.sum()
|
106 |
+
|
107 |
+
# Calculate the sqrt of the mean of all elements in exp_avg_sq_hat
|
108 |
+
exp_avg_mean_sqrt = math.sqrt(exp_avg_sq_hat_sum / param_size)
|
109 |
+
|
110 |
+
for group in self.param_groups:
|
111 |
+
for p in group['params']:
|
112 |
+
if p.grad is None:
|
113 |
+
continue
|
114 |
+
|
115 |
+
state = self.state[p]
|
116 |
+
|
117 |
+
#Perform stable weight decay
|
118 |
+
if group['weight_decay'] !=0:
|
119 |
+
p.data.mul_(1 - group['weight_decay'] * group['lr'] / exp_avg_mean_sqrt)
|
120 |
+
|
121 |
+
beta1, beta2 = group['betas']
|
122 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
123 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
124 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
125 |
+
|
126 |
+
if amsgrad:
|
127 |
+
max_exp_avg_sq = state['max_exp_avg_sq']
|
128 |
+
exp_avg_sq_hat = max_exp_avg_sq / bias_correction2
|
129 |
+
else:
|
130 |
+
exp_avg_sq_hat = exp_avg_sq / bias_correction2
|
131 |
+
|
132 |
+
denom = exp_avg_sq_hat.sqrt().add(group['eps'])
|
133 |
+
|
134 |
+
step_size = group['lr'] / bias_correction1
|
135 |
+
p.addcdiv_(exp_avg, denom, value= - step_size)
|
136 |
+
|
137 |
+
return loss
|
IS_Net/swd_optim/sgds.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch.optim.optimizer import Optimizer, required
|
4 |
+
|
5 |
+
|
6 |
+
class SGDS(Optimizer):
|
7 |
+
r"""Implements stochastic gradient descent with stable weight decay (SGDS).
|
8 |
+
It has be proposed in
|
9 |
+
`Stable Weight Decay Regularization`__.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
13 |
+
parameter groups
|
14 |
+
lr (float): learning rate
|
15 |
+
momentum (float, optional): momentum factor (default: 0)
|
16 |
+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
17 |
+
dampening (float, optional): dampening for momentum (default: 0)
|
18 |
+
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, params, lr=required, momentum=0, dampening=0,
|
22 |
+
weight_decay=0, nesterov=False):
|
23 |
+
if lr is not required and lr < 0.0:
|
24 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
25 |
+
if momentum < 0.0:
|
26 |
+
raise ValueError("Invalid momentum value: {}".format(momentum))
|
27 |
+
if weight_decay < 0.0:
|
28 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
29 |
+
|
30 |
+
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
|
31 |
+
weight_decay=weight_decay, nesterov=nesterov)
|
32 |
+
if nesterov and (momentum <= 0 or dampening != 0):
|
33 |
+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
34 |
+
super(SGDS, self).__init__(params, defaults)
|
35 |
+
|
36 |
+
def __setstate__(self, state):
|
37 |
+
super(SGDS, self).__setstate__(state)
|
38 |
+
for group in self.param_groups:
|
39 |
+
group.setdefault('nesterov', False)
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def step(self, closure=None):
|
43 |
+
"""Performs a single optimization step.
|
44 |
+
Arguments:
|
45 |
+
closure (callable, optional): A closure that reevaluates the model
|
46 |
+
and returns the loss.
|
47 |
+
"""
|
48 |
+
loss = None
|
49 |
+
if closure is not None:
|
50 |
+
with torch.enable_grad():
|
51 |
+
loss = closure()
|
52 |
+
|
53 |
+
for group in self.param_groups:
|
54 |
+
momentum = group['momentum']
|
55 |
+
dampening = group['dampening']
|
56 |
+
nesterov = group['nesterov']
|
57 |
+
|
58 |
+
for p in group['params']:
|
59 |
+
if p.grad is None:
|
60 |
+
continue
|
61 |
+
d_p = p.grad
|
62 |
+
|
63 |
+
# Perform stable weight decay
|
64 |
+
if group['weight_decay'] !=0:
|
65 |
+
bias_correction = (1 - dampening) / (1 - momentum)
|
66 |
+
p.data.mul_(1 - bias_correction * group['lr'] * group['weight_decay'])
|
67 |
+
|
68 |
+
if momentum != 0:
|
69 |
+
param_state = self.state[p]
|
70 |
+
if 'momentum_buffer' not in param_state:
|
71 |
+
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
|
72 |
+
else:
|
73 |
+
buf = param_state['momentum_buffer']
|
74 |
+
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
|
75 |
+
if nesterov:
|
76 |
+
d_p = d_p.add(buf, alpha=momentum)
|
77 |
+
else:
|
78 |
+
d_p = buf
|
79 |
+
|
80 |
+
p.add_(d_p, alpha=-group['lr'])
|
81 |
+
|
82 |
+
return loss
|
IS_Net/train_valid_inference_main.py
ADDED
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
from skimage import io
|
5 |
+
import time
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torch, gc
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.autograd import Variable
|
10 |
+
import torch.optim as optim
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from data_loader import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
|
13 |
+
# from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache,
|
14 |
+
from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch,
|
15 |
+
from models.isnet import ISNetGTEncoder, ISNetDIS
|
16 |
+
from torch.cuda.amp import autocast, GradScaler
|
17 |
+
from datalist import *
|
18 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
+
|
20 |
+
def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
21 |
+
|
22 |
+
torch.manual_seed(hypar["seed"])
|
23 |
+
if torch.cuda.is_available():
|
24 |
+
torch.cuda.manual_seed(hypar["seed"])
|
25 |
+
|
26 |
+
print("define gt encoder ...")
|
27 |
+
net = ISNetGTEncoder() #UNETGTENCODERCombine()
|
28 |
+
# if(hypar["model_digit"]=="half"):
|
29 |
+
# net.half()
|
30 |
+
## load the existing model gt encoder
|
31 |
+
if(hypar["gt_encoder_model"]!=""):
|
32 |
+
model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"]
|
33 |
+
if torch.cuda.is_available():
|
34 |
+
net.load_state_dict(torch.load(model_path))
|
35 |
+
net.cuda()
|
36 |
+
else:
|
37 |
+
net.load_state_dict(torch.load(model_path,map_location="cpu"))
|
38 |
+
print("gt encoder restored from the saved weights ...")
|
39 |
+
return net ############
|
40 |
+
|
41 |
+
if torch.cuda.is_available():
|
42 |
+
net.cuda()
|
43 |
+
|
44 |
+
print("--- define optimizer for GT Encoder---")
|
45 |
+
# optimizer = lion.Lion(net.parameters(), lr=1e-4, betas=(0.9, 0.99))
|
46 |
+
optimizer = optim.AdamW(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
|
47 |
+
# optimizer = optim.SGD(net.parameters(), lr=1e-4)
|
48 |
+
|
49 |
+
model_path = hypar["model_path"]
|
50 |
+
model_save_fre = hypar["model_save_fre"]
|
51 |
+
max_ite = hypar["max_ite"]
|
52 |
+
batch_size_train = hypar["batch_size_train"]
|
53 |
+
batch_size_valid = hypar["batch_size_valid"]
|
54 |
+
|
55 |
+
if(not os.path.exists(model_path)):
|
56 |
+
os.mkdir(model_path)
|
57 |
+
|
58 |
+
ite_num = hypar["start_ite"] # count the total iteration number
|
59 |
+
ite_num4val = 0 #
|
60 |
+
running_loss = 0.0 # count the toal loss
|
61 |
+
running_tar_loss = 0.0 # count the target output loss
|
62 |
+
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
63 |
+
|
64 |
+
train_num = train_datasets[0].__len__()
|
65 |
+
|
66 |
+
net.train()
|
67 |
+
|
68 |
+
start_last = time.time()
|
69 |
+
gos_dataloader = train_dataloaders[0]
|
70 |
+
epoch_num = hypar["max_epoch_num"]
|
71 |
+
notgood_cnt = 0
|
72 |
+
for epoch in range(epoch_num): ## set the epoch num as 100000
|
73 |
+
|
74 |
+
for i, data in enumerate(gos_dataloader):
|
75 |
+
|
76 |
+
if(ite_num >= max_ite):
|
77 |
+
print("Training Reached the Maximal Iteration Number ", max_ite)
|
78 |
+
exit()
|
79 |
+
|
80 |
+
# start_read = time.time()
|
81 |
+
ite_num = ite_num + 1
|
82 |
+
ite_num4val = ite_num4val + 1
|
83 |
+
|
84 |
+
# get the inputs
|
85 |
+
labels = data['label']
|
86 |
+
|
87 |
+
if(hypar["model_digit"]=="full"):
|
88 |
+
labels = labels.type(torch.FloatTensor)
|
89 |
+
else:
|
90 |
+
labels = labels.type(torch.HalfTensor)
|
91 |
+
|
92 |
+
# wrap them in Variable
|
93 |
+
if torch.cuda.is_available():
|
94 |
+
labels_v = Variable(labels.cuda(), requires_grad=False)
|
95 |
+
else:
|
96 |
+
labels_v = Variable(labels, requires_grad=False)
|
97 |
+
|
98 |
+
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
99 |
+
|
100 |
+
# y zero the parameter gradients
|
101 |
+
start_inf_loss_back = time.time()
|
102 |
+
optimizer.zero_grad()
|
103 |
+
|
104 |
+
# plt.imshow(labels_v[0][0].cpu(),cmap='gray')
|
105 |
+
# plt.show()
|
106 |
+
# with autocast():
|
107 |
+
ds, fs = net(labels_v)#net(inputs_v)
|
108 |
+
loss2, loss = net.compute_loss(ds, labels_v)
|
109 |
+
# scaler.scale(loss).backward()
|
110 |
+
# loss.backward()
|
111 |
+
# scaler.step(optimizer)
|
112 |
+
# scaler.update()
|
113 |
+
#ORTHO Loss
|
114 |
+
reg = 1e-8
|
115 |
+
orth_loss = torch.zeros(1).to(device)
|
116 |
+
for name, param in net.named_parameters():
|
117 |
+
if 'bias' not in name:
|
118 |
+
param_flat = param.view(param.shape[0], -1)
|
119 |
+
sym = torch.mm(param_flat, torch.t(param_flat))
|
120 |
+
sym -= torch.eye(param_flat.shape[0]).to(param.device)
|
121 |
+
orth_loss = orth_loss + (reg * sym.abs().sum())
|
122 |
+
loss = loss + orth_loss
|
123 |
+
loss.backward()
|
124 |
+
optimizer.step()
|
125 |
+
|
126 |
+
running_loss += loss.item()
|
127 |
+
running_tar_loss += loss2.item()
|
128 |
+
|
129 |
+
# del outputs, loss
|
130 |
+
del ds, loss2, loss
|
131 |
+
end_inf_loss_back = time.time()-start_inf_loss_back
|
132 |
+
|
133 |
+
print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
134 |
+
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
135 |
+
start_last = time.time()
|
136 |
+
|
137 |
+
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
138 |
+
notgood_cnt += 1
|
139 |
+
net.eval()
|
140 |
+
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
141 |
+
# tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch)
|
142 |
+
|
143 |
+
net.train() # resume train
|
144 |
+
|
145 |
+
tmp_out = 0
|
146 |
+
print("last_f1:",last_f1,np.mean(last_f1))
|
147 |
+
print("tmp_f1:",tmp_f1,np.mean(tmp_f1))
|
148 |
+
# for fi in range(len(last_f1)):
|
149 |
+
if(np.mean(tmp_f1)>np.mean(last_f1)):
|
150 |
+
tmp_out = 1
|
151 |
+
print("tmp_out:",tmp_out)
|
152 |
+
if(tmp_out):
|
153 |
+
notgood_cnt = 0
|
154 |
+
last_f1 = tmp_f1
|
155 |
+
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
156 |
+
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
157 |
+
maxf1 = '_'.join(tmp_f1_str)
|
158 |
+
meanM = '_'.join(tmp_mae_str)
|
159 |
+
# .cpu().detach().numpy()
|
160 |
+
model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\
|
161 |
+
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
162 |
+
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
163 |
+
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
164 |
+
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
165 |
+
"_maxF1_" + maxf1 + \
|
166 |
+
"_mae_" + meanM + \
|
167 |
+
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
168 |
+
torch.save(net.state_dict(), model_path + model_name)
|
169 |
+
|
170 |
+
running_loss = 0.0
|
171 |
+
running_tar_loss = 0.0
|
172 |
+
ite_num4val = 0
|
173 |
+
|
174 |
+
if(np.mean(tmp_f1)>0.99):
|
175 |
+
print("GT encoder is well-trained and obtained...")
|
176 |
+
return net
|
177 |
+
|
178 |
+
if(notgood_cnt >= hypar["early_stop"]):
|
179 |
+
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
180 |
+
exit()
|
181 |
+
print("Training Reaches The Maximum Epoch Number")
|
182 |
+
return net
|
183 |
+
|
184 |
+
def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
185 |
+
net.eval()
|
186 |
+
print("Validating...")
|
187 |
+
epoch_num = hypar["max_epoch_num"]
|
188 |
+
|
189 |
+
val_loss = 0.0
|
190 |
+
tar_loss = 0.0
|
191 |
+
|
192 |
+
|
193 |
+
tmp_f1 = []
|
194 |
+
tmp_mae = []
|
195 |
+
tmp_time = []
|
196 |
+
|
197 |
+
start_valid = time.time()
|
198 |
+
for k in range(len(valid_dataloaders)):
|
199 |
+
|
200 |
+
valid_dataloader = valid_dataloaders[k]
|
201 |
+
valid_dataset = valid_datasets[k]
|
202 |
+
|
203 |
+
val_num = valid_dataset.__len__()
|
204 |
+
mybins = np.arange(0,256)
|
205 |
+
PRE = np.zeros((val_num,len(mybins)-1))
|
206 |
+
REC = np.zeros((val_num,len(mybins)-1))
|
207 |
+
F1 = np.zeros((val_num,len(mybins)-1))
|
208 |
+
MAE = np.zeros((val_num))
|
209 |
+
|
210 |
+
val_cnt = 0.0
|
211 |
+
i_val = None
|
212 |
+
|
213 |
+
for i_val, data_val in enumerate(valid_dataloader):
|
214 |
+
|
215 |
+
# imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
216 |
+
imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape']
|
217 |
+
if(hypar["model_digit"]=="full"):
|
218 |
+
labels_val = labels_val.type(torch.FloatTensor)
|
219 |
+
else:
|
220 |
+
labels_val = labels_val.type(torch.HalfTensor)
|
221 |
+
|
222 |
+
# wrap them in Variable
|
223 |
+
if torch.cuda.is_available():
|
224 |
+
labels_val_v = Variable(labels_val.cuda(), requires_grad=False)
|
225 |
+
else:
|
226 |
+
labels_val_v = Variable(labels_val,requires_grad=False)
|
227 |
+
# with autocast():
|
228 |
+
t_start = time.time()
|
229 |
+
ds_val = net(labels_val_v)[0]
|
230 |
+
t_end = time.time()-t_start
|
231 |
+
tmp_time.append(t_end)
|
232 |
+
|
233 |
+
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
234 |
+
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
235 |
+
|
236 |
+
# compute F measure
|
237 |
+
for t in range(hypar["batch_size_valid"]):
|
238 |
+
val_cnt = val_cnt + 1.0
|
239 |
+
print("num of val: ", val_cnt)
|
240 |
+
i_test = imidx_val[t].data.numpy()
|
241 |
+
|
242 |
+
pred_val = ds_val[0][t,:,:,:].float() # B x 1 x H x W
|
243 |
+
|
244 |
+
## recover the prediction spatial size to the orignal image size
|
245 |
+
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
246 |
+
|
247 |
+
ma = torch.max(pred_val)
|
248 |
+
mi = torch.min(pred_val)
|
249 |
+
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
250 |
+
# pred_val = normPRED(pred_val)
|
251 |
+
|
252 |
+
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
253 |
+
if gt.max()==1:
|
254 |
+
gt=gt*255
|
255 |
+
with torch.no_grad():
|
256 |
+
gt = torch.tensor(gt).to(device)
|
257 |
+
|
258 |
+
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
259 |
+
|
260 |
+
PRE[i_test,:]=pre
|
261 |
+
REC[i_test,:] = rec
|
262 |
+
F1[i_test,:] = f1
|
263 |
+
MAE[i_test] = mae
|
264 |
+
|
265 |
+
del ds_val, gt
|
266 |
+
gc.collect()
|
267 |
+
torch.cuda.empty_cache()
|
268 |
+
|
269 |
+
# if(loss_val.data[0]>1):
|
270 |
+
val_loss += loss_val.item()#data[0]
|
271 |
+
tar_loss += loss2_val.item()#data[0]
|
272 |
+
|
273 |
+
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
274 |
+
|
275 |
+
del loss2_val, loss_val
|
276 |
+
|
277 |
+
print('============================')
|
278 |
+
PRE_m = np.mean(PRE,0)
|
279 |
+
REC_m = np.mean(REC,0)
|
280 |
+
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
281 |
+
# print('--------------:', np.mean(f1_m))
|
282 |
+
tmp_f1.append(np.amax(f1_m))
|
283 |
+
tmp_mae.append(np.mean(MAE))
|
284 |
+
print("The max F1 Score: %f"%(np.max(f1_m)))
|
285 |
+
print("MAE: ", np.mean(MAE))
|
286 |
+
|
287 |
+
# print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid))
|
288 |
+
|
289 |
+
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
290 |
+
|
291 |
+
def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000):
|
292 |
+
|
293 |
+
if hypar["interm_sup"]:
|
294 |
+
print("Get the gt encoder ...")
|
295 |
+
featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val)
|
296 |
+
## freeze the weights of gt encoder
|
297 |
+
for param in featurenet.parameters():
|
298 |
+
param.requires_grad=False
|
299 |
+
|
300 |
+
# scaler = GradScaler()
|
301 |
+
model_path = hypar["model_path"]
|
302 |
+
model_save_fre = hypar["model_save_fre"]
|
303 |
+
max_ite = hypar["max_ite"]
|
304 |
+
batch_size_train = hypar["batch_size_train"]
|
305 |
+
batch_size_valid = hypar["batch_size_valid"]
|
306 |
+
|
307 |
+
if(not os.path.exists(model_path)):
|
308 |
+
os.mkdir(model_path)
|
309 |
+
|
310 |
+
ite_num = hypar["start_ite"] # count the toal iteration number
|
311 |
+
ite_num4val = 0 #
|
312 |
+
running_loss = 0.0 # count the toal loss
|
313 |
+
running_tar_loss = 0.0 # count the target output loss
|
314 |
+
last_mae = [1 for x in range(len(valid_dataloaders))]
|
315 |
+
last_f1 = [0 for x in range(len(valid_dataloaders))]
|
316 |
+
|
317 |
+
train_num = train_datasets[0].__len__()
|
318 |
+
|
319 |
+
net.train()
|
320 |
+
|
321 |
+
start_last = time.time()
|
322 |
+
gos_dataloader = train_dataloaders[0]
|
323 |
+
epoch_num = hypar["max_epoch_num"]
|
324 |
+
notgood_cnt = 0
|
325 |
+
for epoch in range(epoch_num): ## set the epoch num as 100000
|
326 |
+
|
327 |
+
for i, data in enumerate(gos_dataloader):
|
328 |
+
|
329 |
+
if(ite_num >= max_ite):
|
330 |
+
print("Training Reached the Maximal Iteration Number ", max_ite)
|
331 |
+
exit()
|
332 |
+
|
333 |
+
# start_read = time.time()
|
334 |
+
ite_num = ite_num + 1
|
335 |
+
ite_num4val = ite_num4val + 1
|
336 |
+
|
337 |
+
# get the inputs
|
338 |
+
inputs, labels = data['image'], data['label']
|
339 |
+
locations = data['location_blocks']
|
340 |
+
if(hypar["model_digit"]=="full"):
|
341 |
+
inputs = inputs.type(torch.FloatTensor)
|
342 |
+
labels = labels.type(torch.FloatTensor)
|
343 |
+
locations = locations.type(torch.FloatTensor)
|
344 |
+
else:
|
345 |
+
inputs = inputs.type(torch.HalfTensor)
|
346 |
+
labels = labels.type(torch.HalfTensor)
|
347 |
+
locations = locations.type(torch.HalfTensor)
|
348 |
+
|
349 |
+
# wrap them in Variable
|
350 |
+
if torch.cuda.is_available():
|
351 |
+
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False)
|
352 |
+
locations_v = Variable(locations.cuda(), requires_grad=False)
|
353 |
+
else:
|
354 |
+
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
|
355 |
+
locations_v = Variable(locations, requires_grad=False)
|
356 |
+
|
357 |
+
# print("time lapse for data preparation: ", time.time()-start_read, ' s')
|
358 |
+
|
359 |
+
# y zero the parameter gradients
|
360 |
+
start_inf_loss_back = time.time()
|
361 |
+
optimizer.zero_grad()
|
362 |
+
if hypar["interm_sup"]:
|
363 |
+
# with autocast():
|
364 |
+
# forward + backward + optimize
|
365 |
+
_,fs = featurenet(labels_v)
|
366 |
+
ds,dfs = net(inputs_v)
|
367 |
+
## extract the gt encodings
|
368 |
+
loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE')
|
369 |
+
# loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='cosin')
|
370 |
+
# print(next(featurenet.parameters()).dtype,next(net.parameters()).dtype,labels_v.dtype,fs[0][0].dtype)
|
371 |
+
# print(ds[0][0].dtype,dfs[0][0].dtype)
|
372 |
+
# print(loss2.dtype,loss.dtype)
|
373 |
+
else:
|
374 |
+
# with autocast():
|
375 |
+
# forward + backward + optimize
|
376 |
+
ds,_ = net(inputs_v)
|
377 |
+
loss2, loss = net.compute_loss(ds, labels_v)
|
378 |
+
# loss.backward()
|
379 |
+
# with torch.autograd.detect_anomaly():
|
380 |
+
# scaler.scale(loss).backward()
|
381 |
+
#ORTHO Loss
|
382 |
+
reg = 1e-8
|
383 |
+
orth_loss = torch.zeros(1).to(device)
|
384 |
+
for name, param in net.named_parameters():
|
385 |
+
if 'bias' not in name:
|
386 |
+
param_flat = param.view(param.shape[0], -1)
|
387 |
+
sym = torch.mm(param_flat, torch.t(param_flat))
|
388 |
+
sym -= torch.eye(param_flat.shape[0]).to(device)
|
389 |
+
orth_loss = orth_loss + (reg * sym.abs().sum())
|
390 |
+
loss = loss + orth_loss
|
391 |
+
loss.backward()
|
392 |
+
# scaler.step(optimizer)
|
393 |
+
# scaler.update()
|
394 |
+
optimizer.step()
|
395 |
+
# torch.cuda.empty_cache()
|
396 |
+
|
397 |
+
# # print statistics
|
398 |
+
running_loss += loss.item()
|
399 |
+
running_tar_loss += loss2.item()
|
400 |
+
|
401 |
+
# del outputs, loss
|
402 |
+
del ds, loss2, loss
|
403 |
+
end_inf_loss_back = time.time()-start_inf_loss_back
|
404 |
+
|
405 |
+
print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % (
|
406 |
+
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back))
|
407 |
+
start_last = time.time()
|
408 |
+
|
409 |
+
if ite_num % model_save_fre == 0: # validate every 2000 iterations
|
410 |
+
notgood_cnt += 1
|
411 |
+
net.eval()
|
412 |
+
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch)
|
413 |
+
torch.cuda.empty_cache()
|
414 |
+
net.train() # resume train
|
415 |
+
|
416 |
+
tmp_out = 0
|
417 |
+
print("last_f1:",last_f1,np.mean(last_f1))
|
418 |
+
print("tmp_f1:",tmp_f1,np.mean(tmp_f1))
|
419 |
+
if np.mean(tmp_mae)<np.mean(last_mae):
|
420 |
+
last_mae = tmp_mae
|
421 |
+
tmp_out = 1
|
422 |
+
if np.mean(tmp_f1)>np.mean(last_f1):
|
423 |
+
last_f1 = tmp_f1
|
424 |
+
tmp_out = 1
|
425 |
+
print("tmp_out:",tmp_out)
|
426 |
+
if(tmp_out):
|
427 |
+
notgood_cnt = 0
|
428 |
+
# last_f1 = tmp_f1
|
429 |
+
tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1]
|
430 |
+
tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae]
|
431 |
+
maxf1 = '_'.join(tmp_f1_str)
|
432 |
+
meanM = '_'.join(tmp_mae_str)
|
433 |
+
# .cpu().detach().numpy()
|
434 |
+
model_name = "/gpu_itr_"+str(ite_num)+\
|
435 |
+
"_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\
|
436 |
+
"_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\
|
437 |
+
"_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\
|
438 |
+
"_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \
|
439 |
+
"_maxF1_" + maxf1 + \
|
440 |
+
"_mae_" + meanM + \
|
441 |
+
"_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth"
|
442 |
+
torch.save(net.state_dict(), model_path + model_name)
|
443 |
+
|
444 |
+
running_loss = 0.0
|
445 |
+
running_tar_loss = 0.0
|
446 |
+
ite_num4val = 0
|
447 |
+
|
448 |
+
if(notgood_cnt >= hypar["early_stop"]):
|
449 |
+
print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !")
|
450 |
+
exit()
|
451 |
+
|
452 |
+
print("Training Reaches The Maximum Epoch Number")
|
453 |
+
|
454 |
+
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
|
455 |
+
net.eval()
|
456 |
+
print("Validating...")
|
457 |
+
epoch_num = hypar["max_epoch_num"]
|
458 |
+
|
459 |
+
val_loss = 0.0
|
460 |
+
tar_loss = 0.0
|
461 |
+
val_cnt = 0.0
|
462 |
+
|
463 |
+
tmp_f1 = []
|
464 |
+
tmp_mae = []
|
465 |
+
tmp_time = []
|
466 |
+
|
467 |
+
start_valid = time.time()
|
468 |
+
|
469 |
+
for k in range(len(valid_dataloaders)):
|
470 |
+
|
471 |
+
valid_dataloader = valid_dataloaders[k]
|
472 |
+
valid_dataset = valid_datasets[k]
|
473 |
+
|
474 |
+
val_num = valid_dataset.__len__()
|
475 |
+
mybins = np.arange(0,256)
|
476 |
+
PRE = np.zeros((val_num,len(mybins)-1))
|
477 |
+
REC = np.zeros((val_num,len(mybins)-1))
|
478 |
+
F1 = np.zeros((val_num,len(mybins)-1))
|
479 |
+
MAE = np.zeros((val_num))
|
480 |
+
|
481 |
+
for i_val, data_val in enumerate(valid_dataloader):
|
482 |
+
val_cnt = val_cnt + 1.0
|
483 |
+
imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape']
|
484 |
+
|
485 |
+
if(hypar["model_digit"]=="full"):
|
486 |
+
inputs_val = inputs_val.type(torch.FloatTensor)
|
487 |
+
labels_val = labels_val.type(torch.FloatTensor)
|
488 |
+
else:
|
489 |
+
inputs_val = inputs_val.type(torch.HalfTensor)
|
490 |
+
labels_val = labels_val.type(torch.HalfTensor)
|
491 |
+
|
492 |
+
# wrap them in Variable
|
493 |
+
if torch.cuda.is_available():
|
494 |
+
inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False)
|
495 |
+
else:
|
496 |
+
inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False)
|
497 |
+
# with autocast():
|
498 |
+
t_start = time.time()
|
499 |
+
ds_val = net(inputs_val_v)[0]
|
500 |
+
# plt.imshow(inputs_val_v[0][0].cpu().detach())
|
501 |
+
# plt.show()
|
502 |
+
# print(inputs_val_v.cpu().detach().shape)
|
503 |
+
t_end = time.time()-t_start
|
504 |
+
tmp_time.append(t_end)
|
505 |
+
|
506 |
+
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
|
507 |
+
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
|
508 |
+
|
509 |
+
# compute F measure
|
510 |
+
for t in range(hypar["batch_size_valid"]):
|
511 |
+
i_test = imidx_val[t].data.numpy()
|
512 |
+
|
513 |
+
pred_val = ds_val[0][t,:,:,:].float() # B x 1 x H x W
|
514 |
+
|
515 |
+
## recover the prediction spatial size to the orignal image size
|
516 |
+
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear'))
|
517 |
+
|
518 |
+
# pred_val = normPRED(pred_val)
|
519 |
+
ma = torch.max(pred_val)
|
520 |
+
mi = torch.min(pred_val)
|
521 |
+
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
522 |
+
|
523 |
+
gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255
|
524 |
+
if gt.max()==1:
|
525 |
+
gt=gt*255
|
526 |
+
|
527 |
+
with torch.no_grad():
|
528 |
+
gt = torch.tensor(gt).to(device)
|
529 |
+
|
530 |
+
pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar)
|
531 |
+
|
532 |
+
|
533 |
+
PRE[i_test,:]=pre
|
534 |
+
REC[i_test,:] = rec
|
535 |
+
F1[i_test,:] = f1
|
536 |
+
MAE[i_test] = mae
|
537 |
+
|
538 |
+
del ds_val, gt
|
539 |
+
gc.collect()
|
540 |
+
torch.cuda.empty_cache()
|
541 |
+
|
542 |
+
# if(loss_val.data[0]>1):
|
543 |
+
val_loss += loss_val.item()#data[0]
|
544 |
+
tar_loss += loss2_val.item()#data[0]
|
545 |
+
|
546 |
+
print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end))
|
547 |
+
|
548 |
+
del loss2_val, loss_val
|
549 |
+
|
550 |
+
print('============================')
|
551 |
+
PRE_m = np.mean(PRE,0)
|
552 |
+
REC_m = np.mean(REC,0)
|
553 |
+
f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8)
|
554 |
+
|
555 |
+
tmp_f1.append(np.amax(f1_m))
|
556 |
+
tmp_mae.append(np.mean(MAE))
|
557 |
+
|
558 |
+
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
|
559 |
+
|
560 |
+
def main(train_datasets,
|
561 |
+
valid_datasets,
|
562 |
+
hypar): # model: "train", "test"
|
563 |
+
|
564 |
+
### --- Step 1: Build datasets and dataloaders ---
|
565 |
+
dataloaders_train = []
|
566 |
+
dataloaders_valid = []
|
567 |
+
|
568 |
+
if(hypar["mode"]=="train"):
|
569 |
+
print("--- create training dataloader ---")
|
570 |
+
## collect training dataset
|
571 |
+
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
|
572 |
+
## build dataloader for training datasets
|
573 |
+
train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list,
|
574 |
+
cache_size = hypar["cache_size"],
|
575 |
+
cache_boost = hypar["cache_boost_train"],
|
576 |
+
my_transforms = [
|
577 |
+
GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation
|
578 |
+
# GOSResize(hypar["input_size"]),
|
579 |
+
# GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation
|
580 |
+
# GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]),
|
581 |
+
GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]),
|
582 |
+
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
583 |
+
# GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375])
|
584 |
+
],
|
585 |
+
batch_size = hypar["batch_size_train"],
|
586 |
+
shuffle = True,
|
587 |
+
is_train=True)
|
588 |
+
train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list,
|
589 |
+
cache_size = hypar["cache_size"],
|
590 |
+
cache_boost = hypar["cache_boost_train"],
|
591 |
+
my_transforms = [
|
592 |
+
# GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]),
|
593 |
+
GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]),
|
594 |
+
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
595 |
+
# GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375])
|
596 |
+
],
|
597 |
+
batch_size = hypar["batch_size_valid"],
|
598 |
+
shuffle = False,
|
599 |
+
is_train=False)
|
600 |
+
print(len(train_dataloaders), " train dataloaders created")
|
601 |
+
|
602 |
+
print("--- create valid dataloader ---")
|
603 |
+
## build dataloader for validation or testing
|
604 |
+
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
|
605 |
+
## build dataloader for training datasets
|
606 |
+
valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list,
|
607 |
+
cache_size = hypar["cache_size"],
|
608 |
+
cache_boost = hypar["cache_boost_valid"],
|
609 |
+
my_transforms = [
|
610 |
+
# GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]),
|
611 |
+
GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]),
|
612 |
+
# GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]),
|
613 |
+
# GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375])
|
614 |
+
# GOSResize(hypar["input_size"])
|
615 |
+
],
|
616 |
+
batch_size=hypar["batch_size_valid"],
|
617 |
+
shuffle=False,
|
618 |
+
is_train=False)
|
619 |
+
print(len(valid_dataloaders), " valid dataloaders created")
|
620 |
+
# print(valid_datasets[0]["data_name"])
|
621 |
+
|
622 |
+
### --- Step 2: Build Model and Optimizer ---
|
623 |
+
print("--- build model ---")
|
624 |
+
net = hypar["model"]#GOSNETINC(3,1)
|
625 |
+
|
626 |
+
# convert to half precision
|
627 |
+
# if(hypar["model_digit"]=="half"):
|
628 |
+
# net.half()
|
629 |
+
|
630 |
+
if torch.cuda.is_available():
|
631 |
+
net.cuda()
|
632 |
+
|
633 |
+
if(hypar["restore_model"]!=""):
|
634 |
+
print("restore model from:")
|
635 |
+
print(hypar["model_path"]+"/"+hypar["restore_model"])
|
636 |
+
if torch.cuda.is_available():
|
637 |
+
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]),strict=False)
|
638 |
+
else:
|
639 |
+
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"),strict=False)
|
640 |
+
|
641 |
+
print("--- define optimizer ---")
|
642 |
+
# optimizer = optim.AdamW(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
643 |
+
optimizer = optim.AdamW(net.parameters(), lr=4e-5, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
644 |
+
### --- Step 3: Train or Valid Model ---
|
645 |
+
if(hypar["mode"]=="train"):
|
646 |
+
train(net,
|
647 |
+
optimizer,
|
648 |
+
train_dataloaders,
|
649 |
+
train_datasets,
|
650 |
+
valid_dataloaders,
|
651 |
+
valid_datasets,
|
652 |
+
hypar,
|
653 |
+
train_dataloaders_val, train_datasets_val)
|
654 |
+
else:
|
655 |
+
valid(net,
|
656 |
+
valid_dataloaders,
|
657 |
+
valid_datasets,
|
658 |
+
hypar)
|
659 |
+
|
660 |
+
|
661 |
+
if __name__ == "__main__":
|
662 |
+
|
663 |
+
### --------------- STEP 1: Configuring the Train, Valid and Test datasets ---------------
|
664 |
+
## configure the train, valid and inference datasets
|
665 |
+
train_datasets, valid_datasets = [], []
|
666 |
+
|
667 |
+
valid_datasets = [dataset_test] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets
|
668 |
+
train_datasets = [dataset_test] ## users can create mutiple dictionary for setting a list of datasets as training set
|
669 |
+
|
670 |
+
|
671 |
+
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
|
672 |
+
hypar = {}
|
673 |
+
|
674 |
+
## -- 2.1. configure the model saving or restoring path --
|
675 |
+
hypar["mode"] = "train"
|
676 |
+
## "train": for training,
|
677 |
+
## "valid": for validation and inferening,
|
678 |
+
## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be ""
|
679 |
+
## otherwise only accuracy will bee calculated and no predictions will be saved
|
680 |
+
hypar["interm_sup"] = True ## in-dicate if activate intermediate feature supervision
|
681 |
+
|
682 |
+
if hypar["mode"] == "train":
|
683 |
+
hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory
|
684 |
+
hypar["model_path"] ="./saved_models" ## model weights saving (or restoring) path
|
685 |
+
hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
|
686 |
+
hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process
|
687 |
+
hypar["gt_encoder_model"] = ""
|
688 |
+
else: ## configure the segmentation output path and the to-be-used model weights path
|
689 |
+
hypar["valid_out_dir"] = "./your-results/"##".D:/Code/Design_for_graduation/DIS-main/IS-Net/DIS5K-Results-test" ## output inferenced segmentation maps into this fold
|
690 |
+
hypar["model_path"] = "./saved_models" ## load trained weights from this path
|
691 |
+
hypar["restore_model"] = "gpu_itr_102000_traLoss_2.5701_traTarLoss_0.0248_valLoss_2.3643_valTarLoss_0.3743_maxF1_0.8063_mae_0.0825_time_0.015695.pth"##"isnet.pth" ## name of the to-be-loaded weights
|
692 |
+
|
693 |
+
# if hypar["restore_model"]!="":
|
694 |
+
# hypar["start_ite"] = int(hypar["restore_model"].split("_")[2])
|
695 |
+
|
696 |
+
## -- 2.2. choose floating point accuracy --
|
697 |
+
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
|
698 |
+
hypar["seed"] = 0
|
699 |
+
|
700 |
+
## -- 2.3. cache data spatial size --
|
701 |
+
## To handle large size input images, which take a lot of time for loading in training,
|
702 |
+
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
|
703 |
+
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
|
704 |
+
hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
|
705 |
+
hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
|
706 |
+
|
707 |
+
## --- 2.4. data augmentation parameters ---
|
708 |
+
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
|
709 |
+
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
|
710 |
+
hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the datader and it is not in use
|
711 |
+
hypar["random_flip_v"] = 1 ## vertical flip , currently not in use
|
712 |
+
|
713 |
+
## --- 2.5. define model ---
|
714 |
+
print("building model...")
|
715 |
+
hypar["model"] = ISNetDIS(in_ch=5) #U2NETFASTFEATURESUP()
|
716 |
+
hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
|
717 |
+
hypar["model_save_fre"] = 3000 ## valid and save model weights every 2000 iterations
|
718 |
+
|
719 |
+
hypar["batch_size_train"] = 6 ## batch size for training
|
720 |
+
hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing
|
721 |
+
print("batch size: ", hypar["batch_size_train"])
|
722 |
+
|
723 |
+
hypar["max_ite"] = 50000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num
|
724 |
+
hypar["max_epoch_num"] = 500000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
|
725 |
+
|
726 |
+
main(train_datasets,
|
727 |
+
valid_datasets,
|
728 |
+
hypar=hypar)
|
729 |
+
|
MultiScaleDeformableAttention-1.0-py3-none-any.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:152caec7860d1f39f644ac5eed946b5a4eecfad40764396345b3d0e516921b17
|
3 |
+
size 2048806
|
README.md
CHANGED
@@ -9,6 +9,8 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: SAM-prompted dichotomous segmentation. No affiliation.
|
|
|
|
|
|
|
|
|
12 |
---
|
13 |
-
|
14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: SAM-prompted dichotomous segmentation. No affiliation.
|
12 |
+
python_version: 3.11
|
13 |
+
preload_from_hub:
|
14 |
+
- jwlarocque/DIS-SAM DIS-SAM-checkpoint.pth
|
15 |
+
- andzhang01/segment_anything sam_vit_l_0b3195.pth
|
16 |
---
|
|
|
|
SAM/segment_anything/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .build_sam import (
|
8 |
+
build_sam,
|
9 |
+
build_sam_vit_h,
|
10 |
+
build_sam_vit_l,
|
11 |
+
build_sam_vit_b,
|
12 |
+
sam_model_registry,
|
13 |
+
)
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
SAM/segment_anything/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (549 Bytes). View file
|
|
SAM/segment_anything/__pycache__/automatic_mask_generator.cpython-311.pyc
ADDED
Binary file (18.3 kB). View file
|
|
SAM/segment_anything/__pycache__/build_sam.cpython-311.pyc
ADDED
Binary file (3.22 kB). View file
|
|
SAM/segment_anything/__pycache__/predictor.cpython-311.pyc
ADDED
Binary file (14.1 kB). View file
|
|
SAM/segment_anything/automatic_mask_generator.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
10 |
+
|
11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
12 |
+
|
13 |
+
from .modeling import Sam
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .utils.amg import (
|
16 |
+
MaskData,
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
remove_small_regions,
|
28 |
+
rle_to_mask,
|
29 |
+
uncrop_boxes_xyxy,
|
30 |
+
uncrop_masks,
|
31 |
+
uncrop_points,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class SamAutomaticMaskGenerator:
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model: Sam,
|
39 |
+
points_per_side: Optional[int] = 32,
|
40 |
+
points_per_batch: int = 64,
|
41 |
+
pred_iou_thresh: float = 0.88,
|
42 |
+
stability_score_thresh: float = 0.95,
|
43 |
+
stability_score_offset: float = 1.0,
|
44 |
+
box_nms_thresh: float = 0.7,
|
45 |
+
crop_n_layers: int = 0,
|
46 |
+
crop_nms_thresh: float = 0.7,
|
47 |
+
crop_overlap_ratio: float = 512 / 1500,
|
48 |
+
crop_n_points_downscale_factor: int = 1,
|
49 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
50 |
+
min_mask_region_area: int = 0,
|
51 |
+
output_mode: str = "binary_mask",
|
52 |
+
) -> None:
|
53 |
+
"""
|
54 |
+
Using a SAM model, generates masks for the entire image.
|
55 |
+
Generates a grid of point prompts over the image, then filters
|
56 |
+
low quality and duplicate masks. The default settings are chosen
|
57 |
+
for SAM with a ViT-H backbone.
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
model (Sam): The SAM model to use for mask prediction.
|
61 |
+
points_per_side (int or None): The number of points to be sampled
|
62 |
+
along one side of the image. The total number of points is
|
63 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
64 |
+
point sampling.
|
65 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
66 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
67 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
68 |
+
model's predicted mask quality.
|
69 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
70 |
+
the stability of the mask under changes to the cutoff used to binarize
|
71 |
+
the model's mask predictions.
|
72 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
73 |
+
calculated the stability score.
|
74 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
75 |
+
suppression to filter duplicate masks.
|
76 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
77 |
+
crops of the image. Sets the number of layers to run, where each
|
78 |
+
layer has 2**i_layer number of image crops.
|
79 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
80 |
+
suppression to filter duplicate masks between different crops.
|
81 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
82 |
+
In the first crop layer, crops will overlap by this fraction of
|
83 |
+
the image length. Later layers with more crops scale down this overlap.
|
84 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
85 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
86 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
87 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
88 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
89 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
90 |
+
to remove disconnected regions and holes in masks with area smaller
|
91 |
+
than min_mask_region_area. Requires opencv.
|
92 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
93 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
94 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
95 |
+
memory.
|
96 |
+
"""
|
97 |
+
|
98 |
+
assert (points_per_side is None) != (
|
99 |
+
point_grids is None
|
100 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
101 |
+
if points_per_side is not None:
|
102 |
+
self.point_grids = build_all_layer_point_grids(
|
103 |
+
points_per_side,
|
104 |
+
crop_n_layers,
|
105 |
+
crop_n_points_downscale_factor,
|
106 |
+
)
|
107 |
+
elif point_grids is not None:
|
108 |
+
self.point_grids = point_grids
|
109 |
+
else:
|
110 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
111 |
+
|
112 |
+
assert output_mode in [
|
113 |
+
"binary_mask",
|
114 |
+
"uncompressed_rle",
|
115 |
+
"coco_rle",
|
116 |
+
], f"Unknown output_mode {output_mode}."
|
117 |
+
if output_mode == "coco_rle":
|
118 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
119 |
+
|
120 |
+
if min_mask_region_area > 0:
|
121 |
+
import cv2 # type: ignore # noqa: F401
|
122 |
+
|
123 |
+
self.predictor = SamPredictor(model)
|
124 |
+
self.points_per_batch = points_per_batch
|
125 |
+
self.pred_iou_thresh = pred_iou_thresh
|
126 |
+
self.stability_score_thresh = stability_score_thresh
|
127 |
+
self.stability_score_offset = stability_score_offset
|
128 |
+
self.box_nms_thresh = box_nms_thresh
|
129 |
+
self.crop_n_layers = crop_n_layers
|
130 |
+
self.crop_nms_thresh = crop_nms_thresh
|
131 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
132 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
133 |
+
self.min_mask_region_area = min_mask_region_area
|
134 |
+
self.output_mode = output_mode
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
138 |
+
"""
|
139 |
+
Generates masks for the given image.
|
140 |
+
|
141 |
+
Arguments:
|
142 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
146 |
+
a dict containing the following keys:
|
147 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
148 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
149 |
+
is a dictionary containing the RLE.
|
150 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
151 |
+
area (int): The area in pixels of the mask.
|
152 |
+
predicted_iou (float): The model's own prediction of the mask's
|
153 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
154 |
+
point_coords (list(list(float))): The point coordinates input
|
155 |
+
to the model to generate this mask.
|
156 |
+
stability_score (float): A measure of the mask's quality. This
|
157 |
+
is filtered on using the stability_score_thresh parameter.
|
158 |
+
crop_box (list(float)): The crop of the image used to generate
|
159 |
+
the mask, given in XYWH format.
|
160 |
+
"""
|
161 |
+
|
162 |
+
# Generate masks
|
163 |
+
mask_data = self._generate_masks(image)
|
164 |
+
|
165 |
+
# Filter small disconnected regions and holes in masks
|
166 |
+
if self.min_mask_region_area > 0:
|
167 |
+
mask_data = self.postprocess_small_regions(
|
168 |
+
mask_data,
|
169 |
+
self.min_mask_region_area,
|
170 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
171 |
+
)
|
172 |
+
|
173 |
+
# Encode masks
|
174 |
+
if self.output_mode == "coco_rle":
|
175 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
176 |
+
elif self.output_mode == "binary_mask":
|
177 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
178 |
+
else:
|
179 |
+
mask_data["segmentations"] = mask_data["rles"]
|
180 |
+
|
181 |
+
# Write mask records
|
182 |
+
curr_anns = []
|
183 |
+
for idx in range(len(mask_data["segmentations"])):
|
184 |
+
ann = {
|
185 |
+
"segmentation": mask_data["segmentations"][idx],
|
186 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
187 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
188 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
189 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
190 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
191 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
192 |
+
}
|
193 |
+
curr_anns.append(ann)
|
194 |
+
|
195 |
+
return curr_anns
|
196 |
+
|
197 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
198 |
+
orig_size = image.shape[:2]
|
199 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
200 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
201 |
+
)
|
202 |
+
|
203 |
+
# Iterate over image crops
|
204 |
+
data = MaskData()
|
205 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
206 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
207 |
+
data.cat(crop_data)
|
208 |
+
|
209 |
+
# Remove duplicate masks between crops
|
210 |
+
if len(crop_boxes) > 1:
|
211 |
+
# Prefer masks from smaller crops
|
212 |
+
scores = 1 / box_area(data["crop_boxes"])
|
213 |
+
scores = scores.to(data["boxes"].device)
|
214 |
+
keep_by_nms = batched_nms(
|
215 |
+
data["boxes"].float(),
|
216 |
+
scores,
|
217 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
218 |
+
iou_threshold=self.crop_nms_thresh,
|
219 |
+
)
|
220 |
+
data.filter(keep_by_nms)
|
221 |
+
|
222 |
+
data.to_numpy()
|
223 |
+
return data
|
224 |
+
|
225 |
+
def _process_crop(
|
226 |
+
self,
|
227 |
+
image: np.ndarray,
|
228 |
+
crop_box: List[int],
|
229 |
+
crop_layer_idx: int,
|
230 |
+
orig_size: Tuple[int, ...],
|
231 |
+
) -> MaskData:
|
232 |
+
# Crop the image and calculate embeddings
|
233 |
+
x0, y0, x1, y1 = crop_box
|
234 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
235 |
+
cropped_im_size = cropped_im.shape[:2]
|
236 |
+
self.predictor.set_image(cropped_im)
|
237 |
+
|
238 |
+
# Get points for this crop
|
239 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
240 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
241 |
+
|
242 |
+
# Generate masks for this crop in batches
|
243 |
+
data = MaskData()
|
244 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
245 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
246 |
+
data.cat(batch_data)
|
247 |
+
del batch_data
|
248 |
+
self.predictor.reset_image()
|
249 |
+
|
250 |
+
# Remove duplicates within this crop.
|
251 |
+
keep_by_nms = batched_nms(
|
252 |
+
data["boxes"].float(),
|
253 |
+
data["iou_preds"],
|
254 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
255 |
+
iou_threshold=self.box_nms_thresh,
|
256 |
+
)
|
257 |
+
data.filter(keep_by_nms)
|
258 |
+
|
259 |
+
# Return to the original image frame
|
260 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
261 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
262 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
263 |
+
|
264 |
+
return data
|
265 |
+
|
266 |
+
def _process_batch(
|
267 |
+
self,
|
268 |
+
points: np.ndarray,
|
269 |
+
im_size: Tuple[int, ...],
|
270 |
+
crop_box: List[int],
|
271 |
+
orig_size: Tuple[int, ...],
|
272 |
+
) -> MaskData:
|
273 |
+
orig_h, orig_w = orig_size
|
274 |
+
|
275 |
+
# Run model on this batch
|
276 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
277 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
278 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
279 |
+
masks, iou_preds, _ = self.predictor.predict_torch(
|
280 |
+
in_points[:, None, :],
|
281 |
+
in_labels[:, None],
|
282 |
+
multimask_output=True,
|
283 |
+
return_logits=True,
|
284 |
+
)
|
285 |
+
|
286 |
+
# Serialize predictions and store in MaskData
|
287 |
+
data = MaskData(
|
288 |
+
masks=masks.flatten(0, 1),
|
289 |
+
iou_preds=iou_preds.flatten(0, 1),
|
290 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
291 |
+
)
|
292 |
+
del masks
|
293 |
+
|
294 |
+
# Filter by predicted IoU
|
295 |
+
if self.pred_iou_thresh > 0.0:
|
296 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
297 |
+
data.filter(keep_mask)
|
298 |
+
|
299 |
+
# Calculate stability score
|
300 |
+
data["stability_score"] = calculate_stability_score(
|
301 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
302 |
+
)
|
303 |
+
if self.stability_score_thresh > 0.0:
|
304 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
305 |
+
data.filter(keep_mask)
|
306 |
+
|
307 |
+
# Threshold masks and calculate boxes
|
308 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
309 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
310 |
+
|
311 |
+
# Filter boxes that touch crop boundaries
|
312 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
313 |
+
if not torch.all(keep_mask):
|
314 |
+
data.filter(keep_mask)
|
315 |
+
|
316 |
+
# Compress to RLE
|
317 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
318 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
319 |
+
del data["masks"]
|
320 |
+
|
321 |
+
return data
|
322 |
+
|
323 |
+
@staticmethod
|
324 |
+
def postprocess_small_regions(
|
325 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
326 |
+
) -> MaskData:
|
327 |
+
"""
|
328 |
+
Removes small disconnected regions and holes in masks, then reruns
|
329 |
+
box NMS to remove any new duplicates.
|
330 |
+
|
331 |
+
Edits mask_data in place.
|
332 |
+
|
333 |
+
Requires open-cv as a dependency.
|
334 |
+
"""
|
335 |
+
if len(mask_data["rles"]) == 0:
|
336 |
+
return mask_data
|
337 |
+
|
338 |
+
# Filter small disconnected regions and holes
|
339 |
+
new_masks = []
|
340 |
+
scores = []
|
341 |
+
for rle in mask_data["rles"]:
|
342 |
+
mask = rle_to_mask(rle)
|
343 |
+
|
344 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
345 |
+
unchanged = not changed
|
346 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
347 |
+
unchanged = unchanged and not changed
|
348 |
+
|
349 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
350 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
351 |
+
# so NMS will prefer ones that didn't need postprocessing
|
352 |
+
scores.append(float(unchanged))
|
353 |
+
|
354 |
+
# Recalculate boxes and remove any new duplicates
|
355 |
+
masks = torch.cat(new_masks, dim=0)
|
356 |
+
boxes = batched_mask_to_box(masks)
|
357 |
+
keep_by_nms = batched_nms(
|
358 |
+
boxes.float(),
|
359 |
+
torch.as_tensor(scores),
|
360 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
361 |
+
iou_threshold=nms_thresh,
|
362 |
+
)
|
363 |
+
|
364 |
+
# Only recalculate RLEs for masks that have changed
|
365 |
+
for i_mask in keep_by_nms:
|
366 |
+
if scores[i_mask] == 0.0:
|
367 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
368 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
369 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
370 |
+
mask_data.filter(keep_by_nms)
|
371 |
+
|
372 |
+
return mask_data
|
SAM/segment_anything/build_sam.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
12 |
+
|
13 |
+
|
14 |
+
def build_sam_vit_h(checkpoint=None, device="cpu"):
|
15 |
+
return _build_sam(
|
16 |
+
encoder_embed_dim=1280,
|
17 |
+
encoder_depth=32,
|
18 |
+
encoder_num_heads=16,
|
19 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
20 |
+
checkpoint=checkpoint,
|
21 |
+
device=device,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
build_sam = build_sam_vit_h
|
26 |
+
|
27 |
+
|
28 |
+
def build_sam_vit_l(checkpoint=None, device="cpu"):
|
29 |
+
return _build_sam(
|
30 |
+
encoder_embed_dim=1024,
|
31 |
+
encoder_depth=24,
|
32 |
+
encoder_num_heads=16,
|
33 |
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
34 |
+
checkpoint=checkpoint,
|
35 |
+
device=device,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def build_sam_vit_b(checkpoint=None, device="cpu"):
|
40 |
+
return _build_sam(
|
41 |
+
encoder_embed_dim=768,
|
42 |
+
encoder_depth=12,
|
43 |
+
encoder_num_heads=12,
|
44 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
45 |
+
checkpoint=checkpoint,
|
46 |
+
device=device,
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
sam_model_registry = {
|
51 |
+
"default": build_sam_vit_h,
|
52 |
+
"vit_h": build_sam_vit_h,
|
53 |
+
"vit_l": build_sam_vit_l,
|
54 |
+
"vit_b": build_sam_vit_b,
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
def _build_sam(
|
59 |
+
encoder_embed_dim,
|
60 |
+
encoder_depth,
|
61 |
+
encoder_num_heads,
|
62 |
+
encoder_global_attn_indexes,
|
63 |
+
checkpoint=None,
|
64 |
+
device="cpu"
|
65 |
+
):
|
66 |
+
prompt_embed_dim = 256
|
67 |
+
image_size = 1024
|
68 |
+
vit_patch_size = 16
|
69 |
+
image_embedding_size = image_size // vit_patch_size
|
70 |
+
sam = Sam(
|
71 |
+
image_encoder=ImageEncoderViT(
|
72 |
+
depth=encoder_depth,
|
73 |
+
embed_dim=encoder_embed_dim,
|
74 |
+
img_size=image_size,
|
75 |
+
mlp_ratio=4,
|
76 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
77 |
+
num_heads=encoder_num_heads,
|
78 |
+
patch_size=vit_patch_size,
|
79 |
+
qkv_bias=True,
|
80 |
+
use_rel_pos=True,
|
81 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
82 |
+
window_size=14,
|
83 |
+
out_chans=prompt_embed_dim,
|
84 |
+
),
|
85 |
+
prompt_encoder=PromptEncoder(
|
86 |
+
embed_dim=prompt_embed_dim,
|
87 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
88 |
+
input_image_size=(image_size, image_size),
|
89 |
+
mask_in_chans=16,
|
90 |
+
),
|
91 |
+
mask_decoder=MaskDecoder(
|
92 |
+
num_multimask_outputs=3,
|
93 |
+
transformer=TwoWayTransformer(
|
94 |
+
depth=2,
|
95 |
+
embedding_dim=prompt_embed_dim,
|
96 |
+
mlp_dim=2048,
|
97 |
+
num_heads=8,
|
98 |
+
),
|
99 |
+
transformer_dim=prompt_embed_dim,
|
100 |
+
iou_head_depth=3,
|
101 |
+
iou_head_hidden_dim=256,
|
102 |
+
),
|
103 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
104 |
+
pixel_std=[58.395, 57.12, 57.375],
|
105 |
+
)
|
106 |
+
sam.eval()
|
107 |
+
if checkpoint is not None:
|
108 |
+
with open(checkpoint, "rb") as f:
|
109 |
+
state_dict = torch.load(f, map_location=device)
|
110 |
+
sam.load_state_dict(state_dict)
|
111 |
+
return sam
|
SAM/segment_anything/modeling/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .sam import Sam
|
8 |
+
from .image_encoder import ImageEncoderViT
|
9 |
+
from .mask_decoder import MaskDecoder
|
10 |
+
from .prompt_encoder import PromptEncoder
|
11 |
+
from .transformer import TwoWayTransformer
|
SAM/segment_anything/modeling/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (516 Bytes). View file
|
|
SAM/segment_anything/modeling/__pycache__/common.cpython-311.pyc
ADDED
Binary file (3.24 kB). View file
|
|