szyezhu commited on
Commit
019d164
·
1 Parent(s): 2066db5

Upload 46 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/main.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/non_cherry_picky.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/strength_space.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/unconditional.png filter=lfs diff=lfs merge=lfs -text
assets/geo.png ADDED
assets/geo_white.png ADDED
assets/main.png ADDED

Git LFS Details

  • SHA256: 15eb1da3306b6965c62fb13f978e8074eb2b52bd94ee523b1a4262995b702da0
  • Pointer size: 132 Bytes
  • Size of remote file: 3.92 MB
assets/mixing_traj.png ADDED
assets/mixing_traj_white.png ADDED
assets/non_cherry_picky.png ADDED

Git LFS Details

  • SHA256: 9717a8a7412b4da6a76f13993d9422db7bb76afc24ca2f87c82956da4c177da8
  • Pointer size: 132 Bytes
  • Size of remote file: 9.45 MB
assets/strength_space.png ADDED

Git LFS Details

  • SHA256: 6be7e9920de0585cd4e246c9246e1e71b94b673b23d7b86dfe436f84bbf66ad7
  • Pointer size: 132 Bytes
  • Size of remote file: 2.27 MB
assets/teaser.png ADDED

Git LFS Details

  • SHA256: e8fc6904459ed380844bb2d00f8a9b0a6c34c15b575ed0c939eaeb9fe7d47692
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
assets/unconditional.png ADDED

Git LFS Details

  • SHA256: d3cc39beb6510e1b0316b3f93cb9e4880c1a4c2b1cfe12e1d6707470663943f7
  • Pointer size: 132 Bytes
  • Size of remote file: 4.79 MB
