Upload 46 files
Browse files- .gitattributes +5 -0
- assets/geo.png +0 -0
- assets/geo_white.png +0 -0
- assets/main.png +3 -0
- assets/mixing_traj.png +0 -0
- assets/mixing_traj_white.png +0 -0
- assets/non_cherry_picky.png +3 -0
- assets/strength_space.png +3 -0
- assets/teaser.png +3 -0
- assets/unconditional.png +3 -0
- boundarydiffusion.py +713 -0
- configs/afhq.yml +35 -0
- configs/bedroom.yml +35 -0
- configs/celeba.yml +35 -0
- configs/church.yml +35 -0
- configs/imagenet.yml +35 -0
- configs/paths_config.py +25 -0
- data_download.sh +38 -0
- datasets/AFHQ_dataset.py +42 -0
- datasets/CelebA_HQ_dataset.py +83 -0
- datasets/CelebA_HQ_dataset_with_label.py +63 -0
- datasets/IMAGENET_dataset.py +102 -0
- datasets/LSUN_dataset.py +304 -0
- datasets/celeba_attr.txt +40 -0
- datasets/data_utils.py +44 -0
- datasets/imagenet_dic.py +408 -0
- imgs/img1.jpg +0 -0
- losses/clip_loss.py +299 -0
- losses/id_loss.py +35 -0
- main.py +275 -0
- models/ddpm/diffusion.py +348 -0
- models/improved_ddpm/fp16_util.py +236 -0
- models/improved_ddpm/logger.py +451 -0
- models/improved_ddpm/nn.py +170 -0
- models/improved_ddpm/script_util.py +109 -0
- models/improved_ddpm/unet.py +677 -0
- models/insight_face/__init__.py +0 -0
- models/insight_face/helpers.py +178 -0
- models/insight_face/model_irse.py +124 -0
- requirements.txt +10 -0
- utils/align_utils.py +213 -0
- utils/celeba_attr.txt +40 -0
- utils/colab_utils.py +36 -0
- utils/diffusion_utils.py +134 -0
- utils/prepare_lmdb_data.py +140 -0
- utils/text_dic.py +123 -0
- utils/text_templates.py +129 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/main.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/non_cherry_picky.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/strength_space.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/unconditional.png filter=lfs diff=lfs merge=lfs -text
|
assets/geo.png
ADDED
![]() |
assets/geo_white.png
ADDED
![]() |
assets/main.png
ADDED
![]() |
Git LFS Details
|
assets/mixing_traj.png
ADDED
![]() |
assets/mixing_traj_white.png
ADDED
![]() |
assets/non_cherry_picky.png
ADDED
![]() |
Git LFS Details
|
assets/strength_space.png
ADDED
![]() |
Git LFS Details
|
assets/teaser.png
ADDED
![]() |
Git LFS Details
|
assets/unconditional.png
ADDED
![]() |
Git LFS Details
|
boundarydiffusion.py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from glob import glob
|
3 |
+
from tqdm import tqdm
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torchvision.utils as tvu
|
11 |
+
from sklearn import svm
|
12 |
+
import pickle
|
13 |
+
import torch.optim as optim
|
14 |
+
|
15 |
+
from models.ddpm.diffusion import DDPM
|
16 |
+
from models.improved_ddpm.script_util import i_DDPM
|
17 |
+
from utils.text_dic import SRC_TRG_TXT_DIC
|
18 |
+
from utils.diffusion_utils import get_beta_schedule, denoising_step
|
19 |
+
from datasets.data_utils import get_dataset, get_dataloader
|
20 |
+
from configs.paths_config import DATASET_PATHS, MODEL_PATHS, HYBRID_MODEL_PATHS, HYBRID_CONFIG
|
21 |
+
from datasets.imagenet_dic import IMAGENET_DIC
|
22 |
+
from utils.align_utils import run_alignment
|
23 |
+
from utils.distance_utils import euclidean_distance, cosine_similarity
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def compute_radius(x):
|
28 |
+
x = torch.pow(x, 2)
|
29 |
+
r = torch.sum(x)
|
30 |
+
r = torch.sqrt(r)
|
31 |
+
return r
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
class BoundaryDiffusion(object):
|
37 |
+
def __init__(self, args, config, device=None):
|
38 |
+
self.args = args
|
39 |
+
self.config = config
|
40 |
+
if device is None:
|
41 |
+
device = torch.device(
|
42 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
43 |
+
self.device = device
|
44 |
+
|
45 |
+
self.model_var_type = config.model.var_type
|
46 |
+
betas = get_beta_schedule(
|
47 |
+
beta_start=config.diffusion.beta_start,
|
48 |
+
beta_end=config.diffusion.beta_end,
|
49 |
+
num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps
|
50 |
+
)
|
51 |
+
self.betas = torch.from_numpy(betas).float().to(self.device)
|
52 |
+
self.num_timesteps = betas.shape[0]
|
53 |
+
|
54 |
+
alphas = 1.0 - betas
|
55 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
56 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
57 |
+
posterior_variance = betas * \
|
58 |
+
(1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
59 |
+
if self.model_var_type == "fixedlarge":
|
60 |
+
self.logvar = np.log(np.append(posterior_variance[1], betas[1:]))
|
61 |
+
|
62 |
+
elif self.model_var_type == 'fixedsmall':
|
63 |
+
self.logvar = np.log(np.maximum(posterior_variance, 1e-20))
|
64 |
+
|
65 |
+
if self.args.edit_attr is None:
|
66 |
+
self.src_txts = self.args.src_txts
|
67 |
+
self.trg_txts = self.args.trg_txts
|
68 |
+
else:
|
69 |
+
self.src_txts = SRC_TRG_TXT_DIC[self.args.edit_attr][0]
|
70 |
+
self.trg_txts = SRC_TRG_TXT_DIC[self.args.edit_attr][1]
|
71 |
+
|
72 |
+
|
73 |
+
def unconditional(self):
|
74 |
+
print(self.args.exp)
|
75 |
+
|
76 |
+
# ----------- Model -----------#
|
77 |
+
if self.config.data.dataset == "LSUN":
|
78 |
+
if self.config.data.category == "bedroom":
|
79 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
|
80 |
+
elif self.config.data.category == "church_outdoor":
|
81 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
|
82 |
+
elif self.config.data.dataset == "CelebA_HQ":
|
83 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
|
84 |
+
elif self.config.data.dataset == "AFHQ":
|
85 |
+
pass
|
86 |
+
else:
|
87 |
+
raise ValueError
|
88 |
+
|
89 |
+
if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:
|
90 |
+
model = DDPM(self.config)
|
91 |
+
if self.args.model_path:
|
92 |
+
init_ckpt = torch.load(self.args.model_path)
|
93 |
+
else:
|
94 |
+
init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)
|
95 |
+
learn_sigma = False
|
96 |
+
print("Original diffusion Model loaded.")
|
97 |
+
elif self.config.data.dataset in ["FFHQ", "AFHQ"]:
|
98 |
+
model = i_DDPM(self.config.data.dataset)
|
99 |
+
if self.args.model_path:
|
100 |
+
init_ckpt = torch.load(self.args.model_path)
|
101 |
+
else:
|
102 |
+
init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])
|
103 |
+
learn_sigma = True
|
104 |
+
print("Improved diffusion Model loaded.")
|
105 |
+
else:
|
106 |
+
print('Not implemented dataset')
|
107 |
+
raise ValueError
|
108 |
+
model.load_state_dict(init_ckpt)
|
109 |
+
model.to(self.device)
|
110 |
+
model = torch.nn.DataParallel(model)
|
111 |
+
model.eval()
|
112 |
+
|
113 |
+
# ----------- Precompute Latents -----------#
|
114 |
+
seq_inv = np.linspace(0, 1, 999) * 999
|
115 |
+
seq_inv = [int(s) for s in list(seq_inv)]
|
116 |
+
seq_inv_next = [-1] + list(seq_inv[:-1])
|
117 |
+
|
118 |
+
###---- boundaries---####
|
119 |
+
# ---------- Load boundary ----------#
|
120 |
+
classifier = pickle.load(open('./boundary/smile_boundary_h.sav', 'rb'))
|
121 |
+
a = classifier.coef_.reshape(1, 512*8*8).astype(np.float32)
|
122 |
+
# a = a / np.linalg.norm(a)
|
123 |
+
|
124 |
+
z_classifier = pickle.load(open('./boundary/smile_boundary_z.sav', 'rb'))
|
125 |
+
z_a = z_classifier.coef_.reshape(1, 3*256*256).astype(np.float32)
|
126 |
+
z_a = z_a / np.linalg.norm(z_a) # normalized boundary
|
127 |
+
|
128 |
+
x_lat = torch.randn(1, 3, 256, 256, device=self.device)
|
129 |
+
n = 1
|
130 |
+
print("get the sampled latent encodings x_T!")
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
with tqdm(total=len(seq_inv), desc=f"Generative process") as progress_bar:
|
134 |
+
for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):
|
135 |
+
t = (torch.ones(n) * i).to(self.device)
|
136 |
+
t_next = (torch.ones(n) * j).to(self.device)
|
137 |
+
# print("check t and t_next:", t, t_next)
|
138 |
+
if t == self.args.t_0:
|
139 |
+
break
|
140 |
+
x_lat, h_lat = denoising_step(x_lat, t=t, t_next=t_next, models=model,
|
141 |
+
logvars=self.logvar,
|
142 |
+
# sampling_type=self.args.sample_type,
|
143 |
+
sampling_type='ddim',
|
144 |
+
b=self.betas,
|
145 |
+
eta=0.0,
|
146 |
+
learn_sigma=learn_sigma,
|
147 |
+
)
|
148 |
+
|
149 |
+
progress_bar.update(1)
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
# ----- Editing space ------ #
|
155 |
+
start_distance = self.args.start_distance
|
156 |
+
end_distance = self.args.end_distance
|
157 |
+
edit_img_number = self.args.edit_img_number
|
158 |
+
linspace = np.linspace(start_distance, end_distance, edit_img_number)
|
159 |
+
latent_code = h_lat.cpu().view(1,-1).numpy()
|
160 |
+
linspace = linspace - latent_code.dot(a.T)
|
161 |
+
linspace = linspace.reshape(-1, 1).astype(np.float32)
|
162 |
+
edit_h_seq = latent_code + linspace * a
|
163 |
+
|
164 |
+
|
165 |
+
z_linspace = np.linspace(start_distance, end_distance, edit_img_number)
|
166 |
+
z_latent_code = x_lat.cpu().view(1,-1).numpy()
|
167 |
+
z_linspace = z_linspace - z_latent_code.dot(z_a.T)
|
168 |
+
z_linspace = z_linspace.reshape(-1, 1).astype(np.float32)
|
169 |
+
edit_z_seq = z_latent_code + z_linspace * z_a
|
170 |
+
|
171 |
+
|
172 |
+
for k in range(edit_img_number):
|
173 |
+
time_in_start = time.time()
|
174 |
+
seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0
|
175 |
+
seq_inv = [int(s) for s in list(seq_inv)]
|
176 |
+
seq_inv_next = [-1] + list(seq_inv[:-1])
|
177 |
+
|
178 |
+
with tqdm(total=len(seq_inv), desc="Generative process {}".format(it)) as progress_bar:
|
179 |
+
edit_h = torch.from_numpy(edit_h_seq[k]).to(self.device).view(-1, 512, 8, 8)
|
180 |
+
edit_z = torch.from_numpy(edit_z_seq[k]).to(self.device).view(-1, 3, 256, 256)
|
181 |
+
for i, j in zip(reversed(seq_inv), reversed(seq_inv_next)):
|
182 |
+
t = (torch.ones(n) * i).to(self.device)
|
183 |
+
t_next = (torch.ones(n) * j).to(self.device)
|
184 |
+
edit_z, edit_h = denoising_step(edit_z, t=t, t_next=t_next, models=model,
|
185 |
+
logvars=self.logvar,
|
186 |
+
sampling_type=self.args.sample_type,
|
187 |
+
b=self.betas,
|
188 |
+
eta = 1.0,
|
189 |
+
learn_sigma=learn_sigma,
|
190 |
+
ratio=self.args.model_ratio,
|
191 |
+
hybrid=self.args.hybrid_noise,
|
192 |
+
hybrid_config=HYBRID_CONFIG,
|
193 |
+
edit_h=edit_h,
|
194 |
+
)
|
195 |
+
|
196 |
+
save_edit = "unconditioned_smile_"+str(k)+".png"
|
197 |
+
tvu.save_image((edit_z + 1) * 0.5, os.path.join("edit_output",save_edit))
|
198 |
+
time_in_end = time.time()
|
199 |
+
print(f"Editing for 1 image takes {time_in_end - time_in_start:.4f}s")
|
200 |
+
return
|
201 |
+
|
202 |
+
|
203 |
+
def radius(self):
|
204 |
+
print(self.args.exp)
|
205 |
+
|
206 |
+
# ----------- Model -----------#
|
207 |
+
if self.config.data.dataset == "LSUN":
|
208 |
+
if self.config.data.category == "bedroom":
|
209 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
|
210 |
+
elif self.config.data.category == "church_outdoor":
|
211 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
|
212 |
+
elif self.config.data.dataset == "CelebA_HQ":
|
213 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
|
214 |
+
elif self.config.data.dataset == "AFHQ":
|
215 |
+
pass
|
216 |
+
else:
|
217 |
+
raise ValueError
|
218 |
+
|
219 |
+
if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:
|
220 |
+
model = DDPM(self.config)
|
221 |
+
if self.args.model_path:
|
222 |
+
init_ckpt = torch.load(self.args.model_path)
|
223 |
+
else:
|
224 |
+
init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)
|
225 |
+
learn_sigma = False
|
226 |
+
print("Original diffusion Model loaded.")
|
227 |
+
elif self.config.data.dataset in ["FFHQ", "AFHQ"]:
|
228 |
+
model = i_DDPM(self.config.data.dataset)
|
229 |
+
if self.args.model_path:
|
230 |
+
init_ckpt = torch.load(self.args.model_path)
|
231 |
+
else:
|
232 |
+
init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])
|
233 |
+
learn_sigma = True
|
234 |
+
print("Improved diffusion Model loaded.")
|
235 |
+
else:
|
236 |
+
print('Not implemented dataset')
|
237 |
+
raise ValueError
|
238 |
+
model.load_state_dict(init_ckpt)
|
239 |
+
model.to(self.device)
|
240 |
+
model = torch.nn.DataParallel(model)
|
241 |
+
model.eval()
|
242 |
+
|
243 |
+
|
244 |
+
# ---------- Prepare the seq --------- #
|
245 |
+
|
246 |
+
# seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0
|
247 |
+
seq_inv = np.linspace(0, 1, 999) * 999
|
248 |
+
seq_inv = [int(s) for s in list(seq_inv)]
|
249 |
+
seq_inv_next = [-1] + list(seq_inv[:-1])
|
250 |
+
|
251 |
+
n = 1
|
252 |
+
with torch.no_grad():
|
253 |
+
er = 0
|
254 |
+
x_rand = torch.randn(100, 3, 256, 256, device=self.device)
|
255 |
+
for idx in range(100):
|
256 |
+
x = x_rand[idx, :, :, :].unsqueeze(0)
|
257 |
+
|
258 |
+
with tqdm(total=len(seq_inv), desc=f"Generative process") as progress_bar:
|
259 |
+
for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):
|
260 |
+
t = (torch.ones(n) * i).to(self.device)
|
261 |
+
t_next = (torch.ones(n) * j).to(self.device)
|
262 |
+
if t == 500:
|
263 |
+
break
|
264 |
+
x, _ = denoising_step(x, t=t, t_next=t_next, models=model,
|
265 |
+
logvars=self.logvar,
|
266 |
+
# sampling_type=self.args.sample_type,
|
267 |
+
sampling_type='ddim',
|
268 |
+
b=self.betas,
|
269 |
+
eta=0.0,
|
270 |
+
learn_sigma=learn_sigma,
|
271 |
+
)
|
272 |
+
|
273 |
+
progress_bar.update(1)
|
274 |
+
r_x = compute_radius(x)
|
275 |
+
|
276 |
+
er += r_x
|
277 |
+
print("Check radius at step :", er/100)
|
278 |
+
|
279 |
+
|
280 |
+
return
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
def boundary_search(self):
|
288 |
+
print(self.args.exp)
|
289 |
+
|
290 |
+
# ----------- Model -----------#
|
291 |
+
if self.config.data.dataset == "LSUN":
|
292 |
+
if self.config.data.category == "bedroom":
|
293 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
|
294 |
+
elif self.config.data.category == "church_outdoor":
|
295 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
|
296 |
+
elif self.config.data.dataset == "CelebA_HQ":
|
297 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
|
298 |
+
elif self.config.data.dataset == "AFHQ":
|
299 |
+
pass
|
300 |
+
else:
|
301 |
+
raise ValueError
|
302 |
+
|
303 |
+
if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:
|
304 |
+
model = DDPM(self.config)
|
305 |
+
if self.args.model_path:
|
306 |
+
init_ckpt = torch.load(self.args.model_path)
|
307 |
+
else:
|
308 |
+
init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)
|
309 |
+
learn_sigma = False
|
310 |
+
print("Original diffusion Model loaded.")
|
311 |
+
elif self.config.data.dataset in ["FFHQ", "AFHQ"]:
|
312 |
+
model = i_DDPM(self.config.data.dataset)
|
313 |
+
if self.args.model_path:
|
314 |
+
init_ckpt = torch.load(self.args.model_path)
|
315 |
+
else:
|
316 |
+
init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])
|
317 |
+
learn_sigma = True
|
318 |
+
print("Improved diffusion Model loaded.")
|
319 |
+
else:
|
320 |
+
print('Not implemented dataset')
|
321 |
+
raise ValueError
|
322 |
+
model.load_state_dict(init_ckpt)
|
323 |
+
model.to(self.device)
|
324 |
+
model = torch.nn.DataParallel(model)
|
325 |
+
model.eval()
|
326 |
+
|
327 |
+
|
328 |
+
# ----------- Precompute Latents -----------#
|
329 |
+
print("Prepare identity latent")
|
330 |
+
seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0
|
331 |
+
seq_inv = [int(s) for s in list(seq_inv)]
|
332 |
+
seq_inv_next = [-1] + list(seq_inv[:-1])
|
333 |
+
|
334 |
+
|
335 |
+
n = self.args.bs_train
|
336 |
+
img_lat_pairs_dic = {}
|
337 |
+
for mode in ['train', 'test']:
|
338 |
+
img_lat_pairs = []
|
339 |
+
pairs_path = os.path.join('precomputed/',
|
340 |
+
f'{self.config.data.category}_{mode}_t{self.args.t_0}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_pairs.pth')
|
341 |
+
print(pairs_path)
|
342 |
+
if os.path.exists(pairs_path):
|
343 |
+
print(f'{mode} pairs exists')
|
344 |
+
img_lat_pairs_dic[mode] = torch.load(pairs_path)
|
345 |
+
for step, (x0, x_id, x_lat, mid_h, label) in enumerate(img_lat_pairs_dic[mode]):
|
346 |
+
tvu.save_image((x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'{mode}_{step}_0_orig.png'))
|
347 |
+
tvu.save_image((x_id + 1) * 0.5, os.path.join(self.args.image_folder,
|
348 |
+
f'{mode}_{step}_1_rec_ninv{self.args.n_inv_step}.png'))
|
349 |
+
if step == self.args.n_precomp_img - 1:
|
350 |
+
break
|
351 |
+
continue
|
352 |
+
else:
|
353 |
+
train_dataset, test_dataset = get_dataset(self.config.data.dataset, DATASET_PATHS, self.config)
|
354 |
+
loader_dic = get_dataloader(train_dataset, test_dataset, bs_train=self.args.bs_train,
|
355 |
+
num_workers=self.config.data.num_workers)
|
356 |
+
loader = loader_dic[mode]
|
357 |
+
|
358 |
+
for step, (img, label) in enumerate(loader):
|
359 |
+
# for step, img in enumerate(loader):
|
360 |
+
|
361 |
+
x0 = img.to(self.config.device)
|
362 |
+
tvu.save_image((x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'{mode}_{step}_0_orig.png'))
|
363 |
+
|
364 |
+
x = x0.clone()
|
365 |
+
model.eval()
|
366 |
+
label = label.to(self.config.device)
|
367 |
+
|
368 |
+
# print("check x and label:", x.size(), label)
|
369 |
+
|
370 |
+
|
371 |
+
|
372 |
+
with torch.no_grad():
|
373 |
+
with tqdm(total=len(seq_inv), desc=f"Inversion process {mode} {step}") as progress_bar:
|
374 |
+
for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))):
|
375 |
+
t = (torch.ones(n) * i).to(self.device)
|
376 |
+
t_prev = (torch.ones(n) * j).to(self.device)
|
377 |
+
|
378 |
+
x, mid_h_g = denoising_step(x, t=t, t_next=t_prev, models=model,
|
379 |
+
logvars=self.logvar,
|
380 |
+
sampling_type='ddim',
|
381 |
+
b=self.betas,
|
382 |
+
eta=0,
|
383 |
+
learn_sigma=learn_sigma)
|
384 |
+
|
385 |
+
progress_bar.update(1)
|
386 |
+
x_lat = x.clone()
|
387 |
+
tvu.save_image((x_lat + 1) * 0.5, os.path.join(self.args.image_folder,
|
388 |
+
f'{mode}_{step}_1_lat_ninv{self.args.n_inv_step}.png'))
|
389 |
+
|
390 |
+
with tqdm(total=len(seq_inv), desc=f"Generative process {mode} {step}") as progress_bar:
|
391 |
+
for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):
|
392 |
+
t = (torch.ones(n) * i).to(self.device)
|
393 |
+
t_next = (torch.ones(n) * j).to(self.device)
|
394 |
+
|
395 |
+
x, _ = denoising_step(x, t=t, t_next=t_next, models=model,
|
396 |
+
logvars=self.logvar,
|
397 |
+
sampling_type=self.args.sample_type,
|
398 |
+
b=self.betas,
|
399 |
+
learn_sigma=learn_sigma,
|
400 |
+
# edit_h = mid_h,
|
401 |
+
)
|
402 |
+
|
403 |
+
progress_bar.update(1)
|
404 |
+
|
405 |
+
img_lat_pairs.append([x0, x.detach().clone(), x_lat.detach().clone(), mid_h_g.detach().clone(), label])
|
406 |
+
# img_lat_pairs.append([x0, x.detach().clone(), x_lat.detach().clone(), mid_h_g.detach().clone()])
|
407 |
+
tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
|
408 |
+
f'{mode}_{step}_1_rec_ninv{self.args.n_inv_step}.png'))
|
409 |
+
if step == self.args.n_precomp_img - 1:
|
410 |
+
break
|
411 |
+
|
412 |
+
img_lat_pairs_dic[mode] = img_lat_pairs
|
413 |
+
pairs_path = os.path.join('precomputed/',
|
414 |
+
f'{self.config.data.category}_{mode}_t{self.args.t_0}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_pairs.pth')
|
415 |
+
torch.save(img_lat_pairs, pairs_path)
|
416 |
+
|
417 |
+
# ----------- Training boundaries -----------#
|
418 |
+
print("Start boundary search")
|
419 |
+
print(f"Sampling type: {self.args.sample_type.upper()} with eta {self.args.eta}")
|
420 |
+
if self.args.n_train_step != 0:
|
421 |
+
seq_train = np.linspace(0, 1, self.args.n_train_step) * self.args.t_0
|
422 |
+
seq_train = [int(s) for s in list(seq_train)]
|
423 |
+
print('Uniform skip type')
|
424 |
+
else:
|
425 |
+
seq_train = list(range(self.args.t_0))
|
426 |
+
print('No skip')
|
427 |
+
seq_train_next = [-1] + list(seq_train[:-1])
|
428 |
+
|
429 |
+
seq_test = np.linspace(0, 1, self.args.n_test_step) * self.args.t_0
|
430 |
+
seq_test = [int(s) for s in list(seq_test)]
|
431 |
+
seq_test_next = [-1] + list(seq_test[:-1])
|
432 |
+
|
433 |
+
|
434 |
+
for src_txt, trg_txt in zip(self.src_txts, self.trg_txts):
|
435 |
+
print(f"CHANGE {src_txt} TO {trg_txt}")
|
436 |
+
time_in_start = time.time()
|
437 |
+
|
438 |
+
clf_h = svm.SVC(kernel='linear')
|
439 |
+
clf_z = svm.SVC(kernel='linear')
|
440 |
+
# print("clf model:",clf)
|
441 |
+
|
442 |
+
exp_id = os.path.split(self.args.exp)[-1]
|
443 |
+
save_name_h = f'boundary/{exp_id}_{trg_txt.replace(" ", "_")}_h.sav'
|
444 |
+
save_name_z = f'boundary/{exp_id}_{trg_txt.replace(" ", "_")}_z.sav'
|
445 |
+
n_train = len(img_lat_pairs_dic['train'])
|
446 |
+
|
447 |
+
train_data_z = np.empty([n_train, 3*256*256])
|
448 |
+
train_data_h = np.empty([n_train, 512*8*8])
|
449 |
+
train_label = np.empty([n_train,], dtype=int)
|
450 |
+
|
451 |
+
|
452 |
+
for step, (x0, x_id, x_lat, mid_h, label) in enumerate(img_lat_pairs_dic['train']):
|
453 |
+
train_data_h[step, :] = mid_h.view(1,-1).cpu().numpy()
|
454 |
+
train_data_z[step, :] = x_lat.view(1,-1).cpu().numpy()
|
455 |
+
train_label[step] = label.cpu().numpy()
|
456 |
+
|
457 |
+
|
458 |
+
classifier_h = clf_h.fit(train_data_h, train_label)
|
459 |
+
classifier_z = clf_z.fit(train_data_z, train_label)
|
460 |
+
print(np.shape(train_data_h), np.shape(train_data_z), np.shape(train_label))
|
461 |
+
# a = classifier.coef_.reshape(1, 512*8*8).astype(np.float32)
|
462 |
+
# a = classifier.coef_.reshape(1, 3*256*256).astype(np.float32)
|
463 |
+
# a = a / np.linalg.norm(a)
|
464 |
+
time_in_end = time.time()
|
465 |
+
print(f"Finding boundary takes {time_in_end - time_in_start:.4f}s")
|
466 |
+
print("Finishing boudary seperation!")
|
467 |
+
|
468 |
+
# boudary_save_h = 'smiling_boundary_h.sav'
|
469 |
+
# boudary_save_z = 'smiling_boundary_z.sav'
|
470 |
+
pickle.dump(classifier_h, open(save_name_h, 'wb'))
|
471 |
+
pickle.dump(classifier_z, open(save_name_z, 'wb'))
|
472 |
+
|
473 |
+
# test the accuracy ##
|
474 |
+
n_test = len(img_lat_pairs_dic['test'])
|
475 |
+
test_data_h = np.empty([n_test, 512*8*8])
|
476 |
+
test_data_z = np.empty([n_test, 3*256*256])
|
477 |
+
test_lable = np.empty([n_test,], dtype=int)
|
478 |
+
for step, (x0, x_id, x_lat, mid_h, label) in enumerate(img_lat_pairs_dic['test']):
|
479 |
+
test_data_h[step, :] = mid_h.view(1,-1).cpu().numpy()
|
480 |
+
test_data_z[step, :] = x_lat.view(1,-1).cpu().numpy()
|
481 |
+
test_lable[step] = label.cpu().numpy()
|
482 |
+
classifier_h = pickle.load(open(save_name_h, 'rb'))
|
483 |
+
classifier_z = pickle.load(open(save_name_z, 'rb'))
|
484 |
+
print("Boundary loaded!")
|
485 |
+
val_prediction_h = classifier_h.predict(test_data_h)
|
486 |
+
val_prediction_z = classifier_z.predict(test_data_z)
|
487 |
+
correct_num_h = np.sum(test_lable == val_prediction_h)
|
488 |
+
correct_num_z = np.sum(test_lable == val_prediction_z)
|
489 |
+
# print(val_prediction_h, test_lable)
|
490 |
+
print("Validation accuracy on h and z spaces:", correct_num_h/n_test, correct_num_z/n_test)
|
491 |
+
print("total training and testing", n_train, n_test)
|
492 |
+
|
493 |
+
|
494 |
+
return None
|
495 |
+
|
496 |
+
|
497 |
+
|
498 |
+
|
499 |
+
def edit_image_boundary(self):
|
500 |
+
# ----------- Data -----------#
|
501 |
+
n = self.args.bs_test
|
502 |
+
|
503 |
+
|
504 |
+
if self.args.align_face and self.config.data.dataset in ["FFHQ", "CelebA_HQ"]:
|
505 |
+
try:
|
506 |
+
img = run_alignment(self.args.img_path, output_size=self.config.data.image_size)
|
507 |
+
except:
|
508 |
+
img = Image.open(self.args.img_path).convert("RGB")
|
509 |
+
else:
|
510 |
+
img = Image.open(self.args.img_path).convert("RGB")
|
511 |
+
img = img.resize((self.config.data.image_size, self.config.data.image_size), Image.ANTIALIAS)
|
512 |
+
img = np.array(img)/255
|
513 |
+
img = torch.from_numpy(img).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(dim=0).repeat(n, 1, 1, 1)
|
514 |
+
img = img.to(self.config.device)
|
515 |
+
tvu.save_image(img, os.path.join(self.args.image_folder, f'0_orig.png'))
|
516 |
+
x0 = (img - 0.5) * 2.
|
517 |
+
|
518 |
+
# ----------- Models -----------#
|
519 |
+
if self.config.data.dataset == "LSUN":
|
520 |
+
if self.config.data.category == "bedroom":
|
521 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
|
522 |
+
elif self.config.data.category == "church_outdoor":
|
523 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
|
524 |
+
elif self.config.data.dataset == "CelebA_HQ":
|
525 |
+
url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
|
526 |
+
elif self.config.data.dataset in ["FFHQ", "AFHQ", "IMAGENET"]:
|
527 |
+
pass
|
528 |
+
else:
|
529 |
+
raise ValueError
|
530 |
+
|
531 |
+
if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:
|
532 |
+
model = DDPM(self.config)
|
533 |
+
if self.args.model_path:
|
534 |
+
init_ckpt = torch.load(self.args.model_path)
|
535 |
+
else:
|
536 |
+
init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)
|
537 |
+
learn_sigma = False
|
538 |
+
print("Original diffusion Model loaded.")
|
539 |
+
elif self.config.data.dataset in ["FFHQ", "AFHQ"]:
|
540 |
+
model = i_DDPM(self.config.data.dataset)
|
541 |
+
if self.args.model_path:
|
542 |
+
init_ckpt = torch.load(self.args.model_path)
|
543 |
+
else:
|
544 |
+
init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])
|
545 |
+
learn_sigma = True
|
546 |
+
print("Improved diffusion Model loaded.")
|
547 |
+
else:
|
548 |
+
print('Not implemented dataset')
|
549 |
+
raise ValueError
|
550 |
+
model.load_state_dict(init_ckpt)
|
551 |
+
model.to(self.device)
|
552 |
+
model = torch.nn.DataParallel(model)
|
553 |
+
model.eval()
|
554 |
+
|
555 |
+
# ---------- Load boundary ----------#
|
556 |
+
|
557 |
+
boundary_h = pickle.load(open('./boundary/smile_boundary_h.sav', 'rb'))
|
558 |
+
a = boundary_h.coef_.reshape(1, 512*8*8).astype(np.float32)
|
559 |
+
a = a / np.linalg.norm(a)
|
560 |
+
|
561 |
+
boundary_z = pickle.load(open('./boundary/smile_boundary_z.sav', 'rb'))
|
562 |
+
z_a = boundary_z.coef_.reshape(1, 3*256*256).astype(np.float32)
|
563 |
+
z_a = z_a / np.linalg.norm(z_a) # normalized boundary
|
564 |
+
|
565 |
+
|
566 |
+
print("Boundary loaded! In shape:", np.shape(a), np.shape(z_a))
|
567 |
+
|
568 |
+
|
569 |
+
with torch.no_grad():
|
570 |
+
#---------------- Invert Image to Latent in case of Deterministic Inversion process -------------------#
|
571 |
+
if self.args.deterministic_inv:
|
572 |
+
x_lat_path = os.path.join(self.args.image_folder, f'x_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}.pth')
|
573 |
+
h_lat_path = os.path.join(self.args.image_folder, f'h_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}.pth')
|
574 |
+
if not os.path.exists(x_lat_path):
|
575 |
+
seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0
|
576 |
+
seq_inv = [int(s) for s in list(seq_inv)]
|
577 |
+
seq_inv_next = [-1] + list(seq_inv[:-1])
|
578 |
+
|
579 |
+
x = x0.clone()
|
580 |
+
with tqdm(total=len(seq_inv), desc=f"Inversion process ") as progress_bar:
|
581 |
+
for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))):
|
582 |
+
t = (torch.ones(n) * i).to(self.device)
|
583 |
+
t_prev = (torch.ones(n) * j).to(self.device)
|
584 |
+
|
585 |
+
x, mid_h_g = denoising_step(x, t=t, t_next=t_prev, models=model,
|
586 |
+
logvars=self.logvar,
|
587 |
+
sampling_type='ddim',
|
588 |
+
b=self.betas,
|
589 |
+
eta=0,
|
590 |
+
learn_sigma=learn_sigma,
|
591 |
+
ratio=0,
|
592 |
+
)
|
593 |
+
|
594 |
+
|
595 |
+
progress_bar.update(1)
|
596 |
+
x_lat = x.clone()
|
597 |
+
h_lat = mid_h_g.clone()
|
598 |
+
torch.save(x_lat, x_lat_path)
|
599 |
+
torch.save(h_lat, h_lat_path)
|
600 |
+
|
601 |
+
else:
|
602 |
+
print('Latent exists.')
|
603 |
+
x_lat = torch.load(x_lat_path)
|
604 |
+
h_lat = torch.load(h_lat_path)
|
605 |
+
print("Finish inversion for the given image!", h_lat.size())
|
606 |
+
|
607 |
+
|
608 |
+
# ----------- Generative Process -----------#
|
609 |
+
print(f"Sampling type: {self.args.sample_type.upper()} with eta {self.args.eta}, "
|
610 |
+
f" Steps: {self.args.n_test_step}/{self.args.t_0}")
|
611 |
+
|
612 |
+
|
613 |
+
# ----- Editing space ------ #
|
614 |
+
start_distance = self.args.start_distance
|
615 |
+
end_distance = self.args.end_distance
|
616 |
+
edit_img_number = self.args.edit_img_number
|
617 |
+
# [-100, 100]
|
618 |
+
linspace = np.linspace(start_distance, end_distance, edit_img_number)
|
619 |
+
latent_code = h_lat.cpu().view(1,-1).numpy()
|
620 |
+
linspace = linspace - latent_code.dot(a.T)
|
621 |
+
linspace = linspace.reshape(-1, 1).astype(np.float32)
|
622 |
+
edit_h_seq = latent_code + linspace * a
|
623 |
+
|
624 |
+
|
625 |
+
z_linspace = np.linspace(start_distance, end_distance, edit_img_number)
|
626 |
+
z_latent_code = x_lat.cpu().view(1,-1).numpy()
|
627 |
+
z_linspace = z_linspace - z_latent_code.dot(z_a.T)
|
628 |
+
z_linspace = z_linspace.reshape(-1, 1).astype(np.float32)
|
629 |
+
edit_z_seq = z_latent_code + z_linspace * z_a
|
630 |
+
|
631 |
+
|
632 |
+
if self.args.n_test_step != 0:
|
633 |
+
seq_test = np.linspace(0, 1, self.args.n_test_step) * self.args.t_0
|
634 |
+
seq_test = [int(s) for s in list(seq_test)]
|
635 |
+
print('Uniform skip type')
|
636 |
+
else:
|
637 |
+
seq_test = list(range(self.args.t_0))
|
638 |
+
print('No skip')
|
639 |
+
seq_test_next = [-1] + list(seq_test[:-1])
|
640 |
+
|
641 |
+
for it in range(self.args.n_iter):
|
642 |
+
if self.args.deterministic_inv:
|
643 |
+
x = x_lat.clone()
|
644 |
+
else:
|
645 |
+
e = torch.randn_like(x0)
|
646 |
+
a = (1 - self.betas).cumprod(dim=0)
|
647 |
+
x = x0 * a[self.args.t_0 - 1].sqrt() + e * (1.0 - a[self.args.t_0 - 1]).sqrt()
|
648 |
+
tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
|
649 |
+
f'1_lat_ninv{self.args.n_inv_step}.png'))
|
650 |
+
|
651 |
+
|
652 |
+
for k in range(edit_img_number):
|
653 |
+
time_in_start = time.time()
|
654 |
+
|
655 |
+
with tqdm(total=len(seq_test), desc="Generative process {}".format(it)) as progress_bar:
|
656 |
+
edit_h = torch.from_numpy(edit_h_seq[k]).to(self.device).view(-1, 512, 8, 8)
|
657 |
+
edit_z = torch.from_numpy(edit_z_seq[k]).to(self.device).view(-1, 3, 256, 256)
|
658 |
+
for i, j in zip(reversed(seq_test), reversed(seq_test_next)):
|
659 |
+
t = (torch.ones(n) * i).to(self.device)
|
660 |
+
t_next = (torch.ones(n) * j).to(self.device)
|
661 |
+
|
662 |
+
edit_z, edit_h = denoising_step(edit_z, t=t, t_next=t_next, models=model,
|
663 |
+
logvars=self.logvar,
|
664 |
+
sampling_type=self.args.sample_type,
|
665 |
+
b=self.betas,
|
666 |
+
eta = 1.0,
|
667 |
+
learn_sigma=learn_sigma,
|
668 |
+
ratio=self.args.model_ratio,
|
669 |
+
hybrid=self.args.hybrid_noise,
|
670 |
+
hybrid_config=HYBRID_CONFIG,
|
671 |
+
edit_h=edit_h,
|
672 |
+
)
|
673 |
+
|
674 |
+
|
675 |
+
x0 = x.clone()
|
676 |
+
save_edit = "edited_"+str(k)+".png"
|
677 |
+
tvu.save_image((edit_z + 1) * 0.5, os.path.join("edit_output",save_edit))
|
678 |
+
time_in_end = time.time()
|
679 |
+
print(f"Editing for 1 image takes {time_in_end - time_in_start:.4f}s")
|
680 |
+
|
681 |
+
|
682 |
+
# this is for recons
|
683 |
+
with tqdm(total=len(seq_test), desc="Generative process {}".format(it)) as progress_bar:
|
684 |
+
for i, j in zip(reversed(seq_test), reversed(seq_test_next)):
|
685 |
+
t = (torch.ones(n) * i).to(self.device)
|
686 |
+
t_next = (torch.ones(n) * j).to(self.device)
|
687 |
+
x_lat, _ = denoising_step(x_lat, t=t, t_next=t_next, models=model,
|
688 |
+
logvars=self.logvar,
|
689 |
+
sampling_type=self.args.sample_type,
|
690 |
+
b=self.betas,
|
691 |
+
# eta=self.args.eta,
|
692 |
+
eta = 0.0,
|
693 |
+
learn_sigma=learn_sigma,
|
694 |
+
ratio=self.args.model_ratio,
|
695 |
+
hybrid=self.args.hybrid_noise,
|
696 |
+
hybrid_config=HYBRID_CONFIG,
|
697 |
+
edit_h=None,
|
698 |
+
)
|
699 |
+
|
700 |
+
# added intermediate step vis
|
701 |
+
if (i - 99) % 100 == 0:
|
702 |
+
tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
|
703 |
+
f'2_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}_ngen{self.args.n_test_step}_{i}_it{it}.png'))
|
704 |
+
progress_bar.update(1)
|
705 |
+
|
706 |
+
x0 = x.clone()
|
707 |
+
save_edit = "recons.png"
|
708 |
+
tvu.save_image((x_lat + 1) * 0.5, os.path.join("edit_output",save_edit))
|
709 |
+
|
710 |
+
return None
|
711 |
+
|
712 |
+
|
713 |
+
|
configs/afhq.yml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
dataset: "AFHQ"
|
3 |
+
category: "dog"
|
4 |
+
image_size: 256
|
5 |
+
channels: 3
|
6 |
+
logit_transform: false
|
7 |
+
uniform_dequantization: false
|
8 |
+
gaussian_dequantization: false
|
9 |
+
random_flip: true
|
10 |
+
rescaled: true
|
11 |
+
num_workers: 0
|
12 |
+
|
13 |
+
model:
|
14 |
+
type: "simple"
|
15 |
+
in_channels: 3
|
16 |
+
out_ch: 3
|
17 |
+
ch: 128
|
18 |
+
ch_mult: [1, 1, 2, 2, 4, 4]
|
19 |
+
num_res_blocks: 2
|
20 |
+
attn_resolutions: [16, ]
|
21 |
+
dropout: 0.0
|
22 |
+
var_type: fixedsmall
|
23 |
+
ema_rate: 0.999
|
24 |
+
ema: True
|
25 |
+
resamp_with_conv: True
|
26 |
+
|
27 |
+
diffusion:
|
28 |
+
beta_schedule: linear
|
29 |
+
beta_start: 0.0001
|
30 |
+
beta_end: 0.02
|
31 |
+
num_diffusion_timesteps: 1000
|
32 |
+
|
33 |
+
sampling:
|
34 |
+
batch_size: 4
|
35 |
+
last_only: True
|
configs/bedroom.yml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
dataset: "LSUN"
|
3 |
+
category: "bedroom"
|
4 |
+
image_size: 256
|
5 |
+
channels: 3
|
6 |
+
logit_transform: false
|
7 |
+
uniform_dequantization: false
|
8 |
+
gaussian_dequantization: false
|
9 |
+
random_flip: true
|
10 |
+
rescaled: true
|
11 |
+
num_workers: 0
|
12 |
+
|
13 |
+
model:
|
14 |
+
type: "simple"
|
15 |
+
in_channels: 3
|
16 |
+
out_ch: 3
|
17 |
+
ch: 128
|
18 |
+
ch_mult: [1, 1, 2, 2, 4, 4]
|
19 |
+
num_res_blocks: 2
|
20 |
+
attn_resolutions: [16, ]
|
21 |
+
dropout: 0.0
|
22 |
+
var_type: fixedsmall
|
23 |
+
ema_rate: 0.999
|
24 |
+
ema: True
|
25 |
+
resamp_with_conv: True
|
26 |
+
|
27 |
+
diffusion:
|
28 |
+
beta_schedule: linear
|
29 |
+
beta_start: 0.0001
|
30 |
+
beta_end: 0.02
|
31 |
+
num_diffusion_timesteps: 1000
|
32 |
+
|
33 |
+
sampling:
|
34 |
+
batch_size: 4
|
35 |
+
last_only: True
|
configs/celeba.yml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
dataset: "CelebA_HQ"
|
3 |
+
category: "CelebA_HQ"
|
4 |
+
image_size: 256
|
5 |
+
channels: 3
|
6 |
+
logit_transform: false
|
7 |
+
uniform_dequantization: false
|
8 |
+
gaussian_dequantization: false
|
9 |
+
random_flip: true
|
10 |
+
rescaled: true
|
11 |
+
num_workers: 0
|
12 |
+
|
13 |
+
model:
|
14 |
+
type: "simple"
|
15 |
+
in_channels: 3
|
16 |
+
out_ch: 3
|
17 |
+
ch: 128
|
18 |
+
ch_mult: [1, 1, 2, 2, 4, 4]
|
19 |
+
num_res_blocks: 2
|
20 |
+
attn_resolutions: [16, ]
|
21 |
+
dropout: 0.0
|
22 |
+
var_type: fixedsmall
|
23 |
+
ema_rate: 0.999
|
24 |
+
ema: True
|
25 |
+
resamp_with_conv: True
|
26 |
+
|
27 |
+
diffusion:
|
28 |
+
beta_schedule: linear
|
29 |
+
beta_start: 0.0001
|
30 |
+
beta_end: 0.02
|
31 |
+
num_diffusion_timesteps: 1000
|
32 |
+
|
33 |
+
sampling:
|
34 |
+
batch_size: 4
|
35 |
+
last_only: True
|
configs/church.yml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
dataset: "LSUN"
|
3 |
+
category: "church_outdoor"
|
4 |
+
image_size: 256
|
5 |
+
channels: 3
|
6 |
+
logit_transform: false
|
7 |
+
uniform_dequantization: false
|
8 |
+
gaussian_dequantization: false
|
9 |
+
random_flip: true
|
10 |
+
rescaled: true
|
11 |
+
num_workers: 0
|
12 |
+
|
13 |
+
model:
|
14 |
+
type: "simple"
|
15 |
+
in_channels: 3
|
16 |
+
out_ch: 3
|
17 |
+
ch: 128
|
18 |
+
ch_mult: [1, 1, 2, 2, 4, 4]
|
19 |
+
num_res_blocks: 2
|
20 |
+
attn_resolutions: [16, ]
|
21 |
+
dropout: 0.0
|
22 |
+
var_type: fixedsmall
|
23 |
+
ema_rate: 0.999
|
24 |
+
ema: True
|
25 |
+
resamp_with_conv: True
|
26 |
+
|
27 |
+
diffusion:
|
28 |
+
beta_schedule: linear
|
29 |
+
beta_start: 0.0001
|
30 |
+
beta_end: 0.02
|
31 |
+
num_diffusion_timesteps: 1000
|
32 |
+
|
33 |
+
sampling:
|
34 |
+
batch_size: 4
|
35 |
+
last_only: True
|
configs/imagenet.yml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
dataset: "IMAGENET"
|
3 |
+
category: "IMAGENET"
|
4 |
+
image_size: 512
|
5 |
+
channels: 3
|
6 |
+
logit_transform: false
|
7 |
+
uniform_dequantization: false
|
8 |
+
gaussian_dequantization: false
|
9 |
+
random_flip: true
|
10 |
+
rescaled: true
|
11 |
+
num_workers: 0
|
12 |
+
|
13 |
+
model:
|
14 |
+
type: "simple"
|
15 |
+
in_channels: 3
|
16 |
+
out_ch: 3
|
17 |
+
ch: 128
|
18 |
+
ch_mult: [1, 1, 2, 2, 4, 4]
|
19 |
+
num_res_blocks: 2
|
20 |
+
attn_resolutions: [16, ]
|
21 |
+
dropout: 0.0
|
22 |
+
var_type: fixedsmall
|
23 |
+
ema_rate: 0.999
|
24 |
+
ema: True
|
25 |
+
resamp_with_conv: True
|
26 |
+
|
27 |
+
diffusion:
|
28 |
+
beta_schedule: linear
|
29 |
+
beta_start: 0.0001
|
30 |
+
beta_end: 0.02
|
31 |
+
num_diffusion_timesteps: 1000
|
32 |
+
|
33 |
+
sampling:
|
34 |
+
batch_size: 4
|
35 |
+
last_only: True
|
configs/paths_config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET_PATHS = {
|
2 |
+
'FFHQ': '/n/fs/visualai-scr/Data/CelebA-HQ/',
|
3 |
+
'CelebA_HQ': '/n/fs/visualai-scr/Data/CelebA-HQ/',
|
4 |
+
'AFHQ': '/n/fs/visualai-scr/Data/AFHQ-Dog/',
|
5 |
+
'LSUN': '/n/fs/yz-diff/dataset/',
|
6 |
+
'IMAGENET': 'data/imagenet/',
|
7 |
+
}
|
8 |
+
|
9 |
+
MODEL_PATHS = {
|
10 |
+
'AFHQ': "pretrained/afhqdog_p2.pt",
|
11 |
+
'FFHQ': "pretrained/ffhq_10m.pt",
|
12 |
+
'ir_se50': 'pretrained/model_ir_se50.pth',
|
13 |
+
'IMAGENET': "pretrained/512x512_diffusion.pt",
|
14 |
+
'shape_predictor': "pretrained/shape_predictor_68_face_landmarks.dat.bz2",
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
HYBRID_MODEL_PATHS = [
|
19 |
+
'./checkpoint/human_face/curly_hair_t401.pth',
|
20 |
+
'./checkpoint/human_face/with_makeup_t401.pth',
|
21 |
+
]
|
22 |
+
|
23 |
+
HYBRID_CONFIG = \
|
24 |
+
{ 300: [0.4, 0.6, 0],
|
25 |
+
0: [0.15, 0.15, 0.7]}
|
data_download.sh
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified version of download.sh in https://github.com/naver-ai/StyleMapGAN
|
3 |
+
"""
|
4 |
+
|
5 |
+
DATASET=$1
|
6 |
+
BASE_DIR=$2
|
7 |
+
|
8 |
+
if [ $DATASET == "celeba_hq" ]; then
|
9 |
+
URL="https://docs.google.com/uc?export=download&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o"
|
10 |
+
DATASET_FOLDER="/n/fs/visualai-scr/Data/CelebA-HQ"
|
11 |
+
ZIP_FILE=$DATASET_FOLDER/celeba_hq_raw.zip
|
12 |
+
elif [ $DATASET == "afhq" ]; then
|
13 |
+
URL="https://docs.google.com/uc?export=download&id=1Pf4f6Y27lQX9y9vjeSQnoOQntw_ln7il"
|
14 |
+
DATASET_FOLDER="./data/afhq"
|
15 |
+
ZIP_FILE=$DATASET_FOLDER/afhq_raw.zip
|
16 |
+
else
|
17 |
+
echo "Unknown DATASET"
|
18 |
+
exit 1
|
19 |
+
fi
|
20 |
+
mkdir -p $DATASET_FOLDER
|
21 |
+
|
22 |
+
# wget --no-check-certificate -r $URL -O $ZIP_FILE
|
23 |
+
|
24 |
+
# wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate $URL -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o" -O $ZIP_FILE && rm -rf ~/cookies.txt
|
25 |
+
# unzip $ZIP_FILE -d $DATASET_FOLDER
|
26 |
+
# rm $ZIP_FILE
|
27 |
+
|
28 |
+
# raw images to LMDB format
|
29 |
+
TARGET_SIZE=256,1024
|
30 |
+
for DATASET_TYPE in "train" "test" "val"; do
|
31 |
+
python utils/prepare_lmdb_data.py --out $DATASET_FOLDER/LMDB_$DATASET_TYPE --size $TARGET_SIZE $DATASET_FOLDER/raw_images/$DATASET_TYPE --attr gender
|
32 |
+
done
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o" -O a.zip && rm -rf ~/cookies.txt
|
38 |
+
|
datasets/AFHQ_dataset.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from glob import glob
|
3 |
+
import os
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torchvision.transforms as tfs
|
6 |
+
|
7 |
+
class AFHQ_dataset(Dataset):
|
8 |
+
def __init__(self, image_root, transform=None, mode='train', animal_class='dog', img_size=256):
|
9 |
+
super().__init__()
|
10 |
+
self.image_paths = glob(os.path.join(image_root, mode, animal_class, '*.jpg'))
|
11 |
+
self.transform = transform
|
12 |
+
self.img_size = img_size
|
13 |
+
|
14 |
+
def __getitem__(self, index):
|
15 |
+
image_path = self.image_paths[index]
|
16 |
+
x = Image.open(image_path)
|
17 |
+
x = x.resize((self.img_size, self.img_size))
|
18 |
+
if self.transform is not None:
|
19 |
+
x = self.transform(x)
|
20 |
+
return x
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.image_paths)
|
24 |
+
|
25 |
+
|
26 |
+
################################################################################
|
27 |
+
|
28 |
+
def get_afhq_dataset(data_root, config):
|
29 |
+
train_transform = tfs.Compose([tfs.ToTensor(),
|
30 |
+
tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
31 |
+
inplace=True)])
|
32 |
+
|
33 |
+
test_transform = tfs.Compose([tfs.ToTensor(),
|
34 |
+
tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
35 |
+
inplace=True)])
|
36 |
+
|
37 |
+
train_dataset = AFHQ_dataset(data_root, transform=train_transform, mode='train', animal_class='dog',
|
38 |
+
img_size=config.data.image_size)
|
39 |
+
test_dataset = AFHQ_dataset(data_root, transform=test_transform, mode='val', animal_class='dog',
|
40 |
+
img_size=config.data.image_size)
|
41 |
+
|
42 |
+
return train_dataset, test_dataset
|
datasets/CelebA_HQ_dataset.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import lmdb
|
3 |
+
from io import BytesIO
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision.transforms as tfs
|
6 |
+
import os
|
7 |
+
|
8 |
+
class MultiResolutionDataset(Dataset):
|
9 |
+
def __init__(self, path, transform, resolution=256):
|
10 |
+
self.env = lmdb.open(
|
11 |
+
path,
|
12 |
+
max_readers=32,
|
13 |
+
readonly=True,
|
14 |
+
lock=False,
|
15 |
+
readahead=False,
|
16 |
+
meminit=False,
|
17 |
+
# attribute=,
|
18 |
+
)
|
19 |
+
|
20 |
+
if not self.env:
|
21 |
+
raise IOError("Cannot open lmdb dataset", path)
|
22 |
+
|
23 |
+
with self.env.begin(write=False) as txn:
|
24 |
+
self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))
|
25 |
+
|
26 |
+
self.resolution = resolution
|
27 |
+
self.transform = transform
|
28 |
+
|
29 |
+
attr_file_path = '/n/fs/yz-diff/inversion/list_attr_celeba.txt'
|
30 |
+
self.labels = file_to_list(attr_file_path)
|
31 |
+
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return self.length
|
35 |
+
|
36 |
+
def __getitem__(self, index):
|
37 |
+
with self.env.begin(write=False) as txn:
|
38 |
+
key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8")
|
39 |
+
key_label = f"{str(index).zfill(5)}".encode("utf-8")
|
40 |
+
print("check key:", key, key_label)
|
41 |
+
img_bytes = txn.get(key)
|
42 |
+
img_id = int(txn.get(key_label).decode("utf-8"))
|
43 |
+
|
44 |
+
buffer = BytesIO(img_bytes)
|
45 |
+
img = Image.open(buffer)
|
46 |
+
img = self.transform(img)
|
47 |
+
|
48 |
+
attr_label = self.labels[img_id-1].split()
|
49 |
+
# map the attr to the index position
|
50 |
+
label = int(attr_label[32])
|
51 |
+
print("check img_id and label:", img_id, label)
|
52 |
+
|
53 |
+
|
54 |
+
return img, label
|
55 |
+
|
56 |
+
|
57 |
+
################################################################################
|
58 |
+
|
59 |
+
def get_celeba_dataset(data_root, config):
|
60 |
+
train_transform = tfs.Compose([tfs.ToTensor(),
|
61 |
+
tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
62 |
+
inplace=True)])
|
63 |
+
|
64 |
+
test_transform = tfs.Compose([tfs.ToTensor(),
|
65 |
+
tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
66 |
+
inplace=True)])
|
67 |
+
|
68 |
+
train_dataset = MultiResolutionDataset(os.path.join(data_root, 'LMDB_train'),
|
69 |
+
train_transform, config.data.image_size)
|
70 |
+
test_dataset = MultiResolutionDataset(os.path.join(data_root, 'LMDB_test'),
|
71 |
+
test_transform, config.data.image_size)
|
72 |
+
|
73 |
+
|
74 |
+
return train_dataset, test_dataset
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def file_to_list(filename):
|
79 |
+
with open(filename, encoding='utf-8') as f:
|
80 |
+
files = f.readlines()
|
81 |
+
files = [f.rstrip() for f in files]
|
82 |
+
return files
|
83 |
+
|
datasets/CelebA_HQ_dataset_with_label.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import lmdb
|
3 |
+
from io import BytesIO
|
4 |
+
from PIL import Image
|
5 |
+
import torchvision.transforms as tfs
|
6 |
+
import os
|
7 |
+
|
8 |
+
class MultiResolutionDataset(Dataset):
|
9 |
+
def __init__(self, path, transform, resolution=256):
|
10 |
+
self.env = lmdb.open(
|
11 |
+
path,
|
12 |
+
max_readers=32,
|
13 |
+
readonly=True,
|
14 |
+
lock=False,
|
15 |
+
readahead=False,
|
16 |
+
meminit=False,
|
17 |
+
)
|
18 |
+
|
19 |
+
if not self.env:
|
20 |
+
raise IOError("Cannot open lmdb dataset", path)
|
21 |
+
|
22 |
+
with self.env.begin(write=False) as txn:
|
23 |
+
self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))
|
24 |
+
|
25 |
+
self.resolution = resolution
|
26 |
+
self.transform = transform
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return self.length
|
30 |
+
|
31 |
+
def __getitem__(self, index):
|
32 |
+
with self.env.begin(write=False) as txn:
|
33 |
+
key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8")
|
34 |
+
img_bytes = txn.get(key)
|
35 |
+
|
36 |
+
buffer = BytesIO(img_bytes)
|
37 |
+
img = Image.open(buffer)
|
38 |
+
img = self.transform(img)
|
39 |
+
|
40 |
+
return img
|
41 |
+
|
42 |
+
|
43 |
+
################################################################################
|
44 |
+
|
45 |
+
def get_celeba_dataset(data_root, config):
|
46 |
+
train_transform = tfs.Compose([tfs.ToTensor(),
|
47 |
+
tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
48 |
+
inplace=True)])
|
49 |
+
|
50 |
+
test_transform = tfs.Compose([tfs.ToTensor(),
|
51 |
+
tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
52 |
+
inplace=True)])
|
53 |
+
|
54 |
+
train_dataset = MultiResolutionDataset(os.path.join(data_root, 'LMDB_train'),
|
55 |
+
train_transform, config.data.image_size)
|
56 |
+
test_dataset = MultiResolutionDataset(os.path.join(data_root, 'LMDB_test'),
|
57 |
+
test_transform, config.data.image_size)
|
58 |
+
|
59 |
+
|
60 |
+
return train_dataset, test_dataset
|
61 |
+
|
62 |
+
|
63 |
+
|
datasets/IMAGENET_dataset.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from glob import glob
|
3 |
+
import os
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
from .imagenet_dic import IMAGENET_DIC
|
9 |
+
|
10 |
+
def get_imagenet_dataset(data_root, config, class_num=None, random_crop=True, random_flip=False):
|
11 |
+
train_dataset = IMAGENET_dataset(data_root, mode='train', class_num=class_num, img_size=config.data.image_size,
|
12 |
+
random_crop=random_crop, random_flip=random_flip)
|
13 |
+
test_dataset = IMAGENET_dataset(data_root, mode='val', class_num=class_num, img_size=config.data.image_size,
|
14 |
+
random_crop=random_crop, random_flip=random_flip)
|
15 |
+
|
16 |
+
return train_dataset, test_dataset
|
17 |
+
|
18 |
+
|
19 |
+
###################################################################
|
20 |
+
|
21 |
+
|
22 |
+
class IMAGENET_dataset(Dataset):
|
23 |
+
def __init__(self, image_root, mode='val', class_num=None, img_size=512, random_crop=True, random_flip=False):
|
24 |
+
super().__init__()
|
25 |
+
if class_num is not None:
|
26 |
+
self.data_dir = os.path.join(image_root, mode, IMAGENET_DIC[str(class_num)][0], '*.JPEG')
|
27 |
+
self.image_paths = sorted(glob(self.data_dir))
|
28 |
+
else:
|
29 |
+
self.data_dir = os.path.join(image_root, mode, '*', '*.JPEG')
|
30 |
+
self.image_paths = sorted(glob(self.data_dir))
|
31 |
+
self.img_size = img_size
|
32 |
+
self.random_crop = random_crop
|
33 |
+
self.random_flip = random_flip
|
34 |
+
self.class_num = class_num
|
35 |
+
|
36 |
+
def __getitem__(self, index):
|
37 |
+
f = self.image_paths[index]
|
38 |
+
pil_image = Image.open(f)
|
39 |
+
pil_image.load()
|
40 |
+
pil_image = pil_image.convert("RGB")
|
41 |
+
|
42 |
+
if self.random_crop:
|
43 |
+
arr = random_crop_arr(pil_image, self.img_size)
|
44 |
+
else:
|
45 |
+
arr = center_crop_arr(pil_image, self.img_size)
|
46 |
+
|
47 |
+
if self.random_flip and random.random() < 0.5:
|
48 |
+
arr = arr[:, ::-1]
|
49 |
+
|
50 |
+
arr = arr.astype(np.float32) / 127.5 - 1
|
51 |
+
|
52 |
+
# y = [self.class_num, IMAGENET_DIC[str(self.class_num)][0], IMAGENET_DIC[str(self.class_num)][1]]
|
53 |
+
# y = self.class_num
|
54 |
+
|
55 |
+
return np.transpose(arr, [2, 0, 1])#, y
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.image_paths)
|
59 |
+
|
60 |
+
|
61 |
+
def center_crop_arr(pil_image, image_size):
|
62 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
63 |
+
# argument, which uses BOX downsampling at powers of two first.
|
64 |
+
# Thus, we do it by hand to improve downsample quality.
|
65 |
+
while min(*pil_image.size) >= 2 * image_size:
|
66 |
+
pil_image = pil_image.resize(
|
67 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
68 |
+
)
|
69 |
+
|
70 |
+
scale = image_size / min(*pil_image.size)
|
71 |
+
pil_image = pil_image.resize(
|
72 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
73 |
+
)
|
74 |
+
|
75 |
+
arr = np.array(pil_image)
|
76 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
77 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
78 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
79 |
+
|
80 |
+
|
81 |
+
def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
|
82 |
+
min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
|
83 |
+
max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
|
84 |
+
smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
|
85 |
+
|
86 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
87 |
+
# argument, which uses BOX downsampling at powers of two first.
|
88 |
+
# Thus, we do it by hand to improve downsample quality.
|
89 |
+
while min(*pil_image.size) >= 2 * smaller_dim_size:
|
90 |
+
pil_image = pil_image.resize(
|
91 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
92 |
+
)
|
93 |
+
|
94 |
+
scale = smaller_dim_size / min(*pil_image.size)
|
95 |
+
pil_image = pil_image.resize(
|
96 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
97 |
+
)
|
98 |
+
|
99 |
+
arr = np.array(pil_image)
|
100 |
+
crop_y = random.randrange(arr.shape[0] - image_size + 1)
|
101 |
+
crop_x = random.randrange(arr.shape[1] - image_size + 1)
|
102 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
datasets/LSUN_dataset.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from collections.abc import Iterable
|
3 |
+
from torchvision.datasets.utils import verify_str_arg, iterable_to_str
|
4 |
+
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import io
|
8 |
+
import pickle
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
import torch.utils.data as data
|
12 |
+
import torchvision.transforms as tfs
|
13 |
+
|
14 |
+
class VisionDataset(data.Dataset):
|
15 |
+
_repr_indent = 4
|
16 |
+
|
17 |
+
def __init__(self, root, transforms=None, transform=None, target_transform=None):
|
18 |
+
if isinstance(root, torch._six.string_classes):
|
19 |
+
root = os.path.expanduser(root)
|
20 |
+
self.root = root
|
21 |
+
|
22 |
+
has_transforms = transforms is not None
|
23 |
+
has_separate_transform = transform is not None or target_transform is not None
|
24 |
+
if has_transforms and has_separate_transform:
|
25 |
+
raise ValueError("Only transforms or transform/target_transform can "
|
26 |
+
"be passed as argument")
|
27 |
+
|
28 |
+
# for backwards-compatibility
|
29 |
+
self.transform = transform
|
30 |
+
self.target_transform = target_transform
|
31 |
+
|
32 |
+
if has_separate_transform:
|
33 |
+
transforms = StandardTransform(transform, target_transform)
|
34 |
+
self.transforms = transforms
|
35 |
+
|
36 |
+
def __getitem__(self, index):
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
def __repr__(self):
|
43 |
+
head = "Dataset " + self.__class__.__name__
|
44 |
+
body = ["Number of datapoints: {}".format(self.__len__())]
|
45 |
+
if self.root is not None:
|
46 |
+
body.append("Root location: {}".format(self.root))
|
47 |
+
body += self.extra_repr().splitlines()
|
48 |
+
if hasattr(self, 'transform') and self.transform is not None:
|
49 |
+
body += self._format_transform_repr(self.transform,
|
50 |
+
"Transforms: ")
|
51 |
+
if hasattr(self, 'target_transform') and self.target_transform is not None:
|
52 |
+
body += self._format_transform_repr(self.target_transform,
|
53 |
+
"Target transforms: ")
|
54 |
+
lines = [head] + [" " * self._repr_indent + line for line in body]
|
55 |
+
return '\n'.join(lines)
|
56 |
+
|
57 |
+
def _format_transform_repr(self, transform, head):
|
58 |
+
lines = transform.__repr__().splitlines()
|
59 |
+
return (["{}{}".format(head, lines[0])] +
|
60 |
+
["{}{}".format(" " * len(head), line) for line in lines[1:]])
|
61 |
+
|
62 |
+
def extra_repr(self):
|
63 |
+
return ""
|
64 |
+
|
65 |
+
|
66 |
+
class StandardTransform(object):
|
67 |
+
def __init__(self, transform=None, target_transform=None):
|
68 |
+
self.transform = transform
|
69 |
+
self.target_transform = target_transform
|
70 |
+
|
71 |
+
def __call__(self, input, target):
|
72 |
+
if self.transform is not None:
|
73 |
+
input = self.transform(input)
|
74 |
+
if self.target_transform is not None:
|
75 |
+
target = self.target_transform(target)
|
76 |
+
return input, target
|
77 |
+
|
78 |
+
def _format_transform_repr(self, transform, head):
|
79 |
+
lines = transform.__repr__().splitlines()
|
80 |
+
return (["{}{}".format(head, lines[0])] +
|
81 |
+
["{}{}".format(" " * len(head), line) for line in lines[1:]])
|
82 |
+
|
83 |
+
def __repr__(self):
|
84 |
+
body = [self.__class__.__name__]
|
85 |
+
if self.transform is not None:
|
86 |
+
body += self._format_transform_repr(self.transform,
|
87 |
+
"Transform: ")
|
88 |
+
if self.target_transform is not None:
|
89 |
+
body += self._format_transform_repr(self.target_transform,
|
90 |
+
"Target transform: ")
|
91 |
+
|
92 |
+
return '\n'.join(body)
|
93 |
+
|
94 |
+
################################################################
|
95 |
+
|
96 |
+
class LSUNClass(VisionDataset):
|
97 |
+
def __init__(self, root, transform=None, target_transform=None):
|
98 |
+
import lmdb
|
99 |
+
|
100 |
+
super(LSUNClass, self).__init__(
|
101 |
+
root, transform=transform, target_transform=target_transform
|
102 |
+
)
|
103 |
+
|
104 |
+
self.env = lmdb.open(
|
105 |
+
root,
|
106 |
+
max_readers=1,
|
107 |
+
readonly=True,
|
108 |
+
lock=False,
|
109 |
+
readahead=False,
|
110 |
+
meminit=False,
|
111 |
+
)
|
112 |
+
with self.env.begin(write=False) as txn:
|
113 |
+
self.length = txn.stat()["entries"]
|
114 |
+
root_split = root.split("/")
|
115 |
+
cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}")
|
116 |
+
if os.path.isfile(cache_file):
|
117 |
+
self.keys = pickle.load(open(cache_file, "rb"))
|
118 |
+
else:
|
119 |
+
with self.env.begin(write=False) as txn:
|
120 |
+
self.keys = [key for key, _ in txn.cursor()]
|
121 |
+
pickle.dump(self.keys, open(cache_file, "wb"))
|
122 |
+
|
123 |
+
def __getitem__(self, index):
|
124 |
+
img, target = None, None
|
125 |
+
env = self.env
|
126 |
+
with env.begin(write=False) as txn:
|
127 |
+
imgbuf = txn.get(self.keys[index])
|
128 |
+
|
129 |
+
buf = io.BytesIO()
|
130 |
+
buf.write(imgbuf)
|
131 |
+
buf.seek(0)
|
132 |
+
img = Image.open(buf).convert("RGB")
|
133 |
+
|
134 |
+
if self.transform is not None:
|
135 |
+
img = self.transform(img)
|
136 |
+
|
137 |
+
if self.target_transform is not None:
|
138 |
+
target = self.target_transform(target)
|
139 |
+
|
140 |
+
return img, target
|
141 |
+
|
142 |
+
def __len__(self):
|
143 |
+
return self.length
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
class LSUN(VisionDataset):
|
148 |
+
"""
|
149 |
+
`LSUN <https://www.yf.io/p/lsun>`_ dataset.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
root (string): Root directory for the database files.
|
153 |
+
classes (string or list): One of {'train', 'val', 'test'} or a list of
|
154 |
+
categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
|
155 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
156 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
157 |
+
target_transform (callable, optional): A function/transform that takes in the
|
158 |
+
target and transforms it.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, root, classes="train", transform=None, target_transform=None):
|
162 |
+
super(LSUN, self).__init__(
|
163 |
+
root, transform=transform, target_transform=target_transform
|
164 |
+
)
|
165 |
+
self.classes = self._verify_classes(classes)
|
166 |
+
|
167 |
+
# for each class, create an LSUNClassDataset
|
168 |
+
self.dbs = []
|
169 |
+
for c in self.classes:
|
170 |
+
self.dbs.append(
|
171 |
+
LSUNClass(root=root + "/" + c + "_lmdb", transform=transform)
|
172 |
+
)
|
173 |
+
|
174 |
+
self.indices = []
|
175 |
+
count = 0
|
176 |
+
for db in self.dbs:
|
177 |
+
count += len(db)
|
178 |
+
self.indices.append(count)
|
179 |
+
|
180 |
+
self.length = count
|
181 |
+
|
182 |
+
def _verify_classes(self, classes):
|
183 |
+
categories = [
|
184 |
+
"bedroom",
|
185 |
+
"bridge",
|
186 |
+
"church_outdoor",
|
187 |
+
"classroom",
|
188 |
+
"conference_room",
|
189 |
+
"dining_room",
|
190 |
+
"kitchen",
|
191 |
+
"living_room",
|
192 |
+
"restaurant",
|
193 |
+
"tower",
|
194 |
+
]
|
195 |
+
dset_opts = ["train", "val", "test"]
|
196 |
+
|
197 |
+
try:
|
198 |
+
verify_str_arg(classes, "classes", dset_opts)
|
199 |
+
if classes == "test":
|
200 |
+
classes = [classes]
|
201 |
+
else:
|
202 |
+
classes = [c + "_" + classes for c in categories]
|
203 |
+
except ValueError:
|
204 |
+
if not isinstance(classes, Iterable):
|
205 |
+
msg = (
|
206 |
+
"Expected type str or Iterable for argument classes, "
|
207 |
+
"but got type {}."
|
208 |
+
)
|
209 |
+
raise ValueError(msg.format(type(classes)))
|
210 |
+
|
211 |
+
classes = list(classes)
|
212 |
+
msg_fmtstr = (
|
213 |
+
"Expected type str for elements in argument classes, "
|
214 |
+
"but got type {}."
|
215 |
+
)
|
216 |
+
for c in classes:
|
217 |
+
verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
|
218 |
+
c_short = c.split("_")
|
219 |
+
category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
|
220 |
+
|
221 |
+
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
|
222 |
+
msg = msg_fmtstr.format(
|
223 |
+
category, "LSUN class", iterable_to_str(categories)
|
224 |
+
)
|
225 |
+
verify_str_arg(category, valid_values=categories, custom_msg=msg)
|
226 |
+
|
227 |
+
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
|
228 |
+
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
|
229 |
+
|
230 |
+
return classes
|
231 |
+
|
232 |
+
def __getitem__(self, index):
|
233 |
+
"""
|
234 |
+
Args:
|
235 |
+
index (int): Index
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
tuple: Tuple (image, target) where target is the index of the target category.
|
239 |
+
"""
|
240 |
+
target = 0
|
241 |
+
sub = 0
|
242 |
+
for ind in self.indices:
|
243 |
+
if index < ind:
|
244 |
+
break
|
245 |
+
target += 1
|
246 |
+
sub = ind
|
247 |
+
|
248 |
+
db = self.dbs[target]
|
249 |
+
index = index - sub
|
250 |
+
|
251 |
+
if self.target_transform is not None:
|
252 |
+
target = self.target_transform(target)
|
253 |
+
|
254 |
+
img, _ = db[index]
|
255 |
+
return img#, target
|
256 |
+
|
257 |
+
def __len__(self):
|
258 |
+
return self.length
|
259 |
+
|
260 |
+
def extra_repr(self):
|
261 |
+
return "Classes: {classes}".format(**self.__dict__)
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
################################################################
|
268 |
+
|
269 |
+
def get_lsun_dataset(data_root, config):
|
270 |
+
|
271 |
+
train_folder = "{}_train".format(config.data.category)
|
272 |
+
val_folder = "{}_val".format(config.data.category)
|
273 |
+
|
274 |
+
train_dataset = LSUN(
|
275 |
+
root=os.path.join(data_root),
|
276 |
+
classes=[train_folder],
|
277 |
+
transform=tfs.Compose(
|
278 |
+
[
|
279 |
+
tfs.Resize(config.data.image_size),
|
280 |
+
tfs.CenterCrop(config.data.image_size),
|
281 |
+
tfs.ToTensor(),
|
282 |
+
tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
283 |
+
inplace=True)
|
284 |
+
]
|
285 |
+
),
|
286 |
+
)
|
287 |
+
|
288 |
+
test_dataset = LSUN(
|
289 |
+
root=os.path.join(data_root),
|
290 |
+
classes=[val_folder],
|
291 |
+
transform=tfs.Compose(
|
292 |
+
[
|
293 |
+
tfs.Resize(config.data.image_size),
|
294 |
+
tfs.CenterCrop(config.data.image_size),
|
295 |
+
tfs.ToTensor(),
|
296 |
+
tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
|
297 |
+
inplace=True)
|
298 |
+
,
|
299 |
+
]
|
300 |
+
),
|
301 |
+
)
|
302 |
+
|
303 |
+
|
304 |
+
return train_dataset, test_dataset
|
datasets/celeba_attr.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
5_o_Clock_Shadow
|
2 |
+
Arched_Eyebrows
|
3 |
+
Attractive
|
4 |
+
Bags_Under_Eyes
|
5 |
+
Bald
|
6 |
+
Bangs
|
7 |
+
Big_Lips
|
8 |
+
Big_Nose
|
9 |
+
Black_Hair
|
10 |
+
Blond_Hair
|
11 |
+
Blurry
|
12 |
+
Brown_Hair
|
13 |
+
Bushy_Eyebrows
|
14 |
+
Chubby
|
15 |
+
Double_Chin
|
16 |
+
Eyeglasses
|
17 |
+
Goatee
|
18 |
+
Gray_Hair
|
19 |
+
Heavy_Makeup
|
20 |
+
High_Cheekbones
|
21 |
+
Male
|
22 |
+
Mouth_Slightly_Open
|
23 |
+
Mustache
|
24 |
+
Narrow_Eyes
|
25 |
+
No_Beard
|
26 |
+
Oval_Face
|
27 |
+
Pale_Skin
|
28 |
+
Pointy_Nose
|
29 |
+
Receding_Hairline
|
30 |
+
Rosy_Cheeks
|
31 |
+
Sideburns
|
32 |
+
Smiling
|
33 |
+
Straight_Hair
|
34 |
+
Wavy_Hair
|
35 |
+
Wearing_Earrings
|
36 |
+
Wearing_Hat
|
37 |
+
Wearing_Lipstick
|
38 |
+
Wearing_Necklace
|
39 |
+
Wearing_Necktie
|
40 |
+
Young
|
datasets/data_utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .AFHQ_dataset import get_afhq_dataset
|
2 |
+
from .CelebA_HQ_dataset import get_celeba_dataset
|
3 |
+
from .LSUN_dataset import get_lsun_dataset
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from .IMAGENET_dataset import get_imagenet_dataset
|
6 |
+
|
7 |
+
def get_dataset(dataset_type, dataset_paths, config, target_class_num=None, gender=None):
|
8 |
+
if dataset_type == 'AFHQ':
|
9 |
+
train_dataset, test_dataset = get_afhq_dataset(dataset_paths['AFHQ'], config)
|
10 |
+
elif dataset_type == "LSUN":
|
11 |
+
train_dataset, test_dataset = get_lsun_dataset(dataset_paths['LSUN'], config)
|
12 |
+
elif dataset_type == "CelebA_HQ":
|
13 |
+
train_dataset, test_dataset = get_celeba_dataset(dataset_paths['CelebA_HQ'], config)
|
14 |
+
elif dataset_type == "IMAGENET":
|
15 |
+
train_dataset, test_dataset = get_imagenet_dataset(dataset_paths['IMAGENET'], config, class_num=target_class_num)
|
16 |
+
else:
|
17 |
+
raise ValueError
|
18 |
+
|
19 |
+
return train_dataset, test_dataset
|
20 |
+
|
21 |
+
|
22 |
+
def get_dataloader(train_dataset, test_dataset, bs_train=1, num_workers=0):
|
23 |
+
train_loader = DataLoader(
|
24 |
+
train_dataset,
|
25 |
+
batch_size=bs_train,
|
26 |
+
drop_last=True,
|
27 |
+
shuffle=True,
|
28 |
+
sampler=None,
|
29 |
+
num_workers=num_workers,
|
30 |
+
pin_memory=True,
|
31 |
+
)
|
32 |
+
test_loader = DataLoader(
|
33 |
+
test_dataset,
|
34 |
+
batch_size=1,
|
35 |
+
drop_last=True,
|
36 |
+
sampler=None,
|
37 |
+
shuffle=True,
|
38 |
+
num_workers=num_workers,
|
39 |
+
pin_memory=True,
|
40 |
+
)
|
41 |
+
|
42 |
+
return {'train': train_loader, 'test': test_loader}
|
43 |
+
|
44 |
+
|
datasets/imagenet_dic.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
IMAGENET_DIC = {"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"],
|
2 |
+
"3": ["n01491361", "tiger_shark"], "4": ["n01494475", "hammerhead"], "5": ["n01496331", "electric_ray"],
|
3 |
+
"6": ["n01498041", "stingray"], "7": ["n01514668", "cock"], "8": ["n01514859", "hen"],
|
4 |
+
"9": ["n01518878", "ostrich"], "10": ["n01530575", "brambling"], "11": ["n01531178", "goldfinch"],
|
5 |
+
"12": ["n01532829", "house_finch"], "13": ["n01534433", "junco"], "14": ["n01537544", "indigo_bunting"],
|
6 |
+
"15": ["n01558993", "robin"], "16": ["n01560419", "bulbul"], "17": ["n01580077", "jay"],
|
7 |
+
"18": ["n01582220", "magpie"], "19": ["n01592084", "chickadee"], "20": ["n01601694", "water_ouzel"],
|
8 |
+
"21": ["n01608432", "kite"], "22": ["n01614925", "bald_eagle"], "23": ["n01616318", "vulture"],
|
9 |
+
"24": ["n01622779", "great_grey_owl"], "25": ["n01629819", "European_fire_salamander"],
|
10 |
+
"26": ["n01630670", "common_newt"], "27": ["n01631663", "eft"],
|
11 |
+
"28": ["n01632458", "spotted_salamander"], "29": ["n01632777", "axolotl"],
|
12 |
+
"30": ["n01641577", "bullfrog"], "31": ["n01644373", "tree_frog"], "32": ["n01644900", "tailed_frog"],
|
13 |
+
"33": ["n01664065", "loggerhead"], "34": ["n01665541", "leatherback_turtle"],
|
14 |
+
"35": ["n01667114", "mud_turtle"], "36": ["n01667778", "terrapin"], "37": ["n01669191", "box_turtle"],
|
15 |
+
"38": ["n01675722", "banded_gecko"], "39": ["n01677366", "common_iguana"],
|
16 |
+
"40": ["n01682714", "American_chameleon"], "41": ["n01685808", "whiptail"],
|
17 |
+
"42": ["n01687978", "agama"], "43": ["n01688243", "frilled_lizard"],
|
18 |
+
"44": ["n01689811", "alligator_lizard"], "45": ["n01692333", "Gila_monster"],
|
19 |
+
"46": ["n01693334", "green_lizard"], "47": ["n01694178", "African_chameleon"],
|
20 |
+
"48": ["n01695060", "Komodo_dragon"], "49": ["n01697457", "African_crocodile"],
|
21 |
+
"50": ["n01698640", "American_alligator"], "51": ["n01704323", "triceratops"],
|
22 |
+
"52": ["n01728572", "thunder_snake"], "53": ["n01728920", "ringneck_snake"],
|
23 |
+
"54": ["n01729322", "hognose_snake"], "55": ["n01729977", "green_snake"],
|
24 |
+
"56": ["n01734418", "king_snake"], "57": ["n01735189", "garter_snake"],
|
25 |
+
"58": ["n01737021", "water_snake"], "59": ["n01739381", "vine_snake"],
|
26 |
+
"60": ["n01740131", "night_snake"], "61": ["n01742172", "boa_constrictor"],
|
27 |
+
"62": ["n01744401", "rock_python"], "63": ["n01748264", "Indian_cobra"],
|
28 |
+
"64": ["n01749939", "green_mamba"], "65": ["n01751748", "sea_snake"],
|
29 |
+
"66": ["n01753488", "horned_viper"], "67": ["n01755581", "diamondback"],
|
30 |
+
"68": ["n01756291", "sidewinder"], "69": ["n01768244", "trilobite"], "70": ["n01770081", "harvestman"],
|
31 |
+
"71": ["n01770393", "scorpion"], "72": ["n01773157", "black_and_gold_garden_spider"],
|
32 |
+
"73": ["n01773549", "barn_spider"], "74": ["n01773797", "garden_spider"],
|
33 |
+
"75": ["n01774384", "black_widow"], "76": ["n01774750", "tarantula"],
|
34 |
+
"77": ["n01775062", "wolf_spider"], "78": ["n01776313", "tick"], "79": ["n01784675", "centipede"],
|
35 |
+
"80": ["n01795545", "black_grouse"], "81": ["n01796340", "ptarmigan"],
|
36 |
+
"82": ["n01797886", "ruffed_grouse"], "83": ["n01798484", "prairie_chicken"],
|
37 |
+
"84": ["n01806143", "peacock"], "85": ["n01806567", "quail"], "86": ["n01807496", "partridge"],
|
38 |
+
"87": ["n01817953", "African_grey"], "88": ["n01818515", "macaw"],
|
39 |
+
"89": ["n01819313", "sulphur-crested_cockatoo"], "90": ["n01820546", "lorikeet"],
|
40 |
+
"91": ["n01824575", "coucal"], "92": ["n01828970", "bee_eater"], "93": ["n01829413", "hornbill"],
|
41 |
+
"94": ["n01833805", "hummingbird"], "95": ["n01843065", "jacamar"], "96": ["n01843383", "toucan"],
|
42 |
+
"97": ["n01847000", "drake"], "98": ["n01855032", "red-breasted_merganser"],
|
43 |
+
"99": ["n01855672", "goose"], "100": ["n01860187", "black_swan"], "101": ["n01871265", "tusker"],
|
44 |
+
"102": ["n01872401", "echidna"], "103": ["n01873310", "platypus"], "104": ["n01877812", "wallaby"],
|
45 |
+
"105": ["n01882714", "koala"], "106": ["n01883070", "wombat"], "107": ["n01910747", "jellyfish"],
|
46 |
+
"108": ["n01914609", "sea_anemone"], "109": ["n01917289", "brain_coral"],
|
47 |
+
"110": ["n01924916", "flatworm"], "111": ["n01930112", "nematode"], "112": ["n01943899", "conch"],
|
48 |
+
"113": ["n01944390", "snail"], "114": ["n01945685", "slug"], "115": ["n01950731", "sea_slug"],
|
49 |
+
"116": ["n01955084", "chiton"], "117": ["n01968897", "chambered_nautilus"],
|
50 |
+
"118": ["n01978287", "Dungeness_crab"], "119": ["n01978455", "rock_crab"],
|
51 |
+
"120": ["n01980166", "fiddler_crab"], "121": ["n01981276", "king_crab"],
|
52 |
+
"122": ["n01983481", "American_lobster"], "123": ["n01984695", "spiny_lobster"],
|
53 |
+
"124": ["n01985128", "crayfish"], "125": ["n01986214", "hermit_crab"], "126": ["n01990800", "isopod"],
|
54 |
+
"127": ["n02002556", "white_stork"], "128": ["n02002724", "black_stork"],
|
55 |
+
"129": ["n02006656", "spoonbill"], "130": ["n02007558", "flamingo"],
|
56 |
+
"131": ["n02009229", "little_blue_heron"], "132": ["n02009912", "American_egret"],
|
57 |
+
"133": ["n02011460", "bittern"], "134": ["n02012849", "crane"], "135": ["n02013706", "limpkin"],
|
58 |
+
"136": ["n02017213", "European_gallinule"], "137": ["n02018207", "American_coot"],
|
59 |
+
"138": ["n02018795", "bustard"], "139": ["n02025239", "ruddy_turnstone"],
|
60 |
+
"140": ["n02027492", "red-backed_sandpiper"], "141": ["n02028035", "redshank"],
|
61 |
+
"142": ["n02033041", "dowitcher"], "143": ["n02037110", "oystercatcher"],
|
62 |
+
"144": ["n02051845", "pelican"], "145": ["n02056570", "king_penguin"],
|
63 |
+
"146": ["n02058221", "albatross"], "147": ["n02066245", "grey_whale"],
|
64 |
+
"148": ["n02071294", "killer_whale"], "149": ["n02074367", "dugong"], "150": ["n02077923", "sea_lion"],
|
65 |
+
"151": ["n02085620", "Chihuahua"], "152": ["n02085782", "Japanese_spaniel"],
|
66 |
+
"153": ["n02085936", "Maltese_dog"], "154": ["n02086079", "Pekinese"], "155": ["n02086240", "Shih-Tzu"],
|
67 |
+
"156": ["n02086646", "Blenheim_spaniel"], "157": ["n02086910", "papillon"],
|
68 |
+
"158": ["n02087046", "toy_terrier"], "159": ["n02087394", "Rhodesian_ridgeback"],
|
69 |
+
"160": ["n02088094", "Afghan_hound"], "161": ["n02088238", "basset"], "162": ["n02088364", "beagle"],
|
70 |
+
"163": ["n02088466", "bloodhound"], "164": ["n02088632", "bluetick"],
|
71 |
+
"165": ["n02089078", "black-and-tan_coonhound"], "166": ["n02089867", "Walker_hound"],
|
72 |
+
"167": ["n02089973", "English_foxhound"], "168": ["n02090379", "redbone"],
|
73 |
+
"169": ["n02090622", "borzoi"], "170": ["n02090721", "Irish_wolfhound"],
|
74 |
+
"171": ["n02091032", "Italian_greyhound"], "172": ["n02091134", "whippet"],
|
75 |
+
"173": ["n02091244", "Ibizan_hound"], "174": ["n02091467", "Norwegian_elkhound"],
|
76 |
+
"175": ["n02091635", "otterhound"], "176": ["n02091831", "Saluki"],
|
77 |
+
"177": ["n02092002", "Scottish_deerhound"], "178": ["n02092339", "Weimaraner"],
|
78 |
+
"179": ["n02093256", "Staffordshire_bullterrier"],
|
79 |
+
"180": ["n02093428", "American_Staffordshire_terrier"], "181": ["n02093647", "Bedlington_terrier"],
|
80 |
+
"182": ["n02093754", "Border_terrier"], "183": ["n02093859", "Kerry_blue_terrier"],
|
81 |
+
"184": ["n02093991", "Irish_terrier"], "185": ["n02094114", "Norfolk_terrier"],
|
82 |
+
"186": ["n02094258", "Norwich_terrier"], "187": ["n02094433", "Yorkshire_terrier"],
|
83 |
+
"188": ["n02095314", "wire-haired_fox_terrier"], "189": ["n02095570", "Lakeland_terrier"],
|
84 |
+
"190": ["n02095889", "Sealyham_terrier"], "191": ["n02096051", "Airedale"],
|
85 |
+
"192": ["n02096177", "cairn"], "193": ["n02096294", "Australian_terrier"],
|
86 |
+
"194": ["n02096437", "Dandie_Dinmont"], "195": ["n02096585", "Boston_bull"],
|
87 |
+
"196": ["n02097047", "miniature_schnauzer"], "197": ["n02097130", "giant_schnauzer"],
|
88 |
+
"198": ["n02097209", "standard_schnauzer"], "199": ["n02097298", "Scotch_terrier"],
|
89 |
+
"200": ["n02097474", "Tibetan_terrier"], "201": ["n02097658", "silky_terrier"],
|
90 |
+
"202": ["n02098105", "soft-coated_wheaten_terrier"],
|
91 |
+
"203": ["n02098286", "West_Highland_white_terrier"], "204": ["n02098413", "Lhasa"],
|
92 |
+
"205": ["n02099267", "flat-coated_retriever"], "206": ["n02099429", "curly-coated_retriever"],
|
93 |
+
"207": ["n02099601", "golden_retriever"], "208": ["n02099712", "Labrador_retriever"],
|
94 |
+
"209": ["n02099849", "Chesapeake_Bay_retriever"], "210": ["n02100236", "German_short-haired_pointer"],
|
95 |
+
"211": ["n02100583", "vizsla"], "212": ["n02100735", "English_setter"],
|
96 |
+
"213": ["n02100877", "Irish_setter"], "214": ["n02101006", "Gordon_setter"],
|
97 |
+
"215": ["n02101388", "Brittany_spaniel"], "216": ["n02101556", "clumber"],
|
98 |
+
"217": ["n02102040", "English_springer"], "218": ["n02102177", "Welsh_springer_spaniel"],
|
99 |
+
"219": ["n02102318", "cocker_spaniel"], "220": ["n02102480", "Sussex_spaniel"],
|
100 |
+
"221": ["n02102973", "Irish_water_spaniel"], "222": ["n02104029", "kuvasz"],
|
101 |
+
"223": ["n02104365", "schipperke"], "224": ["n02105056", "groenendael"],
|
102 |
+
"225": ["n02105162", "malinois"], "226": ["n02105251", "briard"], "227": ["n02105412", "kelpie"],
|
103 |
+
"228": ["n02105505", "komondor"], "229": ["n02105641", "Old_English_sheepdog"],
|
104 |
+
"230": ["n02105855", "Shetland_sheepdog"], "231": ["n02106030", "collie"],
|
105 |
+
"232": ["n02106166", "Border_collie"], "233": ["n02106382", "Bouvier_des_Flandres"],
|
106 |
+
"234": ["n02106550", "Rottweiler"], "235": ["n02106662", "German_shepherd"],
|
107 |
+
"236": ["n02107142", "Doberman"], "237": ["n02107312", "miniature_pinscher"],
|
108 |
+
"238": ["n02107574", "Greater_Swiss_Mountain_dog"], "239": ["n02107683", "Bernese_mountain_dog"],
|
109 |
+
"240": ["n02107908", "Appenzeller"], "241": ["n02108000", "EntleBucher"], "242": ["n02108089", "boxer"],
|
110 |
+
"243": ["n02108422", "bull_mastiff"], "244": ["n02108551", "Tibetan_mastiff"],
|
111 |
+
"245": ["n02108915", "French_bulldog"], "246": ["n02109047", "Great_Dane"],
|
112 |
+
"247": ["n02109525", "Saint_Bernard"], "248": ["n02109961", "Eskimo_dog"],
|
113 |
+
"249": ["n02110063", "malamute"], "250": ["n02110185", "Siberian_husky"],
|
114 |
+
"251": ["n02110341", "dalmatian"], "252": ["n02110627", "affenpinscher"],
|
115 |
+
"253": ["n02110806", "basenji"], "254": ["n02110958", "pug"], "255": ["n02111129", "Leonberg"],
|
116 |
+
"256": ["n02111277", "Newfoundland"], "257": ["n02111500", "Great_Pyrenees"],
|
117 |
+
"258": ["n02111889", "Samoyed"], "259": ["n02112018", "Pomeranian"], "260": ["n02112137", "chow"],
|
118 |
+
"261": ["n02112350", "keeshond"], "262": ["n02112706", "Brabancon_griffon"],
|
119 |
+
"263": ["n02113023", "Pembroke"], "264": ["n02113186", "Cardigan"], "265": ["n02113624", "toy_poodle"],
|
120 |
+
"266": ["n02113712", "miniature_poodle"], "267": ["n02113799", "standard_poodle"],
|
121 |
+
"268": ["n02113978", "Mexican_hairless"], "269": ["n02114367", "timber_wolf"],
|
122 |
+
"270": ["n02114548", "white_wolf"], "271": ["n02114712", "red_wolf"], "272": ["n02114855", "coyote"],
|
123 |
+
"273": ["n02115641", "dingo"], "274": ["n02115913", "dhole"],
|
124 |
+
"275": ["n02116738", "African_hunting_dog"], "276": ["n02117135", "hyena"],
|
125 |
+
"277": ["n02119022", "red_fox"], "278": ["n02119789", "kit_fox"], "279": ["n02120079", "Arctic_fox"],
|
126 |
+
"280": ["n02120505", "grey_fox"], "281": ["n02123045", "tabby"], "282": ["n02123159", "tiger_cat"],
|
127 |
+
"283": ["n02123394", "Persian_cat"], "284": ["n02123597", "Siamese_cat"],
|
128 |
+
"285": ["n02124075", "Egyptian_cat"], "286": ["n02125311", "cougar"], "287": ["n02127052", "lynx"],
|
129 |
+
"288": ["n02128385", "leopard"], "289": ["n02128757", "snow_leopard"], "290": ["n02128925", "jaguar"],
|
130 |
+
"291": ["n02129165", "lion"], "292": ["n02129604", "tiger"], "293": ["n02130308", "cheetah"],
|
131 |
+
"294": ["n02132136", "brown_bear"], "295": ["n02133161", "American_black_bear"],
|
132 |
+
"296": ["n02134084", "ice_bear"], "297": ["n02134418", "sloth_bear"], "298": ["n02137549", "mongoose"],
|
133 |
+
"299": ["n02138441", "meerkat"], "300": ["n02165105", "tiger_beetle"], "301": ["n02165456", "ladybug"],
|
134 |
+
"302": ["n02167151", "ground_beetle"], "303": ["n02168699", "long-horned_beetle"],
|
135 |
+
"304": ["n02169497", "leaf_beetle"], "305": ["n02172182", "dung_beetle"],
|
136 |
+
"306": ["n02174001", "rhinoceros_beetle"], "307": ["n02177972", "weevil"], "308": ["n02190166", "fly"],
|
137 |
+
"309": ["n02206856", "bee"], "310": ["n02219486", "ant"], "311": ["n02226429", "grasshopper"],
|
138 |
+
"312": ["n02229544", "cricket"], "313": ["n02231487", "walking_stick"],
|
139 |
+
"314": ["n02233338", "cockroach"], "315": ["n02236044", "mantis"], "316": ["n02256656", "cicada"],
|
140 |
+
"317": ["n02259212", "leafhopper"], "318": ["n02264363", "lacewing"], "319": ["n02268443", "dragonfly"],
|
141 |
+
"320": ["n02268853", "damselfly"], "321": ["n02276258", "admiral"], "322": ["n02277742", "ringlet"],
|
142 |
+
"323": ["n02279972", "monarch"], "324": ["n02280649", "cabbage_butterfly"],
|
143 |
+
"325": ["n02281406", "sulphur_butterfly"], "326": ["n02281787", "lycaenid"],
|
144 |
+
"327": ["n02317335", "starfish"], "328": ["n02319095", "sea_urchin"],
|
145 |
+
"329": ["n02321529", "sea_cucumber"], "330": ["n02325366", "wood_rabbit"], "331": ["n02326432", "hare"],
|
146 |
+
"332": ["n02328150", "Angora"], "333": ["n02342885", "hamster"], "334": ["n02346627", "porcupine"],
|
147 |
+
"335": ["n02356798", "fox_squirrel"], "336": ["n02361337", "marmot"], "337": ["n02363005", "beaver"],
|
148 |
+
"338": ["n02364673", "guinea_pig"], "339": ["n02389026", "sorrel"], "340": ["n02391049", "zebra"],
|
149 |
+
"341": ["n02395406", "hog"], "342": ["n02396427", "wild_boar"], "343": ["n02397096", "warthog"],
|
150 |
+
"344": ["n02398521", "hippopotamus"], "345": ["n02403003", "ox"], "346": ["n02408429", "water_buffalo"],
|
151 |
+
"347": ["n02410509", "bison"], "348": ["n02412080", "ram"], "349": ["n02415577", "bighorn"],
|
152 |
+
"350": ["n02417914", "ibex"], "351": ["n02422106", "hartebeest"], "352": ["n02422699", "impala"],
|
153 |
+
"353": ["n02423022", "gazelle"], "354": ["n02437312", "Arabian_camel"], "355": ["n02437616", "llama"],
|
154 |
+
"356": ["n02441942", "weasel"], "357": ["n02442845", "mink"], "358": ["n02443114", "polecat"],
|
155 |
+
"359": ["n02443484", "black-footed_ferret"], "360": ["n02444819", "otter"],
|
156 |
+
"361": ["n02445715", "skunk"], "362": ["n02447366", "badger"], "363": ["n02454379", "armadillo"],
|
157 |
+
"364": ["n02457408", "three-toed_sloth"], "365": ["n02480495", "orangutan"],
|
158 |
+
"366": ["n02480855", "gorilla"], "367": ["n02481823", "chimpanzee"], "368": ["n02483362", "gibbon"],
|
159 |
+
"369": ["n02483708", "siamang"], "370": ["n02484975", "guenon"], "371": ["n02486261", "patas"],
|
160 |
+
"372": ["n02486410", "baboon"], "373": ["n02487347", "macaque"], "374": ["n02488291", "langur"],
|
161 |
+
"375": ["n02488702", "colobus"], "376": ["n02489166", "proboscis_monkey"],
|
162 |
+
"377": ["n02490219", "marmoset"], "378": ["n02492035", "capuchin"],
|
163 |
+
"379": ["n02492660", "howler_monkey"], "380": ["n02493509", "titi"],
|
164 |
+
"381": ["n02493793", "spider_monkey"], "382": ["n02494079", "squirrel_monkey"],
|
165 |
+
"383": ["n02497673", "Madagascar_cat"], "384": ["n02500267", "indri"],
|
166 |
+
"385": ["n02504013", "Indian_elephant"], "386": ["n02504458", "African_elephant"],
|
167 |
+
"387": ["n02509815", "lesser_panda"], "388": ["n02510455", "giant_panda"],
|
168 |
+
"389": ["n02514041", "barracouta"], "390": ["n02526121", "eel"], "391": ["n02536864", "coho"],
|
169 |
+
"392": ["n02606052", "rock_beauty"], "393": ["n02607072", "anemone_fish"],
|
170 |
+
"394": ["n02640242", "sturgeon"], "395": ["n02641379", "gar"], "396": ["n02643566", "lionfish"],
|
171 |
+
"397": ["n02655020", "puffer"], "398": ["n02666196", "abacus"], "399": ["n02667093", "abaya"],
|
172 |
+
"400": ["n02669723", "academic_gown"], "401": ["n02672831", "accordion"],
|
173 |
+
"402": ["n02676566", "acoustic_guitar"], "403": ["n02687172", "aircraft_carrier"],
|
174 |
+
"404": ["n02690373", "airliner"], "405": ["n02692877", "airship"], "406": ["n02699494", "altar"],
|
175 |
+
"407": ["n02701002", "ambulance"], "408": ["n02704792", "amphibian"],
|
176 |
+
"409": ["n02708093", "analog_clock"], "410": ["n02727426", "apiary"], "411": ["n02730930", "apron"],
|
177 |
+
"412": ["n02747177", "ashcan"], "413": ["n02749479", "assault_rifle"], "414": ["n02769748", "backpack"],
|
178 |
+
"415": ["n02776631", "bakery"], "416": ["n02777292", "balance_beam"], "417": ["n02782093", "balloon"],
|
179 |
+
"418": ["n02783161", "ballpoint"], "419": ["n02786058", "Band_Aid"], "420": ["n02787622", "banjo"],
|
180 |
+
"421": ["n02788148", "bannister"], "422": ["n02790996", "barbell"],
|
181 |
+
"423": ["n02791124", "barber_chair"], "424": ["n02791270", "barbershop"], "425": ["n02793495", "barn"],
|
182 |
+
"426": ["n02794156", "barometer"], "427": ["n02795169", "barrel"], "428": ["n02797295", "barrow"],
|
183 |
+
"429": ["n02799071", "baseball"], "430": ["n02802426", "basketball"], "431": ["n02804414", "bassinet"],
|
184 |
+
"432": ["n02804610", "bassoon"], "433": ["n02807133", "bathing_cap"],
|
185 |
+
"434": ["n02808304", "bath_towel"], "435": ["n02808440", "bathtub"],
|
186 |
+
"436": ["n02814533", "beach_wagon"], "437": ["n02814860", "beacon"], "438": ["n02815834", "beaker"],
|
187 |
+
"439": ["n02817516", "bearskin"], "440": ["n02823428", "beer_bottle"],
|
188 |
+
"441": ["n02823750", "beer_glass"], "442": ["n02825657", "bell_cote"], "443": ["n02834397", "bib"],
|
189 |
+
"444": ["n02835271", "bicycle-built-for-two"], "445": ["n02837789", "bikini"],
|
190 |
+
"446": ["n02840245", "binder"], "447": ["n02841315", "binoculars"], "448": ["n02843684", "birdhouse"],
|
191 |
+
"449": ["n02859443", "boathouse"], "450": ["n02860847", "bobsled"], "451": ["n02865351", "bolo_tie"],
|
192 |
+
"452": ["n02869837", "bonnet"], "453": ["n02870880", "bookcase"], "454": ["n02871525", "bookshop"],
|
193 |
+
"455": ["n02877765", "bottlecap"], "456": ["n02879718", "bow"], "457": ["n02883205", "bow_tie"],
|
194 |
+
"458": ["n02892201", "brass"], "459": ["n02892767", "brassiere"], "460": ["n02894605", "breakwater"],
|
195 |
+
"461": ["n02895154", "breastplate"], "462": ["n02906734", "broom"], "463": ["n02909870", "bucket"],
|
196 |
+
"464": ["n02910353", "buckle"], "465": ["n02916936", "bulletproof_vest"],
|
197 |
+
"466": ["n02917067", "bullet_train"], "467": ["n02927161", "butcher_shop"], "468": ["n02930766", "cab"],
|
198 |
+
"469": ["n02939185", "caldron"], "470": ["n02948072", "candle"], "471": ["n02950826", "cannon"],
|
199 |
+
"472": ["n02951358", "canoe"], "473": ["n02951585", "can_opener"], "474": ["n02963159", "cardigan"],
|
200 |
+
"475": ["n02965783", "car_mirror"], "476": ["n02966193", "carousel"],
|
201 |
+
"477": ["n02966687", "carpenter's_kit"], "478": ["n02971356", "carton"],
|
202 |
+
"479": ["n02974003", "car_wheel"], "480": ["n02977058", "cash_machine"],
|
203 |
+
"481": ["n02978881", "cassette"], "482": ["n02979186", "cassette_player"],
|
204 |
+
"483": ["n02980441", "castle"], "484": ["n02981792", "catamaran"], "485": ["n02988304", "CD_player"],
|
205 |
+
"486": ["n02992211", "cello"], "487": ["n02992529", "cellular_telephone"],
|
206 |
+
"488": ["n02999410", "chain"], "489": ["n03000134", "chainlink_fence"],
|
207 |
+
"490": ["n03000247", "chain_mail"], "491": ["n03000684", "chain_saw"], "492": ["n03014705", "chest"],
|
208 |
+
"493": ["n03016953", "chiffonier"], "494": ["n03017168", "chime"],
|
209 |
+
"495": ["n03018349", "china_cabinet"], "496": ["n03026506", "Christmas_stocking"],
|
210 |
+
"497": ["n03028079", "church"], "498": ["n03032252", "cinema"], "499": ["n03041632", "cleaver"],
|
211 |
+
"500": ["n03042490", "cliff_dwelling"], "501": ["n03045698", "cloak"], "502": ["n03047690", "clog"],
|
212 |
+
"503": ["n03062245", "cocktail_shaker"], "504": ["n03063599", "coffee_mug"],
|
213 |
+
"505": ["n03063689", "coffeepot"], "506": ["n03065424", "coil"],
|
214 |
+
"507": ["n03075370", "combination_lock"], "508": ["n03085013", "computer_keyboard"],
|
215 |
+
"509": ["n03089624", "confectionery"], "510": ["n03095699", "container_ship"],
|
216 |
+
"511": ["n03100240", "convertible"], "512": ["n03109150", "corkscrew"], "513": ["n03110669", "cornet"],
|
217 |
+
"514": ["n03124043", "cowboy_boot"], "515": ["n03124170", "cowboy_hat"], "516": ["n03125729", "cradle"],
|
218 |
+
"517": ["n03126707", "crane"], "518": ["n03127747", "crash_helmet"], "519": ["n03127925", "crate"],
|
219 |
+
"520": ["n03131574", "crib"], "521": ["n03133878", "Crock_Pot"], "522": ["n03134739", "croquet_ball"],
|
220 |
+
"523": ["n03141823", "crutch"], "524": ["n03146219", "cuirass"], "525": ["n03160309", "dam"],
|
221 |
+
"526": ["n03179701", "desk"], "527": ["n03180011", "desktop_computer"],
|
222 |
+
"528": ["n03187595", "dial_telephone"], "529": ["n03188531", "diaper"],
|
223 |
+
"530": ["n03196217", "digital_clock"], "531": ["n03197337", "digital_watch"],
|
224 |
+
"532": ["n03201208", "dining_table"], "533": ["n03207743", "dishrag"],
|
225 |
+
"534": ["n03207941", "dishwasher"], "535": ["n03208938", "disk_brake"], "536": ["n03216828", "dock"],
|
226 |
+
"537": ["n03218198", "dogsled"], "538": ["n03220513", "dome"], "539": ["n03223299", "doormat"],
|
227 |
+
"540": ["n03240683", "drilling_platform"], "541": ["n03249569", "drum"],
|
228 |
+
"542": ["n03250847", "drumstick"], "543": ["n03255030", "dumbbell"], "544": ["n03259280", "Dutch_oven"],
|
229 |
+
"545": ["n03271574", "electric_fan"], "546": ["n03272010", "electric_guitar"],
|
230 |
+
"547": ["n03272562", "electric_locomotive"], "548": ["n03290653", "entertainment_center"],
|
231 |
+
"549": ["n03291819", "envelope"], "550": ["n03297495", "espresso_maker"],
|
232 |
+
"551": ["n03314780", "face_powder"], "552": ["n03325584", "feather_boa"], "553": ["n03337140", "file"],
|
233 |
+
"554": ["n03344393", "fireboat"], "555": ["n03345487", "fire_engine"],
|
234 |
+
"556": ["n03347037", "fire_screen"], "557": ["n03355925", "flagpole"], "558": ["n03372029", "flute"],
|
235 |
+
"559": ["n03376595", "folding_chair"], "560": ["n03379051", "football_helmet"],
|
236 |
+
"561": ["n03384352", "forklift"], "562": ["n03388043", "fountain"],
|
237 |
+
"563": ["n03388183", "fountain_pen"], "564": ["n03388549", "four-poster"],
|
238 |
+
"565": ["n03393912", "freight_car"], "566": ["n03394916", "French_horn"],
|
239 |
+
"567": ["n03400231", "frying_pan"], "568": ["n03404251", "fur_coat"],
|
240 |
+
"569": ["n03417042", "garbage_truck"], "570": ["n03424325", "gasmask"],
|
241 |
+
"571": ["n03425413", "gas_pump"], "572": ["n03443371", "goblet"], "573": ["n03444034", "go-kart"],
|
242 |
+
"574": ["n03445777", "golf_ball"], "575": ["n03445924", "golfcart"], "576": ["n03447447", "gondola"],
|
243 |
+
"577": ["n03447721", "gong"], "578": ["n03450230", "gown"], "579": ["n03452741", "grand_piano"],
|
244 |
+
"580": ["n03457902", "greenhouse"], "581": ["n03459775", "grille"],
|
245 |
+
"582": ["n03461385", "grocery_store"], "583": ["n03467068", "guillotine"],
|
246 |
+
"584": ["n03476684", "hair_slide"], "585": ["n03476991", "hair_spray"],
|
247 |
+
"586": ["n03478589", "half_track"], "587": ["n03481172", "hammer"], "588": ["n03482405", "hamper"],
|
248 |
+
"589": ["n03483316", "hand_blower"], "590": ["n03485407", "hand-held_computer"],
|
249 |
+
"591": ["n03485794", "handkerchief"], "592": ["n03492542", "hard_disc"],
|
250 |
+
"593": ["n03494278", "harmonica"], "594": ["n03495258", "harp"], "595": ["n03496892", "harvester"],
|
251 |
+
"596": ["n03498962", "hatchet"], "597": ["n03527444", "holster"], "598": ["n03529860", "home_theater"],
|
252 |
+
"599": ["n03530642", "honeycomb"], "600": ["n03532672", "hook"], "601": ["n03534580", "hoopskirt"],
|
253 |
+
"602": ["n03535780", "horizontal_bar"], "603": ["n03538406", "horse_cart"],
|
254 |
+
"604": ["n03544143", "hourglass"], "605": ["n03584254", "iPod"], "606": ["n03584829", "iron"],
|
255 |
+
"607": ["n03590841", "jack-o'-lantern"], "608": ["n03594734", "jean"], "609": ["n03594945", "jeep"],
|
256 |
+
"610": ["n03595614", "jersey"], "611": ["n03598930", "jigsaw_puzzle"],
|
257 |
+
"612": ["n03599486", "jinrikisha"], "613": ["n03602883", "joystick"], "614": ["n03617480", "kimono"],
|
258 |
+
"615": ["n03623198", "knee_pad"], "616": ["n03627232", "knot"], "617": ["n03630383", "lab_coat"],
|
259 |
+
"618": ["n03633091", "ladle"], "619": ["n03637318", "lampshade"], "620": ["n03642806", "laptop"],
|
260 |
+
"621": ["n03649909", "lawn_mower"], "622": ["n03657121", "lens_cap"],
|
261 |
+
"623": ["n03658185", "letter_opener"], "624": ["n03661043", "library"],
|
262 |
+
"625": ["n03662601", "lifeboat"], "626": ["n03666591", "lighter"], "627": ["n03670208", "limousine"],
|
263 |
+
"628": ["n03673027", "liner"], "629": ["n03676483", "lipstick"], "630": ["n03680355", "Loafer"],
|
264 |
+
"631": ["n03690938", "lotion"], "632": ["n03691459", "loudspeaker"], "633": ["n03692522", "loupe"],
|
265 |
+
"634": ["n03697007", "lumbermill"], "635": ["n03706229", "magnetic_compass"],
|
266 |
+
"636": ["n03709823", "mailbag"], "637": ["n03710193", "mailbox"], "638": ["n03710637", "maillot"],
|
267 |
+
"639": ["n03710721", "maillot"], "640": ["n03717622", "manhole_cover"], "641": ["n03720891", "maraca"],
|
268 |
+
"642": ["n03721384", "marimba"], "643": ["n03724870", "mask"], "644": ["n03729826", "matchstick"],
|
269 |
+
"645": ["n03733131", "maypole"], "646": ["n03733281", "maze"], "647": ["n03733805", "measuring_cup"],
|
270 |
+
"648": ["n03742115", "medicine_chest"], "649": ["n03743016", "megalith"],
|
271 |
+
"650": ["n03759954", "microphone"], "651": ["n03761084", "microwave"],
|
272 |
+
"652": ["n03763968", "military_uniform"], "653": ["n03764736", "milk_can"],
|
273 |
+
"654": ["n03769881", "minibus"], "655": ["n03770439", "miniskirt"], "656": ["n03770679", "minivan"],
|
274 |
+
"657": ["n03773504", "missile"], "658": ["n03775071", "mitten"], "659": ["n03775546", "mixing_bowl"],
|
275 |
+
"660": ["n03776460", "mobile_home"], "661": ["n03777568", "Model_T"], "662": ["n03777754", "modem"],
|
276 |
+
"663": ["n03781244", "monastery"], "664": ["n03782006", "monitor"], "665": ["n03785016", "moped"],
|
277 |
+
"666": ["n03786901", "mortar"], "667": ["n03787032", "mortarboard"], "668": ["n03788195", "mosque"],
|
278 |
+
"669": ["n03788365", "mosquito_net"], "670": ["n03791053", "motor_scooter"],
|
279 |
+
"671": ["n03792782", "mountain_bike"], "672": ["n03792972", "mountain_tent"],
|
280 |
+
"673": ["n03793489", "mouse"], "674": ["n03794056", "mousetrap"], "675": ["n03796401", "moving_van"],
|
281 |
+
"676": ["n03803284", "muzzle"], "677": ["n03804744", "nail"], "678": ["n03814639", "neck_brace"],
|
282 |
+
"679": ["n03814906", "necklace"], "680": ["n03825788", "nipple"], "681": ["n03832673", "notebook"],
|
283 |
+
"682": ["n03837869", "obelisk"], "683": ["n03838899", "oboe"], "684": ["n03840681", "ocarina"],
|
284 |
+
"685": ["n03841143", "odometer"], "686": ["n03843555", "oil_filter"], "687": ["n03854065", "organ"],
|
285 |
+
"688": ["n03857828", "oscilloscope"], "689": ["n03866082", "overskirt"], "690": ["n03868242", "oxcart"],
|
286 |
+
"691": ["n03868863", "oxygen_mask"], "692": ["n03871628", "packet"], "693": ["n03873416", "paddle"],
|
287 |
+
"694": ["n03874293", "paddlewheel"], "695": ["n03874599", "padlock"],
|
288 |
+
"696": ["n03876231", "paintbrush"], "697": ["n03877472", "pajama"], "698": ["n03877845", "palace"],
|
289 |
+
"699": ["n03884397", "panpipe"], "700": ["n03887697", "paper_towel"], "701": ["n03888257", "parachute"],
|
290 |
+
"702": ["n03888605", "parallel_bars"], "703": ["n03891251", "park_bench"],
|
291 |
+
"704": ["n03891332", "parking_meter"], "705": ["n03895866", "passenger_car"],
|
292 |
+
"706": ["n03899768", "patio"], "707": ["n03902125", "pay-phone"], "708": ["n03903868", "pedestal"],
|
293 |
+
"709": ["n03908618", "pencil_box"], "710": ["n03908714", "pencil_sharpener"],
|
294 |
+
"711": ["n03916031", "perfume"], "712": ["n03920288", "Petri_dish"],
|
295 |
+
"713": ["n03924679", "photocopier"], "714": ["n03929660", "pick"], "715": ["n03929855", "pickelhaube"],
|
296 |
+
"716": ["n03930313", "picket_fence"], "717": ["n03930630", "pickup"], "718": ["n03933933", "pier"],
|
297 |
+
"719": ["n03935335", "piggy_bank"], "720": ["n03937543", "pill_bottle"], "721": ["n03938244", "pillow"],
|
298 |
+
"722": ["n03942813", "ping-pong_ball"], "723": ["n03944341", "pinwheel"],
|
299 |
+
"724": ["n03947888", "pirate"], "725": ["n03950228", "pitcher"], "726": ["n03954731", "plane"],
|
300 |
+
"727": ["n03956157", "planetarium"], "728": ["n03958227", "plastic_bag"],
|
301 |
+
"729": ["n03961711", "plate_rack"], "730": ["n03967562", "plow"], "731": ["n03970156", "plunger"],
|
302 |
+
"732": ["n03976467", "Polaroid_camera"], "733": ["n03976657", "pole"],
|
303 |
+
"734": ["n03977966", "police_van"], "735": ["n03980874", "poncho"], "736": ["n03982430", "pool_table"],
|
304 |
+
"737": ["n03983396", "pop_bottle"], "738": ["n03991062", "pot"], "739": ["n03992509", "potter's_wheel"],
|
305 |
+
"740": ["n03995372", "power_drill"], "741": ["n03998194", "prayer_rug"],
|
306 |
+
"742": ["n04004767", "printer"], "743": ["n04005630", "prison"], "744": ["n04008634", "projectile"],
|
307 |
+
"745": ["n04009552", "projector"], "746": ["n04019541", "puck"], "747": ["n04023962", "punching_bag"],
|
308 |
+
"748": ["n04026417", "purse"], "749": ["n04033901", "quill"], "750": ["n04033995", "quilt"],
|
309 |
+
"751": ["n04037443", "racer"], "752": ["n04039381", "racket"], "753": ["n04040759", "radiator"],
|
310 |
+
"754": ["n04041544", "radio"], "755": ["n04044716", "radio_telescope"],
|
311 |
+
"756": ["n04049303", "rain_barrel"], "757": ["n04065272", "recreational_vehicle"],
|
312 |
+
"758": ["n04067472", "reel"], "759": ["n04069434", "reflex_camera"],
|
313 |
+
"760": ["n04070727", "refrigerator"], "761": ["n04074963", "remote_control"],
|
314 |
+
"762": ["n04081281", "restaurant"], "763": ["n04086273", "revolver"], "764": ["n04090263", "rifle"],
|
315 |
+
"765": ["n04099969", "rocking_chair"], "766": ["n04111531", "rotisserie"],
|
316 |
+
"767": ["n04116512", "rubber_eraser"], "768": ["n04118538", "rugby_ball"], "769": ["n04118776", "rule"],
|
317 |
+
"770": ["n04120489", "running_shoe"], "771": ["n04125021", "safe"], "772": ["n04127249", "safety_pin"],
|
318 |
+
"773": ["n04131690", "saltshaker"], "774": ["n04133789", "sandal"], "775": ["n04136333", "sarong"],
|
319 |
+
"776": ["n04141076", "sax"], "777": ["n04141327", "scabbard"], "778": ["n04141975", "scale"],
|
320 |
+
"779": ["n04146614", "school_bus"], "780": ["n04147183", "schooner"],
|
321 |
+
"781": ["n04149813", "scoreboard"], "782": ["n04152593", "screen"], "783": ["n04153751", "screw"],
|
322 |
+
"784": ["n04154565", "screwdriver"], "785": ["n04162706", "seat_belt"],
|
323 |
+
"786": ["n04179913", "sewing_machine"], "787": ["n04192698", "shield"],
|
324 |
+
"788": ["n04200800", "shoe_shop"], "789": ["n04201297", "shoji"],
|
325 |
+
"790": ["n04204238", "shopping_basket"], "791": ["n04204347", "shopping_cart"],
|
326 |
+
"792": ["n04208210", "shovel"], "793": ["n04209133", "shower_cap"],
|
327 |
+
"794": ["n04209239", "shower_curtain"], "795": ["n04228054", "ski"], "796": ["n04229816", "ski_mask"],
|
328 |
+
"797": ["n04235860", "sleeping_bag"], "798": ["n04238763", "slide_rule"],
|
329 |
+
"799": ["n04239074", "sliding_door"], "800": ["n04243546", "slot"], "801": ["n04251144", "snorkel"],
|
330 |
+
"802": ["n04252077", "snowmobile"], "803": ["n04252225", "snowplow"],
|
331 |
+
"804": ["n04254120", "soap_dispenser"], "805": ["n04254680", "soccer_ball"],
|
332 |
+
"806": ["n04254777", "sock"], "807": ["n04258138", "solar_dish"], "808": ["n04259630", "sombrero"],
|
333 |
+
"809": ["n04263257", "soup_bowl"], "810": ["n04264628", "space_bar"],
|
334 |
+
"811": ["n04265275", "space_heater"], "812": ["n04266014", "space_shuttle"],
|
335 |
+
"813": ["n04270147", "spatula"], "814": ["n04273569", "speedboat"], "815": ["n04275548", "spider_web"],
|
336 |
+
"816": ["n04277352", "spindle"], "817": ["n04285008", "sports_car"], "818": ["n04286575", "spotlight"],
|
337 |
+
"819": ["n04296562", "stage"], "820": ["n04310018", "steam_locomotive"],
|
338 |
+
"821": ["n04311004", "steel_arch_bridge"], "822": ["n04311174", "steel_drum"],
|
339 |
+
"823": ["n04317175", "stethoscope"], "824": ["n04325704", "stole"], "825": ["n04326547", "stone_wall"],
|
340 |
+
"826": ["n04328186", "stopwatch"], "827": ["n04330267", "stove"], "828": ["n04332243", "strainer"],
|
341 |
+
"829": ["n04335435", "streetcar"], "830": ["n04336792", "stretcher"],
|
342 |
+
"831": ["n04344873", "studio_couch"], "832": ["n04346328", "stupa"], "833": ["n04347754", "submarine"],
|
343 |
+
"834": ["n04350905", "suit"], "835": ["n04355338", "sundial"], "836": ["n04355933", "sunglass"],
|
344 |
+
"837": ["n04356056", "sunglasses"], "838": ["n04357314", "sunscreen"],
|
345 |
+
"839": ["n04366367", "suspension_bridge"], "840": ["n04367480", "swab"],
|
346 |
+
"841": ["n04370456", "sweatshirt"], "842": ["n04371430", "swimming_trunks"],
|
347 |
+
"843": ["n04371774", "swing"], "844": ["n04372370", "switch"], "845": ["n04376876", "syringe"],
|
348 |
+
"846": ["n04380533", "table_lamp"], "847": ["n04389033", "tank"], "848": ["n04392985", "tape_player"],
|
349 |
+
"849": ["n04398044", "teapot"], "850": ["n04399382", "teddy"], "851": ["n04404412", "television"],
|
350 |
+
"852": ["n04409515", "tennis_ball"], "853": ["n04417672", "thatch"],
|
351 |
+
"854": ["n04418357", "theater_curtain"], "855": ["n04423845", "thimble"],
|
352 |
+
"856": ["n04428191", "thresher"], "857": ["n04429376", "throne"], "858": ["n04435653", "tile_roof"],
|
353 |
+
"859": ["n04442312", "toaster"], "860": ["n04443257", "tobacco_shop"],
|
354 |
+
"861": ["n04447861", "toilet_seat"], "862": ["n04456115", "torch"], "863": ["n04458633", "totem_pole"],
|
355 |
+
"864": ["n04461696", "tow_truck"], "865": ["n04462240", "toyshop"], "866": ["n04465501", "tractor"],
|
356 |
+
"867": ["n04467665", "trailer_truck"], "868": ["n04476259", "tray"],
|
357 |
+
"869": ["n04479046", "trench_coat"], "870": ["n04482393", "tricycle"], "871": ["n04483307", "trimaran"],
|
358 |
+
"872": ["n04485082", "tripod"], "873": ["n04486054", "triumphal_arch"],
|
359 |
+
"874": ["n04487081", "trolleybus"], "875": ["n04487394", "trombone"], "876": ["n04493381", "tub"],
|
360 |
+
"877": ["n04501370", "turnstile"], "878": ["n04505470", "typewriter_keyboard"],
|
361 |
+
"879": ["n04507155", "umbrella"], "880": ["n04509417", "unicycle"], "881": ["n04515003", "upright"],
|
362 |
+
"882": ["n04517823", "vacuum"], "883": ["n04522168", "vase"], "884": ["n04523525", "vault"],
|
363 |
+
"885": ["n04525038", "velvet"], "886": ["n04525305", "vending_machine"],
|
364 |
+
"887": ["n04532106", "vestment"], "888": ["n04532670", "viaduct"], "889": ["n04536866", "violin"],
|
365 |
+
"890": ["n04540053", "volleyball"], "891": ["n04542943", "waffle_iron"],
|
366 |
+
"892": ["n04548280", "wall_clock"], "893": ["n04548362", "wallet"], "894": ["n04550184", "wardrobe"],
|
367 |
+
"895": ["n04552348", "warplane"], "896": ["n04553703", "washbasin"], "897": ["n04554684", "washer"],
|
368 |
+
"898": ["n04557648", "water_bottle"], "899": ["n04560804", "water_jug"],
|
369 |
+
"900": ["n04562935", "water_tower"], "901": ["n04579145", "whiskey_jug"],
|
370 |
+
"902": ["n04579432", "whistle"], "903": ["n04584207", "wig"], "904": ["n04589890", "window_screen"],
|
371 |
+
"905": ["n04590129", "window_shade"], "906": ["n04591157", "Windsor_tie"],
|
372 |
+
"907": ["n04591713", "wine_bottle"], "908": ["n04592741", "wing"], "909": ["n04596742", "wok"],
|
373 |
+
"910": ["n04597913", "wooden_spoon"], "911": ["n04599235", "wool"], "912": ["n04604644", "worm_fence"],
|
374 |
+
"913": ["n04606251", "wreck"], "914": ["n04612504", "yawl"], "915": ["n04613696", "yurt"],
|
375 |
+
"916": ["n06359193", "web_site"], "917": ["n06596364", "comic_book"],
|
376 |
+
"918": ["n06785654", "crossword_puzzle"], "919": ["n06794110", "street_sign"],
|
377 |
+
"920": ["n06874185", "traffic_light"], "921": ["n07248320", "book_jacket"],
|
378 |
+
"922": ["n07565083", "menu"], "923": ["n07579787", "plate"], "924": ["n07583066", "guacamole"],
|
379 |
+
"925": ["n07584110", "consomme"], "926": ["n07590611", "hot_pot"], "927": ["n07613480", "trifle"],
|
380 |
+
"928": ["n07614500", "ice_cream"], "929": ["n07615774", "ice_lolly"],
|
381 |
+
"930": ["n07684084", "French_loaf"], "931": ["n07693725", "bagel"], "932": ["n07695742", "pretzel"],
|
382 |
+
"933": ["n07697313", "cheeseburger"], "934": ["n07697537", "hotdog"],
|
383 |
+
"935": ["n07711569", "mashed_potato"], "936": ["n07714571", "head_cabbage"],
|
384 |
+
"937": ["n07714990", "broccoli"], "938": ["n07715103", "cauliflower"], "939": ["n07716358", "zucchini"],
|
385 |
+
"940": ["n07716906", "spaghetti_squash"], "941": ["n07717410", "acorn_squash"],
|
386 |
+
"942": ["n07717556", "butternut_squash"], "943": ["n07718472", "cucumber"],
|
387 |
+
"944": ["n07718747", "artichoke"], "945": ["n07720875", "bell_pepper"], "946": ["n07730033", "cardoon"],
|
388 |
+
"947": ["n07734744", "mushroom"], "948": ["n07742313", "Granny_Smith"],
|
389 |
+
"949": ["n07745940", "strawberry"], "950": ["n07747607", "orange"], "951": ["n07749582", "lemon"],
|
390 |
+
"952": ["n07753113", "fig"], "953": ["n07753275", "pineapple"], "954": ["n07753592", "banana"],
|
391 |
+
"955": ["n07754684", "jackfruit"], "956": ["n07760859", "custard_apple"],
|
392 |
+
"957": ["n07768694", "pomegranate"], "958": ["n07802026", "hay"], "959": ["n07831146", "carbonara"],
|
393 |
+
"960": ["n07836838", "chocolate_sauce"], "961": ["n07860988", "dough"],
|
394 |
+
"962": ["n07871810", "meat_loaf"], "963": ["n07873807", "pizza"], "964": ["n07875152", "potpie"],
|
395 |
+
"965": ["n07880968", "burrito"], "966": ["n07892512", "red_wine"], "967": ["n07920052", "espresso"],
|
396 |
+
"968": ["n07930864", "cup"], "969": ["n07932039", "eggnog"], "970": ["n09193705", "alp"],
|
397 |
+
"971": ["n09229709", "bubble"], "972": ["n09246464", "cliff"], "973": ["n09256479", "coral_reef"],
|
398 |
+
"974": ["n09288635", "geyser"], "975": ["n09332890", "lakeside"], "976": ["n09399592", "promontory"],
|
399 |
+
"977": ["n09421951", "sandbar"], "978": ["n09428293", "seashore"], "979": ["n09468604", "valley"],
|
400 |
+
"980": ["n09472597", "volcano"], "981": ["n09835506", "ballplayer"], "982": ["n10148035", "groom"],
|
401 |
+
"983": ["n10565667", "scuba_diver"], "984": ["n11879895", "rapeseed"], "985": ["n11939491", "daisy"],
|
402 |
+
"986": ["n12057211", "yellow_lady's_slipper"], "987": ["n12144580", "corn"],
|
403 |
+
"988": ["n12267677", "acorn"], "989": ["n12620546", "hip"], "990": ["n12768682", "buckeye"],
|
404 |
+
"991": ["n12985857", "coral_fungus"], "992": ["n12998815", "agaric"], "993": ["n13037406", "gyromitra"],
|
405 |
+
"994": ["n13040303", "stinkhorn"], "995": ["n13044778", "earthstar"],
|
406 |
+
"996": ["n13052670", "hen-of-the-woods"], "997": ["n13054560", "bolete"], "998": ["n13133613", "ear"],
|
407 |
+
"999": ["n15075141", "toilet_tissue"]}
|
408 |
+
|
imgs/img1.jpg
ADDED
![]() |
losses/clip_loss.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import clip
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from utils.text_templates import imagenet_templates, part_templates, imagenet_templates_small
|
9 |
+
|
10 |
+
|
11 |
+
class DirectionLoss(torch.nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, loss_type='mse'):
|
14 |
+
super(DirectionLoss, self).__init__()
|
15 |
+
|
16 |
+
self.loss_type = loss_type
|
17 |
+
|
18 |
+
self.loss_func = {
|
19 |
+
'mse': torch.nn.MSELoss,
|
20 |
+
'cosine': torch.nn.CosineSimilarity,
|
21 |
+
'mae': torch.nn.L1Loss
|
22 |
+
}[loss_type]()
|
23 |
+
|
24 |
+
def forward(self, x, y):
|
25 |
+
if self.loss_type == "cosine":
|
26 |
+
return 1. - self.loss_func(x, y)
|
27 |
+
|
28 |
+
return self.loss_func(x, y)
|
29 |
+
|
30 |
+
class CLIPLoss(torch.nn.Module):
|
31 |
+
def __init__(self, device, lambda_direction=1., lambda_patch=0., lambda_global=0., lambda_manifold=0., lambda_texture=0., patch_loss_type='mae', direction_loss_type='cosine', clip_model='ViT-B/16'):
|
32 |
+
super(CLIPLoss, self).__init__()
|
33 |
+
|
34 |
+
self.device = device
|
35 |
+
self.model, clip_preprocess = clip.load(clip_model, device=self.device)
|
36 |
+
|
37 |
+
self.clip_preprocess = clip_preprocess
|
38 |
+
|
39 |
+
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
|
40 |
+
clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
|
41 |
+
clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
|
42 |
+
|
43 |
+
self.target_direction = None
|
44 |
+
self.patch_text_directions = None
|
45 |
+
|
46 |
+
self.patch_loss = DirectionLoss(patch_loss_type)
|
47 |
+
self.direction_loss = DirectionLoss(direction_loss_type)
|
48 |
+
self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2)
|
49 |
+
|
50 |
+
self.lambda_global = lambda_global
|
51 |
+
self.lambda_patch = lambda_patch
|
52 |
+
self.lambda_direction = lambda_direction
|
53 |
+
self.lambda_manifold = lambda_manifold
|
54 |
+
self.lambda_texture = lambda_texture
|
55 |
+
|
56 |
+
self.src_text_features = None
|
57 |
+
self.target_text_features = None
|
58 |
+
self.angle_loss = torch.nn.L1Loss()
|
59 |
+
|
60 |
+
self.model_cnn, preprocess_cnn = clip.load("RN50", device=self.device)
|
61 |
+
self.preprocess_cnn = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
|
62 |
+
preprocess_cnn.transforms[:2] + # to match CLIP input scale assumptions
|
63 |
+
preprocess_cnn.transforms[4:]) # + skip convert PIL to tensor
|
64 |
+
|
65 |
+
self.texture_loss = torch.nn.MSELoss()
|
66 |
+
|
67 |
+
def tokenize(self, strings: list):
|
68 |
+
return clip.tokenize(strings).to(self.device)
|
69 |
+
|
70 |
+
def encode_text(self, tokens: list) -> torch.Tensor:
|
71 |
+
return self.model.encode_text(tokens)
|
72 |
+
|
73 |
+
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
|
74 |
+
images = self.preprocess(images).to(self.device)
|
75 |
+
return self.model.encode_image(images)
|
76 |
+
|
77 |
+
def encode_images_with_cnn(self, images: torch.Tensor) -> torch.Tensor:
|
78 |
+
images = self.preprocess_cnn(images).to(self.device)
|
79 |
+
return self.model_cnn.encode_image(images)
|
80 |
+
|
81 |
+
def distance_with_templates(self, img: torch.Tensor, class_str: str, templates=imagenet_templates) -> torch.Tensor:
|
82 |
+
|
83 |
+
text_features = self.get_text_features(class_str, templates)
|
84 |
+
image_features = self.get_image_features(img)
|
85 |
+
|
86 |
+
similarity = image_features @ text_features.T
|
87 |
+
|
88 |
+
return 1. - similarity
|
89 |
+
|
90 |
+
def get_text_features(self, class_str: str, templates=imagenet_templates, norm: bool = True) -> torch.Tensor:
|
91 |
+
template_text = self.compose_text_with_templates(class_str, templates)
|
92 |
+
|
93 |
+
tokens = clip.tokenize(template_text).to(self.device)
|
94 |
+
|
95 |
+
text_features = self.encode_text(tokens).detach()
|
96 |
+
|
97 |
+
if norm:
|
98 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
99 |
+
|
100 |
+
return text_features
|
101 |
+
|
102 |
+
def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
|
103 |
+
image_features = self.encode_images(img)
|
104 |
+
|
105 |
+
if norm:
|
106 |
+
image_features /= image_features.clone().norm(dim=-1, keepdim=True)
|
107 |
+
|
108 |
+
return image_features
|
109 |
+
|
110 |
+
def compute_text_direction(self, source_class: str, target_class: str) -> torch.Tensor:
|
111 |
+
source_features = self.get_text_features(source_class)
|
112 |
+
target_features = self.get_text_features(target_class)
|
113 |
+
|
114 |
+
text_direction = (target_features - source_features).mean(axis=0, keepdim=True)
|
115 |
+
text_direction /= text_direction.norm(dim=-1, keepdim=True)
|
116 |
+
|
117 |
+
return text_direction
|
118 |
+
|
119 |
+
def compute_img2img_direction(self, source_images: torch.Tensor, target_images: list) -> torch.Tensor:
|
120 |
+
with torch.no_grad():
|
121 |
+
|
122 |
+
src_encoding = self.get_image_features(source_images)
|
123 |
+
src_encoding = src_encoding.mean(dim=0, keepdim=True)
|
124 |
+
|
125 |
+
target_encodings = []
|
126 |
+
for target_img in target_images:
|
127 |
+
preprocessed = self.clip_preprocess(Image.open(target_img)).unsqueeze(0).to(self.device)
|
128 |
+
|
129 |
+
encoding = self.model.encode_image(preprocessed)
|
130 |
+
encoding /= encoding.norm(dim=-1, keepdim=True)
|
131 |
+
|
132 |
+
target_encodings.append(encoding)
|
133 |
+
|
134 |
+
target_encoding = torch.cat(target_encodings, axis=0)
|
135 |
+
target_encoding = target_encoding.mean(dim=0, keepdim=True)
|
136 |
+
|
137 |
+
direction = target_encoding - src_encoding
|
138 |
+
direction /= direction.norm(dim=-1, keepdim=True)
|
139 |
+
|
140 |
+
return direction
|
141 |
+
|
142 |
+
def set_text_features(self, source_class: str, target_class: str) -> None:
|
143 |
+
source_features = self.get_text_features(source_class).mean(axis=0, keepdim=True)
|
144 |
+
self.src_text_features = source_features / source_features.norm(dim=-1, keepdim=True)
|
145 |
+
|
146 |
+
target_features = self.get_text_features(target_class).mean(axis=0, keepdim=True)
|
147 |
+
self.target_text_features = target_features / target_features.norm(dim=-1, keepdim=True)
|
148 |
+
|
149 |
+
def clip_angle_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
|
150 |
+
if self.src_text_features is None:
|
151 |
+
self.set_text_features(source_class, target_class)
|
152 |
+
|
153 |
+
cos_text_angle = self.target_text_features @ self.src_text_features.T
|
154 |
+
text_angle = torch.acos(cos_text_angle)
|
155 |
+
|
156 |
+
src_img_features = self.get_image_features(src_img).unsqueeze(2)
|
157 |
+
target_img_features = self.get_image_features(target_img).unsqueeze(1)
|
158 |
+
|
159 |
+
cos_img_angle = torch.clamp(target_img_features @ src_img_features, min=-1.0, max=1.0)
|
160 |
+
img_angle = torch.acos(cos_img_angle)
|
161 |
+
|
162 |
+
text_angle = text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1)
|
163 |
+
cos_text_angle = cos_text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1)
|
164 |
+
|
165 |
+
return self.angle_loss(cos_img_angle, cos_text_angle)
|
166 |
+
|
167 |
+
def compose_text_with_templates(self, text: str, templates=imagenet_templates) -> list:
|
168 |
+
return [template.format(text) for template in templates]
|
169 |
+
|
170 |
+
def clip_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
|
171 |
+
|
172 |
+
if self.target_direction is None:
|
173 |
+
self.target_direction = self.compute_text_direction(source_class, target_class)
|
174 |
+
|
175 |
+
src_encoding = self.get_image_features(src_img)
|
176 |
+
target_encoding = self.get_image_features(target_img)
|
177 |
+
|
178 |
+
edit_direction = (target_encoding - src_encoding)
|
179 |
+
edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7)
|
180 |
+
return self.direction_loss(edit_direction, self.target_direction).mean()
|
181 |
+
|
182 |
+
def global_clip_loss(self, img: torch.Tensor, text) -> torch.Tensor:
|
183 |
+
if not isinstance(text, list):
|
184 |
+
text = [text]
|
185 |
+
|
186 |
+
tokens = clip.tokenize(text).to(self.device)
|
187 |
+
image = self.preprocess(img)
|
188 |
+
|
189 |
+
logits_per_image, _ = self.model(image, tokens)
|
190 |
+
|
191 |
+
return (1. - logits_per_image / 100).mean()
|
192 |
+
|
193 |
+
def random_patch_centers(self, img_shape, num_patches, size):
|
194 |
+
batch_size, channels, height, width = img_shape
|
195 |
+
|
196 |
+
half_size = size // 2
|
197 |
+
patch_centers = np.concatenate([np.random.randint(half_size, width - half_size, size=(batch_size * num_patches, 1)),
|
198 |
+
np.random.randint(half_size, height - half_size, size=(batch_size * num_patches, 1))], axis=1)
|
199 |
+
|
200 |
+
return patch_centers
|
201 |
+
|
202 |
+
def generate_patches(self, img: torch.Tensor, patch_centers, size):
|
203 |
+
batch_size = img.shape[0]
|
204 |
+
num_patches = len(patch_centers) // batch_size
|
205 |
+
half_size = size // 2
|
206 |
+
|
207 |
+
patches = []
|
208 |
+
|
209 |
+
for batch_idx in range(batch_size):
|
210 |
+
for patch_idx in range(num_patches):
|
211 |
+
|
212 |
+
center_x = patch_centers[batch_idx * num_patches + patch_idx][0]
|
213 |
+
center_y = patch_centers[batch_idx * num_patches + patch_idx][1]
|
214 |
+
|
215 |
+
patch = img[batch_idx:batch_idx+1, :, center_y - half_size:center_y + half_size, center_x - half_size:center_x + half_size]
|
216 |
+
|
217 |
+
patches.append(patch)
|
218 |
+
|
219 |
+
patches = torch.cat(patches, axis=0)
|
220 |
+
|
221 |
+
return patches
|
222 |
+
|
223 |
+
def patch_scores(self, img: torch.Tensor, class_str: str, patch_centers, patch_size: int) -> torch.Tensor:
|
224 |
+
|
225 |
+
parts = self.compose_text_with_templates(class_str, part_templates)
|
226 |
+
tokens = clip.tokenize(parts).to(self.device)
|
227 |
+
text_features = self.encode_text(tokens).detach()
|
228 |
+
|
229 |
+
patches = self.generate_patches(img, patch_centers, patch_size)
|
230 |
+
image_features = self.get_image_features(patches)
|
231 |
+
|
232 |
+
similarity = image_features @ text_features.T
|
233 |
+
|
234 |
+
return similarity
|
235 |
+
|
236 |
+
def clip_patch_similarity(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
|
237 |
+
patch_size = 196 #TODO remove magic number
|
238 |
+
|
239 |
+
patch_centers = self.random_patch_centers(src_img.shape, 4, patch_size) #TODO remove magic number
|
240 |
+
|
241 |
+
src_scores = self.patch_scores(src_img, source_class, patch_centers, patch_size)
|
242 |
+
target_scores = self.patch_scores(target_img, target_class, patch_centers, patch_size)
|
243 |
+
|
244 |
+
return self.patch_loss(src_scores, target_scores)
|
245 |
+
|
246 |
+
def patch_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
|
247 |
+
|
248 |
+
if self.patch_text_directions is None:
|
249 |
+
src_part_classes = self.compose_text_with_templates(source_class, part_templates)
|
250 |
+
target_part_classes = self.compose_text_with_templates(target_class, part_templates)
|
251 |
+
|
252 |
+
parts_classes = list(zip(src_part_classes, target_part_classes))
|
253 |
+
|
254 |
+
self.patch_text_directions = torch.cat([self.compute_text_direction(pair[0], pair[1]) for pair in parts_classes], dim=0)
|
255 |
+
|
256 |
+
patch_size = 510 # TODO remove magic numbers
|
257 |
+
|
258 |
+
patch_centers = self.random_patch_centers(src_img.shape, 1, patch_size)
|
259 |
+
|
260 |
+
patches = self.generate_patches(src_img, patch_centers, patch_size)
|
261 |
+
src_features = self.get_image_features(patches)
|
262 |
+
|
263 |
+
patches = self.generate_patches(target_img, patch_centers, patch_size)
|
264 |
+
target_features = self.get_image_features(patches)
|
265 |
+
|
266 |
+
edit_direction = (target_features - src_features)
|
267 |
+
edit_direction /= edit_direction.clone().norm(dim=-1, keepdim=True)
|
268 |
+
|
269 |
+
cosine_dists = 1. - self.patch_direction_loss(edit_direction.unsqueeze(1), self.patch_text_directions.unsqueeze(0))
|
270 |
+
|
271 |
+
patch_class_scores = cosine_dists * (edit_direction @ self.patch_text_directions.T).softmax(dim=-1)
|
272 |
+
|
273 |
+
return patch_class_scores.mean()
|
274 |
+
|
275 |
+
def cnn_feature_loss(self, src_img: torch.Tensor, target_img: torch.Tensor) -> torch.Tensor:
|
276 |
+
src_features = self.encode_images_with_cnn(src_img)
|
277 |
+
target_features = self.encode_images_with_cnn(target_img)
|
278 |
+
|
279 |
+
return self.texture_loss(src_features, target_features)
|
280 |
+
|
281 |
+
def forward(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str, texture_image: torch.Tensor = None):
|
282 |
+
clip_loss = 0.0
|
283 |
+
|
284 |
+
if self.lambda_global:
|
285 |
+
clip_loss += self.lambda_global * self.global_clip_loss(target_img, [f"a {target_class}"])
|
286 |
+
|
287 |
+
if self.lambda_patch:
|
288 |
+
clip_loss += self.lambda_patch * self.patch_directional_loss(src_img, source_class, target_img, target_class)
|
289 |
+
|
290 |
+
if self.lambda_direction:
|
291 |
+
clip_loss += self.lambda_direction * self.clip_directional_loss(src_img, source_class, target_img, target_class)
|
292 |
+
|
293 |
+
if self.lambda_manifold:
|
294 |
+
clip_loss += self.lambda_manifold * self.clip_angle_loss(src_img, source_class, target_img, target_class)
|
295 |
+
|
296 |
+
if self.lambda_texture and (texture_image is not None):
|
297 |
+
clip_loss += self.lambda_texture * self.cnn_feature_loss(texture_image, target_img)
|
298 |
+
|
299 |
+
return clip_loss
|
losses/id_loss.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from configs.paths_config import MODEL_PATHS
|
4 |
+
from models.insight_face.model_irse import Backbone, MobileFaceNet
|
5 |
+
|
6 |
+
|
7 |
+
class IDLoss(nn.Module):
|
8 |
+
def __init__(self, use_mobile_id=False):
|
9 |
+
super(IDLoss, self).__init__()
|
10 |
+
print('Loading ResNet ArcFace')
|
11 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
12 |
+
self.facenet.load_state_dict(torch.load(MODEL_PATHS['ir_se50']))
|
13 |
+
|
14 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
15 |
+
self.facenet.eval()
|
16 |
+
|
17 |
+
def extract_feats(self, x):
|
18 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
19 |
+
x = self.face_pool(x)
|
20 |
+
x_feats = self.facenet(x)
|
21 |
+
return x_feats
|
22 |
+
|
23 |
+
def forward(self, x, x_hat):
|
24 |
+
n_samples = x.shape[0]
|
25 |
+
x_feats = self.extract_feats(x)
|
26 |
+
x_feats = x_feats.detach()
|
27 |
+
|
28 |
+
x_hat_feats = self.extract_feats(x_hat)
|
29 |
+
losses = []
|
30 |
+
for i in range(n_samples):
|
31 |
+
loss_sample = 1 - x_hat_feats[i].dot(x_feats[i])
|
32 |
+
losses.append(loss_sample.unsqueeze(0))
|
33 |
+
|
34 |
+
losses = torch.cat(losses, dim=0)
|
35 |
+
return losses
|
main.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import traceback
|
3 |
+
import logging
|
4 |
+
import yaml
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from boundarydiffusion import BoundaryDiffusion
|
11 |
+
from configs.paths_config import HYBRID_MODEL_PATHS
|
12 |
+
|
13 |
+
def parse_args_and_config():
|
14 |
+
parser = argparse.ArgumentParser(description=globals()['__doc__'])
|
15 |
+
|
16 |
+
# Mode
|
17 |
+
parser.add_argument('--radius', action='store_true')
|
18 |
+
parser.add_argument('--unconditional', action='store_true')
|
19 |
+
parser.add_argument('--boundary_search', action='store_true')
|
20 |
+
parser.add_argument('--diffusion_hyperplane', action='store_true')
|
21 |
+
parser.add_argument('--clip_finetune', action='store_true')
|
22 |
+
parser.add_argument('--clip_latent_optim', action='store_true')
|
23 |
+
parser.add_argument('--edit_images_from_dataset', action='store_true')
|
24 |
+
parser.add_argument('--edit_one_image', action='store_true')
|
25 |
+
parser.add_argument('--unseen2unseen', action='store_true')
|
26 |
+
parser.add_argument('--clip_finetune_eff', action='store_true')
|
27 |
+
parser.add_argument('--edit_one_image_eff', action='store_true')
|
28 |
+
parser.add_argument('--edit_image_boundary', action='store_true')
|
29 |
+
|
30 |
+
# Default
|
31 |
+
parser.add_argument('--config', type=str, required=True, help='Path to the config file')
|
32 |
+
parser.add_argument('--seed', type=int, default=1006, help='Random seed')
|
33 |
+
parser.add_argument('--exp', type=str, default='./runs/', help='Path for saving running related data.')
|
34 |
+
parser.add_argument('--comment', type=str, default='', help='A string for experiment comment')
|
35 |
+
parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
|
36 |
+
parser.add_argument('--ni', type=int, default=1, help="No interaction. Suitable for Slurm Job launcher")
|
37 |
+
parser.add_argument('--align_face', type=int, default=1, help='align face or not')
|
38 |
+
|
39 |
+
# Text
|
40 |
+
parser.add_argument('--edit_attr', type=str, default=None, help='Attribute to edit defiend in ./utils/text_dic.py')
|
41 |
+
parser.add_argument('--src_txts', type=str, action='append', help='Source text e.g. Face')
|
42 |
+
parser.add_argument('--trg_txts', type=str, action='append', help='Target text e.g. Angry Face')
|
43 |
+
parser.add_argument('--target_class_num', type=str, default=None)
|
44 |
+
|
45 |
+
# Sampling
|
46 |
+
parser.add_argument('--t_0', type=int, default=400, help='Return step in [0, 1000)')
|
47 |
+
parser.add_argument('--n_inv_step', type=int, default=40, help='# of steps during generative pross for inversion')
|
48 |
+
parser.add_argument('--n_train_step', type=int, default=6, help='# of steps during generative pross for train')
|
49 |
+
parser.add_argument('--n_test_step', type=int, default=40, help='# of steps during generative pross for test')
|
50 |
+
parser.add_argument('--sample_type', type=str, default='ddim', help='ddpm for Markovian sampling, ddim for non-Markovian sampling')
|
51 |
+
parser.add_argument('--eta', type=float, default=0.0, help='Controls of varaince of the generative process')
|
52 |
+
parser.add_argument('--start_distance', type=float, default=-150.0, help='Starting distance of the editing space')
|
53 |
+
parser.add_argument('--end_distance', type=float, default=150.0, help='Ending distance of the editing space')
|
54 |
+
parser.add_argument('--edit_img_number', type=int, default=20, help='Number of editing images')
|
55 |
+
|
56 |
+
# Train & Test
|
57 |
+
parser.add_argument('--do_train', type=int, default=1, help='Whether to train or not during CLIP finetuning')
|
58 |
+
parser.add_argument('--do_test', type=int, default=1, help='Whether to test or not during CLIP finetuning')
|
59 |
+
parser.add_argument('--save_train_image', type=int, default=1, help='Wheter to save training results during CLIP fineuning')
|
60 |
+
parser.add_argument('--bs_train', type=int, default=1, help='Training batch size during CLIP fineuning')
|
61 |
+
parser.add_argument('--bs_test', type=int, default=1, help='Test batch size during CLIP fineuning')
|
62 |
+
parser.add_argument('--n_precomp_img', type=int, default=100, help='# of images to precompute latents')
|
63 |
+
parser.add_argument('--n_train_img', type=int, default=50, help='# of training images')
|
64 |
+
parser.add_argument('--n_test_img', type=int, default=10, help='# of test images')
|
65 |
+
parser.add_argument('--model_path', type=str, default=None, help='Test model path')
|
66 |
+
parser.add_argument('--img_path', type=str, default=None, help='Image path to test')
|
67 |
+
parser.add_argument('--deterministic_inv', type=int, default=1, help='Whether to use deterministic inversion during inference')
|
68 |
+
parser.add_argument('--hybrid_noise', type=int, default=0, help='Whether to change multiple attributes by mixing multiple models')
|
69 |
+
parser.add_argument('--model_ratio', type=float, default=1, help='Degree of change, noise ratio from original and finetuned model.')
|
70 |
+
|
71 |
+
|
72 |
+
# Loss & Optimization
|
73 |
+
parser.add_argument('--clip_loss_w', type=int, default=0, help='Weights of CLIP loss')
|
74 |
+
parser.add_argument('--l1_loss_w', type=float, default=0, help='Weights of L1 loss')
|
75 |
+
parser.add_argument('--id_loss_w', type=float, default=0, help='Weights of ID loss')
|
76 |
+
parser.add_argument('--clip_model_name', type=str, default='ViT-B/16', help='ViT-B/16, ViT-B/32, RN50x16 etc')
|
77 |
+
parser.add_argument('--lr_clip_finetune', type=float, default=2e-6, help='Initial learning rate for finetuning')
|
78 |
+
parser.add_argument('--lr_clip_lat_opt', type=float, default=2e-2, help='Initial learning rate for latent optim')
|
79 |
+
parser.add_argument('--n_iter', type=int, default=1, help='# of iterations of a generative process with `n_train_img` images')
|
80 |
+
parser.add_argument('--scheduler', type=int, default=1, help='Whether to increase the learning rate')
|
81 |
+
parser.add_argument('--sch_gamma', type=float, default=1.3, help='Scheduler gamma')
|
82 |
+
|
83 |
+
args = parser.parse_args()
|
84 |
+
|
85 |
+
# parse config file
|
86 |
+
with open(os.path.join('configs', args.config), 'r') as f:
|
87 |
+
config = yaml.safe_load(f)
|
88 |
+
new_config = dict2namespace(config)
|
89 |
+
|
90 |
+
if args.diffusion_hyperplane:
|
91 |
+
if args.edit_attr is not None:
|
92 |
+
args.exp = args.exp + f'_SP_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
93 |
+
else:
|
94 |
+
args.exp = args.exp + f'_SP_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
95 |
+
elif args.radius:
|
96 |
+
if args.edit_attr is not None:
|
97 |
+
args.exp = args.exp + f'_R_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
98 |
+
else:
|
99 |
+
args.exp = args.exp + f'_R_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
100 |
+
elif args.unconditional:
|
101 |
+
if args.edit_attr is not None:
|
102 |
+
args.exp = args.exp + f'_UN_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
103 |
+
else:
|
104 |
+
args.exp = args.exp + f'_UN_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
105 |
+
elif args.boundary_search:
|
106 |
+
if args.edit_attr is not None:
|
107 |
+
args.exp = args.exp + f'_BCLIP_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
108 |
+
else:
|
109 |
+
args.exp = args.exp + f'_BCLIP_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
110 |
+
elif args.clip_finetune or args.clip_finetune_eff :
|
111 |
+
if args.edit_attr is not None:
|
112 |
+
args.exp = args.exp + f'_FT_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
113 |
+
else:
|
114 |
+
args.exp = args.exp + f'_FT_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
|
115 |
+
elif args.clip_latent_optim:
|
116 |
+
if args.edit_attr is not None:
|
117 |
+
args.exp = args.exp + f'_LO_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_lat_opt}'
|
118 |
+
else:
|
119 |
+
args.exp = args.exp + f'_LO_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_lat_opt}'
|
120 |
+
elif args.edit_images_from_dataset:
|
121 |
+
if args.model_path:
|
122 |
+
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_{os.path.split(args.model_path)[-1].replace(".pth","")}'
|
123 |
+
elif args.hybrid_noise:
|
124 |
+
hb_str = '_'
|
125 |
+
for i, model_name in enumerate(HYBRID_MODEL_PATHS):
|
126 |
+
hb_str = hb_str + model_name.split('_')[1]
|
127 |
+
if i != len(HYBRID_MODEL_PATHS) - 1:
|
128 |
+
hb_str = hb_str + '_'
|
129 |
+
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}' + hb_str
|
130 |
+
else:
|
131 |
+
args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}_orig'
|
132 |
+
|
133 |
+
elif args.edit_image_boundary:
|
134 |
+
if args.model_path:
|
135 |
+
args.exp = args.exp + f'_E1_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
|
136 |
+
elif args.hybrid_noise:
|
137 |
+
hb_str = '_'
|
138 |
+
for i, model_name in enumerate(HYBRID_MODEL_PATHS):
|
139 |
+
hb_str = hb_str + model_name.split('_')[1]
|
140 |
+
if i != len(HYBRID_MODEL_PATHS) - 1:
|
141 |
+
hb_str = hb_str + '_'
|
142 |
+
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' + hb_str
|
143 |
+
else:
|
144 |
+
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_orig'
|
145 |
+
|
146 |
+
|
147 |
+
elif args.edit_one_image:
|
148 |
+
if args.model_path:
|
149 |
+
args.exp = args.exp + f'_E1_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
|
150 |
+
elif args.hybrid_noise:
|
151 |
+
hb_str = '_'
|
152 |
+
for i, model_name in enumerate(HYBRID_MODEL_PATHS):
|
153 |
+
hb_str = hb_str + model_name.split('_')[1]
|
154 |
+
if i != len(HYBRID_MODEL_PATHS) - 1:
|
155 |
+
hb_str = hb_str + '_'
|
156 |
+
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' + hb_str
|
157 |
+
else:
|
158 |
+
args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_orig'
|
159 |
+
|
160 |
+
elif args.unseen2unseen:
|
161 |
+
if args.model_path:
|
162 |
+
args.exp = args.exp + f'_U2U_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
|
163 |
+
elif args.hybrid_noise:
|
164 |
+
hb_str = '_'
|
165 |
+
for i, model_name in enumerate(HYBRID_MODEL_PATHS):
|
166 |
+
hb_str = hb_str + model_name.split('_')[1]
|
167 |
+
if i != len(HYBRID_MODEL_PATHS) - 1:
|
168 |
+
hb_str = hb_str + '_'
|
169 |
+
args.exp = args.exp + f'_U2U_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}' + hb_str
|
170 |
+
else:
|
171 |
+
args.exp = args.exp + f'_U2U_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}_orig'
|
172 |
+
|
173 |
+
elif args.recon_exp:
|
174 |
+
args.exp = args.exp + f'_REC_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}'
|
175 |
+
elif args.find_best_image:
|
176 |
+
args.exp = args.exp + f'_FOpt_{new_config.data.category}_{args.trg_txts[0]}_t{args.t_0}_ninv{args.n_train_step}'
|
177 |
+
|
178 |
+
|
179 |
+
level = getattr(logging, args.verbose.upper(), None)
|
180 |
+
if not isinstance(level, int):
|
181 |
+
raise ValueError('level {} not supported'.format(args.verbose))
|
182 |
+
|
183 |
+
handler1 = logging.StreamHandler()
|
184 |
+
formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
|
185 |
+
handler1.setFormatter(formatter)
|
186 |
+
logger = logging.getLogger()
|
187 |
+
logger.addHandler(handler1)
|
188 |
+
logger.setLevel(level)
|
189 |
+
|
190 |
+
os.makedirs(args.exp, exist_ok=True)
|
191 |
+
os.makedirs('checkpoint', exist_ok=True)
|
192 |
+
os.makedirs('precomputed', exist_ok=True)
|
193 |
+
os.makedirs('runs', exist_ok=True)
|
194 |
+
os.makedirs(args.exp, exist_ok=True)
|
195 |
+
args.image_folder = os.path.join(args.exp, 'image_samples')
|
196 |
+
if not os.path.exists(args.image_folder):
|
197 |
+
os.makedirs(args.image_folder)
|
198 |
+
else:
|
199 |
+
overwrite = False
|
200 |
+
if args.ni:
|
201 |
+
overwrite = True
|
202 |
+
else:
|
203 |
+
response = input("Image folder already exists. Overwrite? (Y/N)")
|
204 |
+
if response.upper() == 'Y':
|
205 |
+
overwrite = True
|
206 |
+
|
207 |
+
if overwrite:
|
208 |
+
# shutil.rmtree(args.image_folder)
|
209 |
+
os.makedirs(args.image_folder, exist_ok=True)
|
210 |
+
else:
|
211 |
+
print("Output image folder exists. Program halted.")
|
212 |
+
sys.exit(0)
|
213 |
+
|
214 |
+
# add device
|
215 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
216 |
+
logging.info("Using device: {}".format(device))
|
217 |
+
new_config.device = device
|
218 |
+
|
219 |
+
# set random seed
|
220 |
+
torch.manual_seed(args.seed)
|
221 |
+
np.random.seed(args.seed)
|
222 |
+
if torch.cuda.is_available():
|
223 |
+
torch.cuda.manual_seed_all(args.seed)
|
224 |
+
|
225 |
+
torch.backends.cudnn.benchmark = True
|
226 |
+
|
227 |
+
return args, new_config
|
228 |
+
|
229 |
+
|
230 |
+
def dict2namespace(config):
|
231 |
+
namespace = argparse.Namespace()
|
232 |
+
for key, value in config.items():
|
233 |
+
if isinstance(value, dict):
|
234 |
+
new_value = dict2namespace(value)
|
235 |
+
else:
|
236 |
+
new_value = value
|
237 |
+
setattr(namespace, key, new_value)
|
238 |
+
return namespace
|
239 |
+
|
240 |
+
|
241 |
+
def main():
|
242 |
+
args, config = parse_args_and_config()
|
243 |
+
print(">" * 80)
|
244 |
+
logging.info("Exp instance id = {}".format(os.getpid()))
|
245 |
+
logging.info("Exp comment = {}".format(args.comment))
|
246 |
+
logging.info("Config =")
|
247 |
+
print("<" * 80)
|
248 |
+
|
249 |
+
|
250 |
+
runner = BoundaryDiffusion(args, config)
|
251 |
+
try:
|
252 |
+
if args.clip_finetune:
|
253 |
+
runner.clip_finetune()
|
254 |
+
elif args.radius:
|
255 |
+
runner.radius()
|
256 |
+
elif args.unconditional:
|
257 |
+
runner.unconditional()
|
258 |
+
elif args.diffusion_hyperplane:
|
259 |
+
runner.diffusion_hyperplane()
|
260 |
+
elif args.boundary_search:
|
261 |
+
runner.boundary_search()
|
262 |
+
elif args.edit_image_boundary:
|
263 |
+
runner.edit_image_boundary()
|
264 |
+
else:
|
265 |
+
print('Choose one mode!')
|
266 |
+
raise ValueError
|
267 |
+
except Exception:
|
268 |
+
logging.error(traceback.format_exc())
|
269 |
+
|
270 |
+
|
271 |
+
return 0
|
272 |
+
|
273 |
+
|
274 |
+
if __name__ == '__main__':
|
275 |
+
sys.exit(main())
|
models/ddpm/diffusion.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
7 |
+
"""
|
8 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
9 |
+
From Fairseq.
|
10 |
+
Build sinusoidal embeddings.
|
11 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
12 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
13 |
+
"""
|
14 |
+
assert len(timesteps.shape) == 1
|
15 |
+
|
16 |
+
half_dim = embedding_dim // 2
|
17 |
+
emb = math.log(10000) / (half_dim - 1)
|
18 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
19 |
+
emb = emb.to(device=timesteps.device)
|
20 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
21 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
22 |
+
if embedding_dim % 2 == 1: # zero pad
|
23 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
24 |
+
return emb
|
25 |
+
|
26 |
+
|
27 |
+
def nonlinearity(x):
|
28 |
+
# swish
|
29 |
+
return x * torch.sigmoid(x)
|
30 |
+
|
31 |
+
|
32 |
+
def Normalize(in_channels):
|
33 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
34 |
+
|
35 |
+
|
36 |
+
class Upsample(nn.Module):
|
37 |
+
def __init__(self, in_channels, with_conv):
|
38 |
+
super().__init__()
|
39 |
+
self.with_conv = with_conv
|
40 |
+
if self.with_conv:
|
41 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
42 |
+
in_channels,
|
43 |
+
kernel_size=3,
|
44 |
+
stride=1,
|
45 |
+
padding=1)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
x = torch.nn.functional.interpolate(
|
49 |
+
x, scale_factor=2.0, mode="nearest")
|
50 |
+
if self.with_conv:
|
51 |
+
x = self.conv(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class Downsample(nn.Module):
|
56 |
+
def __init__(self, in_channels, with_conv):
|
57 |
+
super().__init__()
|
58 |
+
self.with_conv = with_conv
|
59 |
+
if self.with_conv:
|
60 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
61 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
62 |
+
in_channels,
|
63 |
+
kernel_size=3,
|
64 |
+
stride=2,
|
65 |
+
padding=0)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.with_conv:
|
69 |
+
pad = (0, 1, 0, 1)
|
70 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
71 |
+
x = self.conv(x)
|
72 |
+
else:
|
73 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class ResnetBlock(nn.Module):
|
78 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
79 |
+
dropout, temb_channels=512):
|
80 |
+
super().__init__()
|
81 |
+
self.in_channels = in_channels
|
82 |
+
out_channels = in_channels if out_channels is None else out_channels
|
83 |
+
self.out_channels = out_channels
|
84 |
+
self.use_conv_shortcut = conv_shortcut
|
85 |
+
|
86 |
+
self.norm1 = Normalize(in_channels)
|
87 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
88 |
+
out_channels,
|
89 |
+
kernel_size=3,
|
90 |
+
stride=1,
|
91 |
+
padding=1)
|
92 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
93 |
+
out_channels)
|
94 |
+
self.norm2 = Normalize(out_channels)
|
95 |
+
self.dropout = torch.nn.Dropout(dropout)
|
96 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
97 |
+
out_channels,
|
98 |
+
kernel_size=3,
|
99 |
+
stride=1,
|
100 |
+
padding=1)
|
101 |
+
if self.in_channels != self.out_channels:
|
102 |
+
if self.use_conv_shortcut:
|
103 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
104 |
+
out_channels,
|
105 |
+
kernel_size=3,
|
106 |
+
stride=1,
|
107 |
+
padding=1)
|
108 |
+
else:
|
109 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
110 |
+
out_channels,
|
111 |
+
kernel_size=1,
|
112 |
+
stride=1,
|
113 |
+
padding=0)
|
114 |
+
|
115 |
+
def forward(self, x, temb):
|
116 |
+
h = x
|
117 |
+
h = self.norm1(h)
|
118 |
+
h = nonlinearity(h)
|
119 |
+
h = self.conv1(h)
|
120 |
+
|
121 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
122 |
+
|
123 |
+
h = self.norm2(h)
|
124 |
+
h = nonlinearity(h)
|
125 |
+
h = self.dropout(h)
|
126 |
+
h = self.conv2(h)
|
127 |
+
|
128 |
+
if self.in_channels != self.out_channels:
|
129 |
+
if self.use_conv_shortcut:
|
130 |
+
x = self.conv_shortcut(x)
|
131 |
+
else:
|
132 |
+
x = self.nin_shortcut(x)
|
133 |
+
|
134 |
+
return x + h
|
135 |
+
|
136 |
+
|
137 |
+
class AttnBlock(nn.Module):
|
138 |
+
def __init__(self, in_channels):
|
139 |
+
super().__init__()
|
140 |
+
self.in_channels = in_channels
|
141 |
+
|
142 |
+
self.norm = Normalize(in_channels)
|
143 |
+
self.q = torch.nn.Conv2d(in_channels,
|
144 |
+
in_channels,
|
145 |
+
kernel_size=1,
|
146 |
+
stride=1,
|
147 |
+
padding=0)
|
148 |
+
self.k = torch.nn.Conv2d(in_channels,
|
149 |
+
in_channels,
|
150 |
+
kernel_size=1,
|
151 |
+
stride=1,
|
152 |
+
padding=0)
|
153 |
+
self.v = torch.nn.Conv2d(in_channels,
|
154 |
+
in_channels,
|
155 |
+
kernel_size=1,
|
156 |
+
stride=1,
|
157 |
+
padding=0)
|
158 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
159 |
+
in_channels,
|
160 |
+
kernel_size=1,
|
161 |
+
stride=1,
|
162 |
+
padding=0)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
h_ = x
|
166 |
+
h_ = self.norm(h_)
|
167 |
+
q = self.q(h_)
|
168 |
+
k = self.k(h_)
|
169 |
+
v = self.v(h_)
|
170 |
+
|
171 |
+
# compute attention
|
172 |
+
b, c, h, w = q.shape
|
173 |
+
q = q.reshape(b, c, h * w)
|
174 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
175 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
176 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
177 |
+
w_ = w_ * (int(c) ** (-0.5))
|
178 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
179 |
+
|
180 |
+
# attend to values
|
181 |
+
v = v.reshape(b, c, h * w)
|
182 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
183 |
+
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
184 |
+
h_ = torch.bmm(v, w_)
|
185 |
+
h_ = h_.reshape(b, c, h, w)
|
186 |
+
|
187 |
+
h_ = self.proj_out(h_)
|
188 |
+
|
189 |
+
return x + h_
|
190 |
+
|
191 |
+
|
192 |
+
class DDPM(nn.Module):
|
193 |
+
def __init__(self, config):
|
194 |
+
super().__init__()
|
195 |
+
self.config = config
|
196 |
+
ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
|
197 |
+
num_res_blocks = config.model.num_res_blocks
|
198 |
+
attn_resolutions = config.model.attn_resolutions
|
199 |
+
dropout = config.model.dropout
|
200 |
+
in_channels = config.model.in_channels
|
201 |
+
resolution = config.data.image_size
|
202 |
+
resamp_with_conv = config.model.resamp_with_conv
|
203 |
+
|
204 |
+
self.ch = ch
|
205 |
+
self.temb_ch = self.ch * 4
|
206 |
+
self.num_resolutions = len(ch_mult)
|
207 |
+
self.num_res_blocks = num_res_blocks
|
208 |
+
self.resolution = resolution
|
209 |
+
self.in_channels = in_channels
|
210 |
+
|
211 |
+
# timestep embedding
|
212 |
+
self.temb = nn.Module()
|
213 |
+
self.temb.dense = nn.ModuleList([
|
214 |
+
torch.nn.Linear(self.ch,
|
215 |
+
self.temb_ch),
|
216 |
+
torch.nn.Linear(self.temb_ch,
|
217 |
+
self.temb_ch),
|
218 |
+
])
|
219 |
+
|
220 |
+
# downsampling
|
221 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
222 |
+
self.ch,
|
223 |
+
kernel_size=3,
|
224 |
+
stride=1,
|
225 |
+
padding=1)
|
226 |
+
|
227 |
+
curr_res = resolution
|
228 |
+
in_ch_mult = (1,) + ch_mult
|
229 |
+
self.down = nn.ModuleList()
|
230 |
+
block_in = None
|
231 |
+
for i_level in range(self.num_resolutions):
|
232 |
+
block = nn.ModuleList()
|
233 |
+
attn = nn.ModuleList()
|
234 |
+
block_in = ch * in_ch_mult[i_level]
|
235 |
+
block_out = ch * ch_mult[i_level]
|
236 |
+
for i_block in range(self.num_res_blocks):
|
237 |
+
block.append(ResnetBlock(in_channels=block_in,
|
238 |
+
out_channels=block_out,
|
239 |
+
temb_channels=self.temb_ch,
|
240 |
+
dropout=dropout))
|
241 |
+
block_in = block_out
|
242 |
+
if curr_res in attn_resolutions:
|
243 |
+
attn.append(AttnBlock(block_in))
|
244 |
+
down = nn.Module()
|
245 |
+
down.block = block
|
246 |
+
down.attn = attn
|
247 |
+
if i_level != self.num_resolutions - 1:
|
248 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
249 |
+
curr_res = curr_res // 2
|
250 |
+
self.down.append(down)
|
251 |
+
|
252 |
+
# middle
|
253 |
+
self.mid = nn.Module()
|
254 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
255 |
+
out_channels=block_in,
|
256 |
+
temb_channels=self.temb_ch,
|
257 |
+
dropout=dropout)
|
258 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
259 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
260 |
+
out_channels=block_in,
|
261 |
+
temb_channels=self.temb_ch,
|
262 |
+
dropout=dropout)
|
263 |
+
|
264 |
+
# upsampling
|
265 |
+
self.up = nn.ModuleList()
|
266 |
+
for i_level in reversed(range(self.num_resolutions)):
|
267 |
+
block = nn.ModuleList()
|
268 |
+
attn = nn.ModuleList()
|
269 |
+
block_out = ch * ch_mult[i_level]
|
270 |
+
skip_in = ch * ch_mult[i_level]
|
271 |
+
for i_block in range(self.num_res_blocks + 1):
|
272 |
+
if i_block == self.num_res_blocks:
|
273 |
+
skip_in = ch * in_ch_mult[i_level]
|
274 |
+
block.append(ResnetBlock(in_channels=block_in + skip_in,
|
275 |
+
out_channels=block_out,
|
276 |
+
temb_channels=self.temb_ch,
|
277 |
+
dropout=dropout))
|
278 |
+
block_in = block_out
|
279 |
+
if curr_res in attn_resolutions:
|
280 |
+
attn.append(AttnBlock(block_in))
|
281 |
+
up = nn.Module()
|
282 |
+
up.block = block
|
283 |
+
up.attn = attn
|
284 |
+
if i_level != 0:
|
285 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
286 |
+
curr_res = curr_res * 2
|
287 |
+
self.up.insert(0, up) # prepend to get consistent order
|
288 |
+
|
289 |
+
# end
|
290 |
+
self.norm_out = Normalize(block_in)
|
291 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
292 |
+
out_ch,
|
293 |
+
kernel_size=3,
|
294 |
+
stride=1,
|
295 |
+
padding=1)
|
296 |
+
|
297 |
+
def forward(self, x, t, edit_h=None):
|
298 |
+
assert x.shape[2] == x.shape[3] == self.resolution
|
299 |
+
|
300 |
+
# print("check input in U-NET:", x.size()) # [1,3,256,256]
|
301 |
+
|
302 |
+
# timestep embedding
|
303 |
+
temb = get_timestep_embedding(t, self.ch)
|
304 |
+
temb = self.temb.dense[0](temb)
|
305 |
+
temb = nonlinearity(temb)
|
306 |
+
temb = self.temb.dense[1](temb)
|
307 |
+
|
308 |
+
# downsampling
|
309 |
+
hs = [self.conv_in(x)]
|
310 |
+
for i_level in range(self.num_resolutions):
|
311 |
+
for i_block in range(self.num_res_blocks):
|
312 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
313 |
+
if len(self.down[i_level].attn) > 0:
|
314 |
+
h = self.down[i_level].attn[i_block](h)
|
315 |
+
hs.append(h)
|
316 |
+
if i_level != self.num_resolutions - 1:
|
317 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
318 |
+
|
319 |
+
# middle
|
320 |
+
h = hs[-1]
|
321 |
+
h = self.mid.block_1(h, temb)
|
322 |
+
h = self.mid.attn_1(h)
|
323 |
+
h = self.mid.block_2(h, temb)
|
324 |
+
mid_h = h.detach().clone() # get the bottleneck h space embedding
|
325 |
+
# print("check Unet:", mid_h.size()) # [1, 512, 8, 8]
|
326 |
+
# exit()
|
327 |
+
if edit_h != None:
|
328 |
+
h = edit_h
|
329 |
+
|
330 |
+
# upsampling
|
331 |
+
for i_level in reversed(range(self.num_resolutions)):
|
332 |
+
for i_block in range(self.num_res_blocks + 1):
|
333 |
+
h = self.up[i_level].block[i_block](
|
334 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
335 |
+
if len(self.up[i_level].attn) > 0:
|
336 |
+
h = self.up[i_level].attn[i_block](h)
|
337 |
+
if i_level != 0:
|
338 |
+
h = self.up[i_level].upsample(h)
|
339 |
+
# print("check UNET upsampled:", h.size()) # [1, 128, 256, 256]
|
340 |
+
|
341 |
+
# end
|
342 |
+
h = self.norm_out(h)
|
343 |
+
h = nonlinearity(h)
|
344 |
+
h = self.conv_out(h)
|
345 |
+
# print("check U-NET output:", h.size(), mid_h.size()) # [1,3,256,256]
|
346 |
+
# exit()
|
347 |
+
|
348 |
+
return mid_h, h
|
models/improved_ddpm/fp16_util.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpers to train with 16-bit precision.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch as th
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
9 |
+
|
10 |
+
from . import logger
|
11 |
+
|
12 |
+
INITIAL_LOG_LOSS_SCALE = 20.0
|
13 |
+
|
14 |
+
|
15 |
+
def convert_module_to_f16(l):
|
16 |
+
"""
|
17 |
+
Convert primitive modules to float16.
|
18 |
+
"""
|
19 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
20 |
+
l.weight.data = l.weight.data.half()
|
21 |
+
if l.bias is not None:
|
22 |
+
l.bias.data = l.bias.data.half()
|
23 |
+
|
24 |
+
|
25 |
+
def convert_module_to_f32(l):
|
26 |
+
"""
|
27 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
28 |
+
"""
|
29 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
30 |
+
l.weight.data = l.weight.data.float()
|
31 |
+
if l.bias is not None:
|
32 |
+
l.bias.data = l.bias.data.float()
|
33 |
+
|
34 |
+
|
35 |
+
def make_master_params(param_groups_and_shapes):
|
36 |
+
"""
|
37 |
+
Copy model parameters into a (differently-shaped) list of full-precision
|
38 |
+
parameters.
|
39 |
+
"""
|
40 |
+
master_params = []
|
41 |
+
for param_group, shape in param_groups_and_shapes:
|
42 |
+
master_param = nn.Parameter(
|
43 |
+
_flatten_dense_tensors(
|
44 |
+
[param.detach().float() for (_, param) in param_group]
|
45 |
+
).view(shape)
|
46 |
+
)
|
47 |
+
master_param.requires_grad = True
|
48 |
+
master_params.append(master_param)
|
49 |
+
return master_params
|
50 |
+
|
51 |
+
|
52 |
+
def model_grads_to_master_grads(param_groups_and_shapes, master_params):
|
53 |
+
"""
|
54 |
+
Copy the gradients from the model parameters into the master parameters
|
55 |
+
from make_master_params().
|
56 |
+
"""
|
57 |
+
for master_param, (param_group, shape) in zip(
|
58 |
+
master_params, param_groups_and_shapes
|
59 |
+
):
|
60 |
+
master_param.grad = _flatten_dense_tensors(
|
61 |
+
[param_grad_or_zeros(param) for (_, param) in param_group]
|
62 |
+
).view(shape)
|
63 |
+
|
64 |
+
|
65 |
+
def master_params_to_model_params(param_groups_and_shapes, master_params):
|
66 |
+
"""
|
67 |
+
Copy the master parameter data back into the model parameters.
|
68 |
+
"""
|
69 |
+
# Without copying to a list, if a generator is passed, this will
|
70 |
+
# silently not copy any parameters.
|
71 |
+
for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
|
72 |
+
for (_, param), unflat_master_param in zip(
|
73 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
74 |
+
):
|
75 |
+
param.detach().copy_(unflat_master_param)
|
76 |
+
|
77 |
+
|
78 |
+
def unflatten_master_params(param_group, master_param):
|
79 |
+
return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
|
80 |
+
|
81 |
+
|
82 |
+
def get_param_groups_and_shapes(named_model_params):
|
83 |
+
named_model_params = list(named_model_params)
|
84 |
+
scalar_vector_named_params = (
|
85 |
+
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
|
86 |
+
(-1),
|
87 |
+
)
|
88 |
+
matrix_named_params = (
|
89 |
+
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
|
90 |
+
(1, -1),
|
91 |
+
)
|
92 |
+
return [scalar_vector_named_params, matrix_named_params]
|
93 |
+
|
94 |
+
|
95 |
+
def master_params_to_state_dict(
|
96 |
+
model, param_groups_and_shapes, master_params, use_fp16
|
97 |
+
):
|
98 |
+
if use_fp16:
|
99 |
+
state_dict = model.state_dict()
|
100 |
+
for master_param, (param_group, _) in zip(
|
101 |
+
master_params, param_groups_and_shapes
|
102 |
+
):
|
103 |
+
for (name, _), unflat_master_param in zip(
|
104 |
+
param_group, unflatten_master_params(param_group, master_param.view(-1))
|
105 |
+
):
|
106 |
+
assert name in state_dict
|
107 |
+
state_dict[name] = unflat_master_param
|
108 |
+
else:
|
109 |
+
state_dict = model.state_dict()
|
110 |
+
for i, (name, _value) in enumerate(model.named_parameters()):
|
111 |
+
assert name in state_dict
|
112 |
+
state_dict[name] = master_params[i]
|
113 |
+
return state_dict
|
114 |
+
|
115 |
+
|
116 |
+
def state_dict_to_master_params(model, state_dict, use_fp16):
|
117 |
+
if use_fp16:
|
118 |
+
named_model_params = [
|
119 |
+
(name, state_dict[name]) for name, _ in model.named_parameters()
|
120 |
+
]
|
121 |
+
param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
|
122 |
+
master_params = make_master_params(param_groups_and_shapes)
|
123 |
+
else:
|
124 |
+
master_params = [state_dict[name] for name, _ in model.named_parameters()]
|
125 |
+
return master_params
|
126 |
+
|
127 |
+
|
128 |
+
def zero_master_grads(master_params):
|
129 |
+
for param in master_params:
|
130 |
+
param.grad = None
|
131 |
+
|
132 |
+
|
133 |
+
def zero_grad(model_params):
|
134 |
+
for param in model_params:
|
135 |
+
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
|
136 |
+
if param.grad is not None:
|
137 |
+
param.grad.detach_()
|
138 |
+
param.grad.zero_()
|
139 |
+
|
140 |
+
|
141 |
+
def param_grad_or_zeros(param):
|
142 |
+
if param.grad is not None:
|
143 |
+
return param.grad.data.detach()
|
144 |
+
else:
|
145 |
+
return th.zeros_like(param)
|
146 |
+
|
147 |
+
|
148 |
+
class MixedPrecisionTrainer:
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
*,
|
152 |
+
model,
|
153 |
+
use_fp16=False,
|
154 |
+
fp16_scale_growth=1e-3,
|
155 |
+
initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
|
156 |
+
):
|
157 |
+
self.model = model
|
158 |
+
self.use_fp16 = use_fp16
|
159 |
+
self.fp16_scale_growth = fp16_scale_growth
|
160 |
+
|
161 |
+
self.model_params = list(self.model.parameters())
|
162 |
+
self.master_params = self.model_params
|
163 |
+
self.param_groups_and_shapes = None
|
164 |
+
self.lg_loss_scale = initial_lg_loss_scale
|
165 |
+
|
166 |
+
if self.use_fp16:
|
167 |
+
self.param_groups_and_shapes = get_param_groups_and_shapes(
|
168 |
+
self.model.named_parameters()
|
169 |
+
)
|
170 |
+
self.master_params = make_master_params(self.param_groups_and_shapes)
|
171 |
+
self.model.convert_to_fp16()
|
172 |
+
|
173 |
+
def zero_grad(self):
|
174 |
+
zero_grad(self.model_params)
|
175 |
+
|
176 |
+
def backward(self, loss: th.Tensor):
|
177 |
+
if self.use_fp16:
|
178 |
+
loss_scale = 2 ** self.lg_loss_scale
|
179 |
+
(loss * loss_scale).backward()
|
180 |
+
else:
|
181 |
+
loss.backward()
|
182 |
+
|
183 |
+
def optimize(self, opt: th.optim.Optimizer):
|
184 |
+
if self.use_fp16:
|
185 |
+
return self._optimize_fp16(opt)
|
186 |
+
else:
|
187 |
+
return self._optimize_normal(opt)
|
188 |
+
|
189 |
+
def _optimize_fp16(self, opt: th.optim.Optimizer):
|
190 |
+
logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
|
191 |
+
model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
|
192 |
+
grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
|
193 |
+
if check_overflow(grad_norm):
|
194 |
+
self.lg_loss_scale -= 1
|
195 |
+
logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
|
196 |
+
zero_master_grads(self.master_params)
|
197 |
+
return False
|
198 |
+
|
199 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
200 |
+
logger.logkv_mean("param_norm", param_norm)
|
201 |
+
|
202 |
+
self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
|
203 |
+
opt.step()
|
204 |
+
zero_master_grads(self.master_params)
|
205 |
+
master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
|
206 |
+
self.lg_loss_scale += self.fp16_scale_growth
|
207 |
+
return True
|
208 |
+
|
209 |
+
def _optimize_normal(self, opt: th.optim.Optimizer):
|
210 |
+
grad_norm, param_norm = self._compute_norms()
|
211 |
+
logger.logkv_mean("grad_norm", grad_norm)
|
212 |
+
logger.logkv_mean("param_norm", param_norm)
|
213 |
+
opt.step()
|
214 |
+
return True
|
215 |
+
|
216 |
+
def _compute_norms(self, grad_scale=1.0):
|
217 |
+
grad_norm = 0.0
|
218 |
+
param_norm = 0.0
|
219 |
+
for p in self.master_params:
|
220 |
+
with th.no_grad():
|
221 |
+
param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
|
222 |
+
if p.grad is not None:
|
223 |
+
grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
|
224 |
+
return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
|
225 |
+
|
226 |
+
def master_params_to_state_dict(self, master_params):
|
227 |
+
return master_params_to_state_dict(
|
228 |
+
self.model, self.param_groups_and_shapes, master_params, self.use_fp16
|
229 |
+
)
|
230 |
+
|
231 |
+
def state_dict_to_master_params(self, state_dict):
|
232 |
+
return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
|
233 |
+
|
234 |
+
|
235 |
+
def check_overflow(value):
|
236 |
+
return (value == float("inf")) or (value == -float("inf")) or (value != value)
|
models/improved_ddpm/logger.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Logger based on OpenAI baselines to avoid extra RL-based dependencies:
|
3 |
+
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import os.path as osp
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
import datetime
|
12 |
+
import tempfile
|
13 |
+
import warnings
|
14 |
+
from collections import defaultdict
|
15 |
+
from contextlib import contextmanager
|
16 |
+
|
17 |
+
DEBUG = 10
|
18 |
+
INFO = 20
|
19 |
+
WARN = 30
|
20 |
+
ERROR = 40
|
21 |
+
|
22 |
+
DISABLED = 50
|
23 |
+
|
24 |
+
|
25 |
+
class KVWriter(object):
|
26 |
+
def writekvs(self, kvs):
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
|
30 |
+
class SeqWriter(object):
|
31 |
+
def writeseq(self, seq):
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
|
35 |
+
class HumanOutputFormat(KVWriter, SeqWriter):
|
36 |
+
def __init__(self, filename_or_file):
|
37 |
+
if isinstance(filename_or_file, str):
|
38 |
+
self.file = open(filename_or_file, "wt")
|
39 |
+
self.own_file = True
|
40 |
+
else:
|
41 |
+
assert hasattr(filename_or_file, "read"), (
|
42 |
+
"expected file or str, got %s" % filename_or_file
|
43 |
+
)
|
44 |
+
self.file = filename_or_file
|
45 |
+
self.own_file = False
|
46 |
+
|
47 |
+
def writekvs(self, kvs):
|
48 |
+
# Create strings for printing
|
49 |
+
key2str = {}
|
50 |
+
for (key, val) in sorted(kvs.items()):
|
51 |
+
if hasattr(val, "__float__"):
|
52 |
+
valstr = "%-8.3g" % val
|
53 |
+
else:
|
54 |
+
valstr = str(val)
|
55 |
+
key2str[self._truncate(key)] = self._truncate(valstr)
|
56 |
+
|
57 |
+
# Find max widths
|
58 |
+
if len(key2str) == 0:
|
59 |
+
print("WARNING: tried to write empty key-value dict")
|
60 |
+
return
|
61 |
+
else:
|
62 |
+
keywidth = max(map(len, key2str.keys()))
|
63 |
+
valwidth = max(map(len, key2str.values()))
|
64 |
+
|
65 |
+
# Write out the data
|
66 |
+
dashes = "-" * (keywidth + valwidth + 7)
|
67 |
+
lines = [dashes]
|
68 |
+
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
|
69 |
+
lines.append(
|
70 |
+
"| %s%s | %s%s |"
|
71 |
+
% (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
|
72 |
+
)
|
73 |
+
lines.append(dashes)
|
74 |
+
self.file.write("\n".join(lines) + "\n")
|
75 |
+
|
76 |
+
# Flush the output to the file
|
77 |
+
self.file.flush()
|
78 |
+
|
79 |
+
def _truncate(self, s):
|
80 |
+
maxlen = 30
|
81 |
+
return s[: maxlen - 3] + "..." if len(s) > maxlen else s
|
82 |
+
|
83 |
+
def writeseq(self, seq):
|
84 |
+
seq = list(seq)
|
85 |
+
for (i, elem) in enumerate(seq):
|
86 |
+
self.file.write(elem)
|
87 |
+
if i < len(seq) - 1: # add space unless this is the last one
|
88 |
+
self.file.write(" ")
|
89 |
+
self.file.write("\n")
|
90 |
+
self.file.flush()
|
91 |
+
|
92 |
+
def close(self):
|
93 |
+
if self.own_file:
|
94 |
+
self.file.close()
|
95 |
+
|
96 |
+
|
97 |
+
class JSONOutputFormat(KVWriter):
|
98 |
+
def __init__(self, filename):
|
99 |
+
self.file = open(filename, "wt")
|
100 |
+
|
101 |
+
def writekvs(self, kvs):
|
102 |
+
for k, v in sorted(kvs.items()):
|
103 |
+
if hasattr(v, "dtype"):
|
104 |
+
kvs[k] = float(v)
|
105 |
+
self.file.write(json.dumps(kvs) + "\n")
|
106 |
+
self.file.flush()
|
107 |
+
|
108 |
+
def close(self):
|
109 |
+
self.file.close()
|
110 |
+
|
111 |
+
|
112 |
+
class CSVOutputFormat(KVWriter):
|
113 |
+
def __init__(self, filename):
|
114 |
+
self.file = open(filename, "w+t")
|
115 |
+
self.keys = []
|
116 |
+
self.sep = ","
|
117 |
+
|
118 |
+
def writekvs(self, kvs):
|
119 |
+
# Add our current row to the history
|
120 |
+
extra_keys = list(kvs.keys() - self.keys)
|
121 |
+
extra_keys.sort()
|
122 |
+
if extra_keys:
|
123 |
+
self.keys.extend(extra_keys)
|
124 |
+
self.file.seek(0)
|
125 |
+
lines = self.file.readlines()
|
126 |
+
self.file.seek(0)
|
127 |
+
for (i, k) in enumerate(self.keys):
|
128 |
+
if i > 0:
|
129 |
+
self.file.write(",")
|
130 |
+
self.file.write(k)
|
131 |
+
self.file.write("\n")
|
132 |
+
for line in lines[1:]:
|
133 |
+
self.file.write(line[:-1])
|
134 |
+
self.file.write(self.sep * len(extra_keys))
|
135 |
+
self.file.write("\n")
|
136 |
+
for (i, k) in enumerate(self.keys):
|
137 |
+
if i > 0:
|
138 |
+
self.file.write(",")
|
139 |
+
v = kvs.get(k)
|
140 |
+
if v is not None:
|
141 |
+
self.file.write(str(v))
|
142 |
+
self.file.write("\n")
|
143 |
+
self.file.flush()
|
144 |
+
|
145 |
+
def close(self):
|
146 |
+
self.file.close()
|
147 |
+
|
148 |
+
|
149 |
+
def make_output_format(format, ev_dir, log_suffix=""):
|
150 |
+
os.makedirs(ev_dir, exist_ok=True)
|
151 |
+
if format == "stdout":
|
152 |
+
return HumanOutputFormat(sys.stdout)
|
153 |
+
elif format == "log":
|
154 |
+
return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
|
155 |
+
elif format == "json":
|
156 |
+
return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
|
157 |
+
elif format == "csv":
|
158 |
+
return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
|
159 |
+
else:
|
160 |
+
raise ValueError("Unknown format specified: %s" % (format,))
|
161 |
+
|
162 |
+
|
163 |
+
# ================================================================
|
164 |
+
# API
|
165 |
+
# ================================================================
|
166 |
+
|
167 |
+
|
168 |
+
def logkv(key, val):
|
169 |
+
"""
|
170 |
+
Log a value of some diagnostic
|
171 |
+
Call this once for each diagnostic quantity, each iteration
|
172 |
+
If called many times, last value will be used.
|
173 |
+
"""
|
174 |
+
get_current().logkv(key, val)
|
175 |
+
|
176 |
+
|
177 |
+
def logkv_mean(key, val):
|
178 |
+
"""
|
179 |
+
The same as logkv(), but if called many times, values averaged.
|
180 |
+
"""
|
181 |
+
get_current().logkv_mean(key, val)
|
182 |
+
|
183 |
+
|
184 |
+
def logkvs(d):
|
185 |
+
"""
|
186 |
+
Log a dictionary of key-value pairs
|
187 |
+
"""
|
188 |
+
for (k, v) in d.items():
|
189 |
+
logkv(k, v)
|
190 |
+
|
191 |
+
|
192 |
+
def dumpkvs():
|
193 |
+
"""
|
194 |
+
Write all of the diagnostics from the current iteration
|
195 |
+
"""
|
196 |
+
return get_current().dumpkvs()
|
197 |
+
|
198 |
+
|
199 |
+
def getkvs():
|
200 |
+
return get_current().name2val
|
201 |
+
|
202 |
+
|
203 |
+
def log(*args, level=INFO):
|
204 |
+
"""
|
205 |
+
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
|
206 |
+
"""
|
207 |
+
get_current().log(*args, level=level)
|
208 |
+
|
209 |
+
|
210 |
+
def debug(*args):
|
211 |
+
log(*args, level=DEBUG)
|
212 |
+
|
213 |
+
|
214 |
+
def info(*args):
|
215 |
+
log(*args, level=INFO)
|
216 |
+
|
217 |
+
|
218 |
+
def warn(*args):
|
219 |
+
log(*args, level=WARN)
|
220 |
+
|
221 |
+
|
222 |
+
def error(*args):
|
223 |
+
log(*args, level=ERROR)
|
224 |
+
|
225 |
+
|
226 |
+
def set_level(level):
|
227 |
+
"""
|
228 |
+
Set logging threshold on current logger.
|
229 |
+
"""
|
230 |
+
get_current().set_level(level)
|
231 |
+
|
232 |
+
|
233 |
+
def set_comm(comm):
|
234 |
+
get_current().set_comm(comm)
|
235 |
+
|
236 |
+
|
237 |
+
def get_dir():
|
238 |
+
"""
|
239 |
+
Get directory that log files are being written to.
|
240 |
+
will be None if there is no output directory (i.e., if you didn't call start)
|
241 |
+
"""
|
242 |
+
return get_current().get_dir()
|
243 |
+
|
244 |
+
|
245 |
+
record_tabular = logkv
|
246 |
+
dump_tabular = dumpkvs
|
247 |
+
|
248 |
+
|
249 |
+
@contextmanager
|
250 |
+
def profile_kv(scopename):
|
251 |
+
logkey = "wait_" + scopename
|
252 |
+
tstart = time.time()
|
253 |
+
try:
|
254 |
+
yield
|
255 |
+
finally:
|
256 |
+
get_current().name2val[logkey] += time.time() - tstart
|
257 |
+
|
258 |
+
|
259 |
+
def profile(n):
|
260 |
+
"""
|
261 |
+
Usage:
|
262 |
+
@profile("my_func")
|
263 |
+
def my_func(): code
|
264 |
+
"""
|
265 |
+
|
266 |
+
def decorator_with_name(func):
|
267 |
+
def func_wrapper(*args, **kwargs):
|
268 |
+
with profile_kv(n):
|
269 |
+
return func(*args, **kwargs)
|
270 |
+
|
271 |
+
return func_wrapper
|
272 |
+
|
273 |
+
return decorator_with_name
|
274 |
+
|
275 |
+
|
276 |
+
# ================================================================
|
277 |
+
# Backend
|
278 |
+
# ================================================================
|
279 |
+
|
280 |
+
|
281 |
+
def get_current():
|
282 |
+
if Logger.CURRENT is None:
|
283 |
+
_configure_default_logger()
|
284 |
+
|
285 |
+
return Logger.CURRENT
|
286 |
+
|
287 |
+
|
288 |
+
class Logger(object):
|
289 |
+
DEFAULT = None # A logger with no output files. (See right below class definition)
|
290 |
+
# So that you can still log to the terminal without setting up any output files
|
291 |
+
CURRENT = None # Current logger being used by the free functions above
|
292 |
+
|
293 |
+
def __init__(self, dir, output_formats, comm=None):
|
294 |
+
self.name2val = defaultdict(float) # values this iteration
|
295 |
+
self.name2cnt = defaultdict(int)
|
296 |
+
self.level = INFO
|
297 |
+
self.dir = dir
|
298 |
+
self.output_formats = output_formats
|
299 |
+
self.comm = comm
|
300 |
+
|
301 |
+
# Logging API, forwarded
|
302 |
+
# ----------------------------------------
|
303 |
+
def logkv(self, key, val):
|
304 |
+
self.name2val[key] = val
|
305 |
+
|
306 |
+
def logkv_mean(self, key, val):
|
307 |
+
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
308 |
+
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
|
309 |
+
self.name2cnt[key] = cnt + 1
|
310 |
+
|
311 |
+
def dumpkvs(self):
|
312 |
+
if self.comm is None:
|
313 |
+
d = self.name2val
|
314 |
+
else:
|
315 |
+
d = mpi_weighted_mean(
|
316 |
+
self.comm,
|
317 |
+
{
|
318 |
+
name: (val, self.name2cnt.get(name, 1))
|
319 |
+
for (name, val) in self.name2val.items()
|
320 |
+
},
|
321 |
+
)
|
322 |
+
if self.comm.rank != 0:
|
323 |
+
d["dummy"] = 1 # so we don't get a warning about empty dict
|
324 |
+
out = d.copy() # Return the dict for unit testing purposes
|
325 |
+
for fmt in self.output_formats:
|
326 |
+
if isinstance(fmt, KVWriter):
|
327 |
+
fmt.writekvs(d)
|
328 |
+
self.name2val.clear()
|
329 |
+
self.name2cnt.clear()
|
330 |
+
return out
|
331 |
+
|
332 |
+
def log(self, *args, level=INFO):
|
333 |
+
if self.level <= level:
|
334 |
+
self._do_log(args)
|
335 |
+
|
336 |
+
# Configuration
|
337 |
+
# ----------------------------------------
|
338 |
+
def set_level(self, level):
|
339 |
+
self.level = level
|
340 |
+
|
341 |
+
def set_comm(self, comm):
|
342 |
+
self.comm = comm
|
343 |
+
|
344 |
+
def get_dir(self):
|
345 |
+
return self.dir
|
346 |
+
|
347 |
+
def close(self):
|
348 |
+
for fmt in self.output_formats:
|
349 |
+
fmt.close()
|
350 |
+
|
351 |
+
# Misc
|
352 |
+
# ----------------------------------------
|
353 |
+
def _do_log(self, args):
|
354 |
+
for fmt in self.output_formats:
|
355 |
+
if isinstance(fmt, SeqWriter):
|
356 |
+
fmt.writeseq(map(str, args))
|
357 |
+
|
358 |
+
|
359 |
+
def get_rank_without_mpi_import():
|
360 |
+
# check environment variables here instead of importing mpi4py
|
361 |
+
# to avoid calling MPI_Init() when this module is imported
|
362 |
+
for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
|
363 |
+
if varname in os.environ:
|
364 |
+
return int(os.environ[varname])
|
365 |
+
return 0
|
366 |
+
|
367 |
+
|
368 |
+
def mpi_weighted_mean(comm, local_name2valcount):
|
369 |
+
"""
|
370 |
+
Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
|
371 |
+
Perform a weighted average over dicts that are each on a different node
|
372 |
+
Input: local_name2valcount: dict mapping key -> (value, count)
|
373 |
+
Returns: key -> mean
|
374 |
+
"""
|
375 |
+
all_name2valcount = comm.gather(local_name2valcount)
|
376 |
+
if comm.rank == 0:
|
377 |
+
name2sum = defaultdict(float)
|
378 |
+
name2count = defaultdict(float)
|
379 |
+
for n2vc in all_name2valcount:
|
380 |
+
for (name, (val, count)) in n2vc.items():
|
381 |
+
try:
|
382 |
+
val = float(val)
|
383 |
+
except ValueError:
|
384 |
+
if comm.rank == 0:
|
385 |
+
warnings.warn(
|
386 |
+
"WARNING: tried to compute mean on non-float {}={}".format(
|
387 |
+
name, val
|
388 |
+
)
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
name2sum[name] += val * count
|
392 |
+
name2count[name] += count
|
393 |
+
return {name: name2sum[name] / name2count[name] for name in name2sum}
|
394 |
+
else:
|
395 |
+
return {}
|
396 |
+
|
397 |
+
|
398 |
+
def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
|
399 |
+
"""
|
400 |
+
If comm is provided, average all numerical stats across that comm
|
401 |
+
"""
|
402 |
+
if dir is None:
|
403 |
+
dir = os.getenv("OPENAI_LOGDIR")
|
404 |
+
if dir is None:
|
405 |
+
dir = osp.join(
|
406 |
+
tempfile.gettempdir(),
|
407 |
+
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
|
408 |
+
)
|
409 |
+
assert isinstance(dir, str)
|
410 |
+
dir = os.path.expanduser(dir)
|
411 |
+
os.makedirs(os.path.expanduser(dir), exist_ok=True)
|
412 |
+
|
413 |
+
rank = get_rank_without_mpi_import()
|
414 |
+
if rank > 0:
|
415 |
+
log_suffix = log_suffix + "-rank%03i" % rank
|
416 |
+
|
417 |
+
if format_strs is None:
|
418 |
+
if rank == 0:
|
419 |
+
format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
|
420 |
+
else:
|
421 |
+
format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
|
422 |
+
format_strs = filter(None, format_strs)
|
423 |
+
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
|
424 |
+
|
425 |
+
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
|
426 |
+
if output_formats:
|
427 |
+
log("Logging to %s" % dir)
|
428 |
+
|
429 |
+
|
430 |
+
def _configure_default_logger():
|
431 |
+
configure()
|
432 |
+
Logger.DEFAULT = Logger.CURRENT
|
433 |
+
|
434 |
+
|
435 |
+
def reset():
|
436 |
+
if Logger.CURRENT is not Logger.DEFAULT:
|
437 |
+
Logger.CURRENT.close()
|
438 |
+
Logger.CURRENT = Logger.DEFAULT
|
439 |
+
log("Reset logger")
|
440 |
+
|
441 |
+
|
442 |
+
@contextmanager
|
443 |
+
def scoped_configure(dir=None, format_strs=None, comm=None):
|
444 |
+
prevlogger = Logger.CURRENT
|
445 |
+
configure(dir=dir, format_strs=format_strs, comm=comm)
|
446 |
+
try:
|
447 |
+
yield
|
448 |
+
finally:
|
449 |
+
Logger.CURRENT.close()
|
450 |
+
Logger.CURRENT = prevlogger
|
451 |
+
|
models/improved_ddpm/nn.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Various utilities for neural networks.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch as th
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
12 |
+
class SiLU(nn.Module):
|
13 |
+
def forward(self, x):
|
14 |
+
return x * th.sigmoid(x)
|
15 |
+
|
16 |
+
|
17 |
+
class GroupNorm32(nn.GroupNorm):
|
18 |
+
def forward(self, x):
|
19 |
+
return super().forward(x.float()).type(x.dtype)
|
20 |
+
|
21 |
+
|
22 |
+
def conv_nd(dims, *args, **kwargs):
|
23 |
+
"""
|
24 |
+
Create a 1D, 2D, or 3D convolution module.
|
25 |
+
"""
|
26 |
+
if dims == 1:
|
27 |
+
return nn.Conv1d(*args, **kwargs)
|
28 |
+
elif dims == 2:
|
29 |
+
return nn.Conv2d(*args, **kwargs)
|
30 |
+
elif dims == 3:
|
31 |
+
return nn.Conv3d(*args, **kwargs)
|
32 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
33 |
+
|
34 |
+
|
35 |
+
def linear(*args, **kwargs):
|
36 |
+
"""
|
37 |
+
Create a linear module.
|
38 |
+
"""
|
39 |
+
return nn.Linear(*args, **kwargs)
|
40 |
+
|
41 |
+
|
42 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
43 |
+
"""
|
44 |
+
Create a 1D, 2D, or 3D average pooling module.
|
45 |
+
"""
|
46 |
+
if dims == 1:
|
47 |
+
return nn.AvgPool1d(*args, **kwargs)
|
48 |
+
elif dims == 2:
|
49 |
+
return nn.AvgPool2d(*args, **kwargs)
|
50 |
+
elif dims == 3:
|
51 |
+
return nn.AvgPool3d(*args, **kwargs)
|
52 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
53 |
+
|
54 |
+
|
55 |
+
def update_ema(target_params, source_params, rate=0.99):
|
56 |
+
"""
|
57 |
+
Update target parameters to be closer to those of source parameters using
|
58 |
+
an exponential moving average.
|
59 |
+
|
60 |
+
:param target_params: the target parameter sequence.
|
61 |
+
:param source_params: the source parameter sequence.
|
62 |
+
:param rate: the EMA rate (closer to 1 means slower).
|
63 |
+
"""
|
64 |
+
for targ, src in zip(target_params, source_params):
|
65 |
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
66 |
+
|
67 |
+
|
68 |
+
def zero_module(module):
|
69 |
+
"""
|
70 |
+
Zero out the parameters of a module and return it.
|
71 |
+
"""
|
72 |
+
for p in module.parameters():
|
73 |
+
p.detach().zero_()
|
74 |
+
return module
|
75 |
+
|
76 |
+
|
77 |
+
def scale_module(module, scale):
|
78 |
+
"""
|
79 |
+
Scale the parameters of a module and return it.
|
80 |
+
"""
|
81 |
+
for p in module.parameters():
|
82 |
+
p.detach().mul_(scale)
|
83 |
+
return module
|
84 |
+
|
85 |
+
|
86 |
+
def mean_flat(tensor):
|
87 |
+
"""
|
88 |
+
Take the mean over all non-batch dimensions.
|
89 |
+
"""
|
90 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
91 |
+
|
92 |
+
|
93 |
+
def normalization(channels):
|
94 |
+
"""
|
95 |
+
Make a standard normalization layer.
|
96 |
+
|
97 |
+
:param channels: number of input channels.
|
98 |
+
:return: an nn.Module for normalization.
|
99 |
+
"""
|
100 |
+
return GroupNorm32(32, channels)
|
101 |
+
|
102 |
+
|
103 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
104 |
+
"""
|
105 |
+
Create sinusoidal timestep embeddings.
|
106 |
+
|
107 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
108 |
+
These may be fractional.
|
109 |
+
:param dim: the dimension of the output.
|
110 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
111 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
112 |
+
"""
|
113 |
+
half = dim // 2
|
114 |
+
freqs = th.exp(
|
115 |
+
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
116 |
+
).to(device=timesteps.device)
|
117 |
+
args = timesteps[:, None].float() * freqs[None]
|
118 |
+
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
119 |
+
if dim % 2:
|
120 |
+
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
121 |
+
return embedding
|
122 |
+
|
123 |
+
|
124 |
+
def checkpoint(func, inputs, params, flag):
|
125 |
+
"""
|
126 |
+
Evaluate a function without caching intermediate activations, allowing for
|
127 |
+
reduced memory at the expense of extra compute in the backward pass.
|
128 |
+
|
129 |
+
:param func: the function to evaluate.
|
130 |
+
:param inputs: the argument sequence to pass to `func`.
|
131 |
+
:param params: a sequence of parameters `func` depends on but does not
|
132 |
+
explicitly take as arguments.
|
133 |
+
:param flag: if False, disable gradient checkpointing.
|
134 |
+
"""
|
135 |
+
if flag:
|
136 |
+
args = tuple(inputs) + tuple(params)
|
137 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
138 |
+
else:
|
139 |
+
return func(*inputs)
|
140 |
+
|
141 |
+
|
142 |
+
class CheckpointFunction(th.autograd.Function):
|
143 |
+
@staticmethod
|
144 |
+
def forward(ctx, run_function, length, *args):
|
145 |
+
ctx.run_function = run_function
|
146 |
+
ctx.input_tensors = list(args[:length])
|
147 |
+
ctx.input_params = list(args[length:])
|
148 |
+
with th.no_grad():
|
149 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
150 |
+
return output_tensors
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def backward(ctx, *output_grads):
|
154 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
155 |
+
with th.enable_grad():
|
156 |
+
# Fixes a bug where the first op in run_function modifies the
|
157 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
158 |
+
# Tensors.
|
159 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
160 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
161 |
+
input_grads = th.autograd.grad(
|
162 |
+
output_tensors,
|
163 |
+
ctx.input_tensors + ctx.input_params,
|
164 |
+
output_grads,
|
165 |
+
allow_unused=True,
|
166 |
+
)
|
167 |
+
del ctx.input_tensors
|
168 |
+
del ctx.input_params
|
169 |
+
del output_tensors
|
170 |
+
return (None, None) + input_grads
|
models/improved_ddpm/script_util.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .unet import UNetModel
|
2 |
+
|
3 |
+
NUM_CLASSES = 1000
|
4 |
+
|
5 |
+
AFHQ_DICT = dict(
|
6 |
+
attention_resolutions="16",
|
7 |
+
class_cond=False,
|
8 |
+
dropout=0.0,
|
9 |
+
image_size=256,
|
10 |
+
learn_sigma=True,
|
11 |
+
num_channels=128,
|
12 |
+
num_head_channels=64,
|
13 |
+
num_res_blocks=1,
|
14 |
+
resblock_updown=True,
|
15 |
+
use_fp16=False,
|
16 |
+
use_scale_shift_norm=True,
|
17 |
+
num_heads=4,
|
18 |
+
num_heads_upsample=-1,
|
19 |
+
channel_mult="",
|
20 |
+
use_checkpoint=False,
|
21 |
+
use_new_attention_order=False,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
IMAGENET_DICT = dict(
|
26 |
+
attention_resolutions="32,16,8",
|
27 |
+
class_cond=True,
|
28 |
+
image_size=512,
|
29 |
+
learn_sigma=True,
|
30 |
+
num_channels=256,
|
31 |
+
num_head_channels=64,
|
32 |
+
num_res_blocks=2,
|
33 |
+
resblock_updown=True,
|
34 |
+
use_fp16=False,
|
35 |
+
use_scale_shift_norm=True,
|
36 |
+
dropout=0.0,
|
37 |
+
num_heads=4,
|
38 |
+
num_heads_upsample=-1,
|
39 |
+
channel_mult="",
|
40 |
+
use_checkpoint=False,
|
41 |
+
use_new_attention_order=False,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def create_model(
|
46 |
+
image_size,
|
47 |
+
num_channels,
|
48 |
+
num_res_blocks,
|
49 |
+
channel_mult="",
|
50 |
+
learn_sigma=False,
|
51 |
+
class_cond=False,
|
52 |
+
use_checkpoint=False,
|
53 |
+
attention_resolutions="16",
|
54 |
+
num_heads=1,
|
55 |
+
num_head_channels=-1,
|
56 |
+
num_heads_upsample=-1,
|
57 |
+
use_scale_shift_norm=False,
|
58 |
+
dropout=0,
|
59 |
+
resblock_updown=False,
|
60 |
+
use_fp16=False,
|
61 |
+
use_new_attention_order=False,
|
62 |
+
):
|
63 |
+
if channel_mult == "":
|
64 |
+
if image_size == 512:
|
65 |
+
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
|
66 |
+
elif image_size == 256:
|
67 |
+
channel_mult = (1, 1, 2, 2, 4, 4)
|
68 |
+
elif image_size == 128:
|
69 |
+
channel_mult = (1, 1, 2, 3, 4)
|
70 |
+
elif image_size == 64:
|
71 |
+
channel_mult = (1, 2, 3, 4)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"unsupported image size: {image_size}")
|
74 |
+
else:
|
75 |
+
channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
|
76 |
+
|
77 |
+
attention_ds = []
|
78 |
+
for res in attention_resolutions.split(","):
|
79 |
+
attention_ds.append(image_size // int(res))
|
80 |
+
|
81 |
+
return UNetModel(
|
82 |
+
image_size=image_size,
|
83 |
+
in_channels=3,
|
84 |
+
model_channels=num_channels,
|
85 |
+
out_channels=(3 if not learn_sigma else 6),
|
86 |
+
num_res_blocks=num_res_blocks,
|
87 |
+
attention_resolutions=tuple(attention_ds),
|
88 |
+
dropout=dropout,
|
89 |
+
channel_mult=channel_mult,
|
90 |
+
num_classes=(NUM_CLASSES if class_cond else None),
|
91 |
+
use_checkpoint=use_checkpoint,
|
92 |
+
use_fp16=use_fp16,
|
93 |
+
num_heads=num_heads,
|
94 |
+
num_head_channels=num_head_channels,
|
95 |
+
num_heads_upsample=num_heads_upsample,
|
96 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
97 |
+
resblock_updown=resblock_updown,
|
98 |
+
use_new_attention_order=use_new_attention_order,
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
def i_DDPM(dataset_name = 'AFHQ'):
|
103 |
+
if dataset_name in ['AFHQ', 'FFHQ']:
|
104 |
+
return create_model(**AFHQ_DICT)
|
105 |
+
elif dataset_name == 'IMAGENET':
|
106 |
+
return create_model(**IMAGENET_DICT)
|
107 |
+
else:
|
108 |
+
print('Not implemented.')
|
109 |
+
exit()
|
models/improved_ddpm/unet.py
ADDED
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Codebase for "Improved Denoising Diffusion Probabilistic Models".
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
from abc import abstractmethod
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch as th
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from .fp16_util import convert_module_to_f16, convert_module_to_f32
|
16 |
+
from .nn import (
|
17 |
+
checkpoint,
|
18 |
+
conv_nd,
|
19 |
+
linear,
|
20 |
+
avg_pool_nd,
|
21 |
+
zero_module,
|
22 |
+
normalization,
|
23 |
+
timestep_embedding,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class AttentionPool2d(nn.Module):
|
28 |
+
"""
|
29 |
+
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
spacial_dim: int,
|
35 |
+
embed_dim: int,
|
36 |
+
num_heads_channels: int,
|
37 |
+
output_dim: int = None,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
self.positional_embedding = nn.Parameter(
|
41 |
+
th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
|
42 |
+
)
|
43 |
+
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
44 |
+
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
45 |
+
self.num_heads = embed_dim // num_heads_channels
|
46 |
+
self.attention = QKVAttention(self.num_heads)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
b, c, *_spatial = x.shape
|
50 |
+
x = x.reshape(b, c, -1) # NC(HW)
|
51 |
+
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
52 |
+
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
53 |
+
x = self.qkv_proj(x)
|
54 |
+
x = self.attention(x)
|
55 |
+
x = self.c_proj(x)
|
56 |
+
return x[:, :, 0]
|
57 |
+
|
58 |
+
|
59 |
+
class TimestepBlock(nn.Module):
|
60 |
+
"""
|
61 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
62 |
+
"""
|
63 |
+
|
64 |
+
@abstractmethod
|
65 |
+
def forward(self, x, emb):
|
66 |
+
"""
|
67 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
68 |
+
"""
|
69 |
+
|
70 |
+
|
71 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
72 |
+
"""
|
73 |
+
A sequential module that passes timestep embeddings to the children that
|
74 |
+
support it as an extra input.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def forward(self, x, emb):
|
78 |
+
for layer in self:
|
79 |
+
if isinstance(layer, TimestepBlock):
|
80 |
+
x = layer(x, emb)
|
81 |
+
else:
|
82 |
+
x = layer(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class Upsample(nn.Module):
|
87 |
+
"""
|
88 |
+
An upsampling layer with an optional convolution.
|
89 |
+
|
90 |
+
:param channels: channels in the inputs and outputs.
|
91 |
+
:param use_conv: a bool determining if a convolution is applied.
|
92 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
93 |
+
upsampling occurs in the inner-two dimensions.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
97 |
+
super().__init__()
|
98 |
+
self.channels = channels
|
99 |
+
self.out_channels = out_channels or channels
|
100 |
+
self.use_conv = use_conv
|
101 |
+
self.dims = dims
|
102 |
+
if use_conv:
|
103 |
+
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
assert x.shape[1] == self.channels
|
107 |
+
if self.dims == 3:
|
108 |
+
x = F.interpolate(
|
109 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
113 |
+
if self.use_conv:
|
114 |
+
x = self.conv(x)
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
class Downsample(nn.Module):
|
119 |
+
"""
|
120 |
+
A downsampling layer with an optional convolution.
|
121 |
+
|
122 |
+
:param channels: channels in the inputs and outputs.
|
123 |
+
:param use_conv: a bool determining if a convolution is applied.
|
124 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
125 |
+
downsampling occurs in the inner-two dimensions.
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
129 |
+
super().__init__()
|
130 |
+
self.channels = channels
|
131 |
+
self.out_channels = out_channels or channels
|
132 |
+
self.use_conv = use_conv
|
133 |
+
self.dims = dims
|
134 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
135 |
+
if use_conv:
|
136 |
+
self.op = conv_nd(
|
137 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
assert self.channels == self.out_channels
|
141 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
assert x.shape[1] == self.channels
|
145 |
+
return self.op(x)
|
146 |
+
|
147 |
+
|
148 |
+
class ResBlock(TimestepBlock):
|
149 |
+
"""
|
150 |
+
A residual block that can optionally change the number of channels.
|
151 |
+
|
152 |
+
:param channels: the number of input channels.
|
153 |
+
:param emb_channels: the number of timestep embedding channels.
|
154 |
+
:param dropout: the rate of dropout.
|
155 |
+
:param out_channels: if specified, the number of out channels.
|
156 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
157 |
+
convolution instead of a smaller 1x1 convolution to change the
|
158 |
+
channels in the skip connection.
|
159 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
160 |
+
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
161 |
+
:param up: if True, use this block for upsampling.
|
162 |
+
:param down: if True, use this block for downsampling.
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
channels,
|
168 |
+
emb_channels,
|
169 |
+
dropout,
|
170 |
+
out_channels=None,
|
171 |
+
use_conv=False,
|
172 |
+
use_scale_shift_norm=False,
|
173 |
+
dims=2,
|
174 |
+
use_checkpoint=False,
|
175 |
+
up=False,
|
176 |
+
down=False,
|
177 |
+
):
|
178 |
+
super().__init__()
|
179 |
+
self.channels = channels
|
180 |
+
self.emb_channels = emb_channels
|
181 |
+
self.dropout = dropout
|
182 |
+
self.out_channels = out_channels or channels
|
183 |
+
self.use_conv = use_conv
|
184 |
+
self.use_checkpoint = use_checkpoint
|
185 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
186 |
+
|
187 |
+
self.in_layers = nn.Sequential(
|
188 |
+
normalization(channels),
|
189 |
+
nn.SiLU(),
|
190 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
191 |
+
)
|
192 |
+
|
193 |
+
self.updown = up or down
|
194 |
+
|
195 |
+
if up:
|
196 |
+
self.h_upd = Upsample(channels, False, dims)
|
197 |
+
self.x_upd = Upsample(channels, False, dims)
|
198 |
+
elif down:
|
199 |
+
self.h_upd = Downsample(channels, False, dims)
|
200 |
+
self.x_upd = Downsample(channels, False, dims)
|
201 |
+
else:
|
202 |
+
self.h_upd = self.x_upd = nn.Identity()
|
203 |
+
|
204 |
+
self.emb_layers = nn.Sequential(
|
205 |
+
nn.SiLU(),
|
206 |
+
linear(
|
207 |
+
emb_channels,
|
208 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
209 |
+
),
|
210 |
+
)
|
211 |
+
self.out_layers = nn.Sequential(
|
212 |
+
normalization(self.out_channels),
|
213 |
+
nn.SiLU(),
|
214 |
+
nn.Dropout(p=dropout),
|
215 |
+
zero_module(
|
216 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
217 |
+
),
|
218 |
+
)
|
219 |
+
|
220 |
+
if self.out_channels == channels:
|
221 |
+
self.skip_connection = nn.Identity()
|
222 |
+
elif use_conv:
|
223 |
+
self.skip_connection = conv_nd(
|
224 |
+
dims, channels, self.out_channels, 3, padding=1
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
228 |
+
|
229 |
+
def forward(self, x, emb):
|
230 |
+
"""
|
231 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
232 |
+
|
233 |
+
:param x: an [N x C x ...] Tensor of features.
|
234 |
+
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
235 |
+
:return: an [N x C x ...] Tensor of outputs.
|
236 |
+
"""
|
237 |
+
return checkpoint(
|
238 |
+
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
239 |
+
)
|
240 |
+
|
241 |
+
def _forward(self, x, emb):
|
242 |
+
if self.updown:
|
243 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
244 |
+
h = in_rest(x)
|
245 |
+
h = self.h_upd(h)
|
246 |
+
x = self.x_upd(x)
|
247 |
+
h = in_conv(h)
|
248 |
+
else:
|
249 |
+
h = self.in_layers(x)
|
250 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
251 |
+
while len(emb_out.shape) < len(h.shape):
|
252 |
+
emb_out = emb_out[..., None]
|
253 |
+
if self.use_scale_shift_norm:
|
254 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
255 |
+
scale, shift = th.chunk(emb_out, 2, dim=1)
|
256 |
+
h = out_norm(h) * (1 + scale) + shift
|
257 |
+
h = out_rest(h)
|
258 |
+
else:
|
259 |
+
h = h + emb_out
|
260 |
+
h = self.out_layers(h)
|
261 |
+
return self.skip_connection(x) + h
|
262 |
+
|
263 |
+
|
264 |
+
class AttentionBlock(nn.Module):
|
265 |
+
"""
|
266 |
+
An attention block that allows spatial positions to attend to each other.
|
267 |
+
|
268 |
+
Originally ported from here, but adapted to the N-d case.
|
269 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
270 |
+
"""
|
271 |
+
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
channels,
|
275 |
+
num_heads=1,
|
276 |
+
num_head_channels=-1,
|
277 |
+
use_checkpoint=False,
|
278 |
+
use_new_attention_order=False,
|
279 |
+
):
|
280 |
+
super().__init__()
|
281 |
+
self.channels = channels
|
282 |
+
if num_head_channels == -1:
|
283 |
+
self.num_heads = num_heads
|
284 |
+
else:
|
285 |
+
assert (
|
286 |
+
channels % num_head_channels == 0
|
287 |
+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
288 |
+
self.num_heads = channels // num_head_channels
|
289 |
+
self.use_checkpoint = use_checkpoint
|
290 |
+
self.norm = normalization(channels)
|
291 |
+
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
292 |
+
if use_new_attention_order:
|
293 |
+
# split qkv before split heads
|
294 |
+
self.attention = QKVAttention(self.num_heads)
|
295 |
+
else:
|
296 |
+
# split heads before split qkv
|
297 |
+
self.attention = QKVAttentionLegacy(self.num_heads)
|
298 |
+
|
299 |
+
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
return checkpoint(self._forward, (x,), self.parameters(), True)
|
303 |
+
|
304 |
+
def _forward(self, x):
|
305 |
+
b, c, *spatial = x.shape
|
306 |
+
x = x.reshape(b, c, -1)
|
307 |
+
qkv = self.qkv(self.norm(x))
|
308 |
+
h = self.attention(qkv)
|
309 |
+
h = self.proj_out(h)
|
310 |
+
return (x + h).reshape(b, c, *spatial)
|
311 |
+
|
312 |
+
|
313 |
+
def count_flops_attn(model, _x, y):
|
314 |
+
"""
|
315 |
+
A counter for the `thop` package to count the operations in an
|
316 |
+
attention operation.
|
317 |
+
Meant to be used like:
|
318 |
+
macs, params = thop.profile(
|
319 |
+
model,
|
320 |
+
inputs=(inputs, timestamps),
|
321 |
+
custom_ops={QKVAttention: QKVAttention.count_flops},
|
322 |
+
)
|
323 |
+
"""
|
324 |
+
b, c, *spatial = y[0].shape
|
325 |
+
num_spatial = int(np.prod(spatial))
|
326 |
+
# We perform two matmuls with the same number of ops.
|
327 |
+
# The first computes the weight matrix, the second computes
|
328 |
+
# the combination of the value vectors.
|
329 |
+
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
330 |
+
model.total_ops += th.DoubleTensor([matmul_ops])
|
331 |
+
|
332 |
+
|
333 |
+
class QKVAttentionLegacy(nn.Module):
|
334 |
+
"""
|
335 |
+
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
336 |
+
"""
|
337 |
+
|
338 |
+
def __init__(self, n_heads):
|
339 |
+
super().__init__()
|
340 |
+
self.n_heads = n_heads
|
341 |
+
|
342 |
+
def forward(self, qkv):
|
343 |
+
"""
|
344 |
+
Apply QKV attention.
|
345 |
+
|
346 |
+
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
347 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
348 |
+
"""
|
349 |
+
bs, width, length = qkv.shape
|
350 |
+
assert width % (3 * self.n_heads) == 0
|
351 |
+
ch = width // (3 * self.n_heads)
|
352 |
+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
353 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
354 |
+
weight = th.einsum(
|
355 |
+
"bct,bcs->bts", q * scale, k * scale
|
356 |
+
) # More stable with f16 than dividing afterwards
|
357 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
358 |
+
a = th.einsum("bts,bcs->bct", weight, v)
|
359 |
+
return a.reshape(bs, -1, length)
|
360 |
+
|
361 |
+
@staticmethod
|
362 |
+
def count_flops(model, _x, y):
|
363 |
+
return count_flops_attn(model, _x, y)
|
364 |
+
|
365 |
+
|
366 |
+
class QKVAttention(nn.Module):
|
367 |
+
"""
|
368 |
+
A module which performs QKV attention and splits in a different order.
|
369 |
+
"""
|
370 |
+
|
371 |
+
def __init__(self, n_heads):
|
372 |
+
super().__init__()
|
373 |
+
self.n_heads = n_heads
|
374 |
+
|
375 |
+
def forward(self, qkv):
|
376 |
+
"""
|
377 |
+
Apply QKV attention.
|
378 |
+
|
379 |
+
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
380 |
+
:return: an [N x (H * C) x T] tensor after attention.
|
381 |
+
"""
|
382 |
+
bs, width, length = qkv.shape
|
383 |
+
assert width % (3 * self.n_heads) == 0
|
384 |
+
ch = width // (3 * self.n_heads)
|
385 |
+
q, k, v = qkv.chunk(3, dim=1)
|
386 |
+
scale = 1 / math.sqrt(math.sqrt(ch))
|
387 |
+
weight = th.einsum(
|
388 |
+
"bct,bcs->bts",
|
389 |
+
(q * scale).view(bs * self.n_heads, ch, length),
|
390 |
+
(k * scale).view(bs * self.n_heads, ch, length),
|
391 |
+
) # More stable with f16 than dividing afterwards
|
392 |
+
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
393 |
+
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
394 |
+
return a.reshape(bs, -1, length)
|
395 |
+
|
396 |
+
@staticmethod
|
397 |
+
def count_flops(model, _x, y):
|
398 |
+
return count_flops_attn(model, _x, y)
|
399 |
+
|
400 |
+
|
401 |
+
class UNetModel(nn.Module):
|
402 |
+
"""
|
403 |
+
The full UNet model with attention and timestep embedding.
|
404 |
+
|
405 |
+
:param in_channels: channels in the input Tensor.
|
406 |
+
:param model_channels: base channel count for the model.
|
407 |
+
:param out_channels: channels in the output Tensor.
|
408 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
409 |
+
:param attention_resolutions: a collection of downsample rates at which
|
410 |
+
attention will take place. May be a set, list, or tuple.
|
411 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
412 |
+
will be used.
|
413 |
+
:param dropout: the dropout probability.
|
414 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
415 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
416 |
+
downsampling.
|
417 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
418 |
+
:param num_classes: if specified (as an int), then this model will be
|
419 |
+
class-conditional with `num_classes` classes.
|
420 |
+
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
421 |
+
:param num_heads: the number of attention heads in each attention layer.
|
422 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
423 |
+
a fixed channel width per attention head.
|
424 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
425 |
+
of heads for upsampling. Deprecated.
|
426 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
427 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
428 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
429 |
+
increased efficiency.
|
430 |
+
"""
|
431 |
+
|
432 |
+
def __init__(
|
433 |
+
self,
|
434 |
+
image_size,
|
435 |
+
in_channels,
|
436 |
+
model_channels,
|
437 |
+
out_channels,
|
438 |
+
num_res_blocks,
|
439 |
+
attention_resolutions,
|
440 |
+
dropout=0,
|
441 |
+
channel_mult=(1, 2, 4, 8),
|
442 |
+
conv_resample=True,
|
443 |
+
dims=2,
|
444 |
+
num_classes=None,
|
445 |
+
use_checkpoint=False,
|
446 |
+
use_fp16=False,
|
447 |
+
num_heads=1,
|
448 |
+
num_head_channels=-1,
|
449 |
+
num_heads_upsample=-1,
|
450 |
+
use_scale_shift_norm=False,
|
451 |
+
resblock_updown=False,
|
452 |
+
use_new_attention_order=False,
|
453 |
+
):
|
454 |
+
super().__init__()
|
455 |
+
|
456 |
+
if num_heads_upsample == -1:
|
457 |
+
num_heads_upsample = num_heads
|
458 |
+
|
459 |
+
self.image_size = image_size
|
460 |
+
self.in_channels = in_channels
|
461 |
+
self.model_channels = model_channels
|
462 |
+
self.out_channels = out_channels
|
463 |
+
self.num_res_blocks = num_res_blocks
|
464 |
+
self.attention_resolutions = attention_resolutions
|
465 |
+
self.dropout = dropout
|
466 |
+
self.channel_mult = channel_mult
|
467 |
+
self.conv_resample = conv_resample
|
468 |
+
self.num_classes = num_classes
|
469 |
+
self.use_checkpoint = use_checkpoint
|
470 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
471 |
+
self.num_heads = num_heads
|
472 |
+
self.num_head_channels = num_head_channels
|
473 |
+
self.num_heads_upsample = num_heads_upsample
|
474 |
+
|
475 |
+
time_embed_dim = model_channels * 4
|
476 |
+
self.time_embed = nn.Sequential(
|
477 |
+
linear(model_channels, time_embed_dim),
|
478 |
+
nn.SiLU(),
|
479 |
+
linear(time_embed_dim, time_embed_dim),
|
480 |
+
)
|
481 |
+
|
482 |
+
if self.num_classes is not None:
|
483 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
484 |
+
|
485 |
+
ch = input_ch = int(channel_mult[0] * model_channels)
|
486 |
+
self.input_blocks = nn.ModuleList(
|
487 |
+
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
|
488 |
+
)
|
489 |
+
self._feature_size = ch
|
490 |
+
input_block_chans = [ch]
|
491 |
+
ds = 1
|
492 |
+
for level, mult in enumerate(channel_mult):
|
493 |
+
for _ in range(num_res_blocks):
|
494 |
+
layers = [
|
495 |
+
ResBlock(
|
496 |
+
ch,
|
497 |
+
time_embed_dim,
|
498 |
+
dropout,
|
499 |
+
out_channels=int(mult * model_channels),
|
500 |
+
dims=dims,
|
501 |
+
use_checkpoint=use_checkpoint,
|
502 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
503 |
+
)
|
504 |
+
]
|
505 |
+
ch = int(mult * model_channels)
|
506 |
+
if ds in attention_resolutions:
|
507 |
+
layers.append(
|
508 |
+
AttentionBlock(
|
509 |
+
ch,
|
510 |
+
use_checkpoint=use_checkpoint,
|
511 |
+
num_heads=num_heads,
|
512 |
+
num_head_channels=num_head_channels,
|
513 |
+
use_new_attention_order=use_new_attention_order,
|
514 |
+
)
|
515 |
+
)
|
516 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
517 |
+
self._feature_size += ch
|
518 |
+
input_block_chans.append(ch)
|
519 |
+
if level != len(channel_mult) - 1:
|
520 |
+
out_ch = ch
|
521 |
+
self.input_blocks.append(
|
522 |
+
TimestepEmbedSequential(
|
523 |
+
ResBlock(
|
524 |
+
ch,
|
525 |
+
time_embed_dim,
|
526 |
+
dropout,
|
527 |
+
out_channels=out_ch,
|
528 |
+
dims=dims,
|
529 |
+
use_checkpoint=use_checkpoint,
|
530 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
531 |
+
down=True,
|
532 |
+
)
|
533 |
+
if resblock_updown
|
534 |
+
else Downsample(
|
535 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
536 |
+
)
|
537 |
+
)
|
538 |
+
)
|
539 |
+
ch = out_ch
|
540 |
+
input_block_chans.append(ch)
|
541 |
+
ds *= 2
|
542 |
+
self._feature_size += ch
|
543 |
+
|
544 |
+
self.middle_block = TimestepEmbedSequential(
|
545 |
+
ResBlock(
|
546 |
+
ch,
|
547 |
+
time_embed_dim,
|
548 |
+
dropout,
|
549 |
+
dims=dims,
|
550 |
+
use_checkpoint=use_checkpoint,
|
551 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
552 |
+
),
|
553 |
+
AttentionBlock(
|
554 |
+
ch,
|
555 |
+
use_checkpoint=use_checkpoint,
|
556 |
+
num_heads=num_heads,
|
557 |
+
num_head_channels=num_head_channels,
|
558 |
+
use_new_attention_order=use_new_attention_order,
|
559 |
+
),
|
560 |
+
ResBlock(
|
561 |
+
ch,
|
562 |
+
time_embed_dim,
|
563 |
+
dropout,
|
564 |
+
dims=dims,
|
565 |
+
use_checkpoint=use_checkpoint,
|
566 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
567 |
+
),
|
568 |
+
)
|
569 |
+
self._feature_size += ch
|
570 |
+
|
571 |
+
self.output_blocks = nn.ModuleList([])
|
572 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
573 |
+
for i in range(num_res_blocks + 1):
|
574 |
+
ich = input_block_chans.pop()
|
575 |
+
layers = [
|
576 |
+
ResBlock(
|
577 |
+
ch + ich,
|
578 |
+
time_embed_dim,
|
579 |
+
dropout,
|
580 |
+
out_channels=int(model_channels * mult),
|
581 |
+
dims=dims,
|
582 |
+
use_checkpoint=use_checkpoint,
|
583 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
584 |
+
)
|
585 |
+
]
|
586 |
+
ch = int(model_channels * mult)
|
587 |
+
if ds in attention_resolutions:
|
588 |
+
layers.append(
|
589 |
+
AttentionBlock(
|
590 |
+
ch,
|
591 |
+
use_checkpoint=use_checkpoint,
|
592 |
+
num_heads=num_heads_upsample,
|
593 |
+
num_head_channels=num_head_channels,
|
594 |
+
use_new_attention_order=use_new_attention_order,
|
595 |
+
)
|
596 |
+
)
|
597 |
+
if level and i == num_res_blocks:
|
598 |
+
out_ch = ch
|
599 |
+
layers.append(
|
600 |
+
ResBlock(
|
601 |
+
ch,
|
602 |
+
time_embed_dim,
|
603 |
+
dropout,
|
604 |
+
out_channels=out_ch,
|
605 |
+
dims=dims,
|
606 |
+
use_checkpoint=use_checkpoint,
|
607 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
608 |
+
up=True,
|
609 |
+
)
|
610 |
+
if resblock_updown
|
611 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
612 |
+
)
|
613 |
+
ds //= 2
|
614 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
615 |
+
self._feature_size += ch
|
616 |
+
|
617 |
+
self.out = nn.Sequential(
|
618 |
+
normalization(ch),
|
619 |
+
nn.SiLU(),
|
620 |
+
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
|
621 |
+
)
|
622 |
+
|
623 |
+
def convert_to_fp16(self):
|
624 |
+
"""
|
625 |
+
Convert the torso of the model to float16.
|
626 |
+
"""
|
627 |
+
self.input_blocks.apply(convert_module_to_f16)
|
628 |
+
self.middle_block.apply(convert_module_to_f16)
|
629 |
+
self.output_blocks.apply(convert_module_to_f16)
|
630 |
+
|
631 |
+
def convert_to_fp32(self):
|
632 |
+
"""
|
633 |
+
Convert the torso of the model to float32.
|
634 |
+
"""
|
635 |
+
self.input_blocks.apply(convert_module_to_f32)
|
636 |
+
self.middle_block.apply(convert_module_to_f32)
|
637 |
+
self.output_blocks.apply(convert_module_to_f32)
|
638 |
+
|
639 |
+
def forward(self, x, timesteps, y=None, ref_img=None, edit_h=None):
|
640 |
+
"""
|
641 |
+
Apply the model to an input batch.
|
642 |
+
|
643 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
644 |
+
:param timesteps: a 1-D batch of timesteps.
|
645 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
646 |
+
:return: an [N x C x ...] Tensor of outputs.
|
647 |
+
"""
|
648 |
+
# assert (y is not None) == (
|
649 |
+
# self.num_classes is not None
|
650 |
+
# ), "must specify y if and only if the model is class-conditional"
|
651 |
+
|
652 |
+
hs = []
|
653 |
+
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
654 |
+
|
655 |
+
# if self.num_classes is not None:
|
656 |
+
# assert y.shape == (x.shape[0],)
|
657 |
+
# emb = emb + self.label_emb(y)
|
658 |
+
|
659 |
+
h = x.type(self.dtype)
|
660 |
+
for module in self.input_blocks:
|
661 |
+
h = module(h, emb)
|
662 |
+
hs.append(h)
|
663 |
+
h = self.middle_block(h, emb)
|
664 |
+
mid_h = h.detach().clone() # get the bottleneck h space embedding
|
665 |
+
# print("check Unet:", mid_h.size()) # [1, 512, 8, 8]
|
666 |
+
# exit()
|
667 |
+
if edit_h != None:
|
668 |
+
h = edit_h
|
669 |
+
|
670 |
+
for module in self.output_blocks:
|
671 |
+
h = th.cat([h, hs.pop()], dim=1)
|
672 |
+
h = module(h, emb)
|
673 |
+
h = h.type(x.dtype)
|
674 |
+
# print("check U-NET output:", h.size(), mid_h.size(), self.out(h).size()) # [1,3,256,256]
|
675 |
+
# exit()
|
676 |
+
|
677 |
+
return mid_h, self.out(h)
|
models/insight_face/__init__.py
ADDED
File without changes
|
models/insight_face/helpers.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
4 |
+
|
5 |
+
"""
|
6 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class Conv_block(Module):
|
13 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
14 |
+
super(Conv_block, self).__init__()
|
15 |
+
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
|
16 |
+
self.bn = BatchNorm2d(out_c)
|
17 |
+
self.prelu = PReLU(out_c)
|
18 |
+
def forward(self, x):
|
19 |
+
x = self.conv(x)
|
20 |
+
x = self.bn(x)
|
21 |
+
x = self.prelu(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
class Linear_block(Module):
|
25 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
26 |
+
super(Linear_block, self).__init__()
|
27 |
+
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
|
28 |
+
self.bn = BatchNorm2d(out_c)
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.conv(x)
|
31 |
+
x = self.bn(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
class Depth_Wise(Module):
|
35 |
+
def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
36 |
+
super(Depth_Wise, self).__init__()
|
37 |
+
self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
38 |
+
self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
|
39 |
+
self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
40 |
+
self.residual = residual
|
41 |
+
def forward(self, x):
|
42 |
+
if self.residual:
|
43 |
+
short_cut = x
|
44 |
+
x = self.conv(x)
|
45 |
+
x = self.conv_dw(x)
|
46 |
+
x = self.project(x)
|
47 |
+
if self.residual:
|
48 |
+
output = short_cut + x
|
49 |
+
else:
|
50 |
+
output = x
|
51 |
+
return output
|
52 |
+
|
53 |
+
class Residual(Module):
|
54 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
55 |
+
super(Residual, self).__init__()
|
56 |
+
modules = []
|
57 |
+
for _ in range(num_block):
|
58 |
+
modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
|
59 |
+
self.model = Sequential(*modules)
|
60 |
+
def forward(self, x):
|
61 |
+
return self.model(x)
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
######################################################################################
|
67 |
+
|
68 |
+
|
69 |
+
class Flatten(Module):
|
70 |
+
def forward(self, input):
|
71 |
+
return input.view(input.size(0), -1)
|
72 |
+
|
73 |
+
|
74 |
+
def l2_norm(input, axis=1):
|
75 |
+
norm = torch.norm(input, 2, axis, True)
|
76 |
+
output = torch.div(input, norm)
|
77 |
+
return output
|
78 |
+
|
79 |
+
|
80 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
81 |
+
""" A named tuple describing a ResNet block. """
|
82 |
+
|
83 |
+
|
84 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
85 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
86 |
+
|
87 |
+
|
88 |
+
def get_blocks(num_layers):
|
89 |
+
if num_layers == 50:
|
90 |
+
blocks = [
|
91 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
92 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
93 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
94 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
95 |
+
]
|
96 |
+
elif num_layers == 100:
|
97 |
+
blocks = [
|
98 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
99 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
100 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
101 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
102 |
+
]
|
103 |
+
elif num_layers == 152:
|
104 |
+
blocks = [
|
105 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
106 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
107 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
108 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
109 |
+
]
|
110 |
+
else:
|
111 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
112 |
+
return blocks
|
113 |
+
|
114 |
+
|
115 |
+
class SEModule(Module):
|
116 |
+
def __init__(self, channels, reduction):
|
117 |
+
super(SEModule, self).__init__()
|
118 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
119 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
120 |
+
self.relu = ReLU(inplace=True)
|
121 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
122 |
+
self.sigmoid = Sigmoid()
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
module_input = x
|
126 |
+
x = self.avg_pool(x)
|
127 |
+
x = self.fc1(x)
|
128 |
+
x = self.relu(x)
|
129 |
+
x = self.fc2(x)
|
130 |
+
x = self.sigmoid(x)
|
131 |
+
return module_input * x
|
132 |
+
|
133 |
+
|
134 |
+
class bottleneck_IR(Module):
|
135 |
+
def __init__(self, in_channel, depth, stride):
|
136 |
+
super(bottleneck_IR, self).__init__()
|
137 |
+
if in_channel == depth:
|
138 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
139 |
+
else:
|
140 |
+
self.shortcut_layer = Sequential(
|
141 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
142 |
+
BatchNorm2d(depth)
|
143 |
+
)
|
144 |
+
self.res_layer = Sequential(
|
145 |
+
BatchNorm2d(in_channel),
|
146 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
147 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
148 |
+
)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
shortcut = self.shortcut_layer(x)
|
152 |
+
res = self.res_layer(x)
|
153 |
+
return res + shortcut
|
154 |
+
|
155 |
+
|
156 |
+
class bottleneck_IR_SE(Module):
|
157 |
+
def __init__(self, in_channel, depth, stride):
|
158 |
+
super(bottleneck_IR_SE, self).__init__()
|
159 |
+
if in_channel == depth:
|
160 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
161 |
+
else:
|
162 |
+
self.shortcut_layer = Sequential(
|
163 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
164 |
+
BatchNorm2d(depth)
|
165 |
+
)
|
166 |
+
self.res_layer = Sequential(
|
167 |
+
BatchNorm2d(in_channel),
|
168 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
169 |
+
PReLU(depth),
|
170 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
171 |
+
BatchNorm2d(depth),
|
172 |
+
SEModule(depth, 16)
|
173 |
+
)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
shortcut = self.shortcut_layer(x)
|
177 |
+
res = self.res_layer(x)
|
178 |
+
return res + shortcut
|
models/insight_face/model_irse.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
2 |
+
from models.insight_face.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
3 |
+
from models.insight_face.helpers import Conv_block, Linear_block, Depth_Wise, Residual
|
4 |
+
"""
|
5 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class MobileFaceNet(Module):
|
10 |
+
def __init__(self, embedding_size):
|
11 |
+
super(MobileFaceNet, self).__init__()
|
12 |
+
self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
|
13 |
+
self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
|
14 |
+
self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
|
15 |
+
self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
16 |
+
self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
|
17 |
+
self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
18 |
+
self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
|
19 |
+
self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
20 |
+
self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
21 |
+
self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
|
22 |
+
self.conv_6_flatten = Flatten()
|
23 |
+
self.linear = Linear(512, embedding_size, bias=False)
|
24 |
+
self.bn = BatchNorm1d(embedding_size)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
out = self.conv1(x)
|
28 |
+
out = self.conv2_dw(out)
|
29 |
+
out = self.conv_23(out)
|
30 |
+
out = self.conv_3(out)
|
31 |
+
out = self.conv_34(out)
|
32 |
+
out = self.conv_4(out)
|
33 |
+
out = self.conv_45(out)
|
34 |
+
out = self.conv_5(out)
|
35 |
+
out = self.conv_6_sep(out)
|
36 |
+
out = self.conv_6_dw(out)
|
37 |
+
out = self.conv_6_flatten(out)
|
38 |
+
out = self.linear(out)
|
39 |
+
out = self.bn(out)
|
40 |
+
return l2_norm(out)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
######################################################################################
|
48 |
+
|
49 |
+
class Backbone(Module):
|
50 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
51 |
+
super(Backbone, self).__init__()
|
52 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
53 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
54 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
55 |
+
blocks = get_blocks(num_layers)
|
56 |
+
if mode == 'ir':
|
57 |
+
unit_module = bottleneck_IR
|
58 |
+
elif mode == 'ir_se':
|
59 |
+
unit_module = bottleneck_IR_SE
|
60 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
61 |
+
BatchNorm2d(64),
|
62 |
+
PReLU(64))
|
63 |
+
if input_size == 112:
|
64 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
65 |
+
Dropout(drop_ratio),
|
66 |
+
Flatten(),
|
67 |
+
Linear(512 * 7 * 7, 512),
|
68 |
+
BatchNorm1d(512, affine=affine))
|
69 |
+
else:
|
70 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
71 |
+
Dropout(drop_ratio),
|
72 |
+
Flatten(),
|
73 |
+
Linear(512 * 14 * 14, 512),
|
74 |
+
BatchNorm1d(512, affine=affine))
|
75 |
+
|
76 |
+
modules = []
|
77 |
+
for block in blocks:
|
78 |
+
for bottleneck in block:
|
79 |
+
modules.append(unit_module(bottleneck.in_channel,
|
80 |
+
bottleneck.depth,
|
81 |
+
bottleneck.stride))
|
82 |
+
self.body = Sequential(*modules)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = self.input_layer(x)
|
86 |
+
x = self.body(x)
|
87 |
+
x = self.output_layer(x)
|
88 |
+
return l2_norm(x)
|
89 |
+
|
90 |
+
|
91 |
+
def IR_50(input_size):
|
92 |
+
"""Constructs a ir-50 model."""
|
93 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
94 |
+
return model
|
95 |
+
|
96 |
+
|
97 |
+
def IR_101(input_size):
|
98 |
+
"""Constructs a ir-101 model."""
|
99 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
100 |
+
return model
|
101 |
+
|
102 |
+
|
103 |
+
def IR_152(input_size):
|
104 |
+
"""Constructs a ir-152 model."""
|
105 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
106 |
+
return model
|
107 |
+
|
108 |
+
|
109 |
+
def IR_SE_50(input_size):
|
110 |
+
"""Constructs a ir_se-50 model."""
|
111 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
112 |
+
return model
|
113 |
+
|
114 |
+
|
115 |
+
def IR_SE_101(input_size):
|
116 |
+
"""Constructs a ir_se-101 model."""
|
117 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
118 |
+
return model
|
119 |
+
|
120 |
+
|
121 |
+
def IR_SE_152(input_size):
|
122 |
+
"""Constructs a ir_se-152 model."""
|
123 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
124 |
+
return model
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmake>=3.22.3
|
2 |
+
lmdb>=1.2.1
|
3 |
+
numpy>=1.19.2
|
4 |
+
Pillow>=8.4.0
|
5 |
+
PyYAML>=6.0
|
6 |
+
tqdm>=4.55.1
|
7 |
+
opencv_python>=4.5.2.52
|
8 |
+
ftfy>=6.0.3
|
9 |
+
regex>=2021.10.23
|
10 |
+
dlib>=19.22.1
|
utils/align_utils.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
|
3 |
+
author: lzhbrian (https://lzhbrian.me)
|
4 |
+
date: 2020.1.5
|
5 |
+
note: code is heavily borrowed from
|
6 |
+
https://github.com/NVlabs/ffhq-dataset
|
7 |
+
http://dlib.net/face_landmark_detection.py.html
|
8 |
+
|
9 |
+
requirements:
|
10 |
+
apt install cmake
|
11 |
+
conda install Pillow numpy scipy
|
12 |
+
pip install dlib
|
13 |
+
# download face landmark model from:
|
14 |
+
# http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
|
15 |
+
"""
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
import time
|
18 |
+
import numpy as np
|
19 |
+
import PIL
|
20 |
+
import PIL.Image
|
21 |
+
import os
|
22 |
+
import scipy
|
23 |
+
import scipy.ndimage
|
24 |
+
import dlib
|
25 |
+
import multiprocessing as mp
|
26 |
+
import math
|
27 |
+
|
28 |
+
from configs.paths_config import MODEL_PATHS
|
29 |
+
|
30 |
+
SHAPE_PREDICTOR_PATH = MODEL_PATHS["shape_predictor"]
|
31 |
+
|
32 |
+
|
33 |
+
def run_alignment(image_path, output_size):
|
34 |
+
if not os.path.exists("pretrained/shape_predictor_68_face_landmarks.dat"):
|
35 |
+
print('Downloading files for aligning face image...')
|
36 |
+
os.system(f'wget -P pretrained/ http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')
|
37 |
+
os.system('bzip2 -dk pretrained/shape_predictor_68_face_landmarks.dat.bz2')
|
38 |
+
print('Done.')
|
39 |
+
predictor = dlib.shape_predictor("pretrained/shape_predictor_68_face_landmarks.dat")
|
40 |
+
aligned_image = align_face(filepath=image_path, predictor=predictor, output_size=output_size, transform_size=output_size)
|
41 |
+
print("Aligned image has shape: {}".format(aligned_image.size))
|
42 |
+
return aligned_image
|
43 |
+
|
44 |
+
|
45 |
+
def get_landmark(filepath, predictor):
|
46 |
+
"""get landmark with dlib
|
47 |
+
:return: np.array shape=(68, 2)
|
48 |
+
"""
|
49 |
+
detector = dlib.get_frontal_face_detector()
|
50 |
+
|
51 |
+
img = dlib.load_rgb_image(filepath)
|
52 |
+
dets = detector(img, 1)
|
53 |
+
|
54 |
+
for k, d in enumerate(dets):
|
55 |
+
shape = predictor(img, d)
|
56 |
+
|
57 |
+
t = list(shape.parts())
|
58 |
+
a = []
|
59 |
+
for tt in t:
|
60 |
+
a.append([tt.x, tt.y])
|
61 |
+
lm = np.array(a)
|
62 |
+
return lm
|
63 |
+
|
64 |
+
|
65 |
+
def align_face(filepath, predictor, output_size=256, transform_size=256):
|
66 |
+
"""
|
67 |
+
:param filepath: str
|
68 |
+
:return: PIL Image
|
69 |
+
"""
|
70 |
+
|
71 |
+
lm = get_landmark(filepath, predictor)
|
72 |
+
|
73 |
+
lm_chin = lm[0: 17] # left-right
|
74 |
+
lm_eyebrow_left = lm[17: 22] # left-right
|
75 |
+
lm_eyebrow_right = lm[22: 27] # left-right
|
76 |
+
lm_nose = lm[27: 31] # top-down
|
77 |
+
lm_nostrils = lm[31: 36] # top-down
|
78 |
+
lm_eye_left = lm[36: 42] # left-clockwise
|
79 |
+
lm_eye_right = lm[42: 48] # left-clockwise
|
80 |
+
lm_mouth_outer = lm[48: 60] # left-clockwise
|
81 |
+
lm_mouth_inner = lm[60: 68] # left-clockwise
|
82 |
+
|
83 |
+
# Calculate auxiliary vectors.
|
84 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
85 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
86 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
87 |
+
eye_to_eye = eye_right - eye_left
|
88 |
+
mouth_left = lm_mouth_outer[0]
|
89 |
+
mouth_right = lm_mouth_outer[6]
|
90 |
+
mouth_avg = (mouth_left + mouth_right) * 0.5
|
91 |
+
eye_to_mouth = mouth_avg - eye_avg
|
92 |
+
|
93 |
+
# Choose oriented crop rectangle.
|
94 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
95 |
+
x /= np.hypot(*x)
|
96 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
97 |
+
y = np.flipud(x) * [-1, 1]
|
98 |
+
c = eye_avg + eye_to_mouth * 0.1
|
99 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
100 |
+
qsize = np.hypot(*x) * 2
|
101 |
+
|
102 |
+
# read image
|
103 |
+
img = PIL.Image.open(filepath)
|
104 |
+
enable_padding = True
|
105 |
+
|
106 |
+
# Shrink.
|
107 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
108 |
+
if shrink > 1:
|
109 |
+
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
|
110 |
+
img = img.resize(rsize, PIL.Image.ANTIALIAS)
|
111 |
+
quad /= shrink
|
112 |
+
qsize /= shrink
|
113 |
+
|
114 |
+
# Crop.
|
115 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
116 |
+
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
117 |
+
int(np.ceil(max(quad[:, 1]))))
|
118 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
|
119 |
+
min(crop[3] + border, img.size[1]))
|
120 |
+
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
|
121 |
+
img = img.crop(crop)
|
122 |
+
quad -= crop[0:2]
|
123 |
+
|
124 |
+
# Pad.
|
125 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
126 |
+
int(np.ceil(max(quad[:, 1]))))
|
127 |
+
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
|
128 |
+
max(pad[3] - img.size[1] + border, 0))
|
129 |
+
if enable_padding and max(pad) > border - 4:
|
130 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
131 |
+
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
132 |
+
h, w, _ = img.shape
|
133 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
134 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
|
135 |
+
1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
|
136 |
+
blur = qsize * 0.02
|
137 |
+
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
138 |
+
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
139 |
+
img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
|
140 |
+
quad += pad[:2]
|
141 |
+
|
142 |
+
# Transform.
|
143 |
+
img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
|
144 |
+
if output_size < transform_size:
|
145 |
+
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
|
146 |
+
|
147 |
+
# Save aligned image.
|
148 |
+
return img
|
149 |
+
|
150 |
+
|
151 |
+
def chunks(lst, n):
|
152 |
+
"""Yield successive n-sized chunks from lst."""
|
153 |
+
for i in range(0, len(lst), n):
|
154 |
+
yield lst[i:i + n]
|
155 |
+
|
156 |
+
|
157 |
+
def extract_on_paths(file_paths):
|
158 |
+
predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
|
159 |
+
pid = mp.current_process().name
|
160 |
+
print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths)))
|
161 |
+
tot_count = len(file_paths)
|
162 |
+
count = 0
|
163 |
+
for file_path, res_path in file_paths:
|
164 |
+
count += 1
|
165 |
+
if count % 100 == 0:
|
166 |
+
print('{} done with {}/{}'.format(pid, count, tot_count))
|
167 |
+
try:
|
168 |
+
res = align_face(file_path, predictor)
|
169 |
+
res = res.convert('RGB')
|
170 |
+
os.makedirs(os.path.dirname(res_path), exist_ok=True)
|
171 |
+
res.save(res_path)
|
172 |
+
except Exception:
|
173 |
+
continue
|
174 |
+
print('\tDone!')
|
175 |
+
|
176 |
+
|
177 |
+
def parse_args():
|
178 |
+
parser = ArgumentParser(add_help=False)
|
179 |
+
parser.add_argument('--num_threads', type=int, default=1)
|
180 |
+
parser.add_argument('--root_path', type=str, default='')
|
181 |
+
args = parser.parse_args()
|
182 |
+
return args
|
183 |
+
|
184 |
+
|
185 |
+
def run(args):
|
186 |
+
root_path = args.root_path
|
187 |
+
out_crops_path = root_path + '_crops'
|
188 |
+
if not os.path.exists(out_crops_path):
|
189 |
+
os.makedirs(out_crops_path, exist_ok=True)
|
190 |
+
|
191 |
+
file_paths = []
|
192 |
+
for root, dirs, files in os.walk(root_path):
|
193 |
+
for file in files:
|
194 |
+
file_path = os.path.join(root, file)
|
195 |
+
fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path))
|
196 |
+
res_path = '{}.jpg'.format(os.path.splitext(fname)[0])
|
197 |
+
if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path):
|
198 |
+
continue
|
199 |
+
file_paths.append((file_path, res_path))
|
200 |
+
|
201 |
+
file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
|
202 |
+
print(len(file_chunks))
|
203 |
+
pool = mp.Pool(args.num_threads)
|
204 |
+
print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
|
205 |
+
tic = time.time()
|
206 |
+
pool.map(extract_on_paths, file_chunks)
|
207 |
+
toc = time.time()
|
208 |
+
print('Mischief managed in {}s'.format(toc - tic))
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == '__main__':
|
212 |
+
args = parse_args()
|
213 |
+
run(args)
|
utils/celeba_attr.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
5_o_Clock_Shadow
|
2 |
+
Arched_Eyebrows
|
3 |
+
Attractive
|
4 |
+
Bags_Under_Eyes
|
5 |
+
Bald
|
6 |
+
Bangs
|
7 |
+
Big_Lips
|
8 |
+
Big_Nose
|
9 |
+
Black_Hair
|
10 |
+
Blond_Hair
|
11 |
+
Blurry
|
12 |
+
Brown_Hair
|
13 |
+
Bushy_Eyebrows
|
14 |
+
Chubby
|
15 |
+
Double_Chin
|
16 |
+
Eyeglasses
|
17 |
+
Goatee
|
18 |
+
Gray_Hair
|
19 |
+
Heavy_Makeup
|
20 |
+
High_Cheekbones
|
21 |
+
Male
|
22 |
+
Mouth_Slightly_Open
|
23 |
+
Mustache
|
24 |
+
Narrow_Eyes
|
25 |
+
No_Beard
|
26 |
+
Oval_Face
|
27 |
+
Pale_Skin
|
28 |
+
Pointy_Nose
|
29 |
+
Receding_Hairline
|
30 |
+
Rosy_Cheeks
|
31 |
+
Sideburns
|
32 |
+
Smiling
|
33 |
+
Straight_Hair
|
34 |
+
Wavy_Hair
|
35 |
+
Wearing_Earrings
|
36 |
+
Wearing_Hat
|
37 |
+
Wearing_Lipstick
|
38 |
+
Wearing_Necklace
|
39 |
+
Wearing_Necktie
|
40 |
+
Young
|
utils/colab_utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydrive.auth import GoogleAuth
|
2 |
+
from pydrive.drive import GoogleDrive
|
3 |
+
from google.colab import auth
|
4 |
+
from oauth2client.client import GoogleCredentials
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
class GoogleDrive_Dowonloader(object):
|
9 |
+
def __init__(self, use_pydrive):
|
10 |
+
self.use_pydrive = use_pydrive
|
11 |
+
|
12 |
+
if self.use_pydrive:
|
13 |
+
self.authenticate()
|
14 |
+
|
15 |
+
def authenticate(self):
|
16 |
+
auth.authenticate_user()
|
17 |
+
gauth = GoogleAuth()
|
18 |
+
gauth.credentials = GoogleCredentials.get_application_default()
|
19 |
+
self.drive = GoogleDrive(gauth)
|
20 |
+
|
21 |
+
def ensure_file_exists(self, file_id, file_dst):
|
22 |
+
if not os.path.isfile(file_dst):
|
23 |
+
if self.use_pydrive:
|
24 |
+
print(f'Downloading {file_dst} ...')
|
25 |
+
downloaded = self.drive.CreateFile({'id':file_id})
|
26 |
+
downloaded.FetchMetadata(fetch_all=True)
|
27 |
+
downloaded.GetContentFile(file_dst)
|
28 |
+
print('Finished')
|
29 |
+
else:
|
30 |
+
from gdown import download as drive_download
|
31 |
+
drive_download(f'https://drive.google.com/uc?id={file_id}', file_dst, quiet=False)
|
32 |
+
else:
|
33 |
+
print(f'{file_dst} exists.')
|
34 |
+
|
35 |
+
|
36 |
+
|
utils/diffusion_utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
|
6 |
+
betas = np.linspace(beta_start, beta_end,
|
7 |
+
num_diffusion_timesteps, dtype=np.float64)
|
8 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
9 |
+
return betas
|
10 |
+
|
11 |
+
|
12 |
+
def extract(a, t, x_shape):
|
13 |
+
"""Extract coefficients from a based on t and reshape to make it
|
14 |
+
broadcastable with x_shape."""
|
15 |
+
bs, = t.shape
|
16 |
+
assert x_shape[0] == bs
|
17 |
+
out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long())
|
18 |
+
assert out.shape == (bs,)
|
19 |
+
out = out.reshape((bs,) + (1,) * (len(x_shape) - 1))
|
20 |
+
return out
|
21 |
+
|
22 |
+
|
23 |
+
def denoising_step(xt, t, t_next, *,
|
24 |
+
models,
|
25 |
+
logvars,
|
26 |
+
b,
|
27 |
+
sampling_type='ddpm',
|
28 |
+
eta=0.0,
|
29 |
+
learn_sigma=False,
|
30 |
+
hybrid=False,
|
31 |
+
hybrid_config=None,
|
32 |
+
ratio=1.0,
|
33 |
+
out_x0_t=False,
|
34 |
+
edit_h=None,
|
35 |
+
):
|
36 |
+
|
37 |
+
# Compute noise and variance
|
38 |
+
if type(models) != list:
|
39 |
+
model = models
|
40 |
+
if edit_h == None:
|
41 |
+
mid_h, et = model(xt, t)
|
42 |
+
# print("check mid_h and et:", mid_h.size(), et.size())
|
43 |
+
else:
|
44 |
+
mid_h, et = model(xt, t, edit_h)
|
45 |
+
# print("Denoising for editing!")
|
46 |
+
if learn_sigma:
|
47 |
+
et, logvar_learned = torch.split(et, et.shape[1] // 2, dim=1)
|
48 |
+
logvar = logvar_learned
|
49 |
+
# print("split et:", et.size())
|
50 |
+
else:
|
51 |
+
logvar = extract(logvars, t, xt.shape)
|
52 |
+
else:
|
53 |
+
if not hybrid:
|
54 |
+
et = 0
|
55 |
+
logvar = 0
|
56 |
+
if ratio != 0.0:
|
57 |
+
et_i = ratio * models[1](xt, t)
|
58 |
+
if learn_sigma:
|
59 |
+
et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
|
60 |
+
logvar += logvar_learned
|
61 |
+
else:
|
62 |
+
logvar += ratio * extract(logvars, t, xt.shape)
|
63 |
+
et += et_i
|
64 |
+
|
65 |
+
if ratio != 1.0:
|
66 |
+
et_i = (1 - ratio) * models[0](xt, t)
|
67 |
+
if learn_sigma:
|
68 |
+
et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
|
69 |
+
logvar += logvar_learned
|
70 |
+
else:
|
71 |
+
logvar += (1 - ratio) * extract(logvars, t, xt.shape)
|
72 |
+
et += et_i
|
73 |
+
|
74 |
+
else:
|
75 |
+
for thr in list(hybrid_config.keys()):
|
76 |
+
if t.item() >= thr:
|
77 |
+
et = 0
|
78 |
+
logvar = 0
|
79 |
+
for i, ratio in enumerate(hybrid_config[thr]):
|
80 |
+
ratio /= sum(hybrid_config[thr])
|
81 |
+
et_i = models[i+1](xt, t)
|
82 |
+
if learn_sigma:
|
83 |
+
et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
|
84 |
+
logvar_i = logvar_learned
|
85 |
+
else:
|
86 |
+
logvar_i = extract(logvars, t, xt.shape)
|
87 |
+
et += ratio * et_i
|
88 |
+
logvar += ratio * logvar_i
|
89 |
+
break
|
90 |
+
|
91 |
+
# Compute the next x
|
92 |
+
bt = extract(b, t, xt.shape)
|
93 |
+
at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)
|
94 |
+
|
95 |
+
if t_next.sum() == -t_next.shape[0]:
|
96 |
+
at_next = torch.ones_like(at)
|
97 |
+
else:
|
98 |
+
at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)
|
99 |
+
|
100 |
+
xt_next = torch.zeros_like(xt)
|
101 |
+
if sampling_type == 'ddpm':
|
102 |
+
weight = bt / torch.sqrt(1 - at)
|
103 |
+
|
104 |
+
mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
|
105 |
+
noise = torch.randn_like(xt)
|
106 |
+
mask = 1 - (t == 0).float()
|
107 |
+
mask = mask.reshape((xt.shape[0],) + (1,) * (len(xt.shape) - 1))
|
108 |
+
xt_next = mean + mask * torch.exp(0.5 * logvar) * noise
|
109 |
+
xt_next = xt_next.float()
|
110 |
+
|
111 |
+
elif sampling_type == 'ddim':
|
112 |
+
# print("check ddim incersion:", et.size())
|
113 |
+
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
|
114 |
+
if eta == 0:
|
115 |
+
xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et
|
116 |
+
elif at > (at_next):
|
117 |
+
print('Inversion process is only possible with eta = 0')
|
118 |
+
raise ValueError
|
119 |
+
else:
|
120 |
+
c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()
|
121 |
+
c2 = ((1 - at_next) - c1 ** 2).sqrt()
|
122 |
+
xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * torch.randn_like(xt)
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
# print("check out:", xt_next.size(), mid_h.size(), x0_t.size())
|
127 |
+
if out_x0_t == True:
|
128 |
+
# print("three output!")
|
129 |
+
return xt_next, x0_t, mid_h
|
130 |
+
else:
|
131 |
+
# print("two output!")
|
132 |
+
return xt_next, mid_h
|
133 |
+
|
134 |
+
|
utils/prepare_lmdb_data.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/prepare_data.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
from io import BytesIO
|
7 |
+
import multiprocessing
|
8 |
+
from functools import partial
|
9 |
+
import os, glob, sys
|
10 |
+
|
11 |
+
from PIL import Image
|
12 |
+
import lmdb
|
13 |
+
from tqdm import tqdm
|
14 |
+
from torchvision import datasets
|
15 |
+
from torchvision.transforms import functional as trans_fn
|
16 |
+
|
17 |
+
|
18 |
+
def resize_and_convert(img, size, resample, quality=100):
|
19 |
+
img = trans_fn.resize(img, (size, size), resample)
|
20 |
+
# img = trans_fn.center_crop(img, size)
|
21 |
+
buffer = BytesIO()
|
22 |
+
img.save(buffer, format="jpeg", quality=quality)
|
23 |
+
val = buffer.getvalue()
|
24 |
+
|
25 |
+
return val
|
26 |
+
|
27 |
+
|
28 |
+
def resize_multiple(
|
29 |
+
img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
|
30 |
+
):
|
31 |
+
imgs = []
|
32 |
+
|
33 |
+
for size in sizes:
|
34 |
+
imgs.append(resize_and_convert(img, size, resample, quality))
|
35 |
+
|
36 |
+
return imgs
|
37 |
+
|
38 |
+
|
39 |
+
def resize_worker(img_file, sizes, resample):
|
40 |
+
i, file, img_id = img_file
|
41 |
+
# print("check resize_worker:", i, file, img_id)
|
42 |
+
img = Image.open(file)
|
43 |
+
img = img.convert("RGB")
|
44 |
+
out = resize_multiple(img, sizes=sizes, resample=resample)
|
45 |
+
|
46 |
+
return i, out, img_id
|
47 |
+
|
48 |
+
|
49 |
+
def file_to_list(filename):
|
50 |
+
with open(filename, encoding='utf-8') as f:
|
51 |
+
files = f.readlines()
|
52 |
+
files = [f.rstrip() for f in files]
|
53 |
+
return files
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def prepare(
|
58 |
+
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
|
59 |
+
):
|
60 |
+
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
|
61 |
+
files = sorted(dataset.imgs, key=lambda x: x[0])
|
62 |
+
files = [(i, file, file.split('/')[-1].split('.')[0]) for i, (file, label) in enumerate(files)]
|
63 |
+
total = 0
|
64 |
+
|
65 |
+
with multiprocessing.Pool(n_worker) as pool:
|
66 |
+
for i, imgs, img_id in tqdm(pool.imap_unordered(resize_fn, files)):
|
67 |
+
key_label = f"{str(i).zfill(5)}".encode("utf-8")
|
68 |
+
for size, img in zip(sizes, imgs):
|
69 |
+
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
|
70 |
+
with env.begin(write=True) as txn:
|
71 |
+
txn.put(key, img)
|
72 |
+
txn.put(key_label, str(img_id).encode("utf-8"))
|
73 |
+
|
74 |
+
total += 1
|
75 |
+
|
76 |
+
with env.begin(write=True) as txn:
|
77 |
+
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
|
78 |
+
|
79 |
+
|
80 |
+
def prepare_attr(
|
81 |
+
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, label_attr='gender'
|
82 |
+
):
|
83 |
+
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
|
84 |
+
files = sorted(dataset.imgs, key=lambda x: x[0])
|
85 |
+
attr_file_path = '/n/fs/yz-diff/inversion/list_attr_celeba.txt'
|
86 |
+
labels = file_to_list(attr_file_path)
|
87 |
+
attr_dict = {}
|
88 |
+
files_attr = []
|
89 |
+
for i, (file, split) in enumerate(files):
|
90 |
+
img_id = int(file.split('/')[-1].split('.')[0])
|
91 |
+
# print("check i, file, and split:", i, file, split, img_id)
|
92 |
+
attr_label = labels[img_id-1].split()
|
93 |
+
label = int(attr_label[21])
|
94 |
+
# print("check attr_label:", attr_label, len(attr_label), label)
|
95 |
+
files_attr.append((i, file, label))
|
96 |
+
# exit()
|
97 |
+
|
98 |
+
files = files_attr
|
99 |
+
# files = [(i, file) for i, (file, label) in enumerate(files)]
|
100 |
+
total = 0
|
101 |
+
|
102 |
+
|
103 |
+
with multiprocessing.Pool(n_worker) as pool:
|
104 |
+
for i, imgs, label in tqdm(pool.imap_unordered(resize_fn, files)):
|
105 |
+
# print("check i, imgs, label:", label)
|
106 |
+
for size, img in zip(sizes, imgs):
|
107 |
+
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
|
108 |
+
key_label = f"{'label'}-{str(i).zfill(5)}".encode("utf-8")
|
109 |
+
|
110 |
+
with env.begin(write=True) as txn:
|
111 |
+
txn.put(key, img)
|
112 |
+
txn.put(key_label, str(label).encode("utf-8"))
|
113 |
+
|
114 |
+
total += 1
|
115 |
+
|
116 |
+
with env.begin(write=True) as txn:
|
117 |
+
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
parser = argparse.ArgumentParser()
|
122 |
+
parser.add_argument("--out", type=str)
|
123 |
+
parser.add_argument("--size", type=str, default="128,256,512,1024")
|
124 |
+
parser.add_argument("--n_worker", type=int, default=5)
|
125 |
+
parser.add_argument("--resample", type=str, default="bilinear")
|
126 |
+
parser.add_argument("--attr", type=str)
|
127 |
+
parser.add_argument("path", type=str)
|
128 |
+
|
129 |
+
args = parser.parse_args()
|
130 |
+
|
131 |
+
resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
|
132 |
+
resample = resample_map[args.resample]
|
133 |
+
|
134 |
+
sizes = [int(s.strip()) for s in args.size.split(",")]
|
135 |
+
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
|
136 |
+
|
137 |
+
imgset = datasets.ImageFolder(args.path)
|
138 |
+
|
139 |
+
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
|
140 |
+
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
|
utils/text_dic.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SRC_TRG_TXT_DIC = {
|
2 |
+
# Human face
|
3 |
+
'tanned': (['face'],
|
4 |
+
['tanned face']),
|
5 |
+
'pale': (['face'],
|
6 |
+
['pale face']),
|
7 |
+
'makeup': (['person'],
|
8 |
+
['person with makeup']),
|
9 |
+
'no_makeup': (['person'],
|
10 |
+
['person without makeup']),
|
11 |
+
'old': (['person'],
|
12 |
+
['old person']),
|
13 |
+
'young': (['person'],
|
14 |
+
['young person']),
|
15 |
+
'beards': (['person'],
|
16 |
+
['person with beards']),
|
17 |
+
'angry': (['face'],
|
18 |
+
['angry face']),
|
19 |
+
'surprised': (['face'],
|
20 |
+
['surprised face']),
|
21 |
+
'smiling': (['face'],
|
22 |
+
['smiling face']),
|
23 |
+
'blond_hair': (['person'],
|
24 |
+
['person with blond hair']),
|
25 |
+
'red_hair': (['person'],
|
26 |
+
['person with red hair']),
|
27 |
+
'grey_hair': (['person'],
|
28 |
+
['person with red hair']),
|
29 |
+
'curly_hair': (['person'],
|
30 |
+
['person with curly hair']),
|
31 |
+
|
32 |
+
'nicolas': (['Person'],
|
33 |
+
['Nicolas Cage']),
|
34 |
+
'zuckerberg': (['Person'],
|
35 |
+
['Mark Zuckerberg']),
|
36 |
+
'benedict': (['Person'],
|
37 |
+
['Benedict Cumberbatch']),
|
38 |
+
'gogh': (['photo'],
|
39 |
+
['painting by Gogh']),
|
40 |
+
'frida': (['photo'],
|
41 |
+
['self-portrait by Frida Kahlo']),
|
42 |
+
'modigliani': (['photo'],
|
43 |
+
['Painting in Modigliani style']),
|
44 |
+
'sketch': (['photo'],
|
45 |
+
['sketch']),
|
46 |
+
'watercolor': (['photo'],
|
47 |
+
['Watercolor Art with Thick Brushstrokes']),
|
48 |
+
'elf': (['Human'],
|
49 |
+
['Tolkien elf']),
|
50 |
+
'super_saiyan': (['Human'],
|
51 |
+
['Super saiyan']),
|
52 |
+
'pixar': (['Human'],
|
53 |
+
['3D render in the style of Pixar']),
|
54 |
+
'neanderthal': (['Human'],
|
55 |
+
['Neanderthal']),
|
56 |
+
'zombie': (['Human'],
|
57 |
+
['Zombie']),
|
58 |
+
'jocker': (['Human'],
|
59 |
+
['The Jocker']),
|
60 |
+
|
61 |
+
|
62 |
+
# Dog face
|
63 |
+
'dog_nicolas': (['Dog'],
|
64 |
+
['Nicolas Cage']),
|
65 |
+
'dog_yorkshire': (['Dog'],
|
66 |
+
['Yorkshire Terrier']),
|
67 |
+
'dog_smiling': (['Dog'],
|
68 |
+
['Smiling Dog']),
|
69 |
+
'dog_zombie': (['Dog'],
|
70 |
+
['Zombie']),
|
71 |
+
'dog_super_saiyan': (['Dog'],
|
72 |
+
['Super saiyan']),
|
73 |
+
'dog_venom': (['Dog'],
|
74 |
+
['Venom']),
|
75 |
+
'dog_bear': (['Dog'],
|
76 |
+
['Bear']),
|
77 |
+
'dog_fox': (['Dog'],
|
78 |
+
['Fox']),
|
79 |
+
'dog_wolf': (['Dog'],
|
80 |
+
['Wolf']),
|
81 |
+
'dog_hamster': (['Dog'],
|
82 |
+
['Hamster']),
|
83 |
+
|
84 |
+
|
85 |
+
# Church
|
86 |
+
'church_snow': (['Church'],
|
87 |
+
['Snow Coverd Church']),
|
88 |
+
'church_night': (['Church'],
|
89 |
+
['Church at night']),
|
90 |
+
'church_red_brick': (['Church'],
|
91 |
+
['Red brick wall Church']),
|
92 |
+
'church_golden': (['Church'],
|
93 |
+
['Golden Church']),
|
94 |
+
'church_wooden_house': (['Church'],
|
95 |
+
['Wooden House']),
|
96 |
+
'church_gothic': (['Church'],
|
97 |
+
['Gothic Church']),
|
98 |
+
'church_ancient_tower': (['Church'],
|
99 |
+
['Ancient traditional Asian tower']),
|
100 |
+
'church_temple': (['Church'],
|
101 |
+
['Temple']),
|
102 |
+
'church_factory': (['church'],
|
103 |
+
['factory with chimneys']),
|
104 |
+
'church_department_store': (['church'],
|
105 |
+
['department store']),
|
106 |
+
|
107 |
+
|
108 |
+
# Bedroom
|
109 |
+
'bedroom_blue': (['Bedroom'],
|
110 |
+
['Blue tone Bedroom']),
|
111 |
+
'bedroom_green': (['Bedroom'],
|
112 |
+
['Green tone Bedroom']),
|
113 |
+
'bedroom_golden': (['Bedroom'],
|
114 |
+
['Golden Bedroom']),
|
115 |
+
'bedroom_princess': (['Bedroom'],
|
116 |
+
['Princess Bedroom']),
|
117 |
+
'bedroom_palace': (['Bedroom'],
|
118 |
+
['Palace Bedroom']),
|
119 |
+
'bedroom_wooden': (['Bedroom'],
|
120 |
+
['Wooden Bedroom']),
|
121 |
+
|
122 |
+
|
123 |
+
}
|
utils/text_templates.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
imagenet_templates = [
|
2 |
+
'a bad photo of a {}.',
|
3 |
+
'a sculpture of a {}.',
|
4 |
+
'a photo of the hard to see {}.',
|
5 |
+
'a low resolution photo of the {}.',
|
6 |
+
'a rendering of a {}.',
|
7 |
+
'graffiti of a {}.',
|
8 |
+
'a bad photo of the {}.',
|
9 |
+
'a cropped photo of the {}.',
|
10 |
+
'a tattoo of a {}.',
|
11 |
+
'the embroidered {}.',
|
12 |
+
'a photo of a hard to see {}.',
|
13 |
+
'a bright photo of a {}.',
|
14 |
+
'a photo of a clean {}.',
|
15 |
+
'a photo of a dirty {}.',
|
16 |
+
'a dark photo of the {}.',
|
17 |
+
'a drawing of a {}.',
|
18 |
+
'a photo of my {}.',
|
19 |
+
'the plastic {}.',
|
20 |
+
'a photo of the cool {}.',
|
21 |
+
'a close-up photo of a {}.',
|
22 |
+
'a black and white photo of the {}.',
|
23 |
+
'a painting of the {}.',
|
24 |
+
'a painting of a {}.',
|
25 |
+
'a pixelated photo of the {}.',
|
26 |
+
'a sculpture of the {}.',
|
27 |
+
'a bright photo of the {}.',
|
28 |
+
'a cropped photo of a {}.',
|
29 |
+
'a plastic {}.',
|
30 |
+
'a photo of the dirty {}.',
|
31 |
+
'a jpeg corrupted photo of a {}.',
|
32 |
+
'a blurry photo of the {}.',
|
33 |
+
'a photo of the {}.',
|
34 |
+
'a good photo of the {}.',
|
35 |
+
'a rendering of the {}.',
|
36 |
+
'a {} in a video game.',
|
37 |
+
'a photo of one {}.',
|
38 |
+
'a doodle of a {}.',
|
39 |
+
'a close-up photo of the {}.',
|
40 |
+
'a photo of a {}.',
|
41 |
+
'the origami {}.',
|
42 |
+
'the {} in a video game.',
|
43 |
+
'a sketch of a {}.',
|
44 |
+
'a doodle of the {}.',
|
45 |
+
'a origami {}.',
|
46 |
+
'a low resolution photo of a {}.',
|
47 |
+
'the toy {}.',
|
48 |
+
'a rendition of the {}.',
|
49 |
+
'a photo of the clean {}.',
|
50 |
+
'a photo of a large {}.',
|
51 |
+
'a rendition of a {}.',
|
52 |
+
'a photo of a nice {}.',
|
53 |
+
'a photo of a weird {}.',
|
54 |
+
'a blurry photo of a {}.',
|
55 |
+
'a cartoon {}.',
|
56 |
+
'art of a {}.',
|
57 |
+
'a sketch of the {}.',
|
58 |
+
'a embroidered {}.',
|
59 |
+
'a pixelated photo of a {}.',
|
60 |
+
'itap of the {}.',
|
61 |
+
'a jpeg corrupted photo of the {}.',
|
62 |
+
'a good photo of a {}.',
|
63 |
+
'a plushie {}.',
|
64 |
+
'a photo of the nice {}.',
|
65 |
+
'a photo of the small {}.',
|
66 |
+
'a photo of the weird {}.',
|
67 |
+
'the cartoon {}.',
|
68 |
+
'art of the {}.',
|
69 |
+
'a drawing of the {}.',
|
70 |
+
'a photo of the large {}.',
|
71 |
+
'a black and white photo of a {}.',
|
72 |
+
'the plushie {}.',
|
73 |
+
'a dark photo of a {}.',
|
74 |
+
'itap of a {}.',
|
75 |
+
'graffiti of the {}.',
|
76 |
+
'a toy {}.',
|
77 |
+
'itap of my {}.',
|
78 |
+
'a photo of a cool {}.',
|
79 |
+
'a photo of a small {}.',
|
80 |
+
'a tattoo of the {}.',
|
81 |
+
]
|
82 |
+
|
83 |
+
part_templates = [
|
84 |
+
'the paw of a {}.',
|
85 |
+
'the nose of a {}.',
|
86 |
+
'the eye of the {}.',
|
87 |
+
'the ears of a {}.',
|
88 |
+
'an eye of a {}.',
|
89 |
+
'the tongue of a {}.',
|
90 |
+
'the fur of the {}.',
|
91 |
+
'colorful {} fur.',
|
92 |
+
'a snout of a {}.',
|
93 |
+
'the teeth of the {}.',
|
94 |
+
'the {}s fangs.',
|
95 |
+
'a claw of the {}.',
|
96 |
+
'the face of the {}',
|
97 |
+
'a neck of a {}',
|
98 |
+
'the head of the {}',
|
99 |
+
]
|
100 |
+
|
101 |
+
imagenet_templates_small = [
|
102 |
+
'a photo of a {}.',
|
103 |
+
'a rendering of a {}.',
|
104 |
+
'a cropped photo of the {}.',
|
105 |
+
'the photo of a {}.',
|
106 |
+
'a photo of a clean {}.',
|
107 |
+
'a photo of a dirty {}.',
|
108 |
+
'a dark photo of the {}.',
|
109 |
+
'a photo of my {}.',
|
110 |
+
'a photo of the cool {}.',
|
111 |
+
'a close-up photo of a {}.',
|
112 |
+
'a bright photo of the {}.',
|
113 |
+
'a cropped photo of a {}.',
|
114 |
+
'a photo of the {}.',
|
115 |
+
'a good photo of the {}.',
|
116 |
+
'a photo of one {}.',
|
117 |
+
'a close-up photo of the {}.',
|
118 |
+
'a rendition of the {}.',
|
119 |
+
'a photo of the clean {}.',
|
120 |
+
'a rendition of a {}.',
|
121 |
+
'a photo of a nice {}.',
|
122 |
+
'a good photo of a {}.',
|
123 |
+
'a photo of the nice {}.',
|
124 |
+
'a photo of the small {}.',
|
125 |
+
'a photo of the weird {}.',
|
126 |
+
'a photo of the large {}.',
|
127 |
+
'a photo of a cool {}.',
|
128 |
+
'a photo of a small {}.',
|
129 |
+
]
|