Spaces:
Running
on
Zero
Running
on
Zero
main
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +12 -0
- LICENSE +21 -0
- README.md +2 -2
- flux_inference_recraft.py +442 -0
- flux_minimal_inference.py +576 -0
- flux_minimal_inference_asylora.py +583 -0
- flux_train_network.py +588 -0
- flux_train_network_asylora.py +591 -0
- flux_train_recraft.py +713 -0
- gradio_app.py +372 -0
- gradio_app_asy.py +329 -0
- library/__init__.py +0 -0
- library/adafactor_fused.py +138 -0
- library/attention_processors.py +227 -0
- library/config_util.py +716 -0
- library/custom_offloading_utils.py +227 -0
- library/custom_train_functions.py +559 -0
- library/deepspeed_utils.py +139 -0
- library/device_utils.py +84 -0
- library/flux_models.py +1237 -0
- library/flux_train_utils.py +582 -0
- library/flux_train_utils_recraft.py +659 -0
- library/flux_utils.py +472 -0
- library/huggingface_util.py +84 -0
- library/hypernetwork.py +223 -0
- library/ipex/__init__.py +180 -0
- library/ipex/attention.py +177 -0
- library/ipex/diffusers.py +312 -0
- library/ipex/gradscaler.py +183 -0
- library/ipex/hijacks.py +313 -0
- library/lpw_stable_diffusion.py +1233 -0
- library/model_util.py +1356 -0
- library/original_unet.py +1919 -0
- library/sai_model_spec.py +334 -0
- library/sd3_models.py +1413 -0
- library/sd3_train_utils.py +945 -0
- library/sd3_utils.py +302 -0
- library/sdxl_lpw_stable_diffusion.py +1271 -0
- library/sdxl_model_util.py +583 -0
- library/sdxl_original_control_net.py +272 -0
- library/sdxl_original_unet.py +1292 -0
- library/sdxl_train_util.py +382 -0
- library/slicing_vae.py +682 -0
- library/strategy_base.py +570 -0
- library/strategy_flux.py +271 -0
- library/strategy_sd.py +171 -0
- library/strategy_sd3.py +420 -0
- library/strategy_sdxl.py +306 -0
- library/train_util.py +0 -0
- library/utils.py +582 -0
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
test
|
3 |
+
*.egg-info
|
4 |
+
.vscode
|
5 |
+
.gradio
|
6 |
+
wandb
|
7 |
+
Merge
|
8 |
+
asy_results
|
9 |
+
recraft_results
|
10 |
+
drop
|
11 |
+
SplitAsy
|
12 |
+
example*
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Show Lab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -4,8 +4,8 @@ emoji: 🖼
|
|
4 |
colorFrom: purple
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: Generate high quality images from prmopts
|
|
|
4 |
colorFrom: purple
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.38.0
|
8 |
+
app_file: gradio_app_asy.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
short_description: Generate high quality images from prmopts
|
flux_inference_recraft.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from typing import Any
|
6 |
+
import pdb
|
7 |
+
import os
|
8 |
+
|
9 |
+
import time
|
10 |
+
from PIL import Image, ImageOps
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from accelerate import Accelerator
|
14 |
+
from library.device_utils import clean_memory_on_device
|
15 |
+
from safetensors.torch import load_file
|
16 |
+
from networks import lora_flux
|
17 |
+
|
18 |
+
from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, \
|
19 |
+
strategy_base, strategy_flux, train_util
|
20 |
+
from torchvision import transforms
|
21 |
+
import train_network
|
22 |
+
from library.utils import setup_logging
|
23 |
+
from diffusers.utils import load_image
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
setup_logging()
|
27 |
+
import logging
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
def load_target_model(
|
33 |
+
fp8_base: bool,
|
34 |
+
pretrained_model_name_or_path: str,
|
35 |
+
disable_mmap_load_safetensors: bool,
|
36 |
+
clip_l_path: str,
|
37 |
+
fp8_base_unet: bool,
|
38 |
+
t5xxl_path: str,
|
39 |
+
ae_path: str,
|
40 |
+
weight_dtype: torch.dtype,
|
41 |
+
accelerator: Accelerator
|
42 |
+
):
|
43 |
+
# Determine the loading data type
|
44 |
+
loading_dtype = None if fp8_base else weight_dtype
|
45 |
+
|
46 |
+
# Load the main model to the accelerator's device
|
47 |
+
_, model = flux_utils.load_flow_model(
|
48 |
+
pretrained_model_name_or_path,
|
49 |
+
# loading_dtype,
|
50 |
+
torch.float8_e4m3fn,
|
51 |
+
# accelerator.device, # Changed from "cpu" to accelerator.device
|
52 |
+
"cpu",
|
53 |
+
disable_mmap=disable_mmap_load_safetensors
|
54 |
+
)
|
55 |
+
|
56 |
+
if fp8_base:
|
57 |
+
# Check dtype of the model
|
58 |
+
if model.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
|
59 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
60 |
+
elif model.dtype == torch.float8_e4m3fn:
|
61 |
+
logger.info("Loaded fp8 FLUX model")
|
62 |
+
|
63 |
+
# Load the CLIP model to the accelerator's device
|
64 |
+
clip_l = flux_utils.load_clip_l(
|
65 |
+
clip_l_path,
|
66 |
+
weight_dtype,
|
67 |
+
# accelerator.device, # Changed from "cpu" to accelerator.device
|
68 |
+
"cpu",
|
69 |
+
disable_mmap=disable_mmap_load_safetensors
|
70 |
+
)
|
71 |
+
clip_l.eval()
|
72 |
+
|
73 |
+
# Determine the loading data type for T5XXL
|
74 |
+
if fp8_base and not fp8_base_unet:
|
75 |
+
loading_dtype_t5xxl = None # as is
|
76 |
+
else:
|
77 |
+
loading_dtype_t5xxl = weight_dtype
|
78 |
+
|
79 |
+
# Load the T5XXL model to the accelerator's device
|
80 |
+
t5xxl = flux_utils.load_t5xxl(
|
81 |
+
t5xxl_path,
|
82 |
+
loading_dtype_t5xxl,
|
83 |
+
# accelerator.device, # Changed from "cpu" to accelerator.device
|
84 |
+
"cpu",
|
85 |
+
disable_mmap=disable_mmap_load_safetensors
|
86 |
+
)
|
87 |
+
t5xxl.eval()
|
88 |
+
|
89 |
+
if fp8_base and not fp8_base_unet:
|
90 |
+
# Check dtype of the T5XXL model
|
91 |
+
if t5xxl.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
|
92 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
93 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
94 |
+
logger.info("Loaded fp8 T5XXL model")
|
95 |
+
|
96 |
+
# Load the AE model to the accelerator's device
|
97 |
+
ae = flux_utils.load_ae(
|
98 |
+
ae_path,
|
99 |
+
weight_dtype,
|
100 |
+
# accelerator.device, # Changed from "cpu" to accelerator.device
|
101 |
+
"cpu",
|
102 |
+
disable_mmap=disable_mmap_load_safetensors
|
103 |
+
)
|
104 |
+
|
105 |
+
# # Wrap models with Accelerator for potential distributed setups
|
106 |
+
# model, clip_l, t5xxl, ae = accelerator.prepare(model, clip_l, t5xxl, ae)
|
107 |
+
|
108 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
109 |
+
|
110 |
+
|
111 |
+
import torchvision.transforms as transforms
|
112 |
+
|
113 |
+
|
114 |
+
class ResizeWithPadding:
|
115 |
+
def __init__(self, size, fill=255):
|
116 |
+
self.size = size
|
117 |
+
self.fill = fill
|
118 |
+
|
119 |
+
def __call__(self, img):
|
120 |
+
if isinstance(img, np.ndarray):
|
121 |
+
img = Image.fromarray(img)
|
122 |
+
elif not isinstance(img, Image.Image):
|
123 |
+
raise TypeError("Input must be a PIL Image or a NumPy array")
|
124 |
+
|
125 |
+
width, height = img.size
|
126 |
+
|
127 |
+
if width == height:
|
128 |
+
img = img.resize((self.size, self.size), Image.LANCZOS)
|
129 |
+
else:
|
130 |
+
max_dim = max(width, height)
|
131 |
+
|
132 |
+
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
|
133 |
+
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
|
134 |
+
|
135 |
+
img = new_img.resize((self.size, self.size), Image.LANCZOS)
|
136 |
+
|
137 |
+
return img
|
138 |
+
|
139 |
+
|
140 |
+
def sample(args, accelerator, vae, text_encoder, flux, output_dir, sample_images, sample_prompts):
|
141 |
+
def encode_images_to_latents(vae, images):
|
142 |
+
# Get image dimensions
|
143 |
+
b, c, h, w = images.shape
|
144 |
+
num_split = 2 if args.frame_num == 4 else 3
|
145 |
+
# Split the image into three parts
|
146 |
+
img_parts = [images[:, :, :, i * w // num_split:(i + 1) * w // num_split] for i in range(num_split)]
|
147 |
+
# Encode each part
|
148 |
+
latents = [vae.encode(img) for img in img_parts]
|
149 |
+
# Concatenate latents in the latent space to reconstruct the full image
|
150 |
+
latents = torch.cat(latents, dim=-1)
|
151 |
+
return latents
|
152 |
+
|
153 |
+
def encode_images_to_latents2(vae, images):
|
154 |
+
latents = vae.encode(images)
|
155 |
+
return latents
|
156 |
+
|
157 |
+
# Directly use precomputed conditions
|
158 |
+
conditions = {}
|
159 |
+
with torch.no_grad():
|
160 |
+
for image_path, prompt_dict in zip(sample_images, sample_prompts):
|
161 |
+
prompt = prompt_dict.get("prompt", "")
|
162 |
+
if prompt not in conditions:
|
163 |
+
logger.info(f"Cache conditions for image: {image_path} with prompt: {prompt}")
|
164 |
+
resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255)
|
165 |
+
img_transforms = transforms.Compose([
|
166 |
+
resize_transform,
|
167 |
+
transforms.ToTensor(),
|
168 |
+
transforms.Normalize([0.5], [0.5]),
|
169 |
+
])
|
170 |
+
# Load and preprocess image
|
171 |
+
image = img_transforms(np.array(load_image(image_path), dtype=np.uint8)).unsqueeze(0).to(
|
172 |
+
# accelerator.device, # Move image to CUDA
|
173 |
+
vae.device,
|
174 |
+
dtype=vae.dtype
|
175 |
+
)
|
176 |
+
latents = encode_images_to_latents2(vae, image)
|
177 |
+
|
178 |
+
# Log the shape of latents
|
179 |
+
logger.debug(f"Encoded latents shape for prompt '{prompt}': {latents.shape}")
|
180 |
+
# Store conditions on CUDA
|
181 |
+
# conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu")
|
182 |
+
conditions[prompt] = latents.to("cpu")
|
183 |
+
|
184 |
+
sample_conditions = conditions
|
185 |
+
|
186 |
+
if sample_conditions is not None:
|
187 |
+
conditions = {k: v for k, v in sample_conditions.items()} # Already on CUDA
|
188 |
+
|
189 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
190 |
+
text_encoder[0].to(accelerator.device)
|
191 |
+
text_encoder[1].to(accelerator.device)
|
192 |
+
|
193 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
|
194 |
+
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
|
195 |
+
|
196 |
+
with accelerator.autocast(), torch.no_grad():
|
197 |
+
for prompt_dict in sample_prompts:
|
198 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
199 |
+
if p not in sample_prompts_te_outputs:
|
200 |
+
logger.info(f"Cache Text Encoder outputs for prompt: {p}")
|
201 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
202 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
203 |
+
tokenize_strategy, text_encoder, tokens_and_masks, True
|
204 |
+
)
|
205 |
+
|
206 |
+
logger.info(f"Generating image")
|
207 |
+
save_dir = output_dir
|
208 |
+
os.makedirs(save_dir, exist_ok=True)
|
209 |
+
|
210 |
+
with torch.no_grad(), accelerator.autocast():
|
211 |
+
for prompt_dict in sample_prompts:
|
212 |
+
sample_image_inference(
|
213 |
+
args,
|
214 |
+
accelerator,
|
215 |
+
flux,
|
216 |
+
text_encoder,
|
217 |
+
vae,
|
218 |
+
save_dir,
|
219 |
+
prompt_dict,
|
220 |
+
sample_prompts_te_outputs,
|
221 |
+
None,
|
222 |
+
conditions
|
223 |
+
)
|
224 |
+
|
225 |
+
clean_memory_on_device(accelerator.device)
|
226 |
+
|
227 |
+
|
228 |
+
def sample_image_inference(
|
229 |
+
args,
|
230 |
+
accelerator: Accelerator,
|
231 |
+
flux: flux_models.Flux,
|
232 |
+
text_encoder,
|
233 |
+
ae: flux_models.AutoEncoder,
|
234 |
+
save_dir,
|
235 |
+
prompt_dict,
|
236 |
+
sample_prompts_te_outputs,
|
237 |
+
prompt_replacement,
|
238 |
+
sample_images_ae_outputs
|
239 |
+
):
|
240 |
+
# Extract parameters from prompt_dict
|
241 |
+
sample_steps = prompt_dict.get("sample_steps", 20)
|
242 |
+
width = prompt_dict.get("width", 1024) if args.frame_num == 4 else prompt_dict.get("width", 1056)
|
243 |
+
height = prompt_dict.get("height", 1024) if args.frame_num == 4 else prompt_dict.get("height", 1056)
|
244 |
+
scale = prompt_dict.get("scale", 1.0)
|
245 |
+
seed = prompt_dict.get("seed")
|
246 |
+
prompt: str = prompt_dict.get("prompt", "")
|
247 |
+
|
248 |
+
if prompt_replacement is not None:
|
249 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
250 |
+
|
251 |
+
if seed is not None:
|
252 |
+
torch.manual_seed(seed)
|
253 |
+
torch.cuda.manual_seed(seed)
|
254 |
+
else:
|
255 |
+
# True random sample image generation
|
256 |
+
torch.seed()
|
257 |
+
torch.cuda.seed()
|
258 |
+
|
259 |
+
# Ensure height and width are divisible by 16
|
260 |
+
height = max(64, height - height % 16)
|
261 |
+
width = max(64, width - width % 16)
|
262 |
+
logger.info(f"prompt: {prompt}")
|
263 |
+
logger.info(f"height: {height}")
|
264 |
+
logger.info(f"width: {width}")
|
265 |
+
logger.info(f"sample_steps: {sample_steps}")
|
266 |
+
logger.info(f"scale: {scale}")
|
267 |
+
if seed is not None:
|
268 |
+
logger.info(f"seed: {seed}")
|
269 |
+
|
270 |
+
# Encode prompts
|
271 |
+
# Assuming that TokenizeStrategy and TextEncodingStrategy are compatible with Accelerator
|
272 |
+
text_encoder_conds = []
|
273 |
+
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
274 |
+
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
275 |
+
logger.info(f"Using cached text encoder outputs for prompt: {prompt}")
|
276 |
+
|
277 |
+
if sample_images_ae_outputs and prompt in sample_images_ae_outputs:
|
278 |
+
ae_outputs = sample_images_ae_outputs[prompt]
|
279 |
+
else:
|
280 |
+
ae_outputs = None
|
281 |
+
|
282 |
+
# ae_outputs = torch.load('ae_outputs.pth', map_location='cuda:0')
|
283 |
+
|
284 |
+
# text_encoder_conds = torch.load('text_encoder_conds.pth', map_location='cuda:0')
|
285 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
286 |
+
|
287 |
+
# 打印调试信息
|
288 |
+
logger.debug(
|
289 |
+
f"l_pooled shape: {l_pooled.shape}, t5_out shape: {t5_out.shape}, txt_ids shape: {txt_ids.shape}, t5_attn_mask shape: {t5_attn_mask.shape}")
|
290 |
+
|
291 |
+
# 采样图像
|
292 |
+
weight_dtype = ae.dtype # TODO: give dtype as argument
|
293 |
+
packed_latent_height = height // 16
|
294 |
+
packed_latent_width = width // 16
|
295 |
+
|
296 |
+
# 打印调试信息
|
297 |
+
logger.debug(f"packed_latent_height: {packed_latent_height}, packed_latent_width: {packed_latent_width}")
|
298 |
+
|
299 |
+
# 准备噪声张量在 CUDA 上
|
300 |
+
noise = torch.randn(
|
301 |
+
1,
|
302 |
+
packed_latent_height * packed_latent_width,
|
303 |
+
16 * 2 * 2,
|
304 |
+
device=accelerator.device,
|
305 |
+
dtype=weight_dtype,
|
306 |
+
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
307 |
+
)
|
308 |
+
|
309 |
+
timesteps = flux_train_utils.get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
310 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(
|
311 |
+
accelerator.device, dtype=weight_dtype
|
312 |
+
)
|
313 |
+
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
314 |
+
|
315 |
+
clip_l, t5xxl = text_encoder
|
316 |
+
# ae.to("cpu")
|
317 |
+
clip_l.to("cpu")
|
318 |
+
t5xxl.to("cpu")
|
319 |
+
|
320 |
+
clean_memory_on_device(accelerator.device)
|
321 |
+
flux.to("cuda")
|
322 |
+
|
323 |
+
for param in flux.parameters():
|
324 |
+
param.requires_grad = False
|
325 |
+
|
326 |
+
# 执行去噪
|
327 |
+
with accelerator.autocast(), torch.no_grad():
|
328 |
+
x = flux_train_utils.denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps,
|
329 |
+
guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs)
|
330 |
+
|
331 |
+
# 打印x的形状
|
332 |
+
logger.debug(f"x shape after denoise: {x.shape}")
|
333 |
+
|
334 |
+
x = x.float()
|
335 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
336 |
+
|
337 |
+
# 将潜在向量转换为图像
|
338 |
+
# clean_memory_on_device(accelerator.device)
|
339 |
+
ae.to(accelerator.device)
|
340 |
+
with accelerator.autocast(), torch.no_grad():
|
341 |
+
x = ae.decode(x)
|
342 |
+
ae.to("cpu")
|
343 |
+
clean_memory_on_device(accelerator.device)
|
344 |
+
|
345 |
+
x = x.clamp(-1, 1)
|
346 |
+
x = x.permute(0, 2, 3, 1)
|
347 |
+
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
348 |
+
|
349 |
+
# 生成唯一的文件名
|
350 |
+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
351 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
352 |
+
i: int = prompt_dict.get("enum", 0) # Ensure 'enum' exists
|
353 |
+
img_filename = f"{ts_str}{seed_suffix}_{i}.png" # Added 'i' to filename for uniqueness
|
354 |
+
image.save(os.path.join(save_dir, img_filename))
|
355 |
+
|
356 |
+
|
357 |
+
def setup_argparse():
|
358 |
+
parser = argparse.ArgumentParser(description="FLUX-Controlnet-Inpainting Inference Script")
|
359 |
+
|
360 |
+
# Paths
|
361 |
+
parser.add_argument('--base_flux_checkpoint', type=str, required=True,
|
362 |
+
help='Path to BASE_FLUX_CHECKPOINT')
|
363 |
+
parser.add_argument('--lora_weights_path', type=str, required=True,
|
364 |
+
help='Path to LORA_WEIGHTS_PATH')
|
365 |
+
parser.add_argument('--clip_l_path', type=str, required=True,
|
366 |
+
help='Path to CLIP_L_PATH')
|
367 |
+
parser.add_argument('--t5xxl_path', type=str, required=True,
|
368 |
+
help='Path to T5XXL_PATH')
|
369 |
+
parser.add_argument('--ae_path', type=str, required=True,
|
370 |
+
help='Path to AE_PATH')
|
371 |
+
parser.add_argument('--sample_images_file', type=str, required=True,
|
372 |
+
help='Path to SAMPLE_IMAGES_FILE')
|
373 |
+
parser.add_argument('--sample_prompts_file', type=str, required=True,
|
374 |
+
help='Path to SAMPLE_PROMPTS_FILE')
|
375 |
+
parser.add_argument('--output_dir', type=str, required=True,
|
376 |
+
help='Directory to save OUTPUT_DIR')
|
377 |
+
parser.add_argument('--frame_num', type=int, choices=[4, 9], required=True,
|
378 |
+
help="The number of steps in the generated step diagram (choose 4 or 9)")
|
379 |
+
|
380 |
+
return parser.parse_args()
|
381 |
+
|
382 |
+
|
383 |
+
def main(args):
|
384 |
+
accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
|
385 |
+
|
386 |
+
BASE_FLUX_CHECKPOINT = args.base_flux_checkpoint
|
387 |
+
LORA_WEIGHTS_PATH = args.lora_weights_path
|
388 |
+
CLIP_L_PATH = args.clip_l_path
|
389 |
+
T5XXL_PATH = args.t5xxl_path
|
390 |
+
AE_PATH = args.ae_path
|
391 |
+
|
392 |
+
SAMPLE_IMAGES_FILE = args.sample_images_file
|
393 |
+
SAMPLE_PROMPTS_FILE = args.sample_prompts_file
|
394 |
+
OUTPUT_DIR = args.output_dir
|
395 |
+
|
396 |
+
with open(SAMPLE_IMAGES_FILE, "r", encoding="utf-8") as f:
|
397 |
+
image_lines = f.readlines()
|
398 |
+
sample_images = [line.strip() for line in image_lines if line.strip() and not line.strip().startswith("#")]
|
399 |
+
|
400 |
+
sample_prompts = train_util.load_prompts(SAMPLE_PROMPTS_FILE)
|
401 |
+
|
402 |
+
# Load models onto CUDA via Accelerator
|
403 |
+
_, [clip_l, t5xxl], ae, model = load_target_model(
|
404 |
+
fp8_base=True,
|
405 |
+
pretrained_model_name_or_path=BASE_FLUX_CHECKPOINT,
|
406 |
+
disable_mmap_load_safetensors=False,
|
407 |
+
clip_l_path=CLIP_L_PATH,
|
408 |
+
fp8_base_unet=False,
|
409 |
+
t5xxl_path=T5XXL_PATH,
|
410 |
+
ae_path=AE_PATH,
|
411 |
+
weight_dtype=torch.bfloat16,
|
412 |
+
accelerator=accelerator
|
413 |
+
)
|
414 |
+
|
415 |
+
model.eval()
|
416 |
+
clip_l.eval()
|
417 |
+
t5xxl.eval()
|
418 |
+
ae.eval()
|
419 |
+
|
420 |
+
# LoRA
|
421 |
+
multiplier = 1.0
|
422 |
+
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
423 |
+
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
|
424 |
+
True)
|
425 |
+
|
426 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
427 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
428 |
+
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
429 |
+
lora_model.eval()
|
430 |
+
lora_model.to("cuda")
|
431 |
+
|
432 |
+
# Set text encoders
|
433 |
+
text_encoder = [clip_l, t5xxl]
|
434 |
+
|
435 |
+
sample(args, accelerator, vae=ae, text_encoder=text_encoder, flux=model, output_dir=OUTPUT_DIR,
|
436 |
+
sample_images=sample_images, sample_prompts=sample_prompts)
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == "__main__":
|
440 |
+
args = setup_argparse()
|
441 |
+
|
442 |
+
main(args)
|
flux_minimal_inference.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Minimum Inference Code for FLUX
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import datetime
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
from typing import Callable, List, Optional
|
9 |
+
import einops
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
import accelerate
|
16 |
+
from transformers import CLIPTextModel
|
17 |
+
from safetensors.torch import load_file
|
18 |
+
|
19 |
+
from library import device_utils
|
20 |
+
from library.device_utils import init_ipex, get_preferred_device
|
21 |
+
from networks import oft_flux
|
22 |
+
|
23 |
+
init_ipex()
|
24 |
+
|
25 |
+
|
26 |
+
from library.utils import setup_logging, str_to_dtype
|
27 |
+
|
28 |
+
setup_logging()
|
29 |
+
import logging
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
import networks.lora_flux as lora_flux
|
34 |
+
from library import flux_models, flux_utils, sd3_utils, strategy_flux
|
35 |
+
|
36 |
+
|
37 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
38 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
39 |
+
|
40 |
+
|
41 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
42 |
+
m = (y2 - y1) / (x2 - x1)
|
43 |
+
b = y1 - m * x1
|
44 |
+
return lambda x: m * x + b
|
45 |
+
|
46 |
+
|
47 |
+
def get_schedule(
|
48 |
+
num_steps: int,
|
49 |
+
image_seq_len: int,
|
50 |
+
base_shift: float = 0.5,
|
51 |
+
max_shift: float = 1.15,
|
52 |
+
shift: bool = True,
|
53 |
+
) -> list[float]:
|
54 |
+
# extra step for zero
|
55 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
56 |
+
|
57 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
58 |
+
if shift:
|
59 |
+
# eastimate mu based on linear estimation between two points
|
60 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
61 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
62 |
+
|
63 |
+
return timesteps.tolist()
|
64 |
+
|
65 |
+
|
66 |
+
def denoise(
|
67 |
+
model: flux_models.Flux,
|
68 |
+
img: torch.Tensor,
|
69 |
+
img_ids: torch.Tensor,
|
70 |
+
txt: torch.Tensor,
|
71 |
+
txt_ids: torch.Tensor,
|
72 |
+
vec: torch.Tensor,
|
73 |
+
timesteps: list[float],
|
74 |
+
guidance: float = 4.0,
|
75 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
76 |
+
neg_txt: Optional[torch.Tensor] = None,
|
77 |
+
neg_vec: Optional[torch.Tensor] = None,
|
78 |
+
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
79 |
+
cfg_scale: Optional[float] = None,
|
80 |
+
):
|
81 |
+
# this is ignored for schnell
|
82 |
+
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
83 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
84 |
+
|
85 |
+
# prepare classifier free guidance
|
86 |
+
if neg_txt is not None and neg_vec is not None:
|
87 |
+
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
|
88 |
+
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
|
89 |
+
b_txt = torch.cat([neg_txt, txt], dim=0)
|
90 |
+
b_vec = torch.cat([neg_vec, vec], dim=0)
|
91 |
+
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
|
92 |
+
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
93 |
+
else:
|
94 |
+
b_t5_attn_mask = None
|
95 |
+
else:
|
96 |
+
b_img_ids = img_ids
|
97 |
+
b_txt_ids = txt_ids
|
98 |
+
b_txt = txt
|
99 |
+
b_vec = vec
|
100 |
+
b_t5_attn_mask = t5_attn_mask
|
101 |
+
|
102 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
103 |
+
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
104 |
+
|
105 |
+
# classifier free guidance
|
106 |
+
if neg_txt is not None and neg_vec is not None:
|
107 |
+
b_img = torch.cat([img, img], dim=0)
|
108 |
+
else:
|
109 |
+
b_img = img
|
110 |
+
|
111 |
+
pred = model(
|
112 |
+
img=b_img,
|
113 |
+
img_ids=b_img_ids,
|
114 |
+
txt=b_txt,
|
115 |
+
txt_ids=b_txt_ids,
|
116 |
+
y=b_vec,
|
117 |
+
timesteps=t_vec,
|
118 |
+
guidance=guidance_vec,
|
119 |
+
txt_attention_mask=b_t5_attn_mask,
|
120 |
+
)
|
121 |
+
|
122 |
+
# classifier free guidance
|
123 |
+
if neg_txt is not None and neg_vec is not None:
|
124 |
+
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
|
125 |
+
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
|
126 |
+
|
127 |
+
img = img + (t_prev - t_curr) * pred
|
128 |
+
|
129 |
+
return img
|
130 |
+
|
131 |
+
|
132 |
+
def do_sample(
|
133 |
+
accelerator: Optional[accelerate.Accelerator],
|
134 |
+
model: flux_models.Flux,
|
135 |
+
img: torch.Tensor,
|
136 |
+
img_ids: torch.Tensor,
|
137 |
+
l_pooled: torch.Tensor,
|
138 |
+
t5_out: torch.Tensor,
|
139 |
+
txt_ids: torch.Tensor,
|
140 |
+
num_steps: int,
|
141 |
+
guidance: float,
|
142 |
+
t5_attn_mask: Optional[torch.Tensor],
|
143 |
+
is_schnell: bool,
|
144 |
+
device: torch.device,
|
145 |
+
flux_dtype: torch.dtype,
|
146 |
+
neg_l_pooled: Optional[torch.Tensor] = None,
|
147 |
+
neg_t5_out: Optional[torch.Tensor] = None,
|
148 |
+
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
149 |
+
cfg_scale: Optional[float] = None,
|
150 |
+
):
|
151 |
+
logger.info(f"num_steps: {num_steps}")
|
152 |
+
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
153 |
+
|
154 |
+
# denoise initial noise
|
155 |
+
if accelerator:
|
156 |
+
with accelerator.autocast(), torch.no_grad():
|
157 |
+
x = denoise(
|
158 |
+
model,
|
159 |
+
img,
|
160 |
+
img_ids,
|
161 |
+
t5_out,
|
162 |
+
txt_ids,
|
163 |
+
l_pooled,
|
164 |
+
timesteps,
|
165 |
+
guidance,
|
166 |
+
t5_attn_mask,
|
167 |
+
neg_t5_out,
|
168 |
+
neg_l_pooled,
|
169 |
+
neg_t5_attn_mask,
|
170 |
+
cfg_scale,
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
174 |
+
x = denoise(
|
175 |
+
model,
|
176 |
+
img,
|
177 |
+
img_ids,
|
178 |
+
t5_out,
|
179 |
+
txt_ids,
|
180 |
+
l_pooled,
|
181 |
+
timesteps,
|
182 |
+
guidance,
|
183 |
+
t5_attn_mask,
|
184 |
+
neg_t5_out,
|
185 |
+
neg_l_pooled,
|
186 |
+
neg_t5_attn_mask,
|
187 |
+
cfg_scale,
|
188 |
+
)
|
189 |
+
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
def generate_image(
|
194 |
+
model,
|
195 |
+
clip_l: CLIPTextModel,
|
196 |
+
t5xxl,
|
197 |
+
ae,
|
198 |
+
prompt: str,
|
199 |
+
seed: Optional[int],
|
200 |
+
image_width: int,
|
201 |
+
image_height: int,
|
202 |
+
steps: Optional[int],
|
203 |
+
guidance: float,
|
204 |
+
negative_prompt: Optional[str],
|
205 |
+
cfg_scale: float,
|
206 |
+
):
|
207 |
+
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
208 |
+
logger.info(f"Seed: {seed}")
|
209 |
+
|
210 |
+
# make first noise with packed shape
|
211 |
+
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
212 |
+
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
213 |
+
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
214 |
+
noise = torch.randn(
|
215 |
+
1,
|
216 |
+
packed_latent_height * packed_latent_width,
|
217 |
+
16 * 2 * 2,
|
218 |
+
device=device,
|
219 |
+
dtype=noise_dtype,
|
220 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
221 |
+
)
|
222 |
+
|
223 |
+
# prepare img and img ids
|
224 |
+
|
225 |
+
# this is needed only for img2img
|
226 |
+
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
227 |
+
# if img.shape[0] == 1 and bs > 1:
|
228 |
+
# img = repeat(img, "1 ... -> bs ...", bs=bs)
|
229 |
+
|
230 |
+
# txt2img only needs img_ids
|
231 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
232 |
+
|
233 |
+
# prepare fp8 models
|
234 |
+
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
|
235 |
+
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
236 |
+
clip_l.to(clip_l_dtype) # fp8
|
237 |
+
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
238 |
+
clip_l.fp8_prepared = True
|
239 |
+
|
240 |
+
if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
|
241 |
+
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
242 |
+
|
243 |
+
def prepare_fp8(text_encoder, target_dtype):
|
244 |
+
def forward_hook(module):
|
245 |
+
def forward(hidden_states):
|
246 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
247 |
+
hidden_linear = module.wi_1(hidden_states)
|
248 |
+
hidden_states = hidden_gelu * hidden_linear
|
249 |
+
hidden_states = module.dropout(hidden_states)
|
250 |
+
|
251 |
+
hidden_states = module.wo(hidden_states)
|
252 |
+
return hidden_states
|
253 |
+
|
254 |
+
return forward
|
255 |
+
|
256 |
+
for module in text_encoder.modules():
|
257 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
258 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
259 |
+
module.to(target_dtype)
|
260 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
261 |
+
# print("set", module.__class__.__name__, "hooks")
|
262 |
+
module.forward = forward_hook(module)
|
263 |
+
|
264 |
+
t5xxl.to(t5xxl_dtype)
|
265 |
+
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
266 |
+
t5xxl.fp8_prepared = True
|
267 |
+
|
268 |
+
# prepare embeddings
|
269 |
+
logger.info("Encoding prompts...")
|
270 |
+
clip_l = clip_l.to(device)
|
271 |
+
t5xxl = t5xxl.to(device)
|
272 |
+
|
273 |
+
def encode(prpt: str):
|
274 |
+
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
275 |
+
with torch.no_grad():
|
276 |
+
if is_fp8(clip_l_dtype):
|
277 |
+
with accelerator.autocast():
|
278 |
+
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
279 |
+
else:
|
280 |
+
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
281 |
+
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
282 |
+
|
283 |
+
if is_fp8(t5xxl_dtype):
|
284 |
+
with accelerator.autocast():
|
285 |
+
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
286 |
+
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
290 |
+
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
291 |
+
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
292 |
+
)
|
293 |
+
return l_pooled, t5_out, txt_ids, t5_attn_mask
|
294 |
+
|
295 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
|
296 |
+
if negative_prompt:
|
297 |
+
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
|
298 |
+
else:
|
299 |
+
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
|
300 |
+
|
301 |
+
# NaN check
|
302 |
+
if torch.isnan(l_pooled).any():
|
303 |
+
raise ValueError("NaN in l_pooled")
|
304 |
+
if torch.isnan(t5_out).any():
|
305 |
+
raise ValueError("NaN in t5_out")
|
306 |
+
|
307 |
+
if args.offload:
|
308 |
+
clip_l = clip_l.cpu()
|
309 |
+
t5xxl = t5xxl.cpu()
|
310 |
+
# del clip_l, t5xxl
|
311 |
+
device_utils.clean_memory()
|
312 |
+
|
313 |
+
# generate image
|
314 |
+
logger.info("Generating image...")
|
315 |
+
model = model.to(device)
|
316 |
+
if steps is None:
|
317 |
+
steps = 4 if is_schnell else 50
|
318 |
+
|
319 |
+
img_ids = img_ids.to(device)
|
320 |
+
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
321 |
+
|
322 |
+
x = do_sample(
|
323 |
+
accelerator,
|
324 |
+
model,
|
325 |
+
noise,
|
326 |
+
img_ids,
|
327 |
+
l_pooled,
|
328 |
+
t5_out,
|
329 |
+
txt_ids,
|
330 |
+
steps,
|
331 |
+
guidance,
|
332 |
+
t5_attn_mask,
|
333 |
+
is_schnell,
|
334 |
+
device,
|
335 |
+
flux_dtype,
|
336 |
+
neg_l_pooled,
|
337 |
+
neg_t5_out,
|
338 |
+
neg_t5_attn_mask,
|
339 |
+
cfg_scale,
|
340 |
+
)
|
341 |
+
if args.offload:
|
342 |
+
model = model.cpu()
|
343 |
+
# del model
|
344 |
+
device_utils.clean_memory()
|
345 |
+
|
346 |
+
# unpack
|
347 |
+
x = x.float()
|
348 |
+
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
349 |
+
|
350 |
+
# decode
|
351 |
+
logger.info("Decoding image...")
|
352 |
+
ae = ae.to(device)
|
353 |
+
with torch.no_grad():
|
354 |
+
if is_fp8(ae_dtype):
|
355 |
+
with accelerator.autocast():
|
356 |
+
x = ae.decode(x)
|
357 |
+
else:
|
358 |
+
with torch.autocast(device_type=device.type, dtype=ae_dtype):
|
359 |
+
x = ae.decode(x)
|
360 |
+
if args.offload:
|
361 |
+
ae = ae.cpu()
|
362 |
+
|
363 |
+
x = x.clamp(-1, 1)
|
364 |
+
x = x.permute(0, 2, 3, 1)
|
365 |
+
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
366 |
+
|
367 |
+
# save image
|
368 |
+
output_dir = args.output_dir
|
369 |
+
os.makedirs(output_dir, exist_ok=True)
|
370 |
+
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
371 |
+
img.save(output_path)
|
372 |
+
|
373 |
+
logger.info(f"Saved image to {output_path}")
|
374 |
+
|
375 |
+
|
376 |
+
if __name__ == "__main__":
|
377 |
+
target_height = 768 # 1024
|
378 |
+
target_width = 1360 # 1024
|
379 |
+
|
380 |
+
# steps = 50 # 28 # 50
|
381 |
+
# guidance_scale = 5
|
382 |
+
# seed = 1 # None # 1
|
383 |
+
|
384 |
+
device = get_preferred_device()
|
385 |
+
|
386 |
+
parser = argparse.ArgumentParser()
|
387 |
+
parser.add_argument("--ckpt_path", type=str, required=True)
|
388 |
+
parser.add_argument("--clip_l", type=str, required=False)
|
389 |
+
parser.add_argument("--t5xxl", type=str, required=False)
|
390 |
+
parser.add_argument("--ae", type=str, required=False)
|
391 |
+
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
392 |
+
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
393 |
+
parser.add_argument("--output_dir", type=str, default=".")
|
394 |
+
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
|
395 |
+
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
|
396 |
+
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
|
397 |
+
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
|
398 |
+
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
|
399 |
+
parser.add_argument("--seed", type=int, default=None)
|
400 |
+
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
401 |
+
parser.add_argument("--guidance", type=float, default=3.5)
|
402 |
+
parser.add_argument("--negative_prompt", type=str, default=None)
|
403 |
+
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
404 |
+
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
405 |
+
parser.add_argument(
|
406 |
+
"--lora_weights",
|
407 |
+
type=str,
|
408 |
+
nargs="*",
|
409 |
+
default=[],
|
410 |
+
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
|
411 |
+
)
|
412 |
+
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
413 |
+
parser.add_argument("--width", type=int, default=target_width)
|
414 |
+
parser.add_argument("--height", type=int, default=target_height)
|
415 |
+
parser.add_argument("--interactive", action="store_true")
|
416 |
+
args = parser.parse_args()
|
417 |
+
|
418 |
+
seed = args.seed
|
419 |
+
steps = args.steps
|
420 |
+
guidance_scale = args.guidance
|
421 |
+
|
422 |
+
def is_fp8(dt):
|
423 |
+
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
424 |
+
|
425 |
+
dtype = str_to_dtype(args.dtype)
|
426 |
+
clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
|
427 |
+
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
|
428 |
+
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
|
429 |
+
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
|
430 |
+
|
431 |
+
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
|
432 |
+
|
433 |
+
loading_device = "cpu" if args.offload else device
|
434 |
+
|
435 |
+
use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
|
436 |
+
if any(use_fp8):
|
437 |
+
accelerator = accelerate.Accelerator(mixed_precision="bf16")
|
438 |
+
else:
|
439 |
+
accelerator = None
|
440 |
+
|
441 |
+
# load clip_l
|
442 |
+
logger.info(f"Loading clip_l from {args.clip_l}...")
|
443 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
444 |
+
clip_l.eval()
|
445 |
+
|
446 |
+
logger.info(f"Loading t5xxl from {args.t5xxl}...")
|
447 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
448 |
+
t5xxl.eval()
|
449 |
+
|
450 |
+
# if is_fp8(clip_l_dtype):
|
451 |
+
# clip_l = accelerator.prepare(clip_l)
|
452 |
+
# if is_fp8(t5xxl_dtype):
|
453 |
+
# t5xxl = accelerator.prepare(t5xxl)
|
454 |
+
|
455 |
+
# DiT
|
456 |
+
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
|
457 |
+
model.eval()
|
458 |
+
logger.info(f"Casting model to {flux_dtype}")
|
459 |
+
model.to(flux_dtype) # make sure model is dtype
|
460 |
+
# if is_fp8(flux_dtype):
|
461 |
+
# model = accelerator.prepare(model)
|
462 |
+
# if args.offload:
|
463 |
+
# model = model.to("cpu")
|
464 |
+
|
465 |
+
t5xxl_max_length = 256 if is_schnell else 512
|
466 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
467 |
+
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
468 |
+
|
469 |
+
# AE
|
470 |
+
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
|
471 |
+
ae.eval()
|
472 |
+
# if is_fp8(ae_dtype):
|
473 |
+
# ae = accelerator.prepare(ae)
|
474 |
+
|
475 |
+
# LoRA
|
476 |
+
lora_models: List[lora_flux.LoRANetwork] = []
|
477 |
+
for weights_file in args.lora_weights:
|
478 |
+
if ";" in weights_file:
|
479 |
+
weights_file, multiplier = weights_file.split(";")
|
480 |
+
multiplier = float(multiplier)
|
481 |
+
else:
|
482 |
+
multiplier = 1.0
|
483 |
+
|
484 |
+
weights_sd = load_file(weights_file)
|
485 |
+
is_lora = is_oft = False
|
486 |
+
for key in weights_sd.keys():
|
487 |
+
if key.startswith("lora"):
|
488 |
+
is_lora = True
|
489 |
+
if key.startswith("oft"):
|
490 |
+
is_oft = True
|
491 |
+
if is_lora or is_oft:
|
492 |
+
break
|
493 |
+
|
494 |
+
module = lora_flux if is_lora else oft_flux
|
495 |
+
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
496 |
+
|
497 |
+
if args.merge_lora_weights:
|
498 |
+
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
499 |
+
else:
|
500 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
501 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
502 |
+
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
503 |
+
lora_model.eval()
|
504 |
+
lora_model.to(device)
|
505 |
+
|
506 |
+
lora_models.append(lora_model)
|
507 |
+
|
508 |
+
if not args.interactive:
|
509 |
+
generate_image(
|
510 |
+
model,
|
511 |
+
clip_l,
|
512 |
+
t5xxl,
|
513 |
+
ae,
|
514 |
+
args.prompt,
|
515 |
+
args.seed,
|
516 |
+
args.width,
|
517 |
+
args.height,
|
518 |
+
args.steps,
|
519 |
+
args.guidance,
|
520 |
+
args.negative_prompt,
|
521 |
+
args.cfg_scale,
|
522 |
+
)
|
523 |
+
else:
|
524 |
+
# loop for interactive
|
525 |
+
width = target_width
|
526 |
+
height = target_height
|
527 |
+
steps = None
|
528 |
+
guidance = args.guidance
|
529 |
+
cfg_scale = args.cfg_scale
|
530 |
+
|
531 |
+
while True:
|
532 |
+
print(
|
533 |
+
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
534 |
+
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
|
535 |
+
)
|
536 |
+
prompt = input()
|
537 |
+
if prompt == "":
|
538 |
+
break
|
539 |
+
|
540 |
+
# parse options
|
541 |
+
options = prompt.split("--")
|
542 |
+
prompt = options[0].strip()
|
543 |
+
seed = None
|
544 |
+
negative_prompt = None
|
545 |
+
for opt in options[1:]:
|
546 |
+
try:
|
547 |
+
opt = opt.strip()
|
548 |
+
if opt.startswith("w"):
|
549 |
+
width = int(opt[1:].strip())
|
550 |
+
elif opt.startswith("h"):
|
551 |
+
height = int(opt[1:].strip())
|
552 |
+
elif opt.startswith("s"):
|
553 |
+
steps = int(opt[1:].strip())
|
554 |
+
elif opt.startswith("d"):
|
555 |
+
seed = int(opt[1:].strip())
|
556 |
+
elif opt.startswith("g"):
|
557 |
+
guidance = float(opt[1:].strip())
|
558 |
+
elif opt.startswith("m"):
|
559 |
+
mutipliers = opt[1:].strip().split(",")
|
560 |
+
if len(mutipliers) != len(lora_models):
|
561 |
+
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
562 |
+
continue
|
563 |
+
for i, lora_model in enumerate(lora_models):
|
564 |
+
lora_model.set_multiplier(float(mutipliers[i]))
|
565 |
+
elif opt.startswith("n"):
|
566 |
+
negative_prompt = opt[1:].strip()
|
567 |
+
if negative_prompt == "-":
|
568 |
+
negative_prompt = ""
|
569 |
+
elif opt.startswith("c"):
|
570 |
+
cfg_scale = float(opt[1:].strip())
|
571 |
+
except ValueError as e:
|
572 |
+
logger.error(f"Invalid option: {opt}, {e}")
|
573 |
+
|
574 |
+
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
|
575 |
+
|
576 |
+
logger.info("Done!")
|
flux_minimal_inference_asylora.py
ADDED
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Minimum Inference Code for FLUX
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import datetime
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
from typing import Callable, List, Optional
|
10 |
+
import einops
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
import accelerate
|
17 |
+
from transformers import CLIPTextModel
|
18 |
+
from safetensors.torch import load_file
|
19 |
+
|
20 |
+
from library import device_utils
|
21 |
+
from library.device_utils import init_ipex, get_preferred_device
|
22 |
+
from networks import oft_flux
|
23 |
+
|
24 |
+
init_ipex()
|
25 |
+
|
26 |
+
|
27 |
+
from library.utils import setup_logging, str_to_dtype
|
28 |
+
|
29 |
+
setup_logging()
|
30 |
+
import logging
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
import networks.asylora_flux as lora_flux
|
35 |
+
from library import flux_models, flux_utils, sd3_utils, strategy_flux
|
36 |
+
|
37 |
+
|
38 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
39 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
40 |
+
|
41 |
+
|
42 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
43 |
+
m = (y2 - y1) / (x2 - x1)
|
44 |
+
b = y1 - m * x1
|
45 |
+
return lambda x: m * x + b
|
46 |
+
|
47 |
+
|
48 |
+
def get_schedule(
|
49 |
+
num_steps: int,
|
50 |
+
image_seq_len: int,
|
51 |
+
base_shift: float = 0.5,
|
52 |
+
max_shift: float = 1.15,
|
53 |
+
shift: bool = True,
|
54 |
+
) -> list[float]:
|
55 |
+
# extra step for zero
|
56 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
57 |
+
|
58 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
59 |
+
if shift:
|
60 |
+
# eastimate mu based on linear estimation between two points
|
61 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
62 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
63 |
+
|
64 |
+
return timesteps.tolist()
|
65 |
+
|
66 |
+
|
67 |
+
def denoise(
|
68 |
+
model: flux_models.Flux,
|
69 |
+
img: torch.Tensor,
|
70 |
+
img_ids: torch.Tensor,
|
71 |
+
txt: torch.Tensor,
|
72 |
+
txt_ids: torch.Tensor,
|
73 |
+
vec: torch.Tensor,
|
74 |
+
timesteps: list[float],
|
75 |
+
guidance: float = 4.0,
|
76 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
77 |
+
neg_txt: Optional[torch.Tensor] = None,
|
78 |
+
neg_vec: Optional[torch.Tensor] = None,
|
79 |
+
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
80 |
+
cfg_scale: Optional[float] = None,
|
81 |
+
):
|
82 |
+
# this is ignored for schnell
|
83 |
+
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
84 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
85 |
+
|
86 |
+
# prepare classifier free guidance
|
87 |
+
if neg_txt is not None and neg_vec is not None:
|
88 |
+
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
|
89 |
+
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
|
90 |
+
b_txt = torch.cat([neg_txt, txt], dim=0)
|
91 |
+
b_vec = torch.cat([neg_vec, vec], dim=0)
|
92 |
+
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
|
93 |
+
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
94 |
+
else:
|
95 |
+
b_t5_attn_mask = None
|
96 |
+
else:
|
97 |
+
b_img_ids = img_ids
|
98 |
+
b_txt_ids = txt_ids
|
99 |
+
b_txt = txt
|
100 |
+
b_vec = vec
|
101 |
+
b_t5_attn_mask = t5_attn_mask
|
102 |
+
|
103 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
104 |
+
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
105 |
+
|
106 |
+
# classifier free guidance
|
107 |
+
if neg_txt is not None and neg_vec is not None:
|
108 |
+
b_img = torch.cat([img, img], dim=0)
|
109 |
+
else:
|
110 |
+
b_img = img
|
111 |
+
|
112 |
+
pred = model(
|
113 |
+
img=b_img,
|
114 |
+
img_ids=b_img_ids,
|
115 |
+
txt=b_txt,
|
116 |
+
txt_ids=b_txt_ids,
|
117 |
+
y=b_vec,
|
118 |
+
timesteps=t_vec,
|
119 |
+
guidance=guidance_vec,
|
120 |
+
txt_attention_mask=b_t5_attn_mask,
|
121 |
+
)
|
122 |
+
|
123 |
+
# classifier free guidance
|
124 |
+
if neg_txt is not None and neg_vec is not None:
|
125 |
+
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
|
126 |
+
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
|
127 |
+
|
128 |
+
img = img + (t_prev - t_curr) * pred
|
129 |
+
|
130 |
+
return img
|
131 |
+
|
132 |
+
|
133 |
+
def do_sample(
|
134 |
+
accelerator: Optional[accelerate.Accelerator],
|
135 |
+
model: flux_models.Flux,
|
136 |
+
img: torch.Tensor,
|
137 |
+
img_ids: torch.Tensor,
|
138 |
+
l_pooled: torch.Tensor,
|
139 |
+
t5_out: torch.Tensor,
|
140 |
+
txt_ids: torch.Tensor,
|
141 |
+
num_steps: int,
|
142 |
+
guidance: float,
|
143 |
+
t5_attn_mask: Optional[torch.Tensor],
|
144 |
+
is_schnell: bool,
|
145 |
+
device: torch.device,
|
146 |
+
flux_dtype: torch.dtype,
|
147 |
+
neg_l_pooled: Optional[torch.Tensor] = None,
|
148 |
+
neg_t5_out: Optional[torch.Tensor] = None,
|
149 |
+
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
150 |
+
cfg_scale: Optional[float] = None,
|
151 |
+
):
|
152 |
+
logger.info(f"num_steps: {num_steps}")
|
153 |
+
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
154 |
+
|
155 |
+
# denoise initial noise
|
156 |
+
if accelerator:
|
157 |
+
with accelerator.autocast(), torch.no_grad():
|
158 |
+
x = denoise(
|
159 |
+
model,
|
160 |
+
img,
|
161 |
+
img_ids,
|
162 |
+
t5_out,
|
163 |
+
txt_ids,
|
164 |
+
l_pooled,
|
165 |
+
timesteps,
|
166 |
+
guidance,
|
167 |
+
t5_attn_mask,
|
168 |
+
neg_t5_out,
|
169 |
+
neg_l_pooled,
|
170 |
+
neg_t5_attn_mask,
|
171 |
+
cfg_scale,
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
175 |
+
x = denoise(
|
176 |
+
model,
|
177 |
+
img,
|
178 |
+
img_ids,
|
179 |
+
t5_out,
|
180 |
+
txt_ids,
|
181 |
+
l_pooled,
|
182 |
+
timesteps,
|
183 |
+
guidance,
|
184 |
+
t5_attn_mask,
|
185 |
+
neg_t5_out,
|
186 |
+
neg_l_pooled,
|
187 |
+
neg_t5_attn_mask,
|
188 |
+
cfg_scale,
|
189 |
+
)
|
190 |
+
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
def generate_image(
|
195 |
+
model,
|
196 |
+
clip_l: CLIPTextModel,
|
197 |
+
t5xxl,
|
198 |
+
ae,
|
199 |
+
prompt: str,
|
200 |
+
seed: Optional[int],
|
201 |
+
image_width: int,
|
202 |
+
image_height: int,
|
203 |
+
steps: Optional[int],
|
204 |
+
guidance: float,
|
205 |
+
negative_prompt: Optional[str],
|
206 |
+
cfg_scale: float,
|
207 |
+
):
|
208 |
+
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
209 |
+
logger.info(f"Seed: {seed}")
|
210 |
+
|
211 |
+
# make first noise with packed shape
|
212 |
+
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
213 |
+
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
214 |
+
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
215 |
+
noise = torch.randn(
|
216 |
+
1,
|
217 |
+
packed_latent_height * packed_latent_width,
|
218 |
+
16 * 2 * 2,
|
219 |
+
device=device,
|
220 |
+
dtype=noise_dtype,
|
221 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
222 |
+
)
|
223 |
+
|
224 |
+
# prepare img and img ids
|
225 |
+
|
226 |
+
# this is needed only for img2img
|
227 |
+
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
228 |
+
# if img.shape[0] == 1 and bs > 1:
|
229 |
+
# img = repeat(img, "1 ... -> bs ...", bs=bs)
|
230 |
+
|
231 |
+
# txt2img only needs img_ids
|
232 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
233 |
+
|
234 |
+
# prepare fp8 models
|
235 |
+
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
|
236 |
+
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
237 |
+
clip_l.to(clip_l_dtype) # fp8
|
238 |
+
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
239 |
+
clip_l.fp8_prepared = True
|
240 |
+
|
241 |
+
if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
|
242 |
+
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
243 |
+
|
244 |
+
def prepare_fp8(text_encoder, target_dtype):
|
245 |
+
def forward_hook(module):
|
246 |
+
def forward(hidden_states):
|
247 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
248 |
+
hidden_linear = module.wi_1(hidden_states)
|
249 |
+
hidden_states = hidden_gelu * hidden_linear
|
250 |
+
hidden_states = module.dropout(hidden_states)
|
251 |
+
|
252 |
+
hidden_states = module.wo(hidden_states)
|
253 |
+
return hidden_states
|
254 |
+
|
255 |
+
return forward
|
256 |
+
|
257 |
+
for module in text_encoder.modules():
|
258 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
259 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
260 |
+
module.to(target_dtype)
|
261 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
262 |
+
# print("set", module.__class__.__name__, "hooks")
|
263 |
+
module.forward = forward_hook(module)
|
264 |
+
|
265 |
+
t5xxl.to(t5xxl_dtype)
|
266 |
+
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
267 |
+
t5xxl.fp8_prepared = True
|
268 |
+
|
269 |
+
# prepare embeddings
|
270 |
+
logger.info("Encoding prompts...")
|
271 |
+
clip_l = clip_l.to(device)
|
272 |
+
t5xxl = t5xxl.to(device)
|
273 |
+
|
274 |
+
def encode(prpt: str):
|
275 |
+
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
276 |
+
with torch.no_grad():
|
277 |
+
if is_fp8(clip_l_dtype):
|
278 |
+
with accelerator.autocast():
|
279 |
+
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
280 |
+
else:
|
281 |
+
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
282 |
+
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
283 |
+
|
284 |
+
if is_fp8(t5xxl_dtype):
|
285 |
+
with accelerator.autocast():
|
286 |
+
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
287 |
+
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
291 |
+
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
292 |
+
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
293 |
+
)
|
294 |
+
return l_pooled, t5_out, txt_ids, t5_attn_mask
|
295 |
+
|
296 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
|
297 |
+
if negative_prompt:
|
298 |
+
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
|
299 |
+
else:
|
300 |
+
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
|
301 |
+
|
302 |
+
# NaN check
|
303 |
+
if torch.isnan(l_pooled).any():
|
304 |
+
raise ValueError("NaN in l_pooled")
|
305 |
+
if torch.isnan(t5_out).any():
|
306 |
+
raise ValueError("NaN in t5_out")
|
307 |
+
|
308 |
+
if args.offload:
|
309 |
+
clip_l = clip_l.cpu()
|
310 |
+
t5xxl = t5xxl.cpu()
|
311 |
+
# del clip_l, t5xxl
|
312 |
+
device_utils.clean_memory()
|
313 |
+
|
314 |
+
# generate image
|
315 |
+
logger.info("Generating image...")
|
316 |
+
model = model.to(device)
|
317 |
+
if steps is None:
|
318 |
+
steps = 4 if is_schnell else 50
|
319 |
+
|
320 |
+
img_ids = img_ids.to(device)
|
321 |
+
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
322 |
+
|
323 |
+
x = do_sample(
|
324 |
+
accelerator,
|
325 |
+
model,
|
326 |
+
noise,
|
327 |
+
img_ids,
|
328 |
+
l_pooled,
|
329 |
+
t5_out,
|
330 |
+
txt_ids,
|
331 |
+
steps,
|
332 |
+
guidance,
|
333 |
+
t5_attn_mask,
|
334 |
+
is_schnell,
|
335 |
+
device,
|
336 |
+
flux_dtype,
|
337 |
+
neg_l_pooled,
|
338 |
+
neg_t5_out,
|
339 |
+
neg_t5_attn_mask,
|
340 |
+
cfg_scale,
|
341 |
+
)
|
342 |
+
if args.offload:
|
343 |
+
model = model.cpu()
|
344 |
+
# del model
|
345 |
+
device_utils.clean_memory()
|
346 |
+
|
347 |
+
# unpack
|
348 |
+
x = x.float()
|
349 |
+
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
350 |
+
|
351 |
+
# decode
|
352 |
+
logger.info("Decoding image...")
|
353 |
+
ae = ae.to(device)
|
354 |
+
with torch.no_grad():
|
355 |
+
if is_fp8(ae_dtype):
|
356 |
+
with accelerator.autocast():
|
357 |
+
x = ae.decode(x)
|
358 |
+
else:
|
359 |
+
with torch.autocast(device_type=device.type, dtype=ae_dtype):
|
360 |
+
x = ae.decode(x)
|
361 |
+
if args.offload:
|
362 |
+
ae = ae.cpu()
|
363 |
+
|
364 |
+
x = x.clamp(-1, 1)
|
365 |
+
x = x.permute(0, 2, 3, 1)
|
366 |
+
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
367 |
+
|
368 |
+
# save image
|
369 |
+
output_dir = args.output_dir
|
370 |
+
os.makedirs(output_dir, exist_ok=True)
|
371 |
+
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
372 |
+
img.save(output_path)
|
373 |
+
|
374 |
+
logger.info(f"Saved image to {output_path}")
|
375 |
+
|
376 |
+
|
377 |
+
if __name__ == "__main__":
|
378 |
+
target_height = 768 # 1024
|
379 |
+
target_width = 1360 # 1024
|
380 |
+
|
381 |
+
# steps = 50 # 28 # 50
|
382 |
+
# guidance_scale = 5
|
383 |
+
# seed = 1 # None # 1
|
384 |
+
|
385 |
+
device = get_preferred_device()
|
386 |
+
|
387 |
+
parser = argparse.ArgumentParser()
|
388 |
+
parser.add_argument("--lora_ups_num", type=int, required=True)
|
389 |
+
parser.add_argument("--lora_up_cur", type=int, required=True)
|
390 |
+
parser.add_argument("--ckpt_path", type=str, required=True)
|
391 |
+
parser.add_argument("--clip_l", type=str, required=False)
|
392 |
+
parser.add_argument("--t5xxl", type=str, required=False)
|
393 |
+
parser.add_argument("--ae", type=str, required=False)
|
394 |
+
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
395 |
+
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
396 |
+
parser.add_argument("--output_dir", type=str, default=".")
|
397 |
+
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
|
398 |
+
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
|
399 |
+
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
|
400 |
+
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
|
401 |
+
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
|
402 |
+
parser.add_argument("--seed", type=int, default=None)
|
403 |
+
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
404 |
+
parser.add_argument("--guidance", type=float, default=3.5)
|
405 |
+
parser.add_argument("--negative_prompt", type=str, default=None)
|
406 |
+
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
407 |
+
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
408 |
+
parser.add_argument(
|
409 |
+
"--lora_weights",
|
410 |
+
type=str,
|
411 |
+
nargs="*",
|
412 |
+
default=[],
|
413 |
+
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
|
414 |
+
)
|
415 |
+
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
416 |
+
parser.add_argument("--width", type=int, default=target_width)
|
417 |
+
parser.add_argument("--height", type=int, default=target_height)
|
418 |
+
parser.add_argument("--interactive", action="store_true")
|
419 |
+
args = parser.parse_args()
|
420 |
+
|
421 |
+
seed = args.seed
|
422 |
+
steps = args.steps
|
423 |
+
guidance_scale = args.guidance
|
424 |
+
lora_ups_num = args.lora_ups_num
|
425 |
+
lora_up_cur = args.lora_up_cur
|
426 |
+
|
427 |
+
def is_fp8(dt):
|
428 |
+
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
429 |
+
|
430 |
+
dtype = str_to_dtype(args.dtype)
|
431 |
+
clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
|
432 |
+
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
|
433 |
+
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
|
434 |
+
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
|
435 |
+
|
436 |
+
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
|
437 |
+
|
438 |
+
loading_device = "cpu" if args.offload else device
|
439 |
+
|
440 |
+
use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
|
441 |
+
if any(use_fp8):
|
442 |
+
accelerator = accelerate.Accelerator(mixed_precision="bf16")
|
443 |
+
else:
|
444 |
+
accelerator = None
|
445 |
+
|
446 |
+
# load clip_l
|
447 |
+
logger.info(f"Loading clip_l from {args.clip_l}...")
|
448 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
449 |
+
clip_l.eval()
|
450 |
+
|
451 |
+
logger.info(f"Loading t5xxl from {args.t5xxl}...")
|
452 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
453 |
+
t5xxl.eval()
|
454 |
+
|
455 |
+
# if is_fp8(clip_l_dtype):
|
456 |
+
# clip_l = accelerator.prepare(clip_l)
|
457 |
+
# if is_fp8(t5xxl_dtype):
|
458 |
+
# t5xxl = accelerator.prepare(t5xxl)
|
459 |
+
|
460 |
+
# DiT
|
461 |
+
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
|
462 |
+
model.eval()
|
463 |
+
logger.info(f"Casting model to {flux_dtype}")
|
464 |
+
model.to(flux_dtype) # make sure model is dtype
|
465 |
+
# if is_fp8(flux_dtype):
|
466 |
+
# model = accelerator.prepare(model)
|
467 |
+
# if args.offload:
|
468 |
+
# model = model.to("cpu")
|
469 |
+
|
470 |
+
t5xxl_max_length = 256 if is_schnell else 512
|
471 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
472 |
+
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
473 |
+
|
474 |
+
# AE
|
475 |
+
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
|
476 |
+
ae.eval()
|
477 |
+
# if is_fp8(ae_dtype):
|
478 |
+
# ae = accelerator.prepare(ae)
|
479 |
+
|
480 |
+
# LoRA
|
481 |
+
lora_models: List[lora_flux.LoRANetwork] = []
|
482 |
+
for weights_file in args.lora_weights:
|
483 |
+
if ";" in weights_file:
|
484 |
+
weights_file, multiplier = weights_file.split(";")
|
485 |
+
multiplier = float(multiplier)
|
486 |
+
else:
|
487 |
+
multiplier = 1.0
|
488 |
+
|
489 |
+
weights_sd = load_file(weights_file)
|
490 |
+
is_lora = is_oft = False
|
491 |
+
for key in weights_sd.keys():
|
492 |
+
if key.startswith("lora"):
|
493 |
+
is_lora = True
|
494 |
+
if key.startswith("oft"):
|
495 |
+
is_oft = True
|
496 |
+
if is_lora or is_oft:
|
497 |
+
break
|
498 |
+
|
499 |
+
module = lora_flux if is_lora else oft_flux
|
500 |
+
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num)
|
501 |
+
for sub_lora in lora_model.unet_loras:
|
502 |
+
sub_lora.set_lora_up_cur(lora_up_cur-1)
|
503 |
+
|
504 |
+
if args.merge_lora_weights:
|
505 |
+
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
506 |
+
else:
|
507 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
508 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
509 |
+
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
510 |
+
lora_model.eval()
|
511 |
+
lora_model.to(device)
|
512 |
+
|
513 |
+
lora_models.append(lora_model)
|
514 |
+
|
515 |
+
if not args.interactive:
|
516 |
+
generate_image(
|
517 |
+
model,
|
518 |
+
clip_l,
|
519 |
+
t5xxl,
|
520 |
+
ae,
|
521 |
+
args.prompt,
|
522 |
+
args.seed,
|
523 |
+
args.width,
|
524 |
+
args.height,
|
525 |
+
args.steps,
|
526 |
+
args.guidance,
|
527 |
+
args.negative_prompt,
|
528 |
+
args.cfg_scale,
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
# loop for interactive
|
532 |
+
width = target_width
|
533 |
+
height = target_height
|
534 |
+
steps = None
|
535 |
+
guidance = args.guidance
|
536 |
+
cfg_scale = args.cfg_scale
|
537 |
+
|
538 |
+
while True:
|
539 |
+
print(
|
540 |
+
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
541 |
+
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
|
542 |
+
)
|
543 |
+
prompt = input()
|
544 |
+
if prompt == "":
|
545 |
+
break
|
546 |
+
|
547 |
+
# parse options
|
548 |
+
options = prompt.split("--")
|
549 |
+
prompt = options[0].strip()
|
550 |
+
seed = None
|
551 |
+
negative_prompt = None
|
552 |
+
for opt in options[1:]:
|
553 |
+
try:
|
554 |
+
opt = opt.strip()
|
555 |
+
if opt.startswith("w"):
|
556 |
+
width = int(opt[1:].strip())
|
557 |
+
elif opt.startswith("h"):
|
558 |
+
height = int(opt[1:].strip())
|
559 |
+
elif opt.startswith("s"):
|
560 |
+
steps = int(opt[1:].strip())
|
561 |
+
elif opt.startswith("d"):
|
562 |
+
seed = int(opt[1:].strip())
|
563 |
+
elif opt.startswith("g"):
|
564 |
+
guidance = float(opt[1:].strip())
|
565 |
+
elif opt.startswith("m"):
|
566 |
+
mutipliers = opt[1:].strip().split(",")
|
567 |
+
if len(mutipliers) != len(lora_models):
|
568 |
+
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
569 |
+
continue
|
570 |
+
for i, lora_model in enumerate(lora_models):
|
571 |
+
lora_model.set_multiplier(float(mutipliers[i]))
|
572 |
+
elif opt.startswith("n"):
|
573 |
+
negative_prompt = opt[1:].strip()
|
574 |
+
if negative_prompt == "-":
|
575 |
+
negative_prompt = ""
|
576 |
+
elif opt.startswith("c"):
|
577 |
+
cfg_scale = float(opt[1:].strip())
|
578 |
+
except ValueError as e:
|
579 |
+
logger.error(f"Invalid option: {opt}, {e}")
|
580 |
+
|
581 |
+
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
|
582 |
+
|
583 |
+
logger.info("Done!")
|
flux_train_network.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from typing import Any, Optional, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from accelerate import Accelerator
|
9 |
+
|
10 |
+
from library.device_utils import clean_memory_on_device, init_ipex
|
11 |
+
|
12 |
+
init_ipex()
|
13 |
+
|
14 |
+
import train_network
|
15 |
+
from library import (
|
16 |
+
flux_models,
|
17 |
+
flux_train_utils,
|
18 |
+
flux_utils,
|
19 |
+
sd3_train_utils,
|
20 |
+
strategy_base,
|
21 |
+
strategy_flux,
|
22 |
+
train_util,
|
23 |
+
)
|
24 |
+
from library.utils import setup_logging
|
25 |
+
|
26 |
+
setup_logging()
|
27 |
+
import logging
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
33 |
+
def __init__(self):
|
34 |
+
super().__init__()
|
35 |
+
self.sample_prompts_te_outputs = None
|
36 |
+
self.is_schnell: Optional[bool] = None
|
37 |
+
self.is_swapping_blocks: bool = False
|
38 |
+
|
39 |
+
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
40 |
+
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
41 |
+
# sdxl_train_util.verify_sdxl_training_args(args)
|
42 |
+
|
43 |
+
if args.fp8_base_unet:
|
44 |
+
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
45 |
+
|
46 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
47 |
+
logger.warning(
|
48 |
+
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
49 |
+
)
|
50 |
+
args.cache_text_encoder_outputs = True
|
51 |
+
|
52 |
+
if args.cache_text_encoder_outputs:
|
53 |
+
assert (
|
54 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
55 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
56 |
+
|
57 |
+
# prepare CLIP-L/T5XXL training flags
|
58 |
+
self.train_clip_l = not args.network_train_unet_only
|
59 |
+
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
60 |
+
|
61 |
+
if args.max_token_length is not None:
|
62 |
+
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
63 |
+
|
64 |
+
assert (
|
65 |
+
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
66 |
+
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
67 |
+
|
68 |
+
# deprecated split_mode option
|
69 |
+
if args.split_mode:
|
70 |
+
if args.blocks_to_swap is not None:
|
71 |
+
logger.warning(
|
72 |
+
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
|
73 |
+
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
logger.warning(
|
77 |
+
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
|
78 |
+
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
|
79 |
+
)
|
80 |
+
args.blocks_to_swap = 18 # 18 is safe for most cases
|
81 |
+
|
82 |
+
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
83 |
+
if val_dataset_group is not None:
|
84 |
+
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
85 |
+
|
86 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
87 |
+
# currently offload to cpu for some models
|
88 |
+
|
89 |
+
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
90 |
+
loading_dtype = None if args.fp8_base else weight_dtype
|
91 |
+
|
92 |
+
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
93 |
+
self.is_schnell, model = flux_utils.load_flow_model(
|
94 |
+
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
95 |
+
)
|
96 |
+
if args.fp8_base:
|
97 |
+
# check dtype of model
|
98 |
+
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
99 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
100 |
+
elif model.dtype == torch.float8_e4m3fn:
|
101 |
+
logger.info("Loaded fp8 FLUX model")
|
102 |
+
else:
|
103 |
+
logger.info(
|
104 |
+
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
105 |
+
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
106 |
+
)
|
107 |
+
model.to(torch.float8_e4m3fn)
|
108 |
+
|
109 |
+
# if args.split_mode:
|
110 |
+
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
111 |
+
|
112 |
+
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
113 |
+
if self.is_swapping_blocks:
|
114 |
+
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
115 |
+
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
116 |
+
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
117 |
+
|
118 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
119 |
+
clip_l.eval()
|
120 |
+
|
121 |
+
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
122 |
+
if args.fp8_base and not args.fp8_base_unet:
|
123 |
+
loading_dtype = None # as is
|
124 |
+
else:
|
125 |
+
loading_dtype = weight_dtype
|
126 |
+
|
127 |
+
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
128 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
129 |
+
t5xxl.eval()
|
130 |
+
if args.fp8_base and not args.fp8_base_unet:
|
131 |
+
# check dtype of model
|
132 |
+
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
133 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
134 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
135 |
+
logger.info("Loaded fp8 T5XXL model")
|
136 |
+
|
137 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
138 |
+
|
139 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
140 |
+
|
141 |
+
def get_tokenize_strategy(self, args):
|
142 |
+
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
143 |
+
|
144 |
+
if args.t5xxl_max_token_length is None:
|
145 |
+
if is_schnell:
|
146 |
+
t5xxl_max_token_length = 256
|
147 |
+
else:
|
148 |
+
t5xxl_max_token_length = 512
|
149 |
+
else:
|
150 |
+
t5xxl_max_token_length = args.t5xxl_max_token_length
|
151 |
+
|
152 |
+
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
153 |
+
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
154 |
+
|
155 |
+
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
156 |
+
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
157 |
+
|
158 |
+
def get_latents_caching_strategy(self, args):
|
159 |
+
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
160 |
+
return latents_caching_strategy
|
161 |
+
|
162 |
+
def get_text_encoding_strategy(self, args):
|
163 |
+
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
164 |
+
|
165 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
166 |
+
# check t5xxl is trained or not
|
167 |
+
self.train_t5xxl = network.train_t5xxl
|
168 |
+
|
169 |
+
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
170 |
+
raise ValueError(
|
171 |
+
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
172 |
+
)
|
173 |
+
|
174 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
175 |
+
if args.cache_text_encoder_outputs:
|
176 |
+
if self.train_clip_l and not self.train_t5xxl:
|
177 |
+
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
178 |
+
else:
|
179 |
+
return None # no text encoders are needed for encoding because both are cached
|
180 |
+
else:
|
181 |
+
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
182 |
+
|
183 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
184 |
+
return [self.train_clip_l, self.train_t5xxl]
|
185 |
+
|
186 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
187 |
+
if args.cache_text_encoder_outputs:
|
188 |
+
# if the text encoders is trained, we need tokenization, so is_partial is True
|
189 |
+
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
190 |
+
args.cache_text_encoder_outputs_to_disk,
|
191 |
+
args.text_encoder_batch_size,
|
192 |
+
args.skip_cache_check,
|
193 |
+
is_partial=self.train_clip_l or self.train_t5xxl,
|
194 |
+
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
return None
|
198 |
+
|
199 |
+
def cache_text_encoder_outputs_if_needed(
|
200 |
+
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
201 |
+
):
|
202 |
+
if args.cache_text_encoder_outputs:
|
203 |
+
if not args.lowram:
|
204 |
+
# メモリ消費を減らす
|
205 |
+
logger.info("move vae and unet to cpu to save memory")
|
206 |
+
org_vae_device = vae.device
|
207 |
+
org_unet_device = unet.device
|
208 |
+
vae.to("cpu")
|
209 |
+
unet.to("cpu")
|
210 |
+
clean_memory_on_device(accelerator.device)
|
211 |
+
|
212 |
+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
213 |
+
logger.info("move text encoders to gpu")
|
214 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
215 |
+
text_encoders[1].to(accelerator.device)
|
216 |
+
|
217 |
+
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
218 |
+
# if we load fp8 weights, the model is already fp8, so we use it as is
|
219 |
+
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
220 |
+
else:
|
221 |
+
# otherwise, we need to convert it to target dtype
|
222 |
+
text_encoders[1].to(weight_dtype)
|
223 |
+
|
224 |
+
with accelerator.autocast():
|
225 |
+
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
226 |
+
|
227 |
+
# cache sample prompts
|
228 |
+
if args.sample_prompts is not None:
|
229 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
230 |
+
|
231 |
+
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
232 |
+
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
233 |
+
|
234 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
235 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
236 |
+
with accelerator.autocast(), torch.no_grad():
|
237 |
+
for prompt_dict in prompts:
|
238 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
239 |
+
if p not in sample_prompts_te_outputs:
|
240 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
241 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
242 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
243 |
+
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
244 |
+
)
|
245 |
+
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
246 |
+
|
247 |
+
accelerator.wait_for_everyone()
|
248 |
+
|
249 |
+
# move back to cpu
|
250 |
+
if not self.is_train_text_encoder(args):
|
251 |
+
logger.info("move CLIP-L back to cpu")
|
252 |
+
text_encoders[0].to("cpu")
|
253 |
+
logger.info("move t5XXL back to cpu")
|
254 |
+
text_encoders[1].to("cpu")
|
255 |
+
clean_memory_on_device(accelerator.device)
|
256 |
+
|
257 |
+
if not args.lowram:
|
258 |
+
logger.info("move vae and unet back to original device")
|
259 |
+
vae.to(org_vae_device)
|
260 |
+
unet.to(org_unet_device)
|
261 |
+
else:
|
262 |
+
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
263 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
264 |
+
text_encoders[1].to(accelerator.device)
|
265 |
+
|
266 |
+
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
267 |
+
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
268 |
+
|
269 |
+
# # get size embeddings
|
270 |
+
# orig_size = batch["original_sizes_hw"]
|
271 |
+
# crop_size = batch["crop_top_lefts"]
|
272 |
+
# target_size = batch["target_sizes_hw"]
|
273 |
+
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
274 |
+
|
275 |
+
# # concat embeddings
|
276 |
+
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
277 |
+
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
278 |
+
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
279 |
+
|
280 |
+
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
281 |
+
# return noise_pred
|
282 |
+
|
283 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
284 |
+
text_encoders = text_encoder # for compatibility
|
285 |
+
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
286 |
+
|
287 |
+
flux_train_utils.sample_images(
|
288 |
+
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
289 |
+
)
|
290 |
+
# return
|
291 |
+
|
292 |
+
"""
|
293 |
+
class FluxUpperLowerWrapper(torch.nn.Module):
|
294 |
+
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
295 |
+
super().__init__()
|
296 |
+
self.flux_upper = flux_upper
|
297 |
+
self.flux_lower = flux_lower
|
298 |
+
self.target_device = device
|
299 |
+
|
300 |
+
def prepare_block_swap_before_forward(self):
|
301 |
+
pass
|
302 |
+
|
303 |
+
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
304 |
+
self.flux_lower.to("cpu")
|
305 |
+
clean_memory_on_device(self.target_device)
|
306 |
+
self.flux_upper.to(self.target_device)
|
307 |
+
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
308 |
+
self.flux_upper.to("cpu")
|
309 |
+
clean_memory_on_device(self.target_device)
|
310 |
+
self.flux_lower.to(self.target_device)
|
311 |
+
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
312 |
+
|
313 |
+
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
314 |
+
clean_memory_on_device(accelerator.device)
|
315 |
+
flux_train_utils.sample_images(
|
316 |
+
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
317 |
+
)
|
318 |
+
clean_memory_on_device(accelerator.device)
|
319 |
+
"""
|
320 |
+
|
321 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
322 |
+
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
323 |
+
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
324 |
+
return noise_scheduler
|
325 |
+
|
326 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
327 |
+
return vae.encode(images)
|
328 |
+
|
329 |
+
def shift_scale_latents(self, args, latents):
|
330 |
+
return latents
|
331 |
+
|
332 |
+
def get_noise_pred_and_target(
|
333 |
+
self,
|
334 |
+
args,
|
335 |
+
accelerator,
|
336 |
+
noise_scheduler,
|
337 |
+
latents,
|
338 |
+
batch,
|
339 |
+
text_encoder_conds,
|
340 |
+
unet: flux_models.Flux,
|
341 |
+
network,
|
342 |
+
weight_dtype,
|
343 |
+
train_unet,
|
344 |
+
is_train=True
|
345 |
+
):
|
346 |
+
# Sample noise that we'll add to the latents
|
347 |
+
noise = torch.randn_like(latents)
|
348 |
+
bsz = latents.shape[0]
|
349 |
+
|
350 |
+
# get noisy model input and timesteps
|
351 |
+
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
352 |
+
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
353 |
+
)
|
354 |
+
|
355 |
+
# pack latents and get img_ids
|
356 |
+
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
357 |
+
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
358 |
+
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
359 |
+
|
360 |
+
# get guidance
|
361 |
+
# ensure guidance_scale in args is float
|
362 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
363 |
+
|
364 |
+
# ensure the hidden state will require grad
|
365 |
+
if args.gradient_checkpointing:
|
366 |
+
noisy_model_input.requires_grad_(True)
|
367 |
+
for t in text_encoder_conds:
|
368 |
+
if t is not None and t.dtype.is_floating_point:
|
369 |
+
t.requires_grad_(True)
|
370 |
+
img_ids.requires_grad_(True)
|
371 |
+
guidance_vec.requires_grad_(True)
|
372 |
+
|
373 |
+
# Predict the noise residual
|
374 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
375 |
+
if not args.apply_t5_attn_mask:
|
376 |
+
t5_attn_mask = None
|
377 |
+
|
378 |
+
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
379 |
+
# if not args.split_mode:
|
380 |
+
# normal forward
|
381 |
+
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
382 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
383 |
+
model_pred = unet(
|
384 |
+
img=img,
|
385 |
+
img_ids=img_ids,
|
386 |
+
txt=t5_out,
|
387 |
+
txt_ids=txt_ids,
|
388 |
+
y=l_pooled,
|
389 |
+
timesteps=timesteps / 1000,
|
390 |
+
guidance=guidance_vec,
|
391 |
+
txt_attention_mask=t5_attn_mask,
|
392 |
+
)
|
393 |
+
"""
|
394 |
+
else:
|
395 |
+
# split forward to reduce memory usage
|
396 |
+
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
397 |
+
with accelerator.autocast():
|
398 |
+
# move flux lower to cpu, and then move flux upper to gpu
|
399 |
+
unet.to("cpu")
|
400 |
+
clean_memory_on_device(accelerator.device)
|
401 |
+
self.flux_upper.to(accelerator.device)
|
402 |
+
|
403 |
+
# upper model does not require grad
|
404 |
+
with torch.no_grad():
|
405 |
+
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
406 |
+
img=packed_noisy_model_input,
|
407 |
+
img_ids=img_ids,
|
408 |
+
txt=t5_out,
|
409 |
+
txt_ids=txt_ids,
|
410 |
+
y=l_pooled,
|
411 |
+
timesteps=timesteps / 1000,
|
412 |
+
guidance=guidance_vec,
|
413 |
+
txt_attention_mask=t5_attn_mask,
|
414 |
+
)
|
415 |
+
|
416 |
+
# move flux upper back to cpu, and then move flux lower to gpu
|
417 |
+
self.flux_upper.to("cpu")
|
418 |
+
clean_memory_on_device(accelerator.device)
|
419 |
+
unet.to(accelerator.device)
|
420 |
+
|
421 |
+
# lower model requires grad
|
422 |
+
intermediate_img.requires_grad_(True)
|
423 |
+
intermediate_txt.requires_grad_(True)
|
424 |
+
vec.requires_grad_(True)
|
425 |
+
pe.requires_grad_(True)
|
426 |
+
|
427 |
+
with torch.set_grad_enabled(is_train and train_unet):
|
428 |
+
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
429 |
+
"""
|
430 |
+
|
431 |
+
return model_pred
|
432 |
+
|
433 |
+
model_pred = call_dit(
|
434 |
+
img=packed_noisy_model_input,
|
435 |
+
img_ids=img_ids,
|
436 |
+
t5_out=t5_out,
|
437 |
+
txt_ids=txt_ids,
|
438 |
+
l_pooled=l_pooled,
|
439 |
+
timesteps=timesteps,
|
440 |
+
guidance_vec=guidance_vec,
|
441 |
+
t5_attn_mask=t5_attn_mask,
|
442 |
+
)
|
443 |
+
|
444 |
+
# unpack latents
|
445 |
+
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
446 |
+
|
447 |
+
# apply model prediction type
|
448 |
+
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
449 |
+
|
450 |
+
# flow matching loss: this is different from SD3
|
451 |
+
target = noise - latents
|
452 |
+
|
453 |
+
# differential output preservation
|
454 |
+
if "custom_attributes" in batch:
|
455 |
+
diff_output_pr_indices = []
|
456 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
457 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
458 |
+
diff_output_pr_indices.append(i)
|
459 |
+
|
460 |
+
if len(diff_output_pr_indices) > 0:
|
461 |
+
network.set_multiplier(0.0)
|
462 |
+
unet.prepare_block_swap_before_forward()
|
463 |
+
with torch.no_grad():
|
464 |
+
model_pred_prior = call_dit(
|
465 |
+
img=packed_noisy_model_input[diff_output_pr_indices],
|
466 |
+
img_ids=img_ids[diff_output_pr_indices],
|
467 |
+
t5_out=t5_out[diff_output_pr_indices],
|
468 |
+
txt_ids=txt_ids[diff_output_pr_indices],
|
469 |
+
l_pooled=l_pooled[diff_output_pr_indices],
|
470 |
+
timesteps=timesteps[diff_output_pr_indices],
|
471 |
+
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
472 |
+
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
473 |
+
)
|
474 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
475 |
+
|
476 |
+
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
477 |
+
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
478 |
+
args,
|
479 |
+
model_pred_prior,
|
480 |
+
noisy_model_input[diff_output_pr_indices],
|
481 |
+
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
482 |
+
)
|
483 |
+
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
484 |
+
|
485 |
+
return model_pred, target, timesteps, weighting
|
486 |
+
|
487 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
488 |
+
return loss
|
489 |
+
|
490 |
+
def get_sai_model_spec(self, args):
|
491 |
+
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
492 |
+
|
493 |
+
def update_metadata(self, metadata, args):
|
494 |
+
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
495 |
+
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
496 |
+
metadata["ss_logit_mean"] = args.logit_mean
|
497 |
+
metadata["ss_logit_std"] = args.logit_std
|
498 |
+
metadata["ss_mode_scale"] = args.mode_scale
|
499 |
+
metadata["ss_guidance_scale"] = args.guidance_scale
|
500 |
+
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
501 |
+
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
502 |
+
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
503 |
+
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
504 |
+
|
505 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
506 |
+
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
507 |
+
|
508 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
509 |
+
if index == 0: # CLIP-L
|
510 |
+
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
511 |
+
else: # T5XXL
|
512 |
+
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
513 |
+
|
514 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
515 |
+
if index == 0: # CLIP-L
|
516 |
+
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
517 |
+
text_encoder.to(te_weight_dtype) # fp8
|
518 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
519 |
+
else: # T5XXL
|
520 |
+
|
521 |
+
def prepare_fp8(text_encoder, target_dtype):
|
522 |
+
def forward_hook(module):
|
523 |
+
def forward(hidden_states):
|
524 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
525 |
+
hidden_linear = module.wi_1(hidden_states)
|
526 |
+
hidden_states = hidden_gelu * hidden_linear
|
527 |
+
hidden_states = module.dropout(hidden_states)
|
528 |
+
|
529 |
+
hidden_states = module.wo(hidden_states)
|
530 |
+
return hidden_states
|
531 |
+
|
532 |
+
return forward
|
533 |
+
|
534 |
+
for module in text_encoder.modules():
|
535 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
536 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
537 |
+
module.to(target_dtype)
|
538 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
539 |
+
# print("set", module.__class__.__name__, "hooks")
|
540 |
+
module.forward = forward_hook(module)
|
541 |
+
|
542 |
+
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
543 |
+
logger.info(f"T5XXL already prepared for fp8")
|
544 |
+
else:
|
545 |
+
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
546 |
+
text_encoder.to(te_weight_dtype) # fp8
|
547 |
+
prepare_fp8(text_encoder, weight_dtype)
|
548 |
+
|
549 |
+
def prepare_unet_with_accelerator(
|
550 |
+
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
551 |
+
) -> torch.nn.Module:
|
552 |
+
if not self.is_swapping_blocks:
|
553 |
+
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
554 |
+
|
555 |
+
# if we doesn't swap blocks, we can move the model to device
|
556 |
+
flux: flux_models.Flux = unet
|
557 |
+
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
558 |
+
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
559 |
+
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
560 |
+
|
561 |
+
return flux
|
562 |
+
|
563 |
+
|
564 |
+
def setup_parser() -> argparse.ArgumentParser:
|
565 |
+
parser = train_network.setup_parser()
|
566 |
+
train_util.add_dit_training_arguments(parser)
|
567 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
568 |
+
|
569 |
+
parser.add_argument(
|
570 |
+
"--split_mode",
|
571 |
+
action="store_true",
|
572 |
+
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
573 |
+
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
574 |
+
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
|
575 |
+
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
|
576 |
+
)
|
577 |
+
return parser
|
578 |
+
|
579 |
+
|
580 |
+
if __name__ == "__main__":
|
581 |
+
parser = setup_parser()
|
582 |
+
|
583 |
+
args = parser.parse_args()
|
584 |
+
train_util.verify_command_line_training_args(args)
|
585 |
+
args = train_util.read_config_from_file(args, parser)
|
586 |
+
|
587 |
+
trainer = FluxNetworkTrainer()
|
588 |
+
trainer.train(args)
|
flux_train_network_asylora.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import copy
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
from typing import Any, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from accelerate import Accelerator
|
12 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
13 |
+
|
14 |
+
init_ipex()
|
15 |
+
|
16 |
+
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
17 |
+
import train_network_asylora
|
18 |
+
from library.utils import setup_logging
|
19 |
+
|
20 |
+
setup_logging()
|
21 |
+
import logging
|
22 |
+
import re
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class FluxNetworkTrainer(train_network_asylora.NetworkTrainer):
|
28 |
+
def __init__(self):
|
29 |
+
super().__init__()
|
30 |
+
self.sample_prompts_te_outputs = None
|
31 |
+
self.is_schnell: Optional[bool] = None
|
32 |
+
self.is_swapping_blocks: bool = False
|
33 |
+
|
34 |
+
def assert_extra_args(self, args, train_dataset_group):
|
35 |
+
super().assert_extra_args(args, train_dataset_group)
|
36 |
+
# sdxl_train_util.verify_sdxl_training_args(args)
|
37 |
+
|
38 |
+
if args.fp8_base_unet:
|
39 |
+
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
40 |
+
|
41 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
42 |
+
logger.warning(
|
43 |
+
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
|
44 |
+
)
|
45 |
+
args.cache_text_encoder_outputs = True
|
46 |
+
|
47 |
+
if args.cache_text_encoder_outputs:
|
48 |
+
assert (
|
49 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
50 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
51 |
+
|
52 |
+
# prepare CLIP-L/T5XXL training flags
|
53 |
+
self.train_clip_l = not args.network_train_unet_only
|
54 |
+
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
55 |
+
|
56 |
+
if args.max_token_length is not None:
|
57 |
+
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
58 |
+
|
59 |
+
assert (
|
60 |
+
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
61 |
+
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
62 |
+
|
63 |
+
# deprecated split_mode option
|
64 |
+
if args.split_mode:
|
65 |
+
if args.blocks_to_swap is not None:
|
66 |
+
logger.warning(
|
67 |
+
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
|
68 |
+
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
logger.warning(
|
72 |
+
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
|
73 |
+
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
|
74 |
+
)
|
75 |
+
args.blocks_to_swap = 18 # 18 is safe for most cases
|
76 |
+
|
77 |
+
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
78 |
+
|
79 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
80 |
+
# currently offload to cpu for some models
|
81 |
+
|
82 |
+
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
83 |
+
loading_dtype = None if args.fp8_base else weight_dtype
|
84 |
+
|
85 |
+
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
86 |
+
self.is_schnell, model = flux_utils.load_flow_model(
|
87 |
+
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
88 |
+
)
|
89 |
+
if args.fp8_base:
|
90 |
+
# check dtype of model
|
91 |
+
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
92 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
93 |
+
elif model.dtype == torch.float8_e4m3fn:
|
94 |
+
logger.info("Loaded fp8 FLUX model")
|
95 |
+
else:
|
96 |
+
logger.info(
|
97 |
+
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
98 |
+
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
99 |
+
)
|
100 |
+
model.to(torch.float8_e4m3fn)
|
101 |
+
|
102 |
+
# if args.split_mode:
|
103 |
+
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
104 |
+
|
105 |
+
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
106 |
+
if self.is_swapping_blocks:
|
107 |
+
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
108 |
+
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
109 |
+
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
110 |
+
|
111 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
112 |
+
clip_l.eval()
|
113 |
+
|
114 |
+
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
115 |
+
if args.fp8_base and not args.fp8_base_unet:
|
116 |
+
loading_dtype = None # as is
|
117 |
+
else:
|
118 |
+
loading_dtype = weight_dtype
|
119 |
+
|
120 |
+
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
121 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
122 |
+
t5xxl.eval()
|
123 |
+
if args.fp8_base and not args.fp8_base_unet:
|
124 |
+
# check dtype of model
|
125 |
+
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
126 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
127 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
128 |
+
logger.info("Loaded fp8 T5XXL model")
|
129 |
+
|
130 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
131 |
+
|
132 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
133 |
+
|
134 |
+
def get_tokenize_strategy(self, args):
|
135 |
+
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
136 |
+
|
137 |
+
if args.t5xxl_max_token_length is None:
|
138 |
+
if is_schnell:
|
139 |
+
t5xxl_max_token_length = 256
|
140 |
+
else:
|
141 |
+
t5xxl_max_token_length = 512
|
142 |
+
else:
|
143 |
+
t5xxl_max_token_length = args.t5xxl_max_token_length
|
144 |
+
|
145 |
+
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
146 |
+
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
147 |
+
|
148 |
+
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
149 |
+
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
150 |
+
|
151 |
+
def get_latents_caching_strategy(self, args):
|
152 |
+
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
153 |
+
return latents_caching_strategy
|
154 |
+
|
155 |
+
def get_text_encoding_strategy(self, args):
|
156 |
+
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
157 |
+
|
158 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
159 |
+
# check t5xxl is trained or not
|
160 |
+
self.train_t5xxl = network.train_t5xxl
|
161 |
+
|
162 |
+
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
163 |
+
raise ValueError(
|
164 |
+
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
165 |
+
)
|
166 |
+
|
167 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
168 |
+
if args.cache_text_encoder_outputs:
|
169 |
+
if self.train_clip_l and not self.train_t5xxl:
|
170 |
+
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
171 |
+
else:
|
172 |
+
return None # no text encoders are needed for encoding because both are cached
|
173 |
+
else:
|
174 |
+
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
175 |
+
|
176 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
177 |
+
return [self.train_clip_l, self.train_t5xxl]
|
178 |
+
|
179 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
180 |
+
if args.cache_text_encoder_outputs:
|
181 |
+
# if the text encoders is trained, we need tokenization, so is_partial is True
|
182 |
+
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
183 |
+
args.cache_text_encoder_outputs_to_disk,
|
184 |
+
args.text_encoder_batch_size,
|
185 |
+
args.skip_cache_check,
|
186 |
+
is_partial=self.train_clip_l or self.train_t5xxl,
|
187 |
+
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
return None
|
191 |
+
|
192 |
+
def cache_text_encoder_outputs_if_needed(
|
193 |
+
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
194 |
+
):
|
195 |
+
if args.cache_text_encoder_outputs:
|
196 |
+
if not args.lowram:
|
197 |
+
# メモリ消費を減らす
|
198 |
+
logger.info("move vae and unet to cpu to save memory")
|
199 |
+
org_vae_device = vae.device
|
200 |
+
org_unet_device = unet.device
|
201 |
+
vae.to("cpu")
|
202 |
+
unet.to("cpu")
|
203 |
+
clean_memory_on_device(accelerator.device)
|
204 |
+
|
205 |
+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
206 |
+
logger.info("move text encoders to gpu")
|
207 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
208 |
+
text_encoders[1].to(accelerator.device)
|
209 |
+
|
210 |
+
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
211 |
+
# if we load fp8 weights, the model is already fp8, so we use it as is
|
212 |
+
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
213 |
+
else:
|
214 |
+
# otherwise, we need to convert it to target dtype
|
215 |
+
text_encoders[1].to(weight_dtype)
|
216 |
+
|
217 |
+
with accelerator.autocast():
|
218 |
+
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
219 |
+
|
220 |
+
# cache sample prompts
|
221 |
+
if args.sample_prompts is not None:
|
222 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
223 |
+
|
224 |
+
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
225 |
+
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
226 |
+
|
227 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
228 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
229 |
+
with accelerator.autocast(), torch.no_grad():
|
230 |
+
for prompt_dict in prompts:
|
231 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
232 |
+
if p not in sample_prompts_te_outputs:
|
233 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
234 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
235 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
236 |
+
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
237 |
+
)
|
238 |
+
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
239 |
+
|
240 |
+
accelerator.wait_for_everyone()
|
241 |
+
|
242 |
+
# move back to cpu
|
243 |
+
if not self.is_train_text_encoder(args):
|
244 |
+
logger.info("move CLIP-L back to cpu")
|
245 |
+
text_encoders[0].to("cpu")
|
246 |
+
logger.info("move t5XXL back to cpu")
|
247 |
+
text_encoders[1].to("cpu")
|
248 |
+
clean_memory_on_device(accelerator.device)
|
249 |
+
|
250 |
+
if not args.lowram:
|
251 |
+
logger.info("move vae and unet back to original device")
|
252 |
+
vae.to(org_vae_device)
|
253 |
+
unet.to(org_unet_device)
|
254 |
+
else:
|
255 |
+
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
256 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
257 |
+
text_encoders[1].to(accelerator.device)
|
258 |
+
|
259 |
+
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
260 |
+
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
261 |
+
|
262 |
+
# # get size embeddings
|
263 |
+
# orig_size = batch["original_sizes_hw"]
|
264 |
+
# crop_size = batch["crop_top_lefts"]
|
265 |
+
# target_size = batch["target_sizes_hw"]
|
266 |
+
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
267 |
+
|
268 |
+
# # concat embeddings
|
269 |
+
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
270 |
+
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
271 |
+
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
272 |
+
|
273 |
+
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
274 |
+
# return noise_pred
|
275 |
+
|
276 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
277 |
+
text_encoders = text_encoder # for compatibility
|
278 |
+
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
279 |
+
|
280 |
+
flux_train_utils.sample_images(
|
281 |
+
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
282 |
+
)
|
283 |
+
# return
|
284 |
+
|
285 |
+
"""
|
286 |
+
class FluxUpperLowerWrapper(torch.nn.Module):
|
287 |
+
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
288 |
+
super().__init__()
|
289 |
+
self.flux_upper = flux_upper
|
290 |
+
self.flux_lower = flux_lower
|
291 |
+
self.target_device = device
|
292 |
+
|
293 |
+
def prepare_block_swap_before_forward(self):
|
294 |
+
pass
|
295 |
+
|
296 |
+
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
297 |
+
self.flux_lower.to("cpu")
|
298 |
+
clean_memory_on_device(self.target_device)
|
299 |
+
self.flux_upper.to(self.target_device)
|
300 |
+
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
301 |
+
self.flux_upper.to("cpu")
|
302 |
+
clean_memory_on_device(self.target_device)
|
303 |
+
self.flux_lower.to(self.target_device)
|
304 |
+
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
305 |
+
|
306 |
+
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
307 |
+
clean_memory_on_device(accelerator.device)
|
308 |
+
flux_train_utils.sample_images(
|
309 |
+
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
310 |
+
)
|
311 |
+
clean_memory_on_device(accelerator.device)
|
312 |
+
"""
|
313 |
+
|
314 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
315 |
+
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
316 |
+
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
317 |
+
return noise_scheduler
|
318 |
+
|
319 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
320 |
+
return vae.encode(images)
|
321 |
+
|
322 |
+
def shift_scale_latents(self, args, latents):
|
323 |
+
return latents
|
324 |
+
|
325 |
+
def get_noise_pred_and_target(
|
326 |
+
self,
|
327 |
+
args,
|
328 |
+
accelerator,
|
329 |
+
noise_scheduler,
|
330 |
+
latents,
|
331 |
+
batch,
|
332 |
+
text_encoder_conds,
|
333 |
+
unet: flux_models.Flux,
|
334 |
+
network,
|
335 |
+
weight_dtype,
|
336 |
+
train_unet,
|
337 |
+
):
|
338 |
+
# Sample noise that we'll add to the latents
|
339 |
+
noise = torch.randn_like(latents)
|
340 |
+
bsz = latents.shape[0]
|
341 |
+
|
342 |
+
# get noisy model input and timesteps
|
343 |
+
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
344 |
+
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
345 |
+
)
|
346 |
+
|
347 |
+
# pack latents and get img_ids
|
348 |
+
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
349 |
+
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
350 |
+
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
351 |
+
|
352 |
+
# get guidance
|
353 |
+
# ensure guidance_scale in args is float
|
354 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
355 |
+
|
356 |
+
# ensure the hidden state will require grad
|
357 |
+
if args.gradient_checkpointing:
|
358 |
+
noisy_model_input.requires_grad_(True)
|
359 |
+
for t in text_encoder_conds:
|
360 |
+
if t is not None and t.dtype.is_floating_point:
|
361 |
+
t.requires_grad_(True)
|
362 |
+
img_ids.requires_grad_(True)
|
363 |
+
guidance_vec.requires_grad_(True)
|
364 |
+
|
365 |
+
# Predict the noise residual
|
366 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
367 |
+
if not args.apply_t5_attn_mask:
|
368 |
+
t5_attn_mask = None
|
369 |
+
|
370 |
+
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
371 |
+
# if not args.split_mode:
|
372 |
+
# normal forward
|
373 |
+
with accelerator.autocast():
|
374 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
375 |
+
model_pred = unet(
|
376 |
+
img=img,
|
377 |
+
img_ids=img_ids,
|
378 |
+
txt=t5_out,
|
379 |
+
txt_ids=txt_ids,
|
380 |
+
y=l_pooled,
|
381 |
+
timesteps=timesteps / 1000,
|
382 |
+
guidance=guidance_vec,
|
383 |
+
txt_attention_mask=t5_attn_mask
|
384 |
+
)
|
385 |
+
"""
|
386 |
+
else:
|
387 |
+
# split forward to reduce memory usage
|
388 |
+
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
389 |
+
with accelerator.autocast():
|
390 |
+
# move flux lower to cpu, and then move flux upper to gpu
|
391 |
+
unet.to("cpu")
|
392 |
+
clean_memory_on_device(accelerator.device)
|
393 |
+
self.flux_upper.to(accelerator.device)
|
394 |
+
|
395 |
+
# upper model does not require grad
|
396 |
+
with torch.no_grad():
|
397 |
+
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
398 |
+
img=packed_noisy_model_input,
|
399 |
+
img_ids=img_ids,
|
400 |
+
txt=t5_out,
|
401 |
+
txt_ids=txt_ids,
|
402 |
+
y=l_pooled,
|
403 |
+
timesteps=timesteps / 1000,
|
404 |
+
guidance=guidance_vec,
|
405 |
+
txt_attention_mask=t5_attn_mask,
|
406 |
+
)
|
407 |
+
|
408 |
+
# move flux upper back to cpu, and then move flux lower to gpu
|
409 |
+
self.flux_upper.to("cpu")
|
410 |
+
clean_memory_on_device(accelerator.device)
|
411 |
+
unet.to(accelerator.device)
|
412 |
+
|
413 |
+
# lower model requires grad
|
414 |
+
intermediate_img.requires_grad_(True)
|
415 |
+
intermediate_txt.requires_grad_(True)
|
416 |
+
vec.requires_grad_(True)
|
417 |
+
pe.requires_grad_(True)
|
418 |
+
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
419 |
+
"""
|
420 |
+
|
421 |
+
return model_pred
|
422 |
+
|
423 |
+
# 获取数据集分类编号 文本
|
424 |
+
# lora_category = batch["captions"][0].split(",")[0][3:]
|
425 |
+
# assert lora_category.isdigit(), f"lora_category 不是整数,值为: {lora_category}, {batch['captions'][0]}"
|
426 |
+
# lora_category = int(lora_category)
|
427 |
+
|
428 |
+
prompt_cur = batch["captions"][0]
|
429 |
+
match = re.search(r'--lora_up_cur (\d+)', prompt_cur)
|
430 |
+
assert match, "Pattern '--lora_up_cur' not found"
|
431 |
+
lora_category = int(match.group(1))
|
432 |
+
|
433 |
+
for lora in network.unet_loras:
|
434 |
+
lora.set_lora_up_cur(lora_category-1)
|
435 |
+
|
436 |
+
model_pred = call_dit(
|
437 |
+
img=packed_noisy_model_input,
|
438 |
+
img_ids=img_ids,
|
439 |
+
t5_out=t5_out,
|
440 |
+
txt_ids=txt_ids,
|
441 |
+
l_pooled=l_pooled,
|
442 |
+
timesteps=timesteps,
|
443 |
+
guidance_vec=guidance_vec,
|
444 |
+
t5_attn_mask=t5_attn_mask
|
445 |
+
)
|
446 |
+
|
447 |
+
# unpack latents
|
448 |
+
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
449 |
+
|
450 |
+
# apply model prediction type
|
451 |
+
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
452 |
+
|
453 |
+
# flow matching loss: this is different from SD3
|
454 |
+
target = noise - latents
|
455 |
+
|
456 |
+
# differential output preservation
|
457 |
+
if "custom_attributes" in batch:
|
458 |
+
diff_output_pr_indices = []
|
459 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
460 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
461 |
+
diff_output_pr_indices.append(i)
|
462 |
+
|
463 |
+
if len(diff_output_pr_indices) > 0:
|
464 |
+
network.set_multiplier(0.0)
|
465 |
+
unet.prepare_block_swap_before_forward()
|
466 |
+
with torch.no_grad():
|
467 |
+
model_pred_prior = call_dit(
|
468 |
+
img=packed_noisy_model_input[diff_output_pr_indices],
|
469 |
+
img_ids=img_ids[diff_output_pr_indices],
|
470 |
+
t5_out=t5_out[diff_output_pr_indices],
|
471 |
+
txt_ids=txt_ids[diff_output_pr_indices],
|
472 |
+
l_pooled=l_pooled[diff_output_pr_indices],
|
473 |
+
timesteps=timesteps[diff_output_pr_indices],
|
474 |
+
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
475 |
+
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
476 |
+
)
|
477 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
478 |
+
|
479 |
+
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
480 |
+
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
481 |
+
args,
|
482 |
+
model_pred_prior,
|
483 |
+
noisy_model_input[diff_output_pr_indices],
|
484 |
+
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
485 |
+
)
|
486 |
+
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
487 |
+
|
488 |
+
return model_pred, target, timesteps, None, weighting
|
489 |
+
|
490 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
491 |
+
return loss
|
492 |
+
|
493 |
+
def get_sai_model_spec(self, args):
|
494 |
+
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
495 |
+
|
496 |
+
def update_metadata(self, metadata, args):
|
497 |
+
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
498 |
+
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
499 |
+
metadata["ss_logit_mean"] = args.logit_mean
|
500 |
+
metadata["ss_logit_std"] = args.logit_std
|
501 |
+
metadata["ss_mode_scale"] = args.mode_scale
|
502 |
+
metadata["ss_guidance_scale"] = args.guidance_scale
|
503 |
+
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
504 |
+
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
505 |
+
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
506 |
+
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
507 |
+
|
508 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
509 |
+
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
510 |
+
|
511 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
512 |
+
if index == 0: # CLIP-L
|
513 |
+
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
514 |
+
else: # T5XXL
|
515 |
+
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
516 |
+
|
517 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
518 |
+
if index == 0: # CLIP-L
|
519 |
+
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
520 |
+
text_encoder.to(te_weight_dtype) # fp8
|
521 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
522 |
+
else: # T5XXL
|
523 |
+
|
524 |
+
def prepare_fp8(text_encoder, target_dtype):
|
525 |
+
def forward_hook(module):
|
526 |
+
def forward(hidden_states):
|
527 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
528 |
+
hidden_linear = module.wi_1(hidden_states)
|
529 |
+
hidden_states = hidden_gelu * hidden_linear
|
530 |
+
hidden_states = module.dropout(hidden_states)
|
531 |
+
|
532 |
+
hidden_states = module.wo(hidden_states)
|
533 |
+
return hidden_states
|
534 |
+
|
535 |
+
return forward
|
536 |
+
|
537 |
+
for module in text_encoder.modules():
|
538 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
539 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
540 |
+
module.to(target_dtype)
|
541 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
542 |
+
# print("set", module.__class__.__name__, "hooks")
|
543 |
+
module.forward = forward_hook(module)
|
544 |
+
|
545 |
+
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
546 |
+
logger.info(f"T5XXL already prepared for fp8")
|
547 |
+
else:
|
548 |
+
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
549 |
+
text_encoder.to(te_weight_dtype) # fp8
|
550 |
+
prepare_fp8(text_encoder, weight_dtype)
|
551 |
+
|
552 |
+
def prepare_unet_with_accelerator(
|
553 |
+
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
554 |
+
) -> torch.nn.Module:
|
555 |
+
if not self.is_swapping_blocks:
|
556 |
+
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
557 |
+
|
558 |
+
# if we doesn't swap blocks, we can move the model to device
|
559 |
+
flux: flux_models.Flux = unet
|
560 |
+
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
561 |
+
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
562 |
+
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
563 |
+
|
564 |
+
return flux
|
565 |
+
|
566 |
+
|
567 |
+
def setup_parser() -> argparse.ArgumentParser:
|
568 |
+
parser = train_network_asylora.setup_parser()
|
569 |
+
train_util.add_dit_training_arguments(parser)
|
570 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
571 |
+
|
572 |
+
parser.add_argument(
|
573 |
+
"--split_mode",
|
574 |
+
action="store_true",
|
575 |
+
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
576 |
+
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
577 |
+
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
|
578 |
+
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
|
579 |
+
)
|
580 |
+
return parser
|
581 |
+
|
582 |
+
|
583 |
+
if __name__ == "__main__":
|
584 |
+
parser = setup_parser()
|
585 |
+
|
586 |
+
args = parser.parse_args()
|
587 |
+
train_util.verify_command_line_training_args(args)
|
588 |
+
args = train_util.read_config_from_file(args, parser)
|
589 |
+
|
590 |
+
trainer = FluxNetworkTrainer()
|
591 |
+
trainer.train(args)
|
flux_train_recraft.py
ADDED
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
from typing import Any
|
6 |
+
import pdb
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
11 |
+
|
12 |
+
init_ipex()
|
13 |
+
|
14 |
+
from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
15 |
+
from torchvision import transforms
|
16 |
+
import train_network
|
17 |
+
from library.utils import setup_logging
|
18 |
+
from diffusers.utils import load_image
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image, ImageOps
|
21 |
+
|
22 |
+
setup_logging()
|
23 |
+
import logging
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
# NUM_SPLIT = 2
|
28 |
+
|
29 |
+
class ResizeWithPadding:
|
30 |
+
def __init__(self, size, fill=255):
|
31 |
+
self.size = size
|
32 |
+
self.fill = fill
|
33 |
+
|
34 |
+
def __call__(self, img):
|
35 |
+
if isinstance(img, np.ndarray):
|
36 |
+
img = Image.fromarray(img)
|
37 |
+
elif not isinstance(img, Image.Image):
|
38 |
+
raise TypeError("Input must be a PIL Image or a NumPy array")
|
39 |
+
|
40 |
+
width, height = img.size
|
41 |
+
|
42 |
+
if width == height:
|
43 |
+
img = img.resize((self.size, self.size), Image.LANCZOS)
|
44 |
+
else:
|
45 |
+
max_dim = max(width, height)
|
46 |
+
|
47 |
+
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
|
48 |
+
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
|
49 |
+
|
50 |
+
img = new_img.resize((self.size, self.size), Image.LANCZOS)
|
51 |
+
|
52 |
+
return img
|
53 |
+
|
54 |
+
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
55 |
+
def __init__(self):
|
56 |
+
super().__init__()
|
57 |
+
self.sample_prompts_te_outputs = None
|
58 |
+
self.sample_conditions = None
|
59 |
+
self.is_schnell: Optional[bool] = None
|
60 |
+
|
61 |
+
def assert_extra_args(self, args, train_dataset_group):
|
62 |
+
super().assert_extra_args(args, train_dataset_group)
|
63 |
+
# sdxl_train_util.verify_sdxl_training_args(args)
|
64 |
+
|
65 |
+
if args.fp8_base_unet:
|
66 |
+
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
67 |
+
|
68 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
69 |
+
logger.warning(
|
70 |
+
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
71 |
+
)
|
72 |
+
args.cache_text_encoder_outputs = True
|
73 |
+
|
74 |
+
if args.cache_text_encoder_outputs:
|
75 |
+
assert (
|
76 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
77 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
78 |
+
|
79 |
+
# prepare CLIP-L/T5XXL training flags
|
80 |
+
self.train_clip_l = not args.network_train_unet_only
|
81 |
+
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
82 |
+
|
83 |
+
if args.max_token_length is not None:
|
84 |
+
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
85 |
+
|
86 |
+
assert not args.split_mode or not args.cpu_offload_checkpointing, (
|
87 |
+
"split_mode and cpu_offload_checkpointing cannot be used together"
|
88 |
+
" / split_modeとcpu_offload_checkpointingは同時に使用できません"
|
89 |
+
)
|
90 |
+
|
91 |
+
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
92 |
+
|
93 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
94 |
+
# currently offload to cpu for some models
|
95 |
+
|
96 |
+
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
97 |
+
loading_dtype = None if args.fp8_base else weight_dtype
|
98 |
+
|
99 |
+
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
100 |
+
self.is_schnell, model = flux_utils.load_flow_model(
|
101 |
+
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
102 |
+
)
|
103 |
+
if args.fp8_base:
|
104 |
+
# check dtype of model
|
105 |
+
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
106 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
107 |
+
elif model.dtype == torch.float8_e4m3fn:
|
108 |
+
logger.info("Loaded fp8 FLUX model")
|
109 |
+
|
110 |
+
if args.split_mode:
|
111 |
+
model = self.prepare_split_model(model, weight_dtype, accelerator)
|
112 |
+
|
113 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
114 |
+
clip_l.eval()
|
115 |
+
|
116 |
+
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
117 |
+
if args.fp8_base and not args.fp8_base_unet:
|
118 |
+
loading_dtype = None # as is
|
119 |
+
else:
|
120 |
+
loading_dtype = weight_dtype
|
121 |
+
|
122 |
+
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
123 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
124 |
+
t5xxl.eval()
|
125 |
+
if args.fp8_base and not args.fp8_base_unet:
|
126 |
+
# check dtype of model
|
127 |
+
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
128 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
129 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
130 |
+
logger.info("Loaded fp8 T5XXL model")
|
131 |
+
|
132 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
133 |
+
|
134 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
135 |
+
|
136 |
+
def prepare_split_model(self, model, weight_dtype, accelerator):
|
137 |
+
from accelerate import init_empty_weights
|
138 |
+
|
139 |
+
logger.info("prepare split model")
|
140 |
+
with init_empty_weights():
|
141 |
+
flux_upper = flux_models.FluxUpper(model.params)
|
142 |
+
flux_lower = flux_models.FluxLower(model.params)
|
143 |
+
sd = model.state_dict()
|
144 |
+
|
145 |
+
# lower (trainable)
|
146 |
+
logger.info("load state dict for lower")
|
147 |
+
flux_lower.load_state_dict(sd, strict=False, assign=True)
|
148 |
+
flux_lower.to(dtype=weight_dtype)
|
149 |
+
|
150 |
+
# upper (frozen)
|
151 |
+
logger.info("load state dict for upper")
|
152 |
+
flux_upper.load_state_dict(sd, strict=False, assign=True)
|
153 |
+
|
154 |
+
logger.info("prepare upper model")
|
155 |
+
target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype
|
156 |
+
flux_upper.to(accelerator.device, dtype=target_dtype)
|
157 |
+
flux_upper.eval()
|
158 |
+
|
159 |
+
if args.fp8_base:
|
160 |
+
# this is required to run on fp8
|
161 |
+
flux_upper = accelerator.prepare(flux_upper)
|
162 |
+
|
163 |
+
flux_upper.to("cpu")
|
164 |
+
|
165 |
+
self.flux_upper = flux_upper
|
166 |
+
del model # we don't need model anymore
|
167 |
+
clean_memory_on_device(accelerator.device)
|
168 |
+
|
169 |
+
logger.info("split model prepared")
|
170 |
+
|
171 |
+
return flux_lower
|
172 |
+
|
173 |
+
def get_tokenize_strategy(self, args):
|
174 |
+
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
175 |
+
|
176 |
+
if args.t5xxl_max_token_length is None:
|
177 |
+
if is_schnell:
|
178 |
+
t5xxl_max_token_length = 256
|
179 |
+
else:
|
180 |
+
t5xxl_max_token_length = 512
|
181 |
+
else:
|
182 |
+
t5xxl_max_token_length = args.t5xxl_max_token_length
|
183 |
+
|
184 |
+
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
185 |
+
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
186 |
+
|
187 |
+
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
188 |
+
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
189 |
+
|
190 |
+
def get_latents_caching_strategy(self, args):
|
191 |
+
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
192 |
+
return latents_caching_strategy
|
193 |
+
|
194 |
+
def get_text_encoding_strategy(self, args):
|
195 |
+
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
196 |
+
|
197 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
198 |
+
# check t5xxl is trained or not
|
199 |
+
self.train_t5xxl = network.train_t5xxl
|
200 |
+
|
201 |
+
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
202 |
+
raise ValueError(
|
203 |
+
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
204 |
+
)
|
205 |
+
|
206 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
207 |
+
if args.cache_text_encoder_outputs:
|
208 |
+
if self.train_clip_l and not self.train_t5xxl:
|
209 |
+
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
210 |
+
else:
|
211 |
+
return None # no text encoders are needed for encoding because both are cached
|
212 |
+
else:
|
213 |
+
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
214 |
+
|
215 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
216 |
+
return [self.train_clip_l, self.train_t5xxl]
|
217 |
+
|
218 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
219 |
+
if args.cache_text_encoder_outputs:
|
220 |
+
# if the text encoders is trained, we need tokenization, so is_partial is True
|
221 |
+
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
222 |
+
args.cache_text_encoder_outputs_to_disk,
|
223 |
+
args.text_encoder_batch_size,
|
224 |
+
args.skip_cache_check,
|
225 |
+
is_partial=self.train_clip_l or self.train_t5xxl,
|
226 |
+
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
return None
|
230 |
+
|
231 |
+
def cache_text_encoder_outputs_if_needed(
|
232 |
+
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
233 |
+
):
|
234 |
+
if args.cache_text_encoder_outputs:
|
235 |
+
if not args.lowram:
|
236 |
+
# メモリ消費を減らす
|
237 |
+
logger.info("move vae and unet to cpu to save memory")
|
238 |
+
org_vae_device = vae.device
|
239 |
+
org_unet_device = unet.device
|
240 |
+
vae.to("cpu")
|
241 |
+
unet.to("cpu")
|
242 |
+
clean_memory_on_device(accelerator.device)
|
243 |
+
|
244 |
+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
245 |
+
logger.info("move text encoders to gpu")
|
246 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
247 |
+
text_encoders[1].to(accelerator.device)
|
248 |
+
|
249 |
+
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
250 |
+
# if we load fp8 weights, the model is already fp8, so we use it as is
|
251 |
+
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
252 |
+
else:
|
253 |
+
# otherwise, we need to convert it to target dtype
|
254 |
+
text_encoders[1].to(weight_dtype)
|
255 |
+
|
256 |
+
with accelerator.autocast():
|
257 |
+
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
258 |
+
|
259 |
+
# cache sample prompts
|
260 |
+
if args.sample_prompts is not None:
|
261 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
262 |
+
|
263 |
+
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
264 |
+
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
265 |
+
|
266 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
267 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
268 |
+
with accelerator.autocast(), torch.no_grad():
|
269 |
+
for prompt_dict in prompts:
|
270 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
271 |
+
if p not in sample_prompts_te_outputs:
|
272 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
273 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
274 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
275 |
+
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
276 |
+
)
|
277 |
+
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
278 |
+
|
279 |
+
# 添加conditions缓存逻辑
|
280 |
+
if args.sample_images is not None:
|
281 |
+
logger.info(f"cache conditions for sample images: {args.sample_images}")
|
282 |
+
|
283 |
+
# lc03lc
|
284 |
+
resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255)
|
285 |
+
img_transforms = transforms.Compose([
|
286 |
+
resize_transform,
|
287 |
+
transforms.ToTensor(),
|
288 |
+
transforms.Normalize([0.5], [0.5]),
|
289 |
+
])
|
290 |
+
|
291 |
+
if args.sample_images.endswith(".txt"):
|
292 |
+
with open(args.sample_images, "r", encoding="utf-8") as f:
|
293 |
+
lines = f.readlines()
|
294 |
+
sample_images = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
295 |
+
else:
|
296 |
+
raise NotImplementedError(f"sample_images file format not supported: {args.sample_images}")
|
297 |
+
|
298 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
299 |
+
conditions = {} # key: prompt, value: latents
|
300 |
+
|
301 |
+
with torch.no_grad():
|
302 |
+
for image, prompt_dict in zip(sample_images, prompts):
|
303 |
+
prompt = prompt_dict.get("prompt", "")
|
304 |
+
if prompt not in conditions:
|
305 |
+
logger.info(f"cache conditions for image: {image} with prompt: {prompt}")
|
306 |
+
image = img_transforms(np.array(load_image(image), dtype=np.uint8)).unsqueeze(0).to(vae.device, dtype=vae.dtype)
|
307 |
+
latents = self.encode_images_to_latents2(args, accelerator, vae, image)
|
308 |
+
# lc03lc
|
309 |
+
conditions[prompt] = latents
|
310 |
+
# if args.frame_num == 4:
|
311 |
+
# conditions[prompt] = latents[:,:,2*latents.shape[2]//3:latents.shape[2], 2*latents.shape[3]//3:latents.shape[3]].to("cpu")
|
312 |
+
# else:
|
313 |
+
# conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu")
|
314 |
+
|
315 |
+
self.sample_conditions = conditions
|
316 |
+
|
317 |
+
accelerator.wait_for_everyone()
|
318 |
+
|
319 |
+
# move back to cpu
|
320 |
+
if not self.is_train_text_encoder(args):
|
321 |
+
logger.info("move CLIP-L back to cpu")
|
322 |
+
text_encoders[0].to("cpu")
|
323 |
+
logger.info("move t5XXL back to cpu")
|
324 |
+
text_encoders[1].to("cpu")
|
325 |
+
clean_memory_on_device(accelerator.device)
|
326 |
+
|
327 |
+
if not args.lowram:
|
328 |
+
logger.info("move vae and unet back to original device")
|
329 |
+
vae.to(org_vae_device)
|
330 |
+
unet.to(org_unet_device)
|
331 |
+
else:
|
332 |
+
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
333 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
334 |
+
text_encoders[1].to(accelerator.device)
|
335 |
+
|
336 |
+
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
337 |
+
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
338 |
+
|
339 |
+
# # get size embeddings
|
340 |
+
# orig_size = batch["original_sizes_hw"]
|
341 |
+
# crop_size = batch["crop_top_lefts"]
|
342 |
+
# target_size = batch["target_sizes_hw"]
|
343 |
+
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
344 |
+
|
345 |
+
# # concat embeddings
|
346 |
+
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
347 |
+
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
348 |
+
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
349 |
+
|
350 |
+
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
351 |
+
# return noise_pred
|
352 |
+
|
353 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
354 |
+
text_encoders = text_encoder # for compatibility
|
355 |
+
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
356 |
+
# 直接使用预先计算的conditions
|
357 |
+
conditions = None
|
358 |
+
if self.sample_conditions is not None:
|
359 |
+
conditions = {k: v.to(accelerator.device) for k, v in self.sample_conditions.items()}
|
360 |
+
|
361 |
+
if not args.split_mode:
|
362 |
+
flux_train_utils.sample_images(
|
363 |
+
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs, None, conditions
|
364 |
+
)
|
365 |
+
return
|
366 |
+
|
367 |
+
class FluxUpperLowerWrapper(torch.nn.Module):
|
368 |
+
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
369 |
+
super().__init__()
|
370 |
+
self.flux_upper = flux_upper
|
371 |
+
self.flux_lower = flux_lower
|
372 |
+
self.target_device = device
|
373 |
+
|
374 |
+
def prepare_block_swap_before_forward(self):
|
375 |
+
pass
|
376 |
+
|
377 |
+
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
378 |
+
self.flux_lower.to("cpu")
|
379 |
+
clean_memory_on_device(self.target_device)
|
380 |
+
self.flux_upper.to(self.target_device)
|
381 |
+
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
382 |
+
self.flux_upper.to("cpu")
|
383 |
+
clean_memory_on_device(self.target_device)
|
384 |
+
self.flux_lower.to(self.target_device)
|
385 |
+
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
386 |
+
|
387 |
+
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
388 |
+
clean_memory_on_device(accelerator.device)
|
389 |
+
flux_train_utils.sample_images(
|
390 |
+
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs, conditions
|
391 |
+
)
|
392 |
+
clean_memory_on_device(accelerator.device)
|
393 |
+
|
394 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
395 |
+
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
396 |
+
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
397 |
+
return noise_scheduler
|
398 |
+
|
399 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
400 |
+
# 获取图像尺寸
|
401 |
+
b, c, h, w = images.shape
|
402 |
+
|
403 |
+
# num_split = NUM_SPLIT
|
404 |
+
num_split = 2 if args.frame_num == 4 else 3
|
405 |
+
# 将图像分成三个部分
|
406 |
+
img_parts = [images[:,:,:,i*w//num_split:(i+1)*w//num_split] for i in range(num_split)]
|
407 |
+
# 分别编码
|
408 |
+
latents = [vae.encode(img) for img in img_parts]
|
409 |
+
# 在latent空间拼接回完整图像
|
410 |
+
latents = torch.cat(latents, dim=-1)
|
411 |
+
|
412 |
+
return latents
|
413 |
+
|
414 |
+
def encode_images_to_latents2(self, args, accelerator, vae, images):
|
415 |
+
# 获取图像尺寸
|
416 |
+
b, c, h, w = images.shape
|
417 |
+
# num_split = NUM_SPLIT
|
418 |
+
num_split = 2 if args.frame_num == 4 else 3
|
419 |
+
latents = vae.encode(images)
|
420 |
+
return latents
|
421 |
+
|
422 |
+
def encode_images_to_latents3(self, args, accelerator, vae, images):
|
423 |
+
b, c, h, w = images.shape
|
424 |
+
# Number of splits along each dimension
|
425 |
+
num_split = 3
|
426 |
+
# Check if the image can be evenly divided into 3x3 grid
|
427 |
+
assert h % num_split == 0 and w % num_split == 0, "Image dimensions must be divisible by 3."
|
428 |
+
|
429 |
+
# Height and width of each split
|
430 |
+
split_h, split_w = h // num_split, w // num_split
|
431 |
+
|
432 |
+
# Store latents for each split
|
433 |
+
latents = []
|
434 |
+
|
435 |
+
for i in range(num_split):
|
436 |
+
for j in range(num_split):
|
437 |
+
# Extract the (i, j) sub-image
|
438 |
+
img_part = images[:, :, i * split_h:(i + 1) * split_h, j * split_w:(j + 1) * split_w]
|
439 |
+
# Encode the sub-image using VAE
|
440 |
+
latent = vae.encode(img_part)
|
441 |
+
# Append the latent
|
442 |
+
latents.append(latent)
|
443 |
+
|
444 |
+
# Combine latents into a 3x3 grid in the latent space
|
445 |
+
# Latents list -> Tensor [num_split^2, b, latent_dim, h', w']
|
446 |
+
latents = torch.stack(latents, dim=0)
|
447 |
+
|
448 |
+
# Reshape into a 3x3 grid
|
449 |
+
# Shape: [num_split, num_split, b, latent_dim, h', w']
|
450 |
+
latents = latents.view(num_split, num_split, b, *latents.shape[2:])
|
451 |
+
|
452 |
+
# Combine the 3x3 grid along height and width in latent space
|
453 |
+
# Concatenate along width for each row, then concatenate rows along height
|
454 |
+
latents = torch.cat([torch.cat(latents[i], dim=-1) for i in range(num_split)], dim=-2)
|
455 |
+
|
456 |
+
# Final shape: [b, latent_dim, h', w']
|
457 |
+
return latents
|
458 |
+
|
459 |
+
def shift_scale_latents(self, args, latents):
|
460 |
+
return latents
|
461 |
+
|
462 |
+
def get_noise_pred_and_target(
|
463 |
+
self,
|
464 |
+
args,
|
465 |
+
accelerator,
|
466 |
+
noise_scheduler,
|
467 |
+
latents,
|
468 |
+
batch,
|
469 |
+
text_encoder_conds,
|
470 |
+
unet: flux_models.Flux,
|
471 |
+
network,
|
472 |
+
weight_dtype,
|
473 |
+
train_unet,
|
474 |
+
):
|
475 |
+
# Sample noise that we'll add to the latents
|
476 |
+
noise = torch.randn_like(latents)
|
477 |
+
bsz = latents.shape[0]
|
478 |
+
|
479 |
+
# get noisy model input and timesteps
|
480 |
+
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
481 |
+
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
482 |
+
)
|
483 |
+
|
484 |
+
# pack latents and get img_ids
|
485 |
+
# yiren ? need modify?
|
486 |
+
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
487 |
+
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
488 |
+
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
489 |
+
|
490 |
+
# get guidance
|
491 |
+
# ensure guidance_scale in args is float
|
492 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
493 |
+
|
494 |
+
# ensure the hidden state will require grad
|
495 |
+
if args.gradient_checkpointing:
|
496 |
+
noisy_model_input.requires_grad_(True)
|
497 |
+
for t in text_encoder_conds:
|
498 |
+
if t is not None and t.dtype.is_floating_point:
|
499 |
+
t.requires_grad_(True)
|
500 |
+
img_ids.requires_grad_(True)
|
501 |
+
guidance_vec.requires_grad_(True)
|
502 |
+
|
503 |
+
# Predict the noise residual
|
504 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
505 |
+
if not args.apply_t5_attn_mask:
|
506 |
+
t5_attn_mask = None
|
507 |
+
|
508 |
+
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
509 |
+
if not args.split_mode:
|
510 |
+
# normal forward
|
511 |
+
with accelerator.autocast():
|
512 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
513 |
+
model_pred = unet(
|
514 |
+
img=img,
|
515 |
+
img_ids=img_ids,
|
516 |
+
txt=t5_out,
|
517 |
+
txt_ids=txt_ids,
|
518 |
+
y=l_pooled,
|
519 |
+
timesteps=timesteps / 1000,
|
520 |
+
guidance=guidance_vec,
|
521 |
+
txt_attention_mask=t5_attn_mask,
|
522 |
+
)
|
523 |
+
else:
|
524 |
+
# split forward to reduce memory usage
|
525 |
+
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
526 |
+
with accelerator.autocast():
|
527 |
+
# move flux lower to cpu, and then move flux upper to gpu
|
528 |
+
unet.to("cpu")
|
529 |
+
clean_memory_on_device(accelerator.device)
|
530 |
+
self.flux_upper.to(accelerator.device)
|
531 |
+
|
532 |
+
# upper model does not require grad
|
533 |
+
with torch.no_grad():
|
534 |
+
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
535 |
+
img=packed_noisy_model_input,
|
536 |
+
img_ids=img_ids,
|
537 |
+
txt=t5_out,
|
538 |
+
txt_ids=txt_ids,
|
539 |
+
y=l_pooled,
|
540 |
+
timesteps=timesteps / 1000,
|
541 |
+
guidance=guidance_vec,
|
542 |
+
txt_attention_mask=t5_attn_mask,
|
543 |
+
)
|
544 |
+
|
545 |
+
# move flux upper back to cpu, and then move flux lower to gpu
|
546 |
+
self.flux_upper.to("cpu")
|
547 |
+
clean_memory_on_device(accelerator.device)
|
548 |
+
unet.to(accelerator.device)
|
549 |
+
|
550 |
+
# lower model requires grad
|
551 |
+
intermediate_img.requires_grad_(True)
|
552 |
+
intermediate_txt.requires_grad_(True)
|
553 |
+
vec.requires_grad_(True)
|
554 |
+
pe.requires_grad_(True)
|
555 |
+
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
556 |
+
|
557 |
+
return model_pred
|
558 |
+
|
559 |
+
model_pred = call_dit(
|
560 |
+
img=packed_noisy_model_input,
|
561 |
+
img_ids=img_ids,
|
562 |
+
t5_out=t5_out,
|
563 |
+
txt_ids=txt_ids,
|
564 |
+
l_pooled=l_pooled,
|
565 |
+
timesteps=timesteps,
|
566 |
+
guidance_vec=guidance_vec,
|
567 |
+
t5_attn_mask=t5_attn_mask,
|
568 |
+
)
|
569 |
+
|
570 |
+
# unpack latents
|
571 |
+
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
572 |
+
|
573 |
+
# apply model prediction type
|
574 |
+
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
575 |
+
|
576 |
+
# flow matching loss: this is different from SD3
|
577 |
+
target = noise - latents
|
578 |
+
|
579 |
+
# differential output preservation
|
580 |
+
if "custom_attributes" in batch:
|
581 |
+
diff_output_pr_indices = []
|
582 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
583 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
584 |
+
diff_output_pr_indices.append(i)
|
585 |
+
|
586 |
+
if len(diff_output_pr_indices) > 0:
|
587 |
+
network.set_multiplier(0.0)
|
588 |
+
with torch.no_grad():
|
589 |
+
model_pred_prior = call_dit(
|
590 |
+
img=packed_noisy_model_input[diff_output_pr_indices],
|
591 |
+
img_ids=img_ids[diff_output_pr_indices],
|
592 |
+
t5_out=t5_out[diff_output_pr_indices],
|
593 |
+
txt_ids=txt_ids[diff_output_pr_indices],
|
594 |
+
l_pooled=l_pooled[diff_output_pr_indices],
|
595 |
+
timesteps=timesteps[diff_output_pr_indices],
|
596 |
+
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
597 |
+
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
598 |
+
)
|
599 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
600 |
+
|
601 |
+
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
602 |
+
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
603 |
+
args,
|
604 |
+
model_pred_prior,
|
605 |
+
noisy_model_input[diff_output_pr_indices],
|
606 |
+
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
607 |
+
)
|
608 |
+
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
609 |
+
|
610 |
+
# elimilate the loss in the left top quarter of the image
|
611 |
+
h, w = target.shape[2], target.shape[3]
|
612 |
+
# num_split = NUM_SPLIT
|
613 |
+
num_split = 2 if args.frame_num == 4 else 3
|
614 |
+
# target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split]
|
615 |
+
# target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split]
|
616 |
+
target[:, :, 2*h//num_split:h, 2*w//num_split:w] = model_pred[:, :, 2*h//num_split:h, 2*w//num_split:w]
|
617 |
+
|
618 |
+
|
619 |
+
return model_pred, target, timesteps, None, weighting
|
620 |
+
|
621 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
622 |
+
return loss
|
623 |
+
|
624 |
+
def get_sai_model_spec(self, args):
|
625 |
+
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
626 |
+
|
627 |
+
def update_metadata(self, metadata, args):
|
628 |
+
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
629 |
+
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
630 |
+
metadata["ss_logit_mean"] = args.logit_mean
|
631 |
+
metadata["ss_logit_std"] = args.logit_std
|
632 |
+
metadata["ss_mode_scale"] = args.mode_scale
|
633 |
+
metadata["ss_guidance_scale"] = args.guidance_scale
|
634 |
+
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
635 |
+
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
636 |
+
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
637 |
+
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
638 |
+
|
639 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
640 |
+
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
641 |
+
|
642 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
643 |
+
if index == 0: # CLIP-L
|
644 |
+
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
645 |
+
else: # T5XXL
|
646 |
+
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
647 |
+
|
648 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
649 |
+
if index == 0: # CLIP-L
|
650 |
+
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
651 |
+
text_encoder.to(te_weight_dtype) # fp8
|
652 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
653 |
+
else: # T5XXL
|
654 |
+
|
655 |
+
def prepare_fp8(text_encoder, target_dtype):
|
656 |
+
def forward_hook(module):
|
657 |
+
def forward(hidden_states):
|
658 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
659 |
+
hidden_linear = module.wi_1(hidden_states)
|
660 |
+
hidden_states = hidden_gelu * hidden_linear
|
661 |
+
hidden_states = module.dropout(hidden_states)
|
662 |
+
|
663 |
+
hidden_states = module.wo(hidden_states)
|
664 |
+
return hidden_states
|
665 |
+
|
666 |
+
return forward
|
667 |
+
|
668 |
+
for module in text_encoder.modules():
|
669 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
670 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
671 |
+
module.to(target_dtype)
|
672 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
673 |
+
# print("set", module.__class__.__name__, "hooks")
|
674 |
+
module.forward = forward_hook(module)
|
675 |
+
|
676 |
+
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
677 |
+
logger.info(f"T5XXL already prepared for fp8")
|
678 |
+
else:
|
679 |
+
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
680 |
+
text_encoder.to(te_weight_dtype) # fp8
|
681 |
+
prepare_fp8(text_encoder, weight_dtype)
|
682 |
+
|
683 |
+
|
684 |
+
def setup_parser() -> argparse.ArgumentParser:
|
685 |
+
parser = train_network.setup_parser()
|
686 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
687 |
+
|
688 |
+
parser.add_argument(
|
689 |
+
"--split_mode",
|
690 |
+
action="store_true",
|
691 |
+
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
692 |
+
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
693 |
+
)
|
694 |
+
|
695 |
+
parser.add_argument(
|
696 |
+
'--frame_num',
|
697 |
+
type=int,
|
698 |
+
choices=[4, 9],
|
699 |
+
required=True,
|
700 |
+
help="The number of steps in the generated step diagram (choose 4 or 9)"
|
701 |
+
)
|
702 |
+
return parser
|
703 |
+
|
704 |
+
|
705 |
+
if __name__ == "__main__":
|
706 |
+
parser = setup_parser()
|
707 |
+
|
708 |
+
args = parser.parse_args()
|
709 |
+
train_util.verify_command_line_training_args(args)
|
710 |
+
args = train_util.read_config_from_file(args, parser)
|
711 |
+
|
712 |
+
trainer = FluxNetworkTrainer()
|
713 |
+
trainer.train(args)
|
gradio_app.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from accelerate import Accelerator
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
from torchvision import transforms
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
from networks import lora_flux
|
12 |
+
from library import flux_utils, flux_train_utils_recraft as flux_train_utils, strategy_flux
|
13 |
+
import logging
|
14 |
+
from huggingface_hub import login
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
+
# Set up logger
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
logging.basicConfig(level=logging.DEBUG)
|
22 |
+
|
23 |
+
accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
|
24 |
+
|
25 |
+
# hf_token = os.getenv("HF_TOKEN")
|
26 |
+
# login(token=hf_token)
|
27 |
+
|
28 |
+
# # Model paths dynamically retrieved using selected model
|
29 |
+
# model_paths = {
|
30 |
+
# 'Wood Sculpture': {
|
31 |
+
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
32 |
+
# 'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp16.safetensors",
|
33 |
+
# 'LORA_REPO': "showlab/makeanything",
|
34 |
+
# 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors",
|
35 |
+
# "Frame": 4
|
36 |
+
# },
|
37 |
+
# 'LEGO': {
|
38 |
+
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
39 |
+
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp16.safetensors",
|
40 |
+
# 'LORA_REPO': "showlab/makeanything",
|
41 |
+
# 'LORA_FILE': "recraft/recraft_9f_lego.safetensors",
|
42 |
+
# "Frame": 9
|
43 |
+
# },
|
44 |
+
# 'Sketch': {
|
45 |
+
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
46 |
+
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp16.safetensors",
|
47 |
+
# 'LORA_REPO': "showlab/makeanything",
|
48 |
+
# 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors",
|
49 |
+
# "Frame": 9
|
50 |
+
# },
|
51 |
+
# 'Portrait': {
|
52 |
+
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
53 |
+
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp16.safetensors",
|
54 |
+
# 'LORA_REPO': "showlab/makeanything",
|
55 |
+
# 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors",
|
56 |
+
# "Frame": 9
|
57 |
+
# }
|
58 |
+
# }
|
59 |
+
|
60 |
+
# # Common paths
|
61 |
+
# clip_repo_id = "comfyanonymous/flux_text_encoders"
|
62 |
+
# t5xxl_file = "t5xxl_fp16.safetensors"
|
63 |
+
# clip_l_file = "clip_l.safetensors"
|
64 |
+
# ae_repo_id = "black-forest-labs/FLUX.1-dev"
|
65 |
+
# ae_file = "ae.safetensors"
|
66 |
+
|
67 |
+
model_paths = {
|
68 |
+
'Wood Sculpture': {
|
69 |
+
'BASE_FLUX_CHECKPOINT': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/flux_merge_lora/flux_merge_4f_wood_sculpture-fp8_e4m3fn.safetensors",
|
70 |
+
'LORA_WEIGHTS_PATH': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/recraft/recraft_4f_wood_sculpture.safetensors",
|
71 |
+
'Frame': 4
|
72 |
+
},
|
73 |
+
'LEGO': {
|
74 |
+
'BASE_FLUX_CHECKPOINT': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/flux_merge_lora/flux_merge_9f_lego-fp8_e4m3fn.safetensors",
|
75 |
+
'LORA_WEIGHTS_PATH': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/recraft/recraft_9f_lego.safetensors",
|
76 |
+
'Frame': 9
|
77 |
+
},
|
78 |
+
'Sketch': {
|
79 |
+
'BASE_FLUX_CHECKPOINT': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/flux_merge_lora/flux_merge_9f_sketch-fp8_e4m3fn.safetensors",
|
80 |
+
'LORA_WEIGHTS_PATH': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/recraft/recraft_9f_sketch.safetensors",
|
81 |
+
'Frame': 9
|
82 |
+
},
|
83 |
+
'Portrait': {
|
84 |
+
'BASE_FLUX_CHECKPOINT': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/flux_merge_lora/flux_merge_9f_portrait-fp8_e4m3fn.safetensors",
|
85 |
+
'LORA_WEIGHTS_PATH': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/recraft/recraft_9f_portrait.safetensors",
|
86 |
+
'Frame': 9
|
87 |
+
}
|
88 |
+
}
|
89 |
+
CLIP_L_PATH = "/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
|
90 |
+
T5XXL_PATH = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/clip/t5xxl_fp16.safetensors"
|
91 |
+
AE_PATH = "/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
92 |
+
|
93 |
+
|
94 |
+
# Model placeholders
|
95 |
+
model = None
|
96 |
+
clip_l = None
|
97 |
+
t5xxl = None
|
98 |
+
ae = None
|
99 |
+
lora_model = None
|
100 |
+
|
101 |
+
# Function to load a file from Hugging Face Hub
|
102 |
+
def download_file(repo_id, file_name):
|
103 |
+
return hf_hub_download(repo_id=repo_id, filename=file_name)
|
104 |
+
|
105 |
+
# Load model function with dynamic paths based on the selected model
|
106 |
+
def load_target_model(selected_model):
|
107 |
+
global model, clip_l, t5xxl, ae, lora_model
|
108 |
+
model_path = model_paths[selected_model]
|
109 |
+
BASE_FLUX_CHECKPOINT = model_path['BASE_FLUX_CHECKPOINT']
|
110 |
+
LORA_WEIGHTS_PATH = model_path['LORA_WEIGHTS_PATH']
|
111 |
+
|
112 |
+
logger.info("Loading models...")
|
113 |
+
try:
|
114 |
+
if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
115 |
+
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
116 |
+
clip_l.eval()
|
117 |
+
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
118 |
+
t5xxl.eval()
|
119 |
+
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
120 |
+
logger.info("Models loaded successfully.")
|
121 |
+
# Load models
|
122 |
+
_, model = flux_utils.load_flow_model(
|
123 |
+
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
124 |
+
)
|
125 |
+
# Load LoRA weights
|
126 |
+
multiplier = 1.0
|
127 |
+
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
128 |
+
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
129 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
130 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
131 |
+
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
132 |
+
lora_model.eval()
|
133 |
+
|
134 |
+
logger.info("Models loaded successfully.")
|
135 |
+
return "Models loaded successfully. Using Recraft: {}".format(selected_model)
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
logger.error(f"Error loading models: {e}")
|
139 |
+
return f"Error loading models: {e}"
|
140 |
+
|
141 |
+
# Image pre-processing (resize and padding)
|
142 |
+
class ResizeWithPadding:
|
143 |
+
def __init__(self, size, fill=255):
|
144 |
+
self.size = size
|
145 |
+
self.fill = fill
|
146 |
+
|
147 |
+
def __call__(self, img):
|
148 |
+
if isinstance(img, np.ndarray):
|
149 |
+
img = Image.fromarray(img)
|
150 |
+
elif not isinstance(img, Image.Image):
|
151 |
+
raise TypeError("Input must be a PIL Image or a NumPy array")
|
152 |
+
|
153 |
+
width, height = img.size
|
154 |
+
max_dim = max(width, height)
|
155 |
+
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
|
156 |
+
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
|
157 |
+
img = new_img.resize((self.size, self.size), Image.LANCZOS)
|
158 |
+
return img
|
159 |
+
|
160 |
+
# The function to generate image from a prompt and conditional image
|
161 |
+
# @spaces.GPU(duration=180)
|
162 |
+
def infer(prompt, sample_image, recraft_model, seed=0):
|
163 |
+
global model, clip_l, t5xxl, ae, lora_model
|
164 |
+
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
165 |
+
logger.error("Models not loaded. Please load the models first.")
|
166 |
+
return None
|
167 |
+
|
168 |
+
model_path = model_paths[recraft_model]
|
169 |
+
frame_num = model_path['Frame']
|
170 |
+
|
171 |
+
logger.info(f"Started generating image with prompt: {prompt}")
|
172 |
+
|
173 |
+
lora_model.to("cuda")
|
174 |
+
|
175 |
+
model.eval()
|
176 |
+
clip_l.eval()
|
177 |
+
t5xxl.eval()
|
178 |
+
ae.eval()
|
179 |
+
|
180 |
+
# # Load models
|
181 |
+
# model, [clip_l, t5xxl], ae = load_target_model()
|
182 |
+
|
183 |
+
# # LoRA
|
184 |
+
# multiplier = 1.0
|
185 |
+
# weights_sd = load_file(LORA_WEIGHTS_PATH)
|
186 |
+
# lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
|
187 |
+
# True)
|
188 |
+
|
189 |
+
# lora_model.apply_to([clip_l, t5xxl], model)
|
190 |
+
# info = lora_model.load_state_dict(weights_sd, strict=True)
|
191 |
+
# logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
192 |
+
# lora_model.eval()
|
193 |
+
# lora_model.to(device)
|
194 |
+
|
195 |
+
logger.info(f"Using seed: {seed}")
|
196 |
+
|
197 |
+
# Preprocess the conditional image
|
198 |
+
resize_transform = ResizeWithPadding(size=512) if frame_num == 4 else ResizeWithPadding(size=352)
|
199 |
+
img_transforms = transforms.Compose([
|
200 |
+
resize_transform,
|
201 |
+
transforms.ToTensor(),
|
202 |
+
transforms.Normalize([0.5], [0.5]),
|
203 |
+
])
|
204 |
+
image = img_transforms(np.array(sample_image, dtype=np.uint8)).unsqueeze(0).to(
|
205 |
+
device=device,
|
206 |
+
dtype=torch.bfloat16
|
207 |
+
)
|
208 |
+
logger.debug("Conditional image preprocessed.")
|
209 |
+
|
210 |
+
# Encode the image to latents
|
211 |
+
ae.to(device)
|
212 |
+
latents = ae.encode(image)
|
213 |
+
logger.debug("Image encoded to latents.")
|
214 |
+
|
215 |
+
conditions = {}
|
216 |
+
# conditions[prompt] = latents.to("cpu")
|
217 |
+
conditions[prompt] = latents
|
218 |
+
|
219 |
+
|
220 |
+
# ae.to("cpu")
|
221 |
+
clip_l.to(device)
|
222 |
+
t5xxl.to(device)
|
223 |
+
|
224 |
+
# Encode the prompt
|
225 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
|
226 |
+
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
|
227 |
+
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
228 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True)
|
229 |
+
|
230 |
+
logger.debug("Prompt encoded.")
|
231 |
+
|
232 |
+
# Prepare the noise and other parameters
|
233 |
+
width = 1024 if frame_num == 4 else 1056
|
234 |
+
height = 1024 if frame_num == 4 else 1056
|
235 |
+
|
236 |
+
height = max(64, height - height % 16)
|
237 |
+
width = max(64, width - width % 16)
|
238 |
+
|
239 |
+
packed_latent_height = height // 16
|
240 |
+
packed_latent_width = width // 16
|
241 |
+
|
242 |
+
torch.manual_seed(seed)
|
243 |
+
noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16)
|
244 |
+
logger.debug("Noise prepared.")
|
245 |
+
|
246 |
+
# Generate the image
|
247 |
+
timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20
|
248 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device)
|
249 |
+
|
250 |
+
t5_attn_mask = t5_attn_mask.to(device)
|
251 |
+
ae_outputs = conditions[prompt]
|
252 |
+
|
253 |
+
logger.debug("Image generation parameters set.")
|
254 |
+
|
255 |
+
args = lambda: None
|
256 |
+
args.frame_num = frame_num
|
257 |
+
|
258 |
+
# clip_l.to("cpu")
|
259 |
+
# t5xxl.to("cpu")
|
260 |
+
|
261 |
+
model.to(device)
|
262 |
+
|
263 |
+
print(f"Model device: {model.device}")
|
264 |
+
print(f"Noise device: {noise.device}")
|
265 |
+
print(f"Image IDs device: {img_ids.device}")
|
266 |
+
print(f"T5 output device: {t5_out.device}")
|
267 |
+
print(f"Text IDs device: {txt_ids.device}")
|
268 |
+
print(f"L pooled device: {l_pooled.device}")
|
269 |
+
|
270 |
+
# Run the denoising process
|
271 |
+
with accelerator.autocast(), torch.no_grad():
|
272 |
+
x = flux_train_utils.denoise(
|
273 |
+
args, model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=1.0, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs
|
274 |
+
)
|
275 |
+
logger.debug("Denoising process completed.")
|
276 |
+
|
277 |
+
# Decode the final image
|
278 |
+
x = x.float()
|
279 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
280 |
+
# model.to("cpu")
|
281 |
+
ae.to(device)
|
282 |
+
with accelerator.autocast(), torch.no_grad():
|
283 |
+
x = ae.decode(x)
|
284 |
+
logger.debug("Latents decoded into image.")
|
285 |
+
# ae.to("cpu")
|
286 |
+
|
287 |
+
# Convert the tensor to an image
|
288 |
+
x = x.clamp(-1, 1)
|
289 |
+
x = x.permute(0, 2, 3, 1)
|
290 |
+
generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
291 |
+
|
292 |
+
logger.info("Image generation completed.")
|
293 |
+
return generated_image
|
294 |
+
|
295 |
+
# Gradio interface
|
296 |
+
with gr.Blocks() as demo:
|
297 |
+
gr.Markdown("## Recraft Generation")
|
298 |
+
|
299 |
+
with gr.Row():
|
300 |
+
with gr.Column(scale=1):
|
301 |
+
# Dropdown for selecting the recraft model
|
302 |
+
recraft_model = gr.Dropdown(
|
303 |
+
label="Select Recraft Model",
|
304 |
+
choices=["Wood Sculpture", "LEGO", "Sketch", "Portrait"],
|
305 |
+
value="Wood Sculpture"
|
306 |
+
)
|
307 |
+
|
308 |
+
# Load Model Button
|
309 |
+
load_button = gr.Button("Load Model")
|
310 |
+
|
311 |
+
with gr.Column(scale=1):
|
312 |
+
# Status message box
|
313 |
+
status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
|
314 |
+
|
315 |
+
with gr.Row():
|
316 |
+
with gr.Column(scale=0.5):
|
317 |
+
# Input for the prompt
|
318 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=8)
|
319 |
+
seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=42)
|
320 |
+
|
321 |
+
with gr.Column(scale=0.5):
|
322 |
+
# File upload for image
|
323 |
+
sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
|
324 |
+
run_button = gr.Button("Generate Image")
|
325 |
+
|
326 |
+
with gr.Column(scale=1):
|
327 |
+
# Output result
|
328 |
+
result_image = gr.Image(label="Generated Image", interactive=False)
|
329 |
+
|
330 |
+
# Load model button action
|
331 |
+
load_button.click(fn=load_target_model, inputs=[recraft_model], outputs=[status_box])
|
332 |
+
|
333 |
+
# Run Button
|
334 |
+
run_button.click(fn=infer, inputs=[prompt, sample_image, recraft_model, seed], outputs=[result_image])
|
335 |
+
|
336 |
+
gr.Markdown("### Examples")
|
337 |
+
examples = [
|
338 |
+
[
|
339 |
+
"sks14, 2*2 puzzle of 4 sub-images, step-by-step wood sculpture carving process", # prompt
|
340 |
+
"./gradio_examples/wood_sculpture.png",
|
341 |
+
"Wood Sculpture", # recraft_model
|
342 |
+
12345 # seed
|
343 |
+
],
|
344 |
+
[
|
345 |
+
"sks1, 3*3 puzzle of 9 sub-images, step-by-step lego model construction process", # prompt
|
346 |
+
"./gradio_examples/lego.png",
|
347 |
+
"LEGO", # recraft_model
|
348 |
+
42 # seed
|
349 |
+
],
|
350 |
+
[
|
351 |
+
"sks6, 3*3 puzzle of 9 sub-images, step-by-step portrait painting process", # prompt
|
352 |
+
"./gradio_examples/portrait.png",
|
353 |
+
"Portrait", # recraft_model
|
354 |
+
999 # seed
|
355 |
+
],
|
356 |
+
[
|
357 |
+
"sks10, 3*3 puzzle of 9 sub-images, step-by-step sketch painting process,", # prompt
|
358 |
+
"./gradio_examples/sketch.png",
|
359 |
+
"Sketch",
|
360 |
+
2023
|
361 |
+
]
|
362 |
+
]
|
363 |
+
|
364 |
+
gr.Examples(
|
365 |
+
examples=examples,
|
366 |
+
inputs=[prompt, sample_image, recraft_model, seed],
|
367 |
+
outputs=[result_image],
|
368 |
+
cache_examples=False
|
369 |
+
)
|
370 |
+
|
371 |
+
# Launch the Gradio app
|
372 |
+
demo.launch(server_port=8289, server_name="0.0.0.0", share=True)
|
gradio_app_asy.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from accelerate import Accelerator
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import math
|
10 |
+
import json
|
11 |
+
from torchvision import transforms
|
12 |
+
from safetensors.torch import load_file
|
13 |
+
from networks import asylora_flux as lora_flux
|
14 |
+
from library import flux_utils, strategy_flux
|
15 |
+
import flux_minimal_inference_asylora as flux_train_utils
|
16 |
+
import logging
|
17 |
+
from huggingface_hub import login
|
18 |
+
from huggingface_hub import hf_hub_download
|
19 |
+
|
20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
|
22 |
+
# Set up logger
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
logging.basicConfig(level=logging.DEBUG)
|
25 |
+
|
26 |
+
accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
|
27 |
+
|
28 |
+
# hf_token = os.getenv("HF_TOKEN")
|
29 |
+
# login(token=hf_token)
|
30 |
+
|
31 |
+
# # Model paths dynamically retrieved using selected model
|
32 |
+
# model_paths = {
|
33 |
+
# 'Wood Sculpture': {
|
34 |
+
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
35 |
+
# 'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp16.safetensors",
|
36 |
+
# 'LORA_REPO': "showlab/makeanything",
|
37 |
+
# 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors",
|
38 |
+
# "Frame": 4
|
39 |
+
# },
|
40 |
+
# 'LEGO': {
|
41 |
+
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
42 |
+
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp16.safetensors",
|
43 |
+
# 'LORA_REPO': "showlab/makeanything",
|
44 |
+
# 'LORA_FILE': "recraft/recraft_9f_lego.safetensors",
|
45 |
+
# "Frame": 9
|
46 |
+
# },
|
47 |
+
# 'Sketch': {
|
48 |
+
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
49 |
+
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp16.safetensors",
|
50 |
+
# 'LORA_REPO': "showlab/makeanything",
|
51 |
+
# 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors",
|
52 |
+
# "Frame": 9
|
53 |
+
# },
|
54 |
+
# 'Portrait': {
|
55 |
+
# 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
|
56 |
+
# 'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp16.safetensors",
|
57 |
+
# 'LORA_REPO': "showlab/makeanything",
|
58 |
+
# 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors",
|
59 |
+
# "Frame": 9
|
60 |
+
# }
|
61 |
+
# }
|
62 |
+
|
63 |
+
# # Common paths
|
64 |
+
# clip_repo_id = "comfyanonymous/flux_text_encoders"
|
65 |
+
# t5xxl_file = "t5xxl_fp16.safetensors"
|
66 |
+
# clip_l_file = "clip_l.safetensors"
|
67 |
+
# ae_repo_id = "black-forest-labs/FLUX.1-dev"
|
68 |
+
# ae_file = "ae.safetensors"
|
69 |
+
|
70 |
+
domain_index = {
|
71 |
+
'LEGO': 1, 'Cook': 2, 'Painting': 3, 'Icon': 4, 'Landscape illustration': 5,
|
72 |
+
'Portrait': 6, 'Transformer': 7, 'Sand art': 8, 'Illustration': 9, 'Sketch': 10,
|
73 |
+
'Clay toys': 11, 'Clay sculpture': 12, 'Zbrush Modeling': 13, 'Wood sculpture': 14,
|
74 |
+
'Ink painting': 15, 'Pencil sketch': 16, 'Fabric toys': 17, 'Oil painting': 18,
|
75 |
+
'Jade Carving': 19, 'Line draw': 20, 'Emoji': 21
|
76 |
+
}
|
77 |
+
|
78 |
+
lora_paths = {
|
79 |
+
"9 frame": "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_9f_general.safetensors",
|
80 |
+
"4 frame": "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_4f_general.safetensors"
|
81 |
+
}
|
82 |
+
BASE_FLUX_CHECKPOINT = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/unet/flux1-dev-fp8.safetensors"
|
83 |
+
# LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_9f_general.safetensors"
|
84 |
+
CLIP_L_PATH = "/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
|
85 |
+
T5XXL_PATH = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/clip/t5xxl_fp8_e4m3fn.safetensors"
|
86 |
+
AE_PATH = "/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
|
87 |
+
|
88 |
+
|
89 |
+
# Model placeholders
|
90 |
+
model = None
|
91 |
+
clip_l = None
|
92 |
+
t5xxl = None
|
93 |
+
ae = None
|
94 |
+
lora_model = None
|
95 |
+
|
96 |
+
# Function to load a file from Hugging Face Hub
|
97 |
+
def download_file(repo_id, file_name):
|
98 |
+
return hf_hub_download(repo_id=repo_id, filename=file_name)
|
99 |
+
|
100 |
+
# Load model function with dynamic paths based on the selected model
|
101 |
+
def load_target_model(frame, domain):
|
102 |
+
global model, clip_l, t5xxl, ae, lora_model
|
103 |
+
|
104 |
+
logger.info("Loading models...")
|
105 |
+
# try:
|
106 |
+
if model is None is None or clip_l is None or t5xxl is None or ae is None:
|
107 |
+
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
108 |
+
clip_l.eval()
|
109 |
+
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
110 |
+
t5xxl.eval()
|
111 |
+
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
112 |
+
logger.info("Models loaded successfully.")
|
113 |
+
# Load models
|
114 |
+
_, model = flux_utils.load_flow_model(
|
115 |
+
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
116 |
+
)
|
117 |
+
# Load LoRA weights
|
118 |
+
LORA_WEIGHTS_PATH = lora_paths[frame]
|
119 |
+
multiplier = 1.0
|
120 |
+
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
121 |
+
lora_ups_num = 10 if frame=="9 frame" else 21
|
122 |
+
lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num=lora_ups_num)
|
123 |
+
for sub_lora in lora_model.unet_loras:
|
124 |
+
sub_lora.set_lora_up_cur(domain_index[domain]-1)
|
125 |
+
|
126 |
+
lora_model.apply_to([clip_l, t5xxl], model)
|
127 |
+
info = lora_model.load_state_dict(weights_sd, strict=True)
|
128 |
+
logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
|
129 |
+
lora_model.eval()
|
130 |
+
|
131 |
+
logger.info("Models loaded successfully.")
|
132 |
+
return "Models loaded successfully. Using Frame: {}, Damain: {}".format(frame, domain)
|
133 |
+
|
134 |
+
# except Exception as e:
|
135 |
+
# logger.error(f"Error loading models: {e}")
|
136 |
+
# return f"Error loading models: {e}"
|
137 |
+
|
138 |
+
# The function to generate image from a prompt and conditional image
|
139 |
+
# @spaces.GPU(duration=180)
|
140 |
+
def infer(prompt, frame, seed=0):
|
141 |
+
global model, clip_l, t5xxl, ae, lora_model
|
142 |
+
if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
|
143 |
+
logger.error("Models not loaded. Please load the models first.")
|
144 |
+
return None
|
145 |
+
|
146 |
+
frame_num = int(frame[0:1])
|
147 |
+
|
148 |
+
logger.info(f"Started generating image with prompt: {prompt}")
|
149 |
+
|
150 |
+
lora_model.to("cuda")
|
151 |
+
|
152 |
+
model.eval()
|
153 |
+
clip_l.eval()
|
154 |
+
t5xxl.eval()
|
155 |
+
ae.eval()
|
156 |
+
|
157 |
+
logger.info(f"Using seed: {seed}")
|
158 |
+
|
159 |
+
# ae.to("cpu")
|
160 |
+
clip_l.to(device)
|
161 |
+
t5xxl.to(device)
|
162 |
+
|
163 |
+
# Encode the prompt
|
164 |
+
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
|
165 |
+
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
|
166 |
+
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
167 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True)
|
168 |
+
|
169 |
+
logger.debug("Prompt encoded.")
|
170 |
+
|
171 |
+
# Prepare the noise and other parameters
|
172 |
+
width = 1024 if frame_num == 4 else 1056
|
173 |
+
height = 1024 if frame_num == 4 else 1056
|
174 |
+
|
175 |
+
packed_latent_height, packed_latent_width = math.ceil(height / 16), math.ceil(width / 16)
|
176 |
+
|
177 |
+
torch.manual_seed(seed)
|
178 |
+
noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16)
|
179 |
+
logger.debug("Noise prepared.")
|
180 |
+
|
181 |
+
|
182 |
+
# Generate the image
|
183 |
+
timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20
|
184 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device)
|
185 |
+
|
186 |
+
t5_attn_mask = t5_attn_mask.to(device)
|
187 |
+
|
188 |
+
logger.debug("Image generation parameters set.")
|
189 |
+
|
190 |
+
args = lambda: None
|
191 |
+
args.frame_num = frame_num
|
192 |
+
|
193 |
+
# clip_l.to("cpu")
|
194 |
+
# t5xxl.to("cpu")
|
195 |
+
|
196 |
+
model.to(device)
|
197 |
+
|
198 |
+
print(f"Model device: {model.device}")
|
199 |
+
print(f"Noise device: {noise.device}")
|
200 |
+
print(f"Image IDs device: {img_ids.device}")
|
201 |
+
print(f"T5 output device: {t5_out.device}")
|
202 |
+
print(f"Text IDs device: {txt_ids.device}")
|
203 |
+
print(f"L pooled device: {l_pooled.device}")
|
204 |
+
|
205 |
+
# Run the denoising process
|
206 |
+
with accelerator.autocast(), torch.no_grad():
|
207 |
+
x = flux_train_utils.denoise(
|
208 |
+
model,
|
209 |
+
noise,
|
210 |
+
img_ids,
|
211 |
+
t5_out,
|
212 |
+
txt_ids,
|
213 |
+
l_pooled,
|
214 |
+
timesteps,
|
215 |
+
guidance=4.0,
|
216 |
+
t5_attn_mask=t5_attn_mask,
|
217 |
+
cfg_scale=1.0,
|
218 |
+
)
|
219 |
+
logger.debug("Denoising process completed.")
|
220 |
+
|
221 |
+
# Decode the final image
|
222 |
+
x = x.float()
|
223 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
224 |
+
# model.to("cpu")
|
225 |
+
ae.to(device)
|
226 |
+
with accelerator.autocast(), torch.no_grad():
|
227 |
+
x = ae.decode(x)
|
228 |
+
logger.debug("Latents decoded into image.")
|
229 |
+
# ae.to("cpu")
|
230 |
+
|
231 |
+
# Convert the tensor to an image
|
232 |
+
x = x.clamp(-1, 1)
|
233 |
+
x = x.permute(0, 2, 3, 1)
|
234 |
+
generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
235 |
+
|
236 |
+
logger.info("Image generation completed.")
|
237 |
+
return generated_image
|
238 |
+
|
239 |
+
def update_domains(floor):
|
240 |
+
domains_dict = {
|
241 |
+
"4 frame": [
|
242 |
+
"LEGO", "Cook", "Painting", "Icon", "Landscape illustration",
|
243 |
+
"Portrait", "Transformer", "Sand art", "Illustration", "Sketch",
|
244 |
+
"Clay toys", "Clay sculpture", "Zbrush Modeling", "Wood sculpture", "Ink painting",
|
245 |
+
"Pencil sketch", "Fabric toys", "Oil painting", "Jade Carving", "Line draw", "Emoji"
|
246 |
+
],
|
247 |
+
"9 frame": [
|
248 |
+
"LEGO", "Cook", "Painting", "Icon", "Landscape illustration",
|
249 |
+
"Portrait", "Transformer", "Sand art", "Illustration", "Sketch"
|
250 |
+
]
|
251 |
+
}
|
252 |
+
return gr.Dropdown.update(choices=domains_dict[floor], label="Select Domains")
|
253 |
+
|
254 |
+
# Gradio interface
|
255 |
+
with gr.Blocks() as demo:
|
256 |
+
gr.Markdown("## Asymmertric LoRA Generation")
|
257 |
+
|
258 |
+
with gr.Row():
|
259 |
+
with gr.Column(scale=1):
|
260 |
+
with gr.Row():
|
261 |
+
with gr.Column(scale=1):
|
262 |
+
frame_selector = gr.Radio(choices=["4 frame", "9 frame"], label="Select Floor")
|
263 |
+
with gr.Column(scale=2):
|
264 |
+
domain_selector = gr.Dropdown(choices=[], label="Select Domains")
|
265 |
+
|
266 |
+
# Load Model Button
|
267 |
+
load_button = gr.Button("Load Model")
|
268 |
+
|
269 |
+
with gr.Column(scale=1):
|
270 |
+
# Status message box
|
271 |
+
status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
|
272 |
+
|
273 |
+
with gr.Row():
|
274 |
+
with gr.Column(scale=1):
|
275 |
+
# Input for the prompt
|
276 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=8)
|
277 |
+
with gr.Row():
|
278 |
+
seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=42)
|
279 |
+
run_button = gr.Button("Generate Image")
|
280 |
+
|
281 |
+
with gr.Column(scale=1):
|
282 |
+
# Output result
|
283 |
+
result_image = gr.Image(label="Generated Image", interactive=False)
|
284 |
+
|
285 |
+
frame_selector.change(update_domains, inputs=frame_selector, outputs=domain_selector)
|
286 |
+
|
287 |
+
# Load model button action
|
288 |
+
load_button.click(fn=load_target_model, inputs=[frame_selector, domain_selector], outputs=[status_box])
|
289 |
+
|
290 |
+
# Run Button
|
291 |
+
run_button.click(fn=infer, inputs=[prompt, frame_selector, seed], outputs=[result_image])
|
292 |
+
|
293 |
+
# gr.Markdown("### Examples")
|
294 |
+
# examples = [
|
295 |
+
# [
|
296 |
+
# "sks14, 2*2 puzzle of 4 sub-images, step-by-step wood sculpture carving process", # prompt
|
297 |
+
# "./gradio_examples/wood_sculpture.png",
|
298 |
+
# "Wood Sculpture", # recraft_model
|
299 |
+
# 12345 # seed
|
300 |
+
# ],
|
301 |
+
# [
|
302 |
+
# "sks1, 3*3 puzzle of 9 sub-images, step-by-step lego model construction process", # prompt
|
303 |
+
# "./gradio_examples/lego.png",
|
304 |
+
# "LEGO", # recraft_model
|
305 |
+
# 42 # seed
|
306 |
+
# ],
|
307 |
+
# [
|
308 |
+
# "sks6, 3*3 puzzle of 9 sub-images, step-by-step portrait painting process", # prompt
|
309 |
+
# "./gradio_examples/portrait.png",
|
310 |
+
# "Portrait", # recraft_model
|
311 |
+
# 999 # seed
|
312 |
+
# ],
|
313 |
+
# [
|
314 |
+
# "sks10, 3*3 puzzle of 9 sub-images, step-by-step sketch painting process,", # prompt
|
315 |
+
# "./gradio_examples/sketch.png",
|
316 |
+
# "Sketch",
|
317 |
+
# 2023
|
318 |
+
# ]
|
319 |
+
# ]
|
320 |
+
|
321 |
+
# gr.Examples(
|
322 |
+
# examples=examples,
|
323 |
+
# inputs=[prompt, sample_image, recraft_model, seed],
|
324 |
+
# outputs=[result_image],
|
325 |
+
# cache_examples=False
|
326 |
+
# )
|
327 |
+
|
328 |
+
# Launch the Gradio app
|
329 |
+
demo.launch(server_port=8289, server_name="0.0.0.0", share=True)
|
library/__init__.py
ADDED
File without changes
|
library/adafactor_fused.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from transformers import Adafactor
|
4 |
+
|
5 |
+
# stochastic rounding for bfloat16
|
6 |
+
# The implementation was provided by 2kpr. Thank you very much!
|
7 |
+
|
8 |
+
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
9 |
+
"""
|
10 |
+
copies source into target using stochastic rounding
|
11 |
+
|
12 |
+
Args:
|
13 |
+
target: the target tensor with dtype=bfloat16
|
14 |
+
source: the target tensor with dtype=float32
|
15 |
+
"""
|
16 |
+
# create a random 16 bit integer
|
17 |
+
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
18 |
+
|
19 |
+
# add the random number to the lower 16 bit of the mantissa
|
20 |
+
result.add_(source.view(dtype=torch.int32))
|
21 |
+
|
22 |
+
# mask off the lower 16 bit of the mantissa
|
23 |
+
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
24 |
+
|
25 |
+
# copy the higher 16 bit into the target tensor
|
26 |
+
target.copy_(result.view(dtype=torch.float32))
|
27 |
+
|
28 |
+
del result
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def adafactor_step_param(self, p, group):
|
33 |
+
if p.grad is None:
|
34 |
+
return
|
35 |
+
grad = p.grad
|
36 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
37 |
+
grad = grad.float()
|
38 |
+
if grad.is_sparse:
|
39 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
40 |
+
|
41 |
+
state = self.state[p]
|
42 |
+
grad_shape = grad.shape
|
43 |
+
|
44 |
+
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
|
45 |
+
# State Initialization
|
46 |
+
if len(state) == 0:
|
47 |
+
state["step"] = 0
|
48 |
+
|
49 |
+
if use_first_moment:
|
50 |
+
# Exponential moving average of gradient values
|
51 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
52 |
+
if factored:
|
53 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
54 |
+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
55 |
+
else:
|
56 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
57 |
+
|
58 |
+
state["RMS"] = 0
|
59 |
+
else:
|
60 |
+
if use_first_moment:
|
61 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
62 |
+
if factored:
|
63 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
64 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
65 |
+
else:
|
66 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
67 |
+
|
68 |
+
p_data_fp32 = p
|
69 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
70 |
+
p_data_fp32 = p_data_fp32.float()
|
71 |
+
|
72 |
+
state["step"] += 1
|
73 |
+
state["RMS"] = Adafactor._rms(p_data_fp32)
|
74 |
+
lr = Adafactor._get_lr(group, state)
|
75 |
+
|
76 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
77 |
+
update = (grad**2) + group["eps"][0]
|
78 |
+
if factored:
|
79 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
80 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
81 |
+
|
82 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
83 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
84 |
+
|
85 |
+
# Approximation of exponential moving average of square of gradient
|
86 |
+
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
87 |
+
update.mul_(grad)
|
88 |
+
else:
|
89 |
+
exp_avg_sq = state["exp_avg_sq"]
|
90 |
+
|
91 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
92 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
93 |
+
|
94 |
+
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
95 |
+
update.mul_(lr)
|
96 |
+
|
97 |
+
if use_first_moment:
|
98 |
+
exp_avg = state["exp_avg"]
|
99 |
+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
100 |
+
update = exp_avg
|
101 |
+
|
102 |
+
if group["weight_decay"] != 0:
|
103 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
104 |
+
|
105 |
+
p_data_fp32.add_(-update)
|
106 |
+
|
107 |
+
# if p.dtype in {torch.float16, torch.bfloat16}:
|
108 |
+
# p.copy_(p_data_fp32)
|
109 |
+
|
110 |
+
if p.dtype == torch.bfloat16:
|
111 |
+
copy_stochastic_(p, p_data_fp32)
|
112 |
+
elif p.dtype == torch.float16:
|
113 |
+
p.copy_(p_data_fp32)
|
114 |
+
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def adafactor_step(self, closure=None):
|
118 |
+
"""
|
119 |
+
Performs a single optimization step
|
120 |
+
|
121 |
+
Arguments:
|
122 |
+
closure (callable, optional): A closure that reevaluates the model
|
123 |
+
and returns the loss.
|
124 |
+
"""
|
125 |
+
loss = None
|
126 |
+
if closure is not None:
|
127 |
+
loss = closure()
|
128 |
+
|
129 |
+
for group in self.param_groups:
|
130 |
+
for p in group["params"]:
|
131 |
+
adafactor_step_param(self, p, group)
|
132 |
+
|
133 |
+
return loss
|
134 |
+
|
135 |
+
|
136 |
+
def patch_adafactor_fused(optimizer: Adafactor):
|
137 |
+
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
138 |
+
optimizer.step = adafactor_step.__get__(optimizer)
|
library/attention_processors.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
from diffusers.models.attention_processor import Attention
|
6 |
+
|
7 |
+
|
8 |
+
# flash attention forwards and backwards
|
9 |
+
|
10 |
+
# https://arxiv.org/abs/2205.14135
|
11 |
+
|
12 |
+
EPSILON = 1e-6
|
13 |
+
|
14 |
+
|
15 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
16 |
+
@staticmethod
|
17 |
+
@torch.no_grad()
|
18 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
19 |
+
"""Algorithm 2 in the paper"""
|
20 |
+
|
21 |
+
device = q.device
|
22 |
+
dtype = q.dtype
|
23 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
24 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
25 |
+
|
26 |
+
o = torch.zeros_like(q)
|
27 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
28 |
+
all_row_maxes = torch.full(
|
29 |
+
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
30 |
+
)
|
31 |
+
|
32 |
+
scale = q.shape[-1] ** -0.5
|
33 |
+
|
34 |
+
if mask is None:
|
35 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
36 |
+
else:
|
37 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
38 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
39 |
+
|
40 |
+
row_splits = zip(
|
41 |
+
q.split(q_bucket_size, dim=-2),
|
42 |
+
o.split(q_bucket_size, dim=-2),
|
43 |
+
mask,
|
44 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
45 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
46 |
+
)
|
47 |
+
|
48 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
49 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
50 |
+
|
51 |
+
col_splits = zip(
|
52 |
+
k.split(k_bucket_size, dim=-2),
|
53 |
+
v.split(k_bucket_size, dim=-2),
|
54 |
+
)
|
55 |
+
|
56 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
57 |
+
k_start_index = k_ind * k_bucket_size
|
58 |
+
|
59 |
+
attn_weights = (
|
60 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
61 |
+
)
|
62 |
+
|
63 |
+
if row_mask is not None:
|
64 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
65 |
+
|
66 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
67 |
+
causal_mask = torch.ones(
|
68 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
69 |
+
).triu(q_start_index - k_start_index + 1)
|
70 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
71 |
+
|
72 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
73 |
+
attn_weights -= block_row_maxes
|
74 |
+
exp_weights = torch.exp(attn_weights)
|
75 |
+
|
76 |
+
if row_mask is not None:
|
77 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
78 |
+
|
79 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
80 |
+
min=EPSILON
|
81 |
+
)
|
82 |
+
|
83 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
84 |
+
|
85 |
+
exp_values = torch.einsum(
|
86 |
+
"... i j, ... j d -> ... i d", exp_weights, vc
|
87 |
+
)
|
88 |
+
|
89 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
90 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
91 |
+
|
92 |
+
new_row_sums = (
|
93 |
+
exp_row_max_diff * row_sums
|
94 |
+
+ exp_block_row_max_diff * block_row_sums
|
95 |
+
)
|
96 |
+
|
97 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
98 |
+
(exp_block_row_max_diff / new_row_sums) * exp_values
|
99 |
+
)
|
100 |
+
|
101 |
+
row_maxes.copy_(new_row_maxes)
|
102 |
+
row_sums.copy_(new_row_sums)
|
103 |
+
|
104 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
105 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
106 |
+
|
107 |
+
return o
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
@torch.no_grad()
|
111 |
+
def backward(ctx, do):
|
112 |
+
"""Algorithm 4 in the paper"""
|
113 |
+
|
114 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
115 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
116 |
+
|
117 |
+
device = q.device
|
118 |
+
|
119 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
120 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
121 |
+
|
122 |
+
dq = torch.zeros_like(q)
|
123 |
+
dk = torch.zeros_like(k)
|
124 |
+
dv = torch.zeros_like(v)
|
125 |
+
|
126 |
+
row_splits = zip(
|
127 |
+
q.split(q_bucket_size, dim=-2),
|
128 |
+
o.split(q_bucket_size, dim=-2),
|
129 |
+
do.split(q_bucket_size, dim=-2),
|
130 |
+
mask,
|
131 |
+
l.split(q_bucket_size, dim=-2),
|
132 |
+
m.split(q_bucket_size, dim=-2),
|
133 |
+
dq.split(q_bucket_size, dim=-2),
|
134 |
+
)
|
135 |
+
|
136 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
137 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
138 |
+
|
139 |
+
col_splits = zip(
|
140 |
+
k.split(k_bucket_size, dim=-2),
|
141 |
+
v.split(k_bucket_size, dim=-2),
|
142 |
+
dk.split(k_bucket_size, dim=-2),
|
143 |
+
dv.split(k_bucket_size, dim=-2),
|
144 |
+
)
|
145 |
+
|
146 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
147 |
+
k_start_index = k_ind * k_bucket_size
|
148 |
+
|
149 |
+
attn_weights = (
|
150 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
151 |
+
)
|
152 |
+
|
153 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
154 |
+
causal_mask = torch.ones(
|
155 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
156 |
+
).triu(q_start_index - k_start_index + 1)
|
157 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
158 |
+
|
159 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
160 |
+
|
161 |
+
if row_mask is not None:
|
162 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
163 |
+
|
164 |
+
p = exp_attn_weights / lc
|
165 |
+
|
166 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
167 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
168 |
+
|
169 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
170 |
+
ds = p * scale * (dp - D)
|
171 |
+
|
172 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
173 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
174 |
+
|
175 |
+
dqc.add_(dq_chunk)
|
176 |
+
dkc.add_(dk_chunk)
|
177 |
+
dvc.add_(dv_chunk)
|
178 |
+
|
179 |
+
return dq, dk, dv, None, None, None, None
|
180 |
+
|
181 |
+
|
182 |
+
class FlashAttnProcessor:
|
183 |
+
def __call__(
|
184 |
+
self,
|
185 |
+
attn: Attention,
|
186 |
+
hidden_states,
|
187 |
+
encoder_hidden_states=None,
|
188 |
+
attention_mask=None,
|
189 |
+
) -> Any:
|
190 |
+
q_bucket_size = 512
|
191 |
+
k_bucket_size = 1024
|
192 |
+
|
193 |
+
h = attn.heads
|
194 |
+
q = attn.to_q(hidden_states)
|
195 |
+
|
196 |
+
encoder_hidden_states = (
|
197 |
+
encoder_hidden_states
|
198 |
+
if encoder_hidden_states is not None
|
199 |
+
else hidden_states
|
200 |
+
)
|
201 |
+
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
202 |
+
|
203 |
+
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
204 |
+
context_k, context_v = attn.hypernetwork.forward(
|
205 |
+
hidden_states, encoder_hidden_states
|
206 |
+
)
|
207 |
+
context_k = context_k.to(hidden_states.dtype)
|
208 |
+
context_v = context_v.to(hidden_states.dtype)
|
209 |
+
else:
|
210 |
+
context_k = encoder_hidden_states
|
211 |
+
context_v = encoder_hidden_states
|
212 |
+
|
213 |
+
k = attn.to_k(context_k)
|
214 |
+
v = attn.to_v(context_v)
|
215 |
+
del encoder_hidden_states, hidden_states
|
216 |
+
|
217 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
218 |
+
|
219 |
+
out = FlashAttentionFunction.apply(
|
220 |
+
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
221 |
+
)
|
222 |
+
|
223 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
224 |
+
|
225 |
+
out = attn.to_out[0](out)
|
226 |
+
out = attn.to_out[1](out)
|
227 |
+
return out
|
library/config_util.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import (
|
3 |
+
asdict,
|
4 |
+
dataclass,
|
5 |
+
)
|
6 |
+
import functools
|
7 |
+
import random
|
8 |
+
from textwrap import dedent, indent
|
9 |
+
import json
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# from toolz import curry
|
13 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
14 |
+
|
15 |
+
import toml
|
16 |
+
import voluptuous
|
17 |
+
from voluptuous import (
|
18 |
+
Any,
|
19 |
+
ExactSequence,
|
20 |
+
MultipleInvalid,
|
21 |
+
Object,
|
22 |
+
Required,
|
23 |
+
Schema,
|
24 |
+
)
|
25 |
+
from transformers import CLIPTokenizer
|
26 |
+
|
27 |
+
from . import train_util
|
28 |
+
from .train_util import (
|
29 |
+
DreamBoothSubset,
|
30 |
+
FineTuningSubset,
|
31 |
+
ControlNetSubset,
|
32 |
+
DreamBoothDataset,
|
33 |
+
FineTuningDataset,
|
34 |
+
ControlNetDataset,
|
35 |
+
DatasetGroup,
|
36 |
+
)
|
37 |
+
from .utils import setup_logging
|
38 |
+
|
39 |
+
setup_logging()
|
40 |
+
import logging
|
41 |
+
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
|
44 |
+
|
45 |
+
def add_config_arguments(parser: argparse.ArgumentParser):
|
46 |
+
parser.add_argument(
|
47 |
+
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
# TODO: inherit Params class in Subset, Dataset
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class BaseSubsetParams:
|
56 |
+
image_dir: Optional[str] = None
|
57 |
+
num_repeats: int = 1
|
58 |
+
shuffle_caption: bool = False
|
59 |
+
caption_separator: str = (",",)
|
60 |
+
keep_tokens: int = 0
|
61 |
+
keep_tokens_separator: str = (None,)
|
62 |
+
secondary_separator: Optional[str] = None
|
63 |
+
enable_wildcard: bool = False
|
64 |
+
color_aug: bool = False
|
65 |
+
flip_aug: bool = False
|
66 |
+
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
67 |
+
random_crop: bool = False
|
68 |
+
caption_prefix: Optional[str] = None
|
69 |
+
caption_suffix: Optional[str] = None
|
70 |
+
caption_dropout_rate: float = 0.0
|
71 |
+
caption_dropout_every_n_epochs: int = 0
|
72 |
+
caption_tag_dropout_rate: float = 0.0
|
73 |
+
token_warmup_min: int = 1
|
74 |
+
token_warmup_step: float = 0
|
75 |
+
custom_attributes: Optional[Dict[str, Any]] = None
|
76 |
+
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class DreamBoothSubsetParams(BaseSubsetParams):
|
80 |
+
is_reg: bool = False
|
81 |
+
class_tokens: Optional[str] = None
|
82 |
+
caption_extension: str = ".caption"
|
83 |
+
cache_info: bool = False
|
84 |
+
alpha_mask: bool = False
|
85 |
+
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class FineTuningSubsetParams(BaseSubsetParams):
|
89 |
+
metadata_file: Optional[str] = None
|
90 |
+
alpha_mask: bool = False
|
91 |
+
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class ControlNetSubsetParams(BaseSubsetParams):
|
95 |
+
conditioning_data_dir: str = None
|
96 |
+
caption_extension: str = ".caption"
|
97 |
+
cache_info: bool = False
|
98 |
+
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class BaseDatasetParams:
|
102 |
+
resolution: Optional[Tuple[int, int]] = None
|
103 |
+
network_multiplier: float = 1.0
|
104 |
+
debug_dataset: bool = False
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class DreamBoothDatasetParams(BaseDatasetParams):
|
109 |
+
batch_size: int = 1
|
110 |
+
enable_bucket: bool = False
|
111 |
+
min_bucket_reso: int = 256
|
112 |
+
max_bucket_reso: int = 1024
|
113 |
+
bucket_reso_steps: int = 64
|
114 |
+
bucket_no_upscale: bool = False
|
115 |
+
prior_loss_weight: float = 1.0
|
116 |
+
|
117 |
+
|
118 |
+
@dataclass
|
119 |
+
class FineTuningDatasetParams(BaseDatasetParams):
|
120 |
+
batch_size: int = 1
|
121 |
+
enable_bucket: bool = False
|
122 |
+
min_bucket_reso: int = 256
|
123 |
+
max_bucket_reso: int = 1024
|
124 |
+
bucket_reso_steps: int = 64
|
125 |
+
bucket_no_upscale: bool = False
|
126 |
+
|
127 |
+
|
128 |
+
@dataclass
|
129 |
+
class ControlNetDatasetParams(BaseDatasetParams):
|
130 |
+
batch_size: int = 1
|
131 |
+
enable_bucket: bool = False
|
132 |
+
min_bucket_reso: int = 256
|
133 |
+
max_bucket_reso: int = 1024
|
134 |
+
bucket_reso_steps: int = 64
|
135 |
+
bucket_no_upscale: bool = False
|
136 |
+
|
137 |
+
|
138 |
+
@dataclass
|
139 |
+
class SubsetBlueprint:
|
140 |
+
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
141 |
+
|
142 |
+
|
143 |
+
@dataclass
|
144 |
+
class DatasetBlueprint:
|
145 |
+
is_dreambooth: bool
|
146 |
+
is_controlnet: bool
|
147 |
+
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
148 |
+
subsets: Sequence[SubsetBlueprint]
|
149 |
+
|
150 |
+
|
151 |
+
@dataclass
|
152 |
+
class DatasetGroupBlueprint:
|
153 |
+
datasets: Sequence[DatasetBlueprint]
|
154 |
+
|
155 |
+
|
156 |
+
@dataclass
|
157 |
+
class Blueprint:
|
158 |
+
dataset_group: DatasetGroupBlueprint
|
159 |
+
|
160 |
+
|
161 |
+
class ConfigSanitizer:
|
162 |
+
# @curry
|
163 |
+
@staticmethod
|
164 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
165 |
+
Schema(ExactSequence([klass, klass]))(value)
|
166 |
+
return tuple(value)
|
167 |
+
|
168 |
+
# @curry
|
169 |
+
@staticmethod
|
170 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
171 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
172 |
+
try:
|
173 |
+
Schema(klass)(value)
|
174 |
+
return (value, value)
|
175 |
+
except:
|
176 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
177 |
+
|
178 |
+
# subset schema
|
179 |
+
SUBSET_ASCENDABLE_SCHEMA = {
|
180 |
+
"color_aug": bool,
|
181 |
+
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
|
182 |
+
"flip_aug": bool,
|
183 |
+
"num_repeats": int,
|
184 |
+
"random_crop": bool,
|
185 |
+
"shuffle_caption": bool,
|
186 |
+
"keep_tokens": int,
|
187 |
+
"keep_tokens_separator": str,
|
188 |
+
"secondary_separator": str,
|
189 |
+
"caption_separator": str,
|
190 |
+
"enable_wildcard": bool,
|
191 |
+
"token_warmup_min": int,
|
192 |
+
"token_warmup_step": Any(float, int),
|
193 |
+
"caption_prefix": str,
|
194 |
+
"caption_suffix": str,
|
195 |
+
"custom_attributes": dict,
|
196 |
+
}
|
197 |
+
# DO means DropOut
|
198 |
+
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
199 |
+
"caption_dropout_every_n_epochs": int,
|
200 |
+
"caption_dropout_rate": Any(float, int),
|
201 |
+
"caption_tag_dropout_rate": Any(float, int),
|
202 |
+
}
|
203 |
+
# DB means DreamBooth
|
204 |
+
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
205 |
+
"caption_extension": str,
|
206 |
+
"class_tokens": str,
|
207 |
+
"cache_info": bool,
|
208 |
+
}
|
209 |
+
DB_SUBSET_DISTINCT_SCHEMA = {
|
210 |
+
Required("image_dir"): str,
|
211 |
+
"is_reg": bool,
|
212 |
+
"alpha_mask": bool,
|
213 |
+
}
|
214 |
+
# FT means FineTuning
|
215 |
+
FT_SUBSET_DISTINCT_SCHEMA = {
|
216 |
+
Required("metadata_file"): str,
|
217 |
+
"image_dir": str,
|
218 |
+
"alpha_mask": bool,
|
219 |
+
}
|
220 |
+
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
221 |
+
"caption_extension": str,
|
222 |
+
"cache_info": bool,
|
223 |
+
}
|
224 |
+
CN_SUBSET_DISTINCT_SCHEMA = {
|
225 |
+
Required("image_dir"): str,
|
226 |
+
Required("conditioning_data_dir"): str,
|
227 |
+
}
|
228 |
+
|
229 |
+
# datasets schema
|
230 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
231 |
+
"batch_size": int,
|
232 |
+
"bucket_no_upscale": bool,
|
233 |
+
"bucket_reso_steps": int,
|
234 |
+
"enable_bucket": bool,
|
235 |
+
"max_bucket_reso": int,
|
236 |
+
"min_bucket_reso": int,
|
237 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
238 |
+
"network_multiplier": float,
|
239 |
+
}
|
240 |
+
|
241 |
+
# options handled by argparse but not handled by user config
|
242 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
243 |
+
"debug_dataset": bool,
|
244 |
+
"max_token_length": Any(None, int),
|
245 |
+
"prior_loss_weight": Any(float, int),
|
246 |
+
}
|
247 |
+
# for handling default None value of argparse
|
248 |
+
ARGPARSE_NULLABLE_OPTNAMES = [
|
249 |
+
"face_crop_aug_range",
|
250 |
+
"resolution",
|
251 |
+
]
|
252 |
+
# prepare map because option name may differ among argparse and user config
|
253 |
+
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
254 |
+
"train_batch_size": "batch_size",
|
255 |
+
"dataset_repeats": "num_repeats",
|
256 |
+
}
|
257 |
+
|
258 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
259 |
+
assert support_dreambooth or support_finetuning or support_controlnet, (
|
260 |
+
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
261 |
+
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
262 |
+
)
|
263 |
+
|
264 |
+
self.db_subset_schema = self.__merge_dict(
|
265 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
266 |
+
self.DB_SUBSET_DISTINCT_SCHEMA,
|
267 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
268 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
269 |
+
)
|
270 |
+
|
271 |
+
self.ft_subset_schema = self.__merge_dict(
|
272 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
273 |
+
self.FT_SUBSET_DISTINCT_SCHEMA,
|
274 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
275 |
+
)
|
276 |
+
|
277 |
+
self.cn_subset_schema = self.__merge_dict(
|
278 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
279 |
+
self.CN_SUBSET_DISTINCT_SCHEMA,
|
280 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
281 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
282 |
+
)
|
283 |
+
|
284 |
+
self.db_dataset_schema = self.__merge_dict(
|
285 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
286 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
287 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
288 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
289 |
+
{"subsets": [self.db_subset_schema]},
|
290 |
+
)
|
291 |
+
|
292 |
+
self.ft_dataset_schema = self.__merge_dict(
|
293 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
294 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
295 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
296 |
+
{"subsets": [self.ft_subset_schema]},
|
297 |
+
)
|
298 |
+
|
299 |
+
self.cn_dataset_schema = self.__merge_dict(
|
300 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
301 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
302 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
303 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
304 |
+
{"subsets": [self.cn_subset_schema]},
|
305 |
+
)
|
306 |
+
|
307 |
+
if support_dreambooth and support_finetuning:
|
308 |
+
|
309 |
+
def validate_flex_dataset(dataset_config: dict):
|
310 |
+
subsets_config = dataset_config.get("subsets", [])
|
311 |
+
|
312 |
+
if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
313 |
+
return Schema(self.cn_dataset_schema)(dataset_config)
|
314 |
+
# check dataset meets FT style
|
315 |
+
# NOTE: all FT subsets should have "metadata_file"
|
316 |
+
elif all(["metadata_file" in subset for subset in subsets_config]):
|
317 |
+
return Schema(self.ft_dataset_schema)(dataset_config)
|
318 |
+
# check dataset meets DB style
|
319 |
+
# NOTE: all DB subsets should have no "metadata_file"
|
320 |
+
elif all(["metadata_file" not in subset for subset in subsets_config]):
|
321 |
+
return Schema(self.db_dataset_schema)(dataset_config)
|
322 |
+
else:
|
323 |
+
raise voluptuous.Invalid(
|
324 |
+
"DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
|
325 |
+
)
|
326 |
+
|
327 |
+
self.dataset_schema = validate_flex_dataset
|
328 |
+
elif support_dreambooth:
|
329 |
+
if support_controlnet:
|
330 |
+
self.dataset_schema = self.cn_dataset_schema
|
331 |
+
else:
|
332 |
+
self.dataset_schema = self.db_dataset_schema
|
333 |
+
elif support_finetuning:
|
334 |
+
self.dataset_schema = self.ft_dataset_schema
|
335 |
+
elif support_controlnet:
|
336 |
+
self.dataset_schema = self.cn_dataset_schema
|
337 |
+
|
338 |
+
self.general_schema = self.__merge_dict(
|
339 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
340 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
341 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
342 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
343 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
344 |
+
)
|
345 |
+
|
346 |
+
self.user_config_validator = Schema(
|
347 |
+
{
|
348 |
+
"general": self.general_schema,
|
349 |
+
"datasets": [self.dataset_schema],
|
350 |
+
}
|
351 |
+
)
|
352 |
+
|
353 |
+
self.argparse_schema = self.__merge_dict(
|
354 |
+
self.general_schema,
|
355 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
356 |
+
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
|
357 |
+
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
|
358 |
+
)
|
359 |
+
|
360 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
361 |
+
|
362 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
363 |
+
try:
|
364 |
+
return self.user_config_validator(user_config)
|
365 |
+
except MultipleInvalid:
|
366 |
+
# TODO: エラー発生時のメッセージをわかりやすくする
|
367 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
368 |
+
raise
|
369 |
+
|
370 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
371 |
+
# However this will help us to detect program bug
|
372 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
373 |
+
try:
|
374 |
+
return self.argparse_config_validator(argparse_namespace)
|
375 |
+
except MultipleInvalid:
|
376 |
+
# XXX: this should be a bug
|
377 |
+
logger.error(
|
378 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
379 |
+
)
|
380 |
+
raise
|
381 |
+
|
382 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
383 |
+
@staticmethod
|
384 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
385 |
+
merged = {}
|
386 |
+
for schema in dict_list:
|
387 |
+
# merged |= schema
|
388 |
+
for k, v in schema.items():
|
389 |
+
merged[k] = v
|
390 |
+
return merged
|
391 |
+
|
392 |
+
|
393 |
+
class BlueprintGenerator:
|
394 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
395 |
+
|
396 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
397 |
+
self.sanitizer = sanitizer
|
398 |
+
|
399 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
400 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
401 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
402 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
403 |
+
|
404 |
+
# convert argparse namespace to dict like config
|
405 |
+
# NOTE: it is ok to have extra entries in dict
|
406 |
+
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
|
407 |
+
argparse_config = {
|
408 |
+
optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
|
409 |
+
}
|
410 |
+
|
411 |
+
general_config = sanitized_user_config.get("general", {})
|
412 |
+
|
413 |
+
dataset_blueprints = []
|
414 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
415 |
+
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
416 |
+
subsets = dataset_config.get("subsets", [])
|
417 |
+
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
418 |
+
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
419 |
+
if is_controlnet:
|
420 |
+
subset_params_klass = ControlNetSubsetParams
|
421 |
+
dataset_params_klass = ControlNetDatasetParams
|
422 |
+
elif is_dreambooth:
|
423 |
+
subset_params_klass = DreamBoothSubsetParams
|
424 |
+
dataset_params_klass = DreamBoothDatasetParams
|
425 |
+
else:
|
426 |
+
subset_params_klass = FineTuningSubsetParams
|
427 |
+
dataset_params_klass = FineTuningDatasetParams
|
428 |
+
|
429 |
+
subset_blueprints = []
|
430 |
+
for subset_config in subsets:
|
431 |
+
params = self.generate_params_by_fallbacks(
|
432 |
+
subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
|
433 |
+
)
|
434 |
+
subset_blueprints.append(SubsetBlueprint(params))
|
435 |
+
|
436 |
+
params = self.generate_params_by_fallbacks(
|
437 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
438 |
+
)
|
439 |
+
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
440 |
+
|
441 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
442 |
+
|
443 |
+
return Blueprint(dataset_group_blueprint)
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
447 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
448 |
+
search_value = BlueprintGenerator.search_value
|
449 |
+
default_params = asdict(param_klass())
|
450 |
+
param_names = default_params.keys()
|
451 |
+
|
452 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
453 |
+
|
454 |
+
return param_klass(**params)
|
455 |
+
|
456 |
+
@staticmethod
|
457 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
458 |
+
for cand in fallbacks:
|
459 |
+
value = cand.get(key)
|
460 |
+
if value is not None:
|
461 |
+
return value
|
462 |
+
|
463 |
+
return default_value
|
464 |
+
|
465 |
+
|
466 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
467 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
468 |
+
|
469 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
470 |
+
if dataset_blueprint.is_controlnet:
|
471 |
+
subset_klass = ControlNetSubset
|
472 |
+
dataset_klass = ControlNetDataset
|
473 |
+
elif dataset_blueprint.is_dreambooth:
|
474 |
+
subset_klass = DreamBoothSubset
|
475 |
+
dataset_klass = DreamBoothDataset
|
476 |
+
else:
|
477 |
+
subset_klass = FineTuningSubset
|
478 |
+
dataset_klass = FineTuningDataset
|
479 |
+
|
480 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
481 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
482 |
+
datasets.append(dataset)
|
483 |
+
|
484 |
+
# print info
|
485 |
+
info = ""
|
486 |
+
for i, dataset in enumerate(datasets):
|
487 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
488 |
+
is_controlnet = isinstance(dataset, ControlNetDataset)
|
489 |
+
info += dedent(
|
490 |
+
f"""\
|
491 |
+
[Dataset {i}]
|
492 |
+
batch_size: {dataset.batch_size}
|
493 |
+
resolution: {(dataset.width, dataset.height)}
|
494 |
+
enable_bucket: {dataset.enable_bucket}
|
495 |
+
network_multiplier: {dataset.network_multiplier}
|
496 |
+
"""
|
497 |
+
)
|
498 |
+
|
499 |
+
if dataset.enable_bucket:
|
500 |
+
info += indent(
|
501 |
+
dedent(
|
502 |
+
f"""\
|
503 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
504 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
505 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
506 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
507 |
+
\n"""
|
508 |
+
),
|
509 |
+
" ",
|
510 |
+
)
|
511 |
+
else:
|
512 |
+
info += "\n"
|
513 |
+
|
514 |
+
for j, subset in enumerate(dataset.subsets):
|
515 |
+
info += indent(
|
516 |
+
dedent(
|
517 |
+
f"""\
|
518 |
+
[Subset {j} of Dataset {i}]
|
519 |
+
image_dir: "{subset.image_dir}"
|
520 |
+
image_count: {subset.img_count}
|
521 |
+
num_repeats: {subset.num_repeats}
|
522 |
+
shuffle_caption: {subset.shuffle_caption}
|
523 |
+
keep_tokens: {subset.keep_tokens}
|
524 |
+
keep_tokens_separator: {subset.keep_tokens_separator}
|
525 |
+
caption_separator: {subset.caption_separator}
|
526 |
+
secondary_separator: {subset.secondary_separator}
|
527 |
+
enable_wildcard: {subset.enable_wildcard}
|
528 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
529 |
+
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
530 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
531 |
+
caption_prefix: {subset.caption_prefix}
|
532 |
+
caption_suffix: {subset.caption_suffix}
|
533 |
+
color_aug: {subset.color_aug}
|
534 |
+
flip_aug: {subset.flip_aug}
|
535 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
536 |
+
random_crop: {subset.random_crop}
|
537 |
+
token_warmup_min: {subset.token_warmup_min}
|
538 |
+
token_warmup_step: {subset.token_warmup_step}
|
539 |
+
alpha_mask: {subset.alpha_mask}
|
540 |
+
custom_attributes: {subset.custom_attributes}
|
541 |
+
"""
|
542 |
+
),
|
543 |
+
" ",
|
544 |
+
)
|
545 |
+
|
546 |
+
if is_dreambooth:
|
547 |
+
info += indent(
|
548 |
+
dedent(
|
549 |
+
f"""\
|
550 |
+
is_reg: {subset.is_reg}
|
551 |
+
class_tokens: {subset.class_tokens}
|
552 |
+
caption_extension: {subset.caption_extension}
|
553 |
+
\n"""
|
554 |
+
),
|
555 |
+
" ",
|
556 |
+
)
|
557 |
+
elif not is_controlnet:
|
558 |
+
info += indent(
|
559 |
+
dedent(
|
560 |
+
f"""\
|
561 |
+
metadata_file: {subset.metadata_file}
|
562 |
+
\n"""
|
563 |
+
),
|
564 |
+
" ",
|
565 |
+
)
|
566 |
+
|
567 |
+
logger.info(f"{info}")
|
568 |
+
|
569 |
+
# make buckets first because it determines the length of dataset
|
570 |
+
# and set the same seed for all datasets
|
571 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
572 |
+
for i, dataset in enumerate(datasets):
|
573 |
+
logger.info(f"[Dataset {i}]")
|
574 |
+
dataset.make_buckets()
|
575 |
+
dataset.set_seed(seed)
|
576 |
+
|
577 |
+
return DatasetGroup(datasets)
|
578 |
+
|
579 |
+
|
580 |
+
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
581 |
+
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
582 |
+
tokens = name.split("_")
|
583 |
+
try:
|
584 |
+
n_repeats = int(tokens[0])
|
585 |
+
except ValueError as e:
|
586 |
+
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
587 |
+
return 0, ""
|
588 |
+
caption_by_folder = "_".join(tokens[1:])
|
589 |
+
return n_repeats, caption_by_folder
|
590 |
+
|
591 |
+
def generate(base_dir: Optional[str], is_reg: bool):
|
592 |
+
if base_dir is None:
|
593 |
+
return []
|
594 |
+
|
595 |
+
base_dir: Path = Path(base_dir)
|
596 |
+
if not base_dir.is_dir():
|
597 |
+
return []
|
598 |
+
|
599 |
+
subsets_config = []
|
600 |
+
for subdir in base_dir.iterdir():
|
601 |
+
if not subdir.is_dir():
|
602 |
+
continue
|
603 |
+
|
604 |
+
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
|
605 |
+
if num_repeats < 1:
|
606 |
+
continue
|
607 |
+
|
608 |
+
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
|
609 |
+
subsets_config.append(subset_config)
|
610 |
+
|
611 |
+
return subsets_config
|
612 |
+
|
613 |
+
subsets_config = []
|
614 |
+
subsets_config += generate(train_data_dir, False)
|
615 |
+
subsets_config += generate(reg_data_dir, True)
|
616 |
+
|
617 |
+
return subsets_config
|
618 |
+
|
619 |
+
|
620 |
+
def generate_controlnet_subsets_config_by_subdirs(
|
621 |
+
train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
|
622 |
+
):
|
623 |
+
def generate(base_dir: Optional[str]):
|
624 |
+
if base_dir is None:
|
625 |
+
return []
|
626 |
+
|
627 |
+
base_dir: Path = Path(base_dir)
|
628 |
+
if not base_dir.is_dir():
|
629 |
+
return []
|
630 |
+
|
631 |
+
subsets_config = []
|
632 |
+
subset_config = {
|
633 |
+
"image_dir": train_data_dir,
|
634 |
+
"conditioning_data_dir": conditioning_data_dir,
|
635 |
+
"caption_extension": caption_extension,
|
636 |
+
"num_repeats": 1,
|
637 |
+
}
|
638 |
+
subsets_config.append(subset_config)
|
639 |
+
|
640 |
+
return subsets_config
|
641 |
+
|
642 |
+
subsets_config = []
|
643 |
+
subsets_config += generate(train_data_dir)
|
644 |
+
|
645 |
+
return subsets_config
|
646 |
+
|
647 |
+
|
648 |
+
def load_user_config(file: str) -> dict:
|
649 |
+
file: Path = Path(file)
|
650 |
+
if not file.is_file():
|
651 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
652 |
+
|
653 |
+
if file.name.lower().endswith(".json"):
|
654 |
+
try:
|
655 |
+
with open(file, "r") as f:
|
656 |
+
config = json.load(f)
|
657 |
+
except Exception:
|
658 |
+
logger.error(
|
659 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
660 |
+
)
|
661 |
+
raise
|
662 |
+
elif file.name.lower().endswith(".toml"):
|
663 |
+
try:
|
664 |
+
config = toml.load(file)
|
665 |
+
except Exception:
|
666 |
+
logger.error(
|
667 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
668 |
+
)
|
669 |
+
raise
|
670 |
+
else:
|
671 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
672 |
+
|
673 |
+
return config
|
674 |
+
|
675 |
+
|
676 |
+
# for config test
|
677 |
+
if __name__ == "__main__":
|
678 |
+
parser = argparse.ArgumentParser()
|
679 |
+
parser.add_argument("--support_dreambooth", action="store_true")
|
680 |
+
parser.add_argument("--support_finetuning", action="store_true")
|
681 |
+
parser.add_argument("--support_controlnet", action="store_true")
|
682 |
+
parser.add_argument("--support_dropout", action="store_true")
|
683 |
+
parser.add_argument("dataset_config")
|
684 |
+
config_args, remain = parser.parse_known_args()
|
685 |
+
|
686 |
+
parser = argparse.ArgumentParser()
|
687 |
+
train_util.add_dataset_arguments(
|
688 |
+
parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
|
689 |
+
)
|
690 |
+
train_util.add_training_arguments(parser, config_args.support_dreambooth)
|
691 |
+
argparse_namespace = parser.parse_args(remain)
|
692 |
+
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
693 |
+
|
694 |
+
logger.info("[argparse_namespace]")
|
695 |
+
logger.info(f"{vars(argparse_namespace)}")
|
696 |
+
|
697 |
+
user_config = load_user_config(config_args.dataset_config)
|
698 |
+
|
699 |
+
logger.info("")
|
700 |
+
logger.info("[user_config]")
|
701 |
+
logger.info(f"{user_config}")
|
702 |
+
|
703 |
+
sanitizer = ConfigSanitizer(
|
704 |
+
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
705 |
+
)
|
706 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
707 |
+
|
708 |
+
logger.info("")
|
709 |
+
logger.info("[sanitized_user_config]")
|
710 |
+
logger.info(f"{sanitized_user_config}")
|
711 |
+
|
712 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
713 |
+
|
714 |
+
logger.info("")
|
715 |
+
logger.info("[blueprint]")
|
716 |
+
logger.info(f"{blueprint}")
|
library/custom_offloading_utils.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import ThreadPoolExecutor
|
2 |
+
import time
|
3 |
+
from typing import Optional
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from library.device_utils import clean_memory_on_device
|
8 |
+
|
9 |
+
|
10 |
+
def synchronize_device(device: torch.device):
|
11 |
+
if device.type == "cuda":
|
12 |
+
torch.cuda.synchronize()
|
13 |
+
elif device.type == "xpu":
|
14 |
+
torch.xpu.synchronize()
|
15 |
+
elif device.type == "mps":
|
16 |
+
torch.mps.synchronize()
|
17 |
+
|
18 |
+
|
19 |
+
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
20 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
21 |
+
|
22 |
+
weight_swap_jobs = []
|
23 |
+
|
24 |
+
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
25 |
+
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
26 |
+
# print(module_to_cpu.__class__, module_to_cuda.__class__)
|
27 |
+
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
28 |
+
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
29 |
+
|
30 |
+
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
|
31 |
+
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
|
32 |
+
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
|
33 |
+
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
|
34 |
+
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
|
35 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
36 |
+
else:
|
37 |
+
if module_to_cuda.weight.data.device.type != device.type:
|
38 |
+
# print(
|
39 |
+
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
|
40 |
+
# )
|
41 |
+
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
|
42 |
+
|
43 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
44 |
+
|
45 |
+
stream = torch.cuda.Stream()
|
46 |
+
with torch.cuda.stream(stream):
|
47 |
+
# cuda to cpu
|
48 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
49 |
+
cuda_data_view.record_stream(stream)
|
50 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
51 |
+
|
52 |
+
stream.synchronize()
|
53 |
+
|
54 |
+
# cpu to cuda
|
55 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
56 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
57 |
+
module_to_cuda.weight.data = cuda_data_view
|
58 |
+
|
59 |
+
stream.synchronize()
|
60 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
61 |
+
|
62 |
+
|
63 |
+
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
64 |
+
"""
|
65 |
+
not tested
|
66 |
+
"""
|
67 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
68 |
+
|
69 |
+
weight_swap_jobs = []
|
70 |
+
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
71 |
+
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
72 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
73 |
+
|
74 |
+
# device to cpu
|
75 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
76 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
77 |
+
|
78 |
+
synchronize_device()
|
79 |
+
|
80 |
+
# cpu to device
|
81 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
82 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
83 |
+
module_to_cuda.weight.data = cuda_data_view
|
84 |
+
|
85 |
+
synchronize_device()
|
86 |
+
|
87 |
+
|
88 |
+
def weighs_to_device(layer: nn.Module, device: torch.device):
|
89 |
+
for module in layer.modules():
|
90 |
+
if hasattr(module, "weight") and module.weight is not None:
|
91 |
+
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
92 |
+
|
93 |
+
|
94 |
+
class Offloader:
|
95 |
+
"""
|
96 |
+
common offloading class
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
100 |
+
self.num_blocks = num_blocks
|
101 |
+
self.blocks_to_swap = blocks_to_swap
|
102 |
+
self.device = device
|
103 |
+
self.debug = debug
|
104 |
+
|
105 |
+
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
106 |
+
self.futures = {}
|
107 |
+
self.cuda_available = device.type == "cuda"
|
108 |
+
|
109 |
+
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
|
110 |
+
if self.cuda_available:
|
111 |
+
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
|
112 |
+
else:
|
113 |
+
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
|
114 |
+
|
115 |
+
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
|
116 |
+
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
117 |
+
if self.debug:
|
118 |
+
start_time = time.perf_counter()
|
119 |
+
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
|
120 |
+
|
121 |
+
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
122 |
+
|
123 |
+
if self.debug:
|
124 |
+
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
125 |
+
return bidx_to_cpu, bidx_to_cuda # , event
|
126 |
+
|
127 |
+
block_to_cpu = blocks[block_idx_to_cpu]
|
128 |
+
block_to_cuda = blocks[block_idx_to_cuda]
|
129 |
+
|
130 |
+
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
|
131 |
+
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
|
132 |
+
)
|
133 |
+
|
134 |
+
def _wait_blocks_move(self, block_idx):
|
135 |
+
if block_idx not in self.futures:
|
136 |
+
return
|
137 |
+
|
138 |
+
if self.debug:
|
139 |
+
print(f"Wait for block {block_idx}")
|
140 |
+
start_time = time.perf_counter()
|
141 |
+
|
142 |
+
future = self.futures.pop(block_idx)
|
143 |
+
_, bidx_to_cuda = future.result()
|
144 |
+
|
145 |
+
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
146 |
+
|
147 |
+
if self.debug:
|
148 |
+
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
149 |
+
|
150 |
+
|
151 |
+
class ModelOffloader(Offloader):
|
152 |
+
"""
|
153 |
+
supports forward offloading
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
157 |
+
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
158 |
+
|
159 |
+
# register backward hooks
|
160 |
+
self.remove_handles = []
|
161 |
+
for i, block in enumerate(blocks):
|
162 |
+
hook = self.create_backward_hook(blocks, i)
|
163 |
+
if hook is not None:
|
164 |
+
handle = block.register_full_backward_hook(hook)
|
165 |
+
self.remove_handles.append(handle)
|
166 |
+
|
167 |
+
def __del__(self):
|
168 |
+
for handle in self.remove_handles:
|
169 |
+
handle.remove()
|
170 |
+
|
171 |
+
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
172 |
+
# -1 for 0-based index
|
173 |
+
num_blocks_propagated = self.num_blocks - block_index - 1
|
174 |
+
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
175 |
+
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
176 |
+
|
177 |
+
if not swapping and not waiting:
|
178 |
+
return None
|
179 |
+
|
180 |
+
# create hook
|
181 |
+
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
182 |
+
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
183 |
+
block_idx_to_wait = block_index - 1
|
184 |
+
|
185 |
+
def backward_hook(module, grad_input, grad_output):
|
186 |
+
if self.debug:
|
187 |
+
print(f"Backward hook for block {block_index}")
|
188 |
+
|
189 |
+
if swapping:
|
190 |
+
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
191 |
+
if waiting:
|
192 |
+
self._wait_blocks_move(block_idx_to_wait)
|
193 |
+
return None
|
194 |
+
|
195 |
+
return backward_hook
|
196 |
+
|
197 |
+
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
198 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
199 |
+
return
|
200 |
+
|
201 |
+
if self.debug:
|
202 |
+
print("Prepare block devices before forward")
|
203 |
+
|
204 |
+
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
205 |
+
b.to(self.device)
|
206 |
+
weighs_to_device(b, self.device) # make sure weights are on device
|
207 |
+
|
208 |
+
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
209 |
+
b.to(self.device) # move block to device first
|
210 |
+
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
211 |
+
|
212 |
+
synchronize_device(self.device)
|
213 |
+
clean_memory_on_device(self.device)
|
214 |
+
|
215 |
+
def wait_for_block(self, block_idx: int):
|
216 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
217 |
+
return
|
218 |
+
self._wait_blocks_move(block_idx)
|
219 |
+
|
220 |
+
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
221 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
222 |
+
return
|
223 |
+
if block_idx >= self.blocks_to_swap:
|
224 |
+
return
|
225 |
+
block_idx_to_cpu = block_idx
|
226 |
+
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
|
227 |
+
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
library/custom_train_functions.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
from .utils import setup_logging
|
7 |
+
|
8 |
+
setup_logging()
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
15 |
+
if hasattr(noise_scheduler, "all_snr"):
|
16 |
+
return
|
17 |
+
|
18 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
19 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
20 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
21 |
+
alpha = sqrt_alphas_cumprod
|
22 |
+
sigma = sqrt_one_minus_alphas_cumprod
|
23 |
+
all_snr = (alpha / sigma) ** 2
|
24 |
+
|
25 |
+
noise_scheduler.all_snr = all_snr.to(device)
|
26 |
+
|
27 |
+
|
28 |
+
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
29 |
+
# fix beta: zero terminal SNR
|
30 |
+
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
31 |
+
|
32 |
+
def enforce_zero_terminal_snr(betas):
|
33 |
+
# Convert betas to alphas_bar_sqrt
|
34 |
+
alphas = 1 - betas
|
35 |
+
alphas_bar = alphas.cumprod(0)
|
36 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
37 |
+
|
38 |
+
# Store old values.
|
39 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
40 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
41 |
+
# Shift so last timestep is zero.
|
42 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
43 |
+
# Scale so first timestep is back to old value.
|
44 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
45 |
+
|
46 |
+
# Convert alphas_bar_sqrt to betas
|
47 |
+
alphas_bar = alphas_bar_sqrt**2
|
48 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
49 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
50 |
+
betas = 1 - alphas
|
51 |
+
return betas
|
52 |
+
|
53 |
+
betas = noise_scheduler.betas
|
54 |
+
betas = enforce_zero_terminal_snr(betas)
|
55 |
+
alphas = 1.0 - betas
|
56 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
57 |
+
|
58 |
+
# logger.info(f"original: {noise_scheduler.betas}")
|
59 |
+
# logger.info(f"fixed: {betas}")
|
60 |
+
|
61 |
+
noise_scheduler.betas = betas
|
62 |
+
noise_scheduler.alphas = alphas
|
63 |
+
noise_scheduler.alphas_cumprod = alphas_cumprod
|
64 |
+
|
65 |
+
|
66 |
+
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
67 |
+
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
68 |
+
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
69 |
+
if v_prediction:
|
70 |
+
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
71 |
+
else:
|
72 |
+
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
73 |
+
loss = loss * snr_weight
|
74 |
+
return loss
|
75 |
+
|
76 |
+
|
77 |
+
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
78 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
79 |
+
loss = loss * scale
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def get_snr_scale(timesteps, noise_scheduler):
|
84 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
85 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
86 |
+
scale = snr_t / (snr_t + 1)
|
87 |
+
# # show debug info
|
88 |
+
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
89 |
+
return scale
|
90 |
+
|
91 |
+
|
92 |
+
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
93 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
94 |
+
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
95 |
+
loss = loss + loss / scale * v_pred_like_loss
|
96 |
+
return loss
|
97 |
+
|
98 |
+
|
99 |
+
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
|
100 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
101 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
102 |
+
if v_prediction:
|
103 |
+
weight = 1 / (snr_t + 1)
|
104 |
+
else:
|
105 |
+
weight = 1 / torch.sqrt(snr_t)
|
106 |
+
loss = weight * loss
|
107 |
+
return loss
|
108 |
+
|
109 |
+
|
110 |
+
# TODO train_utilと分散しているのでどちらかに寄せる
|
111 |
+
|
112 |
+
|
113 |
+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
114 |
+
parser.add_argument(
|
115 |
+
"--min_snr_gamma",
|
116 |
+
type=float,
|
117 |
+
default=None,
|
118 |
+
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--scale_v_pred_loss_like_noise_pred",
|
122 |
+
action="store_true",
|
123 |
+
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--v_pred_like_loss",
|
127 |
+
type=float,
|
128 |
+
default=None,
|
129 |
+
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--debiased_estimation_loss",
|
133 |
+
action="store_true",
|
134 |
+
help="debiased estimation loss / debiased estimation loss",
|
135 |
+
)
|
136 |
+
if support_weighted_captions:
|
137 |
+
parser.add_argument(
|
138 |
+
"--weighted_captions",
|
139 |
+
action="store_true",
|
140 |
+
default=False,
|
141 |
+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
re_attention = re.compile(
|
146 |
+
r"""
|
147 |
+
\\\(|
|
148 |
+
\\\)|
|
149 |
+
\\\[|
|
150 |
+
\\]|
|
151 |
+
\\\\|
|
152 |
+
\\|
|
153 |
+
\(|
|
154 |
+
\[|
|
155 |
+
:([+-]?[.\d]+)\)|
|
156 |
+
\)|
|
157 |
+
]|
|
158 |
+
[^\\()\[\]:]+|
|
159 |
+
:
|
160 |
+
""",
|
161 |
+
re.X,
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
def parse_prompt_attention(text):
|
166 |
+
"""
|
167 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
168 |
+
Accepted tokens are:
|
169 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
170 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
171 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
172 |
+
\( - literal character '('
|
173 |
+
\[ - literal character '['
|
174 |
+
\) - literal character ')'
|
175 |
+
\] - literal character ']'
|
176 |
+
\\ - literal character '\'
|
177 |
+
anything else - just text
|
178 |
+
>>> parse_prompt_attention('normal text')
|
179 |
+
[['normal text', 1.0]]
|
180 |
+
>>> parse_prompt_attention('an (important) word')
|
181 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
182 |
+
>>> parse_prompt_attention('(unbalanced')
|
183 |
+
[['unbalanced', 1.1]]
|
184 |
+
>>> parse_prompt_attention('\(literal\]')
|
185 |
+
[['(literal]', 1.0]]
|
186 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
187 |
+
[['unnecessaryparens', 1.1]]
|
188 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
189 |
+
[['a ', 1.0],
|
190 |
+
['house', 1.5730000000000004],
|
191 |
+
[' ', 1.1],
|
192 |
+
['on', 1.0],
|
193 |
+
[' a ', 1.1],
|
194 |
+
['hill', 0.55],
|
195 |
+
[', sun, ', 1.1],
|
196 |
+
['sky', 1.4641000000000006],
|
197 |
+
['.', 1.1]]
|
198 |
+
"""
|
199 |
+
|
200 |
+
res = []
|
201 |
+
round_brackets = []
|
202 |
+
square_brackets = []
|
203 |
+
|
204 |
+
round_bracket_multiplier = 1.1
|
205 |
+
square_bracket_multiplier = 1 / 1.1
|
206 |
+
|
207 |
+
def multiply_range(start_position, multiplier):
|
208 |
+
for p in range(start_position, len(res)):
|
209 |
+
res[p][1] *= multiplier
|
210 |
+
|
211 |
+
for m in re_attention.finditer(text):
|
212 |
+
text = m.group(0)
|
213 |
+
weight = m.group(1)
|
214 |
+
|
215 |
+
if text.startswith("\\"):
|
216 |
+
res.append([text[1:], 1.0])
|
217 |
+
elif text == "(":
|
218 |
+
round_brackets.append(len(res))
|
219 |
+
elif text == "[":
|
220 |
+
square_brackets.append(len(res))
|
221 |
+
elif weight is not None and len(round_brackets) > 0:
|
222 |
+
multiply_range(round_brackets.pop(), float(weight))
|
223 |
+
elif text == ")" and len(round_brackets) > 0:
|
224 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
225 |
+
elif text == "]" and len(square_brackets) > 0:
|
226 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
227 |
+
else:
|
228 |
+
res.append([text, 1.0])
|
229 |
+
|
230 |
+
for pos in round_brackets:
|
231 |
+
multiply_range(pos, round_bracket_multiplier)
|
232 |
+
|
233 |
+
for pos in square_brackets:
|
234 |
+
multiply_range(pos, square_bracket_multiplier)
|
235 |
+
|
236 |
+
if len(res) == 0:
|
237 |
+
res = [["", 1.0]]
|
238 |
+
|
239 |
+
# merge runs of identical weights
|
240 |
+
i = 0
|
241 |
+
while i + 1 < len(res):
|
242 |
+
if res[i][1] == res[i + 1][1]:
|
243 |
+
res[i][0] += res[i + 1][0]
|
244 |
+
res.pop(i + 1)
|
245 |
+
else:
|
246 |
+
i += 1
|
247 |
+
|
248 |
+
return res
|
249 |
+
|
250 |
+
|
251 |
+
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
252 |
+
r"""
|
253 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
254 |
+
|
255 |
+
No padding, starting or ending token is included.
|
256 |
+
"""
|
257 |
+
tokens = []
|
258 |
+
weights = []
|
259 |
+
truncated = False
|
260 |
+
for text in prompt:
|
261 |
+
texts_and_weights = parse_prompt_attention(text)
|
262 |
+
text_token = []
|
263 |
+
text_weight = []
|
264 |
+
for word, weight in texts_and_weights:
|
265 |
+
# tokenize and discard the starting and the ending token
|
266 |
+
token = tokenizer(word).input_ids[1:-1]
|
267 |
+
text_token += token
|
268 |
+
# copy the weight by length of token
|
269 |
+
text_weight += [weight] * len(token)
|
270 |
+
# stop if the text is too long (longer than truncation limit)
|
271 |
+
if len(text_token) > max_length:
|
272 |
+
truncated = True
|
273 |
+
break
|
274 |
+
# truncate
|
275 |
+
if len(text_token) > max_length:
|
276 |
+
truncated = True
|
277 |
+
text_token = text_token[:max_length]
|
278 |
+
text_weight = text_weight[:max_length]
|
279 |
+
tokens.append(text_token)
|
280 |
+
weights.append(text_weight)
|
281 |
+
if truncated:
|
282 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
283 |
+
return tokens, weights
|
284 |
+
|
285 |
+
|
286 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
287 |
+
r"""
|
288 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
289 |
+
"""
|
290 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
291 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
292 |
+
for i in range(len(tokens)):
|
293 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
294 |
+
if no_boseos_middle:
|
295 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
296 |
+
else:
|
297 |
+
w = []
|
298 |
+
if len(weights[i]) == 0:
|
299 |
+
w = [1.0] * weights_length
|
300 |
+
else:
|
301 |
+
for j in range(max_embeddings_multiples):
|
302 |
+
w.append(1.0) # weight for starting token in this chunk
|
303 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
304 |
+
w.append(1.0) # weight for ending token in this chunk
|
305 |
+
w += [1.0] * (weights_length - len(w))
|
306 |
+
weights[i] = w[:]
|
307 |
+
|
308 |
+
return tokens, weights
|
309 |
+
|
310 |
+
|
311 |
+
def get_unweighted_text_embeddings(
|
312 |
+
tokenizer,
|
313 |
+
text_encoder,
|
314 |
+
text_input: torch.Tensor,
|
315 |
+
chunk_length: int,
|
316 |
+
clip_skip: int,
|
317 |
+
eos: int,
|
318 |
+
pad: int,
|
319 |
+
no_boseos_middle: Optional[bool] = True,
|
320 |
+
):
|
321 |
+
"""
|
322 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
323 |
+
it should be split into chunks and sent to the text encoder individually.
|
324 |
+
"""
|
325 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
326 |
+
if max_embeddings_multiples > 1:
|
327 |
+
text_embeddings = []
|
328 |
+
for i in range(max_embeddings_multiples):
|
329 |
+
# extract the i-th chunk
|
330 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
331 |
+
|
332 |
+
# cover the head and the tail by the starting and the ending tokens
|
333 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
334 |
+
if pad == eos: # v1
|
335 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
336 |
+
else: # v2
|
337 |
+
for j in range(len(text_input_chunk)):
|
338 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
339 |
+
text_input_chunk[j, -1] = eos
|
340 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
341 |
+
text_input_chunk[j, 1] = eos
|
342 |
+
|
343 |
+
if clip_skip is None or clip_skip == 1:
|
344 |
+
text_embedding = text_encoder(text_input_chunk)[0]
|
345 |
+
else:
|
346 |
+
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
347 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
348 |
+
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
349 |
+
|
350 |
+
if no_boseos_middle:
|
351 |
+
if i == 0:
|
352 |
+
# discard the ending token
|
353 |
+
text_embedding = text_embedding[:, :-1]
|
354 |
+
elif i == max_embeddings_multiples - 1:
|
355 |
+
# discard the starting token
|
356 |
+
text_embedding = text_embedding[:, 1:]
|
357 |
+
else:
|
358 |
+
# discard both starting and ending tokens
|
359 |
+
text_embedding = text_embedding[:, 1:-1]
|
360 |
+
|
361 |
+
text_embeddings.append(text_embedding)
|
362 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
363 |
+
else:
|
364 |
+
if clip_skip is None or clip_skip == 1:
|
365 |
+
text_embeddings = text_encoder(text_input)[0]
|
366 |
+
else:
|
367 |
+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
368 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
369 |
+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
370 |
+
return text_embeddings
|
371 |
+
|
372 |
+
|
373 |
+
def get_weighted_text_embeddings(
|
374 |
+
tokenizer,
|
375 |
+
text_encoder,
|
376 |
+
prompt: Union[str, List[str]],
|
377 |
+
device,
|
378 |
+
max_embeddings_multiples: Optional[int] = 3,
|
379 |
+
no_boseos_middle: Optional[bool] = False,
|
380 |
+
clip_skip=None,
|
381 |
+
):
|
382 |
+
r"""
|
383 |
+
Prompts can be assigned with local weights using brackets. For example,
|
384 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
385 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
386 |
+
|
387 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
388 |
+
|
389 |
+
Args:
|
390 |
+
prompt (`str` or `List[str]`):
|
391 |
+
The prompt or prompts to guide the image generation.
|
392 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
393 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
394 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
395 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
396 |
+
ending token in each of the chunk in the middle.
|
397 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
398 |
+
Skip the parsing of brackets.
|
399 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
400 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
401 |
+
"""
|
402 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
403 |
+
if isinstance(prompt, str):
|
404 |
+
prompt = [prompt]
|
405 |
+
|
406 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
407 |
+
|
408 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
409 |
+
max_length = max([len(token) for token in prompt_tokens])
|
410 |
+
|
411 |
+
max_embeddings_multiples = min(
|
412 |
+
max_embeddings_multiples,
|
413 |
+
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
414 |
+
)
|
415 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
416 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
417 |
+
|
418 |
+
# pad the length of tokens and weights
|
419 |
+
bos = tokenizer.bos_token_id
|
420 |
+
eos = tokenizer.eos_token_id
|
421 |
+
pad = tokenizer.pad_token_id
|
422 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
423 |
+
prompt_tokens,
|
424 |
+
prompt_weights,
|
425 |
+
max_length,
|
426 |
+
bos,
|
427 |
+
eos,
|
428 |
+
no_boseos_middle=no_boseos_middle,
|
429 |
+
chunk_length=tokenizer.model_max_length,
|
430 |
+
)
|
431 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
432 |
+
|
433 |
+
# get the embeddings
|
434 |
+
text_embeddings = get_unweighted_text_embeddings(
|
435 |
+
tokenizer,
|
436 |
+
text_encoder,
|
437 |
+
prompt_tokens,
|
438 |
+
tokenizer.model_max_length,
|
439 |
+
clip_skip,
|
440 |
+
eos,
|
441 |
+
pad,
|
442 |
+
no_boseos_middle=no_boseos_middle,
|
443 |
+
)
|
444 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
445 |
+
|
446 |
+
# assign weights to the prompts and normalize in the sense of mean
|
447 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
448 |
+
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
449 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
450 |
+
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
451 |
+
|
452 |
+
return text_embeddings
|
453 |
+
|
454 |
+
|
455 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
456 |
+
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
457 |
+
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
458 |
+
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
459 |
+
for i in range(iterations):
|
460 |
+
r = random.random() * 2 + 2 # Rather than always going 2x,
|
461 |
+
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
462 |
+
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
463 |
+
if wn == 1 or hn == 1:
|
464 |
+
break # Lowest resolution is 1x1
|
465 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
466 |
+
|
467 |
+
|
468 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
469 |
+
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
470 |
+
if noise_offset is None:
|
471 |
+
return noise
|
472 |
+
if adaptive_noise_scale is not None:
|
473 |
+
# latent shape: (batch_size, channels, height, width)
|
474 |
+
# abs mean value for each channel
|
475 |
+
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
476 |
+
|
477 |
+
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
478 |
+
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
479 |
+
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
480 |
+
|
481 |
+
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
482 |
+
return noise
|
483 |
+
|
484 |
+
|
485 |
+
def apply_masked_loss(loss, batch):
|
486 |
+
if "conditioning_images" in batch:
|
487 |
+
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
488 |
+
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
489 |
+
mask_image = mask_image / 2 + 0.5
|
490 |
+
# print(f"conditioning_image: {mask_image.shape}")
|
491 |
+
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
492 |
+
# alpha mask is 0 to 1
|
493 |
+
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
494 |
+
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
495 |
+
else:
|
496 |
+
return loss
|
497 |
+
|
498 |
+
# resize to the same size as the loss
|
499 |
+
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
500 |
+
loss = loss * mask_image
|
501 |
+
return loss
|
502 |
+
|
503 |
+
|
504 |
+
"""
|
505 |
+
##########################################
|
506 |
+
# Perlin Noise
|
507 |
+
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
508 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
509 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
510 |
+
|
511 |
+
grid = (
|
512 |
+
torch.stack(
|
513 |
+
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
514 |
+
dim=-1,
|
515 |
+
)
|
516 |
+
% 1
|
517 |
+
)
|
518 |
+
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
519 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
520 |
+
|
521 |
+
tile_grads = (
|
522 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
523 |
+
.repeat_interleave(d[0], 0)
|
524 |
+
.repeat_interleave(d[1], 1)
|
525 |
+
)
|
526 |
+
dot = lambda grad, shift: (
|
527 |
+
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
528 |
+
* grad[: shape[0], : shape[1]]
|
529 |
+
).sum(dim=-1)
|
530 |
+
|
531 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
532 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
533 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
534 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
535 |
+
t = fade(grid[: shape[0], : shape[1]])
|
536 |
+
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
537 |
+
|
538 |
+
|
539 |
+
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
540 |
+
noise = torch.zeros(shape, device=device)
|
541 |
+
frequency = 1
|
542 |
+
amplitude = 1
|
543 |
+
for _ in range(octaves):
|
544 |
+
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
545 |
+
frequency *= 2
|
546 |
+
amplitude *= persistence
|
547 |
+
return noise
|
548 |
+
|
549 |
+
|
550 |
+
def perlin_noise(noise, device, octaves):
|
551 |
+
_, c, w, h = noise.shape
|
552 |
+
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
553 |
+
noise_perlin = []
|
554 |
+
for _ in range(c):
|
555 |
+
noise_perlin.append(perlin())
|
556 |
+
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
557 |
+
noise += noise_perlin # broadcast for each batch
|
558 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
559 |
+
"""
|
library/deepspeed_utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from accelerate import DeepSpeedPlugin, Accelerator
|
5 |
+
|
6 |
+
from .utils import setup_logging
|
7 |
+
|
8 |
+
setup_logging()
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
15 |
+
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
16 |
+
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
17 |
+
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
18 |
+
parser.add_argument(
|
19 |
+
"--offload_optimizer_device",
|
20 |
+
type=str,
|
21 |
+
default=None,
|
22 |
+
choices=[None, "cpu", "nvme"],
|
23 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--offload_optimizer_nvme_path",
|
27 |
+
type=str,
|
28 |
+
default=None,
|
29 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--offload_param_device",
|
33 |
+
type=str,
|
34 |
+
default=None,
|
35 |
+
choices=[None, "cpu", "nvme"],
|
36 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--offload_param_nvme_path",
|
40 |
+
type=str,
|
41 |
+
default=None,
|
42 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--zero3_init_flag",
|
46 |
+
action="store_true",
|
47 |
+
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
48 |
+
"Only applicable with ZeRO Stage-3.",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--zero3_save_16bit_model",
|
52 |
+
action="store_true",
|
53 |
+
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--fp16_master_weights_and_gradients",
|
57 |
+
action="store_true",
|
58 |
+
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def prepare_deepspeed_args(args: argparse.Namespace):
|
63 |
+
if not args.deepspeed:
|
64 |
+
return
|
65 |
+
|
66 |
+
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
67 |
+
args.max_data_loader_n_workers = 1
|
68 |
+
|
69 |
+
|
70 |
+
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
71 |
+
if not args.deepspeed:
|
72 |
+
return None
|
73 |
+
|
74 |
+
try:
|
75 |
+
import deepspeed
|
76 |
+
except ImportError as e:
|
77 |
+
logger.error(
|
78 |
+
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
79 |
+
)
|
80 |
+
exit(1)
|
81 |
+
|
82 |
+
deepspeed_plugin = DeepSpeedPlugin(
|
83 |
+
zero_stage=args.zero_stage,
|
84 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
85 |
+
gradient_clipping=args.max_grad_norm,
|
86 |
+
offload_optimizer_device=args.offload_optimizer_device,
|
87 |
+
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
88 |
+
offload_param_device=args.offload_param_device,
|
89 |
+
offload_param_nvme_path=args.offload_param_nvme_path,
|
90 |
+
zero3_init_flag=args.zero3_init_flag,
|
91 |
+
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
92 |
+
)
|
93 |
+
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
94 |
+
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
95 |
+
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
96 |
+
)
|
97 |
+
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
98 |
+
if args.mixed_precision.lower() == "fp16":
|
99 |
+
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
100 |
+
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
101 |
+
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
102 |
+
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
103 |
+
logger.info("[DeepSpeed] full fp16 enable.")
|
104 |
+
else:
|
105 |
+
logger.info(
|
106 |
+
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
107 |
+
)
|
108 |
+
|
109 |
+
if args.offload_optimizer_device is not None:
|
110 |
+
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
111 |
+
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
112 |
+
logger.info("[DeepSpeed] building cpu_adam done.")
|
113 |
+
|
114 |
+
return deepspeed_plugin
|
115 |
+
|
116 |
+
|
117 |
+
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
118 |
+
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
119 |
+
# remove None from models
|
120 |
+
models = {k: v for k, v in models.items() if v is not None}
|
121 |
+
|
122 |
+
class DeepSpeedWrapper(torch.nn.Module):
|
123 |
+
def __init__(self, **kw_models) -> None:
|
124 |
+
super().__init__()
|
125 |
+
self.models = torch.nn.ModuleDict()
|
126 |
+
|
127 |
+
for key, model in kw_models.items():
|
128 |
+
if isinstance(model, list):
|
129 |
+
model = torch.nn.ModuleList(model)
|
130 |
+
assert isinstance(
|
131 |
+
model, torch.nn.Module
|
132 |
+
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
133 |
+
self.models.update(torch.nn.ModuleDict({key: model}))
|
134 |
+
|
135 |
+
def get_models(self):
|
136 |
+
return self.models
|
137 |
+
|
138 |
+
ds_model = DeepSpeedWrapper(**models)
|
139 |
+
return ds_model
|
library/device_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import gc
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
try:
|
7 |
+
HAS_CUDA = torch.cuda.is_available()
|
8 |
+
except Exception:
|
9 |
+
HAS_CUDA = False
|
10 |
+
|
11 |
+
try:
|
12 |
+
HAS_MPS = torch.backends.mps.is_available()
|
13 |
+
except Exception:
|
14 |
+
HAS_MPS = False
|
15 |
+
|
16 |
+
try:
|
17 |
+
import intel_extension_for_pytorch as ipex # noqa
|
18 |
+
|
19 |
+
HAS_XPU = torch.xpu.is_available()
|
20 |
+
except Exception:
|
21 |
+
HAS_XPU = False
|
22 |
+
|
23 |
+
|
24 |
+
def clean_memory():
|
25 |
+
gc.collect()
|
26 |
+
if HAS_CUDA:
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
if HAS_XPU:
|
29 |
+
torch.xpu.empty_cache()
|
30 |
+
if HAS_MPS:
|
31 |
+
torch.mps.empty_cache()
|
32 |
+
|
33 |
+
|
34 |
+
def clean_memory_on_device(device: torch.device):
|
35 |
+
r"""
|
36 |
+
Clean memory on the specified device, will be called from training scripts.
|
37 |
+
"""
|
38 |
+
gc.collect()
|
39 |
+
|
40 |
+
# device may "cuda" or "cuda:0", so we need to check the type of device
|
41 |
+
if device.type == "cuda":
|
42 |
+
torch.cuda.empty_cache()
|
43 |
+
if device.type == "xpu":
|
44 |
+
torch.xpu.empty_cache()
|
45 |
+
if device.type == "mps":
|
46 |
+
torch.mps.empty_cache()
|
47 |
+
|
48 |
+
|
49 |
+
@functools.lru_cache(maxsize=None)
|
50 |
+
def get_preferred_device() -> torch.device:
|
51 |
+
r"""
|
52 |
+
Do not call this function from training scripts. Use accelerator.device instead.
|
53 |
+
"""
|
54 |
+
if HAS_CUDA:
|
55 |
+
device = torch.device("cuda")
|
56 |
+
elif HAS_XPU:
|
57 |
+
device = torch.device("xpu")
|
58 |
+
elif HAS_MPS:
|
59 |
+
device = torch.device("mps")
|
60 |
+
else:
|
61 |
+
device = torch.device("cpu")
|
62 |
+
print(f"get_preferred_device() -> {device}")
|
63 |
+
return device
|
64 |
+
|
65 |
+
|
66 |
+
def init_ipex():
|
67 |
+
"""
|
68 |
+
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
69 |
+
|
70 |
+
This function should run right after importing torch and before doing anything else.
|
71 |
+
|
72 |
+
If IPEX is not available, this function does nothing.
|
73 |
+
"""
|
74 |
+
try:
|
75 |
+
if HAS_XPU:
|
76 |
+
from library.ipex import ipex_init
|
77 |
+
|
78 |
+
is_initialized, error_message = ipex_init()
|
79 |
+
if not is_initialized:
|
80 |
+
print("failed to initialize ipex:", error_message)
|
81 |
+
else:
|
82 |
+
return
|
83 |
+
except Exception as e:
|
84 |
+
print("failed to initialize ipex:", e)
|
library/flux_models.py
ADDED
@@ -0,0 +1,1237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from FLUX repo: https://github.com/black-forest-labs/flux
|
2 |
+
# license: Apache-2.0 License
|
3 |
+
|
4 |
+
|
5 |
+
from concurrent.futures import Future, ThreadPoolExecutor
|
6 |
+
from dataclasses import dataclass
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from typing import Dict, List, Optional, Union
|
11 |
+
|
12 |
+
from library import utils
|
13 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
14 |
+
|
15 |
+
init_ipex()
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from einops import rearrange
|
19 |
+
from torch import Tensor, nn
|
20 |
+
from torch.utils.checkpoint import checkpoint
|
21 |
+
from library import custom_offloading_utils
|
22 |
+
|
23 |
+
# USE_REENTRANT = True
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class FluxParams:
|
28 |
+
in_channels: int
|
29 |
+
vec_in_dim: int
|
30 |
+
context_in_dim: int
|
31 |
+
hidden_size: int
|
32 |
+
mlp_ratio: float
|
33 |
+
num_heads: int
|
34 |
+
depth: int
|
35 |
+
depth_single_blocks: int
|
36 |
+
axes_dim: list[int]
|
37 |
+
theta: int
|
38 |
+
qkv_bias: bool
|
39 |
+
guidance_embed: bool
|
40 |
+
|
41 |
+
|
42 |
+
# region autoencoder
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class AutoEncoderParams:
|
47 |
+
resolution: int
|
48 |
+
in_channels: int
|
49 |
+
ch: int
|
50 |
+
out_ch: int
|
51 |
+
ch_mult: list[int]
|
52 |
+
num_res_blocks: int
|
53 |
+
z_channels: int
|
54 |
+
scale_factor: float
|
55 |
+
shift_factor: float
|
56 |
+
|
57 |
+
|
58 |
+
def swish(x: Tensor) -> Tensor:
|
59 |
+
return x * torch.sigmoid(x)
|
60 |
+
|
61 |
+
|
62 |
+
class AttnBlock(nn.Module):
|
63 |
+
def __init__(self, in_channels: int):
|
64 |
+
super().__init__()
|
65 |
+
self.in_channels = in_channels
|
66 |
+
|
67 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
68 |
+
|
69 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
70 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
71 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
72 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
73 |
+
|
74 |
+
def attention(self, h_: Tensor) -> Tensor:
|
75 |
+
h_ = self.norm(h_)
|
76 |
+
q = self.q(h_)
|
77 |
+
k = self.k(h_)
|
78 |
+
v = self.v(h_)
|
79 |
+
|
80 |
+
b, c, h, w = q.shape
|
81 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
82 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
83 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
84 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
85 |
+
|
86 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
87 |
+
|
88 |
+
def forward(self, x: Tensor) -> Tensor:
|
89 |
+
return x + self.proj_out(self.attention(x))
|
90 |
+
|
91 |
+
|
92 |
+
class ResnetBlock(nn.Module):
|
93 |
+
def __init__(self, in_channels: int, out_channels: int):
|
94 |
+
super().__init__()
|
95 |
+
self.in_channels = in_channels
|
96 |
+
out_channels = in_channels if out_channels is None else out_channels
|
97 |
+
self.out_channels = out_channels
|
98 |
+
|
99 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
100 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
101 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
102 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
103 |
+
if self.in_channels != self.out_channels:
|
104 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
h = x
|
108 |
+
h = self.norm1(h)
|
109 |
+
h = swish(h)
|
110 |
+
h = self.conv1(h)
|
111 |
+
|
112 |
+
h = self.norm2(h)
|
113 |
+
h = swish(h)
|
114 |
+
h = self.conv2(h)
|
115 |
+
|
116 |
+
if self.in_channels != self.out_channels:
|
117 |
+
x = self.nin_shortcut(x)
|
118 |
+
|
119 |
+
return x + h
|
120 |
+
|
121 |
+
|
122 |
+
class Downsample(nn.Module):
|
123 |
+
def __init__(self, in_channels: int):
|
124 |
+
super().__init__()
|
125 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
126 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
127 |
+
|
128 |
+
def forward(self, x: Tensor):
|
129 |
+
pad = (0, 1, 0, 1)
|
130 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
131 |
+
x = self.conv(x)
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
class Upsample(nn.Module):
|
136 |
+
def __init__(self, in_channels: int):
|
137 |
+
super().__init__()
|
138 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
139 |
+
|
140 |
+
def forward(self, x: Tensor):
|
141 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
142 |
+
x = self.conv(x)
|
143 |
+
return x
|
144 |
+
|
145 |
+
|
146 |
+
class Encoder(nn.Module):
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
resolution: int,
|
150 |
+
in_channels: int,
|
151 |
+
ch: int,
|
152 |
+
ch_mult: list[int],
|
153 |
+
num_res_blocks: int,
|
154 |
+
z_channels: int,
|
155 |
+
):
|
156 |
+
super().__init__()
|
157 |
+
self.ch = ch
|
158 |
+
self.num_resolutions = len(ch_mult)
|
159 |
+
self.num_res_blocks = num_res_blocks
|
160 |
+
self.resolution = resolution
|
161 |
+
self.in_channels = in_channels
|
162 |
+
# downsampling
|
163 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
164 |
+
|
165 |
+
curr_res = resolution
|
166 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
167 |
+
self.in_ch_mult = in_ch_mult
|
168 |
+
self.down = nn.ModuleList()
|
169 |
+
block_in = self.ch
|
170 |
+
for i_level in range(self.num_resolutions):
|
171 |
+
block = nn.ModuleList()
|
172 |
+
attn = nn.ModuleList()
|
173 |
+
block_in = ch * in_ch_mult[i_level]
|
174 |
+
block_out = ch * ch_mult[i_level]
|
175 |
+
for _ in range(self.num_res_blocks):
|
176 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
177 |
+
block_in = block_out
|
178 |
+
down = nn.Module()
|
179 |
+
down.block = block
|
180 |
+
down.attn = attn
|
181 |
+
if i_level != self.num_resolutions - 1:
|
182 |
+
down.downsample = Downsample(block_in)
|
183 |
+
curr_res = curr_res // 2
|
184 |
+
self.down.append(down)
|
185 |
+
|
186 |
+
# middle
|
187 |
+
self.mid = nn.Module()
|
188 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
189 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
190 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
191 |
+
|
192 |
+
# end
|
193 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
194 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
195 |
+
|
196 |
+
def forward(self, x: Tensor) -> Tensor:
|
197 |
+
# downsampling
|
198 |
+
hs = [self.conv_in(x)]
|
199 |
+
for i_level in range(self.num_resolutions):
|
200 |
+
for i_block in range(self.num_res_blocks):
|
201 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
202 |
+
if len(self.down[i_level].attn) > 0:
|
203 |
+
h = self.down[i_level].attn[i_block](h)
|
204 |
+
hs.append(h)
|
205 |
+
if i_level != self.num_resolutions - 1:
|
206 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
207 |
+
|
208 |
+
# middle
|
209 |
+
h = hs[-1]
|
210 |
+
h = self.mid.block_1(h)
|
211 |
+
h = self.mid.attn_1(h)
|
212 |
+
h = self.mid.block_2(h)
|
213 |
+
# end
|
214 |
+
h = self.norm_out(h)
|
215 |
+
h = swish(h)
|
216 |
+
h = self.conv_out(h)
|
217 |
+
return h
|
218 |
+
|
219 |
+
|
220 |
+
class Decoder(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
ch: int,
|
224 |
+
out_ch: int,
|
225 |
+
ch_mult: list[int],
|
226 |
+
num_res_blocks: int,
|
227 |
+
in_channels: int,
|
228 |
+
resolution: int,
|
229 |
+
z_channels: int,
|
230 |
+
):
|
231 |
+
super().__init__()
|
232 |
+
self.ch = ch
|
233 |
+
self.num_resolutions = len(ch_mult)
|
234 |
+
self.num_res_blocks = num_res_blocks
|
235 |
+
self.resolution = resolution
|
236 |
+
self.in_channels = in_channels
|
237 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
238 |
+
|
239 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
240 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
241 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
242 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
243 |
+
|
244 |
+
# z to block_in
|
245 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
246 |
+
|
247 |
+
# middle
|
248 |
+
self.mid = nn.Module()
|
249 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
250 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
251 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
252 |
+
|
253 |
+
# upsampling
|
254 |
+
self.up = nn.ModuleList()
|
255 |
+
for i_level in reversed(range(self.num_resolutions)):
|
256 |
+
block = nn.ModuleList()
|
257 |
+
attn = nn.ModuleList()
|
258 |
+
block_out = ch * ch_mult[i_level]
|
259 |
+
for _ in range(self.num_res_blocks + 1):
|
260 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
261 |
+
block_in = block_out
|
262 |
+
up = nn.Module()
|
263 |
+
up.block = block
|
264 |
+
up.attn = attn
|
265 |
+
if i_level != 0:
|
266 |
+
up.upsample = Upsample(block_in)
|
267 |
+
curr_res = curr_res * 2
|
268 |
+
self.up.insert(0, up) # prepend to get consistent order
|
269 |
+
|
270 |
+
# end
|
271 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
272 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
273 |
+
|
274 |
+
def forward(self, z: Tensor) -> Tensor:
|
275 |
+
# z to block_in
|
276 |
+
h = self.conv_in(z)
|
277 |
+
|
278 |
+
# middle
|
279 |
+
h = self.mid.block_1(h)
|
280 |
+
h = self.mid.attn_1(h)
|
281 |
+
h = self.mid.block_2(h)
|
282 |
+
|
283 |
+
# upsampling
|
284 |
+
for i_level in reversed(range(self.num_resolutions)):
|
285 |
+
for i_block in range(self.num_res_blocks + 1):
|
286 |
+
h = self.up[i_level].block[i_block](h)
|
287 |
+
if len(self.up[i_level].attn) > 0:
|
288 |
+
h = self.up[i_level].attn[i_block](h)
|
289 |
+
if i_level != 0:
|
290 |
+
h = self.up[i_level].upsample(h)
|
291 |
+
|
292 |
+
# end
|
293 |
+
h = self.norm_out(h)
|
294 |
+
h = swish(h)
|
295 |
+
h = self.conv_out(h)
|
296 |
+
return h
|
297 |
+
|
298 |
+
|
299 |
+
class DiagonalGaussian(nn.Module):
|
300 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
301 |
+
super().__init__()
|
302 |
+
self.sample = sample
|
303 |
+
self.chunk_dim = chunk_dim
|
304 |
+
|
305 |
+
def forward(self, z: Tensor) -> Tensor:
|
306 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
307 |
+
if self.sample:
|
308 |
+
std = torch.exp(0.5 * logvar)
|
309 |
+
return mean + std * torch.randn_like(mean)
|
310 |
+
else:
|
311 |
+
return mean
|
312 |
+
|
313 |
+
|
314 |
+
class AutoEncoder(nn.Module):
|
315 |
+
def __init__(self, params: AutoEncoderParams):
|
316 |
+
super().__init__()
|
317 |
+
self.encoder = Encoder(
|
318 |
+
resolution=params.resolution,
|
319 |
+
in_channels=params.in_channels,
|
320 |
+
ch=params.ch,
|
321 |
+
ch_mult=params.ch_mult,
|
322 |
+
num_res_blocks=params.num_res_blocks,
|
323 |
+
z_channels=params.z_channels,
|
324 |
+
)
|
325 |
+
self.decoder = Decoder(
|
326 |
+
resolution=params.resolution,
|
327 |
+
in_channels=params.in_channels,
|
328 |
+
ch=params.ch,
|
329 |
+
out_ch=params.out_ch,
|
330 |
+
ch_mult=params.ch_mult,
|
331 |
+
num_res_blocks=params.num_res_blocks,
|
332 |
+
z_channels=params.z_channels,
|
333 |
+
)
|
334 |
+
self.reg = DiagonalGaussian()
|
335 |
+
|
336 |
+
self.scale_factor = params.scale_factor
|
337 |
+
self.shift_factor = params.shift_factor
|
338 |
+
|
339 |
+
@property
|
340 |
+
def device(self) -> torch.device:
|
341 |
+
return next(self.parameters()).device
|
342 |
+
|
343 |
+
@property
|
344 |
+
def dtype(self) -> torch.dtype:
|
345 |
+
return next(self.parameters()).dtype
|
346 |
+
|
347 |
+
def encode(self, x: Tensor) -> Tensor:
|
348 |
+
z = self.reg(self.encoder(x))
|
349 |
+
z = self.scale_factor * (z - self.shift_factor)
|
350 |
+
return z
|
351 |
+
|
352 |
+
def decode(self, z: Tensor) -> Tensor:
|
353 |
+
z = z / self.scale_factor + self.shift_factor
|
354 |
+
return self.decoder(z)
|
355 |
+
|
356 |
+
def forward(self, x: Tensor) -> Tensor:
|
357 |
+
return self.decode(self.encode(x))
|
358 |
+
|
359 |
+
|
360 |
+
# endregion
|
361 |
+
# region config
|
362 |
+
|
363 |
+
|
364 |
+
@dataclass
|
365 |
+
class ModelSpec:
|
366 |
+
params: FluxParams
|
367 |
+
ae_params: AutoEncoderParams
|
368 |
+
ckpt_path: str | None
|
369 |
+
ae_path: str | None
|
370 |
+
# repo_id: str | None
|
371 |
+
# repo_flow: str | None
|
372 |
+
# repo_ae: str | None
|
373 |
+
|
374 |
+
|
375 |
+
configs = {
|
376 |
+
"dev": ModelSpec(
|
377 |
+
# repo_id="black-forest-labs/FLUX.1-dev",
|
378 |
+
# repo_flow="flux1-dev.sft",
|
379 |
+
# repo_ae="ae.sft",
|
380 |
+
ckpt_path=None, # os.getenv("FLUX_DEV"),
|
381 |
+
params=FluxParams(
|
382 |
+
in_channels=64,
|
383 |
+
vec_in_dim=768,
|
384 |
+
context_in_dim=4096,
|
385 |
+
hidden_size=3072,
|
386 |
+
mlp_ratio=4.0,
|
387 |
+
num_heads=24,
|
388 |
+
depth=19,
|
389 |
+
depth_single_blocks=38,
|
390 |
+
axes_dim=[16, 56, 56],
|
391 |
+
theta=10_000,
|
392 |
+
qkv_bias=True,
|
393 |
+
guidance_embed=True,
|
394 |
+
),
|
395 |
+
ae_path=None, # os.getenv("AE"),
|
396 |
+
ae_params=AutoEncoderParams(
|
397 |
+
resolution=256,
|
398 |
+
in_channels=3,
|
399 |
+
ch=128,
|
400 |
+
out_ch=3,
|
401 |
+
ch_mult=[1, 2, 4, 4],
|
402 |
+
num_res_blocks=2,
|
403 |
+
z_channels=16,
|
404 |
+
scale_factor=0.3611,
|
405 |
+
shift_factor=0.1159,
|
406 |
+
),
|
407 |
+
),
|
408 |
+
"schnell": ModelSpec(
|
409 |
+
# repo_id="black-forest-labs/FLUX.1-schnell",
|
410 |
+
# repo_flow="flux1-schnell.sft",
|
411 |
+
# repo_ae="ae.sft",
|
412 |
+
ckpt_path=None, # os.getenv("FLUX_SCHNELL"),
|
413 |
+
params=FluxParams(
|
414 |
+
in_channels=64,
|
415 |
+
vec_in_dim=768,
|
416 |
+
context_in_dim=4096,
|
417 |
+
hidden_size=3072,
|
418 |
+
mlp_ratio=4.0,
|
419 |
+
num_heads=24,
|
420 |
+
depth=19,
|
421 |
+
depth_single_blocks=38,
|
422 |
+
axes_dim=[16, 56, 56],
|
423 |
+
theta=10_000,
|
424 |
+
qkv_bias=True,
|
425 |
+
guidance_embed=False,
|
426 |
+
),
|
427 |
+
ae_path=None, # os.getenv("AE"),
|
428 |
+
ae_params=AutoEncoderParams(
|
429 |
+
resolution=256,
|
430 |
+
in_channels=3,
|
431 |
+
ch=128,
|
432 |
+
out_ch=3,
|
433 |
+
ch_mult=[1, 2, 4, 4],
|
434 |
+
num_res_blocks=2,
|
435 |
+
z_channels=16,
|
436 |
+
scale_factor=0.3611,
|
437 |
+
shift_factor=0.1159,
|
438 |
+
),
|
439 |
+
),
|
440 |
+
}
|
441 |
+
|
442 |
+
|
443 |
+
# endregion
|
444 |
+
|
445 |
+
# region math
|
446 |
+
|
447 |
+
|
448 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
449 |
+
q, k = apply_rope(q, k, pe)
|
450 |
+
|
451 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
452 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
453 |
+
|
454 |
+
return x
|
455 |
+
|
456 |
+
|
457 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
458 |
+
assert dim % 2 == 0
|
459 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
460 |
+
omega = 1.0 / (theta**scale)
|
461 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
462 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
463 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
464 |
+
return out.float()
|
465 |
+
|
466 |
+
|
467 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
468 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
469 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
470 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
471 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
472 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
473 |
+
|
474 |
+
|
475 |
+
# endregion
|
476 |
+
|
477 |
+
|
478 |
+
# region layers
|
479 |
+
|
480 |
+
|
481 |
+
# for cpu_offload_checkpointing
|
482 |
+
|
483 |
+
|
484 |
+
def to_cuda(x):
|
485 |
+
if isinstance(x, torch.Tensor):
|
486 |
+
return x.cuda()
|
487 |
+
elif isinstance(x, (list, tuple)):
|
488 |
+
return [to_cuda(elem) for elem in x]
|
489 |
+
elif isinstance(x, dict):
|
490 |
+
return {k: to_cuda(v) for k, v in x.items()}
|
491 |
+
else:
|
492 |
+
return x
|
493 |
+
|
494 |
+
|
495 |
+
def to_cpu(x):
|
496 |
+
if isinstance(x, torch.Tensor):
|
497 |
+
return x.cpu()
|
498 |
+
elif isinstance(x, (list, tuple)):
|
499 |
+
return [to_cpu(elem) for elem in x]
|
500 |
+
elif isinstance(x, dict):
|
501 |
+
return {k: to_cpu(v) for k, v in x.items()}
|
502 |
+
else:
|
503 |
+
return x
|
504 |
+
|
505 |
+
|
506 |
+
class EmbedND(nn.Module):
|
507 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
508 |
+
super().__init__()
|
509 |
+
self.dim = dim
|
510 |
+
self.theta = theta
|
511 |
+
self.axes_dim = axes_dim
|
512 |
+
|
513 |
+
def forward(self, ids: Tensor) -> Tensor:
|
514 |
+
n_axes = ids.shape[-1]
|
515 |
+
emb = torch.cat(
|
516 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
517 |
+
dim=-3,
|
518 |
+
)
|
519 |
+
|
520 |
+
return emb.unsqueeze(1)
|
521 |
+
|
522 |
+
|
523 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
524 |
+
"""
|
525 |
+
Create sinusoidal timestep embeddings.
|
526 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
527 |
+
These may be fractional.
|
528 |
+
:param dim: the dimension of the output.
|
529 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
530 |
+
:return: an (N, D) Tensor of positional embeddings.
|
531 |
+
"""
|
532 |
+
t = time_factor * t
|
533 |
+
half = dim // 2
|
534 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
535 |
+
|
536 |
+
args = t[:, None].float() * freqs[None]
|
537 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
538 |
+
if dim % 2:
|
539 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
540 |
+
if torch.is_floating_point(t):
|
541 |
+
embedding = embedding.to(t)
|
542 |
+
return embedding
|
543 |
+
|
544 |
+
|
545 |
+
class MLPEmbedder(nn.Module):
|
546 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
547 |
+
super().__init__()
|
548 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
549 |
+
self.silu = nn.SiLU()
|
550 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
551 |
+
|
552 |
+
self.gradient_checkpointing = False
|
553 |
+
|
554 |
+
def enable_gradient_checkpointing(self):
|
555 |
+
self.gradient_checkpointing = True
|
556 |
+
|
557 |
+
def disable_gradient_checkpointing(self):
|
558 |
+
self.gradient_checkpointing = False
|
559 |
+
|
560 |
+
def _forward(self, x: Tensor) -> Tensor:
|
561 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
562 |
+
|
563 |
+
def forward(self, *args, **kwargs):
|
564 |
+
if self.training and self.gradient_checkpointing:
|
565 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
566 |
+
else:
|
567 |
+
return self._forward(*args, **kwargs)
|
568 |
+
|
569 |
+
# def forward(self, x):
|
570 |
+
# if self.training and self.gradient_checkpointing:
|
571 |
+
# def create_custom_forward(func):
|
572 |
+
# def custom_forward(*inputs):
|
573 |
+
# return func(*inputs)
|
574 |
+
# return custom_forward
|
575 |
+
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT)
|
576 |
+
# else:
|
577 |
+
# return self._forward(x)
|
578 |
+
|
579 |
+
|
580 |
+
class RMSNorm(torch.nn.Module):
|
581 |
+
def __init__(self, dim: int):
|
582 |
+
super().__init__()
|
583 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
584 |
+
|
585 |
+
def forward(self, x: Tensor):
|
586 |
+
x_dtype = x.dtype
|
587 |
+
x = x.float()
|
588 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
589 |
+
# return (x * rrms).to(dtype=x_dtype) * self.scale
|
590 |
+
return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
|
591 |
+
|
592 |
+
|
593 |
+
class QKNorm(torch.nn.Module):
|
594 |
+
def __init__(self, dim: int):
|
595 |
+
super().__init__()
|
596 |
+
self.query_norm = RMSNorm(dim)
|
597 |
+
self.key_norm = RMSNorm(dim)
|
598 |
+
|
599 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
600 |
+
q = self.query_norm(q)
|
601 |
+
k = self.key_norm(k)
|
602 |
+
return q.to(v), k.to(v)
|
603 |
+
|
604 |
+
|
605 |
+
class SelfAttention(nn.Module):
|
606 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
607 |
+
super().__init__()
|
608 |
+
self.num_heads = num_heads
|
609 |
+
head_dim = dim // num_heads
|
610 |
+
|
611 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
612 |
+
self.norm = QKNorm(head_dim)
|
613 |
+
self.proj = nn.Linear(dim, dim)
|
614 |
+
|
615 |
+
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
|
616 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
617 |
+
qkv = self.qkv(x)
|
618 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
619 |
+
q, k = self.norm(q, k, v)
|
620 |
+
x = attention(q, k, v, pe=pe)
|
621 |
+
x = self.proj(x)
|
622 |
+
return x
|
623 |
+
|
624 |
+
|
625 |
+
@dataclass
|
626 |
+
class ModulationOut:
|
627 |
+
shift: Tensor
|
628 |
+
scale: Tensor
|
629 |
+
gate: Tensor
|
630 |
+
|
631 |
+
|
632 |
+
class Modulation(nn.Module):
|
633 |
+
def __init__(self, dim: int, double: bool):
|
634 |
+
super().__init__()
|
635 |
+
self.is_double = double
|
636 |
+
self.multiplier = 6 if double else 3
|
637 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
638 |
+
|
639 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
640 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
641 |
+
|
642 |
+
return (
|
643 |
+
ModulationOut(*out[:3]),
|
644 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
645 |
+
)
|
646 |
+
|
647 |
+
|
648 |
+
class DoubleStreamBlock(nn.Module):
|
649 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
650 |
+
super().__init__()
|
651 |
+
|
652 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
653 |
+
self.num_heads = num_heads
|
654 |
+
self.hidden_size = hidden_size
|
655 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
656 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
657 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
658 |
+
|
659 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
660 |
+
self.img_mlp = nn.Sequential(
|
661 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
662 |
+
nn.GELU(approximate="tanh"),
|
663 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
664 |
+
)
|
665 |
+
|
666 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
667 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
668 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
669 |
+
|
670 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
671 |
+
self.txt_mlp = nn.Sequential(
|
672 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
673 |
+
nn.GELU(approximate="tanh"),
|
674 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
675 |
+
)
|
676 |
+
|
677 |
+
self.gradient_checkpointing = False
|
678 |
+
self.cpu_offload_checkpointing = False
|
679 |
+
|
680 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
681 |
+
self.gradient_checkpointing = True
|
682 |
+
self.cpu_offload_checkpointing = cpu_offload
|
683 |
+
|
684 |
+
def disable_gradient_checkpointing(self):
|
685 |
+
self.gradient_checkpointing = False
|
686 |
+
self.cpu_offload_checkpointing = False
|
687 |
+
|
688 |
+
def _forward(
|
689 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
690 |
+
) -> tuple[Tensor, Tensor]:
|
691 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
692 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
693 |
+
|
694 |
+
# prepare image for attention
|
695 |
+
img_modulated = self.img_norm1(img)
|
696 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
697 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
698 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
699 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
700 |
+
|
701 |
+
# prepare txt for attention
|
702 |
+
txt_modulated = self.txt_norm1(txt)
|
703 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
704 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
705 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
706 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
707 |
+
|
708 |
+
# run actual attention
|
709 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
710 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
711 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
712 |
+
|
713 |
+
# make attention mask if not None
|
714 |
+
attn_mask = None
|
715 |
+
if txt_attention_mask is not None:
|
716 |
+
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
717 |
+
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
718 |
+
attn_mask = torch.cat(
|
719 |
+
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
|
720 |
+
) # b, seq_len + img_len
|
721 |
+
|
722 |
+
# broadcast attn_mask to all heads
|
723 |
+
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
724 |
+
|
725 |
+
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
726 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
727 |
+
|
728 |
+
# calculate the img blocks
|
729 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
730 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
731 |
+
|
732 |
+
# calculate the txt blocks
|
733 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
734 |
+
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
735 |
+
return img, txt
|
736 |
+
|
737 |
+
def forward(
|
738 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
739 |
+
) -> tuple[Tensor, Tensor]:
|
740 |
+
if self.training and self.gradient_checkpointing:
|
741 |
+
if not self.cpu_offload_checkpointing:
|
742 |
+
return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
|
743 |
+
# cpu offload checkpointing
|
744 |
+
|
745 |
+
def create_custom_forward(func):
|
746 |
+
def custom_forward(*inputs):
|
747 |
+
cuda_inputs = to_cuda(inputs)
|
748 |
+
outputs = func(*cuda_inputs)
|
749 |
+
return to_cpu(outputs)
|
750 |
+
|
751 |
+
return custom_forward
|
752 |
+
|
753 |
+
return torch.utils.checkpoint.checkpoint(
|
754 |
+
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
|
755 |
+
)
|
756 |
+
|
757 |
+
else:
|
758 |
+
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
759 |
+
|
760 |
+
|
761 |
+
class SingleStreamBlock(nn.Module):
|
762 |
+
"""
|
763 |
+
A DiT block with parallel linear layers as described in
|
764 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
765 |
+
"""
|
766 |
+
|
767 |
+
def __init__(
|
768 |
+
self,
|
769 |
+
hidden_size: int,
|
770 |
+
num_heads: int,
|
771 |
+
mlp_ratio: float = 4.0,
|
772 |
+
qk_scale: float | None = None,
|
773 |
+
):
|
774 |
+
super().__init__()
|
775 |
+
self.hidden_dim = hidden_size
|
776 |
+
self.num_heads = num_heads
|
777 |
+
head_dim = hidden_size // num_heads
|
778 |
+
self.scale = qk_scale or head_dim**-0.5
|
779 |
+
|
780 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
781 |
+
# qkv and mlp_in
|
782 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
783 |
+
# proj and mlp_out
|
784 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
785 |
+
|
786 |
+
self.norm = QKNorm(head_dim)
|
787 |
+
|
788 |
+
self.hidden_size = hidden_size
|
789 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
790 |
+
|
791 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
792 |
+
self.modulation = Modulation(hidden_size, double=False)
|
793 |
+
|
794 |
+
self.gradient_checkpointing = False
|
795 |
+
self.cpu_offload_checkpointing = False
|
796 |
+
|
797 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
798 |
+
self.gradient_checkpointing = True
|
799 |
+
self.cpu_offload_checkpointing = cpu_offload
|
800 |
+
|
801 |
+
def disable_gradient_checkpointing(self):
|
802 |
+
self.gradient_checkpointing = False
|
803 |
+
self.cpu_offload_checkpointing = False
|
804 |
+
|
805 |
+
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
806 |
+
mod, _ = self.modulation(vec)
|
807 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
808 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
809 |
+
|
810 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
811 |
+
q, k = self.norm(q, k, v)
|
812 |
+
|
813 |
+
# make attention mask if not None
|
814 |
+
attn_mask = None
|
815 |
+
if txt_attention_mask is not None:
|
816 |
+
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
817 |
+
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
818 |
+
attn_mask = torch.cat(
|
819 |
+
(
|
820 |
+
attn_mask,
|
821 |
+
torch.ones(
|
822 |
+
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
|
823 |
+
),
|
824 |
+
),
|
825 |
+
dim=1,
|
826 |
+
) # b, seq_len + img_len = x_len
|
827 |
+
|
828 |
+
# broadcast attn_mask to all heads
|
829 |
+
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
830 |
+
|
831 |
+
# compute attention
|
832 |
+
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
833 |
+
|
834 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
835 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
836 |
+
return x + mod.gate * output
|
837 |
+
|
838 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
839 |
+
if self.training and self.gradient_checkpointing:
|
840 |
+
if not self.cpu_offload_checkpointing:
|
841 |
+
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
|
842 |
+
|
843 |
+
# cpu offload checkpointing
|
844 |
+
|
845 |
+
def create_custom_forward(func):
|
846 |
+
def custom_forward(*inputs):
|
847 |
+
cuda_inputs = to_cuda(inputs)
|
848 |
+
outputs = func(*cuda_inputs)
|
849 |
+
return to_cpu(outputs)
|
850 |
+
|
851 |
+
return custom_forward
|
852 |
+
|
853 |
+
return torch.utils.checkpoint.checkpoint(
|
854 |
+
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
|
855 |
+
)
|
856 |
+
else:
|
857 |
+
return self._forward(x, vec, pe, txt_attention_mask)
|
858 |
+
|
859 |
+
|
860 |
+
class LastLayer(nn.Module):
|
861 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
862 |
+
super().__init__()
|
863 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
864 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
865 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
866 |
+
|
867 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
868 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
869 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
870 |
+
x = self.linear(x)
|
871 |
+
return x
|
872 |
+
|
873 |
+
|
874 |
+
# endregion
|
875 |
+
|
876 |
+
|
877 |
+
class Flux(nn.Module):
|
878 |
+
"""
|
879 |
+
Transformer model for flow matching on sequences.
|
880 |
+
"""
|
881 |
+
|
882 |
+
def __init__(self, params: FluxParams):
|
883 |
+
super().__init__()
|
884 |
+
|
885 |
+
self.params = params
|
886 |
+
self.in_channels = params.in_channels
|
887 |
+
self.out_channels = self.in_channels
|
888 |
+
if params.hidden_size % params.num_heads != 0:
|
889 |
+
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
890 |
+
pe_dim = params.hidden_size // params.num_heads
|
891 |
+
if sum(params.axes_dim) != pe_dim:
|
892 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
893 |
+
self.hidden_size = params.hidden_size
|
894 |
+
self.num_heads = params.num_heads
|
895 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
896 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
897 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
898 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
899 |
+
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
900 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
901 |
+
|
902 |
+
self.double_blocks = nn.ModuleList(
|
903 |
+
[
|
904 |
+
DoubleStreamBlock(
|
905 |
+
self.hidden_size,
|
906 |
+
self.num_heads,
|
907 |
+
mlp_ratio=params.mlp_ratio,
|
908 |
+
qkv_bias=params.qkv_bias,
|
909 |
+
)
|
910 |
+
for _ in range(params.depth)
|
911 |
+
]
|
912 |
+
)
|
913 |
+
|
914 |
+
self.single_blocks = nn.ModuleList(
|
915 |
+
[
|
916 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
917 |
+
for _ in range(params.depth_single_blocks)
|
918 |
+
]
|
919 |
+
)
|
920 |
+
|
921 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
922 |
+
|
923 |
+
self.gradient_checkpointing = False
|
924 |
+
self.cpu_offload_checkpointing = False
|
925 |
+
self.blocks_to_swap = None
|
926 |
+
|
927 |
+
self.offloader_double = None
|
928 |
+
self.offloader_single = None
|
929 |
+
self.num_double_blocks = len(self.double_blocks)
|
930 |
+
self.num_single_blocks = len(self.single_blocks)
|
931 |
+
|
932 |
+
@property
|
933 |
+
def device(self):
|
934 |
+
return next(self.parameters()).device
|
935 |
+
|
936 |
+
@property
|
937 |
+
def dtype(self):
|
938 |
+
return next(self.parameters()).dtype
|
939 |
+
|
940 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
941 |
+
self.gradient_checkpointing = True
|
942 |
+
self.cpu_offload_checkpointing = cpu_offload
|
943 |
+
|
944 |
+
self.time_in.enable_gradient_checkpointing()
|
945 |
+
self.vector_in.enable_gradient_checkpointing()
|
946 |
+
if self.guidance_in.__class__ != nn.Identity:
|
947 |
+
self.guidance_in.enable_gradient_checkpointing()
|
948 |
+
|
949 |
+
for block in self.double_blocks + self.single_blocks:
|
950 |
+
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
|
951 |
+
|
952 |
+
print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
|
953 |
+
|
954 |
+
def disable_gradient_checkpointing(self):
|
955 |
+
self.gradient_checkpointing = False
|
956 |
+
self.cpu_offload_checkpointing = False
|
957 |
+
|
958 |
+
self.time_in.disable_gradient_checkpointing()
|
959 |
+
self.vector_in.disable_gradient_checkpointing()
|
960 |
+
if self.guidance_in.__class__ != nn.Identity:
|
961 |
+
self.guidance_in.disable_gradient_checkpointing()
|
962 |
+
|
963 |
+
for block in self.double_blocks + self.single_blocks:
|
964 |
+
block.disable_gradient_checkpointing()
|
965 |
+
|
966 |
+
print("FLUX: Gradient checkpointing disabled.")
|
967 |
+
|
968 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
969 |
+
self.blocks_to_swap = num_blocks
|
970 |
+
double_blocks_to_swap = num_blocks // 2
|
971 |
+
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
972 |
+
|
973 |
+
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
|
974 |
+
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
|
975 |
+
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
|
976 |
+
)
|
977 |
+
|
978 |
+
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
979 |
+
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
980 |
+
)
|
981 |
+
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
982 |
+
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
983 |
+
)
|
984 |
+
print(
|
985 |
+
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
986 |
+
)
|
987 |
+
|
988 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
989 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
990 |
+
if self.blocks_to_swap:
|
991 |
+
save_double_blocks = self.double_blocks
|
992 |
+
save_single_blocks = self.single_blocks
|
993 |
+
self.double_blocks = None
|
994 |
+
self.single_blocks = None
|
995 |
+
|
996 |
+
self.to(device)
|
997 |
+
|
998 |
+
if self.blocks_to_swap:
|
999 |
+
self.double_blocks = save_double_blocks
|
1000 |
+
self.single_blocks = save_single_blocks
|
1001 |
+
|
1002 |
+
def prepare_block_swap_before_forward(self):
|
1003 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
1004 |
+
return
|
1005 |
+
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
1006 |
+
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
1007 |
+
|
1008 |
+
def forward(
|
1009 |
+
self,
|
1010 |
+
img: Tensor,
|
1011 |
+
img_ids: Tensor,
|
1012 |
+
txt: Tensor,
|
1013 |
+
txt_ids: Tensor,
|
1014 |
+
timesteps: Tensor,
|
1015 |
+
y: Tensor,
|
1016 |
+
guidance: Tensor | None = None,
|
1017 |
+
txt_attention_mask: Tensor | None = None,
|
1018 |
+
) -> Tensor:
|
1019 |
+
if img.ndim != 3 or txt.ndim != 3:
|
1020 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
1021 |
+
|
1022 |
+
# running on sequences img
|
1023 |
+
img = self.img_in(img)
|
1024 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
1025 |
+
if self.params.guidance_embed:
|
1026 |
+
if guidance is None:
|
1027 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
1028 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
1029 |
+
vec = vec + self.vector_in(y)
|
1030 |
+
txt = self.txt_in(txt)
|
1031 |
+
|
1032 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
1033 |
+
pe = self.pe_embedder(ids)
|
1034 |
+
|
1035 |
+
if not self.blocks_to_swap:
|
1036 |
+
for block in self.double_blocks:
|
1037 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1038 |
+
img = torch.cat((txt, img), 1)
|
1039 |
+
for block in self.single_blocks:
|
1040 |
+
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1041 |
+
else:
|
1042 |
+
for block_idx, block in enumerate(self.double_blocks):
|
1043 |
+
self.offloader_double.wait_for_block(block_idx)
|
1044 |
+
|
1045 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1046 |
+
|
1047 |
+
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
1048 |
+
|
1049 |
+
img = torch.cat((txt, img), 1)
|
1050 |
+
|
1051 |
+
for block_idx, block in enumerate(self.single_blocks):
|
1052 |
+
self.offloader_single.wait_for_block(block_idx)
|
1053 |
+
|
1054 |
+
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1055 |
+
|
1056 |
+
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
|
1057 |
+
|
1058 |
+
img = img[:, txt.shape[1] :, ...]
|
1059 |
+
|
1060 |
+
if self.training and self.cpu_offload_checkpointing:
|
1061 |
+
img = img.to(self.device)
|
1062 |
+
vec = vec.to(self.device)
|
1063 |
+
|
1064 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
1065 |
+
|
1066 |
+
return img
|
1067 |
+
|
1068 |
+
|
1069 |
+
"""
|
1070 |
+
class FluxUpper(nn.Module):
|
1071 |
+
""
|
1072 |
+
Transformer model for flow matching on sequences.
|
1073 |
+
""
|
1074 |
+
|
1075 |
+
def __init__(self, params: FluxParams):
|
1076 |
+
super().__init__()
|
1077 |
+
|
1078 |
+
self.params = params
|
1079 |
+
self.in_channels = params.in_channels
|
1080 |
+
self.out_channels = self.in_channels
|
1081 |
+
if params.hidden_size % params.num_heads != 0:
|
1082 |
+
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
1083 |
+
pe_dim = params.hidden_size // params.num_heads
|
1084 |
+
if sum(params.axes_dim) != pe_dim:
|
1085 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
1086 |
+
self.hidden_size = params.hidden_size
|
1087 |
+
self.num_heads = params.num_heads
|
1088 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
1089 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
1090 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
1091 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
1092 |
+
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
1093 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
1094 |
+
|
1095 |
+
self.double_blocks = nn.ModuleList(
|
1096 |
+
[
|
1097 |
+
DoubleStreamBlock(
|
1098 |
+
self.hidden_size,
|
1099 |
+
self.num_heads,
|
1100 |
+
mlp_ratio=params.mlp_ratio,
|
1101 |
+
qkv_bias=params.qkv_bias,
|
1102 |
+
)
|
1103 |
+
for _ in range(params.depth)
|
1104 |
+
]
|
1105 |
+
)
|
1106 |
+
|
1107 |
+
self.gradient_checkpointing = False
|
1108 |
+
|
1109 |
+
@property
|
1110 |
+
def device(self):
|
1111 |
+
return next(self.parameters()).device
|
1112 |
+
|
1113 |
+
@property
|
1114 |
+
def dtype(self):
|
1115 |
+
return next(self.parameters()).dtype
|
1116 |
+
|
1117 |
+
def enable_gradient_checkpointing(self):
|
1118 |
+
self.gradient_checkpointing = True
|
1119 |
+
|
1120 |
+
self.time_in.enable_gradient_checkpointing()
|
1121 |
+
self.vector_in.enable_gradient_checkpointing()
|
1122 |
+
if self.guidance_in.__class__ != nn.Identity:
|
1123 |
+
self.guidance_in.enable_gradient_checkpointing()
|
1124 |
+
|
1125 |
+
for block in self.double_blocks:
|
1126 |
+
block.enable_gradient_checkpointing()
|
1127 |
+
|
1128 |
+
print("FLUX: Gradient checkpointing enabled.")
|
1129 |
+
|
1130 |
+
def disable_gradient_checkpointing(self):
|
1131 |
+
self.gradient_checkpointing = False
|
1132 |
+
|
1133 |
+
self.time_in.disable_gradient_checkpointing()
|
1134 |
+
self.vector_in.disable_gradient_checkpointing()
|
1135 |
+
if self.guidance_in.__class__ != nn.Identity:
|
1136 |
+
self.guidance_in.disable_gradient_checkpointing()
|
1137 |
+
|
1138 |
+
for block in self.double_blocks:
|
1139 |
+
block.disable_gradient_checkpointing()
|
1140 |
+
|
1141 |
+
print("FLUX: Gradient checkpointing disabled.")
|
1142 |
+
|
1143 |
+
def forward(
|
1144 |
+
self,
|
1145 |
+
img: Tensor,
|
1146 |
+
img_ids: Tensor,
|
1147 |
+
txt: Tensor,
|
1148 |
+
txt_ids: Tensor,
|
1149 |
+
timesteps: Tensor,
|
1150 |
+
y: Tensor,
|
1151 |
+
guidance: Tensor | None = None,
|
1152 |
+
txt_attention_mask: Tensor | None = None,
|
1153 |
+
) -> Tensor:
|
1154 |
+
if img.ndim != 3 or txt.ndim != 3:
|
1155 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
1156 |
+
|
1157 |
+
# running on sequences img
|
1158 |
+
img = self.img_in(img)
|
1159 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
1160 |
+
if self.params.guidance_embed:
|
1161 |
+
if guidance is None:
|
1162 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
1163 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
1164 |
+
vec = vec + self.vector_in(y)
|
1165 |
+
txt = self.txt_in(txt)
|
1166 |
+
|
1167 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
1168 |
+
pe = self.pe_embedder(ids)
|
1169 |
+
|
1170 |
+
for block in self.double_blocks:
|
1171 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1172 |
+
|
1173 |
+
return img, txt, vec, pe
|
1174 |
+
|
1175 |
+
|
1176 |
+
class FluxLower(nn.Module):
|
1177 |
+
""
|
1178 |
+
Transformer model for flow matching on sequences.
|
1179 |
+
""
|
1180 |
+
|
1181 |
+
def __init__(self, params: FluxParams):
|
1182 |
+
super().__init__()
|
1183 |
+
self.hidden_size = params.hidden_size
|
1184 |
+
self.num_heads = params.num_heads
|
1185 |
+
self.out_channels = params.in_channels
|
1186 |
+
|
1187 |
+
self.single_blocks = nn.ModuleList(
|
1188 |
+
[
|
1189 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
1190 |
+
for _ in range(params.depth_single_blocks)
|
1191 |
+
]
|
1192 |
+
)
|
1193 |
+
|
1194 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
1195 |
+
|
1196 |
+
self.gradient_checkpointing = False
|
1197 |
+
|
1198 |
+
@property
|
1199 |
+
def device(self):
|
1200 |
+
return next(self.parameters()).device
|
1201 |
+
|
1202 |
+
@property
|
1203 |
+
def dtype(self):
|
1204 |
+
return next(self.parameters()).dtype
|
1205 |
+
|
1206 |
+
def enable_gradient_checkpointing(self):
|
1207 |
+
self.gradient_checkpointing = True
|
1208 |
+
|
1209 |
+
for block in self.single_blocks:
|
1210 |
+
block.enable_gradient_checkpointing()
|
1211 |
+
|
1212 |
+
print("FLUX: Gradient checkpointing enabled.")
|
1213 |
+
|
1214 |
+
def disable_gradient_checkpointing(self):
|
1215 |
+
self.gradient_checkpointing = False
|
1216 |
+
|
1217 |
+
for block in self.single_blocks:
|
1218 |
+
block.disable_gradient_checkpointing()
|
1219 |
+
|
1220 |
+
print("FLUX: Gradient checkpointing disabled.")
|
1221 |
+
|
1222 |
+
def forward(
|
1223 |
+
self,
|
1224 |
+
img: Tensor,
|
1225 |
+
txt: Tensor,
|
1226 |
+
vec: Tensor | None = None,
|
1227 |
+
pe: Tensor | None = None,
|
1228 |
+
txt_attention_mask: Tensor | None = None,
|
1229 |
+
) -> Tensor:
|
1230 |
+
img = torch.cat((txt, img), 1)
|
1231 |
+
for block in self.single_blocks:
|
1232 |
+
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1233 |
+
img = img[:, txt.shape[1] :, ...]
|
1234 |
+
|
1235 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
1236 |
+
return img
|
1237 |
+
"""
|
library/flux_train_utils.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import toml
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from accelerate import Accelerator, PartialState
|
12 |
+
from transformers import CLIPTextModel
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
from safetensors.torch import save_file
|
16 |
+
|
17 |
+
from library import flux_models, flux_utils, strategy_base, train_util
|
18 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
19 |
+
|
20 |
+
init_ipex()
|
21 |
+
|
22 |
+
from .utils import setup_logging, mem_eff_save_file
|
23 |
+
|
24 |
+
setup_logging()
|
25 |
+
import logging
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
# region sample images
|
31 |
+
|
32 |
+
|
33 |
+
def sample_images(
|
34 |
+
accelerator: Accelerator,
|
35 |
+
args: argparse.Namespace,
|
36 |
+
epoch,
|
37 |
+
steps,
|
38 |
+
flux,
|
39 |
+
ae,
|
40 |
+
text_encoders,
|
41 |
+
sample_prompts_te_outputs,
|
42 |
+
prompt_replacement=None,
|
43 |
+
):
|
44 |
+
if steps == 0:
|
45 |
+
if not args.sample_at_first:
|
46 |
+
return
|
47 |
+
else:
|
48 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
49 |
+
return
|
50 |
+
if args.sample_every_n_epochs is not None:
|
51 |
+
# sample_every_n_steps は無視する
|
52 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
53 |
+
return
|
54 |
+
else:
|
55 |
+
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
56 |
+
return
|
57 |
+
|
58 |
+
logger.info("")
|
59 |
+
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
60 |
+
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
61 |
+
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
62 |
+
return
|
63 |
+
|
64 |
+
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
65 |
+
|
66 |
+
# unwrap unet and text_encoder(s)
|
67 |
+
flux = accelerator.unwrap_model(flux)
|
68 |
+
if text_encoders is not None:
|
69 |
+
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
70 |
+
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
71 |
+
|
72 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
73 |
+
|
74 |
+
save_dir = args.output_dir + "/sample"
|
75 |
+
os.makedirs(save_dir, exist_ok=True)
|
76 |
+
|
77 |
+
# save random state to restore later
|
78 |
+
rng_state = torch.get_rng_state()
|
79 |
+
cuda_rng_state = None
|
80 |
+
try:
|
81 |
+
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
82 |
+
except Exception:
|
83 |
+
pass
|
84 |
+
|
85 |
+
if distributed_state.num_processes <= 1:
|
86 |
+
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
87 |
+
with torch.no_grad(), accelerator.autocast():
|
88 |
+
for prompt_dict in prompts:
|
89 |
+
sample_image_inference(
|
90 |
+
accelerator,
|
91 |
+
args,
|
92 |
+
flux,
|
93 |
+
text_encoders,
|
94 |
+
ae,
|
95 |
+
save_dir,
|
96 |
+
prompt_dict,
|
97 |
+
epoch,
|
98 |
+
steps,
|
99 |
+
sample_prompts_te_outputs,
|
100 |
+
prompt_replacement,
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
104 |
+
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
105 |
+
per_process_prompts = [] # list of lists
|
106 |
+
for i in range(distributed_state.num_processes):
|
107 |
+
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
108 |
+
|
109 |
+
with torch.no_grad():
|
110 |
+
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
111 |
+
for prompt_dict in prompt_dict_lists[0]:
|
112 |
+
sample_image_inference(
|
113 |
+
accelerator,
|
114 |
+
args,
|
115 |
+
flux,
|
116 |
+
text_encoders,
|
117 |
+
ae,
|
118 |
+
save_dir,
|
119 |
+
prompt_dict,
|
120 |
+
epoch,
|
121 |
+
steps,
|
122 |
+
sample_prompts_te_outputs,
|
123 |
+
prompt_replacement,
|
124 |
+
)
|
125 |
+
|
126 |
+
torch.set_rng_state(rng_state)
|
127 |
+
if cuda_rng_state is not None:
|
128 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
129 |
+
|
130 |
+
clean_memory_on_device(accelerator.device)
|
131 |
+
|
132 |
+
|
133 |
+
def sample_image_inference(
|
134 |
+
accelerator: Accelerator,
|
135 |
+
args: argparse.Namespace,
|
136 |
+
flux: flux_models.Flux,
|
137 |
+
text_encoders: Optional[List[CLIPTextModel]],
|
138 |
+
ae: flux_models.AutoEncoder,
|
139 |
+
save_dir,
|
140 |
+
prompt_dict,
|
141 |
+
epoch,
|
142 |
+
steps,
|
143 |
+
sample_prompts_te_outputs,
|
144 |
+
prompt_replacement,
|
145 |
+
):
|
146 |
+
assert isinstance(prompt_dict, dict)
|
147 |
+
# negative_prompt = prompt_dict.get("negative_prompt")
|
148 |
+
sample_steps = prompt_dict.get("sample_steps", 20)
|
149 |
+
width = prompt_dict.get("width", 512)
|
150 |
+
height = prompt_dict.get("height", 512)
|
151 |
+
scale = prompt_dict.get("scale", 3.5)
|
152 |
+
seed = prompt_dict.get("seed")
|
153 |
+
# controlnet_image = prompt_dict.get("controlnet_image")
|
154 |
+
prompt: str = prompt_dict.get("prompt", "")
|
155 |
+
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
156 |
+
|
157 |
+
if prompt_replacement is not None:
|
158 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
159 |
+
# if negative_prompt is not None:
|
160 |
+
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
161 |
+
|
162 |
+
if seed is not None:
|
163 |
+
torch.manual_seed(seed)
|
164 |
+
torch.cuda.manual_seed(seed)
|
165 |
+
else:
|
166 |
+
# True random sample image generation
|
167 |
+
torch.seed()
|
168 |
+
torch.cuda.seed()
|
169 |
+
|
170 |
+
# if negative_prompt is None:
|
171 |
+
# negative_prompt = ""
|
172 |
+
|
173 |
+
height = max(64, height - height % 16) # round to divisible by 16
|
174 |
+
width = max(64, width - width % 16) # round to divisible by 16
|
175 |
+
logger.info(f"prompt: {prompt}")
|
176 |
+
# logger.info(f"negative_prompt: {negative_prompt}")
|
177 |
+
logger.info(f"height: {height}")
|
178 |
+
logger.info(f"width: {width}")
|
179 |
+
logger.info(f"sample_steps: {sample_steps}")
|
180 |
+
logger.info(f"scale: {scale}")
|
181 |
+
# logger.info(f"sample_sampler: {sampler_name}")
|
182 |
+
if seed is not None:
|
183 |
+
logger.info(f"seed: {seed}")
|
184 |
+
|
185 |
+
# encode prompts
|
186 |
+
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
187 |
+
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
188 |
+
|
189 |
+
text_encoder_conds = []
|
190 |
+
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
191 |
+
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
192 |
+
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
193 |
+
if text_encoders is not None:
|
194 |
+
print(f"Encoding prompt: {prompt}")
|
195 |
+
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
196 |
+
# strategy has apply_t5_attn_mask option
|
197 |
+
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
198 |
+
|
199 |
+
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
200 |
+
if len(text_encoder_conds) == 0:
|
201 |
+
text_encoder_conds = encoded_text_encoder_conds
|
202 |
+
else:
|
203 |
+
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
204 |
+
for i in range(len(encoded_text_encoder_conds)):
|
205 |
+
if encoded_text_encoder_conds[i] is not None:
|
206 |
+
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
207 |
+
|
208 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
209 |
+
|
210 |
+
# sample image
|
211 |
+
weight_dtype = ae.dtype # TOFO give dtype as argument
|
212 |
+
packed_latent_height = height // 16
|
213 |
+
packed_latent_width = width // 16
|
214 |
+
noise = torch.randn(
|
215 |
+
1,
|
216 |
+
packed_latent_height * packed_latent_width,
|
217 |
+
16 * 2 * 2,
|
218 |
+
device=accelerator.device,
|
219 |
+
dtype=weight_dtype,
|
220 |
+
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
221 |
+
)
|
222 |
+
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
223 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
224 |
+
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
225 |
+
|
226 |
+
with accelerator.autocast(), torch.no_grad():
|
227 |
+
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
|
228 |
+
|
229 |
+
x = x.float()
|
230 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
231 |
+
|
232 |
+
# latent to image
|
233 |
+
clean_memory_on_device(accelerator.device)
|
234 |
+
org_vae_device = ae.device # will be on cpu
|
235 |
+
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
236 |
+
with accelerator.autocast(), torch.no_grad():
|
237 |
+
x = ae.decode(x)
|
238 |
+
ae.to(org_vae_device)
|
239 |
+
clean_memory_on_device(accelerator.device)
|
240 |
+
|
241 |
+
x = x.clamp(-1, 1)
|
242 |
+
x = x.permute(0, 2, 3, 1)
|
243 |
+
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
244 |
+
|
245 |
+
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
246 |
+
# but adding 'enum' to the filename should be enough
|
247 |
+
|
248 |
+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
249 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
250 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
251 |
+
i: int = prompt_dict["enum"]
|
252 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
253 |
+
image.save(os.path.join(save_dir, img_filename))
|
254 |
+
|
255 |
+
# send images to wandb if enabled
|
256 |
+
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
257 |
+
wandb_tracker = accelerator.get_tracker("wandb")
|
258 |
+
|
259 |
+
import wandb
|
260 |
+
|
261 |
+
# not to commit images to avoid inconsistency between training and logging steps
|
262 |
+
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
263 |
+
|
264 |
+
|
265 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
266 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
267 |
+
|
268 |
+
|
269 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
270 |
+
m = (y2 - y1) / (x2 - x1)
|
271 |
+
b = y1 - m * x1
|
272 |
+
return lambda x: m * x + b
|
273 |
+
|
274 |
+
|
275 |
+
def get_schedule(
|
276 |
+
num_steps: int,
|
277 |
+
image_seq_len: int,
|
278 |
+
base_shift: float = 0.5,
|
279 |
+
max_shift: float = 1.15,
|
280 |
+
shift: bool = True,
|
281 |
+
) -> list[float]:
|
282 |
+
# extra step for zero
|
283 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
284 |
+
|
285 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
286 |
+
if shift:
|
287 |
+
# eastimate mu based on linear estimation between two points
|
288 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
289 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
290 |
+
|
291 |
+
return timesteps.tolist()
|
292 |
+
|
293 |
+
|
294 |
+
def denoise(
|
295 |
+
model: flux_models.Flux,
|
296 |
+
img: torch.Tensor,
|
297 |
+
img_ids: torch.Tensor,
|
298 |
+
txt: torch.Tensor,
|
299 |
+
txt_ids: torch.Tensor,
|
300 |
+
vec: torch.Tensor,
|
301 |
+
timesteps: list[float],
|
302 |
+
guidance: float = 4.0,
|
303 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
304 |
+
):
|
305 |
+
# this is ignored for schnell
|
306 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
307 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
308 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
309 |
+
model.prepare_block_swap_before_forward()
|
310 |
+
pred = model(
|
311 |
+
img=img,
|
312 |
+
img_ids=img_ids,
|
313 |
+
txt=txt,
|
314 |
+
txt_ids=txt_ids,
|
315 |
+
y=vec,
|
316 |
+
timesteps=t_vec,
|
317 |
+
guidance=guidance_vec,
|
318 |
+
txt_attention_mask=t5_attn_mask,
|
319 |
+
)
|
320 |
+
|
321 |
+
img = img + (t_prev - t_curr) * pred
|
322 |
+
|
323 |
+
model.prepare_block_swap_before_forward()
|
324 |
+
return img
|
325 |
+
|
326 |
+
|
327 |
+
# endregion
|
328 |
+
|
329 |
+
|
330 |
+
# region train
|
331 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
332 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
333 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
334 |
+
timesteps = timesteps.to(device)
|
335 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
336 |
+
|
337 |
+
sigma = sigmas[step_indices].flatten()
|
338 |
+
while len(sigma.shape) < n_dim:
|
339 |
+
sigma = sigma.unsqueeze(-1)
|
340 |
+
return sigma
|
341 |
+
|
342 |
+
|
343 |
+
def compute_density_for_timestep_sampling(
|
344 |
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
345 |
+
):
|
346 |
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
347 |
+
|
348 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
349 |
+
|
350 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
351 |
+
"""
|
352 |
+
if weighting_scheme == "logit_normal":
|
353 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
354 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
355 |
+
u = torch.nn.functional.sigmoid(u)
|
356 |
+
elif weighting_scheme == "mode":
|
357 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
358 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
359 |
+
else:
|
360 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
361 |
+
return u
|
362 |
+
|
363 |
+
|
364 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
365 |
+
"""Computes loss weighting scheme for SD3 training.
|
366 |
+
|
367 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
368 |
+
|
369 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
370 |
+
"""
|
371 |
+
if weighting_scheme == "sigma_sqrt":
|
372 |
+
weighting = (sigmas**-2.0).float()
|
373 |
+
elif weighting_scheme == "cosmap":
|
374 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
375 |
+
weighting = 2 / (math.pi * bot)
|
376 |
+
else:
|
377 |
+
weighting = torch.ones_like(sigmas)
|
378 |
+
return weighting
|
379 |
+
|
380 |
+
|
381 |
+
def get_noisy_model_input_and_timesteps(
|
382 |
+
args, noise_scheduler, latents, noise, device, dtype
|
383 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
384 |
+
bsz, _, h, w = latents.shape
|
385 |
+
sigmas = None
|
386 |
+
|
387 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
388 |
+
# Simple random t-based noise sampling
|
389 |
+
if args.timestep_sampling == "sigmoid":
|
390 |
+
# https://github.com/XLabs-AI/x-flux/tree/main
|
391 |
+
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
392 |
+
else:
|
393 |
+
t = torch.rand((bsz,), device=device)
|
394 |
+
|
395 |
+
timesteps = t * 1000.0
|
396 |
+
t = t.view(-1, 1, 1, 1)
|
397 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
398 |
+
elif args.timestep_sampling == "shift":
|
399 |
+
shift = args.discrete_flow_shift
|
400 |
+
logits_norm = torch.randn(bsz, device=device)
|
401 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
402 |
+
timesteps = logits_norm.sigmoid()
|
403 |
+
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
404 |
+
|
405 |
+
t = timesteps.view(-1, 1, 1, 1)
|
406 |
+
timesteps = timesteps * 1000.0
|
407 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
408 |
+
elif args.timestep_sampling == "flux_shift":
|
409 |
+
logits_norm = torch.randn(bsz, device=device)
|
410 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
411 |
+
timesteps = logits_norm.sigmoid()
|
412 |
+
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
413 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
414 |
+
|
415 |
+
t = timesteps.view(-1, 1, 1, 1)
|
416 |
+
timesteps = timesteps * 1000.0
|
417 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
418 |
+
else:
|
419 |
+
# Sample a random timestep for each image
|
420 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
421 |
+
u = compute_density_for_timestep_sampling(
|
422 |
+
weighting_scheme=args.weighting_scheme,
|
423 |
+
batch_size=bsz,
|
424 |
+
logit_mean=args.logit_mean,
|
425 |
+
logit_std=args.logit_std,
|
426 |
+
mode_scale=args.mode_scale,
|
427 |
+
)
|
428 |
+
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
429 |
+
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
430 |
+
|
431 |
+
# Add noise according to flow matching.
|
432 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
433 |
+
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
434 |
+
|
435 |
+
return noisy_model_input, timesteps, sigmas
|
436 |
+
|
437 |
+
|
438 |
+
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
439 |
+
weighting = None
|
440 |
+
if args.model_prediction_type == "raw":
|
441 |
+
pass
|
442 |
+
elif args.model_prediction_type == "additive":
|
443 |
+
# add the model_pred to the noisy_model_input
|
444 |
+
model_pred = model_pred + noisy_model_input
|
445 |
+
elif args.model_prediction_type == "sigma_scaled":
|
446 |
+
# apply sigma scaling
|
447 |
+
model_pred = model_pred * (-sigmas) + noisy_model_input
|
448 |
+
|
449 |
+
# these weighting schemes use a uniform timestep sampling
|
450 |
+
# and instead post-weight the loss
|
451 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
452 |
+
|
453 |
+
return model_pred, weighting
|
454 |
+
|
455 |
+
|
456 |
+
def save_models(
|
457 |
+
ckpt_path: str,
|
458 |
+
flux: flux_models.Flux,
|
459 |
+
sai_metadata: Optional[dict],
|
460 |
+
save_dtype: Optional[torch.dtype] = None,
|
461 |
+
use_mem_eff_save: bool = False,
|
462 |
+
):
|
463 |
+
state_dict = {}
|
464 |
+
|
465 |
+
def update_sd(prefix, sd):
|
466 |
+
for k, v in sd.items():
|
467 |
+
key = prefix + k
|
468 |
+
if save_dtype is not None and v.dtype != save_dtype:
|
469 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
470 |
+
state_dict[key] = v
|
471 |
+
|
472 |
+
update_sd("", flux.state_dict())
|
473 |
+
|
474 |
+
if not use_mem_eff_save:
|
475 |
+
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
476 |
+
else:
|
477 |
+
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
478 |
+
|
479 |
+
|
480 |
+
def save_flux_model_on_train_end(
|
481 |
+
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
|
482 |
+
):
|
483 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
484 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
485 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
486 |
+
|
487 |
+
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
488 |
+
|
489 |
+
|
490 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
491 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
492 |
+
def save_flux_model_on_epoch_end_or_stepwise(
|
493 |
+
args: argparse.Namespace,
|
494 |
+
on_epoch_end: bool,
|
495 |
+
accelerator,
|
496 |
+
save_dtype: torch.dtype,
|
497 |
+
epoch: int,
|
498 |
+
num_train_epochs: int,
|
499 |
+
global_step: int,
|
500 |
+
flux: flux_models.Flux,
|
501 |
+
):
|
502 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
503 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
504 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
505 |
+
|
506 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
507 |
+
args,
|
508 |
+
on_epoch_end,
|
509 |
+
accelerator,
|
510 |
+
True,
|
511 |
+
True,
|
512 |
+
epoch,
|
513 |
+
num_train_epochs,
|
514 |
+
global_step,
|
515 |
+
sd_saver,
|
516 |
+
None,
|
517 |
+
)
|
518 |
+
|
519 |
+
|
520 |
+
# endregion
|
521 |
+
|
522 |
+
|
523 |
+
def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
524 |
+
parser.add_argument(
|
525 |
+
"--clip_l",
|
526 |
+
type=str,
|
527 |
+
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
|
528 |
+
)
|
529 |
+
parser.add_argument(
|
530 |
+
"--t5xxl",
|
531 |
+
type=str,
|
532 |
+
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
533 |
+
)
|
534 |
+
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
535 |
+
parser.add_argument(
|
536 |
+
"--t5xxl_max_token_length",
|
537 |
+
type=int,
|
538 |
+
default=None,
|
539 |
+
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
|
540 |
+
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
541 |
+
)
|
542 |
+
parser.add_argument(
|
543 |
+
"--apply_t5_attn_mask",
|
544 |
+
action="store_true",
|
545 |
+
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
546 |
+
)
|
547 |
+
|
548 |
+
parser.add_argument(
|
549 |
+
"--guidance_scale",
|
550 |
+
type=float,
|
551 |
+
default=3.5,
|
552 |
+
help="the FLUX.1 dev variant is a guidance distilled model",
|
553 |
+
)
|
554 |
+
|
555 |
+
parser.add_argument(
|
556 |
+
"--timestep_sampling",
|
557 |
+
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
558 |
+
default="sigma",
|
559 |
+
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
|
560 |
+
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
|
561 |
+
)
|
562 |
+
parser.add_argument(
|
563 |
+
"--sigmoid_scale",
|
564 |
+
type=float,
|
565 |
+
default=1.0,
|
566 |
+
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
567 |
+
)
|
568 |
+
parser.add_argument(
|
569 |
+
"--model_prediction_type",
|
570 |
+
choices=["raw", "additive", "sigma_scaled"],
|
571 |
+
default="sigma_scaled",
|
572 |
+
help="How to interpret and process the model prediction: "
|
573 |
+
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
574 |
+
" / モデル予測の解釈と処理方法:"
|
575 |
+
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
576 |
+
)
|
577 |
+
parser.add_argument(
|
578 |
+
"--discrete_flow_shift",
|
579 |
+
type=float,
|
580 |
+
default=3.0,
|
581 |
+
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
582 |
+
)
|
library/flux_train_utils_recraft.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import toml
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
9 |
+
import pdb
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from accelerate import Accelerator, PartialState
|
13 |
+
from transformers import CLIPTextModel
|
14 |
+
from tqdm import tqdm
|
15 |
+
from PIL import Image
|
16 |
+
from safetensors.torch import save_file
|
17 |
+
|
18 |
+
from library import flux_models, flux_utils, strategy_base, train_util
|
19 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
20 |
+
|
21 |
+
init_ipex()
|
22 |
+
|
23 |
+
from .utils import setup_logging, mem_eff_save_file
|
24 |
+
|
25 |
+
setup_logging()
|
26 |
+
import logging
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
# region sample images
|
32 |
+
|
33 |
+
def sample_images(
|
34 |
+
accelerator: Accelerator,
|
35 |
+
args: argparse.Namespace,
|
36 |
+
epoch,
|
37 |
+
steps,
|
38 |
+
flux,
|
39 |
+
ae,
|
40 |
+
text_encoders,
|
41 |
+
sample_prompts_te_outputs,
|
42 |
+
prompt_replacement=None,
|
43 |
+
sample_images_ae_outputs=None
|
44 |
+
):
|
45 |
+
if steps == 0:
|
46 |
+
if not args.sample_at_first:
|
47 |
+
return
|
48 |
+
else:
|
49 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
50 |
+
return
|
51 |
+
if args.sample_every_n_epochs is not None:
|
52 |
+
# sample_every_n_steps は無視する
|
53 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
54 |
+
return
|
55 |
+
else:
|
56 |
+
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
57 |
+
return
|
58 |
+
|
59 |
+
logger.info("")
|
60 |
+
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
61 |
+
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
62 |
+
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
63 |
+
return
|
64 |
+
|
65 |
+
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
66 |
+
|
67 |
+
# unwrap unet and text_encoder(s)
|
68 |
+
flux = accelerator.unwrap_model(flux)
|
69 |
+
if text_encoders is not None:
|
70 |
+
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
71 |
+
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
72 |
+
|
73 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
74 |
+
|
75 |
+
save_dir = args.output_dir + "/sample"
|
76 |
+
os.makedirs(save_dir, exist_ok=True)
|
77 |
+
|
78 |
+
# save random state to restore later
|
79 |
+
rng_state = torch.get_rng_state()
|
80 |
+
cuda_rng_state = None
|
81 |
+
try:
|
82 |
+
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
83 |
+
except Exception:
|
84 |
+
pass
|
85 |
+
|
86 |
+
if distributed_state.num_processes <= 1:
|
87 |
+
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
88 |
+
with torch.no_grad(), accelerator.autocast():
|
89 |
+
for prompt_dict in prompts:
|
90 |
+
sample_image_inference(
|
91 |
+
accelerator,
|
92 |
+
args,
|
93 |
+
flux,
|
94 |
+
text_encoders,
|
95 |
+
ae,
|
96 |
+
save_dir,
|
97 |
+
prompt_dict,
|
98 |
+
epoch,
|
99 |
+
steps,
|
100 |
+
sample_prompts_te_outputs,
|
101 |
+
prompt_replacement,
|
102 |
+
sample_images_ae_outputs
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
106 |
+
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
107 |
+
per_process_prompts = [] # list of lists
|
108 |
+
for i in range(distributed_state.num_processes):
|
109 |
+
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
110 |
+
|
111 |
+
with torch.no_grad():
|
112 |
+
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
113 |
+
for prompt_dict in prompt_dict_lists[0]:
|
114 |
+
sample_image_inference(
|
115 |
+
accelerator,
|
116 |
+
args,
|
117 |
+
flux,
|
118 |
+
text_encoders,
|
119 |
+
ae,
|
120 |
+
save_dir,
|
121 |
+
prompt_dict,
|
122 |
+
epoch,
|
123 |
+
steps,
|
124 |
+
sample_prompts_te_outputs,
|
125 |
+
prompt_replacement,
|
126 |
+
sample_images_ae_outputs
|
127 |
+
)
|
128 |
+
|
129 |
+
torch.set_rng_state(rng_state)
|
130 |
+
if cuda_rng_state is not None:
|
131 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
132 |
+
|
133 |
+
clean_memory_on_device(accelerator.device)
|
134 |
+
|
135 |
+
|
136 |
+
def sample_image_inference(
|
137 |
+
accelerator: Accelerator,
|
138 |
+
args: argparse.Namespace,
|
139 |
+
flux: flux_models.Flux,
|
140 |
+
text_encoders: Optional[List[CLIPTextModel]],
|
141 |
+
ae: flux_models.AutoEncoder,
|
142 |
+
save_dir,
|
143 |
+
prompt_dict,
|
144 |
+
epoch,
|
145 |
+
steps,
|
146 |
+
sample_prompts_te_outputs,
|
147 |
+
prompt_replacement,
|
148 |
+
sample_images_ae_outputs
|
149 |
+
):
|
150 |
+
assert isinstance(prompt_dict, dict)
|
151 |
+
# negative_prompt = prompt_dict.get("negative_prompt")
|
152 |
+
sample_steps = prompt_dict.get("sample_steps", 20)
|
153 |
+
width = prompt_dict.get("width", 1024) if args.frame_num==4 else prompt_dict.get("width", 1056)
|
154 |
+
height = prompt_dict.get("height", 1024) if args.frame_num==4 else prompt_dict.get("height", 1056)
|
155 |
+
scale = prompt_dict.get("scale", 1.0)
|
156 |
+
seed = prompt_dict.get("seed")
|
157 |
+
# controlnet_image = prompt_dict.get("controlnet_image")
|
158 |
+
prompt: str = prompt_dict.get("prompt", "")
|
159 |
+
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
160 |
+
|
161 |
+
if prompt_replacement is not None:
|
162 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
163 |
+
# if negative_prompt is not None:
|
164 |
+
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
165 |
+
|
166 |
+
if seed is not None:
|
167 |
+
torch.manual_seed(seed)
|
168 |
+
torch.cuda.manual_seed(seed)
|
169 |
+
else:
|
170 |
+
# True random sample image generation
|
171 |
+
torch.seed()
|
172 |
+
torch.cuda.seed()
|
173 |
+
|
174 |
+
# if negative_prompt is None:
|
175 |
+
# negative_prompt = ""
|
176 |
+
|
177 |
+
height = max(64, height - height % 16) # round to divisible by 16
|
178 |
+
width = max(64, width - width % 16) # round to divisible by 16
|
179 |
+
logger.info(f"prompt: {prompt}")
|
180 |
+
# logger.info(f"negative_prompt: {negative_prompt}")
|
181 |
+
logger.info(f"height: {height}")
|
182 |
+
logger.info(f"width: {width}")
|
183 |
+
logger.info(f"sample_steps: {sample_steps}")
|
184 |
+
logger.info(f"scale: {scale}")
|
185 |
+
# logger.info(f"sample_sampler: {sampler_name}")
|
186 |
+
if seed is not None:
|
187 |
+
logger.info(f"seed: {seed}")
|
188 |
+
|
189 |
+
# encode prompts
|
190 |
+
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
191 |
+
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
192 |
+
|
193 |
+
text_encoder_conds = []
|
194 |
+
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
195 |
+
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
196 |
+
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
197 |
+
if text_encoders is not None:
|
198 |
+
print(f"Encoding prompt: {prompt}")
|
199 |
+
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
200 |
+
# strategy has apply_t5_attn_mask option
|
201 |
+
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
202 |
+
|
203 |
+
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
204 |
+
if len(text_encoder_conds) == 0:
|
205 |
+
text_encoder_conds = encoded_text_encoder_conds
|
206 |
+
else:
|
207 |
+
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
208 |
+
for i in range(len(encoded_text_encoder_conds)):
|
209 |
+
if encoded_text_encoder_conds[i] is not None:
|
210 |
+
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
211 |
+
|
212 |
+
if sample_images_ae_outputs and prompt in sample_images_ae_outputs:
|
213 |
+
ae_outputs = sample_images_ae_outputs[prompt]
|
214 |
+
else:
|
215 |
+
ae_outputs = None
|
216 |
+
|
217 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
218 |
+
|
219 |
+
# sample image
|
220 |
+
weight_dtype = ae.dtype # TOFO give dtype as argument
|
221 |
+
packed_latent_height = height // 16
|
222 |
+
packed_latent_width = width // 16
|
223 |
+
noise = torch.randn(
|
224 |
+
1,
|
225 |
+
packed_latent_height * packed_latent_width,
|
226 |
+
16 * 2 * 2,
|
227 |
+
device=accelerator.device,
|
228 |
+
dtype=weight_dtype,
|
229 |
+
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
230 |
+
)
|
231 |
+
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
|
232 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
233 |
+
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
234 |
+
|
235 |
+
with accelerator.autocast(), torch.no_grad():
|
236 |
+
x = denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs)
|
237 |
+
|
238 |
+
x = x.float()
|
239 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
240 |
+
|
241 |
+
# latent to image
|
242 |
+
clean_memory_on_device(accelerator.device)
|
243 |
+
org_vae_device = ae.device # will be on cpu
|
244 |
+
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
245 |
+
with accelerator.autocast(), torch.no_grad():
|
246 |
+
x = ae.decode(x)
|
247 |
+
ae.to(org_vae_device)
|
248 |
+
clean_memory_on_device(accelerator.device)
|
249 |
+
|
250 |
+
x = x.clamp(-1, 1)
|
251 |
+
x = x.permute(0, 2, 3, 1)
|
252 |
+
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
253 |
+
|
254 |
+
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
255 |
+
# but adding 'enum' to the filename should be enough
|
256 |
+
|
257 |
+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
258 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
259 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
260 |
+
i: int = prompt_dict["enum"]
|
261 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
262 |
+
image.save(os.path.join(save_dir, img_filename))
|
263 |
+
|
264 |
+
# send images to wandb if enabled
|
265 |
+
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
266 |
+
wandb_tracker = accelerator.get_tracker("wandb")
|
267 |
+
|
268 |
+
import wandb
|
269 |
+
# not to commit images to avoid inconsistency between training and logging steps
|
270 |
+
wandb_tracker.log(
|
271 |
+
{f"sample_{i}": wandb.Image(
|
272 |
+
image,
|
273 |
+
caption=prompt # positive prompt as a caption
|
274 |
+
)},
|
275 |
+
commit=False
|
276 |
+
)
|
277 |
+
|
278 |
+
|
279 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
280 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
281 |
+
|
282 |
+
|
283 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
284 |
+
m = (y2 - y1) / (x2 - x1)
|
285 |
+
b = y1 - m * x1
|
286 |
+
return lambda x: m * x + b
|
287 |
+
|
288 |
+
|
289 |
+
def get_schedule(
|
290 |
+
num_steps: int,
|
291 |
+
image_seq_len: int,
|
292 |
+
base_shift: float = 0.5,
|
293 |
+
max_shift: float = 1.15,
|
294 |
+
shift: bool = True,
|
295 |
+
) -> list[float]:
|
296 |
+
# extra step for zero
|
297 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
298 |
+
|
299 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
300 |
+
if shift:
|
301 |
+
# eastimate mu based on linear estimation between two points
|
302 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
303 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
304 |
+
|
305 |
+
return timesteps.tolist()
|
306 |
+
|
307 |
+
|
308 |
+
def denoise(
|
309 |
+
args: argparse.Namespace,
|
310 |
+
model: flux_models.Flux,
|
311 |
+
img: torch.Tensor,
|
312 |
+
img_ids: torch.Tensor,
|
313 |
+
txt: torch.Tensor,
|
314 |
+
txt_ids: torch.Tensor,
|
315 |
+
vec: torch.Tensor,
|
316 |
+
timesteps: list[float],
|
317 |
+
guidance: float = 4.0,
|
318 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
319 |
+
ae_outputs: torch.Tensor = None,
|
320 |
+
):
|
321 |
+
# this is ignored for schnell
|
322 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
323 |
+
img_ids = img_ids.to(img.device)
|
324 |
+
txt_ids = txt_ids.to(img.device)
|
325 |
+
vec = vec.to(img.device)
|
326 |
+
txt = txt.to(img.device)
|
327 |
+
|
328 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
329 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
330 |
+
model.prepare_block_swap_before_forward()
|
331 |
+
if args.frame_num == 4:
|
332 |
+
packed_latent_height, packed_latent_width = ae_outputs.shape[2]*2 // 2, ae_outputs.shape[3]*2 // 2
|
333 |
+
img = flux_utils.unpack_latents(img, packed_latent_height, packed_latent_width)
|
334 |
+
img[:,:, img.shape[2] // 2: img.shape[2], :img.shape[3] // 2] = ae_outputs
|
335 |
+
else:
|
336 |
+
packed_latent_height, packed_latent_width = ae_outputs.shape[2]*3 // 2, ae_outputs.shape[3]*3 // 2
|
337 |
+
img = flux_utils.unpack_latents(img, packed_latent_height, packed_latent_width)
|
338 |
+
img[:,:, 2*img.shape[2] // 3: img.shape[2], 2*img.shape[3] // 3:img.shape[3]] = ae_outputs
|
339 |
+
|
340 |
+
img = flux_utils.pack_latents(img)
|
341 |
+
pred = model(
|
342 |
+
img=img,
|
343 |
+
img_ids=img_ids,
|
344 |
+
txt=txt,
|
345 |
+
txt_ids=txt_ids,
|
346 |
+
y=vec,
|
347 |
+
timesteps=t_vec,
|
348 |
+
guidance=guidance_vec,
|
349 |
+
txt_attention_mask=t5_attn_mask,
|
350 |
+
)
|
351 |
+
|
352 |
+
img = img + (t_prev - t_curr) * pred
|
353 |
+
|
354 |
+
model.prepare_block_swap_before_forward()
|
355 |
+
return img
|
356 |
+
|
357 |
+
|
358 |
+
# endregion
|
359 |
+
|
360 |
+
|
361 |
+
# region train
|
362 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
363 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
364 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
365 |
+
timesteps = timesteps.to(device)
|
366 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
367 |
+
|
368 |
+
sigma = sigmas[step_indices].flatten()
|
369 |
+
while len(sigma.shape) < n_dim:
|
370 |
+
sigma = sigma.unsqueeze(-1)
|
371 |
+
return sigma
|
372 |
+
|
373 |
+
|
374 |
+
def compute_density_for_timestep_sampling(
|
375 |
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
376 |
+
):
|
377 |
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
378 |
+
|
379 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
380 |
+
|
381 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
382 |
+
"""
|
383 |
+
if weighting_scheme == "logit_normal":
|
384 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
385 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
386 |
+
u = torch.nn.functional.sigmoid(u)
|
387 |
+
elif weighting_scheme == "mode":
|
388 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
389 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
390 |
+
else:
|
391 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
392 |
+
return u
|
393 |
+
|
394 |
+
|
395 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
396 |
+
"""Computes loss weighting scheme for SD3 training.
|
397 |
+
|
398 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
399 |
+
|
400 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
401 |
+
"""
|
402 |
+
if weighting_scheme == "sigma_sqrt":
|
403 |
+
weighting = (sigmas**-2.0).float()
|
404 |
+
elif weighting_scheme == "cosmap":
|
405 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
406 |
+
weighting = 2 / (math.pi * bot)
|
407 |
+
else:
|
408 |
+
weighting = torch.ones_like(sigmas)
|
409 |
+
return weighting
|
410 |
+
|
411 |
+
|
412 |
+
def get_noisy_model_input_and_timesteps(
|
413 |
+
args, noise_scheduler, latents, noise, device, dtype
|
414 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
415 |
+
bsz, _, h, w = latents.shape
|
416 |
+
sigmas = None
|
417 |
+
|
418 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
419 |
+
# Simple random t-based noise sampling
|
420 |
+
if args.timestep_sampling == "sigmoid":
|
421 |
+
# https://github.com/XLabs-AI/x-flux/tree/main
|
422 |
+
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
423 |
+
else:
|
424 |
+
t = torch.rand((bsz,), device=device)
|
425 |
+
|
426 |
+
timesteps = t * 1000.0
|
427 |
+
t = t.view(-1, 1, 1, 1)
|
428 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
429 |
+
elif args.timestep_sampling == "shift":
|
430 |
+
shift = args.discrete_flow_shift
|
431 |
+
logits_norm = torch.randn(bsz, device=device)
|
432 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
433 |
+
timesteps = logits_norm.sigmoid()
|
434 |
+
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
435 |
+
|
436 |
+
t = timesteps.view(-1, 1, 1, 1)
|
437 |
+
timesteps = timesteps * 1000.0
|
438 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
439 |
+
elif args.timestep_sampling == "flux_shift":
|
440 |
+
logits_norm = torch.randn(bsz, device=device)
|
441 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
442 |
+
timesteps = logits_norm.sigmoid()
|
443 |
+
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
444 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
445 |
+
|
446 |
+
t = timesteps.view(-1, 1, 1, 1)
|
447 |
+
timesteps = timesteps * 1000.0
|
448 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
449 |
+
else:
|
450 |
+
# Sample a random timestep for each image
|
451 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
452 |
+
u = compute_density_for_timestep_sampling(
|
453 |
+
weighting_scheme=args.weighting_scheme,
|
454 |
+
batch_size=bsz,
|
455 |
+
logit_mean=args.logit_mean,
|
456 |
+
logit_std=args.logit_std,
|
457 |
+
mode_scale=args.mode_scale,
|
458 |
+
)
|
459 |
+
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
460 |
+
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
461 |
+
|
462 |
+
# Add noise according to flow matching.
|
463 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
464 |
+
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
465 |
+
|
466 |
+
# 替换部分区域为原始latents
|
467 |
+
h, w = noisy_model_input.shape[2], noisy_model_input.shape[3]
|
468 |
+
# import pdb; pdb.set_trace()
|
469 |
+
if args.frame_num == 4:
|
470 |
+
noisy_model_input[:, :, h//2 : h, w//2 : w] = latents[:, :, h//2:h, w//2:w]
|
471 |
+
else:
|
472 |
+
noisy_model_input[:, :, 2*h//3 : h, 2*w//3 : w] = latents[:, :, 2*h//3:h, 2*w//3:w]
|
473 |
+
|
474 |
+
|
475 |
+
return noisy_model_input, timesteps, sigmas
|
476 |
+
|
477 |
+
|
478 |
+
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
479 |
+
weighting = None
|
480 |
+
if args.model_prediction_type == "raw":
|
481 |
+
pass
|
482 |
+
elif args.model_prediction_type == "additive":
|
483 |
+
# add the model_pred to the noisy_model_input
|
484 |
+
model_pred = model_pred + noisy_model_input
|
485 |
+
elif args.model_prediction_type == "sigma_scaled":
|
486 |
+
# apply sigma scaling
|
487 |
+
model_pred = model_pred * (-sigmas) + noisy_model_input
|
488 |
+
|
489 |
+
# these weighting schemes use a uniform timestep sampling
|
490 |
+
# and instead post-weight the loss
|
491 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
492 |
+
|
493 |
+
return model_pred, weighting
|
494 |
+
|
495 |
+
|
496 |
+
def save_models(
|
497 |
+
ckpt_path: str,
|
498 |
+
flux: flux_models.Flux,
|
499 |
+
sai_metadata: Optional[dict],
|
500 |
+
save_dtype: Optional[torch.dtype] = None,
|
501 |
+
use_mem_eff_save: bool = False,
|
502 |
+
):
|
503 |
+
state_dict = {}
|
504 |
+
|
505 |
+
def update_sd(prefix, sd):
|
506 |
+
for k, v in sd.items():
|
507 |
+
key = prefix + k
|
508 |
+
if save_dtype is not None and v.dtype != save_dtype:
|
509 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
510 |
+
state_dict[key] = v
|
511 |
+
|
512 |
+
update_sd("", flux.state_dict())
|
513 |
+
|
514 |
+
if not use_mem_eff_save:
|
515 |
+
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
516 |
+
else:
|
517 |
+
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
518 |
+
|
519 |
+
|
520 |
+
def save_flux_model_on_train_end(
|
521 |
+
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
|
522 |
+
):
|
523 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
524 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
525 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
526 |
+
|
527 |
+
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
528 |
+
|
529 |
+
|
530 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
531 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
532 |
+
def save_flux_model_on_epoch_end_or_stepwise(
|
533 |
+
args: argparse.Namespace,
|
534 |
+
on_epoch_end: bool,
|
535 |
+
accelerator,
|
536 |
+
save_dtype: torch.dtype,
|
537 |
+
epoch: int,
|
538 |
+
num_train_epochs: int,
|
539 |
+
global_step: int,
|
540 |
+
flux: flux_models.Flux,
|
541 |
+
):
|
542 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
543 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
544 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
545 |
+
|
546 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
547 |
+
args,
|
548 |
+
on_epoch_end,
|
549 |
+
accelerator,
|
550 |
+
True,
|
551 |
+
True,
|
552 |
+
epoch,
|
553 |
+
num_train_epochs,
|
554 |
+
global_step,
|
555 |
+
sd_saver,
|
556 |
+
None,
|
557 |
+
)
|
558 |
+
|
559 |
+
|
560 |
+
# endregion
|
561 |
+
|
562 |
+
|
563 |
+
def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
564 |
+
parser.add_argument(
|
565 |
+
"--clip_l",
|
566 |
+
type=str,
|
567 |
+
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
|
568 |
+
)
|
569 |
+
parser.add_argument(
|
570 |
+
"--t5xxl",
|
571 |
+
type=str,
|
572 |
+
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
573 |
+
)
|
574 |
+
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
575 |
+
parser.add_argument(
|
576 |
+
"--t5xxl_max_token_length",
|
577 |
+
type=int,
|
578 |
+
default=None,
|
579 |
+
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
|
580 |
+
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
581 |
+
)
|
582 |
+
parser.add_argument(
|
583 |
+
"--apply_t5_attn_mask",
|
584 |
+
action="store_true",
|
585 |
+
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
586 |
+
)
|
587 |
+
parser.add_argument(
|
588 |
+
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
589 |
+
)
|
590 |
+
parser.add_argument(
|
591 |
+
"--cache_text_encoder_outputs_to_disk",
|
592 |
+
action="store_true",
|
593 |
+
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
594 |
+
)
|
595 |
+
parser.add_argument(
|
596 |
+
"--text_encoder_batch_size",
|
597 |
+
type=int,
|
598 |
+
default=None,
|
599 |
+
help="text encoder batch size (default: None, use dataset's batch size)"
|
600 |
+
+ " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)",
|
601 |
+
)
|
602 |
+
parser.add_argument(
|
603 |
+
"--disable_mmap_load_safetensors",
|
604 |
+
action="store_true",
|
605 |
+
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
|
606 |
+
)
|
607 |
+
|
608 |
+
# copy from Diffusers
|
609 |
+
parser.add_argument(
|
610 |
+
"--weighting_scheme",
|
611 |
+
type=str,
|
612 |
+
default="none",
|
613 |
+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
|
614 |
+
)
|
615 |
+
parser.add_argument(
|
616 |
+
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
617 |
+
)
|
618 |
+
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
|
619 |
+
parser.add_argument(
|
620 |
+
"--mode_scale",
|
621 |
+
type=float,
|
622 |
+
default=1.29,
|
623 |
+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
624 |
+
)
|
625 |
+
parser.add_argument(
|
626 |
+
"--guidance_scale",
|
627 |
+
type=float,
|
628 |
+
default=3.5,
|
629 |
+
help="the FLUX.1 dev variant is a guidance distilled model",
|
630 |
+
)
|
631 |
+
|
632 |
+
parser.add_argument(
|
633 |
+
"--timestep_sampling",
|
634 |
+
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
635 |
+
default="sigma",
|
636 |
+
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
|
637 |
+
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
|
638 |
+
)
|
639 |
+
parser.add_argument(
|
640 |
+
"--sigmoid_scale",
|
641 |
+
type=float,
|
642 |
+
default=1.0,
|
643 |
+
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
644 |
+
)
|
645 |
+
parser.add_argument(
|
646 |
+
"--model_prediction_type",
|
647 |
+
choices=["raw", "additive", "sigma_scaled"],
|
648 |
+
default="sigma_scaled",
|
649 |
+
help="How to interpret and process the model prediction: "
|
650 |
+
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
651 |
+
" / モデル予測の解釈と処理方法:"
|
652 |
+
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
653 |
+
)
|
654 |
+
parser.add_argument(
|
655 |
+
"--discrete_flow_shift",
|
656 |
+
type=float,
|
657 |
+
default=3.0,
|
658 |
+
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
659 |
+
)
|
library/flux_utils.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import replace
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
import einops
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from safetensors.torch import load_file
|
9 |
+
from safetensors import safe_open
|
10 |
+
from accelerate import init_empty_weights
|
11 |
+
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
|
12 |
+
|
13 |
+
from library.utils import setup_logging
|
14 |
+
|
15 |
+
setup_logging()
|
16 |
+
import logging
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
from library import flux_models
|
21 |
+
from library.utils import load_safetensors
|
22 |
+
|
23 |
+
MODEL_VERSION_FLUX_V1 = "flux1"
|
24 |
+
MODEL_NAME_DEV = "dev"
|
25 |
+
MODEL_NAME_SCHNELL = "schnell"
|
26 |
+
|
27 |
+
|
28 |
+
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
29 |
+
"""
|
30 |
+
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
|
31 |
+
|
32 |
+
Args:
|
33 |
+
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Tuple[bool, bool, Tuple[int, int], List[str]]:
|
37 |
+
- bool: Diffusersかどうかを示すフラグ。
|
38 |
+
- bool: Schnellかどうかを示すフラグ。
|
39 |
+
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
|
40 |
+
- List[str]: チェックポイントに含まれるキーのリスト。
|
41 |
+
"""
|
42 |
+
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
|
43 |
+
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
|
44 |
+
|
45 |
+
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
|
46 |
+
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
|
47 |
+
if "00001-of-00003" in ckpt_path:
|
48 |
+
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
|
49 |
+
else:
|
50 |
+
ckpt_paths = [ckpt_path]
|
51 |
+
|
52 |
+
keys = []
|
53 |
+
for ckpt_path in ckpt_paths:
|
54 |
+
with safe_open(ckpt_path, framework="pt") as f:
|
55 |
+
keys.extend(f.keys())
|
56 |
+
|
57 |
+
# if the key has annoying prefix, remove it
|
58 |
+
if keys[0].startswith("model.diffusion_model."):
|
59 |
+
keys = [key.replace("model.diffusion_model.", "") for key in keys]
|
60 |
+
|
61 |
+
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
62 |
+
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
63 |
+
|
64 |
+
# check number of double and single blocks
|
65 |
+
if not is_diffusers:
|
66 |
+
max_double_block_index = max(
|
67 |
+
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
|
68 |
+
)
|
69 |
+
max_single_block_index = max(
|
70 |
+
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
max_double_block_index = max(
|
74 |
+
[
|
75 |
+
int(key.split(".")[1])
|
76 |
+
for key in keys
|
77 |
+
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
|
78 |
+
]
|
79 |
+
)
|
80 |
+
max_single_block_index = max(
|
81 |
+
[
|
82 |
+
int(key.split(".")[1])
|
83 |
+
for key in keys
|
84 |
+
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
|
85 |
+
]
|
86 |
+
)
|
87 |
+
|
88 |
+
num_double_blocks = max_double_block_index + 1
|
89 |
+
num_single_blocks = max_single_block_index + 1
|
90 |
+
|
91 |
+
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
|
92 |
+
|
93 |
+
|
94 |
+
def load_flow_model(
|
95 |
+
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
96 |
+
) -> Tuple[bool, flux_models.Flux]:
|
97 |
+
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
98 |
+
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
99 |
+
|
100 |
+
# build model
|
101 |
+
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
102 |
+
with torch.device("meta"):
|
103 |
+
params = flux_models.configs[name].params
|
104 |
+
|
105 |
+
# set the number of blocks
|
106 |
+
if params.depth != num_double_blocks:
|
107 |
+
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
|
108 |
+
params = replace(params, depth=num_double_blocks)
|
109 |
+
if params.depth_single_blocks != num_single_blocks:
|
110 |
+
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
|
111 |
+
params = replace(params, depth_single_blocks=num_single_blocks)
|
112 |
+
|
113 |
+
model = flux_models.Flux(params)
|
114 |
+
if dtype is not None:
|
115 |
+
model = model.to(dtype)
|
116 |
+
|
117 |
+
# load_sft doesn't support torch.device
|
118 |
+
logger.info(f"Loading state dict from {ckpt_path}")
|
119 |
+
sd = {}
|
120 |
+
for ckpt_path in ckpt_paths:
|
121 |
+
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
|
122 |
+
|
123 |
+
# convert Diffusers to BFL
|
124 |
+
if is_diffusers:
|
125 |
+
logger.info("Converting Diffusers to BFL")
|
126 |
+
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
127 |
+
logger.info("Converted Diffusers to BFL")
|
128 |
+
|
129 |
+
# if the key has annoying prefix, remove it
|
130 |
+
for key in list(sd.keys()):
|
131 |
+
new_key = key.replace("model.diffusion_model.", "")
|
132 |
+
if new_key == key:
|
133 |
+
break # the model doesn't have annoying prefix
|
134 |
+
sd[new_key] = sd.pop(key)
|
135 |
+
|
136 |
+
info = model.load_state_dict(sd, strict=False, assign=True)
|
137 |
+
logger.info(f"Loaded Flux: {info}")
|
138 |
+
return is_schnell, model
|
139 |
+
|
140 |
+
|
141 |
+
def load_ae(
|
142 |
+
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
143 |
+
) -> flux_models.AutoEncoder:
|
144 |
+
logger.info("Building AutoEncoder")
|
145 |
+
with torch.device("meta"):
|
146 |
+
# dev and schnell have the same AE params
|
147 |
+
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
|
148 |
+
|
149 |
+
logger.info(f"Loading state dict from {ckpt_path}")
|
150 |
+
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
151 |
+
info = ae.load_state_dict(sd, strict=False, assign=True)
|
152 |
+
logger.info(f"Loaded AE: {info}")
|
153 |
+
return ae
|
154 |
+
|
155 |
+
|
156 |
+
def load_clip_l(
|
157 |
+
ckpt_path: Optional[str],
|
158 |
+
dtype: torch.dtype,
|
159 |
+
device: Union[str, torch.device],
|
160 |
+
disable_mmap: bool = False,
|
161 |
+
state_dict: Optional[dict] = None,
|
162 |
+
) -> CLIPTextModel:
|
163 |
+
logger.info("Building CLIP-L")
|
164 |
+
CLIPL_CONFIG = {
|
165 |
+
"_name_or_path": "clip-vit-large-patch14/",
|
166 |
+
"architectures": ["CLIPModel"],
|
167 |
+
"initializer_factor": 1.0,
|
168 |
+
"logit_scale_init_value": 2.6592,
|
169 |
+
"model_type": "clip",
|
170 |
+
"projection_dim": 768,
|
171 |
+
# "text_config": {
|
172 |
+
"_name_or_path": "",
|
173 |
+
"add_cross_attention": False,
|
174 |
+
"architectures": None,
|
175 |
+
"attention_dropout": 0.0,
|
176 |
+
"bad_words_ids": None,
|
177 |
+
"bos_token_id": 0,
|
178 |
+
"chunk_size_feed_forward": 0,
|
179 |
+
"cross_attention_hidden_size": None,
|
180 |
+
"decoder_start_token_id": None,
|
181 |
+
"diversity_penalty": 0.0,
|
182 |
+
"do_sample": False,
|
183 |
+
"dropout": 0.0,
|
184 |
+
"early_stopping": False,
|
185 |
+
"encoder_no_repeat_ngram_size": 0,
|
186 |
+
"eos_token_id": 2,
|
187 |
+
"finetuning_task": None,
|
188 |
+
"forced_bos_token_id": None,
|
189 |
+
"forced_eos_token_id": None,
|
190 |
+
"hidden_act": "quick_gelu",
|
191 |
+
"hidden_size": 768,
|
192 |
+
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
|
193 |
+
"initializer_factor": 1.0,
|
194 |
+
"initializer_range": 0.02,
|
195 |
+
"intermediate_size": 3072,
|
196 |
+
"is_decoder": False,
|
197 |
+
"is_encoder_decoder": False,
|
198 |
+
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
|
199 |
+
"layer_norm_eps": 1e-05,
|
200 |
+
"length_penalty": 1.0,
|
201 |
+
"max_length": 20,
|
202 |
+
"max_position_embeddings": 77,
|
203 |
+
"min_length": 0,
|
204 |
+
"model_type": "clip_text_model",
|
205 |
+
"no_repeat_ngram_size": 0,
|
206 |
+
"num_attention_heads": 12,
|
207 |
+
"num_beam_groups": 1,
|
208 |
+
"num_beams": 1,
|
209 |
+
"num_hidden_layers": 12,
|
210 |
+
"num_return_sequences": 1,
|
211 |
+
"output_attentions": False,
|
212 |
+
"output_hidden_states": False,
|
213 |
+
"output_scores": False,
|
214 |
+
"pad_token_id": 1,
|
215 |
+
"prefix": None,
|
216 |
+
"problem_type": None,
|
217 |
+
"projection_dim": 768,
|
218 |
+
"pruned_heads": {},
|
219 |
+
"remove_invalid_values": False,
|
220 |
+
"repetition_penalty": 1.0,
|
221 |
+
"return_dict": True,
|
222 |
+
"return_dict_in_generate": False,
|
223 |
+
"sep_token_id": None,
|
224 |
+
"task_specific_params": None,
|
225 |
+
"temperature": 1.0,
|
226 |
+
"tie_encoder_decoder": False,
|
227 |
+
"tie_word_embeddings": True,
|
228 |
+
"tokenizer_class": None,
|
229 |
+
"top_k": 50,
|
230 |
+
"top_p": 1.0,
|
231 |
+
"torch_dtype": None,
|
232 |
+
"torchscript": False,
|
233 |
+
"transformers_version": "4.16.0.dev0",
|
234 |
+
"use_bfloat16": False,
|
235 |
+
"vocab_size": 49408,
|
236 |
+
"hidden_act": "gelu",
|
237 |
+
"hidden_size": 1280,
|
238 |
+
"intermediate_size": 5120,
|
239 |
+
"num_attention_heads": 20,
|
240 |
+
"num_hidden_layers": 32,
|
241 |
+
# },
|
242 |
+
# "text_config_dict": {
|
243 |
+
"hidden_size": 768,
|
244 |
+
"intermediate_size": 3072,
|
245 |
+
"num_attention_heads": 12,
|
246 |
+
"num_hidden_layers": 12,
|
247 |
+
"projection_dim": 768,
|
248 |
+
# },
|
249 |
+
# "torch_dtype": "float32",
|
250 |
+
# "transformers_version": None,
|
251 |
+
}
|
252 |
+
config = CLIPConfig(**CLIPL_CONFIG)
|
253 |
+
with init_empty_weights():
|
254 |
+
clip = CLIPTextModel._from_config(config)
|
255 |
+
|
256 |
+
if state_dict is not None:
|
257 |
+
sd = state_dict
|
258 |
+
else:
|
259 |
+
logger.info(f"Loading state dict from {ckpt_path}")
|
260 |
+
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
261 |
+
info = clip.load_state_dict(sd, strict=False, assign=True)
|
262 |
+
logger.info(f"Loaded CLIP-L: {info}")
|
263 |
+
return clip
|
264 |
+
|
265 |
+
|
266 |
+
def load_t5xxl(
|
267 |
+
ckpt_path: str,
|
268 |
+
dtype: Optional[torch.dtype],
|
269 |
+
device: Union[str, torch.device],
|
270 |
+
disable_mmap: bool = False,
|
271 |
+
state_dict: Optional[dict] = None,
|
272 |
+
) -> T5EncoderModel:
|
273 |
+
T5_CONFIG_JSON = """
|
274 |
+
{
|
275 |
+
"architectures": [
|
276 |
+
"T5EncoderModel"
|
277 |
+
],
|
278 |
+
"classifier_dropout": 0.0,
|
279 |
+
"d_ff": 10240,
|
280 |
+
"d_kv": 64,
|
281 |
+
"d_model": 4096,
|
282 |
+
"decoder_start_token_id": 0,
|
283 |
+
"dense_act_fn": "gelu_new",
|
284 |
+
"dropout_rate": 0.1,
|
285 |
+
"eos_token_id": 1,
|
286 |
+
"feed_forward_proj": "gated-gelu",
|
287 |
+
"initializer_factor": 1.0,
|
288 |
+
"is_encoder_decoder": true,
|
289 |
+
"is_gated_act": true,
|
290 |
+
"layer_norm_epsilon": 1e-06,
|
291 |
+
"model_type": "t5",
|
292 |
+
"num_decoder_layers": 24,
|
293 |
+
"num_heads": 64,
|
294 |
+
"num_layers": 24,
|
295 |
+
"output_past": true,
|
296 |
+
"pad_token_id": 0,
|
297 |
+
"relative_attention_max_distance": 128,
|
298 |
+
"relative_attention_num_buckets": 32,
|
299 |
+
"tie_word_embeddings": false,
|
300 |
+
"torch_dtype": "float16",
|
301 |
+
"transformers_version": "4.41.2",
|
302 |
+
"use_cache": true,
|
303 |
+
"vocab_size": 32128
|
304 |
+
}
|
305 |
+
"""
|
306 |
+
config = json.loads(T5_CONFIG_JSON)
|
307 |
+
config = T5Config(**config)
|
308 |
+
with init_empty_weights():
|
309 |
+
t5xxl = T5EncoderModel._from_config(config)
|
310 |
+
|
311 |
+
if state_dict is not None:
|
312 |
+
sd = state_dict
|
313 |
+
else:
|
314 |
+
logger.info(f"Loading state dict from {ckpt_path}")
|
315 |
+
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
316 |
+
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
|
317 |
+
logger.info(f"Loaded T5xxl: {info}")
|
318 |
+
return t5xxl
|
319 |
+
|
320 |
+
|
321 |
+
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
|
322 |
+
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
|
323 |
+
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
|
324 |
+
|
325 |
+
|
326 |
+
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
327 |
+
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
328 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
329 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
|
330 |
+
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
331 |
+
return img_ids
|
332 |
+
|
333 |
+
|
334 |
+
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
335 |
+
"""
|
336 |
+
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
337 |
+
"""
|
338 |
+
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
339 |
+
return x
|
340 |
+
|
341 |
+
|
342 |
+
def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
343 |
+
"""
|
344 |
+
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
|
345 |
+
"""
|
346 |
+
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
# region Diffusers
|
351 |
+
|
352 |
+
NUM_DOUBLE_BLOCKS = 19
|
353 |
+
NUM_SINGLE_BLOCKS = 38
|
354 |
+
|
355 |
+
BFL_TO_DIFFUSERS_MAP = {
|
356 |
+
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
|
357 |
+
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
|
358 |
+
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
|
359 |
+
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
|
360 |
+
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
|
361 |
+
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
|
362 |
+
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
|
363 |
+
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
|
364 |
+
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
|
365 |
+
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
|
366 |
+
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
|
367 |
+
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
|
368 |
+
"txt_in.weight": ["context_embedder.weight"],
|
369 |
+
"txt_in.bias": ["context_embedder.bias"],
|
370 |
+
"img_in.weight": ["x_embedder.weight"],
|
371 |
+
"img_in.bias": ["x_embedder.bias"],
|
372 |
+
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
|
373 |
+
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
|
374 |
+
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
|
375 |
+
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
|
376 |
+
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
|
377 |
+
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
|
378 |
+
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
|
379 |
+
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
|
380 |
+
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
|
381 |
+
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
|
382 |
+
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
|
383 |
+
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
|
384 |
+
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
|
385 |
+
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
|
386 |
+
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
|
387 |
+
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
|
388 |
+
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
|
389 |
+
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
|
390 |
+
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
|
391 |
+
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
|
392 |
+
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
|
393 |
+
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
|
394 |
+
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
|
395 |
+
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
|
396 |
+
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
|
397 |
+
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
|
398 |
+
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
|
399 |
+
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
|
400 |
+
"single_blocks.().linear2.weight": ["proj_out.weight"],
|
401 |
+
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
|
402 |
+
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
|
403 |
+
"single_blocks.().linear2.weight": ["proj_out.weight"],
|
404 |
+
"single_blocks.().linear2.bias": ["proj_out.bias"],
|
405 |
+
"final_layer.linear.weight": ["proj_out.weight"],
|
406 |
+
"final_layer.linear.bias": ["proj_out.bias"],
|
407 |
+
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
|
408 |
+
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
|
409 |
+
}
|
410 |
+
|
411 |
+
|
412 |
+
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
|
413 |
+
# make reverse map from diffusers map
|
414 |
+
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
|
415 |
+
for b in range(num_double_blocks):
|
416 |
+
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
417 |
+
if key.startswith("double_blocks."):
|
418 |
+
block_prefix = f"transformer_blocks.{b}."
|
419 |
+
for i, weight in enumerate(weights):
|
420 |
+
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
421 |
+
for b in range(num_single_blocks):
|
422 |
+
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
423 |
+
if key.startswith("single_blocks."):
|
424 |
+
block_prefix = f"single_transformer_blocks.{b}."
|
425 |
+
for i, weight in enumerate(weights):
|
426 |
+
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
|
427 |
+
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
|
428 |
+
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
|
429 |
+
for i, weight in enumerate(weights):
|
430 |
+
diffusers_to_bfl_map[weight] = (i, key)
|
431 |
+
return diffusers_to_bfl_map
|
432 |
+
|
433 |
+
|
434 |
+
def convert_diffusers_sd_to_bfl(
|
435 |
+
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
|
436 |
+
) -> dict[str, torch.Tensor]:
|
437 |
+
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
|
438 |
+
|
439 |
+
# iterate over three safetensors files to reduce memory usage
|
440 |
+
flux_sd = {}
|
441 |
+
for diffusers_key, tensor in diffusers_sd.items():
|
442 |
+
if diffusers_key in diffusers_to_bfl_map:
|
443 |
+
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
|
444 |
+
if bfl_key not in flux_sd:
|
445 |
+
flux_sd[bfl_key] = []
|
446 |
+
flux_sd[bfl_key].append((index, tensor))
|
447 |
+
else:
|
448 |
+
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
|
449 |
+
raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
|
450 |
+
|
451 |
+
# concat tensors if multiple tensors are mapped to a single key, sort by index
|
452 |
+
for key, values in flux_sd.items():
|
453 |
+
if len(values) == 1:
|
454 |
+
flux_sd[key] = values[0][1]
|
455 |
+
else:
|
456 |
+
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
|
457 |
+
|
458 |
+
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
|
459 |
+
def swap_scale_shift(weight):
|
460 |
+
shift, scale = weight.chunk(2, dim=0)
|
461 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
462 |
+
return new_weight
|
463 |
+
|
464 |
+
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
|
465 |
+
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
|
466 |
+
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
|
467 |
+
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
|
468 |
+
|
469 |
+
return flux_sd
|
470 |
+
|
471 |
+
|
472 |
+
# endregion
|
library/huggingface_util.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, BinaryIO
|
2 |
+
from huggingface_hub import HfApi
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
from library.utils import fire_in_thread
|
7 |
+
from library.utils import setup_logging
|
8 |
+
setup_logging()
|
9 |
+
import logging
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
13 |
+
api = HfApi(
|
14 |
+
token=token,
|
15 |
+
)
|
16 |
+
try:
|
17 |
+
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
18 |
+
return True
|
19 |
+
except:
|
20 |
+
return False
|
21 |
+
|
22 |
+
|
23 |
+
def upload(
|
24 |
+
args: argparse.Namespace,
|
25 |
+
src: Union[str, Path, bytes, BinaryIO],
|
26 |
+
dest_suffix: str = "",
|
27 |
+
force_sync_upload: bool = False,
|
28 |
+
):
|
29 |
+
repo_id = args.huggingface_repo_id
|
30 |
+
repo_type = args.huggingface_repo_type
|
31 |
+
token = args.huggingface_token
|
32 |
+
path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
|
33 |
+
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
34 |
+
api = HfApi(token=token)
|
35 |
+
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
36 |
+
try:
|
37 |
+
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
38 |
+
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
39 |
+
logger.error("===========================================")
|
40 |
+
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
41 |
+
logger.error("===========================================")
|
42 |
+
|
43 |
+
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
44 |
+
|
45 |
+
def uploader():
|
46 |
+
try:
|
47 |
+
if is_folder:
|
48 |
+
api.upload_folder(
|
49 |
+
repo_id=repo_id,
|
50 |
+
repo_type=repo_type,
|
51 |
+
folder_path=src,
|
52 |
+
path_in_repo=path_in_repo,
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
api.upload_file(
|
56 |
+
repo_id=repo_id,
|
57 |
+
repo_type=repo_type,
|
58 |
+
path_or_fileobj=src,
|
59 |
+
path_in_repo=path_in_repo,
|
60 |
+
)
|
61 |
+
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
62 |
+
logger.error("===========================================")
|
63 |
+
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
64 |
+
logger.error("===========================================")
|
65 |
+
|
66 |
+
if args.async_upload and not force_sync_upload:
|
67 |
+
fire_in_thread(uploader)
|
68 |
+
else:
|
69 |
+
uploader()
|
70 |
+
|
71 |
+
|
72 |
+
def list_dir(
|
73 |
+
repo_id: str,
|
74 |
+
subfolder: str,
|
75 |
+
repo_type: str,
|
76 |
+
revision: str = "main",
|
77 |
+
token: str = None,
|
78 |
+
):
|
79 |
+
api = HfApi(
|
80 |
+
token=token,
|
81 |
+
)
|
82 |
+
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
83 |
+
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
84 |
+
return file_list
|
library/hypernetwork.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from diffusers.models.attention_processor import (
|
4 |
+
Attention,
|
5 |
+
AttnProcessor2_0,
|
6 |
+
SlicedAttnProcessor,
|
7 |
+
XFormersAttnProcessor
|
8 |
+
)
|
9 |
+
|
10 |
+
try:
|
11 |
+
import xformers.ops
|
12 |
+
except:
|
13 |
+
xformers = None
|
14 |
+
|
15 |
+
|
16 |
+
loaded_networks = []
|
17 |
+
|
18 |
+
|
19 |
+
def apply_single_hypernetwork(
|
20 |
+
hypernetwork, hidden_states, encoder_hidden_states
|
21 |
+
):
|
22 |
+
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
|
23 |
+
return context_k, context_v
|
24 |
+
|
25 |
+
|
26 |
+
def apply_hypernetworks(context_k, context_v, layer=None):
|
27 |
+
if len(loaded_networks) == 0:
|
28 |
+
return context_v, context_v
|
29 |
+
for hypernetwork in loaded_networks:
|
30 |
+
context_k, context_v = hypernetwork.forward(context_k, context_v)
|
31 |
+
|
32 |
+
context_k = context_k.to(dtype=context_k.dtype)
|
33 |
+
context_v = context_v.to(dtype=context_k.dtype)
|
34 |
+
|
35 |
+
return context_k, context_v
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def xformers_forward(
|
40 |
+
self: XFormersAttnProcessor,
|
41 |
+
attn: Attention,
|
42 |
+
hidden_states: torch.Tensor,
|
43 |
+
encoder_hidden_states: torch.Tensor = None,
|
44 |
+
attention_mask: torch.Tensor = None,
|
45 |
+
):
|
46 |
+
batch_size, sequence_length, _ = (
|
47 |
+
hidden_states.shape
|
48 |
+
if encoder_hidden_states is None
|
49 |
+
else encoder_hidden_states.shape
|
50 |
+
)
|
51 |
+
|
52 |
+
attention_mask = attn.prepare_attention_mask(
|
53 |
+
attention_mask, sequence_length, batch_size
|
54 |
+
)
|
55 |
+
|
56 |
+
query = attn.to_q(hidden_states)
|
57 |
+
|
58 |
+
if encoder_hidden_states is None:
|
59 |
+
encoder_hidden_states = hidden_states
|
60 |
+
elif attn.norm_cross:
|
61 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
62 |
+
|
63 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
64 |
+
|
65 |
+
key = attn.to_k(context_k)
|
66 |
+
value = attn.to_v(context_v)
|
67 |
+
|
68 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
69 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
70 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
71 |
+
|
72 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
73 |
+
query,
|
74 |
+
key,
|
75 |
+
value,
|
76 |
+
attn_bias=attention_mask,
|
77 |
+
op=self.attention_op,
|
78 |
+
scale=attn.scale,
|
79 |
+
)
|
80 |
+
hidden_states = hidden_states.to(query.dtype)
|
81 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
82 |
+
|
83 |
+
# linear proj
|
84 |
+
hidden_states = attn.to_out[0](hidden_states)
|
85 |
+
# dropout
|
86 |
+
hidden_states = attn.to_out[1](hidden_states)
|
87 |
+
return hidden_states
|
88 |
+
|
89 |
+
|
90 |
+
def sliced_attn_forward(
|
91 |
+
self: SlicedAttnProcessor,
|
92 |
+
attn: Attention,
|
93 |
+
hidden_states: torch.Tensor,
|
94 |
+
encoder_hidden_states: torch.Tensor = None,
|
95 |
+
attention_mask: torch.Tensor = None,
|
96 |
+
):
|
97 |
+
batch_size, sequence_length, _ = (
|
98 |
+
hidden_states.shape
|
99 |
+
if encoder_hidden_states is None
|
100 |
+
else encoder_hidden_states.shape
|
101 |
+
)
|
102 |
+
attention_mask = attn.prepare_attention_mask(
|
103 |
+
attention_mask, sequence_length, batch_size
|
104 |
+
)
|
105 |
+
|
106 |
+
query = attn.to_q(hidden_states)
|
107 |
+
dim = query.shape[-1]
|
108 |
+
query = attn.head_to_batch_dim(query)
|
109 |
+
|
110 |
+
if encoder_hidden_states is None:
|
111 |
+
encoder_hidden_states = hidden_states
|
112 |
+
elif attn.norm_cross:
|
113 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
114 |
+
|
115 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
116 |
+
|
117 |
+
key = attn.to_k(context_k)
|
118 |
+
value = attn.to_v(context_v)
|
119 |
+
key = attn.head_to_batch_dim(key)
|
120 |
+
value = attn.head_to_batch_dim(value)
|
121 |
+
|
122 |
+
batch_size_attention, query_tokens, _ = query.shape
|
123 |
+
hidden_states = torch.zeros(
|
124 |
+
(batch_size_attention, query_tokens, dim // attn.heads),
|
125 |
+
device=query.device,
|
126 |
+
dtype=query.dtype,
|
127 |
+
)
|
128 |
+
|
129 |
+
for i in range(batch_size_attention // self.slice_size):
|
130 |
+
start_idx = i * self.slice_size
|
131 |
+
end_idx = (i + 1) * self.slice_size
|
132 |
+
|
133 |
+
query_slice = query[start_idx:end_idx]
|
134 |
+
key_slice = key[start_idx:end_idx]
|
135 |
+
attn_mask_slice = (
|
136 |
+
attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
137 |
+
)
|
138 |
+
|
139 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
140 |
+
|
141 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
142 |
+
|
143 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
144 |
+
|
145 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
146 |
+
|
147 |
+
# linear proj
|
148 |
+
hidden_states = attn.to_out[0](hidden_states)
|
149 |
+
# dropout
|
150 |
+
hidden_states = attn.to_out[1](hidden_states)
|
151 |
+
|
152 |
+
return hidden_states
|
153 |
+
|
154 |
+
|
155 |
+
def v2_0_forward(
|
156 |
+
self: AttnProcessor2_0,
|
157 |
+
attn: Attention,
|
158 |
+
hidden_states,
|
159 |
+
encoder_hidden_states=None,
|
160 |
+
attention_mask=None,
|
161 |
+
):
|
162 |
+
batch_size, sequence_length, _ = (
|
163 |
+
hidden_states.shape
|
164 |
+
if encoder_hidden_states is None
|
165 |
+
else encoder_hidden_states.shape
|
166 |
+
)
|
167 |
+
inner_dim = hidden_states.shape[-1]
|
168 |
+
|
169 |
+
if attention_mask is not None:
|
170 |
+
attention_mask = attn.prepare_attention_mask(
|
171 |
+
attention_mask, sequence_length, batch_size
|
172 |
+
)
|
173 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
174 |
+
# (batch, heads, source_length, target_length)
|
175 |
+
attention_mask = attention_mask.view(
|
176 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
177 |
+
)
|
178 |
+
|
179 |
+
query = attn.to_q(hidden_states)
|
180 |
+
|
181 |
+
if encoder_hidden_states is None:
|
182 |
+
encoder_hidden_states = hidden_states
|
183 |
+
elif attn.norm_cross:
|
184 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
185 |
+
|
186 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
187 |
+
|
188 |
+
key = attn.to_k(context_k)
|
189 |
+
value = attn.to_v(context_v)
|
190 |
+
|
191 |
+
head_dim = inner_dim // attn.heads
|
192 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
193 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
194 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
195 |
+
|
196 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
197 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
198 |
+
hidden_states = F.scaled_dot_product_attention(
|
199 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
200 |
+
)
|
201 |
+
|
202 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
203 |
+
batch_size, -1, attn.heads * head_dim
|
204 |
+
)
|
205 |
+
hidden_states = hidden_states.to(query.dtype)
|
206 |
+
|
207 |
+
# linear proj
|
208 |
+
hidden_states = attn.to_out[0](hidden_states)
|
209 |
+
# dropout
|
210 |
+
hidden_states = attn.to_out[1](hidden_states)
|
211 |
+
return hidden_states
|
212 |
+
|
213 |
+
|
214 |
+
def replace_attentions_for_hypernetwork():
|
215 |
+
import diffusers.models.attention_processor
|
216 |
+
|
217 |
+
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
|
218 |
+
xformers_forward
|
219 |
+
)
|
220 |
+
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
|
221 |
+
sliced_attn_forward
|
222 |
+
)
|
223 |
+
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
|
library/ipex/__init__.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import contextlib
|
4 |
+
import torch
|
5 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
6 |
+
from .hijacks import ipex_hijacks
|
7 |
+
|
8 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
9 |
+
|
10 |
+
def ipex_init(): # pylint: disable=too-many-statements
|
11 |
+
try:
|
12 |
+
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
13 |
+
return True, "Skipping IPEX hijack"
|
14 |
+
else:
|
15 |
+
# Replace cuda with xpu:
|
16 |
+
torch.cuda.current_device = torch.xpu.current_device
|
17 |
+
torch.cuda.current_stream = torch.xpu.current_stream
|
18 |
+
torch.cuda.device = torch.xpu.device
|
19 |
+
torch.cuda.device_count = torch.xpu.device_count
|
20 |
+
torch.cuda.device_of = torch.xpu.device_of
|
21 |
+
torch.cuda.get_device_name = torch.xpu.get_device_name
|
22 |
+
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
23 |
+
torch.cuda.init = torch.xpu.init
|
24 |
+
torch.cuda.is_available = torch.xpu.is_available
|
25 |
+
torch.cuda.is_initialized = torch.xpu.is_initialized
|
26 |
+
torch.cuda.is_current_stream_capturing = lambda: False
|
27 |
+
torch.cuda.set_device = torch.xpu.set_device
|
28 |
+
torch.cuda.stream = torch.xpu.stream
|
29 |
+
torch.cuda.synchronize = torch.xpu.synchronize
|
30 |
+
torch.cuda.Event = torch.xpu.Event
|
31 |
+
torch.cuda.Stream = torch.xpu.Stream
|
32 |
+
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
33 |
+
torch.Tensor.cuda = torch.Tensor.xpu
|
34 |
+
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
35 |
+
torch.nn.Module.cuda = torch.nn.Module.xpu
|
36 |
+
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
37 |
+
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
38 |
+
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
39 |
+
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
40 |
+
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
41 |
+
torch.cuda._tls = torch.xpu.lazy_init._tls
|
42 |
+
torch.cuda.threading = torch.xpu.lazy_init.threading
|
43 |
+
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
44 |
+
torch.cuda.Optional = torch.xpu.Optional
|
45 |
+
torch.cuda.__cached__ = torch.xpu.__cached__
|
46 |
+
torch.cuda.__loader__ = torch.xpu.__loader__
|
47 |
+
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
48 |
+
torch.cuda.Tuple = torch.xpu.Tuple
|
49 |
+
torch.cuda.streams = torch.xpu.streams
|
50 |
+
torch.cuda._lazy_new = torch.xpu._lazy_new
|
51 |
+
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
52 |
+
torch.cuda.Any = torch.xpu.Any
|
53 |
+
torch.cuda.__doc__ = torch.xpu.__doc__
|
54 |
+
torch.cuda.default_generators = torch.xpu.default_generators
|
55 |
+
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
56 |
+
torch.cuda._get_device_index = torch.xpu._get_device_index
|
57 |
+
torch.cuda.__path__ = torch.xpu.__path__
|
58 |
+
torch.cuda.Device = torch.xpu.Device
|
59 |
+
torch.cuda.IntTensor = torch.xpu.IntTensor
|
60 |
+
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
61 |
+
torch.cuda.set_stream = torch.xpu.set_stream
|
62 |
+
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
63 |
+
torch.cuda.os = torch.xpu.os
|
64 |
+
torch.cuda.torch = torch.xpu.torch
|
65 |
+
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
66 |
+
torch.cuda.Union = torch.xpu.Union
|
67 |
+
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
68 |
+
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
69 |
+
torch.cuda.LongTensor = torch.xpu.LongTensor
|
70 |
+
torch.cuda.IntStorage = torch.xpu.IntStorage
|
71 |
+
torch.cuda.LongStorage = torch.xpu.LongStorage
|
72 |
+
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
73 |
+
torch.cuda.__package__ = torch.xpu.__package__
|
74 |
+
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
75 |
+
torch.cuda.CharTensor = torch.xpu.CharTensor
|
76 |
+
torch.cuda.List = torch.xpu.List
|
77 |
+
torch.cuda._lazy_init = torch.xpu._lazy_init
|
78 |
+
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
79 |
+
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
80 |
+
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
81 |
+
torch.cuda.StreamContext = torch.xpu.StreamContext
|
82 |
+
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
83 |
+
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
84 |
+
torch.cuda._lazy_call = torch.xpu._lazy_call
|
85 |
+
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
86 |
+
torch.cuda.random = torch.xpu.random
|
87 |
+
torch.cuda._device = torch.xpu._device
|
88 |
+
torch.cuda.classproperty = torch.xpu.classproperty
|
89 |
+
torch.cuda.__name__ = torch.xpu.__name__
|
90 |
+
torch.cuda._device_t = torch.xpu._device_t
|
91 |
+
torch.cuda.warnings = torch.xpu.warnings
|
92 |
+
torch.cuda.__spec__ = torch.xpu.__spec__
|
93 |
+
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
94 |
+
torch.cuda.CharStorage = torch.xpu.CharStorage
|
95 |
+
torch.cuda.__file__ = torch.xpu.__file__
|
96 |
+
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
97 |
+
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
98 |
+
|
99 |
+
# Memory:
|
100 |
+
torch.cuda.memory = torch.xpu.memory
|
101 |
+
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
102 |
+
torch.xpu.empty_cache = lambda: None
|
103 |
+
torch.cuda.empty_cache = torch.xpu.empty_cache
|
104 |
+
torch.cuda.memory_stats = torch.xpu.memory_stats
|
105 |
+
torch.cuda.memory_summary = torch.xpu.memory_summary
|
106 |
+
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
107 |
+
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
108 |
+
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
109 |
+
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
110 |
+
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
111 |
+
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
112 |
+
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
113 |
+
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
114 |
+
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
115 |
+
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
116 |
+
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
117 |
+
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
118 |
+
|
119 |
+
# RNG:
|
120 |
+
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
121 |
+
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
122 |
+
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
123 |
+
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
124 |
+
torch.cuda.manual_seed = torch.xpu.manual_seed
|
125 |
+
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
126 |
+
torch.cuda.seed = torch.xpu.seed
|
127 |
+
torch.cuda.seed_all = torch.xpu.seed_all
|
128 |
+
torch.cuda.initial_seed = torch.xpu.initial_seed
|
129 |
+
|
130 |
+
# AMP:
|
131 |
+
torch.cuda.amp = torch.xpu.amp
|
132 |
+
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
133 |
+
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
134 |
+
|
135 |
+
if not hasattr(torch.cuda.amp, "common"):
|
136 |
+
torch.cuda.amp.common = contextlib.nullcontext()
|
137 |
+
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
138 |
+
|
139 |
+
try:
|
140 |
+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
141 |
+
except Exception: # pylint: disable=broad-exception-caught
|
142 |
+
try:
|
143 |
+
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
144 |
+
gradscaler_init()
|
145 |
+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
146 |
+
except Exception: # pylint: disable=broad-exception-caught
|
147 |
+
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
148 |
+
|
149 |
+
# C
|
150 |
+
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
151 |
+
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
152 |
+
ipex._C._DeviceProperties.major = 2024
|
153 |
+
ipex._C._DeviceProperties.minor = 0
|
154 |
+
|
155 |
+
# Fix functions with ipex:
|
156 |
+
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
157 |
+
torch._utils._get_available_device_type = lambda: "xpu"
|
158 |
+
torch.has_cuda = True
|
159 |
+
torch.cuda.has_half = True
|
160 |
+
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
161 |
+
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
162 |
+
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
163 |
+
torch.version.cuda = "12.1"
|
164 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
165 |
+
torch.cuda.get_device_properties.major = 12
|
166 |
+
torch.cuda.get_device_properties.minor = 1
|
167 |
+
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
168 |
+
torch.cuda.utilization = lambda *args, **kwargs: 0
|
169 |
+
|
170 |
+
ipex_hijacks()
|
171 |
+
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
172 |
+
try:
|
173 |
+
from .diffusers import ipex_diffusers
|
174 |
+
ipex_diffusers()
|
175 |
+
except Exception: # pylint: disable=broad-exception-caught
|
176 |
+
pass
|
177 |
+
torch.cuda.is_xpu_hijacked = True
|
178 |
+
except Exception as e:
|
179 |
+
return False, e
|
180 |
+
return True, None
|
library/ipex/attention.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
4 |
+
from functools import cache
|
5 |
+
|
6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
7 |
+
|
8 |
+
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
9 |
+
|
10 |
+
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
11 |
+
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
12 |
+
|
13 |
+
# Find something divisible with the input_tokens
|
14 |
+
@cache
|
15 |
+
def find_slice_size(slice_size, slice_block_size):
|
16 |
+
while (slice_size * slice_block_size) > attention_slice_rate:
|
17 |
+
slice_size = slice_size // 2
|
18 |
+
if slice_size <= 1:
|
19 |
+
slice_size = 1
|
20 |
+
break
|
21 |
+
return slice_size
|
22 |
+
|
23 |
+
# Find slice sizes for SDPA
|
24 |
+
@cache
|
25 |
+
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
26 |
+
if len(query_shape) == 3:
|
27 |
+
batch_size_attention, query_tokens, shape_three = query_shape
|
28 |
+
shape_four = 1
|
29 |
+
else:
|
30 |
+
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
31 |
+
|
32 |
+
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
33 |
+
block_size = batch_size_attention * slice_block_size
|
34 |
+
|
35 |
+
split_slice_size = batch_size_attention
|
36 |
+
split_2_slice_size = query_tokens
|
37 |
+
split_3_slice_size = shape_three
|
38 |
+
|
39 |
+
do_split = False
|
40 |
+
do_split_2 = False
|
41 |
+
do_split_3 = False
|
42 |
+
|
43 |
+
if block_size > sdpa_slice_trigger_rate:
|
44 |
+
do_split = True
|
45 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
46 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
47 |
+
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
48 |
+
do_split_2 = True
|
49 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
50 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
51 |
+
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
52 |
+
do_split_3 = True
|
53 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
54 |
+
|
55 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
56 |
+
|
57 |
+
# Find slice sizes for BMM
|
58 |
+
@cache
|
59 |
+
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
60 |
+
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
61 |
+
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
62 |
+
block_size = batch_size_attention * slice_block_size
|
63 |
+
|
64 |
+
split_slice_size = batch_size_attention
|
65 |
+
split_2_slice_size = input_tokens
|
66 |
+
split_3_slice_size = mat2_atten_shape
|
67 |
+
|
68 |
+
do_split = False
|
69 |
+
do_split_2 = False
|
70 |
+
do_split_3 = False
|
71 |
+
|
72 |
+
if block_size > attention_slice_rate:
|
73 |
+
do_split = True
|
74 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
75 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
76 |
+
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
77 |
+
do_split_2 = True
|
78 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
79 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
80 |
+
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
81 |
+
do_split_3 = True
|
82 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
83 |
+
|
84 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
85 |
+
|
86 |
+
|
87 |
+
original_torch_bmm = torch.bmm
|
88 |
+
def torch_bmm_32_bit(input, mat2, *, out=None):
|
89 |
+
if input.device.type != "xpu":
|
90 |
+
return original_torch_bmm(input, mat2, out=out)
|
91 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
92 |
+
|
93 |
+
# Slice BMM
|
94 |
+
if do_split:
|
95 |
+
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
96 |
+
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
97 |
+
for i in range(batch_size_attention // split_slice_size):
|
98 |
+
start_idx = i * split_slice_size
|
99 |
+
end_idx = (i + 1) * split_slice_size
|
100 |
+
if do_split_2:
|
101 |
+
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
102 |
+
start_idx_2 = i2 * split_2_slice_size
|
103 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
104 |
+
if do_split_3:
|
105 |
+
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
106 |
+
start_idx_3 = i3 * split_3_slice_size
|
107 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
108 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
109 |
+
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
110 |
+
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
111 |
+
out=out
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
115 |
+
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
116 |
+
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
117 |
+
out=out
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
121 |
+
input[start_idx:end_idx],
|
122 |
+
mat2[start_idx:end_idx],
|
123 |
+
out=out
|
124 |
+
)
|
125 |
+
torch.xpu.synchronize(input.device)
|
126 |
+
else:
|
127 |
+
return original_torch_bmm(input, mat2, out=out)
|
128 |
+
return hidden_states
|
129 |
+
|
130 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
131 |
+
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
132 |
+
if query.device.type != "xpu":
|
133 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
134 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
135 |
+
|
136 |
+
# Slice SDPA
|
137 |
+
if do_split:
|
138 |
+
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
139 |
+
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
140 |
+
for i in range(batch_size_attention // split_slice_size):
|
141 |
+
start_idx = i * split_slice_size
|
142 |
+
end_idx = (i + 1) * split_slice_size
|
143 |
+
if do_split_2:
|
144 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
145 |
+
start_idx_2 = i2 * split_2_slice_size
|
146 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
147 |
+
if do_split_3:
|
148 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
149 |
+
start_idx_3 = i3 * split_3_slice_size
|
150 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
151 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
152 |
+
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
153 |
+
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
154 |
+
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
155 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
156 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
160 |
+
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
161 |
+
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
162 |
+
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
163 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
164 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
168 |
+
query[start_idx:end_idx],
|
169 |
+
key[start_idx:end_idx],
|
170 |
+
value[start_idx:end_idx],
|
171 |
+
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
172 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
173 |
+
)
|
174 |
+
torch.xpu.synchronize(query.device)
|
175 |
+
else:
|
176 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
177 |
+
return hidden_states
|
library/ipex/diffusers.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
4 |
+
import diffusers #0.24.0 # pylint: disable=import-error
|
5 |
+
from diffusers.models.attention_processor import Attention
|
6 |
+
from diffusers.utils import USE_PEFT_BACKEND
|
7 |
+
from functools import cache
|
8 |
+
|
9 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
10 |
+
|
11 |
+
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
12 |
+
|
13 |
+
@cache
|
14 |
+
def find_slice_size(slice_size, slice_block_size):
|
15 |
+
while (slice_size * slice_block_size) > attention_slice_rate:
|
16 |
+
slice_size = slice_size // 2
|
17 |
+
if slice_size <= 1:
|
18 |
+
slice_size = 1
|
19 |
+
break
|
20 |
+
return slice_size
|
21 |
+
|
22 |
+
@cache
|
23 |
+
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
24 |
+
if len(query_shape) == 3:
|
25 |
+
batch_size_attention, query_tokens, shape_three = query_shape
|
26 |
+
shape_four = 1
|
27 |
+
else:
|
28 |
+
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
29 |
+
if slice_size is not None:
|
30 |
+
batch_size_attention = slice_size
|
31 |
+
|
32 |
+
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
33 |
+
block_size = batch_size_attention * slice_block_size
|
34 |
+
|
35 |
+
split_slice_size = batch_size_attention
|
36 |
+
split_2_slice_size = query_tokens
|
37 |
+
split_3_slice_size = shape_three
|
38 |
+
|
39 |
+
do_split = False
|
40 |
+
do_split_2 = False
|
41 |
+
do_split_3 = False
|
42 |
+
|
43 |
+
if query_device_type != "xpu":
|
44 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
45 |
+
|
46 |
+
if block_size > attention_slice_rate:
|
47 |
+
do_split = True
|
48 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
49 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
50 |
+
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
51 |
+
do_split_2 = True
|
52 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
53 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
54 |
+
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
55 |
+
do_split_3 = True
|
56 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
57 |
+
|
58 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
59 |
+
|
60 |
+
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
61 |
+
r"""
|
62 |
+
Processor for implementing sliced attention.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
slice_size (`int`, *optional*):
|
66 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
67 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, slice_size):
|
71 |
+
self.slice_size = slice_size
|
72 |
+
|
73 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
74 |
+
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
75 |
+
|
76 |
+
residual = hidden_states
|
77 |
+
|
78 |
+
input_ndim = hidden_states.ndim
|
79 |
+
|
80 |
+
if input_ndim == 4:
|
81 |
+
batch_size, channel, height, width = hidden_states.shape
|
82 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
83 |
+
|
84 |
+
batch_size, sequence_length, _ = (
|
85 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
86 |
+
)
|
87 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
88 |
+
|
89 |
+
if attn.group_norm is not None:
|
90 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
91 |
+
|
92 |
+
query = attn.to_q(hidden_states)
|
93 |
+
dim = query.shape[-1]
|
94 |
+
query = attn.head_to_batch_dim(query)
|
95 |
+
|
96 |
+
if encoder_hidden_states is None:
|
97 |
+
encoder_hidden_states = hidden_states
|
98 |
+
elif attn.norm_cross:
|
99 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
100 |
+
|
101 |
+
key = attn.to_k(encoder_hidden_states)
|
102 |
+
value = attn.to_v(encoder_hidden_states)
|
103 |
+
key = attn.head_to_batch_dim(key)
|
104 |
+
value = attn.head_to_batch_dim(value)
|
105 |
+
|
106 |
+
batch_size_attention, query_tokens, shape_three = query.shape
|
107 |
+
hidden_states = torch.zeros(
|
108 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
109 |
+
)
|
110 |
+
|
111 |
+
####################################################################
|
112 |
+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
113 |
+
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
114 |
+
|
115 |
+
for i in range(batch_size_attention // split_slice_size):
|
116 |
+
start_idx = i * split_slice_size
|
117 |
+
end_idx = (i + 1) * split_slice_size
|
118 |
+
if do_split_2:
|
119 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
120 |
+
start_idx_2 = i2 * split_2_slice_size
|
121 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
122 |
+
if do_split_3:
|
123 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
124 |
+
start_idx_3 = i3 * split_3_slice_size
|
125 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
126 |
+
|
127 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
128 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
129 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
130 |
+
|
131 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
132 |
+
del query_slice
|
133 |
+
del key_slice
|
134 |
+
del attn_mask_slice
|
135 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
136 |
+
|
137 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
138 |
+
del attn_slice
|
139 |
+
else:
|
140 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
141 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
142 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
143 |
+
|
144 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
145 |
+
del query_slice
|
146 |
+
del key_slice
|
147 |
+
del attn_mask_slice
|
148 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
149 |
+
|
150 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
151 |
+
del attn_slice
|
152 |
+
torch.xpu.synchronize(query.device)
|
153 |
+
else:
|
154 |
+
query_slice = query[start_idx:end_idx]
|
155 |
+
key_slice = key[start_idx:end_idx]
|
156 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
157 |
+
|
158 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
159 |
+
del query_slice
|
160 |
+
del key_slice
|
161 |
+
del attn_mask_slice
|
162 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
163 |
+
|
164 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
165 |
+
del attn_slice
|
166 |
+
####################################################################
|
167 |
+
|
168 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
169 |
+
|
170 |
+
# linear proj
|
171 |
+
hidden_states = attn.to_out[0](hidden_states)
|
172 |
+
# dropout
|
173 |
+
hidden_states = attn.to_out[1](hidden_states)
|
174 |
+
|
175 |
+
if input_ndim == 4:
|
176 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
177 |
+
|
178 |
+
if attn.residual_connection:
|
179 |
+
hidden_states = hidden_states + residual
|
180 |
+
|
181 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
182 |
+
|
183 |
+
return hidden_states
|
184 |
+
|
185 |
+
|
186 |
+
class AttnProcessor:
|
187 |
+
r"""
|
188 |
+
Default processor for performing attention-related computations.
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
192 |
+
encoder_hidden_states=None, attention_mask=None,
|
193 |
+
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
194 |
+
|
195 |
+
residual = hidden_states
|
196 |
+
|
197 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
198 |
+
|
199 |
+
if attn.spatial_norm is not None:
|
200 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
201 |
+
|
202 |
+
input_ndim = hidden_states.ndim
|
203 |
+
|
204 |
+
if input_ndim == 4:
|
205 |
+
batch_size, channel, height, width = hidden_states.shape
|
206 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
207 |
+
|
208 |
+
batch_size, sequence_length, _ = (
|
209 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
210 |
+
)
|
211 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
212 |
+
|
213 |
+
if attn.group_norm is not None:
|
214 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
215 |
+
|
216 |
+
query = attn.to_q(hidden_states, *args)
|
217 |
+
|
218 |
+
if encoder_hidden_states is None:
|
219 |
+
encoder_hidden_states = hidden_states
|
220 |
+
elif attn.norm_cross:
|
221 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
222 |
+
|
223 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
224 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
225 |
+
|
226 |
+
query = attn.head_to_batch_dim(query)
|
227 |
+
key = attn.head_to_batch_dim(key)
|
228 |
+
value = attn.head_to_batch_dim(value)
|
229 |
+
|
230 |
+
####################################################################
|
231 |
+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
232 |
+
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
233 |
+
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
234 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
235 |
+
|
236 |
+
if do_split:
|
237 |
+
for i in range(batch_size_attention // split_slice_size):
|
238 |
+
start_idx = i * split_slice_size
|
239 |
+
end_idx = (i + 1) * split_slice_size
|
240 |
+
if do_split_2:
|
241 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
242 |
+
start_idx_2 = i2 * split_2_slice_size
|
243 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
244 |
+
if do_split_3:
|
245 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
246 |
+
start_idx_3 = i3 * split_3_slice_size
|
247 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
248 |
+
|
249 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
250 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
251 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
252 |
+
|
253 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
254 |
+
del query_slice
|
255 |
+
del key_slice
|
256 |
+
del attn_mask_slice
|
257 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
258 |
+
|
259 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
260 |
+
del attn_slice
|
261 |
+
else:
|
262 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
263 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
264 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
265 |
+
|
266 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
267 |
+
del query_slice
|
268 |
+
del key_slice
|
269 |
+
del attn_mask_slice
|
270 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
271 |
+
|
272 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
273 |
+
del attn_slice
|
274 |
+
else:
|
275 |
+
query_slice = query[start_idx:end_idx]
|
276 |
+
key_slice = key[start_idx:end_idx]
|
277 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
278 |
+
|
279 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
280 |
+
del query_slice
|
281 |
+
del key_slice
|
282 |
+
del attn_mask_slice
|
283 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
284 |
+
|
285 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
286 |
+
del attn_slice
|
287 |
+
torch.xpu.synchronize(query.device)
|
288 |
+
else:
|
289 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
290 |
+
hidden_states = torch.bmm(attention_probs, value)
|
291 |
+
####################################################################
|
292 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
293 |
+
|
294 |
+
# linear proj
|
295 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
296 |
+
# dropout
|
297 |
+
hidden_states = attn.to_out[1](hidden_states)
|
298 |
+
|
299 |
+
if input_ndim == 4:
|
300 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
301 |
+
|
302 |
+
if attn.residual_connection:
|
303 |
+
hidden_states = hidden_states + residual
|
304 |
+
|
305 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
306 |
+
|
307 |
+
return hidden_states
|
308 |
+
|
309 |
+
def ipex_diffusers():
|
310 |
+
#ARC GPUs can't allocate more than 4GB to a single block:
|
311 |
+
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
312 |
+
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
library/ipex/gradscaler.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import torch
|
3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
4 |
+
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
5 |
+
|
6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
7 |
+
|
8 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
9 |
+
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
10 |
+
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
11 |
+
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
12 |
+
|
13 |
+
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
14 |
+
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
15 |
+
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
16 |
+
|
17 |
+
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
18 |
+
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
19 |
+
# However, we don't know their devices or dtypes in advance.
|
20 |
+
|
21 |
+
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
22 |
+
# Google says mypy struggles with defaultdicts type annotations.
|
23 |
+
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
24 |
+
# sync grad to master weight
|
25 |
+
if hasattr(optimizer, "sync_grad"):
|
26 |
+
optimizer.sync_grad()
|
27 |
+
with torch.no_grad():
|
28 |
+
for group in optimizer.param_groups:
|
29 |
+
for param in group["params"]:
|
30 |
+
if param.grad is None:
|
31 |
+
continue
|
32 |
+
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
33 |
+
raise ValueError("Attempting to unscale FP16 gradients.")
|
34 |
+
if param.grad.is_sparse:
|
35 |
+
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
36 |
+
# coalesce() deduplicates indices and adds all values that have the same index.
|
37 |
+
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
38 |
+
# so we should check the coalesced _values().
|
39 |
+
if param.grad.dtype is torch.float16:
|
40 |
+
param.grad = param.grad.coalesce()
|
41 |
+
to_unscale = param.grad._values()
|
42 |
+
else:
|
43 |
+
to_unscale = param.grad
|
44 |
+
|
45 |
+
# -: is there a way to split by device and dtype without appending in the inner loop?
|
46 |
+
to_unscale = to_unscale.to("cpu")
|
47 |
+
per_device_and_dtype_grads[to_unscale.device][
|
48 |
+
to_unscale.dtype
|
49 |
+
].append(to_unscale)
|
50 |
+
|
51 |
+
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
52 |
+
for grads in per_dtype_grads.values():
|
53 |
+
core._amp_foreach_non_finite_check_and_unscale_(
|
54 |
+
grads,
|
55 |
+
per_device_found_inf.get("cpu"),
|
56 |
+
per_device_inv_scale.get("cpu"),
|
57 |
+
)
|
58 |
+
|
59 |
+
return per_device_found_inf._per_device_tensors
|
60 |
+
|
61 |
+
def unscale_(self, optimizer):
|
62 |
+
"""
|
63 |
+
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
64 |
+
:meth:`unscale_` is optional, serving cases where you need to
|
65 |
+
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
66 |
+
between the backward pass(es) and :meth:`step`.
|
67 |
+
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
68 |
+
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
69 |
+
...
|
70 |
+
scaler.scale(loss).backward()
|
71 |
+
scaler.unscale_(optimizer)
|
72 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
73 |
+
scaler.step(optimizer)
|
74 |
+
scaler.update()
|
75 |
+
Args:
|
76 |
+
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
77 |
+
.. warning::
|
78 |
+
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
79 |
+
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
80 |
+
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
81 |
+
.. warning::
|
82 |
+
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
83 |
+
"""
|
84 |
+
if not self._enabled:
|
85 |
+
return
|
86 |
+
|
87 |
+
self._check_scale_growth_tracker("unscale_")
|
88 |
+
|
89 |
+
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
90 |
+
|
91 |
+
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
92 |
+
raise RuntimeError(
|
93 |
+
"unscale_() has already been called on this optimizer since the last update()."
|
94 |
+
)
|
95 |
+
elif optimizer_state["stage"] is OptState.STEPPED:
|
96 |
+
raise RuntimeError("unscale_() is being called after step().")
|
97 |
+
|
98 |
+
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
99 |
+
assert self._scale is not None
|
100 |
+
if device_supports_fp64:
|
101 |
+
inv_scale = self._scale.double().reciprocal().float()
|
102 |
+
else:
|
103 |
+
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
104 |
+
found_inf = torch.full(
|
105 |
+
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
106 |
+
)
|
107 |
+
|
108 |
+
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
109 |
+
optimizer, inv_scale, found_inf, False
|
110 |
+
)
|
111 |
+
optimizer_state["stage"] = OptState.UNSCALED
|
112 |
+
|
113 |
+
def update(self, new_scale=None):
|
114 |
+
"""
|
115 |
+
Updates the scale factor.
|
116 |
+
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
117 |
+
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
118 |
+
the scale is multiplied by ``growth_factor`` to increase it.
|
119 |
+
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
120 |
+
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
121 |
+
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
122 |
+
affect the scale GradScaler uses internally.)
|
123 |
+
Args:
|
124 |
+
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
|
125 |
+
.. warning::
|
126 |
+
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
127 |
+
been invoked for all optimizers used this iteration.
|
128 |
+
"""
|
129 |
+
if not self._enabled:
|
130 |
+
return
|
131 |
+
|
132 |
+
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
133 |
+
|
134 |
+
if new_scale is not None:
|
135 |
+
# Accept a new user-defined scale.
|
136 |
+
if isinstance(new_scale, float):
|
137 |
+
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
138 |
+
else:
|
139 |
+
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
|
140 |
+
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
|
141 |
+
assert new_scale.numel() == 1, reason
|
142 |
+
assert new_scale.requires_grad is False, reason
|
143 |
+
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
144 |
+
else:
|
145 |
+
# Consume shared inf/nan data collected from optimizers to update the scale.
|
146 |
+
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
147 |
+
found_infs = [
|
148 |
+
found_inf.to(device="cpu", non_blocking=True)
|
149 |
+
for state in self._per_optimizer_states.values()
|
150 |
+
for found_inf in state["found_inf_per_device"].values()
|
151 |
+
]
|
152 |
+
|
153 |
+
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
154 |
+
|
155 |
+
found_inf_combined = found_infs[0]
|
156 |
+
if len(found_infs) > 1:
|
157 |
+
for i in range(1, len(found_infs)):
|
158 |
+
found_inf_combined += found_infs[i]
|
159 |
+
|
160 |
+
to_device = _scale.device
|
161 |
+
_scale = _scale.to("cpu")
|
162 |
+
_growth_tracker = _growth_tracker.to("cpu")
|
163 |
+
|
164 |
+
core._amp_update_scale_(
|
165 |
+
_scale,
|
166 |
+
_growth_tracker,
|
167 |
+
found_inf_combined,
|
168 |
+
self._growth_factor,
|
169 |
+
self._backoff_factor,
|
170 |
+
self._growth_interval,
|
171 |
+
)
|
172 |
+
|
173 |
+
_scale = _scale.to(to_device)
|
174 |
+
_growth_tracker = _growth_tracker.to(to_device)
|
175 |
+
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
176 |
+
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
177 |
+
|
178 |
+
def gradscaler_init():
|
179 |
+
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
180 |
+
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
181 |
+
torch.xpu.amp.GradScaler.unscale_ = unscale_
|
182 |
+
torch.xpu.amp.GradScaler.update = update
|
183 |
+
return torch.xpu.amp.GradScaler
|
library/ipex/hijacks.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import wraps
|
3 |
+
from contextlib import nullcontext
|
4 |
+
import torch
|
5 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
9 |
+
|
10 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
11 |
+
|
12 |
+
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
13 |
+
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
14 |
+
if isinstance(device_ids, list) and len(device_ids) > 1:
|
15 |
+
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
16 |
+
return module.to("xpu")
|
17 |
+
|
18 |
+
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
19 |
+
return nullcontext()
|
20 |
+
|
21 |
+
@property
|
22 |
+
def is_cuda(self):
|
23 |
+
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
24 |
+
|
25 |
+
def check_device(device):
|
26 |
+
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
27 |
+
|
28 |
+
def return_xpu(device):
|
29 |
+
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
30 |
+
|
31 |
+
|
32 |
+
# Autocast
|
33 |
+
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
34 |
+
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
35 |
+
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
36 |
+
if device_type == "cuda":
|
37 |
+
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
38 |
+
else:
|
39 |
+
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
40 |
+
|
41 |
+
# Latent Antialias CPU Offload:
|
42 |
+
original_interpolate = torch.nn.functional.interpolate
|
43 |
+
@wraps(torch.nn.functional.interpolate)
|
44 |
+
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
45 |
+
if antialias or align_corners is not None or mode == 'bicubic':
|
46 |
+
return_device = tensor.device
|
47 |
+
return_dtype = tensor.dtype
|
48 |
+
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
49 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
|
50 |
+
else:
|
51 |
+
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
52 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
53 |
+
|
54 |
+
|
55 |
+
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
56 |
+
original_from_numpy = torch.from_numpy
|
57 |
+
@wraps(torch.from_numpy)
|
58 |
+
def from_numpy(ndarray):
|
59 |
+
if ndarray.dtype == float:
|
60 |
+
return original_from_numpy(ndarray.astype('float32'))
|
61 |
+
else:
|
62 |
+
return original_from_numpy(ndarray)
|
63 |
+
|
64 |
+
original_as_tensor = torch.as_tensor
|
65 |
+
@wraps(torch.as_tensor)
|
66 |
+
def as_tensor(data, dtype=None, device=None):
|
67 |
+
if check_device(device):
|
68 |
+
device = return_xpu(device)
|
69 |
+
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
70 |
+
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
71 |
+
return original_as_tensor(data, dtype=torch.float32, device=device)
|
72 |
+
else:
|
73 |
+
return original_as_tensor(data, dtype=dtype, device=device)
|
74 |
+
|
75 |
+
|
76 |
+
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
77 |
+
original_torch_bmm = torch.bmm
|
78 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
79 |
+
else:
|
80 |
+
# 32 bit attention workarounds for Alchemist:
|
81 |
+
try:
|
82 |
+
from .attention import torch_bmm_32_bit as original_torch_bmm
|
83 |
+
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
84 |
+
except Exception: # pylint: disable=broad-exception-caught
|
85 |
+
original_torch_bmm = torch.bmm
|
86 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
87 |
+
|
88 |
+
|
89 |
+
# Data Type Errors:
|
90 |
+
@wraps(torch.bmm)
|
91 |
+
def torch_bmm(input, mat2, *, out=None):
|
92 |
+
if input.dtype != mat2.dtype:
|
93 |
+
mat2 = mat2.to(input.dtype)
|
94 |
+
return original_torch_bmm(input, mat2, out=out)
|
95 |
+
|
96 |
+
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
97 |
+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
98 |
+
if query.dtype != key.dtype:
|
99 |
+
key = key.to(dtype=query.dtype)
|
100 |
+
if query.dtype != value.dtype:
|
101 |
+
value = value.to(dtype=query.dtype)
|
102 |
+
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
103 |
+
attn_mask = attn_mask.to(dtype=query.dtype)
|
104 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
105 |
+
|
106 |
+
# A1111 FP16
|
107 |
+
original_functional_group_norm = torch.nn.functional.group_norm
|
108 |
+
@wraps(torch.nn.functional.group_norm)
|
109 |
+
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
110 |
+
if weight is not None and input.dtype != weight.data.dtype:
|
111 |
+
input = input.to(dtype=weight.data.dtype)
|
112 |
+
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
113 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
114 |
+
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
|
115 |
+
|
116 |
+
# A1111 BF16
|
117 |
+
original_functional_layer_norm = torch.nn.functional.layer_norm
|
118 |
+
@wraps(torch.nn.functional.layer_norm)
|
119 |
+
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
120 |
+
if weight is not None and input.dtype != weight.data.dtype:
|
121 |
+
input = input.to(dtype=weight.data.dtype)
|
122 |
+
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
123 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
124 |
+
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
|
125 |
+
|
126 |
+
# Training
|
127 |
+
original_functional_linear = torch.nn.functional.linear
|
128 |
+
@wraps(torch.nn.functional.linear)
|
129 |
+
def functional_linear(input, weight, bias=None):
|
130 |
+
if input.dtype != weight.data.dtype:
|
131 |
+
input = input.to(dtype=weight.data.dtype)
|
132 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
133 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
134 |
+
return original_functional_linear(input, weight, bias=bias)
|
135 |
+
|
136 |
+
original_functional_conv2d = torch.nn.functional.conv2d
|
137 |
+
@wraps(torch.nn.functional.conv2d)
|
138 |
+
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
139 |
+
if input.dtype != weight.data.dtype:
|
140 |
+
input = input.to(dtype=weight.data.dtype)
|
141 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
142 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
143 |
+
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
144 |
+
|
145 |
+
# A1111 Embedding BF16
|
146 |
+
original_torch_cat = torch.cat
|
147 |
+
@wraps(torch.cat)
|
148 |
+
def torch_cat(tensor, *args, **kwargs):
|
149 |
+
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
150 |
+
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
151 |
+
else:
|
152 |
+
return original_torch_cat(tensor, *args, **kwargs)
|
153 |
+
|
154 |
+
# SwinIR BF16:
|
155 |
+
original_functional_pad = torch.nn.functional.pad
|
156 |
+
@wraps(torch.nn.functional.pad)
|
157 |
+
def functional_pad(input, pad, mode='constant', value=None):
|
158 |
+
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
159 |
+
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
160 |
+
else:
|
161 |
+
return original_functional_pad(input, pad, mode=mode, value=value)
|
162 |
+
|
163 |
+
|
164 |
+
original_torch_tensor = torch.tensor
|
165 |
+
@wraps(torch.tensor)
|
166 |
+
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
167 |
+
if check_device(device):
|
168 |
+
device = return_xpu(device)
|
169 |
+
if not device_supports_fp64:
|
170 |
+
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
171 |
+
if dtype == torch.float64:
|
172 |
+
dtype = torch.float32
|
173 |
+
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
174 |
+
dtype = torch.float32
|
175 |
+
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
176 |
+
|
177 |
+
original_Tensor_to = torch.Tensor.to
|
178 |
+
@wraps(torch.Tensor.to)
|
179 |
+
def Tensor_to(self, device=None, *args, **kwargs):
|
180 |
+
if check_device(device):
|
181 |
+
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
182 |
+
else:
|
183 |
+
return original_Tensor_to(self, device, *args, **kwargs)
|
184 |
+
|
185 |
+
original_Tensor_cuda = torch.Tensor.cuda
|
186 |
+
@wraps(torch.Tensor.cuda)
|
187 |
+
def Tensor_cuda(self, device=None, *args, **kwargs):
|
188 |
+
if check_device(device):
|
189 |
+
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
190 |
+
else:
|
191 |
+
return original_Tensor_cuda(self, device, *args, **kwargs)
|
192 |
+
|
193 |
+
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
194 |
+
@wraps(torch.Tensor.pin_memory)
|
195 |
+
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
196 |
+
if device is None:
|
197 |
+
device = "xpu"
|
198 |
+
if check_device(device):
|
199 |
+
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
200 |
+
else:
|
201 |
+
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
202 |
+
|
203 |
+
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
204 |
+
@wraps(torch.UntypedStorage.__init__)
|
205 |
+
def UntypedStorage_init(*args, device=None, **kwargs):
|
206 |
+
if check_device(device):
|
207 |
+
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
208 |
+
else:
|
209 |
+
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
210 |
+
|
211 |
+
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
212 |
+
@wraps(torch.UntypedStorage.cuda)
|
213 |
+
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
214 |
+
if check_device(device):
|
215 |
+
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
216 |
+
else:
|
217 |
+
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
218 |
+
|
219 |
+
original_torch_empty = torch.empty
|
220 |
+
@wraps(torch.empty)
|
221 |
+
def torch_empty(*args, device=None, **kwargs):
|
222 |
+
if check_device(device):
|
223 |
+
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
224 |
+
else:
|
225 |
+
return original_torch_empty(*args, device=device, **kwargs)
|
226 |
+
|
227 |
+
original_torch_randn = torch.randn
|
228 |
+
@wraps(torch.randn)
|
229 |
+
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
230 |
+
if dtype == bytes:
|
231 |
+
dtype = None
|
232 |
+
if check_device(device):
|
233 |
+
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
234 |
+
else:
|
235 |
+
return original_torch_randn(*args, device=device, **kwargs)
|
236 |
+
|
237 |
+
original_torch_ones = torch.ones
|
238 |
+
@wraps(torch.ones)
|
239 |
+
def torch_ones(*args, device=None, **kwargs):
|
240 |
+
if check_device(device):
|
241 |
+
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
242 |
+
else:
|
243 |
+
return original_torch_ones(*args, device=device, **kwargs)
|
244 |
+
|
245 |
+
original_torch_zeros = torch.zeros
|
246 |
+
@wraps(torch.zeros)
|
247 |
+
def torch_zeros(*args, device=None, **kwargs):
|
248 |
+
if check_device(device):
|
249 |
+
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
250 |
+
else:
|
251 |
+
return original_torch_zeros(*args, device=device, **kwargs)
|
252 |
+
|
253 |
+
original_torch_linspace = torch.linspace
|
254 |
+
@wraps(torch.linspace)
|
255 |
+
def torch_linspace(*args, device=None, **kwargs):
|
256 |
+
if check_device(device):
|
257 |
+
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
258 |
+
else:
|
259 |
+
return original_torch_linspace(*args, device=device, **kwargs)
|
260 |
+
|
261 |
+
original_torch_Generator = torch.Generator
|
262 |
+
@wraps(torch.Generator)
|
263 |
+
def torch_Generator(device=None):
|
264 |
+
if check_device(device):
|
265 |
+
return original_torch_Generator(return_xpu(device))
|
266 |
+
else:
|
267 |
+
return original_torch_Generator(device)
|
268 |
+
|
269 |
+
original_torch_load = torch.load
|
270 |
+
@wraps(torch.load)
|
271 |
+
def torch_load(f, map_location=None, *args, **kwargs):
|
272 |
+
if map_location is None:
|
273 |
+
map_location = "xpu"
|
274 |
+
if check_device(map_location):
|
275 |
+
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
276 |
+
else:
|
277 |
+
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
278 |
+
|
279 |
+
|
280 |
+
# Hijack Functions:
|
281 |
+
def ipex_hijacks():
|
282 |
+
torch.tensor = torch_tensor
|
283 |
+
torch.Tensor.to = Tensor_to
|
284 |
+
torch.Tensor.cuda = Tensor_cuda
|
285 |
+
torch.Tensor.pin_memory = Tensor_pin_memory
|
286 |
+
torch.UntypedStorage.__init__ = UntypedStorage_init
|
287 |
+
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
288 |
+
torch.empty = torch_empty
|
289 |
+
torch.randn = torch_randn
|
290 |
+
torch.ones = torch_ones
|
291 |
+
torch.zeros = torch_zeros
|
292 |
+
torch.linspace = torch_linspace
|
293 |
+
torch.Generator = torch_Generator
|
294 |
+
torch.load = torch_load
|
295 |
+
|
296 |
+
torch.backends.cuda.sdp_kernel = return_null_context
|
297 |
+
torch.nn.DataParallel = DummyDataParallel
|
298 |
+
torch.UntypedStorage.is_cuda = is_cuda
|
299 |
+
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
300 |
+
|
301 |
+
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
302 |
+
torch.nn.functional.group_norm = functional_group_norm
|
303 |
+
torch.nn.functional.layer_norm = functional_layer_norm
|
304 |
+
torch.nn.functional.linear = functional_linear
|
305 |
+
torch.nn.functional.conv2d = functional_conv2d
|
306 |
+
torch.nn.functional.interpolate = interpolate
|
307 |
+
torch.nn.functional.pad = functional_pad
|
308 |
+
|
309 |
+
torch.bmm = torch_bmm
|
310 |
+
torch.cat = torch_cat
|
311 |
+
if not device_supports_fp64:
|
312 |
+
torch.from_numpy = from_numpy
|
313 |
+
torch.as_tensor = as_tensor
|
library/lpw_stable_diffusion.py
ADDED
@@ -0,0 +1,1233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
2 |
+
# and modify to support SD2.x
|
3 |
+
|
4 |
+
import inspect
|
5 |
+
import re
|
6 |
+
from typing import Callable, List, Optional, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from packaging import version
|
12 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
13 |
+
|
14 |
+
import diffusers
|
15 |
+
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
16 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
17 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
18 |
+
from diffusers.utils import logging
|
19 |
+
|
20 |
+
try:
|
21 |
+
from diffusers.utils import PIL_INTERPOLATION
|
22 |
+
except ImportError:
|
23 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
24 |
+
PIL_INTERPOLATION = {
|
25 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
26 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
27 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
28 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
29 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
30 |
+
}
|
31 |
+
else:
|
32 |
+
PIL_INTERPOLATION = {
|
33 |
+
"linear": PIL.Image.LINEAR,
|
34 |
+
"bilinear": PIL.Image.BILINEAR,
|
35 |
+
"bicubic": PIL.Image.BICUBIC,
|
36 |
+
"lanczos": PIL.Image.LANCZOS,
|
37 |
+
"nearest": PIL.Image.NEAREST,
|
38 |
+
}
|
39 |
+
# ------------------------------------------------------------------------------
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
42 |
+
|
43 |
+
re_attention = re.compile(
|
44 |
+
r"""
|
45 |
+
\\\(|
|
46 |
+
\\\)|
|
47 |
+
\\\[|
|
48 |
+
\\]|
|
49 |
+
\\\\|
|
50 |
+
\\|
|
51 |
+
\(|
|
52 |
+
\[|
|
53 |
+
:([+-]?[.\d]+)\)|
|
54 |
+
\)|
|
55 |
+
]|
|
56 |
+
[^\\()\[\]:]+|
|
57 |
+
:
|
58 |
+
""",
|
59 |
+
re.X,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def parse_prompt_attention(text):
|
64 |
+
"""
|
65 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
66 |
+
Accepted tokens are:
|
67 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
68 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
69 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
70 |
+
\( - literal character '('
|
71 |
+
\[ - literal character '['
|
72 |
+
\) - literal character ')'
|
73 |
+
\] - literal character ']'
|
74 |
+
\\ - literal character '\'
|
75 |
+
anything else - just text
|
76 |
+
>>> parse_prompt_attention('normal text')
|
77 |
+
[['normal text', 1.0]]
|
78 |
+
>>> parse_prompt_attention('an (important) word')
|
79 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
80 |
+
>>> parse_prompt_attention('(unbalanced')
|
81 |
+
[['unbalanced', 1.1]]
|
82 |
+
>>> parse_prompt_attention('\(literal\]')
|
83 |
+
[['(literal]', 1.0]]
|
84 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
85 |
+
[['unnecessaryparens', 1.1]]
|
86 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
87 |
+
[['a ', 1.0],
|
88 |
+
['house', 1.5730000000000004],
|
89 |
+
[' ', 1.1],
|
90 |
+
['on', 1.0],
|
91 |
+
[' a ', 1.1],
|
92 |
+
['hill', 0.55],
|
93 |
+
[', sun, ', 1.1],
|
94 |
+
['sky', 1.4641000000000006],
|
95 |
+
['.', 1.1]]
|
96 |
+
"""
|
97 |
+
|
98 |
+
res = []
|
99 |
+
round_brackets = []
|
100 |
+
square_brackets = []
|
101 |
+
|
102 |
+
round_bracket_multiplier = 1.1
|
103 |
+
square_bracket_multiplier = 1 / 1.1
|
104 |
+
|
105 |
+
def multiply_range(start_position, multiplier):
|
106 |
+
for p in range(start_position, len(res)):
|
107 |
+
res[p][1] *= multiplier
|
108 |
+
|
109 |
+
for m in re_attention.finditer(text):
|
110 |
+
text = m.group(0)
|
111 |
+
weight = m.group(1)
|
112 |
+
|
113 |
+
if text.startswith("\\"):
|
114 |
+
res.append([text[1:], 1.0])
|
115 |
+
elif text == "(":
|
116 |
+
round_brackets.append(len(res))
|
117 |
+
elif text == "[":
|
118 |
+
square_brackets.append(len(res))
|
119 |
+
elif weight is not None and len(round_brackets) > 0:
|
120 |
+
multiply_range(round_brackets.pop(), float(weight))
|
121 |
+
elif text == ")" and len(round_brackets) > 0:
|
122 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
123 |
+
elif text == "]" and len(square_brackets) > 0:
|
124 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
125 |
+
else:
|
126 |
+
res.append([text, 1.0])
|
127 |
+
|
128 |
+
for pos in round_brackets:
|
129 |
+
multiply_range(pos, round_bracket_multiplier)
|
130 |
+
|
131 |
+
for pos in square_brackets:
|
132 |
+
multiply_range(pos, square_bracket_multiplier)
|
133 |
+
|
134 |
+
if len(res) == 0:
|
135 |
+
res = [["", 1.0]]
|
136 |
+
|
137 |
+
# merge runs of identical weights
|
138 |
+
i = 0
|
139 |
+
while i + 1 < len(res):
|
140 |
+
if res[i][1] == res[i + 1][1]:
|
141 |
+
res[i][0] += res[i + 1][0]
|
142 |
+
res.pop(i + 1)
|
143 |
+
else:
|
144 |
+
i += 1
|
145 |
+
|
146 |
+
return res
|
147 |
+
|
148 |
+
|
149 |
+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
150 |
+
r"""
|
151 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
152 |
+
|
153 |
+
No padding, starting or ending token is included.
|
154 |
+
"""
|
155 |
+
tokens = []
|
156 |
+
weights = []
|
157 |
+
truncated = False
|
158 |
+
for text in prompt:
|
159 |
+
texts_and_weights = parse_prompt_attention(text)
|
160 |
+
text_token = []
|
161 |
+
text_weight = []
|
162 |
+
for word, weight in texts_and_weights:
|
163 |
+
# tokenize and discard the starting and the ending token
|
164 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
165 |
+
text_token += token
|
166 |
+
# copy the weight by length of token
|
167 |
+
text_weight += [weight] * len(token)
|
168 |
+
# stop if the text is too long (longer than truncation limit)
|
169 |
+
if len(text_token) > max_length:
|
170 |
+
truncated = True
|
171 |
+
break
|
172 |
+
# truncate
|
173 |
+
if len(text_token) > max_length:
|
174 |
+
truncated = True
|
175 |
+
text_token = text_token[:max_length]
|
176 |
+
text_weight = text_weight[:max_length]
|
177 |
+
tokens.append(text_token)
|
178 |
+
weights.append(text_weight)
|
179 |
+
if truncated:
|
180 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
181 |
+
return tokens, weights
|
182 |
+
|
183 |
+
|
184 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
185 |
+
r"""
|
186 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
187 |
+
"""
|
188 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
189 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
190 |
+
for i in range(len(tokens)):
|
191 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
192 |
+
if no_boseos_middle:
|
193 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
194 |
+
else:
|
195 |
+
w = []
|
196 |
+
if len(weights[i]) == 0:
|
197 |
+
w = [1.0] * weights_length
|
198 |
+
else:
|
199 |
+
for j in range(max_embeddings_multiples):
|
200 |
+
w.append(1.0) # weight for starting token in this chunk
|
201 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
202 |
+
w.append(1.0) # weight for ending token in this chunk
|
203 |
+
w += [1.0] * (weights_length - len(w))
|
204 |
+
weights[i] = w[:]
|
205 |
+
|
206 |
+
return tokens, weights
|
207 |
+
|
208 |
+
|
209 |
+
def get_unweighted_text_embeddings(
|
210 |
+
pipe: StableDiffusionPipeline,
|
211 |
+
text_input: torch.Tensor,
|
212 |
+
chunk_length: int,
|
213 |
+
clip_skip: int,
|
214 |
+
eos: int,
|
215 |
+
pad: int,
|
216 |
+
no_boseos_middle: Optional[bool] = True,
|
217 |
+
):
|
218 |
+
"""
|
219 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
220 |
+
it should be split into chunks and sent to the text encoder individually.
|
221 |
+
"""
|
222 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
223 |
+
if max_embeddings_multiples > 1:
|
224 |
+
text_embeddings = []
|
225 |
+
for i in range(max_embeddings_multiples):
|
226 |
+
# extract the i-th chunk
|
227 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
228 |
+
|
229 |
+
# cover the head and the tail by the starting and the ending tokens
|
230 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
231 |
+
if pad == eos: # v1
|
232 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
233 |
+
else: # v2
|
234 |
+
for j in range(len(text_input_chunk)):
|
235 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
236 |
+
text_input_chunk[j, -1] = eos
|
237 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
238 |
+
text_input_chunk[j, 1] = eos
|
239 |
+
|
240 |
+
if clip_skip is None or clip_skip == 1:
|
241 |
+
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
242 |
+
else:
|
243 |
+
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
244 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
245 |
+
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
246 |
+
|
247 |
+
if no_boseos_middle:
|
248 |
+
if i == 0:
|
249 |
+
# discard the ending token
|
250 |
+
text_embedding = text_embedding[:, :-1]
|
251 |
+
elif i == max_embeddings_multiples - 1:
|
252 |
+
# discard the starting token
|
253 |
+
text_embedding = text_embedding[:, 1:]
|
254 |
+
else:
|
255 |
+
# discard both starting and ending tokens
|
256 |
+
text_embedding = text_embedding[:, 1:-1]
|
257 |
+
|
258 |
+
text_embeddings.append(text_embedding)
|
259 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
260 |
+
else:
|
261 |
+
if clip_skip is None or clip_skip == 1:
|
262 |
+
text_embeddings = pipe.text_encoder(text_input)[0]
|
263 |
+
else:
|
264 |
+
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
265 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
266 |
+
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
267 |
+
return text_embeddings
|
268 |
+
|
269 |
+
|
270 |
+
def get_weighted_text_embeddings(
|
271 |
+
pipe: StableDiffusionPipeline,
|
272 |
+
prompt: Union[str, List[str]],
|
273 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
274 |
+
max_embeddings_multiples: Optional[int] = 3,
|
275 |
+
no_boseos_middle: Optional[bool] = False,
|
276 |
+
skip_parsing: Optional[bool] = False,
|
277 |
+
skip_weighting: Optional[bool] = False,
|
278 |
+
clip_skip=None,
|
279 |
+
):
|
280 |
+
r"""
|
281 |
+
Prompts can be assigned with local weights using brackets. For example,
|
282 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
283 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
284 |
+
|
285 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
pipe (`StableDiffusionPipeline`):
|
289 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
290 |
+
prompt (`str` or `List[str]`):
|
291 |
+
The prompt or prompts to guide the image generation.
|
292 |
+
uncond_prompt (`str` or `List[str]`):
|
293 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
294 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
295 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
296 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
297 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
298 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
299 |
+
ending token in each of the chunk in the middle.
|
300 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
301 |
+
Skip the parsing of brackets.
|
302 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
303 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
304 |
+
"""
|
305 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
306 |
+
if isinstance(prompt, str):
|
307 |
+
prompt = [prompt]
|
308 |
+
|
309 |
+
if not skip_parsing:
|
310 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
311 |
+
if uncond_prompt is not None:
|
312 |
+
if isinstance(uncond_prompt, str):
|
313 |
+
uncond_prompt = [uncond_prompt]
|
314 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
315 |
+
else:
|
316 |
+
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
317 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
318 |
+
if uncond_prompt is not None:
|
319 |
+
if isinstance(uncond_prompt, str):
|
320 |
+
uncond_prompt = [uncond_prompt]
|
321 |
+
uncond_tokens = [
|
322 |
+
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
323 |
+
]
|
324 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
325 |
+
|
326 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
327 |
+
max_length = max([len(token) for token in prompt_tokens])
|
328 |
+
if uncond_prompt is not None:
|
329 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
330 |
+
|
331 |
+
max_embeddings_multiples = min(
|
332 |
+
max_embeddings_multiples,
|
333 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
334 |
+
)
|
335 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
336 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
337 |
+
|
338 |
+
# pad the length of tokens and weights
|
339 |
+
bos = pipe.tokenizer.bos_token_id
|
340 |
+
eos = pipe.tokenizer.eos_token_id
|
341 |
+
pad = pipe.tokenizer.pad_token_id
|
342 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
343 |
+
prompt_tokens,
|
344 |
+
prompt_weights,
|
345 |
+
max_length,
|
346 |
+
bos,
|
347 |
+
eos,
|
348 |
+
no_boseos_middle=no_boseos_middle,
|
349 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
350 |
+
)
|
351 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
352 |
+
if uncond_prompt is not None:
|
353 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
354 |
+
uncond_tokens,
|
355 |
+
uncond_weights,
|
356 |
+
max_length,
|
357 |
+
bos,
|
358 |
+
eos,
|
359 |
+
no_boseos_middle=no_boseos_middle,
|
360 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
361 |
+
)
|
362 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
363 |
+
|
364 |
+
# get the embeddings
|
365 |
+
text_embeddings = get_unweighted_text_embeddings(
|
366 |
+
pipe,
|
367 |
+
prompt_tokens,
|
368 |
+
pipe.tokenizer.model_max_length,
|
369 |
+
clip_skip,
|
370 |
+
eos,
|
371 |
+
pad,
|
372 |
+
no_boseos_middle=no_boseos_middle,
|
373 |
+
)
|
374 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
375 |
+
if uncond_prompt is not None:
|
376 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
377 |
+
pipe,
|
378 |
+
uncond_tokens,
|
379 |
+
pipe.tokenizer.model_max_length,
|
380 |
+
clip_skip,
|
381 |
+
eos,
|
382 |
+
pad,
|
383 |
+
no_boseos_middle=no_boseos_middle,
|
384 |
+
)
|
385 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
386 |
+
|
387 |
+
# assign weights to the prompts and normalize in the sense of mean
|
388 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
389 |
+
if (not skip_parsing) and (not skip_weighting):
|
390 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
391 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
392 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
393 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
394 |
+
if uncond_prompt is not None:
|
395 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
396 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
397 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
398 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
399 |
+
|
400 |
+
if uncond_prompt is not None:
|
401 |
+
return text_embeddings, uncond_embeddings
|
402 |
+
return text_embeddings, None
|
403 |
+
|
404 |
+
|
405 |
+
def preprocess_image(image):
|
406 |
+
w, h = image.size
|
407 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
408 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
409 |
+
image = np.array(image).astype(np.float32) / 255.0
|
410 |
+
image = image[None].transpose(0, 3, 1, 2)
|
411 |
+
image = torch.from_numpy(image)
|
412 |
+
return 2.0 * image - 1.0
|
413 |
+
|
414 |
+
|
415 |
+
def preprocess_mask(mask, scale_factor=8):
|
416 |
+
mask = mask.convert("L")
|
417 |
+
w, h = mask.size
|
418 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
419 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
420 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
421 |
+
mask = np.tile(mask, (4, 1, 1))
|
422 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
423 |
+
mask = 1 - mask # repaint white, keep black
|
424 |
+
mask = torch.from_numpy(mask)
|
425 |
+
return mask
|
426 |
+
|
427 |
+
|
428 |
+
def prepare_controlnet_image(
|
429 |
+
image: PIL.Image.Image,
|
430 |
+
width: int,
|
431 |
+
height: int,
|
432 |
+
batch_size: int,
|
433 |
+
num_images_per_prompt: int,
|
434 |
+
device: torch.device,
|
435 |
+
dtype: torch.dtype,
|
436 |
+
do_classifier_free_guidance: bool = False,
|
437 |
+
guess_mode: bool = False,
|
438 |
+
):
|
439 |
+
if not isinstance(image, torch.Tensor):
|
440 |
+
if isinstance(image, PIL.Image.Image):
|
441 |
+
image = [image]
|
442 |
+
|
443 |
+
if isinstance(image[0], PIL.Image.Image):
|
444 |
+
images = []
|
445 |
+
|
446 |
+
for image_ in image:
|
447 |
+
image_ = image_.convert("RGB")
|
448 |
+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
449 |
+
image_ = np.array(image_)
|
450 |
+
image_ = image_[None, :]
|
451 |
+
images.append(image_)
|
452 |
+
|
453 |
+
image = images
|
454 |
+
|
455 |
+
image = np.concatenate(image, axis=0)
|
456 |
+
image = np.array(image).astype(np.float32) / 255.0
|
457 |
+
image = image.transpose(0, 3, 1, 2)
|
458 |
+
image = torch.from_numpy(image)
|
459 |
+
elif isinstance(image[0], torch.Tensor):
|
460 |
+
image = torch.cat(image, dim=0)
|
461 |
+
|
462 |
+
image_batch_size = image.shape[0]
|
463 |
+
|
464 |
+
if image_batch_size == 1:
|
465 |
+
repeat_by = batch_size
|
466 |
+
else:
|
467 |
+
# image batch size is the same as prompt batch size
|
468 |
+
repeat_by = num_images_per_prompt
|
469 |
+
|
470 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
471 |
+
|
472 |
+
image = image.to(device=device, dtype=dtype)
|
473 |
+
|
474 |
+
if do_classifier_free_guidance and not guess_mode:
|
475 |
+
image = torch.cat([image] * 2)
|
476 |
+
|
477 |
+
return image
|
478 |
+
|
479 |
+
|
480 |
+
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
481 |
+
r"""
|
482 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
483 |
+
weighting in prompt.
|
484 |
+
|
485 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
486 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
487 |
+
|
488 |
+
Args:
|
489 |
+
vae ([`AutoencoderKL`]):
|
490 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
491 |
+
text_encoder ([`CLIPTextModel`]):
|
492 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
493 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
494 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
495 |
+
tokenizer (`CLIPTokenizer`):
|
496 |
+
Tokenizer of class
|
497 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
498 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
499 |
+
scheduler ([`SchedulerMixin`]):
|
500 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
501 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
502 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
503 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
504 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
505 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
506 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
507 |
+
"""
|
508 |
+
|
509 |
+
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
510 |
+
|
511 |
+
def __init__(
|
512 |
+
self,
|
513 |
+
vae: AutoencoderKL,
|
514 |
+
text_encoder: CLIPTextModel,
|
515 |
+
tokenizer: CLIPTokenizer,
|
516 |
+
unet: UNet2DConditionModel,
|
517 |
+
scheduler: SchedulerMixin,
|
518 |
+
# clip_skip: int,
|
519 |
+
safety_checker: StableDiffusionSafetyChecker,
|
520 |
+
feature_extractor: CLIPFeatureExtractor,
|
521 |
+
requires_safety_checker: bool = True,
|
522 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
523 |
+
clip_skip: int = 1,
|
524 |
+
):
|
525 |
+
super().__init__(
|
526 |
+
vae=vae,
|
527 |
+
text_encoder=text_encoder,
|
528 |
+
tokenizer=tokenizer,
|
529 |
+
unet=unet,
|
530 |
+
scheduler=scheduler,
|
531 |
+
safety_checker=safety_checker,
|
532 |
+
feature_extractor=feature_extractor,
|
533 |
+
requires_safety_checker=requires_safety_checker,
|
534 |
+
image_encoder=image_encoder,
|
535 |
+
)
|
536 |
+
self.custom_clip_skip = clip_skip
|
537 |
+
self.__init__additional__()
|
538 |
+
|
539 |
+
def __init__additional__(self):
|
540 |
+
if not hasattr(self, "vae_scale_factor"):
|
541 |
+
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
542 |
+
|
543 |
+
@property
|
544 |
+
def _execution_device(self):
|
545 |
+
r"""
|
546 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
547 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
548 |
+
hooks.
|
549 |
+
"""
|
550 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
551 |
+
return self.device
|
552 |
+
for module in self.unet.modules():
|
553 |
+
if (
|
554 |
+
hasattr(module, "_hf_hook")
|
555 |
+
and hasattr(module._hf_hook, "execution_device")
|
556 |
+
and module._hf_hook.execution_device is not None
|
557 |
+
):
|
558 |
+
return torch.device(module._hf_hook.execution_device)
|
559 |
+
return self.device
|
560 |
+
|
561 |
+
def _encode_prompt(
|
562 |
+
self,
|
563 |
+
prompt,
|
564 |
+
device,
|
565 |
+
num_images_per_prompt,
|
566 |
+
do_classifier_free_guidance,
|
567 |
+
negative_prompt,
|
568 |
+
max_embeddings_multiples,
|
569 |
+
):
|
570 |
+
r"""
|
571 |
+
Encodes the prompt into text encoder hidden states.
|
572 |
+
|
573 |
+
Args:
|
574 |
+
prompt (`str` or `list(int)`):
|
575 |
+
prompt to be encoded
|
576 |
+
device: (`torch.device`):
|
577 |
+
torch device
|
578 |
+
num_images_per_prompt (`int`):
|
579 |
+
number of images that should be generated per prompt
|
580 |
+
do_classifier_free_guidance (`bool`):
|
581 |
+
whether to use classifier free guidance or not
|
582 |
+
negative_prompt (`str` or `List[str]`):
|
583 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
584 |
+
if `guidance_scale` is less than `1`).
|
585 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
586 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
587 |
+
"""
|
588 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
589 |
+
|
590 |
+
if negative_prompt is None:
|
591 |
+
negative_prompt = [""] * batch_size
|
592 |
+
elif isinstance(negative_prompt, str):
|
593 |
+
negative_prompt = [negative_prompt] * batch_size
|
594 |
+
if batch_size != len(negative_prompt):
|
595 |
+
raise ValueError(
|
596 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
597 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
598 |
+
" the batch size of `prompt`."
|
599 |
+
)
|
600 |
+
|
601 |
+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
602 |
+
pipe=self,
|
603 |
+
prompt=prompt,
|
604 |
+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
605 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
606 |
+
clip_skip=self.custom_clip_skip,
|
607 |
+
)
|
608 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
609 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
610 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
611 |
+
|
612 |
+
if do_classifier_free_guidance:
|
613 |
+
bs_embed, seq_len, _ = uncond_embeddings.shape
|
614 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
615 |
+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
616 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
617 |
+
|
618 |
+
return text_embeddings
|
619 |
+
|
620 |
+
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
621 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
622 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
623 |
+
|
624 |
+
if strength < 0 or strength > 1:
|
625 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
626 |
+
|
627 |
+
if height % 8 != 0 or width % 8 != 0:
|
628 |
+
logger.info(f'{height} {width}')
|
629 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
630 |
+
|
631 |
+
if (callback_steps is None) or (
|
632 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
633 |
+
):
|
634 |
+
raise ValueError(
|
635 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
636 |
+
)
|
637 |
+
|
638 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
639 |
+
if is_text2img:
|
640 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
641 |
+
else:
|
642 |
+
# get the original timestep using init_timestep
|
643 |
+
offset = self.scheduler.config.get("steps_offset", 0)
|
644 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
645 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
646 |
+
|
647 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
648 |
+
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
649 |
+
return timesteps, num_inference_steps - t_start
|
650 |
+
|
651 |
+
def run_safety_checker(self, image, device, dtype):
|
652 |
+
if self.safety_checker is not None:
|
653 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
654 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
655 |
+
else:
|
656 |
+
has_nsfw_concept = None
|
657 |
+
return image, has_nsfw_concept
|
658 |
+
|
659 |
+
def decode_latents(self, latents):
|
660 |
+
latents = 1 / 0.18215 * latents
|
661 |
+
image = self.vae.decode(latents).sample
|
662 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
663 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
664 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
665 |
+
return image
|
666 |
+
|
667 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
668 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
669 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
670 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
671 |
+
# and should be between [0, 1]
|
672 |
+
|
673 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
674 |
+
extra_step_kwargs = {}
|
675 |
+
if accepts_eta:
|
676 |
+
extra_step_kwargs["eta"] = eta
|
677 |
+
|
678 |
+
# check if the scheduler accepts generator
|
679 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
680 |
+
if accepts_generator:
|
681 |
+
extra_step_kwargs["generator"] = generator
|
682 |
+
return extra_step_kwargs
|
683 |
+
|
684 |
+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
|
685 |
+
if image is None:
|
686 |
+
shape = (
|
687 |
+
batch_size,
|
688 |
+
self.unet.in_channels,
|
689 |
+
height // self.vae_scale_factor,
|
690 |
+
width // self.vae_scale_factor,
|
691 |
+
)
|
692 |
+
|
693 |
+
if latents is None:
|
694 |
+
if device.type == "mps":
|
695 |
+
# randn does not work reproducibly on mps
|
696 |
+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
697 |
+
else:
|
698 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
699 |
+
else:
|
700 |
+
if latents.shape != shape:
|
701 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
702 |
+
latents = latents.to(device)
|
703 |
+
|
704 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
705 |
+
latents = latents * self.scheduler.init_noise_sigma
|
706 |
+
return latents, None, None
|
707 |
+
else:
|
708 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
709 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
710 |
+
init_latents = 0.18215 * init_latents
|
711 |
+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
|
712 |
+
init_latents_orig = init_latents
|
713 |
+
shape = init_latents.shape
|
714 |
+
|
715 |
+
# add noise to latents using the timesteps
|
716 |
+
if device.type == "mps":
|
717 |
+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
718 |
+
else:
|
719 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
720 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
721 |
+
return latents, init_latents_orig, noise
|
722 |
+
|
723 |
+
@torch.no_grad()
|
724 |
+
def __call__(
|
725 |
+
self,
|
726 |
+
prompt: Union[str, List[str]],
|
727 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
728 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
729 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
730 |
+
height: int = 512,
|
731 |
+
width: int = 512,
|
732 |
+
num_inference_steps: int = 50,
|
733 |
+
guidance_scale: float = 7.5,
|
734 |
+
strength: float = 0.8,
|
735 |
+
num_images_per_prompt: Optional[int] = 1,
|
736 |
+
eta: float = 0.0,
|
737 |
+
generator: Optional[torch.Generator] = None,
|
738 |
+
latents: Optional[torch.FloatTensor] = None,
|
739 |
+
max_embeddings_multiples: Optional[int] = 3,
|
740 |
+
output_type: Optional[str] = "pil",
|
741 |
+
return_dict: bool = True,
|
742 |
+
controlnet=None,
|
743 |
+
controlnet_image=None,
|
744 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
745 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
746 |
+
callback_steps: int = 1,
|
747 |
+
):
|
748 |
+
r"""
|
749 |
+
Function invoked when calling the pipeline for generation.
|
750 |
+
|
751 |
+
Args:
|
752 |
+
prompt (`str` or `List[str]`):
|
753 |
+
The prompt or prompts to guide the image generation.
|
754 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
755 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
756 |
+
if `guidance_scale` is less than `1`).
|
757 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
758 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
759 |
+
process.
|
760 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
761 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
762 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
763 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
764 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
765 |
+
height (`int`, *optional*, defaults to 512):
|
766 |
+
The height in pixels of the generated image.
|
767 |
+
width (`int`, *optional*, defaults to 512):
|
768 |
+
The width in pixels of the generated image.
|
769 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
770 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
771 |
+
expense of slower inference.
|
772 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
773 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
774 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
775 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
776 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
777 |
+
usually at the expense of lower image quality.
|
778 |
+
strength (`float`, *optional*, defaults to 0.8):
|
779 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
780 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
781 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
782 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
783 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
784 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
785 |
+
The number of images to generate per prompt.
|
786 |
+
eta (`float`, *optional*, defaults to 0.0):
|
787 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
788 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
789 |
+
generator (`torch.Generator`, *optional*):
|
790 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
791 |
+
deterministic.
|
792 |
+
latents (`torch.FloatTensor`, *optional*):
|
793 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
794 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
795 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
796 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
797 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
798 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
799 |
+
The output format of the generate image. Choose between
|
800 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
801 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
802 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
803 |
+
plain tuple.
|
804 |
+
controlnet (`diffusers.ControlNetModel`, *optional*):
|
805 |
+
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
806 |
+
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
807 |
+
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
808 |
+
inference.
|
809 |
+
callback (`Callable`, *optional*):
|
810 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
811 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
812 |
+
is_cancelled_callback (`Callable`, *optional*):
|
813 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
814 |
+
`True`, the inference will be cancelled.
|
815 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
816 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
817 |
+
called at every step.
|
818 |
+
|
819 |
+
Returns:
|
820 |
+
`None` if cancelled by `is_cancelled_callback`,
|
821 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
822 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
823 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
824 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
825 |
+
(nsfw) content, according to the `safety_checker`.
|
826 |
+
"""
|
827 |
+
if controlnet is not None and controlnet_image is None:
|
828 |
+
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
829 |
+
|
830 |
+
# 0. Default height and width to unet
|
831 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
832 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
833 |
+
|
834 |
+
# 1. Check inputs. Raise error if not correct
|
835 |
+
self.check_inputs(prompt, height, width, strength, callback_steps)
|
836 |
+
|
837 |
+
# 2. Define call parameters
|
838 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
839 |
+
device = self._execution_device
|
840 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
841 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
842 |
+
# corresponds to doing no classifier free guidance.
|
843 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
844 |
+
|
845 |
+
# 3. Encode input prompt
|
846 |
+
text_embeddings = self._encode_prompt(
|
847 |
+
prompt,
|
848 |
+
device,
|
849 |
+
num_images_per_prompt,
|
850 |
+
do_classifier_free_guidance,
|
851 |
+
negative_prompt,
|
852 |
+
max_embeddings_multiples,
|
853 |
+
)
|
854 |
+
dtype = text_embeddings.dtype
|
855 |
+
|
856 |
+
# 4. Preprocess image and mask
|
857 |
+
if isinstance(image, PIL.Image.Image):
|
858 |
+
image = preprocess_image(image)
|
859 |
+
if image is not None:
|
860 |
+
image = image.to(device=self.device, dtype=dtype)
|
861 |
+
if isinstance(mask_image, PIL.Image.Image):
|
862 |
+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
863 |
+
if mask_image is not None:
|
864 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
865 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
866 |
+
else:
|
867 |
+
mask = None
|
868 |
+
|
869 |
+
if controlnet_image is not None:
|
870 |
+
controlnet_image = prepare_controlnet_image(
|
871 |
+
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
872 |
+
)
|
873 |
+
|
874 |
+
# 5. set timesteps
|
875 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
876 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
877 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
878 |
+
|
879 |
+
# 6. Prepare latent variables
|
880 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
881 |
+
image,
|
882 |
+
latent_timestep,
|
883 |
+
batch_size * num_images_per_prompt,
|
884 |
+
height,
|
885 |
+
width,
|
886 |
+
dtype,
|
887 |
+
device,
|
888 |
+
generator,
|
889 |
+
latents,
|
890 |
+
)
|
891 |
+
|
892 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
893 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
894 |
+
|
895 |
+
# 8. Denoising loop
|
896 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
897 |
+
# expand the latents if we are doing classifier free guidance
|
898 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
899 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
900 |
+
|
901 |
+
unet_additional_args = {}
|
902 |
+
if controlnet is not None:
|
903 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
904 |
+
latent_model_input,
|
905 |
+
t,
|
906 |
+
encoder_hidden_states=text_embeddings,
|
907 |
+
controlnet_cond=controlnet_image,
|
908 |
+
conditioning_scale=1.0,
|
909 |
+
guess_mode=False,
|
910 |
+
return_dict=False,
|
911 |
+
)
|
912 |
+
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
913 |
+
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
914 |
+
|
915 |
+
# predict the noise residual
|
916 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
917 |
+
|
918 |
+
# perform guidance
|
919 |
+
if do_classifier_free_guidance:
|
920 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
921 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
922 |
+
|
923 |
+
# compute the previous noisy sample x_t -> x_t-1
|
924 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
925 |
+
|
926 |
+
if mask is not None:
|
927 |
+
# masking
|
928 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
929 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
930 |
+
|
931 |
+
# call the callback, if provided
|
932 |
+
if i % callback_steps == 0:
|
933 |
+
if callback is not None:
|
934 |
+
callback(i, t, latents)
|
935 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
936 |
+
return None
|
937 |
+
|
938 |
+
return latents
|
939 |
+
|
940 |
+
def latents_to_image(self, latents):
|
941 |
+
# 9. Post-processing
|
942 |
+
image = self.decode_latents(latents.to(self.vae.dtype))
|
943 |
+
image = self.numpy_to_pil(image)
|
944 |
+
return image
|
945 |
+
|
946 |
+
def text2img(
|
947 |
+
self,
|
948 |
+
prompt: Union[str, List[str]],
|
949 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
950 |
+
height: int = 512,
|
951 |
+
width: int = 512,
|
952 |
+
num_inference_steps: int = 50,
|
953 |
+
guidance_scale: float = 7.5,
|
954 |
+
num_images_per_prompt: Optional[int] = 1,
|
955 |
+
eta: float = 0.0,
|
956 |
+
generator: Optional[torch.Generator] = None,
|
957 |
+
latents: Optional[torch.FloatTensor] = None,
|
958 |
+
max_embeddings_multiples: Optional[int] = 3,
|
959 |
+
output_type: Optional[str] = "pil",
|
960 |
+
return_dict: bool = True,
|
961 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
962 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
963 |
+
callback_steps: int = 1,
|
964 |
+
):
|
965 |
+
r"""
|
966 |
+
Function for text-to-image generation.
|
967 |
+
Args:
|
968 |
+
prompt (`str` or `List[str]`):
|
969 |
+
The prompt or prompts to guide the image generation.
|
970 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
971 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
972 |
+
if `guidance_scale` is less than `1`).
|
973 |
+
height (`int`, *optional*, defaults to 512):
|
974 |
+
The height in pixels of the generated image.
|
975 |
+
width (`int`, *optional*, defaults to 512):
|
976 |
+
The width in pixels of the generated image.
|
977 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
978 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
979 |
+
expense of slower inference.
|
980 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
981 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
982 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
983 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
984 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
985 |
+
usually at the expense of lower image quality.
|
986 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
987 |
+
The number of images to generate per prompt.
|
988 |
+
eta (`float`, *optional*, defaults to 0.0):
|
989 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
990 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
991 |
+
generator (`torch.Generator`, *optional*):
|
992 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
993 |
+
deterministic.
|
994 |
+
latents (`torch.FloatTensor`, *optional*):
|
995 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
996 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
997 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
998 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
999 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1000 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1001 |
+
The output format of the generate image. Choose between
|
1002 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1003 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1004 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1005 |
+
plain tuple.
|
1006 |
+
callback (`Callable`, *optional*):
|
1007 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1008 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1009 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1010 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1011 |
+
`True`, the inference will be cancelled.
|
1012 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1013 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1014 |
+
called at every step.
|
1015 |
+
Returns:
|
1016 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1017 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1018 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1019 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1020 |
+
(nsfw) content, according to the `safety_checker`.
|
1021 |
+
"""
|
1022 |
+
return self.__call__(
|
1023 |
+
prompt=prompt,
|
1024 |
+
negative_prompt=negative_prompt,
|
1025 |
+
height=height,
|
1026 |
+
width=width,
|
1027 |
+
num_inference_steps=num_inference_steps,
|
1028 |
+
guidance_scale=guidance_scale,
|
1029 |
+
num_images_per_prompt=num_images_per_prompt,
|
1030 |
+
eta=eta,
|
1031 |
+
generator=generator,
|
1032 |
+
latents=latents,
|
1033 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1034 |
+
output_type=output_type,
|
1035 |
+
return_dict=return_dict,
|
1036 |
+
callback=callback,
|
1037 |
+
is_cancelled_callback=is_cancelled_callback,
|
1038 |
+
callback_steps=callback_steps,
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
def img2img(
|
1042 |
+
self,
|
1043 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1044 |
+
prompt: Union[str, List[str]],
|
1045 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1046 |
+
strength: float = 0.8,
|
1047 |
+
num_inference_steps: Optional[int] = 50,
|
1048 |
+
guidance_scale: Optional[float] = 7.5,
|
1049 |
+
num_images_per_prompt: Optional[int] = 1,
|
1050 |
+
eta: Optional[float] = 0.0,
|
1051 |
+
generator: Optional[torch.Generator] = None,
|
1052 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1053 |
+
output_type: Optional[str] = "pil",
|
1054 |
+
return_dict: bool = True,
|
1055 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1056 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1057 |
+
callback_steps: int = 1,
|
1058 |
+
):
|
1059 |
+
r"""
|
1060 |
+
Function for image-to-image generation.
|
1061 |
+
Args:
|
1062 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1063 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1064 |
+
process.
|
1065 |
+
prompt (`str` or `List[str]`):
|
1066 |
+
The prompt or prompts to guide the image generation.
|
1067 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1068 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1069 |
+
if `guidance_scale` is less than `1`).
|
1070 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1071 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
1072 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
1073 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
1074 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
1075 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
1076 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1077 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1078 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
1079 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1080 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1081 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1082 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1083 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1084 |
+
usually at the expense of lower image quality.
|
1085 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1086 |
+
The number of images to generate per prompt.
|
1087 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1088 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1089 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1090 |
+
generator (`torch.Generator`, *optional*):
|
1091 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1092 |
+
deterministic.
|
1093 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1094 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1095 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1096 |
+
The output format of the generate image. Choose between
|
1097 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1098 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1099 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1100 |
+
plain tuple.
|
1101 |
+
callback (`Callable`, *optional*):
|
1102 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1103 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1104 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1105 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1106 |
+
`True`, the inference will be cancelled.
|
1107 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1108 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1109 |
+
called at every step.
|
1110 |
+
Returns:
|
1111 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1112 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1113 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1114 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1115 |
+
(nsfw) content, according to the `safety_checker`.
|
1116 |
+
"""
|
1117 |
+
return self.__call__(
|
1118 |
+
prompt=prompt,
|
1119 |
+
negative_prompt=negative_prompt,
|
1120 |
+
image=image,
|
1121 |
+
num_inference_steps=num_inference_steps,
|
1122 |
+
guidance_scale=guidance_scale,
|
1123 |
+
strength=strength,
|
1124 |
+
num_images_per_prompt=num_images_per_prompt,
|
1125 |
+
eta=eta,
|
1126 |
+
generator=generator,
|
1127 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1128 |
+
output_type=output_type,
|
1129 |
+
return_dict=return_dict,
|
1130 |
+
callback=callback,
|
1131 |
+
is_cancelled_callback=is_cancelled_callback,
|
1132 |
+
callback_steps=callback_steps,
|
1133 |
+
)
|
1134 |
+
|
1135 |
+
def inpaint(
|
1136 |
+
self,
|
1137 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1138 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
1139 |
+
prompt: Union[str, List[str]],
|
1140 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1141 |
+
strength: float = 0.8,
|
1142 |
+
num_inference_steps: Optional[int] = 50,
|
1143 |
+
guidance_scale: Optional[float] = 7.5,
|
1144 |
+
num_images_per_prompt: Optional[int] = 1,
|
1145 |
+
eta: Optional[float] = 0.0,
|
1146 |
+
generator: Optional[torch.Generator] = None,
|
1147 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1148 |
+
output_type: Optional[str] = "pil",
|
1149 |
+
return_dict: bool = True,
|
1150 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1151 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1152 |
+
callback_steps: int = 1,
|
1153 |
+
):
|
1154 |
+
r"""
|
1155 |
+
Function for inpaint.
|
1156 |
+
Args:
|
1157 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1158 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1159 |
+
process. This is the image whose masked region will be inpainted.
|
1160 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1161 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
1162 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
1163 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
1164 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
1165 |
+
prompt (`str` or `List[str]`):
|
1166 |
+
The prompt or prompts to guide the image generation.
|
1167 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1168 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1169 |
+
if `guidance_scale` is less than `1`).
|
1170 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1171 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
1172 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
1173 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
1174 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
1175 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1176 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
1177 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
1178 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1179 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1180 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1181 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1182 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1183 |
+
usually at the expense of lower image quality.
|
1184 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1185 |
+
The number of images to generate per prompt.
|
1186 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1187 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1188 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1189 |
+
generator (`torch.Generator`, *optional*):
|
1190 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1191 |
+
deterministic.
|
1192 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1193 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1194 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1195 |
+
The output format of the generate image. Choose between
|
1196 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1197 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1198 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1199 |
+
plain tuple.
|
1200 |
+
callback (`Callable`, *optional*):
|
1201 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1202 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1203 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1204 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1205 |
+
`True`, the inference will be cancelled.
|
1206 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1207 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1208 |
+
called at every step.
|
1209 |
+
Returns:
|
1210 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1211 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1212 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1213 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1214 |
+
(nsfw) content, according to the `safety_checker`.
|
1215 |
+
"""
|
1216 |
+
return self.__call__(
|
1217 |
+
prompt=prompt,
|
1218 |
+
negative_prompt=negative_prompt,
|
1219 |
+
image=image,
|
1220 |
+
mask_image=mask_image,
|
1221 |
+
num_inference_steps=num_inference_steps,
|
1222 |
+
guidance_scale=guidance_scale,
|
1223 |
+
strength=strength,
|
1224 |
+
num_images_per_prompt=num_images_per_prompt,
|
1225 |
+
eta=eta,
|
1226 |
+
generator=generator,
|
1227 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1228 |
+
output_type=output_type,
|
1229 |
+
return_dict=return_dict,
|
1230 |
+
callback=callback,
|
1231 |
+
is_cancelled_callback=is_cancelled_callback,
|
1232 |
+
callback_steps=callback_steps,
|
1233 |
+
)
|
library/model_util.py
ADDED
@@ -0,0 +1,1356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: split from train_db_fixed.py.
|
2 |
+
# v2: support safetensors
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from library.device_utils import init_ipex
|
9 |
+
init_ipex()
|
10 |
+
|
11 |
+
import diffusers
|
12 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
13 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
14 |
+
from safetensors.torch import load_file, save_file
|
15 |
+
from library.original_unet import UNet2DConditionModel
|
16 |
+
from library.utils import setup_logging
|
17 |
+
setup_logging()
|
18 |
+
import logging
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
22 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
23 |
+
BETA_START = 0.00085
|
24 |
+
BETA_END = 0.0120
|
25 |
+
|
26 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
27 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
28 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
29 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
30 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
31 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
32 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
33 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
34 |
+
UNET_PARAMS_NUM_HEADS = 8
|
35 |
+
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
36 |
+
|
37 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
38 |
+
VAE_PARAMS_RESOLUTION = 256
|
39 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
40 |
+
VAE_PARAMS_OUT_CH = 3
|
41 |
+
VAE_PARAMS_CH = 128
|
42 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
43 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
44 |
+
|
45 |
+
# V2
|
46 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
47 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
48 |
+
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
49 |
+
|
50 |
+
# Diffusersの設定を読み込むための参照モデル
|
51 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
52 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
53 |
+
|
54 |
+
|
55 |
+
# region StableDiffusion->Diffusersの変換コード
|
56 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
57 |
+
|
58 |
+
|
59 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
60 |
+
"""
|
61 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
62 |
+
"""
|
63 |
+
if n_shave_prefix_segments >= 0:
|
64 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
65 |
+
else:
|
66 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
67 |
+
|
68 |
+
|
69 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
70 |
+
"""
|
71 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
72 |
+
"""
|
73 |
+
mapping = []
|
74 |
+
for old_item in old_list:
|
75 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
76 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
77 |
+
|
78 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
79 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
80 |
+
|
81 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
82 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
83 |
+
|
84 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
85 |
+
|
86 |
+
mapping.append({"old": old_item, "new": new_item})
|
87 |
+
|
88 |
+
return mapping
|
89 |
+
|
90 |
+
|
91 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
92 |
+
"""
|
93 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
94 |
+
"""
|
95 |
+
mapping = []
|
96 |
+
for old_item in old_list:
|
97 |
+
new_item = old_item
|
98 |
+
|
99 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
100 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
101 |
+
|
102 |
+
mapping.append({"old": old_item, "new": new_item})
|
103 |
+
|
104 |
+
return mapping
|
105 |
+
|
106 |
+
|
107 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
108 |
+
"""
|
109 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
110 |
+
"""
|
111 |
+
mapping = []
|
112 |
+
for old_item in old_list:
|
113 |
+
new_item = old_item
|
114 |
+
|
115 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
116 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
117 |
+
|
118 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
119 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
120 |
+
|
121 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
122 |
+
|
123 |
+
mapping.append({"old": old_item, "new": new_item})
|
124 |
+
|
125 |
+
return mapping
|
126 |
+
|
127 |
+
|
128 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
129 |
+
"""
|
130 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
131 |
+
"""
|
132 |
+
mapping = []
|
133 |
+
for old_item in old_list:
|
134 |
+
new_item = old_item
|
135 |
+
|
136 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
137 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
138 |
+
|
139 |
+
if diffusers.__version__ < "0.17.0":
|
140 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
141 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
142 |
+
|
143 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
144 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
145 |
+
|
146 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
147 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
148 |
+
|
149 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
150 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
151 |
+
else:
|
152 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
153 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
154 |
+
|
155 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
156 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
157 |
+
|
158 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
159 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
160 |
+
|
161 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
162 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
163 |
+
|
164 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
165 |
+
|
166 |
+
mapping.append({"old": old_item, "new": new_item})
|
167 |
+
|
168 |
+
return mapping
|
169 |
+
|
170 |
+
|
171 |
+
def assign_to_checkpoint(
|
172 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
173 |
+
):
|
174 |
+
"""
|
175 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
176 |
+
to them. It splits attention layers, and takes into account additional replacements
|
177 |
+
that may arise.
|
178 |
+
|
179 |
+
Assigns the weights to the new checkpoint.
|
180 |
+
"""
|
181 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
182 |
+
|
183 |
+
# Splits the attention layers into three variables.
|
184 |
+
if attention_paths_to_split is not None:
|
185 |
+
for path, path_map in attention_paths_to_split.items():
|
186 |
+
old_tensor = old_checkpoint[path]
|
187 |
+
channels = old_tensor.shape[0] // 3
|
188 |
+
|
189 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
190 |
+
|
191 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
192 |
+
|
193 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
194 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
195 |
+
|
196 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
197 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
198 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
199 |
+
|
200 |
+
for path in paths:
|
201 |
+
new_path = path["new"]
|
202 |
+
|
203 |
+
# These have already been assigned
|
204 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
205 |
+
continue
|
206 |
+
|
207 |
+
# Global renaming happens here
|
208 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
209 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
210 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
211 |
+
|
212 |
+
if additional_replacements is not None:
|
213 |
+
for replacement in additional_replacements:
|
214 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
215 |
+
|
216 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
217 |
+
reshaping = False
|
218 |
+
if diffusers.__version__ < "0.17.0":
|
219 |
+
if "proj_attn.weight" in new_path:
|
220 |
+
reshaping = True
|
221 |
+
else:
|
222 |
+
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
223 |
+
reshaping = True
|
224 |
+
|
225 |
+
if reshaping:
|
226 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
227 |
+
else:
|
228 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
229 |
+
|
230 |
+
|
231 |
+
def conv_attn_to_linear(checkpoint):
|
232 |
+
keys = list(checkpoint.keys())
|
233 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
234 |
+
for key in keys:
|
235 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
236 |
+
if checkpoint[key].ndim > 2:
|
237 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
238 |
+
elif "proj_attn.weight" in key:
|
239 |
+
if checkpoint[key].ndim > 2:
|
240 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
241 |
+
|
242 |
+
|
243 |
+
def linear_transformer_to_conv(checkpoint):
|
244 |
+
keys = list(checkpoint.keys())
|
245 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
246 |
+
for key in keys:
|
247 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
248 |
+
if checkpoint[key].ndim == 2:
|
249 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
250 |
+
|
251 |
+
|
252 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
253 |
+
"""
|
254 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
255 |
+
"""
|
256 |
+
|
257 |
+
# extract state_dict for UNet
|
258 |
+
unet_state_dict = {}
|
259 |
+
unet_key = "model.diffusion_model."
|
260 |
+
keys = list(checkpoint.keys())
|
261 |
+
for key in keys:
|
262 |
+
if key.startswith(unet_key):
|
263 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
264 |
+
|
265 |
+
new_checkpoint = {}
|
266 |
+
|
267 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
268 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
269 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
270 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
271 |
+
|
272 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
273 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
274 |
+
|
275 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
276 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
277 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
278 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
279 |
+
|
280 |
+
# Retrieves the keys for the input blocks only
|
281 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
282 |
+
input_blocks = {
|
283 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
|
284 |
+
}
|
285 |
+
|
286 |
+
# Retrieves the keys for the middle blocks only
|
287 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
288 |
+
middle_blocks = {
|
289 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
|
290 |
+
}
|
291 |
+
|
292 |
+
# Retrieves the keys for the output blocks only
|
293 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
294 |
+
output_blocks = {
|
295 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
|
296 |
+
}
|
297 |
+
|
298 |
+
for i in range(1, num_input_blocks):
|
299 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
300 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
301 |
+
|
302 |
+
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
303 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
304 |
+
|
305 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
306 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
307 |
+
f"input_blocks.{i}.0.op.weight"
|
308 |
+
)
|
309 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
|
310 |
+
|
311 |
+
paths = renew_resnet_paths(resnets)
|
312 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
313 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
314 |
+
|
315 |
+
if len(attentions):
|
316 |
+
paths = renew_attention_paths(attentions)
|
317 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
318 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
319 |
+
|
320 |
+
resnet_0 = middle_blocks[0]
|
321 |
+
attentions = middle_blocks[1]
|
322 |
+
resnet_1 = middle_blocks[2]
|
323 |
+
|
324 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
325 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
326 |
+
|
327 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
328 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
329 |
+
|
330 |
+
attentions_paths = renew_attention_paths(attentions)
|
331 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
332 |
+
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
333 |
+
|
334 |
+
for i in range(num_output_blocks):
|
335 |
+
block_id = i // (config["layers_per_block"] + 1)
|
336 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
337 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
338 |
+
output_block_list = {}
|
339 |
+
|
340 |
+
for layer in output_block_layers:
|
341 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
342 |
+
if layer_id in output_block_list:
|
343 |
+
output_block_list[layer_id].append(layer_name)
|
344 |
+
else:
|
345 |
+
output_block_list[layer_id] = [layer_name]
|
346 |
+
|
347 |
+
if len(output_block_list) > 1:
|
348 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
349 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
350 |
+
|
351 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
352 |
+
paths = renew_resnet_paths(resnets)
|
353 |
+
|
354 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
355 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
356 |
+
|
357 |
+
# オリジナル:
|
358 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
359 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
360 |
+
|
361 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
362 |
+
for l in output_block_list.values():
|
363 |
+
l.sort()
|
364 |
+
|
365 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
366 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
367 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
368 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
369 |
+
]
|
370 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
371 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
372 |
+
]
|
373 |
+
|
374 |
+
# Clear attentions as they have been attributed above.
|
375 |
+
if len(attentions) == 2:
|
376 |
+
attentions = []
|
377 |
+
|
378 |
+
if len(attentions):
|
379 |
+
paths = renew_attention_paths(attentions)
|
380 |
+
meta_path = {
|
381 |
+
"old": f"output_blocks.{i}.1",
|
382 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
383 |
+
}
|
384 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
385 |
+
else:
|
386 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
387 |
+
for path in resnet_0_paths:
|
388 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
389 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
390 |
+
|
391 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
392 |
+
|
393 |
+
# SDのv2では1*1のconv2dがlinearに変わっている
|
394 |
+
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
395 |
+
if v2 and not config.get("use_linear_projection", False):
|
396 |
+
linear_transformer_to_conv(new_checkpoint)
|
397 |
+
|
398 |
+
return new_checkpoint
|
399 |
+
|
400 |
+
|
401 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
402 |
+
# extract state dict for VAE
|
403 |
+
vae_state_dict = {}
|
404 |
+
vae_key = "first_stage_model."
|
405 |
+
keys = list(checkpoint.keys())
|
406 |
+
for key in keys:
|
407 |
+
if key.startswith(vae_key):
|
408 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
409 |
+
# if len(vae_state_dict) == 0:
|
410 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
411 |
+
# vae_state_dict = checkpoint
|
412 |
+
|
413 |
+
new_checkpoint = {}
|
414 |
+
|
415 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
416 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
417 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
418 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
419 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
420 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
421 |
+
|
422 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
423 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
424 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
425 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
426 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
427 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
428 |
+
|
429 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
430 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
431 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
432 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
433 |
+
|
434 |
+
# Retrieves the keys for the encoder down blocks only
|
435 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
436 |
+
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
|
437 |
+
|
438 |
+
# Retrieves the keys for the decoder up blocks only
|
439 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
440 |
+
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
441 |
+
|
442 |
+
for i in range(num_down_blocks):
|
443 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
444 |
+
|
445 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
446 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
447 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
448 |
+
)
|
449 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
450 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
451 |
+
)
|
452 |
+
|
453 |
+
paths = renew_vae_resnet_paths(resnets)
|
454 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
455 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
456 |
+
|
457 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
458 |
+
num_mid_res_blocks = 2
|
459 |
+
for i in range(1, num_mid_res_blocks + 1):
|
460 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
461 |
+
|
462 |
+
paths = renew_vae_resnet_paths(resnets)
|
463 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
464 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
465 |
+
|
466 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
467 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
468 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
469 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
470 |
+
conv_attn_to_linear(new_checkpoint)
|
471 |
+
|
472 |
+
for i in range(num_up_blocks):
|
473 |
+
block_id = num_up_blocks - 1 - i
|
474 |
+
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
475 |
+
|
476 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
477 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
478 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
479 |
+
]
|
480 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
481 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
482 |
+
]
|
483 |
+
|
484 |
+
paths = renew_vae_resnet_paths(resnets)
|
485 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
486 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
487 |
+
|
488 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
489 |
+
num_mid_res_blocks = 2
|
490 |
+
for i in range(1, num_mid_res_blocks + 1):
|
491 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
492 |
+
|
493 |
+
paths = renew_vae_resnet_paths(resnets)
|
494 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
495 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
496 |
+
|
497 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
498 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
499 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
500 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
501 |
+
conv_attn_to_linear(new_checkpoint)
|
502 |
+
return new_checkpoint
|
503 |
+
|
504 |
+
|
505 |
+
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
|
506 |
+
"""
|
507 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
508 |
+
"""
|
509 |
+
# unet_params = original_config.model.params.unet_config.params
|
510 |
+
|
511 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
512 |
+
|
513 |
+
down_block_types = []
|
514 |
+
resolution = 1
|
515 |
+
for i in range(len(block_out_channels)):
|
516 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
517 |
+
down_block_types.append(block_type)
|
518 |
+
if i != len(block_out_channels) - 1:
|
519 |
+
resolution *= 2
|
520 |
+
|
521 |
+
up_block_types = []
|
522 |
+
for i in range(len(block_out_channels)):
|
523 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
524 |
+
up_block_types.append(block_type)
|
525 |
+
resolution //= 2
|
526 |
+
|
527 |
+
config = dict(
|
528 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
529 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
530 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
531 |
+
down_block_types=tuple(down_block_types),
|
532 |
+
up_block_types=tuple(up_block_types),
|
533 |
+
block_out_channels=tuple(block_out_channels),
|
534 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
535 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
536 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
537 |
+
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
538 |
+
)
|
539 |
+
if v2 and use_linear_projection_in_v2:
|
540 |
+
config["use_linear_projection"] = True
|
541 |
+
|
542 |
+
return config
|
543 |
+
|
544 |
+
|
545 |
+
def create_vae_diffusers_config():
|
546 |
+
"""
|
547 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
548 |
+
"""
|
549 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
550 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
551 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
552 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
553 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
554 |
+
|
555 |
+
config = dict(
|
556 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
557 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
558 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
559 |
+
down_block_types=tuple(down_block_types),
|
560 |
+
up_block_types=tuple(up_block_types),
|
561 |
+
block_out_channels=tuple(block_out_channels),
|
562 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
563 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
564 |
+
)
|
565 |
+
return config
|
566 |
+
|
567 |
+
|
568 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
569 |
+
keys = list(checkpoint.keys())
|
570 |
+
text_model_dict = {}
|
571 |
+
for key in keys:
|
572 |
+
if key.startswith("cond_stage_model.transformer"):
|
573 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
574 |
+
|
575 |
+
# remove position_ids for newer transformer, which causes error :(
|
576 |
+
if "text_model.embeddings.position_ids" in text_model_dict:
|
577 |
+
text_model_dict.pop("text_model.embeddings.position_ids")
|
578 |
+
|
579 |
+
return text_model_dict
|
580 |
+
|
581 |
+
|
582 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
583 |
+
# 嫌になるくらい違うぞ!
|
584 |
+
def convert_key(key):
|
585 |
+
if not key.startswith("cond_stage_model"):
|
586 |
+
return None
|
587 |
+
|
588 |
+
# common conversion
|
589 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
590 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
591 |
+
|
592 |
+
if "resblocks" in key:
|
593 |
+
# resblocks conversion
|
594 |
+
key = key.replace(".resblocks.", ".layers.")
|
595 |
+
if ".ln_" in key:
|
596 |
+
key = key.replace(".ln_", ".layer_norm")
|
597 |
+
elif ".mlp." in key:
|
598 |
+
key = key.replace(".c_fc.", ".fc1.")
|
599 |
+
key = key.replace(".c_proj.", ".fc2.")
|
600 |
+
elif ".attn.out_proj" in key:
|
601 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
602 |
+
elif ".attn.in_proj" in key:
|
603 |
+
key = None # 特殊なので後で処理する
|
604 |
+
else:
|
605 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
606 |
+
elif ".positional_embedding" in key:
|
607 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
608 |
+
elif ".text_projection" in key:
|
609 |
+
key = None # 使われない???
|
610 |
+
elif ".logit_scale" in key:
|
611 |
+
key = None # 使われない???
|
612 |
+
elif ".token_embedding" in key:
|
613 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
614 |
+
elif ".ln_final" in key:
|
615 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
616 |
+
return key
|
617 |
+
|
618 |
+
keys = list(checkpoint.keys())
|
619 |
+
new_sd = {}
|
620 |
+
for key in keys:
|
621 |
+
# remove resblocks 23
|
622 |
+
if ".resblocks.23." in key:
|
623 |
+
continue
|
624 |
+
new_key = convert_key(key)
|
625 |
+
if new_key is None:
|
626 |
+
continue
|
627 |
+
new_sd[new_key] = checkpoint[key]
|
628 |
+
|
629 |
+
# attnの変換
|
630 |
+
for key in keys:
|
631 |
+
if ".resblocks.23." in key:
|
632 |
+
continue
|
633 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
634 |
+
# 三つに分割
|
635 |
+
values = torch.chunk(checkpoint[key], 3)
|
636 |
+
|
637 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
638 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
639 |
+
key_pfx = key_pfx.replace("_weight", "")
|
640 |
+
key_pfx = key_pfx.replace("_bias", "")
|
641 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
642 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
643 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
644 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
645 |
+
|
646 |
+
# rename or add position_ids
|
647 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
648 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
649 |
+
# waifu diffusion v1.4
|
650 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
651 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
652 |
+
else:
|
653 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
654 |
+
|
655 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
656 |
+
return new_sd
|
657 |
+
|
658 |
+
|
659 |
+
# endregion
|
660 |
+
|
661 |
+
|
662 |
+
# region Diffusers->StableDiffusion の変換コード
|
663 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
664 |
+
|
665 |
+
|
666 |
+
def conv_transformer_to_linear(checkpoint):
|
667 |
+
keys = list(checkpoint.keys())
|
668 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
669 |
+
for key in keys:
|
670 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
671 |
+
if checkpoint[key].ndim > 2:
|
672 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
673 |
+
|
674 |
+
|
675 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
676 |
+
unet_conversion_map = [
|
677 |
+
# (stable-diffusion, HF Diffusers)
|
678 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
679 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
680 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
681 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
682 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
683 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
684 |
+
("out.0.weight", "conv_norm_out.weight"),
|
685 |
+
("out.0.bias", "conv_norm_out.bias"),
|
686 |
+
("out.2.weight", "conv_out.weight"),
|
687 |
+
("out.2.bias", "conv_out.bias"),
|
688 |
+
]
|
689 |
+
|
690 |
+
unet_conversion_map_resnet = [
|
691 |
+
# (stable-diffusion, HF Diffusers)
|
692 |
+
("in_layers.0", "norm1"),
|
693 |
+
("in_layers.2", "conv1"),
|
694 |
+
("out_layers.0", "norm2"),
|
695 |
+
("out_layers.3", "conv2"),
|
696 |
+
("emb_layers.1", "time_emb_proj"),
|
697 |
+
("skip_connection", "conv_shortcut"),
|
698 |
+
]
|
699 |
+
|
700 |
+
unet_conversion_map_layer = []
|
701 |
+
for i in range(4):
|
702 |
+
# loop over downblocks/upblocks
|
703 |
+
|
704 |
+
for j in range(2):
|
705 |
+
# loop over resnets/attentions for downblocks
|
706 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
707 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
708 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
709 |
+
|
710 |
+
if i < 3:
|
711 |
+
# no attention layers in down_blocks.3
|
712 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
713 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
714 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
715 |
+
|
716 |
+
for j in range(3):
|
717 |
+
# loop over resnets/attentions for upblocks
|
718 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
719 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
720 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
721 |
+
|
722 |
+
if i > 0:
|
723 |
+
# no attention layers in up_blocks.0
|
724 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
725 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
726 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
727 |
+
|
728 |
+
if i < 3:
|
729 |
+
# no downsample in down_blocks.3
|
730 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
731 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
732 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
733 |
+
|
734 |
+
# no upsample in up_blocks.3
|
735 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
736 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
737 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
738 |
+
|
739 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
740 |
+
sd_mid_atn_prefix = "middle_block.1."
|
741 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
742 |
+
|
743 |
+
for j in range(2):
|
744 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
745 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
746 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
747 |
+
|
748 |
+
# buyer beware: this is a *brittle* function,
|
749 |
+
# and correct output requires that all of these pieces interact in
|
750 |
+
# the exact order in which I have arranged them.
|
751 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
752 |
+
for sd_name, hf_name in unet_conversion_map:
|
753 |
+
mapping[hf_name] = sd_name
|
754 |
+
for k, v in mapping.items():
|
755 |
+
if "resnets" in k:
|
756 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
757 |
+
v = v.replace(hf_part, sd_part)
|
758 |
+
mapping[k] = v
|
759 |
+
for k, v in mapping.items():
|
760 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
761 |
+
v = v.replace(hf_part, sd_part)
|
762 |
+
mapping[k] = v
|
763 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
764 |
+
|
765 |
+
if v2:
|
766 |
+
conv_transformer_to_linear(new_state_dict)
|
767 |
+
|
768 |
+
return new_state_dict
|
769 |
+
|
770 |
+
|
771 |
+
def controlnet_conversion_map():
|
772 |
+
unet_conversion_map = [
|
773 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
774 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
775 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
776 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
777 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
778 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
779 |
+
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
|
780 |
+
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
|
781 |
+
]
|
782 |
+
|
783 |
+
unet_conversion_map_resnet = [
|
784 |
+
("in_layers.0", "norm1"),
|
785 |
+
("in_layers.2", "conv1"),
|
786 |
+
("out_layers.0", "norm2"),
|
787 |
+
("out_layers.3", "conv2"),
|
788 |
+
("emb_layers.1", "time_emb_proj"),
|
789 |
+
("skip_connection", "conv_shortcut"),
|
790 |
+
]
|
791 |
+
|
792 |
+
unet_conversion_map_layer = []
|
793 |
+
for i in range(4):
|
794 |
+
for j in range(2):
|
795 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
796 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
797 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
798 |
+
|
799 |
+
if i < 3:
|
800 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
801 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
802 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
803 |
+
|
804 |
+
if i < 3:
|
805 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
806 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
807 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
808 |
+
|
809 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
810 |
+
sd_mid_atn_prefix = "middle_block.1."
|
811 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
812 |
+
|
813 |
+
for j in range(2):
|
814 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
815 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
816 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
817 |
+
|
818 |
+
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
819 |
+
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
820 |
+
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
821 |
+
sd_prefix = f"input_hint_block.{i*2}."
|
822 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
823 |
+
|
824 |
+
for i in range(12):
|
825 |
+
hf_prefix = f"controlnet_down_blocks.{i}."
|
826 |
+
sd_prefix = f"zero_convs.{i}.0."
|
827 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
828 |
+
|
829 |
+
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
|
830 |
+
|
831 |
+
|
832 |
+
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
833 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
834 |
+
|
835 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
836 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
837 |
+
mapping[diffusers_name] = sd_name
|
838 |
+
for k, v in mapping.items():
|
839 |
+
if "resnets" in k:
|
840 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
841 |
+
v = v.replace(diffusers_part, sd_part)
|
842 |
+
mapping[k] = v
|
843 |
+
for k, v in mapping.items():
|
844 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
845 |
+
v = v.replace(diffusers_part, sd_part)
|
846 |
+
mapping[k] = v
|
847 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
848 |
+
return new_state_dict
|
849 |
+
|
850 |
+
|
851 |
+
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
852 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
853 |
+
|
854 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
855 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
856 |
+
mapping[sd_name] = diffusers_name
|
857 |
+
for k, v in mapping.items():
|
858 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
859 |
+
v = v.replace(sd_part, diffusers_part)
|
860 |
+
mapping[k] = v
|
861 |
+
for k, v in mapping.items():
|
862 |
+
if "resnets" in v:
|
863 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
864 |
+
v = v.replace(sd_part, diffusers_part)
|
865 |
+
mapping[k] = v
|
866 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
867 |
+
return new_state_dict
|
868 |
+
|
869 |
+
|
870 |
+
# ================#
|
871 |
+
# VAE Conversion #
|
872 |
+
# ================#
|
873 |
+
|
874 |
+
|
875 |
+
def reshape_weight_for_sd(w):
|
876 |
+
# convert HF linear weights to SD conv2d weights
|
877 |
+
return w.reshape(*w.shape, 1, 1)
|
878 |
+
|
879 |
+
|
880 |
+
def convert_vae_state_dict(vae_state_dict):
|
881 |
+
vae_conversion_map = [
|
882 |
+
# (stable-diffusion, HF Diffusers)
|
883 |
+
("nin_shortcut", "conv_shortcut"),
|
884 |
+
("norm_out", "conv_norm_out"),
|
885 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
886 |
+
]
|
887 |
+
|
888 |
+
for i in range(4):
|
889 |
+
# down_blocks have two resnets
|
890 |
+
for j in range(2):
|
891 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
892 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
893 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
894 |
+
|
895 |
+
if i < 3:
|
896 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
897 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
898 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
899 |
+
|
900 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
901 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
902 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
903 |
+
|
904 |
+
# up_blocks have three resnets
|
905 |
+
# also, up blocks in hf are numbered in reverse from sd
|
906 |
+
for j in range(3):
|
907 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
908 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
909 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
910 |
+
|
911 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
912 |
+
for i in range(2):
|
913 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
914 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
915 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
916 |
+
|
917 |
+
if diffusers.__version__ < "0.17.0":
|
918 |
+
vae_conversion_map_attn = [
|
919 |
+
# (stable-diffusion, HF Diffusers)
|
920 |
+
("norm.", "group_norm."),
|
921 |
+
("q.", "query."),
|
922 |
+
("k.", "key."),
|
923 |
+
("v.", "value."),
|
924 |
+
("proj_out.", "proj_attn."),
|
925 |
+
]
|
926 |
+
else:
|
927 |
+
vae_conversion_map_attn = [
|
928 |
+
# (stable-diffusion, HF Diffusers)
|
929 |
+
("norm.", "group_norm."),
|
930 |
+
("q.", "to_q."),
|
931 |
+
("k.", "to_k."),
|
932 |
+
("v.", "to_v."),
|
933 |
+
("proj_out.", "to_out.0."),
|
934 |
+
]
|
935 |
+
|
936 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
937 |
+
for k, v in mapping.items():
|
938 |
+
for sd_part, hf_part in vae_conversion_map:
|
939 |
+
v = v.replace(hf_part, sd_part)
|
940 |
+
mapping[k] = v
|
941 |
+
for k, v in mapping.items():
|
942 |
+
if "attentions" in k:
|
943 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
944 |
+
v = v.replace(hf_part, sd_part)
|
945 |
+
mapping[k] = v
|
946 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
947 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
948 |
+
for k, v in new_state_dict.items():
|
949 |
+
for weight_name in weights_to_convert:
|
950 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
951 |
+
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
952 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
953 |
+
|
954 |
+
return new_state_dict
|
955 |
+
|
956 |
+
|
957 |
+
# endregion
|
958 |
+
|
959 |
+
# region 自作のモデル読み書きなど
|
960 |
+
|
961 |
+
|
962 |
+
def is_safetensors(path):
|
963 |
+
return os.path.splitext(path)[1].lower() == ".safetensors"
|
964 |
+
|
965 |
+
|
966 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
967 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
968 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
969 |
+
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
970 |
+
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
971 |
+
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
972 |
+
]
|
973 |
+
|
974 |
+
if is_safetensors(ckpt_path):
|
975 |
+
checkpoint = None
|
976 |
+
state_dict = load_file(ckpt_path) # , device) # may causes error
|
977 |
+
else:
|
978 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
979 |
+
if "state_dict" in checkpoint:
|
980 |
+
state_dict = checkpoint["state_dict"]
|
981 |
+
else:
|
982 |
+
state_dict = checkpoint
|
983 |
+
checkpoint = None
|
984 |
+
|
985 |
+
key_reps = []
|
986 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
987 |
+
for key in state_dict.keys():
|
988 |
+
if key.startswith(rep_from):
|
989 |
+
new_key = rep_to + key[len(rep_from) :]
|
990 |
+
key_reps.append((key, new_key))
|
991 |
+
|
992 |
+
for key, new_key in key_reps:
|
993 |
+
state_dict[new_key] = state_dict[key]
|
994 |
+
del state_dict[key]
|
995 |
+
|
996 |
+
return checkpoint, state_dict
|
997 |
+
|
998 |
+
|
999 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
1000 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
|
1001 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
1002 |
+
|
1003 |
+
# Convert the UNet2DConditionModel model.
|
1004 |
+
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
|
1005 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
1006 |
+
|
1007 |
+
unet = UNet2DConditionModel(**unet_config).to(device)
|
1008 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
1009 |
+
logger.info(f"loading u-net: {info}")
|
1010 |
+
|
1011 |
+
# Convert the VAE model.
|
1012 |
+
vae_config = create_vae_diffusers_config()
|
1013 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
1014 |
+
|
1015 |
+
vae = AutoencoderKL(**vae_config).to(device)
|
1016 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
1017 |
+
logger.info(f"loading vae: {info}")
|
1018 |
+
|
1019 |
+
# convert text_model
|
1020 |
+
if v2:
|
1021 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
1022 |
+
cfg = CLIPTextConfig(
|
1023 |
+
vocab_size=49408,
|
1024 |
+
hidden_size=1024,
|
1025 |
+
intermediate_size=4096,
|
1026 |
+
num_hidden_layers=23,
|
1027 |
+
num_attention_heads=16,
|
1028 |
+
max_position_embeddings=77,
|
1029 |
+
hidden_act="gelu",
|
1030 |
+
layer_norm_eps=1e-05,
|
1031 |
+
dropout=0.0,
|
1032 |
+
attention_dropout=0.0,
|
1033 |
+
initializer_range=0.02,
|
1034 |
+
initializer_factor=1.0,
|
1035 |
+
pad_token_id=1,
|
1036 |
+
bos_token_id=0,
|
1037 |
+
eos_token_id=2,
|
1038 |
+
model_type="clip_text_model",
|
1039 |
+
projection_dim=512,
|
1040 |
+
torch_dtype="float32",
|
1041 |
+
transformers_version="4.25.0.dev0",
|
1042 |
+
)
|
1043 |
+
text_model = CLIPTextModel._from_config(cfg)
|
1044 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
1045 |
+
else:
|
1046 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
1047 |
+
|
1048 |
+
# logging.set_verbosity_error() # don't show annoying warning
|
1049 |
+
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
1050 |
+
# logging.set_verbosity_warning()
|
1051 |
+
# logger.info(f"config: {text_model.config}")
|
1052 |
+
cfg = CLIPTextConfig(
|
1053 |
+
vocab_size=49408,
|
1054 |
+
hidden_size=768,
|
1055 |
+
intermediate_size=3072,
|
1056 |
+
num_hidden_layers=12,
|
1057 |
+
num_attention_heads=12,
|
1058 |
+
max_position_embeddings=77,
|
1059 |
+
hidden_act="quick_gelu",
|
1060 |
+
layer_norm_eps=1e-05,
|
1061 |
+
dropout=0.0,
|
1062 |
+
attention_dropout=0.0,
|
1063 |
+
initializer_range=0.02,
|
1064 |
+
initializer_factor=1.0,
|
1065 |
+
pad_token_id=1,
|
1066 |
+
bos_token_id=0,
|
1067 |
+
eos_token_id=2,
|
1068 |
+
model_type="clip_text_model",
|
1069 |
+
projection_dim=768,
|
1070 |
+
torch_dtype="float32",
|
1071 |
+
)
|
1072 |
+
text_model = CLIPTextModel._from_config(cfg)
|
1073 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
1074 |
+
logger.info(f"loading text encoder: {info}")
|
1075 |
+
|
1076 |
+
return text_model, vae, unet
|
1077 |
+
|
1078 |
+
|
1079 |
+
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
1080 |
+
# only for reference
|
1081 |
+
version_str = "sd"
|
1082 |
+
if v2:
|
1083 |
+
version_str += "_v2"
|
1084 |
+
else:
|
1085 |
+
version_str += "_v1"
|
1086 |
+
if v_parameterization:
|
1087 |
+
version_str += "_v"
|
1088 |
+
return version_str
|
1089 |
+
|
1090 |
+
|
1091 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
1092 |
+
def convert_key(key):
|
1093 |
+
# position_idsの除去
|
1094 |
+
if ".position_ids" in key:
|
1095 |
+
return None
|
1096 |
+
|
1097 |
+
# common
|
1098 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
1099 |
+
key = key.replace("text_model.", "")
|
1100 |
+
if "layers" in key:
|
1101 |
+
# resblocks conversion
|
1102 |
+
key = key.replace(".layers.", ".resblocks.")
|
1103 |
+
if ".layer_norm" in key:
|
1104 |
+
key = key.replace(".layer_norm", ".ln_")
|
1105 |
+
elif ".mlp." in key:
|
1106 |
+
key = key.replace(".fc1.", ".c_fc.")
|
1107 |
+
key = key.replace(".fc2.", ".c_proj.")
|
1108 |
+
elif ".self_attn.out_proj" in key:
|
1109 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
1110 |
+
elif ".self_attn." in key:
|
1111 |
+
key = None # 特殊なので後で処理する
|
1112 |
+
else:
|
1113 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
1114 |
+
elif ".position_embedding" in key:
|
1115 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
1116 |
+
elif ".token_embedding" in key:
|
1117 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
1118 |
+
elif "final_layer_norm" in key:
|
1119 |
+
key = key.replace("final_layer_norm", "ln_final")
|
1120 |
+
return key
|
1121 |
+
|
1122 |
+
keys = list(checkpoint.keys())
|
1123 |
+
new_sd = {}
|
1124 |
+
for key in keys:
|
1125 |
+
new_key = convert_key(key)
|
1126 |
+
if new_key is None:
|
1127 |
+
continue
|
1128 |
+
new_sd[new_key] = checkpoint[key]
|
1129 |
+
|
1130 |
+
# attnの変換
|
1131 |
+
for key in keys:
|
1132 |
+
if "layers" in key and "q_proj" in key:
|
1133 |
+
# 三つを結合
|
1134 |
+
key_q = key
|
1135 |
+
key_k = key.replace("q_proj", "k_proj")
|
1136 |
+
key_v = key.replace("q_proj", "v_proj")
|
1137 |
+
|
1138 |
+
value_q = checkpoint[key_q]
|
1139 |
+
value_k = checkpoint[key_k]
|
1140 |
+
value_v = checkpoint[key_v]
|
1141 |
+
value = torch.cat([value_q, value_k, value_v])
|
1142 |
+
|
1143 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
1144 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
1145 |
+
new_sd[new_key] = value
|
1146 |
+
|
1147 |
+
# 最後の層などを捏造するか
|
1148 |
+
if make_dummy_weights:
|
1149 |
+
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
|
1150 |
+
keys = list(new_sd.keys())
|
1151 |
+
for key in keys:
|
1152 |
+
if key.startswith("transformer.resblocks.22."):
|
1153 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
1154 |
+
|
1155 |
+
# Diffusersに含まれない重みを作っておく
|
1156 |
+
new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
1157 |
+
new_sd["logit_scale"] = torch.tensor(1)
|
1158 |
+
|
1159 |
+
return new_sd
|
1160 |
+
|
1161 |
+
|
1162 |
+
def save_stable_diffusion_checkpoint(
|
1163 |
+
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
|
1164 |
+
):
|
1165 |
+
if ckpt_path is not None:
|
1166 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1167 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1168 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1169 |
+
checkpoint = {}
|
1170 |
+
strict = False
|
1171 |
+
else:
|
1172 |
+
strict = True
|
1173 |
+
if "state_dict" in state_dict:
|
1174 |
+
del state_dict["state_dict"]
|
1175 |
+
else:
|
1176 |
+
# 新しく作る
|
1177 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1178 |
+
checkpoint = {}
|
1179 |
+
state_dict = {}
|
1180 |
+
strict = False
|
1181 |
+
|
1182 |
+
def update_sd(prefix, sd):
|
1183 |
+
for k, v in sd.items():
|
1184 |
+
key = prefix + k
|
1185 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1186 |
+
if save_dtype is not None:
|
1187 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1188 |
+
state_dict[key] = v
|
1189 |
+
|
1190 |
+
# Convert the UNet model
|
1191 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1192 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1193 |
+
|
1194 |
+
# Convert the text encoder model
|
1195 |
+
if v2:
|
1196 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
1197 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1198 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1199 |
+
else:
|
1200 |
+
text_enc_dict = text_encoder.state_dict()
|
1201 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1202 |
+
|
1203 |
+
# Convert the VAE
|
1204 |
+
if vae is not None:
|
1205 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1206 |
+
update_sd("first_stage_model.", vae_dict)
|
1207 |
+
|
1208 |
+
# Put together new checkpoint
|
1209 |
+
key_count = len(state_dict.keys())
|
1210 |
+
new_ckpt = {"state_dict": state_dict}
|
1211 |
+
|
1212 |
+
# epoch and global_step are sometimes not int
|
1213 |
+
try:
|
1214 |
+
if "epoch" in checkpoint:
|
1215 |
+
epochs += checkpoint["epoch"]
|
1216 |
+
if "global_step" in checkpoint:
|
1217 |
+
steps += checkpoint["global_step"]
|
1218 |
+
except:
|
1219 |
+
pass
|
1220 |
+
|
1221 |
+
new_ckpt["epoch"] = epochs
|
1222 |
+
new_ckpt["global_step"] = steps
|
1223 |
+
|
1224 |
+
if is_safetensors(output_file):
|
1225 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1226 |
+
save_file(state_dict, output_file, metadata)
|
1227 |
+
else:
|
1228 |
+
torch.save(new_ckpt, output_file)
|
1229 |
+
|
1230 |
+
return key_count
|
1231 |
+
|
1232 |
+
|
1233 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1234 |
+
if pretrained_model_name_or_path is None:
|
1235 |
+
# load default settings for v1/v2
|
1236 |
+
if v2:
|
1237 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1238 |
+
else:
|
1239 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1240 |
+
|
1241 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1242 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1243 |
+
if vae is None:
|
1244 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1245 |
+
|
1246 |
+
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
|
1247 |
+
# TODO this consumes a lot of memory
|
1248 |
+
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
|
1249 |
+
diffusers_unet.load_state_dict(unet.state_dict())
|
1250 |
+
|
1251 |
+
pipeline = StableDiffusionPipeline(
|
1252 |
+
unet=diffusers_unet,
|
1253 |
+
text_encoder=text_encoder,
|
1254 |
+
vae=vae,
|
1255 |
+
scheduler=scheduler,
|
1256 |
+
tokenizer=tokenizer,
|
1257 |
+
safety_checker=None,
|
1258 |
+
feature_extractor=None,
|
1259 |
+
requires_safety_checker=None,
|
1260 |
+
)
|
1261 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1262 |
+
|
1263 |
+
|
1264 |
+
VAE_PREFIX = "first_stage_model."
|
1265 |
+
|
1266 |
+
|
1267 |
+
def load_vae(vae_id, dtype):
|
1268 |
+
logger.info(f"load VAE: {vae_id}")
|
1269 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1270 |
+
# Diffusers local/remote
|
1271 |
+
try:
|
1272 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1273 |
+
except EnvironmentError as e:
|
1274 |
+
logger.error(f"exception occurs in loading vae: {e}")
|
1275 |
+
logger.error("retry with subfolder='vae'")
|
1276 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1277 |
+
return vae
|
1278 |
+
|
1279 |
+
# local
|
1280 |
+
vae_config = create_vae_diffusers_config()
|
1281 |
+
|
1282 |
+
if vae_id.endswith(".bin"):
|
1283 |
+
# SD 1.5 VAE on Huggingface
|
1284 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1285 |
+
else:
|
1286 |
+
# StableDiffusion
|
1287 |
+
vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
|
1288 |
+
vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
|
1289 |
+
|
1290 |
+
# vae only or full model
|
1291 |
+
full_model = False
|
1292 |
+
for vae_key in vae_sd:
|
1293 |
+
if vae_key.startswith(VAE_PREFIX):
|
1294 |
+
full_model = True
|
1295 |
+
break
|
1296 |
+
if not full_model:
|
1297 |
+
sd = {}
|
1298 |
+
for key, value in vae_sd.items():
|
1299 |
+
sd[VAE_PREFIX + key] = value
|
1300 |
+
vae_sd = sd
|
1301 |
+
del sd
|
1302 |
+
|
1303 |
+
# Convert the VAE model.
|
1304 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1305 |
+
|
1306 |
+
vae = AutoencoderKL(**vae_config)
|
1307 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1308 |
+
return vae
|
1309 |
+
|
1310 |
+
|
1311 |
+
# endregion
|
1312 |
+
|
1313 |
+
|
1314 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1315 |
+
max_width, max_height = max_reso
|
1316 |
+
max_area = max_width * max_height
|
1317 |
+
|
1318 |
+
resos = set()
|
1319 |
+
|
1320 |
+
width = int(math.sqrt(max_area) // divisible) * divisible
|
1321 |
+
resos.add((width, width))
|
1322 |
+
|
1323 |
+
width = min_size
|
1324 |
+
while width <= max_size:
|
1325 |
+
height = min(max_size, int((max_area // width) // divisible) * divisible)
|
1326 |
+
if height >= min_size:
|
1327 |
+
resos.add((width, height))
|
1328 |
+
resos.add((height, width))
|
1329 |
+
|
1330 |
+
# # make additional resos
|
1331 |
+
# if width >= height and width - divisible >= min_size:
|
1332 |
+
# resos.add((width - divisible, height))
|
1333 |
+
# resos.add((height, width - divisible))
|
1334 |
+
# if height >= width and height - divisible >= min_size:
|
1335 |
+
# resos.add((width, height - divisible))
|
1336 |
+
# resos.add((height - divisible, width))
|
1337 |
+
|
1338 |
+
width += divisible
|
1339 |
+
|
1340 |
+
resos = list(resos)
|
1341 |
+
resos.sort()
|
1342 |
+
return resos
|
1343 |
+
|
1344 |
+
|
1345 |
+
if __name__ == "__main__":
|
1346 |
+
resos = make_bucket_resolutions((512, 768))
|
1347 |
+
logger.info(f"{len(resos)}")
|
1348 |
+
logger.info(f"{resos}")
|
1349 |
+
aspect_ratios = [w / h for w, h in resos]
|
1350 |
+
logger.info(f"{aspect_ratios}")
|
1351 |
+
|
1352 |
+
ars = set()
|
1353 |
+
for ar in aspect_ratios:
|
1354 |
+
if ar in ars:
|
1355 |
+
logger.error(f"error! duplicate ar: {ar}")
|
1356 |
+
ars.add(ar)
|
library/original_unet.py
ADDED
@@ -0,0 +1,1919 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
|
2 |
+
# 条件分岐等で不要な部分は削除している
|
3 |
+
# コードの多くはDiffusersからコピーしている
|
4 |
+
# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
|
5 |
+
|
6 |
+
# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
|
7 |
+
# Unnecessary parts are deleted by condition branching.
|
8 |
+
# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
|
9 |
+
|
10 |
+
"""
|
11 |
+
v1.5とv2.1の相違点は
|
12 |
+
- attention_head_dimがintかlist[int]か
|
13 |
+
- cross_attention_dimが768か1024か
|
14 |
+
- use_linear_projection: trueがない(=False, 1.5)かあるか
|
15 |
+
- upcast_attentionがFalse(1.5)かTrue(2.1)か
|
16 |
+
- (以下は多分無視していい)
|
17 |
+
- sample_sizeが64か96か
|
18 |
+
- dual_cross_attentionがあるかないか
|
19 |
+
- num_class_embedsがあるかないか
|
20 |
+
- only_cross_attentionがあるかないか
|
21 |
+
|
22 |
+
v1.5
|
23 |
+
{
|
24 |
+
"_class_name": "UNet2DConditionModel",
|
25 |
+
"_diffusers_version": "0.6.0",
|
26 |
+
"act_fn": "silu",
|
27 |
+
"attention_head_dim": 8,
|
28 |
+
"block_out_channels": [
|
29 |
+
320,
|
30 |
+
640,
|
31 |
+
1280,
|
32 |
+
1280
|
33 |
+
],
|
34 |
+
"center_input_sample": false,
|
35 |
+
"cross_attention_dim": 768,
|
36 |
+
"down_block_types": [
|
37 |
+
"CrossAttnDownBlock2D",
|
38 |
+
"CrossAttnDownBlock2D",
|
39 |
+
"CrossAttnDownBlock2D",
|
40 |
+
"DownBlock2D"
|
41 |
+
],
|
42 |
+
"downsample_padding": 1,
|
43 |
+
"flip_sin_to_cos": true,
|
44 |
+
"freq_shift": 0,
|
45 |
+
"in_channels": 4,
|
46 |
+
"layers_per_block": 2,
|
47 |
+
"mid_block_scale_factor": 1,
|
48 |
+
"norm_eps": 1e-05,
|
49 |
+
"norm_num_groups": 32,
|
50 |
+
"out_channels": 4,
|
51 |
+
"sample_size": 64,
|
52 |
+
"up_block_types": [
|
53 |
+
"UpBlock2D",
|
54 |
+
"CrossAttnUpBlock2D",
|
55 |
+
"CrossAttnUpBlock2D",
|
56 |
+
"CrossAttnUpBlock2D"
|
57 |
+
]
|
58 |
+
}
|
59 |
+
|
60 |
+
v2.1
|
61 |
+
{
|
62 |
+
"_class_name": "UNet2DConditionModel",
|
63 |
+
"_diffusers_version": "0.10.0.dev0",
|
64 |
+
"act_fn": "silu",
|
65 |
+
"attention_head_dim": [
|
66 |
+
5,
|
67 |
+
10,
|
68 |
+
20,
|
69 |
+
20
|
70 |
+
],
|
71 |
+
"block_out_channels": [
|
72 |
+
320,
|
73 |
+
640,
|
74 |
+
1280,
|
75 |
+
1280
|
76 |
+
],
|
77 |
+
"center_input_sample": false,
|
78 |
+
"cross_attention_dim": 1024,
|
79 |
+
"down_block_types": [
|
80 |
+
"CrossAttnDownBlock2D",
|
81 |
+
"CrossAttnDownBlock2D",
|
82 |
+
"CrossAttnDownBlock2D",
|
83 |
+
"DownBlock2D"
|
84 |
+
],
|
85 |
+
"downsample_padding": 1,
|
86 |
+
"dual_cross_attention": false,
|
87 |
+
"flip_sin_to_cos": true,
|
88 |
+
"freq_shift": 0,
|
89 |
+
"in_channels": 4,
|
90 |
+
"layers_per_block": 2,
|
91 |
+
"mid_block_scale_factor": 1,
|
92 |
+
"norm_eps": 1e-05,
|
93 |
+
"norm_num_groups": 32,
|
94 |
+
"num_class_embeds": null,
|
95 |
+
"only_cross_attention": false,
|
96 |
+
"out_channels": 4,
|
97 |
+
"sample_size": 96,
|
98 |
+
"up_block_types": [
|
99 |
+
"UpBlock2D",
|
100 |
+
"CrossAttnUpBlock2D",
|
101 |
+
"CrossAttnUpBlock2D",
|
102 |
+
"CrossAttnUpBlock2D"
|
103 |
+
],
|
104 |
+
"use_linear_projection": true,
|
105 |
+
"upcast_attention": true
|
106 |
+
}
|
107 |
+
"""
|
108 |
+
|
109 |
+
import math
|
110 |
+
from types import SimpleNamespace
|
111 |
+
from typing import Dict, Optional, Tuple, Union
|
112 |
+
import torch
|
113 |
+
from torch import nn
|
114 |
+
from torch.nn import functional as F
|
115 |
+
from einops import rearrange
|
116 |
+
from library.utils import setup_logging
|
117 |
+
setup_logging()
|
118 |
+
import logging
|
119 |
+
logger = logging.getLogger(__name__)
|
120 |
+
|
121 |
+
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
122 |
+
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
123 |
+
TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
|
124 |
+
IN_CHANNELS: int = 4
|
125 |
+
OUT_CHANNELS: int = 4
|
126 |
+
LAYERS_PER_BLOCK: int = 2
|
127 |
+
LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
|
128 |
+
TIME_EMBED_FLIP_SIN_TO_COS: bool = True
|
129 |
+
TIME_EMBED_FREQ_SHIFT: int = 0
|
130 |
+
NORM_GROUPS: int = 32
|
131 |
+
NORM_EPS: float = 1e-5
|
132 |
+
TRANSFORMER_NORM_NUM_GROUPS = 32
|
133 |
+
|
134 |
+
DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
|
135 |
+
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
|
136 |
+
|
137 |
+
|
138 |
+
# region memory efficient attention
|
139 |
+
|
140 |
+
# FlashAttentionを使うCrossAttention
|
141 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
142 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
143 |
+
|
144 |
+
# constants
|
145 |
+
|
146 |
+
EPSILON = 1e-6
|
147 |
+
|
148 |
+
# helper functions
|
149 |
+
|
150 |
+
|
151 |
+
def exists(val):
|
152 |
+
return val is not None
|
153 |
+
|
154 |
+
|
155 |
+
def default(val, d):
|
156 |
+
return val if exists(val) else d
|
157 |
+
|
158 |
+
|
159 |
+
# flash attention forwards and backwards
|
160 |
+
|
161 |
+
# https://arxiv.org/abs/2205.14135
|
162 |
+
|
163 |
+
|
164 |
+
class FlashAttentionFunction(torch.autograd.Function):
|
165 |
+
@staticmethod
|
166 |
+
@torch.no_grad()
|
167 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
168 |
+
"""Algorithm 2 in the paper"""
|
169 |
+
|
170 |
+
device = q.device
|
171 |
+
dtype = q.dtype
|
172 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
173 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
174 |
+
|
175 |
+
o = torch.zeros_like(q)
|
176 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
177 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
178 |
+
|
179 |
+
scale = q.shape[-1] ** -0.5
|
180 |
+
|
181 |
+
if not exists(mask):
|
182 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
183 |
+
else:
|
184 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
185 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
186 |
+
|
187 |
+
row_splits = zip(
|
188 |
+
q.split(q_bucket_size, dim=-2),
|
189 |
+
o.split(q_bucket_size, dim=-2),
|
190 |
+
mask,
|
191 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
192 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
193 |
+
)
|
194 |
+
|
195 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
196 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
197 |
+
|
198 |
+
col_splits = zip(
|
199 |
+
k.split(k_bucket_size, dim=-2),
|
200 |
+
v.split(k_bucket_size, dim=-2),
|
201 |
+
)
|
202 |
+
|
203 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
204 |
+
k_start_index = k_ind * k_bucket_size
|
205 |
+
|
206 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
207 |
+
|
208 |
+
if exists(row_mask):
|
209 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
210 |
+
|
211 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
212 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
213 |
+
q_start_index - k_start_index + 1
|
214 |
+
)
|
215 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
216 |
+
|
217 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
218 |
+
attn_weights -= block_row_maxes
|
219 |
+
exp_weights = torch.exp(attn_weights)
|
220 |
+
|
221 |
+
if exists(row_mask):
|
222 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
223 |
+
|
224 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
225 |
+
|
226 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
227 |
+
|
228 |
+
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
229 |
+
|
230 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
231 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
232 |
+
|
233 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
234 |
+
|
235 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
236 |
+
|
237 |
+
row_maxes.copy_(new_row_maxes)
|
238 |
+
row_sums.copy_(new_row_sums)
|
239 |
+
|
240 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
241 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
242 |
+
|
243 |
+
return o
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
@torch.no_grad()
|
247 |
+
def backward(ctx, do):
|
248 |
+
"""Algorithm 4 in the paper"""
|
249 |
+
|
250 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
251 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
252 |
+
|
253 |
+
device = q.device
|
254 |
+
|
255 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
256 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
257 |
+
|
258 |
+
dq = torch.zeros_like(q)
|
259 |
+
dk = torch.zeros_like(k)
|
260 |
+
dv = torch.zeros_like(v)
|
261 |
+
|
262 |
+
row_splits = zip(
|
263 |
+
q.split(q_bucket_size, dim=-2),
|
264 |
+
o.split(q_bucket_size, dim=-2),
|
265 |
+
do.split(q_bucket_size, dim=-2),
|
266 |
+
mask,
|
267 |
+
l.split(q_bucket_size, dim=-2),
|
268 |
+
m.split(q_bucket_size, dim=-2),
|
269 |
+
dq.split(q_bucket_size, dim=-2),
|
270 |
+
)
|
271 |
+
|
272 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
273 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
274 |
+
|
275 |
+
col_splits = zip(
|
276 |
+
k.split(k_bucket_size, dim=-2),
|
277 |
+
v.split(k_bucket_size, dim=-2),
|
278 |
+
dk.split(k_bucket_size, dim=-2),
|
279 |
+
dv.split(k_bucket_size, dim=-2),
|
280 |
+
)
|
281 |
+
|
282 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
283 |
+
k_start_index = k_ind * k_bucket_size
|
284 |
+
|
285 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
286 |
+
|
287 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
288 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
289 |
+
q_start_index - k_start_index + 1
|
290 |
+
)
|
291 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
292 |
+
|
293 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
294 |
+
|
295 |
+
if exists(row_mask):
|
296 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
297 |
+
|
298 |
+
p = exp_attn_weights / lc
|
299 |
+
|
300 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
301 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
302 |
+
|
303 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
304 |
+
ds = p * scale * (dp - D)
|
305 |
+
|
306 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
307 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
308 |
+
|
309 |
+
dqc.add_(dq_chunk)
|
310 |
+
dkc.add_(dk_chunk)
|
311 |
+
dvc.add_(dv_chunk)
|
312 |
+
|
313 |
+
return dq, dk, dv, None, None, None, None
|
314 |
+
|
315 |
+
|
316 |
+
# endregion
|
317 |
+
|
318 |
+
|
319 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
320 |
+
return next(parameter.parameters()).dtype
|
321 |
+
|
322 |
+
|
323 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
324 |
+
return next(parameter.parameters()).device
|
325 |
+
|
326 |
+
|
327 |
+
def get_timestep_embedding(
|
328 |
+
timesteps: torch.Tensor,
|
329 |
+
embedding_dim: int,
|
330 |
+
flip_sin_to_cos: bool = False,
|
331 |
+
downscale_freq_shift: float = 1,
|
332 |
+
scale: float = 1,
|
333 |
+
max_period: int = 10000,
|
334 |
+
):
|
335 |
+
"""
|
336 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
337 |
+
|
338 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
339 |
+
These may be fractional.
|
340 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
341 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
342 |
+
"""
|
343 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
344 |
+
|
345 |
+
half_dim = embedding_dim // 2
|
346 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
347 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
348 |
+
|
349 |
+
emb = torch.exp(exponent)
|
350 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
351 |
+
|
352 |
+
# scale embeddings
|
353 |
+
emb = scale * emb
|
354 |
+
|
355 |
+
# concat sine and cosine embeddings
|
356 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
357 |
+
|
358 |
+
# flip sine and cosine embeddings
|
359 |
+
if flip_sin_to_cos:
|
360 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
361 |
+
|
362 |
+
# zero pad
|
363 |
+
if embedding_dim % 2 == 1:
|
364 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
365 |
+
return emb
|
366 |
+
|
367 |
+
|
368 |
+
# Deep Shrink: We do not common this function, because minimize dependencies.
|
369 |
+
def resize_like(x, target, mode="bicubic", align_corners=False):
|
370 |
+
org_dtype = x.dtype
|
371 |
+
if org_dtype == torch.bfloat16:
|
372 |
+
x = x.to(torch.float32)
|
373 |
+
|
374 |
+
if x.shape[-2:] != target.shape[-2:]:
|
375 |
+
if mode == "nearest":
|
376 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
377 |
+
else:
|
378 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
379 |
+
|
380 |
+
if org_dtype == torch.bfloat16:
|
381 |
+
x = x.to(org_dtype)
|
382 |
+
return x
|
383 |
+
|
384 |
+
|
385 |
+
class SampleOutput:
|
386 |
+
def __init__(self, sample):
|
387 |
+
self.sample = sample
|
388 |
+
|
389 |
+
|
390 |
+
class TimestepEmbedding(nn.Module):
|
391 |
+
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
392 |
+
super().__init__()
|
393 |
+
|
394 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
395 |
+
self.act = None
|
396 |
+
if act_fn == "silu":
|
397 |
+
self.act = nn.SiLU()
|
398 |
+
elif act_fn == "mish":
|
399 |
+
self.act = nn.Mish()
|
400 |
+
|
401 |
+
if out_dim is not None:
|
402 |
+
time_embed_dim_out = out_dim
|
403 |
+
else:
|
404 |
+
time_embed_dim_out = time_embed_dim
|
405 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
406 |
+
|
407 |
+
def forward(self, sample):
|
408 |
+
sample = self.linear_1(sample)
|
409 |
+
|
410 |
+
if self.act is not None:
|
411 |
+
sample = self.act(sample)
|
412 |
+
|
413 |
+
sample = self.linear_2(sample)
|
414 |
+
return sample
|
415 |
+
|
416 |
+
|
417 |
+
class Timesteps(nn.Module):
|
418 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
419 |
+
super().__init__()
|
420 |
+
self.num_channels = num_channels
|
421 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
422 |
+
self.downscale_freq_shift = downscale_freq_shift
|
423 |
+
|
424 |
+
def forward(self, timesteps):
|
425 |
+
t_emb = get_timestep_embedding(
|
426 |
+
timesteps,
|
427 |
+
self.num_channels,
|
428 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
429 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
430 |
+
)
|
431 |
+
return t_emb
|
432 |
+
|
433 |
+
|
434 |
+
class ResnetBlock2D(nn.Module):
|
435 |
+
def __init__(
|
436 |
+
self,
|
437 |
+
in_channels,
|
438 |
+
out_channels,
|
439 |
+
):
|
440 |
+
super().__init__()
|
441 |
+
self.in_channels = in_channels
|
442 |
+
self.out_channels = out_channels
|
443 |
+
|
444 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
|
445 |
+
|
446 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
447 |
+
|
448 |
+
self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
|
449 |
+
|
450 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
|
451 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
452 |
+
|
453 |
+
# if non_linearity == "swish":
|
454 |
+
self.nonlinearity = lambda x: F.silu(x)
|
455 |
+
|
456 |
+
self.use_in_shortcut = self.in_channels != self.out_channels
|
457 |
+
|
458 |
+
self.conv_shortcut = None
|
459 |
+
if self.use_in_shortcut:
|
460 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
461 |
+
|
462 |
+
def forward(self, input_tensor, temb):
|
463 |
+
hidden_states = input_tensor
|
464 |
+
|
465 |
+
hidden_states = self.norm1(hidden_states)
|
466 |
+
hidden_states = self.nonlinearity(hidden_states)
|
467 |
+
|
468 |
+
hidden_states = self.conv1(hidden_states)
|
469 |
+
|
470 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
471 |
+
hidden_states = hidden_states + temb
|
472 |
+
|
473 |
+
hidden_states = self.norm2(hidden_states)
|
474 |
+
hidden_states = self.nonlinearity(hidden_states)
|
475 |
+
|
476 |
+
hidden_states = self.conv2(hidden_states)
|
477 |
+
|
478 |
+
if self.conv_shortcut is not None:
|
479 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
480 |
+
|
481 |
+
output_tensor = input_tensor + hidden_states
|
482 |
+
|
483 |
+
return output_tensor
|
484 |
+
|
485 |
+
|
486 |
+
class DownBlock2D(nn.Module):
|
487 |
+
def __init__(
|
488 |
+
self,
|
489 |
+
in_channels: int,
|
490 |
+
out_channels: int,
|
491 |
+
add_downsample=True,
|
492 |
+
):
|
493 |
+
super().__init__()
|
494 |
+
|
495 |
+
self.has_cross_attention = False
|
496 |
+
resnets = []
|
497 |
+
|
498 |
+
for i in range(LAYERS_PER_BLOCK):
|
499 |
+
in_channels = in_channels if i == 0 else out_channels
|
500 |
+
resnets.append(
|
501 |
+
ResnetBlock2D(
|
502 |
+
in_channels=in_channels,
|
503 |
+
out_channels=out_channels,
|
504 |
+
)
|
505 |
+
)
|
506 |
+
self.resnets = nn.ModuleList(resnets)
|
507 |
+
|
508 |
+
if add_downsample:
|
509 |
+
self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
|
510 |
+
else:
|
511 |
+
self.downsamplers = None
|
512 |
+
|
513 |
+
self.gradient_checkpointing = False
|
514 |
+
|
515 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
516 |
+
pass
|
517 |
+
|
518 |
+
def set_use_sdpa(self, sdpa):
|
519 |
+
pass
|
520 |
+
|
521 |
+
def forward(self, hidden_states, temb=None):
|
522 |
+
output_states = ()
|
523 |
+
|
524 |
+
for resnet in self.resnets:
|
525 |
+
if self.training and self.gradient_checkpointing:
|
526 |
+
|
527 |
+
def create_custom_forward(module):
|
528 |
+
def custom_forward(*inputs):
|
529 |
+
return module(*inputs)
|
530 |
+
|
531 |
+
return custom_forward
|
532 |
+
|
533 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
534 |
+
else:
|
535 |
+
hidden_states = resnet(hidden_states, temb)
|
536 |
+
|
537 |
+
output_states += (hidden_states,)
|
538 |
+
|
539 |
+
if self.downsamplers is not None:
|
540 |
+
for downsampler in self.downsamplers:
|
541 |
+
hidden_states = downsampler(hidden_states)
|
542 |
+
|
543 |
+
output_states += (hidden_states,)
|
544 |
+
|
545 |
+
return hidden_states, output_states
|
546 |
+
|
547 |
+
|
548 |
+
class Downsample2D(nn.Module):
|
549 |
+
def __init__(self, channels, out_channels):
|
550 |
+
super().__init__()
|
551 |
+
|
552 |
+
self.channels = channels
|
553 |
+
self.out_channels = out_channels
|
554 |
+
|
555 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
|
556 |
+
|
557 |
+
def forward(self, hidden_states):
|
558 |
+
assert hidden_states.shape[1] == self.channels
|
559 |
+
hidden_states = self.conv(hidden_states)
|
560 |
+
|
561 |
+
return hidden_states
|
562 |
+
|
563 |
+
|
564 |
+
class CrossAttention(nn.Module):
|
565 |
+
def __init__(
|
566 |
+
self,
|
567 |
+
query_dim: int,
|
568 |
+
cross_attention_dim: Optional[int] = None,
|
569 |
+
heads: int = 8,
|
570 |
+
dim_head: int = 64,
|
571 |
+
upcast_attention: bool = False,
|
572 |
+
):
|
573 |
+
super().__init__()
|
574 |
+
inner_dim = dim_head * heads
|
575 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
576 |
+
self.upcast_attention = upcast_attention
|
577 |
+
|
578 |
+
self.scale = dim_head**-0.5
|
579 |
+
self.heads = heads
|
580 |
+
|
581 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
582 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
583 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
584 |
+
|
585 |
+
self.to_out = nn.ModuleList([])
|
586 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
587 |
+
# no dropout here
|
588 |
+
|
589 |
+
self.use_memory_efficient_attention_xformers = False
|
590 |
+
self.use_memory_efficient_attention_mem_eff = False
|
591 |
+
self.use_sdpa = False
|
592 |
+
|
593 |
+
# Attention processor
|
594 |
+
self.processor = None
|
595 |
+
|
596 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
597 |
+
self.use_memory_efficient_attention_xformers = xformers
|
598 |
+
self.use_memory_efficient_attention_mem_eff = mem_eff
|
599 |
+
|
600 |
+
def set_use_sdpa(self, sdpa):
|
601 |
+
self.use_sdpa = sdpa
|
602 |
+
|
603 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
604 |
+
batch_size, seq_len, dim = tensor.shape
|
605 |
+
head_size = self.heads
|
606 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
607 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
608 |
+
return tensor
|
609 |
+
|
610 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
611 |
+
batch_size, seq_len, dim = tensor.shape
|
612 |
+
head_size = self.heads
|
613 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
614 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
615 |
+
return tensor
|
616 |
+
|
617 |
+
def set_processor(self):
|
618 |
+
return self.processor
|
619 |
+
|
620 |
+
def get_processor(self):
|
621 |
+
return self.processor
|
622 |
+
|
623 |
+
def forward(self, hidden_states, context=None, mask=None, **kwargs):
|
624 |
+
if self.processor is not None:
|
625 |
+
(
|
626 |
+
hidden_states,
|
627 |
+
encoder_hidden_states,
|
628 |
+
attention_mask,
|
629 |
+
) = translate_attention_names_from_diffusers(
|
630 |
+
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
631 |
+
)
|
632 |
+
return self.processor(
|
633 |
+
attn=self,
|
634 |
+
hidden_states=hidden_states,
|
635 |
+
encoder_hidden_states=context,
|
636 |
+
attention_mask=mask,
|
637 |
+
**kwargs
|
638 |
+
)
|
639 |
+
if self.use_memory_efficient_attention_xformers:
|
640 |
+
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
641 |
+
if self.use_memory_efficient_attention_mem_eff:
|
642 |
+
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
643 |
+
if self.use_sdpa:
|
644 |
+
return self.forward_sdpa(hidden_states, context, mask)
|
645 |
+
|
646 |
+
query = self.to_q(hidden_states)
|
647 |
+
context = context if context is not None else hidden_states
|
648 |
+
key = self.to_k(context)
|
649 |
+
value = self.to_v(context)
|
650 |
+
|
651 |
+
query = self.reshape_heads_to_batch_dim(query)
|
652 |
+
key = self.reshape_heads_to_batch_dim(key)
|
653 |
+
value = self.reshape_heads_to_batch_dim(value)
|
654 |
+
|
655 |
+
hidden_states = self._attention(query, key, value)
|
656 |
+
|
657 |
+
# linear proj
|
658 |
+
hidden_states = self.to_out[0](hidden_states)
|
659 |
+
# hidden_states = self.to_out[1](hidden_states) # no dropout
|
660 |
+
return hidden_states
|
661 |
+
|
662 |
+
def _attention(self, query, key, value):
|
663 |
+
if self.upcast_attention:
|
664 |
+
query = query.float()
|
665 |
+
key = key.float()
|
666 |
+
|
667 |
+
attention_scores = torch.baddbmm(
|
668 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
669 |
+
query,
|
670 |
+
key.transpose(-1, -2),
|
671 |
+
beta=0,
|
672 |
+
alpha=self.scale,
|
673 |
+
)
|
674 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
675 |
+
|
676 |
+
# cast back to the original dtype
|
677 |
+
attention_probs = attention_probs.to(value.dtype)
|
678 |
+
|
679 |
+
# compute attention output
|
680 |
+
hidden_states = torch.bmm(attention_probs, value)
|
681 |
+
|
682 |
+
# reshape hidden_states
|
683 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
684 |
+
return hidden_states
|
685 |
+
|
686 |
+
# TODO support Hypernetworks
|
687 |
+
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
|
688 |
+
import xformers.ops
|
689 |
+
|
690 |
+
h = self.heads
|
691 |
+
q_in = self.to_q(x)
|
692 |
+
context = context if context is not None else x
|
693 |
+
context = context.to(x.dtype)
|
694 |
+
k_in = self.to_k(context)
|
695 |
+
v_in = self.to_v(context)
|
696 |
+
|
697 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
698 |
+
del q_in, k_in, v_in
|
699 |
+
|
700 |
+
q = q.contiguous()
|
701 |
+
k = k.contiguous()
|
702 |
+
v = v.contiguous()
|
703 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
704 |
+
|
705 |
+
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
706 |
+
|
707 |
+
out = self.to_out[0](out)
|
708 |
+
return out
|
709 |
+
|
710 |
+
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
|
711 |
+
flash_func = FlashAttentionFunction
|
712 |
+
|
713 |
+
q_bucket_size = 512
|
714 |
+
k_bucket_size = 1024
|
715 |
+
|
716 |
+
h = self.heads
|
717 |
+
q = self.to_q(x)
|
718 |
+
context = context if context is not None else x
|
719 |
+
context = context.to(x.dtype)
|
720 |
+
k = self.to_k(context)
|
721 |
+
v = self.to_v(context)
|
722 |
+
del context, x
|
723 |
+
|
724 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
725 |
+
|
726 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
727 |
+
|
728 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
729 |
+
|
730 |
+
out = self.to_out[0](out)
|
731 |
+
return out
|
732 |
+
|
733 |
+
def forward_sdpa(self, x, context=None, mask=None):
|
734 |
+
h = self.heads
|
735 |
+
q_in = self.to_q(x)
|
736 |
+
context = context if context is not None else x
|
737 |
+
context = context.to(x.dtype)
|
738 |
+
k_in = self.to_k(context)
|
739 |
+
v_in = self.to_v(context)
|
740 |
+
|
741 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
742 |
+
del q_in, k_in, v_in
|
743 |
+
|
744 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
745 |
+
|
746 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
747 |
+
|
748 |
+
out = self.to_out[0](out)
|
749 |
+
return out
|
750 |
+
|
751 |
+
def translate_attention_names_from_diffusers(
|
752 |
+
hidden_states: torch.FloatTensor,
|
753 |
+
context: Optional[torch.FloatTensor] = None,
|
754 |
+
mask: Optional[torch.FloatTensor] = None,
|
755 |
+
# HF naming
|
756 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
757 |
+
attention_mask: Optional[torch.FloatTensor] = None
|
758 |
+
):
|
759 |
+
# translate from hugging face diffusers
|
760 |
+
context = context if context is not None else encoder_hidden_states
|
761 |
+
|
762 |
+
# translate from hugging face diffusers
|
763 |
+
mask = mask if mask is not None else attention_mask
|
764 |
+
|
765 |
+
return hidden_states, context, mask
|
766 |
+
|
767 |
+
# feedforward
|
768 |
+
class GEGLU(nn.Module):
|
769 |
+
r"""
|
770 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
771 |
+
|
772 |
+
Parameters:
|
773 |
+
dim_in (`int`): The number of channels in the input.
|
774 |
+
dim_out (`int`): The number of channels in the output.
|
775 |
+
"""
|
776 |
+
|
777 |
+
def __init__(self, dim_in: int, dim_out: int):
|
778 |
+
super().__init__()
|
779 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
780 |
+
|
781 |
+
def gelu(self, gate):
|
782 |
+
if gate.device.type != "mps":
|
783 |
+
return F.gelu(gate)
|
784 |
+
# mps: gelu is not implemented for float16
|
785 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
786 |
+
|
787 |
+
def forward(self, hidden_states):
|
788 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
789 |
+
return hidden_states * self.gelu(gate)
|
790 |
+
|
791 |
+
|
792 |
+
class FeedForward(nn.Module):
|
793 |
+
def __init__(
|
794 |
+
self,
|
795 |
+
dim: int,
|
796 |
+
):
|
797 |
+
super().__init__()
|
798 |
+
inner_dim = int(dim * 4) # mult is always 4
|
799 |
+
|
800 |
+
self.net = nn.ModuleList([])
|
801 |
+
# project in
|
802 |
+
self.net.append(GEGLU(dim, inner_dim))
|
803 |
+
# project dropout
|
804 |
+
self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
|
805 |
+
# project out
|
806 |
+
self.net.append(nn.Linear(inner_dim, dim))
|
807 |
+
|
808 |
+
def forward(self, hidden_states):
|
809 |
+
for module in self.net:
|
810 |
+
hidden_states = module(hidden_states)
|
811 |
+
return hidden_states
|
812 |
+
|
813 |
+
|
814 |
+
class BasicTransformerBlock(nn.Module):
|
815 |
+
def __init__(
|
816 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
|
817 |
+
):
|
818 |
+
super().__init__()
|
819 |
+
|
820 |
+
# 1. Self-Attn
|
821 |
+
self.attn1 = CrossAttention(
|
822 |
+
query_dim=dim,
|
823 |
+
cross_attention_dim=None,
|
824 |
+
heads=num_attention_heads,
|
825 |
+
dim_head=attention_head_dim,
|
826 |
+
upcast_attention=upcast_attention,
|
827 |
+
)
|
828 |
+
self.ff = FeedForward(dim)
|
829 |
+
|
830 |
+
# 2. Cross-Attn
|
831 |
+
self.attn2 = CrossAttention(
|
832 |
+
query_dim=dim,
|
833 |
+
cross_attention_dim=cross_attention_dim,
|
834 |
+
heads=num_attention_heads,
|
835 |
+
dim_head=attention_head_dim,
|
836 |
+
upcast_attention=upcast_attention,
|
837 |
+
)
|
838 |
+
|
839 |
+
self.norm1 = nn.LayerNorm(dim)
|
840 |
+
self.norm2 = nn.LayerNorm(dim)
|
841 |
+
|
842 |
+
# 3. Feed-forward
|
843 |
+
self.norm3 = nn.LayerNorm(dim)
|
844 |
+
|
845 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
|
846 |
+
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
847 |
+
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
848 |
+
|
849 |
+
def set_use_sdpa(self, sdpa: bool):
|
850 |
+
self.attn1.set_use_sdpa(sdpa)
|
851 |
+
self.attn2.set_use_sdpa(sdpa)
|
852 |
+
|
853 |
+
def forward(self, hidden_states, context=None, timestep=None):
|
854 |
+
# 1. Self-Attention
|
855 |
+
norm_hidden_states = self.norm1(hidden_states)
|
856 |
+
|
857 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
858 |
+
|
859 |
+
# 2. Cross-Attention
|
860 |
+
norm_hidden_states = self.norm2(hidden_states)
|
861 |
+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
862 |
+
|
863 |
+
# 3. Feed-forward
|
864 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
865 |
+
|
866 |
+
return hidden_states
|
867 |
+
|
868 |
+
|
869 |
+
class Transformer2DModel(nn.Module):
|
870 |
+
def __init__(
|
871 |
+
self,
|
872 |
+
num_attention_heads: int = 16,
|
873 |
+
attention_head_dim: int = 88,
|
874 |
+
in_channels: Optional[int] = None,
|
875 |
+
cross_attention_dim: Optional[int] = None,
|
876 |
+
use_linear_projection: bool = False,
|
877 |
+
upcast_attention: bool = False,
|
878 |
+
):
|
879 |
+
super().__init__()
|
880 |
+
self.in_channels = in_channels
|
881 |
+
self.num_attention_heads = num_attention_heads
|
882 |
+
self.attention_head_dim = attention_head_dim
|
883 |
+
inner_dim = num_attention_heads * attention_head_dim
|
884 |
+
self.use_linear_projection = use_linear_projection
|
885 |
+
|
886 |
+
self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
|
887 |
+
|
888 |
+
if use_linear_projection:
|
889 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
890 |
+
else:
|
891 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
892 |
+
|
893 |
+
self.transformer_blocks = nn.ModuleList(
|
894 |
+
[
|
895 |
+
BasicTransformerBlock(
|
896 |
+
inner_dim,
|
897 |
+
num_attention_heads,
|
898 |
+
attention_head_dim,
|
899 |
+
cross_attention_dim=cross_attention_dim,
|
900 |
+
upcast_attention=upcast_attention,
|
901 |
+
)
|
902 |
+
]
|
903 |
+
)
|
904 |
+
|
905 |
+
if use_linear_projection:
|
906 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
907 |
+
else:
|
908 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
909 |
+
|
910 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
911 |
+
for transformer in self.transformer_blocks:
|
912 |
+
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
913 |
+
|
914 |
+
def set_use_sdpa(self, sdpa):
|
915 |
+
for transformer in self.transformer_blocks:
|
916 |
+
transformer.set_use_sdpa(sdpa)
|
917 |
+
|
918 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
919 |
+
# 1. Input
|
920 |
+
batch, _, height, weight = hidden_states.shape
|
921 |
+
residual = hidden_states
|
922 |
+
|
923 |
+
hidden_states = self.norm(hidden_states)
|
924 |
+
if not self.use_linear_projection:
|
925 |
+
hidden_states = self.proj_in(hidden_states)
|
926 |
+
inner_dim = hidden_states.shape[1]
|
927 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
928 |
+
else:
|
929 |
+
inner_dim = hidden_states.shape[1]
|
930 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
931 |
+
hidden_states = self.proj_in(hidden_states)
|
932 |
+
|
933 |
+
# 2. Blocks
|
934 |
+
for block in self.transformer_blocks:
|
935 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
936 |
+
|
937 |
+
# 3. Output
|
938 |
+
if not self.use_linear_projection:
|
939 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
940 |
+
hidden_states = self.proj_out(hidden_states)
|
941 |
+
else:
|
942 |
+
hidden_states = self.proj_out(hidden_states)
|
943 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
944 |
+
|
945 |
+
output = hidden_states + residual
|
946 |
+
|
947 |
+
if not return_dict:
|
948 |
+
return (output,)
|
949 |
+
|
950 |
+
return SampleOutput(sample=output)
|
951 |
+
|
952 |
+
|
953 |
+
class CrossAttnDownBlock2D(nn.Module):
|
954 |
+
def __init__(
|
955 |
+
self,
|
956 |
+
in_channels: int,
|
957 |
+
out_channels: int,
|
958 |
+
add_downsample=True,
|
959 |
+
cross_attention_dim=1280,
|
960 |
+
attn_num_head_channels=1,
|
961 |
+
use_linear_projection=False,
|
962 |
+
upcast_attention=False,
|
963 |
+
):
|
964 |
+
super().__init__()
|
965 |
+
self.has_cross_attention = True
|
966 |
+
resnets = []
|
967 |
+
attentions = []
|
968 |
+
|
969 |
+
self.attn_num_head_channels = attn_num_head_channels
|
970 |
+
|
971 |
+
for i in range(LAYERS_PER_BLOCK):
|
972 |
+
in_channels = in_channels if i == 0 else out_channels
|
973 |
+
|
974 |
+
resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
|
975 |
+
attentions.append(
|
976 |
+
Transformer2DModel(
|
977 |
+
attn_num_head_channels,
|
978 |
+
out_channels // attn_num_head_channels,
|
979 |
+
in_channels=out_channels,
|
980 |
+
cross_attention_dim=cross_attention_dim,
|
981 |
+
use_linear_projection=use_linear_projection,
|
982 |
+
upcast_attention=upcast_attention,
|
983 |
+
)
|
984 |
+
)
|
985 |
+
self.attentions = nn.ModuleList(attentions)
|
986 |
+
self.resnets = nn.ModuleList(resnets)
|
987 |
+
|
988 |
+
if add_downsample:
|
989 |
+
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
|
990 |
+
else:
|
991 |
+
self.downsamplers = None
|
992 |
+
|
993 |
+
self.gradient_checkpointing = False
|
994 |
+
|
995 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
996 |
+
for attn in self.attentions:
|
997 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
998 |
+
|
999 |
+
def set_use_sdpa(self, sdpa):
|
1000 |
+
for attn in self.attentions:
|
1001 |
+
attn.set_use_sdpa(sdpa)
|
1002 |
+
|
1003 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
1004 |
+
output_states = ()
|
1005 |
+
|
1006 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1007 |
+
if self.training and self.gradient_checkpointing:
|
1008 |
+
|
1009 |
+
def create_custom_forward(module, return_dict=None):
|
1010 |
+
def custom_forward(*inputs):
|
1011 |
+
if return_dict is not None:
|
1012 |
+
return module(*inputs, return_dict=return_dict)
|
1013 |
+
else:
|
1014 |
+
return module(*inputs)
|
1015 |
+
|
1016 |
+
return custom_forward
|
1017 |
+
|
1018 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1019 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1020 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1021 |
+
)[0]
|
1022 |
+
else:
|
1023 |
+
hidden_states = resnet(hidden_states, temb)
|
1024 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1025 |
+
|
1026 |
+
output_states += (hidden_states,)
|
1027 |
+
|
1028 |
+
if self.downsamplers is not None:
|
1029 |
+
for downsampler in self.downsamplers:
|
1030 |
+
hidden_states = downsampler(hidden_states)
|
1031 |
+
|
1032 |
+
output_states += (hidden_states,)
|
1033 |
+
|
1034 |
+
return hidden_states, output_states
|
1035 |
+
|
1036 |
+
|
1037 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
1038 |
+
def __init__(
|
1039 |
+
self,
|
1040 |
+
in_channels: int,
|
1041 |
+
attn_num_head_channels=1,
|
1042 |
+
cross_attention_dim=1280,
|
1043 |
+
use_linear_projection=False,
|
1044 |
+
):
|
1045 |
+
super().__init__()
|
1046 |
+
|
1047 |
+
self.has_cross_attention = True
|
1048 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1049 |
+
|
1050 |
+
# Middle block has two resnets and one attention
|
1051 |
+
resnets = [
|
1052 |
+
ResnetBlock2D(
|
1053 |
+
in_channels=in_channels,
|
1054 |
+
out_channels=in_channels,
|
1055 |
+
),
|
1056 |
+
ResnetBlock2D(
|
1057 |
+
in_channels=in_channels,
|
1058 |
+
out_channels=in_channels,
|
1059 |
+
),
|
1060 |
+
]
|
1061 |
+
attentions = [
|
1062 |
+
Transformer2DModel(
|
1063 |
+
attn_num_head_channels,
|
1064 |
+
in_channels // attn_num_head_channels,
|
1065 |
+
in_channels=in_channels,
|
1066 |
+
cross_attention_dim=cross_attention_dim,
|
1067 |
+
use_linear_projection=use_linear_projection,
|
1068 |
+
)
|
1069 |
+
]
|
1070 |
+
|
1071 |
+
self.attentions = nn.ModuleList(attentions)
|
1072 |
+
self.resnets = nn.ModuleList(resnets)
|
1073 |
+
|
1074 |
+
self.gradient_checkpointing = False
|
1075 |
+
|
1076 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1077 |
+
for attn in self.attentions:
|
1078 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
1079 |
+
|
1080 |
+
def set_use_sdpa(self, sdpa):
|
1081 |
+
for attn in self.attentions:
|
1082 |
+
attn.set_use_sdpa(sdpa)
|
1083 |
+
|
1084 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
1085 |
+
for i, resnet in enumerate(self.resnets):
|
1086 |
+
attn = None if i == 0 else self.attentions[i - 1]
|
1087 |
+
|
1088 |
+
if self.training and self.gradient_checkpointing:
|
1089 |
+
|
1090 |
+
def create_custom_forward(module, return_dict=None):
|
1091 |
+
def custom_forward(*inputs):
|
1092 |
+
if return_dict is not None:
|
1093 |
+
return module(*inputs, return_dict=return_dict)
|
1094 |
+
else:
|
1095 |
+
return module(*inputs)
|
1096 |
+
|
1097 |
+
return custom_forward
|
1098 |
+
|
1099 |
+
if attn is not None:
|
1100 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1101 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1102 |
+
)[0]
|
1103 |
+
|
1104 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1105 |
+
else:
|
1106 |
+
if attn is not None:
|
1107 |
+
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
1108 |
+
hidden_states = resnet(hidden_states, temb)
|
1109 |
+
|
1110 |
+
return hidden_states
|
1111 |
+
|
1112 |
+
|
1113 |
+
class Upsample2D(nn.Module):
|
1114 |
+
def __init__(self, channels, out_channels):
|
1115 |
+
super().__init__()
|
1116 |
+
self.channels = channels
|
1117 |
+
self.out_channels = out_channels
|
1118 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
1119 |
+
|
1120 |
+
def forward(self, hidden_states, output_size):
|
1121 |
+
assert hidden_states.shape[1] == self.channels
|
1122 |
+
|
1123 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
1124 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
1125 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
1126 |
+
dtype = hidden_states.dtype
|
1127 |
+
if dtype == torch.bfloat16:
|
1128 |
+
hidden_states = hidden_states.to(torch.float32)
|
1129 |
+
|
1130 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
1131 |
+
if hidden_states.shape[0] >= 64:
|
1132 |
+
hidden_states = hidden_states.contiguous()
|
1133 |
+
|
1134 |
+
# if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
|
1135 |
+
if output_size is None:
|
1136 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
1137 |
+
else:
|
1138 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
1139 |
+
|
1140 |
+
# If the input is bfloat16, we cast back to bfloat16
|
1141 |
+
if dtype == torch.bfloat16:
|
1142 |
+
hidden_states = hidden_states.to(dtype)
|
1143 |
+
|
1144 |
+
hidden_states = self.conv(hidden_states)
|
1145 |
+
|
1146 |
+
return hidden_states
|
1147 |
+
|
1148 |
+
|
1149 |
+
class UpBlock2D(nn.Module):
|
1150 |
+
def __init__(
|
1151 |
+
self,
|
1152 |
+
in_channels: int,
|
1153 |
+
prev_output_channel: int,
|
1154 |
+
out_channels: int,
|
1155 |
+
add_upsample=True,
|
1156 |
+
):
|
1157 |
+
super().__init__()
|
1158 |
+
|
1159 |
+
self.has_cross_attention = False
|
1160 |
+
resnets = []
|
1161 |
+
|
1162 |
+
for i in range(LAYERS_PER_BLOCK_UP):
|
1163 |
+
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
|
1164 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1165 |
+
|
1166 |
+
resnets.append(
|
1167 |
+
ResnetBlock2D(
|
1168 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1169 |
+
out_channels=out_channels,
|
1170 |
+
)
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
self.resnets = nn.ModuleList(resnets)
|
1174 |
+
|
1175 |
+
if add_upsample:
|
1176 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
1177 |
+
else:
|
1178 |
+
self.upsamplers = None
|
1179 |
+
|
1180 |
+
self.gradient_checkpointing = False
|
1181 |
+
|
1182 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1183 |
+
pass
|
1184 |
+
|
1185 |
+
def set_use_sdpa(self, sdpa):
|
1186 |
+
pass
|
1187 |
+
|
1188 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
1189 |
+
for resnet in self.resnets:
|
1190 |
+
# pop res hidden states
|
1191 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1192 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1193 |
+
|
1194 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1195 |
+
|
1196 |
+
if self.training and self.gradient_checkpointing:
|
1197 |
+
|
1198 |
+
def create_custom_forward(module):
|
1199 |
+
def custom_forward(*inputs):
|
1200 |
+
return module(*inputs)
|
1201 |
+
|
1202 |
+
return custom_forward
|
1203 |
+
|
1204 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1205 |
+
else:
|
1206 |
+
hidden_states = resnet(hidden_states, temb)
|
1207 |
+
|
1208 |
+
if self.upsamplers is not None:
|
1209 |
+
for upsampler in self.upsamplers:
|
1210 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1211 |
+
|
1212 |
+
return hidden_states
|
1213 |
+
|
1214 |
+
|
1215 |
+
class CrossAttnUpBlock2D(nn.Module):
|
1216 |
+
def __init__(
|
1217 |
+
self,
|
1218 |
+
in_channels: int,
|
1219 |
+
out_channels: int,
|
1220 |
+
prev_output_channel: int,
|
1221 |
+
attn_num_head_channels=1,
|
1222 |
+
cross_attention_dim=1280,
|
1223 |
+
add_upsample=True,
|
1224 |
+
use_linear_projection=False,
|
1225 |
+
upcast_attention=False,
|
1226 |
+
):
|
1227 |
+
super().__init__()
|
1228 |
+
resnets = []
|
1229 |
+
attentions = []
|
1230 |
+
|
1231 |
+
self.has_cross_attention = True
|
1232 |
+
self.attn_num_head_channels = attn_num_head_channels
|
1233 |
+
|
1234 |
+
for i in range(LAYERS_PER_BLOCK_UP):
|
1235 |
+
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
|
1236 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1237 |
+
|
1238 |
+
resnets.append(
|
1239 |
+
ResnetBlock2D(
|
1240 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
1241 |
+
out_channels=out_channels,
|
1242 |
+
)
|
1243 |
+
)
|
1244 |
+
attentions.append(
|
1245 |
+
Transformer2DModel(
|
1246 |
+
attn_num_head_channels,
|
1247 |
+
out_channels // attn_num_head_channels,
|
1248 |
+
in_channels=out_channels,
|
1249 |
+
cross_attention_dim=cross_attention_dim,
|
1250 |
+
use_linear_projection=use_linear_projection,
|
1251 |
+
upcast_attention=upcast_attention,
|
1252 |
+
)
|
1253 |
+
)
|
1254 |
+
|
1255 |
+
self.attentions = nn.ModuleList(attentions)
|
1256 |
+
self.resnets = nn.ModuleList(resnets)
|
1257 |
+
|
1258 |
+
if add_upsample:
|
1259 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
1260 |
+
else:
|
1261 |
+
self.upsamplers = None
|
1262 |
+
|
1263 |
+
self.gradient_checkpointing = False
|
1264 |
+
|
1265 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
1266 |
+
for attn in self.attentions:
|
1267 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
1268 |
+
|
1269 |
+
def set_use_sdpa(self, sdpa):
|
1270 |
+
for attn in self.attentions:
|
1271 |
+
attn.set_use_sdpa(sdpa)
|
1272 |
+
|
1273 |
+
def forward(
|
1274 |
+
self,
|
1275 |
+
hidden_states,
|
1276 |
+
res_hidden_states_tuple,
|
1277 |
+
temb=None,
|
1278 |
+
encoder_hidden_states=None,
|
1279 |
+
upsample_size=None,
|
1280 |
+
):
|
1281 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1282 |
+
# pop res hidden states
|
1283 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1284 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1285 |
+
|
1286 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1287 |
+
|
1288 |
+
if self.training and self.gradient_checkpointing:
|
1289 |
+
|
1290 |
+
def create_custom_forward(module, return_dict=None):
|
1291 |
+
def custom_forward(*inputs):
|
1292 |
+
if return_dict is not None:
|
1293 |
+
return module(*inputs, return_dict=return_dict)
|
1294 |
+
else:
|
1295 |
+
return module(*inputs)
|
1296 |
+
|
1297 |
+
return custom_forward
|
1298 |
+
|
1299 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1300 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1301 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
1302 |
+
)[0]
|
1303 |
+
else:
|
1304 |
+
hidden_states = resnet(hidden_states, temb)
|
1305 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1306 |
+
|
1307 |
+
if self.upsamplers is not None:
|
1308 |
+
for upsampler in self.upsamplers:
|
1309 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1310 |
+
|
1311 |
+
return hidden_states
|
1312 |
+
|
1313 |
+
|
1314 |
+
def get_down_block(
|
1315 |
+
down_block_type,
|
1316 |
+
in_channels,
|
1317 |
+
out_channels,
|
1318 |
+
add_downsample,
|
1319 |
+
attn_num_head_channels,
|
1320 |
+
cross_attention_dim,
|
1321 |
+
use_linear_projection,
|
1322 |
+
upcast_attention,
|
1323 |
+
):
|
1324 |
+
if down_block_type == "DownBlock2D":
|
1325 |
+
return DownBlock2D(
|
1326 |
+
in_channels=in_channels,
|
1327 |
+
out_channels=out_channels,
|
1328 |
+
add_downsample=add_downsample,
|
1329 |
+
)
|
1330 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
1331 |
+
return CrossAttnDownBlock2D(
|
1332 |
+
in_channels=in_channels,
|
1333 |
+
out_channels=out_channels,
|
1334 |
+
add_downsample=add_downsample,
|
1335 |
+
cross_attention_dim=cross_attention_dim,
|
1336 |
+
attn_num_head_channels=attn_num_head_channels,
|
1337 |
+
use_linear_projection=use_linear_projection,
|
1338 |
+
upcast_attention=upcast_attention,
|
1339 |
+
)
|
1340 |
+
|
1341 |
+
|
1342 |
+
def get_up_block(
|
1343 |
+
up_block_type,
|
1344 |
+
in_channels,
|
1345 |
+
out_channels,
|
1346 |
+
prev_output_channel,
|
1347 |
+
add_upsample,
|
1348 |
+
attn_num_head_channels,
|
1349 |
+
cross_attention_dim=None,
|
1350 |
+
use_linear_projection=False,
|
1351 |
+
upcast_attention=False,
|
1352 |
+
):
|
1353 |
+
if up_block_type == "UpBlock2D":
|
1354 |
+
return UpBlock2D(
|
1355 |
+
in_channels=in_channels,
|
1356 |
+
prev_output_channel=prev_output_channel,
|
1357 |
+
out_channels=out_channels,
|
1358 |
+
add_upsample=add_upsample,
|
1359 |
+
)
|
1360 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
1361 |
+
return CrossAttnUpBlock2D(
|
1362 |
+
in_channels=in_channels,
|
1363 |
+
out_channels=out_channels,
|
1364 |
+
prev_output_channel=prev_output_channel,
|
1365 |
+
attn_num_head_channels=attn_num_head_channels,
|
1366 |
+
cross_attention_dim=cross_attention_dim,
|
1367 |
+
add_upsample=add_upsample,
|
1368 |
+
use_linear_projection=use_linear_projection,
|
1369 |
+
upcast_attention=upcast_attention,
|
1370 |
+
)
|
1371 |
+
|
1372 |
+
|
1373 |
+
class UNet2DConditionModel(nn.Module):
|
1374 |
+
_supports_gradient_checkpointing = True
|
1375 |
+
|
1376 |
+
def __init__(
|
1377 |
+
self,
|
1378 |
+
sample_size: Optional[int] = None,
|
1379 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
1380 |
+
cross_attention_dim: int = 1280,
|
1381 |
+
use_linear_projection: bool = False,
|
1382 |
+
upcast_attention: bool = False,
|
1383 |
+
**kwargs,
|
1384 |
+
):
|
1385 |
+
super().__init__()
|
1386 |
+
assert sample_size is not None, "sample_size must be specified"
|
1387 |
+
logger.info(
|
1388 |
+
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
|
1389 |
+
)
|
1390 |
+
|
1391 |
+
# 外部からの参照用に定義しておく
|
1392 |
+
self.in_channels = IN_CHANNELS
|
1393 |
+
self.out_channels = OUT_CHANNELS
|
1394 |
+
|
1395 |
+
self.sample_size = sample_size
|
1396 |
+
self.prepare_config(sample_size=sample_size)
|
1397 |
+
|
1398 |
+
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
1399 |
+
|
1400 |
+
# input
|
1401 |
+
self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
|
1402 |
+
|
1403 |
+
# time
|
1404 |
+
self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
|
1405 |
+
|
1406 |
+
self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
|
1407 |
+
|
1408 |
+
self.down_blocks = nn.ModuleList([])
|
1409 |
+
self.mid_block = None
|
1410 |
+
self.up_blocks = nn.ModuleList([])
|
1411 |
+
|
1412 |
+
if isinstance(attention_head_dim, int):
|
1413 |
+
attention_head_dim = (attention_head_dim,) * 4
|
1414 |
+
|
1415 |
+
# down
|
1416 |
+
output_channel = BLOCK_OUT_CHANNELS[0]
|
1417 |
+
for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
|
1418 |
+
input_channel = output_channel
|
1419 |
+
output_channel = BLOCK_OUT_CHANNELS[i]
|
1420 |
+
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
|
1421 |
+
|
1422 |
+
down_block = get_down_block(
|
1423 |
+
down_block_type,
|
1424 |
+
in_channels=input_channel,
|
1425 |
+
out_channels=output_channel,
|
1426 |
+
add_downsample=not is_final_block,
|
1427 |
+
attn_num_head_channels=attention_head_dim[i],
|
1428 |
+
cross_attention_dim=cross_attention_dim,
|
1429 |
+
use_linear_projection=use_linear_projection,
|
1430 |
+
upcast_attention=upcast_attention,
|
1431 |
+
)
|
1432 |
+
self.down_blocks.append(down_block)
|
1433 |
+
|
1434 |
+
# mid
|
1435 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
1436 |
+
in_channels=BLOCK_OUT_CHANNELS[-1],
|
1437 |
+
attn_num_head_channels=attention_head_dim[-1],
|
1438 |
+
cross_attention_dim=cross_attention_dim,
|
1439 |
+
use_linear_projection=use_linear_projection,
|
1440 |
+
)
|
1441 |
+
|
1442 |
+
# count how many layers upsample the images
|
1443 |
+
self.num_upsamplers = 0
|
1444 |
+
|
1445 |
+
# up
|
1446 |
+
reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
|
1447 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
1448 |
+
output_channel = reversed_block_out_channels[0]
|
1449 |
+
for i, up_block_type in enumerate(UP_BLOCK_TYPES):
|
1450 |
+
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
|
1451 |
+
|
1452 |
+
prev_output_channel = output_channel
|
1453 |
+
output_channel = reversed_block_out_channels[i]
|
1454 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
|
1455 |
+
|
1456 |
+
# add upsample block for all BUT final layer
|
1457 |
+
if not is_final_block:
|
1458 |
+
add_upsample = True
|
1459 |
+
self.num_upsamplers += 1
|
1460 |
+
else:
|
1461 |
+
add_upsample = False
|
1462 |
+
|
1463 |
+
up_block = get_up_block(
|
1464 |
+
up_block_type,
|
1465 |
+
in_channels=input_channel,
|
1466 |
+
out_channels=output_channel,
|
1467 |
+
prev_output_channel=prev_output_channel,
|
1468 |
+
add_upsample=add_upsample,
|
1469 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
1470 |
+
cross_attention_dim=cross_attention_dim,
|
1471 |
+
use_linear_projection=use_linear_projection,
|
1472 |
+
upcast_attention=upcast_attention,
|
1473 |
+
)
|
1474 |
+
self.up_blocks.append(up_block)
|
1475 |
+
prev_output_channel = output_channel
|
1476 |
+
|
1477 |
+
# out
|
1478 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
|
1479 |
+
self.conv_act = nn.SiLU()
|
1480 |
+
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
1481 |
+
|
1482 |
+
# region diffusers compatibility
|
1483 |
+
def prepare_config(self, *args, **kwargs):
|
1484 |
+
self.config = SimpleNamespace(**kwargs)
|
1485 |
+
|
1486 |
+
@property
|
1487 |
+
def dtype(self) -> torch.dtype:
|
1488 |
+
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
1489 |
+
return get_parameter_dtype(self)
|
1490 |
+
|
1491 |
+
@property
|
1492 |
+
def device(self) -> torch.device:
|
1493 |
+
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
1494 |
+
return get_parameter_device(self)
|
1495 |
+
|
1496 |
+
def set_attention_slice(self, slice_size):
|
1497 |
+
raise NotImplementedError("Attention slicing is not supported for this model.")
|
1498 |
+
|
1499 |
+
def is_gradient_checkpointing(self) -> bool:
|
1500 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
1501 |
+
|
1502 |
+
def enable_gradient_checkpointing(self):
|
1503 |
+
self.set_gradient_checkpointing(value=True)
|
1504 |
+
|
1505 |
+
def disable_gradient_checkpointing(self):
|
1506 |
+
self.set_gradient_checkpointing(value=False)
|
1507 |
+
|
1508 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
|
1509 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1510 |
+
for module in modules:
|
1511 |
+
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
1512 |
+
|
1513 |
+
def set_use_sdpa(self, sdpa: bool) -> None:
|
1514 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1515 |
+
for module in modules:
|
1516 |
+
module.set_use_sdpa(sdpa)
|
1517 |
+
|
1518 |
+
def set_gradient_checkpointing(self, value=False):
|
1519 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
1520 |
+
for module in modules:
|
1521 |
+
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
1522 |
+
module.gradient_checkpointing = value
|
1523 |
+
|
1524 |
+
# endregion
|
1525 |
+
|
1526 |
+
def forward(
|
1527 |
+
self,
|
1528 |
+
sample: torch.FloatTensor,
|
1529 |
+
timestep: Union[torch.Tensor, float, int],
|
1530 |
+
encoder_hidden_states: torch.Tensor,
|
1531 |
+
class_labels: Optional[torch.Tensor] = None,
|
1532 |
+
return_dict: bool = True,
|
1533 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1534 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1535 |
+
) -> Union[Dict, Tuple]:
|
1536 |
+
r"""
|
1537 |
+
Args:
|
1538 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
1539 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
1540 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
1541 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1542 |
+
Whether or not to return a dict instead of a plain tuple.
|
1543 |
+
|
1544 |
+
Returns:
|
1545 |
+
`SampleOutput` or `tuple`:
|
1546 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
1547 |
+
"""
|
1548 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1549 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
1550 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1551 |
+
# on the fly if necessary.
|
1552 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
1553 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
1554 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
1555 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
1556 |
+
|
1557 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1558 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
1559 |
+
forward_upsample_size = False
|
1560 |
+
upsample_size = None
|
1561 |
+
|
1562 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
1563 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
1564 |
+
forward_upsample_size = True
|
1565 |
+
|
1566 |
+
# 1. time
|
1567 |
+
timesteps = timestep
|
1568 |
+
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
1569 |
+
|
1570 |
+
t_emb = self.time_proj(timesteps)
|
1571 |
+
|
1572 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
1573 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1574 |
+
# there might be better ways to encapsulate this.
|
1575 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
1576 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
1577 |
+
# time_projでキャストしておけばいいんじゃね?
|
1578 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
1579 |
+
emb = self.time_embedding(t_emb)
|
1580 |
+
|
1581 |
+
# 2. pre-process
|
1582 |
+
sample = self.conv_in(sample)
|
1583 |
+
|
1584 |
+
down_block_res_samples = (sample,)
|
1585 |
+
for downsample_block in self.down_blocks:
|
1586 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
1587 |
+
# まあこちらのほうがわかりやすいかもしれない
|
1588 |
+
if downsample_block.has_cross_attention:
|
1589 |
+
sample, res_samples = downsample_block(
|
1590 |
+
hidden_states=sample,
|
1591 |
+
temb=emb,
|
1592 |
+
encoder_hidden_states=encoder_hidden_states,
|
1593 |
+
)
|
1594 |
+
else:
|
1595 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
1596 |
+
|
1597 |
+
down_block_res_samples += res_samples
|
1598 |
+
|
1599 |
+
# skip connectionにControlNetの出力を追加する
|
1600 |
+
if down_block_additional_residuals is not None:
|
1601 |
+
down_block_res_samples = list(down_block_res_samples)
|
1602 |
+
for i in range(len(down_block_res_samples)):
|
1603 |
+
down_block_res_samples[i] += down_block_additional_residuals[i]
|
1604 |
+
down_block_res_samples = tuple(down_block_res_samples)
|
1605 |
+
|
1606 |
+
# 4. mid
|
1607 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
1608 |
+
|
1609 |
+
# ControlNetの出力を追加する
|
1610 |
+
if mid_block_additional_residual is not None:
|
1611 |
+
sample += mid_block_additional_residual
|
1612 |
+
|
1613 |
+
# 5. up
|
1614 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
1615 |
+
is_final_block = i == len(self.up_blocks) - 1
|
1616 |
+
|
1617 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1618 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
1619 |
+
|
1620 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
1621 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
1622 |
+
if not is_final_block and forward_upsample_size:
|
1623 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1624 |
+
|
1625 |
+
if upsample_block.has_cross_attention:
|
1626 |
+
sample = upsample_block(
|
1627 |
+
hidden_states=sample,
|
1628 |
+
temb=emb,
|
1629 |
+
res_hidden_states_tuple=res_samples,
|
1630 |
+
encoder_hidden_states=encoder_hidden_states,
|
1631 |
+
upsample_size=upsample_size,
|
1632 |
+
)
|
1633 |
+
else:
|
1634 |
+
sample = upsample_block(
|
1635 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
1636 |
+
)
|
1637 |
+
|
1638 |
+
# 6. post-process
|
1639 |
+
sample = self.conv_norm_out(sample)
|
1640 |
+
sample = self.conv_act(sample)
|
1641 |
+
sample = self.conv_out(sample)
|
1642 |
+
|
1643 |
+
if not return_dict:
|
1644 |
+
return (sample,)
|
1645 |
+
|
1646 |
+
return SampleOutput(sample=sample)
|
1647 |
+
|
1648 |
+
def handle_unusual_timesteps(self, sample, timesteps):
|
1649 |
+
r"""
|
1650 |
+
timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
|
1651 |
+
"""
|
1652 |
+
if not torch.is_tensor(timesteps):
|
1653 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
1654 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
1655 |
+
is_mps = sample.device.type == "mps"
|
1656 |
+
if isinstance(timesteps, float):
|
1657 |
+
dtype = torch.float32 if is_mps else torch.float64
|
1658 |
+
else:
|
1659 |
+
dtype = torch.int32 if is_mps else torch.int64
|
1660 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
1661 |
+
elif len(timesteps.shape) == 0:
|
1662 |
+
timesteps = timesteps[None].to(sample.device)
|
1663 |
+
|
1664 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1665 |
+
timesteps = timesteps.expand(sample.shape[0])
|
1666 |
+
|
1667 |
+
return timesteps
|
1668 |
+
|
1669 |
+
|
1670 |
+
class InferUNet2DConditionModel:
|
1671 |
+
def __init__(self, original_unet: UNet2DConditionModel):
|
1672 |
+
self.delegate = original_unet
|
1673 |
+
|
1674 |
+
# override original model's forward method: because forward is not called by `__call__`
|
1675 |
+
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
1676 |
+
self.delegate.forward = self.forward
|
1677 |
+
|
1678 |
+
# override original model's up blocks' forward method
|
1679 |
+
for up_block in self.delegate.up_blocks:
|
1680 |
+
if up_block.__class__.__name__ == "UpBlock2D":
|
1681 |
+
|
1682 |
+
def resnet_wrapper(func, block):
|
1683 |
+
def forward(*args, **kwargs):
|
1684 |
+
return func(block, *args, **kwargs)
|
1685 |
+
|
1686 |
+
return forward
|
1687 |
+
|
1688 |
+
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
|
1689 |
+
|
1690 |
+
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
1691 |
+
|
1692 |
+
def cross_attn_up_wrapper(func, block):
|
1693 |
+
def forward(*args, **kwargs):
|
1694 |
+
return func(block, *args, **kwargs)
|
1695 |
+
|
1696 |
+
return forward
|
1697 |
+
|
1698 |
+
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
|
1699 |
+
|
1700 |
+
# Deep Shrink
|
1701 |
+
self.ds_depth_1 = None
|
1702 |
+
self.ds_depth_2 = None
|
1703 |
+
self.ds_timesteps_1 = None
|
1704 |
+
self.ds_timesteps_2 = None
|
1705 |
+
self.ds_ratio = None
|
1706 |
+
|
1707 |
+
# call original model's methods
|
1708 |
+
def __getattr__(self, name):
|
1709 |
+
return getattr(self.delegate, name)
|
1710 |
+
|
1711 |
+
def __call__(self, *args, **kwargs):
|
1712 |
+
return self.delegate(*args, **kwargs)
|
1713 |
+
|
1714 |
+
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
1715 |
+
if ds_depth_1 is None:
|
1716 |
+
logger.info("Deep Shrink is disabled.")
|
1717 |
+
self.ds_depth_1 = None
|
1718 |
+
self.ds_timesteps_1 = None
|
1719 |
+
self.ds_depth_2 = None
|
1720 |
+
self.ds_timesteps_2 = None
|
1721 |
+
self.ds_ratio = None
|
1722 |
+
else:
|
1723 |
+
logger.info(
|
1724 |
+
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
1725 |
+
)
|
1726 |
+
self.ds_depth_1 = ds_depth_1
|
1727 |
+
self.ds_timesteps_1 = ds_timesteps_1
|
1728 |
+
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
1729 |
+
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
1730 |
+
self.ds_ratio = ds_ratio
|
1731 |
+
|
1732 |
+
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
1733 |
+
for resnet in _self.resnets:
|
1734 |
+
# pop res hidden states
|
1735 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1736 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1737 |
+
|
1738 |
+
# Deep Shrink
|
1739 |
+
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
1740 |
+
hidden_states = resize_like(hidden_states, res_hidden_states)
|
1741 |
+
|
1742 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1743 |
+
hidden_states = resnet(hidden_states, temb)
|
1744 |
+
|
1745 |
+
if _self.upsamplers is not None:
|
1746 |
+
for upsampler in _self.upsamplers:
|
1747 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1748 |
+
|
1749 |
+
return hidden_states
|
1750 |
+
|
1751 |
+
def cross_attn_up_block_forward(
|
1752 |
+
self,
|
1753 |
+
_self,
|
1754 |
+
hidden_states,
|
1755 |
+
res_hidden_states_tuple,
|
1756 |
+
temb=None,
|
1757 |
+
encoder_hidden_states=None,
|
1758 |
+
upsample_size=None,
|
1759 |
+
):
|
1760 |
+
for resnet, attn in zip(_self.resnets, _self.attentions):
|
1761 |
+
# pop res hidden states
|
1762 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
1763 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1764 |
+
|
1765 |
+
# Deep Shrink
|
1766 |
+
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
1767 |
+
hidden_states = resize_like(hidden_states, res_hidden_states)
|
1768 |
+
|
1769 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1770 |
+
hidden_states = resnet(hidden_states, temb)
|
1771 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
1772 |
+
|
1773 |
+
if _self.upsamplers is not None:
|
1774 |
+
for upsampler in _self.upsamplers:
|
1775 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1776 |
+
|
1777 |
+
return hidden_states
|
1778 |
+
|
1779 |
+
def forward(
|
1780 |
+
self,
|
1781 |
+
sample: torch.FloatTensor,
|
1782 |
+
timestep: Union[torch.Tensor, float, int],
|
1783 |
+
encoder_hidden_states: torch.Tensor,
|
1784 |
+
class_labels: Optional[torch.Tensor] = None,
|
1785 |
+
return_dict: bool = True,
|
1786 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1787 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1788 |
+
) -> Union[Dict, Tuple]:
|
1789 |
+
r"""
|
1790 |
+
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
|
1791 |
+
"""
|
1792 |
+
|
1793 |
+
r"""
|
1794 |
+
Args:
|
1795 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
1796 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
1797 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
1798 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1799 |
+
Whether or not to return a dict instead of a plain tuple.
|
1800 |
+
|
1801 |
+
Returns:
|
1802 |
+
`SampleOutput` or `tuple`:
|
1803 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
1804 |
+
"""
|
1805 |
+
|
1806 |
+
_self = self.delegate
|
1807 |
+
|
1808 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1809 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
1810 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
1811 |
+
# on the fly if necessary.
|
1812 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
1813 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
1814 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
1815 |
+
default_overall_up_factor = 2**_self.num_upsamplers
|
1816 |
+
|
1817 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
1818 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
1819 |
+
forward_upsample_size = False
|
1820 |
+
upsample_size = None
|
1821 |
+
|
1822 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
1823 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
1824 |
+
forward_upsample_size = True
|
1825 |
+
|
1826 |
+
# 1. time
|
1827 |
+
timesteps = timestep
|
1828 |
+
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
1829 |
+
|
1830 |
+
t_emb = _self.time_proj(timesteps)
|
1831 |
+
|
1832 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
1833 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1834 |
+
# there might be better ways to encapsulate this.
|
1835 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
1836 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
1837 |
+
# time_projでキャストしておけばいいんじゃね?
|
1838 |
+
t_emb = t_emb.to(dtype=_self.dtype)
|
1839 |
+
emb = _self.time_embedding(t_emb)
|
1840 |
+
|
1841 |
+
# 2. pre-process
|
1842 |
+
sample = _self.conv_in(sample)
|
1843 |
+
|
1844 |
+
down_block_res_samples = (sample,)
|
1845 |
+
for depth, downsample_block in enumerate(_self.down_blocks):
|
1846 |
+
# Deep Shrink
|
1847 |
+
if self.ds_depth_1 is not None:
|
1848 |
+
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
1849 |
+
self.ds_depth_2 is not None
|
1850 |
+
and depth == self.ds_depth_2
|
1851 |
+
and timesteps[0] < self.ds_timesteps_1
|
1852 |
+
and timesteps[0] >= self.ds_timesteps_2
|
1853 |
+
):
|
1854 |
+
org_dtype = sample.dtype
|
1855 |
+
if org_dtype == torch.bfloat16:
|
1856 |
+
sample = sample.to(torch.float32)
|
1857 |
+
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
1858 |
+
|
1859 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
1860 |
+
# まあこちらのほうがわかりやすいかもしれない
|
1861 |
+
if downsample_block.has_cross_attention:
|
1862 |
+
sample, res_samples = downsample_block(
|
1863 |
+
hidden_states=sample,
|
1864 |
+
temb=emb,
|
1865 |
+
encoder_hidden_states=encoder_hidden_states,
|
1866 |
+
)
|
1867 |
+
else:
|
1868 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
1869 |
+
|
1870 |
+
down_block_res_samples += res_samples
|
1871 |
+
|
1872 |
+
# skip connectionにControlNetの出力を追加する
|
1873 |
+
if down_block_additional_residuals is not None:
|
1874 |
+
down_block_res_samples = list(down_block_res_samples)
|
1875 |
+
for i in range(len(down_block_res_samples)):
|
1876 |
+
down_block_res_samples[i] += down_block_additional_residuals[i]
|
1877 |
+
down_block_res_samples = tuple(down_block_res_samples)
|
1878 |
+
|
1879 |
+
# 4. mid
|
1880 |
+
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
1881 |
+
|
1882 |
+
# ControlNetの出力を追加する
|
1883 |
+
if mid_block_additional_residual is not None:
|
1884 |
+
sample += mid_block_additional_residual
|
1885 |
+
|
1886 |
+
# 5. up
|
1887 |
+
for i, upsample_block in enumerate(_self.up_blocks):
|
1888 |
+
is_final_block = i == len(_self.up_blocks) - 1
|
1889 |
+
|
1890 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1891 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
1892 |
+
|
1893 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
1894 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
1895 |
+
if not is_final_block and forward_upsample_size:
|
1896 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
1897 |
+
|
1898 |
+
if upsample_block.has_cross_attention:
|
1899 |
+
sample = upsample_block(
|
1900 |
+
hidden_states=sample,
|
1901 |
+
temb=emb,
|
1902 |
+
res_hidden_states_tuple=res_samples,
|
1903 |
+
encoder_hidden_states=encoder_hidden_states,
|
1904 |
+
upsample_size=upsample_size,
|
1905 |
+
)
|
1906 |
+
else:
|
1907 |
+
sample = upsample_block(
|
1908 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
1909 |
+
)
|
1910 |
+
|
1911 |
+
# 6. post-process
|
1912 |
+
sample = _self.conv_norm_out(sample)
|
1913 |
+
sample = _self.conv_act(sample)
|
1914 |
+
sample = _self.conv_out(sample)
|
1915 |
+
|
1916 |
+
if not return_dict:
|
1917 |
+
return (sample,)
|
1918 |
+
|
1919 |
+
return SampleOutput(sample=sample)
|
library/sai_model_spec.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/Stability-AI/ModelSpec
|
2 |
+
import datetime
|
3 |
+
import hashlib
|
4 |
+
from io import BytesIO
|
5 |
+
import os
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
import safetensors
|
8 |
+
from library.utils import setup_logging
|
9 |
+
|
10 |
+
setup_logging()
|
11 |
+
import logging
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
r"""
|
16 |
+
# Metadata Example
|
17 |
+
metadata = {
|
18 |
+
# === Must ===
|
19 |
+
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
20 |
+
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
|
21 |
+
"modelspec.implementation": "sgm",
|
22 |
+
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
|
23 |
+
# === Should ===
|
24 |
+
"modelspec.author": "Example Corp", # Your name or company name
|
25 |
+
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
|
26 |
+
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
|
27 |
+
# === Can ===
|
28 |
+
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
|
29 |
+
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
|
30 |
+
}
|
31 |
+
"""
|
32 |
+
|
33 |
+
BASE_METADATA = {
|
34 |
+
# === Must ===
|
35 |
+
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
36 |
+
"modelspec.architecture": None,
|
37 |
+
"modelspec.implementation": None,
|
38 |
+
"modelspec.title": None,
|
39 |
+
"modelspec.resolution": None,
|
40 |
+
# === Should ===
|
41 |
+
"modelspec.description": None,
|
42 |
+
"modelspec.author": None,
|
43 |
+
"modelspec.date": None,
|
44 |
+
# === Can ===
|
45 |
+
"modelspec.license": None,
|
46 |
+
"modelspec.tags": None,
|
47 |
+
"modelspec.merged_from": None,
|
48 |
+
"modelspec.prediction_type": None,
|
49 |
+
"modelspec.timestep_range": None,
|
50 |
+
"modelspec.encoder_layer": None,
|
51 |
+
}
|
52 |
+
|
53 |
+
# 別に使うやつだけ定義
|
54 |
+
MODELSPEC_TITLE = "modelspec.title"
|
55 |
+
|
56 |
+
ARCH_SD_V1 = "stable-diffusion-v1"
|
57 |
+
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
58 |
+
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
59 |
+
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
60 |
+
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
|
61 |
+
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
62 |
+
ARCH_FLUX_1_DEV = "flux-1-dev"
|
63 |
+
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
64 |
+
|
65 |
+
ADAPTER_LORA = "lora"
|
66 |
+
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
67 |
+
|
68 |
+
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
69 |
+
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
70 |
+
IMPL_DIFFUSERS = "diffusers"
|
71 |
+
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
72 |
+
|
73 |
+
PRED_TYPE_EPSILON = "epsilon"
|
74 |
+
PRED_TYPE_V = "v"
|
75 |
+
|
76 |
+
|
77 |
+
def load_bytes_in_safetensors(tensors):
|
78 |
+
bytes = safetensors.torch.save(tensors)
|
79 |
+
b = BytesIO(bytes)
|
80 |
+
|
81 |
+
b.seek(0)
|
82 |
+
header = b.read(8)
|
83 |
+
n = int.from_bytes(header, "little")
|
84 |
+
|
85 |
+
offset = n + 8
|
86 |
+
b.seek(offset)
|
87 |
+
|
88 |
+
return b.read()
|
89 |
+
|
90 |
+
|
91 |
+
def precalculate_safetensors_hashes(state_dict):
|
92 |
+
# calculate each tensor one by one to reduce memory usage
|
93 |
+
hash_sha256 = hashlib.sha256()
|
94 |
+
for tensor in state_dict.values():
|
95 |
+
single_tensor_sd = {"tensor": tensor}
|
96 |
+
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
|
97 |
+
hash_sha256.update(bytes_for_tensor)
|
98 |
+
|
99 |
+
return f"0x{hash_sha256.hexdigest()}"
|
100 |
+
|
101 |
+
|
102 |
+
def update_hash_sha256(metadata: dict, state_dict: dict):
|
103 |
+
raise NotImplementedError
|
104 |
+
|
105 |
+
|
106 |
+
def build_metadata(
|
107 |
+
state_dict: Optional[dict],
|
108 |
+
v2: bool,
|
109 |
+
v_parameterization: bool,
|
110 |
+
sdxl: bool,
|
111 |
+
lora: bool,
|
112 |
+
textual_inversion: bool,
|
113 |
+
timestamp: float,
|
114 |
+
title: Optional[str] = None,
|
115 |
+
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
116 |
+
is_stable_diffusion_ckpt: Optional[bool] = None,
|
117 |
+
author: Optional[str] = None,
|
118 |
+
description: Optional[str] = None,
|
119 |
+
license: Optional[str] = None,
|
120 |
+
tags: Optional[str] = None,
|
121 |
+
merged_from: Optional[str] = None,
|
122 |
+
timesteps: Optional[Tuple[int, int]] = None,
|
123 |
+
clip_skip: Optional[int] = None,
|
124 |
+
sd3: Optional[str] = None,
|
125 |
+
flux: Optional[str] = None,
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
sd3: only supports "m", flux: only supports "dev"
|
129 |
+
"""
|
130 |
+
# if state_dict is None, hash is not calculated
|
131 |
+
|
132 |
+
metadata = {}
|
133 |
+
metadata.update(BASE_METADATA)
|
134 |
+
|
135 |
+
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
|
136 |
+
# if state_dict is not None:
|
137 |
+
# hash = precalculate_safetensors_hashes(state_dict)
|
138 |
+
# metadata["modelspec.hash_sha256"] = hash
|
139 |
+
|
140 |
+
if sdxl:
|
141 |
+
arch = ARCH_SD_XL_V1_BASE
|
142 |
+
elif sd3 is not None:
|
143 |
+
arch = ARCH_SD3_M + "-" + sd3
|
144 |
+
elif flux is not None:
|
145 |
+
if flux == "dev":
|
146 |
+
arch = ARCH_FLUX_1_DEV
|
147 |
+
else:
|
148 |
+
arch = ARCH_FLUX_1_UNKNOWN
|
149 |
+
elif v2:
|
150 |
+
if v_parameterization:
|
151 |
+
arch = ARCH_SD_V2_768_V
|
152 |
+
else:
|
153 |
+
arch = ARCH_SD_V2_512
|
154 |
+
else:
|
155 |
+
arch = ARCH_SD_V1
|
156 |
+
|
157 |
+
if lora:
|
158 |
+
arch += f"/{ADAPTER_LORA}"
|
159 |
+
elif textual_inversion:
|
160 |
+
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
161 |
+
|
162 |
+
metadata["modelspec.architecture"] = arch
|
163 |
+
|
164 |
+
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
165 |
+
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
166 |
+
|
167 |
+
if flux is not None:
|
168 |
+
# Flux
|
169 |
+
impl = IMPL_FLUX
|
170 |
+
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
171 |
+
# Stable Diffusion ckpt, TI, SDXL LoRA
|
172 |
+
impl = IMPL_STABILITY_AI
|
173 |
+
else:
|
174 |
+
# v1/v2 LoRA or Diffusers
|
175 |
+
impl = IMPL_DIFFUSERS
|
176 |
+
metadata["modelspec.implementation"] = impl
|
177 |
+
|
178 |
+
if title is None:
|
179 |
+
if lora:
|
180 |
+
title = "LoRA"
|
181 |
+
elif textual_inversion:
|
182 |
+
title = "TextualInversion"
|
183 |
+
else:
|
184 |
+
title = "Checkpoint"
|
185 |
+
title += f"@{timestamp}"
|
186 |
+
metadata[MODELSPEC_TITLE] = title
|
187 |
+
|
188 |
+
if author is not None:
|
189 |
+
metadata["modelspec.author"] = author
|
190 |
+
else:
|
191 |
+
del metadata["modelspec.author"]
|
192 |
+
|
193 |
+
if description is not None:
|
194 |
+
metadata["modelspec.description"] = description
|
195 |
+
else:
|
196 |
+
del metadata["modelspec.description"]
|
197 |
+
|
198 |
+
if merged_from is not None:
|
199 |
+
metadata["modelspec.merged_from"] = merged_from
|
200 |
+
else:
|
201 |
+
del metadata["modelspec.merged_from"]
|
202 |
+
|
203 |
+
if license is not None:
|
204 |
+
metadata["modelspec.license"] = license
|
205 |
+
else:
|
206 |
+
del metadata["modelspec.license"]
|
207 |
+
|
208 |
+
if tags is not None:
|
209 |
+
metadata["modelspec.tags"] = tags
|
210 |
+
else:
|
211 |
+
del metadata["modelspec.tags"]
|
212 |
+
|
213 |
+
# remove microsecond from time
|
214 |
+
int_ts = int(timestamp)
|
215 |
+
|
216 |
+
# time to iso-8601 compliant date
|
217 |
+
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
218 |
+
metadata["modelspec.date"] = date
|
219 |
+
|
220 |
+
if reso is not None:
|
221 |
+
# comma separated to tuple
|
222 |
+
if isinstance(reso, str):
|
223 |
+
reso = tuple(map(int, reso.split(",")))
|
224 |
+
if len(reso) == 1:
|
225 |
+
reso = (reso[0], reso[0])
|
226 |
+
else:
|
227 |
+
# resolution is defined in dataset, so use default
|
228 |
+
if sdxl or sd3 is not None or flux is not None:
|
229 |
+
reso = 1024
|
230 |
+
elif v2 and v_parameterization:
|
231 |
+
reso = 768
|
232 |
+
else:
|
233 |
+
reso = 512
|
234 |
+
if isinstance(reso, int):
|
235 |
+
reso = (reso, reso)
|
236 |
+
|
237 |
+
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
238 |
+
|
239 |
+
if flux is not None:
|
240 |
+
del metadata["modelspec.prediction_type"]
|
241 |
+
elif v_parameterization:
|
242 |
+
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
243 |
+
else:
|
244 |
+
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
245 |
+
|
246 |
+
if timesteps is not None:
|
247 |
+
if isinstance(timesteps, str) or isinstance(timesteps, int):
|
248 |
+
timesteps = (timesteps, timesteps)
|
249 |
+
if len(timesteps) == 1:
|
250 |
+
timesteps = (timesteps[0], timesteps[0])
|
251 |
+
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
|
252 |
+
else:
|
253 |
+
del metadata["modelspec.timestep_range"]
|
254 |
+
|
255 |
+
if clip_skip is not None:
|
256 |
+
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
257 |
+
else:
|
258 |
+
del metadata["modelspec.encoder_layer"]
|
259 |
+
|
260 |
+
# # assert all values are filled
|
261 |
+
# assert all([v is not None for v in metadata.values()]), metadata
|
262 |
+
if not all([v is not None for v in metadata.values()]):
|
263 |
+
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
264 |
+
|
265 |
+
return metadata
|
266 |
+
|
267 |
+
|
268 |
+
# region utils
|
269 |
+
|
270 |
+
|
271 |
+
def get_title(metadata: dict) -> Optional[str]:
|
272 |
+
return metadata.get(MODELSPEC_TITLE, None)
|
273 |
+
|
274 |
+
|
275 |
+
def load_metadata_from_safetensors(model: str) -> dict:
|
276 |
+
if not model.endswith(".safetensors"):
|
277 |
+
return {}
|
278 |
+
|
279 |
+
with safetensors.safe_open(model, framework="pt") as f:
|
280 |
+
metadata = f.metadata()
|
281 |
+
if metadata is None:
|
282 |
+
metadata = {}
|
283 |
+
return metadata
|
284 |
+
|
285 |
+
|
286 |
+
def build_merged_from(models: List[str]) -> str:
|
287 |
+
def get_title(model: str):
|
288 |
+
metadata = load_metadata_from_safetensors(model)
|
289 |
+
title = metadata.get(MODELSPEC_TITLE, None)
|
290 |
+
if title is None:
|
291 |
+
title = os.path.splitext(os.path.basename(model))[0] # use filename
|
292 |
+
return title
|
293 |
+
|
294 |
+
titles = [get_title(model) for model in models]
|
295 |
+
return ", ".join(titles)
|
296 |
+
|
297 |
+
|
298 |
+
# endregion
|
299 |
+
|
300 |
+
|
301 |
+
r"""
|
302 |
+
if __name__ == "__main__":
|
303 |
+
import argparse
|
304 |
+
import torch
|
305 |
+
from safetensors.torch import load_file
|
306 |
+
from library import train_util
|
307 |
+
|
308 |
+
parser = argparse.ArgumentParser()
|
309 |
+
parser.add_argument("--ckpt", type=str, required=True)
|
310 |
+
args = parser.parse_args()
|
311 |
+
|
312 |
+
print(f"Loading {args.ckpt}")
|
313 |
+
state_dict = load_file(args.ckpt)
|
314 |
+
|
315 |
+
print(f"Calculating metadata")
|
316 |
+
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
|
317 |
+
print(metadata)
|
318 |
+
del state_dict
|
319 |
+
|
320 |
+
# by reference implementation
|
321 |
+
with open(args.ckpt, mode="rb") as file_data:
|
322 |
+
file_hash = hashlib.sha256()
|
323 |
+
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
|
324 |
+
header = json.loads(file_data.read(head_len[0])) # header itself, json string
|
325 |
+
content = (
|
326 |
+
file_data.read()
|
327 |
+
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
|
328 |
+
file_hash.update(content)
|
329 |
+
# ===== Update the hash for modelspec =====
|
330 |
+
by_ref = f"0x{file_hash.hexdigest()}"
|
331 |
+
print(by_ref)
|
332 |
+
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
|
333 |
+
|
334 |
+
"""
|
library/sd3_models.py
ADDED
@@ -0,0 +1,1413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref
|
2 |
+
# the original code is licensed under the MIT License
|
3 |
+
|
4 |
+
# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!
|
5 |
+
|
6 |
+
from ast import Tuple
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from functools import partial
|
10 |
+
import math
|
11 |
+
from types import SimpleNamespace
|
12 |
+
from typing import Dict, List, Optional, Union
|
13 |
+
import einops
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.utils.checkpoint import checkpoint
|
19 |
+
from transformers import CLIPTokenizer, T5TokenizerFast
|
20 |
+
|
21 |
+
from library import custom_offloading_utils
|
22 |
+
from library.device_utils import clean_memory_on_device
|
23 |
+
|
24 |
+
from .utils import setup_logging
|
25 |
+
|
26 |
+
setup_logging()
|
27 |
+
import logging
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
memory_efficient_attention = None
|
33 |
+
try:
|
34 |
+
import xformers
|
35 |
+
except:
|
36 |
+
pass
|
37 |
+
|
38 |
+
try:
|
39 |
+
from xformers.ops import memory_efficient_attention
|
40 |
+
except:
|
41 |
+
memory_efficient_attention = None
|
42 |
+
|
43 |
+
|
44 |
+
# region mmdit
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class SD3Params:
|
49 |
+
patch_size: int
|
50 |
+
depth: int
|
51 |
+
num_patches: int
|
52 |
+
pos_embed_max_size: int
|
53 |
+
adm_in_channels: int
|
54 |
+
qk_norm: Optional[str]
|
55 |
+
x_block_self_attn_layers: list[int]
|
56 |
+
context_embedder_in_features: int
|
57 |
+
context_embedder_out_features: int
|
58 |
+
model_type: str
|
59 |
+
|
60 |
+
|
61 |
+
def get_2d_sincos_pos_embed(
|
62 |
+
embed_dim,
|
63 |
+
grid_size,
|
64 |
+
scaling_factor=None,
|
65 |
+
offset=None,
|
66 |
+
):
|
67 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
68 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
69 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
70 |
+
grid = np.stack(grid, axis=0)
|
71 |
+
if scaling_factor is not None:
|
72 |
+
grid = grid / scaling_factor
|
73 |
+
if offset is not None:
|
74 |
+
grid = grid - offset
|
75 |
+
|
76 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
77 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
78 |
+
return pos_embed
|
79 |
+
|
80 |
+
|
81 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
82 |
+
assert embed_dim % 2 == 0
|
83 |
+
|
84 |
+
# use half of dimensions to encode grid_h
|
85 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
86 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
87 |
+
|
88 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
89 |
+
return emb
|
90 |
+
|
91 |
+
|
92 |
+
def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16):
|
93 |
+
"""
|
94 |
+
This function is contributed by KohakuBlueleaf. Thanks for the contribution!
|
95 |
+
|
96 |
+
Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions
|
97 |
+
when the resolution differs from the training resolution.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
embed_dim (int): Dimension of the positional embedding.
|
101 |
+
grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid.
|
102 |
+
cls_token (bool): Whether to include class token. Defaults to False.
|
103 |
+
extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0.
|
104 |
+
sample_size (int): Reference resolution (typically training resolution). Defaults to 64.
|
105 |
+
base_size (int): Base grid size used during training. Defaults to 16.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or
|
109 |
+
(H*W + extra_tokens, embed_dim) if cls_token is True.
|
110 |
+
"""
|
111 |
+
# Convert grid_size to tuple if it's an integer
|
112 |
+
if isinstance(grid_size, int):
|
113 |
+
grid_size = (grid_size, grid_size)
|
114 |
+
|
115 |
+
# Create normalized grid coordinates (0 to 1)
|
116 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0]
|
117 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1]
|
118 |
+
|
119 |
+
# Calculate scaling factors for height and width
|
120 |
+
# This ensures that the central region matches the original resolution's embeddings
|
121 |
+
scale_h = base_size * grid_size[0] / (sample_size)
|
122 |
+
scale_w = base_size * grid_size[1] / (sample_size)
|
123 |
+
|
124 |
+
# Calculate shift values to center the original resolution's embedding region
|
125 |
+
# This ensures that the central sample_size x sample_size region has similar
|
126 |
+
# positional embeddings to the original resolution
|
127 |
+
shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0])
|
128 |
+
shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1])
|
129 |
+
|
130 |
+
# Apply scaling and shifting to create the final grid coordinates
|
131 |
+
grid_h = grid_h * scale_h - shift_h
|
132 |
+
grid_w = grid_w * scale_w - shift_w
|
133 |
+
|
134 |
+
# Create 2D grid using meshgrid (note: w goes first)
|
135 |
+
grid = np.meshgrid(grid_w, grid_h)
|
136 |
+
grid = np.stack(grid, axis=0)
|
137 |
+
|
138 |
+
# # Calculate the starting indices for the central region
|
139 |
+
# # This is used for debugging/visualization of the central region
|
140 |
+
# st_h = (grid_size[0] - sample_size) // 2
|
141 |
+
# st_w = (grid_size[1] - sample_size) // 2
|
142 |
+
# print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size])
|
143 |
+
|
144 |
+
# Reshape grid for positional embedding calculation
|
145 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
146 |
+
|
147 |
+
# Generate the sinusoidal positional embeddings
|
148 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
149 |
+
|
150 |
+
# Add zeros for extra tokens (e.g., [CLS] token) if required
|
151 |
+
if cls_token and extra_tokens > 0:
|
152 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
153 |
+
|
154 |
+
return pos_embed
|
155 |
+
|
156 |
+
|
157 |
+
# if __name__ == "__main__":
|
158 |
+
# # This is what you get when you load SD3.5 state dict
|
159 |
+
# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed(
|
160 |
+
# 1536, [384, 384], sample_size=64, base_size=16
|
161 |
+
# )).float().unsqueeze(0)
|
162 |
+
|
163 |
+
|
164 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
165 |
+
"""
|
166 |
+
embed_dim: output dimension for each position
|
167 |
+
pos: a list of positions to be encoded: size (M,)
|
168 |
+
out: (M, D)
|
169 |
+
"""
|
170 |
+
assert embed_dim % 2 == 0
|
171 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
172 |
+
omega /= embed_dim / 2.0
|
173 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
174 |
+
|
175 |
+
pos = pos.reshape(-1) # (M,)
|
176 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
177 |
+
|
178 |
+
emb_sin = np.sin(out) # (M, D/2)
|
179 |
+
emb_cos = np.cos(out) # (M, D/2)
|
180 |
+
|
181 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
182 |
+
return emb
|
183 |
+
|
184 |
+
|
185 |
+
def get_1d_sincos_pos_embed_from_grid_torch(
|
186 |
+
embed_dim,
|
187 |
+
pos,
|
188 |
+
device=None,
|
189 |
+
dtype=torch.float32,
|
190 |
+
):
|
191 |
+
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
|
192 |
+
omega *= 2.0 / embed_dim
|
193 |
+
omega = 1.0 / 10000**omega
|
194 |
+
out = torch.outer(pos.reshape(-1), omega)
|
195 |
+
emb = torch.cat([out.sin(), out.cos()], dim=1)
|
196 |
+
return emb
|
197 |
+
|
198 |
+
|
199 |
+
def get_2d_sincos_pos_embed_torch(
|
200 |
+
embed_dim,
|
201 |
+
w,
|
202 |
+
h,
|
203 |
+
val_center=7.5,
|
204 |
+
val_magnitude=7.5,
|
205 |
+
device=None,
|
206 |
+
dtype=torch.float32,
|
207 |
+
):
|
208 |
+
small = min(h, w)
|
209 |
+
val_h = (h / small) * val_magnitude
|
210 |
+
val_w = (w / small) * val_magnitude
|
211 |
+
grid_h, grid_w = torch.meshgrid(
|
212 |
+
torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype),
|
213 |
+
torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype),
|
214 |
+
indexing="ij",
|
215 |
+
)
|
216 |
+
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
217 |
+
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
218 |
+
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
219 |
+
return emb
|
220 |
+
|
221 |
+
|
222 |
+
def modulate(x, shift, scale):
|
223 |
+
if shift is None:
|
224 |
+
shift = torch.zeros_like(scale)
|
225 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
226 |
+
|
227 |
+
|
228 |
+
def default(x, default_value):
|
229 |
+
if x is None:
|
230 |
+
return default_value
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
def timestep_embedding(t, dim, max_period=10000):
|
235 |
+
half = dim // 2
|
236 |
+
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
237 |
+
# device=t.device, dtype=t.dtype
|
238 |
+
# )
|
239 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
240 |
+
args = t[:, None].float() * freqs[None]
|
241 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
242 |
+
if dim % 2:
|
243 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
244 |
+
if torch.is_floating_point(t):
|
245 |
+
embedding = embedding.to(dtype=t.dtype)
|
246 |
+
return embedding
|
247 |
+
|
248 |
+
|
249 |
+
class PatchEmbed(nn.Module):
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
img_size=256,
|
253 |
+
patch_size=4,
|
254 |
+
in_channels=3,
|
255 |
+
embed_dim=512,
|
256 |
+
norm_layer=None,
|
257 |
+
flatten=True,
|
258 |
+
bias=True,
|
259 |
+
strict_img_size=True,
|
260 |
+
dynamic_img_pad=False,
|
261 |
+
):
|
262 |
+
# dynamic_img_pad and norm is omitted in SD3.5
|
263 |
+
super().__init__()
|
264 |
+
self.patch_size = patch_size
|
265 |
+
self.flatten = flatten
|
266 |
+
self.strict_img_size = strict_img_size
|
267 |
+
self.dynamic_img_pad = dynamic_img_pad
|
268 |
+
if img_size is not None:
|
269 |
+
self.img_size = img_size
|
270 |
+
self.grid_size = img_size // patch_size
|
271 |
+
self.num_patches = self.grid_size**2
|
272 |
+
else:
|
273 |
+
self.img_size = None
|
274 |
+
self.grid_size = None
|
275 |
+
self.num_patches = None
|
276 |
+
|
277 |
+
self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
|
278 |
+
self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
B, C, H, W = x.shape
|
282 |
+
|
283 |
+
if self.dynamic_img_pad:
|
284 |
+
# Pad input so we won't have partial patch
|
285 |
+
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
|
286 |
+
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
|
287 |
+
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
|
288 |
+
x = self.proj(x)
|
289 |
+
if self.flatten:
|
290 |
+
x = x.flatten(2).transpose(1, 2)
|
291 |
+
x = self.norm(x)
|
292 |
+
return x
|
293 |
+
|
294 |
+
|
295 |
+
# FinalLayer in mmdit.py
|
296 |
+
class UnPatch(nn.Module):
|
297 |
+
def __init__(self, hidden_size=512, patch_size=4, out_channels=3):
|
298 |
+
super().__init__()
|
299 |
+
self.patch_size = patch_size
|
300 |
+
self.c = out_channels
|
301 |
+
|
302 |
+
# eps is default in mmdit.py
|
303 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
304 |
+
self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels)
|
305 |
+
self.adaLN_modulation = nn.Sequential(
|
306 |
+
nn.SiLU(),
|
307 |
+
nn.Linear(hidden_size, 2 * hidden_size),
|
308 |
+
)
|
309 |
+
|
310 |
+
def forward(self, x: torch.Tensor, cmod, H=None, W=None):
|
311 |
+
b, n, _ = x.shape
|
312 |
+
p = self.patch_size
|
313 |
+
c = self.c
|
314 |
+
if H is None and W is None:
|
315 |
+
w = h = int(n**0.5)
|
316 |
+
assert h * w == n
|
317 |
+
else:
|
318 |
+
h = H // p if H else n // (W // p)
|
319 |
+
w = W // p if W else n // h
|
320 |
+
assert h * w == n
|
321 |
+
|
322 |
+
shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1)
|
323 |
+
x = modulate(self.norm_final(x), shift, scale)
|
324 |
+
x = self.linear(x)
|
325 |
+
|
326 |
+
x = x.view(b, h, w, p, p, c)
|
327 |
+
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
|
328 |
+
x = x.view(b, c, h * p, w * p)
|
329 |
+
return x
|
330 |
+
|
331 |
+
|
332 |
+
class MLP(nn.Module):
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
in_features,
|
336 |
+
hidden_features=None,
|
337 |
+
out_features=None,
|
338 |
+
act_layer=lambda: nn.GELU(),
|
339 |
+
norm_layer=None,
|
340 |
+
bias=True,
|
341 |
+
use_conv=False,
|
342 |
+
):
|
343 |
+
super().__init__()
|
344 |
+
out_features = out_features or in_features
|
345 |
+
hidden_features = hidden_features or in_features
|
346 |
+
self.use_conv = use_conv
|
347 |
+
|
348 |
+
layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear
|
349 |
+
|
350 |
+
self.fc1 = layer(in_features, hidden_features, bias=bias)
|
351 |
+
self.fc2 = layer(hidden_features, out_features, bias=bias)
|
352 |
+
self.act = act_layer()
|
353 |
+
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
|
354 |
+
|
355 |
+
def forward(self, x):
|
356 |
+
x = self.fc1(x)
|
357 |
+
x = self.act(x)
|
358 |
+
x = self.norm(x)
|
359 |
+
x = self.fc2(x)
|
360 |
+
return x
|
361 |
+
|
362 |
+
|
363 |
+
class TimestepEmbedding(nn.Module):
|
364 |
+
def __init__(self, hidden_size, freq_embed_size=256):
|
365 |
+
super().__init__()
|
366 |
+
self.mlp = nn.Sequential(
|
367 |
+
nn.Linear(freq_embed_size, hidden_size),
|
368 |
+
nn.SiLU(),
|
369 |
+
nn.Linear(hidden_size, hidden_size),
|
370 |
+
)
|
371 |
+
self.freq_embed_size = freq_embed_size
|
372 |
+
|
373 |
+
def forward(self, t, dtype=None, **kwargs):
|
374 |
+
t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype)
|
375 |
+
t_emb = self.mlp(t_freq)
|
376 |
+
return t_emb
|
377 |
+
|
378 |
+
|
379 |
+
class Embedder(nn.Module):
|
380 |
+
def __init__(self, input_dim, hidden_size):
|
381 |
+
super().__init__()
|
382 |
+
self.mlp = nn.Sequential(
|
383 |
+
nn.Linear(input_dim, hidden_size),
|
384 |
+
nn.SiLU(),
|
385 |
+
nn.Linear(hidden_size, hidden_size),
|
386 |
+
)
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
return self.mlp(x)
|
390 |
+
|
391 |
+
|
392 |
+
def rmsnorm(x, eps=1e-6):
|
393 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
394 |
+
|
395 |
+
|
396 |
+
class RMSNorm(torch.nn.Module):
|
397 |
+
def __init__(
|
398 |
+
self,
|
399 |
+
dim: int,
|
400 |
+
elementwise_affine: bool = False,
|
401 |
+
eps: float = 1e-6,
|
402 |
+
device=None,
|
403 |
+
dtype=None,
|
404 |
+
):
|
405 |
+
"""
|
406 |
+
Initialize the RMSNorm normalization layer.
|
407 |
+
Args:
|
408 |
+
dim (int): The dimension of the input tensor.
|
409 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
410 |
+
Attributes:
|
411 |
+
eps (float): A small value added to the denominator for numerical stability.
|
412 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
413 |
+
"""
|
414 |
+
super().__init__()
|
415 |
+
self.eps = eps
|
416 |
+
self.learnable_scale = elementwise_affine
|
417 |
+
if self.learnable_scale:
|
418 |
+
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
419 |
+
else:
|
420 |
+
self.register_parameter("weight", None)
|
421 |
+
|
422 |
+
def forward(self, x):
|
423 |
+
"""
|
424 |
+
Forward pass through the RMSNorm layer.
|
425 |
+
Args:
|
426 |
+
x (torch.Tensor): The input tensor.
|
427 |
+
Returns:
|
428 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
429 |
+
"""
|
430 |
+
x = rmsnorm(x, eps=self.eps)
|
431 |
+
if self.learnable_scale:
|
432 |
+
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
433 |
+
else:
|
434 |
+
return x
|
435 |
+
|
436 |
+
|
437 |
+
class SwiGLUFeedForward(nn.Module):
|
438 |
+
def __init__(
|
439 |
+
self,
|
440 |
+
dim: int,
|
441 |
+
hidden_dim: int,
|
442 |
+
multiple_of: int,
|
443 |
+
ffn_dim_multiplier: float = None,
|
444 |
+
):
|
445 |
+
super().__init__()
|
446 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
447 |
+
# custom dim factor multiplier
|
448 |
+
if ffn_dim_multiplier is not None:
|
449 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
450 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
451 |
+
|
452 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
453 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
454 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
455 |
+
|
456 |
+
def forward(self, x):
|
457 |
+
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
458 |
+
|
459 |
+
|
460 |
+
# Linears for SelfAttention in mmdit.py
|
461 |
+
class AttentionLinears(nn.Module):
|
462 |
+
def __init__(
|
463 |
+
self,
|
464 |
+
dim: int,
|
465 |
+
num_heads: int = 8,
|
466 |
+
qkv_bias: bool = False,
|
467 |
+
pre_only: bool = False,
|
468 |
+
qk_norm: Optional[str] = None,
|
469 |
+
):
|
470 |
+
super().__init__()
|
471 |
+
self.num_heads = num_heads
|
472 |
+
self.head_dim = dim // num_heads
|
473 |
+
|
474 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
475 |
+
if not pre_only:
|
476 |
+
self.proj = nn.Linear(dim, dim)
|
477 |
+
self.pre_only = pre_only
|
478 |
+
|
479 |
+
if qk_norm == "rms":
|
480 |
+
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
|
481 |
+
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
|
482 |
+
elif qk_norm == "ln":
|
483 |
+
self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
|
484 |
+
self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
|
485 |
+
elif qk_norm is None:
|
486 |
+
self.ln_q = nn.Identity()
|
487 |
+
self.ln_k = nn.Identity()
|
488 |
+
else:
|
489 |
+
raise ValueError(qk_norm)
|
490 |
+
|
491 |
+
def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
|
492 |
+
"""
|
493 |
+
output:
|
494 |
+
q, k, v: [B, L, D]
|
495 |
+
"""
|
496 |
+
B, L, C = x.shape
|
497 |
+
qkv: torch.Tensor = self.qkv(x)
|
498 |
+
q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2)
|
499 |
+
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
500 |
+
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
501 |
+
return (q, k, v)
|
502 |
+
|
503 |
+
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
504 |
+
assert not self.pre_only
|
505 |
+
x = self.proj(x)
|
506 |
+
return x
|
507 |
+
|
508 |
+
|
509 |
+
MEMORY_LAYOUTS = {
|
510 |
+
"torch": (
|
511 |
+
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
|
512 |
+
lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
|
513 |
+
lambda x: (1, x, 1, 1),
|
514 |
+
),
|
515 |
+
"xformers": (
|
516 |
+
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim),
|
517 |
+
lambda x: x.reshape(x.shape[0], x.shape[1], -1),
|
518 |
+
lambda x: (1, 1, x, 1),
|
519 |
+
),
|
520 |
+
"math": (
|
521 |
+
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
|
522 |
+
lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
|
523 |
+
lambda x: (1, x, 1, 1),
|
524 |
+
),
|
525 |
+
}
|
526 |
+
# ATTN_FUNCTION = {
|
527 |
+
# "torch": F.scaled_dot_product_attention,
|
528 |
+
# "xformers": memory_efficient_attention,
|
529 |
+
# }
|
530 |
+
|
531 |
+
|
532 |
+
def vanilla_attention(q, k, v, mask, scale=None):
|
533 |
+
if scale is None:
|
534 |
+
scale = math.sqrt(q.size(-1))
|
535 |
+
scores = torch.bmm(q, k.transpose(-1, -2)) / scale
|
536 |
+
if mask is not None:
|
537 |
+
mask = einops.rearrange(mask, "b ... -> b (...)")
|
538 |
+
max_neg_value = -torch.finfo(scores.dtype).max
|
539 |
+
mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3))
|
540 |
+
scores = scores.masked_fill(~mask, max_neg_value)
|
541 |
+
p_attn = F.softmax(scores, dim=-1)
|
542 |
+
return torch.bmm(p_attn, v)
|
543 |
+
|
544 |
+
|
545 |
+
def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"):
|
546 |
+
"""
|
547 |
+
q, k, v: [B, L, D]
|
548 |
+
"""
|
549 |
+
pre_attn_layout = MEMORY_LAYOUTS[mode][0]
|
550 |
+
post_attn_layout = MEMORY_LAYOUTS[mode][1]
|
551 |
+
q = pre_attn_layout(q, head_dim)
|
552 |
+
k = pre_attn_layout(k, head_dim)
|
553 |
+
v = pre_attn_layout(v, head_dim)
|
554 |
+
|
555 |
+
# scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale)
|
556 |
+
if mode == "torch":
|
557 |
+
assert scale is None
|
558 |
+
scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale)
|
559 |
+
elif mode == "xformers":
|
560 |
+
scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale)
|
561 |
+
else:
|
562 |
+
scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale)
|
563 |
+
|
564 |
+
scores = post_attn_layout(scores)
|
565 |
+
return scores
|
566 |
+
|
567 |
+
|
568 |
+
# DismantledBlock in mmdit.py
|
569 |
+
class SingleDiTBlock(nn.Module):
|
570 |
+
"""
|
571 |
+
A DiT block with gated adaptive layer norm (adaLN) conditioning.
|
572 |
+
"""
|
573 |
+
|
574 |
+
def __init__(
|
575 |
+
self,
|
576 |
+
hidden_size: int,
|
577 |
+
num_heads: int,
|
578 |
+
mlp_ratio: float = 4.0,
|
579 |
+
attn_mode: str = "xformers",
|
580 |
+
qkv_bias: bool = False,
|
581 |
+
pre_only: bool = False,
|
582 |
+
rmsnorm: bool = False,
|
583 |
+
scale_mod_only: bool = False,
|
584 |
+
swiglu: bool = False,
|
585 |
+
qk_norm: Optional[str] = None,
|
586 |
+
x_block_self_attn: bool = False,
|
587 |
+
**block_kwargs,
|
588 |
+
):
|
589 |
+
super().__init__()
|
590 |
+
assert attn_mode in MEMORY_LAYOUTS
|
591 |
+
self.attn_mode = attn_mode
|
592 |
+
if not rmsnorm:
|
593 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
594 |
+
else:
|
595 |
+
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
596 |
+
self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm)
|
597 |
+
|
598 |
+
self.x_block_self_attn = x_block_self_attn
|
599 |
+
if self.x_block_self_attn:
|
600 |
+
assert not pre_only
|
601 |
+
assert not scale_mod_only
|
602 |
+
self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm)
|
603 |
+
|
604 |
+
if not pre_only:
|
605 |
+
if not rmsnorm:
|
606 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
607 |
+
else:
|
608 |
+
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
609 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
610 |
+
if not pre_only:
|
611 |
+
if not swiglu:
|
612 |
+
self.mlp = MLP(
|
613 |
+
in_features=hidden_size,
|
614 |
+
hidden_features=mlp_hidden_dim,
|
615 |
+
act_layer=lambda: nn.GELU(approximate="tanh"),
|
616 |
+
)
|
617 |
+
else:
|
618 |
+
self.mlp = SwiGLUFeedForward(
|
619 |
+
dim=hidden_size,
|
620 |
+
hidden_dim=mlp_hidden_dim,
|
621 |
+
multiple_of=256,
|
622 |
+
)
|
623 |
+
self.scale_mod_only = scale_mod_only
|
624 |
+
if self.x_block_self_attn:
|
625 |
+
n_mods = 9
|
626 |
+
elif not scale_mod_only:
|
627 |
+
n_mods = 6 if not pre_only else 2
|
628 |
+
else:
|
629 |
+
n_mods = 4 if not pre_only else 1
|
630 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size))
|
631 |
+
self.pre_only = pre_only
|
632 |
+
|
633 |
+
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
634 |
+
if not self.pre_only:
|
635 |
+
if not self.scale_mod_only:
|
636 |
+
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1)
|
637 |
+
else:
|
638 |
+
shift_msa = None
|
639 |
+
shift_mlp = None
|
640 |
+
(scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1)
|
641 |
+
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
642 |
+
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
643 |
+
else:
|
644 |
+
if not self.scale_mod_only:
|
645 |
+
(shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1)
|
646 |
+
else:
|
647 |
+
shift_msa = None
|
648 |
+
scale_msa = self.adaLN_modulation(c)
|
649 |
+
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
650 |
+
return qkv, None
|
651 |
+
|
652 |
+
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
653 |
+
assert self.x_block_self_attn
|
654 |
+
(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation(
|
655 |
+
c
|
656 |
+
).chunk(9, dim=1)
|
657 |
+
x_norm = self.norm1(x)
|
658 |
+
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
659 |
+
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
660 |
+
return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2)
|
661 |
+
|
662 |
+
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
663 |
+
assert not self.pre_only
|
664 |
+
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
665 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
666 |
+
return x
|
667 |
+
|
668 |
+
def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0):
|
669 |
+
assert not self.pre_only
|
670 |
+
if attn1_dropout > 0.0:
|
671 |
+
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
|
672 |
+
attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
|
673 |
+
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
|
674 |
+
else:
|
675 |
+
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
676 |
+
x = x + attn_
|
677 |
+
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
|
678 |
+
x = x + attn2_
|
679 |
+
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
680 |
+
x = x + mlp_
|
681 |
+
return x
|
682 |
+
|
683 |
+
|
684 |
+
# JointBlock + block_mixing in mmdit.py
|
685 |
+
class MMDiTBlock(nn.Module):
|
686 |
+
def __init__(self, *args, **kwargs):
|
687 |
+
super().__init__()
|
688 |
+
pre_only = kwargs.pop("pre_only")
|
689 |
+
x_block_self_attn = kwargs.pop("x_block_self_attn")
|
690 |
+
|
691 |
+
self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
|
692 |
+
self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)
|
693 |
+
|
694 |
+
self.head_dim = self.x_block.attn.head_dim
|
695 |
+
self.mode = self.x_block.attn_mode
|
696 |
+
self.gradient_checkpointing = False
|
697 |
+
|
698 |
+
def enable_gradient_checkpointing(self):
|
699 |
+
self.gradient_checkpointing = True
|
700 |
+
|
701 |
+
def _forward(self, context, x, c):
|
702 |
+
ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c)
|
703 |
+
|
704 |
+
if self.x_block.x_block_self_attn:
|
705 |
+
x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c)
|
706 |
+
else:
|
707 |
+
x_qkv, x_intermediates = self.x_block.pre_attention(x, c)
|
708 |
+
|
709 |
+
ctx_len = ctx_qkv[0].size(1)
|
710 |
+
|
711 |
+
q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1)
|
712 |
+
k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1)
|
713 |
+
v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1)
|
714 |
+
|
715 |
+
attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode)
|
716 |
+
ctx_attn_out = attn[:, :ctx_len]
|
717 |
+
x_attn_out = attn[:, ctx_len:]
|
718 |
+
|
719 |
+
if self.x_block.x_block_self_attn:
|
720 |
+
x_q2, x_k2, x_v2 = x_qkv2
|
721 |
+
attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode)
|
722 |
+
x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates)
|
723 |
+
else:
|
724 |
+
x = self.x_block.post_attention(x_attn_out, *x_intermediates)
|
725 |
+
|
726 |
+
if not self.context_block.pre_only:
|
727 |
+
context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate)
|
728 |
+
else:
|
729 |
+
context = None
|
730 |
+
|
731 |
+
return context, x
|
732 |
+
|
733 |
+
def forward(self, *args, **kwargs):
|
734 |
+
if self.training and self.gradient_checkpointing:
|
735 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
736 |
+
else:
|
737 |
+
return self._forward(*args, **kwargs)
|
738 |
+
|
739 |
+
|
740 |
+
class MMDiT(nn.Module):
|
741 |
+
"""
|
742 |
+
Diffusion model with a Transformer backbone.
|
743 |
+
"""
|
744 |
+
|
745 |
+
# prepare pos_embed for latent size * 2
|
746 |
+
POS_EMBED_MAX_RATIO = 1.5
|
747 |
+
|
748 |
+
def __init__(
|
749 |
+
self,
|
750 |
+
input_size: int = 32,
|
751 |
+
patch_size: int = 2,
|
752 |
+
in_channels: int = 4,
|
753 |
+
depth: int = 28,
|
754 |
+
# hidden_size: Optional[int] = None,
|
755 |
+
# num_heads: Optional[int] = None,
|
756 |
+
mlp_ratio: float = 4.0,
|
757 |
+
learn_sigma: bool = False,
|
758 |
+
adm_in_channels: Optional[int] = None,
|
759 |
+
context_embedder_in_features: Optional[int] = None,
|
760 |
+
context_embedder_out_features: Optional[int] = None,
|
761 |
+
use_checkpoint: bool = False,
|
762 |
+
register_length: int = 0,
|
763 |
+
attn_mode: str = "torch",
|
764 |
+
rmsnorm: bool = False,
|
765 |
+
scale_mod_only: bool = False,
|
766 |
+
swiglu: bool = False,
|
767 |
+
out_channels: Optional[int] = None,
|
768 |
+
pos_embed_scaling_factor: Optional[float] = None,
|
769 |
+
pos_embed_offset: Optional[float] = None,
|
770 |
+
pos_embed_max_size: Optional[int] = None,
|
771 |
+
num_patches=None,
|
772 |
+
qk_norm: Optional[str] = None,
|
773 |
+
x_block_self_attn_layers: Optional[list[int]] = [],
|
774 |
+
qkv_bias: bool = True,
|
775 |
+
pos_emb_random_crop_rate: float = 0.0,
|
776 |
+
use_scaled_pos_embed: bool = False,
|
777 |
+
pos_embed_latent_sizes: Optional[list[int]] = None,
|
778 |
+
model_type: str = "sd3m",
|
779 |
+
):
|
780 |
+
super().__init__()
|
781 |
+
self._model_type = model_type
|
782 |
+
self.learn_sigma = learn_sigma
|
783 |
+
self.in_channels = in_channels
|
784 |
+
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
785 |
+
self.out_channels = default(out_channels, default_out_channels)
|
786 |
+
self.patch_size = patch_size
|
787 |
+
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
788 |
+
self.pos_embed_offset = pos_embed_offset
|
789 |
+
self.pos_embed_max_size = pos_embed_max_size
|
790 |
+
self.x_block_self_attn_layers = x_block_self_attn_layers
|
791 |
+
self.pos_emb_random_crop_rate = pos_emb_random_crop_rate
|
792 |
+
self.gradient_checkpointing = use_checkpoint
|
793 |
+
|
794 |
+
# hidden_size = default(hidden_size, 64 * depth)
|
795 |
+
# num_heads = default(num_heads, hidden_size // 64)
|
796 |
+
|
797 |
+
# apply magic --> this defines a head_size of 64
|
798 |
+
self.hidden_size = 64 * depth
|
799 |
+
num_heads = depth
|
800 |
+
|
801 |
+
self.num_heads = num_heads
|
802 |
+
|
803 |
+
self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes)
|
804 |
+
|
805 |
+
self.x_embedder = PatchEmbed(
|
806 |
+
input_size,
|
807 |
+
patch_size,
|
808 |
+
in_channels,
|
809 |
+
self.hidden_size,
|
810 |
+
bias=True,
|
811 |
+
strict_img_size=self.pos_embed_max_size is None,
|
812 |
+
)
|
813 |
+
self.t_embedder = TimestepEmbedding(self.hidden_size)
|
814 |
+
|
815 |
+
self.y_embedder = None
|
816 |
+
if adm_in_channels is not None:
|
817 |
+
assert isinstance(adm_in_channels, int)
|
818 |
+
self.y_embedder = Embedder(adm_in_channels, self.hidden_size)
|
819 |
+
|
820 |
+
if context_embedder_in_features is not None:
|
821 |
+
self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features)
|
822 |
+
else:
|
823 |
+
self.context_embedder = nn.Identity()
|
824 |
+
|
825 |
+
self.register_length = register_length
|
826 |
+
if self.register_length > 0:
|
827 |
+
self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size))
|
828 |
+
|
829 |
+
# num_patches = self.x_embedder.num_patches
|
830 |
+
# Will use fixed sin-cos embedding:
|
831 |
+
# just use a buffer already
|
832 |
+
if num_patches is not None:
|
833 |
+
self.register_buffer(
|
834 |
+
"pos_embed",
|
835 |
+
torch.empty(1, num_patches, self.hidden_size),
|
836 |
+
)
|
837 |
+
else:
|
838 |
+
self.pos_embed = None
|
839 |
+
|
840 |
+
self.use_checkpoint = use_checkpoint
|
841 |
+
self.joint_blocks = nn.ModuleList(
|
842 |
+
[
|
843 |
+
MMDiTBlock(
|
844 |
+
self.hidden_size,
|
845 |
+
num_heads,
|
846 |
+
mlp_ratio=mlp_ratio,
|
847 |
+
attn_mode=attn_mode,
|
848 |
+
qkv_bias=qkv_bias,
|
849 |
+
pre_only=i == depth - 1,
|
850 |
+
rmsnorm=rmsnorm,
|
851 |
+
scale_mod_only=scale_mod_only,
|
852 |
+
swiglu=swiglu,
|
853 |
+
qk_norm=qk_norm,
|
854 |
+
x_block_self_attn=(i in self.x_block_self_attn_layers),
|
855 |
+
)
|
856 |
+
for i in range(depth)
|
857 |
+
]
|
858 |
+
)
|
859 |
+
for block in self.joint_blocks:
|
860 |
+
block.gradient_checkpointing = use_checkpoint
|
861 |
+
|
862 |
+
self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels)
|
863 |
+
# self.initialize_weights()
|
864 |
+
|
865 |
+
self.blocks_to_swap = None
|
866 |
+
self.offloader = None
|
867 |
+
self.num_blocks = len(self.joint_blocks)
|
868 |
+
|
869 |
+
def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
|
870 |
+
self.use_scaled_pos_embed = use_scaled_pos_embed
|
871 |
+
|
872 |
+
if self.use_scaled_pos_embed:
|
873 |
+
# remove pos_embed to free up memory up to 0.4 GB
|
874 |
+
self.pos_embed = None
|
875 |
+
|
876 |
+
# remove duplicates and sort latent sizes in ascending order
|
877 |
+
latent_sizes = list(set(latent_sizes))
|
878 |
+
latent_sizes = sorted(latent_sizes)
|
879 |
+
|
880 |
+
patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
|
881 |
+
|
882 |
+
# calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape
|
883 |
+
max_areas = []
|
884 |
+
for i in range(1, len(patched_sizes)):
|
885 |
+
prev_area = patched_sizes[i - 1] ** 2
|
886 |
+
area = patched_sizes[i] ** 2
|
887 |
+
max_areas.append((prev_area + area) // 2)
|
888 |
+
|
889 |
+
# area of the last latent size, if the latent size exceeds this, error will be raised
|
890 |
+
max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2))
|
891 |
+
# print("max_areas", max_areas)
|
892 |
+
|
893 |
+
self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)]
|
894 |
+
|
895 |
+
self.resolution_pos_embeds = {}
|
896 |
+
for patched_size in patched_sizes:
|
897 |
+
grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
|
898 |
+
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
|
899 |
+
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
|
900 |
+
self.resolution_pos_embeds[patched_size] = pos_embed
|
901 |
+
# print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}")
|
902 |
+
|
903 |
+
else:
|
904 |
+
self.resolution_area_to_latent_size = None
|
905 |
+
self.resolution_pos_embeds = None
|
906 |
+
|
907 |
+
@property
|
908 |
+
def model_type(self):
|
909 |
+
return self._model_type
|
910 |
+
|
911 |
+
@property
|
912 |
+
def device(self):
|
913 |
+
return next(self.parameters()).device
|
914 |
+
|
915 |
+
@property
|
916 |
+
def dtype(self):
|
917 |
+
return next(self.parameters()).dtype
|
918 |
+
|
919 |
+
def enable_gradient_checkpointing(self):
|
920 |
+
self.gradient_checkpointing = True
|
921 |
+
for block in self.joint_blocks:
|
922 |
+
block.enable_gradient_checkpointing()
|
923 |
+
|
924 |
+
def disable_gradient_checkpointing(self):
|
925 |
+
self.gradient_checkpointing = False
|
926 |
+
for block in self.joint_blocks:
|
927 |
+
block.disable_gradient_checkpointing()
|
928 |
+
|
929 |
+
def initialize_weights(self):
|
930 |
+
# TODO: Init context_embedder?
|
931 |
+
# Initialize transformer layers:
|
932 |
+
def _basic_init(module):
|
933 |
+
if isinstance(module, nn.Linear):
|
934 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
935 |
+
if module.bias is not None:
|
936 |
+
nn.init.constant_(module.bias, 0)
|
937 |
+
|
938 |
+
self.apply(_basic_init)
|
939 |
+
|
940 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding
|
941 |
+
if self.pos_embed is not None:
|
942 |
+
pos_embed = get_2d_sincos_pos_embed(
|
943 |
+
self.pos_embed.shape[-1],
|
944 |
+
int(self.pos_embed.shape[-2] ** 0.5),
|
945 |
+
scaling_factor=self.pos_embed_scaling_factor,
|
946 |
+
)
|
947 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
948 |
+
|
949 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
950 |
+
w = self.x_embedder.proj.weight.data
|
951 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
952 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
953 |
+
|
954 |
+
if getattr(self, "y_embedder", None) is not None:
|
955 |
+
nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
|
956 |
+
nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)
|
957 |
+
|
958 |
+
# Initialize timestep embedding MLP:
|
959 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
960 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
961 |
+
|
962 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
963 |
+
for block in self.joint_blocks:
|
964 |
+
nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
|
965 |
+
nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
|
966 |
+
nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
|
967 |
+
nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)
|
968 |
+
|
969 |
+
# Zero-out output layers:
|
970 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
971 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
972 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
973 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
974 |
+
|
975 |
+
def set_pos_emb_random_crop_rate(self, rate: float):
|
976 |
+
self.pos_emb_random_crop_rate = rate
|
977 |
+
|
978 |
+
def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False):
|
979 |
+
p = self.x_embedder.patch_size
|
980 |
+
# patched size
|
981 |
+
h = (h + 1) // p
|
982 |
+
w = (w + 1) // p
|
983 |
+
if self.pos_embed is None: # should not happen
|
984 |
+
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
|
985 |
+
assert self.pos_embed_max_size is not None
|
986 |
+
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
987 |
+
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
988 |
+
|
989 |
+
if not random_crop:
|
990 |
+
top = (self.pos_embed_max_size - h) // 2
|
991 |
+
left = (self.pos_embed_max_size - w) // 2
|
992 |
+
else:
|
993 |
+
top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item()
|
994 |
+
left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item()
|
995 |
+
|
996 |
+
spatial_pos_embed = self.pos_embed.reshape(
|
997 |
+
1,
|
998 |
+
self.pos_embed_max_size,
|
999 |
+
self.pos_embed_max_size,
|
1000 |
+
self.pos_embed.shape[-1],
|
1001 |
+
)
|
1002 |
+
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
1003 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
1004 |
+
return spatial_pos_embed
|
1005 |
+
|
1006 |
+
def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False):
|
1007 |
+
p = self.x_embedder.patch_size
|
1008 |
+
# patched size
|
1009 |
+
h = (h + 1) // p
|
1010 |
+
w = (w + 1) // p
|
1011 |
+
|
1012 |
+
# select pos_embed size based on area
|
1013 |
+
area = h * w
|
1014 |
+
patched_size = None
|
1015 |
+
for area_, patched_size_ in self.resolution_area_to_latent_size:
|
1016 |
+
if area <= area_:
|
1017 |
+
patched_size = patched_size_
|
1018 |
+
break
|
1019 |
+
if patched_size is None:
|
1020 |
+
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
|
1021 |
+
|
1022 |
+
pos_embed = self.resolution_pos_embeds[patched_size]
|
1023 |
+
pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
|
1024 |
+
if h > pos_embed_size or w > pos_embed_size:
|
1025 |
+
# # fallback to normal pos_embed
|
1026 |
+
# return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
|
1027 |
+
# extend pos_embed size
|
1028 |
+
logger.warning(
|
1029 |
+
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
|
1030 |
+
)
|
1031 |
+
pos_embed_size = max(h, w)
|
1032 |
+
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
|
1033 |
+
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
|
1034 |
+
self.resolution_pos_embeds[patched_size] = pos_embed
|
1035 |
+
logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
|
1036 |
+
|
1037 |
+
if not random_crop:
|
1038 |
+
top = (pos_embed_size - h) // 2
|
1039 |
+
left = (pos_embed_size - w) // 2
|
1040 |
+
else:
|
1041 |
+
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
|
1042 |
+
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
|
1043 |
+
|
1044 |
+
if pos_embed.device != device:
|
1045 |
+
pos_embed = pos_embed.to(device)
|
1046 |
+
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
|
1047 |
+
self.resolution_pos_embeds[patched_size] = pos_embed # update device
|
1048 |
+
if pos_embed.dtype != dtype:
|
1049 |
+
pos_embed = pos_embed.to(dtype)
|
1050 |
+
self.resolution_pos_embeds[patched_size] = pos_embed # update dtype
|
1051 |
+
|
1052 |
+
spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1])
|
1053 |
+
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
1054 |
+
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
1055 |
+
# print(
|
1056 |
+
# f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}"
|
1057 |
+
# )
|
1058 |
+
return spatial_pos_embed
|
1059 |
+
|
1060 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
1061 |
+
self.blocks_to_swap = num_blocks
|
1062 |
+
|
1063 |
+
assert (
|
1064 |
+
self.blocks_to_swap <= self.num_blocks - 2
|
1065 |
+
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
|
1066 |
+
|
1067 |
+
self.offloader = custom_offloading_utils.ModelOffloader(
|
1068 |
+
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
|
1069 |
+
)
|
1070 |
+
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
|
1071 |
+
|
1072 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
1073 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
1074 |
+
if self.blocks_to_swap:
|
1075 |
+
save_blocks = self.joint_blocks
|
1076 |
+
self.joint_blocks = None
|
1077 |
+
|
1078 |
+
self.to(device)
|
1079 |
+
|
1080 |
+
if self.blocks_to_swap:
|
1081 |
+
self.joint_blocks = save_blocks
|
1082 |
+
|
1083 |
+
def prepare_block_swap_before_forward(self):
|
1084 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
1085 |
+
return
|
1086 |
+
self.offloader.prepare_block_devices_before_forward(self.joint_blocks)
|
1087 |
+
|
1088 |
+
def forward(
|
1089 |
+
self,
|
1090 |
+
x: torch.Tensor,
|
1091 |
+
t: torch.Tensor,
|
1092 |
+
y: Optional[torch.Tensor] = None,
|
1093 |
+
context: Optional[torch.Tensor] = None,
|
1094 |
+
) -> torch.Tensor:
|
1095 |
+
"""
|
1096 |
+
Forward pass of DiT.
|
1097 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
1098 |
+
t: (N,) tensor of diffusion timesteps
|
1099 |
+
y: (N, D) tensor of class labels
|
1100 |
+
"""
|
1101 |
+
pos_emb_random_crop = (
|
1102 |
+
False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate
|
1103 |
+
)
|
1104 |
+
|
1105 |
+
B, C, H, W = x.shape
|
1106 |
+
|
1107 |
+
# x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
1108 |
+
if not self.use_scaled_pos_embed:
|
1109 |
+
pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
|
1110 |
+
else:
|
1111 |
+
# print(f"Using scaled pos_embed for size {H}x{W}")
|
1112 |
+
pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop)
|
1113 |
+
x = self.x_embedder(x) + pos_embed
|
1114 |
+
del pos_embed
|
1115 |
+
|
1116 |
+
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
1117 |
+
if y is not None and self.y_embedder is not None:
|
1118 |
+
y = self.y_embedder(y) # (N, D)
|
1119 |
+
c = c + y # (N, D)
|
1120 |
+
|
1121 |
+
if context is not None:
|
1122 |
+
context = self.context_embedder(context)
|
1123 |
+
|
1124 |
+
if self.register_length > 0:
|
1125 |
+
context = torch.cat(
|
1126 |
+
(einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
if not self.blocks_to_swap:
|
1130 |
+
for block in self.joint_blocks:
|
1131 |
+
context, x = block(context, x, c)
|
1132 |
+
else:
|
1133 |
+
for block_idx, block in enumerate(self.joint_blocks):
|
1134 |
+
self.offloader.wait_for_block(block_idx)
|
1135 |
+
|
1136 |
+
context, x = block(context, x, c)
|
1137 |
+
|
1138 |
+
self.offloader.submit_move_blocks(self.joint_blocks, block_idx)
|
1139 |
+
|
1140 |
+
x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
|
1141 |
+
return x[:, :, :H, :W]
|
1142 |
+
|
1143 |
+
|
1144 |
+
def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT:
|
1145 |
+
mmdit = MMDiT(
|
1146 |
+
input_size=None,
|
1147 |
+
pos_embed_max_size=params.pos_embed_max_size,
|
1148 |
+
patch_size=params.patch_size,
|
1149 |
+
in_channels=16,
|
1150 |
+
adm_in_channels=params.adm_in_channels,
|
1151 |
+
context_embedder_in_features=params.context_embedder_in_features,
|
1152 |
+
context_embedder_out_features=params.context_embedder_out_features,
|
1153 |
+
depth=params.depth,
|
1154 |
+
mlp_ratio=4,
|
1155 |
+
qk_norm=params.qk_norm,
|
1156 |
+
x_block_self_attn_layers=params.x_block_self_attn_layers,
|
1157 |
+
num_patches=params.num_patches,
|
1158 |
+
attn_mode=attn_mode,
|
1159 |
+
model_type=params.model_type,
|
1160 |
+
)
|
1161 |
+
return mmdit
|
1162 |
+
|
1163 |
+
|
1164 |
+
# endregion
|
1165 |
+
|
1166 |
+
# region VAE
|
1167 |
+
|
1168 |
+
VAE_SCALE_FACTOR = 1.5305
|
1169 |
+
VAE_SHIFT_FACTOR = 0.0609
|
1170 |
+
|
1171 |
+
|
1172 |
+
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
|
1173 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
1174 |
+
|
1175 |
+
|
1176 |
+
class ResnetBlock(torch.nn.Module):
|
1177 |
+
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
|
1178 |
+
super().__init__()
|
1179 |
+
self.in_channels = in_channels
|
1180 |
+
out_channels = in_channels if out_channels is None else out_channels
|
1181 |
+
self.out_channels = out_channels
|
1182 |
+
|
1183 |
+
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
|
1184 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1185 |
+
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
|
1186 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1187 |
+
if self.in_channels != self.out_channels:
|
1188 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
1189 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device
|
1190 |
+
)
|
1191 |
+
else:
|
1192 |
+
self.nin_shortcut = None
|
1193 |
+
self.swish = torch.nn.SiLU(inplace=True)
|
1194 |
+
|
1195 |
+
def forward(self, x):
|
1196 |
+
hidden = x
|
1197 |
+
hidden = self.norm1(hidden)
|
1198 |
+
hidden = self.swish(hidden)
|
1199 |
+
hidden = self.conv1(hidden)
|
1200 |
+
hidden = self.norm2(hidden)
|
1201 |
+
hidden = self.swish(hidden)
|
1202 |
+
hidden = self.conv2(hidden)
|
1203 |
+
if self.in_channels != self.out_channels:
|
1204 |
+
x = self.nin_shortcut(x)
|
1205 |
+
return x + hidden
|
1206 |
+
|
1207 |
+
|
1208 |
+
class AttnBlock(torch.nn.Module):
|
1209 |
+
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
1210 |
+
super().__init__()
|
1211 |
+
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
1212 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
1213 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
1214 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
1215 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
1216 |
+
|
1217 |
+
def forward(self, x):
|
1218 |
+
hidden = self.norm(x)
|
1219 |
+
q = self.q(hidden)
|
1220 |
+
k = self.k(hidden)
|
1221 |
+
v = self.v(hidden)
|
1222 |
+
b, c, h, w = q.shape
|
1223 |
+
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
|
1224 |
+
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
1225 |
+
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
1226 |
+
hidden = self.proj_out(hidden)
|
1227 |
+
return x + hidden
|
1228 |
+
|
1229 |
+
|
1230 |
+
class Downsample(torch.nn.Module):
|
1231 |
+
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
1232 |
+
super().__init__()
|
1233 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
|
1234 |
+
|
1235 |
+
def forward(self, x):
|
1236 |
+
pad = (0, 1, 0, 1)
|
1237 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
1238 |
+
x = self.conv(x)
|
1239 |
+
return x
|
1240 |
+
|
1241 |
+
|
1242 |
+
class Upsample(torch.nn.Module):
|
1243 |
+
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
1244 |
+
super().__init__()
|
1245 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1246 |
+
|
1247 |
+
def forward(self, x):
|
1248 |
+
org_dtype = x.dtype
|
1249 |
+
if x.dtype == torch.bfloat16:
|
1250 |
+
x = x.to(torch.float32)
|
1251 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
1252 |
+
if x.dtype != org_dtype:
|
1253 |
+
x = x.to(org_dtype)
|
1254 |
+
x = self.conv(x)
|
1255 |
+
return x
|
1256 |
+
|
1257 |
+
|
1258 |
+
class VAEEncoder(torch.nn.Module):
|
1259 |
+
def __init__(
|
1260 |
+
self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None
|
1261 |
+
):
|
1262 |
+
super().__init__()
|
1263 |
+
self.num_resolutions = len(ch_mult)
|
1264 |
+
self.num_res_blocks = num_res_blocks
|
1265 |
+
# downsampling
|
1266 |
+
self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1267 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
1268 |
+
self.in_ch_mult = in_ch_mult
|
1269 |
+
self.down = torch.nn.ModuleList()
|
1270 |
+
for i_level in range(self.num_resolutions):
|
1271 |
+
block = torch.nn.ModuleList()
|
1272 |
+
attn = torch.nn.ModuleList()
|
1273 |
+
block_in = ch * in_ch_mult[i_level]
|
1274 |
+
block_out = ch * ch_mult[i_level]
|
1275 |
+
for i_block in range(num_res_blocks):
|
1276 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
1277 |
+
block_in = block_out
|
1278 |
+
down = torch.nn.Module()
|
1279 |
+
down.block = block
|
1280 |
+
down.attn = attn
|
1281 |
+
if i_level != self.num_resolutions - 1:
|
1282 |
+
down.downsample = Downsample(block_in, dtype=dtype, device=device)
|
1283 |
+
self.down.append(down)
|
1284 |
+
# middle
|
1285 |
+
self.mid = torch.nn.Module()
|
1286 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
1287 |
+
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
1288 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
1289 |
+
# end
|
1290 |
+
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
1291 |
+
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1292 |
+
self.swish = torch.nn.SiLU(inplace=True)
|
1293 |
+
|
1294 |
+
def forward(self, x):
|
1295 |
+
# downsampling
|
1296 |
+
hs = [self.conv_in(x)]
|
1297 |
+
for i_level in range(self.num_resolutions):
|
1298 |
+
for i_block in range(self.num_res_blocks):
|
1299 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
1300 |
+
hs.append(h)
|
1301 |
+
if i_level != self.num_resolutions - 1:
|
1302 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
1303 |
+
# middle
|
1304 |
+
h = hs[-1]
|
1305 |
+
h = self.mid.block_1(h)
|
1306 |
+
h = self.mid.attn_1(h)
|
1307 |
+
h = self.mid.block_2(h)
|
1308 |
+
# end
|
1309 |
+
h = self.norm_out(h)
|
1310 |
+
h = self.swish(h)
|
1311 |
+
h = self.conv_out(h)
|
1312 |
+
return h
|
1313 |
+
|
1314 |
+
|
1315 |
+
class VAEDecoder(torch.nn.Module):
|
1316 |
+
def __init__(
|
1317 |
+
self,
|
1318 |
+
ch=128,
|
1319 |
+
out_ch=3,
|
1320 |
+
ch_mult=(1, 2, 4, 4),
|
1321 |
+
num_res_blocks=2,
|
1322 |
+
resolution=256,
|
1323 |
+
z_channels=16,
|
1324 |
+
dtype=torch.float32,
|
1325 |
+
device=None,
|
1326 |
+
):
|
1327 |
+
super().__init__()
|
1328 |
+
self.num_resolutions = len(ch_mult)
|
1329 |
+
self.num_res_blocks = num_res_blocks
|
1330 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
1331 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
1332 |
+
# z to block_in
|
1333 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1334 |
+
# middle
|
1335 |
+
self.mid = torch.nn.Module()
|
1336 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
1337 |
+
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
1338 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
1339 |
+
# upsampling
|
1340 |
+
self.up = torch.nn.ModuleList()
|
1341 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1342 |
+
block = torch.nn.ModuleList()
|
1343 |
+
block_out = ch * ch_mult[i_level]
|
1344 |
+
for i_block in range(self.num_res_blocks + 1):
|
1345 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
1346 |
+
block_in = block_out
|
1347 |
+
up = torch.nn.Module()
|
1348 |
+
up.block = block
|
1349 |
+
if i_level != 0:
|
1350 |
+
up.upsample = Upsample(block_in, dtype=dtype, device=device)
|
1351 |
+
curr_res = curr_res * 2
|
1352 |
+
self.up.insert(0, up) # prepend to get consistent order
|
1353 |
+
# end
|
1354 |
+
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
1355 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
1356 |
+
self.swish = torch.nn.SiLU(inplace=True)
|
1357 |
+
|
1358 |
+
def forward(self, z):
|
1359 |
+
# z to block_in
|
1360 |
+
hidden = self.conv_in(z)
|
1361 |
+
# middle
|
1362 |
+
hidden = self.mid.block_1(hidden)
|
1363 |
+
hidden = self.mid.attn_1(hidden)
|
1364 |
+
hidden = self.mid.block_2(hidden)
|
1365 |
+
# upsampling
|
1366 |
+
for i_level in reversed(range(self.num_resolutions)):
|
1367 |
+
for i_block in range(self.num_res_blocks + 1):
|
1368 |
+
hidden = self.up[i_level].block[i_block](hidden)
|
1369 |
+
if i_level != 0:
|
1370 |
+
hidden = self.up[i_level].upsample(hidden)
|
1371 |
+
# end
|
1372 |
+
hidden = self.norm_out(hidden)
|
1373 |
+
hidden = self.swish(hidden)
|
1374 |
+
hidden = self.conv_out(hidden)
|
1375 |
+
return hidden
|
1376 |
+
|
1377 |
+
|
1378 |
+
class SDVAE(torch.nn.Module):
|
1379 |
+
def __init__(self, dtype=torch.float32, device=None):
|
1380 |
+
super().__init__()
|
1381 |
+
self.encoder = VAEEncoder(dtype=dtype, device=device)
|
1382 |
+
self.decoder = VAEDecoder(dtype=dtype, device=device)
|
1383 |
+
|
1384 |
+
@property
|
1385 |
+
def device(self):
|
1386 |
+
return next(self.parameters()).device
|
1387 |
+
|
1388 |
+
@property
|
1389 |
+
def dtype(self):
|
1390 |
+
return next(self.parameters()).dtype
|
1391 |
+
|
1392 |
+
# @torch.autocast("cuda", dtype=torch.float16)
|
1393 |
+
def decode(self, latent):
|
1394 |
+
return self.decoder(latent)
|
1395 |
+
|
1396 |
+
# @torch.autocast("cuda", dtype=torch.float16)
|
1397 |
+
def encode(self, image):
|
1398 |
+
hidden = self.encoder(image)
|
1399 |
+
mean, logvar = torch.chunk(hidden, 2, dim=1)
|
1400 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
1401 |
+
std = torch.exp(0.5 * logvar)
|
1402 |
+
return mean + std * torch.randn_like(mean)
|
1403 |
+
|
1404 |
+
@staticmethod
|
1405 |
+
def process_in(latent):
|
1406 |
+
return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR
|
1407 |
+
|
1408 |
+
@staticmethod
|
1409 |
+
def process_out(latent):
|
1410 |
+
return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR
|
1411 |
+
|
1412 |
+
|
1413 |
+
# endregion
|
library/sd3_train_utils.py
ADDED
@@ -0,0 +1,945 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import toml
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
from typing import Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from safetensors.torch import save_file
|
11 |
+
from accelerate import Accelerator, PartialState
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image
|
14 |
+
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
15 |
+
|
16 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
17 |
+
|
18 |
+
init_ipex()
|
19 |
+
|
20 |
+
# from transformers import CLIPTokenizer
|
21 |
+
# from library import model_util
|
22 |
+
# , sdxl_model_util, train_util, sdxl_original_unet
|
23 |
+
# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
24 |
+
from .utils import setup_logging
|
25 |
+
|
26 |
+
setup_logging()
|
27 |
+
import logging
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
from library import sd3_models, sd3_utils, strategy_base, train_util
|
32 |
+
|
33 |
+
|
34 |
+
def save_models(
|
35 |
+
ckpt_path: str,
|
36 |
+
mmdit: Optional[sd3_models.MMDiT],
|
37 |
+
vae: Optional[sd3_models.SDVAE],
|
38 |
+
clip_l: Optional[CLIPTextModelWithProjection],
|
39 |
+
clip_g: Optional[CLIPTextModelWithProjection],
|
40 |
+
t5xxl: Optional[T5EncoderModel],
|
41 |
+
sai_metadata: Optional[dict],
|
42 |
+
save_dtype: Optional[torch.dtype] = None,
|
43 |
+
):
|
44 |
+
r"""
|
45 |
+
Save models to checkpoint file. Only supports unified checkpoint format.
|
46 |
+
"""
|
47 |
+
|
48 |
+
state_dict = {}
|
49 |
+
|
50 |
+
def update_sd(prefix, sd):
|
51 |
+
for k, v in sd.items():
|
52 |
+
key = prefix + k
|
53 |
+
if save_dtype is not None:
|
54 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
55 |
+
state_dict[key] = v
|
56 |
+
|
57 |
+
update_sd("model.diffusion_model.", mmdit.state_dict())
|
58 |
+
update_sd("first_stage_model.", vae.state_dict())
|
59 |
+
|
60 |
+
# do not support unified checkpoint format for now
|
61 |
+
# if clip_l is not None:
|
62 |
+
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
|
63 |
+
# if clip_g is not None:
|
64 |
+
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
|
65 |
+
# if t5xxl is not None:
|
66 |
+
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
|
67 |
+
|
68 |
+
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
69 |
+
|
70 |
+
if clip_l is not None:
|
71 |
+
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
|
72 |
+
save_file(clip_l.state_dict(), clip_l_path)
|
73 |
+
if clip_g is not None:
|
74 |
+
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
|
75 |
+
save_file(clip_g.state_dict(), clip_g_path)
|
76 |
+
if t5xxl is not None:
|
77 |
+
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
|
78 |
+
t5xxl_state_dict = t5xxl.state_dict()
|
79 |
+
|
80 |
+
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
|
81 |
+
shared_weight = t5xxl_state_dict["shared.weight"]
|
82 |
+
shared_weight_copy = shared_weight.detach().clone()
|
83 |
+
t5xxl_state_dict["shared.weight"] = shared_weight_copy
|
84 |
+
|
85 |
+
save_file(t5xxl_state_dict, t5xxl_path)
|
86 |
+
|
87 |
+
|
88 |
+
def save_sd3_model_on_train_end(
|
89 |
+
args: argparse.Namespace,
|
90 |
+
save_dtype: torch.dtype,
|
91 |
+
epoch: int,
|
92 |
+
global_step: int,
|
93 |
+
clip_l: Optional[CLIPTextModelWithProjection],
|
94 |
+
clip_g: Optional[CLIPTextModelWithProjection],
|
95 |
+
t5xxl: Optional[T5EncoderModel],
|
96 |
+
mmdit: sd3_models.MMDiT,
|
97 |
+
vae: sd3_models.SDVAE,
|
98 |
+
):
|
99 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
100 |
+
sai_metadata = train_util.get_sai_model_spec(
|
101 |
+
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
|
102 |
+
)
|
103 |
+
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
|
104 |
+
|
105 |
+
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
106 |
+
|
107 |
+
|
108 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
109 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
110 |
+
def save_sd3_model_on_epoch_end_or_stepwise(
|
111 |
+
args: argparse.Namespace,
|
112 |
+
on_epoch_end: bool,
|
113 |
+
accelerator,
|
114 |
+
save_dtype: torch.dtype,
|
115 |
+
epoch: int,
|
116 |
+
num_train_epochs: int,
|
117 |
+
global_step: int,
|
118 |
+
clip_l: Optional[CLIPTextModelWithProjection],
|
119 |
+
clip_g: Optional[CLIPTextModelWithProjection],
|
120 |
+
t5xxl: Optional[T5EncoderModel],
|
121 |
+
mmdit: sd3_models.MMDiT,
|
122 |
+
vae: sd3_models.SDVAE,
|
123 |
+
):
|
124 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
125 |
+
sai_metadata = train_util.get_sai_model_spec(
|
126 |
+
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
|
127 |
+
)
|
128 |
+
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
|
129 |
+
|
130 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
131 |
+
args,
|
132 |
+
on_epoch_end,
|
133 |
+
accelerator,
|
134 |
+
True,
|
135 |
+
True,
|
136 |
+
epoch,
|
137 |
+
num_train_epochs,
|
138 |
+
global_step,
|
139 |
+
sd_saver,
|
140 |
+
None,
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
145 |
+
parser.add_argument(
|
146 |
+
"--clip_l",
|
147 |
+
type=str,
|
148 |
+
required=False,
|
149 |
+
help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--clip_g",
|
153 |
+
type=str,
|
154 |
+
required=False,
|
155 |
+
help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--t5xxl",
|
159 |
+
type=str,
|
160 |
+
required=False,
|
161 |
+
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--save_clip",
|
165 |
+
action="store_true",
|
166 |
+
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--save_t5xxl",
|
170 |
+
action="store_true",
|
171 |
+
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
|
172 |
+
)
|
173 |
+
|
174 |
+
parser.add_argument(
|
175 |
+
"--t5xxl_device",
|
176 |
+
type=str,
|
177 |
+
default=None,
|
178 |
+
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--t5xxl_dtype",
|
182 |
+
type=str,
|
183 |
+
default=None,
|
184 |
+
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用",
|
185 |
+
)
|
186 |
+
|
187 |
+
parser.add_argument(
|
188 |
+
"--t5xxl_max_token_length",
|
189 |
+
type=int,
|
190 |
+
default=256,
|
191 |
+
help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
|
192 |
+
)
|
193 |
+
parser.add_argument(
|
194 |
+
"--apply_lg_attn_mask",
|
195 |
+
action="store_true",
|
196 |
+
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する",
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--apply_t5_attn_mask",
|
200 |
+
action="store_true",
|
201 |
+
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--clip_l_dropout_rate",
|
205 |
+
type=float,
|
206 |
+
default=0.0,
|
207 |
+
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--clip_g_dropout_rate",
|
211 |
+
type=float,
|
212 |
+
default=0.0,
|
213 |
+
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--t5_dropout_rate",
|
217 |
+
type=float,
|
218 |
+
default=0.0,
|
219 |
+
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--pos_emb_random_crop_rate",
|
223 |
+
type=float,
|
224 |
+
default=0.0,
|
225 |
+
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
|
226 |
+
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--enable_scaled_pos_embed",
|
230 |
+
action="store_true",
|
231 |
+
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
|
232 |
+
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
|
233 |
+
)
|
234 |
+
|
235 |
+
# Dependencies of Diffusers noise sampler has been removed for clarity in training
|
236 |
+
|
237 |
+
parser.add_argument(
|
238 |
+
"--training_shift",
|
239 |
+
type=float,
|
240 |
+
default=1.0,
|
241 |
+
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
246 |
+
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
247 |
+
if args.v_parameterization:
|
248 |
+
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
249 |
+
|
250 |
+
if args.clip_skip is not None:
|
251 |
+
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
252 |
+
|
253 |
+
# if args.multires_noise_iterations:
|
254 |
+
# logger.info(
|
255 |
+
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
256 |
+
# )
|
257 |
+
# else:
|
258 |
+
# if args.noise_offset is None:
|
259 |
+
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
260 |
+
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
261 |
+
# logger.info(
|
262 |
+
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
263 |
+
# )
|
264 |
+
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
265 |
+
|
266 |
+
assert (
|
267 |
+
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
268 |
+
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
269 |
+
|
270 |
+
if supportTextEncoderCaching:
|
271 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
272 |
+
args.cache_text_encoder_outputs = True
|
273 |
+
logger.warning(
|
274 |
+
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
275 |
+
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
276 |
+
)
|
277 |
+
|
278 |
+
|
279 |
+
# temporary copied from sd3_minimal_inferece.py
|
280 |
+
|
281 |
+
|
282 |
+
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
283 |
+
start = sampling.timestep(sampling.sigma_max)
|
284 |
+
end = sampling.timestep(sampling.sigma_min)
|
285 |
+
timesteps = torch.linspace(start, end, steps)
|
286 |
+
sigs = []
|
287 |
+
for x in range(len(timesteps)):
|
288 |
+
ts = timesteps[x]
|
289 |
+
sigs.append(sampling.sigma(ts))
|
290 |
+
sigs += [0.0]
|
291 |
+
return torch.FloatTensor(sigs)
|
292 |
+
|
293 |
+
|
294 |
+
def max_denoise(model_sampling, sigmas):
|
295 |
+
max_sigma = float(model_sampling.sigma_max)
|
296 |
+
sigma = float(sigmas[0])
|
297 |
+
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
298 |
+
|
299 |
+
|
300 |
+
def do_sample(
|
301 |
+
height: int,
|
302 |
+
width: int,
|
303 |
+
seed: int,
|
304 |
+
cond: Tuple[torch.Tensor, torch.Tensor],
|
305 |
+
neg_cond: Tuple[torch.Tensor, torch.Tensor],
|
306 |
+
mmdit: sd3_models.MMDiT,
|
307 |
+
steps: int,
|
308 |
+
guidance_scale: float,
|
309 |
+
dtype: torch.dtype,
|
310 |
+
device: str,
|
311 |
+
):
|
312 |
+
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
|
313 |
+
latent = latent.to(dtype).to(device)
|
314 |
+
|
315 |
+
# noise = get_noise(seed, latent).to(device)
|
316 |
+
if seed is not None:
|
317 |
+
generator = torch.manual_seed(seed)
|
318 |
+
else:
|
319 |
+
generator = None
|
320 |
+
noise = (
|
321 |
+
torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu")
|
322 |
+
.to(latent.dtype)
|
323 |
+
.to(device)
|
324 |
+
)
|
325 |
+
|
326 |
+
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
|
327 |
+
|
328 |
+
sigmas = get_all_sigmas(model_sampling, steps).to(device)
|
329 |
+
|
330 |
+
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
|
331 |
+
|
332 |
+
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
|
333 |
+
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
|
334 |
+
|
335 |
+
x = noise_scaled.to(device).to(dtype)
|
336 |
+
# print(x.shape)
|
337 |
+
|
338 |
+
# with torch.no_grad():
|
339 |
+
for i in tqdm(range(len(sigmas) - 1)):
|
340 |
+
sigma_hat = sigmas[i]
|
341 |
+
|
342 |
+
timestep = model_sampling.timestep(sigma_hat).float()
|
343 |
+
timestep = torch.FloatTensor([timestep, timestep]).to(device)
|
344 |
+
|
345 |
+
x_c_nc = torch.cat([x, x], dim=0)
|
346 |
+
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
|
347 |
+
|
348 |
+
mmdit.prepare_block_swap_before_forward()
|
349 |
+
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
|
350 |
+
model_output = model_output.float()
|
351 |
+
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
|
352 |
+
|
353 |
+
pos_out, neg_out = batched.chunk(2)
|
354 |
+
denoised = neg_out + (pos_out - neg_out) * guidance_scale
|
355 |
+
# print(denoised.shape)
|
356 |
+
|
357 |
+
# d = to_d(x, sigma_hat, denoised)
|
358 |
+
dims_to_append = x.ndim - sigma_hat.ndim
|
359 |
+
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
|
360 |
+
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
|
361 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
362 |
+
d = (x - denoised) / sigma_hat_dims
|
363 |
+
|
364 |
+
dt = sigmas[i + 1] - sigma_hat
|
365 |
+
|
366 |
+
# Euler method
|
367 |
+
x = x + d * dt
|
368 |
+
x = x.to(dtype)
|
369 |
+
|
370 |
+
mmdit.prepare_block_swap_before_forward()
|
371 |
+
return x
|
372 |
+
|
373 |
+
|
374 |
+
def sample_images(
|
375 |
+
accelerator: Accelerator,
|
376 |
+
args: argparse.Namespace,
|
377 |
+
epoch,
|
378 |
+
steps,
|
379 |
+
mmdit,
|
380 |
+
vae,
|
381 |
+
text_encoders,
|
382 |
+
sample_prompts_te_outputs,
|
383 |
+
prompt_replacement=None,
|
384 |
+
):
|
385 |
+
if steps == 0:
|
386 |
+
if not args.sample_at_first:
|
387 |
+
return
|
388 |
+
else:
|
389 |
+
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
390 |
+
return
|
391 |
+
if args.sample_every_n_epochs is not None:
|
392 |
+
# sample_every_n_steps は無視する
|
393 |
+
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
394 |
+
return
|
395 |
+
else:
|
396 |
+
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
397 |
+
return
|
398 |
+
|
399 |
+
logger.info("")
|
400 |
+
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
401 |
+
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
402 |
+
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
403 |
+
return
|
404 |
+
|
405 |
+
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
406 |
+
|
407 |
+
# unwrap unet and text_encoder(s)
|
408 |
+
mmdit = accelerator.unwrap_model(mmdit)
|
409 |
+
text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
|
410 |
+
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
411 |
+
|
412 |
+
prompts = train_util.load_prompts(args.sample_prompts)
|
413 |
+
|
414 |
+
save_dir = args.output_dir + "/sample"
|
415 |
+
os.makedirs(save_dir, exist_ok=True)
|
416 |
+
|
417 |
+
# save random state to restore later
|
418 |
+
rng_state = torch.get_rng_state()
|
419 |
+
cuda_rng_state = None
|
420 |
+
try:
|
421 |
+
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
422 |
+
except Exception:
|
423 |
+
pass
|
424 |
+
|
425 |
+
if distributed_state.num_processes <= 1:
|
426 |
+
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
427 |
+
with torch.no_grad(), accelerator.autocast():
|
428 |
+
for prompt_dict in prompts:
|
429 |
+
sample_image_inference(
|
430 |
+
accelerator,
|
431 |
+
args,
|
432 |
+
mmdit,
|
433 |
+
text_encoders,
|
434 |
+
vae,
|
435 |
+
save_dir,
|
436 |
+
prompt_dict,
|
437 |
+
epoch,
|
438 |
+
steps,
|
439 |
+
sample_prompts_te_outputs,
|
440 |
+
prompt_replacement,
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
444 |
+
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
445 |
+
per_process_prompts = [] # list of lists
|
446 |
+
for i in range(distributed_state.num_processes):
|
447 |
+
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
448 |
+
|
449 |
+
with torch.no_grad():
|
450 |
+
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
451 |
+
for prompt_dict in prompt_dict_lists[0]:
|
452 |
+
sample_image_inference(
|
453 |
+
accelerator,
|
454 |
+
args,
|
455 |
+
mmdit,
|
456 |
+
text_encoders,
|
457 |
+
vae,
|
458 |
+
save_dir,
|
459 |
+
prompt_dict,
|
460 |
+
epoch,
|
461 |
+
steps,
|
462 |
+
sample_prompts_te_outputs,
|
463 |
+
prompt_replacement,
|
464 |
+
)
|
465 |
+
|
466 |
+
torch.set_rng_state(rng_state)
|
467 |
+
if cuda_rng_state is not None:
|
468 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
469 |
+
|
470 |
+
clean_memory_on_device(accelerator.device)
|
471 |
+
|
472 |
+
|
473 |
+
def sample_image_inference(
|
474 |
+
accelerator: Accelerator,
|
475 |
+
args: argparse.Namespace,
|
476 |
+
mmdit: sd3_models.MMDiT,
|
477 |
+
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
|
478 |
+
vae: sd3_models.SDVAE,
|
479 |
+
save_dir,
|
480 |
+
prompt_dict,
|
481 |
+
epoch,
|
482 |
+
steps,
|
483 |
+
sample_prompts_te_outputs,
|
484 |
+
prompt_replacement,
|
485 |
+
):
|
486 |
+
assert isinstance(prompt_dict, dict)
|
487 |
+
negative_prompt = prompt_dict.get("negative_prompt")
|
488 |
+
sample_steps = prompt_dict.get("sample_steps", 30)
|
489 |
+
width = prompt_dict.get("width", 512)
|
490 |
+
height = prompt_dict.get("height", 512)
|
491 |
+
scale = prompt_dict.get("scale", 7.5)
|
492 |
+
seed = prompt_dict.get("seed")
|
493 |
+
# controlnet_image = prompt_dict.get("controlnet_image")
|
494 |
+
prompt: str = prompt_dict.get("prompt", "")
|
495 |
+
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
496 |
+
|
497 |
+
if prompt_replacement is not None:
|
498 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
499 |
+
if negative_prompt is not None:
|
500 |
+
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
501 |
+
|
502 |
+
if seed is not None:
|
503 |
+
torch.manual_seed(seed)
|
504 |
+
torch.cuda.manual_seed(seed)
|
505 |
+
else:
|
506 |
+
# True random sample image generation
|
507 |
+
torch.seed()
|
508 |
+
torch.cuda.seed()
|
509 |
+
|
510 |
+
if negative_prompt is None:
|
511 |
+
negative_prompt = ""
|
512 |
+
|
513 |
+
height = max(64, height - height % 8) # round to divisible by 8
|
514 |
+
width = max(64, width - width % 8) # round to divisible by 8
|
515 |
+
logger.info(f"prompt: {prompt}")
|
516 |
+
logger.info(f"negative_prompt: {negative_prompt}")
|
517 |
+
logger.info(f"height: {height}")
|
518 |
+
logger.info(f"width: {width}")
|
519 |
+
logger.info(f"sample_steps: {sample_steps}")
|
520 |
+
logger.info(f"scale: {scale}")
|
521 |
+
# logger.info(f"sample_sampler: {sampler_name}")
|
522 |
+
if seed is not None:
|
523 |
+
logger.info(f"seed: {seed}")
|
524 |
+
|
525 |
+
# encode prompts
|
526 |
+
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
527 |
+
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
528 |
+
|
529 |
+
def encode_prompt(prpt):
|
530 |
+
text_encoder_conds = []
|
531 |
+
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
532 |
+
text_encoder_conds = sample_prompts_te_outputs[prpt]
|
533 |
+
print(f"Using cached text encoder outputs for prompt: {prpt}")
|
534 |
+
if text_encoders is not None:
|
535 |
+
print(f"Encoding prompt: {prpt}")
|
536 |
+
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
537 |
+
# strategy has apply_t5_attn_mask option
|
538 |
+
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
539 |
+
|
540 |
+
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
541 |
+
if len(text_encoder_conds) == 0:
|
542 |
+
text_encoder_conds = encoded_text_encoder_conds
|
543 |
+
else:
|
544 |
+
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
545 |
+
for i in range(len(encoded_text_encoder_conds)):
|
546 |
+
if encoded_text_encoder_conds[i] is not None:
|
547 |
+
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
548 |
+
return text_encoder_conds
|
549 |
+
|
550 |
+
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
|
551 |
+
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
552 |
+
|
553 |
+
# encode negative prompts
|
554 |
+
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
|
555 |
+
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
556 |
+
|
557 |
+
# sample image
|
558 |
+
clean_memory_on_device(accelerator.device)
|
559 |
+
with accelerator.autocast(), torch.no_grad():
|
560 |
+
# mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
|
561 |
+
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
|
562 |
+
|
563 |
+
# latent to image
|
564 |
+
clean_memory_on_device(accelerator.device)
|
565 |
+
org_vae_device = vae.device # will be on cpu
|
566 |
+
vae.to(accelerator.device)
|
567 |
+
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
|
568 |
+
image = vae.decode(latents)
|
569 |
+
vae.to(org_vae_device)
|
570 |
+
clean_memory_on_device(accelerator.device)
|
571 |
+
|
572 |
+
image = image.float()
|
573 |
+
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
574 |
+
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
575 |
+
decoded_np = decoded_np.astype(np.uint8)
|
576 |
+
|
577 |
+
image = Image.fromarray(decoded_np)
|
578 |
+
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
579 |
+
# but adding 'enum' to the filename should be enough
|
580 |
+
|
581 |
+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
582 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
583 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
584 |
+
i: int = prompt_dict["enum"]
|
585 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
586 |
+
image.save(os.path.join(save_dir, img_filename))
|
587 |
+
|
588 |
+
# send images to wandb if enabled
|
589 |
+
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
590 |
+
wandb_tracker = accelerator.get_tracker("wandb")
|
591 |
+
|
592 |
+
import wandb
|
593 |
+
|
594 |
+
# not to commit images to avoid inconsistency between training and logging steps
|
595 |
+
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
596 |
+
|
597 |
+
|
598 |
+
# region Diffusers
|
599 |
+
|
600 |
+
|
601 |
+
from dataclasses import dataclass
|
602 |
+
from typing import Optional, Tuple, Union
|
603 |
+
|
604 |
+
import numpy as np
|
605 |
+
import torch
|
606 |
+
|
607 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
608 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
609 |
+
from diffusers.utils.torch_utils import randn_tensor
|
610 |
+
from diffusers.utils import BaseOutput
|
611 |
+
|
612 |
+
|
613 |
+
@dataclass
|
614 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
615 |
+
"""
|
616 |
+
Output class for the scheduler's `step` function output.
|
617 |
+
|
618 |
+
Args:
|
619 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
620 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
621 |
+
denoising loop.
|
622 |
+
"""
|
623 |
+
|
624 |
+
prev_sample: torch.FloatTensor
|
625 |
+
|
626 |
+
|
627 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
628 |
+
"""
|
629 |
+
Euler scheduler.
|
630 |
+
|
631 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
632 |
+
methods the library implements for all schedulers such as loading and saving.
|
633 |
+
|
634 |
+
Args:
|
635 |
+
num_train_timesteps (`int`, defaults to 1000):
|
636 |
+
The number of diffusion steps to train the model.
|
637 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
638 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
639 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
640 |
+
shift (`float`, defaults to 1.0):
|
641 |
+
The shift value for the timestep schedule.
|
642 |
+
"""
|
643 |
+
|
644 |
+
_compatibles = []
|
645 |
+
order = 1
|
646 |
+
|
647 |
+
@register_to_config
|
648 |
+
def __init__(
|
649 |
+
self,
|
650 |
+
num_train_timesteps: int = 1000,
|
651 |
+
shift: float = 1.0,
|
652 |
+
):
|
653 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
654 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
655 |
+
|
656 |
+
sigmas = timesteps / num_train_timesteps
|
657 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
658 |
+
|
659 |
+
self.timesteps = sigmas * num_train_timesteps
|
660 |
+
|
661 |
+
self._step_index = None
|
662 |
+
self._begin_index = None
|
663 |
+
|
664 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
665 |
+
self.sigma_min = self.sigmas[-1].item()
|
666 |
+
self.sigma_max = self.sigmas[0].item()
|
667 |
+
|
668 |
+
@property
|
669 |
+
def step_index(self):
|
670 |
+
"""
|
671 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
672 |
+
"""
|
673 |
+
return self._step_index
|
674 |
+
|
675 |
+
@property
|
676 |
+
def begin_index(self):
|
677 |
+
"""
|
678 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
679 |
+
"""
|
680 |
+
return self._begin_index
|
681 |
+
|
682 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
683 |
+
def set_begin_index(self, begin_index: int = 0):
|
684 |
+
"""
|
685 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
686 |
+
|
687 |
+
Args:
|
688 |
+
begin_index (`int`):
|
689 |
+
The begin index for the scheduler.
|
690 |
+
"""
|
691 |
+
self._begin_index = begin_index
|
692 |
+
|
693 |
+
def scale_noise(
|
694 |
+
self,
|
695 |
+
sample: torch.FloatTensor,
|
696 |
+
timestep: Union[float, torch.FloatTensor],
|
697 |
+
noise: Optional[torch.FloatTensor] = None,
|
698 |
+
) -> torch.FloatTensor:
|
699 |
+
"""
|
700 |
+
Forward process in flow-matching
|
701 |
+
|
702 |
+
Args:
|
703 |
+
sample (`torch.FloatTensor`):
|
704 |
+
The input sample.
|
705 |
+
timestep (`int`, *optional*):
|
706 |
+
The current timestep in the diffusion chain.
|
707 |
+
|
708 |
+
Returns:
|
709 |
+
`torch.FloatTensor`:
|
710 |
+
A scaled input sample.
|
711 |
+
"""
|
712 |
+
if self.step_index is None:
|
713 |
+
self._init_step_index(timestep)
|
714 |
+
|
715 |
+
sigma = self.sigmas[self.step_index]
|
716 |
+
sample = sigma * noise + (1.0 - sigma) * sample
|
717 |
+
|
718 |
+
return sample
|
719 |
+
|
720 |
+
def _sigma_to_t(self, sigma):
|
721 |
+
return sigma * self.config.num_train_timesteps
|
722 |
+
|
723 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
724 |
+
"""
|
725 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
726 |
+
|
727 |
+
Args:
|
728 |
+
num_inference_steps (`int`):
|
729 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
730 |
+
device (`str` or `torch.device`, *optional*):
|
731 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
732 |
+
"""
|
733 |
+
self.num_inference_steps = num_inference_steps
|
734 |
+
|
735 |
+
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
|
736 |
+
|
737 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
738 |
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
739 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
740 |
+
|
741 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
742 |
+
self.timesteps = timesteps.to(device=device)
|
743 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
744 |
+
|
745 |
+
self._step_index = None
|
746 |
+
self._begin_index = None
|
747 |
+
|
748 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
749 |
+
if schedule_timesteps is None:
|
750 |
+
schedule_timesteps = self.timesteps
|
751 |
+
|
752 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
753 |
+
|
754 |
+
# The sigma index that is taken for the **very** first `step`
|
755 |
+
# is always the second index (or the last index if there is only 1)
|
756 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
757 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
758 |
+
pos = 1 if len(indices) > 1 else 0
|
759 |
+
|
760 |
+
return indices[pos].item()
|
761 |
+
|
762 |
+
def _init_step_index(self, timestep):
|
763 |
+
if self.begin_index is None:
|
764 |
+
if isinstance(timestep, torch.Tensor):
|
765 |
+
timestep = timestep.to(self.timesteps.device)
|
766 |
+
self._step_index = self.index_for_timestep(timestep)
|
767 |
+
else:
|
768 |
+
self._step_index = self._begin_index
|
769 |
+
|
770 |
+
def step(
|
771 |
+
self,
|
772 |
+
model_output: torch.FloatTensor,
|
773 |
+
timestep: Union[float, torch.FloatTensor],
|
774 |
+
sample: torch.FloatTensor,
|
775 |
+
s_churn: float = 0.0,
|
776 |
+
s_tmin: float = 0.0,
|
777 |
+
s_tmax: float = float("inf"),
|
778 |
+
s_noise: float = 1.0,
|
779 |
+
generator: Optional[torch.Generator] = None,
|
780 |
+
return_dict: bool = True,
|
781 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
782 |
+
"""
|
783 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
784 |
+
process from the learned model outputs (most often the predicted noise).
|
785 |
+
|
786 |
+
Args:
|
787 |
+
model_output (`torch.FloatTensor`):
|
788 |
+
The direct output from learned diffusion model.
|
789 |
+
timestep (`float`):
|
790 |
+
The current discrete timestep in the diffusion chain.
|
791 |
+
sample (`torch.FloatTensor`):
|
792 |
+
A current instance of a sample created by the diffusion process.
|
793 |
+
s_churn (`float`):
|
794 |
+
s_tmin (`float`):
|
795 |
+
s_tmax (`float`):
|
796 |
+
s_noise (`float`, defaults to 1.0):
|
797 |
+
Scaling factor for noise added to the sample.
|
798 |
+
generator (`torch.Generator`, *optional*):
|
799 |
+
A random number generator.
|
800 |
+
return_dict (`bool`):
|
801 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
802 |
+
tuple.
|
803 |
+
|
804 |
+
Returns:
|
805 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
806 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
807 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
808 |
+
"""
|
809 |
+
|
810 |
+
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
811 |
+
raise ValueError(
|
812 |
+
(
|
813 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
814 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
815 |
+
" one of the `scheduler.timesteps` as a timestep."
|
816 |
+
),
|
817 |
+
)
|
818 |
+
|
819 |
+
if self.step_index is None:
|
820 |
+
self._init_step_index(timestep)
|
821 |
+
|
822 |
+
# Upcast to avoid precision issues when computing prev_sample
|
823 |
+
sample = sample.to(torch.float32)
|
824 |
+
|
825 |
+
sigma = self.sigmas[self.step_index]
|
826 |
+
|
827 |
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
828 |
+
|
829 |
+
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
|
830 |
+
|
831 |
+
eps = noise * s_noise
|
832 |
+
sigma_hat = sigma * (gamma + 1)
|
833 |
+
|
834 |
+
if gamma > 0:
|
835 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
836 |
+
|
837 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
838 |
+
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
839 |
+
# backwards compatibility
|
840 |
+
|
841 |
+
# if self.config.prediction_type == "vector_field":
|
842 |
+
|
843 |
+
denoised = sample - model_output * sigma
|
844 |
+
# 2. Convert to an ODE derivative
|
845 |
+
derivative = (sample - denoised) / sigma_hat
|
846 |
+
|
847 |
+
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
848 |
+
|
849 |
+
prev_sample = sample + derivative * dt
|
850 |
+
# Cast sample back to model compatible dtype
|
851 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
852 |
+
|
853 |
+
# upon completion increase step index by one
|
854 |
+
self._step_index += 1
|
855 |
+
|
856 |
+
if not return_dict:
|
857 |
+
return (prev_sample,)
|
858 |
+
|
859 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
860 |
+
|
861 |
+
def __len__(self):
|
862 |
+
return self.config.num_train_timesteps
|
863 |
+
|
864 |
+
|
865 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
866 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
867 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
868 |
+
timesteps = timesteps.to(device)
|
869 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
870 |
+
|
871 |
+
sigma = sigmas[step_indices].flatten()
|
872 |
+
while len(sigma.shape) < n_dim:
|
873 |
+
sigma = sigma.unsqueeze(-1)
|
874 |
+
return sigma
|
875 |
+
|
876 |
+
|
877 |
+
def compute_density_for_timestep_sampling(
|
878 |
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
879 |
+
):
|
880 |
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
881 |
+
|
882 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
883 |
+
|
884 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
885 |
+
"""
|
886 |
+
if weighting_scheme == "logit_normal":
|
887 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
888 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
889 |
+
u = torch.nn.functional.sigmoid(u)
|
890 |
+
elif weighting_scheme == "mode":
|
891 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
892 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
893 |
+
else:
|
894 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
895 |
+
return u
|
896 |
+
|
897 |
+
|
898 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
899 |
+
"""Computes loss weighting scheme for SD3 training.
|
900 |
+
|
901 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
902 |
+
|
903 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
904 |
+
"""
|
905 |
+
if weighting_scheme == "sigma_sqrt":
|
906 |
+
weighting = (sigmas**-2.0).float()
|
907 |
+
elif weighting_scheme == "cosmap":
|
908 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
909 |
+
weighting = 2 / (math.pi * bot)
|
910 |
+
else:
|
911 |
+
weighting = torch.ones_like(sigmas)
|
912 |
+
return weighting
|
913 |
+
|
914 |
+
|
915 |
+
# endregion
|
916 |
+
|
917 |
+
|
918 |
+
def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
919 |
+
bsz = latents.shape[0]
|
920 |
+
|
921 |
+
# Sample a random timestep for each image
|
922 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
923 |
+
u = compute_density_for_timestep_sampling(
|
924 |
+
weighting_scheme=args.weighting_scheme,
|
925 |
+
batch_size=bsz,
|
926 |
+
logit_mean=args.logit_mean,
|
927 |
+
logit_std=args.logit_std,
|
928 |
+
mode_scale=args.mode_scale,
|
929 |
+
)
|
930 |
+
t_min = args.min_timestep if args.min_timestep is not None else 0
|
931 |
+
t_max = args.max_timestep if args.max_timestep is not None else 1000
|
932 |
+
shift = args.training_shift
|
933 |
+
|
934 |
+
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
|
935 |
+
u = (u * shift) / (1 + (shift - 1) * u)
|
936 |
+
|
937 |
+
indices = (u * (t_max - t_min) + t_min).long()
|
938 |
+
timesteps = indices.to(device=device, dtype=dtype)
|
939 |
+
|
940 |
+
# sigmas according to flowmatching
|
941 |
+
sigmas = timesteps / 1000
|
942 |
+
sigmas = sigmas.view(-1, 1, 1, 1)
|
943 |
+
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
944 |
+
|
945 |
+
return noisy_model_input, timesteps, sigmas
|
library/sd3_utils.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
from typing import Dict, List, Optional, Union
|
5 |
+
import torch
|
6 |
+
import safetensors
|
7 |
+
from safetensors.torch import load_file
|
8 |
+
from accelerate import init_empty_weights
|
9 |
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
|
10 |
+
|
11 |
+
from .utils import setup_logging
|
12 |
+
|
13 |
+
setup_logging()
|
14 |
+
import logging
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
from library import sd3_models
|
19 |
+
|
20 |
+
# TODO move some of functions to model_util.py
|
21 |
+
from library import sdxl_model_util
|
22 |
+
|
23 |
+
# region models
|
24 |
+
|
25 |
+
# TODO remove dependency on flux_utils
|
26 |
+
from library.utils import load_safetensors
|
27 |
+
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
|
28 |
+
|
29 |
+
|
30 |
+
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
|
31 |
+
logger.info(f"Analyzing state dict state...")
|
32 |
+
|
33 |
+
# analyze configs
|
34 |
+
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
35 |
+
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
36 |
+
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
37 |
+
pos_embed_max_size = round(math.sqrt(num_patches))
|
38 |
+
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
39 |
+
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
40 |
+
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
|
41 |
+
|
42 |
+
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
|
43 |
+
x_block_self_attn_layers = []
|
44 |
+
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
|
45 |
+
for key in list(state_dict.keys()):
|
46 |
+
m = re_attn.search(key)
|
47 |
+
if m:
|
48 |
+
x_block_self_attn_layers.append(int(m.group(1)))
|
49 |
+
|
50 |
+
context_embedder_in_features = context_shape[1]
|
51 |
+
context_embedder_out_features = context_shape[0]
|
52 |
+
|
53 |
+
# only supports 3-5-large, medium or 3-medium
|
54 |
+
if qk_norm is not None:
|
55 |
+
if len(x_block_self_attn_layers) == 0:
|
56 |
+
model_type = "3-5-large"
|
57 |
+
else:
|
58 |
+
model_type = "3-5-medium"
|
59 |
+
else:
|
60 |
+
model_type = "3-medium"
|
61 |
+
|
62 |
+
params = sd3_models.SD3Params(
|
63 |
+
patch_size=patch_size,
|
64 |
+
depth=depth,
|
65 |
+
num_patches=num_patches,
|
66 |
+
pos_embed_max_size=pos_embed_max_size,
|
67 |
+
adm_in_channels=adm_in_channels,
|
68 |
+
qk_norm=qk_norm,
|
69 |
+
x_block_self_attn_layers=x_block_self_attn_layers,
|
70 |
+
context_embedder_in_features=context_embedder_in_features,
|
71 |
+
context_embedder_out_features=context_embedder_out_features,
|
72 |
+
model_type=model_type,
|
73 |
+
)
|
74 |
+
logger.info(f"Analyzed state dict state: {params}")
|
75 |
+
return params
|
76 |
+
|
77 |
+
|
78 |
+
def load_mmdit(
|
79 |
+
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
|
80 |
+
) -> sd3_models.MMDiT:
|
81 |
+
mmdit_sd = {}
|
82 |
+
|
83 |
+
mmdit_prefix = "model.diffusion_model."
|
84 |
+
for k in list(state_dict.keys()):
|
85 |
+
if k.startswith(mmdit_prefix):
|
86 |
+
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
87 |
+
|
88 |
+
# load MMDiT
|
89 |
+
logger.info("Building MMDit")
|
90 |
+
params = analyze_state_dict_state(mmdit_sd)
|
91 |
+
with init_empty_weights():
|
92 |
+
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
|
93 |
+
|
94 |
+
logger.info("Loading state dict...")
|
95 |
+
info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
|
96 |
+
logger.info(f"Loaded MMDiT: {info}")
|
97 |
+
return mmdit
|
98 |
+
|
99 |
+
|
100 |
+
def load_clip_l(
|
101 |
+
clip_l_path: Optional[str],
|
102 |
+
dtype: Optional[Union[str, torch.dtype]],
|
103 |
+
device: Union[str, torch.device],
|
104 |
+
disable_mmap: bool = False,
|
105 |
+
state_dict: Optional[Dict] = None,
|
106 |
+
):
|
107 |
+
clip_l_sd = None
|
108 |
+
if clip_l_path is None:
|
109 |
+
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
110 |
+
# found clip_l: remove prefix "text_encoders.clip_l."
|
111 |
+
logger.info("clip_l is included in the checkpoint")
|
112 |
+
clip_l_sd = {}
|
113 |
+
prefix = "text_encoders.clip_l."
|
114 |
+
for k in list(state_dict.keys()):
|
115 |
+
if k.startswith(prefix):
|
116 |
+
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
117 |
+
elif clip_l_path is None:
|
118 |
+
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
|
119 |
+
return None
|
120 |
+
|
121 |
+
# load clip_l
|
122 |
+
logger.info("Building CLIP-L")
|
123 |
+
config = CLIPTextConfig(
|
124 |
+
vocab_size=49408,
|
125 |
+
hidden_size=768,
|
126 |
+
intermediate_size=3072,
|
127 |
+
num_hidden_layers=12,
|
128 |
+
num_attention_heads=12,
|
129 |
+
max_position_embeddings=77,
|
130 |
+
hidden_act="quick_gelu",
|
131 |
+
layer_norm_eps=1e-05,
|
132 |
+
dropout=0.0,
|
133 |
+
attention_dropout=0.0,
|
134 |
+
initializer_range=0.02,
|
135 |
+
initializer_factor=1.0,
|
136 |
+
pad_token_id=1,
|
137 |
+
bos_token_id=0,
|
138 |
+
eos_token_id=2,
|
139 |
+
model_type="clip_text_model",
|
140 |
+
projection_dim=768,
|
141 |
+
# torch_dtype="float32",
|
142 |
+
# transformers_version="4.25.0.dev0",
|
143 |
+
)
|
144 |
+
with init_empty_weights():
|
145 |
+
clip = CLIPTextModelWithProjection(config)
|
146 |
+
|
147 |
+
if clip_l_sd is None:
|
148 |
+
logger.info(f"Loading state dict from {clip_l_path}")
|
149 |
+
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
150 |
+
|
151 |
+
if "text_projection.weight" not in clip_l_sd:
|
152 |
+
logger.info("Adding text_projection.weight to clip_l_sd")
|
153 |
+
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
|
154 |
+
|
155 |
+
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
|
156 |
+
logger.info(f"Loaded CLIP-L: {info}")
|
157 |
+
return clip
|
158 |
+
|
159 |
+
|
160 |
+
def load_clip_g(
|
161 |
+
clip_g_path: Optional[str],
|
162 |
+
dtype: Optional[Union[str, torch.dtype]],
|
163 |
+
device: Union[str, torch.device],
|
164 |
+
disable_mmap: bool = False,
|
165 |
+
state_dict: Optional[Dict] = None,
|
166 |
+
):
|
167 |
+
clip_g_sd = None
|
168 |
+
if state_dict is not None:
|
169 |
+
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
170 |
+
# found clip_g: remove prefix "text_encoders.clip_g."
|
171 |
+
logger.info("clip_g is included in the checkpoint")
|
172 |
+
clip_g_sd = {}
|
173 |
+
prefix = "text_encoders.clip_g."
|
174 |
+
for k in list(state_dict.keys()):
|
175 |
+
if k.startswith(prefix):
|
176 |
+
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
177 |
+
elif clip_g_path is None:
|
178 |
+
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
|
179 |
+
return None
|
180 |
+
|
181 |
+
# load clip_g
|
182 |
+
logger.info("Building CLIP-G")
|
183 |
+
config = CLIPTextConfig(
|
184 |
+
vocab_size=49408,
|
185 |
+
hidden_size=1280,
|
186 |
+
intermediate_size=5120,
|
187 |
+
num_hidden_layers=32,
|
188 |
+
num_attention_heads=20,
|
189 |
+
max_position_embeddings=77,
|
190 |
+
hidden_act="gelu",
|
191 |
+
layer_norm_eps=1e-05,
|
192 |
+
dropout=0.0,
|
193 |
+
attention_dropout=0.0,
|
194 |
+
initializer_range=0.02,
|
195 |
+
initializer_factor=1.0,
|
196 |
+
pad_token_id=1,
|
197 |
+
bos_token_id=0,
|
198 |
+
eos_token_id=2,
|
199 |
+
model_type="clip_text_model",
|
200 |
+
projection_dim=1280,
|
201 |
+
# torch_dtype="float32",
|
202 |
+
# transformers_version="4.25.0.dev0",
|
203 |
+
)
|
204 |
+
with init_empty_weights():
|
205 |
+
clip = CLIPTextModelWithProjection(config)
|
206 |
+
|
207 |
+
if clip_g_sd is None:
|
208 |
+
logger.info(f"Loading state dict from {clip_g_path}")
|
209 |
+
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
210 |
+
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
|
211 |
+
logger.info(f"Loaded CLIP-G: {info}")
|
212 |
+
return clip
|
213 |
+
|
214 |
+
|
215 |
+
def load_t5xxl(
|
216 |
+
t5xxl_path: Optional[str],
|
217 |
+
dtype: Optional[Union[str, torch.dtype]],
|
218 |
+
device: Union[str, torch.device],
|
219 |
+
disable_mmap: bool = False,
|
220 |
+
state_dict: Optional[Dict] = None,
|
221 |
+
):
|
222 |
+
t5xxl_sd = None
|
223 |
+
if state_dict is not None:
|
224 |
+
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
225 |
+
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
226 |
+
logger.info("t5xxl is included in the checkpoint")
|
227 |
+
t5xxl_sd = {}
|
228 |
+
prefix = "text_encoders.t5xxl."
|
229 |
+
for k in list(state_dict.keys()):
|
230 |
+
if k.startswith(prefix):
|
231 |
+
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
232 |
+
elif t5xxl_path is None:
|
233 |
+
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
|
234 |
+
return None
|
235 |
+
|
236 |
+
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
|
237 |
+
|
238 |
+
|
239 |
+
def load_vae(
|
240 |
+
vae_path: Optional[str],
|
241 |
+
vae_dtype: Optional[Union[str, torch.dtype]],
|
242 |
+
device: Optional[Union[str, torch.device]],
|
243 |
+
disable_mmap: bool = False,
|
244 |
+
state_dict: Optional[Dict] = None,
|
245 |
+
):
|
246 |
+
vae_sd = {}
|
247 |
+
if vae_path:
|
248 |
+
logger.info(f"Loading VAE from {vae_path}...")
|
249 |
+
vae_sd = load_safetensors(vae_path, device, disable_mmap)
|
250 |
+
else:
|
251 |
+
# remove prefix "first_stage_model."
|
252 |
+
vae_sd = {}
|
253 |
+
vae_prefix = "first_stage_model."
|
254 |
+
for k in list(state_dict.keys()):
|
255 |
+
if k.startswith(vae_prefix):
|
256 |
+
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
257 |
+
|
258 |
+
logger.info("Building VAE")
|
259 |
+
vae = sd3_models.SDVAE(vae_dtype, device)
|
260 |
+
logger.info("Loading state dict...")
|
261 |
+
info = vae.load_state_dict(vae_sd)
|
262 |
+
logger.info(f"Loaded VAE: {info}")
|
263 |
+
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
|
264 |
+
return vae
|
265 |
+
|
266 |
+
|
267 |
+
# endregion
|
268 |
+
|
269 |
+
|
270 |
+
class ModelSamplingDiscreteFlow:
|
271 |
+
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
272 |
+
|
273 |
+
def __init__(self, shift=1.0):
|
274 |
+
self.shift = shift
|
275 |
+
timesteps = 1000
|
276 |
+
self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))
|
277 |
+
|
278 |
+
@property
|
279 |
+
def sigma_min(self):
|
280 |
+
return self.sigmas[0]
|
281 |
+
|
282 |
+
@property
|
283 |
+
def sigma_max(self):
|
284 |
+
return self.sigmas[-1]
|
285 |
+
|
286 |
+
def timestep(self, sigma):
|
287 |
+
return sigma * 1000
|
288 |
+
|
289 |
+
def sigma(self, timestep: torch.Tensor):
|
290 |
+
timestep = timestep / 1000.0
|
291 |
+
if self.shift == 1.0:
|
292 |
+
return timestep
|
293 |
+
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
294 |
+
|
295 |
+
def calculate_denoised(self, sigma, model_output, model_input):
|
296 |
+
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
297 |
+
return model_input - model_output * sigma
|
298 |
+
|
299 |
+
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
300 |
+
# assert max_denoise is False, "max_denoise not implemented"
|
301 |
+
# max_denoise is always True, I'm not sure why it's there
|
302 |
+
return sigma * noise + (1.0 - sigma) * latent_image
|
library/sdxl_lpw_stable_diffusion.py
ADDED
@@ -0,0 +1,1271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
2 |
+
# and modify to support SD2.x
|
3 |
+
|
4 |
+
import inspect
|
5 |
+
import re
|
6 |
+
from typing import Callable, List, Optional, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from packaging import version
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
14 |
+
|
15 |
+
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
16 |
+
from diffusers.models import AutoencoderKL
|
17 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
18 |
+
from diffusers.utils import logging
|
19 |
+
from PIL import Image
|
20 |
+
|
21 |
+
from library import (
|
22 |
+
sdxl_model_util,
|
23 |
+
sdxl_train_util,
|
24 |
+
strategy_base,
|
25 |
+
strategy_sdxl,
|
26 |
+
train_util,
|
27 |
+
sdxl_original_unet,
|
28 |
+
sdxl_original_control_net,
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
try:
|
33 |
+
from diffusers.utils import PIL_INTERPOLATION
|
34 |
+
except ImportError:
|
35 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
36 |
+
PIL_INTERPOLATION = {
|
37 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
38 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
39 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
40 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
41 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
42 |
+
}
|
43 |
+
else:
|
44 |
+
PIL_INTERPOLATION = {
|
45 |
+
"linear": PIL.Image.LINEAR,
|
46 |
+
"bilinear": PIL.Image.BILINEAR,
|
47 |
+
"bicubic": PIL.Image.BICUBIC,
|
48 |
+
"lanczos": PIL.Image.LANCZOS,
|
49 |
+
"nearest": PIL.Image.NEAREST,
|
50 |
+
}
|
51 |
+
# ------------------------------------------------------------------------------
|
52 |
+
|
53 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
54 |
+
|
55 |
+
re_attention = re.compile(
|
56 |
+
r"""
|
57 |
+
\\\(|
|
58 |
+
\\\)|
|
59 |
+
\\\[|
|
60 |
+
\\]|
|
61 |
+
\\\\|
|
62 |
+
\\|
|
63 |
+
\(|
|
64 |
+
\[|
|
65 |
+
:([+-]?[.\d]+)\)|
|
66 |
+
\)|
|
67 |
+
]|
|
68 |
+
[^\\()\[\]:]+|
|
69 |
+
:
|
70 |
+
""",
|
71 |
+
re.X,
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def parse_prompt_attention(text):
|
76 |
+
"""
|
77 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
78 |
+
Accepted tokens are:
|
79 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
80 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
81 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
82 |
+
\( - literal character '('
|
83 |
+
\[ - literal character '['
|
84 |
+
\) - literal character ')'
|
85 |
+
\] - literal character ']'
|
86 |
+
\\ - literal character '\'
|
87 |
+
anything else - just text
|
88 |
+
>>> parse_prompt_attention('normal text')
|
89 |
+
[['normal text', 1.0]]
|
90 |
+
>>> parse_prompt_attention('an (important) word')
|
91 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
92 |
+
>>> parse_prompt_attention('(unbalanced')
|
93 |
+
[['unbalanced', 1.1]]
|
94 |
+
>>> parse_prompt_attention('\(literal\]')
|
95 |
+
[['(literal]', 1.0]]
|
96 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
97 |
+
[['unnecessaryparens', 1.1]]
|
98 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
99 |
+
[['a ', 1.0],
|
100 |
+
['house', 1.5730000000000004],
|
101 |
+
[' ', 1.1],
|
102 |
+
['on', 1.0],
|
103 |
+
[' a ', 1.1],
|
104 |
+
['hill', 0.55],
|
105 |
+
[', sun, ', 1.1],
|
106 |
+
['sky', 1.4641000000000006],
|
107 |
+
['.', 1.1]]
|
108 |
+
"""
|
109 |
+
|
110 |
+
res = []
|
111 |
+
round_brackets = []
|
112 |
+
square_brackets = []
|
113 |
+
|
114 |
+
round_bracket_multiplier = 1.1
|
115 |
+
square_bracket_multiplier = 1 / 1.1
|
116 |
+
|
117 |
+
def multiply_range(start_position, multiplier):
|
118 |
+
for p in range(start_position, len(res)):
|
119 |
+
res[p][1] *= multiplier
|
120 |
+
|
121 |
+
for m in re_attention.finditer(text):
|
122 |
+
text = m.group(0)
|
123 |
+
weight = m.group(1)
|
124 |
+
|
125 |
+
if text.startswith("\\"):
|
126 |
+
res.append([text[1:], 1.0])
|
127 |
+
elif text == "(":
|
128 |
+
round_brackets.append(len(res))
|
129 |
+
elif text == "[":
|
130 |
+
square_brackets.append(len(res))
|
131 |
+
elif weight is not None and len(round_brackets) > 0:
|
132 |
+
multiply_range(round_brackets.pop(), float(weight))
|
133 |
+
elif text == ")" and len(round_brackets) > 0:
|
134 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
135 |
+
elif text == "]" and len(square_brackets) > 0:
|
136 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
137 |
+
else:
|
138 |
+
res.append([text, 1.0])
|
139 |
+
|
140 |
+
for pos in round_brackets:
|
141 |
+
multiply_range(pos, round_bracket_multiplier)
|
142 |
+
|
143 |
+
for pos in square_brackets:
|
144 |
+
multiply_range(pos, square_bracket_multiplier)
|
145 |
+
|
146 |
+
if len(res) == 0:
|
147 |
+
res = [["", 1.0]]
|
148 |
+
|
149 |
+
# merge runs of identical weights
|
150 |
+
i = 0
|
151 |
+
while i + 1 < len(res):
|
152 |
+
if res[i][1] == res[i + 1][1]:
|
153 |
+
res[i][0] += res[i + 1][0]
|
154 |
+
res.pop(i + 1)
|
155 |
+
else:
|
156 |
+
i += 1
|
157 |
+
|
158 |
+
return res
|
159 |
+
|
160 |
+
|
161 |
+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
162 |
+
r"""
|
163 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
164 |
+
|
165 |
+
No padding, starting or ending token is included.
|
166 |
+
"""
|
167 |
+
tokens = []
|
168 |
+
weights = []
|
169 |
+
truncated = False
|
170 |
+
for text in prompt:
|
171 |
+
texts_and_weights = parse_prompt_attention(text)
|
172 |
+
text_token = []
|
173 |
+
text_weight = []
|
174 |
+
for word, weight in texts_and_weights:
|
175 |
+
# tokenize and discard the starting and the ending token
|
176 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
177 |
+
text_token += token
|
178 |
+
# copy the weight by length of token
|
179 |
+
text_weight += [weight] * len(token)
|
180 |
+
# stop if the text is too long (longer than truncation limit)
|
181 |
+
if len(text_token) > max_length:
|
182 |
+
truncated = True
|
183 |
+
break
|
184 |
+
# truncate
|
185 |
+
if len(text_token) > max_length:
|
186 |
+
truncated = True
|
187 |
+
text_token = text_token[:max_length]
|
188 |
+
text_weight = text_weight[:max_length]
|
189 |
+
tokens.append(text_token)
|
190 |
+
weights.append(text_weight)
|
191 |
+
if truncated:
|
192 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
193 |
+
return tokens, weights
|
194 |
+
|
195 |
+
|
196 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
197 |
+
r"""
|
198 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
199 |
+
"""
|
200 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
201 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
202 |
+
for i in range(len(tokens)):
|
203 |
+
tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
|
204 |
+
if no_boseos_middle:
|
205 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
206 |
+
else:
|
207 |
+
w = []
|
208 |
+
if len(weights[i]) == 0:
|
209 |
+
w = [1.0] * weights_length
|
210 |
+
else:
|
211 |
+
for j in range(max_embeddings_multiples):
|
212 |
+
w.append(1.0) # weight for starting token in this chunk
|
213 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
214 |
+
w.append(1.0) # weight for ending token in this chunk
|
215 |
+
w += [1.0] * (weights_length - len(w))
|
216 |
+
weights[i] = w[:]
|
217 |
+
|
218 |
+
return tokens, weights
|
219 |
+
|
220 |
+
|
221 |
+
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
|
222 |
+
if not is_sdxl_text_encoder2:
|
223 |
+
# text_encoder1: same as SD1/2
|
224 |
+
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
225 |
+
hidden_states = enc_out["hidden_states"][11]
|
226 |
+
pool = None
|
227 |
+
else:
|
228 |
+
# text_encoder2
|
229 |
+
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
230 |
+
hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
|
231 |
+
# pool = enc_out["text_embeds"]
|
232 |
+
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
|
233 |
+
hidden_states = hidden_states.to(device)
|
234 |
+
if pool is not None:
|
235 |
+
pool = pool.to(device)
|
236 |
+
return hidden_states, pool
|
237 |
+
|
238 |
+
|
239 |
+
def get_unweighted_text_embeddings(
|
240 |
+
pipe: StableDiffusionPipeline,
|
241 |
+
text_input: torch.Tensor,
|
242 |
+
chunk_length: int,
|
243 |
+
clip_skip: int,
|
244 |
+
eos: int,
|
245 |
+
pad: int,
|
246 |
+
is_sdxl_text_encoder2: bool,
|
247 |
+
no_boseos_middle: Optional[bool] = True,
|
248 |
+
):
|
249 |
+
"""
|
250 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
251 |
+
it should be split into chunks and sent to the text encoder individually.
|
252 |
+
"""
|
253 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
254 |
+
text_pool = None
|
255 |
+
if max_embeddings_multiples > 1:
|
256 |
+
text_embeddings = []
|
257 |
+
for i in range(max_embeddings_multiples):
|
258 |
+
# extract the i-th chunk
|
259 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
260 |
+
|
261 |
+
# cover the head and the tail by the starting and the ending tokens
|
262 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
263 |
+
if pad == eos: # v1
|
264 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
265 |
+
else: # v2
|
266 |
+
for j in range(len(text_input_chunk)):
|
267 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
268 |
+
text_input_chunk[j, -1] = eos
|
269 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
270 |
+
text_input_chunk[j, 1] = eos
|
271 |
+
|
272 |
+
text_embedding, current_text_pool = get_hidden_states(
|
273 |
+
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
|
274 |
+
)
|
275 |
+
if text_pool is None:
|
276 |
+
text_pool = current_text_pool
|
277 |
+
|
278 |
+
if no_boseos_middle:
|
279 |
+
if i == 0:
|
280 |
+
# discard the ending token
|
281 |
+
text_embedding = text_embedding[:, :-1]
|
282 |
+
elif i == max_embeddings_multiples - 1:
|
283 |
+
# discard the starting token
|
284 |
+
text_embedding = text_embedding[:, 1:]
|
285 |
+
else:
|
286 |
+
# discard both starting and ending tokens
|
287 |
+
text_embedding = text_embedding[:, 1:-1]
|
288 |
+
|
289 |
+
text_embeddings.append(text_embedding)
|
290 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
291 |
+
else:
|
292 |
+
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
|
293 |
+
return text_embeddings, text_pool
|
294 |
+
|
295 |
+
|
296 |
+
def get_weighted_text_embeddings(
|
297 |
+
pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline,
|
298 |
+
prompt: Union[str, List[str]],
|
299 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
300 |
+
max_embeddings_multiples: Optional[int] = 3,
|
301 |
+
no_boseos_middle: Optional[bool] = False,
|
302 |
+
skip_parsing: Optional[bool] = False,
|
303 |
+
skip_weighting: Optional[bool] = False,
|
304 |
+
clip_skip=None,
|
305 |
+
is_sdxl_text_encoder2=False,
|
306 |
+
):
|
307 |
+
r"""
|
308 |
+
Prompts can be assigned with local weights using brackets. For example,
|
309 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
310 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
311 |
+
|
312 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
pipe (`StableDiffusionPipeline`):
|
316 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
317 |
+
prompt (`str` or `List[str]`):
|
318 |
+
The prompt or prompts to guide the image generation.
|
319 |
+
uncond_prompt (`str` or `List[str]`):
|
320 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
321 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
322 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
323 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
324 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
325 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
326 |
+
ending token in each of the chunk in the middle.
|
327 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
328 |
+
Skip the parsing of brackets.
|
329 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
330 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
331 |
+
"""
|
332 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
333 |
+
if isinstance(prompt, str):
|
334 |
+
prompt = [prompt]
|
335 |
+
|
336 |
+
if not skip_parsing:
|
337 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
338 |
+
if uncond_prompt is not None:
|
339 |
+
if isinstance(uncond_prompt, str):
|
340 |
+
uncond_prompt = [uncond_prompt]
|
341 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
342 |
+
else:
|
343 |
+
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
344 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
345 |
+
if uncond_prompt is not None:
|
346 |
+
if isinstance(uncond_prompt, str):
|
347 |
+
uncond_prompt = [uncond_prompt]
|
348 |
+
uncond_tokens = [
|
349 |
+
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
350 |
+
]
|
351 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
352 |
+
|
353 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
354 |
+
max_length = max([len(token) for token in prompt_tokens])
|
355 |
+
if uncond_prompt is not None:
|
356 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
357 |
+
|
358 |
+
max_embeddings_multiples = min(
|
359 |
+
max_embeddings_multiples,
|
360 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
361 |
+
)
|
362 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
363 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
364 |
+
|
365 |
+
# pad the length of tokens and weights
|
366 |
+
bos = pipe.tokenizer.bos_token_id
|
367 |
+
eos = pipe.tokenizer.eos_token_id
|
368 |
+
pad = pipe.tokenizer.pad_token_id
|
369 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
370 |
+
prompt_tokens,
|
371 |
+
prompt_weights,
|
372 |
+
max_length,
|
373 |
+
bos,
|
374 |
+
eos,
|
375 |
+
pad,
|
376 |
+
no_boseos_middle=no_boseos_middle,
|
377 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
378 |
+
)
|
379 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
380 |
+
if uncond_prompt is not None:
|
381 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
382 |
+
uncond_tokens,
|
383 |
+
uncond_weights,
|
384 |
+
max_length,
|
385 |
+
bos,
|
386 |
+
eos,
|
387 |
+
pad,
|
388 |
+
no_boseos_middle=no_boseos_middle,
|
389 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
390 |
+
)
|
391 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
392 |
+
|
393 |
+
# get the embeddings
|
394 |
+
text_embeddings, text_pool = get_unweighted_text_embeddings(
|
395 |
+
pipe,
|
396 |
+
prompt_tokens,
|
397 |
+
pipe.tokenizer.model_max_length,
|
398 |
+
clip_skip,
|
399 |
+
eos,
|
400 |
+
pad,
|
401 |
+
is_sdxl_text_encoder2,
|
402 |
+
no_boseos_middle=no_boseos_middle,
|
403 |
+
)
|
404 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
405 |
+
|
406 |
+
if uncond_prompt is not None:
|
407 |
+
uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
|
408 |
+
pipe,
|
409 |
+
uncond_tokens,
|
410 |
+
pipe.tokenizer.model_max_length,
|
411 |
+
clip_skip,
|
412 |
+
eos,
|
413 |
+
pad,
|
414 |
+
is_sdxl_text_encoder2,
|
415 |
+
no_boseos_middle=no_boseos_middle,
|
416 |
+
)
|
417 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
418 |
+
|
419 |
+
# assign weights to the prompts and normalize in the sense of mean
|
420 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
421 |
+
if (not skip_parsing) and (not skip_weighting):
|
422 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
423 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
424 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
425 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
426 |
+
if uncond_prompt is not None:
|
427 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
428 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
429 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
430 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
431 |
+
|
432 |
+
if uncond_prompt is not None:
|
433 |
+
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
|
434 |
+
return text_embeddings, text_pool, None, None
|
435 |
+
|
436 |
+
|
437 |
+
def preprocess_image(image):
|
438 |
+
w, h = image.size
|
439 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
440 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
441 |
+
image = np.array(image).astype(np.float32) / 255.0
|
442 |
+
image = image[None].transpose(0, 3, 1, 2)
|
443 |
+
image = torch.from_numpy(image)
|
444 |
+
return 2.0 * image - 1.0
|
445 |
+
|
446 |
+
|
447 |
+
def preprocess_mask(mask, scale_factor=8):
|
448 |
+
mask = mask.convert("L")
|
449 |
+
w, h = mask.size
|
450 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
451 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
452 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
453 |
+
mask = np.tile(mask, (4, 1, 1))
|
454 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
455 |
+
mask = 1 - mask # repaint white, keep black
|
456 |
+
mask = torch.from_numpy(mask)
|
457 |
+
return mask
|
458 |
+
|
459 |
+
|
460 |
+
def prepare_controlnet_image(
|
461 |
+
image: PIL.Image.Image,
|
462 |
+
width: int,
|
463 |
+
height: int,
|
464 |
+
batch_size: int,
|
465 |
+
num_images_per_prompt: int,
|
466 |
+
device: torch.device,
|
467 |
+
dtype: torch.dtype,
|
468 |
+
do_classifier_free_guidance: bool = False,
|
469 |
+
guess_mode: bool = False,
|
470 |
+
):
|
471 |
+
if not isinstance(image, torch.Tensor):
|
472 |
+
if isinstance(image, PIL.Image.Image):
|
473 |
+
image = [image]
|
474 |
+
|
475 |
+
if isinstance(image[0], PIL.Image.Image):
|
476 |
+
images = []
|
477 |
+
|
478 |
+
for image_ in image:
|
479 |
+
image_ = image_.convert("RGB")
|
480 |
+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
481 |
+
image_ = np.array(image_)
|
482 |
+
image_ = image_[None, :]
|
483 |
+
images.append(image_)
|
484 |
+
|
485 |
+
image = images
|
486 |
+
|
487 |
+
image = np.concatenate(image, axis=0)
|
488 |
+
image = np.array(image).astype(np.float32) / 255.0
|
489 |
+
image = image.transpose(0, 3, 1, 2)
|
490 |
+
image = torch.from_numpy(image)
|
491 |
+
elif isinstance(image[0], torch.Tensor):
|
492 |
+
image = torch.cat(image, dim=0)
|
493 |
+
|
494 |
+
image_batch_size = image.shape[0]
|
495 |
+
|
496 |
+
if image_batch_size == 1:
|
497 |
+
repeat_by = batch_size
|
498 |
+
else:
|
499 |
+
# image batch size is the same as prompt batch size
|
500 |
+
repeat_by = num_images_per_prompt
|
501 |
+
|
502 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
503 |
+
|
504 |
+
image = image.to(device=device, dtype=dtype)
|
505 |
+
|
506 |
+
if do_classifier_free_guidance and not guess_mode:
|
507 |
+
image = torch.cat([image] * 2)
|
508 |
+
|
509 |
+
return image
|
510 |
+
|
511 |
+
|
512 |
+
class SdxlStableDiffusionLongPromptWeightingPipeline:
|
513 |
+
r"""
|
514 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
515 |
+
weighting in prompt.
|
516 |
+
|
517 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
518 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
519 |
+
|
520 |
+
Args:
|
521 |
+
vae ([`AutoencoderKL`]):
|
522 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
523 |
+
text_encoder ([`CLIPTextModel`]):
|
524 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
525 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
526 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
527 |
+
tokenizer (`CLIPTokenizer`):
|
528 |
+
Tokenizer of class
|
529 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
530 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
531 |
+
scheduler ([`SchedulerMixin`]):
|
532 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
533 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
534 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
535 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
536 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
537 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
538 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
539 |
+
"""
|
540 |
+
|
541 |
+
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
542 |
+
|
543 |
+
def __init__(
|
544 |
+
self,
|
545 |
+
vae: AutoencoderKL,
|
546 |
+
text_encoder: List[CLIPTextModel],
|
547 |
+
tokenizer: List[CLIPTokenizer],
|
548 |
+
unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet],
|
549 |
+
scheduler: SchedulerMixin,
|
550 |
+
# clip_skip: int,
|
551 |
+
safety_checker: StableDiffusionSafetyChecker,
|
552 |
+
feature_extractor: CLIPFeatureExtractor,
|
553 |
+
requires_safety_checker: bool = True,
|
554 |
+
clip_skip: int = 1,
|
555 |
+
):
|
556 |
+
# clip skip is ignored currently
|
557 |
+
self.tokenizer = tokenizer[0]
|
558 |
+
self.text_encoder = text_encoder[0]
|
559 |
+
self.unet = unet
|
560 |
+
self.scheduler = scheduler
|
561 |
+
self.safety_checker = safety_checker
|
562 |
+
self.feature_extractor = feature_extractor
|
563 |
+
self.requires_safety_checker = requires_safety_checker
|
564 |
+
self.vae = vae
|
565 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
566 |
+
self.progress_bar = lambda x: tqdm(x, leave=False)
|
567 |
+
|
568 |
+
self.clip_skip = clip_skip
|
569 |
+
self.tokenizers = tokenizer
|
570 |
+
self.text_encoders = text_encoder
|
571 |
+
|
572 |
+
# self.__init__additional__()
|
573 |
+
|
574 |
+
# def __init__additional__(self):
|
575 |
+
# if not hasattr(self, "vae_scale_factor"):
|
576 |
+
# setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
577 |
+
|
578 |
+
def to(self, device=None, dtype=None):
|
579 |
+
if device is not None:
|
580 |
+
self.device = device
|
581 |
+
# self.vae.to(device=self.device)
|
582 |
+
if dtype is not None:
|
583 |
+
self.dtype = dtype
|
584 |
+
|
585 |
+
# do not move Text Encoders to device, because Text Encoder should be on CPU
|
586 |
+
|
587 |
+
@property
|
588 |
+
def _execution_device(self):
|
589 |
+
r"""
|
590 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
591 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
592 |
+
hooks.
|
593 |
+
"""
|
594 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
595 |
+
return self.device
|
596 |
+
for module in self.unet.modules():
|
597 |
+
if (
|
598 |
+
hasattr(module, "_hf_hook")
|
599 |
+
and hasattr(module._hf_hook, "execution_device")
|
600 |
+
and module._hf_hook.execution_device is not None
|
601 |
+
):
|
602 |
+
return torch.device(module._hf_hook.execution_device)
|
603 |
+
return self.device
|
604 |
+
|
605 |
+
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
606 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
607 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
608 |
+
|
609 |
+
if strength < 0 or strength > 1:
|
610 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
611 |
+
|
612 |
+
if height % 8 != 0 or width % 8 != 0:
|
613 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
614 |
+
|
615 |
+
if (callback_steps is None) or (
|
616 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
617 |
+
):
|
618 |
+
raise ValueError(
|
619 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
620 |
+
)
|
621 |
+
|
622 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
623 |
+
if is_text2img:
|
624 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
625 |
+
else:
|
626 |
+
# get the original timestep using init_timestep
|
627 |
+
offset = self.scheduler.config.get("steps_offset", 0)
|
628 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
629 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
630 |
+
|
631 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
632 |
+
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
633 |
+
return timesteps, num_inference_steps - t_start
|
634 |
+
|
635 |
+
def run_safety_checker(self, image, device, dtype):
|
636 |
+
if self.safety_checker is not None:
|
637 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
638 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
639 |
+
else:
|
640 |
+
has_nsfw_concept = None
|
641 |
+
return image, has_nsfw_concept
|
642 |
+
|
643 |
+
def decode_latents(self, latents):
|
644 |
+
with torch.no_grad():
|
645 |
+
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
646 |
+
|
647 |
+
# print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32
|
648 |
+
# x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0)
|
649 |
+
# print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16
|
650 |
+
# self.vae.to("cpu")
|
651 |
+
# self.vae.set_use_memory_efficient_attention_xformers(False)
|
652 |
+
# image = self.vae.decode(latents.to("cpu")).sample
|
653 |
+
|
654 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
655 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
656 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
657 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
658 |
+
return image
|
659 |
+
|
660 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
661 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
662 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
663 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
664 |
+
# and should be between [0, 1]
|
665 |
+
|
666 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
667 |
+
extra_step_kwargs = {}
|
668 |
+
if accepts_eta:
|
669 |
+
extra_step_kwargs["eta"] = eta
|
670 |
+
|
671 |
+
# check if the scheduler accepts generator
|
672 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
673 |
+
if accepts_generator:
|
674 |
+
extra_step_kwargs["generator"] = generator
|
675 |
+
return extra_step_kwargs
|
676 |
+
|
677 |
+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
|
678 |
+
if image is None:
|
679 |
+
shape = (
|
680 |
+
batch_size,
|
681 |
+
self.unet.in_channels,
|
682 |
+
height // self.vae_scale_factor,
|
683 |
+
width // self.vae_scale_factor,
|
684 |
+
)
|
685 |
+
|
686 |
+
if latents is None:
|
687 |
+
if device.type == "mps":
|
688 |
+
# randn does not work reproducibly on mps
|
689 |
+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
690 |
+
else:
|
691 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
692 |
+
else:
|
693 |
+
if latents.shape != shape:
|
694 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
695 |
+
latents = latents.to(device)
|
696 |
+
|
697 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
698 |
+
latents = latents * self.scheduler.init_noise_sigma
|
699 |
+
return latents, None, None
|
700 |
+
else:
|
701 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
702 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
703 |
+
init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
|
704 |
+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
|
705 |
+
init_latents_orig = init_latents
|
706 |
+
shape = init_latents.shape
|
707 |
+
|
708 |
+
# add noise to latents using the timesteps
|
709 |
+
if device.type == "mps":
|
710 |
+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
711 |
+
else:
|
712 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
713 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
714 |
+
return latents, init_latents_orig, noise
|
715 |
+
|
716 |
+
@torch.no_grad()
|
717 |
+
def __call__(
|
718 |
+
self,
|
719 |
+
prompt: Union[str, List[str]],
|
720 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
721 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
722 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
723 |
+
height: int = 512,
|
724 |
+
width: int = 512,
|
725 |
+
num_inference_steps: int = 50,
|
726 |
+
guidance_scale: float = 7.5,
|
727 |
+
strength: float = 0.8,
|
728 |
+
num_images_per_prompt: Optional[int] = 1,
|
729 |
+
eta: float = 0.0,
|
730 |
+
generator: Optional[torch.Generator] = None,
|
731 |
+
latents: Optional[torch.FloatTensor] = None,
|
732 |
+
max_embeddings_multiples: Optional[int] = 3,
|
733 |
+
output_type: Optional[str] = "pil",
|
734 |
+
return_dict: bool = True,
|
735 |
+
controlnet: sdxl_original_control_net.SdxlControlNet = None,
|
736 |
+
controlnet_image=None,
|
737 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
738 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
739 |
+
callback_steps: int = 1,
|
740 |
+
):
|
741 |
+
r"""
|
742 |
+
Function invoked when calling the pipeline for generation.
|
743 |
+
|
744 |
+
Args:
|
745 |
+
prompt (`str` or `List[str]`):
|
746 |
+
The prompt or prompts to guide the image generation.
|
747 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
748 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
749 |
+
if `guidance_scale` is less than `1`).
|
750 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
751 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
752 |
+
process.
|
753 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
754 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
755 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
756 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
757 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
758 |
+
height (`int`, *optional*, defaults to 512):
|
759 |
+
The height in pixels of the generated image.
|
760 |
+
width (`int`, *optional*, defaults to 512):
|
761 |
+
The width in pixels of the generated image.
|
762 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
763 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
764 |
+
expense of slower inference.
|
765 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
766 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
767 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
768 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
769 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
770 |
+
usually at the expense of lower image quality.
|
771 |
+
strength (`float`, *optional*, defaults to 0.8):
|
772 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
773 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
774 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
775 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
776 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
777 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
778 |
+
The number of images to generate per prompt.
|
779 |
+
eta (`float`, *optional*, defaults to 0.0):
|
780 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
781 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
782 |
+
generator (`torch.Generator`, *optional*):
|
783 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
784 |
+
deterministic.
|
785 |
+
latents (`torch.FloatTensor`, *optional*):
|
786 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
787 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
788 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
789 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
790 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
791 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
792 |
+
The output format of the generate image. Choose between
|
793 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
794 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
795 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
796 |
+
plain tuple.
|
797 |
+
controlnet (`diffusers.ControlNetModel`, *optional*):
|
798 |
+
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
799 |
+
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
800 |
+
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
801 |
+
inference.
|
802 |
+
callback (`Callable`, *optional*):
|
803 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
804 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
805 |
+
is_cancelled_callback (`Callable`, *optional*):
|
806 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
807 |
+
`True`, the inference will be cancelled.
|
808 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
809 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
810 |
+
called at every step.
|
811 |
+
|
812 |
+
Returns:
|
813 |
+
`None` if cancelled by `is_cancelled_callback`,
|
814 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
815 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
816 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
817 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
818 |
+
(nsfw) content, according to the `safety_checker`.
|
819 |
+
"""
|
820 |
+
if controlnet is not None and controlnet_image is None:
|
821 |
+
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
822 |
+
|
823 |
+
# 0. Default height and width to unet
|
824 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
825 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
826 |
+
|
827 |
+
# 1. Check inputs. Raise error if not correct
|
828 |
+
self.check_inputs(prompt, height, width, strength, callback_steps)
|
829 |
+
|
830 |
+
# 2. Define call parameters
|
831 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
832 |
+
device = self._execution_device
|
833 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
834 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
835 |
+
# corresponds to doing no classifier free guidance.
|
836 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
837 |
+
|
838 |
+
# 3. Encode input prompt
|
839 |
+
tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
840 |
+
encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
841 |
+
|
842 |
+
text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt)
|
843 |
+
hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights(
|
844 |
+
tokenize_strategy, self.text_encoders, text_input_ids, text_weights
|
845 |
+
)
|
846 |
+
text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
|
847 |
+
|
848 |
+
if do_classifier_free_guidance:
|
849 |
+
input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "")
|
850 |
+
hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights(
|
851 |
+
tokenize_strategy, self.text_encoders, input_ids, weights
|
852 |
+
)
|
853 |
+
uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
|
854 |
+
else:
|
855 |
+
uncond_embeddings = None
|
856 |
+
uncond_pool = None
|
857 |
+
|
858 |
+
unet_dtype = self.unet.dtype
|
859 |
+
dtype = unet_dtype
|
860 |
+
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
|
861 |
+
dtype = torch.float16
|
862 |
+
self.unet.to(dtype)
|
863 |
+
|
864 |
+
# 4. Preprocess image and mask
|
865 |
+
if isinstance(image, PIL.Image.Image):
|
866 |
+
image = preprocess_image(image)
|
867 |
+
if image is not None:
|
868 |
+
image = image.to(device=self.device, dtype=dtype)
|
869 |
+
if isinstance(mask_image, PIL.Image.Image):
|
870 |
+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
871 |
+
if mask_image is not None:
|
872 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
873 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
874 |
+
else:
|
875 |
+
mask = None
|
876 |
+
|
877 |
+
# ControlNet is not working yet in SDXL, but keep the code here for future use
|
878 |
+
if controlnet_image is not None:
|
879 |
+
controlnet_image = prepare_controlnet_image(
|
880 |
+
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
881 |
+
)
|
882 |
+
|
883 |
+
# 5. set timesteps
|
884 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
885 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
886 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
887 |
+
|
888 |
+
# 6. Prepare latent variables
|
889 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
890 |
+
image,
|
891 |
+
latent_timestep,
|
892 |
+
batch_size * num_images_per_prompt,
|
893 |
+
height,
|
894 |
+
width,
|
895 |
+
dtype,
|
896 |
+
device,
|
897 |
+
generator,
|
898 |
+
latents,
|
899 |
+
)
|
900 |
+
|
901 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
902 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
903 |
+
|
904 |
+
# create size embs and concat embeddings for SDXL
|
905 |
+
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype)
|
906 |
+
crop_size = torch.zeros_like(orig_size)
|
907 |
+
target_size = orig_size
|
908 |
+
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype)
|
909 |
+
|
910 |
+
# make conditionings
|
911 |
+
text_pool = text_pool.to(device, dtype)
|
912 |
+
if do_classifier_free_guidance:
|
913 |
+
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype)
|
914 |
+
|
915 |
+
uncond_pool = uncond_pool.to(device, dtype)
|
916 |
+
cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype)
|
917 |
+
uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype)
|
918 |
+
vector_embedding = torch.cat([uncond_vector, cond_vector])
|
919 |
+
else:
|
920 |
+
text_embedding = text_embeddings.to(device, dtype)
|
921 |
+
vector_embedding = torch.cat([text_pool, embs], dim=1)
|
922 |
+
|
923 |
+
# 8. Denoising loop
|
924 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
925 |
+
# expand the latents if we are doing classifier free guidance
|
926 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
927 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
928 |
+
|
929 |
+
# FIXME SD1 ControlNet is not working
|
930 |
+
|
931 |
+
# predict the noise residual
|
932 |
+
if controlnet is not None:
|
933 |
+
input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image)
|
934 |
+
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add)
|
935 |
+
else:
|
936 |
+
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
937 |
+
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
938 |
+
|
939 |
+
# perform guidance
|
940 |
+
if do_classifier_free_guidance:
|
941 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
942 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
943 |
+
|
944 |
+
# compute the previous noisy sample x_t -> x_t-1
|
945 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
946 |
+
|
947 |
+
if mask is not None:
|
948 |
+
# masking
|
949 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
950 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
951 |
+
|
952 |
+
# call the callback, if provided
|
953 |
+
if i % callback_steps == 0:
|
954 |
+
if callback is not None:
|
955 |
+
callback(i, t, latents)
|
956 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
957 |
+
return None
|
958 |
+
|
959 |
+
self.unet.to(unet_dtype)
|
960 |
+
return latents
|
961 |
+
|
962 |
+
def latents_to_image(self, latents):
|
963 |
+
# 9. Post-processing
|
964 |
+
image = self.decode_latents(latents.to(self.vae.dtype))
|
965 |
+
image = self.numpy_to_pil(image)
|
966 |
+
return image
|
967 |
+
|
968 |
+
# copy from pil_utils.py
|
969 |
+
def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
|
970 |
+
"""
|
971 |
+
Convert a numpy image or a batch of images to a PIL image.
|
972 |
+
"""
|
973 |
+
if images.ndim == 3:
|
974 |
+
images = images[None, ...]
|
975 |
+
images = (images * 255).round().astype("uint8")
|
976 |
+
if images.shape[-1] == 1:
|
977 |
+
# special case for grayscale (single channel) images
|
978 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
979 |
+
else:
|
980 |
+
pil_images = [Image.fromarray(image) for image in images]
|
981 |
+
|
982 |
+
return pil_images
|
983 |
+
|
984 |
+
def text2img(
|
985 |
+
self,
|
986 |
+
prompt: Union[str, List[str]],
|
987 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
988 |
+
height: int = 512,
|
989 |
+
width: int = 512,
|
990 |
+
num_inference_steps: int = 50,
|
991 |
+
guidance_scale: float = 7.5,
|
992 |
+
num_images_per_prompt: Optional[int] = 1,
|
993 |
+
eta: float = 0.0,
|
994 |
+
generator: Optional[torch.Generator] = None,
|
995 |
+
latents: Optional[torch.FloatTensor] = None,
|
996 |
+
max_embeddings_multiples: Optional[int] = 3,
|
997 |
+
output_type: Optional[str] = "pil",
|
998 |
+
return_dict: bool = True,
|
999 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1000 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1001 |
+
callback_steps: int = 1,
|
1002 |
+
):
|
1003 |
+
r"""
|
1004 |
+
Function for text-to-image generation.
|
1005 |
+
Args:
|
1006 |
+
prompt (`str` or `List[str]`):
|
1007 |
+
The prompt or prompts to guide the image generation.
|
1008 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1009 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1010 |
+
if `guidance_scale` is less than `1`).
|
1011 |
+
height (`int`, *optional*, defaults to 512):
|
1012 |
+
The height in pixels of the generated image.
|
1013 |
+
width (`int`, *optional*, defaults to 512):
|
1014 |
+
The width in pixels of the generated image.
|
1015 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1016 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1017 |
+
expense of slower inference.
|
1018 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1019 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1020 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1021 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1022 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1023 |
+
usually at the expense of lower image quality.
|
1024 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1025 |
+
The number of images to generate per prompt.
|
1026 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1027 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1028 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1029 |
+
generator (`torch.Generator`, *optional*):
|
1030 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1031 |
+
deterministic.
|
1032 |
+
latents (`torch.FloatTensor`, *optional*):
|
1033 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1034 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1035 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1036 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1037 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1038 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1039 |
+
The output format of the generate image. Choose between
|
1040 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1041 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1042 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1043 |
+
plain tuple.
|
1044 |
+
callback (`Callable`, *optional*):
|
1045 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1046 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1047 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1048 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1049 |
+
`True`, the inference will be cancelled.
|
1050 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1051 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1052 |
+
called at every step.
|
1053 |
+
Returns:
|
1054 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1055 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1056 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1057 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1058 |
+
(nsfw) content, according to the `safety_checker`.
|
1059 |
+
"""
|
1060 |
+
return self.__call__(
|
1061 |
+
prompt=prompt,
|
1062 |
+
negative_prompt=negative_prompt,
|
1063 |
+
height=height,
|
1064 |
+
width=width,
|
1065 |
+
num_inference_steps=num_inference_steps,
|
1066 |
+
guidance_scale=guidance_scale,
|
1067 |
+
num_images_per_prompt=num_images_per_prompt,
|
1068 |
+
eta=eta,
|
1069 |
+
generator=generator,
|
1070 |
+
latents=latents,
|
1071 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1072 |
+
output_type=output_type,
|
1073 |
+
return_dict=return_dict,
|
1074 |
+
callback=callback,
|
1075 |
+
is_cancelled_callback=is_cancelled_callback,
|
1076 |
+
callback_steps=callback_steps,
|
1077 |
+
)
|
1078 |
+
|
1079 |
+
def img2img(
|
1080 |
+
self,
|
1081 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1082 |
+
prompt: Union[str, List[str]],
|
1083 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1084 |
+
strength: float = 0.8,
|
1085 |
+
num_inference_steps: Optional[int] = 50,
|
1086 |
+
guidance_scale: Optional[float] = 7.5,
|
1087 |
+
num_images_per_prompt: Optional[int] = 1,
|
1088 |
+
eta: Optional[float] = 0.0,
|
1089 |
+
generator: Optional[torch.Generator] = None,
|
1090 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1091 |
+
output_type: Optional[str] = "pil",
|
1092 |
+
return_dict: bool = True,
|
1093 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1094 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1095 |
+
callback_steps: int = 1,
|
1096 |
+
):
|
1097 |
+
r"""
|
1098 |
+
Function for image-to-image generation.
|
1099 |
+
Args:
|
1100 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1101 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1102 |
+
process.
|
1103 |
+
prompt (`str` or `List[str]`):
|
1104 |
+
The prompt or prompts to guide the image generation.
|
1105 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1106 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1107 |
+
if `guidance_scale` is less than `1`).
|
1108 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1109 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
1110 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
1111 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
1112 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
1113 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
1114 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1115 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1116 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
1117 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1118 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1119 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1120 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1121 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1122 |
+
usually at the expense of lower image quality.
|
1123 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1124 |
+
The number of images to generate per prompt.
|
1125 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1126 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1127 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1128 |
+
generator (`torch.Generator`, *optional*):
|
1129 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1130 |
+
deterministic.
|
1131 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1132 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1133 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1134 |
+
The output format of the generate image. Choose between
|
1135 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1136 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1137 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1138 |
+
plain tuple.
|
1139 |
+
callback (`Callable`, *optional*):
|
1140 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1141 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1142 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1143 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1144 |
+
`True`, the inference will be cancelled.
|
1145 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1146 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1147 |
+
called at every step.
|
1148 |
+
Returns:
|
1149 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1150 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1151 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1152 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1153 |
+
(nsfw) content, according to the `safety_checker`.
|
1154 |
+
"""
|
1155 |
+
return self.__call__(
|
1156 |
+
prompt=prompt,
|
1157 |
+
negative_prompt=negative_prompt,
|
1158 |
+
image=image,
|
1159 |
+
num_inference_steps=num_inference_steps,
|
1160 |
+
guidance_scale=guidance_scale,
|
1161 |
+
strength=strength,
|
1162 |
+
num_images_per_prompt=num_images_per_prompt,
|
1163 |
+
eta=eta,
|
1164 |
+
generator=generator,
|
1165 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1166 |
+
output_type=output_type,
|
1167 |
+
return_dict=return_dict,
|
1168 |
+
callback=callback,
|
1169 |
+
is_cancelled_callback=is_cancelled_callback,
|
1170 |
+
callback_steps=callback_steps,
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
def inpaint(
|
1174 |
+
self,
|
1175 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1176 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
1177 |
+
prompt: Union[str, List[str]],
|
1178 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1179 |
+
strength: float = 0.8,
|
1180 |
+
num_inference_steps: Optional[int] = 50,
|
1181 |
+
guidance_scale: Optional[float] = 7.5,
|
1182 |
+
num_images_per_prompt: Optional[int] = 1,
|
1183 |
+
eta: Optional[float] = 0.0,
|
1184 |
+
generator: Optional[torch.Generator] = None,
|
1185 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1186 |
+
output_type: Optional[str] = "pil",
|
1187 |
+
return_dict: bool = True,
|
1188 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1189 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1190 |
+
callback_steps: int = 1,
|
1191 |
+
):
|
1192 |
+
r"""
|
1193 |
+
Function for inpaint.
|
1194 |
+
Args:
|
1195 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1196 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1197 |
+
process. This is the image whose masked region will be inpainted.
|
1198 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1199 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
1200 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
1201 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
1202 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
1203 |
+
prompt (`str` or `List[str]`):
|
1204 |
+
The prompt or prompts to guide the image generation.
|
1205 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1206 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1207 |
+
if `guidance_scale` is less than `1`).
|
1208 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1209 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
1210 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
1211 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
1212 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
1213 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1214 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
1215 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
1216 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1217 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1218 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1219 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1220 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1221 |
+
usually at the expense of lower image quality.
|
1222 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1223 |
+
The number of images to generate per prompt.
|
1224 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1225 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1226 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1227 |
+
generator (`torch.Generator`, *optional*):
|
1228 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
1229 |
+
deterministic.
|
1230 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1231 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1232 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1233 |
+
The output format of the generate image. Choose between
|
1234 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1235 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1236 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1237 |
+
plain tuple.
|
1238 |
+
callback (`Callable`, *optional*):
|
1239 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1240 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1241 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1242 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1243 |
+
`True`, the inference will be cancelled.
|
1244 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1245 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1246 |
+
called at every step.
|
1247 |
+
Returns:
|
1248 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1249 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1250 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1251 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1252 |
+
(nsfw) content, according to the `safety_checker`.
|
1253 |
+
"""
|
1254 |
+
return self.__call__(
|
1255 |
+
prompt=prompt,
|
1256 |
+
negative_prompt=negative_prompt,
|
1257 |
+
image=image,
|
1258 |
+
mask_image=mask_image,
|
1259 |
+
num_inference_steps=num_inference_steps,
|
1260 |
+
guidance_scale=guidance_scale,
|
1261 |
+
strength=strength,
|
1262 |
+
num_images_per_prompt=num_images_per_prompt,
|
1263 |
+
eta=eta,
|
1264 |
+
generator=generator,
|
1265 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1266 |
+
output_type=output_type,
|
1267 |
+
return_dict=return_dict,
|
1268 |
+
callback=callback,
|
1269 |
+
is_cancelled_callback=is_cancelled_callback,
|
1270 |
+
callback_steps=callback_steps,
|
1271 |
+
)
|
library/sdxl_model_util.py
ADDED
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import safetensors
|
3 |
+
from accelerate import init_empty_weights
|
4 |
+
from accelerate.utils.modeling import set_module_tensor_to_device
|
5 |
+
from safetensors.torch import load_file, save_file
|
6 |
+
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
7 |
+
from typing import List
|
8 |
+
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
9 |
+
from library import model_util
|
10 |
+
from library import sdxl_original_unet
|
11 |
+
from library.utils import setup_logging
|
12 |
+
|
13 |
+
setup_logging()
|
14 |
+
import logging
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
VAE_SCALE_FACTOR = 0.13025
|
19 |
+
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
20 |
+
|
21 |
+
# Diffusersの設定を読み込むための参照モデル
|
22 |
+
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
|
23 |
+
|
24 |
+
DIFFUSERS_SDXL_UNET_CONFIG = {
|
25 |
+
"act_fn": "silu",
|
26 |
+
"addition_embed_type": "text_time",
|
27 |
+
"addition_embed_type_num_heads": 64,
|
28 |
+
"addition_time_embed_dim": 256,
|
29 |
+
"attention_head_dim": [5, 10, 20],
|
30 |
+
"block_out_channels": [320, 640, 1280],
|
31 |
+
"center_input_sample": False,
|
32 |
+
"class_embed_type": None,
|
33 |
+
"class_embeddings_concat": False,
|
34 |
+
"conv_in_kernel": 3,
|
35 |
+
"conv_out_kernel": 3,
|
36 |
+
"cross_attention_dim": 2048,
|
37 |
+
"cross_attention_norm": None,
|
38 |
+
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
|
39 |
+
"downsample_padding": 1,
|
40 |
+
"dual_cross_attention": False,
|
41 |
+
"encoder_hid_dim": None,
|
42 |
+
"encoder_hid_dim_type": None,
|
43 |
+
"flip_sin_to_cos": True,
|
44 |
+
"freq_shift": 0,
|
45 |
+
"in_channels": 4,
|
46 |
+
"layers_per_block": 2,
|
47 |
+
"mid_block_only_cross_attention": None,
|
48 |
+
"mid_block_scale_factor": 1,
|
49 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
50 |
+
"norm_eps": 1e-05,
|
51 |
+
"norm_num_groups": 32,
|
52 |
+
"num_attention_heads": None,
|
53 |
+
"num_class_embeds": None,
|
54 |
+
"only_cross_attention": False,
|
55 |
+
"out_channels": 4,
|
56 |
+
"projection_class_embeddings_input_dim": 2816,
|
57 |
+
"resnet_out_scale_factor": 1.0,
|
58 |
+
"resnet_skip_time_act": False,
|
59 |
+
"resnet_time_scale_shift": "default",
|
60 |
+
"sample_size": 128,
|
61 |
+
"time_cond_proj_dim": None,
|
62 |
+
"time_embedding_act_fn": None,
|
63 |
+
"time_embedding_dim": None,
|
64 |
+
"time_embedding_type": "positional",
|
65 |
+
"timestep_post_act": None,
|
66 |
+
"transformer_layers_per_block": [1, 2, 10],
|
67 |
+
"up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
|
68 |
+
"upcast_attention": False,
|
69 |
+
"use_linear_projection": True,
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
74 |
+
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
75 |
+
|
76 |
+
# SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
|
77 |
+
# logit_scaleはcheckpointの保存時に使用する
|
78 |
+
def convert_key(key):
|
79 |
+
# common conversion
|
80 |
+
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
|
81 |
+
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
|
82 |
+
|
83 |
+
if "resblocks" in key:
|
84 |
+
# resblocks conversion
|
85 |
+
key = key.replace(".resblocks.", ".layers.")
|
86 |
+
if ".ln_" in key:
|
87 |
+
key = key.replace(".ln_", ".layer_norm")
|
88 |
+
elif ".mlp." in key:
|
89 |
+
key = key.replace(".c_fc.", ".fc1.")
|
90 |
+
key = key.replace(".c_proj.", ".fc2.")
|
91 |
+
elif ".attn.out_proj" in key:
|
92 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
93 |
+
elif ".attn.in_proj" in key:
|
94 |
+
key = None # 特殊なので後で処理する
|
95 |
+
else:
|
96 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
97 |
+
elif ".positional_embedding" in key:
|
98 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
99 |
+
elif ".text_projection" in key:
|
100 |
+
key = key.replace("text_model.text_projection", "text_projection.weight")
|
101 |
+
elif ".logit_scale" in key:
|
102 |
+
key = None # 後で処理する
|
103 |
+
elif ".token_embedding" in key:
|
104 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
105 |
+
elif ".ln_final" in key:
|
106 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
107 |
+
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
|
108 |
+
elif ".embeddings.position_ids" in key:
|
109 |
+
key = None # remove this key: position_ids is not used in newer transformers
|
110 |
+
return key
|
111 |
+
|
112 |
+
keys = list(checkpoint.keys())
|
113 |
+
new_sd = {}
|
114 |
+
for key in keys:
|
115 |
+
new_key = convert_key(key)
|
116 |
+
if new_key is None:
|
117 |
+
continue
|
118 |
+
new_sd[new_key] = checkpoint[key]
|
119 |
+
|
120 |
+
# attnの変換
|
121 |
+
for key in keys:
|
122 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
123 |
+
# 三つに分割
|
124 |
+
values = torch.chunk(checkpoint[key], 3)
|
125 |
+
|
126 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
127 |
+
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
|
128 |
+
key_pfx = key_pfx.replace("_weight", "")
|
129 |
+
key_pfx = key_pfx.replace("_bias", "")
|
130 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
131 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
132 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
133 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
134 |
+
|
135 |
+
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
136 |
+
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
137 |
+
|
138 |
+
# temporary workaround for text_projection.weight.weight for Playground-v2
|
139 |
+
if "text_projection.weight.weight" in new_sd:
|
140 |
+
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
141 |
+
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
142 |
+
del new_sd["text_projection.weight.weight"]
|
143 |
+
|
144 |
+
return new_sd, logit_scale
|
145 |
+
|
146 |
+
|
147 |
+
# load state_dict without allocating new tensors
|
148 |
+
def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
149 |
+
# dtype will use fp32 as default
|
150 |
+
missing_keys = list(model.state_dict().keys() - state_dict.keys())
|
151 |
+
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
|
152 |
+
|
153 |
+
# similar to model.load_state_dict()
|
154 |
+
if not missing_keys and not unexpected_keys:
|
155 |
+
for k in list(state_dict.keys()):
|
156 |
+
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
|
157 |
+
return "<All keys matched successfully>"
|
158 |
+
|
159 |
+
# error_msgs
|
160 |
+
error_msgs: List[str] = []
|
161 |
+
if missing_keys:
|
162 |
+
error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
|
163 |
+
if unexpected_keys:
|
164 |
+
error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
|
165 |
+
|
166 |
+
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
167 |
+
|
168 |
+
|
169 |
+
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
|
170 |
+
# model_version is reserved for future use
|
171 |
+
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
172 |
+
|
173 |
+
# Load the state dict
|
174 |
+
if model_util.is_safetensors(ckpt_path):
|
175 |
+
checkpoint = None
|
176 |
+
if disable_mmap:
|
177 |
+
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
|
178 |
+
else:
|
179 |
+
try:
|
180 |
+
state_dict = load_file(ckpt_path, device=map_location)
|
181 |
+
except:
|
182 |
+
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
183 |
+
epoch = None
|
184 |
+
global_step = None
|
185 |
+
else:
|
186 |
+
checkpoint = torch.load(ckpt_path, map_location=map_location)
|
187 |
+
if "state_dict" in checkpoint:
|
188 |
+
state_dict = checkpoint["state_dict"]
|
189 |
+
epoch = checkpoint.get("epoch", 0)
|
190 |
+
global_step = checkpoint.get("global_step", 0)
|
191 |
+
else:
|
192 |
+
state_dict = checkpoint
|
193 |
+
epoch = 0
|
194 |
+
global_step = 0
|
195 |
+
checkpoint = None
|
196 |
+
|
197 |
+
# U-Net
|
198 |
+
logger.info("building U-Net")
|
199 |
+
with init_empty_weights():
|
200 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
201 |
+
|
202 |
+
logger.info("loading U-Net from checkpoint")
|
203 |
+
unet_sd = {}
|
204 |
+
for k in list(state_dict.keys()):
|
205 |
+
if k.startswith("model.diffusion_model."):
|
206 |
+
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
207 |
+
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
208 |
+
logger.info(f"U-Net: {info}")
|
209 |
+
|
210 |
+
# Text Encoders
|
211 |
+
logger.info("building text encoders")
|
212 |
+
|
213 |
+
# Text Encoder 1 is same to Stability AI's SDXL
|
214 |
+
text_model1_cfg = CLIPTextConfig(
|
215 |
+
vocab_size=49408,
|
216 |
+
hidden_size=768,
|
217 |
+
intermediate_size=3072,
|
218 |
+
num_hidden_layers=12,
|
219 |
+
num_attention_heads=12,
|
220 |
+
max_position_embeddings=77,
|
221 |
+
hidden_act="quick_gelu",
|
222 |
+
layer_norm_eps=1e-05,
|
223 |
+
dropout=0.0,
|
224 |
+
attention_dropout=0.0,
|
225 |
+
initializer_range=0.02,
|
226 |
+
initializer_factor=1.0,
|
227 |
+
pad_token_id=1,
|
228 |
+
bos_token_id=0,
|
229 |
+
eos_token_id=2,
|
230 |
+
model_type="clip_text_model",
|
231 |
+
projection_dim=768,
|
232 |
+
# torch_dtype="float32",
|
233 |
+
# transformers_version="4.25.0.dev0",
|
234 |
+
)
|
235 |
+
with init_empty_weights():
|
236 |
+
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
237 |
+
|
238 |
+
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
239 |
+
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
240 |
+
text_model2_cfg = CLIPTextConfig(
|
241 |
+
vocab_size=49408,
|
242 |
+
hidden_size=1280,
|
243 |
+
intermediate_size=5120,
|
244 |
+
num_hidden_layers=32,
|
245 |
+
num_attention_heads=20,
|
246 |
+
max_position_embeddings=77,
|
247 |
+
hidden_act="gelu",
|
248 |
+
layer_norm_eps=1e-05,
|
249 |
+
dropout=0.0,
|
250 |
+
attention_dropout=0.0,
|
251 |
+
initializer_range=0.02,
|
252 |
+
initializer_factor=1.0,
|
253 |
+
pad_token_id=1,
|
254 |
+
bos_token_id=0,
|
255 |
+
eos_token_id=2,
|
256 |
+
model_type="clip_text_model",
|
257 |
+
projection_dim=1280,
|
258 |
+
# torch_dtype="float32",
|
259 |
+
# transformers_version="4.25.0.dev0",
|
260 |
+
)
|
261 |
+
with init_empty_weights():
|
262 |
+
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
263 |
+
|
264 |
+
logger.info("loading text encoders from checkpoint")
|
265 |
+
te1_sd = {}
|
266 |
+
te2_sd = {}
|
267 |
+
for k in list(state_dict.keys()):
|
268 |
+
if k.startswith("conditioner.embedders.0.transformer."):
|
269 |
+
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
270 |
+
elif k.startswith("conditioner.embedders.1.model."):
|
271 |
+
te2_sd[k] = state_dict.pop(k)
|
272 |
+
|
273 |
+
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
|
274 |
+
if "text_model.embeddings.position_ids" in te1_sd:
|
275 |
+
te1_sd.pop("text_model.embeddings.position_ids")
|
276 |
+
|
277 |
+
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
278 |
+
logger.info(f"text encoder 1: {info1}")
|
279 |
+
|
280 |
+
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
281 |
+
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
282 |
+
logger.info(f"text encoder 2: {info2}")
|
283 |
+
|
284 |
+
# prepare vae
|
285 |
+
logger.info("building VAE")
|
286 |
+
vae_config = model_util.create_vae_diffusers_config()
|
287 |
+
with init_empty_weights():
|
288 |
+
vae = AutoencoderKL(**vae_config)
|
289 |
+
|
290 |
+
logger.info("loading VAE from checkpoint")
|
291 |
+
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
292 |
+
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
293 |
+
logger.info(f"VAE: {info}")
|
294 |
+
|
295 |
+
ckpt_info = (epoch, global_step) if epoch is not None else None
|
296 |
+
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
297 |
+
|
298 |
+
|
299 |
+
def make_unet_conversion_map():
|
300 |
+
unet_conversion_map_layer = []
|
301 |
+
|
302 |
+
for i in range(3): # num_blocks is 3 in sdxl
|
303 |
+
# loop over downblocks/upblocks
|
304 |
+
for j in range(2):
|
305 |
+
# loop over resnets/attentions for downblocks
|
306 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
307 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
308 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
309 |
+
|
310 |
+
if i < 3:
|
311 |
+
# no attention layers in down_blocks.3
|
312 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
313 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
314 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
315 |
+
|
316 |
+
for j in range(3):
|
317 |
+
# loop over resnets/attentions for upblocks
|
318 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
319 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
320 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
321 |
+
|
322 |
+
# if i > 0: commentout for sdxl
|
323 |
+
# no attention layers in up_blocks.0
|
324 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
325 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
326 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
327 |
+
|
328 |
+
if i < 3:
|
329 |
+
# no downsample in down_blocks.3
|
330 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
331 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
332 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
333 |
+
|
334 |
+
# no upsample in up_blocks.3
|
335 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
336 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
337 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
338 |
+
|
339 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
340 |
+
sd_mid_atn_prefix = "middle_block.1."
|
341 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
342 |
+
|
343 |
+
for j in range(2):
|
344 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
345 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
346 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
347 |
+
|
348 |
+
unet_conversion_map_resnet = [
|
349 |
+
# (stable-diffusion, HF Diffusers)
|
350 |
+
("in_layers.0.", "norm1."),
|
351 |
+
("in_layers.2.", "conv1."),
|
352 |
+
("out_layers.0.", "norm2."),
|
353 |
+
("out_layers.3.", "conv2."),
|
354 |
+
("emb_layers.1.", "time_emb_proj."),
|
355 |
+
("skip_connection.", "conv_shortcut."),
|
356 |
+
]
|
357 |
+
|
358 |
+
unet_conversion_map = []
|
359 |
+
for sd, hf in unet_conversion_map_layer:
|
360 |
+
if "resnets" in hf:
|
361 |
+
for sd_res, hf_res in unet_conversion_map_resnet:
|
362 |
+
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
363 |
+
else:
|
364 |
+
unet_conversion_map.append((sd, hf))
|
365 |
+
|
366 |
+
for j in range(2):
|
367 |
+
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
368 |
+
sd_time_embed_prefix = f"time_embed.{j*2}."
|
369 |
+
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
370 |
+
|
371 |
+
for j in range(2):
|
372 |
+
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
373 |
+
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
374 |
+
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
375 |
+
|
376 |
+
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
377 |
+
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
378 |
+
unet_conversion_map.append(("out.2.", "conv_out."))
|
379 |
+
|
380 |
+
return unet_conversion_map
|
381 |
+
|
382 |
+
|
383 |
+
def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
|
384 |
+
unet_conversion_map = make_unet_conversion_map()
|
385 |
+
|
386 |
+
conversion_map = {hf: sd for sd, hf in unet_conversion_map}
|
387 |
+
return convert_unet_state_dict(du_sd, conversion_map)
|
388 |
+
|
389 |
+
|
390 |
+
def convert_unet_state_dict(src_sd, conversion_map):
|
391 |
+
converted_sd = {}
|
392 |
+
for src_key, value in src_sd.items():
|
393 |
+
# さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
|
394 |
+
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
|
395 |
+
while len(src_key_fragments) > 0:
|
396 |
+
src_key_prefix = ".".join(src_key_fragments) + "."
|
397 |
+
if src_key_prefix in conversion_map:
|
398 |
+
converted_prefix = conversion_map[src_key_prefix]
|
399 |
+
converted_key = converted_prefix + src_key[len(src_key_prefix) :]
|
400 |
+
converted_sd[converted_key] = value
|
401 |
+
break
|
402 |
+
src_key_fragments.pop(-1)
|
403 |
+
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
|
404 |
+
|
405 |
+
return converted_sd
|
406 |
+
|
407 |
+
|
408 |
+
def convert_sdxl_unet_state_dict_to_diffusers(sd):
|
409 |
+
unet_conversion_map = make_unet_conversion_map()
|
410 |
+
|
411 |
+
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
|
412 |
+
return convert_unet_state_dict(sd, conversion_dict)
|
413 |
+
|
414 |
+
|
415 |
+
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
416 |
+
def convert_key(key):
|
417 |
+
# position_idsの除去
|
418 |
+
if ".position_ids" in key:
|
419 |
+
return None
|
420 |
+
|
421 |
+
# common
|
422 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
423 |
+
key = key.replace("text_model.", "")
|
424 |
+
if "layers" in key:
|
425 |
+
# resblocks conversion
|
426 |
+
key = key.replace(".layers.", ".resblocks.")
|
427 |
+
if ".layer_norm" in key:
|
428 |
+
key = key.replace(".layer_norm", ".ln_")
|
429 |
+
elif ".mlp." in key:
|
430 |
+
key = key.replace(".fc1.", ".c_fc.")
|
431 |
+
key = key.replace(".fc2.", ".c_proj.")
|
432 |
+
elif ".self_attn.out_proj" in key:
|
433 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
434 |
+
elif ".self_attn." in key:
|
435 |
+
key = None # 特殊なので後で処理する
|
436 |
+
else:
|
437 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
438 |
+
elif ".position_embedding" in key:
|
439 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
440 |
+
elif ".token_embedding" in key:
|
441 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
442 |
+
elif "text_projection" in key: # no dot in key
|
443 |
+
key = key.replace("text_projection.weight", "text_projection")
|
444 |
+
elif "final_layer_norm" in key:
|
445 |
+
key = key.replace("final_layer_norm", "ln_final")
|
446 |
+
return key
|
447 |
+
|
448 |
+
keys = list(checkpoint.keys())
|
449 |
+
new_sd = {}
|
450 |
+
for key in keys:
|
451 |
+
new_key = convert_key(key)
|
452 |
+
if new_key is None:
|
453 |
+
continue
|
454 |
+
new_sd[new_key] = checkpoint[key]
|
455 |
+
|
456 |
+
# attnの変換
|
457 |
+
for key in keys:
|
458 |
+
if "layers" in key and "q_proj" in key:
|
459 |
+
# 三つを結合
|
460 |
+
key_q = key
|
461 |
+
key_k = key.replace("q_proj", "k_proj")
|
462 |
+
key_v = key.replace("q_proj", "v_proj")
|
463 |
+
|
464 |
+
value_q = checkpoint[key_q]
|
465 |
+
value_k = checkpoint[key_k]
|
466 |
+
value_v = checkpoint[key_v]
|
467 |
+
value = torch.cat([value_q, value_k, value_v])
|
468 |
+
|
469 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
470 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
471 |
+
new_sd[new_key] = value
|
472 |
+
|
473 |
+
if logit_scale is not None:
|
474 |
+
new_sd["logit_scale"] = logit_scale
|
475 |
+
|
476 |
+
return new_sd
|
477 |
+
|
478 |
+
|
479 |
+
def save_stable_diffusion_checkpoint(
|
480 |
+
output_file,
|
481 |
+
text_encoder1,
|
482 |
+
text_encoder2,
|
483 |
+
unet,
|
484 |
+
epochs,
|
485 |
+
steps,
|
486 |
+
ckpt_info,
|
487 |
+
vae,
|
488 |
+
logit_scale,
|
489 |
+
metadata,
|
490 |
+
save_dtype=None,
|
491 |
+
):
|
492 |
+
state_dict = {}
|
493 |
+
|
494 |
+
def update_sd(prefix, sd):
|
495 |
+
for k, v in sd.items():
|
496 |
+
key = prefix + k
|
497 |
+
if save_dtype is not None:
|
498 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
499 |
+
state_dict[key] = v
|
500 |
+
|
501 |
+
# Convert the UNet model
|
502 |
+
update_sd("model.diffusion_model.", unet.state_dict())
|
503 |
+
|
504 |
+
# Convert the text encoders
|
505 |
+
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
506 |
+
|
507 |
+
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
|
508 |
+
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
509 |
+
|
510 |
+
# Convert the VAE
|
511 |
+
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
|
512 |
+
update_sd("first_stage_model.", vae_dict)
|
513 |
+
|
514 |
+
# Put together new checkpoint
|
515 |
+
key_count = len(state_dict.keys())
|
516 |
+
new_ckpt = {"state_dict": state_dict}
|
517 |
+
|
518 |
+
# epoch and global_step are sometimes not int
|
519 |
+
if ckpt_info is not None:
|
520 |
+
epochs += ckpt_info[0]
|
521 |
+
steps += ckpt_info[1]
|
522 |
+
|
523 |
+
new_ckpt["epoch"] = epochs
|
524 |
+
new_ckpt["global_step"] = steps
|
525 |
+
|
526 |
+
if model_util.is_safetensors(output_file):
|
527 |
+
save_file(state_dict, output_file, metadata)
|
528 |
+
else:
|
529 |
+
torch.save(new_ckpt, output_file)
|
530 |
+
|
531 |
+
return key_count
|
532 |
+
|
533 |
+
|
534 |
+
def save_diffusers_checkpoint(
|
535 |
+
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
|
536 |
+
):
|
537 |
+
from diffusers import StableDiffusionXLPipeline
|
538 |
+
|
539 |
+
# convert U-Net
|
540 |
+
unet_sd = unet.state_dict()
|
541 |
+
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
|
542 |
+
|
543 |
+
diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
|
544 |
+
if save_dtype is not None:
|
545 |
+
diffusers_unet.to(save_dtype)
|
546 |
+
diffusers_unet.load_state_dict(du_unet_sd)
|
547 |
+
|
548 |
+
# create pipeline to save
|
549 |
+
if pretrained_model_name_or_path is None:
|
550 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
|
551 |
+
|
552 |
+
scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
553 |
+
tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
554 |
+
tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
|
555 |
+
if vae is None:
|
556 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
557 |
+
|
558 |
+
# prevent local path from being saved
|
559 |
+
def remove_name_or_path(model):
|
560 |
+
if hasattr(model, "config"):
|
561 |
+
model.config._name_or_path = None
|
562 |
+
model.config._name_or_path = None
|
563 |
+
|
564 |
+
remove_name_or_path(diffusers_unet)
|
565 |
+
remove_name_or_path(text_encoder1)
|
566 |
+
remove_name_or_path(text_encoder2)
|
567 |
+
remove_name_or_path(scheduler)
|
568 |
+
remove_name_or_path(tokenizer1)
|
569 |
+
remove_name_or_path(tokenizer2)
|
570 |
+
remove_name_or_path(vae)
|
571 |
+
|
572 |
+
pipeline = StableDiffusionXLPipeline(
|
573 |
+
unet=diffusers_unet,
|
574 |
+
text_encoder=text_encoder1,
|
575 |
+
text_encoder_2=text_encoder2,
|
576 |
+
vae=vae,
|
577 |
+
scheduler=scheduler,
|
578 |
+
tokenizer=tokenizer1,
|
579 |
+
tokenizer_2=tokenizer2,
|
580 |
+
)
|
581 |
+
if save_dtype is not None:
|
582 |
+
pipeline.to(None, save_dtype)
|
583 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
library/sdxl_original_control_net.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# some parts are modified from Diffusers library (Apache License 2.0)
|
2 |
+
|
3 |
+
import math
|
4 |
+
from types import SimpleNamespace
|
5 |
+
from typing import Any, Optional
|
6 |
+
import torch
|
7 |
+
import torch.utils.checkpoint
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from einops import rearrange
|
11 |
+
from library.utils import setup_logging
|
12 |
+
|
13 |
+
setup_logging()
|
14 |
+
import logging
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
from library import sdxl_original_unet
|
19 |
+
from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl
|
20 |
+
|
21 |
+
|
22 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
dims = [16, 32, 96, 256]
|
27 |
+
|
28 |
+
self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1)
|
29 |
+
self.blocks = nn.ModuleList([])
|
30 |
+
|
31 |
+
for i in range(len(dims) - 1):
|
32 |
+
channel_in = dims[i]
|
33 |
+
channel_out = dims[i + 1]
|
34 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
35 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
36 |
+
|
37 |
+
self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1)
|
38 |
+
nn.init.zeros_(self.conv_out.weight) # zero module weight
|
39 |
+
nn.init.zeros_(self.conv_out.bias) # zero module bias
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.conv_in(x)
|
43 |
+
x = F.silu(x)
|
44 |
+
for block in self.blocks:
|
45 |
+
x = block(x)
|
46 |
+
x = F.silu(x)
|
47 |
+
x = self.conv_out(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel):
|
52 |
+
def __init__(self, multiplier: Optional[float] = None, **kwargs):
|
53 |
+
super().__init__(**kwargs)
|
54 |
+
self.multiplier = multiplier
|
55 |
+
|
56 |
+
# remove unet layers
|
57 |
+
self.output_blocks = nn.ModuleList([])
|
58 |
+
del self.out
|
59 |
+
|
60 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding()
|
61 |
+
|
62 |
+
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280]
|
63 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
64 |
+
for dim in dims:
|
65 |
+
self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1))
|
66 |
+
nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight
|
67 |
+
nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias
|
68 |
+
|
69 |
+
self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1)
|
70 |
+
nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight
|
71 |
+
nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias
|
72 |
+
|
73 |
+
def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel):
|
74 |
+
unet_sd = unet.state_dict()
|
75 |
+
unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")}
|
76 |
+
sd = super().state_dict()
|
77 |
+
sd.update(unet_sd)
|
78 |
+
info = super().load_state_dict(sd, strict=True, assign=True)
|
79 |
+
return info
|
80 |
+
|
81 |
+
def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any:
|
82 |
+
# convert state_dict to SAI format
|
83 |
+
unet_sd = {}
|
84 |
+
for k in list(state_dict.keys()):
|
85 |
+
if not k.startswith("controlnet_"):
|
86 |
+
unet_sd[k] = state_dict.pop(k)
|
87 |
+
unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd)
|
88 |
+
state_dict.update(unet_sd)
|
89 |
+
super().load_state_dict(state_dict, strict=strict, assign=assign)
|
90 |
+
|
91 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
92 |
+
# convert state_dict to Diffusers format
|
93 |
+
state_dict = super().state_dict(destination, prefix, keep_vars)
|
94 |
+
control_net_sd = {}
|
95 |
+
for k in list(state_dict.keys()):
|
96 |
+
if k.startswith("controlnet_"):
|
97 |
+
control_net_sd[k] = state_dict.pop(k)
|
98 |
+
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
|
99 |
+
state_dict.update(control_net_sd)
|
100 |
+
return state_dict
|
101 |
+
|
102 |
+
def forward(
|
103 |
+
self,
|
104 |
+
x: torch.Tensor,
|
105 |
+
timesteps: Optional[torch.Tensor] = None,
|
106 |
+
context: Optional[torch.Tensor] = None,
|
107 |
+
y: Optional[torch.Tensor] = None,
|
108 |
+
cond_image: Optional[torch.Tensor] = None,
|
109 |
+
**kwargs,
|
110 |
+
) -> torch.Tensor:
|
111 |
+
# broadcast timesteps to batch dimension
|
112 |
+
timesteps = timesteps.expand(x.shape[0])
|
113 |
+
|
114 |
+
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
|
115 |
+
t_emb = t_emb.to(x.dtype)
|
116 |
+
emb = self.time_embed(t_emb)
|
117 |
+
|
118 |
+
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
119 |
+
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
120 |
+
emb = emb + self.label_emb(y)
|
121 |
+
|
122 |
+
def call_module(module, h, emb, context):
|
123 |
+
x = h
|
124 |
+
for layer in module:
|
125 |
+
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
|
126 |
+
x = layer(x, emb)
|
127 |
+
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
|
128 |
+
x = layer(x, context)
|
129 |
+
else:
|
130 |
+
x = layer(x)
|
131 |
+
return x
|
132 |
+
|
133 |
+
h = x
|
134 |
+
multiplier = self.multiplier if self.multiplier is not None else 1.0
|
135 |
+
hs = []
|
136 |
+
for i, module in enumerate(self.input_blocks):
|
137 |
+
h = call_module(module, h, emb, context)
|
138 |
+
if i == 0:
|
139 |
+
h = self.controlnet_cond_embedding(cond_image) + h
|
140 |
+
hs.append(self.controlnet_down_blocks[i](h) * multiplier)
|
141 |
+
|
142 |
+
h = call_module(self.middle_block, h, emb, context)
|
143 |
+
h = self.controlnet_mid_block(h) * multiplier
|
144 |
+
|
145 |
+
return hs, h
|
146 |
+
|
147 |
+
|
148 |
+
class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel):
|
149 |
+
"""
|
150 |
+
This class is for training purpose only.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self, **kwargs):
|
154 |
+
super().__init__(**kwargs)
|
155 |
+
|
156 |
+
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
|
157 |
+
# broadcast timesteps to batch dimension
|
158 |
+
timesteps = timesteps.expand(x.shape[0])
|
159 |
+
|
160 |
+
hs = []
|
161 |
+
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
|
162 |
+
t_emb = t_emb.to(x.dtype)
|
163 |
+
emb = self.time_embed(t_emb)
|
164 |
+
|
165 |
+
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
166 |
+
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
167 |
+
emb = emb + self.label_emb(y)
|
168 |
+
|
169 |
+
def call_module(module, h, emb, context):
|
170 |
+
x = h
|
171 |
+
for layer in module:
|
172 |
+
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
|
173 |
+
x = layer(x, emb)
|
174 |
+
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
|
175 |
+
x = layer(x, context)
|
176 |
+
else:
|
177 |
+
x = layer(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
h = x
|
181 |
+
for module in self.input_blocks:
|
182 |
+
h = call_module(module, h, emb, context)
|
183 |
+
hs.append(h)
|
184 |
+
|
185 |
+
h = call_module(self.middle_block, h, emb, context)
|
186 |
+
h = h + mid_add
|
187 |
+
|
188 |
+
for module in self.output_blocks:
|
189 |
+
resi = hs.pop() + input_resi_add.pop()
|
190 |
+
h = torch.cat([h, resi], dim=1)
|
191 |
+
h = call_module(module, h, emb, context)
|
192 |
+
|
193 |
+
h = h.type(x.dtype)
|
194 |
+
h = call_module(self.out, h, emb, context)
|
195 |
+
|
196 |
+
return h
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
import time
|
201 |
+
|
202 |
+
logger.info("create unet")
|
203 |
+
unet = SdxlControlledUNet()
|
204 |
+
unet.to("cuda", torch.bfloat16)
|
205 |
+
unet.set_use_sdpa(True)
|
206 |
+
unet.set_gradient_checkpointing(True)
|
207 |
+
unet.train()
|
208 |
+
|
209 |
+
logger.info("create control_net")
|
210 |
+
control_net = SdxlControlNet()
|
211 |
+
control_net.to("cuda")
|
212 |
+
control_net.set_use_sdpa(True)
|
213 |
+
control_net.set_gradient_checkpointing(True)
|
214 |
+
control_net.train()
|
215 |
+
|
216 |
+
logger.info("Initialize control_net from unet")
|
217 |
+
control_net.init_from_unet(unet)
|
218 |
+
|
219 |
+
unet.requires_grad_(False)
|
220 |
+
control_net.requires_grad_(True)
|
221 |
+
|
222 |
+
# 使用メモリ量確認用の疑似学習ループ
|
223 |
+
logger.info("preparing optimizer")
|
224 |
+
|
225 |
+
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
226 |
+
|
227 |
+
import bitsandbytes
|
228 |
+
|
229 |
+
optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working
|
230 |
+
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
231 |
+
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
232 |
+
|
233 |
+
# import transformers
|
234 |
+
# optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
|
235 |
+
|
236 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
237 |
+
|
238 |
+
logger.info("start training")
|
239 |
+
steps = 10
|
240 |
+
batch_size = 1
|
241 |
+
|
242 |
+
for step in range(steps):
|
243 |
+
logger.info(f"step {step}")
|
244 |
+
if step == 1:
|
245 |
+
time_start = time.perf_counter()
|
246 |
+
|
247 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
248 |
+
t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda")
|
249 |
+
txt = torch.randn(batch_size, 77, 2048).cuda()
|
250 |
+
vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
251 |
+
cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda()
|
252 |
+
|
253 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
254 |
+
input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img)
|
255 |
+
output = unet(x, t, txt, vector, input_resi_add, mid_add)
|
256 |
+
target = torch.randn_like(output)
|
257 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
258 |
+
|
259 |
+
scaler.scale(loss).backward()
|
260 |
+
scaler.step(optimizer)
|
261 |
+
scaler.update()
|
262 |
+
optimizer.zero_grad(set_to_none=True)
|
263 |
+
|
264 |
+
time_end = time.perf_counter()
|
265 |
+
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
266 |
+
|
267 |
+
logger.info("finish training")
|
268 |
+
sd = control_net.state_dict()
|
269 |
+
|
270 |
+
from safetensors.torch import save_file
|
271 |
+
|
272 |
+
save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors")
|
library/sdxl_original_unet.py
ADDED
@@ -0,0 +1,1292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffusersのコードをベースとした sd_xl_baseのU-Net
|
2 |
+
# state dictの形式をSDXLに合わせてある
|
3 |
+
|
4 |
+
"""
|
5 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
6 |
+
params:
|
7 |
+
adm_in_channels: 2816
|
8 |
+
num_classes: sequential
|
9 |
+
use_checkpoint: True
|
10 |
+
in_channels: 4
|
11 |
+
out_channels: 4
|
12 |
+
model_channels: 320
|
13 |
+
attention_resolutions: [4, 2]
|
14 |
+
num_res_blocks: 2
|
15 |
+
channel_mult: [1, 2, 4]
|
16 |
+
num_head_channels: 64
|
17 |
+
use_spatial_transformer: True
|
18 |
+
use_linear_in_transformer: True
|
19 |
+
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
20 |
+
context_dim: 2048
|
21 |
+
spatial_transformer_attn_type: softmax-xformers
|
22 |
+
legacy: False
|
23 |
+
"""
|
24 |
+
|
25 |
+
import math
|
26 |
+
from types import SimpleNamespace
|
27 |
+
from typing import Any, Optional
|
28 |
+
import torch
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
from torch import nn
|
31 |
+
from torch.nn import functional as F
|
32 |
+
from einops import rearrange
|
33 |
+
from library.utils import setup_logging
|
34 |
+
|
35 |
+
setup_logging()
|
36 |
+
import logging
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
IN_CHANNELS: int = 4
|
41 |
+
OUT_CHANNELS: int = 4
|
42 |
+
ADM_IN_CHANNELS: int = 2816
|
43 |
+
CONTEXT_DIM: int = 2048
|
44 |
+
MODEL_CHANNELS: int = 320
|
45 |
+
TIME_EMBED_DIM = 320 * 4
|
46 |
+
|
47 |
+
USE_REENTRANT = True
|
48 |
+
|
49 |
+
# region memory efficient attention
|
50 |
+
|
51 |
+
# FlashAttentionを使うCrossAttention
|
52 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
53 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
54 |
+
|
55 |
+
# constants
|
56 |
+
|
57 |
+
EPSILON = 1e-6
|
58 |
+
|
59 |
+
# helper functions
|
60 |
+
|
61 |
+
|
62 |
+
def exists(val):
|
63 |
+
return val is not None
|
64 |
+
|
65 |
+
|
66 |
+
def default(val, d):
|
67 |
+
return val if exists(val) else d
|
68 |
+
|
69 |
+
|
70 |
+
# flash attention forwards and backwards
|
71 |
+
|
72 |
+
# https://arxiv.org/abs/2205.14135
|
73 |
+
|
74 |
+
|
75 |
+
class FlashAttentionFunction(torch.autograd.Function):
|
76 |
+
@staticmethod
|
77 |
+
@torch.no_grad()
|
78 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
79 |
+
"""Algorithm 2 in the paper"""
|
80 |
+
|
81 |
+
device = q.device
|
82 |
+
dtype = q.dtype
|
83 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
84 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
85 |
+
|
86 |
+
o = torch.zeros_like(q)
|
87 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
88 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
89 |
+
|
90 |
+
scale = q.shape[-1] ** -0.5
|
91 |
+
|
92 |
+
if not exists(mask):
|
93 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
94 |
+
else:
|
95 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
96 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
97 |
+
|
98 |
+
row_splits = zip(
|
99 |
+
q.split(q_bucket_size, dim=-2),
|
100 |
+
o.split(q_bucket_size, dim=-2),
|
101 |
+
mask,
|
102 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
103 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
104 |
+
)
|
105 |
+
|
106 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
107 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
108 |
+
|
109 |
+
col_splits = zip(
|
110 |
+
k.split(k_bucket_size, dim=-2),
|
111 |
+
v.split(k_bucket_size, dim=-2),
|
112 |
+
)
|
113 |
+
|
114 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
115 |
+
k_start_index = k_ind * k_bucket_size
|
116 |
+
|
117 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
118 |
+
|
119 |
+
if exists(row_mask):
|
120 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
121 |
+
|
122 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
123 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
124 |
+
q_start_index - k_start_index + 1
|
125 |
+
)
|
126 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
127 |
+
|
128 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
129 |
+
attn_weights -= block_row_maxes
|
130 |
+
exp_weights = torch.exp(attn_weights)
|
131 |
+
|
132 |
+
if exists(row_mask):
|
133 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
134 |
+
|
135 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
136 |
+
|
137 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
138 |
+
|
139 |
+
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
140 |
+
|
141 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
142 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
143 |
+
|
144 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
145 |
+
|
146 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
147 |
+
|
148 |
+
row_maxes.copy_(new_row_maxes)
|
149 |
+
row_sums.copy_(new_row_sums)
|
150 |
+
|
151 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
152 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
153 |
+
|
154 |
+
return o
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
@torch.no_grad()
|
158 |
+
def backward(ctx, do):
|
159 |
+
"""Algorithm 4 in the paper"""
|
160 |
+
|
161 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
162 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
163 |
+
|
164 |
+
device = q.device
|
165 |
+
|
166 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
167 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
168 |
+
|
169 |
+
dq = torch.zeros_like(q)
|
170 |
+
dk = torch.zeros_like(k)
|
171 |
+
dv = torch.zeros_like(v)
|
172 |
+
|
173 |
+
row_splits = zip(
|
174 |
+
q.split(q_bucket_size, dim=-2),
|
175 |
+
o.split(q_bucket_size, dim=-2),
|
176 |
+
do.split(q_bucket_size, dim=-2),
|
177 |
+
mask,
|
178 |
+
l.split(q_bucket_size, dim=-2),
|
179 |
+
m.split(q_bucket_size, dim=-2),
|
180 |
+
dq.split(q_bucket_size, dim=-2),
|
181 |
+
)
|
182 |
+
|
183 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
184 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
185 |
+
|
186 |
+
col_splits = zip(
|
187 |
+
k.split(k_bucket_size, dim=-2),
|
188 |
+
v.split(k_bucket_size, dim=-2),
|
189 |
+
dk.split(k_bucket_size, dim=-2),
|
190 |
+
dv.split(k_bucket_size, dim=-2),
|
191 |
+
)
|
192 |
+
|
193 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
194 |
+
k_start_index = k_ind * k_bucket_size
|
195 |
+
|
196 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
197 |
+
|
198 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
199 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
200 |
+
q_start_index - k_start_index + 1
|
201 |
+
)
|
202 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
203 |
+
|
204 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
205 |
+
|
206 |
+
if exists(row_mask):
|
207 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
208 |
+
|
209 |
+
p = exp_attn_weights / lc
|
210 |
+
|
211 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
212 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
213 |
+
|
214 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
215 |
+
ds = p * scale * (dp - D)
|
216 |
+
|
217 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
218 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
219 |
+
|
220 |
+
dqc.add_(dq_chunk)
|
221 |
+
dkc.add_(dk_chunk)
|
222 |
+
dvc.add_(dv_chunk)
|
223 |
+
|
224 |
+
return dq, dk, dv, None, None, None, None
|
225 |
+
|
226 |
+
|
227 |
+
# endregion
|
228 |
+
|
229 |
+
|
230 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
231 |
+
return next(parameter.parameters()).dtype
|
232 |
+
|
233 |
+
|
234 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
235 |
+
return next(parameter.parameters()).device
|
236 |
+
|
237 |
+
|
238 |
+
def get_timestep_embedding(
|
239 |
+
timesteps: torch.Tensor,
|
240 |
+
embedding_dim: int,
|
241 |
+
downscale_freq_shift: float = 1,
|
242 |
+
scale: float = 1,
|
243 |
+
max_period: int = 10000,
|
244 |
+
):
|
245 |
+
"""
|
246 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
247 |
+
|
248 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
249 |
+
These may be fractional.
|
250 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
251 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
252 |
+
"""
|
253 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
254 |
+
|
255 |
+
half_dim = embedding_dim // 2
|
256 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
257 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
258 |
+
|
259 |
+
emb = torch.exp(exponent)
|
260 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
261 |
+
|
262 |
+
# scale embeddings
|
263 |
+
emb = scale * emb
|
264 |
+
|
265 |
+
# concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True
|
266 |
+
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
267 |
+
|
268 |
+
# zero pad
|
269 |
+
if embedding_dim % 2 == 1:
|
270 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
271 |
+
return emb
|
272 |
+
|
273 |
+
|
274 |
+
# Deep Shrink: We do not common this function, because minimize dependencies.
|
275 |
+
def resize_like(x, target, mode="bicubic", align_corners=False):
|
276 |
+
org_dtype = x.dtype
|
277 |
+
if org_dtype == torch.bfloat16:
|
278 |
+
x = x.to(torch.float32)
|
279 |
+
|
280 |
+
if x.shape[-2:] != target.shape[-2:]:
|
281 |
+
if mode == "nearest":
|
282 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
283 |
+
else:
|
284 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
285 |
+
|
286 |
+
if org_dtype == torch.bfloat16:
|
287 |
+
x = x.to(org_dtype)
|
288 |
+
return x
|
289 |
+
|
290 |
+
|
291 |
+
class GroupNorm32(nn.GroupNorm):
|
292 |
+
def forward(self, x):
|
293 |
+
if self.weight.dtype != torch.float32:
|
294 |
+
return super().forward(x)
|
295 |
+
return super().forward(x.float()).type(x.dtype)
|
296 |
+
|
297 |
+
|
298 |
+
class ResnetBlock2D(nn.Module):
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
in_channels,
|
302 |
+
out_channels,
|
303 |
+
):
|
304 |
+
super().__init__()
|
305 |
+
self.in_channels = in_channels
|
306 |
+
self.out_channels = out_channels
|
307 |
+
|
308 |
+
self.in_layers = nn.Sequential(
|
309 |
+
GroupNorm32(32, in_channels),
|
310 |
+
nn.SiLU(),
|
311 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
312 |
+
)
|
313 |
+
|
314 |
+
self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels))
|
315 |
+
|
316 |
+
self.out_layers = nn.Sequential(
|
317 |
+
GroupNorm32(32, out_channels),
|
318 |
+
nn.SiLU(),
|
319 |
+
nn.Identity(), # to make state_dict compatible with original model
|
320 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
321 |
+
)
|
322 |
+
|
323 |
+
if in_channels != out_channels:
|
324 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
325 |
+
else:
|
326 |
+
self.skip_connection = nn.Identity()
|
327 |
+
|
328 |
+
self.gradient_checkpointing = False
|
329 |
+
|
330 |
+
def forward_body(self, x, emb):
|
331 |
+
h = self.in_layers(x)
|
332 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
333 |
+
h = h + emb_out[:, :, None, None]
|
334 |
+
h = self.out_layers(h)
|
335 |
+
x = self.skip_connection(x)
|
336 |
+
return x + h
|
337 |
+
|
338 |
+
def forward(self, x, emb):
|
339 |
+
if self.training and self.gradient_checkpointing:
|
340 |
+
# logger.info("ResnetBlock2D: gradient_checkpointing")
|
341 |
+
|
342 |
+
def create_custom_forward(func):
|
343 |
+
def custom_forward(*inputs):
|
344 |
+
return func(*inputs)
|
345 |
+
|
346 |
+
return custom_forward
|
347 |
+
|
348 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
|
349 |
+
else:
|
350 |
+
x = self.forward_body(x, emb)
|
351 |
+
|
352 |
+
return x
|
353 |
+
|
354 |
+
|
355 |
+
class Downsample2D(nn.Module):
|
356 |
+
def __init__(self, channels, out_channels):
|
357 |
+
super().__init__()
|
358 |
+
|
359 |
+
self.channels = channels
|
360 |
+
self.out_channels = out_channels
|
361 |
+
|
362 |
+
self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
|
363 |
+
|
364 |
+
self.gradient_checkpointing = False
|
365 |
+
|
366 |
+
def forward_body(self, hidden_states):
|
367 |
+
assert hidden_states.shape[1] == self.channels
|
368 |
+
hidden_states = self.op(hidden_states)
|
369 |
+
|
370 |
+
return hidden_states
|
371 |
+
|
372 |
+
def forward(self, hidden_states):
|
373 |
+
if self.training and self.gradient_checkpointing:
|
374 |
+
# logger.info("Downsample2D: gradient_checkpointing")
|
375 |
+
|
376 |
+
def create_custom_forward(func):
|
377 |
+
def custom_forward(*inputs):
|
378 |
+
return func(*inputs)
|
379 |
+
|
380 |
+
return custom_forward
|
381 |
+
|
382 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
383 |
+
create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
|
384 |
+
)
|
385 |
+
else:
|
386 |
+
hidden_states = self.forward_body(hidden_states)
|
387 |
+
|
388 |
+
return hidden_states
|
389 |
+
|
390 |
+
|
391 |
+
class CrossAttention(nn.Module):
|
392 |
+
def __init__(
|
393 |
+
self,
|
394 |
+
query_dim: int,
|
395 |
+
cross_attention_dim: Optional[int] = None,
|
396 |
+
heads: int = 8,
|
397 |
+
dim_head: int = 64,
|
398 |
+
upcast_attention: bool = False,
|
399 |
+
):
|
400 |
+
super().__init__()
|
401 |
+
inner_dim = dim_head * heads
|
402 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
403 |
+
self.upcast_attention = upcast_attention
|
404 |
+
|
405 |
+
self.scale = dim_head**-0.5
|
406 |
+
self.heads = heads
|
407 |
+
|
408 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
409 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
410 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
411 |
+
|
412 |
+
self.to_out = nn.ModuleList([])
|
413 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
414 |
+
# no dropout here
|
415 |
+
|
416 |
+
self.use_memory_efficient_attention_xformers = False
|
417 |
+
self.use_memory_efficient_attention_mem_eff = False
|
418 |
+
self.use_sdpa = False
|
419 |
+
|
420 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
421 |
+
self.use_memory_efficient_attention_xformers = xformers
|
422 |
+
self.use_memory_efficient_attention_mem_eff = mem_eff
|
423 |
+
|
424 |
+
def set_use_sdpa(self, sdpa):
|
425 |
+
self.use_sdpa = sdpa
|
426 |
+
|
427 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
428 |
+
batch_size, seq_len, dim = tensor.shape
|
429 |
+
head_size = self.heads
|
430 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
431 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
432 |
+
return tensor
|
433 |
+
|
434 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
435 |
+
batch_size, seq_len, dim = tensor.shape
|
436 |
+
head_size = self.heads
|
437 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
438 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
439 |
+
return tensor
|
440 |
+
|
441 |
+
def forward(self, hidden_states, context=None, mask=None):
|
442 |
+
if self.use_memory_efficient_attention_xformers:
|
443 |
+
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
444 |
+
if self.use_memory_efficient_attention_mem_eff:
|
445 |
+
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
446 |
+
if self.use_sdpa:
|
447 |
+
return self.forward_sdpa(hidden_states, context, mask)
|
448 |
+
|
449 |
+
query = self.to_q(hidden_states)
|
450 |
+
context = context if context is not None else hidden_states
|
451 |
+
key = self.to_k(context)
|
452 |
+
value = self.to_v(context)
|
453 |
+
|
454 |
+
query = self.reshape_heads_to_batch_dim(query)
|
455 |
+
key = self.reshape_heads_to_batch_dim(key)
|
456 |
+
value = self.reshape_heads_to_batch_dim(value)
|
457 |
+
|
458 |
+
hidden_states = self._attention(query, key, value)
|
459 |
+
|
460 |
+
# linear proj
|
461 |
+
hidden_states = self.to_out[0](hidden_states)
|
462 |
+
# hidden_states = self.to_out[1](hidden_states) # no dropout
|
463 |
+
return hidden_states
|
464 |
+
|
465 |
+
def _attention(self, query, key, value):
|
466 |
+
if self.upcast_attention:
|
467 |
+
query = query.float()
|
468 |
+
key = key.float()
|
469 |
+
|
470 |
+
attention_scores = torch.baddbmm(
|
471 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
472 |
+
query,
|
473 |
+
key.transpose(-1, -2),
|
474 |
+
beta=0,
|
475 |
+
alpha=self.scale,
|
476 |
+
)
|
477 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
478 |
+
|
479 |
+
# cast back to the original dtype
|
480 |
+
attention_probs = attention_probs.to(value.dtype)
|
481 |
+
|
482 |
+
# compute attention output
|
483 |
+
hidden_states = torch.bmm(attention_probs, value)
|
484 |
+
|
485 |
+
# reshape hidden_states
|
486 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
487 |
+
return hidden_states
|
488 |
+
|
489 |
+
# TODO support Hypernetworks
|
490 |
+
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
|
491 |
+
import xformers.ops
|
492 |
+
|
493 |
+
h = self.heads
|
494 |
+
q_in = self.to_q(x)
|
495 |
+
context = context if context is not None else x
|
496 |
+
context = context.to(x.dtype)
|
497 |
+
k_in = self.to_k(context)
|
498 |
+
v_in = self.to_v(context)
|
499 |
+
|
500 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
501 |
+
del q_in, k_in, v_in
|
502 |
+
|
503 |
+
q = q.contiguous()
|
504 |
+
k = k.contiguous()
|
505 |
+
v = v.contiguous()
|
506 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
507 |
+
del q, k, v
|
508 |
+
|
509 |
+
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
510 |
+
|
511 |
+
out = self.to_out[0](out)
|
512 |
+
return out
|
513 |
+
|
514 |
+
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
|
515 |
+
flash_func = FlashAttentionFunction
|
516 |
+
|
517 |
+
q_bucket_size = 512
|
518 |
+
k_bucket_size = 1024
|
519 |
+
|
520 |
+
h = self.heads
|
521 |
+
q = self.to_q(x)
|
522 |
+
context = context if context is not None else x
|
523 |
+
context = context.to(x.dtype)
|
524 |
+
k = self.to_k(context)
|
525 |
+
v = self.to_v(context)
|
526 |
+
del context, x
|
527 |
+
|
528 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
529 |
+
|
530 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
531 |
+
|
532 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
533 |
+
|
534 |
+
out = self.to_out[0](out)
|
535 |
+
return out
|
536 |
+
|
537 |
+
def forward_sdpa(self, x, context=None, mask=None):
|
538 |
+
h = self.heads
|
539 |
+
q_in = self.to_q(x)
|
540 |
+
context = context if context is not None else x
|
541 |
+
context = context.to(x.dtype)
|
542 |
+
k_in = self.to_k(context)
|
543 |
+
v_in = self.to_v(context)
|
544 |
+
|
545 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
546 |
+
del q_in, k_in, v_in
|
547 |
+
|
548 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
549 |
+
|
550 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
551 |
+
|
552 |
+
out = self.to_out[0](out)
|
553 |
+
return out
|
554 |
+
|
555 |
+
|
556 |
+
# feedforward
|
557 |
+
class GEGLU(nn.Module):
|
558 |
+
r"""
|
559 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
560 |
+
|
561 |
+
Parameters:
|
562 |
+
dim_in (`int`): The number of channels in the input.
|
563 |
+
dim_out (`int`): The number of channels in the output.
|
564 |
+
"""
|
565 |
+
|
566 |
+
def __init__(self, dim_in: int, dim_out: int):
|
567 |
+
super().__init__()
|
568 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
569 |
+
|
570 |
+
def gelu(self, gate):
|
571 |
+
if gate.device.type != "mps":
|
572 |
+
return F.gelu(gate)
|
573 |
+
# mps: gelu is not implemented for float16
|
574 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
575 |
+
|
576 |
+
def forward(self, hidden_states):
|
577 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
578 |
+
return hidden_states * self.gelu(gate)
|
579 |
+
|
580 |
+
|
581 |
+
class FeedForward(nn.Module):
|
582 |
+
def __init__(
|
583 |
+
self,
|
584 |
+
dim: int,
|
585 |
+
):
|
586 |
+
super().__init__()
|
587 |
+
inner_dim = int(dim * 4) # mult is always 4
|
588 |
+
|
589 |
+
self.net = nn.ModuleList([])
|
590 |
+
# project in
|
591 |
+
self.net.append(GEGLU(dim, inner_dim))
|
592 |
+
# project dropout
|
593 |
+
self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
|
594 |
+
# project out
|
595 |
+
self.net.append(nn.Linear(inner_dim, dim))
|
596 |
+
|
597 |
+
def forward(self, hidden_states):
|
598 |
+
for module in self.net:
|
599 |
+
hidden_states = module(hidden_states)
|
600 |
+
return hidden_states
|
601 |
+
|
602 |
+
|
603 |
+
class BasicTransformerBlock(nn.Module):
|
604 |
+
def __init__(
|
605 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
|
606 |
+
):
|
607 |
+
super().__init__()
|
608 |
+
|
609 |
+
self.gradient_checkpointing = False
|
610 |
+
|
611 |
+
# 1. Self-Attn
|
612 |
+
self.attn1 = CrossAttention(
|
613 |
+
query_dim=dim,
|
614 |
+
cross_attention_dim=None,
|
615 |
+
heads=num_attention_heads,
|
616 |
+
dim_head=attention_head_dim,
|
617 |
+
upcast_attention=upcast_attention,
|
618 |
+
)
|
619 |
+
self.ff = FeedForward(dim)
|
620 |
+
|
621 |
+
# 2. Cross-Attn
|
622 |
+
self.attn2 = CrossAttention(
|
623 |
+
query_dim=dim,
|
624 |
+
cross_attention_dim=cross_attention_dim,
|
625 |
+
heads=num_attention_heads,
|
626 |
+
dim_head=attention_head_dim,
|
627 |
+
upcast_attention=upcast_attention,
|
628 |
+
)
|
629 |
+
|
630 |
+
self.norm1 = nn.LayerNorm(dim)
|
631 |
+
self.norm2 = nn.LayerNorm(dim)
|
632 |
+
|
633 |
+
# 3. Feed-forward
|
634 |
+
self.norm3 = nn.LayerNorm(dim)
|
635 |
+
|
636 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
|
637 |
+
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
638 |
+
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
639 |
+
|
640 |
+
def set_use_sdpa(self, sdpa: bool):
|
641 |
+
self.attn1.set_use_sdpa(sdpa)
|
642 |
+
self.attn2.set_use_sdpa(sdpa)
|
643 |
+
|
644 |
+
def forward_body(self, hidden_states, context=None, timestep=None):
|
645 |
+
# 1. Self-Attention
|
646 |
+
norm_hidden_states = self.norm1(hidden_states)
|
647 |
+
|
648 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
649 |
+
|
650 |
+
# 2. Cross-Attention
|
651 |
+
norm_hidden_states = self.norm2(hidden_states)
|
652 |
+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
653 |
+
|
654 |
+
# 3. Feed-forward
|
655 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
656 |
+
|
657 |
+
return hidden_states
|
658 |
+
|
659 |
+
def forward(self, hidden_states, context=None, timestep=None):
|
660 |
+
if self.training and self.gradient_checkpointing:
|
661 |
+
# logger.info("BasicTransformerBlock: checkpointing")
|
662 |
+
|
663 |
+
def create_custom_forward(func):
|
664 |
+
def custom_forward(*inputs):
|
665 |
+
return func(*inputs)
|
666 |
+
|
667 |
+
return custom_forward
|
668 |
+
|
669 |
+
output = torch.utils.checkpoint.checkpoint(
|
670 |
+
create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
|
671 |
+
)
|
672 |
+
else:
|
673 |
+
output = self.forward_body(hidden_states, context, timestep)
|
674 |
+
|
675 |
+
return output
|
676 |
+
|
677 |
+
|
678 |
+
class Transformer2DModel(nn.Module):
|
679 |
+
def __init__(
|
680 |
+
self,
|
681 |
+
num_attention_heads: int = 16,
|
682 |
+
attention_head_dim: int = 88,
|
683 |
+
in_channels: Optional[int] = None,
|
684 |
+
cross_attention_dim: Optional[int] = None,
|
685 |
+
use_linear_projection: bool = False,
|
686 |
+
upcast_attention: bool = False,
|
687 |
+
num_transformer_layers: int = 1,
|
688 |
+
):
|
689 |
+
super().__init__()
|
690 |
+
self.in_channels = in_channels
|
691 |
+
self.num_attention_heads = num_attention_heads
|
692 |
+
self.attention_head_dim = attention_head_dim
|
693 |
+
inner_dim = num_attention_heads * attention_head_dim
|
694 |
+
self.use_linear_projection = use_linear_projection
|
695 |
+
|
696 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
697 |
+
# self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True)
|
698 |
+
|
699 |
+
if use_linear_projection:
|
700 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
701 |
+
else:
|
702 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
703 |
+
|
704 |
+
blocks = []
|
705 |
+
for _ in range(num_transformer_layers):
|
706 |
+
blocks.append(
|
707 |
+
BasicTransformerBlock(
|
708 |
+
inner_dim,
|
709 |
+
num_attention_heads,
|
710 |
+
attention_head_dim,
|
711 |
+
cross_attention_dim=cross_attention_dim,
|
712 |
+
upcast_attention=upcast_attention,
|
713 |
+
)
|
714 |
+
)
|
715 |
+
|
716 |
+
self.transformer_blocks = nn.ModuleList(blocks)
|
717 |
+
|
718 |
+
if use_linear_projection:
|
719 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
720 |
+
else:
|
721 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
722 |
+
|
723 |
+
self.gradient_checkpointing = False
|
724 |
+
|
725 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
726 |
+
for transformer in self.transformer_blocks:
|
727 |
+
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
728 |
+
|
729 |
+
def set_use_sdpa(self, sdpa):
|
730 |
+
for transformer in self.transformer_blocks:
|
731 |
+
transformer.set_use_sdpa(sdpa)
|
732 |
+
|
733 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None):
|
734 |
+
# 1. Input
|
735 |
+
batch, _, height, weight = hidden_states.shape
|
736 |
+
residual = hidden_states
|
737 |
+
|
738 |
+
hidden_states = self.norm(hidden_states)
|
739 |
+
if not self.use_linear_projection:
|
740 |
+
hidden_states = self.proj_in(hidden_states)
|
741 |
+
inner_dim = hidden_states.shape[1]
|
742 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
743 |
+
else:
|
744 |
+
inner_dim = hidden_states.shape[1]
|
745 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
746 |
+
hidden_states = self.proj_in(hidden_states)
|
747 |
+
|
748 |
+
# 2. Blocks
|
749 |
+
for block in self.transformer_blocks:
|
750 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
751 |
+
|
752 |
+
# 3. Output
|
753 |
+
if not self.use_linear_projection:
|
754 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
755 |
+
hidden_states = self.proj_out(hidden_states)
|
756 |
+
else:
|
757 |
+
hidden_states = self.proj_out(hidden_states)
|
758 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
759 |
+
|
760 |
+
output = hidden_states + residual
|
761 |
+
|
762 |
+
return output
|
763 |
+
|
764 |
+
|
765 |
+
class Upsample2D(nn.Module):
|
766 |
+
def __init__(self, channels, out_channels):
|
767 |
+
super().__init__()
|
768 |
+
self.channels = channels
|
769 |
+
self.out_channels = out_channels
|
770 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
771 |
+
|
772 |
+
self.gradient_checkpointing = False
|
773 |
+
|
774 |
+
def forward_body(self, hidden_states, output_size=None):
|
775 |
+
assert hidden_states.shape[1] == self.channels
|
776 |
+
|
777 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
778 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
779 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
780 |
+
dtype = hidden_states.dtype
|
781 |
+
if dtype == torch.bfloat16:
|
782 |
+
hidden_states = hidden_states.to(torch.float32)
|
783 |
+
|
784 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
785 |
+
if hidden_states.shape[0] >= 64:
|
786 |
+
hidden_states = hidden_states.contiguous()
|
787 |
+
|
788 |
+
# if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
|
789 |
+
if output_size is None:
|
790 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
791 |
+
else:
|
792 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
793 |
+
|
794 |
+
# If the input is bfloat16, we cast back to bfloat16
|
795 |
+
if dtype == torch.bfloat16:
|
796 |
+
hidden_states = hidden_states.to(dtype)
|
797 |
+
|
798 |
+
hidden_states = self.conv(hidden_states)
|
799 |
+
|
800 |
+
return hidden_states
|
801 |
+
|
802 |
+
def forward(self, hidden_states, output_size=None):
|
803 |
+
if self.training and self.gradient_checkpointing:
|
804 |
+
# logger.info("Upsample2D: gradient_checkpointing")
|
805 |
+
|
806 |
+
def create_custom_forward(func):
|
807 |
+
def custom_forward(*inputs):
|
808 |
+
return func(*inputs)
|
809 |
+
|
810 |
+
return custom_forward
|
811 |
+
|
812 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
813 |
+
create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
|
814 |
+
)
|
815 |
+
else:
|
816 |
+
hidden_states = self.forward_body(hidden_states, output_size)
|
817 |
+
|
818 |
+
return hidden_states
|
819 |
+
|
820 |
+
|
821 |
+
class SdxlUNet2DConditionModel(nn.Module):
|
822 |
+
_supports_gradient_checkpointing = True
|
823 |
+
|
824 |
+
def __init__(
|
825 |
+
self,
|
826 |
+
**kwargs,
|
827 |
+
):
|
828 |
+
super().__init__()
|
829 |
+
|
830 |
+
self.in_channels = IN_CHANNELS
|
831 |
+
self.out_channels = OUT_CHANNELS
|
832 |
+
self.model_channels = MODEL_CHANNELS
|
833 |
+
self.time_embed_dim = TIME_EMBED_DIM
|
834 |
+
self.adm_in_channels = ADM_IN_CHANNELS
|
835 |
+
|
836 |
+
self.gradient_checkpointing = False
|
837 |
+
# self.sample_size = sample_size
|
838 |
+
|
839 |
+
# time embedding
|
840 |
+
self.time_embed = nn.Sequential(
|
841 |
+
nn.Linear(self.model_channels, self.time_embed_dim),
|
842 |
+
nn.SiLU(),
|
843 |
+
nn.Linear(self.time_embed_dim, self.time_embed_dim),
|
844 |
+
)
|
845 |
+
|
846 |
+
# label embedding
|
847 |
+
self.label_emb = nn.Sequential(
|
848 |
+
nn.Sequential(
|
849 |
+
nn.Linear(self.adm_in_channels, self.time_embed_dim),
|
850 |
+
nn.SiLU(),
|
851 |
+
nn.Linear(self.time_embed_dim, self.time_embed_dim),
|
852 |
+
)
|
853 |
+
)
|
854 |
+
|
855 |
+
# input
|
856 |
+
self.input_blocks = nn.ModuleList(
|
857 |
+
[
|
858 |
+
nn.Sequential(
|
859 |
+
nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)),
|
860 |
+
)
|
861 |
+
]
|
862 |
+
)
|
863 |
+
|
864 |
+
# level 0
|
865 |
+
for i in range(2):
|
866 |
+
layers = [
|
867 |
+
ResnetBlock2D(
|
868 |
+
in_channels=1 * self.model_channels,
|
869 |
+
out_channels=1 * self.model_channels,
|
870 |
+
),
|
871 |
+
]
|
872 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
873 |
+
|
874 |
+
self.input_blocks.append(
|
875 |
+
nn.Sequential(
|
876 |
+
Downsample2D(
|
877 |
+
channels=1 * self.model_channels,
|
878 |
+
out_channels=1 * self.model_channels,
|
879 |
+
),
|
880 |
+
)
|
881 |
+
)
|
882 |
+
|
883 |
+
# level 1
|
884 |
+
for i in range(2):
|
885 |
+
layers = [
|
886 |
+
ResnetBlock2D(
|
887 |
+
in_channels=(1 if i == 0 else 2) * self.model_channels,
|
888 |
+
out_channels=2 * self.model_channels,
|
889 |
+
),
|
890 |
+
Transformer2DModel(
|
891 |
+
num_attention_heads=2 * self.model_channels // 64,
|
892 |
+
attention_head_dim=64,
|
893 |
+
in_channels=2 * self.model_channels,
|
894 |
+
num_transformer_layers=2,
|
895 |
+
use_linear_projection=True,
|
896 |
+
cross_attention_dim=2048,
|
897 |
+
),
|
898 |
+
]
|
899 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
900 |
+
|
901 |
+
self.input_blocks.append(
|
902 |
+
nn.Sequential(
|
903 |
+
Downsample2D(
|
904 |
+
channels=2 * self.model_channels,
|
905 |
+
out_channels=2 * self.model_channels,
|
906 |
+
),
|
907 |
+
)
|
908 |
+
)
|
909 |
+
|
910 |
+
# level 2
|
911 |
+
for i in range(2):
|
912 |
+
layers = [
|
913 |
+
ResnetBlock2D(
|
914 |
+
in_channels=(2 if i == 0 else 4) * self.model_channels,
|
915 |
+
out_channels=4 * self.model_channels,
|
916 |
+
),
|
917 |
+
Transformer2DModel(
|
918 |
+
num_attention_heads=4 * self.model_channels // 64,
|
919 |
+
attention_head_dim=64,
|
920 |
+
in_channels=4 * self.model_channels,
|
921 |
+
num_transformer_layers=10,
|
922 |
+
use_linear_projection=True,
|
923 |
+
cross_attention_dim=2048,
|
924 |
+
),
|
925 |
+
]
|
926 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
927 |
+
|
928 |
+
# mid
|
929 |
+
self.middle_block = nn.ModuleList(
|
930 |
+
[
|
931 |
+
ResnetBlock2D(
|
932 |
+
in_channels=4 * self.model_channels,
|
933 |
+
out_channels=4 * self.model_channels,
|
934 |
+
),
|
935 |
+
Transformer2DModel(
|
936 |
+
num_attention_heads=4 * self.model_channels // 64,
|
937 |
+
attention_head_dim=64,
|
938 |
+
in_channels=4 * self.model_channels,
|
939 |
+
num_transformer_layers=10,
|
940 |
+
use_linear_projection=True,
|
941 |
+
cross_attention_dim=2048,
|
942 |
+
),
|
943 |
+
ResnetBlock2D(
|
944 |
+
in_channels=4 * self.model_channels,
|
945 |
+
out_channels=4 * self.model_channels,
|
946 |
+
),
|
947 |
+
]
|
948 |
+
)
|
949 |
+
|
950 |
+
# output
|
951 |
+
self.output_blocks = nn.ModuleList([])
|
952 |
+
|
953 |
+
# level 2
|
954 |
+
for i in range(3):
|
955 |
+
layers = [
|
956 |
+
ResnetBlock2D(
|
957 |
+
in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels,
|
958 |
+
out_channels=4 * self.model_channels,
|
959 |
+
),
|
960 |
+
Transformer2DModel(
|
961 |
+
num_attention_heads=4 * self.model_channels // 64,
|
962 |
+
attention_head_dim=64,
|
963 |
+
in_channels=4 * self.model_channels,
|
964 |
+
num_transformer_layers=10,
|
965 |
+
use_linear_projection=True,
|
966 |
+
cross_attention_dim=2048,
|
967 |
+
),
|
968 |
+
]
|
969 |
+
if i == 2:
|
970 |
+
layers.append(
|
971 |
+
Upsample2D(
|
972 |
+
channels=4 * self.model_channels,
|
973 |
+
out_channels=4 * self.model_channels,
|
974 |
+
)
|
975 |
+
)
|
976 |
+
|
977 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
978 |
+
|
979 |
+
# level 1
|
980 |
+
for i in range(3):
|
981 |
+
layers = [
|
982 |
+
ResnetBlock2D(
|
983 |
+
in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels,
|
984 |
+
out_channels=2 * self.model_channels,
|
985 |
+
),
|
986 |
+
Transformer2DModel(
|
987 |
+
num_attention_heads=2 * self.model_channels // 64,
|
988 |
+
attention_head_dim=64,
|
989 |
+
in_channels=2 * self.model_channels,
|
990 |
+
num_transformer_layers=2,
|
991 |
+
use_linear_projection=True,
|
992 |
+
cross_attention_dim=2048,
|
993 |
+
),
|
994 |
+
]
|
995 |
+
if i == 2:
|
996 |
+
layers.append(
|
997 |
+
Upsample2D(
|
998 |
+
channels=2 * self.model_channels,
|
999 |
+
out_channels=2 * self.model_channels,
|
1000 |
+
)
|
1001 |
+
)
|
1002 |
+
|
1003 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
1004 |
+
|
1005 |
+
# level 0
|
1006 |
+
for i in range(3):
|
1007 |
+
layers = [
|
1008 |
+
ResnetBlock2D(
|
1009 |
+
in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels,
|
1010 |
+
out_channels=1 * self.model_channels,
|
1011 |
+
),
|
1012 |
+
]
|
1013 |
+
|
1014 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
1015 |
+
|
1016 |
+
# output
|
1017 |
+
self.out = nn.ModuleList(
|
1018 |
+
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
# region diffusers compatibility
|
1022 |
+
def prepare_config(self):
|
1023 |
+
self.config = SimpleNamespace()
|
1024 |
+
|
1025 |
+
@property
|
1026 |
+
def dtype(self) -> torch.dtype:
|
1027 |
+
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
1028 |
+
return get_parameter_dtype(self)
|
1029 |
+
|
1030 |
+
@property
|
1031 |
+
def device(self) -> torch.device:
|
1032 |
+
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
1033 |
+
return get_parameter_device(self)
|
1034 |
+
|
1035 |
+
def set_attention_slice(self, slice_size):
|
1036 |
+
raise NotImplementedError("Attention slicing is not supported for this model.")
|
1037 |
+
|
1038 |
+
def is_gradient_checkpointing(self) -> bool:
|
1039 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
1040 |
+
|
1041 |
+
def enable_gradient_checkpointing(self):
|
1042 |
+
self.gradient_checkpointing = True
|
1043 |
+
self.set_gradient_checkpointing(value=True)
|
1044 |
+
|
1045 |
+
def disable_gradient_checkpointing(self):
|
1046 |
+
self.gradient_checkpointing = False
|
1047 |
+
self.set_gradient_checkpointing(value=False)
|
1048 |
+
|
1049 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
|
1050 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
1051 |
+
for block in blocks:
|
1052 |
+
for module in block:
|
1053 |
+
if hasattr(module, "set_use_memory_efficient_attention"):
|
1054 |
+
# logger.info(module.__class__.__name__)
|
1055 |
+
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
1056 |
+
|
1057 |
+
def set_use_sdpa(self, sdpa: bool) -> None:
|
1058 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
1059 |
+
for block in blocks:
|
1060 |
+
for module in block:
|
1061 |
+
if hasattr(module, "set_use_sdpa"):
|
1062 |
+
module.set_use_sdpa(sdpa)
|
1063 |
+
|
1064 |
+
def set_gradient_checkpointing(self, value=False):
|
1065 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
1066 |
+
for block in blocks:
|
1067 |
+
for module in block.modules():
|
1068 |
+
if hasattr(module, "gradient_checkpointing"):
|
1069 |
+
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
1070 |
+
module.gradient_checkpointing = value
|
1071 |
+
|
1072 |
+
# endregion
|
1073 |
+
|
1074 |
+
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
1075 |
+
# broadcast timesteps to batch dimension
|
1076 |
+
timesteps = timesteps.expand(x.shape[0])
|
1077 |
+
|
1078 |
+
hs = []
|
1079 |
+
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
1080 |
+
t_emb = t_emb.to(x.dtype)
|
1081 |
+
emb = self.time_embed(t_emb)
|
1082 |
+
|
1083 |
+
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
1084 |
+
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
1085 |
+
# assert x.dtype == self.dtype
|
1086 |
+
emb = emb + self.label_emb(y)
|
1087 |
+
|
1088 |
+
def call_module(module, h, emb, context):
|
1089 |
+
x = h
|
1090 |
+
for layer in module:
|
1091 |
+
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
1092 |
+
if isinstance(layer, ResnetBlock2D):
|
1093 |
+
x = layer(x, emb)
|
1094 |
+
elif isinstance(layer, Transformer2DModel):
|
1095 |
+
x = layer(x, context)
|
1096 |
+
else:
|
1097 |
+
x = layer(x)
|
1098 |
+
return x
|
1099 |
+
|
1100 |
+
# h = x.type(self.dtype)
|
1101 |
+
h = x
|
1102 |
+
|
1103 |
+
for module in self.input_blocks:
|
1104 |
+
h = call_module(module, h, emb, context)
|
1105 |
+
hs.append(h)
|
1106 |
+
|
1107 |
+
h = call_module(self.middle_block, h, emb, context)
|
1108 |
+
|
1109 |
+
for module in self.output_blocks:
|
1110 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
1111 |
+
h = call_module(module, h, emb, context)
|
1112 |
+
|
1113 |
+
h = h.type(x.dtype)
|
1114 |
+
h = call_module(self.out, h, emb, context)
|
1115 |
+
|
1116 |
+
return h
|
1117 |
+
|
1118 |
+
|
1119 |
+
class InferSdxlUNet2DConditionModel:
|
1120 |
+
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
1121 |
+
self.delegate = original_unet
|
1122 |
+
|
1123 |
+
# override original model's forward method: because forward is not called by `__call__`
|
1124 |
+
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
1125 |
+
self.delegate.forward = self.forward
|
1126 |
+
|
1127 |
+
# Deep Shrink
|
1128 |
+
self.ds_depth_1 = None
|
1129 |
+
self.ds_depth_2 = None
|
1130 |
+
self.ds_timesteps_1 = None
|
1131 |
+
self.ds_timesteps_2 = None
|
1132 |
+
self.ds_ratio = None
|
1133 |
+
|
1134 |
+
# call original model's methods
|
1135 |
+
def __getattr__(self, name):
|
1136 |
+
return getattr(self.delegate, name)
|
1137 |
+
|
1138 |
+
def __call__(self, *args, **kwargs):
|
1139 |
+
return self.delegate(*args, **kwargs)
|
1140 |
+
|
1141 |
+
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
1142 |
+
if ds_depth_1 is None:
|
1143 |
+
logger.info("Deep Shrink is disabled.")
|
1144 |
+
self.ds_depth_1 = None
|
1145 |
+
self.ds_timesteps_1 = None
|
1146 |
+
self.ds_depth_2 = None
|
1147 |
+
self.ds_timesteps_2 = None
|
1148 |
+
self.ds_ratio = None
|
1149 |
+
else:
|
1150 |
+
logger.info(
|
1151 |
+
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
1152 |
+
)
|
1153 |
+
self.ds_depth_1 = ds_depth_1
|
1154 |
+
self.ds_timesteps_1 = ds_timesteps_1
|
1155 |
+
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
1156 |
+
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
1157 |
+
self.ds_ratio = ds_ratio
|
1158 |
+
|
1159 |
+
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
|
1160 |
+
r"""
|
1161 |
+
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet.
|
1162 |
+
"""
|
1163 |
+
_self = self.delegate
|
1164 |
+
|
1165 |
+
# broadcast timesteps to batch dimension
|
1166 |
+
timesteps = timesteps.expand(x.shape[0])
|
1167 |
+
|
1168 |
+
hs = []
|
1169 |
+
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
1170 |
+
t_emb = t_emb.to(x.dtype)
|
1171 |
+
emb = _self.time_embed(t_emb)
|
1172 |
+
|
1173 |
+
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
1174 |
+
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
1175 |
+
# assert x.dtype == _self.dtype
|
1176 |
+
emb = emb + _self.label_emb(y)
|
1177 |
+
|
1178 |
+
def call_module(module, h, emb, context):
|
1179 |
+
x = h
|
1180 |
+
for layer in module:
|
1181 |
+
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
1182 |
+
if isinstance(layer, ResnetBlock2D):
|
1183 |
+
x = layer(x, emb)
|
1184 |
+
elif isinstance(layer, Transformer2DModel):
|
1185 |
+
x = layer(x, context)
|
1186 |
+
else:
|
1187 |
+
x = layer(x)
|
1188 |
+
return x
|
1189 |
+
|
1190 |
+
# h = x.type(self.dtype)
|
1191 |
+
h = x
|
1192 |
+
|
1193 |
+
for depth, module in enumerate(_self.input_blocks):
|
1194 |
+
# Deep Shrink
|
1195 |
+
if self.ds_depth_1 is not None:
|
1196 |
+
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
1197 |
+
self.ds_depth_2 is not None
|
1198 |
+
and depth == self.ds_depth_2
|
1199 |
+
and timesteps[0] < self.ds_timesteps_1
|
1200 |
+
and timesteps[0] >= self.ds_timesteps_2
|
1201 |
+
):
|
1202 |
+
# print("downsample", h.shape, self.ds_ratio)
|
1203 |
+
org_dtype = h.dtype
|
1204 |
+
if org_dtype == torch.bfloat16:
|
1205 |
+
h = h.to(torch.float32)
|
1206 |
+
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
1207 |
+
|
1208 |
+
h = call_module(module, h, emb, context)
|
1209 |
+
hs.append(h)
|
1210 |
+
|
1211 |
+
h = call_module(_self.middle_block, h, emb, context)
|
1212 |
+
if mid_add is not None:
|
1213 |
+
h = h + mid_add
|
1214 |
+
|
1215 |
+
for module in _self.output_blocks:
|
1216 |
+
# Deep Shrink
|
1217 |
+
if self.ds_depth_1 is not None:
|
1218 |
+
if hs[-1].shape[-2:] != h.shape[-2:]:
|
1219 |
+
# print("upsample", h.shape, hs[-1].shape)
|
1220 |
+
h = resize_like(h, hs[-1])
|
1221 |
+
|
1222 |
+
resi = hs.pop()
|
1223 |
+
if input_resi_add is not None:
|
1224 |
+
resi = resi + input_resi_add.pop()
|
1225 |
+
|
1226 |
+
h = torch.cat([h, resi], dim=1)
|
1227 |
+
h = call_module(module, h, emb, context)
|
1228 |
+
|
1229 |
+
# Deep Shrink: in case of depth 0
|
1230 |
+
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
|
1231 |
+
# print("upsample", h.shape, x.shape)
|
1232 |
+
h = resize_like(h, x)
|
1233 |
+
|
1234 |
+
h = h.type(x.dtype)
|
1235 |
+
h = call_module(_self.out, h, emb, context)
|
1236 |
+
|
1237 |
+
return h
|
1238 |
+
|
1239 |
+
|
1240 |
+
if __name__ == "__main__":
|
1241 |
+
import time
|
1242 |
+
|
1243 |
+
logger.info("create unet")
|
1244 |
+
unet = SdxlUNet2DConditionModel()
|
1245 |
+
|
1246 |
+
unet.to("cuda")
|
1247 |
+
unet.set_use_memory_efficient_attention(True, False)
|
1248 |
+
unet.set_gradient_checkpointing(True)
|
1249 |
+
unet.train()
|
1250 |
+
|
1251 |
+
# 使用メモリ量確認用の疑似学習ループ
|
1252 |
+
logger.info("preparing optimizer")
|
1253 |
+
|
1254 |
+
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
1255 |
+
|
1256 |
+
# import bitsandbytes
|
1257 |
+
# optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
|
1258 |
+
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
1259 |
+
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
1260 |
+
|
1261 |
+
import transformers
|
1262 |
+
|
1263 |
+
optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
|
1264 |
+
|
1265 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
1266 |
+
|
1267 |
+
logger.info("start training")
|
1268 |
+
steps = 10
|
1269 |
+
batch_size = 1
|
1270 |
+
|
1271 |
+
for step in range(steps):
|
1272 |
+
logger.info(f"step {step}")
|
1273 |
+
if step == 1:
|
1274 |
+
time_start = time.perf_counter()
|
1275 |
+
|
1276 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
1277 |
+
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
1278 |
+
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
1279 |
+
y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
|
1280 |
+
|
1281 |
+
with torch.cuda.amp.autocast(enabled=True):
|
1282 |
+
output = unet(x, t, ctx, y)
|
1283 |
+
target = torch.randn_like(output)
|
1284 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
1285 |
+
|
1286 |
+
scaler.scale(loss).backward()
|
1287 |
+
scaler.step(optimizer)
|
1288 |
+
scaler.update()
|
1289 |
+
optimizer.zero_grad(set_to_none=True)
|
1290 |
+
|
1291 |
+
time_end = time.perf_counter()
|
1292 |
+
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
library/sdxl_train_util.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
8 |
+
|
9 |
+
init_ipex()
|
10 |
+
|
11 |
+
from accelerate import init_empty_weights
|
12 |
+
from tqdm import tqdm
|
13 |
+
from transformers import CLIPTokenizer
|
14 |
+
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
15 |
+
from .utils import setup_logging
|
16 |
+
|
17 |
+
setup_logging()
|
18 |
+
import logging
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
23 |
+
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
24 |
+
|
25 |
+
# DEFAULT_NOISE_OFFSET = 0.0357
|
26 |
+
|
27 |
+
|
28 |
+
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
29 |
+
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
30 |
+
for pi in range(accelerator.state.num_processes):
|
31 |
+
if pi == accelerator.state.local_process_index:
|
32 |
+
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
33 |
+
|
34 |
+
(
|
35 |
+
load_stable_diffusion_format,
|
36 |
+
text_encoder1,
|
37 |
+
text_encoder2,
|
38 |
+
vae,
|
39 |
+
unet,
|
40 |
+
logit_scale,
|
41 |
+
ckpt_info,
|
42 |
+
) = _load_target_model(
|
43 |
+
args.pretrained_model_name_or_path,
|
44 |
+
args.vae,
|
45 |
+
model_version,
|
46 |
+
weight_dtype,
|
47 |
+
accelerator.device if args.lowram else "cpu",
|
48 |
+
model_dtype,
|
49 |
+
args.disable_mmap_load_safetensors,
|
50 |
+
)
|
51 |
+
|
52 |
+
# work on low-ram device
|
53 |
+
if args.lowram:
|
54 |
+
text_encoder1.to(accelerator.device)
|
55 |
+
text_encoder2.to(accelerator.device)
|
56 |
+
unet.to(accelerator.device)
|
57 |
+
vae.to(accelerator.device)
|
58 |
+
|
59 |
+
clean_memory_on_device(accelerator.device)
|
60 |
+
accelerator.wait_for_everyone()
|
61 |
+
|
62 |
+
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
63 |
+
|
64 |
+
|
65 |
+
def _load_target_model(
|
66 |
+
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
|
67 |
+
):
|
68 |
+
# model_dtype only work with full fp16/bf16
|
69 |
+
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
70 |
+
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
71 |
+
|
72 |
+
if load_stable_diffusion_format:
|
73 |
+
logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
|
74 |
+
(
|
75 |
+
text_encoder1,
|
76 |
+
text_encoder2,
|
77 |
+
vae,
|
78 |
+
unet,
|
79 |
+
logit_scale,
|
80 |
+
ckpt_info,
|
81 |
+
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
|
82 |
+
else:
|
83 |
+
# Diffusers model is loaded to CPU
|
84 |
+
from diffusers import StableDiffusionXLPipeline
|
85 |
+
|
86 |
+
variant = "fp16" if weight_dtype == torch.float16 else None
|
87 |
+
logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
88 |
+
try:
|
89 |
+
try:
|
90 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
91 |
+
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
|
92 |
+
)
|
93 |
+
except EnvironmentError as ex:
|
94 |
+
if variant is not None:
|
95 |
+
logger.info("try to load fp32 model")
|
96 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
97 |
+
else:
|
98 |
+
raise ex
|
99 |
+
except EnvironmentError as ex:
|
100 |
+
logger.error(
|
101 |
+
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
102 |
+
)
|
103 |
+
raise ex
|
104 |
+
|
105 |
+
text_encoder1 = pipe.text_encoder
|
106 |
+
text_encoder2 = pipe.text_encoder_2
|
107 |
+
|
108 |
+
# convert to fp32 for cache text_encoders outputs
|
109 |
+
if text_encoder1.dtype != torch.float32:
|
110 |
+
text_encoder1 = text_encoder1.to(dtype=torch.float32)
|
111 |
+
if text_encoder2.dtype != torch.float32:
|
112 |
+
text_encoder2 = text_encoder2.to(dtype=torch.float32)
|
113 |
+
|
114 |
+
vae = pipe.vae
|
115 |
+
unet = pipe.unet
|
116 |
+
del pipe
|
117 |
+
|
118 |
+
# Diffusers U-Net to original U-Net
|
119 |
+
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
120 |
+
with init_empty_weights():
|
121 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
122 |
+
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
123 |
+
logger.info("U-Net converted to original U-Net")
|
124 |
+
|
125 |
+
logit_scale = None
|
126 |
+
ckpt_info = None
|
127 |
+
|
128 |
+
# VAEを読み込む
|
129 |
+
if vae_path is not None:
|
130 |
+
vae = model_util.load_vae(vae_path, weight_dtype)
|
131 |
+
logger.info("additional VAE loaded")
|
132 |
+
|
133 |
+
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
134 |
+
|
135 |
+
|
136 |
+
def load_tokenizers(args: argparse.Namespace):
|
137 |
+
logger.info("prepare tokenizers")
|
138 |
+
|
139 |
+
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
140 |
+
tokeniers = []
|
141 |
+
for i, original_path in enumerate(original_paths):
|
142 |
+
tokenizer: CLIPTokenizer = None
|
143 |
+
if args.tokenizer_cache_dir:
|
144 |
+
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
145 |
+
if os.path.exists(local_tokenizer_path):
|
146 |
+
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
147 |
+
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
148 |
+
|
149 |
+
if tokenizer is None:
|
150 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
151 |
+
|
152 |
+
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
153 |
+
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
154 |
+
tokenizer.save_pretrained(local_tokenizer_path)
|
155 |
+
|
156 |
+
if i == 1:
|
157 |
+
tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
|
158 |
+
|
159 |
+
tokeniers.append(tokenizer)
|
160 |
+
|
161 |
+
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
162 |
+
logger.info(f"update token length: {args.max_token_length}")
|
163 |
+
|
164 |
+
return tokeniers
|
165 |
+
|
166 |
+
|
167 |
+
def match_mixed_precision(args, weight_dtype):
|
168 |
+
if args.full_fp16:
|
169 |
+
assert (
|
170 |
+
weight_dtype == torch.float16
|
171 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
172 |
+
return weight_dtype
|
173 |
+
elif args.full_bf16:
|
174 |
+
assert (
|
175 |
+
weight_dtype == torch.bfloat16
|
176 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
177 |
+
return weight_dtype
|
178 |
+
else:
|
179 |
+
return None
|
180 |
+
|
181 |
+
|
182 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
183 |
+
"""
|
184 |
+
Create sinusoidal timestep embeddings.
|
185 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
186 |
+
These may be fractional.
|
187 |
+
:param dim: the dimension of the output.
|
188 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
189 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
190 |
+
"""
|
191 |
+
half = dim // 2
|
192 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
193 |
+
device=timesteps.device
|
194 |
+
)
|
195 |
+
args = timesteps[:, None].float() * freqs[None]
|
196 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
197 |
+
if dim % 2:
|
198 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
199 |
+
return embedding
|
200 |
+
|
201 |
+
|
202 |
+
def get_timestep_embedding(x, outdim):
|
203 |
+
assert len(x.shape) == 2
|
204 |
+
b, dims = x.shape[0], x.shape[1]
|
205 |
+
x = torch.flatten(x)
|
206 |
+
emb = timestep_embedding(x, outdim)
|
207 |
+
emb = torch.reshape(emb, (b, dims * outdim))
|
208 |
+
return emb
|
209 |
+
|
210 |
+
|
211 |
+
def get_size_embeddings(orig_size, crop_size, target_size, device):
|
212 |
+
emb1 = get_timestep_embedding(orig_size, 256)
|
213 |
+
emb2 = get_timestep_embedding(crop_size, 256)
|
214 |
+
emb3 = get_timestep_embedding(target_size, 256)
|
215 |
+
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
|
216 |
+
return vector
|
217 |
+
|
218 |
+
|
219 |
+
def save_sd_model_on_train_end(
|
220 |
+
args: argparse.Namespace,
|
221 |
+
src_path: str,
|
222 |
+
save_stable_diffusion_format: bool,
|
223 |
+
use_safetensors: bool,
|
224 |
+
save_dtype: torch.dtype,
|
225 |
+
epoch: int,
|
226 |
+
global_step: int,
|
227 |
+
text_encoder1,
|
228 |
+
text_encoder2,
|
229 |
+
unet,
|
230 |
+
vae,
|
231 |
+
logit_scale,
|
232 |
+
ckpt_info,
|
233 |
+
):
|
234 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
235 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
236 |
+
sdxl_model_util.save_stable_diffusion_checkpoint(
|
237 |
+
ckpt_file,
|
238 |
+
text_encoder1,
|
239 |
+
text_encoder2,
|
240 |
+
unet,
|
241 |
+
epoch_no,
|
242 |
+
global_step,
|
243 |
+
ckpt_info,
|
244 |
+
vae,
|
245 |
+
logit_scale,
|
246 |
+
sai_metadata,
|
247 |
+
save_dtype,
|
248 |
+
)
|
249 |
+
|
250 |
+
def diffusers_saver(out_dir):
|
251 |
+
sdxl_model_util.save_diffusers_checkpoint(
|
252 |
+
out_dir,
|
253 |
+
text_encoder1,
|
254 |
+
text_encoder2,
|
255 |
+
unet,
|
256 |
+
src_path,
|
257 |
+
vae,
|
258 |
+
use_safetensors=use_safetensors,
|
259 |
+
save_dtype=save_dtype,
|
260 |
+
)
|
261 |
+
|
262 |
+
train_util.save_sd_model_on_train_end_common(
|
263 |
+
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
264 |
+
)
|
265 |
+
|
266 |
+
|
267 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
268 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
269 |
+
def save_sd_model_on_epoch_end_or_stepwise(
|
270 |
+
args: argparse.Namespace,
|
271 |
+
on_epoch_end: bool,
|
272 |
+
accelerator,
|
273 |
+
src_path,
|
274 |
+
save_stable_diffusion_format: bool,
|
275 |
+
use_safetensors: bool,
|
276 |
+
save_dtype: torch.dtype,
|
277 |
+
epoch: int,
|
278 |
+
num_train_epochs: int,
|
279 |
+
global_step: int,
|
280 |
+
text_encoder1,
|
281 |
+
text_encoder2,
|
282 |
+
unet,
|
283 |
+
vae,
|
284 |
+
logit_scale,
|
285 |
+
ckpt_info,
|
286 |
+
):
|
287 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
288 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
289 |
+
sdxl_model_util.save_stable_diffusion_checkpoint(
|
290 |
+
ckpt_file,
|
291 |
+
text_encoder1,
|
292 |
+
text_encoder2,
|
293 |
+
unet,
|
294 |
+
epoch_no,
|
295 |
+
global_step,
|
296 |
+
ckpt_info,
|
297 |
+
vae,
|
298 |
+
logit_scale,
|
299 |
+
sai_metadata,
|
300 |
+
save_dtype,
|
301 |
+
)
|
302 |
+
|
303 |
+
def diffusers_saver(out_dir):
|
304 |
+
sdxl_model_util.save_diffusers_checkpoint(
|
305 |
+
out_dir,
|
306 |
+
text_encoder1,
|
307 |
+
text_encoder2,
|
308 |
+
unet,
|
309 |
+
src_path,
|
310 |
+
vae,
|
311 |
+
use_safetensors=use_safetensors,
|
312 |
+
save_dtype=save_dtype,
|
313 |
+
)
|
314 |
+
|
315 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
316 |
+
args,
|
317 |
+
on_epoch_end,
|
318 |
+
accelerator,
|
319 |
+
save_stable_diffusion_format,
|
320 |
+
use_safetensors,
|
321 |
+
epoch,
|
322 |
+
num_train_epochs,
|
323 |
+
global_step,
|
324 |
+
sd_saver,
|
325 |
+
diffusers_saver,
|
326 |
+
)
|
327 |
+
|
328 |
+
|
329 |
+
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
|
330 |
+
parser.add_argument(
|
331 |
+
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
332 |
+
)
|
333 |
+
parser.add_argument(
|
334 |
+
"--cache_text_encoder_outputs_to_disk",
|
335 |
+
action="store_true",
|
336 |
+
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
337 |
+
)
|
338 |
+
parser.add_argument(
|
339 |
+
"--disable_mmap_load_safetensors",
|
340 |
+
action="store_true",
|
341 |
+
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
346 |
+
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
347 |
+
if args.v_parameterization:
|
348 |
+
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
349 |
+
|
350 |
+
if args.clip_skip is not None:
|
351 |
+
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
352 |
+
|
353 |
+
# if args.multires_noise_iterations:
|
354 |
+
# logger.info(
|
355 |
+
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
356 |
+
# )
|
357 |
+
# else:
|
358 |
+
# if args.noise_offset is None:
|
359 |
+
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
360 |
+
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
361 |
+
# logger.info(
|
362 |
+
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
363 |
+
# )
|
364 |
+
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
365 |
+
|
366 |
+
# assert (
|
367 |
+
# not hasattr(args, "weighted_captions") or not args.weighted_captions
|
368 |
+
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
369 |
+
|
370 |
+
if supportTextEncoderCaching:
|
371 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
372 |
+
args.cache_text_encoder_outputs = True
|
373 |
+
logger.warning(
|
374 |
+
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
375 |
+
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
376 |
+
)
|
377 |
+
|
378 |
+
|
379 |
+
def sample_images(*args, **kwargs):
|
380 |
+
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
381 |
+
|
382 |
+
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
library/slicing_vae.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Diffusers to reduce VRAM usage
|
2 |
+
|
3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
|
23 |
+
|
24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
26 |
+
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
27 |
+
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
|
28 |
+
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
29 |
+
from .utils import setup_logging
|
30 |
+
setup_logging()
|
31 |
+
import logging
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
def slice_h(x, num_slices):
|
35 |
+
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
36 |
+
# Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
|
37 |
+
# NCHWでもNHWCでもどちらでも動く
|
38 |
+
size = (x.shape[2] + num_slices - 1) // num_slices
|
39 |
+
sliced = []
|
40 |
+
for i in range(num_slices):
|
41 |
+
if i == 0:
|
42 |
+
sliced.append(x[:, :, : size + 1, :])
|
43 |
+
else:
|
44 |
+
end = size * (i + 1) + 1
|
45 |
+
if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
|
46 |
+
end = x.shape[2]
|
47 |
+
sliced.append(x[:, :, size * i - 1 : end, :])
|
48 |
+
if end >= x.shape[2]:
|
49 |
+
break
|
50 |
+
return sliced
|
51 |
+
|
52 |
+
|
53 |
+
def cat_h(sliced):
|
54 |
+
# padding分を除いて結合する
|
55 |
+
cat = []
|
56 |
+
for i, x in enumerate(sliced):
|
57 |
+
if i == 0:
|
58 |
+
cat.append(x[:, :, :-1, :])
|
59 |
+
elif i == len(sliced) - 1:
|
60 |
+
cat.append(x[:, :, 1:, :])
|
61 |
+
else:
|
62 |
+
cat.append(x[:, :, 1:-1, :])
|
63 |
+
del x
|
64 |
+
x = torch.cat(cat, dim=2)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
69 |
+
assert _self.upsample is None and _self.downsample is None
|
70 |
+
assert _self.norm1.num_groups == _self.norm2.num_groups
|
71 |
+
assert temb is None
|
72 |
+
|
73 |
+
# make sure norms are on cpu
|
74 |
+
org_device = input_tensor.device
|
75 |
+
cpu_device = torch.device("cpu")
|
76 |
+
_self.norm1.to(cpu_device)
|
77 |
+
_self.norm2.to(cpu_device)
|
78 |
+
|
79 |
+
# GroupNormがCPUでfp16で動かない対策
|
80 |
+
org_dtype = input_tensor.dtype
|
81 |
+
if org_dtype == torch.float16:
|
82 |
+
_self.norm1.to(torch.float32)
|
83 |
+
_self.norm2.to(torch.float32)
|
84 |
+
|
85 |
+
# すべてのテンソルをCPUに移動する
|
86 |
+
input_tensor = input_tensor.to(cpu_device)
|
87 |
+
hidden_states = input_tensor
|
88 |
+
|
89 |
+
# どうもこれは結果が異なるようだ……
|
90 |
+
# def sliced_norm1(norm, x):
|
91 |
+
# num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
|
92 |
+
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
93 |
+
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
94 |
+
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
95 |
+
# logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
96 |
+
# normed_tensor = []
|
97 |
+
# for i in range(num_div):
|
98 |
+
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
99 |
+
# normed_tensor.append(n)
|
100 |
+
# del n
|
101 |
+
# x = torch.cat(normed_tensor, dim=1)
|
102 |
+
# return num_div, x
|
103 |
+
|
104 |
+
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
|
105 |
+
if org_dtype == torch.float16:
|
106 |
+
hidden_states = hidden_states.to(torch.float32)
|
107 |
+
hidden_states = _self.norm1(hidden_states) # run on cpu
|
108 |
+
if org_dtype == torch.float16:
|
109 |
+
hidden_states = hidden_states.to(torch.float16)
|
110 |
+
|
111 |
+
sliced = slice_h(hidden_states, num_slices)
|
112 |
+
del hidden_states
|
113 |
+
|
114 |
+
for i in range(len(sliced)):
|
115 |
+
x = sliced[i]
|
116 |
+
sliced[i] = None
|
117 |
+
|
118 |
+
# 計算する部分だけGPUに移動する、以下同様
|
119 |
+
x = x.to(org_device)
|
120 |
+
x = _self.nonlinearity(x)
|
121 |
+
x = _self.conv1(x)
|
122 |
+
x = x.to(cpu_device)
|
123 |
+
sliced[i] = x
|
124 |
+
del x
|
125 |
+
|
126 |
+
hidden_states = cat_h(sliced)
|
127 |
+
del sliced
|
128 |
+
|
129 |
+
if org_dtype == torch.float16:
|
130 |
+
hidden_states = hidden_states.to(torch.float32)
|
131 |
+
hidden_states = _self.norm2(hidden_states) # run on cpu
|
132 |
+
if org_dtype == torch.float16:
|
133 |
+
hidden_states = hidden_states.to(torch.float16)
|
134 |
+
|
135 |
+
sliced = slice_h(hidden_states, num_slices)
|
136 |
+
del hidden_states
|
137 |
+
|
138 |
+
for i in range(len(sliced)):
|
139 |
+
x = sliced[i]
|
140 |
+
sliced[i] = None
|
141 |
+
|
142 |
+
x = x.to(org_device)
|
143 |
+
x = _self.nonlinearity(x)
|
144 |
+
x = _self.dropout(x)
|
145 |
+
x = _self.conv2(x)
|
146 |
+
x = x.to(cpu_device)
|
147 |
+
sliced[i] = x
|
148 |
+
del x
|
149 |
+
|
150 |
+
hidden_states = cat_h(sliced)
|
151 |
+
del sliced
|
152 |
+
|
153 |
+
# make shortcut
|
154 |
+
if _self.conv_shortcut is not None:
|
155 |
+
sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
|
156 |
+
del input_tensor
|
157 |
+
|
158 |
+
for i in range(len(sliced)):
|
159 |
+
x = sliced[i]
|
160 |
+
sliced[i] = None
|
161 |
+
|
162 |
+
x = x.to(org_device)
|
163 |
+
x = _self.conv_shortcut(x)
|
164 |
+
x = x.to(cpu_device)
|
165 |
+
sliced[i] = x
|
166 |
+
del x
|
167 |
+
|
168 |
+
input_tensor = torch.cat(sliced, dim=2)
|
169 |
+
del sliced
|
170 |
+
|
171 |
+
output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
|
172 |
+
|
173 |
+
output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
|
174 |
+
return output_tensor
|
175 |
+
|
176 |
+
|
177 |
+
class SlicingEncoder(nn.Module):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
in_channels=3,
|
181 |
+
out_channels=3,
|
182 |
+
down_block_types=("DownEncoderBlock2D",),
|
183 |
+
block_out_channels=(64,),
|
184 |
+
layers_per_block=2,
|
185 |
+
norm_num_groups=32,
|
186 |
+
act_fn="silu",
|
187 |
+
double_z=True,
|
188 |
+
num_slices=2,
|
189 |
+
):
|
190 |
+
super().__init__()
|
191 |
+
self.layers_per_block = layers_per_block
|
192 |
+
|
193 |
+
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
194 |
+
|
195 |
+
self.mid_block = None
|
196 |
+
self.down_blocks = nn.ModuleList([])
|
197 |
+
|
198 |
+
# down
|
199 |
+
output_channel = block_out_channels[0]
|
200 |
+
for i, down_block_type in enumerate(down_block_types):
|
201 |
+
input_channel = output_channel
|
202 |
+
output_channel = block_out_channels[i]
|
203 |
+
is_final_block = i == len(block_out_channels) - 1
|
204 |
+
|
205 |
+
down_block = get_down_block(
|
206 |
+
down_block_type,
|
207 |
+
num_layers=self.layers_per_block,
|
208 |
+
in_channels=input_channel,
|
209 |
+
out_channels=output_channel,
|
210 |
+
add_downsample=not is_final_block,
|
211 |
+
resnet_eps=1e-6,
|
212 |
+
downsample_padding=0,
|
213 |
+
resnet_act_fn=act_fn,
|
214 |
+
resnet_groups=norm_num_groups,
|
215 |
+
attention_head_dim=output_channel,
|
216 |
+
temb_channels=None,
|
217 |
+
)
|
218 |
+
self.down_blocks.append(down_block)
|
219 |
+
|
220 |
+
# mid
|
221 |
+
self.mid_block = UNetMidBlock2D(
|
222 |
+
in_channels=block_out_channels[-1],
|
223 |
+
resnet_eps=1e-6,
|
224 |
+
resnet_act_fn=act_fn,
|
225 |
+
output_scale_factor=1,
|
226 |
+
resnet_time_scale_shift="default",
|
227 |
+
attention_head_dim=block_out_channels[-1],
|
228 |
+
resnet_groups=norm_num_groups,
|
229 |
+
temb_channels=None,
|
230 |
+
)
|
231 |
+
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
232 |
+
|
233 |
+
# out
|
234 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
235 |
+
self.conv_act = nn.SiLU()
|
236 |
+
|
237 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
238 |
+
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
239 |
+
|
240 |
+
# replace forward of ResBlocks
|
241 |
+
def wrapper(func, module, num_slices):
|
242 |
+
def forward(*args, **kwargs):
|
243 |
+
return func(module, num_slices, *args, **kwargs)
|
244 |
+
|
245 |
+
return forward
|
246 |
+
|
247 |
+
self.num_slices = num_slices
|
248 |
+
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
249 |
+
# logger.info(f"initial divisor: {div}")
|
250 |
+
if div >= 2:
|
251 |
+
div = int(div)
|
252 |
+
for resnet in self.mid_block.resnets:
|
253 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
254 |
+
# midblock doesn't have downsample
|
255 |
+
|
256 |
+
for i, down_block in enumerate(self.down_blocks[::-1]):
|
257 |
+
if div >= 2:
|
258 |
+
div = int(div)
|
259 |
+
# logger.info(f"down block: {i} divisor: {div}")
|
260 |
+
for resnet in down_block.resnets:
|
261 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
262 |
+
if down_block.downsamplers is not None:
|
263 |
+
# logger.info("has downsample")
|
264 |
+
for downsample in down_block.downsamplers:
|
265 |
+
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
266 |
+
div *= 2
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
sample = x
|
270 |
+
del x
|
271 |
+
|
272 |
+
org_device = sample.device
|
273 |
+
cpu_device = torch.device("cpu")
|
274 |
+
|
275 |
+
# sample = self.conv_in(sample)
|
276 |
+
sample = sample.to(cpu_device)
|
277 |
+
sliced = slice_h(sample, self.num_slices)
|
278 |
+
del sample
|
279 |
+
|
280 |
+
for i in range(len(sliced)):
|
281 |
+
x = sliced[i]
|
282 |
+
sliced[i] = None
|
283 |
+
|
284 |
+
x = x.to(org_device)
|
285 |
+
x = self.conv_in(x)
|
286 |
+
x = x.to(cpu_device)
|
287 |
+
sliced[i] = x
|
288 |
+
del x
|
289 |
+
|
290 |
+
sample = cat_h(sliced)
|
291 |
+
del sliced
|
292 |
+
|
293 |
+
sample = sample.to(org_device)
|
294 |
+
|
295 |
+
# down
|
296 |
+
for down_block in self.down_blocks:
|
297 |
+
sample = down_block(sample)
|
298 |
+
|
299 |
+
# middle
|
300 |
+
sample = self.mid_block(sample)
|
301 |
+
|
302 |
+
# post-process
|
303 |
+
# ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
|
304 |
+
sample = self.conv_norm_out(sample)
|
305 |
+
sample = self.conv_act(sample)
|
306 |
+
sample = self.conv_out(sample)
|
307 |
+
|
308 |
+
return sample
|
309 |
+
|
310 |
+
def downsample_forward(self, _self, num_slices, hidden_states):
|
311 |
+
assert hidden_states.shape[1] == _self.channels
|
312 |
+
assert _self.use_conv and _self.padding == 0
|
313 |
+
logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
|
314 |
+
|
315 |
+
org_device = hidden_states.device
|
316 |
+
cpu_device = torch.device("cpu")
|
317 |
+
|
318 |
+
hidden_states = hidden_states.to(cpu_device)
|
319 |
+
pad = (0, 1, 0, 1)
|
320 |
+
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
|
321 |
+
|
322 |
+
# slice with even number because of stride 2
|
323 |
+
# strideが2なので偶数でスライスする
|
324 |
+
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
325 |
+
size = (hidden_states.shape[2] + num_slices - 1) // num_slices
|
326 |
+
size = size + 1 if size % 2 == 1 else size
|
327 |
+
|
328 |
+
sliced = []
|
329 |
+
for i in range(num_slices):
|
330 |
+
if i == 0:
|
331 |
+
sliced.append(hidden_states[:, :, : size + 1, :])
|
332 |
+
else:
|
333 |
+
end = size * (i + 1) + 1
|
334 |
+
if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
|
335 |
+
end = hidden_states.shape[2]
|
336 |
+
sliced.append(hidden_states[:, :, size * i - 1 : end, :])
|
337 |
+
if end >= hidden_states.shape[2]:
|
338 |
+
break
|
339 |
+
del hidden_states
|
340 |
+
|
341 |
+
for i in range(len(sliced)):
|
342 |
+
x = sliced[i]
|
343 |
+
sliced[i] = None
|
344 |
+
|
345 |
+
x = x.to(org_device)
|
346 |
+
x = _self.conv(x)
|
347 |
+
x = x.to(cpu_device)
|
348 |
+
|
349 |
+
# ここだけ雰囲気が違うのはCopilotのせい
|
350 |
+
if i == 0:
|
351 |
+
hidden_states = x
|
352 |
+
else:
|
353 |
+
hidden_states = torch.cat([hidden_states, x], dim=2)
|
354 |
+
|
355 |
+
hidden_states = hidden_states.to(org_device)
|
356 |
+
# logger.info(f"downsample forward done {hidden_states.shape}")
|
357 |
+
return hidden_states
|
358 |
+
|
359 |
+
|
360 |
+
class SlicingDecoder(nn.Module):
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
in_channels=3,
|
364 |
+
out_channels=3,
|
365 |
+
up_block_types=("UpDecoderBlock2D",),
|
366 |
+
block_out_channels=(64,),
|
367 |
+
layers_per_block=2,
|
368 |
+
norm_num_groups=32,
|
369 |
+
act_fn="silu",
|
370 |
+
num_slices=2,
|
371 |
+
):
|
372 |
+
super().__init__()
|
373 |
+
self.layers_per_block = layers_per_block
|
374 |
+
|
375 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
376 |
+
|
377 |
+
self.mid_block = None
|
378 |
+
self.up_blocks = nn.ModuleList([])
|
379 |
+
|
380 |
+
# mid
|
381 |
+
self.mid_block = UNetMidBlock2D(
|
382 |
+
in_channels=block_out_channels[-1],
|
383 |
+
resnet_eps=1e-6,
|
384 |
+
resnet_act_fn=act_fn,
|
385 |
+
output_scale_factor=1,
|
386 |
+
resnet_time_scale_shift="default",
|
387 |
+
attention_head_dim=block_out_channels[-1],
|
388 |
+
resnet_groups=norm_num_groups,
|
389 |
+
temb_channels=None,
|
390 |
+
)
|
391 |
+
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
392 |
+
|
393 |
+
# up
|
394 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
395 |
+
output_channel = reversed_block_out_channels[0]
|
396 |
+
for i, up_block_type in enumerate(up_block_types):
|
397 |
+
prev_output_channel = output_channel
|
398 |
+
output_channel = reversed_block_out_channels[i]
|
399 |
+
|
400 |
+
is_final_block = i == len(block_out_channels) - 1
|
401 |
+
|
402 |
+
up_block = get_up_block(
|
403 |
+
up_block_type,
|
404 |
+
num_layers=self.layers_per_block + 1,
|
405 |
+
in_channels=prev_output_channel,
|
406 |
+
out_channels=output_channel,
|
407 |
+
prev_output_channel=None,
|
408 |
+
add_upsample=not is_final_block,
|
409 |
+
resnet_eps=1e-6,
|
410 |
+
resnet_act_fn=act_fn,
|
411 |
+
resnet_groups=norm_num_groups,
|
412 |
+
attention_head_dim=output_channel,
|
413 |
+
temb_channels=None,
|
414 |
+
)
|
415 |
+
self.up_blocks.append(up_block)
|
416 |
+
prev_output_channel = output_channel
|
417 |
+
|
418 |
+
# out
|
419 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
420 |
+
self.conv_act = nn.SiLU()
|
421 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
422 |
+
|
423 |
+
# replace forward of ResBlocks
|
424 |
+
def wrapper(func, module, num_slices):
|
425 |
+
def forward(*args, **kwargs):
|
426 |
+
return func(module, num_slices, *args, **kwargs)
|
427 |
+
|
428 |
+
return forward
|
429 |
+
|
430 |
+
self.num_slices = num_slices
|
431 |
+
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
432 |
+
logger.info(f"initial divisor: {div}")
|
433 |
+
if div >= 2:
|
434 |
+
div = int(div)
|
435 |
+
for resnet in self.mid_block.resnets:
|
436 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
437 |
+
# midblock doesn't have upsample
|
438 |
+
|
439 |
+
for i, up_block in enumerate(self.up_blocks):
|
440 |
+
if div >= 2:
|
441 |
+
div = int(div)
|
442 |
+
# logger.info(f"up block: {i} divisor: {div}")
|
443 |
+
for resnet in up_block.resnets:
|
444 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
445 |
+
if up_block.upsamplers is not None:
|
446 |
+
# logger.info("has upsample")
|
447 |
+
for upsample in up_block.upsamplers:
|
448 |
+
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
449 |
+
div *= 2
|
450 |
+
|
451 |
+
def forward(self, z):
|
452 |
+
sample = z
|
453 |
+
del z
|
454 |
+
sample = self.conv_in(sample)
|
455 |
+
|
456 |
+
# middle
|
457 |
+
sample = self.mid_block(sample)
|
458 |
+
|
459 |
+
# up
|
460 |
+
for i, up_block in enumerate(self.up_blocks):
|
461 |
+
sample = up_block(sample)
|
462 |
+
|
463 |
+
# post-process
|
464 |
+
sample = self.conv_norm_out(sample)
|
465 |
+
sample = self.conv_act(sample)
|
466 |
+
|
467 |
+
# conv_out with slicing because of VRAM usage
|
468 |
+
# conv_outはとてもVRAM使うのでスライスして対応
|
469 |
+
org_device = sample.device
|
470 |
+
cpu_device = torch.device("cpu")
|
471 |
+
sample = sample.to(cpu_device)
|
472 |
+
|
473 |
+
sliced = slice_h(sample, self.num_slices)
|
474 |
+
del sample
|
475 |
+
for i in range(len(sliced)):
|
476 |
+
x = sliced[i]
|
477 |
+
sliced[i] = None
|
478 |
+
|
479 |
+
x = x.to(org_device)
|
480 |
+
x = self.conv_out(x)
|
481 |
+
x = x.to(cpu_device)
|
482 |
+
sliced[i] = x
|
483 |
+
sample = cat_h(sliced)
|
484 |
+
del sliced
|
485 |
+
|
486 |
+
sample = sample.to(org_device)
|
487 |
+
return sample
|
488 |
+
|
489 |
+
def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
|
490 |
+
assert hidden_states.shape[1] == _self.channels
|
491 |
+
assert _self.use_conv_transpose == False and _self.use_conv
|
492 |
+
|
493 |
+
org_dtype = hidden_states.dtype
|
494 |
+
org_device = hidden_states.device
|
495 |
+
cpu_device = torch.device("cpu")
|
496 |
+
|
497 |
+
hidden_states = hidden_states.to(cpu_device)
|
498 |
+
sliced = slice_h(hidden_states, num_slices)
|
499 |
+
del hidden_states
|
500 |
+
|
501 |
+
for i in range(len(sliced)):
|
502 |
+
x = sliced[i]
|
503 |
+
sliced[i] = None
|
504 |
+
|
505 |
+
x = x.to(org_device)
|
506 |
+
|
507 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
508 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
509 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
510 |
+
# PyTorch 2で直らないかね……
|
511 |
+
if org_dtype == torch.bfloat16:
|
512 |
+
x = x.to(torch.float32)
|
513 |
+
|
514 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
515 |
+
|
516 |
+
if org_dtype == torch.bfloat16:
|
517 |
+
x = x.to(org_dtype)
|
518 |
+
|
519 |
+
x = _self.conv(x)
|
520 |
+
|
521 |
+
# upsampleされてるのでpadは2になる
|
522 |
+
if i == 0:
|
523 |
+
x = x[:, :, :-2, :]
|
524 |
+
elif i == num_slices - 1:
|
525 |
+
x = x[:, :, 2:, :]
|
526 |
+
else:
|
527 |
+
x = x[:, :, 2:-2, :]
|
528 |
+
|
529 |
+
x = x.to(cpu_device)
|
530 |
+
sliced[i] = x
|
531 |
+
del x
|
532 |
+
|
533 |
+
hidden_states = torch.cat(sliced, dim=2)
|
534 |
+
# logger.info(f"us hidden_states {hidden_states.shape}")
|
535 |
+
del sliced
|
536 |
+
|
537 |
+
hidden_states = hidden_states.to(org_device)
|
538 |
+
return hidden_states
|
539 |
+
|
540 |
+
|
541 |
+
class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
|
542 |
+
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
|
543 |
+
and Max Welling.
|
544 |
+
|
545 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
546 |
+
implements for all the model (such as downloading or saving, etc.)
|
547 |
+
|
548 |
+
Parameters:
|
549 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
550 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
551 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
552 |
+
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
553 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
554 |
+
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
555 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
556 |
+
obj:`(64,)`): Tuple of block output channels.
|
557 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
558 |
+
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
|
559 |
+
sample_size (`int`, *optional*, defaults to `32`): TODO
|
560 |
+
"""
|
561 |
+
|
562 |
+
@register_to_config
|
563 |
+
def __init__(
|
564 |
+
self,
|
565 |
+
in_channels: int = 3,
|
566 |
+
out_channels: int = 3,
|
567 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
568 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
569 |
+
block_out_channels: Tuple[int] = (64,),
|
570 |
+
layers_per_block: int = 1,
|
571 |
+
act_fn: str = "silu",
|
572 |
+
latent_channels: int = 4,
|
573 |
+
norm_num_groups: int = 32,
|
574 |
+
sample_size: int = 32,
|
575 |
+
num_slices: int = 16,
|
576 |
+
):
|
577 |
+
super().__init__()
|
578 |
+
|
579 |
+
# pass init params to Encoder
|
580 |
+
self.encoder = SlicingEncoder(
|
581 |
+
in_channels=in_channels,
|
582 |
+
out_channels=latent_channels,
|
583 |
+
down_block_types=down_block_types,
|
584 |
+
block_out_channels=block_out_channels,
|
585 |
+
layers_per_block=layers_per_block,
|
586 |
+
act_fn=act_fn,
|
587 |
+
norm_num_groups=norm_num_groups,
|
588 |
+
double_z=True,
|
589 |
+
num_slices=num_slices,
|
590 |
+
)
|
591 |
+
|
592 |
+
# pass init params to Decoder
|
593 |
+
self.decoder = SlicingDecoder(
|
594 |
+
in_channels=latent_channels,
|
595 |
+
out_channels=out_channels,
|
596 |
+
up_block_types=up_block_types,
|
597 |
+
block_out_channels=block_out_channels,
|
598 |
+
layers_per_block=layers_per_block,
|
599 |
+
norm_num_groups=norm_num_groups,
|
600 |
+
act_fn=act_fn,
|
601 |
+
num_slices=num_slices,
|
602 |
+
)
|
603 |
+
|
604 |
+
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
605 |
+
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
606 |
+
self.use_slicing = False
|
607 |
+
|
608 |
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
609 |
+
h = self.encoder(x)
|
610 |
+
moments = self.quant_conv(h)
|
611 |
+
posterior = DiagonalGaussianDistribution(moments)
|
612 |
+
|
613 |
+
if not return_dict:
|
614 |
+
return (posterior,)
|
615 |
+
|
616 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
617 |
+
|
618 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
619 |
+
z = self.post_quant_conv(z)
|
620 |
+
dec = self.decoder(z)
|
621 |
+
|
622 |
+
if not return_dict:
|
623 |
+
return (dec,)
|
624 |
+
|
625 |
+
return DecoderOutput(sample=dec)
|
626 |
+
|
627 |
+
# これはバッチ方向のスライシング 紛らわしい
|
628 |
+
def enable_slicing(self):
|
629 |
+
r"""
|
630 |
+
Enable sliced VAE decoding.
|
631 |
+
|
632 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
633 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
634 |
+
"""
|
635 |
+
self.use_slicing = True
|
636 |
+
|
637 |
+
def disable_slicing(self):
|
638 |
+
r"""
|
639 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
640 |
+
decoding in one step.
|
641 |
+
"""
|
642 |
+
self.use_slicing = False
|
643 |
+
|
644 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
645 |
+
if self.use_slicing and z.shape[0] > 1:
|
646 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
647 |
+
decoded = torch.cat(decoded_slices)
|
648 |
+
else:
|
649 |
+
decoded = self._decode(z).sample
|
650 |
+
|
651 |
+
if not return_dict:
|
652 |
+
return (decoded,)
|
653 |
+
|
654 |
+
return DecoderOutput(sample=decoded)
|
655 |
+
|
656 |
+
def forward(
|
657 |
+
self,
|
658 |
+
sample: torch.FloatTensor,
|
659 |
+
sample_posterior: bool = False,
|
660 |
+
return_dict: bool = True,
|
661 |
+
generator: Optional[torch.Generator] = None,
|
662 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
663 |
+
r"""
|
664 |
+
Args:
|
665 |
+
sample (`torch.FloatTensor`): Input sample.
|
666 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
667 |
+
Whether to sample from the posterior.
|
668 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
669 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
670 |
+
"""
|
671 |
+
x = sample
|
672 |
+
posterior = self.encode(x).latent_dist
|
673 |
+
if sample_posterior:
|
674 |
+
z = posterior.sample(generator=generator)
|
675 |
+
else:
|
676 |
+
z = posterior.mode()
|
677 |
+
dec = self.decode(z).sample
|
678 |
+
|
679 |
+
if not return_dict:
|
680 |
+
return (dec,)
|
681 |
+
|
682 |
+
return DecoderOutput(sample=dec)
|
library/strategy_base.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# base class for platform strategies. this file defines the interface for strategies
|
2 |
+
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
from typing import Any, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
10 |
+
|
11 |
+
|
12 |
+
# TODO remove circular import by moving ImageInfo to a separate file
|
13 |
+
# from library.train_util import ImageInfo
|
14 |
+
|
15 |
+
from library.utils import setup_logging
|
16 |
+
|
17 |
+
setup_logging()
|
18 |
+
import logging
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class TokenizeStrategy:
|
24 |
+
_strategy = None # strategy instance: actual strategy class
|
25 |
+
|
26 |
+
_re_attention = re.compile(
|
27 |
+
r"""\\\(|
|
28 |
+
\\\)|
|
29 |
+
\\\[|
|
30 |
+
\\]|
|
31 |
+
\\\\|
|
32 |
+
\\|
|
33 |
+
\(|
|
34 |
+
\[|
|
35 |
+
:([+-]?[.\d]+)\)|
|
36 |
+
\)|
|
37 |
+
]|
|
38 |
+
[^\\()\[\]:]+|
|
39 |
+
:
|
40 |
+
""",
|
41 |
+
re.X,
|
42 |
+
)
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def set_strategy(cls, strategy):
|
46 |
+
if cls._strategy is not None:
|
47 |
+
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
48 |
+
cls._strategy = strategy
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
|
52 |
+
return cls._strategy
|
53 |
+
|
54 |
+
def _load_tokenizer(
|
55 |
+
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
|
56 |
+
) -> Any:
|
57 |
+
tokenizer = None
|
58 |
+
if tokenizer_cache_dir:
|
59 |
+
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
|
60 |
+
if os.path.exists(local_tokenizer_path):
|
61 |
+
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
62 |
+
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
63 |
+
|
64 |
+
if tokenizer is None:
|
65 |
+
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
|
66 |
+
|
67 |
+
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
68 |
+
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
69 |
+
tokenizer.save_pretrained(local_tokenizer_path)
|
70 |
+
|
71 |
+
return tokenizer
|
72 |
+
|
73 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
74 |
+
raise NotImplementedError
|
75 |
+
|
76 |
+
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
77 |
+
"""
|
78 |
+
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
|
79 |
+
"""
|
80 |
+
raise NotImplementedError
|
81 |
+
|
82 |
+
def _get_weighted_input_ids(
|
83 |
+
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
|
84 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
85 |
+
"""
|
86 |
+
max_length includes starting and ending tokens.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def parse_prompt_attention(text):
|
90 |
+
"""
|
91 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
92 |
+
Accepted tokens are:
|
93 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
94 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
95 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
96 |
+
\( - literal character '('
|
97 |
+
\[ - literal character '['
|
98 |
+
\) - literal character ')'
|
99 |
+
\] - literal character ']'
|
100 |
+
\\ - literal character '\'
|
101 |
+
anything else - just text
|
102 |
+
>>> parse_prompt_attention('normal text')
|
103 |
+
[['normal text', 1.0]]
|
104 |
+
>>> parse_prompt_attention('an (important) word')
|
105 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
106 |
+
>>> parse_prompt_attention('(unbalanced')
|
107 |
+
[['unbalanced', 1.1]]
|
108 |
+
>>> parse_prompt_attention('\(literal\]')
|
109 |
+
[['(literal]', 1.0]]
|
110 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
111 |
+
[['unnecessaryparens', 1.1]]
|
112 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
113 |
+
[['a ', 1.0],
|
114 |
+
['house', 1.5730000000000004],
|
115 |
+
[' ', 1.1],
|
116 |
+
['on', 1.0],
|
117 |
+
[' a ', 1.1],
|
118 |
+
['hill', 0.55],
|
119 |
+
[', sun, ', 1.1],
|
120 |
+
['sky', 1.4641000000000006],
|
121 |
+
['.', 1.1]]
|
122 |
+
"""
|
123 |
+
|
124 |
+
res = []
|
125 |
+
round_brackets = []
|
126 |
+
square_brackets = []
|
127 |
+
|
128 |
+
round_bracket_multiplier = 1.1
|
129 |
+
square_bracket_multiplier = 1 / 1.1
|
130 |
+
|
131 |
+
def multiply_range(start_position, multiplier):
|
132 |
+
for p in range(start_position, len(res)):
|
133 |
+
res[p][1] *= multiplier
|
134 |
+
|
135 |
+
for m in TokenizeStrategy._re_attention.finditer(text):
|
136 |
+
text = m.group(0)
|
137 |
+
weight = m.group(1)
|
138 |
+
|
139 |
+
if text.startswith("\\"):
|
140 |
+
res.append([text[1:], 1.0])
|
141 |
+
elif text == "(":
|
142 |
+
round_brackets.append(len(res))
|
143 |
+
elif text == "[":
|
144 |
+
square_brackets.append(len(res))
|
145 |
+
elif weight is not None and len(round_brackets) > 0:
|
146 |
+
multiply_range(round_brackets.pop(), float(weight))
|
147 |
+
elif text == ")" and len(round_brackets) > 0:
|
148 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
149 |
+
elif text == "]" and len(square_brackets) > 0:
|
150 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
151 |
+
else:
|
152 |
+
res.append([text, 1.0])
|
153 |
+
|
154 |
+
for pos in round_brackets:
|
155 |
+
multiply_range(pos, round_bracket_multiplier)
|
156 |
+
|
157 |
+
for pos in square_brackets:
|
158 |
+
multiply_range(pos, square_bracket_multiplier)
|
159 |
+
|
160 |
+
if len(res) == 0:
|
161 |
+
res = [["", 1.0]]
|
162 |
+
|
163 |
+
# merge runs of identical weights
|
164 |
+
i = 0
|
165 |
+
while i + 1 < len(res):
|
166 |
+
if res[i][1] == res[i + 1][1]:
|
167 |
+
res[i][0] += res[i + 1][0]
|
168 |
+
res.pop(i + 1)
|
169 |
+
else:
|
170 |
+
i += 1
|
171 |
+
|
172 |
+
return res
|
173 |
+
|
174 |
+
def get_prompts_with_weights(text: str, max_length: int):
|
175 |
+
r"""
|
176 |
+
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
|
177 |
+
|
178 |
+
No padding, starting or ending token is included.
|
179 |
+
"""
|
180 |
+
truncated = False
|
181 |
+
|
182 |
+
texts_and_weights = parse_prompt_attention(text)
|
183 |
+
tokens = []
|
184 |
+
weights = []
|
185 |
+
for word, weight in texts_and_weights:
|
186 |
+
# tokenize and discard the starting and the ending token
|
187 |
+
token = tokenizer(word).input_ids[1:-1]
|
188 |
+
tokens += token
|
189 |
+
# copy the weight by length of token
|
190 |
+
weights += [weight] * len(token)
|
191 |
+
# stop if the text is too long (longer than truncation limit)
|
192 |
+
if len(tokens) > max_length:
|
193 |
+
truncated = True
|
194 |
+
break
|
195 |
+
# truncate
|
196 |
+
if len(tokens) > max_length:
|
197 |
+
truncated = True
|
198 |
+
tokens = tokens[:max_length]
|
199 |
+
weights = weights[:max_length]
|
200 |
+
if truncated:
|
201 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
202 |
+
return tokens, weights
|
203 |
+
|
204 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
|
205 |
+
r"""
|
206 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
207 |
+
"""
|
208 |
+
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
|
209 |
+
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
|
210 |
+
return tokens, weights
|
211 |
+
|
212 |
+
if max_length is None:
|
213 |
+
max_length = tokenizer.model_max_length
|
214 |
+
|
215 |
+
tokens, weights = get_prompts_with_weights(text, max_length - 2)
|
216 |
+
tokens, weights = pad_tokens_and_weights(
|
217 |
+
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
|
218 |
+
)
|
219 |
+
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
|
220 |
+
|
221 |
+
def _get_input_ids(
|
222 |
+
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
|
223 |
+
) -> torch.Tensor:
|
224 |
+
"""
|
225 |
+
for SD1.5/2.0/SDXL
|
226 |
+
TODO support batch input
|
227 |
+
"""
|
228 |
+
if max_length is None:
|
229 |
+
max_length = tokenizer.model_max_length - 2
|
230 |
+
|
231 |
+
if weighted:
|
232 |
+
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
|
233 |
+
else:
|
234 |
+
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
235 |
+
|
236 |
+
if max_length > tokenizer.model_max_length:
|
237 |
+
input_ids = input_ids.squeeze(0)
|
238 |
+
iids_list = []
|
239 |
+
if tokenizer.pad_token_id == tokenizer.eos_token_id:
|
240 |
+
# v1
|
241 |
+
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
242 |
+
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
243 |
+
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
|
244 |
+
ids_chunk = (
|
245 |
+
input_ids[0].unsqueeze(0),
|
246 |
+
input_ids[i : i + tokenizer.model_max_length - 2],
|
247 |
+
input_ids[-1].unsqueeze(0),
|
248 |
+
)
|
249 |
+
ids_chunk = torch.cat(ids_chunk)
|
250 |
+
iids_list.append(ids_chunk)
|
251 |
+
else:
|
252 |
+
# v2 or SDXL
|
253 |
+
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
254 |
+
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
255 |
+
ids_chunk = (
|
256 |
+
input_ids[0].unsqueeze(0), # BOS
|
257 |
+
input_ids[i : i + tokenizer.model_max_length - 2],
|
258 |
+
input_ids[-1].unsqueeze(0),
|
259 |
+
) # PAD or EOS
|
260 |
+
ids_chunk = torch.cat(ids_chunk)
|
261 |
+
|
262 |
+
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
263 |
+
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
264 |
+
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
|
265 |
+
ids_chunk[-1] = tokenizer.eos_token_id
|
266 |
+
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
267 |
+
if ids_chunk[1] == tokenizer.pad_token_id:
|
268 |
+
ids_chunk[1] = tokenizer.eos_token_id
|
269 |
+
|
270 |
+
iids_list.append(ids_chunk)
|
271 |
+
|
272 |
+
input_ids = torch.stack(iids_list) # 3,77
|
273 |
+
|
274 |
+
if weighted:
|
275 |
+
weights = weights.squeeze(0)
|
276 |
+
new_weights = torch.ones(input_ids.shape)
|
277 |
+
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
278 |
+
b = i // (tokenizer.model_max_length - 2)
|
279 |
+
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
|
280 |
+
weights = new_weights
|
281 |
+
|
282 |
+
if weighted:
|
283 |
+
return input_ids, weights
|
284 |
+
return input_ids
|
285 |
+
|
286 |
+
|
287 |
+
class TextEncodingStrategy:
|
288 |
+
_strategy = None # strategy instance: actual strategy class
|
289 |
+
|
290 |
+
@classmethod
|
291 |
+
def set_strategy(cls, strategy):
|
292 |
+
if cls._strategy is not None:
|
293 |
+
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
294 |
+
cls._strategy = strategy
|
295 |
+
|
296 |
+
@classmethod
|
297 |
+
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
|
298 |
+
return cls._strategy
|
299 |
+
|
300 |
+
def encode_tokens(
|
301 |
+
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
302 |
+
) -> List[torch.Tensor]:
|
303 |
+
"""
|
304 |
+
Encode tokens into embeddings and outputs.
|
305 |
+
:param tokens: list of token tensors for each TextModel
|
306 |
+
:return: list of output embeddings for each architecture
|
307 |
+
"""
|
308 |
+
raise NotImplementedError
|
309 |
+
|
310 |
+
def encode_tokens_with_weights(
|
311 |
+
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
|
312 |
+
) -> List[torch.Tensor]:
|
313 |
+
"""
|
314 |
+
Encode tokens into embeddings and outputs.
|
315 |
+
:param tokens: list of token tensors for each TextModel
|
316 |
+
:param weights: list of weight tensors for each TextModel
|
317 |
+
:return: list of output embeddings for each architecture
|
318 |
+
"""
|
319 |
+
raise NotImplementedError
|
320 |
+
|
321 |
+
|
322 |
+
class TextEncoderOutputsCachingStrategy:
|
323 |
+
_strategy = None # strategy instance: actual strategy class
|
324 |
+
|
325 |
+
def __init__(
|
326 |
+
self,
|
327 |
+
cache_to_disk: bool,
|
328 |
+
batch_size: Optional[int],
|
329 |
+
skip_disk_cache_validity_check: bool,
|
330 |
+
is_partial: bool = False,
|
331 |
+
is_weighted: bool = False,
|
332 |
+
) -> None:
|
333 |
+
self._cache_to_disk = cache_to_disk
|
334 |
+
self._batch_size = batch_size
|
335 |
+
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
336 |
+
self._is_partial = is_partial
|
337 |
+
self._is_weighted = is_weighted
|
338 |
+
|
339 |
+
@classmethod
|
340 |
+
def set_strategy(cls, strategy):
|
341 |
+
if cls._strategy is not None:
|
342 |
+
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
343 |
+
cls._strategy = strategy
|
344 |
+
|
345 |
+
@classmethod
|
346 |
+
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
347 |
+
return cls._strategy
|
348 |
+
|
349 |
+
@property
|
350 |
+
def cache_to_disk(self):
|
351 |
+
return self._cache_to_disk
|
352 |
+
|
353 |
+
@property
|
354 |
+
def batch_size(self):
|
355 |
+
return self._batch_size
|
356 |
+
|
357 |
+
@property
|
358 |
+
def is_partial(self):
|
359 |
+
return self._is_partial
|
360 |
+
|
361 |
+
@property
|
362 |
+
def is_weighted(self):
|
363 |
+
return self._is_weighted
|
364 |
+
|
365 |
+
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
366 |
+
raise NotImplementedError
|
367 |
+
|
368 |
+
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
369 |
+
raise NotImplementedError
|
370 |
+
|
371 |
+
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
372 |
+
raise NotImplementedError
|
373 |
+
|
374 |
+
def cache_batch_outputs(
|
375 |
+
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
376 |
+
):
|
377 |
+
raise NotImplementedError
|
378 |
+
|
379 |
+
|
380 |
+
class LatentsCachingStrategy:
|
381 |
+
# TODO commonize utillity functions to this class, such as npz handling etc.
|
382 |
+
|
383 |
+
_strategy = None # strategy instance: actual strategy class
|
384 |
+
|
385 |
+
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
386 |
+
self._cache_to_disk = cache_to_disk
|
387 |
+
self._batch_size = batch_size
|
388 |
+
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
389 |
+
|
390 |
+
@classmethod
|
391 |
+
def set_strategy(cls, strategy):
|
392 |
+
if cls._strategy is not None:
|
393 |
+
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
394 |
+
cls._strategy = strategy
|
395 |
+
|
396 |
+
@classmethod
|
397 |
+
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
398 |
+
return cls._strategy
|
399 |
+
|
400 |
+
@property
|
401 |
+
def cache_to_disk(self):
|
402 |
+
return self._cache_to_disk
|
403 |
+
|
404 |
+
@property
|
405 |
+
def batch_size(self):
|
406 |
+
return self._batch_size
|
407 |
+
|
408 |
+
@property
|
409 |
+
def cache_suffix(self):
|
410 |
+
raise NotImplementedError
|
411 |
+
|
412 |
+
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
|
413 |
+
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
|
414 |
+
return int(w), int(h)
|
415 |
+
|
416 |
+
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
417 |
+
raise NotImplementedError
|
418 |
+
|
419 |
+
def is_disk_cached_latents_expected(
|
420 |
+
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
421 |
+
) -> bool:
|
422 |
+
raise NotImplementedError
|
423 |
+
|
424 |
+
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
425 |
+
raise NotImplementedError
|
426 |
+
|
427 |
+
def _default_is_disk_cached_latents_expected(
|
428 |
+
self,
|
429 |
+
latents_stride: int,
|
430 |
+
bucket_reso: Tuple[int, int],
|
431 |
+
npz_path: str,
|
432 |
+
flip_aug: bool,
|
433 |
+
alpha_mask: bool,
|
434 |
+
multi_resolution: bool = False,
|
435 |
+
):
|
436 |
+
if not self.cache_to_disk:
|
437 |
+
return False
|
438 |
+
if not os.path.exists(npz_path):
|
439 |
+
return False
|
440 |
+
if self.skip_disk_cache_validity_check:
|
441 |
+
return True
|
442 |
+
|
443 |
+
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
444 |
+
|
445 |
+
# e.g. "_32x64", HxW
|
446 |
+
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
|
447 |
+
|
448 |
+
try:
|
449 |
+
npz = np.load(npz_path)
|
450 |
+
if "latents" + key_reso_suffix not in npz:
|
451 |
+
return False
|
452 |
+
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
453 |
+
return False
|
454 |
+
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
455 |
+
return False
|
456 |
+
except Exception as e:
|
457 |
+
logger.error(f"Error loading file: {npz_path}")
|
458 |
+
raise e
|
459 |
+
|
460 |
+
return True
|
461 |
+
|
462 |
+
# TODO remove circular dependency for ImageInfo
|
463 |
+
def _default_cache_batch_latents(
|
464 |
+
self,
|
465 |
+
encode_by_vae,
|
466 |
+
vae_device,
|
467 |
+
vae_dtype,
|
468 |
+
image_infos: List,
|
469 |
+
flip_aug: bool,
|
470 |
+
alpha_mask: bool,
|
471 |
+
random_crop: bool,
|
472 |
+
multi_resolution: bool = False,
|
473 |
+
):
|
474 |
+
"""
|
475 |
+
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
476 |
+
"""
|
477 |
+
from library import train_util # import here to avoid circular import
|
478 |
+
|
479 |
+
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
480 |
+
image_infos, alpha_mask, random_crop
|
481 |
+
)
|
482 |
+
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
483 |
+
|
484 |
+
with torch.no_grad():
|
485 |
+
latents_tensors = encode_by_vae(img_tensor).to("cpu")
|
486 |
+
if flip_aug:
|
487 |
+
img_tensor = torch.flip(img_tensor, dims=[3])
|
488 |
+
with torch.no_grad():
|
489 |
+
flipped_latents = encode_by_vae(img_tensor).to("cpu")
|
490 |
+
else:
|
491 |
+
flipped_latents = [None] * len(latents_tensors)
|
492 |
+
|
493 |
+
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
494 |
+
for i in range(len(image_infos)):
|
495 |
+
info = image_infos[i]
|
496 |
+
latents = latents_tensors[i]
|
497 |
+
flipped_latent = flipped_latents[i]
|
498 |
+
alpha_mask = alpha_masks[i]
|
499 |
+
original_size = original_sizes[i]
|
500 |
+
crop_ltrb = crop_ltrbs[i]
|
501 |
+
|
502 |
+
latents_size = latents.shape[1:3] # H, W
|
503 |
+
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
504 |
+
|
505 |
+
if self.cache_to_disk:
|
506 |
+
self.save_latents_to_disk(
|
507 |
+
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
|
508 |
+
)
|
509 |
+
else:
|
510 |
+
info.latents_original_size = original_size
|
511 |
+
info.latents_crop_ltrb = crop_ltrb
|
512 |
+
info.latents = latents
|
513 |
+
if flip_aug:
|
514 |
+
info.latents_flipped = flipped_latent
|
515 |
+
info.alpha_mask = alpha_mask
|
516 |
+
|
517 |
+
def load_latents_from_disk(
|
518 |
+
self, npz_path: str, bucket_reso: Tuple[int, int]
|
519 |
+
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
520 |
+
"""
|
521 |
+
for SD/SDXL
|
522 |
+
"""
|
523 |
+
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
524 |
+
|
525 |
+
def _default_load_latents_from_disk(
|
526 |
+
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
527 |
+
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
528 |
+
if latents_stride is None:
|
529 |
+
key_reso_suffix = ""
|
530 |
+
else:
|
531 |
+
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
532 |
+
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
533 |
+
|
534 |
+
npz = np.load(npz_path)
|
535 |
+
if "latents" + key_reso_suffix not in npz:
|
536 |
+
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
537 |
+
|
538 |
+
latents = npz["latents" + key_reso_suffix]
|
539 |
+
original_size = npz["original_size" + key_reso_suffix].tolist()
|
540 |
+
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
|
541 |
+
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
|
542 |
+
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
543 |
+
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
544 |
+
|
545 |
+
def save_latents_to_disk(
|
546 |
+
self,
|
547 |
+
npz_path,
|
548 |
+
latents_tensor,
|
549 |
+
original_size,
|
550 |
+
crop_ltrb,
|
551 |
+
flipped_latents_tensor=None,
|
552 |
+
alpha_mask=None,
|
553 |
+
key_reso_suffix="",
|
554 |
+
):
|
555 |
+
kwargs = {}
|
556 |
+
|
557 |
+
if os.path.exists(npz_path):
|
558 |
+
# load existing npz and update it
|
559 |
+
npz = np.load(npz_path)
|
560 |
+
for key in npz.files:
|
561 |
+
kwargs[key] = npz[key]
|
562 |
+
|
563 |
+
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
564 |
+
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
565 |
+
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
566 |
+
if flipped_latents_tensor is not None:
|
567 |
+
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
|
568 |
+
if alpha_mask is not None:
|
569 |
+
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
|
570 |
+
np.savez(npz_path, **kwargs)
|
library/strategy_flux.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
from typing import Any, List, Optional, Tuple, Union
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from transformers import CLIPTokenizer, T5TokenizerFast
|
7 |
+
|
8 |
+
from library import flux_utils, train_util
|
9 |
+
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
10 |
+
|
11 |
+
from library.utils import setup_logging
|
12 |
+
|
13 |
+
setup_logging()
|
14 |
+
import logging
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
20 |
+
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
21 |
+
|
22 |
+
|
23 |
+
class FluxTokenizeStrategy(TokenizeStrategy):
|
24 |
+
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
|
25 |
+
self.t5xxl_max_length = t5xxl_max_length
|
26 |
+
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
27 |
+
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
28 |
+
|
29 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
30 |
+
text = [text] if isinstance(text, str) else text
|
31 |
+
|
32 |
+
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
33 |
+
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
34 |
+
|
35 |
+
t5_attn_mask = t5_tokens["attention_mask"]
|
36 |
+
l_tokens = l_tokens["input_ids"]
|
37 |
+
t5_tokens = t5_tokens["input_ids"]
|
38 |
+
|
39 |
+
return [l_tokens, t5_tokens, t5_attn_mask]
|
40 |
+
|
41 |
+
|
42 |
+
class FluxTextEncodingStrategy(TextEncodingStrategy):
|
43 |
+
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
47 |
+
"""
|
48 |
+
self.apply_t5_attn_mask = apply_t5_attn_mask
|
49 |
+
|
50 |
+
def encode_tokens(
|
51 |
+
self,
|
52 |
+
tokenize_strategy: TokenizeStrategy,
|
53 |
+
models: List[Any],
|
54 |
+
tokens: List[torch.Tensor],
|
55 |
+
apply_t5_attn_mask: Optional[bool] = None,
|
56 |
+
) -> List[torch.Tensor]:
|
57 |
+
# supports single model inference
|
58 |
+
|
59 |
+
if apply_t5_attn_mask is None:
|
60 |
+
apply_t5_attn_mask = self.apply_t5_attn_mask
|
61 |
+
|
62 |
+
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
|
63 |
+
l_tokens, t5_tokens = tokens[:2]
|
64 |
+
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
65 |
+
|
66 |
+
# clip_l is None when using T5 only
|
67 |
+
if clip_l is not None and l_tokens is not None:
|
68 |
+
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
|
69 |
+
else:
|
70 |
+
l_pooled = None
|
71 |
+
|
72 |
+
# t5xxl is None when using CLIP only
|
73 |
+
if t5xxl is not None and t5_tokens is not None:
|
74 |
+
# t5_out is [b, max length, 4096]
|
75 |
+
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
|
76 |
+
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
|
77 |
+
# if zero_pad_t5_output:
|
78 |
+
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
79 |
+
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
80 |
+
else:
|
81 |
+
t5_out = None
|
82 |
+
txt_ids = None
|
83 |
+
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
|
84 |
+
|
85 |
+
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
86 |
+
|
87 |
+
|
88 |
+
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
89 |
+
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
cache_to_disk: bool,
|
94 |
+
batch_size: int,
|
95 |
+
skip_disk_cache_validity_check: bool,
|
96 |
+
is_partial: bool = False,
|
97 |
+
apply_t5_attn_mask: bool = False,
|
98 |
+
) -> None:
|
99 |
+
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
100 |
+
self.apply_t5_attn_mask = apply_t5_attn_mask
|
101 |
+
|
102 |
+
self.warn_fp8_weights = False
|
103 |
+
|
104 |
+
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
105 |
+
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
106 |
+
|
107 |
+
def is_disk_cached_outputs_expected(self, npz_path: str):
|
108 |
+
if not self.cache_to_disk:
|
109 |
+
return False
|
110 |
+
if not os.path.exists(npz_path):
|
111 |
+
return False
|
112 |
+
if self.skip_disk_cache_validity_check:
|
113 |
+
return True
|
114 |
+
|
115 |
+
try:
|
116 |
+
npz = np.load(npz_path)
|
117 |
+
if "l_pooled" not in npz:
|
118 |
+
return False
|
119 |
+
if "t5_out" not in npz:
|
120 |
+
return False
|
121 |
+
if "txt_ids" not in npz:
|
122 |
+
return False
|
123 |
+
if "t5_attn_mask" not in npz:
|
124 |
+
return False
|
125 |
+
if "apply_t5_attn_mask" not in npz:
|
126 |
+
return False
|
127 |
+
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
128 |
+
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
129 |
+
return False
|
130 |
+
except Exception as e:
|
131 |
+
logger.error(f"Error loading file: {npz_path}")
|
132 |
+
raise e
|
133 |
+
|
134 |
+
return True
|
135 |
+
|
136 |
+
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
137 |
+
data = np.load(npz_path)
|
138 |
+
l_pooled = data["l_pooled"]
|
139 |
+
t5_out = data["t5_out"]
|
140 |
+
txt_ids = data["txt_ids"]
|
141 |
+
t5_attn_mask = data["t5_attn_mask"]
|
142 |
+
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
143 |
+
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
144 |
+
|
145 |
+
def cache_batch_outputs(
|
146 |
+
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
147 |
+
):
|
148 |
+
if not self.warn_fp8_weights:
|
149 |
+
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
150 |
+
logger.warning(
|
151 |
+
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
|
152 |
+
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
|
153 |
+
)
|
154 |
+
self.warn_fp8_weights = True
|
155 |
+
|
156 |
+
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
157 |
+
captions = [info.caption for info in infos]
|
158 |
+
|
159 |
+
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
160 |
+
with torch.no_grad():
|
161 |
+
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
|
162 |
+
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
|
163 |
+
|
164 |
+
if l_pooled.dtype == torch.bfloat16:
|
165 |
+
l_pooled = l_pooled.float()
|
166 |
+
if t5_out.dtype == torch.bfloat16:
|
167 |
+
t5_out = t5_out.float()
|
168 |
+
if txt_ids.dtype == torch.bfloat16:
|
169 |
+
txt_ids = txt_ids.float()
|
170 |
+
|
171 |
+
l_pooled = l_pooled.cpu().numpy()
|
172 |
+
t5_out = t5_out.cpu().numpy()
|
173 |
+
txt_ids = txt_ids.cpu().numpy()
|
174 |
+
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
175 |
+
|
176 |
+
for i, info in enumerate(infos):
|
177 |
+
l_pooled_i = l_pooled[i]
|
178 |
+
t5_out_i = t5_out[i]
|
179 |
+
txt_ids_i = txt_ids[i]
|
180 |
+
t5_attn_mask_i = t5_attn_mask[i]
|
181 |
+
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
182 |
+
|
183 |
+
if self.cache_to_disk:
|
184 |
+
np.savez(
|
185 |
+
info.text_encoder_outputs_npz,
|
186 |
+
l_pooled=l_pooled_i,
|
187 |
+
t5_out=t5_out_i,
|
188 |
+
txt_ids=txt_ids_i,
|
189 |
+
t5_attn_mask=t5_attn_mask_i,
|
190 |
+
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
194 |
+
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
195 |
+
|
196 |
+
|
197 |
+
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
198 |
+
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
|
199 |
+
|
200 |
+
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
201 |
+
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
202 |
+
|
203 |
+
@property
|
204 |
+
def cache_suffix(self) -> str:
|
205 |
+
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
206 |
+
|
207 |
+
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
208 |
+
return (
|
209 |
+
os.path.splitext(absolute_path)[0]
|
210 |
+
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
211 |
+
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
212 |
+
)
|
213 |
+
|
214 |
+
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
215 |
+
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
216 |
+
|
217 |
+
def load_latents_from_disk(
|
218 |
+
self, npz_path: str, bucket_reso: Tuple[int, int]
|
219 |
+
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
220 |
+
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
221 |
+
|
222 |
+
# TODO remove circular dependency for ImageInfo
|
223 |
+
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
224 |
+
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
225 |
+
vae_device = vae.device
|
226 |
+
vae_dtype = vae.dtype
|
227 |
+
|
228 |
+
self._default_cache_batch_latents(
|
229 |
+
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
230 |
+
)
|
231 |
+
|
232 |
+
if not train_util.HIGH_VRAM:
|
233 |
+
train_util.clean_memory_on_device(vae.device)
|
234 |
+
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
# test code for FluxTokenizeStrategy
|
238 |
+
# tokenizer = sd3_models.SD3Tokenizer()
|
239 |
+
strategy = FluxTokenizeStrategy(256)
|
240 |
+
text = "hello world"
|
241 |
+
|
242 |
+
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
243 |
+
# print(l_tokens.shape)
|
244 |
+
print(l_tokens)
|
245 |
+
print(g_tokens)
|
246 |
+
print(t5_tokens)
|
247 |
+
|
248 |
+
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
249 |
+
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
250 |
+
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
251 |
+
t5_tokens_2 = strategy.t5xxl(
|
252 |
+
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
253 |
+
)
|
254 |
+
print(l_tokens_2)
|
255 |
+
print(g_tokens_2)
|
256 |
+
print(t5_tokens_2)
|
257 |
+
|
258 |
+
# compare
|
259 |
+
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
|
260 |
+
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
|
261 |
+
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
|
262 |
+
|
263 |
+
text = ",".join(["hello world! this is long text"] * 50)
|
264 |
+
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
265 |
+
print(l_tokens)
|
266 |
+
print(g_tokens)
|
267 |
+
print(t5_tokens)
|
268 |
+
|
269 |
+
print(f"model max length l: {strategy.clip_l.model_max_length}")
|
270 |
+
print(f"model max length g: {strategy.clip_g.model_max_length}")
|
271 |
+
print(f"model max length t5: {strategy.t5xxl.model_max_length}")
|
library/strategy_sd.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
from typing import Any, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import CLIPTokenizer
|
7 |
+
from library import train_util
|
8 |
+
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
9 |
+
from library.utils import setup_logging
|
10 |
+
|
11 |
+
setup_logging()
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
18 |
+
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
19 |
+
|
20 |
+
|
21 |
+
class SdTokenizeStrategy(TokenizeStrategy):
|
22 |
+
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
23 |
+
"""
|
24 |
+
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
|
25 |
+
"""
|
26 |
+
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
|
27 |
+
if v2:
|
28 |
+
self.tokenizer = self._load_tokenizer(
|
29 |
+
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
33 |
+
|
34 |
+
if max_length is None:
|
35 |
+
self.max_length = self.tokenizer.model_max_length
|
36 |
+
else:
|
37 |
+
self.max_length = max_length + 2
|
38 |
+
|
39 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
40 |
+
text = [text] if isinstance(text, str) else text
|
41 |
+
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
42 |
+
|
43 |
+
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
44 |
+
text = [text] if isinstance(text, str) else text
|
45 |
+
tokens_list = []
|
46 |
+
weights_list = []
|
47 |
+
for t in text:
|
48 |
+
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
|
49 |
+
tokens_list.append(tokens)
|
50 |
+
weights_list.append(weights)
|
51 |
+
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
|
52 |
+
|
53 |
+
|
54 |
+
class SdTextEncodingStrategy(TextEncodingStrategy):
|
55 |
+
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
56 |
+
self.clip_skip = clip_skip
|
57 |
+
|
58 |
+
def encode_tokens(
|
59 |
+
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
60 |
+
) -> List[torch.Tensor]:
|
61 |
+
text_encoder = models[0]
|
62 |
+
tokens = tokens[0]
|
63 |
+
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
|
64 |
+
|
65 |
+
# tokens: b,n,77
|
66 |
+
b_size = tokens.size()[0]
|
67 |
+
max_token_length = tokens.size()[1] * tokens.size()[2]
|
68 |
+
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
69 |
+
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
70 |
+
|
71 |
+
tokens = tokens.to(text_encoder.device)
|
72 |
+
|
73 |
+
if self.clip_skip is None:
|
74 |
+
encoder_hidden_states = text_encoder(tokens)[0]
|
75 |
+
else:
|
76 |
+
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
77 |
+
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
78 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
79 |
+
|
80 |
+
# bs*3, 77, 768 or 1024
|
81 |
+
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
82 |
+
|
83 |
+
if max_token_length != model_max_length:
|
84 |
+
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
|
85 |
+
if not v1:
|
86 |
+
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
87 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
88 |
+
for i in range(1, max_token_length, model_max_length):
|
89 |
+
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
|
90 |
+
if i > 0:
|
91 |
+
for j in range(len(chunk)):
|
92 |
+
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
|
93 |
+
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
94 |
+
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
95 |
+
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
96 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
97 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
98 |
+
else:
|
99 |
+
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
100 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
101 |
+
for i in range(1, max_token_length, model_max_length):
|
102 |
+
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
103 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
104 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
105 |
+
|
106 |
+
return [encoder_hidden_states]
|
107 |
+
|
108 |
+
def encode_tokens_with_weights(
|
109 |
+
self,
|
110 |
+
tokenize_strategy: TokenizeStrategy,
|
111 |
+
models: List[Any],
|
112 |
+
tokens_list: List[torch.Tensor],
|
113 |
+
weights_list: List[torch.Tensor],
|
114 |
+
) -> List[torch.Tensor]:
|
115 |
+
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
|
116 |
+
|
117 |
+
weights = weights_list[0].to(encoder_hidden_states.device)
|
118 |
+
|
119 |
+
# apply weights
|
120 |
+
if weights.shape[1] == 1: # no max_token_length
|
121 |
+
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
122 |
+
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
123 |
+
else:
|
124 |
+
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
125 |
+
for i in range(weights.shape[1]):
|
126 |
+
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
|
127 |
+
:, i, 1:-1
|
128 |
+
].unsqueeze(-1)
|
129 |
+
|
130 |
+
return [encoder_hidden_states]
|
131 |
+
|
132 |
+
|
133 |
+
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
134 |
+
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
135 |
+
# and we keep the old npz for the backward compatibility.
|
136 |
+
|
137 |
+
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
|
138 |
+
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
|
139 |
+
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
|
140 |
+
|
141 |
+
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
142 |
+
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
143 |
+
self.sd = sd
|
144 |
+
self.suffix = (
|
145 |
+
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
146 |
+
)
|
147 |
+
|
148 |
+
@property
|
149 |
+
def cache_suffix(self) -> str:
|
150 |
+
return self.suffix
|
151 |
+
|
152 |
+
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
153 |
+
# support old .npz
|
154 |
+
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
155 |
+
if os.path.exists(old_npz_file):
|
156 |
+
return old_npz_file
|
157 |
+
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
158 |
+
|
159 |
+
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
160 |
+
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
161 |
+
|
162 |
+
# TODO remove circular dependency for ImageInfo
|
163 |
+
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
164 |
+
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
|
165 |
+
vae_device = vae.device
|
166 |
+
vae_dtype = vae.dtype
|
167 |
+
|
168 |
+
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
169 |
+
|
170 |
+
if not train_util.HIGH_VRAM:
|
171 |
+
train_util.clean_memory_on_device(vae.device)
|
library/strategy_sd3.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
from typing import Any, List, Optional, Tuple, Union
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
|
8 |
+
|
9 |
+
from library import sd3_utils, train_util
|
10 |
+
from library import sd3_models
|
11 |
+
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
12 |
+
|
13 |
+
from library.utils import setup_logging
|
14 |
+
|
15 |
+
setup_logging()
|
16 |
+
import logging
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
22 |
+
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
23 |
+
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
24 |
+
|
25 |
+
|
26 |
+
class Sd3TokenizeStrategy(TokenizeStrategy):
|
27 |
+
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
|
28 |
+
self.t5xxl_max_length = t5xxl_max_length
|
29 |
+
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
30 |
+
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
31 |
+
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
32 |
+
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
|
33 |
+
|
34 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
35 |
+
text = [text] if isinstance(text, str) else text
|
36 |
+
|
37 |
+
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
38 |
+
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
39 |
+
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
40 |
+
|
41 |
+
l_attn_mask = l_tokens["attention_mask"]
|
42 |
+
g_attn_mask = g_tokens["attention_mask"]
|
43 |
+
t5_attn_mask = t5_tokens["attention_mask"]
|
44 |
+
l_tokens = l_tokens["input_ids"]
|
45 |
+
g_tokens = g_tokens["input_ids"]
|
46 |
+
t5_tokens = t5_tokens["input_ids"]
|
47 |
+
|
48 |
+
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
|
49 |
+
|
50 |
+
|
51 |
+
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
apply_lg_attn_mask: Optional[bool] = None,
|
55 |
+
apply_t5_attn_mask: Optional[bool] = None,
|
56 |
+
l_dropout_rate: float = 0.0,
|
57 |
+
g_dropout_rate: float = 0.0,
|
58 |
+
t5_dropout_rate: float = 0.0,
|
59 |
+
) -> None:
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
63 |
+
"""
|
64 |
+
self.apply_lg_attn_mask = apply_lg_attn_mask
|
65 |
+
self.apply_t5_attn_mask = apply_t5_attn_mask
|
66 |
+
self.l_dropout_rate = l_dropout_rate
|
67 |
+
self.g_dropout_rate = g_dropout_rate
|
68 |
+
self.t5_dropout_rate = t5_dropout_rate
|
69 |
+
|
70 |
+
def encode_tokens(
|
71 |
+
self,
|
72 |
+
tokenize_strategy: TokenizeStrategy,
|
73 |
+
models: List[Any],
|
74 |
+
tokens: List[torch.Tensor],
|
75 |
+
apply_lg_attn_mask: Optional[bool] = False,
|
76 |
+
apply_t5_attn_mask: Optional[bool] = False,
|
77 |
+
enable_dropout: bool = True,
|
78 |
+
) -> List[torch.Tensor]:
|
79 |
+
"""
|
80 |
+
returned embeddings are not masked
|
81 |
+
"""
|
82 |
+
clip_l, clip_g, t5xxl = models
|
83 |
+
clip_l: Optional[CLIPTextModel]
|
84 |
+
clip_g: Optional[CLIPTextModelWithProjection]
|
85 |
+
t5xxl: Optional[T5EncoderModel]
|
86 |
+
|
87 |
+
if apply_lg_attn_mask is None:
|
88 |
+
apply_lg_attn_mask = self.apply_lg_attn_mask
|
89 |
+
if apply_t5_attn_mask is None:
|
90 |
+
apply_t5_attn_mask = self.apply_t5_attn_mask
|
91 |
+
|
92 |
+
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
|
93 |
+
|
94 |
+
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
|
95 |
+
|
96 |
+
if l_tokens is None or clip_l is None:
|
97 |
+
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
98 |
+
lg_out = None
|
99 |
+
lg_pooled = None
|
100 |
+
l_attn_mask = None
|
101 |
+
g_attn_mask = None
|
102 |
+
else:
|
103 |
+
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
104 |
+
|
105 |
+
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
|
106 |
+
batch_size, l_seq_len = l_tokens.shape
|
107 |
+
g_seq_len = g_tokens.shape[1]
|
108 |
+
|
109 |
+
non_drop_l_indices = []
|
110 |
+
non_drop_g_indices = []
|
111 |
+
for i in range(l_tokens.shape[0]):
|
112 |
+
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
|
113 |
+
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
|
114 |
+
if not drop_l:
|
115 |
+
non_drop_l_indices.append(i)
|
116 |
+
if not drop_g:
|
117 |
+
non_drop_g_indices.append(i)
|
118 |
+
|
119 |
+
# filter out dropped members
|
120 |
+
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
|
121 |
+
l_tokens = l_tokens[non_drop_l_indices]
|
122 |
+
l_attn_mask = l_attn_mask[non_drop_l_indices]
|
123 |
+
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
|
124 |
+
g_tokens = g_tokens[non_drop_g_indices]
|
125 |
+
g_attn_mask = g_attn_mask[non_drop_g_indices]
|
126 |
+
|
127 |
+
# call clip_l for non-dropped members
|
128 |
+
if len(non_drop_l_indices) > 0:
|
129 |
+
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
|
130 |
+
prompt_embeds = clip_l(
|
131 |
+
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
|
132 |
+
)
|
133 |
+
nd_l_pooled = prompt_embeds[0]
|
134 |
+
nd_l_out = prompt_embeds.hidden_states[-2]
|
135 |
+
if len(non_drop_g_indices) > 0:
|
136 |
+
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
|
137 |
+
prompt_embeds = clip_g(
|
138 |
+
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
|
139 |
+
)
|
140 |
+
nd_g_pooled = prompt_embeds[0]
|
141 |
+
nd_g_out = prompt_embeds.hidden_states[-2]
|
142 |
+
|
143 |
+
# fill in the dropped members
|
144 |
+
if len(non_drop_l_indices) == batch_size:
|
145 |
+
l_pooled = nd_l_pooled
|
146 |
+
l_out = nd_l_out
|
147 |
+
else:
|
148 |
+
# model output is always float32 because of the models are wrapped with Accelerator
|
149 |
+
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
|
150 |
+
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
|
151 |
+
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
|
152 |
+
if len(non_drop_l_indices) > 0:
|
153 |
+
l_pooled[non_drop_l_indices] = nd_l_pooled
|
154 |
+
l_out[non_drop_l_indices] = nd_l_out
|
155 |
+
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
|
156 |
+
|
157 |
+
if len(non_drop_g_indices) == batch_size:
|
158 |
+
g_pooled = nd_g_pooled
|
159 |
+
g_out = nd_g_out
|
160 |
+
else:
|
161 |
+
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
|
162 |
+
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
|
163 |
+
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
|
164 |
+
if len(non_drop_g_indices) > 0:
|
165 |
+
g_pooled[non_drop_g_indices] = nd_g_pooled
|
166 |
+
g_out[non_drop_g_indices] = nd_g_out
|
167 |
+
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
|
168 |
+
|
169 |
+
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
170 |
+
lg_out = torch.cat([l_out, g_out], dim=-1)
|
171 |
+
|
172 |
+
if t5xxl is None or t5_tokens is None:
|
173 |
+
t5_out = None
|
174 |
+
t5_attn_mask = None
|
175 |
+
else:
|
176 |
+
# drop some members of the batch: we do not call t5xxl for dropped members
|
177 |
+
batch_size, t5_seq_len = t5_tokens.shape
|
178 |
+
non_drop_t5_indices = []
|
179 |
+
for i in range(t5_tokens.shape[0]):
|
180 |
+
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
|
181 |
+
if not drop_t5:
|
182 |
+
non_drop_t5_indices.append(i)
|
183 |
+
|
184 |
+
# filter out dropped members
|
185 |
+
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
|
186 |
+
t5_tokens = t5_tokens[non_drop_t5_indices]
|
187 |
+
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
|
188 |
+
|
189 |
+
# call t5xxl for non-dropped members
|
190 |
+
if len(non_drop_t5_indices) > 0:
|
191 |
+
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
|
192 |
+
nd_t5_out, _ = t5xxl(
|
193 |
+
t5_tokens.to(t5xxl.device),
|
194 |
+
nd_t5_attn_mask if apply_t5_attn_mask else None,
|
195 |
+
return_dict=False,
|
196 |
+
output_hidden_states=True,
|
197 |
+
)
|
198 |
+
|
199 |
+
# fill in the dropped members
|
200 |
+
if len(non_drop_t5_indices) == batch_size:
|
201 |
+
t5_out = nd_t5_out
|
202 |
+
else:
|
203 |
+
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
|
204 |
+
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
|
205 |
+
if len(non_drop_t5_indices) > 0:
|
206 |
+
t5_out[non_drop_t5_indices] = nd_t5_out
|
207 |
+
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
|
208 |
+
|
209 |
+
# masks are used for attention masking in transformer
|
210 |
+
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
211 |
+
|
212 |
+
def drop_cached_text_encoder_outputs(
|
213 |
+
self,
|
214 |
+
lg_out: torch.Tensor,
|
215 |
+
t5_out: torch.Tensor,
|
216 |
+
lg_pooled: torch.Tensor,
|
217 |
+
l_attn_mask: torch.Tensor,
|
218 |
+
g_attn_mask: torch.Tensor,
|
219 |
+
t5_attn_mask: torch.Tensor,
|
220 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
221 |
+
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
|
222 |
+
if lg_out is not None:
|
223 |
+
for i in range(lg_out.shape[0]):
|
224 |
+
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
|
225 |
+
if drop_l:
|
226 |
+
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
|
227 |
+
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
|
228 |
+
if l_attn_mask is not None:
|
229 |
+
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
|
230 |
+
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
|
231 |
+
if drop_g:
|
232 |
+
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
|
233 |
+
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
|
234 |
+
if g_attn_mask is not None:
|
235 |
+
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
|
236 |
+
|
237 |
+
if t5_out is not None:
|
238 |
+
for i in range(t5_out.shape[0]):
|
239 |
+
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
|
240 |
+
if drop_t5:
|
241 |
+
t5_out[i] = torch.zeros_like(t5_out[i])
|
242 |
+
if t5_attn_mask is not None:
|
243 |
+
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
244 |
+
|
245 |
+
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
246 |
+
|
247 |
+
def concat_encodings(
|
248 |
+
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
249 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
250 |
+
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
251 |
+
if t5_out is None:
|
252 |
+
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
|
253 |
+
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
|
254 |
+
|
255 |
+
|
256 |
+
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
257 |
+
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
cache_to_disk: bool,
|
262 |
+
batch_size: int,
|
263 |
+
skip_disk_cache_validity_check: bool,
|
264 |
+
is_partial: bool = False,
|
265 |
+
apply_lg_attn_mask: bool = False,
|
266 |
+
apply_t5_attn_mask: bool = False,
|
267 |
+
) -> None:
|
268 |
+
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
269 |
+
self.apply_lg_attn_mask = apply_lg_attn_mask
|
270 |
+
self.apply_t5_attn_mask = apply_t5_attn_mask
|
271 |
+
|
272 |
+
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
273 |
+
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
274 |
+
|
275 |
+
def is_disk_cached_outputs_expected(self, npz_path: str):
|
276 |
+
if not self.cache_to_disk:
|
277 |
+
return False
|
278 |
+
if not os.path.exists(npz_path):
|
279 |
+
return False
|
280 |
+
if self.skip_disk_cache_validity_check:
|
281 |
+
return True
|
282 |
+
|
283 |
+
try:
|
284 |
+
npz = np.load(npz_path)
|
285 |
+
if "lg_out" not in npz:
|
286 |
+
return False
|
287 |
+
if "lg_pooled" not in npz:
|
288 |
+
return False
|
289 |
+
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
290 |
+
return False
|
291 |
+
if "apply_lg_attn_mask" not in npz:
|
292 |
+
return False
|
293 |
+
if "t5_out" not in npz:
|
294 |
+
return False
|
295 |
+
if "t5_attn_mask" not in npz:
|
296 |
+
return False
|
297 |
+
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
298 |
+
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
299 |
+
return False
|
300 |
+
if "apply_t5_attn_mask" not in npz:
|
301 |
+
return False
|
302 |
+
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
303 |
+
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
304 |
+
return False
|
305 |
+
except Exception as e:
|
306 |
+
logger.error(f"Error loading file: {npz_path}")
|
307 |
+
raise e
|
308 |
+
|
309 |
+
return True
|
310 |
+
|
311 |
+
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
312 |
+
data = np.load(npz_path)
|
313 |
+
lg_out = data["lg_out"]
|
314 |
+
lg_pooled = data["lg_pooled"]
|
315 |
+
t5_out = data["t5_out"]
|
316 |
+
|
317 |
+
l_attn_mask = data["clip_l_attn_mask"]
|
318 |
+
g_attn_mask = data["clip_g_attn_mask"]
|
319 |
+
t5_attn_mask = data["t5_attn_mask"]
|
320 |
+
|
321 |
+
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
|
322 |
+
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
323 |
+
|
324 |
+
def cache_batch_outputs(
|
325 |
+
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
326 |
+
):
|
327 |
+
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
|
328 |
+
captions = [info.caption for info in infos]
|
329 |
+
|
330 |
+
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
331 |
+
with torch.no_grad():
|
332 |
+
# always disable dropout during caching
|
333 |
+
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
|
334 |
+
tokenize_strategy,
|
335 |
+
models,
|
336 |
+
tokens_and_masks,
|
337 |
+
apply_lg_attn_mask=self.apply_lg_attn_mask,
|
338 |
+
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
339 |
+
enable_dropout=False,
|
340 |
+
)
|
341 |
+
|
342 |
+
if lg_out.dtype == torch.bfloat16:
|
343 |
+
lg_out = lg_out.float()
|
344 |
+
if lg_pooled.dtype == torch.bfloat16:
|
345 |
+
lg_pooled = lg_pooled.float()
|
346 |
+
if t5_out.dtype == torch.bfloat16:
|
347 |
+
t5_out = t5_out.float()
|
348 |
+
|
349 |
+
lg_out = lg_out.cpu().numpy()
|
350 |
+
lg_pooled = lg_pooled.cpu().numpy()
|
351 |
+
t5_out = t5_out.cpu().numpy()
|
352 |
+
|
353 |
+
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
354 |
+
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
355 |
+
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
356 |
+
|
357 |
+
for i, info in enumerate(infos):
|
358 |
+
lg_out_i = lg_out[i]
|
359 |
+
t5_out_i = t5_out[i]
|
360 |
+
lg_pooled_i = lg_pooled[i]
|
361 |
+
l_attn_mask_i = l_attn_mask[i]
|
362 |
+
g_attn_mask_i = g_attn_mask[i]
|
363 |
+
t5_attn_mask_i = t5_attn_mask[i]
|
364 |
+
apply_lg_attn_mask = self.apply_lg_attn_mask
|
365 |
+
apply_t5_attn_mask = self.apply_t5_attn_mask
|
366 |
+
|
367 |
+
if self.cache_to_disk:
|
368 |
+
np.savez(
|
369 |
+
info.text_encoder_outputs_npz,
|
370 |
+
lg_out=lg_out_i,
|
371 |
+
lg_pooled=lg_pooled_i,
|
372 |
+
t5_out=t5_out_i,
|
373 |
+
clip_l_attn_mask=l_attn_mask_i,
|
374 |
+
clip_g_attn_mask=g_attn_mask_i,
|
375 |
+
t5_attn_mask=t5_attn_mask_i,
|
376 |
+
apply_lg_attn_mask=apply_lg_attn_mask,
|
377 |
+
apply_t5_attn_mask=apply_t5_attn_mask,
|
378 |
+
)
|
379 |
+
else:
|
380 |
+
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
381 |
+
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
382 |
+
|
383 |
+
|
384 |
+
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
385 |
+
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
386 |
+
|
387 |
+
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
388 |
+
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
389 |
+
|
390 |
+
@property
|
391 |
+
def cache_suffix(self) -> str:
|
392 |
+
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
393 |
+
|
394 |
+
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
395 |
+
return (
|
396 |
+
os.path.splitext(absolute_path)[0]
|
397 |
+
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
398 |
+
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
399 |
+
)
|
400 |
+
|
401 |
+
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
402 |
+
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
403 |
+
|
404 |
+
def load_latents_from_disk(
|
405 |
+
self, npz_path: str, bucket_reso: Tuple[int, int]
|
406 |
+
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
407 |
+
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
408 |
+
|
409 |
+
# TODO remove circular dependency for ImageInfo
|
410 |
+
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
411 |
+
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
412 |
+
vae_device = vae.device
|
413 |
+
vae_dtype = vae.dtype
|
414 |
+
|
415 |
+
self._default_cache_batch_latents(
|
416 |
+
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
417 |
+
)
|
418 |
+
|
419 |
+
if not train_util.HIGH_VRAM:
|
420 |
+
train_util.clean_memory_on_device(vae.device)
|
library/strategy_sdxl.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
7 |
+
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
8 |
+
|
9 |
+
|
10 |
+
from library.utils import setup_logging
|
11 |
+
|
12 |
+
setup_logging()
|
13 |
+
import logging
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
19 |
+
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
20 |
+
|
21 |
+
|
22 |
+
class SdxlTokenizeStrategy(TokenizeStrategy):
|
23 |
+
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
24 |
+
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
25 |
+
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
26 |
+
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
|
27 |
+
|
28 |
+
if max_length is None:
|
29 |
+
self.max_length = self.tokenizer1.model_max_length
|
30 |
+
else:
|
31 |
+
self.max_length = max_length + 2
|
32 |
+
|
33 |
+
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
34 |
+
text = [text] if isinstance(text, str) else text
|
35 |
+
return (
|
36 |
+
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
|
37 |
+
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
|
38 |
+
)
|
39 |
+
|
40 |
+
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
41 |
+
text = [text] if isinstance(text, str) else text
|
42 |
+
tokens1_list, tokens2_list = [], []
|
43 |
+
weights1_list, weights2_list = [], []
|
44 |
+
for t in text:
|
45 |
+
tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True)
|
46 |
+
tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True)
|
47 |
+
tokens1_list.append(tokens1)
|
48 |
+
tokens2_list.append(tokens2)
|
49 |
+
weights1_list.append(weights1)
|
50 |
+
weights2_list.append(weights2)
|
51 |
+
return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [
|
52 |
+
torch.stack(weights1_list, dim=0),
|
53 |
+
torch.stack(weights2_list, dim=0),
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
58 |
+
def __init__(self) -> None:
|
59 |
+
pass
|
60 |
+
|
61 |
+
def _pool_workaround(
|
62 |
+
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
|
63 |
+
):
|
64 |
+
r"""
|
65 |
+
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
|
66 |
+
instead of the hidden states for the EOS token
|
67 |
+
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
|
68 |
+
|
69 |
+
Original code from CLIP's pooling function:
|
70 |
+
|
71 |
+
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
72 |
+
\# take features from the eot embedding (eot_token is the highest number in each sequence)
|
73 |
+
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
74 |
+
pooled_output = last_hidden_state[
|
75 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
76 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
77 |
+
]
|
78 |
+
"""
|
79 |
+
|
80 |
+
# input_ids: b*n,77
|
81 |
+
# find index for EOS token
|
82 |
+
|
83 |
+
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
|
84 |
+
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
85 |
+
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
86 |
+
|
87 |
+
# Create a mask where the EOS tokens are
|
88 |
+
eos_token_mask = (input_ids == eos_token_id).int()
|
89 |
+
|
90 |
+
# Use argmax to find the last index of the EOS token for each element in the batch
|
91 |
+
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
|
92 |
+
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
93 |
+
|
94 |
+
# get hidden states for EOS token
|
95 |
+
pooled_output = last_hidden_state[
|
96 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
|
97 |
+
]
|
98 |
+
|
99 |
+
# apply projection: projection may be of different dtype than last_hidden_state
|
100 |
+
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
|
101 |
+
pooled_output = pooled_output.to(last_hidden_state.dtype)
|
102 |
+
|
103 |
+
return pooled_output
|
104 |
+
|
105 |
+
def _get_hidden_states_sdxl(
|
106 |
+
self,
|
107 |
+
input_ids1: torch.Tensor,
|
108 |
+
input_ids2: torch.Tensor,
|
109 |
+
tokenizer1: CLIPTokenizer,
|
110 |
+
tokenizer2: CLIPTokenizer,
|
111 |
+
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
|
112 |
+
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
|
113 |
+
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
|
114 |
+
):
|
115 |
+
# input_ids: b,n,77 -> b*n, 77
|
116 |
+
b_size = input_ids1.size()[0]
|
117 |
+
if input_ids1.size()[1] == 1:
|
118 |
+
max_token_length = None
|
119 |
+
else:
|
120 |
+
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
|
121 |
+
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
122 |
+
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
123 |
+
input_ids1 = input_ids1.to(text_encoder1.device)
|
124 |
+
input_ids2 = input_ids2.to(text_encoder2.device)
|
125 |
+
|
126 |
+
# text_encoder1
|
127 |
+
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
128 |
+
hidden_states1 = enc_out["hidden_states"][11]
|
129 |
+
|
130 |
+
# text_encoder2
|
131 |
+
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
132 |
+
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
133 |
+
|
134 |
+
# pool2 = enc_out["text_embeds"]
|
135 |
+
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
|
136 |
+
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
137 |
+
|
138 |
+
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
139 |
+
n_size = 1 if max_token_length is None else max_token_length // 75
|
140 |
+
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
141 |
+
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
142 |
+
|
143 |
+
if max_token_length is not None:
|
144 |
+
# bs*3, 77, 768 or 1024
|
145 |
+
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
146 |
+
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
147 |
+
for i in range(1, max_token_length, tokenizer1.model_max_length):
|
148 |
+
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
149 |
+
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
150 |
+
hidden_states1 = torch.cat(states_list, dim=1)
|
151 |
+
|
152 |
+
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
153 |
+
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
154 |
+
for i in range(1, max_token_length, tokenizer2.model_max_length):
|
155 |
+
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
156 |
+
# this causes an error:
|
157 |
+
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
158 |
+
# if i > 1:
|
159 |
+
# for j in range(len(chunk)): # batch_size
|
160 |
+
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
161 |
+
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
162 |
+
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
163 |
+
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
164 |
+
hidden_states2 = torch.cat(states_list, dim=1)
|
165 |
+
|
166 |
+
# pool はnの最初のものを使う
|
167 |
+
pool2 = pool2[::n_size]
|
168 |
+
|
169 |
+
return hidden_states1, hidden_states2, pool2
|
170 |
+
|
171 |
+
def encode_tokens(
|
172 |
+
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
173 |
+
) -> List[torch.Tensor]:
|
174 |
+
"""
|
175 |
+
Args:
|
176 |
+
tokenize_strategy: TokenizeStrategy
|
177 |
+
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
|
178 |
+
If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
|
179 |
+
tokens: List of tokens, for text_encoder1 and text_encoder2
|
180 |
+
"""
|
181 |
+
if len(models) == 2:
|
182 |
+
text_encoder1, text_encoder2 = models
|
183 |
+
unwrapped_text_encoder2 = None
|
184 |
+
else:
|
185 |
+
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
|
186 |
+
tokens1, tokens2 = tokens
|
187 |
+
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
|
188 |
+
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
|
189 |
+
|
190 |
+
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
|
191 |
+
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
|
192 |
+
)
|
193 |
+
return [hidden_states1, hidden_states2, pool2]
|
194 |
+
|
195 |
+
def encode_tokens_with_weights(
|
196 |
+
self,
|
197 |
+
tokenize_strategy: TokenizeStrategy,
|
198 |
+
models: List[Any],
|
199 |
+
tokens_list: List[torch.Tensor],
|
200 |
+
weights_list: List[torch.Tensor],
|
201 |
+
) -> List[torch.Tensor]:
|
202 |
+
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list)
|
203 |
+
|
204 |
+
weights_list = [weights.to(hidden_states1.device) for weights in weights_list]
|
205 |
+
|
206 |
+
# apply weights
|
207 |
+
if weights_list[0].shape[1] == 1: # no max_token_length
|
208 |
+
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
209 |
+
hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2)
|
210 |
+
hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2)
|
211 |
+
else:
|
212 |
+
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
213 |
+
for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]):
|
214 |
+
for i in range(weight.shape[1]):
|
215 |
+
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[
|
216 |
+
:, i, 1:-1
|
217 |
+
].unsqueeze(-1)
|
218 |
+
|
219 |
+
return [hidden_states1, hidden_states2, pool2]
|
220 |
+
|
221 |
+
|
222 |
+
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
223 |
+
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
224 |
+
|
225 |
+
def __init__(
|
226 |
+
self,
|
227 |
+
cache_to_disk: bool,
|
228 |
+
batch_size: int,
|
229 |
+
skip_disk_cache_validity_check: bool,
|
230 |
+
is_partial: bool = False,
|
231 |
+
is_weighted: bool = False,
|
232 |
+
) -> None:
|
233 |
+
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
|
234 |
+
|
235 |
+
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
236 |
+
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
237 |
+
|
238 |
+
def is_disk_cached_outputs_expected(self, npz_path: str):
|
239 |
+
if not self.cache_to_disk:
|
240 |
+
return False
|
241 |
+
if not os.path.exists(npz_path):
|
242 |
+
return False
|
243 |
+
if self.skip_disk_cache_validity_check:
|
244 |
+
return True
|
245 |
+
|
246 |
+
try:
|
247 |
+
npz = np.load(npz_path)
|
248 |
+
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
249 |
+
return False
|
250 |
+
except Exception as e:
|
251 |
+
logger.error(f"Error loading file: {npz_path}")
|
252 |
+
raise e
|
253 |
+
|
254 |
+
return True
|
255 |
+
|
256 |
+
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
257 |
+
data = np.load(npz_path)
|
258 |
+
hidden_state1 = data["hidden_state1"]
|
259 |
+
hidden_state2 = data["hidden_state2"]
|
260 |
+
pool2 = data["pool2"]
|
261 |
+
return [hidden_state1, hidden_state2, pool2]
|
262 |
+
|
263 |
+
def cache_batch_outputs(
|
264 |
+
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
265 |
+
):
|
266 |
+
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
267 |
+
captions = [info.caption for info in infos]
|
268 |
+
|
269 |
+
if self.is_weighted:
|
270 |
+
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
|
271 |
+
with torch.no_grad():
|
272 |
+
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights(
|
273 |
+
tokenize_strategy, models, tokens_list, weights_list
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
|
277 |
+
with torch.no_grad():
|
278 |
+
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
|
279 |
+
tokenize_strategy, models, [tokens1, tokens2]
|
280 |
+
)
|
281 |
+
|
282 |
+
if hidden_state1.dtype == torch.bfloat16:
|
283 |
+
hidden_state1 = hidden_state1.float()
|
284 |
+
if hidden_state2.dtype == torch.bfloat16:
|
285 |
+
hidden_state2 = hidden_state2.float()
|
286 |
+
if pool2.dtype == torch.bfloat16:
|
287 |
+
pool2 = pool2.float()
|
288 |
+
|
289 |
+
hidden_state1 = hidden_state1.cpu().numpy()
|
290 |
+
hidden_state2 = hidden_state2.cpu().numpy()
|
291 |
+
pool2 = pool2.cpu().numpy()
|
292 |
+
|
293 |
+
for i, info in enumerate(infos):
|
294 |
+
hidden_state1_i = hidden_state1[i]
|
295 |
+
hidden_state2_i = hidden_state2[i]
|
296 |
+
pool2_i = pool2[i]
|
297 |
+
|
298 |
+
if self.cache_to_disk:
|
299 |
+
np.savez(
|
300 |
+
info.text_encoder_outputs_npz,
|
301 |
+
hidden_state1=hidden_state1_i,
|
302 |
+
hidden_state2=hidden_state2_i,
|
303 |
+
pool2=pool2_i,
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
library/train_util.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
library/utils.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
import threading
|
4 |
+
from typing import *
|
5 |
+
import json
|
6 |
+
import struct
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torchvision import transforms
|
11 |
+
from diffusers import EulerAncestralDiscreteScheduler
|
12 |
+
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
13 |
+
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
14 |
+
import cv2
|
15 |
+
from PIL import Image
|
16 |
+
import numpy as np
|
17 |
+
from safetensors.torch import load_file
|
18 |
+
|
19 |
+
|
20 |
+
def fire_in_thread(f, *args, **kwargs):
|
21 |
+
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
22 |
+
|
23 |
+
|
24 |
+
# region Logging
|
25 |
+
|
26 |
+
|
27 |
+
def add_logging_arguments(parser):
|
28 |
+
parser.add_argument(
|
29 |
+
"--console_log_level",
|
30 |
+
type=str,
|
31 |
+
default=None,
|
32 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
33 |
+
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--console_log_file",
|
37 |
+
type=str,
|
38 |
+
default=None,
|
39 |
+
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
|
40 |
+
)
|
41 |
+
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
|
42 |
+
|
43 |
+
|
44 |
+
def setup_logging(args=None, log_level=None, reset=False):
|
45 |
+
if logging.root.handlers:
|
46 |
+
if reset:
|
47 |
+
# remove all handlers
|
48 |
+
for handler in logging.root.handlers[:]:
|
49 |
+
logging.root.removeHandler(handler)
|
50 |
+
else:
|
51 |
+
return
|
52 |
+
|
53 |
+
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
|
54 |
+
if log_level is None and args is not None:
|
55 |
+
log_level = args.console_log_level
|
56 |
+
if log_level is None:
|
57 |
+
log_level = "INFO"
|
58 |
+
log_level = getattr(logging, log_level)
|
59 |
+
|
60 |
+
msg_init = None
|
61 |
+
if args is not None and args.console_log_file:
|
62 |
+
handler = logging.FileHandler(args.console_log_file, mode="w")
|
63 |
+
else:
|
64 |
+
handler = None
|
65 |
+
if not args or not args.console_log_simple:
|
66 |
+
try:
|
67 |
+
from rich.logging import RichHandler
|
68 |
+
from rich.console import Console
|
69 |
+
from rich.logging import RichHandler
|
70 |
+
|
71 |
+
handler = RichHandler(console=Console(stderr=True))
|
72 |
+
except ImportError:
|
73 |
+
# print("rich is not installed, using basic logging")
|
74 |
+
msg_init = "rich is not installed, using basic logging"
|
75 |
+
|
76 |
+
if handler is None:
|
77 |
+
handler = logging.StreamHandler(sys.stdout) # same as print
|
78 |
+
handler.propagate = False
|
79 |
+
|
80 |
+
formatter = logging.Formatter(
|
81 |
+
fmt="%(message)s",
|
82 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
83 |
+
)
|
84 |
+
handler.setFormatter(formatter)
|
85 |
+
logging.root.setLevel(log_level)
|
86 |
+
logging.root.addHandler(handler)
|
87 |
+
|
88 |
+
if msg_init is not None:
|
89 |
+
logger = logging.getLogger(__name__)
|
90 |
+
logger.info(msg_init)
|
91 |
+
|
92 |
+
|
93 |
+
# endregion
|
94 |
+
|
95 |
+
# region PyTorch utils
|
96 |
+
|
97 |
+
|
98 |
+
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
99 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
100 |
+
|
101 |
+
weight_swap_jobs = []
|
102 |
+
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
103 |
+
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
104 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
105 |
+
|
106 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
107 |
+
|
108 |
+
stream = torch.cuda.Stream()
|
109 |
+
with torch.cuda.stream(stream):
|
110 |
+
# cuda to cpu
|
111 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
112 |
+
cuda_data_view.record_stream(stream)
|
113 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
114 |
+
|
115 |
+
stream.synchronize()
|
116 |
+
|
117 |
+
# cpu to cuda
|
118 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
119 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
120 |
+
module_to_cuda.weight.data = cuda_data_view
|
121 |
+
|
122 |
+
stream.synchronize()
|
123 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
124 |
+
|
125 |
+
|
126 |
+
def weighs_to_device(layer: nn.Module, device: torch.device):
|
127 |
+
for module in layer.modules():
|
128 |
+
if hasattr(module, "weight") and module.weight is not None:
|
129 |
+
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
130 |
+
|
131 |
+
|
132 |
+
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
133 |
+
"""
|
134 |
+
Convert a string to a torch.dtype
|
135 |
+
|
136 |
+
Args:
|
137 |
+
s: string representation of the dtype
|
138 |
+
default_dtype: default dtype to return if s is None
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
torch.dtype: the corresponding torch.dtype
|
142 |
+
|
143 |
+
Raises:
|
144 |
+
ValueError: if the dtype is not supported
|
145 |
+
|
146 |
+
Examples:
|
147 |
+
>>> str_to_dtype("float32")
|
148 |
+
torch.float32
|
149 |
+
>>> str_to_dtype("fp32")
|
150 |
+
torch.float32
|
151 |
+
>>> str_to_dtype("float16")
|
152 |
+
torch.float16
|
153 |
+
>>> str_to_dtype("fp16")
|
154 |
+
torch.float16
|
155 |
+
>>> str_to_dtype("bfloat16")
|
156 |
+
torch.bfloat16
|
157 |
+
>>> str_to_dtype("bf16")
|
158 |
+
torch.bfloat16
|
159 |
+
>>> str_to_dtype("fp8")
|
160 |
+
torch.float8_e4m3fn
|
161 |
+
>>> str_to_dtype("fp8_e4m3fn")
|
162 |
+
torch.float8_e4m3fn
|
163 |
+
>>> str_to_dtype("fp8_e4m3fnuz")
|
164 |
+
torch.float8_e4m3fnuz
|
165 |
+
>>> str_to_dtype("fp8_e5m2")
|
166 |
+
torch.float8_e5m2
|
167 |
+
>>> str_to_dtype("fp8_e5m2fnuz")
|
168 |
+
torch.float8_e5m2fnuz
|
169 |
+
"""
|
170 |
+
if s is None:
|
171 |
+
return default_dtype
|
172 |
+
if s in ["bf16", "bfloat16"]:
|
173 |
+
return torch.bfloat16
|
174 |
+
elif s in ["fp16", "float16"]:
|
175 |
+
return torch.float16
|
176 |
+
elif s in ["fp32", "float32", "float"]:
|
177 |
+
return torch.float32
|
178 |
+
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
|
179 |
+
return torch.float8_e4m3fn
|
180 |
+
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
|
181 |
+
return torch.float8_e4m3fnuz
|
182 |
+
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
|
183 |
+
return torch.float8_e5m2
|
184 |
+
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
|
185 |
+
return torch.float8_e5m2fnuz
|
186 |
+
elif s in ["fp8", "float8"]:
|
187 |
+
return torch.float8_e4m3fn # default fp8
|
188 |
+
else:
|
189 |
+
raise ValueError(f"Unsupported dtype: {s}")
|
190 |
+
|
191 |
+
|
192 |
+
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
193 |
+
"""
|
194 |
+
memory efficient save file
|
195 |
+
"""
|
196 |
+
|
197 |
+
_TYPES = {
|
198 |
+
torch.float64: "F64",
|
199 |
+
torch.float32: "F32",
|
200 |
+
torch.float16: "F16",
|
201 |
+
torch.bfloat16: "BF16",
|
202 |
+
torch.int64: "I64",
|
203 |
+
torch.int32: "I32",
|
204 |
+
torch.int16: "I16",
|
205 |
+
torch.int8: "I8",
|
206 |
+
torch.uint8: "U8",
|
207 |
+
torch.bool: "BOOL",
|
208 |
+
getattr(torch, "float8_e5m2", None): "F8_E5M2",
|
209 |
+
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
|
210 |
+
}
|
211 |
+
_ALIGN = 256
|
212 |
+
|
213 |
+
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
|
214 |
+
validated = {}
|
215 |
+
for key, value in metadata.items():
|
216 |
+
if not isinstance(key, str):
|
217 |
+
raise ValueError(f"Metadata key must be a string, got {type(key)}")
|
218 |
+
if not isinstance(value, str):
|
219 |
+
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
|
220 |
+
validated[key] = str(value)
|
221 |
+
else:
|
222 |
+
validated[key] = value
|
223 |
+
return validated
|
224 |
+
|
225 |
+
print(f"Using memory efficient save file: {filename}")
|
226 |
+
|
227 |
+
header = {}
|
228 |
+
offset = 0
|
229 |
+
if metadata:
|
230 |
+
header["__metadata__"] = validate_metadata(metadata)
|
231 |
+
for k, v in tensors.items():
|
232 |
+
if v.numel() == 0: # empty tensor
|
233 |
+
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
|
234 |
+
else:
|
235 |
+
size = v.numel() * v.element_size()
|
236 |
+
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
|
237 |
+
offset += size
|
238 |
+
|
239 |
+
hjson = json.dumps(header).encode("utf-8")
|
240 |
+
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
|
241 |
+
|
242 |
+
with open(filename, "wb") as f:
|
243 |
+
f.write(struct.pack("<Q", len(hjson)))
|
244 |
+
f.write(hjson)
|
245 |
+
|
246 |
+
for k, v in tensors.items():
|
247 |
+
if v.numel() == 0:
|
248 |
+
continue
|
249 |
+
if v.is_cuda:
|
250 |
+
# Direct GPU to disk save
|
251 |
+
with torch.cuda.device(v.device):
|
252 |
+
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
253 |
+
v = v.unsqueeze(0)
|
254 |
+
tensor_bytes = v.contiguous().view(torch.uint8)
|
255 |
+
tensor_bytes.cpu().numpy().tofile(f)
|
256 |
+
else:
|
257 |
+
# CPU tensor save
|
258 |
+
if v.dim() == 0: # if scalar, need to add a dimension to work with view
|
259 |
+
v = v.unsqueeze(0)
|
260 |
+
v.contiguous().view(torch.uint8).numpy().tofile(f)
|
261 |
+
|
262 |
+
|
263 |
+
class MemoryEfficientSafeOpen:
|
264 |
+
# does not support metadata loading
|
265 |
+
def __init__(self, filename):
|
266 |
+
self.filename = filename
|
267 |
+
self.header, self.header_size = self._read_header()
|
268 |
+
self.file = open(filename, "rb")
|
269 |
+
|
270 |
+
def __enter__(self):
|
271 |
+
return self
|
272 |
+
|
273 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
274 |
+
self.file.close()
|
275 |
+
|
276 |
+
def keys(self):
|
277 |
+
return [k for k in self.header.keys() if k != "__metadata__"]
|
278 |
+
|
279 |
+
def get_tensor(self, key):
|
280 |
+
if key not in self.header:
|
281 |
+
raise KeyError(f"Tensor '{key}' not found in the file")
|
282 |
+
|
283 |
+
metadata = self.header[key]
|
284 |
+
offset_start, offset_end = metadata["data_offsets"]
|
285 |
+
|
286 |
+
if offset_start == offset_end:
|
287 |
+
tensor_bytes = None
|
288 |
+
else:
|
289 |
+
# adjust offset by header size
|
290 |
+
self.file.seek(self.header_size + 8 + offset_start)
|
291 |
+
tensor_bytes = self.file.read(offset_end - offset_start)
|
292 |
+
|
293 |
+
return self._deserialize_tensor(tensor_bytes, metadata)
|
294 |
+
|
295 |
+
def _read_header(self):
|
296 |
+
with open(self.filename, "rb") as f:
|
297 |
+
header_size = struct.unpack("<Q", f.read(8))[0]
|
298 |
+
header_json = f.read(header_size).decode("utf-8")
|
299 |
+
return json.loads(header_json), header_size
|
300 |
+
|
301 |
+
def _deserialize_tensor(self, tensor_bytes, metadata):
|
302 |
+
dtype = self._get_torch_dtype(metadata["dtype"])
|
303 |
+
shape = metadata["shape"]
|
304 |
+
|
305 |
+
if tensor_bytes is None:
|
306 |
+
byte_tensor = torch.empty(0, dtype=torch.uint8)
|
307 |
+
else:
|
308 |
+
tensor_bytes = bytearray(tensor_bytes) # make it writable
|
309 |
+
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
|
310 |
+
|
311 |
+
# process float8 types
|
312 |
+
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
|
313 |
+
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
|
314 |
+
|
315 |
+
# convert to the target dtype and reshape
|
316 |
+
return byte_tensor.view(dtype).reshape(shape)
|
317 |
+
|
318 |
+
@staticmethod
|
319 |
+
def _get_torch_dtype(dtype_str):
|
320 |
+
dtype_map = {
|
321 |
+
"F64": torch.float64,
|
322 |
+
"F32": torch.float32,
|
323 |
+
"F16": torch.float16,
|
324 |
+
"BF16": torch.bfloat16,
|
325 |
+
"I64": torch.int64,
|
326 |
+
"I32": torch.int32,
|
327 |
+
"I16": torch.int16,
|
328 |
+
"I8": torch.int8,
|
329 |
+
"U8": torch.uint8,
|
330 |
+
"BOOL": torch.bool,
|
331 |
+
}
|
332 |
+
# add float8 types if available
|
333 |
+
if hasattr(torch, "float8_e5m2"):
|
334 |
+
dtype_map["F8_E5M2"] = torch.float8_e5m2
|
335 |
+
if hasattr(torch, "float8_e4m3fn"):
|
336 |
+
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
|
337 |
+
return dtype_map.get(dtype_str)
|
338 |
+
|
339 |
+
@staticmethod
|
340 |
+
def _convert_float8(byte_tensor, dtype_str, shape):
|
341 |
+
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
|
342 |
+
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
|
343 |
+
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
|
344 |
+
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
|
345 |
+
else:
|
346 |
+
# # convert to float16 if float8 is not supported
|
347 |
+
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
|
348 |
+
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
349 |
+
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
350 |
+
|
351 |
+
|
352 |
+
def load_safetensors(
|
353 |
+
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
354 |
+
) -> dict[str, torch.Tensor]:
|
355 |
+
if disable_mmap:
|
356 |
+
# return safetensors.torch.load(open(path, "rb").read())
|
357 |
+
# use experimental loader
|
358 |
+
# logger.info(f"Loading without mmap (experimental)")
|
359 |
+
state_dict = {}
|
360 |
+
with MemoryEfficientSafeOpen(path) as f:
|
361 |
+
for key in f.keys():
|
362 |
+
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
|
363 |
+
return state_dict
|
364 |
+
else:
|
365 |
+
try:
|
366 |
+
state_dict = load_file(path, device=device)
|
367 |
+
except:
|
368 |
+
state_dict = load_file(path) # prevent device invalid Error
|
369 |
+
if dtype is not None:
|
370 |
+
for key in state_dict.keys():
|
371 |
+
state_dict[key] = state_dict[key].to(dtype=dtype)
|
372 |
+
return state_dict
|
373 |
+
|
374 |
+
|
375 |
+
# endregion
|
376 |
+
|
377 |
+
# region Image utils
|
378 |
+
|
379 |
+
|
380 |
+
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
381 |
+
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
382 |
+
|
383 |
+
if has_alpha:
|
384 |
+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
385 |
+
else:
|
386 |
+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
387 |
+
|
388 |
+
resized_pil = pil_image.resize(size, interpolation)
|
389 |
+
|
390 |
+
# Convert back to cv2 format
|
391 |
+
if has_alpha:
|
392 |
+
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
|
393 |
+
else:
|
394 |
+
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
|
395 |
+
|
396 |
+
return resized_cv2
|
397 |
+
|
398 |
+
|
399 |
+
# endregion
|
400 |
+
|
401 |
+
# TODO make inf_utils.py
|
402 |
+
# region Gradual Latent hires fix
|
403 |
+
|
404 |
+
|
405 |
+
class GradualLatent:
|
406 |
+
def __init__(
|
407 |
+
self,
|
408 |
+
ratio,
|
409 |
+
start_timesteps,
|
410 |
+
every_n_steps,
|
411 |
+
ratio_step,
|
412 |
+
s_noise=1.0,
|
413 |
+
gaussian_blur_ksize=None,
|
414 |
+
gaussian_blur_sigma=0.5,
|
415 |
+
gaussian_blur_strength=0.5,
|
416 |
+
unsharp_target_x=True,
|
417 |
+
):
|
418 |
+
self.ratio = ratio
|
419 |
+
self.start_timesteps = start_timesteps
|
420 |
+
self.every_n_steps = every_n_steps
|
421 |
+
self.ratio_step = ratio_step
|
422 |
+
self.s_noise = s_noise
|
423 |
+
self.gaussian_blur_ksize = gaussian_blur_ksize
|
424 |
+
self.gaussian_blur_sigma = gaussian_blur_sigma
|
425 |
+
self.gaussian_blur_strength = gaussian_blur_strength
|
426 |
+
self.unsharp_target_x = unsharp_target_x
|
427 |
+
|
428 |
+
def __str__(self) -> str:
|
429 |
+
return (
|
430 |
+
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
431 |
+
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
432 |
+
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
433 |
+
+ f"unsharp_target_x={self.unsharp_target_x})"
|
434 |
+
)
|
435 |
+
|
436 |
+
def apply_unshark_mask(self, x: torch.Tensor):
|
437 |
+
if self.gaussian_blur_ksize is None:
|
438 |
+
return x
|
439 |
+
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
|
440 |
+
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
|
441 |
+
mask = (x - blurred) * self.gaussian_blur_strength
|
442 |
+
sharpened = x + mask
|
443 |
+
return sharpened
|
444 |
+
|
445 |
+
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
446 |
+
org_dtype = x.dtype
|
447 |
+
if org_dtype == torch.bfloat16:
|
448 |
+
x = x.float()
|
449 |
+
|
450 |
+
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
451 |
+
|
452 |
+
# apply unsharp mask / アンシャープマスクを適用する
|
453 |
+
if unsharp and self.gaussian_blur_ksize:
|
454 |
+
x = self.apply_unshark_mask(x)
|
455 |
+
|
456 |
+
return x
|
457 |
+
|
458 |
+
|
459 |
+
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
460 |
+
def __init__(self, *args, **kwargs):
|
461 |
+
super().__init__(*args, **kwargs)
|
462 |
+
self.resized_size = None
|
463 |
+
self.gradual_latent = None
|
464 |
+
|
465 |
+
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
|
466 |
+
self.resized_size = size
|
467 |
+
self.gradual_latent = gradual_latent
|
468 |
+
|
469 |
+
def step(
|
470 |
+
self,
|
471 |
+
model_output: torch.FloatTensor,
|
472 |
+
timestep: Union[float, torch.FloatTensor],
|
473 |
+
sample: torch.FloatTensor,
|
474 |
+
generator: Optional[torch.Generator] = None,
|
475 |
+
return_dict: bool = True,
|
476 |
+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
477 |
+
"""
|
478 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
479 |
+
process from the learned model outputs (most often the predicted noise).
|
480 |
+
|
481 |
+
Args:
|
482 |
+
model_output (`torch.FloatTensor`):
|
483 |
+
The direct output from learned diffusion model.
|
484 |
+
timestep (`float`):
|
485 |
+
The current discrete timestep in the diffusion chain.
|
486 |
+
sample (`torch.FloatTensor`):
|
487 |
+
A current instance of a sample created by the diffusion process.
|
488 |
+
generator (`torch.Generator`, *optional*):
|
489 |
+
A random number generator.
|
490 |
+
return_dict (`bool`):
|
491 |
+
Whether or not to return a
|
492 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
493 |
+
|
494 |
+
Returns:
|
495 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
496 |
+
If return_dict is `True`,
|
497 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
498 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
499 |
+
|
500 |
+
"""
|
501 |
+
|
502 |
+
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
503 |
+
raise ValueError(
|
504 |
+
(
|
505 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
506 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
507 |
+
" one of the `scheduler.timesteps` as a timestep."
|
508 |
+
),
|
509 |
+
)
|
510 |
+
|
511 |
+
if not self.is_scale_input_called:
|
512 |
+
# logger.warning(
|
513 |
+
print(
|
514 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
515 |
+
"See `StableDiffusionPipeline` for a usage example."
|
516 |
+
)
|
517 |
+
|
518 |
+
if self.step_index is None:
|
519 |
+
self._init_step_index(timestep)
|
520 |
+
|
521 |
+
sigma = self.sigmas[self.step_index]
|
522 |
+
|
523 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
524 |
+
if self.config.prediction_type == "epsilon":
|
525 |
+
pred_original_sample = sample - sigma * model_output
|
526 |
+
elif self.config.prediction_type == "v_prediction":
|
527 |
+
# * c_out + input * c_skip
|
528 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
529 |
+
elif self.config.prediction_type == "sample":
|
530 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
531 |
+
else:
|
532 |
+
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
|
533 |
+
|
534 |
+
sigma_from = self.sigmas[self.step_index]
|
535 |
+
sigma_to = self.sigmas[self.step_index + 1]
|
536 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
537 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
538 |
+
|
539 |
+
# 2. Convert to an ODE derivative
|
540 |
+
derivative = (sample - pred_original_sample) / sigma
|
541 |
+
|
542 |
+
dt = sigma_down - sigma
|
543 |
+
|
544 |
+
device = model_output.device
|
545 |
+
if self.resized_size is None:
|
546 |
+
prev_sample = sample + derivative * dt
|
547 |
+
|
548 |
+
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
549 |
+
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
550 |
+
)
|
551 |
+
s_noise = 1.0
|
552 |
+
else:
|
553 |
+
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
554 |
+
s_noise = self.gradual_latent.s_noise
|
555 |
+
|
556 |
+
if self.gradual_latent.unsharp_target_x:
|
557 |
+
prev_sample = sample + derivative * dt
|
558 |
+
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
559 |
+
else:
|
560 |
+
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
561 |
+
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
562 |
+
prev_sample = sample + derivative * dt
|
563 |
+
|
564 |
+
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
565 |
+
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
566 |
+
dtype=model_output.dtype,
|
567 |
+
device=device,
|
568 |
+
generator=generator,
|
569 |
+
)
|
570 |
+
|
571 |
+
prev_sample = prev_sample + noise * sigma_up * s_noise
|
572 |
+
|
573 |
+
# upon completion increase step index by one
|
574 |
+
self._step_index += 1
|
575 |
+
|
576 |
+
if not return_dict:
|
577 |
+
return (prev_sample,)
|
578 |
+
|
579 |
+
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
580 |
+
|
581 |
+
|
582 |
+
# endregion
|