boundarydiffusion.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from glob import glob
3
+ from tqdm import tqdm
4
+ import os
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ import torch
9
+ from torch import nn
10
+ import torchvision.utils as tvu
11
+ from sklearn import svm
12
+ import pickle
13
+ import torch.optim as optim
14
+
15
+ from models.ddpm.diffusion import DDPM
16
+ from models.improved_ddpm.script_util import i_DDPM
17
+ from utils.text_dic import SRC_TRG_TXT_DIC
18
+ from utils.diffusion_utils import get_beta_schedule, denoising_step
19
+ from datasets.data_utils import get_dataset, get_dataloader
20
+ from configs.paths_config import DATASET_PATHS, MODEL_PATHS, HYBRID_MODEL_PATHS, HYBRID_CONFIG
21
+ from datasets.imagenet_dic import IMAGENET_DIC
22
+ from utils.align_utils import run_alignment
23
+ from utils.distance_utils import euclidean_distance, cosine_similarity
24
+
25
+
26
+
27
+ def compute_radius(x):
28
+ x = torch.pow(x, 2)
29
+ r = torch.sum(x)
30
+ r = torch.sqrt(r)
31
+ return r
32
+
33
+
34
+
35
+
36
+ class BoundaryDiffusion(object):
37
+ def __init__(self, args, config, device=None):
38
+ self.args = args
39
+ self.config = config
40
+ if device is None:
41
+ device = torch.device(
42
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
43
+ self.device = device
44
+
45
+ self.model_var_type = config.model.var_type
46
+ betas = get_beta_schedule(
47
+ beta_start=config.diffusion.beta_start,
48
+ beta_end=config.diffusion.beta_end,
49
+ num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps
50
+ )
51
+ self.betas = torch.from_numpy(betas).float().to(self.device)
52
+ self.num_timesteps = betas.shape[0]
53
+
54
+ alphas = 1.0 - betas
55
+ alphas_cumprod = np.cumprod(alphas, axis=0)
56
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
57
+ posterior_variance = betas * \
58
+ (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
59
+ if self.model_var_type == "fixedlarge":
60
+ self.logvar = np.log(np.append(posterior_variance[1], betas[1:]))
61
+
62
+ elif self.model_var_type == 'fixedsmall':
63
+ self.logvar = np.log(np.maximum(posterior_variance, 1e-20))
64
+
65
+ if self.args.edit_attr is None:
66
+ self.src_txts = self.args.src_txts
67
+ self.trg_txts = self.args.trg_txts
68
+ else:
69
+ self.src_txts = SRC_TRG_TXT_DIC[self.args.edit_attr][0]
70
+ self.trg_txts = SRC_TRG_TXT_DIC[self.args.edit_attr][1]
71
+
72
+
73
+ def unconditional(self):
74
+ print(self.args.exp)
75
+
76
+ # ----------- Model -----------#
77
+ if self.config.data.dataset == "LSUN":
78
+ if self.config.data.category == "bedroom":
79
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
80
+ elif self.config.data.category == "church_outdoor":
81
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
82
+ elif self.config.data.dataset == "CelebA_HQ":
83
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
84
+ elif self.config.data.dataset == "AFHQ":
85
+ pass
86
+ else:
87
+ raise ValueError
88
+
89
+ if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:
90
+ model = DDPM(self.config)
91
+ if self.args.model_path:
92
+ init_ckpt = torch.load(self.args.model_path)
93
+ else:
94
+ init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)
95
+ learn_sigma = False
96
+ print("Original diffusion Model loaded.")
97
+ elif self.config.data.dataset in ["FFHQ", "AFHQ"]:
98
+ model = i_DDPM(self.config.data.dataset)
99
+ if self.args.model_path:
100
+ init_ckpt = torch.load(self.args.model_path)
101
+ else:
102
+ init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])
103
+ learn_sigma = True
104
+ print("Improved diffusion Model loaded.")
105
+ else:
106
+ print('Not implemented dataset')
107
+ raise ValueError
108
+ model.load_state_dict(init_ckpt)
109
+ model.to(self.device)
110
+ model = torch.nn.DataParallel(model)
111
+ model.eval()
112
+
113
+ # ----------- Precompute Latents -----------#
114
+ seq_inv = np.linspace(0, 1, 999) * 999
115
+ seq_inv = [int(s) for s in list(seq_inv)]
116
+ seq_inv_next = [-1] + list(seq_inv[:-1])
117
+
118
+ ###---- boundaries---####
119
+ # ---------- Load boundary ----------#
120
+ classifier = pickle.load(open('./boundary/smile_boundary_h.sav', 'rb'))
121
+ a = classifier.coef_.reshape(1, 512*8*8).astype(np.float32)
122
+ # a = a / np.linalg.norm(a)
123
+
124
+ z_classifier = pickle.load(open('./boundary/smile_boundary_z.sav', 'rb'))
125
+ z_a = z_classifier.coef_.reshape(1, 3*256*256).astype(np.float32)
126
+ z_a = z_a / np.linalg.norm(z_a) # normalized boundary
127
+
128
+ x_lat = torch.randn(1, 3, 256, 256, device=self.device)
129
+ n = 1
130
+ print("get the sampled latent encodings x_T!")
131
+
132
+ with torch.no_grad():
133
+ with tqdm(total=len(seq_inv), desc=f"Generative process") as progress_bar:
134
+ for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):
135
+ t = (torch.ones(n) * i).to(self.device)
136
+ t_next = (torch.ones(n) * j).to(self.device)
137
+ # print("check t and t_next:", t, t_next)
138
+ if t == self.args.t_0:
139
+ break
140
+ x_lat, h_lat = denoising_step(x_lat, t=t, t_next=t_next, models=model,
141
+ logvars=self.logvar,
142
+ # sampling_type=self.args.sample_type,
143
+ sampling_type='ddim',
144
+ b=self.betas,
145
+ eta=0.0,
146
+ learn_sigma=learn_sigma,
147
+ )
148
+
149
+ progress_bar.update(1)
150
+
151
+
152
+
153
+
154
+ # ----- Editing space ------ #
155
+ start_distance = self.args.start_distance
156
+ end_distance = self.args.end_distance
157
+ edit_img_number = self.args.edit_img_number
158
+ linspace = np.linspace(start_distance, end_distance, edit_img_number)
159
+ latent_code = h_lat.cpu().view(1,-1).numpy()
160
+ linspace = linspace - latent_code.dot(a.T)
161
+ linspace = linspace.reshape(-1, 1).astype(np.float32)
162
+ edit_h_seq = latent_code + linspace * a
163
+
164
+
165
+ z_linspace = np.linspace(start_distance, end_distance, edit_img_number)
166
+ z_latent_code = x_lat.cpu().view(1,-1).numpy()
167
+ z_linspace = z_linspace - z_latent_code.dot(z_a.T)
168
+ z_linspace = z_linspace.reshape(-1, 1).astype(np.float32)
169
+ edit_z_seq = z_latent_code + z_linspace * z_a
170
+
171
+
172
+ for k in range(edit_img_number):
173
+ time_in_start = time.time()
174
+ seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0
175
+ seq_inv = [int(s) for s in list(seq_inv)]
176
+ seq_inv_next = [-1] + list(seq_inv[:-1])
177
+
178
+ with tqdm(total=len(seq_inv), desc="Generative process {}".format(it)) as progress_bar:
179
+ edit_h = torch.from_numpy(edit_h_seq[k]).to(self.device).view(-1, 512, 8, 8)
180
+ edit_z = torch.from_numpy(edit_z_seq[k]).to(self.device).view(-1, 3, 256, 256)
181
+ for i, j in zip(reversed(seq_inv), reversed(seq_inv_next)):
182
+ t = (torch.ones(n) * i).to(self.device)
183
+ t_next = (torch.ones(n) * j).to(self.device)
184
+ edit_z, edit_h = denoising_step(edit_z, t=t, t_next=t_next, models=model,
185
+ logvars=self.logvar,
186
+ sampling_type=self.args.sample_type,
187
+ b=self.betas,
188
+ eta = 1.0,
189
+ learn_sigma=learn_sigma,
190
+ ratio=self.args.model_ratio,
191
+ hybrid=self.args.hybrid_noise,
192
+ hybrid_config=HYBRID_CONFIG,
193
+ edit_h=edit_h,
194
+ )
195
+
196
+ save_edit = "unconditioned_smile_"+str(k)+".png"
197
+ tvu.save_image((edit_z + 1) * 0.5, os.path.join("edit_output",save_edit))
198
+ time_in_end = time.time()
199
+ print(f"Editing for 1 image takes {time_in_end - time_in_start:.4f}s")
200
+ return
201
+
202
+
203
+ def radius(self):
204
+ print(self.args.exp)
205
+
206
+ # ----------- Model -----------#
207
+ if self.config.data.dataset == "LSUN":
208
+ if self.config.data.category == "bedroom":
209
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
210
+ elif self.config.data.category == "church_outdoor":
211
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
212
+ elif self.config.data.dataset == "CelebA_HQ":
213
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
214
+ elif self.config.data.dataset == "AFHQ":
215
+ pass
216
+ else:
217
+ raise ValueError
218
+
219
+ if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:
220
+ model = DDPM(self.config)
221
+ if self.args.model_path:
222
+ init_ckpt = torch.load(self.args.model_path)
223
+ else:
224
+ init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)
225
+ learn_sigma = False
226
+ print("Original diffusion Model loaded.")
227
+ elif self.config.data.dataset in ["FFHQ", "AFHQ"]:
228
+ model = i_DDPM(self.config.data.dataset)
229
+ if self.args.model_path:
230
+ init_ckpt = torch.load(self.args.model_path)
231
+ else:
232
+ init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])
233
+ learn_sigma = True
234
+ print("Improved diffusion Model loaded.")
235
+ else:
236
+ print('Not implemented dataset')
237
+ raise ValueError
238
+ model.load_state_dict(init_ckpt)
239
+ model.to(self.device)
240
+ model = torch.nn.DataParallel(model)
241
+ model.eval()
242
+
243
+
244
+ # ---------- Prepare the seq --------- #
245
+
246
+ # seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0
247
+ seq_inv = np.linspace(0, 1, 999) * 999
248
+ seq_inv = [int(s) for s in list(seq_inv)]
249
+ seq_inv_next = [-1] + list(seq_inv[:-1])
250
+
251
+ n = 1
252
+ with torch.no_grad():
253
+ er = 0
254
+ x_rand = torch.randn(100, 3, 256, 256, device=self.device)
255
+ for idx in range(100):
256
+ x = x_rand[idx, :, :, :].unsqueeze(0)
257
+
258
+ with tqdm(total=len(seq_inv), desc=f"Generative process") as progress_bar:
259
+ for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):
260
+ t = (torch.ones(n) * i).to(self.device)
261
+ t_next = (torch.ones(n) * j).to(self.device)
262
+ if t == 500:
263
+ break
264
+ x, _ = denoising_step(x, t=t, t_next=t_next, models=model,
265
+ logvars=self.logvar,
266
+ # sampling_type=self.args.sample_type,
267
+ sampling_type='ddim',
268
+ b=self.betas,
269
+ eta=0.0,
270
+ learn_sigma=learn_sigma,
271
+ )
272
+
273
+ progress_bar.update(1)
274
+ r_x = compute_radius(x)
275
+
276
+ er += r_x
277
+ print("Check radius at step :", er/100)
278
+
279
+
280
+ return
281
+
282
+
283
+
284
+
285
+
286
+
287
+ def boundary_search(self):
288
+ print(self.args.exp)
289
+
290
+ # ----------- Model -----------#
291
+ if self.config.data.dataset == "LSUN":
292
+ if self.config.data.category == "bedroom":
293
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
294
+ elif self.config.data.category == "church_outdoor":
295
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
296
+ elif self.config.data.dataset == "CelebA_HQ":
297
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
298
+ elif self.config.data.dataset == "AFHQ":
299
+ pass
300
+ else:
301
+ raise ValueError
302
+
303
+ if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:
304
+ model = DDPM(self.config)
305
+ if self.args.model_path:
306
+ init_ckpt = torch.load(self.args.model_path)
307
+ else:
308
+ init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)
309
+ learn_sigma = False
310
+ print("Original diffusion Model loaded.")
311
+ elif self.config.data.dataset in ["FFHQ", "AFHQ"]:
312
+ model = i_DDPM(self.config.data.dataset)
313
+ if self.args.model_path:
314
+ init_ckpt = torch.load(self.args.model_path)
315
+ else:
316
+ init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])
317
+ learn_sigma = True
318
+ print("Improved diffusion Model loaded.")
319
+ else:
320
+ print('Not implemented dataset')
321
+ raise ValueError
322
+ model.load_state_dict(init_ckpt)
323
+ model.to(self.device)
324
+ model = torch.nn.DataParallel(model)
325
+ model.eval()
326
+
327
+
328
+ # ----------- Precompute Latents -----------#
329
+ print("Prepare identity latent")
330
+ seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0
331
+ seq_inv = [int(s) for s in list(seq_inv)]
332
+ seq_inv_next = [-1] + list(seq_inv[:-1])
333
+
334
+
335
+ n = self.args.bs_train
336
+ img_lat_pairs_dic = {}
337
+ for mode in ['train', 'test']:
338
+ img_lat_pairs = []
339
+ pairs_path = os.path.join('precomputed/',
340
+ f'{self.config.data.category}_{mode}_t{self.args.t_0}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_pairs.pth')
341
+ print(pairs_path)
342
+ if os.path.exists(pairs_path):
343
+ print(f'{mode} pairs exists')
344
+ img_lat_pairs_dic[mode] = torch.load(pairs_path)
345
+ for step, (x0, x_id, x_lat, mid_h, label) in enumerate(img_lat_pairs_dic[mode]):
346
+ tvu.save_image((x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'{mode}_{step}_0_orig.png'))
347
+ tvu.save_image((x_id + 1) * 0.5, os.path.join(self.args.image_folder,
348
+ f'{mode}_{step}_1_rec_ninv{self.args.n_inv_step}.png'))
349
+ if step == self.args.n_precomp_img - 1:
350
+ break
351
+ continue
352
+ else:
353
+ train_dataset, test_dataset = get_dataset(self.config.data.dataset, DATASET_PATHS, self.config)
354
+ loader_dic = get_dataloader(train_dataset, test_dataset, bs_train=self.args.bs_train,
355
+ num_workers=self.config.data.num_workers)
356
+ loader = loader_dic[mode]
357
+
358
+ for step, (img, label) in enumerate(loader):
359
+ # for step, img in enumerate(loader):
360
+
361
+ x0 = img.to(self.config.device)
362
+ tvu.save_image((x0 + 1) * 0.5, os.path.join(self.args.image_folder, f'{mode}_{step}_0_orig.png'))
363
+
364
+ x = x0.clone()
365
+ model.eval()
366
+ label = label.to(self.config.device)
367
+
368
+ # print("check x and label:", x.size(), label)
369
+
370
+
371
+
372
+ with torch.no_grad():
373
+ with tqdm(total=len(seq_inv), desc=f"Inversion process {mode} {step}") as progress_bar:
374
+ for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))):
375
+ t = (torch.ones(n) * i).to(self.device)
376
+ t_prev = (torch.ones(n) * j).to(self.device)
377
+
378
+ x, mid_h_g = denoising_step(x, t=t, t_next=t_prev, models=model,
379
+ logvars=self.logvar,
380
+ sampling_type='ddim',
381
+ b=self.betas,
382
+ eta=0,
383
+ learn_sigma=learn_sigma)
384
+
385
+ progress_bar.update(1)
386
+ x_lat = x.clone()
387
+ tvu.save_image((x_lat + 1) * 0.5, os.path.join(self.args.image_folder,
388
+ f'{mode}_{step}_1_lat_ninv{self.args.n_inv_step}.png'))
389
+
390
+ with tqdm(total=len(seq_inv), desc=f"Generative process {mode} {step}") as progress_bar:
391
+ for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):
392
+ t = (torch.ones(n) * i).to(self.device)
393
+ t_next = (torch.ones(n) * j).to(self.device)
394
+
395
+ x, _ = denoising_step(x, t=t, t_next=t_next, models=model,
396
+ logvars=self.logvar,
397
+ sampling_type=self.args.sample_type,
398
+ b=self.betas,
399
+ learn_sigma=learn_sigma,
400
+ # edit_h = mid_h,
401
+ )
402
+
403
+ progress_bar.update(1)
404
+
405
+ img_lat_pairs.append([x0, x.detach().clone(), x_lat.detach().clone(), mid_h_g.detach().clone(), label])
406
+ # img_lat_pairs.append([x0, x.detach().clone(), x_lat.detach().clone(), mid_h_g.detach().clone()])
407
+ tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
408
+ f'{mode}_{step}_1_rec_ninv{self.args.n_inv_step}.png'))
409
+ if step == self.args.n_precomp_img - 1:
410
+ break
411
+
412
+ img_lat_pairs_dic[mode] = img_lat_pairs
413
+ pairs_path = os.path.join('precomputed/',
414
+ f'{self.config.data.category}_{mode}_t{self.args.t_0}_nim{self.args.n_precomp_img}_ninv{self.args.n_inv_step}_pairs.pth')
415
+ torch.save(img_lat_pairs, pairs_path)
416
+
417
+ # ----------- Training boundaries -----------#
418
+ print("Start boundary search")
419
+ print(f"Sampling type: {self.args.sample_type.upper()} with eta {self.args.eta}")
420
+ if self.args.n_train_step != 0:
421
+ seq_train = np.linspace(0, 1, self.args.n_train_step) * self.args.t_0
422
+ seq_train = [int(s) for s in list(seq_train)]
423
+ print('Uniform skip type')
424
+ else:
425
+ seq_train = list(range(self.args.t_0))
426
+ print('No skip')
427
+ seq_train_next = [-1] + list(seq_train[:-1])
428
+
429
+ seq_test = np.linspace(0, 1, self.args.n_test_step) * self.args.t_0
430
+ seq_test = [int(s) for s in list(seq_test)]
431
+ seq_test_next = [-1] + list(seq_test[:-1])
432
+
433
+
434
+ for src_txt, trg_txt in zip(self.src_txts, self.trg_txts):
435
+ print(f"CHANGE {src_txt} TO {trg_txt}")
436
+ time_in_start = time.time()
437
+
438
+ clf_h = svm.SVC(kernel='linear')
439
+ clf_z = svm.SVC(kernel='linear')
440
+ # print("clf model:",clf)
441
+
442
+ exp_id = os.path.split(self.args.exp)[-1]
443
+ save_name_h = f'boundary/{exp_id}_{trg_txt.replace(" ", "_")}_h.sav'
444
+ save_name_z = f'boundary/{exp_id}_{trg_txt.replace(" ", "_")}_z.sav'
445
+ n_train = len(img_lat_pairs_dic['train'])
446
+
447
+ train_data_z = np.empty([n_train, 3*256*256])
448
+ train_data_h = np.empty([n_train, 512*8*8])
449
+ train_label = np.empty([n_train,], dtype=int)
450
+
451
+
452
+ for step, (x0, x_id, x_lat, mid_h, label) in enumerate(img_lat_pairs_dic['train']):
453
+ train_data_h[step, :] = mid_h.view(1,-1).cpu().numpy()
454
+ train_data_z[step, :] = x_lat.view(1,-1).cpu().numpy()
455
+ train_label[step] = label.cpu().numpy()
456
+
457
+
458
+ classifier_h = clf_h.fit(train_data_h, train_label)
459
+ classifier_z = clf_z.fit(train_data_z, train_label)
460
+ print(np.shape(train_data_h), np.shape(train_data_z), np.shape(train_label))
461
+ # a = classifier.coef_.reshape(1, 512*8*8).astype(np.float32)
462
+ # a = classifier.coef_.reshape(1, 3*256*256).astype(np.float32)
463
+ # a = a / np.linalg.norm(a)
464
+ time_in_end = time.time()
465
+ print(f"Finding boundary takes {time_in_end - time_in_start:.4f}s")
466
+ print("Finishing boudary seperation!")
467
+
468
+ # boudary_save_h = 'smiling_boundary_h.sav'
469
+ # boudary_save_z = 'smiling_boundary_z.sav'
470
+ pickle.dump(classifier_h, open(save_name_h, 'wb'))
471
+ pickle.dump(classifier_z, open(save_name_z, 'wb'))
472
+
473
+ # test the accuracy ##
474
+ n_test = len(img_lat_pairs_dic['test'])
475
+ test_data_h = np.empty([n_test, 512*8*8])
476
+ test_data_z = np.empty([n_test, 3*256*256])
477
+ test_lable = np.empty([n_test,], dtype=int)
478
+ for step, (x0, x_id, x_lat, mid_h, label) in enumerate(img_lat_pairs_dic['test']):
479
+ test_data_h[step, :] = mid_h.view(1,-1).cpu().numpy()
480
+ test_data_z[step, :] = x_lat.view(1,-1).cpu().numpy()
481
+ test_lable[step] = label.cpu().numpy()
482
+ classifier_h = pickle.load(open(save_name_h, 'rb'))
483
+ classifier_z = pickle.load(open(save_name_z, 'rb'))
484
+ print("Boundary loaded!")
485
+ val_prediction_h = classifier_h.predict(test_data_h)
486
+ val_prediction_z = classifier_z.predict(test_data_z)
487
+ correct_num_h = np.sum(test_lable == val_prediction_h)
488
+ correct_num_z = np.sum(test_lable == val_prediction_z)
489
+ # print(val_prediction_h, test_lable)
490
+ print("Validation accuracy on h and z spaces:", correct_num_h/n_test, correct_num_z/n_test)
491
+ print("total training and testing", n_train, n_test)
492
+
493
+
494
+ return None
495
+
496
+
497
+
498
+
499
+ def edit_image_boundary(self):
500
+ # ----------- Data -----------#
501
+ n = self.args.bs_test
502
+
503
+
504
+ if self.args.align_face and self.config.data.dataset in ["FFHQ", "CelebA_HQ"]:
505
+ try:
506
+ img = run_alignment(self.args.img_path, output_size=self.config.data.image_size)
507
+ except:
508
+ img = Image.open(self.args.img_path).convert("RGB")
509
+ else:
510
+ img = Image.open(self.args.img_path).convert("RGB")
511
+ img = img.resize((self.config.data.image_size, self.config.data.image_size), Image.ANTIALIAS)
512
+ img = np.array(img)/255
513
+ img = torch.from_numpy(img).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(dim=0).repeat(n, 1, 1, 1)
514
+ img = img.to(self.config.device)
515
+ tvu.save_image(img, os.path.join(self.args.image_folder, f'0_orig.png'))
516
+ x0 = (img - 0.5) * 2.
517
+
518
+ # ----------- Models -----------#
519
+ if self.config.data.dataset == "LSUN":
520
+ if self.config.data.category == "bedroom":
521
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"
522
+ elif self.config.data.category == "church_outdoor":
523
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"
524
+ elif self.config.data.dataset == "CelebA_HQ":
525
+ url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"
526
+ elif self.config.data.dataset in ["FFHQ", "AFHQ", "IMAGENET"]:
527
+ pass
528
+ else:
529
+ raise ValueError
530
+
531
+ if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:
532
+ model = DDPM(self.config)
533
+ if self.args.model_path:
534
+ init_ckpt = torch.load(self.args.model_path)
535
+ else:
536
+ init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)
537
+ learn_sigma = False
538
+ print("Original diffusion Model loaded.")
539
+ elif self.config.data.dataset in ["FFHQ", "AFHQ"]:
540
+ model = i_DDPM(self.config.data.dataset)
541
+ if self.args.model_path:
542
+ init_ckpt = torch.load(self.args.model_path)
543
+ else:
544
+ init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])
545
+ learn_sigma = True
546
+ print("Improved diffusion Model loaded.")
547
+ else:
548
+ print('Not implemented dataset')
549
+ raise ValueError
550
+ model.load_state_dict(init_ckpt)
551
+ model.to(self.device)
552
+ model = torch.nn.DataParallel(model)
553
+ model.eval()
554
+
555
+ # ---------- Load boundary ----------#
556
+
557
+ boundary_h = pickle.load(open('./boundary/smile_boundary_h.sav', 'rb'))
558
+ a = boundary_h.coef_.reshape(1, 512*8*8).astype(np.float32)
559
+ a = a / np.linalg.norm(a)
560
+
561
+ boundary_z = pickle.load(open('./boundary/smile_boundary_z.sav', 'rb'))
562
+ z_a = boundary_z.coef_.reshape(1, 3*256*256).astype(np.float32)
563
+ z_a = z_a / np.linalg.norm(z_a) # normalized boundary
564
+
565
+
566
+ print("Boundary loaded! In shape:", np.shape(a), np.shape(z_a))
567
+
568
+
569
+ with torch.no_grad():
570
+ #---------------- Invert Image to Latent in case of Deterministic Inversion process -------------------#
571
+ if self.args.deterministic_inv:
572
+ x_lat_path = os.path.join(self.args.image_folder, f'x_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}.pth')
573
+ h_lat_path = os.path.join(self.args.image_folder, f'h_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}.pth')
574
+ if not os.path.exists(x_lat_path):
575
+ seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0
576
+ seq_inv = [int(s) for s in list(seq_inv)]
577
+ seq_inv_next = [-1] + list(seq_inv[:-1])
578
+
579
+ x = x0.clone()
580
+ with tqdm(total=len(seq_inv), desc=f"Inversion process ") as progress_bar:
581
+ for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))):
582
+ t = (torch.ones(n) * i).to(self.device)
583
+ t_prev = (torch.ones(n) * j).to(self.device)
584
+
585
+ x, mid_h_g = denoising_step(x, t=t, t_next=t_prev, models=model,
586
+ logvars=self.logvar,
587
+ sampling_type='ddim',
588
+ b=self.betas,
589
+ eta=0,
590
+ learn_sigma=learn_sigma,
591
+ ratio=0,
592
+ )
593
+
594
+
595
+ progress_bar.update(1)
596
+ x_lat = x.clone()
597
+ h_lat = mid_h_g.clone()
598
+ torch.save(x_lat, x_lat_path)
599
+ torch.save(h_lat, h_lat_path)
600
+
601
+ else:
602
+ print('Latent exists.')
603
+ x_lat = torch.load(x_lat_path)
604
+ h_lat = torch.load(h_lat_path)
605
+ print("Finish inversion for the given image!", h_lat.size())
606
+
607
+
608
+ # ----------- Generative Process -----------#
609
+ print(f"Sampling type: {self.args.sample_type.upper()} with eta {self.args.eta}, "
610
+ f" Steps: {self.args.n_test_step}/{self.args.t_0}")
611
+
612
+
613
+ # ----- Editing space ------ #
614
+ start_distance = self.args.start_distance
615
+ end_distance = self.args.end_distance
616
+ edit_img_number = self.args.edit_img_number
617
+ # [-100, 100]
618
+ linspace = np.linspace(start_distance, end_distance, edit_img_number)
619
+ latent_code = h_lat.cpu().view(1,-1).numpy()
620
+ linspace = linspace - latent_code.dot(a.T)
621
+ linspace = linspace.reshape(-1, 1).astype(np.float32)
622
+ edit_h_seq = latent_code + linspace * a
623
+
624
+
625
+ z_linspace = np.linspace(start_distance, end_distance, edit_img_number)
626
+ z_latent_code = x_lat.cpu().view(1,-1).numpy()
627
+ z_linspace = z_linspace - z_latent_code.dot(z_a.T)
628
+ z_linspace = z_linspace.reshape(-1, 1).astype(np.float32)
629
+ edit_z_seq = z_latent_code + z_linspace * z_a
630
+
631
+
632
+ if self.args.n_test_step != 0:
633
+ seq_test = np.linspace(0, 1, self.args.n_test_step) * self.args.t_0
634
+ seq_test = [int(s) for s in list(seq_test)]
635
+ print('Uniform skip type')
636
+ else:
637
+ seq_test = list(range(self.args.t_0))
638
+ print('No skip')
639
+ seq_test_next = [-1] + list(seq_test[:-1])
640
+
641
+ for it in range(self.args.n_iter):
642
+ if self.args.deterministic_inv:
643
+ x = x_lat.clone()
644
+ else:
645
+ e = torch.randn_like(x0)
646
+ a = (1 - self.betas).cumprod(dim=0)
647
+ x = x0 * a[self.args.t_0 - 1].sqrt() + e * (1.0 - a[self.args.t_0 - 1]).sqrt()
648
+ tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
649
+ f'1_lat_ninv{self.args.n_inv_step}.png'))
650
+
651
+
652
+ for k in range(edit_img_number):
653
+ time_in_start = time.time()
654
+
655
+ with tqdm(total=len(seq_test), desc="Generative process {}".format(it)) as progress_bar:
656
+ edit_h = torch.from_numpy(edit_h_seq[k]).to(self.device).view(-1, 512, 8, 8)
657
+ edit_z = torch.from_numpy(edit_z_seq[k]).to(self.device).view(-1, 3, 256, 256)
658
+ for i, j in zip(reversed(seq_test), reversed(seq_test_next)):
659
+ t = (torch.ones(n) * i).to(self.device)
660
+ t_next = (torch.ones(n) * j).to(self.device)
661
+
662
+ edit_z, edit_h = denoising_step(edit_z, t=t, t_next=t_next, models=model,
663
+ logvars=self.logvar,
664
+ sampling_type=self.args.sample_type,
665
+ b=self.betas,
666
+ eta = 1.0,
667
+ learn_sigma=learn_sigma,
668
+ ratio=self.args.model_ratio,
669
+ hybrid=self.args.hybrid_noise,
670
+ hybrid_config=HYBRID_CONFIG,
671
+ edit_h=edit_h,
672
+ )
673
+
674
+
675
+ x0 = x.clone()
676
+ save_edit = "edited_"+str(k)+".png"
677
+ tvu.save_image((edit_z + 1) * 0.5, os.path.join("edit_output",save_edit))
678
+ time_in_end = time.time()
679
+ print(f"Editing for 1 image takes {time_in_end - time_in_start:.4f}s")
680
+
681
+
682
+ # this is for recons
683
+ with tqdm(total=len(seq_test), desc="Generative process {}".format(it)) as progress_bar:
684
+ for i, j in zip(reversed(seq_test), reversed(seq_test_next)):
685
+ t = (torch.ones(n) * i).to(self.device)
686
+ t_next = (torch.ones(n) * j).to(self.device)
687
+ x_lat, _ = denoising_step(x_lat, t=t, t_next=t_next, models=model,
688
+ logvars=self.logvar,
689
+ sampling_type=self.args.sample_type,
690
+ b=self.betas,
691
+ # eta=self.args.eta,
692
+ eta = 0.0,
693
+ learn_sigma=learn_sigma,
694
+ ratio=self.args.model_ratio,
695
+ hybrid=self.args.hybrid_noise,
696
+ hybrid_config=HYBRID_CONFIG,
697
+ edit_h=None,
698
+ )
699
+
700
+ # added intermediate step vis
701
+ if (i - 99) % 100 == 0:
702
+ tvu.save_image((x + 1) * 0.5, os.path.join(self.args.image_folder,
703
+ f'2_lat_t{self.args.t_0}_ninv{self.args.n_inv_step}_ngen{self.args.n_test_step}_{i}_it{it}.png'))
704
+ progress_bar.update(1)
705
+
706
+ x0 = x.clone()
707
+ save_edit = "recons.png"
708
+ tvu.save_image((x_lat + 1) * 0.5, os.path.join("edit_output",save_edit))
709
+
710
+ return None
711
+
712
+
713
+
configs/afhq.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: "AFHQ"
3
+ category: "dog"
4
+ image_size: 256
5
+ channels: 3
6
+ logit_transform: false
7
+ uniform_dequantization: false
8
+ gaussian_dequantization: false
9
+ random_flip: true
10
+ rescaled: true
11
+ num_workers: 0
12
+
13
+ model:
14
+ type: "simple"
15
+ in_channels: 3
16
+ out_ch: 3
17
+ ch: 128
18
+ ch_mult: [1, 1, 2, 2, 4, 4]
19
+ num_res_blocks: 2
20
+ attn_resolutions: [16, ]
21
+ dropout: 0.0
22
+ var_type: fixedsmall
23
+ ema_rate: 0.999
24
+ ema: True
25
+ resamp_with_conv: True
26
+
27
+ diffusion:
28
+ beta_schedule: linear
29
+ beta_start: 0.0001
30
+ beta_end: 0.02
31
+ num_diffusion_timesteps: 1000
32
+
33
+ sampling:
34
+ batch_size: 4
35
+ last_only: True
configs/bedroom.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: "LSUN"
3
+ category: "bedroom"
4
+ image_size: 256
5
+ channels: 3
6
+ logit_transform: false
7
+ uniform_dequantization: false
8
+ gaussian_dequantization: false
9
+ random_flip: true
10
+ rescaled: true
11
+ num_workers: 0
12
+
13
+ model:
14
+ type: "simple"
15
+ in_channels: 3
16
+ out_ch: 3
17
+ ch: 128
18
+ ch_mult: [1, 1, 2, 2, 4, 4]
19
+ num_res_blocks: 2
20
+ attn_resolutions: [16, ]
21
+ dropout: 0.0
22
+ var_type: fixedsmall
23
+ ema_rate: 0.999
24
+ ema: True
25
+ resamp_with_conv: True
26
+
27
+ diffusion:
28
+ beta_schedule: linear
29
+ beta_start: 0.0001
30
+ beta_end: 0.02
31
+ num_diffusion_timesteps: 1000
32
+
33
+ sampling:
34
+ batch_size: 4
35
+ last_only: True
configs/celeba.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: "CelebA_HQ"
3
+ category: "CelebA_HQ"
4
+ image_size: 256
5
+ channels: 3
6
+ logit_transform: false
7
+ uniform_dequantization: false
8
+ gaussian_dequantization: false
9
+ random_flip: true
10
+ rescaled: true
11
+ num_workers: 0
12
+
13
+ model:
14
+ type: "simple"
15
+ in_channels: 3
16
+ out_ch: 3
17
+ ch: 128
18
+ ch_mult: [1, 1, 2, 2, 4, 4]
19
+ num_res_blocks: 2
20
+ attn_resolutions: [16, ]
21
+ dropout: 0.0
22
+ var_type: fixedsmall
23
+ ema_rate: 0.999
24
+ ema: True
25
+ resamp_with_conv: True
26
+
27
+ diffusion:
28
+ beta_schedule: linear
29
+ beta_start: 0.0001
30
+ beta_end: 0.02
31
+ num_diffusion_timesteps: 1000
32
+
33
+ sampling:
34
+ batch_size: 4
35
+ last_only: True
configs/church.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: "LSUN"
3
+ category: "church_outdoor"
4
+ image_size: 256
5
+ channels: 3
6
+ logit_transform: false
7
+ uniform_dequantization: false
8
+ gaussian_dequantization: false
9
+ random_flip: true
10
+ rescaled: true
11
+ num_workers: 0
12
+
13
+ model:
14
+ type: "simple"
15
+ in_channels: 3
16
+ out_ch: 3
17
+ ch: 128
18
+ ch_mult: [1, 1, 2, 2, 4, 4]
19
+ num_res_blocks: 2
20
+ attn_resolutions: [16, ]
21
+ dropout: 0.0
22
+ var_type: fixedsmall
23
+ ema_rate: 0.999
24
+ ema: True
25
+ resamp_with_conv: True
26
+
27
+ diffusion:
28
+ beta_schedule: linear
29
+ beta_start: 0.0001
30
+ beta_end: 0.02
31
+ num_diffusion_timesteps: 1000
32
+
33
+ sampling:
34
+ batch_size: 4
35
+ last_only: True
configs/imagenet.yml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: "IMAGENET"
3
+ category: "IMAGENET"
4
+ image_size: 512
5
+ channels: 3
6
+ logit_transform: false
7
+ uniform_dequantization: false
8
+ gaussian_dequantization: false
9
+ random_flip: true
10
+ rescaled: true
11
+ num_workers: 0
12
+
13
+ model:
14
+ type: "simple"
15
+ in_channels: 3
16
+ out_ch: 3
17
+ ch: 128
18
+ ch_mult: [1, 1, 2, 2, 4, 4]
19
+ num_res_blocks: 2
20
+ attn_resolutions: [16, ]
21
+ dropout: 0.0
22
+ var_type: fixedsmall
23
+ ema_rate: 0.999
24
+ ema: True
25
+ resamp_with_conv: True
26
+
27
+ diffusion:
28
+ beta_schedule: linear
29
+ beta_start: 0.0001
30
+ beta_end: 0.02
31
+ num_diffusion_timesteps: 1000
32
+
33
+ sampling:
34
+ batch_size: 4
35
+ last_only: True
configs/paths_config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET_PATHS = {
2
+ 'FFHQ': '/n/fs/visualai-scr/Data/CelebA-HQ/',
3
+ 'CelebA_HQ': '/n/fs/visualai-scr/Data/CelebA-HQ/',
4
+ 'AFHQ': '/n/fs/visualai-scr/Data/AFHQ-Dog/',
5
+ 'LSUN': '/n/fs/yz-diff/dataset/',
6
+ 'IMAGENET': 'data/imagenet/',
7
+ }
8
+
9
+ MODEL_PATHS = {
10
+ 'AFHQ': "pretrained/afhqdog_p2.pt",
11
+ 'FFHQ': "pretrained/ffhq_10m.pt",
12
+ 'ir_se50': 'pretrained/model_ir_se50.pth',
13
+ 'IMAGENET': "pretrained/512x512_diffusion.pt",
14
+ 'shape_predictor': "pretrained/shape_predictor_68_face_landmarks.dat.bz2",
15
+ }
16
+
17
+
18
+ HYBRID_MODEL_PATHS = [
19
+ './checkpoint/human_face/curly_hair_t401.pth',
20
+ './checkpoint/human_face/with_makeup_t401.pth',
21
+ ]
22
+
23
+ HYBRID_CONFIG = \
24
+ { 300: [0.4, 0.6, 0],
25
+ 0: [0.15, 0.15, 0.7]}
data_download.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified version of download.sh in https://github.com/naver-ai/StyleMapGAN
3
+ """
4
+
5
+ DATASET=$1
6
+ BASE_DIR=$2
7
+
8
+ if [ $DATASET == "celeba_hq" ]; then
9
+ URL="https://docs.google.com/uc?export=download&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o"
10
+ DATASET_FOLDER="/n/fs/visualai-scr/Data/CelebA-HQ"
11
+ ZIP_FILE=$DATASET_FOLDER/celeba_hq_raw.zip
12
+ elif [ $DATASET == "afhq" ]; then
13
+ URL="https://docs.google.com/uc?export=download&id=1Pf4f6Y27lQX9y9vjeSQnoOQntw_ln7il"
14
+ DATASET_FOLDER="./data/afhq"
15
+ ZIP_FILE=$DATASET_FOLDER/afhq_raw.zip
16
+ else
17
+ echo "Unknown DATASET"
18
+ exit 1
19
+ fi
20
+ mkdir -p $DATASET_FOLDER
21
+
22
+ # wget --no-check-certificate -r $URL -O $ZIP_FILE
23
+
24
+ # wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate $URL -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o" -O $ZIP_FILE && rm -rf ~/cookies.txt
25
+ # unzip $ZIP_FILE -d $DATASET_FOLDER
26
+ # rm $ZIP_FILE
27
+
28
+ # raw images to LMDB format
29
+ TARGET_SIZE=256,1024
30
+ for DATASET_TYPE in "train" "test" "val"; do
31
+ python utils/prepare_lmdb_data.py --out $DATASET_FOLDER/LMDB_$DATASET_TYPE --size $TARGET_SIZE $DATASET_FOLDER/raw_images/$DATASET_TYPE --attr gender
32
+ done
33
+
34
+
35
+
36
+
37
+ wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1R72NB79CX0MpnmWSli2SMu-Wp-M0xI-o" -O a.zip && rm -rf ~/cookies.txt
38
+
datasets/AFHQ_dataset.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from glob import glob
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ import torchvision.transforms as tfs
6
+
7
+ class AFHQ_dataset(Dataset):
8
+ def __init__(self, image_root, transform=None, mode='train', animal_class='dog', img_size=256):
9
+ super().__init__()
10
+ self.image_paths = glob(os.path.join(image_root, mode, animal_class, '*.jpg'))
11
+ self.transform = transform
12
+ self.img_size = img_size
13
+
14
+ def __getitem__(self, index):
15
+ image_path = self.image_paths[index]
16
+ x = Image.open(image_path)
17
+ x = x.resize((self.img_size, self.img_size))
18
+ if self.transform is not None:
19
+ x = self.transform(x)
20
+ return x
21
+
22
+ def __len__(self):
23
+ return len(self.image_paths)
24
+
25
+
26
+ ################################################################################
27
+
28
+ def get_afhq_dataset(data_root, config):
29
+ train_transform = tfs.Compose([tfs.ToTensor(),
30
+ tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
31
+ inplace=True)])
32
+
33
+ test_transform = tfs.Compose([tfs.ToTensor(),
34
+ tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
35
+ inplace=True)])
36
+
37
+ train_dataset = AFHQ_dataset(data_root, transform=train_transform, mode='train', animal_class='dog',
38
+ img_size=config.data.image_size)
39
+ test_dataset = AFHQ_dataset(data_root, transform=test_transform, mode='val', animal_class='dog',
40
+ img_size=config.data.image_size)
41
+
42
+ return train_dataset, test_dataset
datasets/CelebA_HQ_dataset.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import lmdb
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import torchvision.transforms as tfs
6
+ import os
7
+
8
+ class MultiResolutionDataset(Dataset):
9
+ def __init__(self, path, transform, resolution=256):
10
+ self.env = lmdb.open(
11
+ path,
12
+ max_readers=32,
13
+ readonly=True,
14
+ lock=False,
15
+ readahead=False,
16
+ meminit=False,
17
+ # attribute=,
18
+ )
19
+
20
+ if not self.env:
21
+ raise IOError("Cannot open lmdb dataset", path)
22
+
23
+ with self.env.begin(write=False) as txn:
24
+ self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))
25
+
26
+ self.resolution = resolution
27
+ self.transform = transform
28
+
29
+ attr_file_path = '/n/fs/yz-diff/inversion/list_attr_celeba.txt'
30
+ self.labels = file_to_list(attr_file_path)
31
+
32
+
33
+ def __len__(self):
34
+ return self.length
35
+
36
+ def __getitem__(self, index):
37
+ with self.env.begin(write=False) as txn:
38
+ key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8")
39
+ key_label = f"{str(index).zfill(5)}".encode("utf-8")
40
+ print("check key:", key, key_label)
41
+ img_bytes = txn.get(key)
42
+ img_id = int(txn.get(key_label).decode("utf-8"))
43
+
44
+ buffer = BytesIO(img_bytes)
45
+ img = Image.open(buffer)
46
+ img = self.transform(img)
47
+
48
+ attr_label = self.labels[img_id-1].split()
49
+ # map the attr to the index position
50
+ label = int(attr_label[32])
51
+ print("check img_id and label:", img_id, label)
52
+
53
+
54
+ return img, label
55
+
56
+
57
+ ################################################################################
58
+
59
+ def get_celeba_dataset(data_root, config):
60
+ train_transform = tfs.Compose([tfs.ToTensor(),
61
+ tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
62
+ inplace=True)])
63
+
64
+ test_transform = tfs.Compose([tfs.ToTensor(),
65
+ tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
66
+ inplace=True)])
67
+
68
+ train_dataset = MultiResolutionDataset(os.path.join(data_root, 'LMDB_train'),
69
+ train_transform, config.data.image_size)
70
+ test_dataset = MultiResolutionDataset(os.path.join(data_root, 'LMDB_test'),
71
+ test_transform, config.data.image_size)
72
+
73
+
74
+ return train_dataset, test_dataset
75
+
76
+
77
+
78
+ def file_to_list(filename):
79
+ with open(filename, encoding='utf-8') as f:
80
+ files = f.readlines()
81
+ files = [f.rstrip() for f in files]
82
+ return files
83
+
datasets/CelebA_HQ_dataset_with_label.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import lmdb
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import torchvision.transforms as tfs
6
+ import os
7
+
8
+ class MultiResolutionDataset(Dataset):
9
+ def __init__(self, path, transform, resolution=256):
10
+ self.env = lmdb.open(
11
+ path,
12
+ max_readers=32,
13
+ readonly=True,
14
+ lock=False,
15
+ readahead=False,
16
+ meminit=False,
17
+ )
18
+
19
+ if not self.env:
20
+ raise IOError("Cannot open lmdb dataset", path)
21
+
22
+ with self.env.begin(write=False) as txn:
23
+ self.length = int(txn.get("length".encode("utf-8")).decode("utf-8"))
24
+
25
+ self.resolution = resolution
26
+ self.transform = transform
27
+
28
+ def __len__(self):
29
+ return self.length
30
+
31
+ def __getitem__(self, index):
32
+ with self.env.begin(write=False) as txn:
33
+ key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8")
34
+ img_bytes = txn.get(key)
35
+
36
+ buffer = BytesIO(img_bytes)
37
+ img = Image.open(buffer)
38
+ img = self.transform(img)
39
+
40
+ return img
41
+
42
+
43
+ ################################################################################
44
+
45
+ def get_celeba_dataset(data_root, config):
46
+ train_transform = tfs.Compose([tfs.ToTensor(),
47
+ tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
48
+ inplace=True)])
49
+
50
+ test_transform = tfs.Compose([tfs.ToTensor(),
51
+ tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
52
+ inplace=True)])
53
+
54
+ train_dataset = MultiResolutionDataset(os.path.join(data_root, 'LMDB_train'),
55
+ train_transform, config.data.image_size)
56
+ test_dataset = MultiResolutionDataset(os.path.join(data_root, 'LMDB_test'),
57
+ test_transform, config.data.image_size)
58
+
59
+
60
+ return train_dataset, test_dataset
61
+
62
+
63
+
datasets/IMAGENET_dataset.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from glob import glob
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ import math
6
+ import numpy as np
7
+ import random
8
+ from .imagenet_dic import IMAGENET_DIC
9
+
10
+ def get_imagenet_dataset(data_root, config, class_num=None, random_crop=True, random_flip=False):
11
+ train_dataset = IMAGENET_dataset(data_root, mode='train', class_num=class_num, img_size=config.data.image_size,
12
+ random_crop=random_crop, random_flip=random_flip)
13
+ test_dataset = IMAGENET_dataset(data_root, mode='val', class_num=class_num, img_size=config.data.image_size,
14
+ random_crop=random_crop, random_flip=random_flip)
15
+
16
+ return train_dataset, test_dataset
17
+
18
+
19
+ ###################################################################
20
+
21
+
22
+ class IMAGENET_dataset(Dataset):
23
+ def __init__(self, image_root, mode='val', class_num=None, img_size=512, random_crop=True, random_flip=False):
24
+ super().__init__()
25
+ if class_num is not None:
26
+ self.data_dir = os.path.join(image_root, mode, IMAGENET_DIC[str(class_num)][0], '*.JPEG')
27
+ self.image_paths = sorted(glob(self.data_dir))
28
+ else:
29
+ self.data_dir = os.path.join(image_root, mode, '*', '*.JPEG')
30
+ self.image_paths = sorted(glob(self.data_dir))
31
+ self.img_size = img_size
32
+ self.random_crop = random_crop
33
+ self.random_flip = random_flip
34
+ self.class_num = class_num
35
+
36
+ def __getitem__(self, index):
37
+ f = self.image_paths[index]
38
+ pil_image = Image.open(f)
39
+ pil_image.load()
40
+ pil_image = pil_image.convert("RGB")
41
+
42
+ if self.random_crop:
43
+ arr = random_crop_arr(pil_image, self.img_size)
44
+ else:
45
+ arr = center_crop_arr(pil_image, self.img_size)
46
+
47
+ if self.random_flip and random.random() < 0.5:
48
+ arr = arr[:, ::-1]
49
+
50
+ arr = arr.astype(np.float32) / 127.5 - 1
51
+
52
+ # y = [self.class_num, IMAGENET_DIC[str(self.class_num)][0], IMAGENET_DIC[str(self.class_num)][1]]
53
+ # y = self.class_num
54
+
55
+ return np.transpose(arr, [2, 0, 1])#, y
56
+
57
+ def __len__(self):
58
+ return len(self.image_paths)
59
+
60
+
61
+ def center_crop_arr(pil_image, image_size):
62
+ # We are not on a new enough PIL to support the `reducing_gap`
63
+ # argument, which uses BOX downsampling at powers of two first.
64
+ # Thus, we do it by hand to improve downsample quality.
65
+ while min(*pil_image.size) >= 2 * image_size:
66
+ pil_image = pil_image.resize(
67
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
68
+ )
69
+
70
+ scale = image_size / min(*pil_image.size)
71
+ pil_image = pil_image.resize(
72
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
73
+ )
74
+
75
+ arr = np.array(pil_image)
76
+ crop_y = (arr.shape[0] - image_size) // 2
77
+ crop_x = (arr.shape[1] - image_size) // 2
78
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
79
+
80
+
81
+ def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
82
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
83
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
84
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
85
+
86
+ # We are not on a new enough PIL to support the `reducing_gap`
87
+ # argument, which uses BOX downsampling at powers of two first.
88
+ # Thus, we do it by hand to improve downsample quality.
89
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
90
+ pil_image = pil_image.resize(
91
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
92
+ )
93
+
94
+ scale = smaller_dim_size / min(*pil_image.size)
95
+ pil_image = pil_image.resize(
96
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
97
+ )
98
+
99
+ arr = np.array(pil_image)
100
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
101
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
102
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
datasets/LSUN_dataset.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from collections.abc import Iterable
3
+ from torchvision.datasets.utils import verify_str_arg, iterable_to_str
4
+
5
+
6
+ from PIL import Image
7
+ import io
8
+ import pickle
9
+ import os
10
+ import torch
11
+ import torch.utils.data as data
12
+ import torchvision.transforms as tfs
13
+
14
+ class VisionDataset(data.Dataset):
15
+ _repr_indent = 4
16
+
17
+ def __init__(self, root, transforms=None, transform=None, target_transform=None):
18
+ if isinstance(root, torch._six.string_classes):
19
+ root = os.path.expanduser(root)
20
+ self.root = root
21
+
22
+ has_transforms = transforms is not None
23
+ has_separate_transform = transform is not None or target_transform is not None
24
+ if has_transforms and has_separate_transform:
25
+ raise ValueError("Only transforms or transform/target_transform can "
26
+ "be passed as argument")
27
+
28
+ # for backwards-compatibility
29
+ self.transform = transform
30
+ self.target_transform = target_transform
31
+
32
+ if has_separate_transform:
33
+ transforms = StandardTransform(transform, target_transform)
34
+ self.transforms = transforms
35
+
36
+ def __getitem__(self, index):
37
+ raise NotImplementedError
38
+
39
+ def __len__(self):
40
+ raise NotImplementedError
41
+
42
+ def __repr__(self):
43
+ head = "Dataset " + self.__class__.__name__
44
+ body = ["Number of datapoints: {}".format(self.__len__())]
45
+ if self.root is not None:
46
+ body.append("Root location: {}".format(self.root))
47
+ body += self.extra_repr().splitlines()
48
+ if hasattr(self, 'transform') and self.transform is not None:
49
+ body += self._format_transform_repr(self.transform,
50
+ "Transforms: ")
51
+ if hasattr(self, 'target_transform') and self.target_transform is not None:
52
+ body += self._format_transform_repr(self.target_transform,
53
+ "Target transforms: ")
54
+ lines = [head] + [" " * self._repr_indent + line for line in body]
55
+ return '\n'.join(lines)
56
+
57
+ def _format_transform_repr(self, transform, head):
58
+ lines = transform.__repr__().splitlines()
59
+ return (["{}{}".format(head, lines[0])] +
60
+ ["{}{}".format(" " * len(head), line) for line in lines[1:]])
61
+
62
+ def extra_repr(self):
63
+ return ""
64
+
65
+
66
+ class StandardTransform(object):
67
+ def __init__(self, transform=None, target_transform=None):
68
+ self.transform = transform
69
+ self.target_transform = target_transform
70
+
71
+ def __call__(self, input, target):
72
+ if self.transform is not None:
73
+ input = self.transform(input)
74
+ if self.target_transform is not None:
75
+ target = self.target_transform(target)
76
+ return input, target
77
+
78
+ def _format_transform_repr(self, transform, head):
79
+ lines = transform.__repr__().splitlines()
80
+ return (["{}{}".format(head, lines[0])] +
81
+ ["{}{}".format(" " * len(head), line) for line in lines[1:]])
82
+
83
+ def __repr__(self):
84
+ body = [self.__class__.__name__]
85
+ if self.transform is not None:
86
+ body += self._format_transform_repr(self.transform,
87
+ "Transform: ")
88
+ if self.target_transform is not None:
89
+ body += self._format_transform_repr(self.target_transform,
90
+ "Target transform: ")
91
+
92
+ return '\n'.join(body)
93
+
94
+ ################################################################
95
+
96
+ class LSUNClass(VisionDataset):
97
+ def __init__(self, root, transform=None, target_transform=None):
98
+ import lmdb
99
+
100
+ super(LSUNClass, self).__init__(
101
+ root, transform=transform, target_transform=target_transform
102
+ )
103
+
104
+ self.env = lmdb.open(
105
+ root,
106
+ max_readers=1,
107
+ readonly=True,
108
+ lock=False,
109
+ readahead=False,
110
+ meminit=False,
111
+ )
112
+ with self.env.begin(write=False) as txn:
113
+ self.length = txn.stat()["entries"]
114
+ root_split = root.split("/")
115
+ cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}")
116
+ if os.path.isfile(cache_file):
117
+ self.keys = pickle.load(open(cache_file, "rb"))
118
+ else:
119
+ with self.env.begin(write=False) as txn:
120
+ self.keys = [key for key, _ in txn.cursor()]
121
+ pickle.dump(self.keys, open(cache_file, "wb"))
122
+
123
+ def __getitem__(self, index):
124
+ img, target = None, None
125
+ env = self.env
126
+ with env.begin(write=False) as txn:
127
+ imgbuf = txn.get(self.keys[index])
128
+
129
+ buf = io.BytesIO()
130
+ buf.write(imgbuf)
131
+ buf.seek(0)
132
+ img = Image.open(buf).convert("RGB")
133
+
134
+ if self.transform is not None:
135
+ img = self.transform(img)
136
+
137
+ if self.target_transform is not None:
138
+ target = self.target_transform(target)
139
+
140
+ return img, target
141
+
142
+ def __len__(self):
143
+ return self.length
144
+
145
+
146
+
147
+ class LSUN(VisionDataset):
148
+ """
149
+ `LSUN <https://www.yf.io/p/lsun>`_ dataset.
150
+
151
+ Args:
152
+ root (string): Root directory for the database files.
153
+ classes (string or list): One of {'train', 'val', 'test'} or a list of
154
+ categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
155
+ transform (callable, optional): A function/transform that takes in an PIL image
156
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
157
+ target_transform (callable, optional): A function/transform that takes in the
158
+ target and transforms it.
159
+ """
160
+
161
+ def __init__(self, root, classes="train", transform=None, target_transform=None):
162
+ super(LSUN, self).__init__(
163
+ root, transform=transform, target_transform=target_transform
164
+ )
165
+ self.classes = self._verify_classes(classes)
166
+
167
+ # for each class, create an LSUNClassDataset
168
+ self.dbs = []
169
+ for c in self.classes:
170
+ self.dbs.append(
171
+ LSUNClass(root=root + "/" + c + "_lmdb", transform=transform)
172
+ )
173
+
174
+ self.indices = []
175
+ count = 0
176
+ for db in self.dbs:
177
+ count += len(db)
178
+ self.indices.append(count)
179
+
180
+ self.length = count
181
+
182
+ def _verify_classes(self, classes):
183
+ categories = [
184
+ "bedroom",
185
+ "bridge",
186
+ "church_outdoor",
187
+ "classroom",
188
+ "conference_room",
189
+ "dining_room",
190
+ "kitchen",
191
+ "living_room",
192
+ "restaurant",
193
+ "tower",
194
+ ]
195
+ dset_opts = ["train", "val", "test"]
196
+
197
+ try:
198
+ verify_str_arg(classes, "classes", dset_opts)
199
+ if classes == "test":
200
+ classes = [classes]
201
+ else:
202
+ classes = [c + "_" + classes for c in categories]
203
+ except ValueError:
204
+ if not isinstance(classes, Iterable):
205
+ msg = (
206
+ "Expected type str or Iterable for argument classes, "
207
+ "but got type {}."
208
+ )
209
+ raise ValueError(msg.format(type(classes)))
210
+
211
+ classes = list(classes)
212
+ msg_fmtstr = (
213
+ "Expected type str for elements in argument classes, "
214
+ "but got type {}."
215
+ )
216
+ for c in classes:
217
+ verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
218
+ c_short = c.split("_")
219
+ category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
220
+
221
+ msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
222
+ msg = msg_fmtstr.format(
223
+ category, "LSUN class", iterable_to_str(categories)
224
+ )
225
+ verify_str_arg(category, valid_values=categories, custom_msg=msg)
226
+
227
+ msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
228
+ verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
229
+
230
+ return classes
231
+
232
+ def __getitem__(self, index):
233
+ """
234
+ Args:
235
+ index (int): Index
236
+
237
+ Returns:
238
+ tuple: Tuple (image, target) where target is the index of the target category.
239
+ """
240
+ target = 0
241
+ sub = 0
242
+ for ind in self.indices:
243
+ if index < ind:
244
+ break
245
+ target += 1
246
+ sub = ind
247
+
248
+ db = self.dbs[target]
249
+ index = index - sub
250
+
251
+ if self.target_transform is not None:
252
+ target = self.target_transform(target)
253
+
254
+ img, _ = db[index]
255
+ return img#, target
256
+
257
+ def __len__(self):
258
+ return self.length
259
+
260
+ def extra_repr(self):
261
+ return "Classes: {classes}".format(**self.__dict__)
262
+
263
+
264
+
265
+
266
+
267
+ ################################################################
268
+
269
+ def get_lsun_dataset(data_root, config):
270
+
271
+ train_folder = "{}_train".format(config.data.category)
272
+ val_folder = "{}_val".format(config.data.category)
273
+
274
+ train_dataset = LSUN(
275
+ root=os.path.join(data_root),
276
+ classes=[train_folder],
277
+ transform=tfs.Compose(
278
+ [
279
+ tfs.Resize(config.data.image_size),
280
+ tfs.CenterCrop(config.data.image_size),
281
+ tfs.ToTensor(),
282
+ tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
283
+ inplace=True)
284
+ ]
285
+ ),
286
+ )
287
+
288
+ test_dataset = LSUN(
289
+ root=os.path.join(data_root),
290
+ classes=[val_folder],
291
+ transform=tfs.Compose(
292
+ [
293
+ tfs.Resize(config.data.image_size),
294
+ tfs.CenterCrop(config.data.image_size),
295
+ tfs.ToTensor(),
296
+ tfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
297
+ inplace=True)
298
+ ,
299
+ ]
300
+ ),
301
+ )
302
+
303
+
304
+ return train_dataset, test_dataset
datasets/celeba_attr.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 5_o_Clock_Shadow
2
+ Arched_Eyebrows
3
+ Attractive
4
+ Bags_Under_Eyes
5
+ Bald
6
+ Bangs
7
+ Big_Lips
8
+ Big_Nose
9
+ Black_Hair
10
+ Blond_Hair
11
+ Blurry
12
+ Brown_Hair
13
+ Bushy_Eyebrows
14
+ Chubby
15
+ Double_Chin
16
+ Eyeglasses
17
+ Goatee
18
+ Gray_Hair
19
+ Heavy_Makeup
20
+ High_Cheekbones
21
+ Male
22
+ Mouth_Slightly_Open
23
+ Mustache
24
+ Narrow_Eyes
25
+ No_Beard
26
+ Oval_Face
27
+ Pale_Skin
28
+ Pointy_Nose
29
+ Receding_Hairline
30
+ Rosy_Cheeks
31
+ Sideburns
32
+ Smiling
33
+ Straight_Hair
34
+ Wavy_Hair
35
+ Wearing_Earrings
36
+ Wearing_Hat
37
+ Wearing_Lipstick
38
+ Wearing_Necklace
39
+ Wearing_Necktie
40
+ Young
datasets/data_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .AFHQ_dataset import get_afhq_dataset
2
+ from .CelebA_HQ_dataset import get_celeba_dataset
3
+ from .LSUN_dataset import get_lsun_dataset
4
+ from torch.utils.data import DataLoader
5
+ from .IMAGENET_dataset import get_imagenet_dataset
6
+
7
+ def get_dataset(dataset_type, dataset_paths, config, target_class_num=None, gender=None):
8
+ if dataset_type == 'AFHQ':
9
+ train_dataset, test_dataset = get_afhq_dataset(dataset_paths['AFHQ'], config)
10
+ elif dataset_type == "LSUN":
11
+ train_dataset, test_dataset = get_lsun_dataset(dataset_paths['LSUN'], config)
12
+ elif dataset_type == "CelebA_HQ":
13
+ train_dataset, test_dataset = get_celeba_dataset(dataset_paths['CelebA_HQ'], config)
14
+ elif dataset_type == "IMAGENET":
15
+ train_dataset, test_dataset = get_imagenet_dataset(dataset_paths['IMAGENET'], config, class_num=target_class_num)
16
+ else:
17
+ raise ValueError
18
+
19
+ return train_dataset, test_dataset
20
+
21
+
22
+ def get_dataloader(train_dataset, test_dataset, bs_train=1, num_workers=0):
23
+ train_loader = DataLoader(
24
+ train_dataset,
25
+ batch_size=bs_train,
26
+ drop_last=True,
27
+ shuffle=True,
28
+ sampler=None,
29
+ num_workers=num_workers,
30
+ pin_memory=True,
31
+ )
32
+ test_loader = DataLoader(
33
+ test_dataset,
34
+ batch_size=1,
35
+ drop_last=True,
36
+ sampler=None,
37
+ shuffle=True,
38
+ num_workers=num_workers,
39
+ pin_memory=True,
40
+ )
41
+
42
+ return {'train': train_loader, 'test': test_loader}
43
+
44
+
datasets/imagenet_dic.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IMAGENET_DIC = {"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"],
2
+ "3": ["n01491361", "tiger_shark"], "4": ["n01494475", "hammerhead"], "5": ["n01496331", "electric_ray"],
3
+ "6": ["n01498041", "stingray"], "7": ["n01514668", "cock"], "8": ["n01514859", "hen"],
4
+ "9": ["n01518878", "ostrich"], "10": ["n01530575", "brambling"], "11": ["n01531178", "goldfinch"],
5
+ "12": ["n01532829", "house_finch"], "13": ["n01534433", "junco"], "14": ["n01537544", "indigo_bunting"],
6
+ "15": ["n01558993", "robin"], "16": ["n01560419", "bulbul"], "17": ["n01580077", "jay"],
7
+ "18": ["n01582220", "magpie"], "19": ["n01592084", "chickadee"], "20": ["n01601694", "water_ouzel"],
8
+ "21": ["n01608432", "kite"], "22": ["n01614925", "bald_eagle"], "23": ["n01616318", "vulture"],
9
+ "24": ["n01622779", "great_grey_owl"], "25": ["n01629819", "European_fire_salamander"],
10
+ "26": ["n01630670", "common_newt"], "27": ["n01631663", "eft"],
11
+ "28": ["n01632458", "spotted_salamander"], "29": ["n01632777", "axolotl"],
12
+ "30": ["n01641577", "bullfrog"], "31": ["n01644373", "tree_frog"], "32": ["n01644900", "tailed_frog"],
13
+ "33": ["n01664065", "loggerhead"], "34": ["n01665541", "leatherback_turtle"],
14
+ "35": ["n01667114", "mud_turtle"], "36": ["n01667778", "terrapin"], "37": ["n01669191", "box_turtle"],
15
+ "38": ["n01675722", "banded_gecko"], "39": ["n01677366", "common_iguana"],
16
+ "40": ["n01682714", "American_chameleon"], "41": ["n01685808", "whiptail"],
17
+ "42": ["n01687978", "agama"], "43": ["n01688243", "frilled_lizard"],
18
+ "44": ["n01689811", "alligator_lizard"], "45": ["n01692333", "Gila_monster"],
19
+ "46": ["n01693334", "green_lizard"], "47": ["n01694178", "African_chameleon"],
20
+ "48": ["n01695060", "Komodo_dragon"], "49": ["n01697457", "African_crocodile"],
21
+ "50": ["n01698640", "American_alligator"], "51": ["n01704323", "triceratops"],
22
+ "52": ["n01728572", "thunder_snake"], "53": ["n01728920", "ringneck_snake"],
23
+ "54": ["n01729322", "hognose_snake"], "55": ["n01729977", "green_snake"],
24
+ "56": ["n01734418", "king_snake"], "57": ["n01735189", "garter_snake"],
25
+ "58": ["n01737021", "water_snake"], "59": ["n01739381", "vine_snake"],
26
+ "60": ["n01740131", "night_snake"], "61": ["n01742172", "boa_constrictor"],
27
+ "62": ["n01744401", "rock_python"], "63": ["n01748264", "Indian_cobra"],
28
+ "64": ["n01749939", "green_mamba"], "65": ["n01751748", "sea_snake"],
29
+ "66": ["n01753488", "horned_viper"], "67": ["n01755581", "diamondback"],
30
+ "68": ["n01756291", "sidewinder"], "69": ["n01768244", "trilobite"], "70": ["n01770081", "harvestman"],
31
+ "71": ["n01770393", "scorpion"], "72": ["n01773157", "black_and_gold_garden_spider"],
32
+ "73": ["n01773549", "barn_spider"], "74": ["n01773797", "garden_spider"],
33
+ "75": ["n01774384", "black_widow"], "76": ["n01774750", "tarantula"],
34
+ "77": ["n01775062", "wolf_spider"], "78": ["n01776313", "tick"], "79": ["n01784675", "centipede"],
35
+ "80": ["n01795545", "black_grouse"], "81": ["n01796340", "ptarmigan"],
36
+ "82": ["n01797886", "ruffed_grouse"], "83": ["n01798484", "prairie_chicken"],
37
+ "84": ["n01806143", "peacock"], "85": ["n01806567", "quail"], "86": ["n01807496", "partridge"],
38
+ "87": ["n01817953", "African_grey"], "88": ["n01818515", "macaw"],
39
+ "89": ["n01819313", "sulphur-crested_cockatoo"], "90": ["n01820546", "lorikeet"],
40
+ "91": ["n01824575", "coucal"], "92": ["n01828970", "bee_eater"], "93": ["n01829413", "hornbill"],
41
+ "94": ["n01833805", "hummingbird"], "95": ["n01843065", "jacamar"], "96": ["n01843383", "toucan"],
42
+ "97": ["n01847000", "drake"], "98": ["n01855032", "red-breasted_merganser"],
43
+ "99": ["n01855672", "goose"], "100": ["n01860187", "black_swan"], "101": ["n01871265", "tusker"],
44
+ "102": ["n01872401", "echidna"], "103": ["n01873310", "platypus"], "104": ["n01877812", "wallaby"],
45
+ "105": ["n01882714", "koala"], "106": ["n01883070", "wombat"], "107": ["n01910747", "jellyfish"],
46
+ "108": ["n01914609", "sea_anemone"], "109": ["n01917289", "brain_coral"],
47
+ "110": ["n01924916", "flatworm"], "111": ["n01930112", "nematode"], "112": ["n01943899", "conch"],
48
+ "113": ["n01944390", "snail"], "114": ["n01945685", "slug"], "115": ["n01950731", "sea_slug"],
49
+ "116": ["n01955084", "chiton"], "117": ["n01968897", "chambered_nautilus"],
50
+ "118": ["n01978287", "Dungeness_crab"], "119": ["n01978455", "rock_crab"],
51
+ "120": ["n01980166", "fiddler_crab"], "121": ["n01981276", "king_crab"],
52
+ "122": ["n01983481", "American_lobster"], "123": ["n01984695", "spiny_lobster"],
53
+ "124": ["n01985128", "crayfish"], "125": ["n01986214", "hermit_crab"], "126": ["n01990800", "isopod"],
54
+ "127": ["n02002556", "white_stork"], "128": ["n02002724", "black_stork"],
55
+ "129": ["n02006656", "spoonbill"], "130": ["n02007558", "flamingo"],
56
+ "131": ["n02009229", "little_blue_heron"], "132": ["n02009912", "American_egret"],
57
+ "133": ["n02011460", "bittern"], "134": ["n02012849", "crane"], "135": ["n02013706", "limpkin"],
58
+ "136": ["n02017213", "European_gallinule"], "137": ["n02018207", "American_coot"],
59
+ "138": ["n02018795", "bustard"], "139": ["n02025239", "ruddy_turnstone"],
60
+ "140": ["n02027492", "red-backed_sandpiper"], "141": ["n02028035", "redshank"],
61
+ "142": ["n02033041", "dowitcher"], "143": ["n02037110", "oystercatcher"],
62
+ "144": ["n02051845", "pelican"], "145": ["n02056570", "king_penguin"],
63
+ "146": ["n02058221", "albatross"], "147": ["n02066245", "grey_whale"],
64
+ "148": ["n02071294", "killer_whale"], "149": ["n02074367", "dugong"], "150": ["n02077923", "sea_lion"],
65
+ "151": ["n02085620", "Chihuahua"], "152": ["n02085782", "Japanese_spaniel"],
66
+ "153": ["n02085936", "Maltese_dog"], "154": ["n02086079", "Pekinese"], "155": ["n02086240", "Shih-Tzu"],
67
+ "156": ["n02086646", "Blenheim_spaniel"], "157": ["n02086910", "papillon"],
68
+ "158": ["n02087046", "toy_terrier"], "159": ["n02087394", "Rhodesian_ridgeback"],
69
+ "160": ["n02088094", "Afghan_hound"], "161": ["n02088238", "basset"], "162": ["n02088364", "beagle"],
70
+ "163": ["n02088466", "bloodhound"], "164": ["n02088632", "bluetick"],
71
+ "165": ["n02089078", "black-and-tan_coonhound"], "166": ["n02089867", "Walker_hound"],
72
+ "167": ["n02089973", "English_foxhound"], "168": ["n02090379", "redbone"],
73
+ "169": ["n02090622", "borzoi"], "170": ["n02090721", "Irish_wolfhound"],
74
+ "171": ["n02091032", "Italian_greyhound"], "172": ["n02091134", "whippet"],
75
+ "173": ["n02091244", "Ibizan_hound"], "174": ["n02091467", "Norwegian_elkhound"],
76
+ "175": ["n02091635", "otterhound"], "176": ["n02091831", "Saluki"],
77
+ "177": ["n02092002", "Scottish_deerhound"], "178": ["n02092339", "Weimaraner"],
78
+ "179": ["n02093256", "Staffordshire_bullterrier"],
79
+ "180": ["n02093428", "American_Staffordshire_terrier"], "181": ["n02093647", "Bedlington_terrier"],
80
+ "182": ["n02093754", "Border_terrier"], "183": ["n02093859", "Kerry_blue_terrier"],
81
+ "184": ["n02093991", "Irish_terrier"], "185": ["n02094114", "Norfolk_terrier"],
82
+ "186": ["n02094258", "Norwich_terrier"], "187": ["n02094433", "Yorkshire_terrier"],
83
+ "188": ["n02095314", "wire-haired_fox_terrier"], "189": ["n02095570", "Lakeland_terrier"],
84
+ "190": ["n02095889", "Sealyham_terrier"], "191": ["n02096051", "Airedale"],
85
+ "192": ["n02096177", "cairn"], "193": ["n02096294", "Australian_terrier"],
86
+ "194": ["n02096437", "Dandie_Dinmont"], "195": ["n02096585", "Boston_bull"],
87
+ "196": ["n02097047", "miniature_schnauzer"], "197": ["n02097130", "giant_schnauzer"],
88
+ "198": ["n02097209", "standard_schnauzer"], "199": ["n02097298", "Scotch_terrier"],
89
+ "200": ["n02097474", "Tibetan_terrier"], "201": ["n02097658", "silky_terrier"],
90
+ "202": ["n02098105", "soft-coated_wheaten_terrier"],
91
+ "203": ["n02098286", "West_Highland_white_terrier"], "204": ["n02098413", "Lhasa"],
92
+ "205": ["n02099267", "flat-coated_retriever"], "206": ["n02099429", "curly-coated_retriever"],
93
+ "207": ["n02099601", "golden_retriever"], "208": ["n02099712", "Labrador_retriever"],
94
+ "209": ["n02099849", "Chesapeake_Bay_retriever"], "210": ["n02100236", "German_short-haired_pointer"],
95
+ "211": ["n02100583", "vizsla"], "212": ["n02100735", "English_setter"],
96
+ "213": ["n02100877", "Irish_setter"], "214": ["n02101006", "Gordon_setter"],
97
+ "215": ["n02101388", "Brittany_spaniel"], "216": ["n02101556", "clumber"],
98
+ "217": ["n02102040", "English_springer"], "218": ["n02102177", "Welsh_springer_spaniel"],
99
+ "219": ["n02102318", "cocker_spaniel"], "220": ["n02102480", "Sussex_spaniel"],
100
+ "221": ["n02102973", "Irish_water_spaniel"], "222": ["n02104029", "kuvasz"],
101
+ "223": ["n02104365", "schipperke"], "224": ["n02105056", "groenendael"],
102
+ "225": ["n02105162", "malinois"], "226": ["n02105251", "briard"], "227": ["n02105412", "kelpie"],
103
+ "228": ["n02105505", "komondor"], "229": ["n02105641", "Old_English_sheepdog"],
104
+ "230": ["n02105855", "Shetland_sheepdog"], "231": ["n02106030", "collie"],
105
+ "232": ["n02106166", "Border_collie"], "233": ["n02106382", "Bouvier_des_Flandres"],
106
+ "234": ["n02106550", "Rottweiler"], "235": ["n02106662", "German_shepherd"],
107
+ "236": ["n02107142", "Doberman"], "237": ["n02107312", "miniature_pinscher"],
108
+ "238": ["n02107574", "Greater_Swiss_Mountain_dog"], "239": ["n02107683", "Bernese_mountain_dog"],
109
+ "240": ["n02107908", "Appenzeller"], "241": ["n02108000", "EntleBucher"], "242": ["n02108089", "boxer"],
110
+ "243": ["n02108422", "bull_mastiff"], "244": ["n02108551", "Tibetan_mastiff"],
111
+ "245": ["n02108915", "French_bulldog"], "246": ["n02109047", "Great_Dane"],
112
+ "247": ["n02109525", "Saint_Bernard"], "248": ["n02109961", "Eskimo_dog"],
113
+ "249": ["n02110063", "malamute"], "250": ["n02110185", "Siberian_husky"],
114
+ "251": ["n02110341", "dalmatian"], "252": ["n02110627", "affenpinscher"],
115
+ "253": ["n02110806", "basenji"], "254": ["n02110958", "pug"], "255": ["n02111129", "Leonberg"],
116
+ "256": ["n02111277", "Newfoundland"], "257": ["n02111500", "Great_Pyrenees"],
117
+ "258": ["n02111889", "Samoyed"], "259": ["n02112018", "Pomeranian"], "260": ["n02112137", "chow"],
118
+ "261": ["n02112350", "keeshond"], "262": ["n02112706", "Brabancon_griffon"],
119
+ "263": ["n02113023", "Pembroke"], "264": ["n02113186", "Cardigan"], "265": ["n02113624", "toy_poodle"],
120
+ "266": ["n02113712", "miniature_poodle"], "267": ["n02113799", "standard_poodle"],
121
+ "268": ["n02113978", "Mexican_hairless"], "269": ["n02114367", "timber_wolf"],
122
+ "270": ["n02114548", "white_wolf"], "271": ["n02114712", "red_wolf"], "272": ["n02114855", "coyote"],
123
+ "273": ["n02115641", "dingo"], "274": ["n02115913", "dhole"],
124
+ "275": ["n02116738", "African_hunting_dog"], "276": ["n02117135", "hyena"],
125
+ "277": ["n02119022", "red_fox"], "278": ["n02119789", "kit_fox"], "279": ["n02120079", "Arctic_fox"],
126
+ "280": ["n02120505", "grey_fox"], "281": ["n02123045", "tabby"], "282": ["n02123159", "tiger_cat"],
127
+ "283": ["n02123394", "Persian_cat"], "284": ["n02123597", "Siamese_cat"],
128
+ "285": ["n02124075", "Egyptian_cat"], "286": ["n02125311", "cougar"], "287": ["n02127052", "lynx"],
129
+ "288": ["n02128385", "leopard"], "289": ["n02128757", "snow_leopard"], "290": ["n02128925", "jaguar"],
130
+ "291": ["n02129165", "lion"], "292": ["n02129604", "tiger"], "293": ["n02130308", "cheetah"],
131
+ "294": ["n02132136", "brown_bear"], "295": ["n02133161", "American_black_bear"],
132
+ "296": ["n02134084", "ice_bear"], "297": ["n02134418", "sloth_bear"], "298": ["n02137549", "mongoose"],
133
+ "299": ["n02138441", "meerkat"], "300": ["n02165105", "tiger_beetle"], "301": ["n02165456", "ladybug"],
134
+ "302": ["n02167151", "ground_beetle"], "303": ["n02168699", "long-horned_beetle"],
135
+ "304": ["n02169497", "leaf_beetle"], "305": ["n02172182", "dung_beetle"],
136
+ "306": ["n02174001", "rhinoceros_beetle"], "307": ["n02177972", "weevil"], "308": ["n02190166", "fly"],
137
+ "309": ["n02206856", "bee"], "310": ["n02219486", "ant"], "311": ["n02226429", "grasshopper"],
138
+ "312": ["n02229544", "cricket"], "313": ["n02231487", "walking_stick"],
139
+ "314": ["n02233338", "cockroach"], "315": ["n02236044", "mantis"], "316": ["n02256656", "cicada"],
140
+ "317": ["n02259212", "leafhopper"], "318": ["n02264363", "lacewing"], "319": ["n02268443", "dragonfly"],
141
+ "320": ["n02268853", "damselfly"], "321": ["n02276258", "admiral"], "322": ["n02277742", "ringlet"],
142
+ "323": ["n02279972", "monarch"], "324": ["n02280649", "cabbage_butterfly"],
143
+ "325": ["n02281406", "sulphur_butterfly"], "326": ["n02281787", "lycaenid"],
144
+ "327": ["n02317335", "starfish"], "328": ["n02319095", "sea_urchin"],
145
+ "329": ["n02321529", "sea_cucumber"], "330": ["n02325366", "wood_rabbit"], "331": ["n02326432", "hare"],
146
+ "332": ["n02328150", "Angora"], "333": ["n02342885", "hamster"], "334": ["n02346627", "porcupine"],
147
+ "335": ["n02356798", "fox_squirrel"], "336": ["n02361337", "marmot"], "337": ["n02363005", "beaver"],
148
+ "338": ["n02364673", "guinea_pig"], "339": ["n02389026", "sorrel"], "340": ["n02391049", "zebra"],
149
+ "341": ["n02395406", "hog"], "342": ["n02396427", "wild_boar"], "343": ["n02397096", "warthog"],
150
+ "344": ["n02398521", "hippopotamus"], "345": ["n02403003", "ox"], "346": ["n02408429", "water_buffalo"],
151
+ "347": ["n02410509", "bison"], "348": ["n02412080", "ram"], "349": ["n02415577", "bighorn"],
152
+ "350": ["n02417914", "ibex"], "351": ["n02422106", "hartebeest"], "352": ["n02422699", "impala"],
153
+ "353": ["n02423022", "gazelle"], "354": ["n02437312", "Arabian_camel"], "355": ["n02437616", "llama"],
154
+ "356": ["n02441942", "weasel"], "357": ["n02442845", "mink"], "358": ["n02443114", "polecat"],
155
+ "359": ["n02443484", "black-footed_ferret"], "360": ["n02444819", "otter"],
156
+ "361": ["n02445715", "skunk"], "362": ["n02447366", "badger"], "363": ["n02454379", "armadillo"],
157
+ "364": ["n02457408", "three-toed_sloth"], "365": ["n02480495", "orangutan"],
158
+ "366": ["n02480855", "gorilla"], "367": ["n02481823", "chimpanzee"], "368": ["n02483362", "gibbon"],
159
+ "369": ["n02483708", "siamang"], "370": ["n02484975", "guenon"], "371": ["n02486261", "patas"],
160
+ "372": ["n02486410", "baboon"], "373": ["n02487347", "macaque"], "374": ["n02488291", "langur"],
161
+ "375": ["n02488702", "colobus"], "376": ["n02489166", "proboscis_monkey"],
162
+ "377": ["n02490219", "marmoset"], "378": ["n02492035", "capuchin"],
163
+ "379": ["n02492660", "howler_monkey"], "380": ["n02493509", "titi"],
164
+ "381": ["n02493793", "spider_monkey"], "382": ["n02494079", "squirrel_monkey"],
165
+ "383": ["n02497673", "Madagascar_cat"], "384": ["n02500267", "indri"],
166
+ "385": ["n02504013", "Indian_elephant"], "386": ["n02504458", "African_elephant"],
167
+ "387": ["n02509815", "lesser_panda"], "388": ["n02510455", "giant_panda"],
168
+ "389": ["n02514041", "barracouta"], "390": ["n02526121", "eel"], "391": ["n02536864", "coho"],
169
+ "392": ["n02606052", "rock_beauty"], "393": ["n02607072", "anemone_fish"],
170
+ "394": ["n02640242", "sturgeon"], "395": ["n02641379", "gar"], "396": ["n02643566", "lionfish"],
171
+ "397": ["n02655020", "puffer"], "398": ["n02666196", "abacus"], "399": ["n02667093", "abaya"],
172
+ "400": ["n02669723", "academic_gown"], "401": ["n02672831", "accordion"],
173
+ "402": ["n02676566", "acoustic_guitar"], "403": ["n02687172", "aircraft_carrier"],
174
+ "404": ["n02690373", "airliner"], "405": ["n02692877", "airship"], "406": ["n02699494", "altar"],
175
+ "407": ["n02701002", "ambulance"], "408": ["n02704792", "amphibian"],
176
+ "409": ["n02708093", "analog_clock"], "410": ["n02727426", "apiary"], "411": ["n02730930", "apron"],
177
+ "412": ["n02747177", "ashcan"], "413": ["n02749479", "assault_rifle"], "414": ["n02769748", "backpack"],
178
+ "415": ["n02776631", "bakery"], "416": ["n02777292", "balance_beam"], "417": ["n02782093", "balloon"],
179
+ "418": ["n02783161", "ballpoint"], "419": ["n02786058", "Band_Aid"], "420": ["n02787622", "banjo"],
180
+ "421": ["n02788148", "bannister"], "422": ["n02790996", "barbell"],
181
+ "423": ["n02791124", "barber_chair"], "424": ["n02791270", "barbershop"], "425": ["n02793495", "barn"],
182
+ "426": ["n02794156", "barometer"], "427": ["n02795169", "barrel"], "428": ["n02797295", "barrow"],
183
+ "429": ["n02799071", "baseball"], "430": ["n02802426", "basketball"], "431": ["n02804414", "bassinet"],
184
+ "432": ["n02804610", "bassoon"], "433": ["n02807133", "bathing_cap"],
185
+ "434": ["n02808304", "bath_towel"], "435": ["n02808440", "bathtub"],
186
+ "436": ["n02814533", "beach_wagon"], "437": ["n02814860", "beacon"], "438": ["n02815834", "beaker"],
187
+ "439": ["n02817516", "bearskin"], "440": ["n02823428", "beer_bottle"],
188
+ "441": ["n02823750", "beer_glass"], "442": ["n02825657", "bell_cote"], "443": ["n02834397", "bib"],
189
+ "444": ["n02835271", "bicycle-built-for-two"], "445": ["n02837789", "bikini"],
190
+ "446": ["n02840245", "binder"], "447": ["n02841315", "binoculars"], "448": ["n02843684", "birdhouse"],
191
+ "449": ["n02859443", "boathouse"], "450": ["n02860847", "bobsled"], "451": ["n02865351", "bolo_tie"],
192
+ "452": ["n02869837", "bonnet"], "453": ["n02870880", "bookcase"], "454": ["n02871525", "bookshop"],
193
+ "455": ["n02877765", "bottlecap"], "456": ["n02879718", "bow"], "457": ["n02883205", "bow_tie"],
194
+ "458": ["n02892201", "brass"], "459": ["n02892767", "brassiere"], "460": ["n02894605", "breakwater"],
195
+ "461": ["n02895154", "breastplate"], "462": ["n02906734", "broom"], "463": ["n02909870", "bucket"],
196
+ "464": ["n02910353", "buckle"], "465": ["n02916936", "bulletproof_vest"],
197
+ "466": ["n02917067", "bullet_train"], "467": ["n02927161", "butcher_shop"], "468": ["n02930766", "cab"],
198
+ "469": ["n02939185", "caldron"], "470": ["n02948072", "candle"], "471": ["n02950826", "cannon"],
199
+ "472": ["n02951358", "canoe"], "473": ["n02951585", "can_opener"], "474": ["n02963159", "cardigan"],
200
+ "475": ["n02965783", "car_mirror"], "476": ["n02966193", "carousel"],
201
+ "477": ["n02966687", "carpenter's_kit"], "478": ["n02971356", "carton"],
202
+ "479": ["n02974003", "car_wheel"], "480": ["n02977058", "cash_machine"],
203
+ "481": ["n02978881", "cassette"], "482": ["n02979186", "cassette_player"],
204
+ "483": ["n02980441", "castle"], "484": ["n02981792", "catamaran"], "485": ["n02988304", "CD_player"],
205
+ "486": ["n02992211", "cello"], "487": ["n02992529", "cellular_telephone"],
206
+ "488": ["n02999410", "chain"], "489": ["n03000134", "chainlink_fence"],
207
+ "490": ["n03000247", "chain_mail"], "491": ["n03000684", "chain_saw"], "492": ["n03014705", "chest"],
208
+ "493": ["n03016953", "chiffonier"], "494": ["n03017168", "chime"],
209
+ "495": ["n03018349", "china_cabinet"], "496": ["n03026506", "Christmas_stocking"],
210
+ "497": ["n03028079", "church"], "498": ["n03032252", "cinema"], "499": ["n03041632", "cleaver"],
211
+ "500": ["n03042490", "cliff_dwelling"], "501": ["n03045698", "cloak"], "502": ["n03047690", "clog"],
212
+ "503": ["n03062245", "cocktail_shaker"], "504": ["n03063599", "coffee_mug"],
213
+ "505": ["n03063689", "coffeepot"], "506": ["n03065424", "coil"],
214
+ "507": ["n03075370", "combination_lock"], "508": ["n03085013", "computer_keyboard"],
215
+ "509": ["n03089624", "confectionery"], "510": ["n03095699", "container_ship"],
216
+ "511": ["n03100240", "convertible"], "512": ["n03109150", "corkscrew"], "513": ["n03110669", "cornet"],
217
+ "514": ["n03124043", "cowboy_boot"], "515": ["n03124170", "cowboy_hat"], "516": ["n03125729", "cradle"],
218
+ "517": ["n03126707", "crane"], "518": ["n03127747", "crash_helmet"], "519": ["n03127925", "crate"],
219
+ "520": ["n03131574", "crib"], "521": ["n03133878", "Crock_Pot"], "522": ["n03134739", "croquet_ball"],
220
+ "523": ["n03141823", "crutch"], "524": ["n03146219", "cuirass"], "525": ["n03160309", "dam"],
221
+ "526": ["n03179701", "desk"], "527": ["n03180011", "desktop_computer"],
222
+ "528": ["n03187595", "dial_telephone"], "529": ["n03188531", "diaper"],
223
+ "530": ["n03196217", "digital_clock"], "531": ["n03197337", "digital_watch"],
224
+ "532": ["n03201208", "dining_table"], "533": ["n03207743", "dishrag"],
225
+ "534": ["n03207941", "dishwasher"], "535": ["n03208938", "disk_brake"], "536": ["n03216828", "dock"],
226
+ "537": ["n03218198", "dogsled"], "538": ["n03220513", "dome"], "539": ["n03223299", "doormat"],
227
+ "540": ["n03240683", "drilling_platform"], "541": ["n03249569", "drum"],
228
+ "542": ["n03250847", "drumstick"], "543": ["n03255030", "dumbbell"], "544": ["n03259280", "Dutch_oven"],
229
+ "545": ["n03271574", "electric_fan"], "546": ["n03272010", "electric_guitar"],
230
+ "547": ["n03272562", "electric_locomotive"], "548": ["n03290653", "entertainment_center"],
231
+ "549": ["n03291819", "envelope"], "550": ["n03297495", "espresso_maker"],
232
+ "551": ["n03314780", "face_powder"], "552": ["n03325584", "feather_boa"], "553": ["n03337140", "file"],
233
+ "554": ["n03344393", "fireboat"], "555": ["n03345487", "fire_engine"],
234
+ "556": ["n03347037", "fire_screen"], "557": ["n03355925", "flagpole"], "558": ["n03372029", "flute"],
235
+ "559": ["n03376595", "folding_chair"], "560": ["n03379051", "football_helmet"],
236
+ "561": ["n03384352", "forklift"], "562": ["n03388043", "fountain"],
237
+ "563": ["n03388183", "fountain_pen"], "564": ["n03388549", "four-poster"],
238
+ "565": ["n03393912", "freight_car"], "566": ["n03394916", "French_horn"],
239
+ "567": ["n03400231", "frying_pan"], "568": ["n03404251", "fur_coat"],
240
+ "569": ["n03417042", "garbage_truck"], "570": ["n03424325", "gasmask"],
241
+ "571": ["n03425413", "gas_pump"], "572": ["n03443371", "goblet"], "573": ["n03444034", "go-kart"],
242
+ "574": ["n03445777", "golf_ball"], "575": ["n03445924", "golfcart"], "576": ["n03447447", "gondola"],
243
+ "577": ["n03447721", "gong"], "578": ["n03450230", "gown"], "579": ["n03452741", "grand_piano"],
244
+ "580": ["n03457902", "greenhouse"], "581": ["n03459775", "grille"],
245
+ "582": ["n03461385", "grocery_store"], "583": ["n03467068", "guillotine"],
246
+ "584": ["n03476684", "hair_slide"], "585": ["n03476991", "hair_spray"],
247
+ "586": ["n03478589", "half_track"], "587": ["n03481172", "hammer"], "588": ["n03482405", "hamper"],
248
+ "589": ["n03483316", "hand_blower"], "590": ["n03485407", "hand-held_computer"],
249
+ "591": ["n03485794", "handkerchief"], "592": ["n03492542", "hard_disc"],
250
+ "593": ["n03494278", "harmonica"], "594": ["n03495258", "harp"], "595": ["n03496892", "harvester"],
251
+ "596": ["n03498962", "hatchet"], "597": ["n03527444", "holster"], "598": ["n03529860", "home_theater"],
252
+ "599": ["n03530642", "honeycomb"], "600": ["n03532672", "hook"], "601": ["n03534580", "hoopskirt"],
253
+ "602": ["n03535780", "horizontal_bar"], "603": ["n03538406", "horse_cart"],
254
+ "604": ["n03544143", "hourglass"], "605": ["n03584254", "iPod"], "606": ["n03584829", "iron"],
255
+ "607": ["n03590841", "jack-o'-lantern"], "608": ["n03594734", "jean"], "609": ["n03594945", "jeep"],
256
+ "610": ["n03595614", "jersey"], "611": ["n03598930", "jigsaw_puzzle"],
257
+ "612": ["n03599486", "jinrikisha"], "613": ["n03602883", "joystick"], "614": ["n03617480", "kimono"],
258
+ "615": ["n03623198", "knee_pad"], "616": ["n03627232", "knot"], "617": ["n03630383", "lab_coat"],
259
+ "618": ["n03633091", "ladle"], "619": ["n03637318", "lampshade"], "620": ["n03642806", "laptop"],
260
+ "621": ["n03649909", "lawn_mower"], "622": ["n03657121", "lens_cap"],
261
+ "623": ["n03658185", "letter_opener"], "624": ["n03661043", "library"],
262
+ "625": ["n03662601", "lifeboat"], "626": ["n03666591", "lighter"], "627": ["n03670208", "limousine"],
263
+ "628": ["n03673027", "liner"], "629": ["n03676483", "lipstick"], "630": ["n03680355", "Loafer"],
264
+ "631": ["n03690938", "lotion"], "632": ["n03691459", "loudspeaker"], "633": ["n03692522", "loupe"],
265
+ "634": ["n03697007", "lumbermill"], "635": ["n03706229", "magnetic_compass"],
266
+ "636": ["n03709823", "mailbag"], "637": ["n03710193", "mailbox"], "638": ["n03710637", "maillot"],
267
+ "639": ["n03710721", "maillot"], "640": ["n03717622", "manhole_cover"], "641": ["n03720891", "maraca"],
268
+ "642": ["n03721384", "marimba"], "643": ["n03724870", "mask"], "644": ["n03729826", "matchstick"],
269
+ "645": ["n03733131", "maypole"], "646": ["n03733281", "maze"], "647": ["n03733805", "measuring_cup"],
270
+ "648": ["n03742115", "medicine_chest"], "649": ["n03743016", "megalith"],
271
+ "650": ["n03759954", "microphone"], "651": ["n03761084", "microwave"],
272
+ "652": ["n03763968", "military_uniform"], "653": ["n03764736", "milk_can"],
273
+ "654": ["n03769881", "minibus"], "655": ["n03770439", "miniskirt"], "656": ["n03770679", "minivan"],
274
+ "657": ["n03773504", "missile"], "658": ["n03775071", "mitten"], "659": ["n03775546", "mixing_bowl"],
275
+ "660": ["n03776460", "mobile_home"], "661": ["n03777568", "Model_T"], "662": ["n03777754", "modem"],
276
+ "663": ["n03781244", "monastery"], "664": ["n03782006", "monitor"], "665": ["n03785016", "moped"],
277
+ "666": ["n03786901", "mortar"], "667": ["n03787032", "mortarboard"], "668": ["n03788195", "mosque"],
278
+ "669": ["n03788365", "mosquito_net"], "670": ["n03791053", "motor_scooter"],
279
+ "671": ["n03792782", "mountain_bike"], "672": ["n03792972", "mountain_tent"],
280
+ "673": ["n03793489", "mouse"], "674": ["n03794056", "mousetrap"], "675": ["n03796401", "moving_van"],
281
+ "676": ["n03803284", "muzzle"], "677": ["n03804744", "nail"], "678": ["n03814639", "neck_brace"],
282
+ "679": ["n03814906", "necklace"], "680": ["n03825788", "nipple"], "681": ["n03832673", "notebook"],
283
+ "682": ["n03837869", "obelisk"], "683": ["n03838899", "oboe"], "684": ["n03840681", "ocarina"],
284
+ "685": ["n03841143", "odometer"], "686": ["n03843555", "oil_filter"], "687": ["n03854065", "organ"],
285
+ "688": ["n03857828", "oscilloscope"], "689": ["n03866082", "overskirt"], "690": ["n03868242", "oxcart"],
286
+ "691": ["n03868863", "oxygen_mask"], "692": ["n03871628", "packet"], "693": ["n03873416", "paddle"],
287
+ "694": ["n03874293", "paddlewheel"], "695": ["n03874599", "padlock"],
288
+ "696": ["n03876231", "paintbrush"], "697": ["n03877472", "pajama"], "698": ["n03877845", "palace"],
289
+ "699": ["n03884397", "panpipe"], "700": ["n03887697", "paper_towel"], "701": ["n03888257", "parachute"],
290
+ "702": ["n03888605", "parallel_bars"], "703": ["n03891251", "park_bench"],
291
+ "704": ["n03891332", "parking_meter"], "705": ["n03895866", "passenger_car"],
292
+ "706": ["n03899768", "patio"], "707": ["n03902125", "pay-phone"], "708": ["n03903868", "pedestal"],
293
+ "709": ["n03908618", "pencil_box"], "710": ["n03908714", "pencil_sharpener"],
294
+ "711": ["n03916031", "perfume"], "712": ["n03920288", "Petri_dish"],
295
+ "713": ["n03924679", "photocopier"], "714": ["n03929660", "pick"], "715": ["n03929855", "pickelhaube"],
296
+ "716": ["n03930313", "picket_fence"], "717": ["n03930630", "pickup"], "718": ["n03933933", "pier"],
297
+ "719": ["n03935335", "piggy_bank"], "720": ["n03937543", "pill_bottle"], "721": ["n03938244", "pillow"],
298
+ "722": ["n03942813", "ping-pong_ball"], "723": ["n03944341", "pinwheel"],
299
+ "724": ["n03947888", "pirate"], "725": ["n03950228", "pitcher"], "726": ["n03954731", "plane"],
300
+ "727": ["n03956157", "planetarium"], "728": ["n03958227", "plastic_bag"],
301
+ "729": ["n03961711", "plate_rack"], "730": ["n03967562", "plow"], "731": ["n03970156", "plunger"],
302
+ "732": ["n03976467", "Polaroid_camera"], "733": ["n03976657", "pole"],
303
+ "734": ["n03977966", "police_van"], "735": ["n03980874", "poncho"], "736": ["n03982430", "pool_table"],
304
+ "737": ["n03983396", "pop_bottle"], "738": ["n03991062", "pot"], "739": ["n03992509", "potter's_wheel"],
305
+ "740": ["n03995372", "power_drill"], "741": ["n03998194", "prayer_rug"],
306
+ "742": ["n04004767", "printer"], "743": ["n04005630", "prison"], "744": ["n04008634", "projectile"],
307
+ "745": ["n04009552", "projector"], "746": ["n04019541", "puck"], "747": ["n04023962", "punching_bag"],
308
+ "748": ["n04026417", "purse"], "749": ["n04033901", "quill"], "750": ["n04033995", "quilt"],
309
+ "751": ["n04037443", "racer"], "752": ["n04039381", "racket"], "753": ["n04040759", "radiator"],
310
+ "754": ["n04041544", "radio"], "755": ["n04044716", "radio_telescope"],
311
+ "756": ["n04049303", "rain_barrel"], "757": ["n04065272", "recreational_vehicle"],
312
+ "758": ["n04067472", "reel"], "759": ["n04069434", "reflex_camera"],
313
+ "760": ["n04070727", "refrigerator"], "761": ["n04074963", "remote_control"],
314
+ "762": ["n04081281", "restaurant"], "763": ["n04086273", "revolver"], "764": ["n04090263", "rifle"],
315
+ "765": ["n04099969", "rocking_chair"], "766": ["n04111531", "rotisserie"],
316
+ "767": ["n04116512", "rubber_eraser"], "768": ["n04118538", "rugby_ball"], "769": ["n04118776", "rule"],
317
+ "770": ["n04120489", "running_shoe"], "771": ["n04125021", "safe"], "772": ["n04127249", "safety_pin"],
318
+ "773": ["n04131690", "saltshaker"], "774": ["n04133789", "sandal"], "775": ["n04136333", "sarong"],
319
+ "776": ["n04141076", "sax"], "777": ["n04141327", "scabbard"], "778": ["n04141975", "scale"],
320
+ "779": ["n04146614", "school_bus"], "780": ["n04147183", "schooner"],
321
+ "781": ["n04149813", "scoreboard"], "782": ["n04152593", "screen"], "783": ["n04153751", "screw"],
322
+ "784": ["n04154565", "screwdriver"], "785": ["n04162706", "seat_belt"],
323
+ "786": ["n04179913", "sewing_machine"], "787": ["n04192698", "shield"],
324
+ "788": ["n04200800", "shoe_shop"], "789": ["n04201297", "shoji"],
325
+ "790": ["n04204238", "shopping_basket"], "791": ["n04204347", "shopping_cart"],
326
+ "792": ["n04208210", "shovel"], "793": ["n04209133", "shower_cap"],
327
+ "794": ["n04209239", "shower_curtain"], "795": ["n04228054", "ski"], "796": ["n04229816", "ski_mask"],
328
+ "797": ["n04235860", "sleeping_bag"], "798": ["n04238763", "slide_rule"],
329
+ "799": ["n04239074", "sliding_door"], "800": ["n04243546", "slot"], "801": ["n04251144", "snorkel"],
330
+ "802": ["n04252077", "snowmobile"], "803": ["n04252225", "snowplow"],
331
+ "804": ["n04254120", "soap_dispenser"], "805": ["n04254680", "soccer_ball"],
332
+ "806": ["n04254777", "sock"], "807": ["n04258138", "solar_dish"], "808": ["n04259630", "sombrero"],
333
+ "809": ["n04263257", "soup_bowl"], "810": ["n04264628", "space_bar"],
334
+ "811": ["n04265275", "space_heater"], "812": ["n04266014", "space_shuttle"],
335
+ "813": ["n04270147", "spatula"], "814": ["n04273569", "speedboat"], "815": ["n04275548", "spider_web"],
336
+ "816": ["n04277352", "spindle"], "817": ["n04285008", "sports_car"], "818": ["n04286575", "spotlight"],
337
+ "819": ["n04296562", "stage"], "820": ["n04310018", "steam_locomotive"],
338
+ "821": ["n04311004", "steel_arch_bridge"], "822": ["n04311174", "steel_drum"],
339
+ "823": ["n04317175", "stethoscope"], "824": ["n04325704", "stole"], "825": ["n04326547", "stone_wall"],
340
+ "826": ["n04328186", "stopwatch"], "827": ["n04330267", "stove"], "828": ["n04332243", "strainer"],
341
+ "829": ["n04335435", "streetcar"], "830": ["n04336792", "stretcher"],
342
+ "831": ["n04344873", "studio_couch"], "832": ["n04346328", "stupa"], "833": ["n04347754", "submarine"],
343
+ "834": ["n04350905", "suit"], "835": ["n04355338", "sundial"], "836": ["n04355933", "sunglass"],
344
+ "837": ["n04356056", "sunglasses"], "838": ["n04357314", "sunscreen"],
345
+ "839": ["n04366367", "suspension_bridge"], "840": ["n04367480", "swab"],
346
+ "841": ["n04370456", "sweatshirt"], "842": ["n04371430", "swimming_trunks"],
347
+ "843": ["n04371774", "swing"], "844": ["n04372370", "switch"], "845": ["n04376876", "syringe"],
348
+ "846": ["n04380533", "table_lamp"], "847": ["n04389033", "tank"], "848": ["n04392985", "tape_player"],
349
+ "849": ["n04398044", "teapot"], "850": ["n04399382", "teddy"], "851": ["n04404412", "television"],
350
+ "852": ["n04409515", "tennis_ball"], "853": ["n04417672", "thatch"],
351
+ "854": ["n04418357", "theater_curtain"], "855": ["n04423845", "thimble"],
352
+ "856": ["n04428191", "thresher"], "857": ["n04429376", "throne"], "858": ["n04435653", "tile_roof"],
353
+ "859": ["n04442312", "toaster"], "860": ["n04443257", "tobacco_shop"],
354
+ "861": ["n04447861", "toilet_seat"], "862": ["n04456115", "torch"], "863": ["n04458633", "totem_pole"],
355
+ "864": ["n04461696", "tow_truck"], "865": ["n04462240", "toyshop"], "866": ["n04465501", "tractor"],
356
+ "867": ["n04467665", "trailer_truck"], "868": ["n04476259", "tray"],
357
+ "869": ["n04479046", "trench_coat"], "870": ["n04482393", "tricycle"], "871": ["n04483307", "trimaran"],
358
+ "872": ["n04485082", "tripod"], "873": ["n04486054", "triumphal_arch"],
359
+ "874": ["n04487081", "trolleybus"], "875": ["n04487394", "trombone"], "876": ["n04493381", "tub"],
360
+ "877": ["n04501370", "turnstile"], "878": ["n04505470", "typewriter_keyboard"],
361
+ "879": ["n04507155", "umbrella"], "880": ["n04509417", "unicycle"], "881": ["n04515003", "upright"],
362
+ "882": ["n04517823", "vacuum"], "883": ["n04522168", "vase"], "884": ["n04523525", "vault"],
363
+ "885": ["n04525038", "velvet"], "886": ["n04525305", "vending_machine"],
364
+ "887": ["n04532106", "vestment"], "888": ["n04532670", "viaduct"], "889": ["n04536866", "violin"],
365
+ "890": ["n04540053", "volleyball"], "891": ["n04542943", "waffle_iron"],
366
+ "892": ["n04548280", "wall_clock"], "893": ["n04548362", "wallet"], "894": ["n04550184", "wardrobe"],
367
+ "895": ["n04552348", "warplane"], "896": ["n04553703", "washbasin"], "897": ["n04554684", "washer"],
368
+ "898": ["n04557648", "water_bottle"], "899": ["n04560804", "water_jug"],
369
+ "900": ["n04562935", "water_tower"], "901": ["n04579145", "whiskey_jug"],
370
+ "902": ["n04579432", "whistle"], "903": ["n04584207", "wig"], "904": ["n04589890", "window_screen"],
371
+ "905": ["n04590129", "window_shade"], "906": ["n04591157", "Windsor_tie"],
372
+ "907": ["n04591713", "wine_bottle"], "908": ["n04592741", "wing"], "909": ["n04596742", "wok"],
373
+ "910": ["n04597913", "wooden_spoon"], "911": ["n04599235", "wool"], "912": ["n04604644", "worm_fence"],
374
+ "913": ["n04606251", "wreck"], "914": ["n04612504", "yawl"], "915": ["n04613696", "yurt"],
375
+ "916": ["n06359193", "web_site"], "917": ["n06596364", "comic_book"],
376
+ "918": ["n06785654", "crossword_puzzle"], "919": ["n06794110", "street_sign"],
377
+ "920": ["n06874185", "traffic_light"], "921": ["n07248320", "book_jacket"],
378
+ "922": ["n07565083", "menu"], "923": ["n07579787", "plate"], "924": ["n07583066", "guacamole"],
379
+ "925": ["n07584110", "consomme"], "926": ["n07590611", "hot_pot"], "927": ["n07613480", "trifle"],
380
+ "928": ["n07614500", "ice_cream"], "929": ["n07615774", "ice_lolly"],
381
+ "930": ["n07684084", "French_loaf"], "931": ["n07693725", "bagel"], "932": ["n07695742", "pretzel"],
382
+ "933": ["n07697313", "cheeseburger"], "934": ["n07697537", "hotdog"],
383
+ "935": ["n07711569", "mashed_potato"], "936": ["n07714571", "head_cabbage"],
384
+ "937": ["n07714990", "broccoli"], "938": ["n07715103", "cauliflower"], "939": ["n07716358", "zucchini"],
385
+ "940": ["n07716906", "spaghetti_squash"], "941": ["n07717410", "acorn_squash"],
386
+ "942": ["n07717556", "butternut_squash"], "943": ["n07718472", "cucumber"],
387
+ "944": ["n07718747", "artichoke"], "945": ["n07720875", "bell_pepper"], "946": ["n07730033", "cardoon"],
388
+ "947": ["n07734744", "mushroom"], "948": ["n07742313", "Granny_Smith"],
389
+ "949": ["n07745940", "strawberry"], "950": ["n07747607", "orange"], "951": ["n07749582", "lemon"],
390
+ "952": ["n07753113", "fig"], "953": ["n07753275", "pineapple"], "954": ["n07753592", "banana"],
391
+ "955": ["n07754684", "jackfruit"], "956": ["n07760859", "custard_apple"],
392
+ "957": ["n07768694", "pomegranate"], "958": ["n07802026", "hay"], "959": ["n07831146", "carbonara"],
393
+ "960": ["n07836838", "chocolate_sauce"], "961": ["n07860988", "dough"],
394
+ "962": ["n07871810", "meat_loaf"], "963": ["n07873807", "pizza"], "964": ["n07875152", "potpie"],
395
+ "965": ["n07880968", "burrito"], "966": ["n07892512", "red_wine"], "967": ["n07920052", "espresso"],
396
+ "968": ["n07930864", "cup"], "969": ["n07932039", "eggnog"], "970": ["n09193705", "alp"],
397
+ "971": ["n09229709", "bubble"], "972": ["n09246464", "cliff"], "973": ["n09256479", "coral_reef"],
398
+ "974": ["n09288635", "geyser"], "975": ["n09332890", "lakeside"], "976": ["n09399592", "promontory"],
399
+ "977": ["n09421951", "sandbar"], "978": ["n09428293", "seashore"], "979": ["n09468604", "valley"],
400
+ "980": ["n09472597", "volcano"], "981": ["n09835506", "ballplayer"], "982": ["n10148035", "groom"],
401
+ "983": ["n10565667", "scuba_diver"], "984": ["n11879895", "rapeseed"], "985": ["n11939491", "daisy"],
402
+ "986": ["n12057211", "yellow_lady's_slipper"], "987": ["n12144580", "corn"],
403
+ "988": ["n12267677", "acorn"], "989": ["n12620546", "hip"], "990": ["n12768682", "buckeye"],
404
+ "991": ["n12985857", "coral_fungus"], "992": ["n12998815", "agaric"], "993": ["n13037406", "gyromitra"],
405
+ "994": ["n13040303", "stinkhorn"], "995": ["n13044778", "earthstar"],
406
+ "996": ["n13052670", "hen-of-the-woods"], "997": ["n13054560", "bolete"], "998": ["n13133613", "ear"],
407
+ "999": ["n15075141", "toilet_tissue"]}
408
+
imgs/img1.jpg ADDED
losses/clip_loss.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import numpy as np
4
+
5
+ import clip
6
+ from PIL import Image
7
+
8
+ from utils.text_templates import imagenet_templates, part_templates, imagenet_templates_small
9
+
10
+
11
+ class DirectionLoss(torch.nn.Module):
12
+
13
+ def __init__(self, loss_type='mse'):
14
+ super(DirectionLoss, self).__init__()
15
+
16
+ self.loss_type = loss_type
17
+
18
+ self.loss_func = {
19
+ 'mse': torch.nn.MSELoss,
20
+ 'cosine': torch.nn.CosineSimilarity,
21
+ 'mae': torch.nn.L1Loss
22
+ }[loss_type]()
23
+
24
+ def forward(self, x, y):
25
+ if self.loss_type == "cosine":
26
+ return 1. - self.loss_func(x, y)
27
+
28
+ return self.loss_func(x, y)
29
+
30
+ class CLIPLoss(torch.nn.Module):
31
+ def __init__(self, device, lambda_direction=1., lambda_patch=0., lambda_global=0., lambda_manifold=0., lambda_texture=0., patch_loss_type='mae', direction_loss_type='cosine', clip_model='ViT-B/16'):
32
+ super(CLIPLoss, self).__init__()
33
+
34
+ self.device = device
35
+ self.model, clip_preprocess = clip.load(clip_model, device=self.device)
36
+
37
+ self.clip_preprocess = clip_preprocess
38
+
39
+ self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
40
+ clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
41
+ clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
42
+
43
+ self.target_direction = None
44
+ self.patch_text_directions = None
45
+
46
+ self.patch_loss = DirectionLoss(patch_loss_type)
47
+ self.direction_loss = DirectionLoss(direction_loss_type)
48
+ self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2)
49
+
50
+ self.lambda_global = lambda_global
51
+ self.lambda_patch = lambda_patch
52
+ self.lambda_direction = lambda_direction
53
+ self.lambda_manifold = lambda_manifold
54
+ self.lambda_texture = lambda_texture
55
+
56
+ self.src_text_features = None
57
+ self.target_text_features = None
58
+ self.angle_loss = torch.nn.L1Loss()
59
+
60
+ self.model_cnn, preprocess_cnn = clip.load("RN50", device=self.device)
61
+ self.preprocess_cnn = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1].
62
+ preprocess_cnn.transforms[:2] + # to match CLIP input scale assumptions
63
+ preprocess_cnn.transforms[4:]) # + skip convert PIL to tensor
64
+
65
+ self.texture_loss = torch.nn.MSELoss()
66
+
67
+ def tokenize(self, strings: list):
68
+ return clip.tokenize(strings).to(self.device)
69
+
70
+ def encode_text(self, tokens: list) -> torch.Tensor:
71
+ return self.model.encode_text(tokens)
72
+
73
+ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
74
+ images = self.preprocess(images).to(self.device)
75
+ return self.model.encode_image(images)
76
+
77
+ def encode_images_with_cnn(self, images: torch.Tensor) -> torch.Tensor:
78
+ images = self.preprocess_cnn(images).to(self.device)
79
+ return self.model_cnn.encode_image(images)
80
+
81
+ def distance_with_templates(self, img: torch.Tensor, class_str: str, templates=imagenet_templates) -> torch.Tensor:
82
+
83
+ text_features = self.get_text_features(class_str, templates)
84
+ image_features = self.get_image_features(img)
85
+
86
+ similarity = image_features @ text_features.T
87
+
88
+ return 1. - similarity
89
+
90
+ def get_text_features(self, class_str: str, templates=imagenet_templates, norm: bool = True) -> torch.Tensor:
91
+ template_text = self.compose_text_with_templates(class_str, templates)
92
+
93
+ tokens = clip.tokenize(template_text).to(self.device)
94
+
95
+ text_features = self.encode_text(tokens).detach()
96
+
97
+ if norm:
98
+ text_features /= text_features.norm(dim=-1, keepdim=True)
99
+
100
+ return text_features
101
+
102
+ def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
103
+ image_features = self.encode_images(img)
104
+
105
+ if norm:
106
+ image_features /= image_features.clone().norm(dim=-1, keepdim=True)
107
+
108
+ return image_features
109
+
110
+ def compute_text_direction(self, source_class: str, target_class: str) -> torch.Tensor:
111
+ source_features = self.get_text_features(source_class)
112
+ target_features = self.get_text_features(target_class)
113
+
114
+ text_direction = (target_features - source_features).mean(axis=0, keepdim=True)
115
+ text_direction /= text_direction.norm(dim=-1, keepdim=True)
116
+
117
+ return text_direction
118
+
119
+ def compute_img2img_direction(self, source_images: torch.Tensor, target_images: list) -> torch.Tensor:
120
+ with torch.no_grad():
121
+
122
+ src_encoding = self.get_image_features(source_images)
123
+ src_encoding = src_encoding.mean(dim=0, keepdim=True)
124
+
125
+ target_encodings = []
126
+ for target_img in target_images:
127
+ preprocessed = self.clip_preprocess(Image.open(target_img)).unsqueeze(0).to(self.device)
128
+
129
+ encoding = self.model.encode_image(preprocessed)
130
+ encoding /= encoding.norm(dim=-1, keepdim=True)
131
+
132
+ target_encodings.append(encoding)
133
+
134
+ target_encoding = torch.cat(target_encodings, axis=0)
135
+ target_encoding = target_encoding.mean(dim=0, keepdim=True)
136
+
137
+ direction = target_encoding - src_encoding
138
+ direction /= direction.norm(dim=-1, keepdim=True)
139
+
140
+ return direction
141
+
142
+ def set_text_features(self, source_class: str, target_class: str) -> None:
143
+ source_features = self.get_text_features(source_class).mean(axis=0, keepdim=True)
144
+ self.src_text_features = source_features / source_features.norm(dim=-1, keepdim=True)
145
+
146
+ target_features = self.get_text_features(target_class).mean(axis=0, keepdim=True)
147
+ self.target_text_features = target_features / target_features.norm(dim=-1, keepdim=True)
148
+
149
+ def clip_angle_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
150
+ if self.src_text_features is None:
151
+ self.set_text_features(source_class, target_class)
152
+
153
+ cos_text_angle = self.target_text_features @ self.src_text_features.T
154
+ text_angle = torch.acos(cos_text_angle)
155
+
156
+ src_img_features = self.get_image_features(src_img).unsqueeze(2)
157
+ target_img_features = self.get_image_features(target_img).unsqueeze(1)
158
+
159
+ cos_img_angle = torch.clamp(target_img_features @ src_img_features, min=-1.0, max=1.0)
160
+ img_angle = torch.acos(cos_img_angle)
161
+
162
+ text_angle = text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1)
163
+ cos_text_angle = cos_text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1)
164
+
165
+ return self.angle_loss(cos_img_angle, cos_text_angle)
166
+
167
+ def compose_text_with_templates(self, text: str, templates=imagenet_templates) -> list:
168
+ return [template.format(text) for template in templates]
169
+
170
+ def clip_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
171
+
172
+ if self.target_direction is None:
173
+ self.target_direction = self.compute_text_direction(source_class, target_class)
174
+
175
+ src_encoding = self.get_image_features(src_img)
176
+ target_encoding = self.get_image_features(target_img)
177
+
178
+ edit_direction = (target_encoding - src_encoding)
179
+ edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7)
180
+ return self.direction_loss(edit_direction, self.target_direction).mean()
181
+
182
+ def global_clip_loss(self, img: torch.Tensor, text) -> torch.Tensor:
183
+ if not isinstance(text, list):
184
+ text = [text]
185
+
186
+ tokens = clip.tokenize(text).to(self.device)
187
+ image = self.preprocess(img)
188
+
189
+ logits_per_image, _ = self.model(image, tokens)
190
+
191
+ return (1. - logits_per_image / 100).mean()
192
+
193
+ def random_patch_centers(self, img_shape, num_patches, size):
194
+ batch_size, channels, height, width = img_shape
195
+
196
+ half_size = size // 2
197
+ patch_centers = np.concatenate([np.random.randint(half_size, width - half_size, size=(batch_size * num_patches, 1)),
198
+ np.random.randint(half_size, height - half_size, size=(batch_size * num_patches, 1))], axis=1)
199
+
200
+ return patch_centers
201
+
202
+ def generate_patches(self, img: torch.Tensor, patch_centers, size):
203
+ batch_size = img.shape[0]
204
+ num_patches = len(patch_centers) // batch_size
205
+ half_size = size // 2
206
+
207
+ patches = []
208
+
209
+ for batch_idx in range(batch_size):
210
+ for patch_idx in range(num_patches):
211
+
212
+ center_x = patch_centers[batch_idx * num_patches + patch_idx][0]
213
+ center_y = patch_centers[batch_idx * num_patches + patch_idx][1]
214
+
215
+ patch = img[batch_idx:batch_idx+1, :, center_y - half_size:center_y + half_size, center_x - half_size:center_x + half_size]
216
+
217
+ patches.append(patch)
218
+
219
+ patches = torch.cat(patches, axis=0)
220
+
221
+ return patches
222
+
223
+ def patch_scores(self, img: torch.Tensor, class_str: str, patch_centers, patch_size: int) -> torch.Tensor:
224
+
225
+ parts = self.compose_text_with_templates(class_str, part_templates)
226
+ tokens = clip.tokenize(parts).to(self.device)
227
+ text_features = self.encode_text(tokens).detach()
228
+
229
+ patches = self.generate_patches(img, patch_centers, patch_size)
230
+ image_features = self.get_image_features(patches)
231
+
232
+ similarity = image_features @ text_features.T
233
+
234
+ return similarity
235
+
236
+ def clip_patch_similarity(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
237
+ patch_size = 196 #TODO remove magic number
238
+
239
+ patch_centers = self.random_patch_centers(src_img.shape, 4, patch_size) #TODO remove magic number
240
+
241
+ src_scores = self.patch_scores(src_img, source_class, patch_centers, patch_size)
242
+ target_scores = self.patch_scores(target_img, target_class, patch_centers, patch_size)
243
+
244
+ return self.patch_loss(src_scores, target_scores)
245
+
246
+ def patch_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor:
247
+
248
+ if self.patch_text_directions is None:
249
+ src_part_classes = self.compose_text_with_templates(source_class, part_templates)
250
+ target_part_classes = self.compose_text_with_templates(target_class, part_templates)
251
+
252
+ parts_classes = list(zip(src_part_classes, target_part_classes))
253
+
254
+ self.patch_text_directions = torch.cat([self.compute_text_direction(pair[0], pair[1]) for pair in parts_classes], dim=0)
255
+
256
+ patch_size = 510 # TODO remove magic numbers
257
+
258
+ patch_centers = self.random_patch_centers(src_img.shape, 1, patch_size)
259
+
260
+ patches = self.generate_patches(src_img, patch_centers, patch_size)
261
+ src_features = self.get_image_features(patches)
262
+
263
+ patches = self.generate_patches(target_img, patch_centers, patch_size)
264
+ target_features = self.get_image_features(patches)
265
+
266
+ edit_direction = (target_features - src_features)
267
+ edit_direction /= edit_direction.clone().norm(dim=-1, keepdim=True)
268
+
269
+ cosine_dists = 1. - self.patch_direction_loss(edit_direction.unsqueeze(1), self.patch_text_directions.unsqueeze(0))
270
+
271
+ patch_class_scores = cosine_dists * (edit_direction @ self.patch_text_directions.T).softmax(dim=-1)
272
+
273
+ return patch_class_scores.mean()
274
+
275
+ def cnn_feature_loss(self, src_img: torch.Tensor, target_img: torch.Tensor) -> torch.Tensor:
276
+ src_features = self.encode_images_with_cnn(src_img)
277
+ target_features = self.encode_images_with_cnn(target_img)
278
+
279
+ return self.texture_loss(src_features, target_features)
280
+
281
+ def forward(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str, texture_image: torch.Tensor = None):
282
+ clip_loss = 0.0
283
+
284
+ if self.lambda_global:
285
+ clip_loss += self.lambda_global * self.global_clip_loss(target_img, [f"a {target_class}"])
286
+
287
+ if self.lambda_patch:
288
+ clip_loss += self.lambda_patch * self.patch_directional_loss(src_img, source_class, target_img, target_class)
289
+
290
+ if self.lambda_direction:
291
+ clip_loss += self.lambda_direction * self.clip_directional_loss(src_img, source_class, target_img, target_class)
292
+
293
+ if self.lambda_manifold:
294
+ clip_loss += self.lambda_manifold * self.clip_angle_loss(src_img, source_class, target_img, target_class)
295
+
296
+ if self.lambda_texture and (texture_image is not None):
297
+ clip_loss += self.lambda_texture * self.cnn_feature_loss(texture_image, target_img)
298
+
299
+ return clip_loss
losses/id_loss.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from configs.paths_config import MODEL_PATHS
4
+ from models.insight_face.model_irse import Backbone, MobileFaceNet
5
+
6
+
7
+ class IDLoss(nn.Module):
8
+ def __init__(self, use_mobile_id=False):
9
+ super(IDLoss, self).__init__()
10
+ print('Loading ResNet ArcFace')
11
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12
+ self.facenet.load_state_dict(torch.load(MODEL_PATHS['ir_se50']))
13
+
14
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
15
+ self.facenet.eval()
16
+
17
+ def extract_feats(self, x):
18
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
19
+ x = self.face_pool(x)
20
+ x_feats = self.facenet(x)
21
+ return x_feats
22
+
23
+ def forward(self, x, x_hat):
24
+ n_samples = x.shape[0]
25
+ x_feats = self.extract_feats(x)
26
+ x_feats = x_feats.detach()
27
+
28
+ x_hat_feats = self.extract_feats(x_hat)
29
+ losses = []
30
+ for i in range(n_samples):
31
+ loss_sample = 1 - x_hat_feats[i].dot(x_feats[i])
32
+ losses.append(loss_sample.unsqueeze(0))
33
+
34
+ losses = torch.cat(losses, dim=0)
35
+ return losses
main.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import traceback
3
+ import logging
4
+ import yaml
5
+ import sys
6
+ import os
7
+ import torch
8
+ import numpy as np
9
+
10
+ from boundarydiffusion import BoundaryDiffusion
11
+ from configs.paths_config import HYBRID_MODEL_PATHS
12
+
13
+ def parse_args_and_config():
14
+ parser = argparse.ArgumentParser(description=globals()['__doc__'])
15
+
16
+ # Mode
17
+ parser.add_argument('--radius', action='store_true')
18
+ parser.add_argument('--unconditional', action='store_true')
19
+ parser.add_argument('--boundary_search', action='store_true')
20
+ parser.add_argument('--diffusion_hyperplane', action='store_true')
21
+ parser.add_argument('--clip_finetune', action='store_true')
22
+ parser.add_argument('--clip_latent_optim', action='store_true')
23
+ parser.add_argument('--edit_images_from_dataset', action='store_true')
24
+ parser.add_argument('--edit_one_image', action='store_true')
25
+ parser.add_argument('--unseen2unseen', action='store_true')
26
+ parser.add_argument('--clip_finetune_eff', action='store_true')
27
+ parser.add_argument('--edit_one_image_eff', action='store_true')
28
+ parser.add_argument('--edit_image_boundary', action='store_true')
29
+
30
+ # Default
31
+ parser.add_argument('--config', type=str, required=True, help='Path to the config file')
32
+ parser.add_argument('--seed', type=int, default=1006, help='Random seed')
33
+ parser.add_argument('--exp', type=str, default='./runs/', help='Path for saving running related data.')
34
+ parser.add_argument('--comment', type=str, default='', help='A string for experiment comment')
35
+ parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
36
+ parser.add_argument('--ni', type=int, default=1, help="No interaction. Suitable for Slurm Job launcher")
37
+ parser.add_argument('--align_face', type=int, default=1, help='align face or not')
38
+
39
+ # Text
40
+ parser.add_argument('--edit_attr', type=str, default=None, help='Attribute to edit defiend in ./utils/text_dic.py')
41
+ parser.add_argument('--src_txts', type=str, action='append', help='Source text e.g. Face')
42
+ parser.add_argument('--trg_txts', type=str, action='append', help='Target text e.g. Angry Face')
43
+ parser.add_argument('--target_class_num', type=str, default=None)
44
+
45
+ # Sampling
46
+ parser.add_argument('--t_0', type=int, default=400, help='Return step in [0, 1000)')
47
+ parser.add_argument('--n_inv_step', type=int, default=40, help='# of steps during generative pross for inversion')
48
+ parser.add_argument('--n_train_step', type=int, default=6, help='# of steps during generative pross for train')
49
+ parser.add_argument('--n_test_step', type=int, default=40, help='# of steps during generative pross for test')
50
+ parser.add_argument('--sample_type', type=str, default='ddim', help='ddpm for Markovian sampling, ddim for non-Markovian sampling')
51
+ parser.add_argument('--eta', type=float, default=0.0, help='Controls of varaince of the generative process')
52
+ parser.add_argument('--start_distance', type=float, default=-150.0, help='Starting distance of the editing space')
53
+ parser.add_argument('--end_distance', type=float, default=150.0, help='Ending distance of the editing space')
54
+ parser.add_argument('--edit_img_number', type=int, default=20, help='Number of editing images')
55
+
56
+ # Train & Test
57
+ parser.add_argument('--do_train', type=int, default=1, help='Whether to train or not during CLIP finetuning')
58
+ parser.add_argument('--do_test', type=int, default=1, help='Whether to test or not during CLIP finetuning')
59
+ parser.add_argument('--save_train_image', type=int, default=1, help='Wheter to save training results during CLIP fineuning')
60
+ parser.add_argument('--bs_train', type=int, default=1, help='Training batch size during CLIP fineuning')
61
+ parser.add_argument('--bs_test', type=int, default=1, help='Test batch size during CLIP fineuning')
62
+ parser.add_argument('--n_precomp_img', type=int, default=100, help='# of images to precompute latents')
63
+ parser.add_argument('--n_train_img', type=int, default=50, help='# of training images')
64
+ parser.add_argument('--n_test_img', type=int, default=10, help='# of test images')
65
+ parser.add_argument('--model_path', type=str, default=None, help='Test model path')
66
+ parser.add_argument('--img_path', type=str, default=None, help='Image path to test')
67
+ parser.add_argument('--deterministic_inv', type=int, default=1, help='Whether to use deterministic inversion during inference')
68
+ parser.add_argument('--hybrid_noise', type=int, default=0, help='Whether to change multiple attributes by mixing multiple models')
69
+ parser.add_argument('--model_ratio', type=float, default=1, help='Degree of change, noise ratio from original and finetuned model.')
70
+
71
+
72
+ # Loss & Optimization
73
+ parser.add_argument('--clip_loss_w', type=int, default=0, help='Weights of CLIP loss')
74
+ parser.add_argument('--l1_loss_w', type=float, default=0, help='Weights of L1 loss')
75
+ parser.add_argument('--id_loss_w', type=float, default=0, help='Weights of ID loss')
76
+ parser.add_argument('--clip_model_name', type=str, default='ViT-B/16', help='ViT-B/16, ViT-B/32, RN50x16 etc')
77
+ parser.add_argument('--lr_clip_finetune', type=float, default=2e-6, help='Initial learning rate for finetuning')
78
+ parser.add_argument('--lr_clip_lat_opt', type=float, default=2e-2, help='Initial learning rate for latent optim')
79
+ parser.add_argument('--n_iter', type=int, default=1, help='# of iterations of a generative process with `n_train_img` images')
80
+ parser.add_argument('--scheduler', type=int, default=1, help='Whether to increase the learning rate')
81
+ parser.add_argument('--sch_gamma', type=float, default=1.3, help='Scheduler gamma')
82
+
83
+ args = parser.parse_args()
84
+
85
+ # parse config file
86
+ with open(os.path.join('configs', args.config), 'r') as f:
87
+ config = yaml.safe_load(f)
88
+ new_config = dict2namespace(config)
89
+
90
+ if args.diffusion_hyperplane:
91
+ if args.edit_attr is not None:
92
+ args.exp = args.exp + f'_SP_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
93
+ else:
94
+ args.exp = args.exp + f'_SP_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
95
+ elif args.radius:
96
+ if args.edit_attr is not None:
97
+ args.exp = args.exp + f'_R_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
98
+ else:
99
+ args.exp = args.exp + f'_R_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
100
+ elif args.unconditional:
101
+ if args.edit_attr is not None:
102
+ args.exp = args.exp + f'_UN_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
103
+ else:
104
+ args.exp = args.exp + f'_UN_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
105
+ elif args.boundary_search:
106
+ if args.edit_attr is not None:
107
+ args.exp = args.exp + f'_BCLIP_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
108
+ else:
109
+ args.exp = args.exp + f'_BCLIP_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
110
+ elif args.clip_finetune or args.clip_finetune_eff :
111
+ if args.edit_attr is not None:
112
+ args.exp = args.exp + f'_FT_{new_config.data.category}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
113
+ else:
114
+ args.exp = args.exp + f'_FT_{new_config.data.category}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_finetune}'
115
+ elif args.clip_latent_optim:
116
+ if args.edit_attr is not None:
117
+ args.exp = args.exp + f'_LO_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_{args.edit_attr}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_lat_opt}'
118
+ else:
119
+ args.exp = args.exp + f'_LO_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_{args.trg_txts}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_id{args.id_loss_w}_l1{args.l1_loss_w}_lr{args.lr_clip_lat_opt}'
120
+ elif args.edit_images_from_dataset:
121
+ if args.model_path:
122
+ args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_{os.path.split(args.model_path)[-1].replace(".pth","")}'
123
+ elif args.hybrid_noise:
124
+ hb_str = '_'
125
+ for i, model_name in enumerate(HYBRID_MODEL_PATHS):
126
+ hb_str = hb_str + model_name.split('_')[1]
127
+ if i != len(HYBRID_MODEL_PATHS) - 1:
128
+ hb_str = hb_str + '_'
129
+ args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}' + hb_str
130
+ else:
131
+ args.exp = args.exp + f'_ED_{new_config.data.category}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}_orig'
132
+
133
+ elif args.edit_image_boundary:
134
+ if args.model_path:
135
+ args.exp = args.exp + f'_E1_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
136
+ elif args.hybrid_noise:
137
+ hb_str = '_'
138
+ for i, model_name in enumerate(HYBRID_MODEL_PATHS):
139
+ hb_str = hb_str + model_name.split('_')[1]
140
+ if i != len(HYBRID_MODEL_PATHS) - 1:
141
+ hb_str = hb_str + '_'
142
+ args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' + hb_str
143
+ else:
144
+ args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_orig'
145
+
146
+
147
+ elif args.edit_one_image:
148
+ if args.model_path:
149
+ args.exp = args.exp + f'_E1_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
150
+ elif args.hybrid_noise:
151
+ hb_str = '_'
152
+ for i, model_name in enumerate(HYBRID_MODEL_PATHS):
153
+ hb_str = hb_str + model_name.split('_')[1]
154
+ if i != len(HYBRID_MODEL_PATHS) - 1:
155
+ hb_str = hb_str + '_'
156
+ args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}' + hb_str
157
+ else:
158
+ args.exp = args.exp + f'_E1_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_orig'
159
+
160
+ elif args.unseen2unseen:
161
+ if args.model_path:
162
+ args.exp = args.exp + f'_U2U_t{args.t_0}_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_inv_step}_ngen{args.n_train_step}_{os.path.split(args.model_path)[-1].replace(".pth", "")}'
163
+ elif args.hybrid_noise:
164
+ hb_str = '_'
165
+ for i, model_name in enumerate(HYBRID_MODEL_PATHS):
166
+ hb_str = hb_str + model_name.split('_')[1]
167
+ if i != len(HYBRID_MODEL_PATHS) - 1:
168
+ hb_str = hb_str + '_'
169
+ args.exp = args.exp + f'_U2U_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}' + hb_str
170
+ else:
171
+ args.exp = args.exp + f'_U2U_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}_ngen{args.n_train_step}_orig'
172
+
173
+ elif args.recon_exp:
174
+ args.exp = args.exp + f'_REC_{new_config.data.category}_{args.img_path.split("/")[-1].split(".")[0]}_t{args.t_0}_ninv{args.n_train_step}'
175
+ elif args.find_best_image:
176
+ args.exp = args.exp + f'_FOpt_{new_config.data.category}_{args.trg_txts[0]}_t{args.t_0}_ninv{args.n_train_step}'
177
+
178
+
179
+ level = getattr(logging, args.verbose.upper(), None)
180
+ if not isinstance(level, int):
181
+ raise ValueError('level {} not supported'.format(args.verbose))
182
+
183
+ handler1 = logging.StreamHandler()
184
+ formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
185
+ handler1.setFormatter(formatter)
186
+ logger = logging.getLogger()
187
+ logger.addHandler(handler1)
188
+ logger.setLevel(level)
189
+
190
+ os.makedirs(args.exp, exist_ok=True)
191
+ os.makedirs('checkpoint', exist_ok=True)
192
+ os.makedirs('precomputed', exist_ok=True)
193
+ os.makedirs('runs', exist_ok=True)
194
+ os.makedirs(args.exp, exist_ok=True)
195
+ args.image_folder = os.path.join(args.exp, 'image_samples')
196
+ if not os.path.exists(args.image_folder):
197
+ os.makedirs(args.image_folder)
198
+ else:
199
+ overwrite = False
200
+ if args.ni:
201
+ overwrite = True
202
+ else:
203
+ response = input("Image folder already exists. Overwrite? (Y/N)")
204
+ if response.upper() == 'Y':
205
+ overwrite = True
206
+
207
+ if overwrite:
208
+ # shutil.rmtree(args.image_folder)
209
+ os.makedirs(args.image_folder, exist_ok=True)
210
+ else:
211
+ print("Output image folder exists. Program halted.")
212
+ sys.exit(0)
213
+
214
+ # add device
215
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
216
+ logging.info("Using device: {}".format(device))
217
+ new_config.device = device
218
+
219
+ # set random seed
220
+ torch.manual_seed(args.seed)
221
+ np.random.seed(args.seed)
222
+ if torch.cuda.is_available():
223
+ torch.cuda.manual_seed_all(args.seed)
224
+
225
+ torch.backends.cudnn.benchmark = True
226
+
227
+ return args, new_config
228
+
229
+
230
+ def dict2namespace(config):
231
+ namespace = argparse.Namespace()
232
+ for key, value in config.items():
233
+ if isinstance(value, dict):
234
+ new_value = dict2namespace(value)
235
+ else:
236
+ new_value = value
237
+ setattr(namespace, key, new_value)
238
+ return namespace
239
+
240
+
241
+ def main():
242
+ args, config = parse_args_and_config()
243
+ print(">" * 80)
244
+ logging.info("Exp instance id = {}".format(os.getpid()))
245
+ logging.info("Exp comment = {}".format(args.comment))
246
+ logging.info("Config =")
247
+ print("<" * 80)
248
+
249
+
250
+ runner = BoundaryDiffusion(args, config)
251
+ try:
252
+ if args.clip_finetune:
253
+ runner.clip_finetune()
254
+ elif args.radius:
255
+ runner.radius()
256
+ elif args.unconditional:
257
+ runner.unconditional()
258
+ elif args.diffusion_hyperplane:
259
+ runner.diffusion_hyperplane()
260
+ elif args.boundary_search:
261
+ runner.boundary_search()
262
+ elif args.edit_image_boundary:
263
+ runner.edit_image_boundary()
264
+ else:
265
+ print('Choose one mode!')
266
+ raise ValueError
267
+ except Exception:
268
+ logging.error(traceback.format_exc())
269
+
270
+
271
+ return 0
272
+
273
+
274
+ if __name__ == '__main__':
275
+ sys.exit(main())
models/ddpm/diffusion.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ def get_timestep_embedding(timesteps, embedding_dim):
7
+ """
8
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
9
+ From Fairseq.
10
+ Build sinusoidal embeddings.
11
+ This matches the implementation in tensor2tensor, but differs slightly
12
+ from the description in Section 3.5 of "Attention Is All You Need".
13
+ """
14
+ assert len(timesteps.shape) == 1
15
+
16
+ half_dim = embedding_dim // 2
17
+ emb = math.log(10000) / (half_dim - 1)
18
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
19
+ emb = emb.to(device=timesteps.device)
20
+ emb = timesteps.float()[:, None] * emb[None, :]
21
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
22
+ if embedding_dim % 2 == 1: # zero pad
23
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
24
+ return emb
25
+
26
+
27
+ def nonlinearity(x):
28
+ # swish
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ def Normalize(in_channels):
33
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34
+
35
+
36
+ class Upsample(nn.Module):
37
+ def __init__(self, in_channels, with_conv):
38
+ super().__init__()
39
+ self.with_conv = with_conv
40
+ if self.with_conv:
41
+ self.conv = torch.nn.Conv2d(in_channels,
42
+ in_channels,
43
+ kernel_size=3,
44
+ stride=1,
45
+ padding=1)
46
+
47
+ def forward(self, x):
48
+ x = torch.nn.functional.interpolate(
49
+ x, scale_factor=2.0, mode="nearest")
50
+ if self.with_conv:
51
+ x = self.conv(x)
52
+ return x
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, in_channels, with_conv):
57
+ super().__init__()
58
+ self.with_conv = with_conv
59
+ if self.with_conv:
60
+ # no asymmetric padding in torch conv, must do it ourselves
61
+ self.conv = torch.nn.Conv2d(in_channels,
62
+ in_channels,
63
+ kernel_size=3,
64
+ stride=2,
65
+ padding=0)
66
+
67
+ def forward(self, x):
68
+ if self.with_conv:
69
+ pad = (0, 1, 0, 1)
70
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
71
+ x = self.conv(x)
72
+ else:
73
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
74
+ return x
75
+
76
+
77
+ class ResnetBlock(nn.Module):
78
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
79
+ dropout, temb_channels=512):
80
+ super().__init__()
81
+ self.in_channels = in_channels
82
+ out_channels = in_channels if out_channels is None else out_channels
83
+ self.out_channels = out_channels
84
+ self.use_conv_shortcut = conv_shortcut
85
+
86
+ self.norm1 = Normalize(in_channels)
87
+ self.conv1 = torch.nn.Conv2d(in_channels,
88
+ out_channels,
89
+ kernel_size=3,
90
+ stride=1,
91
+ padding=1)
92
+ self.temb_proj = torch.nn.Linear(temb_channels,
93
+ out_channels)
94
+ self.norm2 = Normalize(out_channels)
95
+ self.dropout = torch.nn.Dropout(dropout)
96
+ self.conv2 = torch.nn.Conv2d(out_channels,
97
+ out_channels,
98
+ kernel_size=3,
99
+ stride=1,
100
+ padding=1)
101
+ if self.in_channels != self.out_channels:
102
+ if self.use_conv_shortcut:
103
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
104
+ out_channels,
105
+ kernel_size=3,
106
+ stride=1,
107
+ padding=1)
108
+ else:
109
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=1,
112
+ stride=1,
113
+ padding=0)
114
+
115
+ def forward(self, x, temb):
116
+ h = x
117
+ h = self.norm1(h)
118
+ h = nonlinearity(h)
119
+ h = self.conv1(h)
120
+
121
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
122
+
123
+ h = self.norm2(h)
124
+ h = nonlinearity(h)
125
+ h = self.dropout(h)
126
+ h = self.conv2(h)
127
+
128
+ if self.in_channels != self.out_channels:
129
+ if self.use_conv_shortcut:
130
+ x = self.conv_shortcut(x)
131
+ else:
132
+ x = self.nin_shortcut(x)
133
+
134
+ return x + h
135
+
136
+
137
+ class AttnBlock(nn.Module):
138
+ def __init__(self, in_channels):
139
+ super().__init__()
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = Normalize(in_channels)
143
+ self.q = torch.nn.Conv2d(in_channels,
144
+ in_channels,
145
+ kernel_size=1,
146
+ stride=1,
147
+ padding=0)
148
+ self.k = torch.nn.Conv2d(in_channels,
149
+ in_channels,
150
+ kernel_size=1,
151
+ stride=1,
152
+ padding=0)
153
+ self.v = torch.nn.Conv2d(in_channels,
154
+ in_channels,
155
+ kernel_size=1,
156
+ stride=1,
157
+ padding=0)
158
+ self.proj_out = torch.nn.Conv2d(in_channels,
159
+ in_channels,
160
+ kernel_size=1,
161
+ stride=1,
162
+ padding=0)
163
+
164
+ def forward(self, x):
165
+ h_ = x
166
+ h_ = self.norm(h_)
167
+ q = self.q(h_)
168
+ k = self.k(h_)
169
+ v = self.v(h_)
170
+
171
+ # compute attention
172
+ b, c, h, w = q.shape
173
+ q = q.reshape(b, c, h * w)
174
+ q = q.permute(0, 2, 1) # b,hw,c
175
+ k = k.reshape(b, c, h * w) # b,c,hw
176
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
177
+ w_ = w_ * (int(c) ** (-0.5))
178
+ w_ = torch.nn.functional.softmax(w_, dim=2)
179
+
180
+ # attend to values
181
+ v = v.reshape(b, c, h * w)
182
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
183
+ # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
184
+ h_ = torch.bmm(v, w_)
185
+ h_ = h_.reshape(b, c, h, w)
186
+
187
+ h_ = self.proj_out(h_)
188
+
189
+ return x + h_
190
+
191
+
192
+ class DDPM(nn.Module):
193
+ def __init__(self, config):
194
+ super().__init__()
195
+ self.config = config
196
+ ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
197
+ num_res_blocks = config.model.num_res_blocks
198
+ attn_resolutions = config.model.attn_resolutions
199
+ dropout = config.model.dropout
200
+ in_channels = config.model.in_channels
201
+ resolution = config.data.image_size
202
+ resamp_with_conv = config.model.resamp_with_conv
203
+
204
+ self.ch = ch
205
+ self.temb_ch = self.ch * 4
206
+ self.num_resolutions = len(ch_mult)
207
+ self.num_res_blocks = num_res_blocks
208
+ self.resolution = resolution
209
+ self.in_channels = in_channels
210
+
211
+ # timestep embedding
212
+ self.temb = nn.Module()
213
+ self.temb.dense = nn.ModuleList([
214
+ torch.nn.Linear(self.ch,
215
+ self.temb_ch),
216
+ torch.nn.Linear(self.temb_ch,
217
+ self.temb_ch),
218
+ ])
219
+
220
+ # downsampling
221
+ self.conv_in = torch.nn.Conv2d(in_channels,
222
+ self.ch,
223
+ kernel_size=3,
224
+ stride=1,
225
+ padding=1)
226
+
227
+ curr_res = resolution
228
+ in_ch_mult = (1,) + ch_mult
229
+ self.down = nn.ModuleList()
230
+ block_in = None
231
+ for i_level in range(self.num_resolutions):
232
+ block = nn.ModuleList()
233
+ attn = nn.ModuleList()
234
+ block_in = ch * in_ch_mult[i_level]
235
+ block_out = ch * ch_mult[i_level]
236
+ for i_block in range(self.num_res_blocks):
237
+ block.append(ResnetBlock(in_channels=block_in,
238
+ out_channels=block_out,
239
+ temb_channels=self.temb_ch,
240
+ dropout=dropout))
241
+ block_in = block_out
242
+ if curr_res in attn_resolutions:
243
+ attn.append(AttnBlock(block_in))
244
+ down = nn.Module()
245
+ down.block = block
246
+ down.attn = attn
247
+ if i_level != self.num_resolutions - 1:
248
+ down.downsample = Downsample(block_in, resamp_with_conv)
249
+ curr_res = curr_res // 2
250
+ self.down.append(down)
251
+
252
+ # middle
253
+ self.mid = nn.Module()
254
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
255
+ out_channels=block_in,
256
+ temb_channels=self.temb_ch,
257
+ dropout=dropout)
258
+ self.mid.attn_1 = AttnBlock(block_in)
259
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
260
+ out_channels=block_in,
261
+ temb_channels=self.temb_ch,
262
+ dropout=dropout)
263
+
264
+ # upsampling
265
+ self.up = nn.ModuleList()
266
+ for i_level in reversed(range(self.num_resolutions)):
267
+ block = nn.ModuleList()
268
+ attn = nn.ModuleList()
269
+ block_out = ch * ch_mult[i_level]
270
+ skip_in = ch * ch_mult[i_level]
271
+ for i_block in range(self.num_res_blocks + 1):
272
+ if i_block == self.num_res_blocks:
273
+ skip_in = ch * in_ch_mult[i_level]
274
+ block.append(ResnetBlock(in_channels=block_in + skip_in,
275
+ out_channels=block_out,
276
+ temb_channels=self.temb_ch,
277
+ dropout=dropout))
278
+ block_in = block_out
279
+ if curr_res in attn_resolutions:
280
+ attn.append(AttnBlock(block_in))
281
+ up = nn.Module()
282
+ up.block = block
283
+ up.attn = attn
284
+ if i_level != 0:
285
+ up.upsample = Upsample(block_in, resamp_with_conv)
286
+ curr_res = curr_res * 2
287
+ self.up.insert(0, up) # prepend to get consistent order
288
+
289
+ # end
290
+ self.norm_out = Normalize(block_in)
291
+ self.conv_out = torch.nn.Conv2d(block_in,
292
+ out_ch,
293
+ kernel_size=3,
294
+ stride=1,
295
+ padding=1)
296
+
297
+ def forward(self, x, t, edit_h=None):
298
+ assert x.shape[2] == x.shape[3] == self.resolution
299
+
300
+ # print("check input in U-NET:", x.size()) # [1,3,256,256]
301
+
302
+ # timestep embedding
303
+ temb = get_timestep_embedding(t, self.ch)
304
+ temb = self.temb.dense[0](temb)
305
+ temb = nonlinearity(temb)
306
+ temb = self.temb.dense[1](temb)
307
+
308
+ # downsampling
309
+ hs = [self.conv_in(x)]
310
+ for i_level in range(self.num_resolutions):
311
+ for i_block in range(self.num_res_blocks):
312
+ h = self.down[i_level].block[i_block](hs[-1], temb)
313
+ if len(self.down[i_level].attn) > 0:
314
+ h = self.down[i_level].attn[i_block](h)
315
+ hs.append(h)
316
+ if i_level != self.num_resolutions - 1:
317
+ hs.append(self.down[i_level].downsample(hs[-1]))
318
+
319
+ # middle
320
+ h = hs[-1]
321
+ h = self.mid.block_1(h, temb)
322
+ h = self.mid.attn_1(h)
323
+ h = self.mid.block_2(h, temb)
324
+ mid_h = h.detach().clone() # get the bottleneck h space embedding
325
+ # print("check Unet:", mid_h.size()) # [1, 512, 8, 8]
326
+ # exit()
327
+ if edit_h != None:
328
+ h = edit_h
329
+
330
+ # upsampling
331
+ for i_level in reversed(range(self.num_resolutions)):
332
+ for i_block in range(self.num_res_blocks + 1):
333
+ h = self.up[i_level].block[i_block](
334
+ torch.cat([h, hs.pop()], dim=1), temb)
335
+ if len(self.up[i_level].attn) > 0:
336
+ h = self.up[i_level].attn[i_block](h)
337
+ if i_level != 0:
338
+ h = self.up[i_level].upsample(h)
339
+ # print("check UNET upsampled:", h.size()) # [1, 128, 256, 256]
340
+
341
+ # end
342
+ h = self.norm_out(h)
343
+ h = nonlinearity(h)
344
+ h = self.conv_out(h)
345
+ # print("check U-NET output:", h.size(), mid_h.size()) # [1,3,256,256]
346
+ # exit()
347
+
348
+ return mid_h, h
models/improved_ddpm/fp16_util.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ from . import logger
11
+
12
+ INITIAL_LOG_LOSS_SCALE = 20.0
13
+
14
+
15
+ def convert_module_to_f16(l):
16
+ """
17
+ Convert primitive modules to float16.
18
+ """
19
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20
+ l.weight.data = l.weight.data.half()
21
+ if l.bias is not None:
22
+ l.bias.data = l.bias.data.half()
23
+
24
+
25
+ def convert_module_to_f32(l):
26
+ """
27
+ Convert primitive modules to float32, undoing convert_module_to_f16().
28
+ """
29
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30
+ l.weight.data = l.weight.data.float()
31
+ if l.bias is not None:
32
+ l.bias.data = l.bias.data.float()
33
+
34
+
35
+ def make_master_params(param_groups_and_shapes):
36
+ """
37
+ Copy model parameters into a (differently-shaped) list of full-precision
38
+ parameters.
39
+ """
40
+ master_params = []
41
+ for param_group, shape in param_groups_and_shapes:
42
+ master_param = nn.Parameter(
43
+ _flatten_dense_tensors(
44
+ [param.detach().float() for (_, param) in param_group]
45
+ ).view(shape)
46
+ )
47
+ master_param.requires_grad = True
48
+ master_params.append(master_param)
49
+ return master_params
50
+
51
+
52
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53
+ """
54
+ Copy the gradients from the model parameters into the master parameters
55
+ from make_master_params().
56
+ """
57
+ for master_param, (param_group, shape) in zip(
58
+ master_params, param_groups_and_shapes
59
+ ):
60
+ master_param.grad = _flatten_dense_tensors(
61
+ [param_grad_or_zeros(param) for (_, param) in param_group]
62
+ ).view(shape)
63
+
64
+
65
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
66
+ """
67
+ Copy the master parameter data back into the model parameters.
68
+ """
69
+ # Without copying to a list, if a generator is passed, this will
70
+ # silently not copy any parameters.
71
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72
+ for (_, param), unflat_master_param in zip(
73
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
74
+ ):
75
+ param.detach().copy_(unflat_master_param)
76
+
77
+
78
+ def unflatten_master_params(param_group, master_param):
79
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80
+
81
+
82
+ def get_param_groups_and_shapes(named_model_params):
83
+ named_model_params = list(named_model_params)
84
+ scalar_vector_named_params = (
85
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86
+ (-1),
87
+ )
88
+ matrix_named_params = (
89
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90
+ (1, -1),
91
+ )
92
+ return [scalar_vector_named_params, matrix_named_params]
93
+
94
+
95
+ def master_params_to_state_dict(
96
+ model, param_groups_and_shapes, master_params, use_fp16
97
+ ):
98
+ if use_fp16:
99
+ state_dict = model.state_dict()
100
+ for master_param, (param_group, _) in zip(
101
+ master_params, param_groups_and_shapes
102
+ ):
103
+ for (name, _), unflat_master_param in zip(
104
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
105
+ ):
106
+ assert name in state_dict
107
+ state_dict[name] = unflat_master_param
108
+ else:
109
+ state_dict = model.state_dict()
110
+ for i, (name, _value) in enumerate(model.named_parameters()):
111
+ assert name in state_dict
112
+ state_dict[name] = master_params[i]
113
+ return state_dict
114
+
115
+
116
+ def state_dict_to_master_params(model, state_dict, use_fp16):
117
+ if use_fp16:
118
+ named_model_params = [
119
+ (name, state_dict[name]) for name, _ in model.named_parameters()
120
+ ]
121
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122
+ master_params = make_master_params(param_groups_and_shapes)
123
+ else:
124
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
125
+ return master_params
126
+
127
+
128
+ def zero_master_grads(master_params):
129
+ for param in master_params:
130
+ param.grad = None
131
+
132
+
133
+ def zero_grad(model_params):
134
+ for param in model_params:
135
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136
+ if param.grad is not None:
137
+ param.grad.detach_()
138
+ param.grad.zero_()
139
+
140
+
141
+ def param_grad_or_zeros(param):
142
+ if param.grad is not None:
143
+ return param.grad.data.detach()
144
+ else:
145
+ return th.zeros_like(param)
146
+
147
+
148
+ class MixedPrecisionTrainer:
149
+ def __init__(
150
+ self,
151
+ *,
152
+ model,
153
+ use_fp16=False,
154
+ fp16_scale_growth=1e-3,
155
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156
+ ):
157
+ self.model = model
158
+ self.use_fp16 = use_fp16
159
+ self.fp16_scale_growth = fp16_scale_growth
160
+
161
+ self.model_params = list(self.model.parameters())
162
+ self.master_params = self.model_params
163
+ self.param_groups_and_shapes = None
164
+ self.lg_loss_scale = initial_lg_loss_scale
165
+
166
+ if self.use_fp16:
167
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
168
+ self.model.named_parameters()
169
+ )
170
+ self.master_params = make_master_params(self.param_groups_and_shapes)
171
+ self.model.convert_to_fp16()
172
+
173
+ def zero_grad(self):
174
+ zero_grad(self.model_params)
175
+
176
+ def backward(self, loss: th.Tensor):
177
+ if self.use_fp16:
178
+ loss_scale = 2 ** self.lg_loss_scale
179
+ (loss * loss_scale).backward()
180
+ else:
181
+ loss.backward()
182
+
183
+ def optimize(self, opt: th.optim.Optimizer):
184
+ if self.use_fp16:
185
+ return self._optimize_fp16(opt)
186
+ else:
187
+ return self._optimize_normal(opt)
188
+
189
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
190
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193
+ if check_overflow(grad_norm):
194
+ self.lg_loss_scale -= 1
195
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196
+ zero_master_grads(self.master_params)
197
+ return False
198
+
199
+ logger.logkv_mean("grad_norm", grad_norm)
200
+ logger.logkv_mean("param_norm", param_norm)
201
+
202
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
203
+ opt.step()
204
+ zero_master_grads(self.master_params)
205
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
206
+ self.lg_loss_scale += self.fp16_scale_growth
207
+ return True
208
+
209
+ def _optimize_normal(self, opt: th.optim.Optimizer):
210
+ grad_norm, param_norm = self._compute_norms()
211
+ logger.logkv_mean("grad_norm", grad_norm)
212
+ logger.logkv_mean("param_norm", param_norm)
213
+ opt.step()
214
+ return True
215
+
216
+ def _compute_norms(self, grad_scale=1.0):
217
+ grad_norm = 0.0
218
+ param_norm = 0.0
219
+ for p in self.master_params:
220
+ with th.no_grad():
221
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
222
+ if p.grad is not None:
223
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
224
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
225
+
226
+ def master_params_to_state_dict(self, master_params):
227
+ return master_params_to_state_dict(
228
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
229
+ )
230
+
231
+ def state_dict_to_master_params(self, state_dict):
232
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
233
+
234
+
235
+ def check_overflow(value):
236
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
models/improved_ddpm/logger.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logger based on OpenAI baselines to avoid extra RL-based dependencies:
3
+ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import os.path as osp
9
+ import json
10
+ import time
11
+ import datetime
12
+ import tempfile
13
+ import warnings
14
+ from collections import defaultdict
15
+ from contextlib import contextmanager
16
+
17
+ DEBUG = 10
18
+ INFO = 20
19
+ WARN = 30
20
+ ERROR = 40
21
+
22
+ DISABLED = 50
23
+
24
+
25
+ class KVWriter(object):
26
+ def writekvs(self, kvs):
27
+ raise NotImplementedError
28
+
29
+
30
+ class SeqWriter(object):
31
+ def writeseq(self, seq):
32
+ raise NotImplementedError
33
+
34
+
35
+ class HumanOutputFormat(KVWriter, SeqWriter):
36
+ def __init__(self, filename_or_file):
37
+ if isinstance(filename_or_file, str):
38
+ self.file = open(filename_or_file, "wt")
39
+ self.own_file = True
40
+ else:
41
+ assert hasattr(filename_or_file, "read"), (
42
+ "expected file or str, got %s" % filename_or_file
43
+ )
44
+ self.file = filename_or_file
45
+ self.own_file = False
46
+
47
+ def writekvs(self, kvs):
48
+ # Create strings for printing
49
+ key2str = {}
50
+ for (key, val) in sorted(kvs.items()):
51
+ if hasattr(val, "__float__"):
52
+ valstr = "%-8.3g" % val
53
+ else:
54
+ valstr = str(val)
55
+ key2str[self._truncate(key)] = self._truncate(valstr)
56
+
57
+ # Find max widths
58
+ if len(key2str) == 0:
59
+ print("WARNING: tried to write empty key-value dict")
60
+ return
61
+ else:
62
+ keywidth = max(map(len, key2str.keys()))
63
+ valwidth = max(map(len, key2str.values()))
64
+
65
+ # Write out the data
66
+ dashes = "-" * (keywidth + valwidth + 7)
67
+ lines = [dashes]
68
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
69
+ lines.append(
70
+ "| %s%s | %s%s |"
71
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
72
+ )
73
+ lines.append(dashes)
74
+ self.file.write("\n".join(lines) + "\n")
75
+
76
+ # Flush the output to the file
77
+ self.file.flush()
78
+
79
+ def _truncate(self, s):
80
+ maxlen = 30
81
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
82
+
83
+ def writeseq(self, seq):
84
+ seq = list(seq)
85
+ for (i, elem) in enumerate(seq):
86
+ self.file.write(elem)
87
+ if i < len(seq) - 1: # add space unless this is the last one
88
+ self.file.write(" ")
89
+ self.file.write("\n")
90
+ self.file.flush()
91
+
92
+ def close(self):
93
+ if self.own_file:
94
+ self.file.close()
95
+
96
+
97
+ class JSONOutputFormat(KVWriter):
98
+ def __init__(self, filename):
99
+ self.file = open(filename, "wt")
100
+
101
+ def writekvs(self, kvs):
102
+ for k, v in sorted(kvs.items()):
103
+ if hasattr(v, "dtype"):
104
+ kvs[k] = float(v)
105
+ self.file.write(json.dumps(kvs) + "\n")
106
+ self.file.flush()
107
+
108
+ def close(self):
109
+ self.file.close()
110
+
111
+
112
+ class CSVOutputFormat(KVWriter):
113
+ def __init__(self, filename):
114
+ self.file = open(filename, "w+t")
115
+ self.keys = []
116
+ self.sep = ","
117
+
118
+ def writekvs(self, kvs):
119
+ # Add our current row to the history
120
+ extra_keys = list(kvs.keys() - self.keys)
121
+ extra_keys.sort()
122
+ if extra_keys:
123
+ self.keys.extend(extra_keys)
124
+ self.file.seek(0)
125
+ lines = self.file.readlines()
126
+ self.file.seek(0)
127
+ for (i, k) in enumerate(self.keys):
128
+ if i > 0:
129
+ self.file.write(",")
130
+ self.file.write(k)
131
+ self.file.write("\n")
132
+ for line in lines[1:]:
133
+ self.file.write(line[:-1])
134
+ self.file.write(self.sep * len(extra_keys))
135
+ self.file.write("\n")
136
+ for (i, k) in enumerate(self.keys):
137
+ if i > 0:
138
+ self.file.write(",")
139
+ v = kvs.get(k)
140
+ if v is not None:
141
+ self.file.write(str(v))
142
+ self.file.write("\n")
143
+ self.file.flush()
144
+
145
+ def close(self):
146
+ self.file.close()
147
+
148
+
149
+ def make_output_format(format, ev_dir, log_suffix=""):
150
+ os.makedirs(ev_dir, exist_ok=True)
151
+ if format == "stdout":
152
+ return HumanOutputFormat(sys.stdout)
153
+ elif format == "log":
154
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
155
+ elif format == "json":
156
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
157
+ elif format == "csv":
158
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
159
+ else:
160
+ raise ValueError("Unknown format specified: %s" % (format,))
161
+
162
+
163
+ # ================================================================
164
+ # API
165
+ # ================================================================
166
+
167
+
168
+ def logkv(key, val):
169
+ """
170
+ Log a value of some diagnostic
171
+ Call this once for each diagnostic quantity, each iteration
172
+ If called many times, last value will be used.
173
+ """
174
+ get_current().logkv(key, val)
175
+
176
+
177
+ def logkv_mean(key, val):
178
+ """
179
+ The same as logkv(), but if called many times, values averaged.
180
+ """
181
+ get_current().logkv_mean(key, val)
182
+
183
+
184
+ def logkvs(d):
185
+ """
186
+ Log a dictionary of key-value pairs
187
+ """
188
+ for (k, v) in d.items():
189
+ logkv(k, v)
190
+
191
+
192
+ def dumpkvs():
193
+ """
194
+ Write all of the diagnostics from the current iteration
195
+ """
196
+ return get_current().dumpkvs()
197
+
198
+
199
+ def getkvs():
200
+ return get_current().name2val
201
+
202
+
203
+ def log(*args, level=INFO):
204
+ """
205
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
206
+ """
207
+ get_current().log(*args, level=level)
208
+
209
+
210
+ def debug(*args):
211
+ log(*args, level=DEBUG)
212
+
213
+
214
+ def info(*args):
215
+ log(*args, level=INFO)
216
+
217
+
218
+ def warn(*args):
219
+ log(*args, level=WARN)
220
+
221
+
222
+ def error(*args):
223
+ log(*args, level=ERROR)
224
+
225
+
226
+ def set_level(level):
227
+ """
228
+ Set logging threshold on current logger.
229
+ """
230
+ get_current().set_level(level)
231
+
232
+
233
+ def set_comm(comm):
234
+ get_current().set_comm(comm)
235
+
236
+
237
+ def get_dir():
238
+ """
239
+ Get directory that log files are being written to.
240
+ will be None if there is no output directory (i.e., if you didn't call start)
241
+ """
242
+ return get_current().get_dir()
243
+
244
+
245
+ record_tabular = logkv
246
+ dump_tabular = dumpkvs
247
+
248
+
249
+ @contextmanager
250
+ def profile_kv(scopename):
251
+ logkey = "wait_" + scopename
252
+ tstart = time.time()
253
+ try:
254
+ yield
255
+ finally:
256
+ get_current().name2val[logkey] += time.time() - tstart
257
+
258
+
259
+ def profile(n):
260
+ """
261
+ Usage:
262
+ @profile("my_func")
263
+ def my_func(): code
264
+ """
265
+
266
+ def decorator_with_name(func):
267
+ def func_wrapper(*args, **kwargs):
268
+ with profile_kv(n):
269
+ return func(*args, **kwargs)
270
+
271
+ return func_wrapper
272
+
273
+ return decorator_with_name
274
+
275
+
276
+ # ================================================================
277
+ # Backend
278
+ # ================================================================
279
+
280
+
281
+ def get_current():
282
+ if Logger.CURRENT is None:
283
+ _configure_default_logger()
284
+
285
+ return Logger.CURRENT
286
+
287
+
288
+ class Logger(object):
289
+ DEFAULT = None # A logger with no output files. (See right below class definition)
290
+ # So that you can still log to the terminal without setting up any output files
291
+ CURRENT = None # Current logger being used by the free functions above
292
+
293
+ def __init__(self, dir, output_formats, comm=None):
294
+ self.name2val = defaultdict(float) # values this iteration
295
+ self.name2cnt = defaultdict(int)
296
+ self.level = INFO
297
+ self.dir = dir
298
+ self.output_formats = output_formats
299
+ self.comm = comm
300
+
301
+ # Logging API, forwarded
302
+ # ----------------------------------------
303
+ def logkv(self, key, val):
304
+ self.name2val[key] = val
305
+
306
+ def logkv_mean(self, key, val):
307
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
308
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
309
+ self.name2cnt[key] = cnt + 1
310
+
311
+ def dumpkvs(self):
312
+ if self.comm is None:
313
+ d = self.name2val
314
+ else:
315
+ d = mpi_weighted_mean(
316
+ self.comm,
317
+ {
318
+ name: (val, self.name2cnt.get(name, 1))
319
+ for (name, val) in self.name2val.items()
320
+ },
321
+ )
322
+ if self.comm.rank != 0:
323
+ d["dummy"] = 1 # so we don't get a warning about empty dict
324
+ out = d.copy() # Return the dict for unit testing purposes
325
+ for fmt in self.output_formats:
326
+ if isinstance(fmt, KVWriter):
327
+ fmt.writekvs(d)
328
+ self.name2val.clear()
329
+ self.name2cnt.clear()
330
+ return out
331
+
332
+ def log(self, *args, level=INFO):
333
+ if self.level <= level:
334
+ self._do_log(args)
335
+
336
+ # Configuration
337
+ # ----------------------------------------
338
+ def set_level(self, level):
339
+ self.level = level
340
+
341
+ def set_comm(self, comm):
342
+ self.comm = comm
343
+
344
+ def get_dir(self):
345
+ return self.dir
346
+
347
+ def close(self):
348
+ for fmt in self.output_formats:
349
+ fmt.close()
350
+
351
+ # Misc
352
+ # ----------------------------------------
353
+ def _do_log(self, args):
354
+ for fmt in self.output_formats:
355
+ if isinstance(fmt, SeqWriter):
356
+ fmt.writeseq(map(str, args))
357
+
358
+
359
+ def get_rank_without_mpi_import():
360
+ # check environment variables here instead of importing mpi4py
361
+ # to avoid calling MPI_Init() when this module is imported
362
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
363
+ if varname in os.environ:
364
+ return int(os.environ[varname])
365
+ return 0
366
+
367
+
368
+ def mpi_weighted_mean(comm, local_name2valcount):
369
+ """
370
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
371
+ Perform a weighted average over dicts that are each on a different node
372
+ Input: local_name2valcount: dict mapping key -> (value, count)
373
+ Returns: key -> mean
374
+ """
375
+ all_name2valcount = comm.gather(local_name2valcount)
376
+ if comm.rank == 0:
377
+ name2sum = defaultdict(float)
378
+ name2count = defaultdict(float)
379
+ for n2vc in all_name2valcount:
380
+ for (name, (val, count)) in n2vc.items():
381
+ try:
382
+ val = float(val)
383
+ except ValueError:
384
+ if comm.rank == 0:
385
+ warnings.warn(
386
+ "WARNING: tried to compute mean on non-float {}={}".format(
387
+ name, val
388
+ )
389
+ )
390
+ else:
391
+ name2sum[name] += val * count
392
+ name2count[name] += count
393
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
394
+ else:
395
+ return {}
396
+
397
+
398
+ def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
399
+ """
400
+ If comm is provided, average all numerical stats across that comm
401
+ """
402
+ if dir is None:
403
+ dir = os.getenv("OPENAI_LOGDIR")
404
+ if dir is None:
405
+ dir = osp.join(
406
+ tempfile.gettempdir(),
407
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
408
+ )
409
+ assert isinstance(dir, str)
410
+ dir = os.path.expanduser(dir)
411
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
412
+
413
+ rank = get_rank_without_mpi_import()
414
+ if rank > 0:
415
+ log_suffix = log_suffix + "-rank%03i" % rank
416
+
417
+ if format_strs is None:
418
+ if rank == 0:
419
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
420
+ else:
421
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
422
+ format_strs = filter(None, format_strs)
423
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
424
+
425
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
426
+ if output_formats:
427
+ log("Logging to %s" % dir)
428
+
429
+
430
+ def _configure_default_logger():
431
+ configure()
432
+ Logger.DEFAULT = Logger.CURRENT
433
+
434
+
435
+ def reset():
436
+ if Logger.CURRENT is not Logger.DEFAULT:
437
+ Logger.CURRENT.close()
438
+ Logger.CURRENT = Logger.DEFAULT
439
+ log("Reset logger")
440
+
441
+
442
+ @contextmanager
443
+ def scoped_configure(dir=None, format_strs=None, comm=None):
444
+ prevlogger = Logger.CURRENT
445
+ configure(dir=dir, format_strs=format_strs, comm=comm)
446
+ try:
447
+ yield
448
+ finally:
449
+ Logger.CURRENT.close()
450
+ Logger.CURRENT = prevlogger
451
+
models/improved_ddpm/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
models/improved_ddpm/script_util.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .unet import UNetModel
2
+
3
+ NUM_CLASSES = 1000
4
+
5
+ AFHQ_DICT = dict(
6
+ attention_resolutions="16",
7
+ class_cond=False,
8
+ dropout=0.0,
9
+ image_size=256,
10
+ learn_sigma=True,
11
+ num_channels=128,
12
+ num_head_channels=64,
13
+ num_res_blocks=1,
14
+ resblock_updown=True,
15
+ use_fp16=False,
16
+ use_scale_shift_norm=True,
17
+ num_heads=4,
18
+ num_heads_upsample=-1,
19
+ channel_mult="",
20
+ use_checkpoint=False,
21
+ use_new_attention_order=False,
22
+ )
23
+
24
+
25
+ IMAGENET_DICT = dict(
26
+ attention_resolutions="32,16,8",
27
+ class_cond=True,
28
+ image_size=512,
29
+ learn_sigma=True,
30
+ num_channels=256,
31
+ num_head_channels=64,
32
+ num_res_blocks=2,
33
+ resblock_updown=True,
34
+ use_fp16=False,
35
+ use_scale_shift_norm=True,
36
+ dropout=0.0,
37
+ num_heads=4,
38
+ num_heads_upsample=-1,
39
+ channel_mult="",
40
+ use_checkpoint=False,
41
+ use_new_attention_order=False,
42
+ )
43
+
44
+
45
+ def create_model(
46
+ image_size,
47
+ num_channels,
48
+ num_res_blocks,
49
+ channel_mult="",
50
+ learn_sigma=False,
51
+ class_cond=False,
52
+ use_checkpoint=False,
53
+ attention_resolutions="16",
54
+ num_heads=1,
55
+ num_head_channels=-1,
56
+ num_heads_upsample=-1,
57
+ use_scale_shift_norm=False,
58
+ dropout=0,
59
+ resblock_updown=False,
60
+ use_fp16=False,
61
+ use_new_attention_order=False,
62
+ ):
63
+ if channel_mult == "":
64
+ if image_size == 512:
65
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
66
+ elif image_size == 256:
67
+ channel_mult = (1, 1, 2, 2, 4, 4)
68
+ elif image_size == 128:
69
+ channel_mult = (1, 1, 2, 3, 4)
70
+ elif image_size == 64:
71
+ channel_mult = (1, 2, 3, 4)
72
+ else:
73
+ raise ValueError(f"unsupported image size: {image_size}")
74
+ else:
75
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
76
+
77
+ attention_ds = []
78
+ for res in attention_resolutions.split(","):
79
+ attention_ds.append(image_size // int(res))
80
+
81
+ return UNetModel(
82
+ image_size=image_size,
83
+ in_channels=3,
84
+ model_channels=num_channels,
85
+ out_channels=(3 if not learn_sigma else 6),
86
+ num_res_blocks=num_res_blocks,
87
+ attention_resolutions=tuple(attention_ds),
88
+ dropout=dropout,
89
+ channel_mult=channel_mult,
90
+ num_classes=(NUM_CLASSES if class_cond else None),
91
+ use_checkpoint=use_checkpoint,
92
+ use_fp16=use_fp16,
93
+ num_heads=num_heads,
94
+ num_head_channels=num_head_channels,
95
+ num_heads_upsample=num_heads_upsample,
96
+ use_scale_shift_norm=use_scale_shift_norm,
97
+ resblock_updown=resblock_updown,
98
+ use_new_attention_order=use_new_attention_order,
99
+ )
100
+
101
+
102
+ def i_DDPM(dataset_name = 'AFHQ'):
103
+ if dataset_name in ['AFHQ', 'FFHQ']:
104
+ return create_model(**AFHQ_DICT)
105
+ elif dataset_name == 'IMAGENET':
106
+ return create_model(**IMAGENET_DICT)
107
+ else:
108
+ print('Not implemented.')
109
+ exit()
models/improved_ddpm/unet.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Codebase for "Improved Denoising Diffusion Probabilistic Models".
3
+ """
4
+
5
+
6
+ from abc import abstractmethod
7
+
8
+ import math
9
+
10
+ import numpy as np
11
+ import torch as th
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
16
+ from .nn import (
17
+ checkpoint,
18
+ conv_nd,
19
+ linear,
20
+ avg_pool_nd,
21
+ zero_module,
22
+ normalization,
23
+ timestep_embedding,
24
+ )
25
+
26
+
27
+ class AttentionPool2d(nn.Module):
28
+ """
29
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ spacial_dim: int,
35
+ embed_dim: int,
36
+ num_heads_channels: int,
37
+ output_dim: int = None,
38
+ ):
39
+ super().__init__()
40
+ self.positional_embedding = nn.Parameter(
41
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
42
+ )
43
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
44
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
45
+ self.num_heads = embed_dim // num_heads_channels
46
+ self.attention = QKVAttention(self.num_heads)
47
+
48
+ def forward(self, x):
49
+ b, c, *_spatial = x.shape
50
+ x = x.reshape(b, c, -1) # NC(HW)
51
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
52
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
53
+ x = self.qkv_proj(x)
54
+ x = self.attention(x)
55
+ x = self.c_proj(x)
56
+ return x[:, :, 0]
57
+
58
+
59
+ class TimestepBlock(nn.Module):
60
+ """
61
+ Any module where forward() takes timestep embeddings as a second argument.
62
+ """
63
+
64
+ @abstractmethod
65
+ def forward(self, x, emb):
66
+ """
67
+ Apply the module to `x` given `emb` timestep embeddings.
68
+ """
69
+
70
+
71
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
72
+ """
73
+ A sequential module that passes timestep embeddings to the children that
74
+ support it as an extra input.
75
+ """
76
+
77
+ def forward(self, x, emb):
78
+ for layer in self:
79
+ if isinstance(layer, TimestepBlock):
80
+ x = layer(x, emb)
81
+ else:
82
+ x = layer(x)
83
+ return x
84
+
85
+
86
+ class Upsample(nn.Module):
87
+ """
88
+ An upsampling layer with an optional convolution.
89
+
90
+ :param channels: channels in the inputs and outputs.
91
+ :param use_conv: a bool determining if a convolution is applied.
92
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
93
+ upsampling occurs in the inner-two dimensions.
94
+ """
95
+
96
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.out_channels = out_channels or channels
100
+ self.use_conv = use_conv
101
+ self.dims = dims
102
+ if use_conv:
103
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
104
+
105
+ def forward(self, x):
106
+ assert x.shape[1] == self.channels
107
+ if self.dims == 3:
108
+ x = F.interpolate(
109
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
110
+ )
111
+ else:
112
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
113
+ if self.use_conv:
114
+ x = self.conv(x)
115
+ return x
116
+
117
+
118
+ class Downsample(nn.Module):
119
+ """
120
+ A downsampling layer with an optional convolution.
121
+
122
+ :param channels: channels in the inputs and outputs.
123
+ :param use_conv: a bool determining if a convolution is applied.
124
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
125
+ downsampling occurs in the inner-two dimensions.
126
+ """
127
+
128
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
129
+ super().__init__()
130
+ self.channels = channels
131
+ self.out_channels = out_channels or channels
132
+ self.use_conv = use_conv
133
+ self.dims = dims
134
+ stride = 2 if dims != 3 else (1, 2, 2)
135
+ if use_conv:
136
+ self.op = conv_nd(
137
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
138
+ )
139
+ else:
140
+ assert self.channels == self.out_channels
141
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
142
+
143
+ def forward(self, x):
144
+ assert x.shape[1] == self.channels
145
+ return self.op(x)
146
+
147
+
148
+ class ResBlock(TimestepBlock):
149
+ """
150
+ A residual block that can optionally change the number of channels.
151
+
152
+ :param channels: the number of input channels.
153
+ :param emb_channels: the number of timestep embedding channels.
154
+ :param dropout: the rate of dropout.
155
+ :param out_channels: if specified, the number of out channels.
156
+ :param use_conv: if True and out_channels is specified, use a spatial
157
+ convolution instead of a smaller 1x1 convolution to change the
158
+ channels in the skip connection.
159
+ :param dims: determines if the signal is 1D, 2D, or 3D.
160
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
161
+ :param up: if True, use this block for upsampling.
162
+ :param down: if True, use this block for downsampling.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ channels,
168
+ emb_channels,
169
+ dropout,
170
+ out_channels=None,
171
+ use_conv=False,
172
+ use_scale_shift_norm=False,
173
+ dims=2,
174
+ use_checkpoint=False,
175
+ up=False,
176
+ down=False,
177
+ ):
178
+ super().__init__()
179
+ self.channels = channels
180
+ self.emb_channels = emb_channels
181
+ self.dropout = dropout
182
+ self.out_channels = out_channels or channels
183
+ self.use_conv = use_conv
184
+ self.use_checkpoint = use_checkpoint
185
+ self.use_scale_shift_norm = use_scale_shift_norm
186
+
187
+ self.in_layers = nn.Sequential(
188
+ normalization(channels),
189
+ nn.SiLU(),
190
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
191
+ )
192
+
193
+ self.updown = up or down
194
+
195
+ if up:
196
+ self.h_upd = Upsample(channels, False, dims)
197
+ self.x_upd = Upsample(channels, False, dims)
198
+ elif down:
199
+ self.h_upd = Downsample(channels, False, dims)
200
+ self.x_upd = Downsample(channels, False, dims)
201
+ else:
202
+ self.h_upd = self.x_upd = nn.Identity()
203
+
204
+ self.emb_layers = nn.Sequential(
205
+ nn.SiLU(),
206
+ linear(
207
+ emb_channels,
208
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
209
+ ),
210
+ )
211
+ self.out_layers = nn.Sequential(
212
+ normalization(self.out_channels),
213
+ nn.SiLU(),
214
+ nn.Dropout(p=dropout),
215
+ zero_module(
216
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
217
+ ),
218
+ )
219
+
220
+ if self.out_channels == channels:
221
+ self.skip_connection = nn.Identity()
222
+ elif use_conv:
223
+ self.skip_connection = conv_nd(
224
+ dims, channels, self.out_channels, 3, padding=1
225
+ )
226
+ else:
227
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
228
+
229
+ def forward(self, x, emb):
230
+ """
231
+ Apply the block to a Tensor, conditioned on a timestep embedding.
232
+
233
+ :param x: an [N x C x ...] Tensor of features.
234
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
235
+ :return: an [N x C x ...] Tensor of outputs.
236
+ """
237
+ return checkpoint(
238
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
239
+ )
240
+
241
+ def _forward(self, x, emb):
242
+ if self.updown:
243
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
244
+ h = in_rest(x)
245
+ h = self.h_upd(h)
246
+ x = self.x_upd(x)
247
+ h = in_conv(h)
248
+ else:
249
+ h = self.in_layers(x)
250
+ emb_out = self.emb_layers(emb).type(h.dtype)
251
+ while len(emb_out.shape) < len(h.shape):
252
+ emb_out = emb_out[..., None]
253
+ if self.use_scale_shift_norm:
254
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
255
+ scale, shift = th.chunk(emb_out, 2, dim=1)
256
+ h = out_norm(h) * (1 + scale) + shift
257
+ h = out_rest(h)
258
+ else:
259
+ h = h + emb_out
260
+ h = self.out_layers(h)
261
+ return self.skip_connection(x) + h
262
+
263
+
264
+ class AttentionBlock(nn.Module):
265
+ """
266
+ An attention block that allows spatial positions to attend to each other.
267
+
268
+ Originally ported from here, but adapted to the N-d case.
269
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
270
+ """
271
+
272
+ def __init__(
273
+ self,
274
+ channels,
275
+ num_heads=1,
276
+ num_head_channels=-1,
277
+ use_checkpoint=False,
278
+ use_new_attention_order=False,
279
+ ):
280
+ super().__init__()
281
+ self.channels = channels
282
+ if num_head_channels == -1:
283
+ self.num_heads = num_heads
284
+ else:
285
+ assert (
286
+ channels % num_head_channels == 0
287
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
288
+ self.num_heads = channels // num_head_channels
289
+ self.use_checkpoint = use_checkpoint
290
+ self.norm = normalization(channels)
291
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
292
+ if use_new_attention_order:
293
+ # split qkv before split heads
294
+ self.attention = QKVAttention(self.num_heads)
295
+ else:
296
+ # split heads before split qkv
297
+ self.attention = QKVAttentionLegacy(self.num_heads)
298
+
299
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
300
+
301
+ def forward(self, x):
302
+ return checkpoint(self._forward, (x,), self.parameters(), True)
303
+
304
+ def _forward(self, x):
305
+ b, c, *spatial = x.shape
306
+ x = x.reshape(b, c, -1)
307
+ qkv = self.qkv(self.norm(x))
308
+ h = self.attention(qkv)
309
+ h = self.proj_out(h)
310
+ return (x + h).reshape(b, c, *spatial)
311
+
312
+
313
+ def count_flops_attn(model, _x, y):
314
+ """
315
+ A counter for the `thop` package to count the operations in an
316
+ attention operation.
317
+ Meant to be used like:
318
+ macs, params = thop.profile(
319
+ model,
320
+ inputs=(inputs, timestamps),
321
+ custom_ops={QKVAttention: QKVAttention.count_flops},
322
+ )
323
+ """
324
+ b, c, *spatial = y[0].shape
325
+ num_spatial = int(np.prod(spatial))
326
+ # We perform two matmuls with the same number of ops.
327
+ # The first computes the weight matrix, the second computes
328
+ # the combination of the value vectors.
329
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
330
+ model.total_ops += th.DoubleTensor([matmul_ops])
331
+
332
+
333
+ class QKVAttentionLegacy(nn.Module):
334
+ """
335
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
336
+ """
337
+
338
+ def __init__(self, n_heads):
339
+ super().__init__()
340
+ self.n_heads = n_heads
341
+
342
+ def forward(self, qkv):
343
+ """
344
+ Apply QKV attention.
345
+
346
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
347
+ :return: an [N x (H * C) x T] tensor after attention.
348
+ """
349
+ bs, width, length = qkv.shape
350
+ assert width % (3 * self.n_heads) == 0
351
+ ch = width // (3 * self.n_heads)
352
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
353
+ scale = 1 / math.sqrt(math.sqrt(ch))
354
+ weight = th.einsum(
355
+ "bct,bcs->bts", q * scale, k * scale
356
+ ) # More stable with f16 than dividing afterwards
357
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
358
+ a = th.einsum("bts,bcs->bct", weight, v)
359
+ return a.reshape(bs, -1, length)
360
+
361
+ @staticmethod
362
+ def count_flops(model, _x, y):
363
+ return count_flops_attn(model, _x, y)
364
+
365
+
366
+ class QKVAttention(nn.Module):
367
+ """
368
+ A module which performs QKV attention and splits in a different order.
369
+ """
370
+
371
+ def __init__(self, n_heads):
372
+ super().__init__()
373
+ self.n_heads = n_heads
374
+
375
+ def forward(self, qkv):
376
+ """
377
+ Apply QKV attention.
378
+
379
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
380
+ :return: an [N x (H * C) x T] tensor after attention.
381
+ """
382
+ bs, width, length = qkv.shape
383
+ assert width % (3 * self.n_heads) == 0
384
+ ch = width // (3 * self.n_heads)
385
+ q, k, v = qkv.chunk(3, dim=1)
386
+ scale = 1 / math.sqrt(math.sqrt(ch))
387
+ weight = th.einsum(
388
+ "bct,bcs->bts",
389
+ (q * scale).view(bs * self.n_heads, ch, length),
390
+ (k * scale).view(bs * self.n_heads, ch, length),
391
+ ) # More stable with f16 than dividing afterwards
392
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
393
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
394
+ return a.reshape(bs, -1, length)
395
+
396
+ @staticmethod
397
+ def count_flops(model, _x, y):
398
+ return count_flops_attn(model, _x, y)
399
+
400
+
401
+ class UNetModel(nn.Module):
402
+ """
403
+ The full UNet model with attention and timestep embedding.
404
+
405
+ :param in_channels: channels in the input Tensor.
406
+ :param model_channels: base channel count for the model.
407
+ :param out_channels: channels in the output Tensor.
408
+ :param num_res_blocks: number of residual blocks per downsample.
409
+ :param attention_resolutions: a collection of downsample rates at which
410
+ attention will take place. May be a set, list, or tuple.
411
+ For example, if this contains 4, then at 4x downsampling, attention
412
+ will be used.
413
+ :param dropout: the dropout probability.
414
+ :param channel_mult: channel multiplier for each level of the UNet.
415
+ :param conv_resample: if True, use learned convolutions for upsampling and
416
+ downsampling.
417
+ :param dims: determines if the signal is 1D, 2D, or 3D.
418
+ :param num_classes: if specified (as an int), then this model will be
419
+ class-conditional with `num_classes` classes.
420
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
421
+ :param num_heads: the number of attention heads in each attention layer.
422
+ :param num_heads_channels: if specified, ignore num_heads and instead use
423
+ a fixed channel width per attention head.
424
+ :param num_heads_upsample: works with num_heads to set a different number
425
+ of heads for upsampling. Deprecated.
426
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
427
+ :param resblock_updown: use residual blocks for up/downsampling.
428
+ :param use_new_attention_order: use a different attention pattern for potentially
429
+ increased efficiency.
430
+ """
431
+
432
+ def __init__(
433
+ self,
434
+ image_size,
435
+ in_channels,
436
+ model_channels,
437
+ out_channels,
438
+ num_res_blocks,
439
+ attention_resolutions,
440
+ dropout=0,
441
+ channel_mult=(1, 2, 4, 8),
442
+ conv_resample=True,
443
+ dims=2,
444
+ num_classes=None,
445
+ use_checkpoint=False,
446
+ use_fp16=False,
447
+ num_heads=1,
448
+ num_head_channels=-1,
449
+ num_heads_upsample=-1,
450
+ use_scale_shift_norm=False,
451
+ resblock_updown=False,
452
+ use_new_attention_order=False,
453
+ ):
454
+ super().__init__()
455
+
456
+ if num_heads_upsample == -1:
457
+ num_heads_upsample = num_heads
458
+
459
+ self.image_size = image_size
460
+ self.in_channels = in_channels
461
+ self.model_channels = model_channels
462
+ self.out_channels = out_channels
463
+ self.num_res_blocks = num_res_blocks
464
+ self.attention_resolutions = attention_resolutions
465
+ self.dropout = dropout
466
+ self.channel_mult = channel_mult
467
+ self.conv_resample = conv_resample
468
+ self.num_classes = num_classes
469
+ self.use_checkpoint = use_checkpoint
470
+ self.dtype = th.float16 if use_fp16 else th.float32
471
+ self.num_heads = num_heads
472
+ self.num_head_channels = num_head_channels
473
+ self.num_heads_upsample = num_heads_upsample
474
+
475
+ time_embed_dim = model_channels * 4
476
+ self.time_embed = nn.Sequential(
477
+ linear(model_channels, time_embed_dim),
478
+ nn.SiLU(),
479
+ linear(time_embed_dim, time_embed_dim),
480
+ )
481
+
482
+ if self.num_classes is not None:
483
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
484
+
485
+ ch = input_ch = int(channel_mult[0] * model_channels)
486
+ self.input_blocks = nn.ModuleList(
487
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
488
+ )
489
+ self._feature_size = ch
490
+ input_block_chans = [ch]
491
+ ds = 1
492
+ for level, mult in enumerate(channel_mult):
493
+ for _ in range(num_res_blocks):
494
+ layers = [
495
+ ResBlock(
496
+ ch,
497
+ time_embed_dim,
498
+ dropout,
499
+ out_channels=int(mult * model_channels),
500
+ dims=dims,
501
+ use_checkpoint=use_checkpoint,
502
+ use_scale_shift_norm=use_scale_shift_norm,
503
+ )
504
+ ]
505
+ ch = int(mult * model_channels)
506
+ if ds in attention_resolutions:
507
+ layers.append(
508
+ AttentionBlock(
509
+ ch,
510
+ use_checkpoint=use_checkpoint,
511
+ num_heads=num_heads,
512
+ num_head_channels=num_head_channels,
513
+ use_new_attention_order=use_new_attention_order,
514
+ )
515
+ )
516
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
517
+ self._feature_size += ch
518
+ input_block_chans.append(ch)
519
+ if level != len(channel_mult) - 1:
520
+ out_ch = ch
521
+ self.input_blocks.append(
522
+ TimestepEmbedSequential(
523
+ ResBlock(
524
+ ch,
525
+ time_embed_dim,
526
+ dropout,
527
+ out_channels=out_ch,
528
+ dims=dims,
529
+ use_checkpoint=use_checkpoint,
530
+ use_scale_shift_norm=use_scale_shift_norm,
531
+ down=True,
532
+ )
533
+ if resblock_updown
534
+ else Downsample(
535
+ ch, conv_resample, dims=dims, out_channels=out_ch
536
+ )
537
+ )
538
+ )
539
+ ch = out_ch
540
+ input_block_chans.append(ch)
541
+ ds *= 2
542
+ self._feature_size += ch
543
+
544
+ self.middle_block = TimestepEmbedSequential(
545
+ ResBlock(
546
+ ch,
547
+ time_embed_dim,
548
+ dropout,
549
+ dims=dims,
550
+ use_checkpoint=use_checkpoint,
551
+ use_scale_shift_norm=use_scale_shift_norm,
552
+ ),
553
+ AttentionBlock(
554
+ ch,
555
+ use_checkpoint=use_checkpoint,
556
+ num_heads=num_heads,
557
+ num_head_channels=num_head_channels,
558
+ use_new_attention_order=use_new_attention_order,
559
+ ),
560
+ ResBlock(
561
+ ch,
562
+ time_embed_dim,
563
+ dropout,
564
+ dims=dims,
565
+ use_checkpoint=use_checkpoint,
566
+ use_scale_shift_norm=use_scale_shift_norm,
567
+ ),
568
+ )
569
+ self._feature_size += ch
570
+
571
+ self.output_blocks = nn.ModuleList([])
572
+ for level, mult in list(enumerate(channel_mult))[::-1]:
573
+ for i in range(num_res_blocks + 1):
574
+ ich = input_block_chans.pop()
575
+ layers = [
576
+ ResBlock(
577
+ ch + ich,
578
+ time_embed_dim,
579
+ dropout,
580
+ out_channels=int(model_channels * mult),
581
+ dims=dims,
582
+ use_checkpoint=use_checkpoint,
583
+ use_scale_shift_norm=use_scale_shift_norm,
584
+ )
585
+ ]
586
+ ch = int(model_channels * mult)
587
+ if ds in attention_resolutions:
588
+ layers.append(
589
+ AttentionBlock(
590
+ ch,
591
+ use_checkpoint=use_checkpoint,
592
+ num_heads=num_heads_upsample,
593
+ num_head_channels=num_head_channels,
594
+ use_new_attention_order=use_new_attention_order,
595
+ )
596
+ )
597
+ if level and i == num_res_blocks:
598
+ out_ch = ch
599
+ layers.append(
600
+ ResBlock(
601
+ ch,
602
+ time_embed_dim,
603
+ dropout,
604
+ out_channels=out_ch,
605
+ dims=dims,
606
+ use_checkpoint=use_checkpoint,
607
+ use_scale_shift_norm=use_scale_shift_norm,
608
+ up=True,
609
+ )
610
+ if resblock_updown
611
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
612
+ )
613
+ ds //= 2
614
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
615
+ self._feature_size += ch
616
+
617
+ self.out = nn.Sequential(
618
+ normalization(ch),
619
+ nn.SiLU(),
620
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
621
+ )
622
+
623
+ def convert_to_fp16(self):
624
+ """
625
+ Convert the torso of the model to float16.
626
+ """
627
+ self.input_blocks.apply(convert_module_to_f16)
628
+ self.middle_block.apply(convert_module_to_f16)
629
+ self.output_blocks.apply(convert_module_to_f16)
630
+
631
+ def convert_to_fp32(self):
632
+ """
633
+ Convert the torso of the model to float32.
634
+ """
635
+ self.input_blocks.apply(convert_module_to_f32)
636
+ self.middle_block.apply(convert_module_to_f32)
637
+ self.output_blocks.apply(convert_module_to_f32)
638
+
639
+ def forward(self, x, timesteps, y=None, ref_img=None, edit_h=None):
640
+ """
641
+ Apply the model to an input batch.
642
+
643
+ :param x: an [N x C x ...] Tensor of inputs.
644
+ :param timesteps: a 1-D batch of timesteps.
645
+ :param y: an [N] Tensor of labels, if class-conditional.
646
+ :return: an [N x C x ...] Tensor of outputs.
647
+ """
648
+ # assert (y is not None) == (
649
+ # self.num_classes is not None
650
+ # ), "must specify y if and only if the model is class-conditional"
651
+
652
+ hs = []
653
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
654
+
655
+ # if self.num_classes is not None:
656
+ # assert y.shape == (x.shape[0],)
657
+ # emb = emb + self.label_emb(y)
658
+
659
+ h = x.type(self.dtype)
660
+ for module in self.input_blocks:
661
+ h = module(h, emb)
662
+ hs.append(h)
663
+ h = self.middle_block(h, emb)
664
+ mid_h = h.detach().clone() # get the bottleneck h space embedding
665
+ # print("check Unet:", mid_h.size()) # [1, 512, 8, 8]
666
+ # exit()
667
+ if edit_h != None:
668
+ h = edit_h
669
+
670
+ for module in self.output_blocks:
671
+ h = th.cat([h, hs.pop()], dim=1)
672
+ h = module(h, emb)
673
+ h = h.type(x.dtype)
674
+ # print("check U-NET output:", h.size(), mid_h.size(), self.out(h).size()) # [1,3,256,256]
675
+ # exit()
676
+
677
+ return mid_h, self.out(h)
models/insight_face/__init__.py ADDED
File without changes
models/insight_face/helpers.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4
+
5
+ """
6
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7
+ """
8
+
9
+
10
+
11
+
12
+ class Conv_block(Module):
13
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
14
+ super(Conv_block, self).__init__()
15
+ self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
16
+ self.bn = BatchNorm2d(out_c)
17
+ self.prelu = PReLU(out_c)
18
+ def forward(self, x):
19
+ x = self.conv(x)
20
+ x = self.bn(x)
21
+ x = self.prelu(x)
22
+ return x
23
+
24
+ class Linear_block(Module):
25
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
26
+ super(Linear_block, self).__init__()
27
+ self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
28
+ self.bn = BatchNorm2d(out_c)
29
+ def forward(self, x):
30
+ x = self.conv(x)
31
+ x = self.bn(x)
32
+ return x
33
+
34
+ class Depth_Wise(Module):
35
+ def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
36
+ super(Depth_Wise, self).__init__()
37
+ self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
38
+ self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
39
+ self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
40
+ self.residual = residual
41
+ def forward(self, x):
42
+ if self.residual:
43
+ short_cut = x
44
+ x = self.conv(x)
45
+ x = self.conv_dw(x)
46
+ x = self.project(x)
47
+ if self.residual:
48
+ output = short_cut + x
49
+ else:
50
+ output = x
51
+ return output
52
+
53
+ class Residual(Module):
54
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
55
+ super(Residual, self).__init__()
56
+ modules = []
57
+ for _ in range(num_block):
58
+ modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
59
+ self.model = Sequential(*modules)
60
+ def forward(self, x):
61
+ return self.model(x)
62
+
63
+
64
+
65
+
66
+ ######################################################################################
67
+
68
+
69
+ class Flatten(Module):
70
+ def forward(self, input):
71
+ return input.view(input.size(0), -1)
72
+
73
+
74
+ def l2_norm(input, axis=1):
75
+ norm = torch.norm(input, 2, axis, True)
76
+ output = torch.div(input, norm)
77
+ return output
78
+
79
+
80
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
81
+ """ A named tuple describing a ResNet block. """
82
+
83
+
84
+ def get_block(in_channel, depth, num_units, stride=2):
85
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
86
+
87
+
88
+ def get_blocks(num_layers):
89
+ if num_layers == 50:
90
+ blocks = [
91
+ get_block(in_channel=64, depth=64, num_units=3),
92
+ get_block(in_channel=64, depth=128, num_units=4),
93
+ get_block(in_channel=128, depth=256, num_units=14),
94
+ get_block(in_channel=256, depth=512, num_units=3)
95
+ ]
96
+ elif num_layers == 100:
97
+ blocks = [
98
+ get_block(in_channel=64, depth=64, num_units=3),
99
+ get_block(in_channel=64, depth=128, num_units=13),
100
+ get_block(in_channel=128, depth=256, num_units=30),
101
+ get_block(in_channel=256, depth=512, num_units=3)
102
+ ]
103
+ elif num_layers == 152:
104
+ blocks = [
105
+ get_block(in_channel=64, depth=64, num_units=3),
106
+ get_block(in_channel=64, depth=128, num_units=8),
107
+ get_block(in_channel=128, depth=256, num_units=36),
108
+ get_block(in_channel=256, depth=512, num_units=3)
109
+ ]
110
+ else:
111
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
112
+ return blocks
113
+
114
+
115
+ class SEModule(Module):
116
+ def __init__(self, channels, reduction):
117
+ super(SEModule, self).__init__()
118
+ self.avg_pool = AdaptiveAvgPool2d(1)
119
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
120
+ self.relu = ReLU(inplace=True)
121
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
122
+ self.sigmoid = Sigmoid()
123
+
124
+ def forward(self, x):
125
+ module_input = x
126
+ x = self.avg_pool(x)
127
+ x = self.fc1(x)
128
+ x = self.relu(x)
129
+ x = self.fc2(x)
130
+ x = self.sigmoid(x)
131
+ return module_input * x
132
+
133
+
134
+ class bottleneck_IR(Module):
135
+ def __init__(self, in_channel, depth, stride):
136
+ super(bottleneck_IR, self).__init__()
137
+ if in_channel == depth:
138
+ self.shortcut_layer = MaxPool2d(1, stride)
139
+ else:
140
+ self.shortcut_layer = Sequential(
141
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
142
+ BatchNorm2d(depth)
143
+ )
144
+ self.res_layer = Sequential(
145
+ BatchNorm2d(in_channel),
146
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
147
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
148
+ )
149
+
150
+ def forward(self, x):
151
+ shortcut = self.shortcut_layer(x)
152
+ res = self.res_layer(x)
153
+ return res + shortcut
154
+
155
+
156
+ class bottleneck_IR_SE(Module):
157
+ def __init__(self, in_channel, depth, stride):
158
+ super(bottleneck_IR_SE, self).__init__()
159
+ if in_channel == depth:
160
+ self.shortcut_layer = MaxPool2d(1, stride)
161
+ else:
162
+ self.shortcut_layer = Sequential(
163
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
164
+ BatchNorm2d(depth)
165
+ )
166
+ self.res_layer = Sequential(
167
+ BatchNorm2d(in_channel),
168
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
169
+ PReLU(depth),
170
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
171
+ BatchNorm2d(depth),
172
+ SEModule(depth, 16)
173
+ )
174
+
175
+ def forward(self, x):
176
+ shortcut = self.shortcut_layer(x)
177
+ res = self.res_layer(x)
178
+ return res + shortcut
models/insight_face/model_irse.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from models.insight_face.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+ from models.insight_face.helpers import Conv_block, Linear_block, Depth_Wise, Residual
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class MobileFaceNet(Module):
10
+ def __init__(self, embedding_size):
11
+ super(MobileFaceNet, self).__init__()
12
+ self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
13
+ self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
14
+ self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
15
+ self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
16
+ self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
17
+ self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
18
+ self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
19
+ self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
20
+ self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
21
+ self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
22
+ self.conv_6_flatten = Flatten()
23
+ self.linear = Linear(512, embedding_size, bias=False)
24
+ self.bn = BatchNorm1d(embedding_size)
25
+
26
+ def forward(self, x):
27
+ out = self.conv1(x)
28
+ out = self.conv2_dw(out)
29
+ out = self.conv_23(out)
30
+ out = self.conv_3(out)
31
+ out = self.conv_34(out)
32
+ out = self.conv_4(out)
33
+ out = self.conv_45(out)
34
+ out = self.conv_5(out)
35
+ out = self.conv_6_sep(out)
36
+ out = self.conv_6_dw(out)
37
+ out = self.conv_6_flatten(out)
38
+ out = self.linear(out)
39
+ out = self.bn(out)
40
+ return l2_norm(out)
41
+
42
+
43
+
44
+
45
+
46
+
47
+ ######################################################################################
48
+
49
+ class Backbone(Module):
50
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
51
+ super(Backbone, self).__init__()
52
+ assert input_size in [112, 224], "input_size should be 112 or 224"
53
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
54
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
55
+ blocks = get_blocks(num_layers)
56
+ if mode == 'ir':
57
+ unit_module = bottleneck_IR
58
+ elif mode == 'ir_se':
59
+ unit_module = bottleneck_IR_SE
60
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
61
+ BatchNorm2d(64),
62
+ PReLU(64))
63
+ if input_size == 112:
64
+ self.output_layer = Sequential(BatchNorm2d(512),
65
+ Dropout(drop_ratio),
66
+ Flatten(),
67
+ Linear(512 * 7 * 7, 512),
68
+ BatchNorm1d(512, affine=affine))
69
+ else:
70
+ self.output_layer = Sequential(BatchNorm2d(512),
71
+ Dropout(drop_ratio),
72
+ Flatten(),
73
+ Linear(512 * 14 * 14, 512),
74
+ BatchNorm1d(512, affine=affine))
75
+
76
+ modules = []
77
+ for block in blocks:
78
+ for bottleneck in block:
79
+ modules.append(unit_module(bottleneck.in_channel,
80
+ bottleneck.depth,
81
+ bottleneck.stride))
82
+ self.body = Sequential(*modules)
83
+
84
+ def forward(self, x):
85
+ x = self.input_layer(x)
86
+ x = self.body(x)
87
+ x = self.output_layer(x)
88
+ return l2_norm(x)
89
+
90
+
91
+ def IR_50(input_size):
92
+ """Constructs a ir-50 model."""
93
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
94
+ return model
95
+
96
+
97
+ def IR_101(input_size):
98
+ """Constructs a ir-101 model."""
99
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
100
+ return model
101
+
102
+
103
+ def IR_152(input_size):
104
+ """Constructs a ir-152 model."""
105
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
106
+ return model
107
+
108
+
109
+ def IR_SE_50(input_size):
110
+ """Constructs a ir_se-50 model."""
111
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
112
+ return model
113
+
114
+
115
+ def IR_SE_101(input_size):
116
+ """Constructs a ir_se-101 model."""
117
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
118
+ return model
119
+
120
+
121
+ def IR_SE_152(input_size):
122
+ """Constructs a ir_se-152 model."""
123
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
124
+ return model
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake>=3.22.3
2
+ lmdb>=1.2.1
3
+ numpy>=1.19.2
4
+ Pillow>=8.4.0
5
+ PyYAML>=6.0
6
+ tqdm>=4.55.1
7
+ opencv_python>=4.5.2.52
8
+ ftfy>=6.0.3
9
+ regex>=2021.10.23
10
+ dlib>=19.22.1
utils/align_utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
3
+ author: lzhbrian (https://lzhbrian.me)
4
+ date: 2020.1.5
5
+ note: code is heavily borrowed from
6
+ https://github.com/NVlabs/ffhq-dataset
7
+ http://dlib.net/face_landmark_detection.py.html
8
+
9
+ requirements:
10
+ apt install cmake
11
+ conda install Pillow numpy scipy
12
+ pip install dlib
13
+ # download face landmark model from:
14
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
15
+ """
16
+ from argparse import ArgumentParser
17
+ import time
18
+ import numpy as np
19
+ import PIL
20
+ import PIL.Image
21
+ import os
22
+ import scipy
23
+ import scipy.ndimage
24
+ import dlib
25
+ import multiprocessing as mp
26
+ import math
27
+
28
+ from configs.paths_config import MODEL_PATHS
29
+
30
+ SHAPE_PREDICTOR_PATH = MODEL_PATHS["shape_predictor"]
31
+
32
+
33
+ def run_alignment(image_path, output_size):
34
+ if not os.path.exists("pretrained/shape_predictor_68_face_landmarks.dat"):
35
+ print('Downloading files for aligning face image...')
36
+ os.system(f'wget -P pretrained/ http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')
37
+ os.system('bzip2 -dk pretrained/shape_predictor_68_face_landmarks.dat.bz2')
38
+ print('Done.')
39
+ predictor = dlib.shape_predictor("pretrained/shape_predictor_68_face_landmarks.dat")
40
+ aligned_image = align_face(filepath=image_path, predictor=predictor, output_size=output_size, transform_size=output_size)
41
+ print("Aligned image has shape: {}".format(aligned_image.size))
42
+ return aligned_image
43
+
44
+
45
+ def get_landmark(filepath, predictor):
46
+ """get landmark with dlib
47
+ :return: np.array shape=(68, 2)
48
+ """
49
+ detector = dlib.get_frontal_face_detector()
50
+
51
+ img = dlib.load_rgb_image(filepath)
52
+ dets = detector(img, 1)
53
+
54
+ for k, d in enumerate(dets):
55
+ shape = predictor(img, d)
56
+
57
+ t = list(shape.parts())
58
+ a = []
59
+ for tt in t:
60
+ a.append([tt.x, tt.y])
61
+ lm = np.array(a)
62
+ return lm
63
+
64
+
65
+ def align_face(filepath, predictor, output_size=256, transform_size=256):
66
+ """
67
+ :param filepath: str
68
+ :return: PIL Image
69
+ """
70
+
71
+ lm = get_landmark(filepath, predictor)
72
+
73
+ lm_chin = lm[0: 17] # left-right
74
+ lm_eyebrow_left = lm[17: 22] # left-right
75
+ lm_eyebrow_right = lm[22: 27] # left-right
76
+ lm_nose = lm[27: 31] # top-down
77
+ lm_nostrils = lm[31: 36] # top-down
78
+ lm_eye_left = lm[36: 42] # left-clockwise
79
+ lm_eye_right = lm[42: 48] # left-clockwise
80
+ lm_mouth_outer = lm[48: 60] # left-clockwise
81
+ lm_mouth_inner = lm[60: 68] # left-clockwise
82
+
83
+ # Calculate auxiliary vectors.
84
+ eye_left = np.mean(lm_eye_left, axis=0)
85
+ eye_right = np.mean(lm_eye_right, axis=0)
86
+ eye_avg = (eye_left + eye_right) * 0.5
87
+ eye_to_eye = eye_right - eye_left
88
+ mouth_left = lm_mouth_outer[0]
89
+ mouth_right = lm_mouth_outer[6]
90
+ mouth_avg = (mouth_left + mouth_right) * 0.5
91
+ eye_to_mouth = mouth_avg - eye_avg
92
+
93
+ # Choose oriented crop rectangle.
94
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
95
+ x /= np.hypot(*x)
96
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
97
+ y = np.flipud(x) * [-1, 1]
98
+ c = eye_avg + eye_to_mouth * 0.1
99
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
100
+ qsize = np.hypot(*x) * 2
101
+
102
+ # read image
103
+ img = PIL.Image.open(filepath)
104
+ enable_padding = True
105
+
106
+ # Shrink.
107
+ shrink = int(np.floor(qsize / output_size * 0.5))
108
+ if shrink > 1:
109
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
110
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
111
+ quad /= shrink
112
+ qsize /= shrink
113
+
114
+ # Crop.
115
+ border = max(int(np.rint(qsize * 0.1)), 3)
116
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
117
+ int(np.ceil(max(quad[:, 1]))))
118
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
119
+ min(crop[3] + border, img.size[1]))
120
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
121
+ img = img.crop(crop)
122
+ quad -= crop[0:2]
123
+
124
+ # Pad.
125
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
126
+ int(np.ceil(max(quad[:, 1]))))
127
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
128
+ max(pad[3] - img.size[1] + border, 0))
129
+ if enable_padding and max(pad) > border - 4:
130
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
131
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
132
+ h, w, _ = img.shape
133
+ y, x, _ = np.ogrid[:h, :w, :1]
134
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
135
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
136
+ blur = qsize * 0.02
137
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
138
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
139
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
140
+ quad += pad[:2]
141
+
142
+ # Transform.
143
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
144
+ if output_size < transform_size:
145
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
146
+
147
+ # Save aligned image.
148
+ return img
149
+
150
+
151
+ def chunks(lst, n):
152
+ """Yield successive n-sized chunks from lst."""
153
+ for i in range(0, len(lst), n):
154
+ yield lst[i:i + n]
155
+
156
+
157
+ def extract_on_paths(file_paths):
158
+ predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
159
+ pid = mp.current_process().name
160
+ print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths)))
161
+ tot_count = len(file_paths)
162
+ count = 0
163
+ for file_path, res_path in file_paths:
164
+ count += 1
165
+ if count % 100 == 0:
166
+ print('{} done with {}/{}'.format(pid, count, tot_count))
167
+ try:
168
+ res = align_face(file_path, predictor)
169
+ res = res.convert('RGB')
170
+ os.makedirs(os.path.dirname(res_path), exist_ok=True)
171
+ res.save(res_path)
172
+ except Exception:
173
+ continue
174
+ print('\tDone!')
175
+
176
+
177
+ def parse_args():
178
+ parser = ArgumentParser(add_help=False)
179
+ parser.add_argument('--num_threads', type=int, default=1)
180
+ parser.add_argument('--root_path', type=str, default='')
181
+ args = parser.parse_args()
182
+ return args
183
+
184
+
185
+ def run(args):
186
+ root_path = args.root_path
187
+ out_crops_path = root_path + '_crops'
188
+ if not os.path.exists(out_crops_path):
189
+ os.makedirs(out_crops_path, exist_ok=True)
190
+
191
+ file_paths = []
192
+ for root, dirs, files in os.walk(root_path):
193
+ for file in files:
194
+ file_path = os.path.join(root, file)
195
+ fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path))
196
+ res_path = '{}.jpg'.format(os.path.splitext(fname)[0])
197
+ if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path):
198
+ continue
199
+ file_paths.append((file_path, res_path))
200
+
201
+ file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
202
+ print(len(file_chunks))
203
+ pool = mp.Pool(args.num_threads)
204
+ print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
205
+ tic = time.time()
206
+ pool.map(extract_on_paths, file_chunks)
207
+ toc = time.time()
208
+ print('Mischief managed in {}s'.format(toc - tic))
209
+
210
+
211
+ if __name__ == '__main__':
212
+ args = parse_args()
213
+ run(args)
utils/celeba_attr.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 5_o_Clock_Shadow
2
+ Arched_Eyebrows
3
+ Attractive
4
+ Bags_Under_Eyes
5
+ Bald
6
+ Bangs
7
+ Big_Lips
8
+ Big_Nose
9
+ Black_Hair
10
+ Blond_Hair
11
+ Blurry
12
+ Brown_Hair
13
+ Bushy_Eyebrows
14
+ Chubby
15
+ Double_Chin
16
+ Eyeglasses
17
+ Goatee
18
+ Gray_Hair
19
+ Heavy_Makeup
20
+ High_Cheekbones
21
+ Male
22
+ Mouth_Slightly_Open
23
+ Mustache
24
+ Narrow_Eyes
25
+ No_Beard
26
+ Oval_Face
27
+ Pale_Skin
28
+ Pointy_Nose
29
+ Receding_Hairline
30
+ Rosy_Cheeks
31
+ Sideburns
32
+ Smiling
33
+ Straight_Hair
34
+ Wavy_Hair
35
+ Wearing_Earrings
36
+ Wearing_Hat
37
+ Wearing_Lipstick
38
+ Wearing_Necklace
39
+ Wearing_Necktie
40
+ Young
utils/colab_utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydrive.auth import GoogleAuth
2
+ from pydrive.drive import GoogleDrive
3
+ from google.colab import auth
4
+ from oauth2client.client import GoogleCredentials
5
+ import os
6
+
7
+
8
+ class GoogleDrive_Dowonloader(object):
9
+ def __init__(self, use_pydrive):
10
+ self.use_pydrive = use_pydrive
11
+
12
+ if self.use_pydrive:
13
+ self.authenticate()
14
+
15
+ def authenticate(self):
16
+ auth.authenticate_user()
17
+ gauth = GoogleAuth()
18
+ gauth.credentials = GoogleCredentials.get_application_default()
19
+ self.drive = GoogleDrive(gauth)
20
+
21
+ def ensure_file_exists(self, file_id, file_dst):
22
+ if not os.path.isfile(file_dst):
23
+ if self.use_pydrive:
24
+ print(f'Downloading {file_dst} ...')
25
+ downloaded = self.drive.CreateFile({'id':file_id})
26
+ downloaded.FetchMetadata(fetch_all=True)
27
+ downloaded.GetContentFile(file_dst)
28
+ print('Finished')
29
+ else:
30
+ from gdown import download as drive_download
31
+ drive_download(f'https://drive.google.com/uc?id={file_id}', file_dst, quiet=False)
32
+ else:
33
+ print(f'{file_dst} exists.')
34
+
35
+
36
+
utils/diffusion_utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
6
+ betas = np.linspace(beta_start, beta_end,
7
+ num_diffusion_timesteps, dtype=np.float64)
8
+ assert betas.shape == (num_diffusion_timesteps,)
9
+ return betas
10
+
11
+
12
+ def extract(a, t, x_shape):
13
+ """Extract coefficients from a based on t and reshape to make it
14
+ broadcastable with x_shape."""
15
+ bs, = t.shape
16
+ assert x_shape[0] == bs
17
+ out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long())
18
+ assert out.shape == (bs,)
19
+ out = out.reshape((bs,) + (1,) * (len(x_shape) - 1))
20
+ return out
21
+
22
+
23
+ def denoising_step(xt, t, t_next, *,
24
+ models,
25
+ logvars,
26
+ b,
27
+ sampling_type='ddpm',
28
+ eta=0.0,
29
+ learn_sigma=False,
30
+ hybrid=False,
31
+ hybrid_config=None,
32
+ ratio=1.0,
33
+ out_x0_t=False,
34
+ edit_h=None,
35
+ ):
36
+
37
+ # Compute noise and variance
38
+ if type(models) != list:
39
+ model = models
40
+ if edit_h == None:
41
+ mid_h, et = model(xt, t)
42
+ # print("check mid_h and et:", mid_h.size(), et.size())
43
+ else:
44
+ mid_h, et = model(xt, t, edit_h)
45
+ # print("Denoising for editing!")
46
+ if learn_sigma:
47
+ et, logvar_learned = torch.split(et, et.shape[1] // 2, dim=1)
48
+ logvar = logvar_learned
49
+ # print("split et:", et.size())
50
+ else:
51
+ logvar = extract(logvars, t, xt.shape)
52
+ else:
53
+ if not hybrid:
54
+ et = 0
55
+ logvar = 0
56
+ if ratio != 0.0:
57
+ et_i = ratio * models[1](xt, t)
58
+ if learn_sigma:
59
+ et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
60
+ logvar += logvar_learned
61
+ else:
62
+ logvar += ratio * extract(logvars, t, xt.shape)
63
+ et += et_i
64
+
65
+ if ratio != 1.0:
66
+ et_i = (1 - ratio) * models[0](xt, t)
67
+ if learn_sigma:
68
+ et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
69
+ logvar += logvar_learned
70
+ else:
71
+ logvar += (1 - ratio) * extract(logvars, t, xt.shape)
72
+ et += et_i
73
+
74
+ else:
75
+ for thr in list(hybrid_config.keys()):
76
+ if t.item() >= thr:
77
+ et = 0
78
+ logvar = 0
79
+ for i, ratio in enumerate(hybrid_config[thr]):
80
+ ratio /= sum(hybrid_config[thr])
81
+ et_i = models[i+1](xt, t)
82
+ if learn_sigma:
83
+ et_i, logvar_learned = torch.split(et_i, et_i.shape[1] // 2, dim=1)
84
+ logvar_i = logvar_learned
85
+ else:
86
+ logvar_i = extract(logvars, t, xt.shape)
87
+ et += ratio * et_i
88
+ logvar += ratio * logvar_i
89
+ break
90
+
91
+ # Compute the next x
92
+ bt = extract(b, t, xt.shape)
93
+ at = extract((1.0 - b).cumprod(dim=0), t, xt.shape)
94
+
95
+ if t_next.sum() == -t_next.shape[0]:
96
+ at_next = torch.ones_like(at)
97
+ else:
98
+ at_next = extract((1.0 - b).cumprod(dim=0), t_next, xt.shape)
99
+
100
+ xt_next = torch.zeros_like(xt)
101
+ if sampling_type == 'ddpm':
102
+ weight = bt / torch.sqrt(1 - at)
103
+
104
+ mean = 1 / torch.sqrt(1.0 - bt) * (xt - weight * et)
105
+ noise = torch.randn_like(xt)
106
+ mask = 1 - (t == 0).float()
107
+ mask = mask.reshape((xt.shape[0],) + (1,) * (len(xt.shape) - 1))
108
+ xt_next = mean + mask * torch.exp(0.5 * logvar) * noise
109
+ xt_next = xt_next.float()
110
+
111
+ elif sampling_type == 'ddim':
112
+ # print("check ddim incersion:", et.size())
113
+ x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
114
+ if eta == 0:
115
+ xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et
116
+ elif at > (at_next):
117
+ print('Inversion process is only possible with eta = 0')
118
+ raise ValueError
119
+ else:
120
+ c1 = eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt()
121
+ c2 = ((1 - at_next) - c1 ** 2).sqrt()
122
+ xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * torch.randn_like(xt)
123
+
124
+
125
+
126
+ # print("check out:", xt_next.size(), mid_h.size(), x0_t.size())
127
+ if out_x0_t == True:
128
+ # print("three output!")
129
+ return xt_next, x0_t, mid_h
130
+ else:
131
+ # print("two output!")
132
+ return xt_next, mid_h
133
+
134
+
utils/prepare_lmdb_data.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Refer to https://github.com/rosinality/stylegan2-pytorch/blob/master/prepare_data.py
3
+ """
4
+
5
+ import argparse
6
+ from io import BytesIO
7
+ import multiprocessing
8
+ from functools import partial
9
+ import os, glob, sys
10
+
11
+ from PIL import Image
12
+ import lmdb
13
+ from tqdm import tqdm
14
+ from torchvision import datasets
15
+ from torchvision.transforms import functional as trans_fn
16
+
17
+
18
+ def resize_and_convert(img, size, resample, quality=100):
19
+ img = trans_fn.resize(img, (size, size), resample)
20
+ # img = trans_fn.center_crop(img, size)
21
+ buffer = BytesIO()
22
+ img.save(buffer, format="jpeg", quality=quality)
23
+ val = buffer.getvalue()
24
+
25
+ return val
26
+
27
+
28
+ def resize_multiple(
29
+ img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
30
+ ):
31
+ imgs = []
32
+
33
+ for size in sizes:
34
+ imgs.append(resize_and_convert(img, size, resample, quality))
35
+
36
+ return imgs
37
+
38
+
39
+ def resize_worker(img_file, sizes, resample):
40
+ i, file, img_id = img_file
41
+ # print("check resize_worker:", i, file, img_id)
42
+ img = Image.open(file)
43
+ img = img.convert("RGB")
44
+ out = resize_multiple(img, sizes=sizes, resample=resample)
45
+
46
+ return i, out, img_id
47
+
48
+
49
+ def file_to_list(filename):
50
+ with open(filename, encoding='utf-8') as f:
51
+ files = f.readlines()
52
+ files = [f.rstrip() for f in files]
53
+ return files
54
+
55
+
56
+
57
+ def prepare(
58
+ env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
59
+ ):
60
+ resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
61
+ files = sorted(dataset.imgs, key=lambda x: x[0])
62
+ files = [(i, file, file.split('/')[-1].split('.')[0]) for i, (file, label) in enumerate(files)]
63
+ total = 0
64
+
65
+ with multiprocessing.Pool(n_worker) as pool:
66
+ for i, imgs, img_id in tqdm(pool.imap_unordered(resize_fn, files)):
67
+ key_label = f"{str(i).zfill(5)}".encode("utf-8")
68
+ for size, img in zip(sizes, imgs):
69
+ key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
70
+ with env.begin(write=True) as txn:
71
+ txn.put(key, img)
72
+ txn.put(key_label, str(img_id).encode("utf-8"))
73
+
74
+ total += 1
75
+
76
+ with env.begin(write=True) as txn:
77
+ txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
78
+
79
+
80
+ def prepare_attr(
81
+ env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, label_attr='gender'
82
+ ):
83
+ resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
84
+ files = sorted(dataset.imgs, key=lambda x: x[0])
85
+ attr_file_path = '/n/fs/yz-diff/inversion/list_attr_celeba.txt'
86
+ labels = file_to_list(attr_file_path)
87
+ attr_dict = {}
88
+ files_attr = []
89
+ for i, (file, split) in enumerate(files):
90
+ img_id = int(file.split('/')[-1].split('.')[0])
91
+ # print("check i, file, and split:", i, file, split, img_id)
92
+ attr_label = labels[img_id-1].split()
93
+ label = int(attr_label[21])
94
+ # print("check attr_label:", attr_label, len(attr_label), label)
95
+ files_attr.append((i, file, label))
96
+ # exit()
97
+
98
+ files = files_attr
99
+ # files = [(i, file) for i, (file, label) in enumerate(files)]
100
+ total = 0
101
+
102
+
103
+ with multiprocessing.Pool(n_worker) as pool:
104
+ for i, imgs, label in tqdm(pool.imap_unordered(resize_fn, files)):
105
+ # print("check i, imgs, label:", label)
106
+ for size, img in zip(sizes, imgs):
107
+ key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
108
+ key_label = f"{'label'}-{str(i).zfill(5)}".encode("utf-8")
109
+
110
+ with env.begin(write=True) as txn:
111
+ txn.put(key, img)
112
+ txn.put(key_label, str(label).encode("utf-8"))
113
+
114
+ total += 1
115
+
116
+ with env.begin(write=True) as txn:
117
+ txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument("--out", type=str)
123
+ parser.add_argument("--size", type=str, default="128,256,512,1024")
124
+ parser.add_argument("--n_worker", type=int, default=5)
125
+ parser.add_argument("--resample", type=str, default="bilinear")
126
+ parser.add_argument("--attr", type=str)
127
+ parser.add_argument("path", type=str)
128
+
129
+ args = parser.parse_args()
130
+
131
+ resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
132
+ resample = resample_map[args.resample]
133
+
134
+ sizes = [int(s.strip()) for s in args.size.split(",")]
135
+ print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
136
+
137
+ imgset = datasets.ImageFolder(args.path)
138
+
139
+ with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
140
+ prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
utils/text_dic.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SRC_TRG_TXT_DIC = {
2
+ # Human face
3
+ 'tanned': (['face'],
4
+ ['tanned face']),
5
+ 'pale': (['face'],
6
+ ['pale face']),
7
+ 'makeup': (['person'],
8
+ ['person with makeup']),
9
+ 'no_makeup': (['person'],
10
+ ['person without makeup']),
11
+ 'old': (['person'],
12
+ ['old person']),
13
+ 'young': (['person'],
14
+ ['young person']),
15
+ 'beards': (['person'],
16
+ ['person with beards']),
17
+ 'angry': (['face'],
18
+ ['angry face']),
19
+ 'surprised': (['face'],
20
+ ['surprised face']),
21
+ 'smiling': (['face'],
22
+ ['smiling face']),
23
+ 'blond_hair': (['person'],
24
+ ['person with blond hair']),
25
+ 'red_hair': (['person'],
26
+ ['person with red hair']),
27
+ 'grey_hair': (['person'],
28
+ ['person with red hair']),
29
+ 'curly_hair': (['person'],
30
+ ['person with curly hair']),
31
+
32
+ 'nicolas': (['Person'],
33
+ ['Nicolas Cage']),
34
+ 'zuckerberg': (['Person'],
35
+ ['Mark Zuckerberg']),
36
+ 'benedict': (['Person'],
37
+ ['Benedict Cumberbatch']),
38
+ 'gogh': (['photo'],
39
+ ['painting by Gogh']),
40
+ 'frida': (['photo'],
41
+ ['self-portrait by Frida Kahlo']),
42
+ 'modigliani': (['photo'],
43
+ ['Painting in Modigliani style']),
44
+ 'sketch': (['photo'],
45
+ ['sketch']),
46
+ 'watercolor': (['photo'],
47
+ ['Watercolor Art with Thick Brushstrokes']),
48
+ 'elf': (['Human'],
49
+ ['Tolkien elf']),
50
+ 'super_saiyan': (['Human'],
51
+ ['Super saiyan']),
52
+ 'pixar': (['Human'],
53
+ ['3D render in the style of Pixar']),
54
+ 'neanderthal': (['Human'],
55
+ ['Neanderthal']),
56
+ 'zombie': (['Human'],
57
+ ['Zombie']),
58
+ 'jocker': (['Human'],
59
+ ['The Jocker']),
60
+
61
+
62
+ # Dog face
63
+ 'dog_nicolas': (['Dog'],
64
+ ['Nicolas Cage']),
65
+ 'dog_yorkshire': (['Dog'],
66
+ ['Yorkshire Terrier']),
67
+ 'dog_smiling': (['Dog'],
68
+ ['Smiling Dog']),
69
+ 'dog_zombie': (['Dog'],
70
+ ['Zombie']),
71
+ 'dog_super_saiyan': (['Dog'],
72
+ ['Super saiyan']),
73
+ 'dog_venom': (['Dog'],
74
+ ['Venom']),
75
+ 'dog_bear': (['Dog'],
76
+ ['Bear']),
77
+ 'dog_fox': (['Dog'],
78
+ ['Fox']),
79
+ 'dog_wolf': (['Dog'],
80
+ ['Wolf']),
81
+ 'dog_hamster': (['Dog'],
82
+ ['Hamster']),
83
+
84
+
85
+ # Church
86
+ 'church_snow': (['Church'],
87
+ ['Snow Coverd Church']),
88
+ 'church_night': (['Church'],
89
+ ['Church at night']),
90
+ 'church_red_brick': (['Church'],
91
+ ['Red brick wall Church']),
92
+ 'church_golden': (['Church'],
93
+ ['Golden Church']),
94
+ 'church_wooden_house': (['Church'],
95
+ ['Wooden House']),
96
+ 'church_gothic': (['Church'],
97
+ ['Gothic Church']),
98
+ 'church_ancient_tower': (['Church'],
99
+ ['Ancient traditional Asian tower']),
100
+ 'church_temple': (['Church'],
101
+ ['Temple']),
102
+ 'church_factory': (['church'],
103
+ ['factory with chimneys']),
104
+ 'church_department_store': (['church'],
105
+ ['department store']),
106
+
107
+
108
+ # Bedroom
109
+ 'bedroom_blue': (['Bedroom'],
110
+ ['Blue tone Bedroom']),
111
+ 'bedroom_green': (['Bedroom'],
112
+ ['Green tone Bedroom']),
113
+ 'bedroom_golden': (['Bedroom'],
114
+ ['Golden Bedroom']),
115
+ 'bedroom_princess': (['Bedroom'],
116
+ ['Princess Bedroom']),
117
+ 'bedroom_palace': (['Bedroom'],
118
+ ['Palace Bedroom']),
119
+ 'bedroom_wooden': (['Bedroom'],
120
+ ['Wooden Bedroom']),
121
+
122
+
123
+ }
utils/text_templates.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imagenet_templates = [
2
+ 'a bad photo of a {}.',
3
+ 'a sculpture of a {}.',
4
+ 'a photo of the hard to see {}.',
5
+ 'a low resolution photo of the {}.',
6
+ 'a rendering of a {}.',
7
+ 'graffiti of a {}.',
8
+ 'a bad photo of the {}.',
9
+ 'a cropped photo of the {}.',
10
+ 'a tattoo of a {}.',
11
+ 'the embroidered {}.',
12
+ 'a photo of a hard to see {}.',
13
+ 'a bright photo of a {}.',
14
+ 'a photo of a clean {}.',
15
+ 'a photo of a dirty {}.',
16
+ 'a dark photo of the {}.',
17
+ 'a drawing of a {}.',
18
+ 'a photo of my {}.',
19
+ 'the plastic {}.',
20
+ 'a photo of the cool {}.',
21
+ 'a close-up photo of a {}.',
22
+ 'a black and white photo of the {}.',
23
+ 'a painting of the {}.',
24
+ 'a painting of a {}.',
25
+ 'a pixelated photo of the {}.',
26
+ 'a sculpture of the {}.',
27
+ 'a bright photo of the {}.',
28
+ 'a cropped photo of a {}.',
29
+ 'a plastic {}.',
30
+ 'a photo of the dirty {}.',
31
+ 'a jpeg corrupted photo of a {}.',
32
+ 'a blurry photo of the {}.',
33
+ 'a photo of the {}.',
34
+ 'a good photo of the {}.',
35
+ 'a rendering of the {}.',
36
+ 'a {} in a video game.',
37
+ 'a photo of one {}.',
38
+ 'a doodle of a {}.',
39
+ 'a close-up photo of the {}.',
40
+ 'a photo of a {}.',
41
+ 'the origami {}.',
42
+ 'the {} in a video game.',
43
+ 'a sketch of a {}.',
44
+ 'a doodle of the {}.',
45
+ 'a origami {}.',
46
+ 'a low resolution photo of a {}.',
47
+ 'the toy {}.',
48
+ 'a rendition of the {}.',
49
+ 'a photo of the clean {}.',
50
+ 'a photo of a large {}.',
51
+ 'a rendition of a {}.',
52
+ 'a photo of a nice {}.',
53
+ 'a photo of a weird {}.',
54
+ 'a blurry photo of a {}.',
55
+ 'a cartoon {}.',
56
+ 'art of a {}.',
57
+ 'a sketch of the {}.',
58
+ 'a embroidered {}.',
59
+ 'a pixelated photo of a {}.',
60
+ 'itap of the {}.',
61
+ 'a jpeg corrupted photo of the {}.',
62
+ 'a good photo of a {}.',
63
+ 'a plushie {}.',
64
+ 'a photo of the nice {}.',
65
+ 'a photo of the small {}.',
66
+ 'a photo of the weird {}.',
67
+ 'the cartoon {}.',
68
+ 'art of the {}.',
69
+ 'a drawing of the {}.',
70
+ 'a photo of the large {}.',
71
+ 'a black and white photo of a {}.',
72
+ 'the plushie {}.',
73
+ 'a dark photo of a {}.',
74
+ 'itap of a {}.',
75
+ 'graffiti of the {}.',
76
+ 'a toy {}.',
77
+ 'itap of my {}.',
78
+ 'a photo of a cool {}.',
79
+ 'a photo of a small {}.',
80
+ 'a tattoo of the {}.',
81
+ ]
82
+
83
+ part_templates = [
84
+ 'the paw of a {}.',
85
+ 'the nose of a {}.',
86
+ 'the eye of the {}.',
87
+ 'the ears of a {}.',
88
+ 'an eye of a {}.',
89
+ 'the tongue of a {}.',
90
+ 'the fur of the {}.',
91
+ 'colorful {} fur.',
92
+ 'a snout of a {}.',
93
+ 'the teeth of the {}.',
94
+ 'the {}s fangs.',
95
+ 'a claw of the {}.',
96
+ 'the face of the {}',
97
+ 'a neck of a {}',
98
+ 'the head of the {}',
99
+ ]
100
+
101
+ imagenet_templates_small = [
102
+ 'a photo of a {}.',
103
+ 'a rendering of a {}.',
104
+ 'a cropped photo of the {}.',
105
+ 'the photo of a {}.',
106
+ 'a photo of a clean {}.',
107
+ 'a photo of a dirty {}.',
108
+ 'a dark photo of the {}.',
109
+ 'a photo of my {}.',
110
+ 'a photo of the cool {}.',
111
+ 'a close-up photo of a {}.',
112
+ 'a bright photo of the {}.',
113
+ 'a cropped photo of a {}.',
114
+ 'a photo of the {}.',
115
+ 'a good photo of the {}.',
116
+ 'a photo of one {}.',
117
+ 'a close-up photo of the {}.',
118
+ 'a rendition of the {}.',
119
+ 'a photo of the clean {}.',
120
+ 'a rendition of a {}.',
121
+ 'a photo of a nice {}.',
122
+ 'a good photo of a {}.',
123
+ 'a photo of the nice {}.',
124
+ 'a photo of the small {}.',
125
+ 'a photo of the weird {}.',
126
+ 'a photo of the large {}.',
127
+ 'a photo of a cool {}.',
128
+ 'a photo of a small {}.',
129
+ ]