diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..cd94a6842e224cd9dd69a0b655edadd5d1b602eb 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +MultiScaleDeformableAttention-1.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png b/IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png new file mode 100644 index 0000000000000000000000000000000000000000..64763f11b0ca95dbfdff95166e52785dffb85ffe Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png b/IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png new file mode 100644 index 0000000000000000000000000000000000000000..c1c884768b6e88af6a60b3c442505ca29be2b5d4 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_gt/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png b/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png new file mode 100644 index 0000000000000000000000000000000000000000..526774eea92fc36c6041081c4602273db5ee9ee4 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png b/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png new file mode 100644 index 0000000000000000000000000000000000000000..e9974ef8193af2605a21c8e93e1957a9278c012d Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png b/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png new file mode 100644 index 0000000000000000000000000000000000000000..493cbd08b1bf33ac895cd7f437db9a2f4d5c6658 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_gt/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.jpg b/IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ed04494996862aead69bd958215b5d27420e8725 --- /dev/null +++ b/IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4cb7d3c28db6f3bc4d2227d7551b1ba85abc9d335a1c6a625777de351bb9d469 +size 778870 diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.jpg b/IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f403c7ad4eae41cbdb85a18957df3d8fe09f593e --- /dev/null +++ b/IS_Net/DIS5K/DIS5K-test/enhance_im/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c4229b3b7978308ba3f28903e5a65e2bce7bfa7bd684f53c5bb23f3067dd6c4 +size 146389 diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.jpg b/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9fe922d5759d68b35d2f842c0a08d0c4104bc8e --- /dev/null +++ b/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71fd4c8bd0e10b57142b9781bc4654f368cd91ceb8e0b4e22ff9e54ce0b2fe06 +size 1193002 diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.jpg b/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9fe922d5759d68b35d2f842c0a08d0c4104bc8e --- /dev/null +++ b/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71fd4c8bd0e10b57142b9781bc4654f368cd91ceb8e0b4e22ff9e54ce0b2fe06 +size 1193002 diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.jpg b/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9fe922d5759d68b35d2f842c0a08d0c4104bc8e --- /dev/null +++ b/IS_Net/DIS5K/DIS5K-test/enhance_im/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71fd4c8bd0e10b57142b9781bc4654f368cd91ceb8e0b4e22ff9e54ce0b2fe06 +size 1193002 diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png b/IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png new file mode 100644 index 0000000000000000000000000000000000000000..b0e245df3ce6d512ccc4432910bafeb5d7740aa5 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#2339506821_83cf9f1d22_o_comp_1.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png b/IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png new file mode 100644 index 0000000000000000000000000000000000000000..c39f2073c2e8eb925b69f971e4de294674ff5600 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_sam/1#Accessories#1#Bag#3292738108_c51336a8be_o_comp_1.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png b/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png new file mode 100644 index 0000000000000000000000000000000000000000..1c0e91d0bbb570217bb5bf8095da2fc85366446e Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png b/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png new file mode 100644 index 0000000000000000000000000000000000000000..27c6dfc7c10dec2e7cc0f78d84e90ad80d9d796a Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_1.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png b/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png new file mode 100644 index 0000000000000000000000000000000000000000..4607f7cce85ed9504b99abe5b4c406cc214f6273 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/enhance_sam/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_2.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png b/IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png new file mode 100644 index 0000000000000000000000000000000000000000..261a06cae51479380545fd25f128301975a1ee63 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#3292738108_c51336a8be_o.png b/IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#3292738108_c51336a8be_o.png new file mode 100644 index 0000000000000000000000000000000000000000..6df2787252910f782a4e63bb00667811b8a46256 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/gt/1#Accessories#1#Bag#3292738108_c51336a8be_o.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/gt/4#Architecture#10#Pavilion#5795028920_08884db993_o.png b/IS_Net/DIS5K/DIS5K-test/gt/4#Architecture#10#Pavilion#5795028920_08884db993_o.png new file mode 100644 index 0000000000000000000000000000000000000000..9cdf21dae36eee7cae8078a6574857a41de07f3a Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/gt/4#Architecture#10#Pavilion#5795028920_08884db993_o.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.jpg b/IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ed04494996862aead69bd958215b5d27420e8725 --- /dev/null +++ b/IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4cb7d3c28db6f3bc4d2227d7551b1ba85abc9d335a1c6a625777de351bb9d469 +size 778870 diff --git a/IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#3292738108_c51336a8be_o.jpg b/IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#3292738108_c51336a8be_o.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f403c7ad4eae41cbdb85a18957df3d8fe09f593e --- /dev/null +++ b/IS_Net/DIS5K/DIS5K-test/im/1#Accessories#1#Bag#3292738108_c51336a8be_o.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c4229b3b7978308ba3f28903e5a65e2bce7bfa7bd684f53c5bb23f3067dd6c4 +size 146389 diff --git a/IS_Net/DIS5K/DIS5K-test/im/4#Architecture#10#Pavilion#5795028920_08884db993_o.jpg b/IS_Net/DIS5K/DIS5K-test/im/4#Architecture#10#Pavilion#5795028920_08884db993_o.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e9fe922d5759d68b35d2f842c0a08d0c4104bc8e --- /dev/null +++ b/IS_Net/DIS5K/DIS5K-test/im/4#Architecture#10#Pavilion#5795028920_08884db993_o.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71fd4c8bd0e10b57142b9781bc4654f368cd91ceb8e0b4e22ff9e54ce0b2fe06 +size 1193002 diff --git a/IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png b/IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png new file mode 100644 index 0000000000000000000000000000000000000000..b0e245df3ce6d512ccc4432910bafeb5d7740aa5 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#2339506821_83cf9f1d22_o.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#3292738108_c51336a8be_o.png b/IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#3292738108_c51336a8be_o.png new file mode 100644 index 0000000000000000000000000000000000000000..c39f2073c2e8eb925b69f971e4de294674ff5600 Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/mask/1#Accessories#1#Bag#3292738108_c51336a8be_o.png differ diff --git a/IS_Net/DIS5K/DIS5K-test/mask/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png b/IS_Net/DIS5K/DIS5K-test/mask/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png new file mode 100644 index 0000000000000000000000000000000000000000..1c0e91d0bbb570217bb5bf8095da2fc85366446e Binary files /dev/null and b/IS_Net/DIS5K/DIS5K-test/mask/4#Architecture#10#Pavilion#5795028920_08884db993_o_comp_0.png differ diff --git a/IS_Net/__pycache__/data_loader.cpython-311.pyc b/IS_Net/__pycache__/data_loader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c84cd860816a17a7fe9f9d34b7db5bed2c9ab183 Binary files /dev/null and b/IS_Net/__pycache__/data_loader.cpython-311.pyc differ diff --git a/IS_Net/basics.py b/IS_Net/basics.py new file mode 100644 index 0000000000000000000000000000000000000000..1fca2f6cec280184f5384db6ec5ccae1b34b247a --- /dev/null +++ b/IS_Net/basics.py @@ -0,0 +1,125 @@ +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '2' +from skimage import io, transform +import torch +import torchvision +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +import torch.optim as optim +from skimage.metrics import structural_similarity as ssim +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image +import glob +import cv2 +from scipy.stats import pearsonr + +def mae_torch(pred,gt): + + h,w = gt.shape[0:2] + sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float()))) + maeError = torch.divide(sumError,float(h)*float(w)*255.0+1e-4) + + return maeError + +import torch + +def maximal_f_measure_torch(pd, gt): + gtNum = torch.sum((gt > 128).float() * 1) # 计算真实标签中像素值大于128的数量 + + # 从预测张量中提取正例和负例 + pp = pd[gt > 128] + nn = pd[gt <= 128] + + # 计算正例和负例的直方图 + pp_hist = torch.histc(pp, bins=255, min=0, max=255) + nn_hist = torch.histc(nn, bins=255, min=0, max=255) + + # 反转直方图并计算累积和 + pp_hist_flip = torch.flipud(pp_hist) + nn_hist_flip = torch.flipud(nn_hist) + + pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0) + nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0) + + # 计算Precision、Recall 和 F-measure + precision = (pp_hist_flip_cum) / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-4) + recall = (pp_hist_flip_cum) / (gtNum + 1e-4) + f_measure = (2 * precision * recall) / (precision + recall + 1e-4) + + # 找到最大F-measure及其对应的阈值 + max_f_measure, threshold = torch.max(f_measure, dim=0) + + return max_f_measure.item(), threshold.item() + +def calculate_meam(image1, image2): + # 直方图均衡化 + image1_equalized = cv2.equalizeHist(image1) + image2_equalized = cv2.equalizeHist(image2) + + # 计算Pearson相关系数 + correlation_coefficient, _ = pearsonr(image1_equalized.flatten(), image2_equalized.flatten()) + + # 计算MEAM值 + meam_value = correlation_coefficient * np.mean(np.minimum(image1_equalized, image2_equalized)) + + return meam_value + +def f1score_torch(pd,gt): + + # print(gt.shape) + gtNum = torch.sum((gt>128).float()*1) ## number of ground truth pixels + + pp = pd[gt>128] + nn = pd[gt<=128] + + pp_hist =torch.histc(pp,bins=255,min=0,max=255) + nn_hist = torch.histc(nn,bins=255,min=0,max=255) + + + pp_hist_flip = torch.flipud(pp_hist) + nn_hist_flip = torch.flipud(nn_hist) + + pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0) + nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0) + + precision = (pp_hist_flip_cum)/(pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)#torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4)) + recall = (pp_hist_flip_cum)/(gtNum + 1e-4) + f1 = (1+0.3)*precision*recall/(0.3*precision+recall + 1e-4) + + return torch.reshape(precision,(1,precision.shape[0])),torch.reshape(recall,(1,recall.shape[0])),torch.reshape(f1,(1,f1.shape[0])) + + +def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar): + + import time + tic = time.time() + + if(len(gt.shape)>2): + gt = gt[:,:,0] + # if pred.shape != gt.shape: + # plt.imshow(pred.cpu().detach().numpy()) + # plt.show() + # plt.imshow(gt.cpu().detach().numpy()) + # plt.show() + # pred = pred.transpose(1,0) + # print(pred.shape,gt.shape) + # print(valid_dataset.dataset["im_name"][idx]+".png") + pre, rec, f1 = f1score_torch(pred,gt) + mae = mae_torch(pred,gt) + + # hypar["valid_out_dir"] = hypar["valid_out_dir"]+"-eval" ### + if(hypar["valid_out_dir"]!=""): + if(not os.path.exists(hypar["valid_out_dir"])): + os.mkdir(hypar["valid_out_dir"]) + dataset_folder = os.path.join(hypar["valid_out_dir"],valid_dataset.dataset["data_name"][idx]) + if(not os.path.exists(dataset_folder)): + os.mkdir(dataset_folder) + io.imsave(os.path.join(dataset_folder,valid_dataset.dataset["im_name"][idx]+".png"),pred.cpu().data.numpy().astype(np.uint8)) + # print(valid_dataset.dataset["im_name"][idx]+".png") + # print("time for evaluation : ", time.time()-tic) + + return pre.cpu().data.numpy(), rec.cpu().data.numpy(), f1.cpu().data.numpy(), mae.cpu().data.numpy() diff --git a/IS_Net/data_loader.py b/IS_Net/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7001a3b004e13b77a0cdc7cdb53f1bb39ccbfb --- /dev/null +++ b/IS_Net/data_loader.py @@ -0,0 +1,542 @@ +## data loader +## Ackownledgement: +## We would like to thank Dr. Ibrahim Almakky (https://scholar.google.co.uk/citations?user=T9MTcK0AAAAJ&hl=en) +## for his helps in implementing cache machanism of our DIS dataloader. +from __future__ import print_function, division + +import numpy as np +import random +from copy import deepcopy +import json +from tqdm import tqdm +from skimage import io +import os +from glob import glob +import matplotlib.pyplot as plt +from PIL import Image, ImageOps +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +from torchvision.transforms.functional import normalize +import torch.nn.functional as F +import cv2 +from scipy.ndimage import label + +def show_gray_images(images, m=4): + """ + 展示一组灰度图像 + + 参数: + images: 一个形状为(n, h, w)的数组,其中n是图像的数量,h和w分别是图像的高度和宽度。 + m: 每行展示的图像数量,默认为4。 + + 返回值: + 无 + """ + n, h, w = images.shape # 获取输入图像的数量、高度和宽度 + num_rows = (n + m - 1) // m # 计算需要的行数 + fig, axes = plt.subplots(num_rows, m, figsize=(m*2, num_rows*2)) # 创建画布和子图 + plt.subplots_adjust(wspace=0.05, hspace=0.05) # 调整子图间的间距 + for i in range(num_rows): + for j in range(m): + idx = i*m + j # 计算当前图像的索引 + if idx < n: + axes[i, j].imshow(images[idx], cmap='gray') # 展示图像 + axes[i, j].axis('off') # 关闭坐标轴显示 + plt.show() # 显示图像 +#### --------------------- DIS dataloader cache ---------------------#### + +def segment_connected_components(mask): + # 将mask转换为PyTorch张量 + mask_tensor = torch.tensor(mask) + + # 使用Scipy的label函数找到连通组件 + labeled_array, num_features = label(mask_tensor.numpy()) + + # 创建一个字典来存储每个连通组件的像素值 + components = {} + for label_idx in range(1, num_features + 1): + component_mask = (labeled_array == label_idx) + components[label_idx] = component_mask.astype(int) + + return components + +def FillHole(im_in): + img = np.array(im_in,dtype=np.uint8)[0] + mask = np.zeros_like(img) + contours, _ = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + cv2.drawContours(mask, [contour], -1, 255, thickness=cv2.FILLED) + im_out = torch.from_numpy(mask)[None,...].float() + return im_out + +def get_im_gt_name_dict(datasets, flag='valid'): + print("------------------------------", flag, "--------------------------------") + name_im_gt_mid_list = [] + for i in range(len(datasets)): + print("--->>>", flag, " dataset ",i,"/",len(datasets)," ",datasets[i]["name"],"<<<---") + tmp_im_list, tmp_gt_list, tmp_mid_list = [], [], [] + tmp_im_list = glob(datasets[i]["im_dir"]+os.sep+'*'+datasets[i]["im_ext"]) + + # img_name_dict[im_dirs[i][0]] = tmp_im_list + # print('-im-',datasets[i]["name"],datasets[i]["im_dir"], ': ',len(tmp_im_list)) + + if(datasets[i]["gt_dir"]==""): + print('-gt-', datasets[i]["name"], datasets[i]["gt_dir"], ': ', 'No Ground Truth Found') + tmp_gt_list = [] + else: + tmp_gt_list = [datasets[i]["gt_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["gt_ext"] for x in tmp_im_list] + + # lbl_name_dict[im_dirs[i][0]] = tmp_gt_list + # print('-gt-', datasets[i]["name"],datasets[i]["gt_dir"], ': ',len(tmp_gt_list)) + + if(datasets[i]["mid_dir"]==""): + print('-mid-', datasets[i]["name"], datasets[i]["mid_dir"], ': ', 'No mid Found') + tmp_mid_list = [] + else: + tmp_mid_list = [datasets[i]["mid_dir"]+os.sep+x.split(os.sep)[-1].split(datasets[i]["im_ext"])[0]+datasets[i]["mid_ext"] for x in tmp_im_list] + + # lbl_name_dict[im_dirs[i][0]] = tmp_gt_list + # print('-mid-', datasets[i]["name"],datasets[i]["mid_dir"], ': ',len(tmp_gt_list)) + + + + if flag=="train": ## combine multiple training sets into one dataset + if len(name_im_gt_mid_list)==0: + name_im_gt_mid_list.append({"dataset_name":datasets[i]["name"], + "im_path":tmp_im_list, + "gt_path":tmp_gt_list, + "mid_path":tmp_mid_list, + "im_ext":datasets[i]["im_ext"], + "gt_ext":datasets[i]["gt_ext"], + "mid_ext":datasets[i]["mid_ext"], + "cache_dir":datasets[i]["cache_dir"]}) + else: + name_im_gt_mid_list[0]["dataset_name"] = name_im_gt_mid_list[0]["dataset_name"] + "_" + datasets[i]["name"] + name_im_gt_mid_list[0]["im_path"] = name_im_gt_mid_list[0]["im_path"] + tmp_im_list + name_im_gt_mid_list[0]["gt_path"] = name_im_gt_mid_list[0]["gt_path"] + tmp_gt_list + name_im_gt_mid_list[0]["mid_path"] = name_im_gt_mid_list[0]["mid_path"] + tmp_mid_list + if datasets[i]["im_ext"]!=".jpg" or datasets[i]["gt_ext"]!=".png": + print("Error: Please make sure all you images and ground truth masks are in jpg and png format respectively !!!") + exit() + name_im_gt_mid_list[0]["im_ext"] = ".jpg" + name_im_gt_mid_list[0]["gt_ext"] = ".png" + name_im_gt_mid_list[0]["mid_ext"] = ".png" + name_im_gt_mid_list[0]["cache_dir"] = os.sep.join(datasets[i]["cache_dir"].split(os.sep)[0:-1])+os.sep+name_im_gt_mid_list[0]["dataset_name"] + else: ## keep different validation or inference datasets as separate ones + name_im_gt_mid_list.append({"dataset_name":datasets[i]["name"], + "im_path":tmp_im_list, + "gt_path":tmp_gt_list, + "mid_path":tmp_mid_list, + "im_ext":datasets[i]["im_ext"], + "gt_ext":datasets[i]["gt_ext"], + "mid_ext":datasets[i]["mid_ext"], + "cache_dir":datasets[i]["cache_dir"]}) + + return name_im_gt_mid_list + +def create_dataloaders(name_im_gt_mid_list, cache_size=[], cache_boost=True, my_transforms=[], batch_size=1, shuffle=False,is_train=True): + ## model="train": return one dataloader for training + ## model="valid": return a list of dataloaders for validation or testing + + gos_dataloaders = [] + gos_datasets = [] + + if(len(name_im_gt_mid_list)==0): + return gos_dataloaders, gos_datasets + + num_workers_ = 0 + # if(batch_size>1): + # num_workers_ = 2 + # if(batch_size>4): + # num_workers_ = 4 + # if(batch_size>8): + # num_workers_ = 8 + + for i in range(0,len(name_im_gt_mid_list)): + gos_dataset = GOSDatasetCache([name_im_gt_mid_list[i]], + cache_size = cache_size, + cache_path = name_im_gt_mid_list[i]["cache_dir"], + cache_boost = cache_boost, + transform = transforms.Compose(my_transforms), + is_train=is_train) + gos_dataloaders.append(DataLoader(gos_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers_)) + gos_datasets.append(gos_dataset) + + return gos_dataloaders, gos_datasets + +def im_reader(im_path): + image = Image.open(im_path).convert('RGB') + corrected_image = ImageOps.exif_transpose(image) + # return plt.imread(im_path) + return np.array(corrected_image) + +def im_preprocess(im,size): + if len(im.shape) > 3: + im = im[:,:,:3] + if len(im.shape) < 3: + im = im[:, :, np.newaxis] + if im.shape[2] == 1: + im = np.repeat(im, 3, axis=2) + im_tensor = torch.tensor(im.copy(), dtype=torch.float32) + im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1) + if(len(size)<2): + return im_tensor, im.shape[0:2] + else: + im_tensor = torch.unsqueeze(im_tensor,0) + im_tensor = F.upsample(im_tensor, size, mode="bilinear") + im_tensor = torch.squeeze(im_tensor,0) + + return im_tensor.type(torch.uint8), im.shape[0:2] + +def gt_preprocess(gt,size): + if len(gt.shape) > 2: + gt = gt[:, :, 0] + + gt_tensor = torch.unsqueeze(torch.tensor(gt, dtype=torch.uint8),0) + + if(len(size)<2): + return gt_tensor.type(torch.uint8), gt.shape[0:2] + else: + gt_tensor = torch.unsqueeze(torch.tensor(gt_tensor, dtype=torch.float32),0) + gt_tensor = F.upsample(gt_tensor, size, mode="bilinear") + gt_tensor = torch.squeeze(gt_tensor,0) + + return gt_tensor.type(torch.uint8), gt.shape[0:2] + # return gt_tensor, gt.shape[0:2] + +class GOSRandomHFlip(object): + def __init__(self,prob=0.25): + self.prob = prob + def __call__(self,sample): + imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask'] + + # random horizontal flip + randomnum = random.random() + if randomnum <= self.prob: + image = torch.flip(image,dims=[2]) + label = torch.flip(label,dims=[2]) + box = torch.flip(box,dims=[2]) + mask = torch.flip(mask,dims=[2]) + elif randomnum <= self.prob*2: + image = torch.flip(image,dims=[1]) + label = torch.flip(label,dims=[1]) + box = torch.flip(box,dims=[1]) + mask = torch.flip(mask,dims=[1]) + elif randomnum <= self.prob*3: + image = torch.flip(image,dims=[2]) + label = torch.flip(label,dims=[2]) + box = torch.flip(box,dims=[2]) + mask = torch.flip(mask,dims=[2]) + image = torch.flip(image,dims=[1]) + label = torch.flip(label,dims=[1]) + box = torch.flip(box,dims=[1]) + mask = torch.flip(mask,dims=[1]) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box} + +class GOSResize(object): + def __init__(self,size=[320,320]): + self.size = size + def __call__(self,sample): + imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask'] + + # import time + # start = time.time() + + image = torch.squeeze(F.upsample(torch.unsqueeze(image,0),self.size,mode='bilinear'),dim=0) + label = torch.squeeze(F.upsample(torch.unsqueeze(label,0),self.size,mode='bilinear'),dim=0) + + # print("time for resize: ", time.time()-start) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box} + +class GOSRandomCrop(object): + def __init__(self,size=[288,288]): + self.size = size + def __call__(self,sample): + imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask'] + + h, w = image.shape[1:] + new_h, new_w = self.size + + top = np.random.randint(0, h - new_h) + left = np.random.randint(0, w - new_w) + + image = image[:,top:top+new_h,left:left+new_w] + label = label[:,top:top+new_h,left:left+new_w] + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box} + + +class GOSNormalize(object): + def __init__(self, mean=[0.485,0.456,0.406,0], std=[0.229,0.224,0.225,1.0]): + self.mean = mean + self.std = std + + def __call__(self,sample): + + imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask'] + # print(image.shape) + image = normalize(image,self.mean,self.std) + mask = normalize(mask,0,1) + box = normalize(box,0,1) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box} + +class GOSRandomthorw(object): + def __init__(self,ratio=0.25): + self.ratio = ratio + def __call__(self,sample): + imidx, image, label, shape, box, mask = sample['imidx'], sample['image'], sample['label'], sample['shape'], sample['box'], sample['mask'] + randomnum = random.random() + if randomnum < self.ratio: + mask = torch.zeros_like(mask) + elif randomnum < self.ratio*2: + box = torch.zeros_like(box) + elif randomnum < self.ratio*3: + mask = torch.zeros_like(mask) + box = torch.zeros_like(box) + + return {'imidx':imidx,'image':image, 'label':label, 'shape':shape, 'mask':mask, 'box':box} + +class GOSDatasetCache(Dataset): + + def __init__(self, name_im_gt_mid_list, cache_size=[], cache_path='./cache', cache_file_name='dataset.json', cache_boost=False, transform=None, is_train=True): + + self.is_train = is_train + self.cache_size = cache_size + self.cache_path = cache_path + self.cache_file_name = cache_file_name + self.cache_boost_name = "" + + self.cache_boost = cache_boost + # self.ims_npy = None + # self.gts_npy = None + + ## cache all the images and ground truth into a single pytorch tensor + self.ims_pt = None + self.gts_pt = None + self.mid_pt = None + + ## we will cache the npy as well regardless of the cache_boost + # if(self.cache_boost): + self.cache_boost_name = cache_file_name.split('.json')[0] + + self.transform = transform + + self.dataset = {} + + ## combine different datasets into one + dataset_names = [] + dt_name_list = [] # dataset name per image + im_name_list = [] # image name + im_path_list = [] # im path + gt_path_list = [] # gt path + mid_path_list = [] + im_ext_list = [] # im ext + gt_ext_list = [] # gt ext + mid_ext_list = [] + for i in range(0,len(name_im_gt_mid_list)): + dataset_names.append(name_im_gt_mid_list[i]["dataset_name"]) + # dataset name repeated based on the number of images in this dataset + dt_name_list.extend([name_im_gt_mid_list[i]["dataset_name"] for x in name_im_gt_mid_list[i]["im_path"]]) + im_name_list.extend([x.split(os.sep)[-1].split(name_im_gt_mid_list[i]["im_ext"])[0] for x in name_im_gt_mid_list[i]["im_path"]]) + im_path_list.extend(name_im_gt_mid_list[i]["im_path"]) + gt_path_list.extend(name_im_gt_mid_list[i]["gt_path"]) + mid_path_list.extend(name_im_gt_mid_list[i]["mid_path"]) + im_ext_list.extend([name_im_gt_mid_list[i]["im_ext"] for x in name_im_gt_mid_list[i]["im_path"]]) + gt_ext_list.extend([name_im_gt_mid_list[i]["gt_ext"] for x in name_im_gt_mid_list[i]["gt_path"]]) + mid_ext_list.extend([name_im_gt_mid_list[i]["mid_ext"] for x in name_im_gt_mid_list[i]["mid_path"]]) + + + self.dataset["data_name"] = dt_name_list + self.dataset["im_name"] = im_name_list + self.dataset["im_path"] = im_path_list + self.dataset["ori_im_path"] = deepcopy(im_path_list) + self.dataset["gt_path"] = gt_path_list + self.dataset["ori_gt_path"] = deepcopy(gt_path_list) + self.dataset["mid_path"] = mid_path_list + self.dataset["ori_mid_path"] = deepcopy(mid_path_list) + self.dataset["im_shp"] = [] + self.dataset["gt_shp"] = [] + self.dataset["mid_shp"] = [] + self.dataset["im_ext"] = im_ext_list + self.dataset["gt_ext"] = gt_ext_list + self.dataset["mid_ext"] = mid_ext_list + + + self.dataset["ims_pt_dir"] = "" + self.dataset["gts_pt_dir"] = "" + self.dataset["mid_pt_dir"] = "" + + self.dataset = self.manage_cache(dataset_names) + + def manage_cache(self,dataset_names): + if not os.path.exists(self.cache_path): # create the folder for cache + os.makedirs(self.cache_path) + cache_folder = os.path.join(self.cache_path, "_".join(dataset_names)+"_"+"x".join([str(x) for x in self.cache_size])) + # if cache_folder.__len__() > 100: cache_folder = cache_folder[:100] + if not os.path.exists(cache_folder): # check if the cache files are there, if not then cache + return self.cache(cache_folder) + return self.load_cache(cache_folder) + + def cache(self,cache_folder): + os.mkdir(cache_folder) + cached_dataset = deepcopy(self.dataset) + + # ims_list = [] + # gts_list = [] + ims_pt_list = [] + gts_pt_list = [] + mid_pt_list = [] + for i, im_path in tqdm(enumerate(self.dataset["im_path"]), total=len(self.dataset["im_path"])): + + im_id = cached_dataset["im_name"][i] + # print("im_path: ", im_path) + im = im_reader(im_path) + im, im_shp = im_preprocess(im,self.cache_size) + im_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_im.pt") + torch.save(im,im_cache_file) + + cached_dataset["im_path"][i] = im_cache_file + if(self.cache_boost): + ims_pt_list.append(torch.unsqueeze(im,0)) + # ims_list.append(im.cpu().data.numpy().astype(np.uint8)) + + gt = np.zeros(im.shape[0:2]) + if len(self.dataset["gt_path"])!=0: + gt = im_reader(self.dataset["gt_path"][i]) + gt, gt_shp = gt_preprocess(gt,self.cache_size) + gt_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_gt.pt") + torch.save(gt,gt_cache_file) + if len(self.dataset["gt_path"])>0: + cached_dataset["gt_path"][i] = gt_cache_file + else: + cached_dataset["gt_path"].append(gt_cache_file) + if(self.cache_boost): + gts_pt_list.append(torch.unsqueeze(gt,0)) + + mid = np.zeros(im.shape[0:2]) + if len(self.dataset["mid_path"])!=0: + mid = im_reader(self.dataset["mid_path"][i]) + mid, mid_shp = gt_preprocess(mid,self.cache_size) + mid_cache_file = os.path.join(cache_folder,self.dataset["data_name"][i]+"_"+im_id + "_mid.pt") + torch.save(mid,mid_cache_file) + if len(self.dataset["mid_path"])>0: + cached_dataset["mid_path"][i] = mid_cache_file + else: + cached_dataset["mid_path"].append(mid_cache_file) + if(self.cache_boost): + mid_pt_list.append(torch.unsqueeze(mid,0)) + + # gts_list.append(gt.cpu().data.numpy().astype(np.uint8)) + + # im_shp_cache_file = os.path.join(cache_folder,im_id + "_im_shp.pt") + # torch.save(gt_shp, shp_cache_file) + cached_dataset["im_shp"].append(im_shp) + # self.dataset["im_shp"].append(im_shp) + + # shp_cache_file = os.path.join(cache_folder,im_id + "_gt_shp.pt") + # torch.save(gt_shp, shp_cache_file) + cached_dataset["gt_shp"].append(gt_shp) + # self.dataset["gt_shp"].append(gt_shp) + + cached_dataset["mid_shp"].append(mid_shp) + + if(self.cache_boost): + cached_dataset["ims_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_ims.pt') + cached_dataset["gts_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_gts.pt') + cached_dataset["mid_pt_dir"] = os.path.join(cache_folder, self.cache_boost_name+'_mids.pt') + self.ims_pt = torch.cat(ims_pt_list,dim=0) + self.gts_pt = torch.cat(gts_pt_list,dim=0) + self.mid_pt = torch.cat(mid_pt_list,dim=0) + torch.save(torch.cat(ims_pt_list,dim=0),cached_dataset["ims_pt_dir"]) + torch.save(torch.cat(gts_pt_list,dim=0),cached_dataset["gts_pt_dir"]) + torch.save(torch.cat(mid_pt_list,dim=0),cached_dataset["mid_pt_dir"]) + + try: + json_file = open(os.path.join(cache_folder, self.cache_file_name),"w") + json.dump(cached_dataset, json_file) + json_file.close() + except Exception: + raise FileNotFoundError("Cannot create JSON") + return cached_dataset + + def load_cache(self, cache_folder): + print(os.path.join(cache_folder,self.cache_file_name)) + json_file = open(os.path.join(cache_folder,self.cache_file_name),"r") + dataset = json.load(json_file) + json_file.close() + ## if cache_boost is true, we will load the image npy and ground truth npy into the RAM + ## otherwise the pytorch tensor will be loaded + if(self.cache_boost): + # self.ims_npy = np.load(dataset["ims_npy_dir"]) + # self.gts_npy = np.load(dataset["gts_npy_dir"]) + self.ims_pt = torch.load(dataset["ims_pt_dir"], map_location='cpu') + self.gts_pt = torch.load(dataset["gts_pt_dir"], map_location='cpu') + self.mid_pt = torch.load(dataset["mid_pt_dir"], map_location='cpu') + return dataset + + def __len__(self): + return len(self.dataset["im_path"]) + + def __getitem__(self, idx): + + im = None + gt = None + mid = None + if(self.cache_boost and self.ims_pt is not None): + + # start = time.time() + im = self.ims_pt[idx]#.type(torch.float32) + gt = self.gts_pt[idx]#.type(torch.float32) + mid = self.mid_pt[idx]#.type(torch.float32) + # print(idx, 'time for pt loading: ', time.time()-start) + + else: + # import time + # start = time.time() + # print("tensor***") + im_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["im_path"][idx].split(os.sep)[-2:])) + im = torch.load(im_pt_path)#(self.dataset["im_path"][idx]) + gt_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["gt_path"][idx].split(os.sep)[-2:])) + gt = torch.load(gt_pt_path)#(self.dataset["gt_path"][idx]) + mid_pt_path = os.path.join(self.cache_path,os.sep.join(self.dataset["mid_path"][idx].split(os.sep)[-2:])) + mid = torch.load(mid_pt_path)#(self.dataset["gt_path"][idx]) + # print(idx,'time for tensor loading: ', time.time()-start) + + + im_shp = self.dataset["im_shp"][idx] + # print("time for loading im and gt: ", time.time()-start) + + box = torch.zeros_like(gt[0])+gt[0] + rows, cols = torch.where(box>0) + left = torch.min(cols) + top = torch.min(rows) + right = torch.max(cols) + bottom = torch.max(rows) + box[top:bottom,left:right] = 255 + box[box!=255] = 0 + box = box[None,...] + gim = torch.cat([im,mid,box],dim=0) + + # start_time = time.time() + im = torch.divide(gim,255.0) + gt = torch.divide(gt,255.0) + mask = torch.divide(mid,255.0) + box = torch.divide(box,255.0) + + + sample = { + "imidx": torch.from_numpy(np.array(idx)), + "image": im, + "label": gt, + "mask": mask, + 'box': box, + "shape": torch.from_numpy(np.array(im_shp)), + } + + if self.transform: + sample = self.transform(sample) + return sample diff --git a/IS_Net/datalist.py b/IS_Net/datalist.py new file mode 100644 index 0000000000000000000000000000000000000000..25af8601adbc95a2e0e6e8dfb8d24078fc7eb290 --- /dev/null +++ b/IS_Net/datalist.py @@ -0,0 +1,62 @@ +dataset_test = {"name": "DIS5K-test", + "im_dir": r"DIS5K/DIS5K-test/im", + "gt_dir": r"DIS5K/DIS5K-test/gt", + "mid_dir":r"DIS5K/DIS5K-test/mask", + "im_ext": ".jpg", + "gt_ext": ".png", + "mid_ext": ".png", + "cache_dir":r"DIS5K-Cache/DIS-test"} + +dataset_tr = {"name": "DIS5K-TR-m", + "im_dir": r"DIS5K/DIS-TR/im", + "gt_dir": r"DIS5K/DIS-TR/gt", + "mid_dir":r"DIS5K-TR/mask", + "im_ext": ".jpg", + "gt_ext": ".png", + "mid_ext": ".png", + "cache_dir":r"DIS5K-Cache/DIS-TR-m"} + +dataset_vd = {"name": "DIS5K-VD-m", + "im_dir": r"DIS5K/DIS-VD/im", + "gt_dir": r"DIS5K/DIS-VD/gt", + "mid_dir":r"DIS5K/DIS5K-VD/mask", + "im_ext": ".jpg", + "gt_ext": ".png", + "mid_ext": ".png", + "cache_dir":r"DIS5K-Cache/DIS-VD-m"} + +dataset_te1 = {"name": "DIS5K-TE1-m", + "im_dir": r"DIS5K/DIS-TE1/im", + "gt_dir": r"DIS5K/DIS-TE1/gt", + "mid_dir":r"DIS5K/DIS5K-TE1/mask", + "im_ext": ".jpg", + "gt_ext": ".png", + "mid_ext": ".png", + "cache_dir":r"DIS5K-Cache/DIS-TE1-m"} + +dataset_te2 = {"name": "DIS5K-TE2-m", + "im_dir": r"DIS5K/DIS-TE2/im", + "gt_dir": r"DIS5K/DIS-TE2/gt", + "mid_dir":r"DIS5K/DIS5K-TE2/mask", + "im_ext": ".jpg", + "gt_ext": ".png", + "mid_ext": ".png", + "cache_dir":r"DIS5K-Cache/DIS-TE2-m"} + +dataset_te3 = {"name": "DIS5K-TE3-m", + "im_dir": r"DIS5K/DIS-TE3/im", + "gt_dir": r"DIS5K/DIS-TE3/gt", + "mid_dir":r"DIS5K/DIS5K-TE3/mask", + "im_ext": ".jpg", + "gt_ext": ".png", + "mid_ext": ".png", + "cache_dir":r"DIS5K-Cache/DIS-TE3-m"} + +dataset_te4 = {"name": "DIS5K-TE4-m", + "im_dir": r"DIS5K/DIS-TE4/im", + "gt_dir": r"DIS5K/DIS-TE4/gt", + "mid_dir":r"DIS5K/DIS5K-TE4/mask", + "im_ext": ".jpg", + "gt_ext": ".png", + "mid_ext": ".png", + "cache_dir":r"DIS5K-Cache/DIS-TE4-m"} \ No newline at end of file diff --git a/IS_Net/models/__pycache__/isnet.cpython-311.pyc b/IS_Net/models/__pycache__/isnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e65468f59acb0bd3b3f141b66542e2517d480dd Binary files /dev/null and b/IS_Net/models/__pycache__/isnet.cpython-311.pyc differ diff --git a/IS_Net/models/isnet.py b/IS_Net/models/isnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2ebcb06b17e0ce204134684a1879dd86cb0138 --- /dev/null +++ b/IS_Net/models/isnet.py @@ -0,0 +1,640 @@ +import torch +import torch.nn as nn +from torchvision import models +import torch.nn.functional as F +from timm.models.layers import trunc_normal_, DropPath +import matplotlib.pyplot as plt +import monai + +def iou_loss(pred, mask): + inter = (pred * mask).sum(dim=(2, 3)) #交集 + union = (pred + mask).sum(dim=(2, 3)) - inter #并集-交集 + iou = 1 - (inter + 1) / (union + 1) + return iou.mean() + + +bce_loss = nn.BCELoss(reduction='mean') + +def muti_loss_fusion(preds, target): + loss0 = 0.0 + loss = 0.0 + + for i in range(0,len(preds)): + # print("i: ", i, preds[i].shape) + if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]): + # tmp_target = _upsample_like(target,preds[i]) + tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True) + loss = loss + 20*bce_loss(preds[i],tmp_target) + 0.5*iou_loss(preds[i],tmp_target) + # loss = loss + bce_loss(preds[i],tmp_target)+ iou_loss(preds[i],tmp_target) + # loss = loss + bce_loss(preds[i],tmp_target) + else: + loss = loss + 20*bce_loss(preds[i],target) + 0.5*iou_loss(preds[i],target) + # loss = loss + bce_loss(preds[i],target) + iou_loss(preds[i],target) + # loss = loss + bce_loss(preds[i],target) + if(i==0): + loss0 = loss + return loss0, loss + +MSE_loss = nn.MSELoss(reduction='mean') +kl_loss = nn.KLDivLoss(reduction='mean') +l1_loss = nn.L1Loss(reduction='mean') +smooth_l1_loss = nn.SmoothL1Loss(reduction='mean') +def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'): + loss0 = 0.0 + loss = 0.0 + + for i in range(0,len(preds)): + # print("i: ", i, preds[i].shape) + if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]): + # tmp_target = _upsample_like(target,preds[i]) + tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True) + loss = loss + 20*bce_loss(preds[i],tmp_target) + 0.5*iou_loss(preds[i],tmp_target) + # loss = loss + bce_loss(preds[i],tmp_target) + iou_loss(preds[i],tmp_target) + # loss = loss + bce_loss(preds[i],tmp_target) + else: + loss = loss + 20*bce_loss(preds[i],target) + 0.5*iou_loss(preds[i],target) + # loss = loss + bce_loss(preds[i],target) + iou_loss(preds[i],target) + # loss = loss + bce_loss(preds[i],target) + if(i==0): + loss0 = loss + + for i in range(0,len(dfs)): + if(mode=='MSE'): + loss = loss + MSE_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints + # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item()) + elif(mode=='KL'): + loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)) + # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item()) + elif(mode=='MAE'): + loss = loss + l1_loss(dfs[i],fs[i]) + # print("ls_loss: ", l1_loss(dfs[i],fs[i])) + elif(mode=='SmoothL1'): + loss = loss + smooth_l1_loss(dfs[i],fs[i]) + # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item()) + + return loss0, loss + +class REBNCONV(nn.Module): + def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1): + super(REBNCONV,self).__init__() + + self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self,x): + + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src,tar): + + src = F.upsample(src,size=tar.shape[2:],mode='bilinear') + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7,self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) + hx6dup = _upsample_like(hx6d,hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + + hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-5 ### +class RSU5(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4 ### +class RSU4(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) + self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) + + return hx1d + hxin + +### RSU-4F ### +class RSU4F(nn.Module): + + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F,self).__init__() + + self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) + + self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) + self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) + self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) + self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) + self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) + + def forward(self,x): + + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) + hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) + hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__(self, in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1): + super(myrebnconv,self).__init__() + + self.conv = nn.Conv2d(in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self,x): + return self.rl(self.bn(self.conv(x))) + + +class ISNetGTEncoder(nn.Module): + + def __init__(self,in_ch=1,out_ch=1): + super(ISNetGTEncoder,self).__init__() + + self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1) + + self.stage1 = RSU7(16,16,64) + self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage2 = RSU6(64,16,64) + self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage3 = RSU5(64,32,128) + self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage4 = RSU4(128,32,256) + self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage5 = RSU4F(256,64,512) + self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage6 = RSU4F(512,64,512) + + + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(128,out_ch,3,padding=1) + self.side4 = nn.Conv2d(256,out_ch,3,padding=1) + self.side5 = nn.Conv2d(512,out_ch,3,padding=1) + self.side6 = nn.Conv2d(512,out_ch,3,padding=1) + + def compute_loss(self, preds, targets): + + return muti_loss_fusion(preds,targets) + + def forward(self,x): + + hx = x + + hxin = self.conv_in(hx) + # hx = self.pool_in(hxin) + + #stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + #stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + #stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + #stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + #stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + #stage 6 + hx6 = self.stage6(hx) + + + #side output + d1 = self.side1(hx1) + d1 = _upsample_like(d1,x) + + d2 = self.side2(hx2) + d2 = _upsample_like(d2,x) + + d3 = self.side3(hx3) + d3 = _upsample_like(d3,x) + + d4 = self.side4(hx4) + d4 = _upsample_like(d4,x) + + d5 = self.side5(hx5) + d5 = _upsample_like(d5,x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6,x) + + # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + + return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6] + +class ISNetDIS(nn.Module): + + def __init__(self,in_ch=3,out_ch=1): + super(ISNetDIS,self).__init__() + + self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1) + self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + + self.stage1 = RSU7(64,32,64) + self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage2 = RSU6(64,32,128) + self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage3 = RSU5(128,64,256) + self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage4 = RSU4(256,128,512) + self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage5 = RSU4F(512,256,512) + self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) + + self.stage6 = RSU4F(512,256,512) + + # decoder + self.stage5d = RSU4F(1024,256,512) + self.stage4d = RSU4(1024,128,256) + self.stage3d = RSU5(512,64,128) + self.stage2d = RSU6(256,32,64) + self.stage1d = RSU7(128,16,64) + + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(128,out_ch,3,padding=1) + self.side4 = nn.Conv2d(256,out_ch,3,padding=1) + self.side5 = nn.Conv2d(512,out_ch,3,padding=1) + self.side6 = nn.Conv2d(512,out_ch,3,padding=1) + + # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) + + def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'): + + # return muti_loss_fusion(preds,targets) + return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode) + + def compute_loss(self, preds, targets): + + # return muti_loss_fusion(preds,targets) + return muti_loss_fusion(preds, targets) + + def forward(self,x): + + hx = x + + hxin = self.conv_in(hx) + + #stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + #stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + #stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + #stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + #stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + #stage 6 + hx6 = self.stage6(hx) + + hx6up = _upsample_like(hx6,hx5) + + #-------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat([hx6up,hx5],1)) + hx5dup = _upsample_like(hx5d,hx4) + + hx4d = self.stage4d(torch.cat([hx5dup,hx4],1)) + hx4dup = _upsample_like(hx4d,hx3) + + hx3d = self.stage3d(torch.cat([hx4dup,hx3],1)) + hx3dup = _upsample_like(hx3d,hx2) + + hx2d = self.stage2d(torch.cat([hx3dup,hx2],1)) + hx2dup = _upsample_like(hx2d,hx1) + + hx1d = self.stage1d(torch.cat([hx2dup,hx1],1)) + + + #side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1,x) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2,x) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3,x) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4,x) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5,x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6,x) + + # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + # plt.imshow(hx1d[0][0].cpu().detach().numpy(),cmap='gray') + # plt.show() + # plt.imshow(hx2d[0][0].cpu().detach().numpy(),cmap='gray') + # plt.show() + # plt.imshow(hx3d[0][0].cpu().detach().numpy(),cmap='gray') + # plt.show() + # plt.imshow(hx4d[0][0].cpu().detach().numpy(),cmap='gray') + # plt.show() + # plt.imshow(hx5d[0][0].cpu().detach().numpy(),cmap='gray') + # plt.show() + # plt.imshow(hx6[0][0].cpu().detach().numpy(),cmap='gray') + # plt.show() + return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6] diff --git a/IS_Net/saliency_toolbox.py b/IS_Net/saliency_toolbox.py new file mode 100644 index 0000000000000000000000000000000000000000..16ab611728b866bfd2bd1e879fbc33f7e65f612d --- /dev/null +++ b/IS_Net/saliency_toolbox.py @@ -0,0 +1,552 @@ +import os +import cv2 +import sys +import numpy as np +from glob import glob +from tqdm import tqdm +from scipy.ndimage import correlate +from scipy.ndimage.morphology import distance_transform_edt +from joblib import Parallel, delayed + +eps = sys.float_info.epsilon + +def calcualte_once(gt_name,sm_dir,gt_threshold,beta,measures): + values = dict() + for idx in measures: + values[idx] = list() + if idx == 'Max-F': + values['Precision'] = list() + values['Recall'] = list() + _, name = os.path.split(gt_name) + sm_name = os.path.join(sm_dir, name) + + if os.path.exists(sm_name): + + gt, sm = read_and_normalize(gt_name, sm_name, gt_threshold) + + if 'MAE' in measures: + values['MAE'].append(mean_square_error(gt, sm)) + if 'E-measure' in measures: + values['E-measure'].append(e_measure(gt, sm)) + if 'S-measure' in measures: + values['S-measure'].append(s_measure(gt, sm)) + if 'Adp-F' in measures: + values['Adp-F'].append(adaptive_fmeasure(gt, sm, beta)) + if 'Wgt-F' in measures: + values['Wgt-F'].append(weighted_fmeasure(gt, sm)) + if 'Max-F' in measures: + prec, recall = prec_recall(gt, sm, 256) # 256 thresholds between 0 and 1 + values['Precision'].append(prec) + values['Recall'].append(recall) + else: + print("\n{} not found!".format(os.path.basename(sm_name))) + print('---' * 10) + return values + +def calculate_measures(gt_dir, sm_dir, measures, save=False, beta=np.sqrt(0.3), gt_threshold=0.5, n_thread=1): + """ + function that calculates Saliency measures for given directories + + arameters + ---------- + gt_dir : str + The path to the ground truth directory + sm_dir : str + The path to the predicted saliency map directory + measures : list + list of measure names which need to be calculated + supported measures: 'MAE' => Mean Squared Error + 'E-measure' => Enhanced-alignment measure + 'S-measure' => Structure-measure + 'Max-F' => Maximum F-measure + 'Adp-F' => Adaptive F-measure + 'Wgt-F' => Weighted F-measure + save : str + If spesified, the results will be saved in 'save' directory + beta : float + beta parameter that is used in F-measure formula. default is sqrt(0.3) + gt_threshold : float + The threshold that is used to binrize ground truth maps. + + Returns + ------- + values : dictionary + a dict containing the results + """ + + values = dict() + for idx in measures: + values[idx] = list() + if idx == 'Max-F': + values['Precision'] = list() + values['Recall'] = list() + + results = Parallel(n_jobs=n_thread)(delayed(calcualte_once)(gt_name,sm_dir,gt_threshold,beta,measures) for gt_name in tqdm(glob(os.path.join(gt_dir, '*')), total=len(glob(os.path.join(gt_dir, '*'))))) + for i in results: + if 'MAE' in measures: + values['MAE'].append(i["MAE"]) + if 'E-measure' in measures: + values['E-measure'].append(i["E-measure"]) + if 'S-measure' in measures: + values['S-measure'].append(i["S-measure"]) + if 'Adp-F' in measures: + values['Adp-F'].append(i["Adp-F"]) + if 'Wgt-F' in measures: + values['Wgt-F'].append(i["Wgt-F"]) + if 'Max-F' in measures: # 256 thresholds between 0 and 1 + values['Precision'].append(i["Precision"]) + values['Recall'].append(i["Recall"]) + + if 'MAE' in measures: + values['MAE'] = np.mean(values['MAE']) + + if 'E-measure' in measures: + values['E-measure'] = np.mean(values['E-measure']) + + if 'S-measure' in measures: + values['S-measure'] = np.mean(values['S-measure']) + + if 'Adp-F' in measures: + values['Adp-F'] = np.mean(values['Adp-F']) + + if 'Wgt-F' in measures: + values['Wgt-F'] = np.mean(values['Wgt-F']) + + if 'Max-F' in measures: + values['Precision'] = np.mean(np.hstack(values['Precision'][:]), 1) + values['Recall'] = np.mean(np.hstack(values['Recall'][:]), 1) + f_measures = (1 + beta ** 2) * values['Precision'] * values['Recall'] / ( + beta ** 2 * values['Precision'] + values['Recall']) + values['Fmeasure_all_thresholds'] = f_measures + values['Max-F'] = np.max(f_measures) + + if save: + if not os.path.isdir(save): + os.mkdir(save) + for key in values.keys(): + np.save(os.path.join(save, key + ".npy"), values[key]) + + return values + + +def read_and_normalize(gt_path, sm_path, gt_threshold=0.5): + """ + function that reads, normalizes and crops a ground truth and a saliency map + + parameters + ---------- + gt_path : str + The path to a ground truth map + sm_path : str + The path to a predicted saliency map + gt_threshold : float + The threshold that is used to binrize ground truth maps. + + Returns + ------- + gt_img, sm_img : numpy.ndarray + The prepared arrays + """ + gt_img = norm_img(cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)) + gt_img = (gt_img >= gt_threshold).astype(np.float32) + sm_img = norm_img(cv2.imread(sm_path, cv2.IMREAD_GRAYSCALE)) + if sm_img.shape[0] != gt_img.shape[0] or sm_img.shape[1] != gt_img.shape[1]: + sm_img = cv2.resize(sm_img, (gt_img.shape[1], gt_img.shape[0])) + + return gt_img, sm_img + + +def norm_img(im): + return cv2.normalize(im.astype('float'), + None, + 0.0, 1.0, + cv2.NORM_MINMAX) + + +# MAE +def mean_square_error(gt, sm): + return np.mean(np.abs(sm - gt)) + + +# E-measure +# article: https://arxiv.org/abs/1805.10421 +# original code [Matlab]: https://github.com/DengPingFan/E-measure +def e_measure(gt, sm): + """ + This fucntion computes the Enhanced-alignment Measure (E-Measure) between the saliency map and the ground truth + article: https://arxiv.org/abs/1805.10421 + original code [Matlab]: https://github.com/DengPingFan/E-measure + + parameters + ---------- + gt : numpy.ndarray + The path to the ground truth directory + sm : numpy.ndarray + The path to the predicted saliency map directory + + Returns + ------- + value : float + The calculated E-masure + """ + sm = adptive_binary(sm) + + gt = gt.astype(np.bool_) + sm = sm.astype(np.bool_) + + dgt = gt.astype(np.float32) + dsm = sm.astype(np.float32) + + if np.sum(dgt) == 0: # if the gt is completely black + enhanced_matrix = 1.0 - dsm # only calculate the black area of intersection + elif np.mean(dgt) == 1: # if the gt is completely white + enhanced_matrix = dsm # only calcualte the white area of intersection + else: + # Normal case: + # 1.compute alignment matrix + align_matrix = alignment_term(dsm, dgt) + # 2.compute enhanced alignment matrix + enhanced_matrix = enhanced_alignment_term(align_matrix) + + height, width = gt.shape + value = np.sum(enhanced_matrix) / (height * width - 1 + eps) + return value + + +def alignment_term(dgt, dsm): + # compute global mean + mu_fm = np.mean(dsm) + mu_gt = np.mean(dgt) + + # compute the bias matrix + align_fm = dsm - mu_fm + align_gt = dgt - mu_gt + + # compute alignment matrix + align_Matrix = 2 * (align_gt * align_fm) / (align_gt * align_gt + align_fm * align_fm + eps) + return align_Matrix + + +def enhanced_alignment_term(align_matrix): + enhanced = ((align_matrix + 1) ** 2) / 4 + return enhanced + + +def adptive_binary(sm): + adaptive_threshold = 2 * np.mean(sm) + + if adaptive_threshold > 1: + adaptive_threshold = 1 + + binary_sm = (sm >= adaptive_threshold).astype(np.float32) + + return binary_sm + + +# S-Measure +# article: https://www.crcv.ucf.edu/papers/iccv17/1164.pdf +# Matlab code: https://github.com/DengPingFan/S-measure +def s_measure(gt, sm): + """ + This fucntion computes the structural similarity (S-Measure) between the saliency map and the ground truth + article: https://www.crcv.ucf.edu/papers/iccv17/1164.pdf + original code [Matlab]: https://github.com/DengPingFan/S-measure + + parameters + ---------- + gt : numpy.ndarray + The path to the ground truth directory + sm : numpy.ndarray + The path to the predicted saliency map directory + + Returns + ------- + value : float + The calculated S-masure + """ + gt_mean = np.mean(gt) + + if gt_mean == 0: # if the GT is completely black + sm_mean = np.mean(sm) + measure = 1.0 - sm_mean # only calculate the area of intersection + elif gt_mean == 1: # if the GT is completely white + sm_mean = np.mean(sm) + measure = sm_mean.copy() # only calcualte the area of intersection + else: + alpha = 0.5 + measure = alpha * s_object(sm, gt) + (1 - alpha) * s_region(sm, gt) + if measure < 0: + measure = 0 + + return measure + + +def ssim(gt, sm): + gt = gt.astype(np.float32) + + height, width = sm.shape + num_pixels = width * height + + # Compute the mean of SM,GT + sm_mean = np.mean(sm) + gt_mean = np.mean(gt) + + # Compute the variance of SM,GT + sigma_x2 = np.sum(np.sum((sm - sm_mean) ** 2)) / (num_pixels - 1 + eps) + sigma_y2 = np.sum(np.sum((gt - gt_mean) ** 2)) / (num_pixels - 1 + eps) + + # Compute the covariance + sigma_xy = np.sum(np.sum((sm - sm_mean) * (gt - gt_mean))) / (num_pixels - 1 + eps) + + alpha = 4 * sm_mean * gt_mean * sigma_xy + beta = (sm_mean ** 2 + gt_mean ** 2) * (sigma_x2 + sigma_y2) + + if alpha != 0: + ssim_value = alpha / (beta + eps) + elif alpha == 0 and beta == 0: + ssim_value = 1.0 + else: + ssim_value = 0 + + return ssim_value + + +def divide_sm(sm, x, y): + # copy the 4 regions + lt = sm[:y, :x] + rt = sm[:y, x:] + lb = sm[y:, :x] + rb = sm[y:, x:] + + return lt, rt, lb, rb + + +def divide_gt(gt, x, y): + height, width = gt.shape + area = width * height + + # copy the 4 regions + lt = gt[:y, :x] + rt = gt[:y, x:] + lb = gt[y:, :x] + rb = gt[y:, x:] + + # The different weight (each block proportional to the GT foreground region). + w1 = (x * y) / area + w2 = ((width - x) * y) / area + w3 = (x * (height - y)) / area + w4 = 1.0 - w1 - w2 - w3 + + return lt, rt, lb, rb, w1, w2, w3, w4 + + +def centroid(gt): + # col + rows, cols = gt.shape + + if np.sum(gt) == 0: + x = np.round(cols / 2) + y = np.round(rows / 2) + else: + total = np.sum(gt) + i = np.arange(cols).reshape(1, cols) + 1 + j = np.arange(rows).reshape(rows, 1) + 1 + + x = int(np.round(np.sum(np.sum(gt, 0, keepdims=True) * i) / total)) + y = int(np.round(np.sum(np.sum(gt, 1, keepdims=True) * j) / total)) + + return x, y + + +def s_region(gt, sm): + x, y = centroid(gt) + gt_1, gt_2, gt_3, gt_4, w1, w2, w3, w4 = divide_gt(gt, x, y) + + sm_1, sm_2, sm_3, sm_4 = divide_sm(sm, x, y) + + q1 = ssim(sm_1, gt_1) + q2 = ssim(sm_2, gt_2) + q3 = ssim(sm_3, gt_3) + q4 = ssim(sm_4, gt_4) + + region_value = w1 * q1 + w2 * q2 + w3 * q3 + w4 * q4 + + return region_value + + +def object(gt, sm): + x = np.mean(sm[gt == 1]) + # compute the standard deviations of the foreground or background in sm + sigma_x = np.std(sm[gt == 1]) + score = 2.0 * x / (x ** 2 + 1.0 + sigma_x + eps) + return score + + +def s_object(gt, sm): + # compute the similarity of the foreground in the object level + + sm_fg = sm.copy() + sm_fg[gt == 0] = 0 + o_fg = object(sm_fg, gt) + + # compute the similarity of the background + sm_bg = 1.0 - sm.copy() + sm_bg[gt == 1] = 0 + o_bg = object(sm_bg, gt == 0) + + u = np.mean(gt) + object_value = u * o_fg + (1 - u) * o_bg + return object_value + + + +# Weighted F-Measure +# article: https://ieeexplore.ieee.org/document/6909433 +# Matlab code: https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/FGEval/ +def weighted_fmeasure(gt, sm, beta2=1): + """ + This fucntion computes Weighted F-Measure between the saliency map and the ground truth + article: https://ieeexplore.ieee.org/document/6909433 + original code [Matlab]: https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/FGEval/ + + parameters + ---------- + gt : numpy.ndarray + The path to the ground truth directory + sm : numpy.ndarray + The path to the predicted saliency map directory + + Returns + ------- + value : float + The calculated Weighted F-Measure + """ + dst, idx = distance_transform_edt(1 - gt, return_indices=True) + + raw_idx = idx[0][gt == 0] + col_idx = idx[1][gt == 0] + + e = np.abs(sm - gt).astype(np.float32) + et = np.abs(sm - gt).astype(np.float32) + + et[gt == 0] = et[raw_idx, col_idx] + + k = matlab_style_gauss2d(shape=(7, 7), sigma=5) + + ea = correlate(et.astype(np.float32), k, mode='constant') + min_e_ea = np.abs(sm - gt).astype(np.float32) + + min_e_ea[gt * (ea < e) == 1] = ea[gt * (ea < e) == 1] + + b = np.ones_like(gt).astype(np.float32) + b[gt == 0] = 2 - 1 * np.exp(np.log(1 - 0.5) / 5. * dst[gt == 0]) + + ew = min_e_ea * b + tpw = np.sum(gt) - np.sum(ew[gt == 1]) + fpw = np.sum(ew[gt == 0]) + + rec = 1 - np.mean(ew[gt == 1]) # Weighed Recall + prec = tpw / (eps + tpw + fpw) # Weighted Precision + + value = (1 + beta2) * (rec * prec) / (eps + (beta2 * rec) + prec) + return value + +def matlab_style_gauss2d(shape=(3, 3), sigma=0.5): + """ + 2D gaussian mask - should give the same result as MATLAB's + fspecial('gaussian',[shape],[sigma]) + """ + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m + 1, -n:n + 1] + h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h /= sumh + return h + + + +# Adaptive F-measure + +def adaptive_fmeasure(gt, sm, beta): + """ + This fucntion computes Adaptive F-measure between the saliency map and the ground truth using + the binary method proposed in: + https://ieeexplore.ieee.org/document/5206596 + + parameters + ---------- + gt : numpy.ndarray + The path to the ground truth directory + sm : numpy.ndarray + The path to the predicted saliency map directory + + Returns + ------- + value : float + The calculated Adaptive F-measure + """ + gt_idx = np.where(gt > 0) + gt_cnt = np.sum(gt) + + if gt_cnt == 0: + prec = [] + recall = [] + else: + adaptive_threshold = 2 * np.mean(sm) + if adaptive_threshold > 1: + adaptive_threshold = 1 + sm_binary = (sm >= adaptive_threshold).astype(np.float32) + hit_cnt = np.sum(sm_binary[gt_idx]) + alg_cnt = np.sum(sm_binary) + + if hit_cnt == 0: + prec = 0 + recall = 0 + else: + prec = hit_cnt / (alg_cnt + eps) + recall = hit_cnt / gt_cnt + value = (1 + beta ** 2) * prec * recall / ((beta ** 2 * prec + recall) + eps) + return value + + + +def prec_recall(gt, sm, num_th): + """ + This fucntion computes Adaptive F-measure between the saliency map and the ground truth using + the binary method proposed in: + https://ieeexplore.ieee.org/document/5206596 + The results of this dunction will be used to calculate Max-F measure and plot PR and F-Threshold Curves + parameters + ---------- + gt : numpy.ndarray + The path to the ground truth directory + sm : numpy.ndarray + The path to the predicted saliency map directory + num_th : interger + The total number of thresholds between 0 and 1 + Returns + ------- + prec, recall: numpy.ndarray + The calculated Precision and Recall (shape: (num_th,1)) + """ + gt_idx = np.where(gt > 0) + gt_cnt = np.sum(gt) + + if gt_cnt == 0: + prec = [] + recall = [] + else: + hit_cnt = np.zeros((num_th, 1), np.float32) + alg_cnt = np.zeros((num_th, 1), np.float32) + thresholds = np.linspace(0, 1, num_th) + for k, curTh in enumerate(thresholds): + sm_binary = (sm >= curTh).astype(np.float32) + hit_cnt[k] = np.sum(sm_binary[gt_idx]) + alg_cnt[k] = np.sum(sm_binary) + + prec = hit_cnt / (alg_cnt + eps) + recall = hit_cnt / gt_cnt + + return prec, recall diff --git a/IS_Net/swd_optim/__init__.py b/IS_Net/swd_optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faf1eceb0e680ee291a8bda5aed18d3fa499df34 --- /dev/null +++ b/IS_Net/swd_optim/__init__.py @@ -0,0 +1,10 @@ + +from .adai import Adai +from .adais import AdaiS +from .adams import AdamS +from .sgds import SGDS + +del adai +del adais +del adams +del sgds diff --git a/IS_Net/swd_optim/adai.py b/IS_Net/swd_optim/adai.py new file mode 100644 index 0000000000000000000000000000000000000000..a833d2a107623ab286f783d2f3b4e467d8be2028 --- /dev/null +++ b/IS_Net/swd_optim/adai.py @@ -0,0 +1,116 @@ +import torch +from torch.optim.optimizer import Optimizer, required + +class Adai(Optimizer): + r"""Implements Adaptive Inertia Estimation (Adai) algorithm. + It has be proposed in + `Adai: Separating the Effects of Adaptive Learning Rate and Momentum Inertia`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + betas (Tuple[float, float], optional): beta0 and beta2 (default: (0.1, 0.99)) + eps (float, optional): the inertia bound (default: 1e-03) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + + """ + + def __init__(self, params, lr=required, betas=(0.1, 0.99), eps=1e-03, + weight_decay=0): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0]: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(Adai, self).__init__(params, defaults) + + + def __setstate__(self, state): + super(Adai, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + param_size = 0 + exp_avg_sq_hat_sum = 0. + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + param_size += p.numel() + grad = p.grad.data + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) + # Cumulative products of beta1 + state['beta1_prod'] = torch.ones_like(p.data, memory_format=torch.preserve_format) + + state['step'] += 1 + + exp_avg_sq = state['exp_avg_sq'] + beta0, beta2 = group['betas'] + + bias_correction2 = 1 - beta2 ** state['step'] + + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + exp_avg_sq_hat_sum += exp_avg_sq.sum() / bias_correction2 + + # Calculate the mean of all elements in exp_avg_sq_hat + exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + beta1_prod = state['beta1_prod'] + beta0, beta2 = group['betas'] + + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg_sq_hat = exp_avg_sq / bias_correction2 + beta1 = (1. - (exp_avg_sq_hat / exp_avg_sq_hat_mean).mul(beta0)).clamp(0., 1 - group['eps']) + + beta1_prod.mul_(beta1) + bias_correction1 = 1 - beta1_prod + + exp_avg.mul_(beta1).addcmul_(1 - beta1, grad) + exp_avg_hat = exp_avg / bias_correction1 + + step_size = group['lr'] + p.data.add_(-step_size, exp_avg_hat) + + return loss \ No newline at end of file diff --git a/IS_Net/swd_optim/adais.py b/IS_Net/swd_optim/adais.py new file mode 100644 index 0000000000000000000000000000000000000000..77f3156ca85efa569987a9e6fe8982bf95af40ff --- /dev/null +++ b/IS_Net/swd_optim/adais.py @@ -0,0 +1,120 @@ +import torch +from torch.optim.optimizer import Optimizer, required + + +class AdaiS(Optimizer): + r"""Implements Adai with stable/decoupled weight decay (AdaiS/AdaiW). + It is based on + `Adai: Separating the Effects of Adaptive Learning Rate and Momentum Inertia` + and + `Stable Weight Decay Regularization`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate + betas (Tuple[float, float], optional): beta0 and beta2 (default: (0.1, 0.99)) + eps (float, optional): the inertia bound (default: 1e-03) + weight_decay (float, optional): weight decay (default: 0) + + """ + + def __init__(self, params, lr=required, betas=(0.1, 0.99), eps=1e-03, + weight_decay=0): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0]: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(AdaiS, self).__init__(params, defaults) + + + def __setstate__(self, state): + super(AdaiS, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + param_size = 0 + exp_avg_sq_hat_sum = 0. + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + param_size += p.numel() + grad = p.grad.data + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) + # Cumulative products of beta1 + state['beta1_prod'] = torch.ones_like(p.data, memory_format=torch.preserve_format) + + exp_avg_sq = state['exp_avg_sq'] + beta0, beta2 = group['betas'] + + state['step'] += 1 + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + exp_avg_sq_hat = exp_avg_sq / bias_correction2 + + exp_avg_sq_hat_sum += exp_avg_sq_hat.sum() + + # Calculate the mean of all elements in exp_avg_sq_hat + exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + + # Perform stable/decoupled weight decay + if group['weight_decay'] !=0: + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + state = self.state[p] + + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + beta0, beta2 = group['betas'] + beta1_prod = state['beta1_prod'] + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg_sq_hat = exp_avg_sq / bias_correction2 + + beta1 = (1. - (exp_avg_sq_hat / exp_avg_sq_hat_mean).mul(beta0)).clamp(0., 1 - group['eps']) + + beta1_prod.mul_(beta1) + bias_correction1 = 1 - beta1_prod + + exp_avg.mul_(beta1).addcmul_(1 - beta1, grad) + exp_avg_hat = exp_avg.div(bias_correction1) + + step_size = group['lr'] + p.data.add_(-step_size, exp_avg_hat) + + return loss diff --git a/IS_Net/swd_optim/adams.py b/IS_Net/swd_optim/adams.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc93fc244b4d0ac635748e66c1619505ba42beb --- /dev/null +++ b/IS_Net/swd_optim/adams.py @@ -0,0 +1,137 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdamS(Optimizer): + r"""Implements Adam with stable weight decay (AdamS) algorithm. + It has be proposed in + `Stable Weight Decay Regularization`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-4) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-4, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamS, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + param_size = 0 + exp_avg_sq_hat_sum = 0. + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + param_size += p.numel() + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('AdamS does not support sparse gradients') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + exp_avg_sq_hat = max_exp_avg_sq / bias_correction2 + else: + exp_avg_sq_hat = exp_avg_sq / bias_correction2 + + exp_avg_sq_hat_sum += exp_avg_sq_hat.sum() + + # Calculate the sqrt of the mean of all elements in exp_avg_sq_hat + exp_avg_mean_sqrt = math.sqrt(exp_avg_sq_hat_sum / param_size) + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + state = self.state[p] + + #Perform stable weight decay + if group['weight_decay'] !=0: + p.data.mul_(1 - group['weight_decay'] * group['lr'] / exp_avg_mean_sqrt) + + beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + exp_avg_sq_hat = max_exp_avg_sq / bias_correction2 + else: + exp_avg_sq_hat = exp_avg_sq / bias_correction2 + + denom = exp_avg_sq_hat.sqrt().add(group['eps']) + + step_size = group['lr'] / bias_correction1 + p.addcdiv_(exp_avg, denom, value= - step_size) + + return loss diff --git a/IS_Net/swd_optim/sgds.py b/IS_Net/swd_optim/sgds.py new file mode 100644 index 0000000000000000000000000000000000000000..754ad07c0b81051f29dc88194c5fb41bf9698fce --- /dev/null +++ b/IS_Net/swd_optim/sgds.py @@ -0,0 +1,82 @@ + +import torch +from torch.optim.optimizer import Optimizer, required + + +class SGDS(Optimizer): + r"""Implements stochastic gradient descent with stable weight decay (SGDS). + It has be proposed in + `Stable Weight Decay Regularization`__. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + """ + + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(SGDS, self).__init__(params, defaults) + + def __setstate__(self, state): + super(SGDS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + d_p = p.grad + + # Perform stable weight decay + if group['weight_decay'] !=0: + bias_correction = (1 - dampening) / (1 - momentum) + p.data.mul_(1 - bias_correction * group['lr'] * group['weight_decay']) + + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + + p.add_(d_p, alpha=-group['lr']) + + return loss diff --git a/IS_Net/train_valid_inference_main.py b/IS_Net/train_valid_inference_main.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed1073d2b50bf1c31d4cb028333f902f1252613 --- /dev/null +++ b/IS_Net/train_valid_inference_main.py @@ -0,0 +1,729 @@ +import os +import time +import numpy as np +from skimage import io +import time +import matplotlib.pyplot as plt +import torch, gc +import torch.nn as nn +from torch.autograd import Variable +import torch.optim as optim +import torch.nn.functional as F +from data_loader import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache, +# from data_loader_cache import get_im_gt_name_dict, create_dataloaders, GOSRandomHFlip, GOSResize, GOSRandomCrop, GOSNormalize #GOSDatasetCache, +from basics import f1_mae_torch #normPRED, GOSPRF1ScoresCache,f1score_torch, +from models.isnet import ISNetGTEncoder, ISNetDIS +from torch.cuda.amp import autocast, GradScaler +from datalist import * +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar, train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000): + + torch.manual_seed(hypar["seed"]) + if torch.cuda.is_available(): + torch.cuda.manual_seed(hypar["seed"]) + + print("define gt encoder ...") + net = ISNetGTEncoder() #UNETGTENCODERCombine() + # if(hypar["model_digit"]=="half"): + # net.half() + ## load the existing model gt encoder + if(hypar["gt_encoder_model"]!=""): + model_path = hypar["model_path"]+"/"+hypar["gt_encoder_model"] + if torch.cuda.is_available(): + net.load_state_dict(torch.load(model_path)) + net.cuda() + else: + net.load_state_dict(torch.load(model_path,map_location="cpu")) + print("gt encoder restored from the saved weights ...") + return net ############ + + if torch.cuda.is_available(): + net.cuda() + + print("--- define optimizer for GT Encoder---") + # optimizer = lion.Lion(net.parameters(), lr=1e-4, betas=(0.9, 0.99)) + optimizer = optim.AdamW(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0) + # optimizer = optim.SGD(net.parameters(), lr=1e-4) + + model_path = hypar["model_path"] + model_save_fre = hypar["model_save_fre"] + max_ite = hypar["max_ite"] + batch_size_train = hypar["batch_size_train"] + batch_size_valid = hypar["batch_size_valid"] + + if(not os.path.exists(model_path)): + os.mkdir(model_path) + + ite_num = hypar["start_ite"] # count the total iteration number + ite_num4val = 0 # + running_loss = 0.0 # count the toal loss + running_tar_loss = 0.0 # count the target output loss + last_f1 = [0 for x in range(len(valid_dataloaders))] + + train_num = train_datasets[0].__len__() + + net.train() + + start_last = time.time() + gos_dataloader = train_dataloaders[0] + epoch_num = hypar["max_epoch_num"] + notgood_cnt = 0 + for epoch in range(epoch_num): ## set the epoch num as 100000 + + for i, data in enumerate(gos_dataloader): + + if(ite_num >= max_ite): + print("Training Reached the Maximal Iteration Number ", max_ite) + exit() + + # start_read = time.time() + ite_num = ite_num + 1 + ite_num4val = ite_num4val + 1 + + # get the inputs + labels = data['label'] + + if(hypar["model_digit"]=="full"): + labels = labels.type(torch.FloatTensor) + else: + labels = labels.type(torch.HalfTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + labels_v = Variable(labels.cuda(), requires_grad=False) + else: + labels_v = Variable(labels, requires_grad=False) + + # print("time lapse for data preparation: ", time.time()-start_read, ' s') + + # y zero the parameter gradients + start_inf_loss_back = time.time() + optimizer.zero_grad() + + # plt.imshow(labels_v[0][0].cpu(),cmap='gray') + # plt.show() + # with autocast(): + ds, fs = net(labels_v)#net(inputs_v) + loss2, loss = net.compute_loss(ds, labels_v) + # scaler.scale(loss).backward() + # loss.backward() + # scaler.step(optimizer) + # scaler.update() + #ORTHO Loss + reg = 1e-8 + orth_loss = torch.zeros(1).to(device) + for name, param in net.named_parameters(): + if 'bias' not in name: + param_flat = param.view(param.shape[0], -1) + sym = torch.mm(param_flat, torch.t(param_flat)) + sym -= torch.eye(param_flat.shape[0]).to(param.device) + orth_loss = orth_loss + (reg * sym.abs().sum()) + loss = loss + orth_loss + loss.backward() + optimizer.step() + + running_loss += loss.item() + running_tar_loss += loss2.item() + + # del outputs, loss + del ds, loss2, loss + end_inf_loss_back = time.time()-start_inf_loss_back + + print("GT Encoder Training>>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % ( + epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back)) + start_last = time.time() + + if ite_num % model_save_fre == 0: # validate every 2000 iterations + notgood_cnt += 1 + net.eval() + tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch) + # tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid_gt_encoder(net, train_dataloaders_val, train_datasets_val, hypar, epoch) + + net.train() # resume train + + tmp_out = 0 + print("last_f1:",last_f1,np.mean(last_f1)) + print("tmp_f1:",tmp_f1,np.mean(tmp_f1)) + # for fi in range(len(last_f1)): + if(np.mean(tmp_f1)>np.mean(last_f1)): + tmp_out = 1 + print("tmp_out:",tmp_out) + if(tmp_out): + notgood_cnt = 0 + last_f1 = tmp_f1 + tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1] + tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae] + maxf1 = '_'.join(tmp_f1_str) + meanM = '_'.join(tmp_mae_str) + # .cpu().detach().numpy() + model_name = "/GTENCODER-gpu_itr_"+str(ite_num)+\ + "_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\ + "_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\ + "_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\ + "_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \ + "_maxF1_" + maxf1 + \ + "_mae_" + meanM + \ + "_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth" + torch.save(net.state_dict(), model_path + model_name) + + running_loss = 0.0 + running_tar_loss = 0.0 + ite_num4val = 0 + + if(np.mean(tmp_f1)>0.99): + print("GT encoder is well-trained and obtained...") + return net + + if(notgood_cnt >= hypar["early_stop"]): + print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !") + exit() + print("Training Reaches The Maximum Epoch Number") + return net + +def valid_gt_encoder(net, valid_dataloaders, valid_datasets, hypar, epoch=0): + net.eval() + print("Validating...") + epoch_num = hypar["max_epoch_num"] + + val_loss = 0.0 + tar_loss = 0.0 + + + tmp_f1 = [] + tmp_mae = [] + tmp_time = [] + + start_valid = time.time() + for k in range(len(valid_dataloaders)): + + valid_dataloader = valid_dataloaders[k] + valid_dataset = valid_datasets[k] + + val_num = valid_dataset.__len__() + mybins = np.arange(0,256) + PRE = np.zeros((val_num,len(mybins)-1)) + REC = np.zeros((val_num,len(mybins)-1)) + F1 = np.zeros((val_num,len(mybins)-1)) + MAE = np.zeros((val_num)) + + val_cnt = 0.0 + i_val = None + + for i_val, data_val in enumerate(valid_dataloader): + + # imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'] + imidx_val, labels_val, shapes_val = data_val['imidx'], data_val['label'], data_val['shape'] + if(hypar["model_digit"]=="full"): + labels_val = labels_val.type(torch.FloatTensor) + else: + labels_val = labels_val.type(torch.HalfTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + labels_val_v = Variable(labels_val.cuda(), requires_grad=False) + else: + labels_val_v = Variable(labels_val,requires_grad=False) + # with autocast(): + t_start = time.time() + ds_val = net(labels_val_v)[0] + t_end = time.time()-t_start + tmp_time.append(t_end) + + # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v) + loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v) + + # compute F measure + for t in range(hypar["batch_size_valid"]): + val_cnt = val_cnt + 1.0 + print("num of val: ", val_cnt) + i_test = imidx_val[t].data.numpy() + + pred_val = ds_val[0][t,:,:,:].float() # B x 1 x H x W + + ## recover the prediction spatial size to the orignal image size + pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear')) + + ma = torch.max(pred_val) + mi = torch.min(pred_val) + pred_val = (pred_val-mi)/(ma-mi) # max = 1 + # pred_val = normPRED(pred_val) + + gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 + if gt.max()==1: + gt=gt*255 + with torch.no_grad(): + gt = torch.tensor(gt).to(device) + + pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) + + PRE[i_test,:]=pre + REC[i_test,:] = rec + F1[i_test,:] = f1 + MAE[i_test] = mae + + del ds_val, gt + gc.collect() + torch.cuda.empty_cache() + + # if(loss_val.data[0]>1): + val_loss += loss_val.item()#data[0] + tar_loss += loss2_val.item()#data[0] + + print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end)) + + del loss2_val, loss_val + + print('============================') + PRE_m = np.mean(PRE,0) + REC_m = np.mean(REC,0) + f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8) + # print('--------------:', np.mean(f1_m)) + tmp_f1.append(np.amax(f1_m)) + tmp_mae.append(np.mean(MAE)) + print("The max F1 Score: %f"%(np.max(f1_m))) + print("MAE: ", np.mean(MAE)) + + # print('[epoch: %3d/%3d, ite: %5d] tra_ls: %3f, val_ls: %3f, tar_ls: %3f, maxf1: %3f, val_time: %6f'% (epoch + 1, epoch_num, ite_num, running_loss / ite_num4val, val_loss/val_cnt, tar_loss/val_cnt, tmp_f1[-1], time.time()-start_valid)) + + return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time + +def train(net, optimizer, train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val): #model_path, model_save_fre, max_ite=1000000): + + if hypar["interm_sup"]: + print("Get the gt encoder ...") + featurenet = get_gt_encoder(train_dataloaders, train_datasets, valid_dataloaders, valid_datasets, hypar,train_dataloaders_val, train_datasets_val) + ## freeze the weights of gt encoder + for param in featurenet.parameters(): + param.requires_grad=False + + # scaler = GradScaler() + model_path = hypar["model_path"] + model_save_fre = hypar["model_save_fre"] + max_ite = hypar["max_ite"] + batch_size_train = hypar["batch_size_train"] + batch_size_valid = hypar["batch_size_valid"] + + if(not os.path.exists(model_path)): + os.mkdir(model_path) + + ite_num = hypar["start_ite"] # count the toal iteration number + ite_num4val = 0 # + running_loss = 0.0 # count the toal loss + running_tar_loss = 0.0 # count the target output loss + last_mae = [1 for x in range(len(valid_dataloaders))] + last_f1 = [0 for x in range(len(valid_dataloaders))] + + train_num = train_datasets[0].__len__() + + net.train() + + start_last = time.time() + gos_dataloader = train_dataloaders[0] + epoch_num = hypar["max_epoch_num"] + notgood_cnt = 0 + for epoch in range(epoch_num): ## set the epoch num as 100000 + + for i, data in enumerate(gos_dataloader): + + if(ite_num >= max_ite): + print("Training Reached the Maximal Iteration Number ", max_ite) + exit() + + # start_read = time.time() + ite_num = ite_num + 1 + ite_num4val = ite_num4val + 1 + + # get the inputs + inputs, labels = data['image'], data['label'] + locations = data['location_blocks'] + if(hypar["model_digit"]=="full"): + inputs = inputs.type(torch.FloatTensor) + labels = labels.type(torch.FloatTensor) + locations = locations.type(torch.FloatTensor) + else: + inputs = inputs.type(torch.HalfTensor) + labels = labels.type(torch.HalfTensor) + locations = locations.type(torch.HalfTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), requires_grad=False) + locations_v = Variable(locations.cuda(), requires_grad=False) + else: + inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False) + locations_v = Variable(locations, requires_grad=False) + + # print("time lapse for data preparation: ", time.time()-start_read, ' s') + + # y zero the parameter gradients + start_inf_loss_back = time.time() + optimizer.zero_grad() + if hypar["interm_sup"]: + # with autocast(): + # forward + backward + optimize + _,fs = featurenet(labels_v) + ds,dfs = net(inputs_v) + ## extract the gt encodings + loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='MSE') + # loss2, loss = net.compute_loss_kl(ds, labels_v, dfs, fs, mode='cosin') + # print(next(featurenet.parameters()).dtype,next(net.parameters()).dtype,labels_v.dtype,fs[0][0].dtype) + # print(ds[0][0].dtype,dfs[0][0].dtype) + # print(loss2.dtype,loss.dtype) + else: + # with autocast(): + # forward + backward + optimize + ds,_ = net(inputs_v) + loss2, loss = net.compute_loss(ds, labels_v) + # loss.backward() + # with torch.autograd.detect_anomaly(): + # scaler.scale(loss).backward() + #ORTHO Loss + reg = 1e-8 + orth_loss = torch.zeros(1).to(device) + for name, param in net.named_parameters(): + if 'bias' not in name: + param_flat = param.view(param.shape[0], -1) + sym = torch.mm(param_flat, torch.t(param_flat)) + sym -= torch.eye(param_flat.shape[0]).to(device) + orth_loss = orth_loss + (reg * sym.abs().sum()) + loss = loss + orth_loss + loss.backward() + # scaler.step(optimizer) + # scaler.update() + optimizer.step() + # torch.cuda.empty_cache() + + # # print statistics + running_loss += loss.item() + running_tar_loss += loss2.item() + + # del outputs, loss + del ds, loss2, loss + end_inf_loss_back = time.time()-start_inf_loss_back + + print(">>>"+model_path.split('/')[-1]+" - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" % ( + epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val, time.time()-start_last, time.time()-start_last-end_inf_loss_back)) + start_last = time.time() + + if ite_num % model_save_fre == 0: # validate every 2000 iterations + notgood_cnt += 1 + net.eval() + tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(net, valid_dataloaders, valid_datasets, hypar, epoch) + torch.cuda.empty_cache() + net.train() # resume train + + tmp_out = 0 + print("last_f1:",last_f1,np.mean(last_f1)) + print("tmp_f1:",tmp_f1,np.mean(tmp_f1)) + if np.mean(tmp_mae)np.mean(last_f1): + last_f1 = tmp_f1 + tmp_out = 1 + print("tmp_out:",tmp_out) + if(tmp_out): + notgood_cnt = 0 + # last_f1 = tmp_f1 + tmp_f1_str = [str(round(f1x,4)) for f1x in tmp_f1] + tmp_mae_str = [str(round(mx,4)) for mx in tmp_mae] + maxf1 = '_'.join(tmp_f1_str) + meanM = '_'.join(tmp_mae_str) + # .cpu().detach().numpy() + model_name = "/gpu_itr_"+str(ite_num)+\ + "_traLoss_"+str(np.round(running_loss / ite_num4val,4))+\ + "_traTarLoss_"+str(np.round(running_tar_loss / ite_num4val,4))+\ + "_valLoss_"+str(np.round(val_loss /(i_val+1),4))+\ + "_valTarLoss_"+str(np.round(tar_loss /(i_val+1),4)) + \ + "_maxF1_" + maxf1 + \ + "_mae_" + meanM + \ + "_time_" + str(np.round(np.mean(np.array(tmp_time))/batch_size_valid,6))+".pth" + torch.save(net.state_dict(), model_path + model_name) + + running_loss = 0.0 + running_tar_loss = 0.0 + ite_num4val = 0 + + if(notgood_cnt >= hypar["early_stop"]): + print("No improvements in the last "+str(notgood_cnt)+" validation periods, so training stopped !") + exit() + + print("Training Reaches The Maximum Epoch Number") + +def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0): + net.eval() + print("Validating...") + epoch_num = hypar["max_epoch_num"] + + val_loss = 0.0 + tar_loss = 0.0 + val_cnt = 0.0 + + tmp_f1 = [] + tmp_mae = [] + tmp_time = [] + + start_valid = time.time() + + for k in range(len(valid_dataloaders)): + + valid_dataloader = valid_dataloaders[k] + valid_dataset = valid_datasets[k] + + val_num = valid_dataset.__len__() + mybins = np.arange(0,256) + PRE = np.zeros((val_num,len(mybins)-1)) + REC = np.zeros((val_num,len(mybins)-1)) + F1 = np.zeros((val_num,len(mybins)-1)) + MAE = np.zeros((val_num)) + + for i_val, data_val in enumerate(valid_dataloader): + val_cnt = val_cnt + 1.0 + imidx_val, inputs_val, labels_val, shapes_val = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'] + + if(hypar["model_digit"]=="full"): + inputs_val = inputs_val.type(torch.FloatTensor) + labels_val = labels_val.type(torch.FloatTensor) + else: + inputs_val = inputs_val.type(torch.HalfTensor) + labels_val = labels_val.type(torch.HalfTensor) + + # wrap them in Variable + if torch.cuda.is_available(): + inputs_val_v, labels_val_v = Variable(inputs_val.cuda(), requires_grad=False), Variable(labels_val.cuda(), requires_grad=False) + else: + inputs_val_v, labels_val_v = Variable(inputs_val, requires_grad=False), Variable(labels_val,requires_grad=False) + # with autocast(): + t_start = time.time() + ds_val = net(inputs_val_v)[0] + # plt.imshow(inputs_val_v[0][0].cpu().detach()) + # plt.show() + # print(inputs_val_v.cpu().detach().shape) + t_end = time.time()-t_start + tmp_time.append(t_end) + + # loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v) + loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v) + + # compute F measure + for t in range(hypar["batch_size_valid"]): + i_test = imidx_val[t].data.numpy() + + pred_val = ds_val[0][t,:,:,:].float() # B x 1 x H x W + + ## recover the prediction spatial size to the orignal image size + pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[t][0],shapes_val[t][1]),mode='bilinear')) + + # pred_val = normPRED(pred_val) + ma = torch.max(pred_val) + mi = torch.min(pred_val) + pred_val = (pred_val-mi)/(ma-mi) # max = 1 + + gt = np.squeeze(io.imread(valid_dataset.dataset["ori_gt_path"][i_test])) # max = 255 + if gt.max()==1: + gt=gt*255 + + with torch.no_grad(): + gt = torch.tensor(gt).to(device) + + pre,rec,f1,mae = f1_mae_torch(pred_val*255, gt, valid_dataset, i_test, mybins, hypar) + + + PRE[i_test,:]=pre + REC[i_test,:] = rec + F1[i_test,:] = f1 + MAE[i_test] = mae + + del ds_val, gt + gc.collect() + torch.cuda.empty_cache() + + # if(loss_val.data[0]>1): + val_loss += loss_val.item()#data[0] + tar_loss += loss2_val.item()#data[0] + + print("[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"% (i_val, val_num, val_loss / (i_val + 1), tar_loss / (i_val + 1), np.amax(F1[i_test,:]), MAE[i_test],t_end)) + + del loss2_val, loss_val + + print('============================') + PRE_m = np.mean(PRE,0) + REC_m = np.mean(REC,0) + f1_m = (1+0.3)*PRE_m*REC_m/(0.3*PRE_m+REC_m+1e-8) + + tmp_f1.append(np.amax(f1_m)) + tmp_mae.append(np.mean(MAE)) + + return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time + +def main(train_datasets, + valid_datasets, + hypar): # model: "train", "test" + + ### --- Step 1: Build datasets and dataloaders --- + dataloaders_train = [] + dataloaders_valid = [] + + if(hypar["mode"]=="train"): + print("--- create training dataloader ---") + ## collect training dataset + train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train") + ## build dataloader for training datasets + train_dataloaders, train_datasets = create_dataloaders(train_nm_im_gt_list, + cache_size = hypar["cache_size"], + cache_boost = hypar["cache_boost_train"], + my_transforms = [ + GOSRandomHFlip(), ## this line can be uncommented for horizontal flip augmetation + # GOSResize(hypar["input_size"]), + # GOSRandomCrop(hypar["crop_size"]), ## this line can be uncommented for randomcrop augmentation + # GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]), + GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]), + # GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), + # GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375]) + ], + batch_size = hypar["batch_size_train"], + shuffle = True, + is_train=True) + train_dataloaders_val, train_datasets_val = create_dataloaders(train_nm_im_gt_list, + cache_size = hypar["cache_size"], + cache_boost = hypar["cache_boost_train"], + my_transforms = [ + # GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]), + GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]), + # GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), + # GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375]) + ], + batch_size = hypar["batch_size_valid"], + shuffle = False, + is_train=False) + print(len(train_dataloaders), " train dataloaders created") + + print("--- create valid dataloader ---") + ## build dataloader for validation or testing + valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") + ## build dataloader for training datasets + valid_dataloaders, valid_datasets = create_dataloaders(valid_nm_im_gt_list, + cache_size = hypar["cache_size"], + cache_boost = hypar["cache_boost_valid"], + my_transforms = [ + # GOSNormalize([0.5,0.5,0.5,0,0,0,0,0],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]), + GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0]), + # GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0]), + # GOSNormalize([123.675, 116.28, 103.53],[58.395, 57.12, 57.375]) + # GOSResize(hypar["input_size"]) + ], + batch_size=hypar["batch_size_valid"], + shuffle=False, + is_train=False) + print(len(valid_dataloaders), " valid dataloaders created") + # print(valid_datasets[0]["data_name"]) + + ### --- Step 2: Build Model and Optimizer --- + print("--- build model ---") + net = hypar["model"]#GOSNETINC(3,1) + + # convert to half precision + # if(hypar["model_digit"]=="half"): + # net.half() + + if torch.cuda.is_available(): + net.cuda() + + if(hypar["restore_model"]!=""): + print("restore model from:") + print(hypar["model_path"]+"/"+hypar["restore_model"]) + if torch.cuda.is_available(): + net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"]),strict=False) + else: + net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location="cpu"),strict=False) + + print("--- define optimizer ---") + # optimizer = optim.AdamW(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + optimizer = optim.AdamW(net.parameters(), lr=4e-5, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + ### --- Step 3: Train or Valid Model --- + if(hypar["mode"]=="train"): + train(net, + optimizer, + train_dataloaders, + train_datasets, + valid_dataloaders, + valid_datasets, + hypar, + train_dataloaders_val, train_datasets_val) + else: + valid(net, + valid_dataloaders, + valid_datasets, + hypar) + + +if __name__ == "__main__": + + ### --------------- STEP 1: Configuring the Train, Valid and Test datasets --------------- + ## configure the train, valid and inference datasets + train_datasets, valid_datasets = [], [] + + valid_datasets = [dataset_test] ## users can create mutiple dictionary for setting a list of datasets as vaidation sets or inference sets + train_datasets = [dataset_test] ## users can create mutiple dictionary for setting a list of datasets as training set + + + ### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing --------------- + hypar = {} + + ## -- 2.1. configure the model saving or restoring path -- + hypar["mode"] = "train" + ## "train": for training, + ## "valid": for validation and inferening, + ## in "valid" mode, it will calculate the accuracy as well as save the prediciton results into the "hypar["valid_out_dir"]", which shouldn't be "" + ## otherwise only accuracy will bee calculated and no predictions will be saved + hypar["interm_sup"] = True ## in-dicate if activate intermediate feature supervision + + if hypar["mode"] == "train": + hypar["valid_out_dir"] = "" ## for "train" model leave it as "", for "valid"("inference") mode: set it according to your local directory + hypar["model_path"] ="./saved_models" ## model weights saving (or restoring) path + hypar["restore_model"] = "" ## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing + hypar["start_ite"] = 0 ## start iteration for the training, can be changed to match the restored training process + hypar["gt_encoder_model"] = "" + else: ## configure the segmentation output path and the to-be-used model weights path + hypar["valid_out_dir"] = "./your-results/"##".D:/Code/Design_for_graduation/DIS-main/IS-Net/DIS5K-Results-test" ## output inferenced segmentation maps into this fold + hypar["model_path"] = "./saved_models" ## load trained weights from this path + hypar["restore_model"] = "gpu_itr_102000_traLoss_2.5701_traTarLoss_0.0248_valLoss_2.3643_valTarLoss_0.3743_maxF1_0.8063_mae_0.0825_time_0.015695.pth"##"isnet.pth" ## name of the to-be-loaded weights + + # if hypar["restore_model"]!="": + # hypar["start_ite"] = int(hypar["restore_model"].split("_")[2]) + + ## -- 2.2. choose floating point accuracy -- + hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number + hypar["seed"] = 0 + + ## -- 2.3. cache data spatial size -- + ## To handle large size input images, which take a lot of time for loading in training, + # we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file + hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size + hypar["cache_boost_train"] = False ## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM + hypar["cache_boost_valid"] = False ## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM + + ## --- 2.4. data augmentation parameters --- + hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images + hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation + hypar["random_flip_h"] = 1 ## horizontal flip, currently hard coded in the datader and it is not in use + hypar["random_flip_v"] = 1 ## vertical flip , currently not in use + + ## --- 2.5. define model --- + print("building model...") + hypar["model"] = ISNetDIS(in_ch=5) #U2NETFASTFEATURESUP() + hypar["early_stop"] = 20 ## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10. + hypar["model_save_fre"] = 3000 ## valid and save model weights every 2000 iterations + + hypar["batch_size_train"] = 6 ## batch size for training + hypar["batch_size_valid"] = 1 ## batch size for validation and inferencing + print("batch size: ", hypar["batch_size_train"]) + + hypar["max_ite"] = 50000000 ## if early stop couldn't stop the training process, stop it by the max_ite_num + hypar["max_epoch_num"] = 500000 ## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num + + main(train_datasets, + valid_datasets, + hypar=hypar) + diff --git a/MultiScaleDeformableAttention-1.0-py3-none-any.whl b/MultiScaleDeformableAttention-1.0-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..a70ffc187e3e03d8730736f46f5c3dd9e98506f9 --- /dev/null +++ b/MultiScaleDeformableAttention-1.0-py3-none-any.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:152caec7860d1f39f644ac5eed946b5a4eecfad40764396345b3d0e516921b17 +size 2048806 diff --git a/README.md b/README.md index c5c03a5b6c962b52507db95bfc30e2b781395f2f..ec3d012b5a33c1f4df6ad429180826f1f5aef6c7 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ app_file: app.py pinned: false license: mit short_description: SAM-prompted dichotomous segmentation. No affiliation. +python_version: 3.11 +preload_from_hub: + - jwlarocque/DIS-SAM DIS-SAM-checkpoint.pth + - andzhang01/segment_anything sam_vit_l_0b3195.pth --- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/SAM/segment_anything/__init__.py b/SAM/segment_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34383d83f5e76bc801f31b20e5651e383be348b6 --- /dev/null +++ b/SAM/segment_anything/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/SAM/segment_anything/__pycache__/__init__.cpython-311.pyc b/SAM/segment_anything/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c201c3514e9d3aba34274f8e56b032c8723341f Binary files /dev/null and b/SAM/segment_anything/__pycache__/__init__.cpython-311.pyc differ diff --git a/SAM/segment_anything/__pycache__/automatic_mask_generator.cpython-311.pyc b/SAM/segment_anything/__pycache__/automatic_mask_generator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a13b22d1efa2308ecccda3b93531d2044f7029a6 Binary files /dev/null and b/SAM/segment_anything/__pycache__/automatic_mask_generator.cpython-311.pyc differ diff --git a/SAM/segment_anything/__pycache__/build_sam.cpython-311.pyc b/SAM/segment_anything/__pycache__/build_sam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..653771ea04b22ac865aac84386f05038229bbaaf Binary files /dev/null and b/SAM/segment_anything/__pycache__/build_sam.cpython-311.pyc differ diff --git a/SAM/segment_anything/__pycache__/predictor.cpython-311.pyc b/SAM/segment_anything/__pycache__/predictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1120416848bf4e22cbb22a9c874035783b50a815 Binary files /dev/null and b/SAM/segment_anything/__pycache__/predictor.cpython-311.pyc differ diff --git a/SAM/segment_anything/automatic_mask_generator.py b/SAM/segment_anything/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a8c969207f119feff7087f94e044403acdff00 --- /dev/null +++ b/SAM/segment_anything/automatic_mask_generator.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/SAM/segment_anything/build_sam.py b/SAM/segment_anything/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9b257cc99c2ea23f5fdd69838ec4cdd4c78d16 --- /dev/null +++ b/SAM/segment_anything/build_sam.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None, device="cpu"): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + device=device, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None, device="cpu"): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + device=device, + ) + + +def build_sam_vit_b(checkpoint=None, device="cpu"): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + device=device, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + device="cpu" +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f, map_location=device) + sam.load_state_dict(state_dict) + return sam diff --git a/SAM/segment_anything/modeling/__init__.py b/SAM/segment_anything/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38e906243d898d7fc071c0fe218338c5cace3ea1 --- /dev/null +++ b/SAM/segment_anything/modeling/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/SAM/segment_anything/modeling/__pycache__/__init__.cpython-311.pyc b/SAM/segment_anything/modeling/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1d7cfc9528d799e2225c79595b0bbba8f1c0248 Binary files /dev/null and b/SAM/segment_anything/modeling/__pycache__/__init__.cpython-311.pyc differ diff --git a/SAM/segment_anything/modeling/__pycache__/common.cpython-311.pyc b/SAM/segment_anything/modeling/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2c5d482fcbf50d176f596a43dde959db643342f Binary files /dev/null and b/SAM/segment_anything/modeling/__pycache__/common.cpython-311.pyc differ diff --git a/SAM/segment_anything/modeling/__pycache__/image_encoder.cpython-311.pyc b/SAM/segment_anything/modeling/__pycache__/image_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fcffc3b2646b8304d16cb8f1bcc50fe96ab37bf Binary files /dev/null and b/SAM/segment_anything/modeling/__pycache__/image_encoder.cpython-311.pyc differ diff --git a/SAM/segment_anything/modeling/__pycache__/mask_decoder.cpython-311.pyc b/SAM/segment_anything/modeling/__pycache__/mask_decoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52512b0185eaa3996b37eab849b52e0f085dc035 Binary files /dev/null and b/SAM/segment_anything/modeling/__pycache__/mask_decoder.cpython-311.pyc differ diff --git a/SAM/segment_anything/modeling/__pycache__/prompt_encoder.cpython-311.pyc b/SAM/segment_anything/modeling/__pycache__/prompt_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bebbea3f9f2f583dac251a8ee8217226129fa96 Binary files /dev/null and b/SAM/segment_anything/modeling/__pycache__/prompt_encoder.cpython-311.pyc differ diff --git a/SAM/segment_anything/modeling/__pycache__/sam.cpython-311.pyc b/SAM/segment_anything/modeling/__pycache__/sam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbbd16eb24652f254d2e19990b66a0924e551bd2 Binary files /dev/null and b/SAM/segment_anything/modeling/__pycache__/sam.cpython-311.pyc differ diff --git a/SAM/segment_anything/modeling/__pycache__/transformer.cpython-311.pyc b/SAM/segment_anything/modeling/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5475ee2a0f82304aefec94c296bf04fca6564516 Binary files /dev/null and b/SAM/segment_anything/modeling/__pycache__/transformer.cpython-311.pyc differ diff --git a/SAM/segment_anything/modeling/common.py b/SAM/segment_anything/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96 --- /dev/null +++ b/SAM/segment_anything/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/SAM/segment_anything/modeling/image_encoder.py b/SAM/segment_anything/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..66351d9d7c589be693f4b3485901d3bdfed54d4a --- /dev/null +++ b/SAM/segment_anything/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/SAM/segment_anything/modeling/mask_decoder.py b/SAM/segment_anything/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2fdb03d535a91fa725d1ec4e92a7a1f217dfe0 --- /dev/null +++ b/SAM/segment_anything/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/SAM/segment_anything/modeling/prompt_encoder.py b/SAM/segment_anything/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3 --- /dev/null +++ b/SAM/segment_anything/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/SAM/segment_anything/modeling/sam.py b/SAM/segment_anything/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..8074cff6b40addc6b66f7ab4962218eef20da13c --- /dev/null +++ b/SAM/segment_anything/modeling/sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/SAM/segment_anything/modeling/transformer.py b/SAM/segment_anything/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a --- /dev/null +++ b/SAM/segment_anything/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/SAM/segment_anything/predictor.py b/SAM/segment_anything/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..c276fd233a77fe08eb41d4c143ce119bf242626e --- /dev/null +++ b/SAM/segment_anything/predictor.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/SAM/segment_anything/utils/__init__.py b/SAM/segment_anything/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/SAM/segment_anything/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/SAM/segment_anything/utils/__pycache__/__init__.cpython-311.pyc b/SAM/segment_anything/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9393e4ca6682f067d04d936221216deaf387b9ef Binary files /dev/null and b/SAM/segment_anything/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/SAM/segment_anything/utils/__pycache__/amg.cpython-311.pyc b/SAM/segment_anything/utils/__pycache__/amg.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..343e1d7c5c72c132b76b85e4e62f072eb73e2645 Binary files /dev/null and b/SAM/segment_anything/utils/__pycache__/amg.cpython-311.pyc differ diff --git a/SAM/segment_anything/utils/__pycache__/transforms.cpython-311.pyc b/SAM/segment_anything/utils/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0796e0d6cb7d9f85b9bf742fe1a0953e6c485c14 Binary files /dev/null and b/SAM/segment_anything/utils/__pycache__/transforms.cpython-311.pyc differ diff --git a/SAM/segment_anything/utils/amg.py b/SAM/segment_anything/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..be064071ef399fea96c673ad173689656c23534a --- /dev/null +++ b/SAM/segment_anything/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/SAM/segment_anything/utils/onnx.py b/SAM/segment_anything/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..3196bdf4b782e6eeb3da4ad66ef3c7b1741535fe --- /dev/null +++ b/SAM/segment_anything/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/SAM/segment_anything/utils/transforms.py b/SAM/segment_anything/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c08ba1e3db751f3a5483a003be38c69c2cf2df85 --- /dev/null +++ b/SAM/segment_anything/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/app.py b/app.py index 652dc459175b17dc4acd2394b8b8ad1edd2394a5..bc2b4a89685b13b76dfcd34736ddfa69a7d0f7cc 100644 --- a/app.py +++ b/app.py @@ -1,154 +1,237 @@ import gradio as gr import numpy as np -import random - -# import spaces #[uncomment to use ZeroGPU] -from diffusers import DiffusionPipeline +import numpy as np import torch +import matplotlib.pyplot as plt +import cv2 + +from PIL import Image +import torch.nn as nn +from torch.autograd import Variable +from torchvision import transforms +import torch.nn.functional as F +import gdown +import os + +from io import BytesIO +from IS_Net.data_loader import normalize, im_reader, im_preprocess +from IS_Net.models.isnet import ISNetGTEncoder, ISNetDIS + +from SAM.segment_anything import sam_model_registry, SamPredictor device = "cuda" if torch.cuda.is_available() else "cpu" -model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use - -if torch.cuda.is_available(): - torch_dtype = torch.float16 -else: - torch_dtype = torch.float32 - -pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) -pipe = pipe.to(device) - -MAX_SEED = np.iinfo(np.int32).max -MAX_IMAGE_SIZE = 1024 - - -# @spaces.GPU #[uncomment to use ZeroGPU] -def infer( - prompt, - negative_prompt, - seed, - randomize_seed, - width, - height, - guidance_scale, - num_inference_steps, - progress=gr.Progress(track_tqdm=True), -): - if randomize_seed: - seed = random.randint(0, MAX_SEED) - - generator = torch.Generator().manual_seed(seed) - - image = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - width=width, - height=height, - generator=generator, - ).images[0] - - return image, seed - - -examples = [ - "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", - "An astronaut riding a green horse", - "A delicious ceviche cheesecake slice", -] - -css = """ -#col-container { - margin: 0 auto; - max-width: 640px; -} -""" - -with gr.Blocks(css=css) as demo: - with gr.Column(elem_id="col-container"): - gr.Markdown(" # Text-to-Image Gradio Template") - - with gr.Row(): - prompt = gr.Text( - label="Prompt", - show_label=False, - max_lines=1, - placeholder="Enter your prompt", - container=False, - ) - - run_button = gr.Button("Run", scale=0, variant="primary") - - result = gr.Image(label="Result", show_label=False) - - with gr.Accordion("Advanced Settings", open=False): - negative_prompt = gr.Text( - label="Negative prompt", - max_lines=1, - placeholder="Enter a negative prompt", - visible=False, - ) - - seed = gr.Slider( - label="Seed", - minimum=0, - maximum=MAX_SEED, - step=1, - value=0, - ) - - randomize_seed = gr.Checkbox(label="Randomize seed", value=True) - - with gr.Row(): - width = gr.Slider( - label="Width", - minimum=256, - maximum=MAX_IMAGE_SIZE, - step=32, - value=1024, # Replace with defaults that work for your model - ) - - height = gr.Slider( - label="Height", - minimum=256, - maximum=MAX_IMAGE_SIZE, - step=32, - value=1024, # Replace with defaults that work for your model - ) - - with gr.Row(): - guidance_scale = gr.Slider( - label="Guidance scale", - minimum=0.0, - maximum=10.0, - step=0.1, - value=0.0, # Replace with defaults that work for your model - ) - - num_inference_steps = gr.Slider( - label="Number of inference steps", - minimum=1, - maximum=50, - step=1, - value=2, # Replace with defaults that work for your model - ) - - gr.Examples(examples=examples, inputs=[prompt]) - gr.on( - triggers=[run_button.click, prompt.submit], - fn=infer, - inputs=[ - prompt, - negative_prompt, - seed, - randomize_seed, - width, - height, - guidance_scale, - num_inference_steps, - ], - outputs=[result, seed], + +def show_gray_images(images, m=8, alpha=3): + n, h, w = images.shape + num_rows = (n + m - 1) // m + fig, axes = plt.subplots(num_rows, m, figsize=(m * 2*alpha, num_rows * 2*alpha)) + plt.subplots_adjust(wspace=0.05, hspace=0.05) + for i in range(num_rows): + for j in range(m): + idx = i*m + j + if m == 1 or num_rows == 1: + axes[idx].imshow(images[idx], cmap='gray') + axes[idx].axis('off') + elif idx < n: + axes[i, j].imshow(images[idx], cmap='gray') + axes[i, j].axis('off') + plt.show() + +def show_mask(mask, ax, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) + + + +sam_checkpoint = r"~/.cache/huggingface/hub/sam_vit_l_0b3195.pth" +model_type = "vit_l" + +sam = sam_model_registry[model_type](checkpoint=sam_checkpoint, device=device) +sam.to(device=device) + +predictor = SamPredictor(sam) + + +class GOSNormalize(object): + ''' + Normalize the Image using torch.transforms + ''' + def __init__(self, mean=[0.485,0.456,0.406,0], std=[0.229,0.224,0.225,1.0]): + self.mean = mean + self.std = std + + def __call__(self,image): + image = normalize(image,self.mean,self.std) + return image + +transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0])]) + +def build_model(hypar,device): + net = hypar["model"]#GOSNETINC(3,1) + + # convert to half precision + if(hypar["model_digit"]=="half"): + net.half() + for layer in net.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.float() + + net.to(device) + + if(hypar["restore_model"]!=""): + net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location=device)) + net.to(device) + net.eval() + return net + +def get_box(input_box,size): + + # 初始化一个全零的图像 + image = torch.zeros(size) + + # 填充方框区域为白色(值为255) + image[input_box[1]:input_box[3],input_box[0]:input_box[2]] = 255 + return image + +def get_box_from_mask(gt): + gt = torch.from_numpy(np.array(gt)) + box = torch.zeros_like(gt)+gt + box = box.float() + rows, cols = torch.where(box>0) + left = torch.min(cols) + top = torch.min(rows) + right = torch.max(cols) + bottom = torch.max(rows) + box[top:bottom,left:right] = 255 + box[box!=255] = 0 + return box + +def predict_one(net, image, mask, box, transforms, hypar, device): + ''' + Given an Image, predict the mask + ''' + with torch.no_grad(): + image = torch.from_numpy(np.array(image)) + mask = torch.from_numpy(np.array(mask)) + box = torch.from_numpy(np.array(box)) + if mask.max()==1: + mask = mask.type(torch.float32)*255.0 + # for i in [image,mask[...,None],box[...,None]]: + # print(i.shape) + inputs_val_v = torch.cat([image,mask[...,None],box[...,None]],dim=2) + inputs_val_v = inputs_val_v.permute(2,0,1)[None,...] + shapes_val = inputs_val_v.shape[-2:] + + inputs_val_v = F.upsample(inputs_val_v,(hypar["input_size"]),mode='bilinear') + box = inputs_val_v[0][-1] + box[box>127] = 255 + box[box<=127] = 0 + inputs_val_v[0][-1] = box + # plt.imshow(inputs_val_v[0][-1]) + # plt.show() + inputs_val_v = inputs_val_v.divide(255.0) + # print(shapes_val) + net.eval() + + if(hypar["model_digit"]=="full"): + inputs_val_v = inputs_val_v.type(torch.FloatTensor) + else: + inputs_val_v = inputs_val_v.type(torch.HalfTensor) + + + inputs_val_v = Variable(inputs_val_v, requires_grad=False).to(device) # wrap inputs in Variable + inputs_val_v = transforms(inputs_val_v) + # print(inputs_val_v.shape) + ds_val = net(inputs_val_v)[0][0] + # print(ds_val.shape) + ## recover the prediction spatial size to the orignal image size + pred_val = F.upsample(ds_val,(shapes_val),mode='bilinear')[0][0] + # print(pred_val.shape) + ma = torch.max(pred_val) + mi = torch.min(pred_val) + pred_val = (pred_val-mi)/(ma-mi) # max = 1 + + if device == 'cuda': torch.cuda.empty_cache() + refined_mask = (pred_val.detach().cpu().numpy()*255).astype(np.uint8) + # refined_mask[refined_mask>127] = 255 + # refined_mask[refined_mask<=127] = 0 + # refined_mask = 1 - refined_mask.astype(np.byte) + ret, binary = cv2.threshold(refined_mask, 0, 255, cv2.THRESH_OTSU) + return binary# it is the mask we need + +hypar = {} # paramters for inferencing + +hypar["model_path"] ="~/.cache/huggingface/hub" +hypar["restore_model"] = "DIS-SAM-checkpoint.pth" +hypar["model_digit"] = "full" +hypar["input_size"] = [1024, 1024] +hypar["model"] = ISNetDIS(in_ch=5) +net = build_model(hypar, device) + +def bbox_from_str(bbox_str: str): + if not bbox_str: + return None + split = bbox_str.strip().split(",") + if len(split) == 4: + try: + bbox = [int(x) for x in split] + return np.array(bbox) + except ValueError: + return None + else: + return None + +def predict(input_img: np.ndarray, bbox_str: str): + predictor.set_image(input_img) + + input_label = np.array([1]) + bbox = bbox_from_str(bbox_str) + input_box = bbox if bbox is not None else np.array([0, 0, input_img.shape[1], input_img.shape[0]]) + + masks, scores, logits = predictor.predict( + box=input_box, + point_labels=input_label, + multimask_output=True, ) + mask = masks[0] + DIS_mask = mask + DIS_box = get_box_from_mask(DIS_mask) + refined_mask = predict_one(net,input_img,DIS_mask,DIS_box,transform,hypar,device) + + mask_gray = (mask * 255).astype(np.uint8) + refined_mask_gray = refined_mask.astype(np.uint8) + return mask_gray, refined_mask_gray + +gradio_app = gr.Interface( + predict, + inputs=[ + gr.Image(label="Select Image", sources=['upload', 'webcam'], type="numpy"), + gr.Textbox(label="Bounding Box Prompt (pixels)", placeholder="x1,y1,x2,y2")], + outputs=[gr.Image(label="SAM Mask", type="numpy", image_mode="L"), gr.Image(label="DIS-SAM Mask", type="numpy", image_mode="L")], + title="DIS-SAM", + examples=[ + ["./images/wire_shelf.jpg", "20,100,480,660"], + ["./images/radio_telescope.jpg", "1130,320,4000,2920"], + ["./images/bridge.jpg", ""], + ["./images/tree.jpg", "70,110,2290,1800"] + ] +) if __name__ == "__main__": - demo.launch() + gradio_app.launch() diff --git a/images/bridge.jpg b/images/bridge.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c5ff783b7238e6217621cef217b69df7f8bfb4f --- /dev/null +++ b/images/bridge.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5646db16c5fe0d5ef3c8e6d253504b3f1ea96a400f2c3a30b16febe400a2681 +size 193530 diff --git a/images/radio_telescope.jpg b/images/radio_telescope.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dfc857bc6674dbcc5582f78ac07b6b8ccb804f76 --- /dev/null +++ b/images/radio_telescope.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba679549f072c809e9a62ba643d370ad17bf3cd267ab3412b28acd5dd22f6c18 +size 6522735 diff --git a/images/stairs.jpg b/images/stairs.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0da9536d55ed887ddca194ae12bc53f58779ddf4 --- /dev/null +++ b/images/stairs.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a5d8a6bc2576d9c42e77b1e8b93daf5ffece46d720a9f6f3f2ee271d6c1ec42 +size 64901 diff --git a/images/tree.jpg b/images/tree.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9eac03516e30a71c898848c347862db13859df1a --- /dev/null +++ b/images/tree.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9b51e38673aed17108751024238963d804ba0f17ea5a9c127d2017f6f841091 +size 5143326 diff --git a/images/wire_shelf.jpg b/images/wire_shelf.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fed9e3e0a4b643dda3f121b40756bd578c4e9e75 --- /dev/null +++ b/images/wire_shelf.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5c4ae8ecd20bc3c2f868e0d0851f83480043e1848e18c3f2a1ee455fd9074a9 +size 80377 diff --git a/requirements.txt b/requirements.txt index 73d01db64bd054c7d21fd0bd79b3af087c468809..ebefbd63e101e9661fb79553c7b0d7547c3274e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,163 @@ -accelerate -diffusers -invisible_watermark -torch -transformers -xformers \ No newline at end of file +absl-py==1.4.0 +addict==2.4.0 +albumentations==1.4.3 +aliyun-python-sdk-core==2.15.0 +aliyun-python-sdk-kms==2.16.2 +asttokens #@ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work +backcall #@ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work +backports.functools-lru-cache #@ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work +beautifulsoup4==4.12.2 +cachetools==5.3.1 +certifi==2022.12.7 +cffi==1.16.0 +charset-normalizer==2.1.1 +click==8.1.7 +colorama #@ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work +comm #@ file:///home/conda/feedstock_root/build_artifacts/comm_1691044910542/work +contourpy==1.1.0 +crcmod==1.7 +cryptography==42.0.5 +cycler==0.11.0 +debugpy #@ file:///C:/b/abs_c0y1fjipt2/croot/debugpy_1690906864587/work +decorator #@ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work +dgl==1.1.2 +diffusers==0.27.2 +einops==0.6.1 +entmax==1.3 +executing #@ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work +filelock==3.9.0 +fonttools==4.42.0 +fsspec==2024.3.1 +gdown==4.7.1 +google-auth==2.22.0 +google-auth-oauthlib==1.0.0 +grpcio==1.57.0 +huggingface-hub==0.22.2 +idna==3.4 +imageio==2.31.5 +importlib-metadata #@ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1688754491823/work +#install==1.3.5 +intel-openmp==2021.4.0 +ipykernel #@ file:///D:/bld/ipykernel_1690311464685/work +ipython #@ file:///D:/bld/ipython_1685727936079/work +jedi #@ file:///home/conda/feedstock_root/build_artifacts/jedi_1690896916983/work +Jinja2==3.1.2 +jmespath==0.10.0 +joblib==1.3.2 +jupyter_client #@ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1687700988094/work +jupyter_core #@ file:///D:/bld/jupyter_core_1686775880418/work +kiwisolver==1.4.4 +lazy_loader==0.3 +Markdown==3.4.4 +markdown-it-py==3.0.0 +MarkupSafe==2.1.2 +matlab==0.1 +matplotlib==3.7.2 +matplotlib-inline #@ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work +mdurl==0.1.2 +mkl==2021.4.0 +mmcls==0.25.0 +mmcv-full==1.7.2 +mmengine==0.10.3 +mne==1.4.2 +model-index==0.1.11 +monai==1.3.0 +mpmath==1.2.1 +./MultiScaleDeformableAttention-1.0-py3-none-any.whl +munch==4.0.0 +nest-asyncio #@ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work +networkx==3.0rc1 +ninja==1.11.1.1 +numpy==1.26.4 +oauthlib==3.2.2 +opencv-python==4.8.0.76 +opencv-python-headless==4.9.0.80 +opendatalab==0.0.10 +openmim==0.3.9 +openxlab==0.0.38 +ordered-set==4.1.0 +oss2==2.17.0 +packaging #@ file:///home/conda/feedstock_root/build_artifacts/packaging_1681337016113/work +pandas==2.0.3 +parso #@ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work +pickleshare #@ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work +Pillow==9.3.0 +pix2tex==0.1.2 +platformdirs #@ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1690813113769/work +pooch==1.7.0 +prompt-toolkit #@ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1688565951714/work +protobuf==4.24.2 +psutil #@ file:///C:/ci_311_rebuilds/psutil_1679005906571/work +pure-eval #@ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pycocotools==2.0.7 +pycparser==2.22 +pycryptodome==3.20.0 +Pygments #@ file:///home/conda/feedstock_root/build_artifacts/pygments_1681904169130/work +pynput==1.7.6 +pyparsing==3.0.9 +PyQt6==6.6.1 +PyQt6-Qt6==6.6.2 +PyQt6-sip==13.6.0 +PyQt6-WebEngine==6.6.0 +PyQt6-WebEngine-Qt6==6.6.2 +pyreadline3==3.4.1 +PySide6==6.6.3.1 +PySide6_Addons==6.6.3.1 +PySide6_Essentials==6.6.3.1 +PySocks==1.7.1 +python-dateutil #@ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work +pytz==2023.3 +PyWavelets==1.4.1 +#pywin32 +PyYAML==6.0.1 +pyzmq #@ file:///C:/b/abs_655zk4a3s8/croot/pyzmq_1686601465034/work +regex==2023.12.25 +requests==2.28.2 +requests-oauthlib==1.3.1 +rich==13.4.2 +rsa==4.9 +safetensors==0.4.2 +scikit-image==0.22.0 +scikit-learn==1.4.1.post1 +scipy==1.11.1 +screeninfo==0.8.1 +seaborn==0.13.1 +shapely==2.0.1 +shiboken6==6.6.3.1 +six #@ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work +soupsieve==2.5 +stack-data #@ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work +sympy==1.11.1 +tabulate==0.9.0 +tbb==2021.12.0 +tensorboard==2.14.0 +tensorboard-data-server==0.7.1 +tensorboardX==2.6.2.2 +termcolor==2.4.0 +threadpoolctl==3.2.0 +tifffile==2023.9.26 +timm==0.5.4 +tokenizers==0.15.2 +tomli==2.0.1 +torch==2.3.0 +torch-geometric==2.3.1 +torchaudio==2.3.0 +torchsampler==0.1.2 +torchsummary==1.5.1 +torchvision==0.18.0 +tornado #@ file:///C:/b/abs_61jhmrrua1/croot/tornado_1690848767317/work +tqdm==4.65.0 +traitlets #@ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work +transformers==4.39.3 +typing_extensions==4.10.0 +tzdata==2023.3 +urllib3==1.26.13 +wcwidth #@ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work +Werkzeug==2.3.7 +x-transformers==0.15.0 +xformers==0.0.26.post1 +yacs==0.1.8 +yapf==0.40.2 +zipp #@ file:///home/conda/feedstock_root/build_artifacts/zipp_1689374466814/work