Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +5 -0
- .ipynb_checkpoints/README-checkpoint.md +173 -0
- .python-version +1 -0
- README.md +173 -0
- cache_latents.py +281 -0
- cache_text_encoder_outputs.py +214 -0
- convert_lora.py +135 -0
- dataset/__init__.py +0 -0
- dataset/config_utils.py +372 -0
- dataset/dataset_config.md +387 -0
- dataset/image_video_dataset.py +1400 -0
- docs/advanced_config.md +151 -0
- docs/sampling_during_training.md +108 -0
- docs/wan.md +241 -0
- hunyuan_model/__init__.py +0 -0
- hunyuan_model/activation_layers.py +23 -0
- hunyuan_model/attention.py +295 -0
- hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
- hunyuan_model/embed_layers.py +132 -0
- hunyuan_model/helpers.py +40 -0
- hunyuan_model/mlp_layers.py +118 -0
- hunyuan_model/models.py +1044 -0
- hunyuan_model/modulate_layers.py +76 -0
- hunyuan_model/norm_layers.py +79 -0
- hunyuan_model/pipeline_hunyuan_video.py +1100 -0
- hunyuan_model/posemb_layers.py +310 -0
- hunyuan_model/text_encoder.py +710 -0
- hunyuan_model/token_refiner.py +245 -0
- hunyuan_model/vae.py +446 -0
- hv_generate_video.py +911 -0
- hv_train.py +1721 -0
- hv_train_network.py +0 -0
- merge_lora.py +63 -0
- modules/__init__.py +0 -0
- modules/custom_offloading_utils.py +266 -0
- modules/scheduling_flow_match_discrete.py +257 -0
- modules/unet_causal_3d_blocks.py +818 -0
- networks/__init__.py +0 -0
- networks/lora.py +914 -0
- networks/lora_wan.py +65 -0
- pixel_outputs/pixel_w1_3_lora-000001.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000002.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000003.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000004.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000005.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000006.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000007.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000008.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000009.safetensors +3 -0
- pixel_outputs/pixel_w1_3_lora-000010.safetensors +3 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.venv
|
3 |
+
venv/
|
4 |
+
logs/
|
5 |
+
uv.lock
|
.ipynb_checkpoints/README-checkpoint.md
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pixel Text-to-Video Generation
|
2 |
+
|
3 |
+
This repository contains the necessary steps and scripts to generate videos using the Pixel text-to-video model. The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create high-quality anime-style videos based on textual prompts.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
Before proceeding, ensure that you have the following installed on your system:
|
8 |
+
|
9 |
+
• **Ubuntu** (or a compatible Linux distribution)
|
10 |
+
• **Python 3.x**
|
11 |
+
• **pip** (Python package manager)
|
12 |
+
• **Git**
|
13 |
+
• **Git LFS** (Git Large File Storage)
|
14 |
+
• **FFmpeg**
|
15 |
+
|
16 |
+
## Installation
|
17 |
+
|
18 |
+
1. **Update and Install Dependencies**
|
19 |
+
|
20 |
+
```bash
|
21 |
+
sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
|
22 |
+
```
|
23 |
+
|
24 |
+
2. **Clone the Repository**
|
25 |
+
|
26 |
+
```bash
|
27 |
+
git clone https://huggingface.co/svjack/Pixel_wan_2_1_1_3_B_text2video_lora
|
28 |
+
cd Pixel_wan_2_1_1_3_B_text2video_lora
|
29 |
+
```
|
30 |
+
|
31 |
+
3. **Install Python Dependencies**
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install torch torchvision
|
35 |
+
pip install -r requirements.txt
|
36 |
+
pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
|
37 |
+
pip install moviepy==1.0.3
|
38 |
+
pip install sageattention==1.0.6
|
39 |
+
```
|
40 |
+
|
41 |
+
4. **Download Model Weights**
|
42 |
+
|
43 |
+
```bash
|
44 |
+
wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
|
45 |
+
wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
46 |
+
wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
|
47 |
+
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
|
48 |
+
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
|
49 |
+
```
|
50 |
+
|
51 |
+
## Usage
|
52 |
+
|
53 |
+
To generate a video, use the `wan_generate_video.py` script with the appropriate parameters. Below are examples of how to generate videos using the Pixel model.
|
54 |
+
|
55 |
+
#### Woods
|
56 |
+
|
57 |
+
```bash
|
58 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
59 |
+
--save_path save --output_type both \
|
60 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
61 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
62 |
+
--attn_mode torch \
|
63 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
64 |
+
--lora_multiplier 1.0 \
|
65 |
+
--prompt "The video showcases a pixel art scene from a video game. Golden light filters through the canopy, illuminating soft moss and fallen leaves. Wildflowers bloom nearby, and glowing fireflies hover in the air. A gentle stream flows in the background, its murmur blending with birdsong. The scene radiates tranquility and natural charm."
|
66 |
+
|
67 |
+
```
|
68 |
+
|
69 |
+
|
70 |
+
#### Castle
|
71 |
+
|
72 |
+
```bash
|
73 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
74 |
+
--save_path save --output_type both \
|
75 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
76 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
77 |
+
--attn_mode torch \
|
78 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
79 |
+
--lora_multiplier 1.0 \
|
80 |
+
--prompt "The video showcases a pixel art scene from a video game. the video shifts to a majestic castle under a starry sky. Silvery moonlight bathes the ancient stone walls, casting soft shadows on the surrounding landscape. Towering spires rise into the night, their peaks adorned with glowing orbs that mimic the stars above. A tranquil moat reflects the shimmering heavens, its surface rippling gently in the cool breeze. Fireflies dance around the castle’s ivy-covered arches, adding a touch of magic to the scene. In the distance, a faint aurora paints the horizon with hues of green and purple, blending seamlessly with the celestial tapestry. The scene exudes an aura of timeless wonder and serene beauty."
|
81 |
+
|
82 |
+
```
|
83 |
+
|
84 |
+
|
85 |
+
#### City
|
86 |
+
|
87 |
+
```bash
|
88 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
89 |
+
--save_path save --output_type both \
|
90 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
91 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
92 |
+
--attn_mode torch \
|
93 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
94 |
+
--lora_multiplier 1.0 \
|
95 |
+
--prompt "The video showcases a pixel art scene from a video game. the video showcases a vibrant urban landscape. The city skyline is dominated by towering skyscrapers, their glass facades reflecting the sunlight. The streets are bustling with activity, filled with cars, buses, and pedestrians. Parks and green spaces are scattered throughout, offering a refreshing contrast to the concrete jungle. The architecture is a mix of modern and historic buildings, each telling a story of the city's evolution. The overall scene is alive with energy, capturing the essence of urban life."
|
96 |
+
|
97 |
+
```
|
98 |
+
|
99 |
+
|
100 |
+
#### Girl
|
101 |
+
|
102 |
+
```bash
|
103 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
104 |
+
--save_path save --output_type both \
|
105 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
106 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
107 |
+
--attn_mode torch \
|
108 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
109 |
+
--lora_multiplier 1.0 \
|
110 |
+
--prompt "The video showcases a pixel art scene from a video game. .The video showcases a animation featuring charming anime-style scene featuring a pink-haired girl with angel wings. She's seated at a desk, enjoying a donut while working on a laptop. The setting is a cozy, pastel-colored room with a pink chair, a milk carton, and a coffee cup. The girl's expression is one of delight as she savors her treat."
|
111 |
+
|
112 |
+
```
|
113 |
+
|
114 |
+
#### Snow
|
115 |
+
|
116 |
+
```bash
|
117 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
118 |
+
--save_path save --output_type both \
|
119 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
120 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
121 |
+
--attn_mode torch \
|
122 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
123 |
+
--lora_multiplier 1.0 \
|
124 |
+
--prompt "The video showcases a pixel art scene from a video game. .The video showcases a animation featuring a serene and majestic snowy mountain landscape. The scene is dominated by towering peaks covered in pristine white snow, with a soft gradient of blue and purple hues in the sky. A small cabin with a smoking chimney sits at the base of the mountain, surrounded by pine trees dusted with snow. A winding path leads up the mountain, with footprints visible in the snow. The atmosphere is calm and peaceful, evoking a sense of solitude and wonder."
|
125 |
+
|
126 |
+
```
|
127 |
+
|
128 |
+
|
129 |
+
## Parameters
|
130 |
+
|
131 |
+
* `--fp8`: Enable FP8 precision (optional).
|
132 |
+
* `--task`: Specify the task (e.g., `t2v-1.3B`).
|
133 |
+
* `--video_size`: Set the resolution of the generated video (e.g., `1024 1024`).
|
134 |
+
* `--video_length`: Define the length of the video in frames.
|
135 |
+
* `--infer_steps`: Number of inference steps.
|
136 |
+
* `--save_path`: Directory to save the generated video.
|
137 |
+
* `--output_type`: Output type (e.g., `both` for video and frames).
|
138 |
+
* `--dit`: Path to the diffusion model weights.
|
139 |
+
* `--vae`: Path to the VAE model weights.
|
140 |
+
* `--t5`: Path to the T5 model weights.
|
141 |
+
* `--attn_mode`: Attention mode (e.g., `torch`).
|
142 |
+
* `--lora_weight`: Path to the LoRA weights.
|
143 |
+
* `--lora_multiplier`: Multiplier for LoRA weights.
|
144 |
+
* `--prompt`: Textual prompt for video generation.
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
## Output
|
149 |
+
|
150 |
+
The generated video and frames will be saved in the specified `save_path` directory.
|
151 |
+
|
152 |
+
## Troubleshooting
|
153 |
+
|
154 |
+
• Ensure all dependencies are correctly installed.
|
155 |
+
• Verify that the model weights are downloaded and placed in the correct locations.
|
156 |
+
• Check for any missing Python packages and install them using `pip`.
|
157 |
+
|
158 |
+
## License
|
159 |
+
|
160 |
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
161 |
+
|
162 |
+
## Acknowledgments
|
163 |
+
|
164 |
+
• **Hugging Face** for hosting the model weights.
|
165 |
+
• **Wan-AI** for providing the pre-trained models.
|
166 |
+
• **DeepBeepMeep** for contributing to the model weights.
|
167 |
+
|
168 |
+
## Contact
|
169 |
+
|
170 |
+
For any questions or issues, please open an issue on the repository or contact the maintainer.
|
171 |
+
|
172 |
+
---
|
173 |
+
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.10
|
README.md
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pixel Text-to-Video Generation
|
2 |
+
|
3 |
+
This repository contains the necessary steps and scripts to generate videos using the Pixel text-to-video model. The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create high-quality anime-style videos based on textual prompts.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
Before proceeding, ensure that you have the following installed on your system:
|
8 |
+
|
9 |
+
• **Ubuntu** (or a compatible Linux distribution)
|
10 |
+
• **Python 3.x**
|
11 |
+
• **pip** (Python package manager)
|
12 |
+
• **Git**
|
13 |
+
• **Git LFS** (Git Large File Storage)
|
14 |
+
• **FFmpeg**
|
15 |
+
|
16 |
+
## Installation
|
17 |
+
|
18 |
+
1. **Update and Install Dependencies**
|
19 |
+
|
20 |
+
```bash
|
21 |
+
sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
|
22 |
+
```
|
23 |
+
|
24 |
+
2. **Clone the Repository**
|
25 |
+
|
26 |
+
```bash
|
27 |
+
git clone https://huggingface.co/svjack/Pixel_wan_2_1_1_3_B_text2video_lora
|
28 |
+
cd Pixel_wan_2_1_1_3_B_text2video_lora
|
29 |
+
```
|
30 |
+
|
31 |
+
3. **Install Python Dependencies**
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install torch torchvision
|
35 |
+
pip install -r requirements.txt
|
36 |
+
pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
|
37 |
+
pip install moviepy==1.0.3
|
38 |
+
pip install sageattention==1.0.6
|
39 |
+
```
|
40 |
+
|
41 |
+
4. **Download Model Weights**
|
42 |
+
|
43 |
+
```bash
|
44 |
+
wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
|
45 |
+
wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
46 |
+
wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
|
47 |
+
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
|
48 |
+
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
|
49 |
+
```
|
50 |
+
|
51 |
+
## Usage
|
52 |
+
|
53 |
+
To generate a video, use the `wan_generate_video.py` script with the appropriate parameters. Below are examples of how to generate videos using the Pixel model.
|
54 |
+
|
55 |
+
#### Woods
|
56 |
+
|
57 |
+
```bash
|
58 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
59 |
+
--save_path save --output_type both \
|
60 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
61 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
62 |
+
--attn_mode torch \
|
63 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
64 |
+
--lora_multiplier 1.0 \
|
65 |
+
--prompt "The video showcases a pixel art scene from a video game. Golden light filters through the canopy, illuminating soft moss and fallen leaves. Wildflowers bloom nearby, and glowing fireflies hover in the air. A gentle stream flows in the background, its murmur blending with birdsong. The scene radiates tranquility and natural charm."
|
66 |
+
|
67 |
+
```
|
68 |
+
|
69 |
+
|
70 |
+
#### Castle
|
71 |
+
|
72 |
+
```bash
|
73 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
74 |
+
--save_path save --output_type both \
|
75 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
76 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
77 |
+
--attn_mode torch \
|
78 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
79 |
+
--lora_multiplier 1.0 \
|
80 |
+
--prompt "The video showcases a pixel art scene from a video game. the video shifts to a majestic castle under a starry sky. Silvery moonlight bathes the ancient stone walls, casting soft shadows on the surrounding landscape. Towering spires rise into the night, their peaks adorned with glowing orbs that mimic the stars above. A tranquil moat reflects the shimmering heavens, its surface rippling gently in the cool breeze. Fireflies dance around the castle’s ivy-covered arches, adding a touch of magic to the scene. In the distance, a faint aurora paints the horizon with hues of green and purple, blending seamlessly with the celestial tapestry. The scene exudes an aura of timeless wonder and serene beauty."
|
81 |
+
|
82 |
+
```
|
83 |
+
|
84 |
+
|
85 |
+
#### City
|
86 |
+
|
87 |
+
```bash
|
88 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
89 |
+
--save_path save --output_type both \
|
90 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
91 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
92 |
+
--attn_mode torch \
|
93 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
94 |
+
--lora_multiplier 1.0 \
|
95 |
+
--prompt "The video showcases a pixel art scene from a video game. the video showcases a vibrant urban landscape. The city skyline is dominated by towering skyscrapers, their glass facades reflecting the sunlight. The streets are bustling with activity, filled with cars, buses, and pedestrians. Parks and green spaces are scattered throughout, offering a refreshing contrast to the concrete jungle. The architecture is a mix of modern and historic buildings, each telling a story of the city's evolution. The overall scene is alive with energy, capturing the essence of urban life."
|
96 |
+
|
97 |
+
```
|
98 |
+
|
99 |
+
|
100 |
+
#### Girl
|
101 |
+
|
102 |
+
```bash
|
103 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
104 |
+
--save_path save --output_type both \
|
105 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
106 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
107 |
+
--attn_mode torch \
|
108 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
109 |
+
--lora_multiplier 1.0 \
|
110 |
+
--prompt "The video showcases a pixel art scene from a video game. .The video showcases a animation featuring charming anime-style scene featuring a pink-haired girl with angel wings. She's seated at a desk, enjoying a donut while working on a laptop. The setting is a cozy, pastel-colored room with a pink chair, a milk carton, and a coffee cup. The girl's expression is one of delight as she savors her treat."
|
111 |
+
|
112 |
+
```
|
113 |
+
|
114 |
+
#### Snow
|
115 |
+
|
116 |
+
```bash
|
117 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
|
118 |
+
--save_path save --output_type both \
|
119 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
120 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
121 |
+
--attn_mode torch \
|
122 |
+
--lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
|
123 |
+
--lora_multiplier 1.0 \
|
124 |
+
--prompt "The video showcases a pixel art scene from a video game. .The video showcases a animation featuring a serene and majestic snowy mountain landscape. The scene is dominated by towering peaks covered in pristine white snow, with a soft gradient of blue and purple hues in the sky. A small cabin with a smoking chimney sits at the base of the mountain, surrounded by pine trees dusted with snow. A winding path leads up the mountain, with footprints visible in the snow. The atmosphere is calm and peaceful, evoking a sense of solitude and wonder."
|
125 |
+
|
126 |
+
```
|
127 |
+
|
128 |
+
|
129 |
+
## Parameters
|
130 |
+
|
131 |
+
* `--fp8`: Enable FP8 precision (optional).
|
132 |
+
* `--task`: Specify the task (e.g., `t2v-1.3B`).
|
133 |
+
* `--video_size`: Set the resolution of the generated video (e.g., `1024 1024`).
|
134 |
+
* `--video_length`: Define the length of the video in frames.
|
135 |
+
* `--infer_steps`: Number of inference steps.
|
136 |
+
* `--save_path`: Directory to save the generated video.
|
137 |
+
* `--output_type`: Output type (e.g., `both` for video and frames).
|
138 |
+
* `--dit`: Path to the diffusion model weights.
|
139 |
+
* `--vae`: Path to the VAE model weights.
|
140 |
+
* `--t5`: Path to the T5 model weights.
|
141 |
+
* `--attn_mode`: Attention mode (e.g., `torch`).
|
142 |
+
* `--lora_weight`: Path to the LoRA weights.
|
143 |
+
* `--lora_multiplier`: Multiplier for LoRA weights.
|
144 |
+
* `--prompt`: Textual prompt for video generation.
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
## Output
|
149 |
+
|
150 |
+
The generated video and frames will be saved in the specified `save_path` directory.
|
151 |
+
|
152 |
+
## Troubleshooting
|
153 |
+
|
154 |
+
• Ensure all dependencies are correctly installed.
|
155 |
+
• Verify that the model weights are downloaded and placed in the correct locations.
|
156 |
+
• Check for any missing Python packages and install them using `pip`.
|
157 |
+
|
158 |
+
## License
|
159 |
+
|
160 |
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
161 |
+
|
162 |
+
## Acknowledgments
|
163 |
+
|
164 |
+
• **Hugging Face** for hosting the model weights.
|
165 |
+
• **Wan-AI** for providing the pre-trained models.
|
166 |
+
• **DeepBeepMeep** for contributing to the model weights.
|
167 |
+
|
168 |
+
## Contact
|
169 |
+
|
170 |
+
For any questions or issues, please open an issue on the repository or contact the maintainer.
|
171 |
+
|
172 |
+
---
|
173 |
+
|
cache_latents.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from dataset import config_utils
|
11 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
import logging
|
15 |
+
|
16 |
+
from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache, ARCHITECTURE_HUNYUAN_VIDEO
|
17 |
+
from hunyuan_model.vae import load_vae
|
18 |
+
from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
19 |
+
from utils.model_utils import str_to_dtype
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
|
24 |
+
|
25 |
+
def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
|
26 |
+
import cv2
|
27 |
+
|
28 |
+
imgs = (
|
29 |
+
[image]
|
30 |
+
if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
|
31 |
+
else [image[0], image[-1]]
|
32 |
+
)
|
33 |
+
if len(imgs) > 1:
|
34 |
+
print(f"Number of images: {len(image)}")
|
35 |
+
for i, img in enumerate(imgs):
|
36 |
+
if len(imgs) > 1:
|
37 |
+
print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
|
38 |
+
else:
|
39 |
+
print(f"Image: {img.shape}")
|
40 |
+
cv2_img = np.array(img) if isinstance(img, Image.Image) else img
|
41 |
+
cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
|
42 |
+
cv2.imshow("image", cv2_img)
|
43 |
+
k = cv2.waitKey(0)
|
44 |
+
cv2.destroyAllWindows()
|
45 |
+
if k == ord("q") or k == ord("d"):
|
46 |
+
return k
|
47 |
+
return k
|
48 |
+
|
49 |
+
|
50 |
+
def show_console(
|
51 |
+
image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
|
52 |
+
width: int,
|
53 |
+
back: str,
|
54 |
+
interactive: bool = False,
|
55 |
+
) -> int:
|
56 |
+
from ascii_magic import from_pillow_image, Back
|
57 |
+
|
58 |
+
back = None
|
59 |
+
if back is not None:
|
60 |
+
back = getattr(Back, back.upper())
|
61 |
+
|
62 |
+
k = None
|
63 |
+
imgs = (
|
64 |
+
[image]
|
65 |
+
if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
|
66 |
+
else [image[0], image[-1]]
|
67 |
+
)
|
68 |
+
if len(imgs) > 1:
|
69 |
+
print(f"Number of images: {len(image)}")
|
70 |
+
for i, img in enumerate(imgs):
|
71 |
+
if len(imgs) > 1:
|
72 |
+
print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
|
73 |
+
else:
|
74 |
+
print(f"Image: {img.shape}")
|
75 |
+
pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
|
76 |
+
ascii_img = from_pillow_image(pil_img)
|
77 |
+
ascii_img.to_terminal(columns=width, back=back)
|
78 |
+
|
79 |
+
if interactive:
|
80 |
+
k = input("Press q to quit, d to next dataset, other key to next: ")
|
81 |
+
if k == "q" or k == "d":
|
82 |
+
return ord(k)
|
83 |
+
|
84 |
+
if not interactive:
|
85 |
+
return ord(" ")
|
86 |
+
return ord(k) if k else ord(" ")
|
87 |
+
|
88 |
+
|
89 |
+
def show_datasets(
|
90 |
+
datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
|
91 |
+
):
|
92 |
+
print(f"d: next dataset, q: quit")
|
93 |
+
|
94 |
+
num_workers = max(1, os.cpu_count() - 1)
|
95 |
+
for i, dataset in enumerate(datasets):
|
96 |
+
print(f"Dataset [{i}]")
|
97 |
+
batch_index = 0
|
98 |
+
num_images_to_show = console_num_images
|
99 |
+
k = None
|
100 |
+
for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
|
101 |
+
print(f"bucket resolution: {key}, count: {len(batch)}")
|
102 |
+
for j, item_info in enumerate(batch):
|
103 |
+
item_info: ItemInfo
|
104 |
+
print(f"{batch_index}-{j}: {item_info}")
|
105 |
+
if debug_mode == "image":
|
106 |
+
k = show_image(item_info.content)
|
107 |
+
elif debug_mode == "console":
|
108 |
+
k = show_console(item_info.content, console_width, console_back, console_num_images is None)
|
109 |
+
if num_images_to_show is not None:
|
110 |
+
num_images_to_show -= 1
|
111 |
+
if num_images_to_show == 0:
|
112 |
+
k = ord("d") # next dataset
|
113 |
+
|
114 |
+
if k == ord("q"):
|
115 |
+
return
|
116 |
+
elif k == ord("d"):
|
117 |
+
break
|
118 |
+
if k == ord("d"):
|
119 |
+
break
|
120 |
+
batch_index += 1
|
121 |
+
|
122 |
+
|
123 |
+
def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
|
124 |
+
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
|
125 |
+
if len(contents.shape) == 4:
|
126 |
+
contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
|
127 |
+
|
128 |
+
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
129 |
+
contents = contents.to(vae.device, dtype=vae.dtype)
|
130 |
+
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
|
131 |
+
|
132 |
+
h, w = contents.shape[3], contents.shape[4]
|
133 |
+
if h < 8 or w < 8:
|
134 |
+
item = batch[0] # other items should have the same size
|
135 |
+
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
|
136 |
+
|
137 |
+
# print(f"encode batch: {contents.shape}")
|
138 |
+
with torch.no_grad():
|
139 |
+
latent = vae.encode(contents).latent_dist.sample()
|
140 |
+
# latent = latent * vae.config.scaling_factor
|
141 |
+
|
142 |
+
# # debug: decode and save
|
143 |
+
# with torch.no_grad():
|
144 |
+
# latent_to_decode = latent / vae.config.scaling_factor
|
145 |
+
# images = vae.decode(latent_to_decode, return_dict=False)[0]
|
146 |
+
# images = (images / 2 + 0.5).clamp(0, 1)
|
147 |
+
# images = images.cpu().float().numpy()
|
148 |
+
# images = (images * 255).astype(np.uint8)
|
149 |
+
# images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
|
150 |
+
# for b in range(images.shape[0]):
|
151 |
+
# for f in range(images.shape[1]):
|
152 |
+
# fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
|
153 |
+
# img = Image.fromarray(images[b, f])
|
154 |
+
# img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
|
155 |
+
|
156 |
+
for item, l in zip(batch, latent):
|
157 |
+
# print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
|
158 |
+
save_latent_cache(item, l)
|
159 |
+
|
160 |
+
|
161 |
+
def encode_datasets(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
|
162 |
+
num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
|
163 |
+
for i, dataset in enumerate(datasets):
|
164 |
+
logger.info(f"Encoding dataset [{i}]")
|
165 |
+
all_latent_cache_paths = []
|
166 |
+
for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
|
167 |
+
all_latent_cache_paths.extend([item.latent_cache_path for item in batch])
|
168 |
+
|
169 |
+
if args.skip_existing:
|
170 |
+
filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
|
171 |
+
if len(filtered_batch) == 0:
|
172 |
+
continue
|
173 |
+
batch = filtered_batch
|
174 |
+
|
175 |
+
bs = args.batch_size if args.batch_size is not None else len(batch)
|
176 |
+
for i in range(0, len(batch), bs):
|
177 |
+
encode(batch[i : i + bs])
|
178 |
+
|
179 |
+
# normalize paths
|
180 |
+
all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
|
181 |
+
all_latent_cache_paths = set(all_latent_cache_paths)
|
182 |
+
|
183 |
+
# remove old cache files not in the dataset
|
184 |
+
all_cache_files = dataset.get_all_latent_cache_files()
|
185 |
+
for cache_file in all_cache_files:
|
186 |
+
if os.path.normpath(cache_file) not in all_latent_cache_paths:
|
187 |
+
if args.keep_cache:
|
188 |
+
logger.info(f"Keep cache file not in the dataset: {cache_file}")
|
189 |
+
else:
|
190 |
+
os.remove(cache_file)
|
191 |
+
logger.info(f"Removed old cache file: {cache_file}")
|
192 |
+
|
193 |
+
|
194 |
+
def main(args):
|
195 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
196 |
+
device = torch.device(device)
|
197 |
+
|
198 |
+
# Load dataset config
|
199 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
200 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
201 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
202 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
|
203 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
204 |
+
|
205 |
+
datasets = train_dataset_group.datasets
|
206 |
+
|
207 |
+
if args.debug_mode is not None:
|
208 |
+
show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
|
209 |
+
return
|
210 |
+
|
211 |
+
assert args.vae is not None, "vae checkpoint is required"
|
212 |
+
|
213 |
+
# Load VAE model: HunyuanVideo VAE model is float16
|
214 |
+
vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
|
215 |
+
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
|
216 |
+
vae.eval()
|
217 |
+
logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
|
218 |
+
|
219 |
+
if args.vae_chunk_size is not None:
|
220 |
+
vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
|
221 |
+
logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
|
222 |
+
if args.vae_spatial_tile_sample_min_size is not None:
|
223 |
+
vae.enable_spatial_tiling(True)
|
224 |
+
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
|
225 |
+
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
|
226 |
+
elif args.vae_tiling:
|
227 |
+
vae.enable_spatial_tiling(True)
|
228 |
+
|
229 |
+
# Encode images
|
230 |
+
def encode(one_batch: list[ItemInfo]):
|
231 |
+
encode_and_save_batch(vae, one_batch)
|
232 |
+
|
233 |
+
encode_datasets(datasets, encode, args)
|
234 |
+
|
235 |
+
|
236 |
+
def setup_parser_common() -> argparse.ArgumentParser:
|
237 |
+
parser = argparse.ArgumentParser()
|
238 |
+
|
239 |
+
parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
|
240 |
+
parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
|
241 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
|
242 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
|
243 |
+
parser.add_argument(
|
244 |
+
"--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
|
245 |
+
)
|
246 |
+
parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
|
247 |
+
parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
|
248 |
+
parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
|
249 |
+
parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
|
250 |
+
parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
|
251 |
+
parser.add_argument(
|
252 |
+
"--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--console_num_images",
|
256 |
+
type=int,
|
257 |
+
default=None,
|
258 |
+
help="debug mode: not interactive, number of images to show for each dataset",
|
259 |
+
)
|
260 |
+
return parser
|
261 |
+
|
262 |
+
|
263 |
+
def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
264 |
+
parser.add_argument(
|
265 |
+
"--vae_tiling",
|
266 |
+
action="store_true",
|
267 |
+
help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
|
268 |
+
)
|
269 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
270 |
+
parser.add_argument(
|
271 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
272 |
+
)
|
273 |
+
return parser
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == "__main__":
|
277 |
+
parser = setup_parser_common()
|
278 |
+
parser = hv_setup_parser(parser)
|
279 |
+
|
280 |
+
args = parser.parse_args()
|
281 |
+
main(args)
|
cache_text_encoder_outputs.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from dataset import config_utils
|
10 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
11 |
+
import accelerate
|
12 |
+
|
13 |
+
from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, BaseDataset, ItemInfo, save_text_encoder_output_cache
|
14 |
+
from hunyuan_model import text_encoder as text_encoder_module
|
15 |
+
from hunyuan_model.text_encoder import TextEncoder
|
16 |
+
|
17 |
+
import logging
|
18 |
+
|
19 |
+
from utils.model_utils import str_to_dtype
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
|
24 |
+
|
25 |
+
def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
|
26 |
+
data_type = "video" # video only, image is not supported
|
27 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
|
31 |
+
|
32 |
+
return prompt_outputs.hidden_state, prompt_outputs.attention_mask
|
33 |
+
|
34 |
+
|
35 |
+
def encode_and_save_batch(
|
36 |
+
text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
|
37 |
+
):
|
38 |
+
prompts = [item.caption for item in batch]
|
39 |
+
# print(prompts)
|
40 |
+
|
41 |
+
# encode prompt
|
42 |
+
if accelerator is not None:
|
43 |
+
with accelerator.autocast():
|
44 |
+
prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
|
45 |
+
else:
|
46 |
+
prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
|
47 |
+
|
48 |
+
# # convert to fp16 if needed
|
49 |
+
# if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
|
50 |
+
# prompt_embeds = prompt_embeds.to(text_encoder.dtype)
|
51 |
+
|
52 |
+
# save prompt cache
|
53 |
+
for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
|
54 |
+
save_text_encoder_output_cache(item, embed, mask, is_llm)
|
55 |
+
|
56 |
+
|
57 |
+
def prepare_cache_files_and_paths(datasets: list[BaseDataset]):
|
58 |
+
all_cache_files_for_dataset = [] # exisiting cache files
|
59 |
+
all_cache_paths_for_dataset = [] # all cache paths in the dataset
|
60 |
+
for dataset in datasets:
|
61 |
+
all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()]
|
62 |
+
all_cache_files = set(all_cache_files)
|
63 |
+
all_cache_files_for_dataset.append(all_cache_files)
|
64 |
+
|
65 |
+
all_cache_paths_for_dataset.append(set())
|
66 |
+
return all_cache_files_for_dataset, all_cache_paths_for_dataset
|
67 |
+
|
68 |
+
|
69 |
+
def process_text_encoder_batches(
|
70 |
+
num_workers: Optional[int],
|
71 |
+
skip_existing: bool,
|
72 |
+
batch_size: int,
|
73 |
+
datasets: list[BaseDataset],
|
74 |
+
all_cache_files_for_dataset: list[set],
|
75 |
+
all_cache_paths_for_dataset: list[set],
|
76 |
+
encode: callable,
|
77 |
+
):
|
78 |
+
num_workers = num_workers if num_workers is not None else max(1, os.cpu_count() - 1)
|
79 |
+
for i, dataset in enumerate(datasets):
|
80 |
+
logger.info(f"Encoding dataset [{i}]")
|
81 |
+
all_cache_files = all_cache_files_for_dataset[i]
|
82 |
+
all_cache_paths = all_cache_paths_for_dataset[i]
|
83 |
+
for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
|
84 |
+
# update cache files (it's ok if we update it multiple times)
|
85 |
+
all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch])
|
86 |
+
|
87 |
+
# skip existing cache files
|
88 |
+
if skip_existing:
|
89 |
+
filtered_batch = [
|
90 |
+
item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files
|
91 |
+
]
|
92 |
+
# print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files")
|
93 |
+
if len(filtered_batch) == 0:
|
94 |
+
continue
|
95 |
+
batch = filtered_batch
|
96 |
+
|
97 |
+
bs = batch_size if batch_size is not None else len(batch)
|
98 |
+
for i in range(0, len(batch), bs):
|
99 |
+
encode(batch[i : i + bs])
|
100 |
+
|
101 |
+
|
102 |
+
def post_process_cache_files(
|
103 |
+
datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set]
|
104 |
+
):
|
105 |
+
for i, dataset in enumerate(datasets):
|
106 |
+
all_cache_files = all_cache_files_for_dataset[i]
|
107 |
+
all_cache_paths = all_cache_paths_for_dataset[i]
|
108 |
+
for cache_file in all_cache_files:
|
109 |
+
if cache_file not in all_cache_paths:
|
110 |
+
if args.keep_cache:
|
111 |
+
logger.info(f"Keep cache file not in the dataset: {cache_file}")
|
112 |
+
else:
|
113 |
+
os.remove(cache_file)
|
114 |
+
logger.info(f"Removed old cache file: {cache_file}")
|
115 |
+
|
116 |
+
|
117 |
+
def main(args):
|
118 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
119 |
+
device = torch.device(device)
|
120 |
+
|
121 |
+
# Load dataset config
|
122 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
123 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
124 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
125 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
|
126 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
127 |
+
|
128 |
+
datasets = train_dataset_group.datasets
|
129 |
+
|
130 |
+
# define accelerator for fp8 inference
|
131 |
+
accelerator = None
|
132 |
+
if args.fp8_llm:
|
133 |
+
accelerator = accelerate.Accelerator(mixed_precision="fp16")
|
134 |
+
|
135 |
+
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
|
136 |
+
all_cache_files_for_dataset, all_cache_paths_for_dataset = prepare_cache_files_and_paths(datasets)
|
137 |
+
|
138 |
+
# Load Text Encoder 1
|
139 |
+
text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
|
140 |
+
logger.info(f"loading text encoder 1: {args.text_encoder1}")
|
141 |
+
text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
|
142 |
+
text_encoder_1.to(device=device)
|
143 |
+
|
144 |
+
# Encode with Text Encoder 1 (LLM)
|
145 |
+
logger.info("Encoding with Text Encoder 1")
|
146 |
+
|
147 |
+
def encode_for_text_encoder_1(batch: list[ItemInfo]):
|
148 |
+
encode_and_save_batch(text_encoder_1, batch, is_llm=True, accelerator=accelerator)
|
149 |
+
|
150 |
+
process_text_encoder_batches(
|
151 |
+
args.num_workers,
|
152 |
+
args.skip_existing,
|
153 |
+
args.batch_size,
|
154 |
+
datasets,
|
155 |
+
all_cache_files_for_dataset,
|
156 |
+
all_cache_paths_for_dataset,
|
157 |
+
encode_for_text_encoder_1,
|
158 |
+
)
|
159 |
+
del text_encoder_1
|
160 |
+
|
161 |
+
# Load Text Encoder 2
|
162 |
+
logger.info(f"loading text encoder 2: {args.text_encoder2}")
|
163 |
+
text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
|
164 |
+
text_encoder_2.to(device=device)
|
165 |
+
|
166 |
+
# Encode with Text Encoder 2
|
167 |
+
logger.info("Encoding with Text Encoder 2")
|
168 |
+
|
169 |
+
def encode_for_text_encoder_2(batch: list[ItemInfo]):
|
170 |
+
encode_and_save_batch(text_encoder_2, batch, is_llm=False, accelerator=None)
|
171 |
+
|
172 |
+
process_text_encoder_batches(
|
173 |
+
args.num_workers,
|
174 |
+
args.skip_existing,
|
175 |
+
args.batch_size,
|
176 |
+
datasets,
|
177 |
+
all_cache_files_for_dataset,
|
178 |
+
all_cache_paths_for_dataset,
|
179 |
+
encode_for_text_encoder_2,
|
180 |
+
)
|
181 |
+
del text_encoder_2
|
182 |
+
|
183 |
+
# remove cache files not in dataset
|
184 |
+
post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset)
|
185 |
+
|
186 |
+
|
187 |
+
def setup_parser_common():
|
188 |
+
parser = argparse.ArgumentParser()
|
189 |
+
|
190 |
+
parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
|
191 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
|
192 |
+
parser.add_argument(
|
193 |
+
"--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
|
194 |
+
)
|
195 |
+
parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
|
196 |
+
parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
|
197 |
+
parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
|
198 |
+
return parser
|
199 |
+
|
200 |
+
|
201 |
+
def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
202 |
+
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
|
203 |
+
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
|
204 |
+
parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
|
205 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
206 |
+
return parser
|
207 |
+
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
parser = setup_parser_common()
|
211 |
+
parser = hv_setup_parser(parser)
|
212 |
+
|
213 |
+
args = parser.parse_args()
|
214 |
+
main(args)
|
convert_lora.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from safetensors.torch import load_file, save_file
|
5 |
+
from safetensors import safe_open
|
6 |
+
from utils import model_utils
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
|
14 |
+
|
15 |
+
def convert_from_diffusers(prefix, weights_sd):
|
16 |
+
# convert from diffusers(?) to default LoRA
|
17 |
+
# Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
|
18 |
+
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
|
19 |
+
|
20 |
+
# note: Diffusers has no alpha, so alpha is set to rank
|
21 |
+
new_weights_sd = {}
|
22 |
+
lora_dims = {}
|
23 |
+
for key, weight in weights_sd.items():
|
24 |
+
diffusers_prefix, key_body = key.split(".", 1)
|
25 |
+
if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
|
26 |
+
logger.warning(f"unexpected key: {key} in diffusers format")
|
27 |
+
continue
|
28 |
+
|
29 |
+
new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
|
30 |
+
new_weights_sd[new_key] = weight
|
31 |
+
|
32 |
+
lora_name = new_key.split(".")[0] # before first dot
|
33 |
+
if lora_name not in lora_dims and "lora_down" in new_key:
|
34 |
+
lora_dims[lora_name] = weight.shape[0]
|
35 |
+
|
36 |
+
# add alpha with rank
|
37 |
+
for lora_name, dim in lora_dims.items():
|
38 |
+
new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
|
39 |
+
|
40 |
+
return new_weights_sd
|
41 |
+
|
42 |
+
|
43 |
+
def convert_to_diffusers(prefix, weights_sd):
|
44 |
+
# convert from default LoRA to diffusers
|
45 |
+
|
46 |
+
# get alphas
|
47 |
+
lora_alphas = {}
|
48 |
+
for key, weight in weights_sd.items():
|
49 |
+
if key.startswith(prefix):
|
50 |
+
lora_name = key.split(".", 1)[0] # before first dot
|
51 |
+
if lora_name not in lora_alphas and "alpha" in key:
|
52 |
+
lora_alphas[lora_name] = weight
|
53 |
+
|
54 |
+
new_weights_sd = {}
|
55 |
+
for key, weight in weights_sd.items():
|
56 |
+
if key.startswith(prefix):
|
57 |
+
if "alpha" in key:
|
58 |
+
continue
|
59 |
+
|
60 |
+
lora_name = key.split(".", 1)[0] # before first dot
|
61 |
+
|
62 |
+
module_name = lora_name[len(prefix) :] # remove "lora_unet_"
|
63 |
+
module_name = module_name.replace("_", ".") # replace "_" with "."
|
64 |
+
if ".cross.attn." in module_name or ".self.attn." in module_name:
|
65 |
+
# Wan2.1 lora name to module name: ugly but works
|
66 |
+
module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
|
67 |
+
module_name = module_name.replace("self.attn", "self_attn") # fix self attn
|
68 |
+
else:
|
69 |
+
# HunyuanVideo lora name to module name: ugly but works
|
70 |
+
module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
|
71 |
+
module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
|
72 |
+
module_name = module_name.replace("img.", "img_") # fix img
|
73 |
+
module_name = module_name.replace("txt.", "txt_") # fix txt
|
74 |
+
module_name = module_name.replace("attn.", "attn_") # fix attn
|
75 |
+
|
76 |
+
diffusers_prefix = "diffusion_model"
|
77 |
+
if "lora_down" in key:
|
78 |
+
new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
|
79 |
+
dim = weight.shape[0]
|
80 |
+
elif "lora_up" in key:
|
81 |
+
new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
|
82 |
+
dim = weight.shape[1]
|
83 |
+
else:
|
84 |
+
logger.warning(f"unexpected key: {key} in default LoRA format")
|
85 |
+
continue
|
86 |
+
|
87 |
+
# scale weight by alpha
|
88 |
+
if lora_name in lora_alphas:
|
89 |
+
# we scale both down and up, so scale is sqrt
|
90 |
+
scale = lora_alphas[lora_name] / dim
|
91 |
+
scale = scale.sqrt()
|
92 |
+
weight = weight * scale
|
93 |
+
else:
|
94 |
+
logger.warning(f"missing alpha for {lora_name}")
|
95 |
+
|
96 |
+
new_weights_sd[new_key] = weight
|
97 |
+
|
98 |
+
return new_weights_sd
|
99 |
+
|
100 |
+
|
101 |
+
def convert(input_file, output_file, target_format):
|
102 |
+
logger.info(f"loading {input_file}")
|
103 |
+
weights_sd = load_file(input_file)
|
104 |
+
with safe_open(input_file, framework="pt") as f:
|
105 |
+
metadata = f.metadata()
|
106 |
+
|
107 |
+
logger.info(f"converting to {target_format}")
|
108 |
+
prefix = "lora_unet_"
|
109 |
+
if target_format == "default":
|
110 |
+
new_weights_sd = convert_from_diffusers(prefix, weights_sd)
|
111 |
+
metadata = metadata or {}
|
112 |
+
model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
|
113 |
+
elif target_format == "other":
|
114 |
+
new_weights_sd = convert_to_diffusers(prefix, weights_sd)
|
115 |
+
else:
|
116 |
+
raise ValueError(f"unknown target format: {target_format}")
|
117 |
+
|
118 |
+
logger.info(f"saving to {output_file}")
|
119 |
+
save_file(new_weights_sd, output_file, metadata=metadata)
|
120 |
+
|
121 |
+
logger.info("done")
|
122 |
+
|
123 |
+
|
124 |
+
def parse_args():
|
125 |
+
parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
|
126 |
+
parser.add_argument("--input", type=str, required=True, help="input model file")
|
127 |
+
parser.add_argument("--output", type=str, required=True, help="output model file")
|
128 |
+
parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
|
129 |
+
args = parser.parse_args()
|
130 |
+
return args
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
args = parse_args()
|
135 |
+
convert(args.input, args.output, args.target)
|
dataset/__init__.py
ADDED
File without changes
|
dataset/config_utils.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import (
|
3 |
+
asdict,
|
4 |
+
dataclass,
|
5 |
+
)
|
6 |
+
import functools
|
7 |
+
import random
|
8 |
+
from textwrap import dedent, indent
|
9 |
+
import json
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# from toolz import curry
|
13 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
14 |
+
|
15 |
+
import toml
|
16 |
+
import voluptuous
|
17 |
+
from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
|
18 |
+
|
19 |
+
from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
|
20 |
+
|
21 |
+
import logging
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
logging.basicConfig(level=logging.INFO)
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class BaseDatasetParams:
|
29 |
+
resolution: Tuple[int, int] = (960, 544)
|
30 |
+
enable_bucket: bool = False
|
31 |
+
bucket_no_upscale: bool = False
|
32 |
+
caption_extension: Optional[str] = None
|
33 |
+
batch_size: int = 1
|
34 |
+
num_repeats: int = 1
|
35 |
+
cache_directory: Optional[str] = None
|
36 |
+
debug_dataset: bool = False
|
37 |
+
architecture: str = "no_default" # short style like "hv" or "wan"
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class ImageDatasetParams(BaseDatasetParams):
|
42 |
+
image_directory: Optional[str] = None
|
43 |
+
image_jsonl_file: Optional[str] = None
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class VideoDatasetParams(BaseDatasetParams):
|
48 |
+
video_directory: Optional[str] = None
|
49 |
+
video_jsonl_file: Optional[str] = None
|
50 |
+
target_frames: Sequence[int] = (1,)
|
51 |
+
frame_extraction: Optional[str] = "head"
|
52 |
+
frame_stride: Optional[int] = 1
|
53 |
+
frame_sample: Optional[int] = 1
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class DatasetBlueprint:
|
58 |
+
is_image_dataset: bool
|
59 |
+
params: Union[ImageDatasetParams, VideoDatasetParams]
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class DatasetGroupBlueprint:
|
64 |
+
datasets: Sequence[DatasetBlueprint]
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class Blueprint:
|
69 |
+
dataset_group: DatasetGroupBlueprint
|
70 |
+
|
71 |
+
|
72 |
+
class ConfigSanitizer:
|
73 |
+
# @curry
|
74 |
+
@staticmethod
|
75 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
76 |
+
Schema(ExactSequence([klass, klass]))(value)
|
77 |
+
return tuple(value)
|
78 |
+
|
79 |
+
# @curry
|
80 |
+
@staticmethod
|
81 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
82 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
83 |
+
try:
|
84 |
+
Schema(klass)(value)
|
85 |
+
return (value, value)
|
86 |
+
except:
|
87 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
88 |
+
|
89 |
+
# datasets schema
|
90 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
91 |
+
"caption_extension": str,
|
92 |
+
"batch_size": int,
|
93 |
+
"num_repeats": int,
|
94 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
95 |
+
"enable_bucket": bool,
|
96 |
+
"bucket_no_upscale": bool,
|
97 |
+
}
|
98 |
+
IMAGE_DATASET_DISTINCT_SCHEMA = {
|
99 |
+
"image_directory": str,
|
100 |
+
"image_jsonl_file": str,
|
101 |
+
"cache_directory": str,
|
102 |
+
}
|
103 |
+
VIDEO_DATASET_DISTINCT_SCHEMA = {
|
104 |
+
"video_directory": str,
|
105 |
+
"video_jsonl_file": str,
|
106 |
+
"target_frames": [int],
|
107 |
+
"frame_extraction": str,
|
108 |
+
"frame_stride": int,
|
109 |
+
"frame_sample": int,
|
110 |
+
"cache_directory": str,
|
111 |
+
}
|
112 |
+
|
113 |
+
# options handled by argparse but not handled by user config
|
114 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
115 |
+
"debug_dataset": bool,
|
116 |
+
}
|
117 |
+
|
118 |
+
def __init__(self) -> None:
|
119 |
+
self.image_dataset_schema = self.__merge_dict(
|
120 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
121 |
+
self.IMAGE_DATASET_DISTINCT_SCHEMA,
|
122 |
+
)
|
123 |
+
self.video_dataset_schema = self.__merge_dict(
|
124 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
125 |
+
self.VIDEO_DATASET_DISTINCT_SCHEMA,
|
126 |
+
)
|
127 |
+
|
128 |
+
def validate_flex_dataset(dataset_config: dict):
|
129 |
+
if "target_frames" in dataset_config:
|
130 |
+
return Schema(self.video_dataset_schema)(dataset_config)
|
131 |
+
else:
|
132 |
+
return Schema(self.image_dataset_schema)(dataset_config)
|
133 |
+
|
134 |
+
self.dataset_schema = validate_flex_dataset
|
135 |
+
|
136 |
+
self.general_schema = self.__merge_dict(
|
137 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
138 |
+
)
|
139 |
+
self.user_config_validator = Schema(
|
140 |
+
{
|
141 |
+
"general": self.general_schema,
|
142 |
+
"datasets": [self.dataset_schema],
|
143 |
+
}
|
144 |
+
)
|
145 |
+
self.argparse_schema = self.__merge_dict(
|
146 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
147 |
+
)
|
148 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
149 |
+
|
150 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
151 |
+
try:
|
152 |
+
return self.user_config_validator(user_config)
|
153 |
+
except MultipleInvalid:
|
154 |
+
# TODO: clarify the error message
|
155 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
156 |
+
raise
|
157 |
+
|
158 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
159 |
+
# However this will help us to detect program bug
|
160 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
161 |
+
try:
|
162 |
+
return self.argparse_config_validator(argparse_namespace)
|
163 |
+
except MultipleInvalid:
|
164 |
+
# XXX: this should be a bug
|
165 |
+
logger.error(
|
166 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
167 |
+
)
|
168 |
+
raise
|
169 |
+
|
170 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
171 |
+
@staticmethod
|
172 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
173 |
+
merged = {}
|
174 |
+
for schema in dict_list:
|
175 |
+
# merged |= schema
|
176 |
+
for k, v in schema.items():
|
177 |
+
merged[k] = v
|
178 |
+
return merged
|
179 |
+
|
180 |
+
|
181 |
+
class BlueprintGenerator:
|
182 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
183 |
+
|
184 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
185 |
+
self.sanitizer = sanitizer
|
186 |
+
|
187 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
188 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
189 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
190 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
191 |
+
|
192 |
+
argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
|
193 |
+
general_config = sanitized_user_config.get("general", {})
|
194 |
+
|
195 |
+
dataset_blueprints = []
|
196 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
197 |
+
is_image_dataset = "target_frames" not in dataset_config
|
198 |
+
if is_image_dataset:
|
199 |
+
dataset_params_klass = ImageDatasetParams
|
200 |
+
else:
|
201 |
+
dataset_params_klass = VideoDatasetParams
|
202 |
+
|
203 |
+
params = self.generate_params_by_fallbacks(
|
204 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
205 |
+
)
|
206 |
+
dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
|
207 |
+
|
208 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
209 |
+
|
210 |
+
return Blueprint(dataset_group_blueprint)
|
211 |
+
|
212 |
+
@staticmethod
|
213 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
214 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
215 |
+
search_value = BlueprintGenerator.search_value
|
216 |
+
default_params = asdict(param_klass())
|
217 |
+
param_names = default_params.keys()
|
218 |
+
|
219 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
220 |
+
|
221 |
+
return param_klass(**params)
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
225 |
+
for cand in fallbacks:
|
226 |
+
value = cand.get(key)
|
227 |
+
if value is not None:
|
228 |
+
return value
|
229 |
+
|
230 |
+
return default_value
|
231 |
+
|
232 |
+
|
233 |
+
# if training is True, it will return a dataset group for training, otherwise for caching
|
234 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
|
235 |
+
datasets: List[Union[ImageDataset, VideoDataset]] = []
|
236 |
+
|
237 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
238 |
+
if dataset_blueprint.is_image_dataset:
|
239 |
+
dataset_klass = ImageDataset
|
240 |
+
else:
|
241 |
+
dataset_klass = VideoDataset
|
242 |
+
|
243 |
+
dataset = dataset_klass(**asdict(dataset_blueprint.params))
|
244 |
+
datasets.append(dataset)
|
245 |
+
|
246 |
+
# assertion
|
247 |
+
cache_directories = [dataset.cache_directory for dataset in datasets]
|
248 |
+
num_of_unique_cache_directories = len(set(cache_directories))
|
249 |
+
if num_of_unique_cache_directories != len(cache_directories):
|
250 |
+
raise ValueError(
|
251 |
+
"cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)"
|
252 |
+
+ " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)"
|
253 |
+
)
|
254 |
+
|
255 |
+
# print info
|
256 |
+
info = ""
|
257 |
+
for i, dataset in enumerate(datasets):
|
258 |
+
is_image_dataset = isinstance(dataset, ImageDataset)
|
259 |
+
info += dedent(
|
260 |
+
f"""\
|
261 |
+
[Dataset {i}]
|
262 |
+
is_image_dataset: {is_image_dataset}
|
263 |
+
resolution: {dataset.resolution}
|
264 |
+
batch_size: {dataset.batch_size}
|
265 |
+
num_repeats: {dataset.num_repeats}
|
266 |
+
caption_extension: "{dataset.caption_extension}"
|
267 |
+
enable_bucket: {dataset.enable_bucket}
|
268 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
269 |
+
cache_directory: "{dataset.cache_directory}"
|
270 |
+
debug_dataset: {dataset.debug_dataset}
|
271 |
+
"""
|
272 |
+
)
|
273 |
+
|
274 |
+
if is_image_dataset:
|
275 |
+
info += indent(
|
276 |
+
dedent(
|
277 |
+
f"""\
|
278 |
+
image_directory: "{dataset.image_directory}"
|
279 |
+
image_jsonl_file: "{dataset.image_jsonl_file}"
|
280 |
+
\n"""
|
281 |
+
),
|
282 |
+
" ",
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
info += indent(
|
286 |
+
dedent(
|
287 |
+
f"""\
|
288 |
+
video_directory: "{dataset.video_directory}"
|
289 |
+
video_jsonl_file: "{dataset.video_jsonl_file}"
|
290 |
+
target_frames: {dataset.target_frames}
|
291 |
+
frame_extraction: {dataset.frame_extraction}
|
292 |
+
frame_stride: {dataset.frame_stride}
|
293 |
+
frame_sample: {dataset.frame_sample}
|
294 |
+
\n"""
|
295 |
+
),
|
296 |
+
" ",
|
297 |
+
)
|
298 |
+
logger.info(f"{info}")
|
299 |
+
|
300 |
+
# make buckets first because it determines the length of dataset
|
301 |
+
# and set the same seed for all datasets
|
302 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
303 |
+
for i, dataset in enumerate(datasets):
|
304 |
+
# logger.info(f"[Dataset {i}]")
|
305 |
+
dataset.set_seed(seed)
|
306 |
+
if training:
|
307 |
+
dataset.prepare_for_training()
|
308 |
+
|
309 |
+
return DatasetGroup(datasets)
|
310 |
+
|
311 |
+
|
312 |
+
def load_user_config(file: str) -> dict:
|
313 |
+
file: Path = Path(file)
|
314 |
+
if not file.is_file():
|
315 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
316 |
+
|
317 |
+
if file.name.lower().endswith(".json"):
|
318 |
+
try:
|
319 |
+
with open(file, "r", encoding="utf-8") as f:
|
320 |
+
config = json.load(f)
|
321 |
+
except Exception:
|
322 |
+
logger.error(
|
323 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
324 |
+
)
|
325 |
+
raise
|
326 |
+
elif file.name.lower().endswith(".toml"):
|
327 |
+
try:
|
328 |
+
config = toml.load(file)
|
329 |
+
except Exception:
|
330 |
+
logger.error(
|
331 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
332 |
+
)
|
333 |
+
raise
|
334 |
+
else:
|
335 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
336 |
+
|
337 |
+
return config
|
338 |
+
|
339 |
+
|
340 |
+
# for config test
|
341 |
+
if __name__ == "__main__":
|
342 |
+
parser = argparse.ArgumentParser()
|
343 |
+
parser.add_argument("dataset_config")
|
344 |
+
config_args, remain = parser.parse_known_args()
|
345 |
+
|
346 |
+
parser = argparse.ArgumentParser()
|
347 |
+
parser.add_argument("--debug_dataset", action="store_true")
|
348 |
+
argparse_namespace = parser.parse_args(remain)
|
349 |
+
|
350 |
+
logger.info("[argparse_namespace]")
|
351 |
+
logger.info(f"{vars(argparse_namespace)}")
|
352 |
+
|
353 |
+
user_config = load_user_config(config_args.dataset_config)
|
354 |
+
|
355 |
+
logger.info("")
|
356 |
+
logger.info("[user_config]")
|
357 |
+
logger.info(f"{user_config}")
|
358 |
+
|
359 |
+
sanitizer = ConfigSanitizer()
|
360 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
361 |
+
|
362 |
+
logger.info("")
|
363 |
+
logger.info("[sanitized_user_config]")
|
364 |
+
logger.info(f"{sanitized_user_config}")
|
365 |
+
|
366 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
367 |
+
|
368 |
+
logger.info("")
|
369 |
+
logger.info("[blueprint]")
|
370 |
+
logger.info(f"{blueprint}")
|
371 |
+
|
372 |
+
dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
dataset/dataset_config.md
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
2 |
+
|
3 |
+
## Dataset Configuration
|
4 |
+
|
5 |
+
<details>
|
6 |
+
<summary>English</summary>
|
7 |
+
|
8 |
+
Please create a TOML file for dataset configuration.
|
9 |
+
|
10 |
+
Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
|
11 |
+
|
12 |
+
The cache directory must be different for each dataset.
|
13 |
+
</details>
|
14 |
+
|
15 |
+
<details>
|
16 |
+
<summary>日本語</summary>
|
17 |
+
|
18 |
+
データセットの設定を行うためのTOMLファイルを作成してください。
|
19 |
+
|
20 |
+
画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
|
21 |
+
|
22 |
+
キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
|
23 |
+
</details>
|
24 |
+
|
25 |
+
### Sample for Image Dataset with Caption Text Files
|
26 |
+
|
27 |
+
```toml
|
28 |
+
# resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
29 |
+
# otherwise, the default values will be used for each item
|
30 |
+
|
31 |
+
# general configurations
|
32 |
+
[general]
|
33 |
+
resolution = [960, 544]
|
34 |
+
caption_extension = ".txt"
|
35 |
+
batch_size = 1
|
36 |
+
enable_bucket = true
|
37 |
+
bucket_no_upscale = false
|
38 |
+
|
39 |
+
[[datasets]]
|
40 |
+
image_directory = "/path/to/image_dir"
|
41 |
+
cache_directory = "/path/to/cache_directory"
|
42 |
+
num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
|
43 |
+
|
44 |
+
# other datasets can be added here. each dataset can have different configurations
|
45 |
+
```
|
46 |
+
|
47 |
+
<details>
|
48 |
+
<summary>English</summary>
|
49 |
+
|
50 |
+
`cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
|
51 |
+
|
52 |
+
`num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
|
53 |
+
|
54 |
+
</details>
|
55 |
+
|
56 |
+
<details>
|
57 |
+
<summary>日本語</summary>
|
58 |
+
|
59 |
+
`cache_directory` はオプションです。デフォルトは画像ディレクトリと同じディレクトリに設定されます。ただし、異なるデータセット間でキャッシュファイルが共有されるのを防ぐために、明示的に別のキャッシュディレクトリを設定することをお勧めします。
|
60 |
+
|
61 |
+
`num_repeats` はオプションで、デフォルトは 1 です(繰り返しなし)。画像(や動画)を、その回数だけ単純に繰り返してデータセットを拡張します。たとえば`num_repeats = 2`としたとき、画像20枚のデータセットなら、各画像が2枚ずつ(同一のキャプションで)計40枚存在した場合と同じになります。異なるデータ数のデータセット間でバランスを取るために使用可能です。
|
62 |
+
|
63 |
+
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
|
64 |
+
|
65 |
+
`[[datasets]]`以下を追加することで、他のデータセットを追加できます。各データセットには異なる設定を持てます。
|
66 |
+
</details>
|
67 |
+
|
68 |
+
### Sample for Image Dataset with Metadata JSONL File
|
69 |
+
|
70 |
+
```toml
|
71 |
+
# resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
72 |
+
# caption_extension is not required for metadata jsonl file
|
73 |
+
# cache_directory is required for each dataset with metadata jsonl file
|
74 |
+
|
75 |
+
# general configurations
|
76 |
+
[general]
|
77 |
+
resolution = [960, 544]
|
78 |
+
batch_size = 1
|
79 |
+
enable_bucket = true
|
80 |
+
bucket_no_upscale = false
|
81 |
+
|
82 |
+
[[datasets]]
|
83 |
+
image_jsonl_file = "/path/to/metadata.jsonl"
|
84 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
85 |
+
num_repeats = 1 # optional, default is 1. Same as above.
|
86 |
+
|
87 |
+
# other datasets can be added here. each dataset can have different configurations
|
88 |
+
```
|
89 |
+
|
90 |
+
JSONL file format for metadata:
|
91 |
+
|
92 |
+
```json
|
93 |
+
{"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
|
94 |
+
{"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
|
95 |
+
```
|
96 |
+
|
97 |
+
<details>
|
98 |
+
<summary>日本語</summary>
|
99 |
+
|
100 |
+
resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
|
101 |
+
|
102 |
+
metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須��す。
|
103 |
+
|
104 |
+
キャプションによるデータセットと同様に、複数のデータセットを追加できます。各データセットには異なる設定を持てます。
|
105 |
+
</details>
|
106 |
+
|
107 |
+
|
108 |
+
### Sample for Video Dataset with Caption Text Files
|
109 |
+
|
110 |
+
```toml
|
111 |
+
# resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample,
|
112 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
113 |
+
# num_repeats is also available for video dataset, example is not shown here
|
114 |
+
|
115 |
+
# general configurations
|
116 |
+
[general]
|
117 |
+
resolution = [960, 544]
|
118 |
+
caption_extension = ".txt"
|
119 |
+
batch_size = 1
|
120 |
+
enable_bucket = true
|
121 |
+
bucket_no_upscale = false
|
122 |
+
|
123 |
+
[[datasets]]
|
124 |
+
video_directory = "/path/to/video_dir"
|
125 |
+
cache_directory = "/path/to/cache_directory" # recommended to set cache directory
|
126 |
+
target_frames = [1, 25, 45]
|
127 |
+
frame_extraction = "head"
|
128 |
+
|
129 |
+
# other datasets can be added here. each dataset can have different configurations
|
130 |
+
```
|
131 |
+
|
132 |
+
<details>
|
133 |
+
<summary>日本語</summary>
|
134 |
+
|
135 |
+
resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
|
136 |
+
|
137 |
+
他の注意事項は画像データセットと同様です。
|
138 |
+
</details>
|
139 |
+
|
140 |
+
### Sample for Video Dataset with Metadata JSONL File
|
141 |
+
|
142 |
+
```toml
|
143 |
+
# resolution, target_frames, frame_extraction, frame_stride, frame_sample,
|
144 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
145 |
+
# caption_extension is not required for metadata jsonl file
|
146 |
+
# cache_directory is required for each dataset with metadata jsonl file
|
147 |
+
|
148 |
+
# general configurations
|
149 |
+
[general]
|
150 |
+
resolution = [960, 544]
|
151 |
+
batch_size = 1
|
152 |
+
enable_bucket = true
|
153 |
+
bucket_no_upscale = false
|
154 |
+
|
155 |
+
[[datasets]]
|
156 |
+
video_jsonl_file = "/path/to/metadata.jsonl"
|
157 |
+
target_frames = [1, 25, 45]
|
158 |
+
frame_extraction = "head"
|
159 |
+
cache_directory = "/path/to/cache_directory_head"
|
160 |
+
|
161 |
+
# same metadata jsonl file can be used for multiple datasets
|
162 |
+
[[datasets]]
|
163 |
+
video_jsonl_file = "/path/to/metadata.jsonl"
|
164 |
+
target_frames = [1]
|
165 |
+
frame_stride = 10
|
166 |
+
cache_directory = "/path/to/cache_directory_stride"
|
167 |
+
|
168 |
+
# other datasets can be added here. each dataset can have different configurations
|
169 |
+
```
|
170 |
+
|
171 |
+
JSONL file format for metadata:
|
172 |
+
|
173 |
+
```json
|
174 |
+
{"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
|
175 |
+
{"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
|
176 |
+
```
|
177 |
+
|
178 |
+
<details>
|
179 |
+
<summary>日本語</summary>
|
180 |
+
|
181 |
+
resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
|
182 |
+
|
183 |
+
metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
|
184 |
+
|
185 |
+
他の注意事項は今までのデータセットと同様です。
|
186 |
+
</details>
|
187 |
+
|
188 |
+
### frame_extraction Options
|
189 |
+
|
190 |
+
<details>
|
191 |
+
<summary>English</summary>
|
192 |
+
|
193 |
+
- `head`: Extract the first N frames from the video.
|
194 |
+
- `chunk`: Extract frames by splitting the video into chunks of N frames.
|
195 |
+
- `slide`: Extract frames from the video with a stride of `frame_stride`.
|
196 |
+
- `uniform`: Extract `frame_sample` samples uniformly from the video.
|
197 |
+
|
198 |
+
For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
|
199 |
+
</details>
|
200 |
+
|
201 |
+
<details>
|
202 |
+
<summary>日本語</summary>
|
203 |
+
|
204 |
+
- `head`: 動画から最初のNフレームを抽出します。
|
205 |
+
- `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
|
206 |
+
- `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
|
207 |
+
- `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
|
208 |
+
|
209 |
+
例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
|
210 |
+
</details>
|
211 |
+
|
212 |
+
```
|
213 |
+
Original Video, 40 frames: x = frame, o = no frame
|
214 |
+
oooooooooooooooooooooooooooooooooooooooo
|
215 |
+
|
216 |
+
head, target_frames = [1, 13, 25] -> extract head frames:
|
217 |
+
xooooooooooooooooooooooooooooooooooooooo
|
218 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
219 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
220 |
+
|
221 |
+
chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
|
222 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
223 |
+
oooooooooooooxxxxxxxxxxxxxoooooooooooooo
|
224 |
+
ooooooooooooooooooooooooooxxxxxxxxxxxxxo
|
225 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
226 |
+
|
227 |
+
NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
|
228 |
+
注: frame_extraction "chunk" を使用する場合、target_frames に 1 を含めないでください。全てのフレームが抽出されてしまいます。
|
229 |
+
|
230 |
+
slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
|
231 |
+
xooooooooooooooooooooooooooooooooooooooo
|
232 |
+
ooooooooooxooooooooooooooooooooooooooooo
|
233 |
+
ooooooooooooooooooooxooooooooooooooooooo
|
234 |
+
ooooooooooooooooooooooooooooooxooooooooo
|
235 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
236 |
+
ooooooooooxxxxxxxxxxxxxooooooooooooooooo
|
237 |
+
ooooooooooooooooooooxxxxxxxxxxxxxooooooo
|
238 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
239 |
+
ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
|
240 |
+
|
241 |
+
uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
|
242 |
+
xooooooooooooooooooooooooooooooooooooooo
|
243 |
+
oooooooooooooxoooooooooooooooooooooooooo
|
244 |
+
oooooooooooooooooooooooooxoooooooooooooo
|
245 |
+
ooooooooooooooooooooooooooooooooooooooox
|
246 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
247 |
+
oooooooooxxxxxxxxxxxxxoooooooooooooooooo
|
248 |
+
ooooooooooooooooooxxxxxxxxxxxxxooooooooo
|
249 |
+
oooooooooooooooooooooooooooxxxxxxxxxxxxx
|
250 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
251 |
+
oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
|
252 |
+
ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
|
253 |
+
oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
|
254 |
+
```
|
255 |
+
|
256 |
+
## Specifications
|
257 |
+
|
258 |
+
```toml
|
259 |
+
# general configurations
|
260 |
+
[general]
|
261 |
+
resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
|
262 |
+
caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
|
263 |
+
batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
|
264 |
+
num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
|
265 |
+
enable_bucket = true # optional, default is false. Enable bucketing for datasets
|
266 |
+
bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
|
267 |
+
|
268 |
+
### Image Dataset
|
269 |
+
|
270 |
+
# sample image dataset with caption text files
|
271 |
+
[[datasets]]
|
272 |
+
image_directory = "/path/to/image_dir"
|
273 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
274 |
+
resolution = [960, 544] # required if general resolution is not set
|
275 |
+
batch_size = 4 # optional, overwrite the default batch size
|
276 |
+
num_repeats = 1 # optional, overwrite the default num_repeats
|
277 |
+
enable_bucket = false # optional, overwrite the default bucketing setting
|
278 |
+
bucket_no_upscale = true # optional, overwrite the default bucketing setting
|
279 |
+
cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
|
280 |
+
|
281 |
+
# sample image dataset with metadata **jsonl** file
|
282 |
+
[[datasets]]
|
283 |
+
image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
|
284 |
+
resolution = [960, 544] # required if general resolution is not set
|
285 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
286 |
+
# caption_extension is not required for metadata jsonl file
|
287 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
|
288 |
+
|
289 |
+
### Video Dataset
|
290 |
+
|
291 |
+
# sample video dataset with caption text files
|
292 |
+
[[datasets]]
|
293 |
+
video_directory = "/path/to/video_dir"
|
294 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
295 |
+
resolution = [960, 544] # required if general resolution is not set
|
296 |
+
|
297 |
+
target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
|
298 |
+
|
299 |
+
# NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
|
300 |
+
|
301 |
+
frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
|
302 |
+
frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
|
303 |
+
frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
|
304 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
|
305 |
+
|
306 |
+
# sample video dataset with metadata jsonl file
|
307 |
+
[[datasets]]
|
308 |
+
video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
|
309 |
+
|
310 |
+
target_frames = [1, 79]
|
311 |
+
|
312 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
313 |
+
# frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
|
314 |
+
```
|
315 |
+
|
316 |
+
<!--
|
317 |
+
# sample image dataset with lance
|
318 |
+
[[datasets]]
|
319 |
+
image_lance_dataset = "/path/to/lance_dataset"
|
320 |
+
resolution = [960, 544] # required if general resolution is not set
|
321 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
|
322 |
+
-->
|
323 |
+
|
324 |
+
The metadata with .json file will be supported in the near future.
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
<!--
|
329 |
+
|
330 |
+
```toml
|
331 |
+
# general configurations
|
332 |
+
[general]
|
333 |
+
resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
|
334 |
+
caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
|
335 |
+
batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
|
336 |
+
enable_bucket = true # optional, default is false. Enable bucketing for datasets
|
337 |
+
bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
|
338 |
+
|
339 |
+
# sample image dataset with caption text files
|
340 |
+
[[datasets]]
|
341 |
+
image_directory = "/path/to/image_dir"
|
342 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
343 |
+
resolution = [960, 544] # required if general resolution is not set
|
344 |
+
batch_size = 4 # optional, overwrite the default batch size
|
345 |
+
enable_bucket = false # optional, overwrite the default bucketing setting
|
346 |
+
bucket_no_upscale = true # optional, overwrite the default bucketing setting
|
347 |
+
cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
|
348 |
+
|
349 |
+
# sample image dataset with metadata **jsonl** file
|
350 |
+
[[datasets]]
|
351 |
+
image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
|
352 |
+
resolution = [960, 544] # required if general resolution is not set
|
353 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
354 |
+
# caption_extension is not required for metadata jsonl file
|
355 |
+
# batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
|
356 |
+
|
357 |
+
# sample video dataset with caption text files
|
358 |
+
[[datasets]]
|
359 |
+
video_directory = "/path/to/video_dir"
|
360 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
361 |
+
resolution = [960, 544] # required if general resolution is not set
|
362 |
+
target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
|
363 |
+
frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
|
364 |
+
frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
|
365 |
+
frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
|
366 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
|
367 |
+
|
368 |
+
# sample video dataset with metadata jsonl file
|
369 |
+
[[datasets]]
|
370 |
+
video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
|
371 |
+
target_frames = [1, 79]
|
372 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
373 |
+
# frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
|
374 |
+
```
|
375 |
+
|
376 |
+
# sample image dataset with lance
|
377 |
+
[[datasets]]
|
378 |
+
image_lance_dataset = "/path/to/lance_dataset"
|
379 |
+
resolution = [960, 544] # required if general resolution is not set
|
380 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
|
381 |
+
|
382 |
+
The metadata with .json file will be supported in the near future.
|
383 |
+
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
-->
|
dataset/image_video_dataset.py
ADDED
@@ -0,0 +1,1400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import ThreadPoolExecutor
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
from typing import Optional, Sequence, Tuple, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from safetensors.torch import save_file, load_file
|
13 |
+
from safetensors import safe_open
|
14 |
+
from PIL import Image
|
15 |
+
import cv2
|
16 |
+
import av
|
17 |
+
|
18 |
+
from utils import safetensors_utils
|
19 |
+
from utils.model_utils import dtype_to_str
|
20 |
+
|
21 |
+
import logging
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
logging.basicConfig(level=logging.INFO)
|
25 |
+
|
26 |
+
|
27 |
+
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
|
28 |
+
|
29 |
+
try:
|
30 |
+
import pillow_avif
|
31 |
+
|
32 |
+
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
|
36 |
+
# JPEG-XL on Linux
|
37 |
+
try:
|
38 |
+
from jxlpy import JXLImagePlugin
|
39 |
+
|
40 |
+
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
41 |
+
except:
|
42 |
+
pass
|
43 |
+
|
44 |
+
# JPEG-XL on Windows
|
45 |
+
try:
|
46 |
+
import pillow_jxl
|
47 |
+
|
48 |
+
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
49 |
+
except:
|
50 |
+
pass
|
51 |
+
|
52 |
+
VIDEO_EXTENSIONS = [
|
53 |
+
".mp4",
|
54 |
+
".webm",
|
55 |
+
".avi",
|
56 |
+
".mkv",
|
57 |
+
".mov",
|
58 |
+
".flv",
|
59 |
+
".wmv",
|
60 |
+
".m4v",
|
61 |
+
".mpg",
|
62 |
+
".mpeg",
|
63 |
+
".MP4",
|
64 |
+
".WEBM",
|
65 |
+
".AVI",
|
66 |
+
".MKV",
|
67 |
+
".MOV",
|
68 |
+
".FLV",
|
69 |
+
".WMV",
|
70 |
+
".M4V",
|
71 |
+
".MPG",
|
72 |
+
".MPEG",
|
73 |
+
] # some of them are not tested
|
74 |
+
|
75 |
+
ARCHITECTURE_HUNYUAN_VIDEO = "hv"
|
76 |
+
ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
|
77 |
+
ARCHITECTURE_WAN = "wan"
|
78 |
+
ARCHITECTURE_WAN_FULL = "wan"
|
79 |
+
|
80 |
+
|
81 |
+
def glob_images(directory, base="*"):
|
82 |
+
img_paths = []
|
83 |
+
for ext in IMAGE_EXTENSIONS:
|
84 |
+
if base == "*":
|
85 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
86 |
+
else:
|
87 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
88 |
+
img_paths = list(set(img_paths)) # remove duplicates
|
89 |
+
img_paths.sort()
|
90 |
+
return img_paths
|
91 |
+
|
92 |
+
|
93 |
+
def glob_videos(directory, base="*"):
|
94 |
+
video_paths = []
|
95 |
+
for ext in VIDEO_EXTENSIONS:
|
96 |
+
if base == "*":
|
97 |
+
video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
98 |
+
else:
|
99 |
+
video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
100 |
+
video_paths = list(set(video_paths)) # remove duplicates
|
101 |
+
video_paths.sort()
|
102 |
+
return video_paths
|
103 |
+
|
104 |
+
|
105 |
+
def divisible_by(num: int, divisor: int) -> int:
|
106 |
+
return num - num % divisor
|
107 |
+
|
108 |
+
|
109 |
+
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
|
110 |
+
"""
|
111 |
+
Resize the image to the bucket resolution.
|
112 |
+
"""
|
113 |
+
is_pil_image = isinstance(image, Image.Image)
|
114 |
+
if is_pil_image:
|
115 |
+
image_width, image_height = image.size
|
116 |
+
else:
|
117 |
+
image_height, image_width = image.shape[:2]
|
118 |
+
|
119 |
+
if bucket_reso == (image_width, image_height):
|
120 |
+
return np.array(image) if is_pil_image else image
|
121 |
+
|
122 |
+
bucket_width, bucket_height = bucket_reso
|
123 |
+
if bucket_width == image_width or bucket_height == image_height:
|
124 |
+
image = np.array(image) if is_pil_image else image
|
125 |
+
else:
|
126 |
+
# resize the image to the bucket resolution to match the short side
|
127 |
+
scale_width = bucket_width / image_width
|
128 |
+
scale_height = bucket_height / image_height
|
129 |
+
scale = max(scale_width, scale_height)
|
130 |
+
image_width = int(image_width * scale + 0.5)
|
131 |
+
image_height = int(image_height * scale + 0.5)
|
132 |
+
|
133 |
+
if scale > 1:
|
134 |
+
image = Image.fromarray(image) if not is_pil_image else image
|
135 |
+
image = image.resize((image_width, image_height), Image.LANCZOS)
|
136 |
+
image = np.array(image)
|
137 |
+
else:
|
138 |
+
image = np.array(image) if is_pil_image else image
|
139 |
+
image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
|
140 |
+
|
141 |
+
# crop the image to the bucket resolution
|
142 |
+
crop_left = (image_width - bucket_width) // 2
|
143 |
+
crop_top = (image_height - bucket_height) // 2
|
144 |
+
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
|
145 |
+
return image
|
146 |
+
|
147 |
+
|
148 |
+
class ItemInfo:
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
item_key: str,
|
152 |
+
caption: str,
|
153 |
+
original_size: tuple[int, int],
|
154 |
+
bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
|
155 |
+
frame_count: Optional[int] = None,
|
156 |
+
content: Optional[np.ndarray] = None,
|
157 |
+
latent_cache_path: Optional[str] = None,
|
158 |
+
) -> None:
|
159 |
+
self.item_key = item_key
|
160 |
+
self.caption = caption
|
161 |
+
self.original_size = original_size
|
162 |
+
self.bucket_size = bucket_size
|
163 |
+
self.frame_count = frame_count
|
164 |
+
self.content = content
|
165 |
+
self.latent_cache_path = latent_cache_path
|
166 |
+
self.text_encoder_output_cache_path: Optional[str] = None
|
167 |
+
|
168 |
+
def __str__(self) -> str:
|
169 |
+
return (
|
170 |
+
f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
|
171 |
+
+ f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
|
172 |
+
+ f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
# We use simple if-else approach to support multiple architectures.
|
177 |
+
# Maybe we can use a plugin system in the future.
|
178 |
+
|
179 |
+
# the keys of the dict are `<content_type>_FxHxW_<dtype>` for latents
|
180 |
+
# and `<content_type>_<dtype|mask>` for other tensors
|
181 |
+
|
182 |
+
|
183 |
+
def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
|
184 |
+
"""HunyuanVideo architecture only"""
|
185 |
+
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
|
186 |
+
|
187 |
+
_, F, H, W = latent.shape
|
188 |
+
dtype_str = dtype_to_str(latent.dtype)
|
189 |
+
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
|
190 |
+
|
191 |
+
save_latent_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
|
192 |
+
|
193 |
+
|
194 |
+
def save_latent_cache_wan(
|
195 |
+
item_info: ItemInfo, latent: torch.Tensor, clip_embed: Optional[torch.Tensor], image_latent: Optional[torch.Tensor]
|
196 |
+
):
|
197 |
+
"""Wan architecture only"""
|
198 |
+
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
|
199 |
+
|
200 |
+
_, F, H, W = latent.shape
|
201 |
+
dtype_str = dtype_to_str(latent.dtype)
|
202 |
+
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
|
203 |
+
|
204 |
+
if clip_embed is not None:
|
205 |
+
sd[f"clip_{dtype_str}"] = clip_embed.detach().cpu()
|
206 |
+
|
207 |
+
if image_latent is not None:
|
208 |
+
sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
|
209 |
+
|
210 |
+
save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
|
211 |
+
|
212 |
+
|
213 |
+
def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
|
214 |
+
metadata = {
|
215 |
+
"architecture": arch_fullname,
|
216 |
+
"width": f"{item_info.original_size[0]}",
|
217 |
+
"height": f"{item_info.original_size[1]}",
|
218 |
+
"format_version": "1.0.1",
|
219 |
+
}
|
220 |
+
if item_info.frame_count is not None:
|
221 |
+
metadata["frame_count"] = f"{item_info.frame_count}"
|
222 |
+
|
223 |
+
for key, value in sd.items():
|
224 |
+
# NaN check and show warning, replace NaN with 0
|
225 |
+
if torch.isnan(value).any():
|
226 |
+
logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
|
227 |
+
value[torch.isnan(value)] = 0
|
228 |
+
|
229 |
+
latent_dir = os.path.dirname(item_info.latent_cache_path)
|
230 |
+
os.makedirs(latent_dir, exist_ok=True)
|
231 |
+
|
232 |
+
save_file(sd, item_info.latent_cache_path, metadata=metadata)
|
233 |
+
|
234 |
+
|
235 |
+
def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
|
236 |
+
"""HunyuanVideo architecture only"""
|
237 |
+
assert (
|
238 |
+
embed.dim() == 1 or embed.dim() == 2
|
239 |
+
), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
|
240 |
+
assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
|
241 |
+
|
242 |
+
sd = {}
|
243 |
+
dtype_str = dtype_to_str(embed.dtype)
|
244 |
+
text_encoder_type = "llm" if is_llm else "clipL"
|
245 |
+
sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
|
246 |
+
if mask is not None:
|
247 |
+
sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
|
248 |
+
|
249 |
+
save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
|
250 |
+
|
251 |
+
|
252 |
+
def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor):
|
253 |
+
"""Wan architecture only. Wan2.1 only has a single text encoder"""
|
254 |
+
|
255 |
+
sd = {}
|
256 |
+
dtype_str = dtype_to_str(embed.dtype)
|
257 |
+
text_encoder_type = "t5"
|
258 |
+
sd[f"varlen_{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
|
259 |
+
|
260 |
+
save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
|
261 |
+
|
262 |
+
|
263 |
+
def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
|
264 |
+
for key, value in sd.items():
|
265 |
+
# NaN check and show warning, replace NaN with 0
|
266 |
+
if torch.isnan(value).any():
|
267 |
+
logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
|
268 |
+
value[torch.isnan(value)] = 0
|
269 |
+
|
270 |
+
metadata = {
|
271 |
+
"architecture": arch_fullname,
|
272 |
+
"caption1": item_info.caption,
|
273 |
+
"format_version": "1.0.1",
|
274 |
+
}
|
275 |
+
|
276 |
+
if os.path.exists(item_info.text_encoder_output_cache_path):
|
277 |
+
# load existing cache and update metadata
|
278 |
+
with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
|
279 |
+
existing_metadata = f.metadata()
|
280 |
+
for key in f.keys():
|
281 |
+
if key not in sd: # avoid overwriting by existing cache, we keep the new one
|
282 |
+
sd[key] = f.get_tensor(key)
|
283 |
+
|
284 |
+
assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
|
285 |
+
if existing_metadata["caption1"] != metadata["caption1"]:
|
286 |
+
logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
|
287 |
+
# TODO verify format_version
|
288 |
+
|
289 |
+
existing_metadata.pop("caption1", None)
|
290 |
+
existing_metadata.pop("format_version", None)
|
291 |
+
metadata.update(existing_metadata) # copy existing metadata except caption and format_version
|
292 |
+
else:
|
293 |
+
text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
|
294 |
+
os.makedirs(text_encoder_output_dir, exist_ok=True)
|
295 |
+
|
296 |
+
safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
|
297 |
+
|
298 |
+
|
299 |
+
class BucketSelector:
|
300 |
+
RESOLUTION_STEPS_HUNYUAN = 16
|
301 |
+
RESOLUTION_STEPS_WAN = 16
|
302 |
+
|
303 |
+
def __init__(
|
304 |
+
self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
|
305 |
+
):
|
306 |
+
self.resolution = resolution
|
307 |
+
self.bucket_area = resolution[0] * resolution[1]
|
308 |
+
self.architecture = architecture
|
309 |
+
|
310 |
+
if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
|
311 |
+
self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
|
312 |
+
elif self.architecture == ARCHITECTURE_WAN:
|
313 |
+
self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
|
314 |
+
else:
|
315 |
+
raise ValueError(f"Invalid architecture: {self.architecture}")
|
316 |
+
|
317 |
+
if not enable_bucket:
|
318 |
+
# only define one bucket
|
319 |
+
self.bucket_resolutions = [resolution]
|
320 |
+
self.no_upscale = False
|
321 |
+
else:
|
322 |
+
# prepare bucket resolution
|
323 |
+
self.no_upscale = no_upscale
|
324 |
+
sqrt_size = int(math.sqrt(self.bucket_area))
|
325 |
+
min_size = divisible_by(sqrt_size // 2, self.reso_steps)
|
326 |
+
self.bucket_resolutions = []
|
327 |
+
for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
|
328 |
+
h = divisible_by(self.bucket_area // w, self.reso_steps)
|
329 |
+
self.bucket_resolutions.append((w, h))
|
330 |
+
self.bucket_resolutions.append((h, w))
|
331 |
+
|
332 |
+
self.bucket_resolutions = list(set(self.bucket_resolutions))
|
333 |
+
self.bucket_resolutions.sort()
|
334 |
+
|
335 |
+
# calculate aspect ratio to find the nearest resolution
|
336 |
+
self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
|
337 |
+
|
338 |
+
def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
|
339 |
+
"""
|
340 |
+
return the bucket resolution for the given image size, (width, height)
|
341 |
+
"""
|
342 |
+
area = image_size[0] * image_size[1]
|
343 |
+
if self.no_upscale and area <= self.bucket_area:
|
344 |
+
w, h = image_size
|
345 |
+
w = divisible_by(w, self.reso_steps)
|
346 |
+
h = divisible_by(h, self.reso_steps)
|
347 |
+
return w, h
|
348 |
+
|
349 |
+
aspect_ratio = image_size[0] / image_size[1]
|
350 |
+
ar_errors = self.aspect_ratios - aspect_ratio
|
351 |
+
bucket_id = np.abs(ar_errors).argmin()
|
352 |
+
return self.bucket_resolutions[bucket_id]
|
353 |
+
|
354 |
+
|
355 |
+
def load_video(
|
356 |
+
video_path: str,
|
357 |
+
start_frame: Optional[int] = None,
|
358 |
+
end_frame: Optional[int] = None,
|
359 |
+
bucket_selector: Optional[BucketSelector] = None,
|
360 |
+
bucket_reso: Optional[tuple[int, int]] = None,
|
361 |
+
) -> list[np.ndarray]:
|
362 |
+
"""
|
363 |
+
bucket_reso: if given, resize the video to the bucket resolution, (width, height)
|
364 |
+
"""
|
365 |
+
container = av.open(video_path)
|
366 |
+
video = []
|
367 |
+
for i, frame in enumerate(container.decode(video=0)):
|
368 |
+
if start_frame is not None and i < start_frame:
|
369 |
+
continue
|
370 |
+
if end_frame is not None and i >= end_frame:
|
371 |
+
break
|
372 |
+
frame = frame.to_image()
|
373 |
+
|
374 |
+
if bucket_selector is not None and bucket_reso is None:
|
375 |
+
bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
|
376 |
+
|
377 |
+
if bucket_reso is not None:
|
378 |
+
frame = resize_image_to_bucket(frame, bucket_reso)
|
379 |
+
else:
|
380 |
+
frame = np.array(frame)
|
381 |
+
|
382 |
+
video.append(frame)
|
383 |
+
container.close()
|
384 |
+
return video
|
385 |
+
|
386 |
+
|
387 |
+
class BucketBatchManager:
|
388 |
+
|
389 |
+
def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
|
390 |
+
self.batch_size = batch_size
|
391 |
+
self.buckets = bucketed_item_info
|
392 |
+
self.bucket_resos = list(self.buckets.keys())
|
393 |
+
self.bucket_resos.sort()
|
394 |
+
|
395 |
+
self.bucket_batch_indices = []
|
396 |
+
for bucket_reso in self.bucket_resos:
|
397 |
+
bucket = self.buckets[bucket_reso]
|
398 |
+
num_batches = math.ceil(len(bucket) / self.batch_size)
|
399 |
+
for i in range(num_batches):
|
400 |
+
self.bucket_batch_indices.append((bucket_reso, i))
|
401 |
+
|
402 |
+
self.shuffle()
|
403 |
+
|
404 |
+
def show_bucket_info(self):
|
405 |
+
for bucket_reso in self.bucket_resos:
|
406 |
+
bucket = self.buckets[bucket_reso]
|
407 |
+
logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
|
408 |
+
|
409 |
+
logger.info(f"total batches: {len(self)}")
|
410 |
+
|
411 |
+
def shuffle(self):
|
412 |
+
for bucket in self.buckets.values():
|
413 |
+
random.shuffle(bucket)
|
414 |
+
random.shuffle(self.bucket_batch_indices)
|
415 |
+
|
416 |
+
def __len__(self):
|
417 |
+
return len(self.bucket_batch_indices)
|
418 |
+
|
419 |
+
def __getitem__(self, idx):
|
420 |
+
bucket_reso, batch_idx = self.bucket_batch_indices[idx]
|
421 |
+
bucket = self.buckets[bucket_reso]
|
422 |
+
start = batch_idx * self.batch_size
|
423 |
+
end = min(start + self.batch_size, len(bucket))
|
424 |
+
|
425 |
+
batch_tensor_data = {}
|
426 |
+
varlen_keys = set()
|
427 |
+
for item_info in bucket[start:end]:
|
428 |
+
sd_latent = load_file(item_info.latent_cache_path)
|
429 |
+
sd_te = load_file(item_info.text_encoder_output_cache_path)
|
430 |
+
sd = {**sd_latent, **sd_te}
|
431 |
+
|
432 |
+
# TODO refactor this
|
433 |
+
for key in sd.keys():
|
434 |
+
is_varlen_key = key.startswith("varlen_") # varlen keys are not stacked
|
435 |
+
content_key = key
|
436 |
+
|
437 |
+
if is_varlen_key:
|
438 |
+
content_key = content_key.replace("varlen_", "")
|
439 |
+
|
440 |
+
if content_key.endswith("_mask"):
|
441 |
+
pass
|
442 |
+
else:
|
443 |
+
content_key = content_key.rsplit("_", 1)[0] # remove dtype
|
444 |
+
if content_key.startswith("latents_"):
|
445 |
+
content_key = content_key.rsplit("_", 1)[0] # remove FxHxW
|
446 |
+
|
447 |
+
if content_key not in batch_tensor_data:
|
448 |
+
batch_tensor_data[content_key] = []
|
449 |
+
batch_tensor_data[content_key].append(sd[key])
|
450 |
+
|
451 |
+
if is_varlen_key:
|
452 |
+
varlen_keys.add(content_key)
|
453 |
+
|
454 |
+
for key in batch_tensor_data.keys():
|
455 |
+
if key not in varlen_keys:
|
456 |
+
batch_tensor_data[key] = torch.stack(batch_tensor_data[key])
|
457 |
+
|
458 |
+
return batch_tensor_data
|
459 |
+
|
460 |
+
|
461 |
+
class ContentDatasource:
|
462 |
+
def __init__(self):
|
463 |
+
self.caption_only = False
|
464 |
+
|
465 |
+
def set_caption_only(self, caption_only: bool):
|
466 |
+
self.caption_only = caption_only
|
467 |
+
|
468 |
+
def is_indexable(self):
|
469 |
+
return False
|
470 |
+
|
471 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
472 |
+
"""
|
473 |
+
Returns caption. May not be called if is_indexable() returns False.
|
474 |
+
"""
|
475 |
+
raise NotImplementedError
|
476 |
+
|
477 |
+
def __len__(self):
|
478 |
+
raise NotImplementedError
|
479 |
+
|
480 |
+
def __iter__(self):
|
481 |
+
raise NotImplementedError
|
482 |
+
|
483 |
+
def __next__(self):
|
484 |
+
raise NotImplementedError
|
485 |
+
|
486 |
+
|
487 |
+
class ImageDatasource(ContentDatasource):
|
488 |
+
def __init__(self):
|
489 |
+
super().__init__()
|
490 |
+
|
491 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
|
492 |
+
"""
|
493 |
+
Returns image data as a tuple of image path, image, and caption for the given index.
|
494 |
+
Key must be unique and valid as a file name.
|
495 |
+
May not be called if is_indexable() returns False.
|
496 |
+
"""
|
497 |
+
raise NotImplementedError
|
498 |
+
|
499 |
+
|
500 |
+
class ImageDirectoryDatasource(ImageDatasource):
|
501 |
+
def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
|
502 |
+
super().__init__()
|
503 |
+
self.image_directory = image_directory
|
504 |
+
self.caption_extension = caption_extension
|
505 |
+
self.current_idx = 0
|
506 |
+
|
507 |
+
# glob images
|
508 |
+
logger.info(f"glob images in {self.image_directory}")
|
509 |
+
self.image_paths = glob_images(self.image_directory)
|
510 |
+
logger.info(f"found {len(self.image_paths)} images")
|
511 |
+
|
512 |
+
def is_indexable(self):
|
513 |
+
return True
|
514 |
+
|
515 |
+
def __len__(self):
|
516 |
+
return len(self.image_paths)
|
517 |
+
|
518 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
|
519 |
+
image_path = self.image_paths[idx]
|
520 |
+
image = Image.open(image_path).convert("RGB")
|
521 |
+
|
522 |
+
_, caption = self.get_caption(idx)
|
523 |
+
|
524 |
+
return image_path, image, caption
|
525 |
+
|
526 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
527 |
+
image_path = self.image_paths[idx]
|
528 |
+
caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
|
529 |
+
with open(caption_path, "r", encoding="utf-8") as f:
|
530 |
+
caption = f.read().strip()
|
531 |
+
return image_path, caption
|
532 |
+
|
533 |
+
def __iter__(self):
|
534 |
+
self.current_idx = 0
|
535 |
+
return self
|
536 |
+
|
537 |
+
def __next__(self) -> callable:
|
538 |
+
"""
|
539 |
+
Returns a fetcher function that returns image data.
|
540 |
+
"""
|
541 |
+
if self.current_idx >= len(self.image_paths):
|
542 |
+
raise StopIteration
|
543 |
+
|
544 |
+
if self.caption_only:
|
545 |
+
|
546 |
+
def create_caption_fetcher(index):
|
547 |
+
return lambda: self.get_caption(index)
|
548 |
+
|
549 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
550 |
+
else:
|
551 |
+
|
552 |
+
def create_image_fetcher(index):
|
553 |
+
return lambda: self.get_image_data(index)
|
554 |
+
|
555 |
+
fetcher = create_image_fetcher(self.current_idx)
|
556 |
+
|
557 |
+
self.current_idx += 1
|
558 |
+
return fetcher
|
559 |
+
|
560 |
+
|
561 |
+
class ImageJsonlDatasource(ImageDatasource):
|
562 |
+
def __init__(self, image_jsonl_file: str):
|
563 |
+
super().__init__()
|
564 |
+
self.image_jsonl_file = image_jsonl_file
|
565 |
+
self.current_idx = 0
|
566 |
+
|
567 |
+
# load jsonl
|
568 |
+
logger.info(f"load image jsonl from {self.image_jsonl_file}")
|
569 |
+
self.data = []
|
570 |
+
with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
|
571 |
+
for line in f:
|
572 |
+
try:
|
573 |
+
data = json.loads(line)
|
574 |
+
except json.JSONDecodeError:
|
575 |
+
logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}")
|
576 |
+
raise
|
577 |
+
self.data.append(data)
|
578 |
+
logger.info(f"loaded {len(self.data)} images")
|
579 |
+
|
580 |
+
def is_indexable(self):
|
581 |
+
return True
|
582 |
+
|
583 |
+
def __len__(self):
|
584 |
+
return len(self.data)
|
585 |
+
|
586 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
|
587 |
+
data = self.data[idx]
|
588 |
+
image_path = data["image_path"]
|
589 |
+
image = Image.open(image_path).convert("RGB")
|
590 |
+
|
591 |
+
caption = data["caption"]
|
592 |
+
|
593 |
+
return image_path, image, caption
|
594 |
+
|
595 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
596 |
+
data = self.data[idx]
|
597 |
+
image_path = data["image_path"]
|
598 |
+
caption = data["caption"]
|
599 |
+
return image_path, caption
|
600 |
+
|
601 |
+
def __iter__(self):
|
602 |
+
self.current_idx = 0
|
603 |
+
return self
|
604 |
+
|
605 |
+
def __next__(self) -> callable:
|
606 |
+
if self.current_idx >= len(self.data):
|
607 |
+
raise StopIteration
|
608 |
+
|
609 |
+
if self.caption_only:
|
610 |
+
|
611 |
+
def create_caption_fetcher(index):
|
612 |
+
return lambda: self.get_caption(index)
|
613 |
+
|
614 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
615 |
+
|
616 |
+
else:
|
617 |
+
|
618 |
+
def create_fetcher(index):
|
619 |
+
return lambda: self.get_image_data(index)
|
620 |
+
|
621 |
+
fetcher = create_fetcher(self.current_idx)
|
622 |
+
|
623 |
+
self.current_idx += 1
|
624 |
+
return fetcher
|
625 |
+
|
626 |
+
|
627 |
+
class VideoDatasource(ContentDatasource):
|
628 |
+
def __init__(self):
|
629 |
+
super().__init__()
|
630 |
+
|
631 |
+
# None means all frames
|
632 |
+
self.start_frame = None
|
633 |
+
self.end_frame = None
|
634 |
+
|
635 |
+
self.bucket_selector = None
|
636 |
+
|
637 |
+
def __len__(self):
|
638 |
+
raise NotImplementedError
|
639 |
+
|
640 |
+
def get_video_data_from_path(
|
641 |
+
self,
|
642 |
+
video_path: str,
|
643 |
+
start_frame: Optional[int] = None,
|
644 |
+
end_frame: Optional[int] = None,
|
645 |
+
bucket_selector: Optional[BucketSelector] = None,
|
646 |
+
) -> tuple[str, list[Image.Image], str]:
|
647 |
+
# this method can resize the video if bucket_selector is given to reduce the memory usage
|
648 |
+
|
649 |
+
start_frame = start_frame if start_frame is not None else self.start_frame
|
650 |
+
end_frame = end_frame if end_frame is not None else self.end_frame
|
651 |
+
bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
|
652 |
+
|
653 |
+
video = load_video(video_path, start_frame, end_frame, bucket_selector)
|
654 |
+
return video
|
655 |
+
|
656 |
+
def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
|
657 |
+
self.start_frame = start_frame
|
658 |
+
self.end_frame = end_frame
|
659 |
+
|
660 |
+
def set_bucket_selector(self, bucket_selector: BucketSelector):
|
661 |
+
self.bucket_selector = bucket_selector
|
662 |
+
|
663 |
+
def __iter__(self):
|
664 |
+
raise NotImplementedError
|
665 |
+
|
666 |
+
def __next__(self):
|
667 |
+
raise NotImplementedError
|
668 |
+
|
669 |
+
|
670 |
+
class VideoDirectoryDatasource(VideoDatasource):
|
671 |
+
def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
|
672 |
+
super().__init__()
|
673 |
+
self.video_directory = video_directory
|
674 |
+
self.caption_extension = caption_extension
|
675 |
+
self.current_idx = 0
|
676 |
+
|
677 |
+
# glob images
|
678 |
+
logger.info(f"glob images in {self.video_directory}")
|
679 |
+
self.video_paths = glob_videos(self.video_directory)
|
680 |
+
logger.info(f"found {len(self.video_paths)} videos")
|
681 |
+
|
682 |
+
def is_indexable(self):
|
683 |
+
return True
|
684 |
+
|
685 |
+
def __len__(self):
|
686 |
+
return len(self.video_paths)
|
687 |
+
|
688 |
+
def get_video_data(
|
689 |
+
self,
|
690 |
+
idx: int,
|
691 |
+
start_frame: Optional[int] = None,
|
692 |
+
end_frame: Optional[int] = None,
|
693 |
+
bucket_selector: Optional[BucketSelector] = None,
|
694 |
+
) -> tuple[str, list[Image.Image], str]:
|
695 |
+
video_path = self.video_paths[idx]
|
696 |
+
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
|
697 |
+
|
698 |
+
_, caption = self.get_caption(idx)
|
699 |
+
|
700 |
+
return video_path, video, caption
|
701 |
+
|
702 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
703 |
+
video_path = self.video_paths[idx]
|
704 |
+
caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
|
705 |
+
with open(caption_path, "r", encoding="utf-8") as f:
|
706 |
+
caption = f.read().strip()
|
707 |
+
return video_path, caption
|
708 |
+
|
709 |
+
def __iter__(self):
|
710 |
+
self.current_idx = 0
|
711 |
+
return self
|
712 |
+
|
713 |
+
def __next__(self):
|
714 |
+
if self.current_idx >= len(self.video_paths):
|
715 |
+
raise StopIteration
|
716 |
+
|
717 |
+
if self.caption_only:
|
718 |
+
|
719 |
+
def create_caption_fetcher(index):
|
720 |
+
return lambda: self.get_caption(index)
|
721 |
+
|
722 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
723 |
+
|
724 |
+
else:
|
725 |
+
|
726 |
+
def create_fetcher(index):
|
727 |
+
return lambda: self.get_video_data(index)
|
728 |
+
|
729 |
+
fetcher = create_fetcher(self.current_idx)
|
730 |
+
|
731 |
+
self.current_idx += 1
|
732 |
+
return fetcher
|
733 |
+
|
734 |
+
|
735 |
+
class VideoJsonlDatasource(VideoDatasource):
|
736 |
+
def __init__(self, video_jsonl_file: str):
|
737 |
+
super().__init__()
|
738 |
+
self.video_jsonl_file = video_jsonl_file
|
739 |
+
self.current_idx = 0
|
740 |
+
|
741 |
+
# load jsonl
|
742 |
+
logger.info(f"load video jsonl from {self.video_jsonl_file}")
|
743 |
+
self.data = []
|
744 |
+
with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
|
745 |
+
for line in f:
|
746 |
+
data = json.loads(line)
|
747 |
+
self.data.append(data)
|
748 |
+
logger.info(f"loaded {len(self.data)} videos")
|
749 |
+
|
750 |
+
def is_indexable(self):
|
751 |
+
return True
|
752 |
+
|
753 |
+
def __len__(self):
|
754 |
+
return len(self.data)
|
755 |
+
|
756 |
+
def get_video_data(
|
757 |
+
self,
|
758 |
+
idx: int,
|
759 |
+
start_frame: Optional[int] = None,
|
760 |
+
end_frame: Optional[int] = None,
|
761 |
+
bucket_selector: Optional[BucketSelector] = None,
|
762 |
+
) -> tuple[str, list[Image.Image], str]:
|
763 |
+
data = self.data[idx]
|
764 |
+
video_path = data["video_path"]
|
765 |
+
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
|
766 |
+
|
767 |
+
caption = data["caption"]
|
768 |
+
|
769 |
+
return video_path, video, caption
|
770 |
+
|
771 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
772 |
+
data = self.data[idx]
|
773 |
+
video_path = data["video_path"]
|
774 |
+
caption = data["caption"]
|
775 |
+
return video_path, caption
|
776 |
+
|
777 |
+
def __iter__(self):
|
778 |
+
self.current_idx = 0
|
779 |
+
return self
|
780 |
+
|
781 |
+
def __next__(self):
|
782 |
+
if self.current_idx >= len(self.data):
|
783 |
+
raise StopIteration
|
784 |
+
|
785 |
+
if self.caption_only:
|
786 |
+
|
787 |
+
def create_caption_fetcher(index):
|
788 |
+
return lambda: self.get_caption(index)
|
789 |
+
|
790 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
791 |
+
|
792 |
+
else:
|
793 |
+
|
794 |
+
def create_fetcher(index):
|
795 |
+
return lambda: self.get_video_data(index)
|
796 |
+
|
797 |
+
fetcher = create_fetcher(self.current_idx)
|
798 |
+
|
799 |
+
self.current_idx += 1
|
800 |
+
return fetcher
|
801 |
+
|
802 |
+
|
803 |
+
class BaseDataset(torch.utils.data.Dataset):
|
804 |
+
def __init__(
|
805 |
+
self,
|
806 |
+
resolution: Tuple[int, int] = (960, 544),
|
807 |
+
caption_extension: Optional[str] = None,
|
808 |
+
batch_size: int = 1,
|
809 |
+
num_repeats: int = 1,
|
810 |
+
enable_bucket: bool = False,
|
811 |
+
bucket_no_upscale: bool = False,
|
812 |
+
cache_directory: Optional[str] = None,
|
813 |
+
debug_dataset: bool = False,
|
814 |
+
architecture: str = "no_default",
|
815 |
+
):
|
816 |
+
self.resolution = resolution
|
817 |
+
self.caption_extension = caption_extension
|
818 |
+
self.batch_size = batch_size
|
819 |
+
self.num_repeats = num_repeats
|
820 |
+
self.enable_bucket = enable_bucket
|
821 |
+
self.bucket_no_upscale = bucket_no_upscale
|
822 |
+
self.cache_directory = cache_directory
|
823 |
+
self.debug_dataset = debug_dataset
|
824 |
+
self.architecture = architecture
|
825 |
+
self.seed = None
|
826 |
+
self.current_epoch = 0
|
827 |
+
|
828 |
+
if not self.enable_bucket:
|
829 |
+
self.bucket_no_upscale = False
|
830 |
+
|
831 |
+
def get_metadata(self) -> dict:
|
832 |
+
metadata = {
|
833 |
+
"resolution": self.resolution,
|
834 |
+
"caption_extension": self.caption_extension,
|
835 |
+
"batch_size_per_device": self.batch_size,
|
836 |
+
"num_repeats": self.num_repeats,
|
837 |
+
"enable_bucket": bool(self.enable_bucket),
|
838 |
+
"bucket_no_upscale": bool(self.bucket_no_upscale),
|
839 |
+
}
|
840 |
+
return metadata
|
841 |
+
|
842 |
+
def get_all_latent_cache_files(self):
|
843 |
+
return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
844 |
+
|
845 |
+
def get_all_text_encoder_output_cache_files(self):
|
846 |
+
return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}_te.safetensors"))
|
847 |
+
|
848 |
+
def get_latent_cache_path(self, item_info: ItemInfo) -> str:
|
849 |
+
"""
|
850 |
+
Returns the cache path for the latent tensor.
|
851 |
+
|
852 |
+
item_info: ItemInfo object
|
853 |
+
|
854 |
+
Returns:
|
855 |
+
str: cache path
|
856 |
+
|
857 |
+
cache_path is based on the item_key and the resolution.
|
858 |
+
"""
|
859 |
+
w, h = item_info.original_size
|
860 |
+
basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
|
861 |
+
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
|
862 |
+
return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors")
|
863 |
+
|
864 |
+
def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
|
865 |
+
basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
|
866 |
+
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
|
867 |
+
return os.path.join(self.cache_directory, f"{basename}_{self.architecture}_te.safetensors")
|
868 |
+
|
869 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
870 |
+
raise NotImplementedError
|
871 |
+
|
872 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
873 |
+
raise NotImplementedError
|
874 |
+
|
875 |
+
def prepare_for_training(self):
|
876 |
+
pass
|
877 |
+
|
878 |
+
def set_seed(self, seed: int):
|
879 |
+
self.seed = seed
|
880 |
+
|
881 |
+
def set_current_epoch(self, epoch):
|
882 |
+
if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
|
883 |
+
if epoch > self.current_epoch:
|
884 |
+
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
885 |
+
num_epochs = epoch - self.current_epoch
|
886 |
+
for _ in range(num_epochs):
|
887 |
+
self.current_epoch += 1
|
888 |
+
self.shuffle_buckets()
|
889 |
+
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
|
890 |
+
else:
|
891 |
+
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
892 |
+
self.current_epoch = epoch
|
893 |
+
|
894 |
+
def set_current_step(self, step):
|
895 |
+
self.current_step = step
|
896 |
+
|
897 |
+
def set_max_train_steps(self, max_train_steps):
|
898 |
+
self.max_train_steps = max_train_steps
|
899 |
+
|
900 |
+
def shuffle_buckets(self):
|
901 |
+
raise NotImplementedError
|
902 |
+
|
903 |
+
def __len__(self):
|
904 |
+
return NotImplementedError
|
905 |
+
|
906 |
+
def __getitem__(self, idx):
|
907 |
+
raise NotImplementedError
|
908 |
+
|
909 |
+
def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
|
910 |
+
datasource.set_caption_only(True)
|
911 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
912 |
+
|
913 |
+
data: list[ItemInfo] = []
|
914 |
+
futures = []
|
915 |
+
|
916 |
+
def aggregate_future(consume_all: bool = False):
|
917 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
918 |
+
completed_futures = [future for future in futures if future.done()]
|
919 |
+
if len(completed_futures) == 0:
|
920 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
921 |
+
time.sleep(0.1)
|
922 |
+
continue
|
923 |
+
else:
|
924 |
+
break # submit batch if possible
|
925 |
+
|
926 |
+
for future in completed_futures:
|
927 |
+
item_key, caption = future.result()
|
928 |
+
item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
|
929 |
+
item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
|
930 |
+
data.append(item_info)
|
931 |
+
|
932 |
+
futures.remove(future)
|
933 |
+
|
934 |
+
def submit_batch(flush: bool = False):
|
935 |
+
nonlocal data
|
936 |
+
if len(data) >= batch_size or (len(data) > 0 and flush):
|
937 |
+
batch = data[0:batch_size]
|
938 |
+
if len(data) > batch_size:
|
939 |
+
data = data[batch_size:]
|
940 |
+
else:
|
941 |
+
data = []
|
942 |
+
return batch
|
943 |
+
return None
|
944 |
+
|
945 |
+
for fetch_op in datasource:
|
946 |
+
future = executor.submit(fetch_op)
|
947 |
+
futures.append(future)
|
948 |
+
aggregate_future()
|
949 |
+
while True:
|
950 |
+
batch = submit_batch()
|
951 |
+
if batch is None:
|
952 |
+
break
|
953 |
+
yield batch
|
954 |
+
|
955 |
+
aggregate_future(consume_all=True)
|
956 |
+
while True:
|
957 |
+
batch = submit_batch(flush=True)
|
958 |
+
if batch is None:
|
959 |
+
break
|
960 |
+
yield batch
|
961 |
+
|
962 |
+
executor.shutdown()
|
963 |
+
|
964 |
+
|
965 |
+
class ImageDataset(BaseDataset):
|
966 |
+
def __init__(
|
967 |
+
self,
|
968 |
+
resolution: Tuple[int, int],
|
969 |
+
caption_extension: Optional[str],
|
970 |
+
batch_size: int,
|
971 |
+
num_repeats: int,
|
972 |
+
enable_bucket: bool,
|
973 |
+
bucket_no_upscale: bool,
|
974 |
+
image_directory: Optional[str] = None,
|
975 |
+
image_jsonl_file: Optional[str] = None,
|
976 |
+
cache_directory: Optional[str] = None,
|
977 |
+
debug_dataset: bool = False,
|
978 |
+
architecture: str = "no_default",
|
979 |
+
):
|
980 |
+
super(ImageDataset, self).__init__(
|
981 |
+
resolution,
|
982 |
+
caption_extension,
|
983 |
+
batch_size,
|
984 |
+
num_repeats,
|
985 |
+
enable_bucket,
|
986 |
+
bucket_no_upscale,
|
987 |
+
cache_directory,
|
988 |
+
debug_dataset,
|
989 |
+
architecture,
|
990 |
+
)
|
991 |
+
self.image_directory = image_directory
|
992 |
+
self.image_jsonl_file = image_jsonl_file
|
993 |
+
if image_directory is not None:
|
994 |
+
self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
|
995 |
+
elif image_jsonl_file is not None:
|
996 |
+
self.datasource = ImageJsonlDatasource(image_jsonl_file)
|
997 |
+
else:
|
998 |
+
raise ValueError("image_directory or image_jsonl_file must be specified")
|
999 |
+
|
1000 |
+
if self.cache_directory is None:
|
1001 |
+
self.cache_directory = self.image_directory
|
1002 |
+
|
1003 |
+
self.batch_manager = None
|
1004 |
+
self.num_train_items = 0
|
1005 |
+
|
1006 |
+
def get_metadata(self):
|
1007 |
+
metadata = super().get_metadata()
|
1008 |
+
if self.image_directory is not None:
|
1009 |
+
metadata["image_directory"] = os.path.basename(self.image_directory)
|
1010 |
+
if self.image_jsonl_file is not None:
|
1011 |
+
metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
|
1012 |
+
return metadata
|
1013 |
+
|
1014 |
+
def get_total_image_count(self):
|
1015 |
+
return len(self.datasource) if self.datasource.is_indexable() else None
|
1016 |
+
|
1017 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
1018 |
+
buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
1019 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
1020 |
+
|
1021 |
+
batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
|
1022 |
+
futures = []
|
1023 |
+
|
1024 |
+
# aggregate futures and sort by bucket resolution
|
1025 |
+
def aggregate_future(consume_all: bool = False):
|
1026 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
1027 |
+
completed_futures = [future for future in futures if future.done()]
|
1028 |
+
if len(completed_futures) == 0:
|
1029 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
1030 |
+
time.sleep(0.1)
|
1031 |
+
continue
|
1032 |
+
else:
|
1033 |
+
break # submit batch if possible
|
1034 |
+
|
1035 |
+
for future in completed_futures:
|
1036 |
+
original_size, item_key, image, caption = future.result()
|
1037 |
+
bucket_height, bucket_width = image.shape[:2]
|
1038 |
+
bucket_reso = (bucket_width, bucket_height)
|
1039 |
+
|
1040 |
+
item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
|
1041 |
+
item_info.latent_cache_path = self.get_latent_cache_path(item_info)
|
1042 |
+
|
1043 |
+
if bucket_reso not in batches:
|
1044 |
+
batches[bucket_reso] = []
|
1045 |
+
batches[bucket_reso].append(item_info)
|
1046 |
+
|
1047 |
+
futures.remove(future)
|
1048 |
+
|
1049 |
+
# submit batch if some bucket has enough items
|
1050 |
+
def submit_batch(flush: bool = False):
|
1051 |
+
for key in batches:
|
1052 |
+
if len(batches[key]) >= self.batch_size or flush:
|
1053 |
+
batch = batches[key][0 : self.batch_size]
|
1054 |
+
if len(batches[key]) > self.batch_size:
|
1055 |
+
batches[key] = batches[key][self.batch_size :]
|
1056 |
+
else:
|
1057 |
+
del batches[key]
|
1058 |
+
return key, batch
|
1059 |
+
return None, None
|
1060 |
+
|
1061 |
+
for fetch_op in self.datasource:
|
1062 |
+
|
1063 |
+
# fetch and resize image in a separate thread
|
1064 |
+
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
|
1065 |
+
image_key, image, caption = op()
|
1066 |
+
image: Image.Image
|
1067 |
+
image_size = image.size
|
1068 |
+
|
1069 |
+
bucket_reso = buckset_selector.get_bucket_resolution(image_size)
|
1070 |
+
image = resize_image_to_bucket(image, bucket_reso)
|
1071 |
+
return image_size, image_key, image, caption
|
1072 |
+
|
1073 |
+
future = executor.submit(fetch_and_resize, fetch_op)
|
1074 |
+
futures.append(future)
|
1075 |
+
aggregate_future()
|
1076 |
+
while True:
|
1077 |
+
key, batch = submit_batch()
|
1078 |
+
if key is None:
|
1079 |
+
break
|
1080 |
+
yield key, batch
|
1081 |
+
|
1082 |
+
aggregate_future(consume_all=True)
|
1083 |
+
while True:
|
1084 |
+
key, batch = submit_batch(flush=True)
|
1085 |
+
if key is None:
|
1086 |
+
break
|
1087 |
+
yield key, batch
|
1088 |
+
|
1089 |
+
executor.shutdown()
|
1090 |
+
|
1091 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
1092 |
+
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
|
1093 |
+
|
1094 |
+
def prepare_for_training(self):
|
1095 |
+
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
1096 |
+
|
1097 |
+
# glob cache files
|
1098 |
+
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
1099 |
+
|
1100 |
+
# assign cache files to item info
|
1101 |
+
bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
|
1102 |
+
for cache_file in latent_cache_files:
|
1103 |
+
tokens = os.path.basename(cache_file).split("_")
|
1104 |
+
|
1105 |
+
image_size = tokens[-2] # 0000x0000
|
1106 |
+
image_width, image_height = map(int, image_size.split("x"))
|
1107 |
+
image_size = (image_width, image_height)
|
1108 |
+
|
1109 |
+
item_key = "_".join(tokens[:-2])
|
1110 |
+
text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
|
1111 |
+
if not os.path.exists(text_encoder_output_cache_file):
|
1112 |
+
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
|
1113 |
+
continue
|
1114 |
+
|
1115 |
+
bucket_reso = bucket_selector.get_bucket_resolution(image_size)
|
1116 |
+
item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
|
1117 |
+
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
|
1118 |
+
|
1119 |
+
bucket = bucketed_item_info.get(bucket_reso, [])
|
1120 |
+
for _ in range(self.num_repeats):
|
1121 |
+
bucket.append(item_info)
|
1122 |
+
bucketed_item_info[bucket_reso] = bucket
|
1123 |
+
|
1124 |
+
# prepare batch manager
|
1125 |
+
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
|
1126 |
+
self.batch_manager.show_bucket_info()
|
1127 |
+
|
1128 |
+
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
|
1129 |
+
|
1130 |
+
def shuffle_buckets(self):
|
1131 |
+
# set random seed for this epoch
|
1132 |
+
random.seed(self.seed + self.current_epoch)
|
1133 |
+
self.batch_manager.shuffle()
|
1134 |
+
|
1135 |
+
def __len__(self):
|
1136 |
+
if self.batch_manager is None:
|
1137 |
+
return 100 # dummy value
|
1138 |
+
return len(self.batch_manager)
|
1139 |
+
|
1140 |
+
def __getitem__(self, idx):
|
1141 |
+
return self.batch_manager[idx]
|
1142 |
+
|
1143 |
+
|
1144 |
+
class VideoDataset(BaseDataset):
|
1145 |
+
def __init__(
|
1146 |
+
self,
|
1147 |
+
resolution: Tuple[int, int],
|
1148 |
+
caption_extension: Optional[str],
|
1149 |
+
batch_size: int,
|
1150 |
+
num_repeats: int,
|
1151 |
+
enable_bucket: bool,
|
1152 |
+
bucket_no_upscale: bool,
|
1153 |
+
frame_extraction: Optional[str] = "head",
|
1154 |
+
frame_stride: Optional[int] = 1,
|
1155 |
+
frame_sample: Optional[int] = 1,
|
1156 |
+
target_frames: Optional[list[int]] = None,
|
1157 |
+
video_directory: Optional[str] = None,
|
1158 |
+
video_jsonl_file: Optional[str] = None,
|
1159 |
+
cache_directory: Optional[str] = None,
|
1160 |
+
debug_dataset: bool = False,
|
1161 |
+
architecture: str = "no_default",
|
1162 |
+
):
|
1163 |
+
super(VideoDataset, self).__init__(
|
1164 |
+
resolution,
|
1165 |
+
caption_extension,
|
1166 |
+
batch_size,
|
1167 |
+
num_repeats,
|
1168 |
+
enable_bucket,
|
1169 |
+
bucket_no_upscale,
|
1170 |
+
cache_directory,
|
1171 |
+
debug_dataset,
|
1172 |
+
architecture,
|
1173 |
+
)
|
1174 |
+
self.video_directory = video_directory
|
1175 |
+
self.video_jsonl_file = video_jsonl_file
|
1176 |
+
self.target_frames = target_frames
|
1177 |
+
self.frame_extraction = frame_extraction
|
1178 |
+
self.frame_stride = frame_stride
|
1179 |
+
self.frame_sample = frame_sample
|
1180 |
+
|
1181 |
+
if video_directory is not None:
|
1182 |
+
self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
|
1183 |
+
elif video_jsonl_file is not None:
|
1184 |
+
self.datasource = VideoJsonlDatasource(video_jsonl_file)
|
1185 |
+
|
1186 |
+
if self.frame_extraction == "uniform" and self.frame_sample == 1:
|
1187 |
+
self.frame_extraction = "head"
|
1188 |
+
logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
|
1189 |
+
if self.frame_extraction == "head":
|
1190 |
+
# head extraction. we can limit the number of frames to be extracted
|
1191 |
+
self.datasource.set_start_and_end_frame(0, max(self.target_frames))
|
1192 |
+
|
1193 |
+
if self.cache_directory is None:
|
1194 |
+
self.cache_directory = self.video_directory
|
1195 |
+
|
1196 |
+
self.batch_manager = None
|
1197 |
+
self.num_train_items = 0
|
1198 |
+
|
1199 |
+
def get_metadata(self):
|
1200 |
+
metadata = super().get_metadata()
|
1201 |
+
if self.video_directory is not None:
|
1202 |
+
metadata["video_directory"] = os.path.basename(self.video_directory)
|
1203 |
+
if self.video_jsonl_file is not None:
|
1204 |
+
metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
|
1205 |
+
metadata["frame_extraction"] = self.frame_extraction
|
1206 |
+
metadata["frame_stride"] = self.frame_stride
|
1207 |
+
metadata["frame_sample"] = self.frame_sample
|
1208 |
+
metadata["target_frames"] = self.target_frames
|
1209 |
+
return metadata
|
1210 |
+
|
1211 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
1212 |
+
buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
|
1213 |
+
self.datasource.set_bucket_selector(buckset_selector)
|
1214 |
+
|
1215 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
1216 |
+
|
1217 |
+
# key: (width, height, frame_count), value: [ItemInfo]
|
1218 |
+
batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
|
1219 |
+
futures = []
|
1220 |
+
|
1221 |
+
def aggregate_future(consume_all: bool = False):
|
1222 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
1223 |
+
completed_futures = [future for future in futures if future.done()]
|
1224 |
+
if len(completed_futures) == 0:
|
1225 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
1226 |
+
time.sleep(0.1)
|
1227 |
+
continue
|
1228 |
+
else:
|
1229 |
+
break # submit batch if possible
|
1230 |
+
|
1231 |
+
for future in completed_futures:
|
1232 |
+
original_frame_size, video_key, video, caption = future.result()
|
1233 |
+
|
1234 |
+
frame_count = len(video)
|
1235 |
+
video = np.stack(video, axis=0)
|
1236 |
+
height, width = video.shape[1:3]
|
1237 |
+
bucket_reso = (width, height) # already resized
|
1238 |
+
|
1239 |
+
crop_pos_and_frames = []
|
1240 |
+
if self.frame_extraction == "head":
|
1241 |
+
for target_frame in self.target_frames:
|
1242 |
+
if frame_count >= target_frame:
|
1243 |
+
crop_pos_and_frames.append((0, target_frame))
|
1244 |
+
elif self.frame_extraction == "chunk":
|
1245 |
+
# split by target_frames
|
1246 |
+
for target_frame in self.target_frames:
|
1247 |
+
for i in range(0, frame_count, target_frame):
|
1248 |
+
if i + target_frame <= frame_count:
|
1249 |
+
crop_pos_and_frames.append((i, target_frame))
|
1250 |
+
elif self.frame_extraction == "slide":
|
1251 |
+
# slide window
|
1252 |
+
for target_frame in self.target_frames:
|
1253 |
+
if frame_count >= target_frame:
|
1254 |
+
for i in range(0, frame_count - target_frame + 1, self.frame_stride):
|
1255 |
+
crop_pos_and_frames.append((i, target_frame))
|
1256 |
+
elif self.frame_extraction == "uniform":
|
1257 |
+
# select N frames uniformly
|
1258 |
+
for target_frame in self.target_frames:
|
1259 |
+
if frame_count >= target_frame:
|
1260 |
+
frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
|
1261 |
+
for i in frame_indices:
|
1262 |
+
crop_pos_and_frames.append((i, target_frame))
|
1263 |
+
else:
|
1264 |
+
raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
|
1265 |
+
|
1266 |
+
for crop_pos, target_frame in crop_pos_and_frames:
|
1267 |
+
cropped_video = video[crop_pos : crop_pos + target_frame]
|
1268 |
+
body, ext = os.path.splitext(video_key)
|
1269 |
+
item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
|
1270 |
+
batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
|
1271 |
+
|
1272 |
+
item_info = ItemInfo(
|
1273 |
+
item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
|
1274 |
+
)
|
1275 |
+
item_info.latent_cache_path = self.get_latent_cache_path(item_info)
|
1276 |
+
|
1277 |
+
batch = batches.get(batch_key, [])
|
1278 |
+
batch.append(item_info)
|
1279 |
+
batches[batch_key] = batch
|
1280 |
+
|
1281 |
+
futures.remove(future)
|
1282 |
+
|
1283 |
+
def submit_batch(flush: bool = False):
|
1284 |
+
for key in batches:
|
1285 |
+
if len(batches[key]) >= self.batch_size or flush:
|
1286 |
+
batch = batches[key][0 : self.batch_size]
|
1287 |
+
if len(batches[key]) > self.batch_size:
|
1288 |
+
batches[key] = batches[key][self.batch_size :]
|
1289 |
+
else:
|
1290 |
+
del batches[key]
|
1291 |
+
return key, batch
|
1292 |
+
return None, None
|
1293 |
+
|
1294 |
+
for operator in self.datasource:
|
1295 |
+
|
1296 |
+
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
|
1297 |
+
video_key, video, caption = op()
|
1298 |
+
video: list[np.ndarray]
|
1299 |
+
frame_size = (video[0].shape[1], video[0].shape[0])
|
1300 |
+
|
1301 |
+
# resize if necessary
|
1302 |
+
bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
|
1303 |
+
video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
|
1304 |
+
|
1305 |
+
return frame_size, video_key, video, caption
|
1306 |
+
|
1307 |
+
future = executor.submit(fetch_and_resize, operator)
|
1308 |
+
futures.append(future)
|
1309 |
+
aggregate_future()
|
1310 |
+
while True:
|
1311 |
+
key, batch = submit_batch()
|
1312 |
+
if key is None:
|
1313 |
+
break
|
1314 |
+
yield key, batch
|
1315 |
+
|
1316 |
+
aggregate_future(consume_all=True)
|
1317 |
+
while True:
|
1318 |
+
key, batch = submit_batch(flush=True)
|
1319 |
+
if key is None:
|
1320 |
+
break
|
1321 |
+
yield key, batch
|
1322 |
+
|
1323 |
+
executor.shutdown()
|
1324 |
+
|
1325 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
1326 |
+
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
|
1327 |
+
|
1328 |
+
def prepare_for_training(self):
|
1329 |
+
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
1330 |
+
|
1331 |
+
# glob cache files
|
1332 |
+
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
1333 |
+
|
1334 |
+
# assign cache files to item info
|
1335 |
+
bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
|
1336 |
+
for cache_file in latent_cache_files:
|
1337 |
+
tokens = os.path.basename(cache_file).split("_")
|
1338 |
+
|
1339 |
+
image_size = tokens[-2] # 0000x0000
|
1340 |
+
image_width, image_height = map(int, image_size.split("x"))
|
1341 |
+
image_size = (image_width, image_height)
|
1342 |
+
|
1343 |
+
frame_pos, frame_count = tokens[-3].split("-")
|
1344 |
+
frame_pos, frame_count = int(frame_pos), int(frame_count)
|
1345 |
+
|
1346 |
+
item_key = "_".join(tokens[:-3])
|
1347 |
+
text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
|
1348 |
+
if not os.path.exists(text_encoder_output_cache_file):
|
1349 |
+
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
|
1350 |
+
continue
|
1351 |
+
|
1352 |
+
bucket_reso = bucket_selector.get_bucket_resolution(image_size)
|
1353 |
+
bucket_reso = (*bucket_reso, frame_count)
|
1354 |
+
item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
|
1355 |
+
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
|
1356 |
+
|
1357 |
+
bucket = bucketed_item_info.get(bucket_reso, [])
|
1358 |
+
for _ in range(self.num_repeats):
|
1359 |
+
bucket.append(item_info)
|
1360 |
+
bucketed_item_info[bucket_reso] = bucket
|
1361 |
+
|
1362 |
+
# prepare batch manager
|
1363 |
+
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
|
1364 |
+
self.batch_manager.show_bucket_info()
|
1365 |
+
|
1366 |
+
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
|
1367 |
+
|
1368 |
+
def shuffle_buckets(self):
|
1369 |
+
# set random seed for this epoch
|
1370 |
+
random.seed(self.seed + self.current_epoch)
|
1371 |
+
self.batch_manager.shuffle()
|
1372 |
+
|
1373 |
+
def __len__(self):
|
1374 |
+
if self.batch_manager is None:
|
1375 |
+
return 100 # dummy value
|
1376 |
+
return len(self.batch_manager)
|
1377 |
+
|
1378 |
+
def __getitem__(self, idx):
|
1379 |
+
return self.batch_manager[idx]
|
1380 |
+
|
1381 |
+
|
1382 |
+
class DatasetGroup(torch.utils.data.ConcatDataset):
|
1383 |
+
def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
|
1384 |
+
super().__init__(datasets)
|
1385 |
+
self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
|
1386 |
+
self.num_train_items = 0
|
1387 |
+
for dataset in self.datasets:
|
1388 |
+
self.num_train_items += dataset.num_train_items
|
1389 |
+
|
1390 |
+
def set_current_epoch(self, epoch):
|
1391 |
+
for dataset in self.datasets:
|
1392 |
+
dataset.set_current_epoch(epoch)
|
1393 |
+
|
1394 |
+
def set_current_step(self, step):
|
1395 |
+
for dataset in self.datasets:
|
1396 |
+
dataset.set_current_step(step)
|
1397 |
+
|
1398 |
+
def set_max_train_steps(self, max_train_steps):
|
1399 |
+
for dataset in self.datasets:
|
1400 |
+
dataset.set_max_train_steps(max_train_steps)
|
docs/advanced_config.md
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
2 |
+
|
3 |
+
# Advanced configuration / 高度な設定
|
4 |
+
|
5 |
+
## How to specify `network_args` / `network_args`の指定方法
|
6 |
+
|
7 |
+
The `--network_args` option is an option for specifying detailed arguments to LoRA. Specify the arguments in the form of `key=value` in `--network_args`.
|
8 |
+
|
9 |
+
<details>
|
10 |
+
<summary>日本語</summary>
|
11 |
+
`--network_args`オプションは、LoRAへの詳細な引数を指定するためのオプションです。`--network_args`には、`key=value`の形式で引数を指定します。
|
12 |
+
</details>
|
13 |
+
|
14 |
+
### Example / 記述例
|
15 |
+
|
16 |
+
If you specify it on the command line, write as follows. / コマンドラインで指定する場合は以下のように記述します。
|
17 |
+
|
18 |
+
```bash
|
19 |
+
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
|
20 |
+
--network_module networks.lora --network_dim 32
|
21 |
+
--network_args "key1=value1" "key2=value2" ...
|
22 |
+
```
|
23 |
+
|
24 |
+
If you specify it in the configuration file, write as follows. / 設定ファイルで指定する場合は以下のように記述します。
|
25 |
+
|
26 |
+
```toml
|
27 |
+
network_args = ["key1=value1", "key2=value2", ...]
|
28 |
+
```
|
29 |
+
|
30 |
+
If you specify `"verbose=True"`, detailed information of LoRA will be displayed. / `"verbose=True"`を指定するとLoRAの詳細な情報が表示されます。
|
31 |
+
|
32 |
+
```bash
|
33 |
+
--network_args "verbose=True" "key1=value1" "key2=value2" ...
|
34 |
+
```
|
35 |
+
|
36 |
+
## LoRA+
|
37 |
+
|
38 |
+
LoRA+ is a method to improve the training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiplier for the learning rate. The original paper recommends 16, but adjust as needed. It seems to be good to start from around 4. For details, please refer to the [related PR of sd-scripts](https://github.com/kohya-ss/sd-scripts/pull/1233).
|
39 |
+
|
40 |
+
Specify `loraplus_lr_ratio` with `--network_args`.
|
41 |
+
|
42 |
+
<details>
|
43 |
+
<summary>日本語</summary>
|
44 |
+
|
45 |
+
LoRA+は、LoRAのUP側(LoRA-B)の学習率を上げることで学習速度を向上させる手法です。学習率に対する倍率を指定します。元論文では16を推奨していますが、必要に応じて調整してください。4程度から始めるとよいようです。詳細は[sd-scriptsの関連PR]https://github.com/kohya-ss/sd-scripts/pull/1233)を参照してください。
|
46 |
+
|
47 |
+
`--network_args`で`loraplus_lr_ratio`を指定します。
|
48 |
+
</details>
|
49 |
+
|
50 |
+
### Example / 記述例
|
51 |
+
|
52 |
+
```bash
|
53 |
+
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
|
54 |
+
--network_module networks.lora --network_dim 32 --network_args "loraplus_lr_ratio=4" ...
|
55 |
+
```
|
56 |
+
|
57 |
+
## Select the target modules of LoRA / LoRAの対象モジュールを選択する
|
58 |
+
|
59 |
+
*This feature is highly experimental and the specification may change. / この機能は特に実験的なもので、仕様は変更される可能性があります。*
|
60 |
+
|
61 |
+
By specifying `exclude_patterns` and `include_patterns` with `--network_args`, you can select the target modules of LoRA.
|
62 |
+
|
63 |
+
`exclude_patterns` excludes modules that match the specified pattern. `include_patterns` targets only modules that match the specified pattern.
|
64 |
+
|
65 |
+
Specify the values as a list. For example, `"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`.
|
66 |
+
|
67 |
+
The pattern is a regular expression for the module name. The module name is in the form of `double_blocks.0.img_mod.linear` or `single_blocks.39.modulation.linear`. The regular expression is not a partial match but a complete match.
|
68 |
+
|
69 |
+
The patterns are applied in the order of `exclude_patterns`→`include_patterns`. By default, the Linear layers of `img_mod`, `txt_mod`, and `modulation` of double blocks and single blocks are excluded.
|
70 |
+
|
71 |
+
(`.*(img_mod|txt_mod|modulation).*` is specified.)
|
72 |
+
|
73 |
+
<details>
|
74 |
+
<summary>日本語</summary>
|
75 |
+
|
76 |
+
`--network_args`で`exclude_patterns`と`include_patterns`を指定することで、LoRAの対象モジュールを選択することができます。
|
77 |
+
|
78 |
+
`exclude_patterns`は、指定したパターンに一致するモジュールを除外します。`include_patterns`は、指定したパターンに一致するモジュールのみを対象とします。
|
79 |
+
|
80 |
+
値は、リストで指定します。`"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`のようになります。
|
81 |
+
|
82 |
+
パターンは、モジュール名に対する正規表現です。モジュール名は、たとえば`double_blocks.0.img_mod.linear`や`single_blocks.39.modulation.linear`のような形式です。正規表現は部分一致ではなく完全一致です。
|
83 |
+
|
84 |
+
パターンは、`exclude_patterns`→`include_patterns`の順で適用されます。デフォルトは、double blocksとsingle blocksのLinear層のうち、`img_mod`、`txt_mod`、`modulation`が除外されています。
|
85 |
+
|
86 |
+
(`.*(img_mod|txt_mod|modulation).*`が指定されています。)
|
87 |
+
</details>
|
88 |
+
|
89 |
+
### Example / 記述例
|
90 |
+
|
91 |
+
Only the modules of double blocks / double blocksのモジュールのみを対象とする場合:
|
92 |
+
|
93 |
+
```bash
|
94 |
+
--network_args "exclude_patterns=[r'.*single_blocks.*']"
|
95 |
+
```
|
96 |
+
|
97 |
+
Only the modules of single blocks from the 10th / single blocksの10番目以降のLinearモジュールのみを対象とする場合:
|
98 |
+
|
99 |
+
```bash
|
100 |
+
--network_args "exclude_patterns=[r'.*']" "include_patterns=[r'.*single_blocks\.\d{2}\.linear.*']"
|
101 |
+
```
|
102 |
+
|
103 |
+
## Save and view logs in TensorBoard format / TensorBoard形式のログの保存と参照
|
104 |
+
|
105 |
+
Specify the folder to save the logs with the `--logging_dir` option. Logs in TensorBoard format will be saved.
|
106 |
+
|
107 |
+
For example, if you specify `--logging_dir=logs`, a `logs` folder will be created in the working folder, and logs will be saved in the date folder inside it.
|
108 |
+
|
109 |
+
Also, if you specify the `--log_prefix` option, the specified string will be added before the date. For example, use `--logging_dir=logs --log_prefix=lora_setting1_` for identification.
|
110 |
+
|
111 |
+
To view logs in TensorBoard, open another command prompt and activate the virtual environment. Then enter the following in the working folder.
|
112 |
+
|
113 |
+
```powershell
|
114 |
+
tensorboard --logdir=logs
|
115 |
+
```
|
116 |
+
|
117 |
+
(tensorboard installation is required.)
|
118 |
+
|
119 |
+
Then open a browser and access http://localhost:6006/ to display it.
|
120 |
+
|
121 |
+
<details>
|
122 |
+
<summary>日本語</summary>
|
123 |
+
`--logging_dir`オプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
|
124 |
+
|
125 |
+
たとえば`--logging_dir=logs`と指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。
|
126 |
+
|
127 |
+
また`--log_prefix`オプションを指定すると、日時の前に指定した文字列が追加されます。`--logging_dir=logs --log_prefix=lora_setting1_`などとして識別用にお使いください。
|
128 |
+
|
129 |
+
TensorBoardでログを確認するには、別のコマンドプロンプトを開き、仮想環境を有効にしてから、作業フォルダで以下のように入力します。
|
130 |
+
|
131 |
+
```powershell
|
132 |
+
tensorboard --logdir=logs
|
133 |
+
```
|
134 |
+
|
135 |
+
(tensorboardのインストールが必要です。)
|
136 |
+
|
137 |
+
その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
|
138 |
+
</details>
|
139 |
+
|
140 |
+
## Save and view logs in wandb / wandbでログの保存と参照
|
141 |
+
|
142 |
+
`--log_with wandb` option is available to save logs in wandb format. `tensorboard` or `all` is also available. The default is `tensorboard`.
|
143 |
+
|
144 |
+
Specify the project name with `--log_tracker_name` when using wandb.
|
145 |
+
|
146 |
+
<details>
|
147 |
+
<summary>日本語</summary>
|
148 |
+
`--log_with wandb`オプションを指定するとwandb形式でログを保存することができます。`tensorboard`や`all`も指定可能です。デフォルトは`tensorboard`です。
|
149 |
+
|
150 |
+
wandbを使用する場合は、`--log_tracker_name`でプロジェクト名を指定してください。
|
151 |
+
</details>
|
docs/sampling_during_training.md
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
2 |
+
|
3 |
+
# Sampling during training / 学習中のサンプル画像生成
|
4 |
+
|
5 |
+
By preparing a prompt file, you can generate sample images during training.
|
6 |
+
|
7 |
+
Please be aware that it consumes a considerable amount of VRAM, so be careful when generating sample images for videos with a large number of frames. Also, since it takes time to generate, adjust the frequency of sample image generation as needed.
|
8 |
+
|
9 |
+
<details>
|
10 |
+
<summary>日本語</summary>
|
11 |
+
|
12 |
+
プロンプトファイルを用意することで、学習中にサンプル画像を生成することができます。
|
13 |
+
|
14 |
+
VRAMをそれなりに消費しますので、特にフレーム数が多い動画を生成する場合は注意してください。また生成には時間がかかりますので、サンプル画像生成の頻度は適宜調整してください。
|
15 |
+
</details>
|
16 |
+
|
17 |
+
## How to use / 使い方
|
18 |
+
|
19 |
+
### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
|
20 |
+
|
21 |
+
Example of command line options for training with sampling / 記述例:
|
22 |
+
|
23 |
+
```bash
|
24 |
+
--vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt
|
25 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
|
26 |
+
--text_encoder1 path/to/ckpts/text_encoder
|
27 |
+
--text_encoder2 path/to/ckpts/text_encoder_2
|
28 |
+
--sample_prompts /path/to/prompt_file.txt
|
29 |
+
--sample_every_n_epochs 1 --sample_every_n_steps 1000 --sample_at_first
|
30 |
+
```
|
31 |
+
|
32 |
+
`--vae`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size`, `--text_encoder1`, `--text_encoder2` are the same as when generating images, so please refer to [here](/README.md#inference) for details. `--fp8_llm` can also be specified.
|
33 |
+
|
34 |
+
`--sample_prompts` specifies the path to the prompt file used for sample image generation. Details are described below.
|
35 |
+
|
36 |
+
`--sample_every_n_epochs` specifies how often to generate sample images in epochs, and `--sample_every_n_steps` specifies how often to generate sample images in steps.
|
37 |
+
|
38 |
+
`--sample_at_first` is specified when generating sample images at the beginning of training.
|
39 |
+
|
40 |
+
Sample images and videos are saved in the `sample` directory in the directory specified by `--output_dir`. They are saved as `.png` for still images and `.mp4` for videos.
|
41 |
+
|
42 |
+
<details>
|
43 |
+
<summary>日本語</summary>
|
44 |
+
|
45 |
+
`--vae`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`、`--text_encoder1`、`--text_encoder2`は、画像生成時と同様ですので、詳細は[こちら](/README.ja.md#推論)を参照してください。`--fp8_llm`も指定可能です。
|
46 |
+
|
47 |
+
`--sample_prompts`は、サンプル画像生成に使用するプロンプトファイルのパスを指定します。詳細は後述します。
|
48 |
+
|
49 |
+
`--sample_every_n_epochs`は、何エポックごとにサンプル画像を生成するかを、`--sample_every_n_steps`は、何ステップごとにサンプル画像を生成するかを指定します。
|
50 |
+
|
51 |
+
`--sample_at_first`は、学習開始時にサンプル画像を生成する場合に指定します。
|
52 |
+
|
53 |
+
サンプル画像、動画は、`--output_dir`で指定したディレクトリ内の、`sample`ディレクトリに保存されます。静止画の場合は`.png`、動画の場合は`.mp4`で保存されます。
|
54 |
+
</details>
|
55 |
+
|
56 |
+
### Prompt file / プロンプトファイル
|
57 |
+
|
58 |
+
The prompt file is a text file that contains the prompts for generating sample images. The example is as follows. / プロンプトファイルは、サンプル画像生成のためのプロンプトを記述したテキストファイルです。例は以下の通りです。
|
59 |
+
|
60 |
+
```
|
61 |
+
# prompt 1: for generating a cat video
|
62 |
+
A cat walks on the grass, realistic style. --w 640 --h 480 --f 25 --d 1 --s 20
|
63 |
+
|
64 |
+
# prompt 2: for generating a dog image
|
65 |
+
A dog runs on the beach, realistic style. --w 960 --h 544 --f 1 --d 2 --s 20
|
66 |
+
```
|
67 |
+
|
68 |
+
A line starting with `#` is a comment.
|
69 |
+
|
70 |
+
* `--w` specifies the width of the generated image or video. The default is 256.
|
71 |
+
* `--h` specifies the height. The default is 256.
|
72 |
+
* `--f` specifies the number of frames. The default is 1, which generates a still image.
|
73 |
+
* `--d` specifies the seed. The default is random.
|
74 |
+
* `--s` specifies the number of steps in generation. The default is 20.
|
75 |
+
* `--g` specifies the guidance scale. The default is 6.0, which is the default value during inference of HunyuanVideo. Specify 1.0 for SkyReels V1 models. Ignore this option for Wan2.1 models.
|
76 |
+
* `--fs` specifies the discrete flow shift. The default is 14.5, which corresponds to the number of steps 20. In the HunyuanVideo paper, 7.0 is recommended for 50 steps, and 17.0 is recommended for less than 20 steps (e.g. 10).
|
77 |
+
|
78 |
+
If you train I2V models, you can use the additional options below.
|
79 |
+
|
80 |
+
* `--i path/to/image.png`: the image path for image2video inference.
|
81 |
+
|
82 |
+
If you train the model with classifier free guidance, you can use the additional options below.
|
83 |
+
|
84 |
+
*`--n negative prompt...`: the negative prompt for the classifier free guidance.
|
85 |
+
*`--l 6.0`: the classifier free guidance scale. Should be set to 6.0 for SkyReels V1 models. 5.0 is the default value for Wan2.1 (if omitted).
|
86 |
+
|
87 |
+
<details>
|
88 |
+
<summary>日本語</summary>
|
89 |
+
|
90 |
+
`#` で始まる行はコメントです。
|
91 |
+
|
92 |
+
* `--w` 生成画像、動画の幅を指定します。省略時は256です。
|
93 |
+
* `--h` 高さを指定します。省略時は256です。
|
94 |
+
* `--f` フレーム数を指定します。省略時は1で、静止画を生成します。
|
95 |
+
* `--d` シードを指定します。省略時はランダムです。
|
96 |
+
* `--s` 生成におけるステップ数を指定します。省略時は20です。
|
97 |
+
* `--g` guidance scaleを指定します。省略時は6.0で、HunyuanVideoの推論時のデフォルト値です。
|
98 |
+
* `--fs` discrete flow shiftを指定します。省略時は14.5で、ステップ数20の場合に対応した値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。
|
99 |
+
|
100 |
+
I2Vモデルを学習する場合、以下の追加オプションを使用できます。
|
101 |
+
|
102 |
+
* `--i path/to/image.png`: image2video推論用の画像パス。
|
103 |
+
|
104 |
+
classifier free guidance(ネガティブプロンプト)を必要とするモデルを学習する場合、以下の追加オプションを使用できます。
|
105 |
+
|
106 |
+
*`--n negative prompt...`: classifier free guidance用のネガティブプロンプト。
|
107 |
+
*`--l 6.0`: classifier free guidance scale。SkyReels V1モデルの場合は6.0に設定してください。Wan2.1の場合はデフォルト値が5.0です(省略時)。
|
108 |
+
</details>
|
docs/wan.md
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
2 |
+
|
3 |
+
# Wan 2.1
|
4 |
+
|
5 |
+
## Overview / 概要
|
6 |
+
|
7 |
+
This is an unofficial training and inference script for [Wan2.1](https://github.com/Wan-Video/Wan2.1). The features are as follows.
|
8 |
+
|
9 |
+
- fp8 support and memory reduction by block swap: Inference of a 720x1280x81frames videos with 24GB VRAM, training with 720x1280 images with 24GB VRAM
|
10 |
+
- Inference without installing Flash attention (using PyTorch's scaled dot product attention)
|
11 |
+
- Supports xformers and Sage attention
|
12 |
+
|
13 |
+
This feature is experimental.
|
14 |
+
|
15 |
+
<details>
|
16 |
+
<summary>日本語</summary>
|
17 |
+
[Wan2.1](https://github.com/Wan-Video/Wan2.1) の非公式の学習および推論スクリプトです。
|
18 |
+
|
19 |
+
以下の特徴があります。
|
20 |
+
|
21 |
+
- fp8対応およびblock swapによる省メモリ化:720x1280x81framesの動画を24GB VRAMで推論可能、720x1280の画像での学習が24GB VRAMで可能
|
22 |
+
- Flash attentionのインストールなしでの実行(PyTorchのscaled dot product attentionを使用)
|
23 |
+
- xformersおよびSage attention対応
|
24 |
+
|
25 |
+
この機能は実験的なものです。
|
26 |
+
</details>
|
27 |
+
|
28 |
+
## Download the model / モデルのダウンロード
|
29 |
+
|
30 |
+
Download the T5 `models_t5_umt5-xxl-enc-bf16.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
|
31 |
+
|
32 |
+
Download the VAE from the above page `Wan2.1_VAE.pth` or download `split_files/vae/wan_2.1_vae.safetensors` from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
|
33 |
+
|
34 |
+
Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
35 |
+
|
36 |
+
Please select the appropriate weights according to T2V, I2V, resolution, model size, etc. fp8 models can be used if `--fp8` is specified.
|
37 |
+
|
38 |
+
(Thanks to Comfy-Org for providing the repackaged weights.)
|
39 |
+
<details>
|
40 |
+
<summary>日本語</summary>
|
41 |
+
T5 `models_t5_umt5-xxl-enc-bf16.pth` およびCLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を、次のページからダウンロードしてください:https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
|
42 |
+
|
43 |
+
VAEは上のページから `Wan2.1_VAE.pth` をダウンロードするか、次のページから `split_files/vae/wan_2.1_vae.safetensors` をダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
|
44 |
+
|
45 |
+
DiTの重みを次のページからダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
46 |
+
|
47 |
+
T2VやI2V、解像度、モデルサイズなどにより適切な重みを選択してください。`--fp8`指定時はfp8モデルも使用できます。
|
48 |
+
|
49 |
+
(repackaged版の重みを提供してくださっているComfy-Orgに感謝いたします。)
|
50 |
+
</details>
|
51 |
+
|
52 |
+
## Pre-caching / 事前キャッシュ
|
53 |
+
|
54 |
+
### Latent Pre-caching
|
55 |
+
|
56 |
+
Latent pre-caching is almost the same as in HunyuanVideo. Create the cache using the following command:
|
57 |
+
|
58 |
+
```bash
|
59 |
+
python wan_cache_latents.py --dataset_config path/to/toml --vae path/to/wan_2.1_vae.safetensors
|
60 |
+
```
|
61 |
+
|
62 |
+
If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model. If not specified, the training will raise an error.
|
63 |
+
|
64 |
+
If you're running low on VRAM, specify `--vae_cache_cpu` to use the CPU for the VAE internal cache, which will reduce VRAM usage somewhat.
|
65 |
+
|
66 |
+
<details>
|
67 |
+
<summary>日本語</summary>
|
68 |
+
latentの事前キャッシングはHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
|
69 |
+
|
70 |
+
I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。指定しないと学習時にエラーが発生します。
|
71 |
+
|
72 |
+
VRAMが不足している場合は、`--vae_cache_cpu` を指定するとVAEの内部キャッシュにCPUを使うことで、使用VRAMを多少削減できます。
|
73 |
+
</details>
|
74 |
+
|
75 |
+
### Text Encoder Output Pre-caching
|
76 |
+
|
77 |
+
Text encoder output pre-caching is also almost the same as in HunyuanVideo. Create the cache using the following command:
|
78 |
+
|
79 |
+
```bash
|
80 |
+
python wan_cache_text_encoder_outputs.py --dataset_config path/to/toml --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --batch_size 16
|
81 |
+
```
|
82 |
+
|
83 |
+
Adjust `--batch_size` according to your available VRAM.
|
84 |
+
|
85 |
+
For systems with limited VRAM (less than ~16GB), use `--fp8_t5` to run the T5 in fp8 mode.
|
86 |
+
|
87 |
+
<details>
|
88 |
+
<summary>日本語</summary>
|
89 |
+
テキストエンコーダ出力の事前キャッシングもHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
|
90 |
+
|
91 |
+
使用可能なVRAMに合わせて `--batch_size` を調整してください。
|
92 |
+
|
93 |
+
VRAMが限られているシステム(約16GB未満)の場合は、T5をfp8モードで実行するために `--fp8_t5` を使用してください。
|
94 |
+
</details>
|
95 |
+
|
96 |
+
## Training / 学習
|
97 |
+
|
98 |
+
### Training
|
99 |
+
|
100 |
+
Start training using the following command (input as a single line):
|
101 |
+
|
102 |
+
```bash
|
103 |
+
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 wan_train_network.py
|
104 |
+
--task t2v-1.3B
|
105 |
+
--dit path/to/wan2.1_xxx_bf16.safetensors
|
106 |
+
--dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base
|
107 |
+
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing
|
108 |
+
--max_data_loader_n_workers 2 --persistent_data_loader_workers
|
109 |
+
--network_module networks.lora_wan --network_dim 32
|
110 |
+
--timestep_sampling shift --discrete_flow_shift 3.0
|
111 |
+
--max_train_epochs 16 --save_every_n_epochs 1 --seed 42
|
112 |
+
--output_dir path/to/output_dir --output_name name-of-lora
|
113 |
+
```
|
114 |
+
The above is an example. The appropriate values for `timestep_sampling` and `discrete_flow_shift` need to be determined by experimentation.
|
115 |
+
|
116 |
+
For additional options, use `python wan_train_network.py --help` (note that many options are unverified).
|
117 |
+
|
118 |
+
`--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`. Specify the DiT weights for the task with `--dit`.
|
119 |
+
|
120 |
+
Don't forget to specify `--network_module networks.lora_wan`.
|
121 |
+
|
122 |
+
Other options are mostly the same as `hv_train_network.py`.
|
123 |
+
|
124 |
+
Use `convert_lora.py` for converting the LoRA weights after training, as in HunyuanVideo.
|
125 |
+
|
126 |
+
<details>
|
127 |
+
<summary>日本語</summary>
|
128 |
+
`timestep_sampling`や`discrete_flow_shift`は一例です。どのような値が適切かは実験が必要です。
|
129 |
+
|
130 |
+
その他のオプションについては `python wan_train_network.py --help` を使用してください(多くのオプションは未検証です)。
|
131 |
+
|
132 |
+
`--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。`--dit`に、taskに応じたDiTの重みを指定してください。
|
133 |
+
|
134 |
+
`--network_module` に `networks.lora_wan` を指定することを忘れないでください。
|
135 |
+
|
136 |
+
その他のオプションは、ほぼ`hv_train_network.py`と同様です。
|
137 |
+
|
138 |
+
学習後のLoRAの重みの変換は、HunyuanVideoと同様に`convert_lora.py`を使用してください。
|
139 |
+
</details>
|
140 |
+
|
141 |
+
### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
|
142 |
+
|
143 |
+
Example of command line options for training with sampling / 記述例:
|
144 |
+
|
145 |
+
```bash
|
146 |
+
--vae path/to/wan_2.1_vae.safetensors
|
147 |
+
--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
|
148 |
+
--sample_prompts /path/to/prompt_file.txt
|
149 |
+
--sample_every_n_epochs 1 --sample_every_n_steps 1000 -- sample_at_first
|
150 |
+
```
|
151 |
+
Each option is the same as when generating images or as HunyuanVideo. Please refer to [here](/docs/sampling_during_training.md) for details.
|
152 |
+
|
153 |
+
If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model.
|
154 |
+
|
155 |
+
You can specify the initial image and negative prompts in the prompt file. Please refer to [here](/docs/sampling_during_training.md#prompt-file--プロンプトファイル).
|
156 |
+
|
157 |
+
<details>
|
158 |
+
<summary>日本語</summary>
|
159 |
+
各オプションは推論時、およびHunyuanVideoの場合と同様です。[こちら](/docs/sampling_during_training.md)を参照してください。
|
160 |
+
|
161 |
+
I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。
|
162 |
+
|
163 |
+
プロンプトファイルで、初期画像やネガティブプロンプト等を指定できます。[こちら](/docs/sampling_during_training.md#prompt-file--プロンプトファイル)を参照してください。
|
164 |
+
</details>
|
165 |
+
|
166 |
+
|
167 |
+
## Inference / 推論
|
168 |
+
|
169 |
+
### T2V Inference / T2V推論
|
170 |
+
|
171 |
+
The following is an example of T2V inference (input as a single line):
|
172 |
+
|
173 |
+
```bash
|
174 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 832 480 --video_length 81 --infer_steps 20
|
175 |
+
--prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
|
176 |
+
--dit path/to/wan2.1_t2v_1.3B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
|
177 |
+
--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
|
178 |
+
--attn_mode torch
|
179 |
+
```
|
180 |
+
|
181 |
+
`--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`.
|
182 |
+
|
183 |
+
`--attn_mode` is `torch`, `sdpa` (same as `torch`), `xformers`, `sageattn`,`flash2`, `flash` (same as `flash2`) or `flash3`. `torch` is the default. Other options require the corresponding library to be installed. `flash3` (Flash attention 3) is not tested.
|
184 |
+
|
185 |
+
`--fp8_t5` can be used to specify the T5 model in fp8 format. This option reduces memory usage for the T5 model.
|
186 |
+
|
187 |
+
`--negative_prompt` can be used to specify a negative prompt. If omitted, the default negative prompt is used.
|
188 |
+
|
189 |
+
` --flow_shift` can be used to specify the flow shift (default 3.0 for I2V with 480p, 5.0 for others).
|
190 |
+
|
191 |
+
`--guidance_scale` can be used to specify the guidance scale for classifier free guiance (default 5.0).
|
192 |
+
|
193 |
+
`--blocks_to_swap` is the number of blocks to swap during inference. The default value is None (no block swap). The maximum value is 39 for 14B model and 29 for 1.3B model.
|
194 |
+
|
195 |
+
`--vae_cache_cpu` enables VAE cache in main memory. This reduces VRAM usage slightly but processing is slower.
|
196 |
+
|
197 |
+
Other options are same as `hv_generate_video.py` (some options are not supported, please check the help).
|
198 |
+
|
199 |
+
<details>
|
200 |
+
<summary>日本語</summary>
|
201 |
+
`--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。
|
202 |
+
|
203 |
+
`--attn_mode` には `torch`, `sdpa`(`torch`と同じ)、`xformers`, `sageattn`, `flash2`, `flash`(`flash2`と同じ), `flash3` のいずれかを指定します。デフォルトは `torch` です。その他のオプションを使用する場合は、対応するライブラリをインストールする必要があります。`flash3`(Flash attention 3)は未テストです。
|
204 |
+
|
205 |
+
`--fp8_t5` を指定するとT5モデルをfp8形式で実行します。T5モデル呼び出し時のメモリ使用量を削減します。
|
206 |
+
|
207 |
+
`--negative_prompt` でネガティブプロンプトを指定できます。省略した場合はデフォルトのネガティブプロンプトが使用されます。
|
208 |
+
|
209 |
+
`--flow_shift` でflow shiftを指定できます(480pのI2Vの場合はデフォルト3.0、それ以外は5.0)。
|
210 |
+
|
211 |
+
`--guidance_scale` でclassifier free guianceのガイダンススケールを指定できます(デフォルト5.0)。
|
212 |
+
|
213 |
+
`--blocks_to_swap` は推論時のblock swapの数です。デフォルト値はNone(block swapなし)です。最大値は14Bモデルの場合39、1.3Bモデルの場合29です。
|
214 |
+
|
215 |
+
`--vae_cache_cpu` を有効にすると、VAEのキャッシュをメインメモリに保持します。VRAM使用量が多少減りますが、処理は遅くなります。
|
216 |
+
|
217 |
+
その他のオプションは `hv_generate_video.py` と同じです(一部のオプションはサポートされていないため、ヘルプを確認してください)。
|
218 |
+
</details>
|
219 |
+
|
220 |
+
### I2V Inference / I2V推論
|
221 |
+
|
222 |
+
The following is an example of I2V inference (input as a single line):
|
223 |
+
|
224 |
+
```bash
|
225 |
+
python wan_generate_video.py --fp8 --task i2v-14B --video_size 832 480 --video_length 81 --infer_steps 20
|
226 |
+
--prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
|
227 |
+
--dit path/to/wan2.1_i2v_480p_14B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
|
228 |
+
--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
229 |
+
--attn_mode torch --image_path path/to/image.jpg
|
230 |
+
```
|
231 |
+
|
232 |
+
Add `--clip` to specify the CLIP model. `--image_path` is the path to the image to be used as the initial frame.
|
233 |
+
|
234 |
+
Other options are same as T2V inference.
|
235 |
+
|
236 |
+
<details>
|
237 |
+
<summary>日本語</summary>
|
238 |
+
`--clip` を追加してCLIPモデルを指定します。`--image_path` は初期フレームとして使用する画像のパスです。
|
239 |
+
|
240 |
+
その他のオプションはT2V推論と同じです。
|
241 |
+
</details>
|
hunyuan_model/__init__.py
ADDED
File without changes
|
hunyuan_model/activation_layers.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation_layer(act_type):
|
5 |
+
"""get activation layer
|
6 |
+
|
7 |
+
Args:
|
8 |
+
act_type (str): the activation type
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
torch.nn.functional: the activation layer
|
12 |
+
"""
|
13 |
+
if act_type == "gelu":
|
14 |
+
return lambda: nn.GELU()
|
15 |
+
elif act_type == "gelu_tanh":
|
16 |
+
# Approximate `tanh` requires torch >= 1.13
|
17 |
+
return lambda: nn.GELU(approximate="tanh")
|
18 |
+
elif act_type == "relu":
|
19 |
+
return nn.ReLU
|
20 |
+
elif act_type == "silu":
|
21 |
+
return nn.SiLU
|
22 |
+
else:
|
23 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
hunyuan_model/attention.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
try:
|
9 |
+
import flash_attn
|
10 |
+
from flash_attn.flash_attn_interface import _flash_attn_forward
|
11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
12 |
+
from flash_attn.flash_attn_interface import flash_attn_func
|
13 |
+
except ImportError:
|
14 |
+
flash_attn = None
|
15 |
+
flash_attn_varlen_func = None
|
16 |
+
_flash_attn_forward = None
|
17 |
+
flash_attn_func = None
|
18 |
+
|
19 |
+
try:
|
20 |
+
print(f"Trying to import sageattention")
|
21 |
+
from sageattention import sageattn_varlen, sageattn
|
22 |
+
|
23 |
+
print("Successfully imported sageattention")
|
24 |
+
except ImportError:
|
25 |
+
print(f"Failed to import sageattention")
|
26 |
+
sageattn_varlen = None
|
27 |
+
sageattn = None
|
28 |
+
|
29 |
+
try:
|
30 |
+
import xformers.ops as xops
|
31 |
+
except ImportError:
|
32 |
+
xops = None
|
33 |
+
|
34 |
+
MEMORY_LAYOUT = {
|
35 |
+
"flash": (
|
36 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
37 |
+
lambda x: x,
|
38 |
+
),
|
39 |
+
"flash_fixlen": (
|
40 |
+
lambda x: x,
|
41 |
+
lambda x: x,
|
42 |
+
),
|
43 |
+
"sageattn": (
|
44 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
45 |
+
lambda x: x,
|
46 |
+
),
|
47 |
+
"sageattn_fixlen": (
|
48 |
+
lambda x: x.transpose(1, 2),
|
49 |
+
lambda x: x.transpose(1, 2),
|
50 |
+
),
|
51 |
+
"torch": (
|
52 |
+
lambda x: x.transpose(1, 2),
|
53 |
+
lambda x: x.transpose(1, 2),
|
54 |
+
),
|
55 |
+
"xformers": (
|
56 |
+
lambda x: x,
|
57 |
+
lambda x: x,
|
58 |
+
),
|
59 |
+
"vanilla": (
|
60 |
+
lambda x: x.transpose(1, 2),
|
61 |
+
lambda x: x.transpose(1, 2),
|
62 |
+
),
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def get_cu_seqlens(text_mask, img_len):
|
67 |
+
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
|
68 |
+
|
69 |
+
Args:
|
70 |
+
text_mask (torch.Tensor): the mask of text
|
71 |
+
img_len (int): the length of image
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: the calculated cu_seqlens for flash attention
|
75 |
+
"""
|
76 |
+
batch_size = text_mask.shape[0]
|
77 |
+
text_len = text_mask.sum(dim=1)
|
78 |
+
max_len = text_mask.shape[1] + img_len
|
79 |
+
|
80 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
81 |
+
|
82 |
+
for i in range(batch_size):
|
83 |
+
s = text_len[i] + img_len
|
84 |
+
s1 = i * max_len + s
|
85 |
+
s2 = (i + 1) * max_len
|
86 |
+
cu_seqlens[2 * i + 1] = s1
|
87 |
+
cu_seqlens[2 * i + 2] = s2
|
88 |
+
|
89 |
+
return cu_seqlens
|
90 |
+
|
91 |
+
|
92 |
+
def attention(
|
93 |
+
q_or_qkv_list,
|
94 |
+
k=None,
|
95 |
+
v=None,
|
96 |
+
mode="flash",
|
97 |
+
drop_rate=0,
|
98 |
+
attn_mask=None,
|
99 |
+
total_len=None,
|
100 |
+
causal=False,
|
101 |
+
cu_seqlens_q=None,
|
102 |
+
cu_seqlens_kv=None,
|
103 |
+
max_seqlen_q=None,
|
104 |
+
max_seqlen_kv=None,
|
105 |
+
batch_size=1,
|
106 |
+
):
|
107 |
+
"""
|
108 |
+
Perform QKV self attention.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
112 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
113 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
114 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
115 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
116 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
117 |
+
(default: None)
|
118 |
+
causal (bool): Whether to use causal attention. (default: False)
|
119 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
120 |
+
used to index into q.
|
121 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
122 |
+
used to index into kv.
|
123 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
124 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
128 |
+
"""
|
129 |
+
q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
|
130 |
+
if type(q_or_qkv_list) == list:
|
131 |
+
q_or_qkv_list.clear()
|
132 |
+
split_attn = total_len is not None
|
133 |
+
if split_attn and mode == "sageattn":
|
134 |
+
mode = "sageattn_fixlen"
|
135 |
+
elif split_attn and mode == "flash":
|
136 |
+
mode = "flash_fixlen"
|
137 |
+
# print(f"Attention mode: {mode}, split_attn: {split_attn}")
|
138 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
139 |
+
|
140 |
+
# trim the sequence length to the actual length instead of attn_mask
|
141 |
+
if split_attn:
|
142 |
+
trimmed_len = q.shape[1] - total_len
|
143 |
+
q = [q[i : i + 1, : total_len[i]] for i in range(len(q))]
|
144 |
+
k = [k[i : i + 1, : total_len[i]] for i in range(len(k))]
|
145 |
+
v = [v[i : i + 1, : total_len[i]] for i in range(len(v))]
|
146 |
+
q = [pre_attn_layout(q_i) for q_i in q]
|
147 |
+
k = [pre_attn_layout(k_i) for k_i in k]
|
148 |
+
v = [pre_attn_layout(v_i) for v_i in v]
|
149 |
+
# print(
|
150 |
+
# f"Trimming the sequence length to {total_len},trimmed_len: {trimmed_len}, q.shape: {[q_i.shape for q_i in q]}, mode: {mode}"
|
151 |
+
# )
|
152 |
+
else:
|
153 |
+
q = pre_attn_layout(q)
|
154 |
+
k = pre_attn_layout(k)
|
155 |
+
v = pre_attn_layout(v)
|
156 |
+
|
157 |
+
if mode == "torch":
|
158 |
+
if split_attn:
|
159 |
+
x = []
|
160 |
+
for i in range(len(q)):
|
161 |
+
x_i = F.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate, is_causal=causal)
|
162 |
+
q[i], k[i], v[i] = None, None, None
|
163 |
+
x.append(x_i)
|
164 |
+
del q, k, v
|
165 |
+
else:
|
166 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
167 |
+
attn_mask = attn_mask.to(q.dtype)
|
168 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
169 |
+
del q, k, v
|
170 |
+
del attn_mask
|
171 |
+
|
172 |
+
elif mode == "xformers":
|
173 |
+
# B, M, H, K: M is the sequence length, H is the number of heads, K is the dimension of the heads -> it is same as input dimension
|
174 |
+
# currently only support batch_size = 1
|
175 |
+
assert split_attn, "Xformers only supports splitting"
|
176 |
+
x = []
|
177 |
+
for i in range(len(q)):
|
178 |
+
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) # , causal=causal)
|
179 |
+
q[i], k[i], v[i] = None, None, None
|
180 |
+
x.append(x_i)
|
181 |
+
del q, k, v
|
182 |
+
|
183 |
+
elif mode == "flash":
|
184 |
+
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
185 |
+
del q, k, v
|
186 |
+
# x with shape [(bxs), a, d]
|
187 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
188 |
+
elif mode == "flash_fixlen":
|
189 |
+
x = []
|
190 |
+
for i in range(len(q)):
|
191 |
+
# q: (batch_size, seqlen, nheads, headdim), k: (batch_size, seqlen, nheads_k, headdim), v: (batch_size, seqlen, nheads_k, headdim)
|
192 |
+
x_i = flash_attn_func(q[i], k[i], v[i], dropout_p=drop_rate, causal=causal)
|
193 |
+
q[i], k[i], v[i] = None, None, None
|
194 |
+
x.append(x_i)
|
195 |
+
del q, k, v
|
196 |
+
elif mode == "sageattn":
|
197 |
+
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
198 |
+
del q, k, v
|
199 |
+
# x with shape [(bxs), a, d]
|
200 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
201 |
+
elif mode == "sageattn_fixlen":
|
202 |
+
x = []
|
203 |
+
for i in range(len(q)):
|
204 |
+
# HND seems to cause an error
|
205 |
+
x_i = sageattn(q[i], k[i], v[i]) # (batch_size, seq_len, head_num, head_dim)
|
206 |
+
q[i], k[i], v[i] = None, None, None
|
207 |
+
x.append(x_i)
|
208 |
+
del q, k, v
|
209 |
+
elif mode == "vanilla":
|
210 |
+
assert not split_attn, "Vanilla attention does not support trimming"
|
211 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
212 |
+
|
213 |
+
b, a, s, _ = q.shape
|
214 |
+
s1 = k.size(2)
|
215 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
216 |
+
if causal:
|
217 |
+
# Only applied to self attention
|
218 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
219 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
220 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
221 |
+
attn_bias.to(q.dtype)
|
222 |
+
|
223 |
+
if attn_mask is not None:
|
224 |
+
if attn_mask.dtype == torch.bool:
|
225 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
226 |
+
else:
|
227 |
+
attn_bias += attn_mask
|
228 |
+
|
229 |
+
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
230 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
231 |
+
attn += attn_bias
|
232 |
+
attn = attn.softmax(dim=-1)
|
233 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
234 |
+
x = attn @ v
|
235 |
+
else:
|
236 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
237 |
+
|
238 |
+
if split_attn:
|
239 |
+
x = [post_attn_layout(x_i) for x_i in x]
|
240 |
+
for i in range(len(x)):
|
241 |
+
x[i] = F.pad(x[i], (0, 0, 0, 0, 0, trimmed_len[i]))
|
242 |
+
x = torch.cat(x, dim=0)
|
243 |
+
else:
|
244 |
+
x = post_attn_layout(x)
|
245 |
+
|
246 |
+
b, s, a, d = x.shape
|
247 |
+
out = x.reshape(b, s, -1)
|
248 |
+
return out
|
249 |
+
|
250 |
+
|
251 |
+
def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
|
252 |
+
attn1 = hybrid_seq_parallel_attn(
|
253 |
+
None,
|
254 |
+
q[:, :img_q_len, :, :],
|
255 |
+
k[:, :img_kv_len, :, :],
|
256 |
+
v[:, :img_kv_len, :, :],
|
257 |
+
dropout_p=0.0,
|
258 |
+
causal=False,
|
259 |
+
joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
|
260 |
+
joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
|
261 |
+
joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
|
262 |
+
joint_strategy="rear",
|
263 |
+
)
|
264 |
+
if flash_attn.__version__ >= "2.7.0":
|
265 |
+
attn2, *_ = _flash_attn_forward(
|
266 |
+
q[:, cu_seqlens_q[1] :],
|
267 |
+
k[:, cu_seqlens_kv[1] :],
|
268 |
+
v[:, cu_seqlens_kv[1] :],
|
269 |
+
dropout_p=0.0,
|
270 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
271 |
+
causal=False,
|
272 |
+
window_size_left=-1,
|
273 |
+
window_size_right=-1,
|
274 |
+
softcap=0.0,
|
275 |
+
alibi_slopes=None,
|
276 |
+
return_softmax=False,
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
attn2, *_ = _flash_attn_forward(
|
280 |
+
q[:, cu_seqlens_q[1] :],
|
281 |
+
k[:, cu_seqlens_kv[1] :],
|
282 |
+
v[:, cu_seqlens_kv[1] :],
|
283 |
+
dropout_p=0.0,
|
284 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
285 |
+
causal=False,
|
286 |
+
window_size=(-1, -1),
|
287 |
+
softcap=0.0,
|
288 |
+
alibi_slopes=None,
|
289 |
+
return_softmax=False,
|
290 |
+
)
|
291 |
+
attn = torch.cat([attn1, attn2], dim=1)
|
292 |
+
b, s, a, d = attn.shape
|
293 |
+
attn = attn.reshape(b, s, -1)
|
294 |
+
|
295 |
+
return attn
|
hunyuan_model/autoencoder_kl_causal_3d.py
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
from typing import Dict, Optional, Tuple, Union
|
20 |
+
from dataclasses import dataclass
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
|
27 |
+
# try:
|
28 |
+
# # This diffusers is modified and packed in the mirror.
|
29 |
+
# from diffusers.loaders import FromOriginalVAEMixin
|
30 |
+
# except ImportError:
|
31 |
+
# # Use this to be compatible with the original diffusers.
|
32 |
+
# from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
|
33 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
34 |
+
from diffusers.models.attention_processor import (
|
35 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
36 |
+
CROSS_ATTENTION_PROCESSORS,
|
37 |
+
Attention,
|
38 |
+
AttentionProcessor,
|
39 |
+
AttnAddedKVProcessor,
|
40 |
+
AttnProcessor,
|
41 |
+
)
|
42 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
43 |
+
from diffusers.models.modeling_utils import ModelMixin
|
44 |
+
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class DecoderOutput2(BaseOutput):
|
49 |
+
sample: torch.FloatTensor
|
50 |
+
posterior: Optional[DiagonalGaussianDistribution] = None
|
51 |
+
|
52 |
+
|
53 |
+
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin):
|
54 |
+
r"""
|
55 |
+
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
|
56 |
+
|
57 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
58 |
+
for all models (such as downloading or saving).
|
59 |
+
"""
|
60 |
+
|
61 |
+
_supports_gradient_checkpointing = True
|
62 |
+
|
63 |
+
@register_to_config
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
in_channels: int = 3,
|
67 |
+
out_channels: int = 3,
|
68 |
+
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
|
69 |
+
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
|
70 |
+
block_out_channels: Tuple[int] = (64,),
|
71 |
+
layers_per_block: int = 1,
|
72 |
+
act_fn: str = "silu",
|
73 |
+
latent_channels: int = 4,
|
74 |
+
norm_num_groups: int = 32,
|
75 |
+
sample_size: int = 32,
|
76 |
+
sample_tsize: int = 64,
|
77 |
+
scaling_factor: float = 0.18215,
|
78 |
+
force_upcast: float = True,
|
79 |
+
spatial_compression_ratio: int = 8,
|
80 |
+
time_compression_ratio: int = 4,
|
81 |
+
mid_block_add_attention: bool = True,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.time_compression_ratio = time_compression_ratio
|
86 |
+
|
87 |
+
self.encoder = EncoderCausal3D(
|
88 |
+
in_channels=in_channels,
|
89 |
+
out_channels=latent_channels,
|
90 |
+
down_block_types=down_block_types,
|
91 |
+
block_out_channels=block_out_channels,
|
92 |
+
layers_per_block=layers_per_block,
|
93 |
+
act_fn=act_fn,
|
94 |
+
norm_num_groups=norm_num_groups,
|
95 |
+
double_z=True,
|
96 |
+
time_compression_ratio=time_compression_ratio,
|
97 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
98 |
+
mid_block_add_attention=mid_block_add_attention,
|
99 |
+
)
|
100 |
+
|
101 |
+
self.decoder = DecoderCausal3D(
|
102 |
+
in_channels=latent_channels,
|
103 |
+
out_channels=out_channels,
|
104 |
+
up_block_types=up_block_types,
|
105 |
+
block_out_channels=block_out_channels,
|
106 |
+
layers_per_block=layers_per_block,
|
107 |
+
norm_num_groups=norm_num_groups,
|
108 |
+
act_fn=act_fn,
|
109 |
+
time_compression_ratio=time_compression_ratio,
|
110 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
111 |
+
mid_block_add_attention=mid_block_add_attention,
|
112 |
+
)
|
113 |
+
|
114 |
+
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
115 |
+
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
116 |
+
|
117 |
+
self.use_slicing = False
|
118 |
+
self.use_spatial_tiling = False
|
119 |
+
self.use_temporal_tiling = False
|
120 |
+
|
121 |
+
# only relevant if vae tiling is enabled
|
122 |
+
self.tile_sample_min_tsize = sample_tsize
|
123 |
+
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
|
124 |
+
|
125 |
+
self.tile_sample_min_size = self.config.sample_size
|
126 |
+
sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
|
127 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
128 |
+
self.tile_overlap_factor = 0.25
|
129 |
+
|
130 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
131 |
+
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
|
132 |
+
module.gradient_checkpointing = value
|
133 |
+
|
134 |
+
def enable_temporal_tiling(self, use_tiling: bool = True):
|
135 |
+
self.use_temporal_tiling = use_tiling
|
136 |
+
|
137 |
+
def disable_temporal_tiling(self):
|
138 |
+
self.enable_temporal_tiling(False)
|
139 |
+
|
140 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
141 |
+
self.use_spatial_tiling = use_tiling
|
142 |
+
|
143 |
+
def disable_spatial_tiling(self):
|
144 |
+
self.enable_spatial_tiling(False)
|
145 |
+
|
146 |
+
def enable_tiling(self, use_tiling: bool = True):
|
147 |
+
r"""
|
148 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
149 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
150 |
+
processing larger videos.
|
151 |
+
"""
|
152 |
+
self.enable_spatial_tiling(use_tiling)
|
153 |
+
self.enable_temporal_tiling(use_tiling)
|
154 |
+
|
155 |
+
def disable_tiling(self):
|
156 |
+
r"""
|
157 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
158 |
+
decoding in one step.
|
159 |
+
"""
|
160 |
+
self.disable_spatial_tiling()
|
161 |
+
self.disable_temporal_tiling()
|
162 |
+
|
163 |
+
def enable_slicing(self):
|
164 |
+
r"""
|
165 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
166 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
167 |
+
"""
|
168 |
+
self.use_slicing = True
|
169 |
+
|
170 |
+
def disable_slicing(self):
|
171 |
+
r"""
|
172 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
173 |
+
decoding in one step.
|
174 |
+
"""
|
175 |
+
self.use_slicing = False
|
176 |
+
|
177 |
+
def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
|
178 |
+
# set chunk_size to CausalConv3d recursively
|
179 |
+
def set_chunk_size(module):
|
180 |
+
if hasattr(module, "chunk_size"):
|
181 |
+
module.chunk_size = chunk_size
|
182 |
+
|
183 |
+
self.apply(set_chunk_size)
|
184 |
+
|
185 |
+
@property
|
186 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
187 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
188 |
+
r"""
|
189 |
+
Returns:
|
190 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
191 |
+
indexed by its weight name.
|
192 |
+
"""
|
193 |
+
# set recursively
|
194 |
+
processors = {}
|
195 |
+
|
196 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
197 |
+
if hasattr(module, "get_processor"):
|
198 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
199 |
+
|
200 |
+
for sub_name, child in module.named_children():
|
201 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
202 |
+
|
203 |
+
return processors
|
204 |
+
|
205 |
+
for name, module in self.named_children():
|
206 |
+
fn_recursive_add_processors(name, module, processors)
|
207 |
+
|
208 |
+
return processors
|
209 |
+
|
210 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
211 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
|
212 |
+
r"""
|
213 |
+
Sets the attention processor to use to compute attention.
|
214 |
+
|
215 |
+
Parameters:
|
216 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
217 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
218 |
+
for **all** `Attention` layers.
|
219 |
+
|
220 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
221 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
222 |
+
|
223 |
+
"""
|
224 |
+
count = len(self.attn_processors.keys())
|
225 |
+
|
226 |
+
if isinstance(processor, dict) and len(processor) != count:
|
227 |
+
raise ValueError(
|
228 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
229 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
230 |
+
)
|
231 |
+
|
232 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
233 |
+
if hasattr(module, "set_processor"):
|
234 |
+
if not isinstance(processor, dict):
|
235 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
236 |
+
else:
|
237 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
238 |
+
|
239 |
+
for sub_name, child in module.named_children():
|
240 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
241 |
+
|
242 |
+
for name, module in self.named_children():
|
243 |
+
fn_recursive_attn_processor(name, module, processor)
|
244 |
+
|
245 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
246 |
+
def set_default_attn_processor(self):
|
247 |
+
"""
|
248 |
+
Disables custom attention processors and sets the default attention implementation.
|
249 |
+
"""
|
250 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
251 |
+
processor = AttnAddedKVProcessor()
|
252 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
253 |
+
processor = AttnProcessor()
|
254 |
+
else:
|
255 |
+
raise ValueError(
|
256 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
257 |
+
)
|
258 |
+
|
259 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
260 |
+
|
261 |
+
@apply_forward_hook
|
262 |
+
def encode(
|
263 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
264 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
265 |
+
"""
|
266 |
+
Encode a batch of images/videos into latents.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
270 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
271 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
The latent representations of the encoded images/videos. If `return_dict` is True, a
|
275 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
276 |
+
"""
|
277 |
+
assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
|
278 |
+
|
279 |
+
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
|
280 |
+
return self.temporal_tiled_encode(x, return_dict=return_dict)
|
281 |
+
|
282 |
+
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
283 |
+
return self.spatial_tiled_encode(x, return_dict=return_dict)
|
284 |
+
|
285 |
+
if self.use_slicing and x.shape[0] > 1:
|
286 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
287 |
+
h = torch.cat(encoded_slices)
|
288 |
+
else:
|
289 |
+
h = self.encoder(x)
|
290 |
+
|
291 |
+
moments = self.quant_conv(h)
|
292 |
+
posterior = DiagonalGaussianDistribution(moments)
|
293 |
+
|
294 |
+
if not return_dict:
|
295 |
+
return (posterior,)
|
296 |
+
|
297 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
298 |
+
|
299 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
300 |
+
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
|
301 |
+
|
302 |
+
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
|
303 |
+
return self.temporal_tiled_decode(z, return_dict=return_dict)
|
304 |
+
|
305 |
+
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
306 |
+
return self.spatial_tiled_decode(z, return_dict=return_dict)
|
307 |
+
|
308 |
+
z = self.post_quant_conv(z)
|
309 |
+
dec = self.decoder(z)
|
310 |
+
|
311 |
+
if not return_dict:
|
312 |
+
return (dec,)
|
313 |
+
|
314 |
+
return DecoderOutput(sample=dec)
|
315 |
+
|
316 |
+
@apply_forward_hook
|
317 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
|
318 |
+
"""
|
319 |
+
Decode a batch of images/videos.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
323 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
324 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
328 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
329 |
+
returned.
|
330 |
+
|
331 |
+
"""
|
332 |
+
if self.use_slicing and z.shape[0] > 1:
|
333 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
334 |
+
decoded = torch.cat(decoded_slices)
|
335 |
+
else:
|
336 |
+
decoded = self._decode(z).sample
|
337 |
+
|
338 |
+
if not return_dict:
|
339 |
+
return (decoded,)
|
340 |
+
|
341 |
+
return DecoderOutput(sample=decoded)
|
342 |
+
|
343 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
344 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
345 |
+
for y in range(blend_extent):
|
346 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
347 |
+
return b
|
348 |
+
|
349 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
350 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
351 |
+
for x in range(blend_extent):
|
352 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
353 |
+
return b
|
354 |
+
|
355 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
356 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
357 |
+
for x in range(blend_extent):
|
358 |
+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
|
359 |
+
return b
|
360 |
+
|
361 |
+
def spatial_tiled_encode(
|
362 |
+
self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
|
363 |
+
) -> AutoencoderKLOutput:
|
364 |
+
r"""Encode a batch of images/videos using a tiled encoder.
|
365 |
+
|
366 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
367 |
+
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
|
368 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
369 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
370 |
+
output, but they should be much less noticeable.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
374 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
375 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
379 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
380 |
+
`tuple` is returned.
|
381 |
+
"""
|
382 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
383 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
384 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
385 |
+
|
386 |
+
# Split video into tiles and encode them separately.
|
387 |
+
rows = []
|
388 |
+
for i in range(0, x.shape[-2], overlap_size):
|
389 |
+
row = []
|
390 |
+
for j in range(0, x.shape[-1], overlap_size):
|
391 |
+
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
392 |
+
tile = self.encoder(tile)
|
393 |
+
tile = self.quant_conv(tile)
|
394 |
+
row.append(tile)
|
395 |
+
rows.append(row)
|
396 |
+
result_rows = []
|
397 |
+
for i, row in enumerate(rows):
|
398 |
+
result_row = []
|
399 |
+
for j, tile in enumerate(row):
|
400 |
+
# blend the above tile and the left tile
|
401 |
+
# to the current tile and add the current tile to the result row
|
402 |
+
if i > 0:
|
403 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
404 |
+
if j > 0:
|
405 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
406 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
407 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
408 |
+
|
409 |
+
moments = torch.cat(result_rows, dim=-2)
|
410 |
+
if return_moments:
|
411 |
+
return moments
|
412 |
+
|
413 |
+
posterior = DiagonalGaussianDistribution(moments)
|
414 |
+
if not return_dict:
|
415 |
+
return (posterior,)
|
416 |
+
|
417 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
418 |
+
|
419 |
+
def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
420 |
+
r"""
|
421 |
+
Decode a batch of images/videos using a tiled decoder.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
425 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
426 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
430 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
431 |
+
returned.
|
432 |
+
"""
|
433 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
434 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
435 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
436 |
+
|
437 |
+
# Split z into overlapping tiles and decode them separately.
|
438 |
+
# The tiles have an overlap to avoid seams between tiles.
|
439 |
+
rows = []
|
440 |
+
for i in range(0, z.shape[-2], overlap_size):
|
441 |
+
row = []
|
442 |
+
for j in range(0, z.shape[-1], overlap_size):
|
443 |
+
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
444 |
+
tile = self.post_quant_conv(tile)
|
445 |
+
decoded = self.decoder(tile)
|
446 |
+
row.append(decoded)
|
447 |
+
rows.append(row)
|
448 |
+
result_rows = []
|
449 |
+
for i, row in enumerate(rows):
|
450 |
+
result_row = []
|
451 |
+
for j, tile in enumerate(row):
|
452 |
+
# blend the above tile and the left tile
|
453 |
+
# to the current tile and add the current tile to the result row
|
454 |
+
if i > 0:
|
455 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
456 |
+
if j > 0:
|
457 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
458 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
459 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
460 |
+
|
461 |
+
dec = torch.cat(result_rows, dim=-2)
|
462 |
+
if not return_dict:
|
463 |
+
return (dec,)
|
464 |
+
|
465 |
+
return DecoderOutput(sample=dec)
|
466 |
+
|
467 |
+
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
468 |
+
|
469 |
+
B, C, T, H, W = x.shape
|
470 |
+
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
|
471 |
+
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
|
472 |
+
t_limit = self.tile_latent_min_tsize - blend_extent
|
473 |
+
|
474 |
+
# Split the video into tiles and encode them separately.
|
475 |
+
row = []
|
476 |
+
for i in range(0, T, overlap_size):
|
477 |
+
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
|
478 |
+
if self.use_spatial_tiling and (
|
479 |
+
tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
|
480 |
+
):
|
481 |
+
tile = self.spatial_tiled_encode(tile, return_moments=True)
|
482 |
+
else:
|
483 |
+
tile = self.encoder(tile)
|
484 |
+
tile = self.quant_conv(tile)
|
485 |
+
if i > 0:
|
486 |
+
tile = tile[:, :, 1:, :, :]
|
487 |
+
row.append(tile)
|
488 |
+
result_row = []
|
489 |
+
for i, tile in enumerate(row):
|
490 |
+
if i > 0:
|
491 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
492 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
493 |
+
else:
|
494 |
+
result_row.append(tile[:, :, : t_limit + 1, :, :])
|
495 |
+
|
496 |
+
moments = torch.cat(result_row, dim=2)
|
497 |
+
posterior = DiagonalGaussianDistribution(moments)
|
498 |
+
|
499 |
+
if not return_dict:
|
500 |
+
return (posterior,)
|
501 |
+
|
502 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
503 |
+
|
504 |
+
def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
505 |
+
# Split z into overlapping tiles and decode them separately.
|
506 |
+
|
507 |
+
B, C, T, H, W = z.shape
|
508 |
+
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
|
509 |
+
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
|
510 |
+
t_limit = self.tile_sample_min_tsize - blend_extent
|
511 |
+
|
512 |
+
row = []
|
513 |
+
for i in range(0, T, overlap_size):
|
514 |
+
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
|
515 |
+
if self.use_spatial_tiling and (
|
516 |
+
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
|
517 |
+
):
|
518 |
+
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
|
519 |
+
else:
|
520 |
+
tile = self.post_quant_conv(tile)
|
521 |
+
decoded = self.decoder(tile)
|
522 |
+
if i > 0:
|
523 |
+
decoded = decoded[:, :, 1:, :, :]
|
524 |
+
row.append(decoded)
|
525 |
+
result_row = []
|
526 |
+
for i, tile in enumerate(row):
|
527 |
+
if i > 0:
|
528 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
529 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
530 |
+
else:
|
531 |
+
result_row.append(tile[:, :, : t_limit + 1, :, :])
|
532 |
+
|
533 |
+
dec = torch.cat(result_row, dim=2)
|
534 |
+
if not return_dict:
|
535 |
+
return (dec,)
|
536 |
+
|
537 |
+
return DecoderOutput(sample=dec)
|
538 |
+
|
539 |
+
def forward(
|
540 |
+
self,
|
541 |
+
sample: torch.FloatTensor,
|
542 |
+
sample_posterior: bool = False,
|
543 |
+
return_dict: bool = True,
|
544 |
+
return_posterior: bool = False,
|
545 |
+
generator: Optional[torch.Generator] = None,
|
546 |
+
) -> Union[DecoderOutput2, torch.FloatTensor]:
|
547 |
+
r"""
|
548 |
+
Args:
|
549 |
+
sample (`torch.FloatTensor`): Input sample.
|
550 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
551 |
+
Whether to sample from the posterior.
|
552 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
553 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
554 |
+
"""
|
555 |
+
x = sample
|
556 |
+
posterior = self.encode(x).latent_dist
|
557 |
+
if sample_posterior:
|
558 |
+
z = posterior.sample(generator=generator)
|
559 |
+
else:
|
560 |
+
z = posterior.mode()
|
561 |
+
dec = self.decode(z).sample
|
562 |
+
|
563 |
+
if not return_dict:
|
564 |
+
if return_posterior:
|
565 |
+
return (dec, posterior)
|
566 |
+
else:
|
567 |
+
return (dec,)
|
568 |
+
if return_posterior:
|
569 |
+
return DecoderOutput2(sample=dec, posterior=posterior)
|
570 |
+
else:
|
571 |
+
return DecoderOutput2(sample=dec)
|
572 |
+
|
573 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
574 |
+
def fuse_qkv_projections(self):
|
575 |
+
"""
|
576 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
577 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
578 |
+
|
579 |
+
<Tip warning={true}>
|
580 |
+
|
581 |
+
This API is 🧪 experimental.
|
582 |
+
|
583 |
+
</Tip>
|
584 |
+
"""
|
585 |
+
self.original_attn_processors = None
|
586 |
+
|
587 |
+
for _, attn_processor in self.attn_processors.items():
|
588 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
589 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
590 |
+
|
591 |
+
self.original_attn_processors = self.attn_processors
|
592 |
+
|
593 |
+
for module in self.modules():
|
594 |
+
if isinstance(module, Attention):
|
595 |
+
module.fuse_projections(fuse=True)
|
596 |
+
|
597 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
598 |
+
def unfuse_qkv_projections(self):
|
599 |
+
"""Disables the fused QKV projection if enabled.
|
600 |
+
|
601 |
+
<Tip warning={true}>
|
602 |
+
|
603 |
+
This API is 🧪 experimental.
|
604 |
+
|
605 |
+
</Tip>
|
606 |
+
|
607 |
+
"""
|
608 |
+
if self.original_attn_processors is not None:
|
609 |
+
self.set_attn_processor(self.original_attn_processors)
|
hunyuan_model/embed_layers.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
from .helpers import to_2tuple
|
8 |
+
|
9 |
+
class PatchEmbed(nn.Module):
|
10 |
+
"""2D Image to Patch Embedding
|
11 |
+
|
12 |
+
Image to Patch Embedding using Conv2d
|
13 |
+
|
14 |
+
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
15 |
+
|
16 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
17 |
+
|
18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
19 |
+
|
20 |
+
Remove the _assert function in forward function to be compatible with multi-resolution images.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
patch_size=16,
|
26 |
+
in_chans=3,
|
27 |
+
embed_dim=768,
|
28 |
+
norm_layer=None,
|
29 |
+
flatten=True,
|
30 |
+
bias=True,
|
31 |
+
dtype=None,
|
32 |
+
device=None,
|
33 |
+
):
|
34 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
35 |
+
super().__init__()
|
36 |
+
patch_size = to_2tuple(patch_size)
|
37 |
+
self.patch_size = patch_size
|
38 |
+
self.flatten = flatten
|
39 |
+
|
40 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
|
41 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
42 |
+
if bias:
|
43 |
+
nn.init.zeros_(self.proj.bias)
|
44 |
+
|
45 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
x = self.proj(x)
|
49 |
+
if self.flatten:
|
50 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
51 |
+
x = self.norm(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class TextProjection(nn.Module):
|
56 |
+
"""
|
57 |
+
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
58 |
+
|
59 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
63 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
64 |
+
super().__init__()
|
65 |
+
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
|
66 |
+
self.act_1 = act_layer()
|
67 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
|
68 |
+
|
69 |
+
def forward(self, caption):
|
70 |
+
hidden_states = self.linear_1(caption)
|
71 |
+
hidden_states = self.act_1(hidden_states)
|
72 |
+
hidden_states = self.linear_2(hidden_states)
|
73 |
+
return hidden_states
|
74 |
+
|
75 |
+
|
76 |
+
def timestep_embedding(t, dim, max_period=10000):
|
77 |
+
"""
|
78 |
+
Create sinusoidal timestep embeddings.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
82 |
+
dim (int): the dimension of the output.
|
83 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
87 |
+
|
88 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
89 |
+
"""
|
90 |
+
half = dim // 2
|
91 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
92 |
+
args = t[:, None].float() * freqs[None]
|
93 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
94 |
+
if dim % 2:
|
95 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
96 |
+
return embedding
|
97 |
+
|
98 |
+
|
99 |
+
class TimestepEmbedder(nn.Module):
|
100 |
+
"""
|
101 |
+
Embeds scalar timesteps into vector representations.
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
hidden_size,
|
107 |
+
act_layer,
|
108 |
+
frequency_embedding_size=256,
|
109 |
+
max_period=10000,
|
110 |
+
out_size=None,
|
111 |
+
dtype=None,
|
112 |
+
device=None,
|
113 |
+
):
|
114 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
115 |
+
super().__init__()
|
116 |
+
self.frequency_embedding_size = frequency_embedding_size
|
117 |
+
self.max_period = max_period
|
118 |
+
if out_size is None:
|
119 |
+
out_size = hidden_size
|
120 |
+
|
121 |
+
self.mlp = nn.Sequential(
|
122 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
123 |
+
act_layer(),
|
124 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
125 |
+
)
|
126 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
127 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
128 |
+
|
129 |
+
def forward(self, t):
|
130 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
131 |
+
t_emb = self.mlp(t_freq)
|
132 |
+
return t_emb
|
hunyuan_model/helpers.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
|
3 |
+
from itertools import repeat
|
4 |
+
|
5 |
+
|
6 |
+
def _ntuple(n):
|
7 |
+
def parse(x):
|
8 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
9 |
+
x = tuple(x)
|
10 |
+
if len(x) == 1:
|
11 |
+
x = tuple(repeat(x[0], n))
|
12 |
+
return x
|
13 |
+
return tuple(repeat(x, n))
|
14 |
+
return parse
|
15 |
+
|
16 |
+
|
17 |
+
to_1tuple = _ntuple(1)
|
18 |
+
to_2tuple = _ntuple(2)
|
19 |
+
to_3tuple = _ntuple(3)
|
20 |
+
to_4tuple = _ntuple(4)
|
21 |
+
|
22 |
+
|
23 |
+
def as_tuple(x):
|
24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
25 |
+
return tuple(x)
|
26 |
+
if x is None or isinstance(x, (int, float, str)):
|
27 |
+
return (x,)
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown type {type(x)}")
|
30 |
+
|
31 |
+
|
32 |
+
def as_list_of_2tuple(x):
|
33 |
+
x = as_tuple(x)
|
34 |
+
if len(x) == 1:
|
35 |
+
x = (x[0], x[0])
|
36 |
+
assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
|
37 |
+
lst = []
|
38 |
+
for i in range(0, len(x), 2):
|
39 |
+
lst.append((x[i], x[i + 1]))
|
40 |
+
return lst
|
hunyuan_model/mlp_layers.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from timm library:
|
2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .modulate_layers import modulate
|
10 |
+
from .helpers import to_2tuple
|
11 |
+
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels,
|
19 |
+
hidden_channels=None,
|
20 |
+
out_features=None,
|
21 |
+
act_layer=nn.GELU,
|
22 |
+
norm_layer=None,
|
23 |
+
bias=True,
|
24 |
+
drop=0.0,
|
25 |
+
use_conv=False,
|
26 |
+
device=None,
|
27 |
+
dtype=None,
|
28 |
+
):
|
29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
30 |
+
super().__init__()
|
31 |
+
out_features = out_features or in_channels
|
32 |
+
hidden_channels = hidden_channels or in_channels
|
33 |
+
bias = to_2tuple(bias)
|
34 |
+
drop_probs = to_2tuple(drop)
|
35 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
36 |
+
|
37 |
+
self.fc1 = linear_layer(
|
38 |
+
in_channels, hidden_channels, bias=bias[0], **factory_kwargs
|
39 |
+
)
|
40 |
+
self.act = act_layer()
|
41 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
42 |
+
self.norm = (
|
43 |
+
norm_layer(hidden_channels, **factory_kwargs)
|
44 |
+
if norm_layer is not None
|
45 |
+
else nn.Identity()
|
46 |
+
)
|
47 |
+
self.fc2 = linear_layer(
|
48 |
+
hidden_channels, out_features, bias=bias[1], **factory_kwargs
|
49 |
+
)
|
50 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.fc1(x)
|
54 |
+
x = self.act(x)
|
55 |
+
x = self.drop1(x)
|
56 |
+
x = self.norm(x)
|
57 |
+
x = self.fc2(x)
|
58 |
+
x = self.drop2(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
#
|
63 |
+
class MLPEmbedder(nn.Module):
|
64 |
+
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
|
65 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
66 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
67 |
+
super().__init__()
|
68 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
69 |
+
self.silu = nn.SiLU()
|
70 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
74 |
+
|
75 |
+
|
76 |
+
class FinalLayer(nn.Module):
|
77 |
+
"""The final layer of DiT."""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
|
81 |
+
):
|
82 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
# Just use LayerNorm for the final layer
|
86 |
+
self.norm_final = nn.LayerNorm(
|
87 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
88 |
+
)
|
89 |
+
if isinstance(patch_size, int):
|
90 |
+
self.linear = nn.Linear(
|
91 |
+
hidden_size,
|
92 |
+
patch_size * patch_size * out_channels,
|
93 |
+
bias=True,
|
94 |
+
**factory_kwargs
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
self.linear = nn.Linear(
|
98 |
+
hidden_size,
|
99 |
+
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
|
100 |
+
bias=True,
|
101 |
+
)
|
102 |
+
nn.init.zeros_(self.linear.weight)
|
103 |
+
nn.init.zeros_(self.linear.bias)
|
104 |
+
|
105 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
106 |
+
self.adaLN_modulation = nn.Sequential(
|
107 |
+
act_layer(),
|
108 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
109 |
+
)
|
110 |
+
# Zero-initialize the modulation
|
111 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
112 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
113 |
+
|
114 |
+
def forward(self, x, c):
|
115 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
116 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
117 |
+
x = self.linear(x)
|
118 |
+
return x
|
hunyuan_model/models.py
ADDED
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, List, Tuple, Optional, Union, Dict
|
3 |
+
import accelerate
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.utils.checkpoint import checkpoint
|
9 |
+
|
10 |
+
from .activation_layers import get_activation_layer
|
11 |
+
from .norm_layers import get_norm_layer
|
12 |
+
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
|
13 |
+
from .attention import attention, parallel_attention, get_cu_seqlens
|
14 |
+
from .posemb_layers import apply_rotary_emb
|
15 |
+
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
|
16 |
+
from .modulate_layers import ModulateDiT, modulate, apply_gate
|
17 |
+
from .token_refiner import SingleTokenRefiner
|
18 |
+
from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
|
19 |
+
from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
|
20 |
+
|
21 |
+
from utils.safetensors_utils import MemoryEfficientSafeOpen
|
22 |
+
|
23 |
+
|
24 |
+
class MMDoubleStreamBlock(nn.Module):
|
25 |
+
"""
|
26 |
+
A multimodal dit block with seperate modulation for
|
27 |
+
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
28 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
hidden_size: int,
|
34 |
+
heads_num: int,
|
35 |
+
mlp_width_ratio: float,
|
36 |
+
mlp_act_type: str = "gelu_tanh",
|
37 |
+
qk_norm: bool = True,
|
38 |
+
qk_norm_type: str = "rms",
|
39 |
+
qkv_bias: bool = False,
|
40 |
+
dtype: Optional[torch.dtype] = None,
|
41 |
+
device: Optional[torch.device] = None,
|
42 |
+
attn_mode: str = "flash",
|
43 |
+
split_attn: bool = False,
|
44 |
+
):
|
45 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
46 |
+
super().__init__()
|
47 |
+
self.attn_mode = attn_mode
|
48 |
+
self.split_attn = split_attn
|
49 |
+
|
50 |
+
self.deterministic = False
|
51 |
+
self.heads_num = heads_num
|
52 |
+
head_dim = hidden_size // heads_num
|
53 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
54 |
+
|
55 |
+
self.img_mod = ModulateDiT(
|
56 |
+
hidden_size,
|
57 |
+
factor=6,
|
58 |
+
act_layer=get_activation_layer("silu"),
|
59 |
+
**factory_kwargs,
|
60 |
+
)
|
61 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
62 |
+
|
63 |
+
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
64 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
65 |
+
self.img_attn_q_norm = (
|
66 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
67 |
+
)
|
68 |
+
self.img_attn_k_norm = (
|
69 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
70 |
+
)
|
71 |
+
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
72 |
+
|
73 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
74 |
+
self.img_mlp = MLP(
|
75 |
+
hidden_size,
|
76 |
+
mlp_hidden_dim,
|
77 |
+
act_layer=get_activation_layer(mlp_act_type),
|
78 |
+
bias=True,
|
79 |
+
**factory_kwargs,
|
80 |
+
)
|
81 |
+
|
82 |
+
self.txt_mod = ModulateDiT(
|
83 |
+
hidden_size,
|
84 |
+
factor=6,
|
85 |
+
act_layer=get_activation_layer("silu"),
|
86 |
+
**factory_kwargs,
|
87 |
+
)
|
88 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
89 |
+
|
90 |
+
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
91 |
+
self.txt_attn_q_norm = (
|
92 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
93 |
+
)
|
94 |
+
self.txt_attn_k_norm = (
|
95 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
96 |
+
)
|
97 |
+
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
98 |
+
|
99 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
100 |
+
self.txt_mlp = MLP(
|
101 |
+
hidden_size,
|
102 |
+
mlp_hidden_dim,
|
103 |
+
act_layer=get_activation_layer(mlp_act_type),
|
104 |
+
bias=True,
|
105 |
+
**factory_kwargs,
|
106 |
+
)
|
107 |
+
self.hybrid_seq_parallel_attn = None
|
108 |
+
|
109 |
+
self.gradient_checkpointing = False
|
110 |
+
|
111 |
+
def enable_deterministic(self):
|
112 |
+
self.deterministic = True
|
113 |
+
|
114 |
+
def disable_deterministic(self):
|
115 |
+
self.deterministic = False
|
116 |
+
|
117 |
+
def enable_gradient_checkpointing(self):
|
118 |
+
self.gradient_checkpointing = True
|
119 |
+
|
120 |
+
def disable_gradient_checkpointing(self):
|
121 |
+
self.gradient_checkpointing = False
|
122 |
+
|
123 |
+
def _forward(
|
124 |
+
self,
|
125 |
+
img: torch.Tensor,
|
126 |
+
txt: torch.Tensor,
|
127 |
+
vec: torch.Tensor,
|
128 |
+
attn_mask: Optional[torch.Tensor] = None,
|
129 |
+
total_len: Optional[torch.Tensor] = None,
|
130 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
131 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
132 |
+
max_seqlen_q: Optional[int] = None,
|
133 |
+
max_seqlen_kv: Optional[int] = None,
|
134 |
+
freqs_cis: tuple = None,
|
135 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
136 |
+
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
|
137 |
+
6, dim=-1
|
138 |
+
)
|
139 |
+
(txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
|
140 |
+
6, dim=-1
|
141 |
+
)
|
142 |
+
|
143 |
+
# Prepare image for attention.
|
144 |
+
img_modulated = self.img_norm1(img)
|
145 |
+
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
146 |
+
img_qkv = self.img_attn_qkv(img_modulated)
|
147 |
+
img_modulated = None
|
148 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
149 |
+
img_qkv = None
|
150 |
+
# Apply QK-Norm if needed
|
151 |
+
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
152 |
+
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
153 |
+
|
154 |
+
# Apply RoPE if needed.
|
155 |
+
if freqs_cis is not None:
|
156 |
+
img_q_shape = img_q.shape
|
157 |
+
img_k_shape = img_k.shape
|
158 |
+
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
159 |
+
assert (
|
160 |
+
img_q.shape == img_q_shape and img_k.shape == img_k_shape
|
161 |
+
), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
|
162 |
+
# img_q, img_k = img_qq, img_kk
|
163 |
+
|
164 |
+
# Prepare txt for attention.
|
165 |
+
txt_modulated = self.txt_norm1(txt)
|
166 |
+
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
167 |
+
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
168 |
+
txt_modulated = None
|
169 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
170 |
+
txt_qkv = None
|
171 |
+
# Apply QK-Norm if needed.
|
172 |
+
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
173 |
+
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
174 |
+
|
175 |
+
# Run actual attention.
|
176 |
+
img_q_len = img_q.shape[1]
|
177 |
+
img_kv_len = img_k.shape[1]
|
178 |
+
batch_size = img_k.shape[0]
|
179 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
180 |
+
img_q = txt_q = None
|
181 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
182 |
+
img_k = txt_k = None
|
183 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
184 |
+
img_v = txt_v = None
|
185 |
+
|
186 |
+
assert (
|
187 |
+
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
|
188 |
+
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
|
189 |
+
|
190 |
+
# attention computation start
|
191 |
+
if not self.hybrid_seq_parallel_attn:
|
192 |
+
l = [q, k, v]
|
193 |
+
q = k = v = None
|
194 |
+
attn = attention(
|
195 |
+
l,
|
196 |
+
mode=self.attn_mode,
|
197 |
+
attn_mask=attn_mask,
|
198 |
+
total_len=total_len,
|
199 |
+
cu_seqlens_q=cu_seqlens_q,
|
200 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
201 |
+
max_seqlen_q=max_seqlen_q,
|
202 |
+
max_seqlen_kv=max_seqlen_kv,
|
203 |
+
batch_size=batch_size,
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
attn = parallel_attention(
|
207 |
+
self.hybrid_seq_parallel_attn,
|
208 |
+
q,
|
209 |
+
k,
|
210 |
+
v,
|
211 |
+
img_q_len=img_q_len,
|
212 |
+
img_kv_len=img_kv_len,
|
213 |
+
cu_seqlens_q=cu_seqlens_q,
|
214 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
215 |
+
)
|
216 |
+
|
217 |
+
# attention computation end
|
218 |
+
|
219 |
+
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
|
220 |
+
attn = None
|
221 |
+
|
222 |
+
# Calculate the img bloks.
|
223 |
+
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
224 |
+
img_attn = None
|
225 |
+
img = img + apply_gate(
|
226 |
+
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
227 |
+
gate=img_mod2_gate,
|
228 |
+
)
|
229 |
+
|
230 |
+
# Calculate the txt bloks.
|
231 |
+
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
232 |
+
txt_attn = None
|
233 |
+
txt = txt + apply_gate(
|
234 |
+
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
235 |
+
gate=txt_mod2_gate,
|
236 |
+
)
|
237 |
+
|
238 |
+
return img, txt
|
239 |
+
|
240 |
+
# def forward(
|
241 |
+
# self,
|
242 |
+
# img: torch.Tensor,
|
243 |
+
# txt: torch.Tensor,
|
244 |
+
# vec: torch.Tensor,
|
245 |
+
# attn_mask: Optional[torch.Tensor] = None,
|
246 |
+
# cu_seqlens_q: Optional[torch.Tensor] = None,
|
247 |
+
# cu_seqlens_kv: Optional[torch.Tensor] = None,
|
248 |
+
# max_seqlen_q: Optional[int] = None,
|
249 |
+
# max_seqlen_kv: Optional[int] = None,
|
250 |
+
# freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
251 |
+
# ) -> Tuple[torch.Tensor, torch.Tensor]:
|
252 |
+
def forward(self, *args, **kwargs):
|
253 |
+
if self.training and self.gradient_checkpointing:
|
254 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
255 |
+
else:
|
256 |
+
return self._forward(*args, **kwargs)
|
257 |
+
|
258 |
+
|
259 |
+
class MMSingleStreamBlock(nn.Module):
|
260 |
+
"""
|
261 |
+
A DiT block with parallel linear layers as described in
|
262 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
263 |
+
Also refer to (SD3): https://arxiv.org/abs/2403.03206
|
264 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
hidden_size: int,
|
270 |
+
heads_num: int,
|
271 |
+
mlp_width_ratio: float = 4.0,
|
272 |
+
mlp_act_type: str = "gelu_tanh",
|
273 |
+
qk_norm: bool = True,
|
274 |
+
qk_norm_type: str = "rms",
|
275 |
+
qk_scale: float = None,
|
276 |
+
dtype: Optional[torch.dtype] = None,
|
277 |
+
device: Optional[torch.device] = None,
|
278 |
+
attn_mode: str = "flash",
|
279 |
+
split_attn: bool = False,
|
280 |
+
):
|
281 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
282 |
+
super().__init__()
|
283 |
+
self.attn_mode = attn_mode
|
284 |
+
self.split_attn = split_attn
|
285 |
+
|
286 |
+
self.deterministic = False
|
287 |
+
self.hidden_size = hidden_size
|
288 |
+
self.heads_num = heads_num
|
289 |
+
head_dim = hidden_size // heads_num
|
290 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
291 |
+
self.mlp_hidden_dim = mlp_hidden_dim
|
292 |
+
self.scale = qk_scale or head_dim**-0.5
|
293 |
+
|
294 |
+
# qkv and mlp_in
|
295 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
|
296 |
+
# proj and mlp_out
|
297 |
+
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
|
298 |
+
|
299 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
300 |
+
self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
301 |
+
self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
302 |
+
|
303 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
304 |
+
|
305 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
306 |
+
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
307 |
+
self.hybrid_seq_parallel_attn = None
|
308 |
+
|
309 |
+
self.gradient_checkpointing = False
|
310 |
+
|
311 |
+
def enable_deterministic(self):
|
312 |
+
self.deterministic = True
|
313 |
+
|
314 |
+
def disable_deterministic(self):
|
315 |
+
self.deterministic = False
|
316 |
+
|
317 |
+
def enable_gradient_checkpointing(self):
|
318 |
+
self.gradient_checkpointing = True
|
319 |
+
|
320 |
+
def disable_gradient_checkpointing(self):
|
321 |
+
self.gradient_checkpointing = False
|
322 |
+
|
323 |
+
def _forward(
|
324 |
+
self,
|
325 |
+
x: torch.Tensor,
|
326 |
+
vec: torch.Tensor,
|
327 |
+
txt_len: int,
|
328 |
+
attn_mask: Optional[torch.Tensor] = None,
|
329 |
+
total_len: Optional[torch.Tensor] = None,
|
330 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
331 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
332 |
+
max_seqlen_q: Optional[int] = None,
|
333 |
+
max_seqlen_kv: Optional[int] = None,
|
334 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
335 |
+
) -> torch.Tensor:
|
336 |
+
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
337 |
+
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
338 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
339 |
+
x_mod = None
|
340 |
+
# mlp = mlp.to("cpu", non_blocking=True)
|
341 |
+
# clean_memory_on_device(x.device)
|
342 |
+
|
343 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
344 |
+
qkv = None
|
345 |
+
|
346 |
+
# Apply QK-Norm if needed.
|
347 |
+
q = self.q_norm(q).to(v)
|
348 |
+
k = self.k_norm(k).to(v)
|
349 |
+
|
350 |
+
# Apply RoPE if needed.
|
351 |
+
if freqs_cis is not None:
|
352 |
+
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
353 |
+
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
354 |
+
q = k = None
|
355 |
+
img_q_shape = img_q.shape
|
356 |
+
img_k_shape = img_k.shape
|
357 |
+
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
358 |
+
assert (
|
359 |
+
img_q.shape == img_q_shape and img_k_shape == img_k.shape
|
360 |
+
), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
|
361 |
+
# img_q, img_k = img_qq, img_kk
|
362 |
+
# del img_qq, img_kk
|
363 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
364 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
365 |
+
del img_q, txt_q, img_k, txt_k
|
366 |
+
|
367 |
+
# Compute attention.
|
368 |
+
assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
|
369 |
+
|
370 |
+
# attention computation start
|
371 |
+
if not self.hybrid_seq_parallel_attn:
|
372 |
+
l = [q, k, v]
|
373 |
+
q = k = v = None
|
374 |
+
attn = attention(
|
375 |
+
l,
|
376 |
+
mode=self.attn_mode,
|
377 |
+
attn_mask=attn_mask,
|
378 |
+
total_len=total_len,
|
379 |
+
cu_seqlens_q=cu_seqlens_q,
|
380 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
381 |
+
max_seqlen_q=max_seqlen_q,
|
382 |
+
max_seqlen_kv=max_seqlen_kv,
|
383 |
+
batch_size=x.shape[0],
|
384 |
+
)
|
385 |
+
else:
|
386 |
+
attn = parallel_attention(
|
387 |
+
self.hybrid_seq_parallel_attn,
|
388 |
+
q,
|
389 |
+
k,
|
390 |
+
v,
|
391 |
+
img_q_len=img_q.shape[1],
|
392 |
+
img_kv_len=img_k.shape[1],
|
393 |
+
cu_seqlens_q=cu_seqlens_q,
|
394 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
395 |
+
)
|
396 |
+
# attention computation end
|
397 |
+
|
398 |
+
# Compute activation in mlp stream, cat again and run second linear layer.
|
399 |
+
# mlp = mlp.to(x.device)
|
400 |
+
mlp = self.mlp_act(mlp)
|
401 |
+
attn_mlp = torch.cat((attn, mlp), 2)
|
402 |
+
attn = None
|
403 |
+
mlp = None
|
404 |
+
output = self.linear2(attn_mlp)
|
405 |
+
attn_mlp = None
|
406 |
+
return x + apply_gate(output, gate=mod_gate)
|
407 |
+
|
408 |
+
# def forward(
|
409 |
+
# self,
|
410 |
+
# x: torch.Tensor,
|
411 |
+
# vec: torch.Tensor,
|
412 |
+
# txt_len: int,
|
413 |
+
# attn_mask: Optional[torch.Tensor] = None,
|
414 |
+
# cu_seqlens_q: Optional[torch.Tensor] = None,
|
415 |
+
# cu_seqlens_kv: Optional[torch.Tensor] = None,
|
416 |
+
# max_seqlen_q: Optional[int] = None,
|
417 |
+
# max_seqlen_kv: Optional[int] = None,
|
418 |
+
# freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
419 |
+
# ) -> torch.Tensor:
|
420 |
+
def forward(self, *args, **kwargs):
|
421 |
+
if self.training and self.gradient_checkpointing:
|
422 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
423 |
+
else:
|
424 |
+
return self._forward(*args, **kwargs)
|
425 |
+
|
426 |
+
|
427 |
+
class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
|
428 |
+
"""
|
429 |
+
HunyuanVideo Transformer backbone
|
430 |
+
|
431 |
+
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
|
432 |
+
|
433 |
+
Reference:
|
434 |
+
[1] Flux.1: https://github.com/black-forest-labs/flux
|
435 |
+
[2] MMDiT: http://arxiv.org/abs/2403.03206
|
436 |
+
|
437 |
+
Parameters
|
438 |
+
----------
|
439 |
+
args: argparse.Namespace
|
440 |
+
The arguments parsed by argparse.
|
441 |
+
patch_size: list
|
442 |
+
The size of the patch.
|
443 |
+
in_channels: int
|
444 |
+
The number of input channels.
|
445 |
+
out_channels: int
|
446 |
+
The number of output channels.
|
447 |
+
hidden_size: int
|
448 |
+
The hidden size of the transformer backbone.
|
449 |
+
heads_num: int
|
450 |
+
The number of attention heads.
|
451 |
+
mlp_width_ratio: float
|
452 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
453 |
+
mlp_act_type: str
|
454 |
+
The activation function of the MLP in the transformer block.
|
455 |
+
depth_double_blocks: int
|
456 |
+
The number of transformer blocks in the double blocks.
|
457 |
+
depth_single_blocks: int
|
458 |
+
The number of transformer blocks in the single blocks.
|
459 |
+
rope_dim_list: list
|
460 |
+
The dimension of the rotary embedding for t, h, w.
|
461 |
+
qkv_bias: bool
|
462 |
+
Whether to use bias in the qkv linear layer.
|
463 |
+
qk_norm: bool
|
464 |
+
Whether to use qk norm.
|
465 |
+
qk_norm_type: str
|
466 |
+
The type of qk norm.
|
467 |
+
guidance_embed: bool
|
468 |
+
Whether to use guidance embedding for distillation.
|
469 |
+
text_projection: str
|
470 |
+
The type of the text projection, default is single_refiner.
|
471 |
+
use_attention_mask: bool
|
472 |
+
Whether to use attention mask for text encoder.
|
473 |
+
dtype: torch.dtype
|
474 |
+
The dtype of the model.
|
475 |
+
device: torch.device
|
476 |
+
The device of the model.
|
477 |
+
attn_mode: str
|
478 |
+
The mode of the attention, default is flash.
|
479 |
+
split_attn: bool
|
480 |
+
Whether to use split attention (make attention as batch size 1).
|
481 |
+
"""
|
482 |
+
|
483 |
+
# @register_to_config
|
484 |
+
def __init__(
|
485 |
+
self,
|
486 |
+
text_states_dim: int,
|
487 |
+
text_states_dim_2: int,
|
488 |
+
patch_size: list = [1, 2, 2],
|
489 |
+
in_channels: int = 4, # Should be VAE.config.latent_channels.
|
490 |
+
out_channels: int = None,
|
491 |
+
hidden_size: int = 3072,
|
492 |
+
heads_num: int = 24,
|
493 |
+
mlp_width_ratio: float = 4.0,
|
494 |
+
mlp_act_type: str = "gelu_tanh",
|
495 |
+
mm_double_blocks_depth: int = 20,
|
496 |
+
mm_single_blocks_depth: int = 40,
|
497 |
+
rope_dim_list: List[int] = [16, 56, 56],
|
498 |
+
qkv_bias: bool = True,
|
499 |
+
qk_norm: bool = True,
|
500 |
+
qk_norm_type: str = "rms",
|
501 |
+
guidance_embed: bool = False, # For modulation.
|
502 |
+
text_projection: str = "single_refiner",
|
503 |
+
use_attention_mask: bool = True,
|
504 |
+
dtype: Optional[torch.dtype] = None,
|
505 |
+
device: Optional[torch.device] = None,
|
506 |
+
attn_mode: str = "flash",
|
507 |
+
split_attn: bool = False,
|
508 |
+
):
|
509 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
510 |
+
super().__init__()
|
511 |
+
|
512 |
+
self.patch_size = patch_size
|
513 |
+
self.in_channels = in_channels
|
514 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
515 |
+
self.unpatchify_channels = self.out_channels
|
516 |
+
self.guidance_embed = guidance_embed
|
517 |
+
self.rope_dim_list = rope_dim_list
|
518 |
+
|
519 |
+
# Text projection. Default to linear projection.
|
520 |
+
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
|
521 |
+
self.use_attention_mask = use_attention_mask
|
522 |
+
self.text_projection = text_projection
|
523 |
+
|
524 |
+
self.text_states_dim = text_states_dim
|
525 |
+
self.text_states_dim_2 = text_states_dim_2
|
526 |
+
|
527 |
+
if hidden_size % heads_num != 0:
|
528 |
+
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
|
529 |
+
pe_dim = hidden_size // heads_num
|
530 |
+
if sum(rope_dim_list) != pe_dim:
|
531 |
+
raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
|
532 |
+
self.hidden_size = hidden_size
|
533 |
+
self.heads_num = heads_num
|
534 |
+
|
535 |
+
self.attn_mode = attn_mode
|
536 |
+
self.split_attn = split_attn
|
537 |
+
print(f"Using {self.attn_mode} attention mode, split_attn: {self.split_attn}")
|
538 |
+
|
539 |
+
# image projection
|
540 |
+
self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
|
541 |
+
|
542 |
+
# text projection
|
543 |
+
if self.text_projection == "linear":
|
544 |
+
self.txt_in = TextProjection(
|
545 |
+
self.text_states_dim,
|
546 |
+
self.hidden_size,
|
547 |
+
get_activation_layer("silu"),
|
548 |
+
**factory_kwargs,
|
549 |
+
)
|
550 |
+
elif self.text_projection == "single_refiner":
|
551 |
+
self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
|
552 |
+
else:
|
553 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
554 |
+
|
555 |
+
# time modulation
|
556 |
+
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
|
557 |
+
|
558 |
+
# text modulation
|
559 |
+
self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
|
560 |
+
|
561 |
+
# guidance modulation
|
562 |
+
self.guidance_in = (
|
563 |
+
TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
|
564 |
+
)
|
565 |
+
|
566 |
+
# double blocks
|
567 |
+
self.double_blocks = nn.ModuleList(
|
568 |
+
[
|
569 |
+
MMDoubleStreamBlock(
|
570 |
+
self.hidden_size,
|
571 |
+
self.heads_num,
|
572 |
+
mlp_width_ratio=mlp_width_ratio,
|
573 |
+
mlp_act_type=mlp_act_type,
|
574 |
+
qk_norm=qk_norm,
|
575 |
+
qk_norm_type=qk_norm_type,
|
576 |
+
qkv_bias=qkv_bias,
|
577 |
+
attn_mode=attn_mode,
|
578 |
+
split_attn=split_attn,
|
579 |
+
**factory_kwargs,
|
580 |
+
)
|
581 |
+
for _ in range(mm_double_blocks_depth)
|
582 |
+
]
|
583 |
+
)
|
584 |
+
|
585 |
+
# single blocks
|
586 |
+
self.single_blocks = nn.ModuleList(
|
587 |
+
[
|
588 |
+
MMSingleStreamBlock(
|
589 |
+
self.hidden_size,
|
590 |
+
self.heads_num,
|
591 |
+
mlp_width_ratio=mlp_width_ratio,
|
592 |
+
mlp_act_type=mlp_act_type,
|
593 |
+
qk_norm=qk_norm,
|
594 |
+
qk_norm_type=qk_norm_type,
|
595 |
+
attn_mode=attn_mode,
|
596 |
+
split_attn=split_attn,
|
597 |
+
**factory_kwargs,
|
598 |
+
)
|
599 |
+
for _ in range(mm_single_blocks_depth)
|
600 |
+
]
|
601 |
+
)
|
602 |
+
|
603 |
+
self.final_layer = FinalLayer(
|
604 |
+
self.hidden_size,
|
605 |
+
self.patch_size,
|
606 |
+
self.out_channels,
|
607 |
+
get_activation_layer("silu"),
|
608 |
+
**factory_kwargs,
|
609 |
+
)
|
610 |
+
|
611 |
+
self.gradient_checkpointing = False
|
612 |
+
self.blocks_to_swap = None
|
613 |
+
self.offloader_double = None
|
614 |
+
self.offloader_single = None
|
615 |
+
self._enable_img_in_txt_in_offloading = False
|
616 |
+
|
617 |
+
@property
|
618 |
+
def device(self):
|
619 |
+
return next(self.parameters()).device
|
620 |
+
|
621 |
+
@property
|
622 |
+
def dtype(self):
|
623 |
+
return next(self.parameters()).dtype
|
624 |
+
|
625 |
+
def enable_gradient_checkpointing(self):
|
626 |
+
self.gradient_checkpointing = True
|
627 |
+
|
628 |
+
self.txt_in.enable_gradient_checkpointing()
|
629 |
+
|
630 |
+
for block in self.double_blocks + self.single_blocks:
|
631 |
+
block.enable_gradient_checkpointing()
|
632 |
+
|
633 |
+
print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
|
634 |
+
|
635 |
+
def disable_gradient_checkpointing(self):
|
636 |
+
self.gradient_checkpointing = False
|
637 |
+
|
638 |
+
self.txt_in.disable_gradient_checkpointing()
|
639 |
+
|
640 |
+
for block in self.double_blocks + self.single_blocks:
|
641 |
+
block.disable_gradient_checkpointing()
|
642 |
+
|
643 |
+
print(f"HYVideoDiffusionTransformer: Gradient checkpointing disabled.")
|
644 |
+
|
645 |
+
def enable_img_in_txt_in_offloading(self):
|
646 |
+
self._enable_img_in_txt_in_offloading = True
|
647 |
+
|
648 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
|
649 |
+
self.blocks_to_swap = num_blocks
|
650 |
+
self.num_double_blocks = len(self.double_blocks)
|
651 |
+
self.num_single_blocks = len(self.single_blocks)
|
652 |
+
double_blocks_to_swap = num_blocks // 2
|
653 |
+
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
|
654 |
+
|
655 |
+
assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
|
656 |
+
f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
|
657 |
+
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
|
658 |
+
)
|
659 |
+
|
660 |
+
self.offloader_double = ModelOffloader(
|
661 |
+
"double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
|
662 |
+
)
|
663 |
+
self.offloader_single = ModelOffloader(
|
664 |
+
"single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
|
665 |
+
)
|
666 |
+
print(
|
667 |
+
f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
668 |
+
)
|
669 |
+
|
670 |
+
def switch_block_swap_for_inference(self):
|
671 |
+
if self.blocks_to_swap:
|
672 |
+
self.offloader_double.set_forward_only(True)
|
673 |
+
self.offloader_single.set_forward_only(True)
|
674 |
+
self.prepare_block_swap_before_forward()
|
675 |
+
print(f"HYVideoDiffusionTransformer: Block swap set to forward only.")
|
676 |
+
|
677 |
+
def switch_block_swap_for_training(self):
|
678 |
+
if self.blocks_to_swap:
|
679 |
+
self.offloader_double.set_forward_only(False)
|
680 |
+
self.offloader_single.set_forward_only(False)
|
681 |
+
self.prepare_block_swap_before_forward()
|
682 |
+
print(f"HYVideoDiffusionTransformer: Block swap set to forward and backward.")
|
683 |
+
|
684 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
685 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
686 |
+
if self.blocks_to_swap:
|
687 |
+
save_double_blocks = self.double_blocks
|
688 |
+
save_single_blocks = self.single_blocks
|
689 |
+
self.double_blocks = None
|
690 |
+
self.single_blocks = None
|
691 |
+
|
692 |
+
self.to(device)
|
693 |
+
|
694 |
+
if self.blocks_to_swap:
|
695 |
+
self.double_blocks = save_double_blocks
|
696 |
+
self.single_blocks = save_single_blocks
|
697 |
+
|
698 |
+
def prepare_block_swap_before_forward(self):
|
699 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
700 |
+
return
|
701 |
+
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
702 |
+
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
703 |
+
|
704 |
+
def enable_deterministic(self):
|
705 |
+
for block in self.double_blocks:
|
706 |
+
block.enable_deterministic()
|
707 |
+
for block in self.single_blocks:
|
708 |
+
block.enable_deterministic()
|
709 |
+
|
710 |
+
def disable_deterministic(self):
|
711 |
+
for block in self.double_blocks:
|
712 |
+
block.disable_deterministic()
|
713 |
+
for block in self.single_blocks:
|
714 |
+
block.disable_deterministic()
|
715 |
+
|
716 |
+
def forward(
|
717 |
+
self,
|
718 |
+
x: torch.Tensor,
|
719 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
720 |
+
text_states: torch.Tensor = None,
|
721 |
+
text_mask: torch.Tensor = None, # Now we don't use it.
|
722 |
+
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
723 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
724 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
725 |
+
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
726 |
+
return_dict: bool = True,
|
727 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
728 |
+
out = {}
|
729 |
+
img = x
|
730 |
+
txt = text_states
|
731 |
+
_, _, ot, oh, ow = x.shape
|
732 |
+
tt, th, tw = (
|
733 |
+
ot // self.patch_size[0],
|
734 |
+
oh // self.patch_size[1],
|
735 |
+
ow // self.patch_size[2],
|
736 |
+
)
|
737 |
+
|
738 |
+
# Prepare modulation vectors.
|
739 |
+
vec = self.time_in(t)
|
740 |
+
|
741 |
+
# text modulation
|
742 |
+
vec = vec + self.vector_in(text_states_2)
|
743 |
+
|
744 |
+
# guidance modulation
|
745 |
+
if self.guidance_embed:
|
746 |
+
if guidance is None:
|
747 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
748 |
+
|
749 |
+
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
|
750 |
+
vec = vec + self.guidance_in(guidance)
|
751 |
+
|
752 |
+
# Embed image and text.
|
753 |
+
if self._enable_img_in_txt_in_offloading:
|
754 |
+
self.img_in.to(x.device, non_blocking=True)
|
755 |
+
self.txt_in.to(x.device, non_blocking=True)
|
756 |
+
synchronize_device(x.device)
|
757 |
+
|
758 |
+
img = self.img_in(img)
|
759 |
+
if self.text_projection == "linear":
|
760 |
+
txt = self.txt_in(txt)
|
761 |
+
elif self.text_projection == "single_refiner":
|
762 |
+
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
763 |
+
else:
|
764 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
765 |
+
|
766 |
+
if self._enable_img_in_txt_in_offloading:
|
767 |
+
self.img_in.to(torch.device("cpu"), non_blocking=True)
|
768 |
+
self.txt_in.to(torch.device("cpu"), non_blocking=True)
|
769 |
+
synchronize_device(x.device)
|
770 |
+
clean_memory_on_device(x.device)
|
771 |
+
|
772 |
+
txt_seq_len = txt.shape[1]
|
773 |
+
img_seq_len = img.shape[1]
|
774 |
+
|
775 |
+
# Compute cu_squlens and max_seqlen for flash attention
|
776 |
+
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
|
777 |
+
cu_seqlens_kv = cu_seqlens_q
|
778 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
779 |
+
max_seqlen_kv = max_seqlen_q
|
780 |
+
|
781 |
+
attn_mask = total_len = None
|
782 |
+
if self.split_attn or self.attn_mode == "torch":
|
783 |
+
# calculate text length and total length
|
784 |
+
text_len = text_mask.sum(dim=1) # (bs, )
|
785 |
+
total_len = img_seq_len + text_len # (bs, )
|
786 |
+
if self.attn_mode == "torch" and not self.split_attn:
|
787 |
+
# initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
|
788 |
+
bs = img.shape[0]
|
789 |
+
attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
|
790 |
+
|
791 |
+
# set attention mask with total_len
|
792 |
+
for i in range(bs):
|
793 |
+
attn_mask[i, :, : total_len[i], : total_len[i]] = True
|
794 |
+
total_len = None # means we don't use split_attn
|
795 |
+
|
796 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
797 |
+
# --------------------- Pass through DiT blocks ------------------------
|
798 |
+
for block_idx, block in enumerate(self.double_blocks):
|
799 |
+
double_block_args = [
|
800 |
+
img,
|
801 |
+
txt,
|
802 |
+
vec,
|
803 |
+
attn_mask,
|
804 |
+
total_len,
|
805 |
+
cu_seqlens_q,
|
806 |
+
cu_seqlens_kv,
|
807 |
+
max_seqlen_q,
|
808 |
+
max_seqlen_kv,
|
809 |
+
freqs_cis,
|
810 |
+
]
|
811 |
+
|
812 |
+
if self.blocks_to_swap:
|
813 |
+
self.offloader_double.wait_for_block(block_idx)
|
814 |
+
|
815 |
+
img, txt = block(*double_block_args)
|
816 |
+
|
817 |
+
if self.blocks_to_swap:
|
818 |
+
self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
|
819 |
+
|
820 |
+
# Merge txt and img to pass through single stream blocks.
|
821 |
+
x = torch.cat((img, txt), 1)
|
822 |
+
if self.blocks_to_swap:
|
823 |
+
# delete img, txt to reduce memory usage
|
824 |
+
del img, txt
|
825 |
+
clean_memory_on_device(x.device)
|
826 |
+
|
827 |
+
if len(self.single_blocks) > 0:
|
828 |
+
for block_idx, block in enumerate(self.single_blocks):
|
829 |
+
single_block_args = [
|
830 |
+
x,
|
831 |
+
vec,
|
832 |
+
txt_seq_len,
|
833 |
+
attn_mask,
|
834 |
+
total_len,
|
835 |
+
cu_seqlens_q,
|
836 |
+
cu_seqlens_kv,
|
837 |
+
max_seqlen_q,
|
838 |
+
max_seqlen_kv,
|
839 |
+
freqs_cis,
|
840 |
+
]
|
841 |
+
if self.blocks_to_swap:
|
842 |
+
self.offloader_single.wait_for_block(block_idx)
|
843 |
+
|
844 |
+
x = block(*single_block_args)
|
845 |
+
|
846 |
+
if self.blocks_to_swap:
|
847 |
+
self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
|
848 |
+
|
849 |
+
img = x[:, :img_seq_len, ...]
|
850 |
+
x = None
|
851 |
+
|
852 |
+
# ---------------------------- Final layer ------------------------------
|
853 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
854 |
+
|
855 |
+
img = self.unpatchify(img, tt, th, tw)
|
856 |
+
if return_dict:
|
857 |
+
out["x"] = img
|
858 |
+
return out
|
859 |
+
return img
|
860 |
+
|
861 |
+
def unpatchify(self, x, t, h, w):
|
862 |
+
"""
|
863 |
+
x: (N, T, patch_size**2 * C)
|
864 |
+
imgs: (N, H, W, C)
|
865 |
+
"""
|
866 |
+
c = self.unpatchify_channels
|
867 |
+
pt, ph, pw = self.patch_size
|
868 |
+
assert t * h * w == x.shape[1]
|
869 |
+
|
870 |
+
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
|
871 |
+
x = torch.einsum("nthwcopq->nctohpwq", x)
|
872 |
+
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
873 |
+
|
874 |
+
return imgs
|
875 |
+
|
876 |
+
def params_count(self):
|
877 |
+
counts = {
|
878 |
+
"double": sum(
|
879 |
+
[
|
880 |
+
sum(p.numel() for p in block.img_attn_qkv.parameters())
|
881 |
+
+ sum(p.numel() for p in block.img_attn_proj.parameters())
|
882 |
+
+ sum(p.numel() for p in block.img_mlp.parameters())
|
883 |
+
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
|
884 |
+
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
|
885 |
+
+ sum(p.numel() for p in block.txt_mlp.parameters())
|
886 |
+
for block in self.double_blocks
|
887 |
+
]
|
888 |
+
),
|
889 |
+
"single": sum(
|
890 |
+
[
|
891 |
+
sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
|
892 |
+
for block in self.single_blocks
|
893 |
+
]
|
894 |
+
),
|
895 |
+
"total": sum(p.numel() for p in self.parameters()),
|
896 |
+
}
|
897 |
+
counts["attn+mlp"] = counts["double"] + counts["single"]
|
898 |
+
return counts
|
899 |
+
|
900 |
+
|
901 |
+
#################################################################################
|
902 |
+
# HunyuanVideo Configs #
|
903 |
+
#################################################################################
|
904 |
+
|
905 |
+
HUNYUAN_VIDEO_CONFIG = {
|
906 |
+
"HYVideo-T/2": {
|
907 |
+
"mm_double_blocks_depth": 20,
|
908 |
+
"mm_single_blocks_depth": 40,
|
909 |
+
"rope_dim_list": [16, 56, 56],
|
910 |
+
"hidden_size": 3072,
|
911 |
+
"heads_num": 24,
|
912 |
+
"mlp_width_ratio": 4,
|
913 |
+
},
|
914 |
+
"HYVideo-T/2-cfgdistill": {
|
915 |
+
"mm_double_blocks_depth": 20,
|
916 |
+
"mm_single_blocks_depth": 40,
|
917 |
+
"rope_dim_list": [16, 56, 56],
|
918 |
+
"hidden_size": 3072,
|
919 |
+
"heads_num": 24,
|
920 |
+
"mlp_width_ratio": 4,
|
921 |
+
"guidance_embed": True,
|
922 |
+
},
|
923 |
+
}
|
924 |
+
|
925 |
+
|
926 |
+
def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
|
927 |
+
"""load hunyuan video model
|
928 |
+
|
929 |
+
NOTE: Only support HYVideo-T/2-cfgdistill now.
|
930 |
+
|
931 |
+
Args:
|
932 |
+
text_state_dim (int): text state dimension
|
933 |
+
text_state_dim_2 (int): text state dimension 2
|
934 |
+
in_channels (int): input channels number
|
935 |
+
out_channels (int): output channels number
|
936 |
+
factor_kwargs (dict): factor kwargs
|
937 |
+
|
938 |
+
Returns:
|
939 |
+
model (nn.Module): The hunyuan video model
|
940 |
+
"""
|
941 |
+
# if args.model in HUNYUAN_VIDEO_CONFIG.keys():
|
942 |
+
model = HYVideoDiffusionTransformer(
|
943 |
+
text_states_dim=text_states_dim,
|
944 |
+
text_states_dim_2=text_states_dim_2,
|
945 |
+
in_channels=in_channels,
|
946 |
+
out_channels=out_channels,
|
947 |
+
**HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
|
948 |
+
**factor_kwargs,
|
949 |
+
)
|
950 |
+
return model
|
951 |
+
# else:
|
952 |
+
# raise NotImplementedError()
|
953 |
+
|
954 |
+
|
955 |
+
def load_state_dict(model, model_path):
|
956 |
+
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
|
957 |
+
|
958 |
+
load_key = "module"
|
959 |
+
if load_key in state_dict:
|
960 |
+
state_dict = state_dict[load_key]
|
961 |
+
else:
|
962 |
+
raise KeyError(
|
963 |
+
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
|
964 |
+
f"are: {list(state_dict.keys())}."
|
965 |
+
)
|
966 |
+
model.load_state_dict(state_dict, strict=True, assign=True)
|
967 |
+
return model
|
968 |
+
|
969 |
+
|
970 |
+
def load_transformer(dit_path, attn_mode, split_attn, device, dtype, in_channels=16) -> HYVideoDiffusionTransformer:
|
971 |
+
# =========================== Build main model ===========================
|
972 |
+
factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode, "split_attn": split_attn}
|
973 |
+
latent_channels = 16
|
974 |
+
out_channels = latent_channels
|
975 |
+
|
976 |
+
with accelerate.init_empty_weights():
|
977 |
+
transformer = load_dit_model(
|
978 |
+
text_states_dim=4096,
|
979 |
+
text_states_dim_2=768,
|
980 |
+
in_channels=in_channels,
|
981 |
+
out_channels=out_channels,
|
982 |
+
factor_kwargs=factor_kwargs,
|
983 |
+
)
|
984 |
+
|
985 |
+
if os.path.splitext(dit_path)[-1] == ".safetensors":
|
986 |
+
# loading safetensors: may be already fp8
|
987 |
+
with MemoryEfficientSafeOpen(dit_path) as f:
|
988 |
+
state_dict = {}
|
989 |
+
for k in f.keys():
|
990 |
+
tensor = f.get_tensor(k)
|
991 |
+
tensor = tensor.to(device=device, dtype=dtype)
|
992 |
+
# TODO support comfy model
|
993 |
+
# if k.startswith("model.model."):
|
994 |
+
# k = convert_comfy_model_key(k)
|
995 |
+
state_dict[k] = tensor
|
996 |
+
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
997 |
+
else:
|
998 |
+
transformer = load_state_dict(transformer, dit_path)
|
999 |
+
|
1000 |
+
return transformer
|
1001 |
+
|
1002 |
+
|
1003 |
+
def get_rotary_pos_embed_by_shape(model, latents_size):
|
1004 |
+
target_ndim = 3
|
1005 |
+
ndim = 5 - 2
|
1006 |
+
|
1007 |
+
if isinstance(model.patch_size, int):
|
1008 |
+
assert all(s % model.patch_size == 0 for s in latents_size), (
|
1009 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
|
1010 |
+
f"but got {latents_size}."
|
1011 |
+
)
|
1012 |
+
rope_sizes = [s // model.patch_size for s in latents_size]
|
1013 |
+
elif isinstance(model.patch_size, list):
|
1014 |
+
assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
|
1015 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
|
1016 |
+
f"but got {latents_size}."
|
1017 |
+
)
|
1018 |
+
rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
|
1019 |
+
|
1020 |
+
if len(rope_sizes) != target_ndim:
|
1021 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
1022 |
+
head_dim = model.hidden_size // model.heads_num
|
1023 |
+
rope_dim_list = model.rope_dim_list
|
1024 |
+
if rope_dim_list is None:
|
1025 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
1026 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
1027 |
+
|
1028 |
+
rope_theta = 256
|
1029 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
1030 |
+
rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
|
1031 |
+
)
|
1032 |
+
return freqs_cos, freqs_sin
|
1033 |
+
|
1034 |
+
|
1035 |
+
def get_rotary_pos_embed(vae_name, model, video_length, height, width):
|
1036 |
+
# 884
|
1037 |
+
if "884" in vae_name:
|
1038 |
+
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
1039 |
+
elif "888" in vae_name:
|
1040 |
+
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
|
1041 |
+
else:
|
1042 |
+
latents_size = [video_length, height // 8, width // 8]
|
1043 |
+
|
1044 |
+
return get_rotary_pos_embed_by_shape(model, latents_size)
|
hunyuan_model/modulate_layers.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class ModulateDiT(nn.Module):
|
8 |
+
"""Modulation layer for DiT."""
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
hidden_size: int,
|
12 |
+
factor: int,
|
13 |
+
act_layer: Callable,
|
14 |
+
dtype=None,
|
15 |
+
device=None,
|
16 |
+
):
|
17 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
18 |
+
super().__init__()
|
19 |
+
self.act = act_layer()
|
20 |
+
self.linear = nn.Linear(
|
21 |
+
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
|
22 |
+
)
|
23 |
+
# Zero-initialize the modulation
|
24 |
+
nn.init.zeros_(self.linear.weight)
|
25 |
+
nn.init.zeros_(self.linear.bias)
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
return self.linear(self.act(x))
|
29 |
+
|
30 |
+
|
31 |
+
def modulate(x, shift=None, scale=None):
|
32 |
+
"""modulate by shift and scale
|
33 |
+
|
34 |
+
Args:
|
35 |
+
x (torch.Tensor): input tensor.
|
36 |
+
shift (torch.Tensor, optional): shift tensor. Defaults to None.
|
37 |
+
scale (torch.Tensor, optional): scale tensor. Defaults to None.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: the output tensor after modulate.
|
41 |
+
"""
|
42 |
+
if scale is None and shift is None:
|
43 |
+
return x
|
44 |
+
elif shift is None:
|
45 |
+
return x * (1 + scale.unsqueeze(1))
|
46 |
+
elif scale is None:
|
47 |
+
return x + shift.unsqueeze(1)
|
48 |
+
else:
|
49 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
50 |
+
|
51 |
+
|
52 |
+
def apply_gate(x, gate=None, tanh=False):
|
53 |
+
"""AI is creating summary for apply_gate
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x (torch.Tensor): input tensor.
|
57 |
+
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
58 |
+
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: the output tensor after apply gate.
|
62 |
+
"""
|
63 |
+
if gate is None:
|
64 |
+
return x
|
65 |
+
if tanh:
|
66 |
+
return x * gate.unsqueeze(1).tanh()
|
67 |
+
else:
|
68 |
+
return x * gate.unsqueeze(1)
|
69 |
+
|
70 |
+
|
71 |
+
def ckpt_wrapper(module):
|
72 |
+
def ckpt_forward(*inputs):
|
73 |
+
outputs = module(*inputs)
|
74 |
+
return outputs
|
75 |
+
|
76 |
+
return ckpt_forward
|
hunyuan_model/norm_layers.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class RMSNorm(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim: int,
|
9 |
+
elementwise_affine=True,
|
10 |
+
eps: float = 1e-6,
|
11 |
+
device=None,
|
12 |
+
dtype=None,
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Initialize the RMSNorm normalization layer.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
dim (int): The dimension of the input tensor.
|
19 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
20 |
+
|
21 |
+
Attributes:
|
22 |
+
eps (float): A small value added to the denominator for numerical stability.
|
23 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
24 |
+
|
25 |
+
"""
|
26 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
27 |
+
super().__init__()
|
28 |
+
self.eps = eps
|
29 |
+
if elementwise_affine:
|
30 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
31 |
+
|
32 |
+
def _norm(self, x):
|
33 |
+
"""
|
34 |
+
Apply the RMSNorm normalization to the input tensor.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): The input tensor.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: The normalized tensor.
|
41 |
+
|
42 |
+
"""
|
43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
"""
|
47 |
+
Forward pass through the RMSNorm layer.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
x (torch.Tensor): The input tensor.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
54 |
+
|
55 |
+
"""
|
56 |
+
output = self._norm(x.float()).type_as(x)
|
57 |
+
if hasattr(self, "weight"):
|
58 |
+
# output = output * self.weight
|
59 |
+
# support fp8
|
60 |
+
output = output * self.weight.to(output.dtype)
|
61 |
+
return output
|
62 |
+
|
63 |
+
|
64 |
+
def get_norm_layer(norm_layer):
|
65 |
+
"""
|
66 |
+
Get the normalization layer.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
norm_layer (str): The type of normalization layer.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
norm_layer (nn.Module): The normalization layer.
|
73 |
+
"""
|
74 |
+
if norm_layer == "layer":
|
75 |
+
return nn.LayerNorm
|
76 |
+
elif norm_layer == "rms":
|
77 |
+
return RMSNorm
|
78 |
+
else:
|
79 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
hunyuan_model/pipeline_hunyuan_video.py
ADDED
@@ -0,0 +1,1100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
import numpy as np
|
24 |
+
from dataclasses import dataclass
|
25 |
+
from packaging import version
|
26 |
+
|
27 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
28 |
+
from diffusers.configuration_utils import FrozenDict
|
29 |
+
from diffusers.image_processor import VaeImageProcessor
|
30 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
31 |
+
from diffusers.models import AutoencoderKL
|
32 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
33 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
34 |
+
from diffusers.utils import (
|
35 |
+
USE_PEFT_BACKEND,
|
36 |
+
deprecate,
|
37 |
+
logging,
|
38 |
+
replace_example_docstring,
|
39 |
+
scale_lora_layers,
|
40 |
+
unscale_lora_layers,
|
41 |
+
)
|
42 |
+
from diffusers.utils.torch_utils import randn_tensor
|
43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
44 |
+
from diffusers.utils import BaseOutput
|
45 |
+
|
46 |
+
from ...constants import PRECISION_TO_TYPE
|
47 |
+
from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
48 |
+
from ...text_encoder import TextEncoder
|
49 |
+
from ...modules import HYVideoDiffusionTransformer
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52 |
+
|
53 |
+
EXAMPLE_DOC_STRING = """"""
|
54 |
+
|
55 |
+
|
56 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
57 |
+
"""
|
58 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
59 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
60 |
+
"""
|
61 |
+
std_text = noise_pred_text.std(
|
62 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
|
63 |
+
)
|
64 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
65 |
+
# rescale the results from guidance (fixes overexposure)
|
66 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
67 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
68 |
+
noise_cfg = (
|
69 |
+
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
70 |
+
)
|
71 |
+
return noise_cfg
|
72 |
+
|
73 |
+
|
74 |
+
def retrieve_timesteps(
|
75 |
+
scheduler,
|
76 |
+
num_inference_steps: Optional[int] = None,
|
77 |
+
device: Optional[Union[str, torch.device]] = None,
|
78 |
+
timesteps: Optional[List[int]] = None,
|
79 |
+
sigmas: Optional[List[float]] = None,
|
80 |
+
**kwargs,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
84 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
scheduler (`SchedulerMixin`):
|
88 |
+
The scheduler to get timesteps from.
|
89 |
+
num_inference_steps (`int`):
|
90 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
91 |
+
must be `None`.
|
92 |
+
device (`str` or `torch.device`, *optional*):
|
93 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
94 |
+
timesteps (`List[int]`, *optional*):
|
95 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
96 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
97 |
+
sigmas (`List[float]`, *optional*):
|
98 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
99 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
103 |
+
second element is the number of inference steps.
|
104 |
+
"""
|
105 |
+
if timesteps is not None and sigmas is not None:
|
106 |
+
raise ValueError(
|
107 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
108 |
+
)
|
109 |
+
if timesteps is not None:
|
110 |
+
accepts_timesteps = "timesteps" in set(
|
111 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
112 |
+
)
|
113 |
+
if not accepts_timesteps:
|
114 |
+
raise ValueError(
|
115 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
116 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
117 |
+
)
|
118 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
119 |
+
timesteps = scheduler.timesteps
|
120 |
+
num_inference_steps = len(timesteps)
|
121 |
+
elif sigmas is not None:
|
122 |
+
accept_sigmas = "sigmas" in set(
|
123 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
124 |
+
)
|
125 |
+
if not accept_sigmas:
|
126 |
+
raise ValueError(
|
127 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
128 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
129 |
+
)
|
130 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
131 |
+
timesteps = scheduler.timesteps
|
132 |
+
num_inference_steps = len(timesteps)
|
133 |
+
else:
|
134 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
135 |
+
timesteps = scheduler.timesteps
|
136 |
+
return timesteps, num_inference_steps
|
137 |
+
|
138 |
+
|
139 |
+
@dataclass
|
140 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
141 |
+
videos: Union[torch.Tensor, np.ndarray]
|
142 |
+
|
143 |
+
|
144 |
+
class HunyuanVideoPipeline(DiffusionPipeline):
|
145 |
+
r"""
|
146 |
+
Pipeline for text-to-video generation using HunyuanVideo.
|
147 |
+
|
148 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
149 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
150 |
+
|
151 |
+
Args:
|
152 |
+
vae ([`AutoencoderKL`]):
|
153 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
154 |
+
text_encoder ([`TextEncoder`]):
|
155 |
+
Frozen text-encoder.
|
156 |
+
text_encoder_2 ([`TextEncoder`]):
|
157 |
+
Frozen text-encoder_2.
|
158 |
+
transformer ([`HYVideoDiffusionTransformer`]):
|
159 |
+
A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
|
160 |
+
scheduler ([`SchedulerMixin`]):
|
161 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
162 |
+
"""
|
163 |
+
|
164 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
165 |
+
_optional_components = ["text_encoder_2"]
|
166 |
+
_exclude_from_cpu_offload = ["transformer"]
|
167 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
168 |
+
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
vae: AutoencoderKL,
|
172 |
+
text_encoder: TextEncoder,
|
173 |
+
transformer: HYVideoDiffusionTransformer,
|
174 |
+
scheduler: KarrasDiffusionSchedulers,
|
175 |
+
text_encoder_2: Optional[TextEncoder] = None,
|
176 |
+
progress_bar_config: Dict[str, Any] = None,
|
177 |
+
args=None,
|
178 |
+
):
|
179 |
+
super().__init__()
|
180 |
+
|
181 |
+
# ==========================================================================================
|
182 |
+
if progress_bar_config is None:
|
183 |
+
progress_bar_config = {}
|
184 |
+
if not hasattr(self, "_progress_bar_config"):
|
185 |
+
self._progress_bar_config = {}
|
186 |
+
self._progress_bar_config.update(progress_bar_config)
|
187 |
+
|
188 |
+
self.args = args
|
189 |
+
# ==========================================================================================
|
190 |
+
|
191 |
+
if (
|
192 |
+
hasattr(scheduler.config, "steps_offset")
|
193 |
+
and scheduler.config.steps_offset != 1
|
194 |
+
):
|
195 |
+
deprecation_message = (
|
196 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
197 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
198 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
199 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
200 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
201 |
+
" file"
|
202 |
+
)
|
203 |
+
deprecate(
|
204 |
+
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
|
205 |
+
)
|
206 |
+
new_config = dict(scheduler.config)
|
207 |
+
new_config["steps_offset"] = 1
|
208 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
209 |
+
|
210 |
+
if (
|
211 |
+
hasattr(scheduler.config, "clip_sample")
|
212 |
+
and scheduler.config.clip_sample is True
|
213 |
+
):
|
214 |
+
deprecation_message = (
|
215 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
216 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
217 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
218 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
219 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
220 |
+
)
|
221 |
+
deprecate(
|
222 |
+
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
|
223 |
+
)
|
224 |
+
new_config = dict(scheduler.config)
|
225 |
+
new_config["clip_sample"] = False
|
226 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
227 |
+
|
228 |
+
self.register_modules(
|
229 |
+
vae=vae,
|
230 |
+
text_encoder=text_encoder,
|
231 |
+
transformer=transformer,
|
232 |
+
scheduler=scheduler,
|
233 |
+
text_encoder_2=text_encoder_2,
|
234 |
+
)
|
235 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
236 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
237 |
+
|
238 |
+
def encode_prompt(
|
239 |
+
self,
|
240 |
+
prompt,
|
241 |
+
device,
|
242 |
+
num_videos_per_prompt,
|
243 |
+
do_classifier_free_guidance,
|
244 |
+
negative_prompt=None,
|
245 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
246 |
+
attention_mask: Optional[torch.Tensor] = None,
|
247 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
248 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
249 |
+
lora_scale: Optional[float] = None,
|
250 |
+
clip_skip: Optional[int] = None,
|
251 |
+
text_encoder: Optional[TextEncoder] = None,
|
252 |
+
data_type: Optional[str] = "image",
|
253 |
+
):
|
254 |
+
r"""
|
255 |
+
Encodes the prompt into text encoder hidden states.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
prompt (`str` or `List[str]`, *optional*):
|
259 |
+
prompt to be encoded
|
260 |
+
device: (`torch.device`):
|
261 |
+
torch device
|
262 |
+
num_videos_per_prompt (`int`):
|
263 |
+
number of videos that should be generated per prompt
|
264 |
+
do_classifier_free_guidance (`bool`):
|
265 |
+
whether to use classifier free guidance or not
|
266 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
267 |
+
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
268 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
269 |
+
less than `1`).
|
270 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
271 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
272 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
273 |
+
attention_mask (`torch.Tensor`, *optional*):
|
274 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
275 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
276 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
277 |
+
argument.
|
278 |
+
negative_attention_mask (`torch.Tensor`, *optional*):
|
279 |
+
lora_scale (`float`, *optional*):
|
280 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
281 |
+
clip_skip (`int`, *optional*):
|
282 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
283 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
284 |
+
text_encoder (TextEncoder, *optional*):
|
285 |
+
data_type (`str`, *optional*):
|
286 |
+
"""
|
287 |
+
if text_encoder is None:
|
288 |
+
text_encoder = self.text_encoder
|
289 |
+
|
290 |
+
# set lora scale so that monkey patched LoRA
|
291 |
+
# function of text encoder can correctly access it
|
292 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
293 |
+
self._lora_scale = lora_scale
|
294 |
+
|
295 |
+
# dynamically adjust the LoRA scale
|
296 |
+
if not USE_PEFT_BACKEND:
|
297 |
+
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
|
298 |
+
else:
|
299 |
+
scale_lora_layers(text_encoder.model, lora_scale)
|
300 |
+
|
301 |
+
if prompt is not None and isinstance(prompt, str):
|
302 |
+
batch_size = 1
|
303 |
+
elif prompt is not None and isinstance(prompt, list):
|
304 |
+
batch_size = len(prompt)
|
305 |
+
else:
|
306 |
+
batch_size = prompt_embeds.shape[0]
|
307 |
+
|
308 |
+
if prompt_embeds is None:
|
309 |
+
# textual inversion: process multi-vector tokens if necessary
|
310 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
311 |
+
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
|
312 |
+
|
313 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
314 |
+
|
315 |
+
if clip_skip is None:
|
316 |
+
prompt_outputs = text_encoder.encode(
|
317 |
+
text_inputs, data_type=data_type, device=device
|
318 |
+
)
|
319 |
+
prompt_embeds = prompt_outputs.hidden_state
|
320 |
+
else:
|
321 |
+
prompt_outputs = text_encoder.encode(
|
322 |
+
text_inputs,
|
323 |
+
output_hidden_states=True,
|
324 |
+
data_type=data_type,
|
325 |
+
device=device,
|
326 |
+
)
|
327 |
+
# Access the `hidden_states` first, that contains a tuple of
|
328 |
+
# all the hidden states from the encoder layers. Then index into
|
329 |
+
# the tuple to access the hidden states from the desired layer.
|
330 |
+
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
|
331 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
332 |
+
# representations. The `last_hidden_states` that we typically use for
|
333 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
334 |
+
# layer.
|
335 |
+
prompt_embeds = text_encoder.model.text_model.final_layer_norm(
|
336 |
+
prompt_embeds
|
337 |
+
)
|
338 |
+
|
339 |
+
attention_mask = prompt_outputs.attention_mask
|
340 |
+
if attention_mask is not None:
|
341 |
+
attention_mask = attention_mask.to(device)
|
342 |
+
bs_embed, seq_len = attention_mask.shape
|
343 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
344 |
+
attention_mask = attention_mask.view(
|
345 |
+
bs_embed * num_videos_per_prompt, seq_len
|
346 |
+
)
|
347 |
+
|
348 |
+
if text_encoder is not None:
|
349 |
+
prompt_embeds_dtype = text_encoder.dtype
|
350 |
+
elif self.transformer is not None:
|
351 |
+
prompt_embeds_dtype = self.transformer.dtype
|
352 |
+
else:
|
353 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
354 |
+
|
355 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
356 |
+
|
357 |
+
if prompt_embeds.ndim == 2:
|
358 |
+
bs_embed, _ = prompt_embeds.shape
|
359 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
360 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
361 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
362 |
+
else:
|
363 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
364 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
365 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
366 |
+
prompt_embeds = prompt_embeds.view(
|
367 |
+
bs_embed * num_videos_per_prompt, seq_len, -1
|
368 |
+
)
|
369 |
+
|
370 |
+
# get unconditional embeddings for classifier free guidance
|
371 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
372 |
+
uncond_tokens: List[str]
|
373 |
+
if negative_prompt is None:
|
374 |
+
uncond_tokens = [""] * batch_size
|
375 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
376 |
+
raise TypeError(
|
377 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
378 |
+
f" {type(prompt)}."
|
379 |
+
)
|
380 |
+
elif isinstance(negative_prompt, str):
|
381 |
+
uncond_tokens = [negative_prompt]
|
382 |
+
elif batch_size != len(negative_prompt):
|
383 |
+
raise ValueError(
|
384 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
385 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
386 |
+
" the batch size of `prompt`."
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
uncond_tokens = negative_prompt
|
390 |
+
|
391 |
+
# textual inversion: process multi-vector tokens if necessary
|
392 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
393 |
+
uncond_tokens = self.maybe_convert_prompt(
|
394 |
+
uncond_tokens, text_encoder.tokenizer
|
395 |
+
)
|
396 |
+
|
397 |
+
# max_length = prompt_embeds.shape[1]
|
398 |
+
uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
|
399 |
+
|
400 |
+
negative_prompt_outputs = text_encoder.encode(
|
401 |
+
uncond_input, data_type=data_type, device=device
|
402 |
+
)
|
403 |
+
negative_prompt_embeds = negative_prompt_outputs.hidden_state
|
404 |
+
|
405 |
+
negative_attention_mask = negative_prompt_outputs.attention_mask
|
406 |
+
if negative_attention_mask is not None:
|
407 |
+
negative_attention_mask = negative_attention_mask.to(device)
|
408 |
+
_, seq_len = negative_attention_mask.shape
|
409 |
+
negative_attention_mask = negative_attention_mask.repeat(
|
410 |
+
1, num_videos_per_prompt
|
411 |
+
)
|
412 |
+
negative_attention_mask = negative_attention_mask.view(
|
413 |
+
batch_size * num_videos_per_prompt, seq_len
|
414 |
+
)
|
415 |
+
|
416 |
+
if do_classifier_free_guidance:
|
417 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
418 |
+
seq_len = negative_prompt_embeds.shape[1]
|
419 |
+
|
420 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
421 |
+
dtype=prompt_embeds_dtype, device=device
|
422 |
+
)
|
423 |
+
|
424 |
+
if negative_prompt_embeds.ndim == 2:
|
425 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
426 |
+
1, num_videos_per_prompt
|
427 |
+
)
|
428 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
429 |
+
batch_size * num_videos_per_prompt, -1
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
433 |
+
1, num_videos_per_prompt, 1
|
434 |
+
)
|
435 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
436 |
+
batch_size * num_videos_per_prompt, seq_len, -1
|
437 |
+
)
|
438 |
+
|
439 |
+
if text_encoder is not None:
|
440 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
441 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
442 |
+
unscale_lora_layers(text_encoder.model, lora_scale)
|
443 |
+
|
444 |
+
return (
|
445 |
+
prompt_embeds,
|
446 |
+
negative_prompt_embeds,
|
447 |
+
attention_mask,
|
448 |
+
negative_attention_mask,
|
449 |
+
)
|
450 |
+
|
451 |
+
def decode_latents(self, latents, enable_tiling=True):
|
452 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
453 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
454 |
+
|
455 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
456 |
+
if enable_tiling:
|
457 |
+
self.vae.enable_tiling()
|
458 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
459 |
+
else:
|
460 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
461 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
462 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
463 |
+
if image.ndim == 4:
|
464 |
+
image = image.cpu().permute(0, 2, 3, 1).float()
|
465 |
+
else:
|
466 |
+
image = image.cpu().float()
|
467 |
+
return image
|
468 |
+
|
469 |
+
def prepare_extra_func_kwargs(self, func, kwargs):
|
470 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
471 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
472 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
473 |
+
# and should be between [0, 1]
|
474 |
+
extra_step_kwargs = {}
|
475 |
+
|
476 |
+
for k, v in kwargs.items():
|
477 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
478 |
+
if accepts:
|
479 |
+
extra_step_kwargs[k] = v
|
480 |
+
return extra_step_kwargs
|
481 |
+
|
482 |
+
def check_inputs(
|
483 |
+
self,
|
484 |
+
prompt,
|
485 |
+
height,
|
486 |
+
width,
|
487 |
+
video_length,
|
488 |
+
callback_steps,
|
489 |
+
negative_prompt=None,
|
490 |
+
prompt_embeds=None,
|
491 |
+
negative_prompt_embeds=None,
|
492 |
+
callback_on_step_end_tensor_inputs=None,
|
493 |
+
vae_ver="88-4c-sd",
|
494 |
+
):
|
495 |
+
if height % 8 != 0 or width % 8 != 0:
|
496 |
+
raise ValueError(
|
497 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
498 |
+
)
|
499 |
+
|
500 |
+
if video_length is not None:
|
501 |
+
if "884" in vae_ver:
|
502 |
+
if video_length != 1 and (video_length - 1) % 4 != 0:
|
503 |
+
raise ValueError(
|
504 |
+
f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
|
505 |
+
)
|
506 |
+
elif "888" in vae_ver:
|
507 |
+
if video_length != 1 and (video_length - 1) % 8 != 0:
|
508 |
+
raise ValueError(
|
509 |
+
f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
|
510 |
+
)
|
511 |
+
|
512 |
+
if callback_steps is not None and (
|
513 |
+
not isinstance(callback_steps, int) or callback_steps <= 0
|
514 |
+
):
|
515 |
+
raise ValueError(
|
516 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
517 |
+
f" {type(callback_steps)}."
|
518 |
+
)
|
519 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
520 |
+
k in self._callback_tensor_inputs
|
521 |
+
for k in callback_on_step_end_tensor_inputs
|
522 |
+
):
|
523 |
+
raise ValueError(
|
524 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
525 |
+
)
|
526 |
+
|
527 |
+
if prompt is not None and prompt_embeds is not None:
|
528 |
+
raise ValueError(
|
529 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
530 |
+
" only forward one of the two."
|
531 |
+
)
|
532 |
+
elif prompt is None and prompt_embeds is None:
|
533 |
+
raise ValueError(
|
534 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
535 |
+
)
|
536 |
+
elif prompt is not None and (
|
537 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
538 |
+
):
|
539 |
+
raise ValueError(
|
540 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
541 |
+
)
|
542 |
+
|
543 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
544 |
+
raise ValueError(
|
545 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
546 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
547 |
+
)
|
548 |
+
|
549 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
550 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
551 |
+
raise ValueError(
|
552 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
553 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
554 |
+
f" {negative_prompt_embeds.shape}."
|
555 |
+
)
|
556 |
+
|
557 |
+
|
558 |
+
def prepare_latents(
|
559 |
+
self,
|
560 |
+
batch_size,
|
561 |
+
num_channels_latents,
|
562 |
+
height,
|
563 |
+
width,
|
564 |
+
video_length,
|
565 |
+
dtype,
|
566 |
+
device,
|
567 |
+
generator,
|
568 |
+
latents=None,
|
569 |
+
):
|
570 |
+
shape = (
|
571 |
+
batch_size,
|
572 |
+
num_channels_latents,
|
573 |
+
video_length,
|
574 |
+
int(height) // self.vae_scale_factor,
|
575 |
+
int(width) // self.vae_scale_factor,
|
576 |
+
)
|
577 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
578 |
+
raise ValueError(
|
579 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
580 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
581 |
+
)
|
582 |
+
|
583 |
+
if latents is None:
|
584 |
+
latents = randn_tensor(
|
585 |
+
shape, generator=generator, device=device, dtype=dtype
|
586 |
+
)
|
587 |
+
else:
|
588 |
+
latents = latents.to(device)
|
589 |
+
|
590 |
+
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
|
591 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
592 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
593 |
+
latents = latents * self.scheduler.init_noise_sigma
|
594 |
+
return latents
|
595 |
+
|
596 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
597 |
+
def get_guidance_scale_embedding(
|
598 |
+
self,
|
599 |
+
w: torch.Tensor,
|
600 |
+
embedding_dim: int = 512,
|
601 |
+
dtype: torch.dtype = torch.float32,
|
602 |
+
) -> torch.Tensor:
|
603 |
+
"""
|
604 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
605 |
+
|
606 |
+
Args:
|
607 |
+
w (`torch.Tensor`):
|
608 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
609 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
610 |
+
Dimension of the embeddings to generate.
|
611 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
612 |
+
Data type of the generated embeddings.
|
613 |
+
|
614 |
+
Returns:
|
615 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
616 |
+
"""
|
617 |
+
assert len(w.shape) == 1
|
618 |
+
w = w * 1000.0
|
619 |
+
|
620 |
+
half_dim = embedding_dim // 2
|
621 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
622 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
623 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
624 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
625 |
+
if embedding_dim % 2 == 1: # zero pad
|
626 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
627 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
628 |
+
return emb
|
629 |
+
|
630 |
+
@property
|
631 |
+
def guidance_scale(self):
|
632 |
+
return self._guidance_scale
|
633 |
+
|
634 |
+
@property
|
635 |
+
def guidance_rescale(self):
|
636 |
+
return self._guidance_rescale
|
637 |
+
|
638 |
+
@property
|
639 |
+
def clip_skip(self):
|
640 |
+
return self._clip_skip
|
641 |
+
|
642 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
643 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
644 |
+
# corresponds to doing no classifier free guidance.
|
645 |
+
@property
|
646 |
+
def do_classifier_free_guidance(self):
|
647 |
+
# return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
|
648 |
+
return self._guidance_scale > 1
|
649 |
+
|
650 |
+
@property
|
651 |
+
def cross_attention_kwargs(self):
|
652 |
+
return self._cross_attention_kwargs
|
653 |
+
|
654 |
+
@property
|
655 |
+
def num_timesteps(self):
|
656 |
+
return self._num_timesteps
|
657 |
+
|
658 |
+
@property
|
659 |
+
def interrupt(self):
|
660 |
+
return self._interrupt
|
661 |
+
|
662 |
+
@torch.no_grad()
|
663 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
664 |
+
def __call__(
|
665 |
+
self,
|
666 |
+
prompt: Union[str, List[str]],
|
667 |
+
height: int,
|
668 |
+
width: int,
|
669 |
+
video_length: int,
|
670 |
+
data_type: str = "video",
|
671 |
+
num_inference_steps: int = 50,
|
672 |
+
timesteps: List[int] = None,
|
673 |
+
sigmas: List[float] = None,
|
674 |
+
guidance_scale: float = 7.5,
|
675 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
676 |
+
num_videos_per_prompt: Optional[int] = 1,
|
677 |
+
eta: float = 0.0,
|
678 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
679 |
+
latents: Optional[torch.Tensor] = None,
|
680 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
681 |
+
attention_mask: Optional[torch.Tensor] = None,
|
682 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
683 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
684 |
+
output_type: Optional[str] = "pil",
|
685 |
+
return_dict: bool = True,
|
686 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
687 |
+
guidance_rescale: float = 0.0,
|
688 |
+
clip_skip: Optional[int] = None,
|
689 |
+
callback_on_step_end: Optional[
|
690 |
+
Union[
|
691 |
+
Callable[[int, int, Dict], None],
|
692 |
+
PipelineCallback,
|
693 |
+
MultiPipelineCallbacks,
|
694 |
+
]
|
695 |
+
] = None,
|
696 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
697 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
698 |
+
vae_ver: str = "88-4c-sd",
|
699 |
+
enable_tiling: bool = False,
|
700 |
+
n_tokens: Optional[int] = None,
|
701 |
+
embedded_guidance_scale: Optional[float] = None,
|
702 |
+
**kwargs,
|
703 |
+
):
|
704 |
+
r"""
|
705 |
+
The call function to the pipeline for generation.
|
706 |
+
|
707 |
+
Args:
|
708 |
+
prompt (`str` or `List[str]`):
|
709 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
710 |
+
height (`int`):
|
711 |
+
The height in pixels of the generated image.
|
712 |
+
width (`int`):
|
713 |
+
The width in pixels of the generated image.
|
714 |
+
video_length (`int`):
|
715 |
+
The number of frames in the generated video.
|
716 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
717 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
718 |
+
expense of slower inference.
|
719 |
+
timesteps (`List[int]`, *optional*):
|
720 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
721 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
722 |
+
passed will be used. Must be in descending order.
|
723 |
+
sigmas (`List[float]`, *optional*):
|
724 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
725 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
726 |
+
will be used.
|
727 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
728 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
729 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
730 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
731 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
732 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
733 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
734 |
+
The number of images to generate per prompt.
|
735 |
+
eta (`float`, *optional*, defaults to 0.0):
|
736 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
737 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
738 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
739 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
740 |
+
generation deterministic.
|
741 |
+
latents (`torch.Tensor`, *optional*):
|
742 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
743 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
744 |
+
tensor is generated by sampling using the supplied random `generator`.
|
745 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
746 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
747 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
748 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
749 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
750 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
751 |
+
|
752 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
753 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
754 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
755 |
+
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
|
756 |
+
plain tuple.
|
757 |
+
cross_attention_kwargs (`dict`, *optional*):
|
758 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
759 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
760 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
761 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
762 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
763 |
+
using zero terminal SNR.
|
764 |
+
clip_skip (`int`, *optional*):
|
765 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
766 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
767 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
768 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
769 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
770 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
771 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
772 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
773 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
774 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
775 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
776 |
+
|
777 |
+
Examples:
|
778 |
+
|
779 |
+
Returns:
|
780 |
+
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
781 |
+
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
|
782 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
783 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
784 |
+
"not-safe-for-work" (nsfw) content.
|
785 |
+
"""
|
786 |
+
callback = kwargs.pop("callback", None)
|
787 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
788 |
+
|
789 |
+
if callback is not None:
|
790 |
+
deprecate(
|
791 |
+
"callback",
|
792 |
+
"1.0.0",
|
793 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
794 |
+
)
|
795 |
+
if callback_steps is not None:
|
796 |
+
deprecate(
|
797 |
+
"callback_steps",
|
798 |
+
"1.0.0",
|
799 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
800 |
+
)
|
801 |
+
|
802 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
803 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
804 |
+
|
805 |
+
# 0. Default height and width to unet
|
806 |
+
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
807 |
+
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
808 |
+
# to deal with lora scaling and other possible forward hooks
|
809 |
+
|
810 |
+
# 1. Check inputs. Raise error if not correct
|
811 |
+
self.check_inputs(
|
812 |
+
prompt,
|
813 |
+
height,
|
814 |
+
width,
|
815 |
+
video_length,
|
816 |
+
callback_steps,
|
817 |
+
negative_prompt,
|
818 |
+
prompt_embeds,
|
819 |
+
negative_prompt_embeds,
|
820 |
+
callback_on_step_end_tensor_inputs,
|
821 |
+
vae_ver=vae_ver,
|
822 |
+
)
|
823 |
+
|
824 |
+
self._guidance_scale = guidance_scale
|
825 |
+
self._guidance_rescale = guidance_rescale
|
826 |
+
self._clip_skip = clip_skip
|
827 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
828 |
+
self._interrupt = False
|
829 |
+
|
830 |
+
# 2. Define call parameters
|
831 |
+
if prompt is not None and isinstance(prompt, str):
|
832 |
+
batch_size = 1
|
833 |
+
elif prompt is not None and isinstance(prompt, list):
|
834 |
+
batch_size = len(prompt)
|
835 |
+
else:
|
836 |
+
batch_size = prompt_embeds.shape[0]
|
837 |
+
|
838 |
+
device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
|
839 |
+
|
840 |
+
# 3. Encode input prompt
|
841 |
+
lora_scale = (
|
842 |
+
self.cross_attention_kwargs.get("scale", None)
|
843 |
+
if self.cross_attention_kwargs is not None
|
844 |
+
else None
|
845 |
+
)
|
846 |
+
|
847 |
+
(
|
848 |
+
prompt_embeds,
|
849 |
+
negative_prompt_embeds,
|
850 |
+
prompt_mask,
|
851 |
+
negative_prompt_mask,
|
852 |
+
) = self.encode_prompt(
|
853 |
+
prompt,
|
854 |
+
device,
|
855 |
+
num_videos_per_prompt,
|
856 |
+
self.do_classifier_free_guidance,
|
857 |
+
negative_prompt,
|
858 |
+
prompt_embeds=prompt_embeds,
|
859 |
+
attention_mask=attention_mask,
|
860 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
861 |
+
negative_attention_mask=negative_attention_mask,
|
862 |
+
lora_scale=lora_scale,
|
863 |
+
clip_skip=self.clip_skip,
|
864 |
+
data_type=data_type,
|
865 |
+
)
|
866 |
+
if self.text_encoder_2 is not None:
|
867 |
+
(
|
868 |
+
prompt_embeds_2,
|
869 |
+
negative_prompt_embeds_2,
|
870 |
+
prompt_mask_2,
|
871 |
+
negative_prompt_mask_2,
|
872 |
+
) = self.encode_prompt(
|
873 |
+
prompt,
|
874 |
+
device,
|
875 |
+
num_videos_per_prompt,
|
876 |
+
self.do_classifier_free_guidance,
|
877 |
+
negative_prompt,
|
878 |
+
prompt_embeds=None,
|
879 |
+
attention_mask=None,
|
880 |
+
negative_prompt_embeds=None,
|
881 |
+
negative_attention_mask=None,
|
882 |
+
lora_scale=lora_scale,
|
883 |
+
clip_skip=self.clip_skip,
|
884 |
+
text_encoder=self.text_encoder_2,
|
885 |
+
data_type=data_type,
|
886 |
+
)
|
887 |
+
else:
|
888 |
+
prompt_embeds_2 = None
|
889 |
+
negative_prompt_embeds_2 = None
|
890 |
+
prompt_mask_2 = None
|
891 |
+
negative_prompt_mask_2 = None
|
892 |
+
|
893 |
+
# For classifier free guidance, we need to do two forward passes.
|
894 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
895 |
+
# to avoid doing two forward passes
|
896 |
+
if self.do_classifier_free_guidance:
|
897 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
898 |
+
if prompt_mask is not None:
|
899 |
+
prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
|
900 |
+
if prompt_embeds_2 is not None:
|
901 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
902 |
+
if prompt_mask_2 is not None:
|
903 |
+
prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
|
904 |
+
|
905 |
+
|
906 |
+
# 4. Prepare timesteps
|
907 |
+
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
|
908 |
+
self.scheduler.set_timesteps, {"n_tokens": n_tokens}
|
909 |
+
)
|
910 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
911 |
+
self.scheduler,
|
912 |
+
num_inference_steps,
|
913 |
+
device,
|
914 |
+
timesteps,
|
915 |
+
sigmas,
|
916 |
+
**extra_set_timesteps_kwargs,
|
917 |
+
)
|
918 |
+
|
919 |
+
if "884" in vae_ver:
|
920 |
+
video_length = (video_length - 1) // 4 + 1
|
921 |
+
elif "888" in vae_ver:
|
922 |
+
video_length = (video_length - 1) // 8 + 1
|
923 |
+
else:
|
924 |
+
video_length = video_length
|
925 |
+
|
926 |
+
# 5. Prepare latent variables
|
927 |
+
num_channels_latents = self.transformer.config.in_channels
|
928 |
+
latents = self.prepare_latents(
|
929 |
+
batch_size * num_videos_per_prompt,
|
930 |
+
num_channels_latents,
|
931 |
+
height,
|
932 |
+
width,
|
933 |
+
video_length,
|
934 |
+
prompt_embeds.dtype,
|
935 |
+
device,
|
936 |
+
generator,
|
937 |
+
latents,
|
938 |
+
)
|
939 |
+
|
940 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
941 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
942 |
+
self.scheduler.step,
|
943 |
+
{"generator": generator, "eta": eta},
|
944 |
+
)
|
945 |
+
|
946 |
+
target_dtype = PRECISION_TO_TYPE[self.args.precision]
|
947 |
+
autocast_enabled = (
|
948 |
+
target_dtype != torch.float32
|
949 |
+
) and not self.args.disable_autocast
|
950 |
+
vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
|
951 |
+
vae_autocast_enabled = (
|
952 |
+
vae_dtype != torch.float32
|
953 |
+
) and not self.args.disable_autocast
|
954 |
+
|
955 |
+
# 7. Denoising loop
|
956 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
957 |
+
self._num_timesteps = len(timesteps)
|
958 |
+
|
959 |
+
# if is_progress_bar:
|
960 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
961 |
+
for i, t in enumerate(timesteps):
|
962 |
+
if self.interrupt:
|
963 |
+
continue
|
964 |
+
|
965 |
+
# expand the latents if we are doing classifier free guidance
|
966 |
+
latent_model_input = (
|
967 |
+
torch.cat([latents] * 2)
|
968 |
+
if self.do_classifier_free_guidance
|
969 |
+
else latents
|
970 |
+
)
|
971 |
+
latent_model_input = self.scheduler.scale_model_input(
|
972 |
+
latent_model_input, t
|
973 |
+
)
|
974 |
+
|
975 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
976 |
+
guidance_expand = (
|
977 |
+
torch.tensor(
|
978 |
+
[embedded_guidance_scale] * latent_model_input.shape[0],
|
979 |
+
dtype=torch.float32,
|
980 |
+
device=device,
|
981 |
+
).to(target_dtype)
|
982 |
+
* 1000.0
|
983 |
+
if embedded_guidance_scale is not None
|
984 |
+
else None
|
985 |
+
)
|
986 |
+
|
987 |
+
# predict the noise residual
|
988 |
+
with torch.autocast(
|
989 |
+
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
|
990 |
+
):
|
991 |
+
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
|
992 |
+
latent_model_input, # [2, 16, 33, 24, 42]
|
993 |
+
t_expand, # [2]
|
994 |
+
text_states=prompt_embeds, # [2, 256, 4096]
|
995 |
+
text_mask=prompt_mask, # [2, 256]
|
996 |
+
text_states_2=prompt_embeds_2, # [2, 768]
|
997 |
+
freqs_cos=freqs_cis[0], # [seqlen, head_dim]
|
998 |
+
freqs_sin=freqs_cis[1], # [seqlen, head_dim]
|
999 |
+
guidance=guidance_expand,
|
1000 |
+
return_dict=True,
|
1001 |
+
)[
|
1002 |
+
"x"
|
1003 |
+
]
|
1004 |
+
|
1005 |
+
# perform guidance
|
1006 |
+
if self.do_classifier_free_guidance:
|
1007 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1008 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
1009 |
+
noise_pred_text - noise_pred_uncond
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1013 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1014 |
+
noise_pred = rescale_noise_cfg(
|
1015 |
+
noise_pred,
|
1016 |
+
noise_pred_text,
|
1017 |
+
guidance_rescale=self.guidance_rescale,
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1021 |
+
latents = self.scheduler.step(
|
1022 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
1023 |
+
)[0]
|
1024 |
+
|
1025 |
+
if callback_on_step_end is not None:
|
1026 |
+
callback_kwargs = {}
|
1027 |
+
for k in callback_on_step_end_tensor_inputs:
|
1028 |
+
callback_kwargs[k] = locals()[k]
|
1029 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1030 |
+
|
1031 |
+
latents = callback_outputs.pop("latents", latents)
|
1032 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1033 |
+
negative_prompt_embeds = callback_outputs.pop(
|
1034 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
1035 |
+
)
|
1036 |
+
|
1037 |
+
# call the callback, if provided
|
1038 |
+
if i == len(timesteps) - 1 or (
|
1039 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
1040 |
+
):
|
1041 |
+
if progress_bar is not None:
|
1042 |
+
progress_bar.update()
|
1043 |
+
if callback is not None and i % callback_steps == 0:
|
1044 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1045 |
+
callback(step_idx, t, latents)
|
1046 |
+
|
1047 |
+
if not output_type == "latent":
|
1048 |
+
expand_temporal_dim = False
|
1049 |
+
if len(latents.shape) == 4:
|
1050 |
+
if isinstance(self.vae, AutoencoderKLCausal3D):
|
1051 |
+
latents = latents.unsqueeze(2)
|
1052 |
+
expand_temporal_dim = True
|
1053 |
+
elif len(latents.shape) == 5:
|
1054 |
+
pass
|
1055 |
+
else:
|
1056 |
+
raise ValueError(
|
1057 |
+
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
|
1058 |
+
)
|
1059 |
+
|
1060 |
+
if (
|
1061 |
+
hasattr(self.vae.config, "shift_factor")
|
1062 |
+
and self.vae.config.shift_factor
|
1063 |
+
):
|
1064 |
+
latents = (
|
1065 |
+
latents / self.vae.config.scaling_factor
|
1066 |
+
+ self.vae.config.shift_factor
|
1067 |
+
)
|
1068 |
+
else:
|
1069 |
+
latents = latents / self.vae.config.scaling_factor
|
1070 |
+
|
1071 |
+
with torch.autocast(
|
1072 |
+
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
|
1073 |
+
):
|
1074 |
+
if enable_tiling:
|
1075 |
+
self.vae.enable_tiling()
|
1076 |
+
image = self.vae.decode(
|
1077 |
+
latents, return_dict=False, generator=generator
|
1078 |
+
)[0]
|
1079 |
+
else:
|
1080 |
+
image = self.vae.decode(
|
1081 |
+
latents, return_dict=False, generator=generator
|
1082 |
+
)[0]
|
1083 |
+
|
1084 |
+
if expand_temporal_dim or image.shape[2] == 1:
|
1085 |
+
image = image.squeeze(2)
|
1086 |
+
|
1087 |
+
else:
|
1088 |
+
image = latents
|
1089 |
+
|
1090 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
1091 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
1092 |
+
image = image.cpu().float()
|
1093 |
+
|
1094 |
+
# Offload all models
|
1095 |
+
self.maybe_free_model_hooks()
|
1096 |
+
|
1097 |
+
if not return_dict:
|
1098 |
+
return image
|
1099 |
+
|
1100 |
+
return HunyuanVideoPipelineOutput(videos=image)
|
hunyuan_model/posemb_layers.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Union, Tuple, List
|
3 |
+
|
4 |
+
|
5 |
+
def _to_tuple(x, dim=2):
|
6 |
+
if isinstance(x, int):
|
7 |
+
return (x,) * dim
|
8 |
+
elif len(x) == dim:
|
9 |
+
return x
|
10 |
+
else:
|
11 |
+
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
12 |
+
|
13 |
+
|
14 |
+
def get_meshgrid_nd(start, *args, dim=2):
|
15 |
+
"""
|
16 |
+
Get n-D meshgrid with start, stop and num.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
20 |
+
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
21 |
+
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
22 |
+
n-tuples.
|
23 |
+
*args: See above.
|
24 |
+
dim (int): Dimension of the meshgrid. Defaults to 2.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
grid (np.ndarray): [dim, ...]
|
28 |
+
"""
|
29 |
+
if len(args) == 0:
|
30 |
+
# start is grid_size
|
31 |
+
num = _to_tuple(start, dim=dim)
|
32 |
+
start = (0,) * dim
|
33 |
+
stop = num
|
34 |
+
elif len(args) == 1:
|
35 |
+
# start is start, args[0] is stop, step is 1
|
36 |
+
start = _to_tuple(start, dim=dim)
|
37 |
+
stop = _to_tuple(args[0], dim=dim)
|
38 |
+
num = [stop[i] - start[i] for i in range(dim)]
|
39 |
+
elif len(args) == 2:
|
40 |
+
# start is start, args[0] is stop, args[1] is num
|
41 |
+
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
42 |
+
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
43 |
+
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
44 |
+
else:
|
45 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
46 |
+
|
47 |
+
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
48 |
+
axis_grid = []
|
49 |
+
for i in range(dim):
|
50 |
+
a, b, n = start[i], stop[i], num[i]
|
51 |
+
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
52 |
+
axis_grid.append(g)
|
53 |
+
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
54 |
+
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
55 |
+
|
56 |
+
return grid
|
57 |
+
|
58 |
+
|
59 |
+
#################################################################################
|
60 |
+
# Rotary Positional Embedding Functions #
|
61 |
+
#################################################################################
|
62 |
+
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
|
63 |
+
|
64 |
+
|
65 |
+
def reshape_for_broadcast(
|
66 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
67 |
+
x: torch.Tensor,
|
68 |
+
head_first=False,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
72 |
+
|
73 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
74 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
75 |
+
|
76 |
+
Notes:
|
77 |
+
When using FlashMHAModified, head_first should be False.
|
78 |
+
When using Attention, head_first should be True.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
82 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
83 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: Reshaped frequency tensor.
|
87 |
+
|
88 |
+
Raises:
|
89 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
90 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
91 |
+
"""
|
92 |
+
ndim = x.ndim
|
93 |
+
assert 0 <= 1 < ndim
|
94 |
+
|
95 |
+
if isinstance(freqs_cis, tuple):
|
96 |
+
# freqs_cis: (cos, sin) in real space
|
97 |
+
if head_first:
|
98 |
+
assert freqs_cis[0].shape == (
|
99 |
+
x.shape[-2],
|
100 |
+
x.shape[-1],
|
101 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
102 |
+
shape = [
|
103 |
+
d if i == ndim - 2 or i == ndim - 1 else 1
|
104 |
+
for i, d in enumerate(x.shape)
|
105 |
+
]
|
106 |
+
else:
|
107 |
+
assert freqs_cis[0].shape == (
|
108 |
+
x.shape[1],
|
109 |
+
x.shape[-1],
|
110 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
111 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
112 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
113 |
+
else:
|
114 |
+
# freqs_cis: values in complex space
|
115 |
+
if head_first:
|
116 |
+
assert freqs_cis.shape == (
|
117 |
+
x.shape[-2],
|
118 |
+
x.shape[-1],
|
119 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
120 |
+
shape = [
|
121 |
+
d if i == ndim - 2 or i == ndim - 1 else 1
|
122 |
+
for i, d in enumerate(x.shape)
|
123 |
+
]
|
124 |
+
else:
|
125 |
+
assert freqs_cis.shape == (
|
126 |
+
x.shape[1],
|
127 |
+
x.shape[-1],
|
128 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
129 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
130 |
+
return freqs_cis.view(*shape)
|
131 |
+
|
132 |
+
|
133 |
+
def rotate_half(x):
|
134 |
+
x_real, x_imag = (
|
135 |
+
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
136 |
+
) # [B, S, H, D//2]
|
137 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
138 |
+
|
139 |
+
|
140 |
+
def apply_rotary_emb(
|
141 |
+
xq: torch.Tensor,
|
142 |
+
xk: torch.Tensor,
|
143 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
144 |
+
head_first: bool = False,
|
145 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
146 |
+
"""
|
147 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
148 |
+
|
149 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
150 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
151 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
152 |
+
returned as real tensors.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
156 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
157 |
+
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
158 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
162 |
+
|
163 |
+
"""
|
164 |
+
xk_out = None
|
165 |
+
if isinstance(freqs_cis, tuple):
|
166 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
167 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
168 |
+
# real * cos - imag * sin
|
169 |
+
# imag * cos + real * sin
|
170 |
+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
171 |
+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
172 |
+
else:
|
173 |
+
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
174 |
+
xq_ = torch.view_as_complex(
|
175 |
+
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
176 |
+
) # [B, S, H, D//2]
|
177 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
178 |
+
xq.device
|
179 |
+
) # [S, D//2] --> [1, S, 1, D//2]
|
180 |
+
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
181 |
+
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
182 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
183 |
+
xk_ = torch.view_as_complex(
|
184 |
+
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
185 |
+
) # [B, S, H, D//2]
|
186 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
187 |
+
|
188 |
+
return xq_out, xk_out
|
189 |
+
|
190 |
+
|
191 |
+
def get_nd_rotary_pos_embed(
|
192 |
+
rope_dim_list,
|
193 |
+
start,
|
194 |
+
*args,
|
195 |
+
theta=10000.0,
|
196 |
+
use_real=False,
|
197 |
+
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
198 |
+
interpolation_factor: Union[float, List[float]] = 1.0,
|
199 |
+
):
|
200 |
+
"""
|
201 |
+
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
205 |
+
sum(rope_dim_list) should equal to head_dim of attention layer.
|
206 |
+
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
207 |
+
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
208 |
+
*args: See above.
|
209 |
+
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
210 |
+
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
211 |
+
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
212 |
+
part and an imaginary part separately.
|
213 |
+
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
pos_embed (torch.Tensor): [HW, D/2]
|
217 |
+
"""
|
218 |
+
|
219 |
+
grid = get_meshgrid_nd(
|
220 |
+
start, *args, dim=len(rope_dim_list)
|
221 |
+
) # [3, W, H, D] / [2, W, H]
|
222 |
+
|
223 |
+
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
224 |
+
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
225 |
+
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
226 |
+
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
227 |
+
assert len(theta_rescale_factor) == len(
|
228 |
+
rope_dim_list
|
229 |
+
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
230 |
+
|
231 |
+
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
232 |
+
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
233 |
+
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
234 |
+
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
235 |
+
assert len(interpolation_factor) == len(
|
236 |
+
rope_dim_list
|
237 |
+
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
238 |
+
|
239 |
+
# use 1/ndim of dimensions to encode grid_axis
|
240 |
+
embs = []
|
241 |
+
for i in range(len(rope_dim_list)):
|
242 |
+
emb = get_1d_rotary_pos_embed(
|
243 |
+
rope_dim_list[i],
|
244 |
+
grid[i].reshape(-1),
|
245 |
+
theta,
|
246 |
+
use_real=use_real,
|
247 |
+
theta_rescale_factor=theta_rescale_factor[i],
|
248 |
+
interpolation_factor=interpolation_factor[i],
|
249 |
+
) # 2 x [WHD, rope_dim_list[i]]
|
250 |
+
embs.append(emb)
|
251 |
+
|
252 |
+
if use_real:
|
253 |
+
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
254 |
+
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
255 |
+
return cos, sin
|
256 |
+
else:
|
257 |
+
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
258 |
+
return emb
|
259 |
+
|
260 |
+
|
261 |
+
def get_1d_rotary_pos_embed(
|
262 |
+
dim: int,
|
263 |
+
pos: Union[torch.FloatTensor, int],
|
264 |
+
theta: float = 10000.0,
|
265 |
+
use_real: bool = False,
|
266 |
+
theta_rescale_factor: float = 1.0,
|
267 |
+
interpolation_factor: float = 1.0,
|
268 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
269 |
+
"""
|
270 |
+
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
271 |
+
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
272 |
+
|
273 |
+
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
274 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
275 |
+
The returned tensor contains complex values in complex64 data type.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
dim (int): Dimension of the frequency tensor.
|
279 |
+
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
280 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
281 |
+
use_real (bool, optional): If True, return real part and imaginary part separately.
|
282 |
+
Otherwise, return complex numbers.
|
283 |
+
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
287 |
+
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
288 |
+
"""
|
289 |
+
if isinstance(pos, int):
|
290 |
+
pos = torch.arange(pos).float()
|
291 |
+
|
292 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
293 |
+
# has some connection to NTK literature
|
294 |
+
if theta_rescale_factor != 1.0:
|
295 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
296 |
+
|
297 |
+
freqs = 1.0 / (
|
298 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
299 |
+
) # [D/2]
|
300 |
+
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
301 |
+
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
302 |
+
if use_real:
|
303 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
304 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
305 |
+
return freqs_cos, freqs_sin
|
306 |
+
else:
|
307 |
+
freqs_cis = torch.polar(
|
308 |
+
torch.ones_like(freqs), freqs
|
309 |
+
) # complex64 # [S, D/2]
|
310 |
+
return freqs_cis
|
hunyuan_model/text_encoder.py
ADDED
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import Optional, Tuple, Union
|
5 |
+
from copy import deepcopy
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from transformers import (
|
10 |
+
CLIPTextModel,
|
11 |
+
CLIPTokenizer,
|
12 |
+
AutoTokenizer,
|
13 |
+
AutoModel,
|
14 |
+
CLIPConfig,
|
15 |
+
LlamaForCausalLM,
|
16 |
+
LlamaConfig,
|
17 |
+
)
|
18 |
+
from transformers.utils import ModelOutput
|
19 |
+
from transformers.models.llama import LlamaModel
|
20 |
+
from safetensors.torch import load_file
|
21 |
+
from accelerate import init_empty_weights
|
22 |
+
|
23 |
+
import logging
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
logging.basicConfig(level=logging.INFO)
|
27 |
+
|
28 |
+
|
29 |
+
CLIP_L_HUGGINGFACE_MODEL_ID = "openai/clip-vit-large-patch14"
|
30 |
+
LLAVA_HUGGINGFACE_MODEL_ID = "xtuner/llava-llama-3-8b-v1_1-transformers"
|
31 |
+
|
32 |
+
CLIP_CONFIG = {
|
33 |
+
"_name_or_path": "clip-vit-large-patch14/",
|
34 |
+
"architectures": ["CLIPModel"],
|
35 |
+
"initializer_factor": 1.0,
|
36 |
+
"logit_scale_init_value": 2.6592,
|
37 |
+
"model_type": "clip",
|
38 |
+
"projection_dim": 768,
|
39 |
+
# "text_config": {
|
40 |
+
"_name_or_path": "",
|
41 |
+
"add_cross_attention": False,
|
42 |
+
"architectures": None,
|
43 |
+
"attention_dropout": 0.0,
|
44 |
+
"bad_words_ids": None,
|
45 |
+
"bos_token_id": 0,
|
46 |
+
"chunk_size_feed_forward": 0,
|
47 |
+
"cross_attention_hidden_size": None,
|
48 |
+
"decoder_start_token_id": None,
|
49 |
+
"diversity_penalty": 0.0,
|
50 |
+
"do_sample": False,
|
51 |
+
"dropout": 0.0,
|
52 |
+
"early_stopping": False,
|
53 |
+
"encoder_no_repeat_ngram_size": 0,
|
54 |
+
"eos_token_id": 2,
|
55 |
+
"finetuning_task": None,
|
56 |
+
"forced_bos_token_id": None,
|
57 |
+
"forced_eos_token_id": None,
|
58 |
+
"hidden_act": "quick_gelu",
|
59 |
+
"hidden_size": 768,
|
60 |
+
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
|
61 |
+
"initializer_factor": 1.0,
|
62 |
+
"initializer_range": 0.02,
|
63 |
+
"intermediate_size": 3072,
|
64 |
+
"is_decoder": False,
|
65 |
+
"is_encoder_decoder": False,
|
66 |
+
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
|
67 |
+
"layer_norm_eps": 1e-05,
|
68 |
+
"length_penalty": 1.0,
|
69 |
+
"max_length": 20,
|
70 |
+
"max_position_embeddings": 77,
|
71 |
+
"min_length": 0,
|
72 |
+
"model_type": "clip_text_model",
|
73 |
+
"no_repeat_ngram_size": 0,
|
74 |
+
"num_attention_heads": 12,
|
75 |
+
"num_beam_groups": 1,
|
76 |
+
"num_beams": 1,
|
77 |
+
"num_hidden_layers": 12,
|
78 |
+
"num_return_sequences": 1,
|
79 |
+
"output_attentions": False,
|
80 |
+
"output_hidden_states": False,
|
81 |
+
"output_scores": False,
|
82 |
+
"pad_token_id": 1,
|
83 |
+
"prefix": None,
|
84 |
+
"problem_type": None,
|
85 |
+
"projection_dim": 768,
|
86 |
+
"pruned_heads": {},
|
87 |
+
"remove_invalid_values": False,
|
88 |
+
"repetition_penalty": 1.0,
|
89 |
+
"return_dict": True,
|
90 |
+
"return_dict_in_generate": False,
|
91 |
+
"sep_token_id": None,
|
92 |
+
"task_specific_params": None,
|
93 |
+
"temperature": 1.0,
|
94 |
+
"tie_encoder_decoder": False,
|
95 |
+
"tie_word_embeddings": True,
|
96 |
+
"tokenizer_class": None,
|
97 |
+
"top_k": 50,
|
98 |
+
"top_p": 1.0,
|
99 |
+
"torch_dtype": None,
|
100 |
+
"torchscript": False,
|
101 |
+
"transformers_version": "4.16.0.dev0",
|
102 |
+
"use_bfloat16": False,
|
103 |
+
"vocab_size": 49408,
|
104 |
+
# },
|
105 |
+
# "text_config_dict": {
|
106 |
+
"hidden_size": 768,
|
107 |
+
"intermediate_size": 3072,
|
108 |
+
"num_attention_heads": 12,
|
109 |
+
"num_hidden_layers": 12,
|
110 |
+
"projection_dim": 768,
|
111 |
+
# },
|
112 |
+
# "torch_dtype": "float32",
|
113 |
+
# "transformers_version": null
|
114 |
+
}
|
115 |
+
|
116 |
+
LLAMA_CONFIG = {
|
117 |
+
"architectures": ["LlamaForCausalLM"],
|
118 |
+
"attention_bias": False,
|
119 |
+
"attention_dropout": 0.0,
|
120 |
+
"bos_token_id": 128000,
|
121 |
+
"eos_token_id": 128001,
|
122 |
+
"head_dim": 128,
|
123 |
+
"hidden_act": "silu",
|
124 |
+
"hidden_size": 4096,
|
125 |
+
"initializer_range": 0.02,
|
126 |
+
"intermediate_size": 14336,
|
127 |
+
"max_position_embeddings": 8192,
|
128 |
+
"mlp_bias": False,
|
129 |
+
"model_type": "llama",
|
130 |
+
"num_attention_heads": 32,
|
131 |
+
"num_hidden_layers": 32,
|
132 |
+
"num_key_value_heads": 8,
|
133 |
+
"pretraining_tp": 1,
|
134 |
+
"rms_norm_eps": 1e-05,
|
135 |
+
"rope_scaling": None,
|
136 |
+
"rope_theta": 500000.0,
|
137 |
+
"tie_word_embeddings": False,
|
138 |
+
"torch_dtype": "float16",
|
139 |
+
"transformers_version": "4.46.3",
|
140 |
+
"use_cache": True,
|
141 |
+
"vocab_size": 128320,
|
142 |
+
}
|
143 |
+
|
144 |
+
# When using decoder-only models, we must provide a prompt template to instruct the text encoder
|
145 |
+
# on how to generate the text.
|
146 |
+
# --------------------------------------------------------------------
|
147 |
+
PROMPT_TEMPLATE_ENCODE = (
|
148 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
149 |
+
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
150 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
151 |
+
)
|
152 |
+
PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
153 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
154 |
+
"1. The main content and theme of the video."
|
155 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
156 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
157 |
+
"4. background environment, light, style and atmosphere."
|
158 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
159 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
160 |
+
)
|
161 |
+
|
162 |
+
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
163 |
+
|
164 |
+
PROMPT_TEMPLATE = {
|
165 |
+
"dit-llm-encode": {
|
166 |
+
"template": PROMPT_TEMPLATE_ENCODE,
|
167 |
+
"crop_start": 36,
|
168 |
+
},
|
169 |
+
"dit-llm-encode-video": {
|
170 |
+
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
171 |
+
"crop_start": 95,
|
172 |
+
},
|
173 |
+
}
|
174 |
+
|
175 |
+
|
176 |
+
def use_default(value, default):
|
177 |
+
return value if value is not None else default
|
178 |
+
|
179 |
+
|
180 |
+
def load_clip_l(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
|
181 |
+
if os.path.isdir(text_encoder_path):
|
182 |
+
# load from directory, configs are in the directory
|
183 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
|
184 |
+
else:
|
185 |
+
# load from file, we create the model with the appropriate config
|
186 |
+
config = CLIPConfig(**CLIP_CONFIG)
|
187 |
+
with init_empty_weights():
|
188 |
+
text_encoder = CLIPTextModel._from_config(config, torch_dtype=dtype)
|
189 |
+
|
190 |
+
state_dict = load_file(text_encoder_path)
|
191 |
+
|
192 |
+
text_encoder.load_state_dict(state_dict, strict=True, assign=True)
|
193 |
+
# if dtype is not None:
|
194 |
+
# text_encoder.to(dtype=dtype)
|
195 |
+
|
196 |
+
return text_encoder
|
197 |
+
|
198 |
+
|
199 |
+
def load_clip_l_tokenizer(tokenizer_path: str):
|
200 |
+
if os.path.isdir(tokenizer_path):
|
201 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
|
202 |
+
else:
|
203 |
+
# load from Hugging Face
|
204 |
+
logger.info(f"Loading tokenizer from Hugging Face: {CLIP_L_HUGGINGFACE_MODEL_ID}")
|
205 |
+
tokenizer = CLIPTokenizer.from_pretrained(CLIP_L_HUGGINGFACE_MODEL_ID, max_length=77)
|
206 |
+
|
207 |
+
return tokenizer
|
208 |
+
|
209 |
+
|
210 |
+
def load_llm(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
|
211 |
+
if os.path.isdir(text_encoder_path):
|
212 |
+
# load from directory, configs are in the directory
|
213 |
+
text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
|
214 |
+
else:
|
215 |
+
# load from file, we create the model with the appropriate config
|
216 |
+
config = LlamaConfig(**LLAMA_CONFIG)
|
217 |
+
with init_empty_weights():
|
218 |
+
text_encoder = LlamaForCausalLM._from_config(config, torch_dtype=dtype)
|
219 |
+
|
220 |
+
state_dict = load_file(text_encoder_path)
|
221 |
+
|
222 |
+
# support weights from ComfyUI
|
223 |
+
if "tokenizer" in state_dict:
|
224 |
+
state_dict.pop("tokenizer")
|
225 |
+
|
226 |
+
text_encoder.load_state_dict(state_dict, strict=True, assign=True)
|
227 |
+
|
228 |
+
return text_encoder
|
229 |
+
|
230 |
+
|
231 |
+
def load_llm_tokenizer(tokenizer_path: str, padding_side="right"):
|
232 |
+
if os.path.isdir(tokenizer_path):
|
233 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
234 |
+
else:
|
235 |
+
# load from Hugging Face
|
236 |
+
logger.info(f"Loading tokenizer from Hugging Face: {LLAVA_HUGGINGFACE_MODEL_ID}")
|
237 |
+
tokenizer = AutoTokenizer.from_pretrained(LLAVA_HUGGINGFACE_MODEL_ID, padding_side=padding_side)
|
238 |
+
|
239 |
+
return tokenizer
|
240 |
+
|
241 |
+
|
242 |
+
def load_text_encoder(
|
243 |
+
text_encoder_type: str,
|
244 |
+
text_encoder_path: str,
|
245 |
+
text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
|
246 |
+
):
|
247 |
+
logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
|
248 |
+
|
249 |
+
# reduce peak memory usage by specifying the dtype of the model
|
250 |
+
dtype = text_encoder_dtype
|
251 |
+
if text_encoder_type == "clipL":
|
252 |
+
text_encoder = load_clip_l(text_encoder_path, dtype=dtype)
|
253 |
+
text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
|
254 |
+
elif text_encoder_type == "llm":
|
255 |
+
text_encoder = load_llm(text_encoder_path, dtype=dtype)
|
256 |
+
if hasattr(text_encoder, "norm"):
|
257 |
+
text_encoder.final_layer_norm = text_encoder.norm # by from_pretrained
|
258 |
+
else:
|
259 |
+
text_encoder.final_layer_norm = text_encoder.model.norm # by _from_config
|
260 |
+
else:
|
261 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
262 |
+
# from_pretrained will ensure that the model is in eval mode.
|
263 |
+
|
264 |
+
if dtype is not None:
|
265 |
+
text_encoder = text_encoder.to(dtype=dtype)
|
266 |
+
|
267 |
+
text_encoder.requires_grad_(False)
|
268 |
+
|
269 |
+
logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
|
270 |
+
return text_encoder, text_encoder_path
|
271 |
+
|
272 |
+
|
273 |
+
def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
|
274 |
+
logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
|
275 |
+
|
276 |
+
if tokenizer_type == "clipL":
|
277 |
+
tokenizer = load_clip_l_tokenizer(tokenizer_path)
|
278 |
+
elif tokenizer_type == "llm":
|
279 |
+
tokenizer = load_llm_tokenizer(tokenizer_path, padding_side=padding_side)
|
280 |
+
else:
|
281 |
+
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
|
282 |
+
|
283 |
+
return tokenizer, tokenizer_path
|
284 |
+
|
285 |
+
|
286 |
+
@dataclass
|
287 |
+
class TextEncoderModelOutput(ModelOutput):
|
288 |
+
"""
|
289 |
+
Base class for model's outputs that also contains a pooling of the last hidden states.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
293 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
294 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
295 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
296 |
+
hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
|
297 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
298 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
299 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
300 |
+
text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
|
301 |
+
List of decoded texts.
|
302 |
+
"""
|
303 |
+
|
304 |
+
hidden_state: torch.FloatTensor = None
|
305 |
+
attention_mask: Optional[torch.LongTensor] = None
|
306 |
+
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
|
307 |
+
text_outputs: Optional[list] = None
|
308 |
+
|
309 |
+
|
310 |
+
class TextEncoder(nn.Module):
|
311 |
+
def __init__(
|
312 |
+
self,
|
313 |
+
text_encoder_type: str,
|
314 |
+
max_length: int,
|
315 |
+
text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
|
316 |
+
text_encoder_path: Optional[str] = None,
|
317 |
+
tokenizer_type: Optional[str] = None,
|
318 |
+
tokenizer_path: Optional[str] = None,
|
319 |
+
output_key: Optional[str] = None,
|
320 |
+
use_attention_mask: bool = True,
|
321 |
+
input_max_length: Optional[int] = None,
|
322 |
+
prompt_template: Optional[dict] = None,
|
323 |
+
prompt_template_video: Optional[dict] = None,
|
324 |
+
hidden_state_skip_layer: Optional[int] = None,
|
325 |
+
apply_final_norm: bool = False,
|
326 |
+
reproduce: bool = False,
|
327 |
+
):
|
328 |
+
super().__init__()
|
329 |
+
self.text_encoder_type = text_encoder_type
|
330 |
+
self.max_length = max_length
|
331 |
+
# self.precision = text_encoder_precision
|
332 |
+
self.model_path = text_encoder_path
|
333 |
+
self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
|
334 |
+
self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
|
335 |
+
self.use_attention_mask = use_attention_mask
|
336 |
+
if prompt_template_video is not None:
|
337 |
+
assert use_attention_mask is True, "Attention mask is True required when training videos."
|
338 |
+
self.input_max_length = input_max_length if input_max_length is not None else max_length
|
339 |
+
self.prompt_template = prompt_template
|
340 |
+
self.prompt_template_video = prompt_template_video
|
341 |
+
self.hidden_state_skip_layer = hidden_state_skip_layer
|
342 |
+
self.apply_final_norm = apply_final_norm
|
343 |
+
self.reproduce = reproduce
|
344 |
+
|
345 |
+
self.use_template = self.prompt_template is not None
|
346 |
+
if self.use_template:
|
347 |
+
assert (
|
348 |
+
isinstance(self.prompt_template, dict) and "template" in self.prompt_template
|
349 |
+
), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
|
350 |
+
assert "{}" in str(self.prompt_template["template"]), (
|
351 |
+
"`prompt_template['template']` must contain a placeholder `{}` for the input text, "
|
352 |
+
f"got {self.prompt_template['template']}"
|
353 |
+
)
|
354 |
+
|
355 |
+
self.use_video_template = self.prompt_template_video is not None
|
356 |
+
if self.use_video_template:
|
357 |
+
if self.prompt_template_video is not None:
|
358 |
+
assert (
|
359 |
+
isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
|
360 |
+
), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
|
361 |
+
assert "{}" in str(self.prompt_template_video["template"]), (
|
362 |
+
"`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
|
363 |
+
f"got {self.prompt_template_video['template']}"
|
364 |
+
)
|
365 |
+
|
366 |
+
if "t5" in text_encoder_type:
|
367 |
+
self.output_key = output_key or "last_hidden_state"
|
368 |
+
elif "clip" in text_encoder_type:
|
369 |
+
self.output_key = output_key or "pooler_output"
|
370 |
+
elif "llm" in text_encoder_type or "glm" in text_encoder_type:
|
371 |
+
self.output_key = output_key or "last_hidden_state"
|
372 |
+
else:
|
373 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
374 |
+
|
375 |
+
self.model, self.model_path = load_text_encoder(
|
376 |
+
text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
|
377 |
+
)
|
378 |
+
self.dtype = self.model.dtype
|
379 |
+
|
380 |
+
self.tokenizer, self.tokenizer_path = load_tokenizer(
|
381 |
+
tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
|
382 |
+
)
|
383 |
+
|
384 |
+
def __repr__(self):
|
385 |
+
return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
|
386 |
+
|
387 |
+
@property
|
388 |
+
def device(self):
|
389 |
+
return self.model.device
|
390 |
+
|
391 |
+
@staticmethod
|
392 |
+
def apply_text_to_template(text, template, prevent_empty_text=True):
|
393 |
+
"""
|
394 |
+
Apply text to template.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
text (str): Input text.
|
398 |
+
template (str or list): Template string or list of chat conversation.
|
399 |
+
prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
|
400 |
+
by adding a space. Defaults to True.
|
401 |
+
"""
|
402 |
+
if isinstance(template, str):
|
403 |
+
# Will send string to tokenizer. Used for llm
|
404 |
+
return template.format(text)
|
405 |
+
else:
|
406 |
+
raise TypeError(f"Unsupported template type: {type(template)}")
|
407 |
+
|
408 |
+
def text2tokens(self, text, data_type="image"):
|
409 |
+
"""
|
410 |
+
Tokenize the input text.
|
411 |
+
|
412 |
+
Args:
|
413 |
+
text (str or list): Input text.
|
414 |
+
"""
|
415 |
+
tokenize_input_type = "str"
|
416 |
+
if self.use_template:
|
417 |
+
if data_type == "image":
|
418 |
+
prompt_template = self.prompt_template["template"]
|
419 |
+
elif data_type == "video":
|
420 |
+
prompt_template = self.prompt_template_video["template"]
|
421 |
+
else:
|
422 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
423 |
+
if isinstance(text, (list, tuple)):
|
424 |
+
text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
|
425 |
+
if isinstance(text[0], list):
|
426 |
+
tokenize_input_type = "list"
|
427 |
+
elif isinstance(text, str):
|
428 |
+
text = self.apply_text_to_template(text, prompt_template)
|
429 |
+
if isinstance(text, list):
|
430 |
+
tokenize_input_type = "list"
|
431 |
+
else:
|
432 |
+
raise TypeError(f"Unsupported text type: {type(text)}")
|
433 |
+
|
434 |
+
kwargs = dict(
|
435 |
+
truncation=True,
|
436 |
+
max_length=self.max_length,
|
437 |
+
padding="max_length",
|
438 |
+
return_tensors="pt",
|
439 |
+
)
|
440 |
+
if tokenize_input_type == "str":
|
441 |
+
return self.tokenizer(
|
442 |
+
text,
|
443 |
+
return_length=False,
|
444 |
+
return_overflowing_tokens=False,
|
445 |
+
return_attention_mask=True,
|
446 |
+
**kwargs,
|
447 |
+
)
|
448 |
+
elif tokenize_input_type == "list":
|
449 |
+
return self.tokenizer.apply_chat_template(
|
450 |
+
text,
|
451 |
+
add_generation_prompt=True,
|
452 |
+
tokenize=True,
|
453 |
+
return_dict=True,
|
454 |
+
**kwargs,
|
455 |
+
)
|
456 |
+
else:
|
457 |
+
raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
|
458 |
+
|
459 |
+
def encode(
|
460 |
+
self,
|
461 |
+
batch_encoding,
|
462 |
+
use_attention_mask=None,
|
463 |
+
output_hidden_states=False,
|
464 |
+
do_sample=None,
|
465 |
+
hidden_state_skip_layer=None,
|
466 |
+
return_texts=False,
|
467 |
+
data_type="image",
|
468 |
+
device=None,
|
469 |
+
):
|
470 |
+
"""
|
471 |
+
Args:
|
472 |
+
batch_encoding (dict): Batch encoding from tokenizer.
|
473 |
+
use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
|
474 |
+
Defaults to None.
|
475 |
+
output_hidden_states (bool): Whether to output hidden states. If False, return the value of
|
476 |
+
self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
|
477 |
+
output_hidden_states will be set True. Defaults to False.
|
478 |
+
do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
|
479 |
+
When self.produce is False, do_sample is set to True by default.
|
480 |
+
hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
|
481 |
+
If None, self.output_key will be used. Defaults to None.
|
482 |
+
return_texts (bool): Whether to return the decoded texts. Defaults to False.
|
483 |
+
"""
|
484 |
+
device = self.model.device if device is None else device
|
485 |
+
use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
|
486 |
+
hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
|
487 |
+
do_sample = use_default(do_sample, not self.reproduce)
|
488 |
+
attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
|
489 |
+
outputs = self.model(
|
490 |
+
input_ids=batch_encoding["input_ids"].to(device),
|
491 |
+
attention_mask=attention_mask,
|
492 |
+
output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
|
493 |
+
)
|
494 |
+
if hidden_state_skip_layer is not None:
|
495 |
+
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
496 |
+
# Real last hidden state already has layer norm applied. So here we only apply it
|
497 |
+
# for intermediate layers.
|
498 |
+
if hidden_state_skip_layer > 0 and self.apply_final_norm:
|
499 |
+
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
|
500 |
+
else:
|
501 |
+
last_hidden_state = outputs[self.output_key]
|
502 |
+
|
503 |
+
# Remove hidden states of instruction tokens, only keep prompt tokens.
|
504 |
+
if self.use_template:
|
505 |
+
if data_type == "image":
|
506 |
+
crop_start = self.prompt_template.get("crop_start", -1)
|
507 |
+
elif data_type == "video":
|
508 |
+
crop_start = self.prompt_template_video.get("crop_start", -1)
|
509 |
+
else:
|
510 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
511 |
+
if crop_start > 0:
|
512 |
+
last_hidden_state = last_hidden_state[:, crop_start:]
|
513 |
+
attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
|
514 |
+
|
515 |
+
if output_hidden_states:
|
516 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
|
517 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask)
|
518 |
+
|
519 |
+
def forward(
|
520 |
+
self,
|
521 |
+
text,
|
522 |
+
use_attention_mask=None,
|
523 |
+
output_hidden_states=False,
|
524 |
+
do_sample=False,
|
525 |
+
hidden_state_skip_layer=None,
|
526 |
+
return_texts=False,
|
527 |
+
):
|
528 |
+
batch_encoding = self.text2tokens(text)
|
529 |
+
return self.encode(
|
530 |
+
batch_encoding,
|
531 |
+
use_attention_mask=use_attention_mask,
|
532 |
+
output_hidden_states=output_hidden_states,
|
533 |
+
do_sample=do_sample,
|
534 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
535 |
+
return_texts=return_texts,
|
536 |
+
)
|
537 |
+
|
538 |
+
|
539 |
+
# region HunyanVideo architecture
|
540 |
+
|
541 |
+
|
542 |
+
def load_text_encoder_1(
|
543 |
+
text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
|
544 |
+
) -> TextEncoder:
|
545 |
+
text_encoder_dtype = dtype or torch.float16
|
546 |
+
text_encoder_type = "llm"
|
547 |
+
text_len = 256
|
548 |
+
hidden_state_skip_layer = 2
|
549 |
+
apply_final_norm = False
|
550 |
+
reproduce = False
|
551 |
+
|
552 |
+
prompt_template = "dit-llm-encode"
|
553 |
+
prompt_template = PROMPT_TEMPLATE[prompt_template]
|
554 |
+
prompt_template_video = "dit-llm-encode-video"
|
555 |
+
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
|
556 |
+
|
557 |
+
crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
|
558 |
+
max_length = text_len + crop_start
|
559 |
+
|
560 |
+
text_encoder_1 = TextEncoder(
|
561 |
+
text_encoder_type=text_encoder_type,
|
562 |
+
max_length=max_length,
|
563 |
+
text_encoder_dtype=text_encoder_dtype,
|
564 |
+
text_encoder_path=text_encoder_dir,
|
565 |
+
tokenizer_type=text_encoder_type,
|
566 |
+
prompt_template=prompt_template,
|
567 |
+
prompt_template_video=prompt_template_video,
|
568 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
569 |
+
apply_final_norm=apply_final_norm,
|
570 |
+
reproduce=reproduce,
|
571 |
+
)
|
572 |
+
text_encoder_1.eval()
|
573 |
+
|
574 |
+
if fp8_llm:
|
575 |
+
org_dtype = text_encoder_1.dtype
|
576 |
+
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
|
577 |
+
text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
|
578 |
+
|
579 |
+
# prepare LLM for fp8
|
580 |
+
def prepare_fp8(llama_model: LlamaModel, target_dtype):
|
581 |
+
def forward_hook(module):
|
582 |
+
def forward(hidden_states):
|
583 |
+
input_dtype = hidden_states.dtype
|
584 |
+
hidden_states = hidden_states.to(torch.float32)
|
585 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
586 |
+
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
|
587 |
+
return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
|
588 |
+
|
589 |
+
return forward
|
590 |
+
|
591 |
+
for module in llama_model.modules():
|
592 |
+
if module.__class__.__name__ in ["Embedding"]:
|
593 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
594 |
+
module.to(target_dtype)
|
595 |
+
if module.__class__.__name__ in ["LlamaRMSNorm"]:
|
596 |
+
# print("set", module.__class__.__name__, "hooks")
|
597 |
+
module.forward = forward_hook(module)
|
598 |
+
|
599 |
+
prepare_fp8(text_encoder_1.model, org_dtype)
|
600 |
+
else:
|
601 |
+
text_encoder_1.to(device=device)
|
602 |
+
|
603 |
+
return text_encoder_1
|
604 |
+
|
605 |
+
|
606 |
+
def load_text_encoder_2(
|
607 |
+
text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
|
608 |
+
) -> TextEncoder:
|
609 |
+
text_encoder_dtype = dtype or torch.float16
|
610 |
+
reproduce = False
|
611 |
+
|
612 |
+
text_encoder_2_type = "clipL"
|
613 |
+
text_len_2 = 77
|
614 |
+
|
615 |
+
text_encoder_2 = TextEncoder(
|
616 |
+
text_encoder_type=text_encoder_2_type,
|
617 |
+
max_length=text_len_2,
|
618 |
+
text_encoder_dtype=text_encoder_dtype,
|
619 |
+
text_encoder_path=text_encoder_dir,
|
620 |
+
tokenizer_type=text_encoder_2_type,
|
621 |
+
reproduce=reproduce,
|
622 |
+
)
|
623 |
+
text_encoder_2.eval()
|
624 |
+
|
625 |
+
text_encoder_2.to(device=device)
|
626 |
+
|
627 |
+
return text_encoder_2
|
628 |
+
|
629 |
+
|
630 |
+
# endregion
|
631 |
+
|
632 |
+
|
633 |
+
if __name__ == "__main__":
|
634 |
+
import argparse
|
635 |
+
from utils.model_utils import str_to_dtype
|
636 |
+
|
637 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
638 |
+
|
639 |
+
parser = argparse.ArgumentParser()
|
640 |
+
parser.add_argument("type", type=str, help="Text Encoder type")
|
641 |
+
parser.add_argument("path1", type=str, help="Text Encoder directory or file 1")
|
642 |
+
parser.add_argument("path2", type=str, help="Text Encoder directory or file 2")
|
643 |
+
parser.add_argument("--dtype", type=str, default=None, help="Data type for Text Encoder")
|
644 |
+
args = parser.parse_args()
|
645 |
+
|
646 |
+
dtype = str_to_dtype(args.dtype) if args.dtype is not None else torch.float16
|
647 |
+
|
648 |
+
"""
|
649 |
+
if args.type == "clipL":
|
650 |
+
text_encoder_1st = load_clip_l(args.path1, dtype=dtype)
|
651 |
+
tokenizer_1st = load_clip_l_tokenizer(args.path1)
|
652 |
+
text_encoder_2nd = load_clip_l(args.path2, dtype=dtype)
|
653 |
+
tokenizer_2nd = load_clip_l_tokenizer(args.path2)
|
654 |
+
elif args.type == "llm":
|
655 |
+
text_encoder_1st = load_llm(args.path1, dtype=dtype)
|
656 |
+
tokenizer_1st = load_llm_tokenizer(args.path1)
|
657 |
+
text_encoder_2nd = load_llm(args.path2, dtype=dtype)
|
658 |
+
tokenizer_2nd = load_llm_tokenizer(args.path2)
|
659 |
+
|
660 |
+
print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
|
661 |
+
print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
|
662 |
+
|
663 |
+
text_encoder_1st.to(device=device)
|
664 |
+
text_encoder_2nd.to(device=device)
|
665 |
+
|
666 |
+
test_text = "A cat sitting on a table"
|
667 |
+
token_ids_1st = tokenizer_1st(test_text, return_tensors="pt")["input_ids"]
|
668 |
+
token_ids_2nd = tokenizer_2nd(test_text, return_tensors="pt")["input_ids"]
|
669 |
+
assert torch.allclose(token_ids_1st, token_ids_2nd)
|
670 |
+
print(f"Token IDs are the same: {token_ids_1st}")
|
671 |
+
|
672 |
+
with torch.no_grad():
|
673 |
+
text_encoder_1st_output = text_encoder_1st(token_ids_1st.to(device), output_hidden_states=True)
|
674 |
+
text_encoder_2nd_output = text_encoder_2nd(token_ids_2nd.to(device), output_hidden_states=True)
|
675 |
+
print(f"1st Text Encoder output keys: {text_encoder_1st_output.keys()}")
|
676 |
+
print(f"2nd Text Encoder output keys: {text_encoder_2nd_output.keys()}")
|
677 |
+
for key in text_encoder_1st_output:
|
678 |
+
print(f"Checking output: {key}")
|
679 |
+
assert key in text_encoder_2nd_output, f"Key {key} not in 2nd Text Encoder output"
|
680 |
+
assert torch.allclose(text_encoder_1st_output[key], text_encoder_2nd_output[key])
|
681 |
+
print(f"Outputs are the same: {key}")
|
682 |
+
print("All outputs are the same.")
|
683 |
+
"""
|
684 |
+
|
685 |
+
if args.type == "clipL":
|
686 |
+
text_encoder_1st = load_text_encoder_2(args.path1, device, dtype)
|
687 |
+
text_encoder_2nd = load_text_encoder_2(args.path2, device, dtype)
|
688 |
+
elif args.type == "llm":
|
689 |
+
text_encoder_1st = load_text_encoder_1(args.path1, device, False, dtype)
|
690 |
+
text_encoder_2nd = load_text_encoder_1(args.path2, device, False, dtype)
|
691 |
+
print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
|
692 |
+
print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
|
693 |
+
|
694 |
+
prompt = "A cat sitting on a table"
|
695 |
+
data_type = "video" # video only, image is not supported
|
696 |
+
text_inputs_1st = text_encoder_1st.text2tokens(prompt, data_type=data_type)
|
697 |
+
text_inputs_2nd = text_encoder_2nd.text2tokens(prompt, data_type=data_type)
|
698 |
+
print(text_inputs_1st)
|
699 |
+
assert torch.allclose(text_inputs_1st["input_ids"], text_inputs_2nd["input_ids"])
|
700 |
+
|
701 |
+
with torch.no_grad():
|
702 |
+
prompt_outputs_1st = text_encoder_1st.encode(text_inputs_1st, data_type=data_type)
|
703 |
+
prompt_outputs_2nd = text_encoder_2nd.encode(text_inputs_1st, data_type=data_type)
|
704 |
+
|
705 |
+
# prompt_outputs.hidden_state, prompt_outputs.attention_mask
|
706 |
+
assert torch.allclose(prompt_outputs_1st.hidden_state, prompt_outputs_2nd.hidden_state)
|
707 |
+
print("Hidden states are the same.")
|
708 |
+
assert torch.allclose(prompt_outputs_1st.attention_mask, prompt_outputs_2nd.attention_mask)
|
709 |
+
print("Attention masks are the same.")
|
710 |
+
print("All outputs are the same.")
|
hunyuan_model/token_refiner.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.utils.checkpoint import checkpoint
|
7 |
+
|
8 |
+
from .activation_layers import get_activation_layer
|
9 |
+
from .attention import attention
|
10 |
+
from .norm_layers import get_norm_layer
|
11 |
+
from .embed_layers import TimestepEmbedder, TextProjection
|
12 |
+
from .mlp_layers import MLP
|
13 |
+
from .modulate_layers import modulate, apply_gate
|
14 |
+
|
15 |
+
|
16 |
+
class IndividualTokenRefinerBlock(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
hidden_size,
|
20 |
+
heads_num,
|
21 |
+
mlp_width_ratio: str = 4.0,
|
22 |
+
mlp_drop_rate: float = 0.0,
|
23 |
+
act_type: str = "silu",
|
24 |
+
qk_norm: bool = False,
|
25 |
+
qk_norm_type: str = "layer",
|
26 |
+
qkv_bias: bool = True,
|
27 |
+
dtype: Optional[torch.dtype] = None,
|
28 |
+
device: Optional[torch.device] = None,
|
29 |
+
):
|
30 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
31 |
+
super().__init__()
|
32 |
+
self.heads_num = heads_num
|
33 |
+
head_dim = hidden_size // heads_num
|
34 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
35 |
+
|
36 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
37 |
+
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
38 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
39 |
+
self.self_attn_q_norm = (
|
40 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
41 |
+
)
|
42 |
+
self.self_attn_k_norm = (
|
43 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
44 |
+
)
|
45 |
+
self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
46 |
+
|
47 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
48 |
+
act_layer = get_activation_layer(act_type)
|
49 |
+
self.mlp = MLP(
|
50 |
+
in_channels=hidden_size,
|
51 |
+
hidden_channels=mlp_hidden_dim,
|
52 |
+
act_layer=act_layer,
|
53 |
+
drop=mlp_drop_rate,
|
54 |
+
**factory_kwargs,
|
55 |
+
)
|
56 |
+
|
57 |
+
self.adaLN_modulation = nn.Sequential(
|
58 |
+
act_layer(),
|
59 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
60 |
+
)
|
61 |
+
# Zero-initialize the modulation
|
62 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
63 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
64 |
+
|
65 |
+
self.gradient_checkpointing = False
|
66 |
+
|
67 |
+
def enable_gradient_checkpointing(self):
|
68 |
+
self.gradient_checkpointing = True
|
69 |
+
|
70 |
+
def disable_gradient_checkpointing(self):
|
71 |
+
self.gradient_checkpointing = False
|
72 |
+
|
73 |
+
def _forward(
|
74 |
+
self,
|
75 |
+
x: torch.Tensor,
|
76 |
+
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
77 |
+
attn_mask: torch.Tensor = None,
|
78 |
+
):
|
79 |
+
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
80 |
+
|
81 |
+
norm_x = self.norm1(x)
|
82 |
+
qkv = self.self_attn_qkv(norm_x)
|
83 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
84 |
+
# Apply QK-Norm if needed
|
85 |
+
q = self.self_attn_q_norm(q).to(v)
|
86 |
+
k = self.self_attn_k_norm(k).to(v)
|
87 |
+
|
88 |
+
# Self-Attention
|
89 |
+
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
90 |
+
|
91 |
+
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
92 |
+
|
93 |
+
# FFN Layer
|
94 |
+
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
95 |
+
|
96 |
+
return x
|
97 |
+
|
98 |
+
def forward(self, *args, **kwargs):
|
99 |
+
if self.training and self.gradient_checkpointing:
|
100 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
101 |
+
else:
|
102 |
+
return self._forward(*args, **kwargs)
|
103 |
+
|
104 |
+
|
105 |
+
class IndividualTokenRefiner(nn.Module):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
hidden_size,
|
109 |
+
heads_num,
|
110 |
+
depth,
|
111 |
+
mlp_width_ratio: float = 4.0,
|
112 |
+
mlp_drop_rate: float = 0.0,
|
113 |
+
act_type: str = "silu",
|
114 |
+
qk_norm: bool = False,
|
115 |
+
qk_norm_type: str = "layer",
|
116 |
+
qkv_bias: bool = True,
|
117 |
+
dtype: Optional[torch.dtype] = None,
|
118 |
+
device: Optional[torch.device] = None,
|
119 |
+
):
|
120 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
121 |
+
super().__init__()
|
122 |
+
self.blocks = nn.ModuleList(
|
123 |
+
[
|
124 |
+
IndividualTokenRefinerBlock(
|
125 |
+
hidden_size=hidden_size,
|
126 |
+
heads_num=heads_num,
|
127 |
+
mlp_width_ratio=mlp_width_ratio,
|
128 |
+
mlp_drop_rate=mlp_drop_rate,
|
129 |
+
act_type=act_type,
|
130 |
+
qk_norm=qk_norm,
|
131 |
+
qk_norm_type=qk_norm_type,
|
132 |
+
qkv_bias=qkv_bias,
|
133 |
+
**factory_kwargs,
|
134 |
+
)
|
135 |
+
for _ in range(depth)
|
136 |
+
]
|
137 |
+
)
|
138 |
+
|
139 |
+
def enable_gradient_checkpointing(self):
|
140 |
+
for block in self.blocks:
|
141 |
+
block.enable_gradient_checkpointing()
|
142 |
+
|
143 |
+
def disable_gradient_checkpointing(self):
|
144 |
+
for block in self.blocks:
|
145 |
+
block.disable_gradient_checkpointing()
|
146 |
+
|
147 |
+
def forward(
|
148 |
+
self,
|
149 |
+
x: torch.Tensor,
|
150 |
+
c: torch.LongTensor,
|
151 |
+
mask: Optional[torch.Tensor] = None,
|
152 |
+
):
|
153 |
+
self_attn_mask = None
|
154 |
+
if mask is not None:
|
155 |
+
batch_size = mask.shape[0]
|
156 |
+
seq_len = mask.shape[1]
|
157 |
+
mask = mask.to(x.device)
|
158 |
+
# batch_size x 1 x seq_len x seq_len
|
159 |
+
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
160 |
+
# batch_size x 1 x seq_len x seq_len
|
161 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
162 |
+
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
163 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
164 |
+
# avoids self-attention weight being NaN for padding tokens
|
165 |
+
self_attn_mask[:, :, :, 0] = True
|
166 |
+
|
167 |
+
for block in self.blocks:
|
168 |
+
x = block(x, c, self_attn_mask)
|
169 |
+
return x
|
170 |
+
|
171 |
+
|
172 |
+
class SingleTokenRefiner(nn.Module):
|
173 |
+
"""
|
174 |
+
A single token refiner block for llm text embedding refine.
|
175 |
+
"""
|
176 |
+
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
in_channels,
|
180 |
+
hidden_size,
|
181 |
+
heads_num,
|
182 |
+
depth,
|
183 |
+
mlp_width_ratio: float = 4.0,
|
184 |
+
mlp_drop_rate: float = 0.0,
|
185 |
+
act_type: str = "silu",
|
186 |
+
qk_norm: bool = False,
|
187 |
+
qk_norm_type: str = "layer",
|
188 |
+
qkv_bias: bool = True,
|
189 |
+
attn_mode: str = "torch",
|
190 |
+
dtype: Optional[torch.dtype] = None,
|
191 |
+
device: Optional[torch.device] = None,
|
192 |
+
):
|
193 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
194 |
+
super().__init__()
|
195 |
+
self.attn_mode = attn_mode
|
196 |
+
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
197 |
+
|
198 |
+
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
|
199 |
+
|
200 |
+
act_layer = get_activation_layer(act_type)
|
201 |
+
# Build timestep embedding layer
|
202 |
+
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
203 |
+
# Build context embedding layer
|
204 |
+
self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
|
205 |
+
|
206 |
+
self.individual_token_refiner = IndividualTokenRefiner(
|
207 |
+
hidden_size=hidden_size,
|
208 |
+
heads_num=heads_num,
|
209 |
+
depth=depth,
|
210 |
+
mlp_width_ratio=mlp_width_ratio,
|
211 |
+
mlp_drop_rate=mlp_drop_rate,
|
212 |
+
act_type=act_type,
|
213 |
+
qk_norm=qk_norm,
|
214 |
+
qk_norm_type=qk_norm_type,
|
215 |
+
qkv_bias=qkv_bias,
|
216 |
+
**factory_kwargs,
|
217 |
+
)
|
218 |
+
|
219 |
+
def enable_gradient_checkpointing(self):
|
220 |
+
self.individual_token_refiner.enable_gradient_checkpointing()
|
221 |
+
|
222 |
+
def disable_gradient_checkpointing(self):
|
223 |
+
self.individual_token_refiner.disable_gradient_checkpointing()
|
224 |
+
|
225 |
+
def forward(
|
226 |
+
self,
|
227 |
+
x: torch.Tensor,
|
228 |
+
t: torch.LongTensor,
|
229 |
+
mask: Optional[torch.LongTensor] = None,
|
230 |
+
):
|
231 |
+
timestep_aware_representations = self.t_embedder(t)
|
232 |
+
|
233 |
+
if mask is None:
|
234 |
+
context_aware_representations = x.mean(dim=1)
|
235 |
+
else:
|
236 |
+
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
|
237 |
+
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
238 |
+
context_aware_representations = self.c_embedder(context_aware_representations)
|
239 |
+
c = timestep_aware_representations + context_aware_representations
|
240 |
+
|
241 |
+
x = self.input_embedder(x)
|
242 |
+
|
243 |
+
x = self.individual_token_refiner(x, c, mask)
|
244 |
+
|
245 |
+
return x
|
hunyuan_model/vae.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import json
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from diffusers.utils import BaseOutput, is_torch_version
|
11 |
+
from diffusers.utils.torch_utils import randn_tensor
|
12 |
+
from diffusers.models.attention_processor import SpatialNorm
|
13 |
+
from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
|
14 |
+
|
15 |
+
import logging
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
logging.basicConfig(level=logging.INFO)
|
19 |
+
|
20 |
+
|
21 |
+
SCALING_FACTOR = 0.476986
|
22 |
+
VAE_VER = "884-16c-hy" # We don't support other versions currently
|
23 |
+
|
24 |
+
|
25 |
+
def load_vae(
|
26 |
+
vae_type: str = "884-16c-hy",
|
27 |
+
vae_dtype: Optional[Union[str, torch.dtype]] = None,
|
28 |
+
sample_size: tuple = None,
|
29 |
+
vae_path: str = None,
|
30 |
+
device=None,
|
31 |
+
):
|
32 |
+
"""the fucntion to load the 3D VAE model
|
33 |
+
|
34 |
+
Args:
|
35 |
+
vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
|
36 |
+
vae_precision (str, optional): the precision to load vae. Defaults to None.
|
37 |
+
sample_size (tuple, optional): the tiling size. Defaults to None.
|
38 |
+
vae_path (str, optional): the path to vae. Defaults to None.
|
39 |
+
logger (_type_, optional): logger. Defaults to None.
|
40 |
+
device (_type_, optional): device to load vae. Defaults to None.
|
41 |
+
"""
|
42 |
+
if vae_path is None:
|
43 |
+
vae_path = VAE_PATH[vae_type]
|
44 |
+
|
45 |
+
logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
|
46 |
+
|
47 |
+
# use fixed config for Hunyuan's VAE
|
48 |
+
CONFIG_JSON = """{
|
49 |
+
"_class_name": "AutoencoderKLCausal3D",
|
50 |
+
"_diffusers_version": "0.4.2",
|
51 |
+
"act_fn": "silu",
|
52 |
+
"block_out_channels": [
|
53 |
+
128,
|
54 |
+
256,
|
55 |
+
512,
|
56 |
+
512
|
57 |
+
],
|
58 |
+
"down_block_types": [
|
59 |
+
"DownEncoderBlockCausal3D",
|
60 |
+
"DownEncoderBlockCausal3D",
|
61 |
+
"DownEncoderBlockCausal3D",
|
62 |
+
"DownEncoderBlockCausal3D"
|
63 |
+
],
|
64 |
+
"in_channels": 3,
|
65 |
+
"latent_channels": 16,
|
66 |
+
"layers_per_block": 2,
|
67 |
+
"norm_num_groups": 32,
|
68 |
+
"out_channels": 3,
|
69 |
+
"sample_size": 256,
|
70 |
+
"sample_tsize": 64,
|
71 |
+
"up_block_types": [
|
72 |
+
"UpDecoderBlockCausal3D",
|
73 |
+
"UpDecoderBlockCausal3D",
|
74 |
+
"UpDecoderBlockCausal3D",
|
75 |
+
"UpDecoderBlockCausal3D"
|
76 |
+
],
|
77 |
+
"scaling_factor": 0.476986,
|
78 |
+
"time_compression_ratio": 4,
|
79 |
+
"mid_block_add_attention": true
|
80 |
+
}"""
|
81 |
+
|
82 |
+
# config = AutoencoderKLCausal3D.load_config(vae_path)
|
83 |
+
config = json.loads(CONFIG_JSON)
|
84 |
+
|
85 |
+
# import here to avoid circular import
|
86 |
+
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
87 |
+
|
88 |
+
if sample_size:
|
89 |
+
vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
|
90 |
+
else:
|
91 |
+
vae = AutoencoderKLCausal3D.from_config(config)
|
92 |
+
|
93 |
+
# vae_ckpt = Path(vae_path) / "pytorch_model.pt"
|
94 |
+
# assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
|
95 |
+
|
96 |
+
if vae_path.endswith(".safetensors"):
|
97 |
+
from safetensors.torch import load_file
|
98 |
+
ckpt = load_file(vae_path)
|
99 |
+
else:
|
100 |
+
ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
|
101 |
+
if "state_dict" in ckpt:
|
102 |
+
ckpt = ckpt["state_dict"]
|
103 |
+
if any(k.startswith("vae.") for k in ckpt.keys()):
|
104 |
+
ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
|
105 |
+
vae.load_state_dict(ckpt)
|
106 |
+
|
107 |
+
spatial_compression_ratio = vae.config.spatial_compression_ratio
|
108 |
+
time_compression_ratio = vae.config.time_compression_ratio
|
109 |
+
|
110 |
+
if vae_dtype is not None:
|
111 |
+
vae = vae.to(vae_dtype)
|
112 |
+
|
113 |
+
vae.requires_grad_(False)
|
114 |
+
|
115 |
+
logger.info(f"VAE to dtype: {vae.dtype}")
|
116 |
+
|
117 |
+
if device is not None:
|
118 |
+
vae = vae.to(device)
|
119 |
+
|
120 |
+
vae.eval()
|
121 |
+
|
122 |
+
return vae, vae_path, spatial_compression_ratio, time_compression_ratio
|
123 |
+
|
124 |
+
|
125 |
+
@dataclass
|
126 |
+
class DecoderOutput(BaseOutput):
|
127 |
+
r"""
|
128 |
+
Output of decoding method.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
132 |
+
The decoded output sample from the last layer of the model.
|
133 |
+
"""
|
134 |
+
|
135 |
+
sample: torch.FloatTensor
|
136 |
+
|
137 |
+
|
138 |
+
class EncoderCausal3D(nn.Module):
|
139 |
+
r"""
|
140 |
+
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
in_channels: int = 3,
|
146 |
+
out_channels: int = 3,
|
147 |
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
|
148 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
149 |
+
layers_per_block: int = 2,
|
150 |
+
norm_num_groups: int = 32,
|
151 |
+
act_fn: str = "silu",
|
152 |
+
double_z: bool = True,
|
153 |
+
mid_block_add_attention=True,
|
154 |
+
time_compression_ratio: int = 4,
|
155 |
+
spatial_compression_ratio: int = 8,
|
156 |
+
):
|
157 |
+
super().__init__()
|
158 |
+
self.layers_per_block = layers_per_block
|
159 |
+
|
160 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
161 |
+
self.mid_block = None
|
162 |
+
self.down_blocks = nn.ModuleList([])
|
163 |
+
|
164 |
+
# down
|
165 |
+
output_channel = block_out_channels[0]
|
166 |
+
for i, down_block_type in enumerate(down_block_types):
|
167 |
+
input_channel = output_channel
|
168 |
+
output_channel = block_out_channels[i]
|
169 |
+
is_final_block = i == len(block_out_channels) - 1
|
170 |
+
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
171 |
+
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
172 |
+
|
173 |
+
if time_compression_ratio == 4:
|
174 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
175 |
+
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
176 |
+
else:
|
177 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
178 |
+
|
179 |
+
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
180 |
+
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
181 |
+
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
182 |
+
down_block = get_down_block3d(
|
183 |
+
down_block_type,
|
184 |
+
num_layers=self.layers_per_block,
|
185 |
+
in_channels=input_channel,
|
186 |
+
out_channels=output_channel,
|
187 |
+
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
188 |
+
downsample_stride=downsample_stride,
|
189 |
+
resnet_eps=1e-6,
|
190 |
+
downsample_padding=0,
|
191 |
+
resnet_act_fn=act_fn,
|
192 |
+
resnet_groups=norm_num_groups,
|
193 |
+
attention_head_dim=output_channel,
|
194 |
+
temb_channels=None,
|
195 |
+
)
|
196 |
+
self.down_blocks.append(down_block)
|
197 |
+
|
198 |
+
# mid
|
199 |
+
self.mid_block = UNetMidBlockCausal3D(
|
200 |
+
in_channels=block_out_channels[-1],
|
201 |
+
resnet_eps=1e-6,
|
202 |
+
resnet_act_fn=act_fn,
|
203 |
+
output_scale_factor=1,
|
204 |
+
resnet_time_scale_shift="default",
|
205 |
+
attention_head_dim=block_out_channels[-1],
|
206 |
+
resnet_groups=norm_num_groups,
|
207 |
+
temb_channels=None,
|
208 |
+
add_attention=mid_block_add_attention,
|
209 |
+
)
|
210 |
+
|
211 |
+
# out
|
212 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
213 |
+
self.conv_act = nn.SiLU()
|
214 |
+
|
215 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
216 |
+
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
217 |
+
|
218 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
219 |
+
r"""The forward method of the `EncoderCausal3D` class."""
|
220 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
|
221 |
+
|
222 |
+
sample = self.conv_in(sample)
|
223 |
+
|
224 |
+
# down
|
225 |
+
for down_block in self.down_blocks:
|
226 |
+
sample = down_block(sample)
|
227 |
+
|
228 |
+
# middle
|
229 |
+
sample = self.mid_block(sample)
|
230 |
+
|
231 |
+
# post-process
|
232 |
+
sample = self.conv_norm_out(sample)
|
233 |
+
sample = self.conv_act(sample)
|
234 |
+
sample = self.conv_out(sample)
|
235 |
+
|
236 |
+
return sample
|
237 |
+
|
238 |
+
|
239 |
+
class DecoderCausal3D(nn.Module):
|
240 |
+
r"""
|
241 |
+
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
in_channels: int = 3,
|
247 |
+
out_channels: int = 3,
|
248 |
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
|
249 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
250 |
+
layers_per_block: int = 2,
|
251 |
+
norm_num_groups: int = 32,
|
252 |
+
act_fn: str = "silu",
|
253 |
+
norm_type: str = "group", # group, spatial
|
254 |
+
mid_block_add_attention=True,
|
255 |
+
time_compression_ratio: int = 4,
|
256 |
+
spatial_compression_ratio: int = 8,
|
257 |
+
):
|
258 |
+
super().__init__()
|
259 |
+
self.layers_per_block = layers_per_block
|
260 |
+
|
261 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
262 |
+
self.mid_block = None
|
263 |
+
self.up_blocks = nn.ModuleList([])
|
264 |
+
|
265 |
+
temb_channels = in_channels if norm_type == "spatial" else None
|
266 |
+
|
267 |
+
# mid
|
268 |
+
self.mid_block = UNetMidBlockCausal3D(
|
269 |
+
in_channels=block_out_channels[-1],
|
270 |
+
resnet_eps=1e-6,
|
271 |
+
resnet_act_fn=act_fn,
|
272 |
+
output_scale_factor=1,
|
273 |
+
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
274 |
+
attention_head_dim=block_out_channels[-1],
|
275 |
+
resnet_groups=norm_num_groups,
|
276 |
+
temb_channels=temb_channels,
|
277 |
+
add_attention=mid_block_add_attention,
|
278 |
+
)
|
279 |
+
|
280 |
+
# up
|
281 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
282 |
+
output_channel = reversed_block_out_channels[0]
|
283 |
+
for i, up_block_type in enumerate(up_block_types):
|
284 |
+
prev_output_channel = output_channel
|
285 |
+
output_channel = reversed_block_out_channels[i]
|
286 |
+
is_final_block = i == len(block_out_channels) - 1
|
287 |
+
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
288 |
+
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
289 |
+
|
290 |
+
if time_compression_ratio == 4:
|
291 |
+
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
292 |
+
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
293 |
+
else:
|
294 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
295 |
+
|
296 |
+
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
297 |
+
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
298 |
+
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
299 |
+
up_block = get_up_block3d(
|
300 |
+
up_block_type,
|
301 |
+
num_layers=self.layers_per_block + 1,
|
302 |
+
in_channels=prev_output_channel,
|
303 |
+
out_channels=output_channel,
|
304 |
+
prev_output_channel=None,
|
305 |
+
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
306 |
+
upsample_scale_factor=upsample_scale_factor,
|
307 |
+
resnet_eps=1e-6,
|
308 |
+
resnet_act_fn=act_fn,
|
309 |
+
resnet_groups=norm_num_groups,
|
310 |
+
attention_head_dim=output_channel,
|
311 |
+
temb_channels=temb_channels,
|
312 |
+
resnet_time_scale_shift=norm_type,
|
313 |
+
)
|
314 |
+
self.up_blocks.append(up_block)
|
315 |
+
prev_output_channel = output_channel
|
316 |
+
|
317 |
+
# out
|
318 |
+
if norm_type == "spatial":
|
319 |
+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
320 |
+
else:
|
321 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
322 |
+
self.conv_act = nn.SiLU()
|
323 |
+
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
324 |
+
|
325 |
+
self.gradient_checkpointing = False
|
326 |
+
|
327 |
+
def forward(
|
328 |
+
self,
|
329 |
+
sample: torch.FloatTensor,
|
330 |
+
latent_embeds: Optional[torch.FloatTensor] = None,
|
331 |
+
) -> torch.FloatTensor:
|
332 |
+
r"""The forward method of the `DecoderCausal3D` class."""
|
333 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
|
334 |
+
|
335 |
+
sample = self.conv_in(sample)
|
336 |
+
|
337 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
338 |
+
if self.training and self.gradient_checkpointing:
|
339 |
+
|
340 |
+
def create_custom_forward(module):
|
341 |
+
def custom_forward(*inputs):
|
342 |
+
return module(*inputs)
|
343 |
+
|
344 |
+
return custom_forward
|
345 |
+
|
346 |
+
if is_torch_version(">=", "1.11.0"):
|
347 |
+
# middle
|
348 |
+
sample = torch.utils.checkpoint.checkpoint(
|
349 |
+
create_custom_forward(self.mid_block),
|
350 |
+
sample,
|
351 |
+
latent_embeds,
|
352 |
+
use_reentrant=False,
|
353 |
+
)
|
354 |
+
sample = sample.to(upscale_dtype)
|
355 |
+
|
356 |
+
# up
|
357 |
+
for up_block in self.up_blocks:
|
358 |
+
sample = torch.utils.checkpoint.checkpoint(
|
359 |
+
create_custom_forward(up_block),
|
360 |
+
sample,
|
361 |
+
latent_embeds,
|
362 |
+
use_reentrant=False,
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
# middle
|
366 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
|
367 |
+
sample = sample.to(upscale_dtype)
|
368 |
+
|
369 |
+
# up
|
370 |
+
for up_block in self.up_blocks:
|
371 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
372 |
+
else:
|
373 |
+
# middle
|
374 |
+
sample = self.mid_block(sample, latent_embeds)
|
375 |
+
sample = sample.to(upscale_dtype)
|
376 |
+
|
377 |
+
# up
|
378 |
+
for up_block in self.up_blocks:
|
379 |
+
sample = up_block(sample, latent_embeds)
|
380 |
+
|
381 |
+
# post-process
|
382 |
+
if latent_embeds is None:
|
383 |
+
sample = self.conv_norm_out(sample)
|
384 |
+
else:
|
385 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
386 |
+
sample = self.conv_act(sample)
|
387 |
+
sample = self.conv_out(sample)
|
388 |
+
|
389 |
+
return sample
|
390 |
+
|
391 |
+
|
392 |
+
class DiagonalGaussianDistribution(object):
|
393 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
394 |
+
if parameters.ndim == 3:
|
395 |
+
dim = 2 # (B, L, C)
|
396 |
+
elif parameters.ndim == 5 or parameters.ndim == 4:
|
397 |
+
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
|
398 |
+
else:
|
399 |
+
raise NotImplementedError
|
400 |
+
self.parameters = parameters
|
401 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
402 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
403 |
+
self.deterministic = deterministic
|
404 |
+
self.std = torch.exp(0.5 * self.logvar)
|
405 |
+
self.var = torch.exp(self.logvar)
|
406 |
+
if self.deterministic:
|
407 |
+
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
|
408 |
+
|
409 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
410 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
411 |
+
sample = randn_tensor(
|
412 |
+
self.mean.shape,
|
413 |
+
generator=generator,
|
414 |
+
device=self.parameters.device,
|
415 |
+
dtype=self.parameters.dtype,
|
416 |
+
)
|
417 |
+
x = self.mean + self.std * sample
|
418 |
+
return x
|
419 |
+
|
420 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
421 |
+
if self.deterministic:
|
422 |
+
return torch.Tensor([0.0])
|
423 |
+
else:
|
424 |
+
reduce_dim = list(range(1, self.mean.ndim))
|
425 |
+
if other is None:
|
426 |
+
return 0.5 * torch.sum(
|
427 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
428 |
+
dim=reduce_dim,
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
return 0.5 * torch.sum(
|
432 |
+
torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
|
433 |
+
dim=reduce_dim,
|
434 |
+
)
|
435 |
+
|
436 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
437 |
+
if self.deterministic:
|
438 |
+
return torch.Tensor([0.0])
|
439 |
+
logtwopi = np.log(2.0 * np.pi)
|
440 |
+
return 0.5 * torch.sum(
|
441 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
442 |
+
dim=dims,
|
443 |
+
)
|
444 |
+
|
445 |
+
def mode(self) -> torch.Tensor:
|
446 |
+
return self.mean
|
hv_generate_video.py
ADDED
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from datetime import datetime
|
3 |
+
from pathlib import Path
|
4 |
+
import random
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
from typing import Optional, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision
|
13 |
+
import accelerate
|
14 |
+
from diffusers.utils.torch_utils import randn_tensor
|
15 |
+
from transformers.models.llama import LlamaModel
|
16 |
+
from tqdm import tqdm
|
17 |
+
import av
|
18 |
+
from einops import rearrange
|
19 |
+
from safetensors.torch import load_file, save_file
|
20 |
+
from safetensors import safe_open
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from hunyuan_model import vae
|
24 |
+
from hunyuan_model.text_encoder import TextEncoder
|
25 |
+
from hunyuan_model.text_encoder import PROMPT_TEMPLATE
|
26 |
+
from hunyuan_model.vae import load_vae
|
27 |
+
from hunyuan_model.models import load_transformer, get_rotary_pos_embed
|
28 |
+
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
29 |
+
from networks import lora
|
30 |
+
|
31 |
+
try:
|
32 |
+
from lycoris.kohya import create_network_from_weights
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
|
36 |
+
from utils.model_utils import str_to_dtype
|
37 |
+
from utils.safetensors_utils import mem_eff_save_file
|
38 |
+
from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket
|
39 |
+
|
40 |
+
import logging
|
41 |
+
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
logging.basicConfig(level=logging.INFO)
|
44 |
+
|
45 |
+
|
46 |
+
def clean_memory_on_device(device):
|
47 |
+
if device.type == "cuda":
|
48 |
+
torch.cuda.empty_cache()
|
49 |
+
elif device.type == "cpu":
|
50 |
+
pass
|
51 |
+
elif device.type == "mps": # not tested
|
52 |
+
torch.mps.empty_cache()
|
53 |
+
|
54 |
+
|
55 |
+
def synchronize_device(device: torch.device):
|
56 |
+
if device.type == "cuda":
|
57 |
+
torch.cuda.synchronize()
|
58 |
+
elif device.type == "xpu":
|
59 |
+
torch.xpu.synchronize()
|
60 |
+
elif device.type == "mps":
|
61 |
+
torch.mps.synchronize()
|
62 |
+
|
63 |
+
|
64 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
|
65 |
+
"""save videos by video tensor
|
66 |
+
copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
|
67 |
+
|
68 |
+
Args:
|
69 |
+
videos (torch.Tensor): video tensor predicted by the model
|
70 |
+
path (str): path to save video
|
71 |
+
rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
|
72 |
+
n_rows (int, optional): Defaults to 1.
|
73 |
+
fps (int, optional): video save fps. Defaults to 8.
|
74 |
+
"""
|
75 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
76 |
+
outputs = []
|
77 |
+
for x in videos:
|
78 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
79 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
80 |
+
if rescale:
|
81 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
82 |
+
x = torch.clamp(x, 0, 1)
|
83 |
+
x = (x * 255).numpy().astype(np.uint8)
|
84 |
+
outputs.append(x)
|
85 |
+
|
86 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
87 |
+
|
88 |
+
# # save video with av
|
89 |
+
# container = av.open(path, "w")
|
90 |
+
# stream = container.add_stream("libx264", rate=fps)
|
91 |
+
# for x in outputs:
|
92 |
+
# frame = av.VideoFrame.from_ndarray(x, format="rgb24")
|
93 |
+
# packet = stream.encode(frame)
|
94 |
+
# container.mux(packet)
|
95 |
+
# packet = stream.encode(None)
|
96 |
+
# container.mux(packet)
|
97 |
+
# container.close()
|
98 |
+
|
99 |
+
height, width, _ = outputs[0].shape
|
100 |
+
|
101 |
+
# create output container
|
102 |
+
container = av.open(path, mode="w")
|
103 |
+
|
104 |
+
# create video stream
|
105 |
+
codec = "libx264"
|
106 |
+
pixel_format = "yuv420p"
|
107 |
+
stream = container.add_stream(codec, rate=fps)
|
108 |
+
stream.width = width
|
109 |
+
stream.height = height
|
110 |
+
stream.pix_fmt = pixel_format
|
111 |
+
stream.bit_rate = 4000000 # 4Mbit/s
|
112 |
+
|
113 |
+
for frame_array in outputs:
|
114 |
+
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
115 |
+
packets = stream.encode(frame)
|
116 |
+
for packet in packets:
|
117 |
+
container.mux(packet)
|
118 |
+
|
119 |
+
for packet in stream.encode():
|
120 |
+
container.mux(packet)
|
121 |
+
|
122 |
+
container.close()
|
123 |
+
|
124 |
+
|
125 |
+
def save_images_grid(
|
126 |
+
videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True
|
127 |
+
):
|
128 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
129 |
+
outputs = []
|
130 |
+
for x in videos:
|
131 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
132 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
133 |
+
if rescale:
|
134 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
135 |
+
x = torch.clamp(x, 0, 1)
|
136 |
+
x = (x * 255).numpy().astype(np.uint8)
|
137 |
+
outputs.append(x)
|
138 |
+
|
139 |
+
if create_subdir:
|
140 |
+
output_dir = os.path.join(parent_dir, image_name)
|
141 |
+
else:
|
142 |
+
output_dir = parent_dir
|
143 |
+
|
144 |
+
os.makedirs(output_dir, exist_ok=True)
|
145 |
+
for i, x in enumerate(outputs):
|
146 |
+
image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png")
|
147 |
+
image = Image.fromarray(x)
|
148 |
+
image.save(image_path)
|
149 |
+
|
150 |
+
|
151 |
+
# region Encoding prompt
|
152 |
+
|
153 |
+
|
154 |
+
def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
|
155 |
+
r"""
|
156 |
+
Encodes the prompt into text encoder hidden states.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
prompt (`str` or `List[str]`):
|
160 |
+
prompt to be encoded
|
161 |
+
device: (`torch.device`):
|
162 |
+
torch device
|
163 |
+
num_videos_per_prompt (`int`):
|
164 |
+
number of videos that should be generated per prompt
|
165 |
+
text_encoder (TextEncoder):
|
166 |
+
text encoder to be used for encoding the prompt
|
167 |
+
"""
|
168 |
+
# LoRA and Textual Inversion are not supported in this script
|
169 |
+
# negative prompt and prompt embedding are not supported in this script
|
170 |
+
# clip_skip is not supported in this script because it is not used in the original script
|
171 |
+
data_type = "video" # video only, image is not supported
|
172 |
+
|
173 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
174 |
+
|
175 |
+
with torch.no_grad():
|
176 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
|
177 |
+
prompt_embeds = prompt_outputs.hidden_state
|
178 |
+
|
179 |
+
attention_mask = prompt_outputs.attention_mask
|
180 |
+
if attention_mask is not None:
|
181 |
+
attention_mask = attention_mask.to(device)
|
182 |
+
bs_embed, seq_len = attention_mask.shape
|
183 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
184 |
+
attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
|
185 |
+
|
186 |
+
prompt_embeds_dtype = text_encoder.dtype
|
187 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
188 |
+
|
189 |
+
if prompt_embeds.ndim == 2:
|
190 |
+
bs_embed, _ = prompt_embeds.shape
|
191 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
192 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
193 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
194 |
+
else:
|
195 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
196 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
197 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
198 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
199 |
+
|
200 |
+
return prompt_embeds, attention_mask
|
201 |
+
|
202 |
+
|
203 |
+
def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None):
|
204 |
+
# constants
|
205 |
+
prompt_template_video = "dit-llm-encode-video"
|
206 |
+
prompt_template = "dit-llm-encode"
|
207 |
+
text_encoder_dtype = torch.float16
|
208 |
+
text_encoder_type = "llm"
|
209 |
+
text_len = 256
|
210 |
+
hidden_state_skip_layer = 2
|
211 |
+
apply_final_norm = False
|
212 |
+
reproduce = False
|
213 |
+
|
214 |
+
text_encoder_2_type = "clipL"
|
215 |
+
text_len_2 = 77
|
216 |
+
|
217 |
+
num_videos = 1
|
218 |
+
|
219 |
+
# if args.prompt_template_video is not None:
|
220 |
+
# crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
|
221 |
+
# elif args.prompt_template is not None:
|
222 |
+
# crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
|
223 |
+
# else:
|
224 |
+
# crop_start = 0
|
225 |
+
crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
|
226 |
+
max_length = text_len + crop_start
|
227 |
+
|
228 |
+
# prompt_template
|
229 |
+
prompt_template = PROMPT_TEMPLATE[prompt_template]
|
230 |
+
|
231 |
+
# prompt_template_video
|
232 |
+
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
|
233 |
+
|
234 |
+
# load text encoders
|
235 |
+
logger.info(f"loading text encoder: {args.text_encoder1}")
|
236 |
+
text_encoder = TextEncoder(
|
237 |
+
text_encoder_type=text_encoder_type,
|
238 |
+
max_length=max_length,
|
239 |
+
text_encoder_dtype=text_encoder_dtype,
|
240 |
+
text_encoder_path=args.text_encoder1,
|
241 |
+
tokenizer_type=text_encoder_type,
|
242 |
+
prompt_template=prompt_template,
|
243 |
+
prompt_template_video=prompt_template_video,
|
244 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
245 |
+
apply_final_norm=apply_final_norm,
|
246 |
+
reproduce=reproduce,
|
247 |
+
)
|
248 |
+
text_encoder.eval()
|
249 |
+
if fp8_llm:
|
250 |
+
org_dtype = text_encoder.dtype
|
251 |
+
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
|
252 |
+
text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
|
253 |
+
|
254 |
+
# prepare LLM for fp8
|
255 |
+
def prepare_fp8(llama_model: LlamaModel, target_dtype):
|
256 |
+
def forward_hook(module):
|
257 |
+
def forward(hidden_states):
|
258 |
+
input_dtype = hidden_states.dtype
|
259 |
+
hidden_states = hidden_states.to(torch.float32)
|
260 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
261 |
+
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
|
262 |
+
return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
|
263 |
+
|
264 |
+
return forward
|
265 |
+
|
266 |
+
for module in llama_model.modules():
|
267 |
+
if module.__class__.__name__ in ["Embedding"]:
|
268 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
269 |
+
module.to(target_dtype)
|
270 |
+
if module.__class__.__name__ in ["LlamaRMSNorm"]:
|
271 |
+
# print("set", module.__class__.__name__, "hooks")
|
272 |
+
module.forward = forward_hook(module)
|
273 |
+
|
274 |
+
prepare_fp8(text_encoder.model, org_dtype)
|
275 |
+
|
276 |
+
logger.info(f"loading text encoder 2: {args.text_encoder2}")
|
277 |
+
text_encoder_2 = TextEncoder(
|
278 |
+
text_encoder_type=text_encoder_2_type,
|
279 |
+
max_length=text_len_2,
|
280 |
+
text_encoder_dtype=text_encoder_dtype,
|
281 |
+
text_encoder_path=args.text_encoder2,
|
282 |
+
tokenizer_type=text_encoder_2_type,
|
283 |
+
reproduce=reproduce,
|
284 |
+
)
|
285 |
+
text_encoder_2.eval()
|
286 |
+
|
287 |
+
# encode prompt
|
288 |
+
logger.info(f"Encoding prompt with text encoder 1")
|
289 |
+
text_encoder.to(device=device)
|
290 |
+
if fp8_llm:
|
291 |
+
with accelerator.autocast():
|
292 |
+
prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
|
293 |
+
else:
|
294 |
+
prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
|
295 |
+
text_encoder = None
|
296 |
+
clean_memory_on_device(device)
|
297 |
+
|
298 |
+
logger.info(f"Encoding prompt with text encoder 2")
|
299 |
+
text_encoder_2.to(device=device)
|
300 |
+
prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
|
301 |
+
|
302 |
+
prompt_embeds = prompt_embeds.to("cpu")
|
303 |
+
prompt_mask = prompt_mask.to("cpu")
|
304 |
+
prompt_embeds_2 = prompt_embeds_2.to("cpu")
|
305 |
+
prompt_mask_2 = prompt_mask_2.to("cpu")
|
306 |
+
|
307 |
+
text_encoder_2 = None
|
308 |
+
clean_memory_on_device(device)
|
309 |
+
|
310 |
+
return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
|
311 |
+
|
312 |
+
|
313 |
+
# endregion
|
314 |
+
|
315 |
+
|
316 |
+
def load_images(image_dir, video_length, bucket_reso):
|
317 |
+
image_files = glob_images(image_dir)
|
318 |
+
if len(image_files) == 0:
|
319 |
+
raise ValueError(f"No image files found in {image_dir}")
|
320 |
+
if len(image_files) < video_length:
|
321 |
+
raise ValueError(f"Number of images in {image_dir} is less than {video_length}")
|
322 |
+
|
323 |
+
image_files.sort()
|
324 |
+
images = []
|
325 |
+
for image_file in image_files[:video_length]:
|
326 |
+
image = Image.open(image_file)
|
327 |
+
image = resize_image_to_bucket(image, bucket_reso) # returns a numpy array
|
328 |
+
images.append(image)
|
329 |
+
|
330 |
+
return images
|
331 |
+
|
332 |
+
|
333 |
+
def prepare_vae(args, device):
|
334 |
+
vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
|
335 |
+
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
|
336 |
+
vae.eval()
|
337 |
+
# vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
|
338 |
+
|
339 |
+
# set chunk_size to CausalConv3d recursively
|
340 |
+
chunk_size = args.vae_chunk_size
|
341 |
+
if chunk_size is not None:
|
342 |
+
vae.set_chunk_size_for_causal_conv_3d(chunk_size)
|
343 |
+
logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
|
344 |
+
|
345 |
+
if args.vae_spatial_tile_sample_min_size is not None:
|
346 |
+
vae.enable_spatial_tiling(True)
|
347 |
+
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
|
348 |
+
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
|
349 |
+
# elif args.vae_tiling:
|
350 |
+
else:
|
351 |
+
vae.enable_spatial_tiling(True)
|
352 |
+
|
353 |
+
return vae, vae_dtype
|
354 |
+
|
355 |
+
|
356 |
+
def encode_to_latents(args, video, device):
|
357 |
+
vae, vae_dtype = prepare_vae(args, device)
|
358 |
+
|
359 |
+
video = video.to(device=device, dtype=vae_dtype)
|
360 |
+
video = video * 2 - 1 # 0, 1 -> -1, 1
|
361 |
+
with torch.no_grad():
|
362 |
+
latents = vae.encode(video).latent_dist.sample()
|
363 |
+
|
364 |
+
if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
|
365 |
+
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
|
366 |
+
else:
|
367 |
+
latents = latents * vae.config.scaling_factor
|
368 |
+
|
369 |
+
return latents
|
370 |
+
|
371 |
+
|
372 |
+
def decode_latents(args, latents, device):
|
373 |
+
vae, vae_dtype = prepare_vae(args, device)
|
374 |
+
|
375 |
+
expand_temporal_dim = False
|
376 |
+
if len(latents.shape) == 4:
|
377 |
+
latents = latents.unsqueeze(2)
|
378 |
+
expand_temporal_dim = True
|
379 |
+
elif len(latents.shape) == 5:
|
380 |
+
pass
|
381 |
+
else:
|
382 |
+
raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
|
383 |
+
|
384 |
+
if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
|
385 |
+
latents = latents / vae.config.scaling_factor + vae.config.shift_factor
|
386 |
+
else:
|
387 |
+
latents = latents / vae.config.scaling_factor
|
388 |
+
|
389 |
+
latents = latents.to(device=device, dtype=vae_dtype)
|
390 |
+
with torch.no_grad():
|
391 |
+
image = vae.decode(latents, return_dict=False)[0]
|
392 |
+
|
393 |
+
if expand_temporal_dim:
|
394 |
+
image = image.squeeze(2)
|
395 |
+
|
396 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
397 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
398 |
+
image = image.cpu().float()
|
399 |
+
|
400 |
+
return image
|
401 |
+
|
402 |
+
|
403 |
+
def parse_args():
|
404 |
+
parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
|
405 |
+
|
406 |
+
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
|
407 |
+
parser.add_argument(
|
408 |
+
"--dit_in_channels",
|
409 |
+
type=int,
|
410 |
+
default=None,
|
411 |
+
help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others",
|
412 |
+
)
|
413 |
+
parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
|
414 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
|
415 |
+
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
|
416 |
+
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
|
417 |
+
|
418 |
+
# LoRA
|
419 |
+
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
|
420 |
+
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
|
421 |
+
parser.add_argument(
|
422 |
+
"--save_merged_model",
|
423 |
+
type=str,
|
424 |
+
default=None,
|
425 |
+
help="Save merged model to path. If specified, no inference will be performed.",
|
426 |
+
)
|
427 |
+
parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights")
|
428 |
+
|
429 |
+
# inference
|
430 |
+
parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
|
431 |
+
parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation")
|
432 |
+
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
|
433 |
+
parser.add_argument("--video_length", type=int, default=129, help="video length")
|
434 |
+
parser.add_argument("--fps", type=int, default=24, help="video fps")
|
435 |
+
parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
|
436 |
+
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
|
437 |
+
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
|
438 |
+
parser.add_argument(
|
439 |
+
"--guidance_scale",
|
440 |
+
type=float,
|
441 |
+
default=1.0,
|
442 |
+
help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)",
|
443 |
+
)
|
444 |
+
parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
|
445 |
+
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
|
446 |
+
parser.add_argument(
|
447 |
+
"--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model"
|
448 |
+
)
|
449 |
+
parser.add_argument(
|
450 |
+
"--split_uncond",
|
451 |
+
action="store_true",
|
452 |
+
help="split unconditional call for classifier free guidance, slower but less memory usage",
|
453 |
+
)
|
454 |
+
parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference")
|
455 |
+
|
456 |
+
# Flow Matching
|
457 |
+
parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
|
458 |
+
|
459 |
+
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
|
460 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
461 |
+
parser.add_argument(
|
462 |
+
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
463 |
+
)
|
464 |
+
parser.add_argument(
|
465 |
+
"--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode"
|
466 |
+
)
|
467 |
+
parser.add_argument(
|
468 |
+
"--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True"
|
469 |
+
)
|
470 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
471 |
+
parser.add_argument(
|
472 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
473 |
+
)
|
474 |
+
parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
|
475 |
+
parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
|
476 |
+
parser.add_argument(
|
477 |
+
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
|
478 |
+
)
|
479 |
+
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
|
480 |
+
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
|
481 |
+
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
|
482 |
+
|
483 |
+
args = parser.parse_args()
|
484 |
+
|
485 |
+
assert (args.latent_path is None or len(args.latent_path) == 0) or (
|
486 |
+
args.output_type == "images" or args.output_type == "video"
|
487 |
+
), "latent_path is only supported for images or video output"
|
488 |
+
|
489 |
+
# update dit_weight based on model_base if not exists
|
490 |
+
|
491 |
+
return args
|
492 |
+
|
493 |
+
|
494 |
+
def check_inputs(args):
|
495 |
+
height = args.video_size[0]
|
496 |
+
width = args.video_size[1]
|
497 |
+
video_length = args.video_length
|
498 |
+
|
499 |
+
if height % 8 != 0 or width % 8 != 0:
|
500 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
501 |
+
return height, width, video_length
|
502 |
+
|
503 |
+
|
504 |
+
def main():
|
505 |
+
args = parse_args()
|
506 |
+
|
507 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
508 |
+
device = torch.device(device)
|
509 |
+
dit_dtype = torch.bfloat16
|
510 |
+
dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
|
511 |
+
logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
|
512 |
+
|
513 |
+
original_base_names = None
|
514 |
+
if args.latent_path is not None and len(args.latent_path) > 0:
|
515 |
+
original_base_names = []
|
516 |
+
latents_list = []
|
517 |
+
seeds = []
|
518 |
+
for latent_path in args.latent_path:
|
519 |
+
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
|
520 |
+
seed = 0
|
521 |
+
|
522 |
+
if os.path.splitext(latent_path)[1] != ".safetensors":
|
523 |
+
latents = torch.load(latent_path, map_location="cpu")
|
524 |
+
else:
|
525 |
+
latents = load_file(latent_path)["latent"]
|
526 |
+
with safe_open(latent_path, framework="pt") as f:
|
527 |
+
metadata = f.metadata()
|
528 |
+
if metadata is None:
|
529 |
+
metadata = {}
|
530 |
+
logger.info(f"Loaded metadata: {metadata}")
|
531 |
+
|
532 |
+
if "seeds" in metadata:
|
533 |
+
seed = int(metadata["seeds"])
|
534 |
+
|
535 |
+
seeds.append(seed)
|
536 |
+
latents_list.append(latents)
|
537 |
+
|
538 |
+
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
|
539 |
+
latents = torch.stack(latents_list, dim=0)
|
540 |
+
else:
|
541 |
+
# prepare accelerator
|
542 |
+
mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
|
543 |
+
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
|
544 |
+
|
545 |
+
# load prompt
|
546 |
+
prompt = args.prompt # TODO load prompts from file
|
547 |
+
assert prompt is not None, "prompt is required"
|
548 |
+
|
549 |
+
# check inputs: may be height, width, video_length etc will be changed for each generation in future
|
550 |
+
height, width, video_length = check_inputs(args)
|
551 |
+
|
552 |
+
# encode prompt with LLM and Text Encoder
|
553 |
+
logger.info(f"Encoding prompt: {prompt}")
|
554 |
+
|
555 |
+
do_classifier_free_guidance = args.guidance_scale != 1.0
|
556 |
+
if do_classifier_free_guidance:
|
557 |
+
negative_prompt = args.negative_prompt
|
558 |
+
if negative_prompt is None:
|
559 |
+
logger.info("Negative prompt is not provided, using empty prompt")
|
560 |
+
negative_prompt = ""
|
561 |
+
logger.info(f"Encoding negative prompt: {negative_prompt}")
|
562 |
+
prompt = [negative_prompt, prompt]
|
563 |
+
else:
|
564 |
+
if args.negative_prompt is not None:
|
565 |
+
logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.")
|
566 |
+
|
567 |
+
prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
|
568 |
+
prompt, args, device, args.fp8_llm, accelerator
|
569 |
+
)
|
570 |
+
|
571 |
+
# encode latents for video2video inference
|
572 |
+
video_latents = None
|
573 |
+
if args.video_path is not None:
|
574 |
+
# v2v inference
|
575 |
+
logger.info(f"Video2Video inference: {args.video_path}")
|
576 |
+
|
577 |
+
if os.path.isfile(args.video_path):
|
578 |
+
video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
|
579 |
+
else:
|
580 |
+
video = load_images(args.video_path, video_length, bucket_reso=(width, height)) # list of frames
|
581 |
+
|
582 |
+
if len(video) < video_length:
|
583 |
+
raise ValueError(f"Video length is less than {video_length}")
|
584 |
+
video = np.stack(video, axis=0) # F, H, W, C
|
585 |
+
video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W
|
586 |
+
video = video / 255.0
|
587 |
+
|
588 |
+
logger.info(f"Encoding video to latents")
|
589 |
+
video_latents = encode_to_latents(args, video, device)
|
590 |
+
video_latents = video_latents.to(device=device, dtype=dit_dtype)
|
591 |
+
|
592 |
+
clean_memory_on_device(device)
|
593 |
+
|
594 |
+
# encode latents for image2video inference
|
595 |
+
image_latents = None
|
596 |
+
if args.image_path is not None:
|
597 |
+
# i2v inference
|
598 |
+
logger.info(f"Image2Video inference: {args.image_path}")
|
599 |
+
|
600 |
+
image = Image.open(args.image_path)
|
601 |
+
image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
|
602 |
+
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W
|
603 |
+
image = image / 255.0
|
604 |
+
|
605 |
+
logger.info(f"Encoding image to latents")
|
606 |
+
image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W
|
607 |
+
image_latents = image_latents.to(device=device, dtype=dit_dtype)
|
608 |
+
|
609 |
+
clean_memory_on_device(device)
|
610 |
+
|
611 |
+
# load DiT model
|
612 |
+
blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
|
613 |
+
loading_device = "cpu" # if blocks_to_swap > 0 else device
|
614 |
+
|
615 |
+
logger.info(f"Loading DiT model from {args.dit}")
|
616 |
+
if args.attn_mode == "sdpa":
|
617 |
+
args.attn_mode = "torch"
|
618 |
+
|
619 |
+
# if image_latents is given, the model should be I2V model, so the in_channels should be 32
|
620 |
+
dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16)
|
621 |
+
|
622 |
+
# if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16
|
623 |
+
# the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway
|
624 |
+
# on the fly merging will be a solution for this issue for .safetenors files (not implemented yet)
|
625 |
+
transformer = load_transformer(
|
626 |
+
args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels
|
627 |
+
)
|
628 |
+
transformer.eval()
|
629 |
+
|
630 |
+
# load LoRA weights
|
631 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
632 |
+
for i, lora_weight in enumerate(args.lora_weight):
|
633 |
+
if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
|
634 |
+
lora_multiplier = args.lora_multiplier[i]
|
635 |
+
else:
|
636 |
+
lora_multiplier = 1.0
|
637 |
+
|
638 |
+
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
|
639 |
+
weights_sd = load_file(lora_weight)
|
640 |
+
|
641 |
+
# Filter to exclude keys that are part of single_blocks
|
642 |
+
if args.exclude_single_blocks:
|
643 |
+
filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k}
|
644 |
+
weights_sd = filtered_weights
|
645 |
+
|
646 |
+
if args.lycoris:
|
647 |
+
lycoris_net, _ = create_network_from_weights(
|
648 |
+
multiplier=lora_multiplier,
|
649 |
+
file=None,
|
650 |
+
weights_sd=weights_sd,
|
651 |
+
unet=transformer,
|
652 |
+
text_encoder=None,
|
653 |
+
vae=None,
|
654 |
+
for_inference=True,
|
655 |
+
)
|
656 |
+
else:
|
657 |
+
network = lora.create_arch_network_from_weights(
|
658 |
+
lora_multiplier, weights_sd, unet=transformer, for_inference=True
|
659 |
+
)
|
660 |
+
logger.info("Merging LoRA weights to DiT model")
|
661 |
+
|
662 |
+
# try:
|
663 |
+
# network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
|
664 |
+
# info = network.load_state_dict(weights_sd, strict=True)
|
665 |
+
# logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
666 |
+
# network.eval()
|
667 |
+
# network.to(device)
|
668 |
+
# except Exception as e:
|
669 |
+
if args.lycoris:
|
670 |
+
lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device)
|
671 |
+
else:
|
672 |
+
network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
|
673 |
+
|
674 |
+
synchronize_device(device)
|
675 |
+
|
676 |
+
logger.info("LoRA weights loaded")
|
677 |
+
|
678 |
+
# save model here before casting to dit_weight_dtype
|
679 |
+
if args.save_merged_model:
|
680 |
+
logger.info(f"Saving merged model to {args.save_merged_model}")
|
681 |
+
mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory
|
682 |
+
logger.info("Merged model saved")
|
683 |
+
return
|
684 |
+
|
685 |
+
if blocks_to_swap > 0:
|
686 |
+
logger.info(f"Casting model to {dit_weight_dtype}")
|
687 |
+
transformer.to(dtype=dit_weight_dtype)
|
688 |
+
logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
|
689 |
+
transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
|
690 |
+
transformer.move_to_device_except_swap_blocks(device)
|
691 |
+
transformer.prepare_block_swap_before_forward()
|
692 |
+
else:
|
693 |
+
logger.info(f"Moving and casting model to {device} and {dit_weight_dtype}")
|
694 |
+
transformer.to(device=device, dtype=dit_weight_dtype)
|
695 |
+
if args.img_in_txt_in_offloading:
|
696 |
+
logger.info("Enable offloading img_in and txt_in to CPU")
|
697 |
+
transformer.enable_img_in_txt_in_offloading()
|
698 |
+
|
699 |
+
# load scheduler
|
700 |
+
logger.info(f"Loading scheduler")
|
701 |
+
scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
|
702 |
+
|
703 |
+
# Prepare timesteps
|
704 |
+
num_inference_steps = args.infer_steps
|
705 |
+
scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
|
706 |
+
timesteps = scheduler.timesteps
|
707 |
+
|
708 |
+
# Prepare generator
|
709 |
+
num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size
|
710 |
+
seed = args.seed
|
711 |
+
if seed is None:
|
712 |
+
seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)]
|
713 |
+
elif isinstance(seed, int):
|
714 |
+
seeds = [seed + i for i in range(num_videos_per_prompt)]
|
715 |
+
else:
|
716 |
+
raise ValueError(f"Seed must be an integer or None, got {seed}.")
|
717 |
+
generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
|
718 |
+
|
719 |
+
# Prepare noisy latents
|
720 |
+
num_channels_latents = 16 # transformer.config.in_channels
|
721 |
+
vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
|
722 |
+
|
723 |
+
vae_ver = vae.VAE_VER
|
724 |
+
if "884" in vae_ver:
|
725 |
+
latent_video_length = (video_length - 1) // 4 + 1
|
726 |
+
elif "888" in vae_ver:
|
727 |
+
latent_video_length = (video_length - 1) // 8 + 1
|
728 |
+
else:
|
729 |
+
latent_video_length = video_length
|
730 |
+
|
731 |
+
# shape = (
|
732 |
+
# num_videos_per_prompt,
|
733 |
+
# num_channels_latents,
|
734 |
+
# latent_video_length,
|
735 |
+
# height // vae_scale_factor,
|
736 |
+
# width // vae_scale_factor,
|
737 |
+
# )
|
738 |
+
# latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
|
739 |
+
|
740 |
+
# make first N frames to be the same if the given seed is same
|
741 |
+
shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor)
|
742 |
+
latents = []
|
743 |
+
for i in range(latent_video_length):
|
744 |
+
latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype))
|
745 |
+
latents = torch.cat(latents, dim=2)
|
746 |
+
|
747 |
+
# pad image_latents to match the length of video_latents
|
748 |
+
if image_latents is not None:
|
749 |
+
zero_latents = torch.zeros_like(latents)
|
750 |
+
zero_latents[:, :, :1, :, :] = image_latents
|
751 |
+
image_latents = zero_latents
|
752 |
+
|
753 |
+
if args.video_path is not None:
|
754 |
+
# v2v inference
|
755 |
+
noise = latents
|
756 |
+
assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}"
|
757 |
+
|
758 |
+
num_inference_steps = int(num_inference_steps * args.strength)
|
759 |
+
timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time
|
760 |
+
t = timestep_start / 1000.0
|
761 |
+
latents = noise * t + video_latents * (1 - t)
|
762 |
+
|
763 |
+
timesteps = timesteps[-num_inference_steps:]
|
764 |
+
|
765 |
+
logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}")
|
766 |
+
|
767 |
+
# FlowMatchDiscreteScheduler does not have init_noise_sigma
|
768 |
+
|
769 |
+
# Denoising loop
|
770 |
+
embedded_guidance_scale = args.embedded_cfg_scale
|
771 |
+
if embedded_guidance_scale is not None:
|
772 |
+
guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
|
773 |
+
guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
|
774 |
+
if do_classifier_free_guidance:
|
775 |
+
guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0)
|
776 |
+
else:
|
777 |
+
guidance_expand = None
|
778 |
+
freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width)
|
779 |
+
# n_tokens = freqs_cos.shape[0]
|
780 |
+
|
781 |
+
# move and cast all inputs to the correct device and dtype
|
782 |
+
prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
|
783 |
+
prompt_mask = prompt_mask.to(device=device)
|
784 |
+
prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
|
785 |
+
prompt_mask_2 = prompt_mask_2.to(device=device)
|
786 |
+
|
787 |
+
freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
|
788 |
+
freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
|
789 |
+
|
790 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference
|
791 |
+
|
792 |
+
# assert split_uncond and split_attn
|
793 |
+
if args.split_attn and do_classifier_free_guidance and not args.split_uncond:
|
794 |
+
logger.warning("split_attn is enabled, split_uncond will be enabled as well.")
|
795 |
+
args.split_uncond = True
|
796 |
+
|
797 |
+
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
|
798 |
+
with tqdm(total=num_inference_steps) as progress_bar:
|
799 |
+
for i, t in enumerate(timesteps):
|
800 |
+
latents = scheduler.scale_model_input(latents, t)
|
801 |
+
|
802 |
+
# predict the noise residual
|
803 |
+
with torch.no_grad(), accelerator.autocast():
|
804 |
+
latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0)
|
805 |
+
if image_latents is not None:
|
806 |
+
latents_image_input = (
|
807 |
+
image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0)
|
808 |
+
)
|
809 |
+
latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W
|
810 |
+
|
811 |
+
batch_size = 1 if args.split_uncond else latents_input.shape[0]
|
812 |
+
|
813 |
+
noise_pred_list = []
|
814 |
+
for j in range(0, latents_input.shape[0], batch_size):
|
815 |
+
noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
|
816 |
+
latents_input[j : j + batch_size], # [1, 16, 33, 24, 42]
|
817 |
+
t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1]
|
818 |
+
text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096]
|
819 |
+
text_mask=prompt_mask[j : j + batch_size], # [1, 256]
|
820 |
+
text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768]
|
821 |
+
freqs_cos=freqs_cos, # [seqlen, head_dim]
|
822 |
+
freqs_sin=freqs_sin, # [seqlen, head_dim]
|
823 |
+
guidance=guidance_expand[j : j + batch_size], # [1]
|
824 |
+
return_dict=True,
|
825 |
+
)["x"]
|
826 |
+
noise_pred_list.append(noise_pred)
|
827 |
+
noise_pred = torch.cat(noise_pred_list, dim=0)
|
828 |
+
|
829 |
+
# perform classifier free guidance
|
830 |
+
if do_classifier_free_guidance:
|
831 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
832 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
833 |
+
|
834 |
+
# # SkyReels' rescale noise config is omitted for now
|
835 |
+
# if guidance_rescale > 0.0:
|
836 |
+
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
837 |
+
# noise_pred = rescale_noise_cfg(
|
838 |
+
# noise_pred,
|
839 |
+
# noise_pred_cond,
|
840 |
+
# guidance_rescale=self.guidance_rescale,
|
841 |
+
# )
|
842 |
+
|
843 |
+
# compute the previous noisy sample x_t -> x_t-1
|
844 |
+
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
845 |
+
|
846 |
+
# update progress bar
|
847 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
848 |
+
if progress_bar is not None:
|
849 |
+
progress_bar.update()
|
850 |
+
|
851 |
+
# print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
|
852 |
+
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
853 |
+
|
854 |
+
latents = latents.detach().cpu()
|
855 |
+
transformer = None
|
856 |
+
clean_memory_on_device(device)
|
857 |
+
|
858 |
+
# Save samples
|
859 |
+
output_type = args.output_type
|
860 |
+
save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
|
861 |
+
os.makedirs(save_path, exist_ok=True)
|
862 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
863 |
+
|
864 |
+
if output_type == "latent" or output_type == "both":
|
865 |
+
# save latent
|
866 |
+
for i, latent in enumerate(latents):
|
867 |
+
latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors"
|
868 |
+
|
869 |
+
if args.no_metadata:
|
870 |
+
metadata = None
|
871 |
+
else:
|
872 |
+
metadata = {
|
873 |
+
"seeds": f"{seeds[i]}",
|
874 |
+
"prompt": f"{args.prompt}",
|
875 |
+
"height": f"{height}",
|
876 |
+
"width": f"{width}",
|
877 |
+
"video_length": f"{video_length}",
|
878 |
+
"infer_steps": f"{num_inference_steps}",
|
879 |
+
"guidance_scale": f"{args.guidance_scale}",
|
880 |
+
"embedded_cfg_scale": f"{args.embedded_cfg_scale}",
|
881 |
+
}
|
882 |
+
if args.negative_prompt is not None:
|
883 |
+
metadata["negative_prompt"] = f"{args.negative_prompt}"
|
884 |
+
sd = {"latent": latent}
|
885 |
+
save_file(sd, latent_path, metadata=metadata)
|
886 |
+
|
887 |
+
logger.info(f"Latent save to: {latent_path}")
|
888 |
+
if output_type == "video" or output_type == "both":
|
889 |
+
# save video
|
890 |
+
videos = decode_latents(args, latents, device)
|
891 |
+
for i, sample in enumerate(videos):
|
892 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
|
893 |
+
sample = sample.unsqueeze(0)
|
894 |
+
video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4"
|
895 |
+
save_videos_grid(sample, video_path, fps=args.fps)
|
896 |
+
logger.info(f"Sample save to: {video_path}")
|
897 |
+
elif output_type == "images":
|
898 |
+
# save images
|
899 |
+
videos = decode_latents(args, latents, device)
|
900 |
+
for i, sample in enumerate(videos):
|
901 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
|
902 |
+
sample = sample.unsqueeze(0)
|
903 |
+
image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}"
|
904 |
+
save_images_grid(sample, save_path, image_name)
|
905 |
+
logger.info(f"Sample images save to: {save_path}/{image_name}")
|
906 |
+
|
907 |
+
logger.info("Done!")
|
908 |
+
|
909 |
+
|
910 |
+
if __name__ == "__main__":
|
911 |
+
main()
|
hv_train.py
ADDED
@@ -0,0 +1,1721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import asyncio
|
3 |
+
from datetime import timedelta
|
4 |
+
import gc
|
5 |
+
import importlib
|
6 |
+
import argparse
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import pathlib
|
10 |
+
import re
|
11 |
+
import sys
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import json
|
15 |
+
from multiprocessing import Value
|
16 |
+
from typing import Any, Dict, List, Optional
|
17 |
+
import accelerate
|
18 |
+
import numpy as np
|
19 |
+
from packaging.version import Version
|
20 |
+
|
21 |
+
import huggingface_hub
|
22 |
+
import toml
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from tqdm import tqdm
|
26 |
+
from accelerate.utils import set_seed
|
27 |
+
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
28 |
+
from safetensors.torch import load_file, save_file
|
29 |
+
import transformers
|
30 |
+
from diffusers.optimization import (
|
31 |
+
SchedulerType as DiffusersSchedulerType,
|
32 |
+
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
|
33 |
+
)
|
34 |
+
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
35 |
+
|
36 |
+
from dataset import config_utils
|
37 |
+
from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape
|
38 |
+
import hunyuan_model.text_encoder as text_encoder_module
|
39 |
+
from hunyuan_model.vae import load_vae
|
40 |
+
import hunyuan_model.vae as vae_module
|
41 |
+
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
42 |
+
import networks.lora as lora_module
|
43 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
44 |
+
from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO
|
45 |
+
|
46 |
+
import logging
|
47 |
+
|
48 |
+
from utils import huggingface_utils, model_utils, train_utils, sai_model_spec
|
49 |
+
|
50 |
+
logger = logging.getLogger(__name__)
|
51 |
+
logging.basicConfig(level=logging.INFO)
|
52 |
+
|
53 |
+
|
54 |
+
BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video"
|
55 |
+
|
56 |
+
# TODO make separate file for some functions to commonize with other scripts
|
57 |
+
|
58 |
+
|
59 |
+
def clean_memory_on_device(device: torch.device):
|
60 |
+
r"""
|
61 |
+
Clean memory on the specified device, will be called from training scripts.
|
62 |
+
"""
|
63 |
+
gc.collect()
|
64 |
+
|
65 |
+
# device may "cuda" or "cuda:0", so we need to check the type of device
|
66 |
+
if device.type == "cuda":
|
67 |
+
torch.cuda.empty_cache()
|
68 |
+
if device.type == "xpu":
|
69 |
+
torch.xpu.empty_cache()
|
70 |
+
if device.type == "mps":
|
71 |
+
torch.mps.empty_cache()
|
72 |
+
|
73 |
+
|
74 |
+
# for collate_fn: epoch and step is multiprocessing.Value
|
75 |
+
class collator_class:
|
76 |
+
def __init__(self, epoch, step, dataset):
|
77 |
+
self.current_epoch = epoch
|
78 |
+
self.current_step = step
|
79 |
+
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
|
80 |
+
|
81 |
+
def __call__(self, examples):
|
82 |
+
worker_info = torch.utils.data.get_worker_info()
|
83 |
+
# worker_info is None in the main process
|
84 |
+
if worker_info is not None:
|
85 |
+
dataset = worker_info.dataset
|
86 |
+
else:
|
87 |
+
dataset = self.dataset
|
88 |
+
|
89 |
+
# set epoch and step
|
90 |
+
dataset.set_current_epoch(self.current_epoch.value)
|
91 |
+
dataset.set_current_step(self.current_step.value)
|
92 |
+
return examples[0]
|
93 |
+
|
94 |
+
|
95 |
+
def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
|
96 |
+
"""
|
97 |
+
DeepSpeed is not supported in this script currently.
|
98 |
+
"""
|
99 |
+
if args.logging_dir is None:
|
100 |
+
logging_dir = None
|
101 |
+
else:
|
102 |
+
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
103 |
+
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
|
104 |
+
|
105 |
+
if args.log_with is None:
|
106 |
+
if logging_dir is not None:
|
107 |
+
log_with = "tensorboard"
|
108 |
+
else:
|
109 |
+
log_with = None
|
110 |
+
else:
|
111 |
+
log_with = args.log_with
|
112 |
+
if log_with in ["tensorboard", "all"]:
|
113 |
+
if logging_dir is None:
|
114 |
+
raise ValueError(
|
115 |
+
"logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
|
116 |
+
)
|
117 |
+
if log_with in ["wandb", "all"]:
|
118 |
+
try:
|
119 |
+
import wandb
|
120 |
+
except ImportError:
|
121 |
+
raise ImportError("No wandb / wandb がインストールされていないようです")
|
122 |
+
if logging_dir is not None:
|
123 |
+
os.makedirs(logging_dir, exist_ok=True)
|
124 |
+
os.environ["WANDB_DIR"] = logging_dir
|
125 |
+
if args.wandb_api_key is not None:
|
126 |
+
wandb.login(key=args.wandb_api_key)
|
127 |
+
|
128 |
+
kwargs_handlers = [
|
129 |
+
(
|
130 |
+
InitProcessGroupKwargs(
|
131 |
+
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
132 |
+
init_method=(
|
133 |
+
"env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
|
134 |
+
),
|
135 |
+
timeout=timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
|
136 |
+
)
|
137 |
+
if torch.cuda.device_count() > 1
|
138 |
+
else None
|
139 |
+
),
|
140 |
+
(
|
141 |
+
DistributedDataParallelKwargs(
|
142 |
+
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
|
143 |
+
)
|
144 |
+
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
145 |
+
else None
|
146 |
+
),
|
147 |
+
]
|
148 |
+
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
149 |
+
|
150 |
+
accelerator = Accelerator(
|
151 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
152 |
+
mixed_precision=args.mixed_precision,
|
153 |
+
log_with=log_with,
|
154 |
+
project_dir=logging_dir,
|
155 |
+
kwargs_handlers=kwargs_handlers,
|
156 |
+
)
|
157 |
+
print("accelerator device:", accelerator.device)
|
158 |
+
return accelerator
|
159 |
+
|
160 |
+
|
161 |
+
def line_to_prompt_dict(line: str) -> dict:
|
162 |
+
# subset of gen_img_diffusers
|
163 |
+
prompt_args = line.split(" --")
|
164 |
+
prompt_dict = {}
|
165 |
+
prompt_dict["prompt"] = prompt_args[0]
|
166 |
+
|
167 |
+
for parg in prompt_args:
|
168 |
+
try:
|
169 |
+
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
170 |
+
if m:
|
171 |
+
prompt_dict["width"] = int(m.group(1))
|
172 |
+
continue
|
173 |
+
|
174 |
+
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
175 |
+
if m:
|
176 |
+
prompt_dict["height"] = int(m.group(1))
|
177 |
+
continue
|
178 |
+
|
179 |
+
m = re.match(r"f (\d+)", parg, re.IGNORECASE)
|
180 |
+
if m:
|
181 |
+
prompt_dict["frame_count"] = int(m.group(1))
|
182 |
+
continue
|
183 |
+
|
184 |
+
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
185 |
+
if m:
|
186 |
+
prompt_dict["seed"] = int(m.group(1))
|
187 |
+
continue
|
188 |
+
|
189 |
+
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
190 |
+
if m: # steps
|
191 |
+
prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
|
192 |
+
continue
|
193 |
+
|
194 |
+
# m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
195 |
+
# if m: # scale
|
196 |
+
# prompt_dict["scale"] = float(m.group(1))
|
197 |
+
# continue
|
198 |
+
# m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
199 |
+
# if m: # negative prompt
|
200 |
+
# prompt_dict["negative_prompt"] = m.group(1)
|
201 |
+
# continue
|
202 |
+
|
203 |
+
except ValueError as ex:
|
204 |
+
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
205 |
+
logger.error(ex)
|
206 |
+
|
207 |
+
return prompt_dict
|
208 |
+
|
209 |
+
|
210 |
+
def load_prompts(prompt_file: str) -> list[Dict]:
|
211 |
+
# read prompts
|
212 |
+
if prompt_file.endswith(".txt"):
|
213 |
+
with open(prompt_file, "r", encoding="utf-8") as f:
|
214 |
+
lines = f.readlines()
|
215 |
+
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
216 |
+
elif prompt_file.endswith(".toml"):
|
217 |
+
with open(prompt_file, "r", encoding="utf-8") as f:
|
218 |
+
data = toml.load(f)
|
219 |
+
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
220 |
+
elif prompt_file.endswith(".json"):
|
221 |
+
with open(prompt_file, "r", encoding="utf-8") as f:
|
222 |
+
prompts = json.load(f)
|
223 |
+
|
224 |
+
# preprocess prompts
|
225 |
+
for i in range(len(prompts)):
|
226 |
+
prompt_dict = prompts[i]
|
227 |
+
if isinstance(prompt_dict, str):
|
228 |
+
prompt_dict = line_to_prompt_dict(prompt_dict)
|
229 |
+
prompts[i] = prompt_dict
|
230 |
+
assert isinstance(prompt_dict, dict)
|
231 |
+
|
232 |
+
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
233 |
+
prompt_dict["enum"] = i
|
234 |
+
prompt_dict.pop("subset", None)
|
235 |
+
|
236 |
+
return prompts
|
237 |
+
|
238 |
+
|
239 |
+
def compute_density_for_timestep_sampling(
|
240 |
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
241 |
+
):
|
242 |
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
243 |
+
|
244 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
245 |
+
|
246 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
247 |
+
"""
|
248 |
+
if weighting_scheme == "logit_normal":
|
249 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
250 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
251 |
+
u = torch.nn.functional.sigmoid(u)
|
252 |
+
elif weighting_scheme == "mode":
|
253 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
254 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
255 |
+
else:
|
256 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
257 |
+
return u
|
258 |
+
|
259 |
+
|
260 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
261 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
262 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
263 |
+
timesteps = timesteps.to(device)
|
264 |
+
|
265 |
+
# if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps):
|
266 |
+
if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]):
|
267 |
+
# raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
|
268 |
+
# round to nearest timestep
|
269 |
+
logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
|
270 |
+
step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]
|
271 |
+
else:
|
272 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
273 |
+
|
274 |
+
sigma = sigmas[step_indices].flatten()
|
275 |
+
while len(sigma.shape) < n_dim:
|
276 |
+
sigma = sigma.unsqueeze(-1)
|
277 |
+
return sigma
|
278 |
+
|
279 |
+
|
280 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype):
|
281 |
+
"""Computes loss weighting scheme for SD3 training.
|
282 |
+
|
283 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
284 |
+
|
285 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
286 |
+
"""
|
287 |
+
if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap":
|
288 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype)
|
289 |
+
if weighting_scheme == "sigma_sqrt":
|
290 |
+
weighting = (sigmas**-2.0).float()
|
291 |
+
else:
|
292 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
293 |
+
weighting = 2 / (math.pi * bot)
|
294 |
+
else:
|
295 |
+
weighting = None # torch.ones_like(sigmas)
|
296 |
+
return weighting
|
297 |
+
|
298 |
+
|
299 |
+
class FineTuningTrainer:
|
300 |
+
def __init__(self):
|
301 |
+
pass
|
302 |
+
|
303 |
+
def process_sample_prompts(
|
304 |
+
self,
|
305 |
+
args: argparse.Namespace,
|
306 |
+
accelerator: Accelerator,
|
307 |
+
sample_prompts: str,
|
308 |
+
text_encoder1: str,
|
309 |
+
text_encoder2: str,
|
310 |
+
fp8_llm: bool,
|
311 |
+
):
|
312 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
|
313 |
+
prompts = load_prompts(sample_prompts)
|
314 |
+
|
315 |
+
def encode_for_text_encoder(text_encoder, is_llm=True):
|
316 |
+
sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
|
317 |
+
with accelerator.autocast(), torch.no_grad():
|
318 |
+
for prompt_dict in prompts:
|
319 |
+
for p in [prompt_dict.get("prompt", "")]:
|
320 |
+
if p not in sample_prompts_te_outputs:
|
321 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
322 |
+
|
323 |
+
data_type = "video"
|
324 |
+
text_inputs = text_encoder.text2tokens(p, data_type=data_type)
|
325 |
+
|
326 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
|
327 |
+
sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask)
|
328 |
+
|
329 |
+
return sample_prompts_te_outputs
|
330 |
+
|
331 |
+
# Load Text Encoder 1 and encode
|
332 |
+
text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype)
|
333 |
+
logger.info(f"loading text encoder 1: {text_encoder1}")
|
334 |
+
text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype)
|
335 |
+
|
336 |
+
logger.info("encoding with Text Encoder 1")
|
337 |
+
te_outputs_1 = encode_for_text_encoder(text_encoder_1)
|
338 |
+
del text_encoder_1
|
339 |
+
|
340 |
+
# Load Text Encoder 2 and encode
|
341 |
+
logger.info(f"loading text encoder 2: {text_encoder2}")
|
342 |
+
text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype)
|
343 |
+
|
344 |
+
logger.info("encoding with Text Encoder 2")
|
345 |
+
te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False)
|
346 |
+
del text_encoder_2
|
347 |
+
|
348 |
+
# prepare sample parameters
|
349 |
+
sample_parameters = []
|
350 |
+
for prompt_dict in prompts:
|
351 |
+
prompt_dict_copy = prompt_dict.copy()
|
352 |
+
p = prompt_dict.get("prompt", "")
|
353 |
+
prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0]
|
354 |
+
prompt_dict_copy["llm_mask"] = te_outputs_1[p][1]
|
355 |
+
prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0]
|
356 |
+
prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1]
|
357 |
+
sample_parameters.append(prompt_dict_copy)
|
358 |
+
|
359 |
+
clean_memory_on_device(accelerator.device)
|
360 |
+
|
361 |
+
return sample_parameters
|
362 |
+
|
363 |
+
def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
|
364 |
+
# adamw, adamw8bit, adafactor
|
365 |
+
|
366 |
+
optimizer_type = args.optimizer_type.lower()
|
367 |
+
|
368 |
+
# split optimizer_type and optimizer_args
|
369 |
+
optimizer_kwargs = {}
|
370 |
+
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
371 |
+
for arg in args.optimizer_args:
|
372 |
+
key, value = arg.split("=")
|
373 |
+
value = ast.literal_eval(value)
|
374 |
+
optimizer_kwargs[key] = value
|
375 |
+
|
376 |
+
lr = args.learning_rate
|
377 |
+
optimizer = None
|
378 |
+
optimizer_class = None
|
379 |
+
|
380 |
+
if optimizer_type.endswith("8bit".lower()):
|
381 |
+
try:
|
382 |
+
import bitsandbytes as bnb
|
383 |
+
except ImportError:
|
384 |
+
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
385 |
+
|
386 |
+
if optimizer_type == "AdamW8bit".lower():
|
387 |
+
logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
388 |
+
optimizer_class = bnb.optim.AdamW8bit
|
389 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
390 |
+
|
391 |
+
elif optimizer_type == "Adafactor".lower():
|
392 |
+
# Adafactor: check relative_step and warmup_init
|
393 |
+
if "relative_step" not in optimizer_kwargs:
|
394 |
+
optimizer_kwargs["relative_step"] = True # default
|
395 |
+
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
396 |
+
logger.info(
|
397 |
+
f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
|
398 |
+
)
|
399 |
+
optimizer_kwargs["relative_step"] = True
|
400 |
+
logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
|
401 |
+
|
402 |
+
if optimizer_kwargs["relative_step"]:
|
403 |
+
logger.info(f"relative_step is true / relative_stepがtrueです")
|
404 |
+
if lr != 0.0:
|
405 |
+
logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
406 |
+
args.learning_rate = None
|
407 |
+
|
408 |
+
if args.lr_scheduler != "adafactor":
|
409 |
+
logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
410 |
+
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
411 |
+
|
412 |
+
lr = None
|
413 |
+
else:
|
414 |
+
if args.max_grad_norm != 0.0:
|
415 |
+
logger.warning(
|
416 |
+
f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
|
417 |
+
)
|
418 |
+
if args.lr_scheduler != "constant_with_warmup":
|
419 |
+
logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
420 |
+
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
421 |
+
logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
422 |
+
|
423 |
+
optimizer_class = transformers.optimization.Adafactor
|
424 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
425 |
+
|
426 |
+
elif optimizer_type == "AdamW".lower():
|
427 |
+
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
|
428 |
+
optimizer_class = torch.optim.AdamW
|
429 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
430 |
+
|
431 |
+
if optimizer is None:
|
432 |
+
# 任意のoptimizerを使う
|
433 |
+
case_sensitive_optimizer_type = args.optimizer_type # not lower
|
434 |
+
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
|
435 |
+
|
436 |
+
if "." not in case_sensitive_optimizer_type: # from torch.optim
|
437 |
+
optimizer_module = torch.optim
|
438 |
+
else: # from other library
|
439 |
+
values = case_sensitive_optimizer_type.split(".")
|
440 |
+
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
441 |
+
case_sensitive_optimizer_type = values[-1]
|
442 |
+
|
443 |
+
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
|
444 |
+
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
445 |
+
|
446 |
+
# for logging
|
447 |
+
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
448 |
+
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
449 |
+
|
450 |
+
# get train and eval functions
|
451 |
+
if hasattr(optimizer, "train") and callable(optimizer.train):
|
452 |
+
train_fn = optimizer.train
|
453 |
+
eval_fn = optimizer.eval
|
454 |
+
else:
|
455 |
+
train_fn = lambda: None
|
456 |
+
eval_fn = lambda: None
|
457 |
+
|
458 |
+
return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn
|
459 |
+
|
460 |
+
def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool:
|
461 |
+
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
462 |
+
|
463 |
+
def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any:
|
464 |
+
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
|
465 |
+
# this scheduler is used for logging only.
|
466 |
+
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
|
467 |
+
class DummyScheduler:
|
468 |
+
def __init__(self, optimizer: torch.optim.Optimizer):
|
469 |
+
self.optimizer = optimizer
|
470 |
+
|
471 |
+
def step(self):
|
472 |
+
pass
|
473 |
+
|
474 |
+
def get_last_lr(self):
|
475 |
+
return [group["lr"] for group in self.optimizer.param_groups]
|
476 |
+
|
477 |
+
return DummyScheduler(optimizer)
|
478 |
+
|
479 |
+
def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int):
|
480 |
+
"""
|
481 |
+
Unified API to get any scheduler from its name.
|
482 |
+
"""
|
483 |
+
# if schedulefree optimizer, return dummy scheduler
|
484 |
+
if self.is_schedulefree_optimizer(optimizer, args):
|
485 |
+
return self.get_dummy_scheduler(optimizer)
|
486 |
+
|
487 |
+
name = args.lr_scheduler
|
488 |
+
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
489 |
+
num_warmup_steps: Optional[int] = (
|
490 |
+
int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
|
491 |
+
)
|
492 |
+
num_decay_steps: Optional[int] = (
|
493 |
+
int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
|
494 |
+
)
|
495 |
+
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
|
496 |
+
num_cycles = args.lr_scheduler_num_cycles
|
497 |
+
power = args.lr_scheduler_power
|
498 |
+
timescale = args.lr_scheduler_timescale
|
499 |
+
min_lr_ratio = args.lr_scheduler_min_lr_ratio
|
500 |
+
|
501 |
+
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
502 |
+
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
503 |
+
for arg in args.lr_scheduler_args:
|
504 |
+
key, value = arg.split("=")
|
505 |
+
value = ast.literal_eval(value)
|
506 |
+
lr_scheduler_kwargs[key] = value
|
507 |
+
|
508 |
+
def wrap_check_needless_num_warmup_steps(return_vals):
|
509 |
+
if num_warmup_steps is not None and num_warmup_steps != 0:
|
510 |
+
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
|
511 |
+
return return_vals
|
512 |
+
|
513 |
+
# using any lr_scheduler from other library
|
514 |
+
if args.lr_scheduler_type:
|
515 |
+
lr_scheduler_type = args.lr_scheduler_type
|
516 |
+
logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
|
517 |
+
if "." not in lr_scheduler_type: # default to use torch.optim
|
518 |
+
lr_scheduler_module = torch.optim.lr_scheduler
|
519 |
+
else:
|
520 |
+
values = lr_scheduler_type.split(".")
|
521 |
+
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
|
522 |
+
lr_scheduler_type = values[-1]
|
523 |
+
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
524 |
+
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
525 |
+
return lr_scheduler
|
526 |
+
|
527 |
+
if name.startswith("adafactor"):
|
528 |
+
assert (
|
529 |
+
type(optimizer) == transformers.optimization.Adafactor
|
530 |
+
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
531 |
+
initial_lr = float(name.split(":")[1])
|
532 |
+
# logger.info(f"adafactor scheduler init lr {initial_lr}")
|
533 |
+
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
|
534 |
+
|
535 |
+
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
|
536 |
+
name = DiffusersSchedulerType(name)
|
537 |
+
schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
|
538 |
+
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
539 |
+
|
540 |
+
name = SchedulerType(name)
|
541 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
542 |
+
|
543 |
+
if name == SchedulerType.CONSTANT:
|
544 |
+
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
545 |
+
|
546 |
+
# All other schedulers require `num_warmup_steps`
|
547 |
+
if num_warmup_steps is None:
|
548 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
549 |
+
|
550 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
551 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
|
552 |
+
|
553 |
+
if name == SchedulerType.INVERSE_SQRT:
|
554 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
|
555 |
+
|
556 |
+
# All other schedulers require `num_training_steps`
|
557 |
+
if num_training_steps is None:
|
558 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
559 |
+
|
560 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
561 |
+
return schedule_func(
|
562 |
+
optimizer,
|
563 |
+
num_warmup_steps=num_warmup_steps,
|
564 |
+
num_training_steps=num_training_steps,
|
565 |
+
num_cycles=num_cycles,
|
566 |
+
**lr_scheduler_kwargs,
|
567 |
+
)
|
568 |
+
|
569 |
+
if name == SchedulerType.POLYNOMIAL:
|
570 |
+
return schedule_func(
|
571 |
+
optimizer,
|
572 |
+
num_warmup_steps=num_warmup_steps,
|
573 |
+
num_training_steps=num_training_steps,
|
574 |
+
power=power,
|
575 |
+
**lr_scheduler_kwargs,
|
576 |
+
)
|
577 |
+
|
578 |
+
if name == SchedulerType.COSINE_WITH_MIN_LR:
|
579 |
+
return schedule_func(
|
580 |
+
optimizer,
|
581 |
+
num_warmup_steps=num_warmup_steps,
|
582 |
+
num_training_steps=num_training_steps,
|
583 |
+
num_cycles=num_cycles / 2,
|
584 |
+
min_lr_rate=min_lr_ratio,
|
585 |
+
**lr_scheduler_kwargs,
|
586 |
+
)
|
587 |
+
|
588 |
+
# these schedulers do not require `num_decay_steps`
|
589 |
+
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
|
590 |
+
return schedule_func(
|
591 |
+
optimizer,
|
592 |
+
num_warmup_steps=num_warmup_steps,
|
593 |
+
num_training_steps=num_training_steps,
|
594 |
+
**lr_scheduler_kwargs,
|
595 |
+
)
|
596 |
+
|
597 |
+
# All other schedulers require `num_decay_steps`
|
598 |
+
if num_decay_steps is None:
|
599 |
+
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
|
600 |
+
if name == SchedulerType.WARMUP_STABLE_DECAY:
|
601 |
+
return schedule_func(
|
602 |
+
optimizer,
|
603 |
+
num_warmup_steps=num_warmup_steps,
|
604 |
+
num_stable_steps=num_stable_steps,
|
605 |
+
num_decay_steps=num_decay_steps,
|
606 |
+
num_cycles=num_cycles / 2,
|
607 |
+
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
|
608 |
+
**lr_scheduler_kwargs,
|
609 |
+
)
|
610 |
+
|
611 |
+
return schedule_func(
|
612 |
+
optimizer,
|
613 |
+
num_warmup_steps=num_warmup_steps,
|
614 |
+
num_training_steps=num_training_steps,
|
615 |
+
num_decay_steps=num_decay_steps,
|
616 |
+
**lr_scheduler_kwargs,
|
617 |
+
)
|
618 |
+
|
619 |
+
def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool:
|
620 |
+
if not args.resume:
|
621 |
+
return False
|
622 |
+
|
623 |
+
if not args.resume_from_huggingface:
|
624 |
+
logger.info(f"resume training from local state: {args.resume}")
|
625 |
+
accelerator.load_state(args.resume)
|
626 |
+
return True
|
627 |
+
|
628 |
+
logger.info(f"resume training from huggingface state: {args.resume}")
|
629 |
+
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
|
630 |
+
path_in_repo = "/".join(args.resume.split("/")[2:])
|
631 |
+
revision = None
|
632 |
+
repo_type = None
|
633 |
+
if ":" in path_in_repo:
|
634 |
+
divided = path_in_repo.split(":")
|
635 |
+
if len(divided) == 2:
|
636 |
+
path_in_repo, revision = divided
|
637 |
+
repo_type = "model"
|
638 |
+
else:
|
639 |
+
path_in_repo, revision, repo_type = divided
|
640 |
+
logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
|
641 |
+
|
642 |
+
list_files = huggingface_utils.list_dir(
|
643 |
+
repo_id=repo_id,
|
644 |
+
subfolder=path_in_repo,
|
645 |
+
revision=revision,
|
646 |
+
token=args.huggingface_token,
|
647 |
+
repo_type=repo_type,
|
648 |
+
)
|
649 |
+
|
650 |
+
async def download(filename) -> str:
|
651 |
+
def task():
|
652 |
+
return huggingface_hub.hf_hub_download(
|
653 |
+
repo_id=repo_id,
|
654 |
+
filename=filename,
|
655 |
+
revision=revision,
|
656 |
+
repo_type=repo_type,
|
657 |
+
token=args.huggingface_token,
|
658 |
+
)
|
659 |
+
|
660 |
+
return await asyncio.get_event_loop().run_in_executor(None, task)
|
661 |
+
|
662 |
+
loop = asyncio.get_event_loop()
|
663 |
+
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
|
664 |
+
if len(results) == 0:
|
665 |
+
raise ValueError(
|
666 |
+
"No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
|
667 |
+
)
|
668 |
+
dirname = os.path.dirname(results[0])
|
669 |
+
accelerator.load_state(dirname)
|
670 |
+
|
671 |
+
return True
|
672 |
+
|
673 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters):
|
674 |
+
pass
|
675 |
+
|
676 |
+
def get_noisy_model_input_and_timesteps(
|
677 |
+
self,
|
678 |
+
args: argparse.Namespace,
|
679 |
+
noise: torch.Tensor,
|
680 |
+
latents: torch.Tensor,
|
681 |
+
noise_scheduler: FlowMatchDiscreteScheduler,
|
682 |
+
device: torch.device,
|
683 |
+
dtype: torch.dtype,
|
684 |
+
):
|
685 |
+
batch_size = noise.shape[0]
|
686 |
+
|
687 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift":
|
688 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
689 |
+
# Simple random t-based noise sampling
|
690 |
+
if args.timestep_sampling == "sigmoid":
|
691 |
+
t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device))
|
692 |
+
else:
|
693 |
+
t = torch.rand((batch_size,), device=device)
|
694 |
+
|
695 |
+
elif args.timestep_sampling == "shift":
|
696 |
+
shift = args.discrete_flow_shift
|
697 |
+
logits_norm = torch.randn(batch_size, device=device)
|
698 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
699 |
+
t = logits_norm.sigmoid()
|
700 |
+
t = (t * shift) / (1 + (shift - 1) * t)
|
701 |
+
|
702 |
+
t_min = args.min_timestep if args.min_timestep is not None else 0
|
703 |
+
t_max = args.max_timestep if args.max_timestep is not None else 1000.0
|
704 |
+
t_min /= 1000.0
|
705 |
+
t_max /= 1000.0
|
706 |
+
t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1]
|
707 |
+
|
708 |
+
timesteps = t * 1000.0
|
709 |
+
t = t.view(-1, 1, 1, 1, 1)
|
710 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
711 |
+
|
712 |
+
timesteps += 1 # 1 to 1000
|
713 |
+
else:
|
714 |
+
# Sample a random timestep for each image
|
715 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
716 |
+
u = compute_density_for_timestep_sampling(
|
717 |
+
weighting_scheme=args.weighting_scheme,
|
718 |
+
batch_size=batch_size,
|
719 |
+
logit_mean=args.logit_mean,
|
720 |
+
logit_std=args.logit_std,
|
721 |
+
mode_scale=args.mode_scale,
|
722 |
+
)
|
723 |
+
# indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
724 |
+
t_min = args.min_timestep if args.min_timestep is not None else 0
|
725 |
+
t_max = args.max_timestep if args.max_timestep is not None else 1000
|
726 |
+
indices = (u * (t_max - t_min) + t_min).long()
|
727 |
+
|
728 |
+
timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000
|
729 |
+
|
730 |
+
# Add noise according to flow matching.
|
731 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
732 |
+
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
733 |
+
|
734 |
+
return noisy_model_input, timesteps
|
735 |
+
|
736 |
+
def train(self, args):
|
737 |
+
if args.seed is None:
|
738 |
+
args.seed = random.randint(0, 2**32)
|
739 |
+
set_seed(args.seed)
|
740 |
+
|
741 |
+
# Load dataset config
|
742 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
743 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
744 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
745 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
|
746 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
|
747 |
+
|
748 |
+
current_epoch = Value("i", 0)
|
749 |
+
current_step = Value("i", 0)
|
750 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
751 |
+
collator = collator_class(current_epoch, current_step, ds_for_collator)
|
752 |
+
|
753 |
+
# prepare accelerator
|
754 |
+
logger.info("preparing accelerator")
|
755 |
+
accelerator = prepare_accelerator(args)
|
756 |
+
is_main_process = accelerator.is_main_process
|
757 |
+
|
758 |
+
# prepare dtype
|
759 |
+
weight_dtype = torch.float32
|
760 |
+
if args.mixed_precision == "fp16":
|
761 |
+
weight_dtype = torch.float16
|
762 |
+
elif args.mixed_precision == "bf16":
|
763 |
+
weight_dtype = torch.bfloat16
|
764 |
+
|
765 |
+
# HunyuanVideo specific
|
766 |
+
vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype)
|
767 |
+
|
768 |
+
# get embedding for sampling images
|
769 |
+
sample_parameters = vae = None
|
770 |
+
if args.sample_prompts:
|
771 |
+
sample_parameters = self.process_sample_prompts(
|
772 |
+
args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm
|
773 |
+
)
|
774 |
+
|
775 |
+
# Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
|
776 |
+
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae)
|
777 |
+
vae.requires_grad_(False)
|
778 |
+
vae.eval()
|
779 |
+
|
780 |
+
if args.vae_chunk_size is not None:
|
781 |
+
vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
|
782 |
+
logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
|
783 |
+
if args.vae_spatial_tile_sample_min_size is not None:
|
784 |
+
vae.enable_spatial_tiling(True)
|
785 |
+
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
|
786 |
+
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
|
787 |
+
elif args.vae_tiling:
|
788 |
+
vae.enable_spatial_tiling(True)
|
789 |
+
|
790 |
+
# load DiT model
|
791 |
+
blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
|
792 |
+
loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
|
793 |
+
|
794 |
+
logger.info(f"Loading DiT model from {args.dit}")
|
795 |
+
if args.sdpa:
|
796 |
+
attn_mode = "torch"
|
797 |
+
elif args.flash_attn:
|
798 |
+
attn_mode = "flash"
|
799 |
+
elif args.sage_attn:
|
800 |
+
attn_mode = "sageattn"
|
801 |
+
elif args.xformers:
|
802 |
+
attn_mode = "xformers"
|
803 |
+
else:
|
804 |
+
raise ValueError(
|
805 |
+
f"either --sdpa, --flash-attn, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --sage-attn, --xformersのいずれかを指定してください"
|
806 |
+
)
|
807 |
+
transformer = load_transformer(
|
808 |
+
args.dit, attn_mode, args.split_attn, loading_device, None, in_channels=args.dit_in_channels
|
809 |
+
) # load as is
|
810 |
+
|
811 |
+
if blocks_to_swap > 0:
|
812 |
+
logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}")
|
813 |
+
transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True)
|
814 |
+
transformer.move_to_device_except_swap_blocks(accelerator.device)
|
815 |
+
if args.img_in_txt_in_offloading:
|
816 |
+
logger.info("Enable offloading img_in and txt_in to CPU")
|
817 |
+
transformer.enable_img_in_txt_in_offloading()
|
818 |
+
|
819 |
+
if args.gradient_checkpointing:
|
820 |
+
transformer.enable_gradient_checkpointing()
|
821 |
+
|
822 |
+
# prepare optimizer, data loader etc.
|
823 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
824 |
+
|
825 |
+
transformer.requires_grad_(False)
|
826 |
+
if accelerator.is_main_process:
|
827 |
+
accelerator.print(f"Trainable modules '{args.trainable_modules}'.")
|
828 |
+
for name, param in transformer.named_parameters():
|
829 |
+
for trainable_module_name in args.trainable_modules:
|
830 |
+
if trainable_module_name in name:
|
831 |
+
param.requires_grad = True
|
832 |
+
break
|
833 |
+
|
834 |
+
total_params = list(transformer.parameters())
|
835 |
+
trainable_params = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
836 |
+
logger.info(
|
837 |
+
f"number of trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6} M, total paramters: {sum(p.numel() for p in total_params) / 1e6} M"
|
838 |
+
)
|
839 |
+
optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer(
|
840 |
+
args, trainable_params
|
841 |
+
)
|
842 |
+
|
843 |
+
# prepare dataloader
|
844 |
+
|
845 |
+
# num workers for data loader: if 0, persistent_workers is not available
|
846 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
847 |
+
|
848 |
+
train_dataloader = torch.utils.data.DataLoader(
|
849 |
+
train_dataset_group,
|
850 |
+
batch_size=1,
|
851 |
+
shuffle=True,
|
852 |
+
collate_fn=collator,
|
853 |
+
num_workers=n_workers,
|
854 |
+
persistent_workers=args.persistent_data_loader_workers,
|
855 |
+
)
|
856 |
+
|
857 |
+
# calculate max_train_steps
|
858 |
+
if args.max_train_epochs is not None:
|
859 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
860 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
861 |
+
)
|
862 |
+
accelerator.print(
|
863 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
864 |
+
)
|
865 |
+
|
866 |
+
# send max_train_steps to train_dataset_group
|
867 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
868 |
+
|
869 |
+
# prepare lr_scheduler
|
870 |
+
lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes)
|
871 |
+
|
872 |
+
# prepare training model. accelerator does some magic here
|
873 |
+
|
874 |
+
# experimental feature: train the model with gradients in fp16/bf16
|
875 |
+
dit_dtype = torch.float32
|
876 |
+
if args.full_fp16:
|
877 |
+
assert (
|
878 |
+
args.mixed_precision == "fp16"
|
879 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
880 |
+
accelerator.print("enable full fp16 training.")
|
881 |
+
dit_weight_dtype = torch.float16
|
882 |
+
elif args.full_bf16:
|
883 |
+
assert (
|
884 |
+
args.mixed_precision == "bf16"
|
885 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
886 |
+
accelerator.print("enable full bf16 training.")
|
887 |
+
dit_weight_dtype = torch.bfloat16
|
888 |
+
else:
|
889 |
+
dit_weight_dtype = torch.float32
|
890 |
+
|
891 |
+
# TODO add fused optimizer and stochastic rounding
|
892 |
+
|
893 |
+
# cast model to dit_weight_dtype
|
894 |
+
# if dit_dtype != dit_weight_dtype:
|
895 |
+
logger.info(f"casting model to {dit_weight_dtype}")
|
896 |
+
transformer.to(dit_weight_dtype)
|
897 |
+
|
898 |
+
if blocks_to_swap > 0:
|
899 |
+
transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0])
|
900 |
+
accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
901 |
+
accelerator.unwrap_model(transformer).prepare_block_swap_before_forward()
|
902 |
+
else:
|
903 |
+
transformer = accelerator.prepare(transformer)
|
904 |
+
|
905 |
+
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
906 |
+
|
907 |
+
transformer.train()
|
908 |
+
|
909 |
+
if args.full_fp16:
|
910 |
+
# patch accelerator for fp16 training
|
911 |
+
# def patch_accelerator_for_fp16_training(accelerator):
|
912 |
+
org_unscale_grads = accelerator.scaler._unscale_grads_
|
913 |
+
|
914 |
+
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
915 |
+
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
916 |
+
|
917 |
+
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
|
918 |
+
|
919 |
+
# resume from local or huggingface. accelerator.step is set
|
920 |
+
self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume)
|
921 |
+
|
922 |
+
# epoch数を計算する
|
923 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
924 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
925 |
+
|
926 |
+
# 学習���る
|
927 |
+
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
928 |
+
|
929 |
+
accelerator.print("running training / 学習開始")
|
930 |
+
accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}")
|
931 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
932 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
933 |
+
accelerator.print(
|
934 |
+
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
935 |
+
)
|
936 |
+
# accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
937 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
938 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
939 |
+
|
940 |
+
if accelerator.is_main_process:
|
941 |
+
init_kwargs = {}
|
942 |
+
if args.wandb_run_name:
|
943 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
944 |
+
if args.log_tracker_config is not None:
|
945 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
946 |
+
accelerator.init_trackers(
|
947 |
+
"hunyuan_video_ft" if args.log_tracker_name is None else args.log_tracker_name,
|
948 |
+
config=train_utils.get_sanitized_config_or_none(args),
|
949 |
+
init_kwargs=init_kwargs,
|
950 |
+
)
|
951 |
+
|
952 |
+
# TODO skip until initial step
|
953 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
954 |
+
|
955 |
+
epoch_to_start = 0
|
956 |
+
global_step = 0
|
957 |
+
noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
|
958 |
+
|
959 |
+
loss_recorder = train_utils.LossRecorder()
|
960 |
+
del train_dataset_group
|
961 |
+
|
962 |
+
# function for saving/removing
|
963 |
+
def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
964 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
965 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
966 |
+
|
967 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
968 |
+
|
969 |
+
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
970 |
+
if args.min_timestep is not None or args.max_timestep is not None:
|
971 |
+
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
972 |
+
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
973 |
+
md_timesteps = (min_time_step, max_time_step)
|
974 |
+
else:
|
975 |
+
md_timesteps = None
|
976 |
+
|
977 |
+
sai_metadata = sai_model_spec.build_metadata(
|
978 |
+
None,
|
979 |
+
ARCHITECTURE_HUNYUAN_VIDEO,
|
980 |
+
time.time(),
|
981 |
+
title,
|
982 |
+
None,
|
983 |
+
args.metadata_author,
|
984 |
+
args.metadata_description,
|
985 |
+
args.metadata_license,
|
986 |
+
args.metadata_tags,
|
987 |
+
timesteps=md_timesteps,
|
988 |
+
is_lora=False,
|
989 |
+
)
|
990 |
+
|
991 |
+
save_file(unwrapped_nw.state_dict(), ckpt_file, sai_metadata)
|
992 |
+
if args.huggingface_repo_id is not None:
|
993 |
+
huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
994 |
+
|
995 |
+
def remove_model(old_ckpt_name):
|
996 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
997 |
+
if os.path.exists(old_ckpt_file):
|
998 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
999 |
+
os.remove(old_ckpt_file)
|
1000 |
+
|
1001 |
+
# For --sample_at_first
|
1002 |
+
optimizer_eval_fn()
|
1003 |
+
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters)
|
1004 |
+
optimizer_train_fn()
|
1005 |
+
if len(accelerator.trackers) > 0:
|
1006 |
+
# log empty object to commit the sample images to wandb
|
1007 |
+
accelerator.log({}, step=0)
|
1008 |
+
|
1009 |
+
# training loop
|
1010 |
+
|
1011 |
+
# log device and dtype for each model
|
1012 |
+
logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
|
1013 |
+
|
1014 |
+
clean_memory_on_device(accelerator.device)
|
1015 |
+
|
1016 |
+
pos_embed_cache = {}
|
1017 |
+
|
1018 |
+
for epoch in range(epoch_to_start, num_train_epochs):
|
1019 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
1020 |
+
current_epoch.value = epoch + 1
|
1021 |
+
|
1022 |
+
for step, batch in enumerate(train_dataloader):
|
1023 |
+
latents, llm_embeds, llm_mask, clip_embeds = batch
|
1024 |
+
bsz = latents.shape[0]
|
1025 |
+
current_step.value = global_step
|
1026 |
+
|
1027 |
+
with accelerator.accumulate(transformer):
|
1028 |
+
latents = latents * vae_module.SCALING_FACTOR
|
1029 |
+
|
1030 |
+
# Sample noise that we'll add to the latents
|
1031 |
+
noise = torch.randn_like(latents)
|
1032 |
+
|
1033 |
+
# calculate model input and timesteps
|
1034 |
+
noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps(
|
1035 |
+
args, noise, latents, noise_scheduler, accelerator.device, dit_dtype
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
weighting = compute_loss_weighting_for_sd3(
|
1039 |
+
args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
# ensure guidance_scale in args is float
|
1043 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype)
|
1044 |
+
|
1045 |
+
# ensure the hidden state will require grad
|
1046 |
+
if args.gradient_checkpointing:
|
1047 |
+
noisy_model_input.requires_grad_(True)
|
1048 |
+
guidance_vec.requires_grad_(True)
|
1049 |
+
|
1050 |
+
pos_emb_shape = latents.shape[1:]
|
1051 |
+
if pos_emb_shape not in pos_embed_cache:
|
1052 |
+
freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(
|
1053 |
+
accelerator.unwrap_model(transformer), latents.shape[2:]
|
1054 |
+
)
|
1055 |
+
# freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
|
1056 |
+
# freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
|
1057 |
+
pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
|
1058 |
+
else:
|
1059 |
+
freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]
|
1060 |
+
|
1061 |
+
# call DiT
|
1062 |
+
latents = latents.to(device=accelerator.device, dtype=dit_dtype)
|
1063 |
+
noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=dit_dtype)
|
1064 |
+
# timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype)
|
1065 |
+
# llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype)
|
1066 |
+
# llm_mask = llm_mask.to(device=accelerator.device)
|
1067 |
+
# clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype)
|
1068 |
+
with accelerator.autocast():
|
1069 |
+
model_pred = transformer(
|
1070 |
+
noisy_model_input,
|
1071 |
+
timesteps,
|
1072 |
+
text_states=llm_embeds,
|
1073 |
+
text_mask=llm_mask,
|
1074 |
+
text_states_2=clip_embeds,
|
1075 |
+
freqs_cos=freqs_cos,
|
1076 |
+
freqs_sin=freqs_sin,
|
1077 |
+
guidance=guidance_vec,
|
1078 |
+
return_dict=False,
|
1079 |
+
)
|
1080 |
+
|
1081 |
+
# flow matching loss
|
1082 |
+
target = noise - latents
|
1083 |
+
|
1084 |
+
loss = torch.nn.functional.mse_loss(model_pred.to(dit_dtype), target, reduction="none")
|
1085 |
+
|
1086 |
+
if weighting is not None:
|
1087 |
+
loss = loss * weighting
|
1088 |
+
# loss = loss.mean([1, 2, 3])
|
1089 |
+
# # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
1090 |
+
# loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
1091 |
+
|
1092 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
1093 |
+
|
1094 |
+
accelerator.backward(loss)
|
1095 |
+
if accelerator.sync_gradients:
|
1096 |
+
# self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
1097 |
+
state = accelerate.PartialState()
|
1098 |
+
if state.distributed_type != accelerate.DistributedType.NO:
|
1099 |
+
for param in transformer.parameters():
|
1100 |
+
if param.grad is not None:
|
1101 |
+
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
1102 |
+
|
1103 |
+
if args.max_grad_norm != 0.0:
|
1104 |
+
params_to_clip = accelerator.unwrap_model(transformer).parameters()
|
1105 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
1106 |
+
|
1107 |
+
optimizer.step()
|
1108 |
+
lr_scheduler.step()
|
1109 |
+
optimizer.zero_grad(set_to_none=True)
|
1110 |
+
|
1111 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
1112 |
+
if accelerator.sync_gradients:
|
1113 |
+
progress_bar.update(1)
|
1114 |
+
global_step += 1
|
1115 |
+
|
1116 |
+
optimizer_eval_fn()
|
1117 |
+
self.sample_images(
|
1118 |
+
accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters
|
1119 |
+
)
|
1120 |
+
|
1121 |
+
# 指定ステップごとにモデルを保存
|
1122 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
1123 |
+
accelerator.wait_for_everyone()
|
1124 |
+
if accelerator.is_main_process:
|
1125 |
+
ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step)
|
1126 |
+
save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch)
|
1127 |
+
|
1128 |
+
if args.save_state:
|
1129 |
+
train_utils.save_and_remove_state_stepwise(args, accelerator, global_step)
|
1130 |
+
|
1131 |
+
remove_step_no = train_utils.get_remove_step_no(args, global_step)
|
1132 |
+
if remove_step_no is not None:
|
1133 |
+
remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no)
|
1134 |
+
remove_model(remove_ckpt_name)
|
1135 |
+
optimizer_train_fn()
|
1136 |
+
|
1137 |
+
current_loss = loss.detach().item()
|
1138 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
1139 |
+
avr_loss: float = loss_recorder.moving_average
|
1140 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
1141 |
+
progress_bar.set_postfix(**logs)
|
1142 |
+
|
1143 |
+
if len(accelerator.trackers) > 0:
|
1144 |
+
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
1145 |
+
accelerator.log(logs, step=global_step)
|
1146 |
+
|
1147 |
+
if global_step >= args.max_train_steps:
|
1148 |
+
break
|
1149 |
+
|
1150 |
+
if len(accelerator.trackers) > 0:
|
1151 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
1152 |
+
accelerator.log(logs, step=epoch + 1)
|
1153 |
+
|
1154 |
+
accelerator.wait_for_everyone()
|
1155 |
+
|
1156 |
+
# 指定エポックごとにモデルを保存
|
1157 |
+
optimizer_eval_fn()
|
1158 |
+
if args.save_every_n_epochs is not None:
|
1159 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
1160 |
+
if is_main_process and saving:
|
1161 |
+
ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1)
|
1162 |
+
save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch + 1)
|
1163 |
+
|
1164 |
+
remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1)
|
1165 |
+
if remove_epoch_no is not None:
|
1166 |
+
remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no)
|
1167 |
+
remove_model(remove_ckpt_name)
|
1168 |
+
|
1169 |
+
if args.save_state:
|
1170 |
+
train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
1171 |
+
|
1172 |
+
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters)
|
1173 |
+
optimizer_train_fn()
|
1174 |
+
|
1175 |
+
# end of epoch
|
1176 |
+
|
1177 |
+
if is_main_process:
|
1178 |
+
transformer = accelerator.unwrap_model(transformer)
|
1179 |
+
|
1180 |
+
accelerator.end_training()
|
1181 |
+
optimizer_eval_fn()
|
1182 |
+
|
1183 |
+
if args.save_state or args.save_state_on_train_end:
|
1184 |
+
train_utils.save_state_on_train_end(args, accelerator)
|
1185 |
+
|
1186 |
+
if is_main_process:
|
1187 |
+
ckpt_name = train_utils.get_last_ckpt_name(args.output_name)
|
1188 |
+
save_model(ckpt_name, transformer, global_step, num_train_epochs, force_sync_upload=True)
|
1189 |
+
|
1190 |
+
logger.info("model saved.")
|
1191 |
+
|
1192 |
+
|
1193 |
+
def setup_parser() -> argparse.ArgumentParser:
|
1194 |
+
def int_or_float(value):
|
1195 |
+
if value.endswith("%"):
|
1196 |
+
try:
|
1197 |
+
return float(value[:-1]) / 100.0
|
1198 |
+
except ValueError:
|
1199 |
+
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
|
1200 |
+
try:
|
1201 |
+
float_value = float(value)
|
1202 |
+
if float_value >= 1 and float_value.is_integer():
|
1203 |
+
return int(value)
|
1204 |
+
return float(value)
|
1205 |
+
except ValueError:
|
1206 |
+
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
|
1207 |
+
|
1208 |
+
parser = argparse.ArgumentParser()
|
1209 |
+
|
1210 |
+
# general settings
|
1211 |
+
parser.add_argument(
|
1212 |
+
"--config_file",
|
1213 |
+
type=str,
|
1214 |
+
default=None,
|
1215 |
+
help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
|
1216 |
+
)
|
1217 |
+
parser.add_argument(
|
1218 |
+
"--dataset_config",
|
1219 |
+
type=pathlib.Path,
|
1220 |
+
default=None,
|
1221 |
+
required=True,
|
1222 |
+
help="config file for dataset / データセットの設定ファイル",
|
1223 |
+
)
|
1224 |
+
|
1225 |
+
# training settings
|
1226 |
+
parser.add_argument(
|
1227 |
+
"--sdpa",
|
1228 |
+
action="store_true",
|
1229 |
+
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
|
1230 |
+
)
|
1231 |
+
parser.add_argument(
|
1232 |
+
"--flash_attn",
|
1233 |
+
action="store_true",
|
1234 |
+
help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要",
|
1235 |
+
)
|
1236 |
+
parser.add_argument(
|
1237 |
+
"--sage_attn",
|
1238 |
+
action="store_true",
|
1239 |
+
help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要",
|
1240 |
+
)
|
1241 |
+
parser.add_argument(
|
1242 |
+
"--xformers",
|
1243 |
+
action="store_true",
|
1244 |
+
help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要",
|
1245 |
+
)
|
1246 |
+
parser.add_argument(
|
1247 |
+
"--split_attn",
|
1248 |
+
action="store_true",
|
1249 |
+
help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)"
|
1250 |
+
" / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)",
|
1251 |
+
)
|
1252 |
+
|
1253 |
+
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1254 |
+
parser.add_argument(
|
1255 |
+
"--max_train_epochs",
|
1256 |
+
type=int,
|
1257 |
+
default=None,
|
1258 |
+
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
|
1259 |
+
)
|
1260 |
+
parser.add_argument(
|
1261 |
+
"--max_data_loader_n_workers",
|
1262 |
+
type=int,
|
1263 |
+
default=8,
|
1264 |
+
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
|
1265 |
+
)
|
1266 |
+
parser.add_argument(
|
1267 |
+
"--persistent_data_loader_workers",
|
1268 |
+
action="store_true",
|
1269 |
+
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
|
1270 |
+
)
|
1271 |
+
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
1272 |
+
parser.add_argument(
|
1273 |
+
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
|
1274 |
+
)
|
1275 |
+
parser.add_argument(
|
1276 |
+
"--gradient_accumulation_steps",
|
1277 |
+
type=int,
|
1278 |
+
default=1,
|
1279 |
+
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
|
1280 |
+
)
|
1281 |
+
parser.add_argument(
|
1282 |
+
"--mixed_precision",
|
1283 |
+
type=str,
|
1284 |
+
default="no",
|
1285 |
+
choices=["no", "fp16", "bf16"],
|
1286 |
+
help="use mixed precision / 混合精度を使う場合、その精度",
|
1287 |
+
)
|
1288 |
+
parser.add_argument("--trainable_modules", nargs="+", default=".", help="Enter a list of trainable modules")
|
1289 |
+
|
1290 |
+
parser.add_argument(
|
1291 |
+
"--logging_dir",
|
1292 |
+
type=str,
|
1293 |
+
default=None,
|
1294 |
+
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
|
1295 |
+
)
|
1296 |
+
parser.add_argument(
|
1297 |
+
"--log_with",
|
1298 |
+
type=str,
|
1299 |
+
default=None,
|
1300 |
+
choices=["tensorboard", "wandb", "all"],
|
1301 |
+
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
|
1302 |
+
)
|
1303 |
+
parser.add_argument(
|
1304 |
+
"--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
|
1305 |
+
)
|
1306 |
+
parser.add_argument(
|
1307 |
+
"--log_tracker_name",
|
1308 |
+
type=str,
|
1309 |
+
default=None,
|
1310 |
+
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
|
1311 |
+
)
|
1312 |
+
parser.add_argument(
|
1313 |
+
"--wandb_run_name",
|
1314 |
+
type=str,
|
1315 |
+
default=None,
|
1316 |
+
help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
|
1317 |
+
)
|
1318 |
+
parser.add_argument(
|
1319 |
+
"--log_tracker_config",
|
1320 |
+
type=str,
|
1321 |
+
default=None,
|
1322 |
+
help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
|
1323 |
+
)
|
1324 |
+
parser.add_argument(
|
1325 |
+
"--wandb_api_key",
|
1326 |
+
type=str,
|
1327 |
+
default=None,
|
1328 |
+
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
|
1329 |
+
)
|
1330 |
+
parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
|
1331 |
+
|
1332 |
+
parser.add_argument(
|
1333 |
+
"--ddp_timeout",
|
1334 |
+
type=int,
|
1335 |
+
default=None,
|
1336 |
+
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
1337 |
+
)
|
1338 |
+
parser.add_argument(
|
1339 |
+
"--ddp_gradient_as_bucket_view",
|
1340 |
+
action="store_true",
|
1341 |
+
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
|
1342 |
+
)
|
1343 |
+
parser.add_argument(
|
1344 |
+
"--ddp_static_graph",
|
1345 |
+
action="store_true",
|
1346 |
+
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
|
1347 |
+
)
|
1348 |
+
|
1349 |
+
parser.add_argument(
|
1350 |
+
"--sample_every_n_steps",
|
1351 |
+
type=int,
|
1352 |
+
default=None,
|
1353 |
+
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
|
1354 |
+
)
|
1355 |
+
parser.add_argument(
|
1356 |
+
"--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
|
1357 |
+
)
|
1358 |
+
parser.add_argument(
|
1359 |
+
"--sample_every_n_epochs",
|
1360 |
+
type=int,
|
1361 |
+
default=None,
|
1362 |
+
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
|
1363 |
+
)
|
1364 |
+
parser.add_argument(
|
1365 |
+
"--sample_prompts",
|
1366 |
+
type=str,
|
1367 |
+
default=None,
|
1368 |
+
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
|
1369 |
+
)
|
1370 |
+
|
1371 |
+
# optimizer and lr scheduler settings
|
1372 |
+
parser.add_argument(
|
1373 |
+
"--optimizer_type",
|
1374 |
+
type=str,
|
1375 |
+
default="",
|
1376 |
+
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. "
|
1377 |
+
"Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ",
|
1378 |
+
)
|
1379 |
+
parser.add_argument(
|
1380 |
+
"--optimizer_args",
|
1381 |
+
type=str,
|
1382 |
+
default=None,
|
1383 |
+
nargs="*",
|
1384 |
+
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
1385 |
+
)
|
1386 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1387 |
+
parser.add_argument(
|
1388 |
+
"--max_grad_norm",
|
1389 |
+
default=1.0,
|
1390 |
+
type=float,
|
1391 |
+
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
|
1392 |
+
)
|
1393 |
+
|
1394 |
+
parser.add_argument(
|
1395 |
+
"--lr_scheduler",
|
1396 |
+
type=str,
|
1397 |
+
default="constant",
|
1398 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
|
1399 |
+
)
|
1400 |
+
parser.add_argument(
|
1401 |
+
"--lr_warmup_steps",
|
1402 |
+
type=int_or_float,
|
1403 |
+
default=0,
|
1404 |
+
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
|
1405 |
+
" / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
|
1406 |
+
)
|
1407 |
+
parser.add_argument(
|
1408 |
+
"--lr_decay_steps",
|
1409 |
+
type=int_or_float,
|
1410 |
+
default=0,
|
1411 |
+
help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
|
1412 |
+
" / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
|
1413 |
+
)
|
1414 |
+
parser.add_argument(
|
1415 |
+
"--lr_scheduler_num_cycles",
|
1416 |
+
type=int,
|
1417 |
+
default=1,
|
1418 |
+
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
|
1419 |
+
)
|
1420 |
+
parser.add_argument(
|
1421 |
+
"--lr_scheduler_power",
|
1422 |
+
type=float,
|
1423 |
+
default=1,
|
1424 |
+
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
|
1425 |
+
)
|
1426 |
+
parser.add_argument(
|
1427 |
+
"--lr_scheduler_timescale",
|
1428 |
+
type=int,
|
1429 |
+
default=None,
|
1430 |
+
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
|
1431 |
+
+ " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
|
1432 |
+
)
|
1433 |
+
parser.add_argument(
|
1434 |
+
"--lr_scheduler_min_lr_ratio",
|
1435 |
+
type=float,
|
1436 |
+
default=None,
|
1437 |
+
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
|
1438 |
+
+ " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
|
1439 |
+
)
|
1440 |
+
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
|
1441 |
+
parser.add_argument(
|
1442 |
+
"--lr_scheduler_args",
|
1443 |
+
type=str,
|
1444 |
+
default=None,
|
1445 |
+
nargs="*",
|
1446 |
+
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
|
1447 |
+
)
|
1448 |
+
|
1449 |
+
# model settings
|
1450 |
+
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス")
|
1451 |
+
parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16")
|
1452 |
+
parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
|
1453 |
+
parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス")
|
1454 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
|
1455 |
+
parser.add_argument(
|
1456 |
+
"--vae_tiling",
|
1457 |
+
action="store_true",
|
1458 |
+
help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled."
|
1459 |
+
" / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。",
|
1460 |
+
)
|
1461 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
1462 |
+
parser.add_argument(
|
1463 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
1464 |
+
)
|
1465 |
+
parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
|
1466 |
+
parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
|
1467 |
+
parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
|
1468 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
|
1469 |
+
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
1470 |
+
parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
|
1471 |
+
|
1472 |
+
parser.add_argument(
|
1473 |
+
"--blocks_to_swap",
|
1474 |
+
type=int,
|
1475 |
+
default=None,
|
1476 |
+
help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX",
|
1477 |
+
)
|
1478 |
+
parser.add_argument(
|
1479 |
+
"--img_in_txt_in_offloading",
|
1480 |
+
action="store_true",
|
1481 |
+
help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする",
|
1482 |
+
)
|
1483 |
+
|
1484 |
+
# parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers")
|
1485 |
+
parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.")
|
1486 |
+
parser.add_argument(
|
1487 |
+
"--timestep_sampling",
|
1488 |
+
choices=["sigma", "uniform", "sigmoid", "shift"],
|
1489 |
+
default="sigma",
|
1490 |
+
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
|
1491 |
+
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
|
1492 |
+
)
|
1493 |
+
parser.add_argument(
|
1494 |
+
"--discrete_flow_shift",
|
1495 |
+
type=float,
|
1496 |
+
default=1.0,
|
1497 |
+
help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
|
1498 |
+
)
|
1499 |
+
parser.add_argument(
|
1500 |
+
"--sigmoid_scale",
|
1501 |
+
type=float,
|
1502 |
+
default=1.0,
|
1503 |
+
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。',
|
1504 |
+
)
|
1505 |
+
parser.add_argument(
|
1506 |
+
"--weighting_scheme",
|
1507 |
+
type=str,
|
1508 |
+
default="none",
|
1509 |
+
choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"],
|
1510 |
+
help="weighting scheme for timestep distribution. Default is none"
|
1511 |
+
" / タイムステップ分布の重み付けスキーム、デフォルトはnone",
|
1512 |
+
)
|
1513 |
+
parser.add_argument(
|
1514 |
+
"--logit_mean",
|
1515 |
+
type=float,
|
1516 |
+
default=0.0,
|
1517 |
+
help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
|
1518 |
+
)
|
1519 |
+
parser.add_argument(
|
1520 |
+
"--logit_std",
|
1521 |
+
type=float,
|
1522 |
+
default=1.0,
|
1523 |
+
help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
|
1524 |
+
)
|
1525 |
+
parser.add_argument(
|
1526 |
+
"--mode_scale",
|
1527 |
+
type=float,
|
1528 |
+
default=1.29,
|
1529 |
+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
|
1530 |
+
)
|
1531 |
+
parser.add_argument(
|
1532 |
+
"--min_timestep",
|
1533 |
+
type=int,
|
1534 |
+
default=None,
|
1535 |
+
help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
|
1536 |
+
)
|
1537 |
+
parser.add_argument(
|
1538 |
+
"--max_timestep",
|
1539 |
+
type=int,
|
1540 |
+
default=None,
|
1541 |
+
help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
|
1542 |
+
)
|
1543 |
+
|
1544 |
+
# save and load settings
|
1545 |
+
parser.add_argument(
|
1546 |
+
"--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
|
1547 |
+
)
|
1548 |
+
parser.add_argument(
|
1549 |
+
"--output_name",
|
1550 |
+
type=str,
|
1551 |
+
default=None,
|
1552 |
+
required=True,
|
1553 |
+
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名",
|
1554 |
+
)
|
1555 |
+
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
1556 |
+
|
1557 |
+
parser.add_argument(
|
1558 |
+
"--save_every_n_epochs",
|
1559 |
+
type=int,
|
1560 |
+
default=None,
|
1561 |
+
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
|
1562 |
+
)
|
1563 |
+
parser.add_argument(
|
1564 |
+
"--save_every_n_steps",
|
1565 |
+
type=int,
|
1566 |
+
default=None,
|
1567 |
+
help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
|
1568 |
+
)
|
1569 |
+
parser.add_argument(
|
1570 |
+
"--save_last_n_epochs",
|
1571 |
+
type=int,
|
1572 |
+
default=None,
|
1573 |
+
help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
|
1574 |
+
)
|
1575 |
+
parser.add_argument(
|
1576 |
+
"--save_last_n_epochs_state",
|
1577 |
+
type=int,
|
1578 |
+
default=None,
|
1579 |
+
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
|
1580 |
+
)
|
1581 |
+
parser.add_argument(
|
1582 |
+
"--save_last_n_steps",
|
1583 |
+
type=int,
|
1584 |
+
default=None,
|
1585 |
+
help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
|
1586 |
+
)
|
1587 |
+
parser.add_argument(
|
1588 |
+
"--save_last_n_steps_state",
|
1589 |
+
type=int,
|
1590 |
+
default=None,
|
1591 |
+
help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
|
1592 |
+
)
|
1593 |
+
parser.add_argument(
|
1594 |
+
"--save_state",
|
1595 |
+
action="store_true",
|
1596 |
+
help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
|
1597 |
+
)
|
1598 |
+
parser.add_argument(
|
1599 |
+
"--save_state_on_train_end",
|
1600 |
+
action="store_true",
|
1601 |
+
help="save training state (including optimizer states etc.) on train end even if --save_state is not specified"
|
1602 |
+
" / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する",
|
1603 |
+
)
|
1604 |
+
|
1605 |
+
# SAI Model spec
|
1606 |
+
parser.add_argument(
|
1607 |
+
"--metadata_title",
|
1608 |
+
type=str,
|
1609 |
+
default=None,
|
1610 |
+
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
|
1611 |
+
)
|
1612 |
+
parser.add_argument(
|
1613 |
+
"--metadata_author",
|
1614 |
+
type=str,
|
1615 |
+
default=None,
|
1616 |
+
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
|
1617 |
+
)
|
1618 |
+
parser.add_argument(
|
1619 |
+
"--metadata_description",
|
1620 |
+
type=str,
|
1621 |
+
default=None,
|
1622 |
+
help="description for model metadata / メタデータに書き込まれるモデル説明",
|
1623 |
+
)
|
1624 |
+
parser.add_argument(
|
1625 |
+
"--metadata_license",
|
1626 |
+
type=str,
|
1627 |
+
default=None,
|
1628 |
+
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
|
1629 |
+
)
|
1630 |
+
parser.add_argument(
|
1631 |
+
"--metadata_tags",
|
1632 |
+
type=str,
|
1633 |
+
default=None,
|
1634 |
+
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
1635 |
+
)
|
1636 |
+
|
1637 |
+
# huggingface settings
|
1638 |
+
parser.add_argument(
|
1639 |
+
"--huggingface_repo_id",
|
1640 |
+
type=str,
|
1641 |
+
default=None,
|
1642 |
+
help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
|
1643 |
+
)
|
1644 |
+
parser.add_argument(
|
1645 |
+
"--huggingface_repo_type",
|
1646 |
+
type=str,
|
1647 |
+
default=None,
|
1648 |
+
help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
|
1649 |
+
)
|
1650 |
+
parser.add_argument(
|
1651 |
+
"--huggingface_path_in_repo",
|
1652 |
+
type=str,
|
1653 |
+
default=None,
|
1654 |
+
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
|
1655 |
+
)
|
1656 |
+
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
|
1657 |
+
parser.add_argument(
|
1658 |
+
"--huggingface_repo_visibility",
|
1659 |
+
type=str,
|
1660 |
+
default=None,
|
1661 |
+
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
|
1662 |
+
)
|
1663 |
+
parser.add_argument(
|
1664 |
+
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
|
1665 |
+
)
|
1666 |
+
parser.add_argument(
|
1667 |
+
"--resume_from_huggingface",
|
1668 |
+
action="store_true",
|
1669 |
+
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
|
1670 |
+
)
|
1671 |
+
parser.add_argument(
|
1672 |
+
"--async_upload",
|
1673 |
+
action="store_true",
|
1674 |
+
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
|
1675 |
+
)
|
1676 |
+
|
1677 |
+
return parser
|
1678 |
+
|
1679 |
+
|
1680 |
+
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
1681 |
+
if not args.config_file:
|
1682 |
+
return args
|
1683 |
+
|
1684 |
+
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
|
1685 |
+
|
1686 |
+
if not os.path.exists(config_path):
|
1687 |
+
logger.info(f"{config_path} not found.")
|
1688 |
+
exit(1)
|
1689 |
+
|
1690 |
+
logger.info(f"Loading settings from {config_path}...")
|
1691 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
1692 |
+
config_dict = toml.load(f)
|
1693 |
+
|
1694 |
+
# combine all sections into one
|
1695 |
+
ignore_nesting_dict = {}
|
1696 |
+
for section_name, section_dict in config_dict.items():
|
1697 |
+
# if value is not dict, save key and value as is
|
1698 |
+
if not isinstance(section_dict, dict):
|
1699 |
+
ignore_nesting_dict[section_name] = section_dict
|
1700 |
+
continue
|
1701 |
+
|
1702 |
+
# if value is dict, save all key and value into one dict
|
1703 |
+
for key, value in section_dict.items():
|
1704 |
+
ignore_nesting_dict[key] = value
|
1705 |
+
|
1706 |
+
config_args = argparse.Namespace(**ignore_nesting_dict)
|
1707 |
+
args = parser.parse_args(namespace=config_args)
|
1708 |
+
args.config_file = os.path.splitext(args.config_file)[0]
|
1709 |
+
logger.info(args.config_file)
|
1710 |
+
|
1711 |
+
return args
|
1712 |
+
|
1713 |
+
|
1714 |
+
if __name__ == "__main__":
|
1715 |
+
parser = setup_parser()
|
1716 |
+
|
1717 |
+
args = parser.parse_args()
|
1718 |
+
args = read_config_from_file(args, parser)
|
1719 |
+
|
1720 |
+
trainer = FineTuningTrainer()
|
1721 |
+
trainer.train(args)
|
hv_train_network.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
merge_lora.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
from safetensors.torch import load_file
|
5 |
+
from networks import lora
|
6 |
+
from utils.safetensors_utils import mem_eff_save_file
|
7 |
+
from hunyuan_model.models import load_transformer
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser(description="HunyuanVideo model merger script")
|
15 |
+
|
16 |
+
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
|
17 |
+
parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
|
18 |
+
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
|
19 |
+
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier (can specify multiple values)")
|
20 |
+
parser.add_argument("--save_merged_model", type=str, required=True, help="Path to save the merged model")
|
21 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for merging")
|
22 |
+
|
23 |
+
return parser.parse_args()
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
args = parse_args()
|
28 |
+
|
29 |
+
device = torch.device(args.device)
|
30 |
+
logger.info(f"Using device: {device}")
|
31 |
+
|
32 |
+
# Load DiT model
|
33 |
+
logger.info(f"Loading DiT model from {args.dit}")
|
34 |
+
transformer = load_transformer(args.dit, "torch", False, "cpu", torch.bfloat16, in_channels=args.dit_in_channels)
|
35 |
+
transformer.eval()
|
36 |
+
|
37 |
+
# Load LoRA weights and merge
|
38 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
39 |
+
for i, lora_weight in enumerate(args.lora_weight):
|
40 |
+
# Use the corresponding lora_multiplier or default to 1.0
|
41 |
+
if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
|
42 |
+
lora_multiplier = args.lora_multiplier[i]
|
43 |
+
else:
|
44 |
+
lora_multiplier = 1.0
|
45 |
+
|
46 |
+
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
|
47 |
+
weights_sd = load_file(lora_weight)
|
48 |
+
network = lora.create_network_from_weights_hunyuan_video(
|
49 |
+
lora_multiplier, weights_sd, unet=transformer, for_inference=True
|
50 |
+
)
|
51 |
+
logger.info("Merging LoRA weights to DiT model")
|
52 |
+
network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
|
53 |
+
|
54 |
+
logger.info("LoRA weights loaded")
|
55 |
+
|
56 |
+
# Save the merged model
|
57 |
+
logger.info(f"Saving merged model to {args.save_merged_model}")
|
58 |
+
mem_eff_save_file(transformer.state_dict(), args.save_merged_model)
|
59 |
+
logger.info("Merged model saved")
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
main()
|
modules/__init__.py
ADDED
File without changes
|
modules/custom_offloading_utils.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import ThreadPoolExecutor
|
2 |
+
import gc
|
3 |
+
import time
|
4 |
+
from typing import Optional
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
def clean_memory_on_device(device: torch.device):
|
10 |
+
r"""
|
11 |
+
Clean memory on the specified device, will be called from training scripts.
|
12 |
+
"""
|
13 |
+
gc.collect()
|
14 |
+
|
15 |
+
# device may "cuda" or "cuda:0", so we need to check the type of device
|
16 |
+
if device.type == "cuda":
|
17 |
+
torch.cuda.empty_cache()
|
18 |
+
if device.type == "xpu":
|
19 |
+
torch.xpu.empty_cache()
|
20 |
+
if device.type == "mps":
|
21 |
+
torch.mps.empty_cache()
|
22 |
+
|
23 |
+
|
24 |
+
def synchronize_device(device: torch.device):
|
25 |
+
if device.type == "cuda":
|
26 |
+
torch.cuda.synchronize()
|
27 |
+
elif device.type == "xpu":
|
28 |
+
torch.xpu.synchronize()
|
29 |
+
elif device.type == "mps":
|
30 |
+
torch.mps.synchronize()
|
31 |
+
|
32 |
+
|
33 |
+
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
34 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
35 |
+
|
36 |
+
weight_swap_jobs = []
|
37 |
+
|
38 |
+
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
39 |
+
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
40 |
+
# print(module_to_cpu.__class__, module_to_cuda.__class__)
|
41 |
+
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
42 |
+
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
43 |
+
|
44 |
+
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
|
45 |
+
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
|
46 |
+
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
|
47 |
+
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
|
48 |
+
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
|
49 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
50 |
+
else:
|
51 |
+
if module_to_cuda.weight.data.device.type != device.type:
|
52 |
+
# print(
|
53 |
+
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
|
54 |
+
# )
|
55 |
+
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
|
56 |
+
|
57 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
58 |
+
|
59 |
+
stream = torch.cuda.Stream()
|
60 |
+
with torch.cuda.stream(stream):
|
61 |
+
# cuda to cpu
|
62 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
63 |
+
cuda_data_view.record_stream(stream)
|
64 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
65 |
+
|
66 |
+
stream.synchronize()
|
67 |
+
|
68 |
+
# cpu to cuda
|
69 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
70 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
71 |
+
module_to_cuda.weight.data = cuda_data_view
|
72 |
+
|
73 |
+
stream.synchronize()
|
74 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
75 |
+
|
76 |
+
|
77 |
+
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
78 |
+
"""
|
79 |
+
not tested
|
80 |
+
"""
|
81 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
82 |
+
|
83 |
+
weight_swap_jobs = []
|
84 |
+
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
85 |
+
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
86 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
87 |
+
|
88 |
+
# device to cpu
|
89 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
90 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
91 |
+
|
92 |
+
synchronize_device()
|
93 |
+
|
94 |
+
# cpu to device
|
95 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
96 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
97 |
+
module_to_cuda.weight.data = cuda_data_view
|
98 |
+
|
99 |
+
synchronize_device()
|
100 |
+
|
101 |
+
|
102 |
+
def weighs_to_device(layer: nn.Module, device: torch.device):
|
103 |
+
for module in layer.modules():
|
104 |
+
if hasattr(module, "weight") and module.weight is not None:
|
105 |
+
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
106 |
+
|
107 |
+
|
108 |
+
class Offloader:
|
109 |
+
"""
|
110 |
+
common offloading class
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
114 |
+
self.block_type = block_type
|
115 |
+
self.num_blocks = num_blocks
|
116 |
+
self.blocks_to_swap = blocks_to_swap
|
117 |
+
self.device = device
|
118 |
+
self.debug = debug
|
119 |
+
|
120 |
+
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
121 |
+
self.futures = {}
|
122 |
+
self.cuda_available = device.type == "cuda"
|
123 |
+
|
124 |
+
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
|
125 |
+
if self.cuda_available:
|
126 |
+
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
|
127 |
+
else:
|
128 |
+
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
|
129 |
+
|
130 |
+
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
|
131 |
+
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
132 |
+
if self.debug:
|
133 |
+
start_time = time.perf_counter()
|
134 |
+
print(
|
135 |
+
f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}"
|
136 |
+
)
|
137 |
+
|
138 |
+
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
139 |
+
|
140 |
+
if self.debug:
|
141 |
+
print(f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
142 |
+
return bidx_to_cpu, bidx_to_cuda # , event
|
143 |
+
|
144 |
+
block_to_cpu = blocks[block_idx_to_cpu]
|
145 |
+
block_to_cuda = blocks[block_idx_to_cuda]
|
146 |
+
|
147 |
+
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
|
148 |
+
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
|
149 |
+
)
|
150 |
+
|
151 |
+
def _wait_blocks_move(self, block_idx):
|
152 |
+
if block_idx not in self.futures:
|
153 |
+
return
|
154 |
+
|
155 |
+
if self.debug:
|
156 |
+
print(f"[{self.block_type}] Wait for block {block_idx}")
|
157 |
+
start_time = time.perf_counter()
|
158 |
+
|
159 |
+
future = self.futures.pop(block_idx)
|
160 |
+
_, bidx_to_cuda = future.result()
|
161 |
+
|
162 |
+
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
163 |
+
|
164 |
+
if self.debug:
|
165 |
+
print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
166 |
+
|
167 |
+
|
168 |
+
class ModelOffloader(Offloader):
|
169 |
+
"""
|
170 |
+
supports forward offloading
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
block_type: str,
|
176 |
+
blocks: list[nn.Module],
|
177 |
+
num_blocks: int,
|
178 |
+
blocks_to_swap: int,
|
179 |
+
supports_backward: bool,
|
180 |
+
device: torch.device,
|
181 |
+
debug: bool = False,
|
182 |
+
):
|
183 |
+
super().__init__(block_type, num_blocks, blocks_to_swap, device, debug)
|
184 |
+
|
185 |
+
self.supports_backward = supports_backward
|
186 |
+
self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
|
187 |
+
|
188 |
+
if self.supports_backward:
|
189 |
+
# register backward hooks
|
190 |
+
self.remove_handles = []
|
191 |
+
for i, block in enumerate(blocks):
|
192 |
+
hook = self.create_backward_hook(blocks, i)
|
193 |
+
if hook is not None:
|
194 |
+
handle = block.register_full_backward_hook(hook)
|
195 |
+
self.remove_handles.append(handle)
|
196 |
+
|
197 |
+
def set_forward_only(self, forward_only: bool):
|
198 |
+
self.forward_only = forward_only
|
199 |
+
|
200 |
+
def __del__(self):
|
201 |
+
if self.supports_backward:
|
202 |
+
for handle in self.remove_handles:
|
203 |
+
handle.remove()
|
204 |
+
|
205 |
+
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
206 |
+
# -1 for 0-based index
|
207 |
+
num_blocks_propagated = self.num_blocks - block_index - 1
|
208 |
+
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
209 |
+
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
210 |
+
|
211 |
+
if not swapping and not waiting:
|
212 |
+
return None
|
213 |
+
|
214 |
+
# create hook
|
215 |
+
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
216 |
+
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
217 |
+
block_idx_to_wait = block_index - 1
|
218 |
+
|
219 |
+
def backward_hook(module, grad_input, grad_output):
|
220 |
+
if self.debug:
|
221 |
+
print(f"Backward hook for block {block_index}")
|
222 |
+
|
223 |
+
if swapping:
|
224 |
+
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
225 |
+
if waiting:
|
226 |
+
self._wait_blocks_move(block_idx_to_wait)
|
227 |
+
return None
|
228 |
+
|
229 |
+
return backward_hook
|
230 |
+
|
231 |
+
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
232 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
233 |
+
return
|
234 |
+
|
235 |
+
if self.debug:
|
236 |
+
print(f"[{self.block_type}] Prepare block devices before forward")
|
237 |
+
|
238 |
+
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
239 |
+
b.to(self.device)
|
240 |
+
weighs_to_device(b, self.device) # make sure weights are on device
|
241 |
+
|
242 |
+
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
243 |
+
b.to(self.device) # move block to device first
|
244 |
+
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
245 |
+
|
246 |
+
synchronize_device(self.device)
|
247 |
+
clean_memory_on_device(self.device)
|
248 |
+
|
249 |
+
def wait_for_block(self, block_idx: int):
|
250 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
251 |
+
return
|
252 |
+
self._wait_blocks_move(block_idx)
|
253 |
+
|
254 |
+
def submit_move_blocks_forward(self, blocks: list[nn.Module], block_idx: int):
|
255 |
+
# check if blocks_to_swap is enabled
|
256 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
257 |
+
return
|
258 |
+
|
259 |
+
# if supports_backward and backward is enabled, we swap blocks more than blocks_to_swap in backward pass
|
260 |
+
if not self.forward_only and block_idx >= self.blocks_to_swap:
|
261 |
+
return
|
262 |
+
|
263 |
+
block_idx_to_cpu = block_idx
|
264 |
+
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
|
265 |
+
block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading
|
266 |
+
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
modules/scheduling_flow_match_discrete.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
|
26 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
27 |
+
from diffusers.utils import BaseOutput, logging
|
28 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
|
36 |
+
"""
|
37 |
+
Output class for the scheduler's `step` function output.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
42 |
+
denoising loop.
|
43 |
+
"""
|
44 |
+
|
45 |
+
prev_sample: torch.FloatTensor
|
46 |
+
|
47 |
+
|
48 |
+
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
49 |
+
"""
|
50 |
+
Euler scheduler.
|
51 |
+
|
52 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
53 |
+
methods the library implements for all schedulers such as loading and saving.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
num_train_timesteps (`int`, defaults to 1000):
|
57 |
+
The number of diffusion steps to train the model.
|
58 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
59 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
60 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
61 |
+
shift (`float`, defaults to 1.0):
|
62 |
+
The shift value for the timestep schedule.
|
63 |
+
reverse (`bool`, defaults to `True`):
|
64 |
+
Whether to reverse the timestep schedule.
|
65 |
+
"""
|
66 |
+
|
67 |
+
_compatibles = []
|
68 |
+
order = 1
|
69 |
+
|
70 |
+
@register_to_config
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
num_train_timesteps: int = 1000,
|
74 |
+
shift: float = 1.0,
|
75 |
+
reverse: bool = True,
|
76 |
+
solver: str = "euler",
|
77 |
+
n_tokens: Optional[int] = None,
|
78 |
+
):
|
79 |
+
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
|
80 |
+
|
81 |
+
if not reverse:
|
82 |
+
sigmas = sigmas.flip(0)
|
83 |
+
|
84 |
+
self.sigmas = sigmas
|
85 |
+
# the value fed to model
|
86 |
+
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
|
87 |
+
|
88 |
+
self._step_index = None
|
89 |
+
self._begin_index = None
|
90 |
+
|
91 |
+
self.supported_solver = ["euler"]
|
92 |
+
if solver not in self.supported_solver:
|
93 |
+
raise ValueError(
|
94 |
+
f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
|
95 |
+
)
|
96 |
+
|
97 |
+
@property
|
98 |
+
def step_index(self):
|
99 |
+
"""
|
100 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
101 |
+
"""
|
102 |
+
return self._step_index
|
103 |
+
|
104 |
+
@property
|
105 |
+
def begin_index(self):
|
106 |
+
"""
|
107 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
108 |
+
"""
|
109 |
+
return self._begin_index
|
110 |
+
|
111 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
112 |
+
def set_begin_index(self, begin_index: int = 0):
|
113 |
+
"""
|
114 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
begin_index (`int`):
|
118 |
+
The begin index for the scheduler.
|
119 |
+
"""
|
120 |
+
self._begin_index = begin_index
|
121 |
+
|
122 |
+
def _sigma_to_t(self, sigma):
|
123 |
+
return sigma * self.config.num_train_timesteps
|
124 |
+
|
125 |
+
def set_timesteps(
|
126 |
+
self,
|
127 |
+
num_inference_steps: int,
|
128 |
+
device: Union[str, torch.device] = None,
|
129 |
+
n_tokens: int = None,
|
130 |
+
):
|
131 |
+
"""
|
132 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
133 |
+
|
134 |
+
Args:
|
135 |
+
num_inference_steps (`int`):
|
136 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
137 |
+
device (`str` or `torch.device`, *optional*):
|
138 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
139 |
+
n_tokens (`int`, *optional*):
|
140 |
+
Number of tokens in the input sequence.
|
141 |
+
"""
|
142 |
+
self.num_inference_steps = num_inference_steps
|
143 |
+
|
144 |
+
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
|
145 |
+
sigmas = self.sd3_time_shift(sigmas)
|
146 |
+
|
147 |
+
if not self.config.reverse:
|
148 |
+
sigmas = 1 - sigmas
|
149 |
+
|
150 |
+
self.sigmas = sigmas
|
151 |
+
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
|
152 |
+
dtype=torch.float32, device=device
|
153 |
+
)
|
154 |
+
|
155 |
+
# Reset step index
|
156 |
+
self._step_index = None
|
157 |
+
|
158 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
159 |
+
if schedule_timesteps is None:
|
160 |
+
schedule_timesteps = self.timesteps
|
161 |
+
|
162 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
163 |
+
|
164 |
+
# The sigma index that is taken for the **very** first `step`
|
165 |
+
# is always the second index (or the last index if there is only 1)
|
166 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
167 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
168 |
+
pos = 1 if len(indices) > 1 else 0
|
169 |
+
|
170 |
+
return indices[pos].item()
|
171 |
+
|
172 |
+
def _init_step_index(self, timestep):
|
173 |
+
if self.begin_index is None:
|
174 |
+
if isinstance(timestep, torch.Tensor):
|
175 |
+
timestep = timestep.to(self.timesteps.device)
|
176 |
+
self._step_index = self.index_for_timestep(timestep)
|
177 |
+
else:
|
178 |
+
self._step_index = self._begin_index
|
179 |
+
|
180 |
+
def scale_model_input(
|
181 |
+
self, sample: torch.Tensor, timestep: Optional[int] = None
|
182 |
+
) -> torch.Tensor:
|
183 |
+
return sample
|
184 |
+
|
185 |
+
def sd3_time_shift(self, t: torch.Tensor):
|
186 |
+
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
|
187 |
+
|
188 |
+
def step(
|
189 |
+
self,
|
190 |
+
model_output: torch.FloatTensor,
|
191 |
+
timestep: Union[float, torch.FloatTensor],
|
192 |
+
sample: torch.FloatTensor,
|
193 |
+
return_dict: bool = True,
|
194 |
+
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
|
195 |
+
"""
|
196 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
197 |
+
process from the learned model outputs (most often the predicted noise).
|
198 |
+
|
199 |
+
Args:
|
200 |
+
model_output (`torch.FloatTensor`):
|
201 |
+
The direct output from learned diffusion model.
|
202 |
+
timestep (`float`):
|
203 |
+
The current discrete timestep in the diffusion chain.
|
204 |
+
sample (`torch.FloatTensor`):
|
205 |
+
A current instance of a sample created by the diffusion process.
|
206 |
+
generator (`torch.Generator`, *optional*):
|
207 |
+
A random number generator.
|
208 |
+
n_tokens (`int`, *optional*):
|
209 |
+
Number of tokens in the input sequence.
|
210 |
+
return_dict (`bool`):
|
211 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
212 |
+
tuple.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
216 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
217 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
218 |
+
"""
|
219 |
+
|
220 |
+
if (
|
221 |
+
isinstance(timestep, int)
|
222 |
+
or isinstance(timestep, torch.IntTensor)
|
223 |
+
or isinstance(timestep, torch.LongTensor)
|
224 |
+
):
|
225 |
+
raise ValueError(
|
226 |
+
(
|
227 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
228 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
229 |
+
" one of the `scheduler.timesteps` as a timestep."
|
230 |
+
),
|
231 |
+
)
|
232 |
+
|
233 |
+
if self.step_index is None:
|
234 |
+
self._init_step_index(timestep)
|
235 |
+
|
236 |
+
# Upcast to avoid precision issues when computing prev_sample
|
237 |
+
sample = sample.to(torch.float32)
|
238 |
+
|
239 |
+
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
|
240 |
+
|
241 |
+
if self.config.solver == "euler":
|
242 |
+
prev_sample = sample + model_output.to(torch.float32) * dt
|
243 |
+
else:
|
244 |
+
raise ValueError(
|
245 |
+
f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
|
246 |
+
)
|
247 |
+
|
248 |
+
# upon completion increase step index by one
|
249 |
+
self._step_index += 1
|
250 |
+
|
251 |
+
if not return_dict:
|
252 |
+
return (prev_sample,)
|
253 |
+
|
254 |
+
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
|
255 |
+
|
256 |
+
def __len__(self):
|
257 |
+
return self.config.num_train_timesteps
|
modules/unet_causal_3d_blocks.py
ADDED
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
|
20 |
+
from typing import Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from torch import nn
|
25 |
+
from einops import rearrange
|
26 |
+
|
27 |
+
from diffusers.utils import logging
|
28 |
+
from diffusers.models.activations import get_activation
|
29 |
+
from diffusers.models.attention_processor import SpatialNorm
|
30 |
+
from diffusers.models.attention_processor import Attention
|
31 |
+
from diffusers.models.normalization import AdaGroupNorm
|
32 |
+
from diffusers.models.normalization import RMSNorm
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
|
37 |
+
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
|
38 |
+
seq_len = n_frame * n_hw
|
39 |
+
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
40 |
+
for i in range(seq_len):
|
41 |
+
i_frame = i // n_hw
|
42 |
+
mask[i, : (i_frame + 1) * n_hw] = 0
|
43 |
+
if batch_size is not None:
|
44 |
+
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
45 |
+
return mask
|
46 |
+
|
47 |
+
|
48 |
+
class CausalConv3d(nn.Module):
|
49 |
+
"""
|
50 |
+
Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
|
51 |
+
This maintains temporal causality in video generation tasks.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
chan_in,
|
57 |
+
chan_out,
|
58 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
59 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
60 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
61 |
+
pad_mode="replicate",
|
62 |
+
chunk_size=0,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
self.pad_mode = pad_mode
|
68 |
+
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
|
69 |
+
self.time_causal_padding = padding
|
70 |
+
self.chunk_size = chunk_size
|
71 |
+
|
72 |
+
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
73 |
+
|
74 |
+
def original_forward(self, x):
|
75 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
76 |
+
return self.conv(x)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
if self.chunk_size == 0:
|
80 |
+
return self.original_forward(x)
|
81 |
+
|
82 |
+
# if not large, call original forward
|
83 |
+
if x.shape[4] < self.chunk_size * 1.5:
|
84 |
+
return self.original_forward(x)
|
85 |
+
|
86 |
+
# # debug: verify the original forward is the same as chunked forward
|
87 |
+
# orig_forwarded_value = None
|
88 |
+
# if x.shape[4] < self.chunk_size * 4:
|
89 |
+
# orig_forwarded_value = self.original_forward(x)
|
90 |
+
|
91 |
+
# get the kernel size
|
92 |
+
kernel_size = self.conv.kernel_size[0] # assume cubic kernel
|
93 |
+
assert kernel_size == self.conv.kernel_size[1] == self.conv.kernel_size[2], "Only cubic kernels are supported"
|
94 |
+
padding_size = kernel_size // 2 # 1 for kernel_size=3, 0 for kernel_size=1
|
95 |
+
|
96 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
97 |
+
|
98 |
+
B, C, D, H, W = orig_shape = x.shape
|
99 |
+
chunk_size = self.chunk_size
|
100 |
+
chunk_size -= chunk_size % self.conv.stride[2] # make sure the chunk size is divisible by stride
|
101 |
+
# print(f"chunked forward: {x.shape}, chunk_size: {chunk_size}")
|
102 |
+
|
103 |
+
# calculate the indices for chunking with overlap and padding by kernel size and stride
|
104 |
+
indices = []
|
105 |
+
i = 0
|
106 |
+
while i < W - padding_size:
|
107 |
+
start_idx = i - padding_size
|
108 |
+
end_idx = min(i + chunk_size + padding_size, W)
|
109 |
+
if i == 0:
|
110 |
+
start_idx = 0
|
111 |
+
end_idx += padding_size # to make sure the first chunk is divisible by stride
|
112 |
+
if W - end_idx < chunk_size // 2: # small chunk at the end
|
113 |
+
end_idx = W
|
114 |
+
indices.append((start_idx, end_idx))
|
115 |
+
i = end_idx - padding_size
|
116 |
+
# print(f"chunked forward: {x.shape}, chunked indices: {indices}")
|
117 |
+
|
118 |
+
chunks = []
|
119 |
+
for start_idx, end_idx in indices:
|
120 |
+
chunk = x[:, :, :, :, start_idx:end_idx]
|
121 |
+
chunk_output = self.conv(chunk)
|
122 |
+
# print(chunk.shape, chunk_output.shape)
|
123 |
+
chunks.append(chunk_output)
|
124 |
+
|
125 |
+
# concatenate the chunks
|
126 |
+
x = torch.cat(chunks, dim=4)
|
127 |
+
|
128 |
+
assert (
|
129 |
+
x.shape[2] == ((D - padding_size * 2) + self.conv.stride[0] - 1) // self.conv.stride[0]
|
130 |
+
), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
|
131 |
+
assert (
|
132 |
+
x.shape[3] == ((H - padding_size * 2) + self.conv.stride[1] - 1) // self.conv.stride[1]
|
133 |
+
), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
|
134 |
+
assert (
|
135 |
+
x.shape[4] == ((W - padding_size * 2) + self.conv.stride[2] - 1) // self.conv.stride[2]
|
136 |
+
), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
|
137 |
+
|
138 |
+
# # debug: verify the original forward is the same as chunked forward
|
139 |
+
# if orig_forwarded_value is not None:
|
140 |
+
# assert torch.allclose(
|
141 |
+
# orig_forwarded_value, x, rtol=1e-4, atol=1e-2
|
142 |
+
# ), f"Chunked forward is different from original forward. {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}, {self.conv.kernel_size}"
|
143 |
+
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class UpsampleCausal3D(nn.Module):
|
148 |
+
"""
|
149 |
+
A 3D upsampling layer with an optional convolution.
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
channels: int,
|
155 |
+
use_conv: bool = False,
|
156 |
+
use_conv_transpose: bool = False,
|
157 |
+
out_channels: Optional[int] = None,
|
158 |
+
name: str = "conv",
|
159 |
+
kernel_size: Optional[int] = None,
|
160 |
+
padding=1,
|
161 |
+
norm_type=None,
|
162 |
+
eps=None,
|
163 |
+
elementwise_affine=None,
|
164 |
+
bias=True,
|
165 |
+
interpolate=True,
|
166 |
+
upsample_factor=(2, 2, 2),
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
self.channels = channels
|
170 |
+
self.out_channels = out_channels or channels
|
171 |
+
self.use_conv = use_conv
|
172 |
+
self.use_conv_transpose = use_conv_transpose
|
173 |
+
self.name = name
|
174 |
+
self.interpolate = interpolate
|
175 |
+
self.upsample_factor = upsample_factor
|
176 |
+
|
177 |
+
if norm_type == "ln_norm":
|
178 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
179 |
+
elif norm_type == "rms_norm":
|
180 |
+
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
181 |
+
elif norm_type is None:
|
182 |
+
self.norm = None
|
183 |
+
else:
|
184 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
185 |
+
|
186 |
+
conv = None
|
187 |
+
if use_conv_transpose:
|
188 |
+
raise NotImplementedError
|
189 |
+
elif use_conv:
|
190 |
+
if kernel_size is None:
|
191 |
+
kernel_size = 3
|
192 |
+
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
193 |
+
|
194 |
+
if name == "conv":
|
195 |
+
self.conv = conv
|
196 |
+
else:
|
197 |
+
self.Conv2d_0 = conv
|
198 |
+
|
199 |
+
def forward(
|
200 |
+
self,
|
201 |
+
hidden_states: torch.FloatTensor,
|
202 |
+
output_size: Optional[int] = None,
|
203 |
+
scale: float = 1.0,
|
204 |
+
) -> torch.FloatTensor:
|
205 |
+
assert hidden_states.shape[1] == self.channels
|
206 |
+
|
207 |
+
if self.norm is not None:
|
208 |
+
raise NotImplementedError
|
209 |
+
|
210 |
+
if self.use_conv_transpose:
|
211 |
+
return self.conv(hidden_states)
|
212 |
+
|
213 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
214 |
+
dtype = hidden_states.dtype
|
215 |
+
if dtype == torch.bfloat16:
|
216 |
+
hidden_states = hidden_states.to(torch.float32)
|
217 |
+
|
218 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
219 |
+
if hidden_states.shape[0] >= 64:
|
220 |
+
hidden_states = hidden_states.contiguous()
|
221 |
+
|
222 |
+
# if `output_size` is passed we force the interpolation output
|
223 |
+
# size and do not make use of `scale_factor=2`
|
224 |
+
if self.interpolate:
|
225 |
+
B, C, T, H, W = hidden_states.shape
|
226 |
+
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
227 |
+
if output_size is None:
|
228 |
+
if T > 1:
|
229 |
+
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
230 |
+
|
231 |
+
first_h = first_h.squeeze(2)
|
232 |
+
first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
|
233 |
+
first_h = first_h.unsqueeze(2)
|
234 |
+
else:
|
235 |
+
raise NotImplementedError
|
236 |
+
|
237 |
+
if T > 1:
|
238 |
+
hidden_states = torch.cat((first_h, other_h), dim=2)
|
239 |
+
else:
|
240 |
+
hidden_states = first_h
|
241 |
+
|
242 |
+
# If the input is bfloat16, we cast back to bfloat16
|
243 |
+
if dtype == torch.bfloat16:
|
244 |
+
hidden_states = hidden_states.to(dtype)
|
245 |
+
|
246 |
+
if self.use_conv:
|
247 |
+
if self.name == "conv":
|
248 |
+
hidden_states = self.conv(hidden_states)
|
249 |
+
else:
|
250 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
251 |
+
|
252 |
+
return hidden_states
|
253 |
+
|
254 |
+
|
255 |
+
class DownsampleCausal3D(nn.Module):
|
256 |
+
"""
|
257 |
+
A 3D downsampling layer with an optional convolution.
|
258 |
+
"""
|
259 |
+
|
260 |
+
def __init__(
|
261 |
+
self,
|
262 |
+
channels: int,
|
263 |
+
use_conv: bool = False,
|
264 |
+
out_channels: Optional[int] = None,
|
265 |
+
padding: int = 1,
|
266 |
+
name: str = "conv",
|
267 |
+
kernel_size=3,
|
268 |
+
norm_type=None,
|
269 |
+
eps=None,
|
270 |
+
elementwise_affine=None,
|
271 |
+
bias=True,
|
272 |
+
stride=2,
|
273 |
+
):
|
274 |
+
super().__init__()
|
275 |
+
self.channels = channels
|
276 |
+
self.out_channels = out_channels or channels
|
277 |
+
self.use_conv = use_conv
|
278 |
+
self.padding = padding
|
279 |
+
stride = stride
|
280 |
+
self.name = name
|
281 |
+
|
282 |
+
if norm_type == "ln_norm":
|
283 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
284 |
+
elif norm_type == "rms_norm":
|
285 |
+
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
286 |
+
elif norm_type is None:
|
287 |
+
self.norm = None
|
288 |
+
else:
|
289 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
290 |
+
|
291 |
+
if use_conv:
|
292 |
+
conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
|
293 |
+
else:
|
294 |
+
raise NotImplementedError
|
295 |
+
|
296 |
+
if name == "conv":
|
297 |
+
self.Conv2d_0 = conv
|
298 |
+
self.conv = conv
|
299 |
+
elif name == "Conv2d_0":
|
300 |
+
self.conv = conv
|
301 |
+
else:
|
302 |
+
self.conv = conv
|
303 |
+
|
304 |
+
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
305 |
+
assert hidden_states.shape[1] == self.channels
|
306 |
+
|
307 |
+
if self.norm is not None:
|
308 |
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
309 |
+
|
310 |
+
assert hidden_states.shape[1] == self.channels
|
311 |
+
|
312 |
+
hidden_states = self.conv(hidden_states)
|
313 |
+
|
314 |
+
return hidden_states
|
315 |
+
|
316 |
+
|
317 |
+
class ResnetBlockCausal3D(nn.Module):
|
318 |
+
r"""
|
319 |
+
A Resnet block.
|
320 |
+
"""
|
321 |
+
|
322 |
+
def __init__(
|
323 |
+
self,
|
324 |
+
*,
|
325 |
+
in_channels: int,
|
326 |
+
out_channels: Optional[int] = None,
|
327 |
+
conv_shortcut: bool = False,
|
328 |
+
dropout: float = 0.0,
|
329 |
+
temb_channels: int = 512,
|
330 |
+
groups: int = 32,
|
331 |
+
groups_out: Optional[int] = None,
|
332 |
+
pre_norm: bool = True,
|
333 |
+
eps: float = 1e-6,
|
334 |
+
non_linearity: str = "swish",
|
335 |
+
skip_time_act: bool = False,
|
336 |
+
# default, scale_shift, ada_group, spatial
|
337 |
+
time_embedding_norm: str = "default",
|
338 |
+
kernel: Optional[torch.FloatTensor] = None,
|
339 |
+
output_scale_factor: float = 1.0,
|
340 |
+
use_in_shortcut: Optional[bool] = None,
|
341 |
+
up: bool = False,
|
342 |
+
down: bool = False,
|
343 |
+
conv_shortcut_bias: bool = True,
|
344 |
+
conv_3d_out_channels: Optional[int] = None,
|
345 |
+
):
|
346 |
+
super().__init__()
|
347 |
+
self.pre_norm = pre_norm
|
348 |
+
self.pre_norm = True
|
349 |
+
self.in_channels = in_channels
|
350 |
+
out_channels = in_channels if out_channels is None else out_channels
|
351 |
+
self.out_channels = out_channels
|
352 |
+
self.use_conv_shortcut = conv_shortcut
|
353 |
+
self.up = up
|
354 |
+
self.down = down
|
355 |
+
self.output_scale_factor = output_scale_factor
|
356 |
+
self.time_embedding_norm = time_embedding_norm
|
357 |
+
self.skip_time_act = skip_time_act
|
358 |
+
|
359 |
+
linear_cls = nn.Linear
|
360 |
+
|
361 |
+
if groups_out is None:
|
362 |
+
groups_out = groups
|
363 |
+
|
364 |
+
if self.time_embedding_norm == "ada_group":
|
365 |
+
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
|
366 |
+
elif self.time_embedding_norm == "spatial":
|
367 |
+
self.norm1 = SpatialNorm(in_channels, temb_channels)
|
368 |
+
else:
|
369 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
370 |
+
|
371 |
+
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
372 |
+
|
373 |
+
if temb_channels is not None:
|
374 |
+
if self.time_embedding_norm == "default":
|
375 |
+
self.time_emb_proj = linear_cls(temb_channels, out_channels)
|
376 |
+
elif self.time_embedding_norm == "scale_shift":
|
377 |
+
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
|
378 |
+
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
379 |
+
self.time_emb_proj = None
|
380 |
+
else:
|
381 |
+
raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
|
382 |
+
else:
|
383 |
+
self.time_emb_proj = None
|
384 |
+
|
385 |
+
if self.time_embedding_norm == "ada_group":
|
386 |
+
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
387 |
+
elif self.time_embedding_norm == "spatial":
|
388 |
+
self.norm2 = SpatialNorm(out_channels, temb_channels)
|
389 |
+
else:
|
390 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
391 |
+
|
392 |
+
self.dropout = torch.nn.Dropout(dropout)
|
393 |
+
conv_3d_out_channels = conv_3d_out_channels or out_channels
|
394 |
+
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
|
395 |
+
|
396 |
+
self.nonlinearity = get_activation(non_linearity)
|
397 |
+
|
398 |
+
self.upsample = self.downsample = None
|
399 |
+
if self.up:
|
400 |
+
self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
|
401 |
+
elif self.down:
|
402 |
+
self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
|
403 |
+
|
404 |
+
self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
|
405 |
+
|
406 |
+
self.conv_shortcut = None
|
407 |
+
if self.use_in_shortcut:
|
408 |
+
self.conv_shortcut = CausalConv3d(
|
409 |
+
in_channels,
|
410 |
+
conv_3d_out_channels,
|
411 |
+
kernel_size=1,
|
412 |
+
stride=1,
|
413 |
+
bias=conv_shortcut_bias,
|
414 |
+
)
|
415 |
+
|
416 |
+
def forward(
|
417 |
+
self,
|
418 |
+
input_tensor: torch.FloatTensor,
|
419 |
+
temb: torch.FloatTensor,
|
420 |
+
scale: float = 1.0,
|
421 |
+
) -> torch.FloatTensor:
|
422 |
+
hidden_states = input_tensor
|
423 |
+
|
424 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
425 |
+
hidden_states = self.norm1(hidden_states, temb)
|
426 |
+
else:
|
427 |
+
hidden_states = self.norm1(hidden_states)
|
428 |
+
|
429 |
+
hidden_states = self.nonlinearity(hidden_states)
|
430 |
+
|
431 |
+
if self.upsample is not None:
|
432 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
433 |
+
if hidden_states.shape[0] >= 64:
|
434 |
+
input_tensor = input_tensor.contiguous()
|
435 |
+
hidden_states = hidden_states.contiguous()
|
436 |
+
input_tensor = self.upsample(input_tensor, scale=scale)
|
437 |
+
hidden_states = self.upsample(hidden_states, scale=scale)
|
438 |
+
elif self.downsample is not None:
|
439 |
+
input_tensor = self.downsample(input_tensor, scale=scale)
|
440 |
+
hidden_states = self.downsample(hidden_states, scale=scale)
|
441 |
+
|
442 |
+
hidden_states = self.conv1(hidden_states)
|
443 |
+
|
444 |
+
if self.time_emb_proj is not None:
|
445 |
+
if not self.skip_time_act:
|
446 |
+
temb = self.nonlinearity(temb)
|
447 |
+
temb = self.time_emb_proj(temb, scale)[:, :, None, None]
|
448 |
+
|
449 |
+
if temb is not None and self.time_embedding_norm == "default":
|
450 |
+
hidden_states = hidden_states + temb
|
451 |
+
|
452 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
453 |
+
hidden_states = self.norm2(hidden_states, temb)
|
454 |
+
else:
|
455 |
+
hidden_states = self.norm2(hidden_states)
|
456 |
+
|
457 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
458 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
459 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
460 |
+
|
461 |
+
hidden_states = self.nonlinearity(hidden_states)
|
462 |
+
|
463 |
+
hidden_states = self.dropout(hidden_states)
|
464 |
+
hidden_states = self.conv2(hidden_states)
|
465 |
+
|
466 |
+
if self.conv_shortcut is not None:
|
467 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
468 |
+
|
469 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
470 |
+
|
471 |
+
return output_tensor
|
472 |
+
|
473 |
+
|
474 |
+
def get_down_block3d(
|
475 |
+
down_block_type: str,
|
476 |
+
num_layers: int,
|
477 |
+
in_channels: int,
|
478 |
+
out_channels: int,
|
479 |
+
temb_channels: int,
|
480 |
+
add_downsample: bool,
|
481 |
+
downsample_stride: int,
|
482 |
+
resnet_eps: float,
|
483 |
+
resnet_act_fn: str,
|
484 |
+
transformer_layers_per_block: int = 1,
|
485 |
+
num_attention_heads: Optional[int] = None,
|
486 |
+
resnet_groups: Optional[int] = None,
|
487 |
+
cross_attention_dim: Optional[int] = None,
|
488 |
+
downsample_padding: Optional[int] = None,
|
489 |
+
dual_cross_attention: bool = False,
|
490 |
+
use_linear_projection: bool = False,
|
491 |
+
only_cross_attention: bool = False,
|
492 |
+
upcast_attention: bool = False,
|
493 |
+
resnet_time_scale_shift: str = "default",
|
494 |
+
attention_type: str = "default",
|
495 |
+
resnet_skip_time_act: bool = False,
|
496 |
+
resnet_out_scale_factor: float = 1.0,
|
497 |
+
cross_attention_norm: Optional[str] = None,
|
498 |
+
attention_head_dim: Optional[int] = None,
|
499 |
+
downsample_type: Optional[str] = None,
|
500 |
+
dropout: float = 0.0,
|
501 |
+
):
|
502 |
+
# If attn head dim is not defined, we default it to the number of heads
|
503 |
+
if attention_head_dim is None:
|
504 |
+
logger.warn(
|
505 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
506 |
+
)
|
507 |
+
attention_head_dim = num_attention_heads
|
508 |
+
|
509 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
510 |
+
if down_block_type == "DownEncoderBlockCausal3D":
|
511 |
+
return DownEncoderBlockCausal3D(
|
512 |
+
num_layers=num_layers,
|
513 |
+
in_channels=in_channels,
|
514 |
+
out_channels=out_channels,
|
515 |
+
dropout=dropout,
|
516 |
+
add_downsample=add_downsample,
|
517 |
+
downsample_stride=downsample_stride,
|
518 |
+
resnet_eps=resnet_eps,
|
519 |
+
resnet_act_fn=resnet_act_fn,
|
520 |
+
resnet_groups=resnet_groups,
|
521 |
+
downsample_padding=downsample_padding,
|
522 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
523 |
+
)
|
524 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
525 |
+
|
526 |
+
|
527 |
+
def get_up_block3d(
|
528 |
+
up_block_type: str,
|
529 |
+
num_layers: int,
|
530 |
+
in_channels: int,
|
531 |
+
out_channels: int,
|
532 |
+
prev_output_channel: int,
|
533 |
+
temb_channels: int,
|
534 |
+
add_upsample: bool,
|
535 |
+
upsample_scale_factor: Tuple,
|
536 |
+
resnet_eps: float,
|
537 |
+
resnet_act_fn: str,
|
538 |
+
resolution_idx: Optional[int] = None,
|
539 |
+
transformer_layers_per_block: int = 1,
|
540 |
+
num_attention_heads: Optional[int] = None,
|
541 |
+
resnet_groups: Optional[int] = None,
|
542 |
+
cross_attention_dim: Optional[int] = None,
|
543 |
+
dual_cross_attention: bool = False,
|
544 |
+
use_linear_projection: bool = False,
|
545 |
+
only_cross_attention: bool = False,
|
546 |
+
upcast_attention: bool = False,
|
547 |
+
resnet_time_scale_shift: str = "default",
|
548 |
+
attention_type: str = "default",
|
549 |
+
resnet_skip_time_act: bool = False,
|
550 |
+
resnet_out_scale_factor: float = 1.0,
|
551 |
+
cross_attention_norm: Optional[str] = None,
|
552 |
+
attention_head_dim: Optional[int] = None,
|
553 |
+
upsample_type: Optional[str] = None,
|
554 |
+
dropout: float = 0.0,
|
555 |
+
) -> nn.Module:
|
556 |
+
# If attn head dim is not defined, we default it to the number of heads
|
557 |
+
if attention_head_dim is None:
|
558 |
+
logger.warn(
|
559 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
560 |
+
)
|
561 |
+
attention_head_dim = num_attention_heads
|
562 |
+
|
563 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
564 |
+
if up_block_type == "UpDecoderBlockCausal3D":
|
565 |
+
return UpDecoderBlockCausal3D(
|
566 |
+
num_layers=num_layers,
|
567 |
+
in_channels=in_channels,
|
568 |
+
out_channels=out_channels,
|
569 |
+
resolution_idx=resolution_idx,
|
570 |
+
dropout=dropout,
|
571 |
+
add_upsample=add_upsample,
|
572 |
+
upsample_scale_factor=upsample_scale_factor,
|
573 |
+
resnet_eps=resnet_eps,
|
574 |
+
resnet_act_fn=resnet_act_fn,
|
575 |
+
resnet_groups=resnet_groups,
|
576 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
577 |
+
temb_channels=temb_channels,
|
578 |
+
)
|
579 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
580 |
+
|
581 |
+
|
582 |
+
class UNetMidBlockCausal3D(nn.Module):
|
583 |
+
"""
|
584 |
+
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
|
585 |
+
"""
|
586 |
+
|
587 |
+
def __init__(
|
588 |
+
self,
|
589 |
+
in_channels: int,
|
590 |
+
temb_channels: int,
|
591 |
+
dropout: float = 0.0,
|
592 |
+
num_layers: int = 1,
|
593 |
+
resnet_eps: float = 1e-6,
|
594 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
595 |
+
resnet_act_fn: str = "swish",
|
596 |
+
resnet_groups: int = 32,
|
597 |
+
attn_groups: Optional[int] = None,
|
598 |
+
resnet_pre_norm: bool = True,
|
599 |
+
add_attention: bool = True,
|
600 |
+
attention_head_dim: int = 1,
|
601 |
+
output_scale_factor: float = 1.0,
|
602 |
+
):
|
603 |
+
super().__init__()
|
604 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
605 |
+
self.add_attention = add_attention
|
606 |
+
|
607 |
+
if attn_groups is None:
|
608 |
+
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
|
609 |
+
|
610 |
+
# there is always at least one resnet
|
611 |
+
resnets = [
|
612 |
+
ResnetBlockCausal3D(
|
613 |
+
in_channels=in_channels,
|
614 |
+
out_channels=in_channels,
|
615 |
+
temb_channels=temb_channels,
|
616 |
+
eps=resnet_eps,
|
617 |
+
groups=resnet_groups,
|
618 |
+
dropout=dropout,
|
619 |
+
time_embedding_norm=resnet_time_scale_shift,
|
620 |
+
non_linearity=resnet_act_fn,
|
621 |
+
output_scale_factor=output_scale_factor,
|
622 |
+
pre_norm=resnet_pre_norm,
|
623 |
+
)
|
624 |
+
]
|
625 |
+
attentions = []
|
626 |
+
|
627 |
+
if attention_head_dim is None:
|
628 |
+
logger.warn(
|
629 |
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
630 |
+
)
|
631 |
+
attention_head_dim = in_channels
|
632 |
+
|
633 |
+
for _ in range(num_layers):
|
634 |
+
if self.add_attention:
|
635 |
+
attentions.append(
|
636 |
+
Attention(
|
637 |
+
in_channels,
|
638 |
+
heads=in_channels // attention_head_dim,
|
639 |
+
dim_head=attention_head_dim,
|
640 |
+
rescale_output_factor=output_scale_factor,
|
641 |
+
eps=resnet_eps,
|
642 |
+
norm_num_groups=attn_groups,
|
643 |
+
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
|
644 |
+
residual_connection=True,
|
645 |
+
bias=True,
|
646 |
+
upcast_softmax=True,
|
647 |
+
_from_deprecated_attn_block=True,
|
648 |
+
)
|
649 |
+
)
|
650 |
+
else:
|
651 |
+
attentions.append(None)
|
652 |
+
|
653 |
+
resnets.append(
|
654 |
+
ResnetBlockCausal3D(
|
655 |
+
in_channels=in_channels,
|
656 |
+
out_channels=in_channels,
|
657 |
+
temb_channels=temb_channels,
|
658 |
+
eps=resnet_eps,
|
659 |
+
groups=resnet_groups,
|
660 |
+
dropout=dropout,
|
661 |
+
time_embedding_norm=resnet_time_scale_shift,
|
662 |
+
non_linearity=resnet_act_fn,
|
663 |
+
output_scale_factor=output_scale_factor,
|
664 |
+
pre_norm=resnet_pre_norm,
|
665 |
+
)
|
666 |
+
)
|
667 |
+
|
668 |
+
self.attentions = nn.ModuleList(attentions)
|
669 |
+
self.resnets = nn.ModuleList(resnets)
|
670 |
+
|
671 |
+
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
672 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
673 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
674 |
+
if attn is not None:
|
675 |
+
B, C, T, H, W = hidden_states.shape
|
676 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
677 |
+
attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
|
678 |
+
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
|
679 |
+
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
680 |
+
hidden_states = resnet(hidden_states, temb)
|
681 |
+
|
682 |
+
return hidden_states
|
683 |
+
|
684 |
+
|
685 |
+
class DownEncoderBlockCausal3D(nn.Module):
|
686 |
+
def __init__(
|
687 |
+
self,
|
688 |
+
in_channels: int,
|
689 |
+
out_channels: int,
|
690 |
+
dropout: float = 0.0,
|
691 |
+
num_layers: int = 1,
|
692 |
+
resnet_eps: float = 1e-6,
|
693 |
+
resnet_time_scale_shift: str = "default",
|
694 |
+
resnet_act_fn: str = "swish",
|
695 |
+
resnet_groups: int = 32,
|
696 |
+
resnet_pre_norm: bool = True,
|
697 |
+
output_scale_factor: float = 1.0,
|
698 |
+
add_downsample: bool = True,
|
699 |
+
downsample_stride: int = 2,
|
700 |
+
downsample_padding: int = 1,
|
701 |
+
):
|
702 |
+
super().__init__()
|
703 |
+
resnets = []
|
704 |
+
|
705 |
+
for i in range(num_layers):
|
706 |
+
in_channels = in_channels if i == 0 else out_channels
|
707 |
+
resnets.append(
|
708 |
+
ResnetBlockCausal3D(
|
709 |
+
in_channels=in_channels,
|
710 |
+
out_channels=out_channels,
|
711 |
+
temb_channels=None,
|
712 |
+
eps=resnet_eps,
|
713 |
+
groups=resnet_groups,
|
714 |
+
dropout=dropout,
|
715 |
+
time_embedding_norm=resnet_time_scale_shift,
|
716 |
+
non_linearity=resnet_act_fn,
|
717 |
+
output_scale_factor=output_scale_factor,
|
718 |
+
pre_norm=resnet_pre_norm,
|
719 |
+
)
|
720 |
+
)
|
721 |
+
|
722 |
+
self.resnets = nn.ModuleList(resnets)
|
723 |
+
|
724 |
+
if add_downsample:
|
725 |
+
self.downsamplers = nn.ModuleList(
|
726 |
+
[
|
727 |
+
DownsampleCausal3D(
|
728 |
+
out_channels,
|
729 |
+
use_conv=True,
|
730 |
+
out_channels=out_channels,
|
731 |
+
padding=downsample_padding,
|
732 |
+
name="op",
|
733 |
+
stride=downsample_stride,
|
734 |
+
)
|
735 |
+
]
|
736 |
+
)
|
737 |
+
else:
|
738 |
+
self.downsamplers = None
|
739 |
+
|
740 |
+
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
741 |
+
for resnet in self.resnets:
|
742 |
+
hidden_states = resnet(hidden_states, temb=None, scale=scale)
|
743 |
+
|
744 |
+
if self.downsamplers is not None:
|
745 |
+
for downsampler in self.downsamplers:
|
746 |
+
hidden_states = downsampler(hidden_states, scale)
|
747 |
+
|
748 |
+
return hidden_states
|
749 |
+
|
750 |
+
|
751 |
+
class UpDecoderBlockCausal3D(nn.Module):
|
752 |
+
def __init__(
|
753 |
+
self,
|
754 |
+
in_channels: int,
|
755 |
+
out_channels: int,
|
756 |
+
resolution_idx: Optional[int] = None,
|
757 |
+
dropout: float = 0.0,
|
758 |
+
num_layers: int = 1,
|
759 |
+
resnet_eps: float = 1e-6,
|
760 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
761 |
+
resnet_act_fn: str = "swish",
|
762 |
+
resnet_groups: int = 32,
|
763 |
+
resnet_pre_norm: bool = True,
|
764 |
+
output_scale_factor: float = 1.0,
|
765 |
+
add_upsample: bool = True,
|
766 |
+
upsample_scale_factor=(2, 2, 2),
|
767 |
+
temb_channels: Optional[int] = None,
|
768 |
+
):
|
769 |
+
super().__init__()
|
770 |
+
resnets = []
|
771 |
+
|
772 |
+
for i in range(num_layers):
|
773 |
+
input_channels = in_channels if i == 0 else out_channels
|
774 |
+
|
775 |
+
resnets.append(
|
776 |
+
ResnetBlockCausal3D(
|
777 |
+
in_channels=input_channels,
|
778 |
+
out_channels=out_channels,
|
779 |
+
temb_channels=temb_channels,
|
780 |
+
eps=resnet_eps,
|
781 |
+
groups=resnet_groups,
|
782 |
+
dropout=dropout,
|
783 |
+
time_embedding_norm=resnet_time_scale_shift,
|
784 |
+
non_linearity=resnet_act_fn,
|
785 |
+
output_scale_factor=output_scale_factor,
|
786 |
+
pre_norm=resnet_pre_norm,
|
787 |
+
)
|
788 |
+
)
|
789 |
+
|
790 |
+
self.resnets = nn.ModuleList(resnets)
|
791 |
+
|
792 |
+
if add_upsample:
|
793 |
+
self.upsamplers = nn.ModuleList(
|
794 |
+
[
|
795 |
+
UpsampleCausal3D(
|
796 |
+
out_channels,
|
797 |
+
use_conv=True,
|
798 |
+
out_channels=out_channels,
|
799 |
+
upsample_factor=upsample_scale_factor,
|
800 |
+
)
|
801 |
+
]
|
802 |
+
)
|
803 |
+
else:
|
804 |
+
self.upsamplers = None
|
805 |
+
|
806 |
+
self.resolution_idx = resolution_idx
|
807 |
+
|
808 |
+
def forward(
|
809 |
+
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
810 |
+
) -> torch.FloatTensor:
|
811 |
+
for resnet in self.resnets:
|
812 |
+
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
813 |
+
|
814 |
+
if self.upsamplers is not None:
|
815 |
+
for upsampler in self.upsamplers:
|
816 |
+
hidden_states = upsampler(hidden_states)
|
817 |
+
|
818 |
+
return hidden_states
|
networks/__init__.py
ADDED
File without changes
|
networks/lora.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module: currently conv2d is not fully supported
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
|
6 |
+
import ast
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import re
|
10 |
+
from typing import Dict, List, Optional, Type, Union
|
11 |
+
from diffusers import AutoencoderKL
|
12 |
+
from transformers import CLIPTextModel
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
|
17 |
+
import logging
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
|
22 |
+
HUNYUAN_TARGET_REPLACE_MODULES = ["MMDoubleStreamBlock", "MMSingleStreamBlock"]
|
23 |
+
|
24 |
+
|
25 |
+
class LoRAModule(torch.nn.Module):
|
26 |
+
"""
|
27 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
lora_name,
|
33 |
+
org_module: torch.nn.Module,
|
34 |
+
multiplier=1.0,
|
35 |
+
lora_dim=4,
|
36 |
+
alpha=1,
|
37 |
+
dropout=None,
|
38 |
+
rank_dropout=None,
|
39 |
+
module_dropout=None,
|
40 |
+
split_dims: Optional[List[int]] = None,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
if alpha == 0 or None, alpha is rank (no scaling).
|
44 |
+
|
45 |
+
split_dims is used to mimic the split qkv of multi-head attention.
|
46 |
+
"""
|
47 |
+
super().__init__()
|
48 |
+
self.lora_name = lora_name
|
49 |
+
|
50 |
+
if org_module.__class__.__name__ == "Conv2d":
|
51 |
+
in_dim = org_module.in_channels
|
52 |
+
out_dim = org_module.out_channels
|
53 |
+
else:
|
54 |
+
in_dim = org_module.in_features
|
55 |
+
out_dim = org_module.out_features
|
56 |
+
|
57 |
+
self.lora_dim = lora_dim
|
58 |
+
self.split_dims = split_dims
|
59 |
+
|
60 |
+
if split_dims is None:
|
61 |
+
if org_module.__class__.__name__ == "Conv2d":
|
62 |
+
kernel_size = org_module.kernel_size
|
63 |
+
stride = org_module.stride
|
64 |
+
padding = org_module.padding
|
65 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
66 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
67 |
+
else:
|
68 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
69 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
70 |
+
|
71 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
72 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
73 |
+
else:
|
74 |
+
# conv2d not supported
|
75 |
+
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
|
76 |
+
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
|
77 |
+
# print(f"split_dims: {split_dims}")
|
78 |
+
self.lora_down = torch.nn.ModuleList(
|
79 |
+
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
|
80 |
+
)
|
81 |
+
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
82 |
+
for lora_down in self.lora_down:
|
83 |
+
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
84 |
+
for lora_up in self.lora_up:
|
85 |
+
torch.nn.init.zeros_(lora_up.weight)
|
86 |
+
|
87 |
+
if type(alpha) == torch.Tensor:
|
88 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
89 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
90 |
+
self.scale = alpha / self.lora_dim
|
91 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # for save/load
|
92 |
+
|
93 |
+
# same as microsoft's
|
94 |
+
self.multiplier = multiplier
|
95 |
+
self.org_module = org_module # remove in applying
|
96 |
+
self.dropout = dropout
|
97 |
+
self.rank_dropout = rank_dropout
|
98 |
+
self.module_dropout = module_dropout
|
99 |
+
|
100 |
+
def apply_to(self):
|
101 |
+
self.org_forward = self.org_module.forward
|
102 |
+
self.org_module.forward = self.forward
|
103 |
+
del self.org_module
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
org_forwarded = self.org_forward(x)
|
107 |
+
|
108 |
+
# module dropout
|
109 |
+
if self.module_dropout is not None and self.training:
|
110 |
+
if torch.rand(1) < self.module_dropout:
|
111 |
+
return org_forwarded
|
112 |
+
|
113 |
+
if self.split_dims is None:
|
114 |
+
lx = self.lora_down(x)
|
115 |
+
|
116 |
+
# normal dropout
|
117 |
+
if self.dropout is not None and self.training:
|
118 |
+
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
119 |
+
|
120 |
+
# rank dropout
|
121 |
+
if self.rank_dropout is not None and self.training:
|
122 |
+
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
123 |
+
if len(lx.size()) == 3:
|
124 |
+
mask = mask.unsqueeze(1) # for Text Encoder
|
125 |
+
elif len(lx.size()) == 4:
|
126 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
127 |
+
lx = lx * mask
|
128 |
+
|
129 |
+
# scaling for rank dropout: treat as if the rank is changed
|
130 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
131 |
+
else:
|
132 |
+
scale = self.scale
|
133 |
+
|
134 |
+
lx = self.lora_up(lx)
|
135 |
+
|
136 |
+
return org_forwarded + lx * self.multiplier * scale
|
137 |
+
else:
|
138 |
+
lxs = [lora_down(x) for lora_down in self.lora_down]
|
139 |
+
|
140 |
+
# normal dropout
|
141 |
+
if self.dropout is not None and self.training:
|
142 |
+
lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
|
143 |
+
|
144 |
+
# rank dropout
|
145 |
+
if self.rank_dropout is not None and self.training:
|
146 |
+
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
|
147 |
+
for i in range(len(lxs)):
|
148 |
+
if len(lx.size()) == 3:
|
149 |
+
masks[i] = masks[i].unsqueeze(1)
|
150 |
+
elif len(lx.size()) == 4:
|
151 |
+
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
|
152 |
+
lxs[i] = lxs[i] * masks[i]
|
153 |
+
|
154 |
+
# scaling for rank dropout: treat as if the rank is changed
|
155 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
156 |
+
else:
|
157 |
+
scale = self.scale
|
158 |
+
|
159 |
+
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
|
160 |
+
|
161 |
+
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
162 |
+
|
163 |
+
|
164 |
+
class LoRAInfModule(LoRAModule):
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
lora_name,
|
168 |
+
org_module: torch.nn.Module,
|
169 |
+
multiplier=1.0,
|
170 |
+
lora_dim=4,
|
171 |
+
alpha=1,
|
172 |
+
**kwargs,
|
173 |
+
):
|
174 |
+
# no dropout for inference
|
175 |
+
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
176 |
+
|
177 |
+
self.org_module_ref = [org_module] # for reference
|
178 |
+
self.enabled = True
|
179 |
+
self.network: LoRANetwork = None
|
180 |
+
|
181 |
+
def set_network(self, network):
|
182 |
+
self.network = network
|
183 |
+
|
184 |
+
# merge weight to org_module
|
185 |
+
# def merge_to(self, sd, dtype, device, non_blocking=False):
|
186 |
+
# if torch.cuda.is_available():
|
187 |
+
# stream = torch.cuda.Stream(device=device)
|
188 |
+
# with torch.cuda.stream(stream):
|
189 |
+
# print(f"merge_to {self.lora_name}")
|
190 |
+
# self._merge_to(sd, dtype, device, non_blocking)
|
191 |
+
# torch.cuda.synchronize(device=device)
|
192 |
+
# print(f"merge_to {self.lora_name} done")
|
193 |
+
# torch.cuda.empty_cache()
|
194 |
+
# else:
|
195 |
+
# self._merge_to(sd, dtype, device, non_blocking)
|
196 |
+
|
197 |
+
def merge_to(self, sd, dtype, device, non_blocking=False):
|
198 |
+
# extract weight from org_module
|
199 |
+
org_sd = self.org_module.state_dict()
|
200 |
+
weight = org_sd["weight"]
|
201 |
+
org_dtype = weight.dtype
|
202 |
+
org_device = weight.device
|
203 |
+
weight = weight.to(device, dtype=torch.float, non_blocking=non_blocking) # for calculation
|
204 |
+
|
205 |
+
if dtype is None:
|
206 |
+
dtype = org_dtype
|
207 |
+
if device is None:
|
208 |
+
device = org_device
|
209 |
+
|
210 |
+
if self.split_dims is None:
|
211 |
+
# get up/down weight
|
212 |
+
down_weight = sd["lora_down.weight"].to(device, dtype=torch.float, non_blocking=non_blocking)
|
213 |
+
up_weight = sd["lora_up.weight"].to(device, dtype=torch.float, non_blocking=non_blocking)
|
214 |
+
|
215 |
+
# merge weight
|
216 |
+
if len(weight.size()) == 2:
|
217 |
+
# linear
|
218 |
+
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
219 |
+
elif down_weight.size()[2:4] == (1, 1):
|
220 |
+
# conv2d 1x1
|
221 |
+
weight = (
|
222 |
+
weight
|
223 |
+
+ self.multiplier
|
224 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
225 |
+
* self.scale
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
# conv2d 3x3
|
229 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
230 |
+
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
231 |
+
weight = weight + self.multiplier * conved * self.scale
|
232 |
+
|
233 |
+
# set weight to org_module
|
234 |
+
org_sd["weight"] = weight.to(org_device, dtype=dtype) # back to CPU without non_blocking
|
235 |
+
self.org_module.load_state_dict(org_sd)
|
236 |
+
else:
|
237 |
+
# split_dims
|
238 |
+
total_dims = sum(self.split_dims)
|
239 |
+
for i in range(len(self.split_dims)):
|
240 |
+
# get up/down weight
|
241 |
+
down_weight = sd[f"lora_down.{i}.weight"].to(device, torch.float, non_blocking=non_blocking) # (rank, in_dim)
|
242 |
+
up_weight = sd[f"lora_up.{i}.weight"].to(device, torch.float, non_blocking=non_blocking) # (split dim, rank)
|
243 |
+
|
244 |
+
# pad up_weight -> (total_dims, rank)
|
245 |
+
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
|
246 |
+
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
|
247 |
+
|
248 |
+
# merge weight
|
249 |
+
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
250 |
+
|
251 |
+
# set weight to org_module
|
252 |
+
org_sd["weight"] = weight.to(org_device, dtype) # back to CPU without non_blocking
|
253 |
+
self.org_module.load_state_dict(org_sd)
|
254 |
+
|
255 |
+
# return weight for merge
|
256 |
+
def get_weight(self, multiplier=None):
|
257 |
+
if multiplier is None:
|
258 |
+
multiplier = self.multiplier
|
259 |
+
|
260 |
+
# get up/down weight from module
|
261 |
+
up_weight = self.lora_up.weight.to(torch.float)
|
262 |
+
down_weight = self.lora_down.weight.to(torch.float)
|
263 |
+
|
264 |
+
# pre-calculated weight
|
265 |
+
if len(down_weight.size()) == 2:
|
266 |
+
# linear
|
267 |
+
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
268 |
+
elif down_weight.size()[2:4] == (1, 1):
|
269 |
+
# conv2d 1x1
|
270 |
+
weight = (
|
271 |
+
self.multiplier
|
272 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
273 |
+
* self.scale
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
# conv2d 3x3
|
277 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
278 |
+
weight = self.multiplier * conved * self.scale
|
279 |
+
|
280 |
+
return weight
|
281 |
+
|
282 |
+
def default_forward(self, x):
|
283 |
+
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
284 |
+
if self.split_dims is None:
|
285 |
+
lx = self.lora_down(x)
|
286 |
+
lx = self.lora_up(lx)
|
287 |
+
return self.org_forward(x) + lx * self.multiplier * self.scale
|
288 |
+
else:
|
289 |
+
lxs = [lora_down(x) for lora_down in self.lora_down]
|
290 |
+
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
|
291 |
+
return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
|
292 |
+
|
293 |
+
def forward(self, x):
|
294 |
+
if not self.enabled:
|
295 |
+
return self.org_forward(x)
|
296 |
+
return self.default_forward(x)
|
297 |
+
|
298 |
+
|
299 |
+
def create_arch_network(
|
300 |
+
multiplier: float,
|
301 |
+
network_dim: Optional[int],
|
302 |
+
network_alpha: Optional[float],
|
303 |
+
vae: nn.Module,
|
304 |
+
text_encoders: List[nn.Module],
|
305 |
+
unet: nn.Module,
|
306 |
+
neuron_dropout: Optional[float] = None,
|
307 |
+
**kwargs,
|
308 |
+
):
|
309 |
+
# add default exclude patterns
|
310 |
+
exclude_patterns = kwargs.get("exclude_patterns", None)
|
311 |
+
if exclude_patterns is None:
|
312 |
+
exclude_patterns = []
|
313 |
+
else:
|
314 |
+
exclude_patterns = ast.literal_eval(exclude_patterns)
|
315 |
+
|
316 |
+
# exclude if 'img_mod', 'txt_mod' or 'modulation' in the name
|
317 |
+
exclude_patterns.append(r".*(img_mod|txt_mod|modulation).*")
|
318 |
+
|
319 |
+
kwargs["exclude_patterns"] = exclude_patterns
|
320 |
+
|
321 |
+
return create_network(
|
322 |
+
HUNYUAN_TARGET_REPLACE_MODULES,
|
323 |
+
"lora_unet",
|
324 |
+
multiplier,
|
325 |
+
network_dim,
|
326 |
+
network_alpha,
|
327 |
+
vae,
|
328 |
+
text_encoders,
|
329 |
+
unet,
|
330 |
+
neuron_dropout=neuron_dropout,
|
331 |
+
**kwargs,
|
332 |
+
)
|
333 |
+
|
334 |
+
|
335 |
+
def create_network(
|
336 |
+
target_replace_modules: List[str],
|
337 |
+
prefix: str,
|
338 |
+
multiplier: float,
|
339 |
+
network_dim: Optional[int],
|
340 |
+
network_alpha: Optional[float],
|
341 |
+
vae: nn.Module,
|
342 |
+
text_encoders: List[nn.Module],
|
343 |
+
unet: nn.Module,
|
344 |
+
neuron_dropout: Optional[float] = None,
|
345 |
+
**kwargs,
|
346 |
+
):
|
347 |
+
""" architecture independent network creation """
|
348 |
+
if network_dim is None:
|
349 |
+
network_dim = 4 # default
|
350 |
+
if network_alpha is None:
|
351 |
+
network_alpha = 1.0
|
352 |
+
|
353 |
+
# extract dim/alpha for conv2d, and block dim
|
354 |
+
conv_dim = kwargs.get("conv_dim", None)
|
355 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
356 |
+
if conv_dim is not None:
|
357 |
+
conv_dim = int(conv_dim)
|
358 |
+
if conv_alpha is None:
|
359 |
+
conv_alpha = 1.0
|
360 |
+
else:
|
361 |
+
conv_alpha = float(conv_alpha)
|
362 |
+
|
363 |
+
# TODO generic rank/dim setting with regular expression
|
364 |
+
|
365 |
+
# rank/module dropout
|
366 |
+
rank_dropout = kwargs.get("rank_dropout", None)
|
367 |
+
if rank_dropout is not None:
|
368 |
+
rank_dropout = float(rank_dropout)
|
369 |
+
module_dropout = kwargs.get("module_dropout", None)
|
370 |
+
if module_dropout is not None:
|
371 |
+
module_dropout = float(module_dropout)
|
372 |
+
|
373 |
+
# verbose
|
374 |
+
verbose = kwargs.get("verbose", False)
|
375 |
+
if verbose is not None:
|
376 |
+
verbose = True if verbose == "True" else False
|
377 |
+
|
378 |
+
# regular expression for module selection: exclude and include
|
379 |
+
exclude_patterns = kwargs.get("exclude_patterns", None)
|
380 |
+
if exclude_patterns is not None and isinstance(exclude_patterns, str):
|
381 |
+
exclude_patterns = ast.literal_eval(exclude_patterns)
|
382 |
+
include_patterns = kwargs.get("include_patterns", None)
|
383 |
+
if include_patterns is not None and isinstance(include_patterns, str):
|
384 |
+
include_patterns = ast.literal_eval(include_patterns)
|
385 |
+
|
386 |
+
# too many arguments ( ^ω^)・・・
|
387 |
+
network = LoRANetwork(
|
388 |
+
target_replace_modules,
|
389 |
+
prefix,
|
390 |
+
text_encoders,
|
391 |
+
unet,
|
392 |
+
multiplier=multiplier,
|
393 |
+
lora_dim=network_dim,
|
394 |
+
alpha=network_alpha,
|
395 |
+
dropout=neuron_dropout,
|
396 |
+
rank_dropout=rank_dropout,
|
397 |
+
module_dropout=module_dropout,
|
398 |
+
conv_lora_dim=conv_dim,
|
399 |
+
conv_alpha=conv_alpha,
|
400 |
+
exclude_patterns=exclude_patterns,
|
401 |
+
include_patterns=include_patterns,
|
402 |
+
verbose=verbose,
|
403 |
+
)
|
404 |
+
|
405 |
+
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
406 |
+
# loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
407 |
+
# loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
408 |
+
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
409 |
+
# loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
410 |
+
# loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
411 |
+
if loraplus_lr_ratio is not None: # or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
412 |
+
network.set_loraplus_lr_ratio(loraplus_lr_ratio) # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
413 |
+
|
414 |
+
return network
|
415 |
+
|
416 |
+
|
417 |
+
class LoRANetwork(torch.nn.Module):
|
418 |
+
# only supports U-Net (DiT), Text Encoders are not supported
|
419 |
+
|
420 |
+
def __init__(
|
421 |
+
self,
|
422 |
+
target_replace_modules: List[str],
|
423 |
+
prefix: str,
|
424 |
+
text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
|
425 |
+
unet: nn.Module,
|
426 |
+
multiplier: float = 1.0,
|
427 |
+
lora_dim: int = 4,
|
428 |
+
alpha: float = 1,
|
429 |
+
dropout: Optional[float] = None,
|
430 |
+
rank_dropout: Optional[float] = None,
|
431 |
+
module_dropout: Optional[float] = None,
|
432 |
+
conv_lora_dim: Optional[int] = None,
|
433 |
+
conv_alpha: Optional[float] = None,
|
434 |
+
module_class: Type[object] = LoRAModule,
|
435 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
436 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
437 |
+
exclude_patterns: Optional[List[str]] = None,
|
438 |
+
include_patterns: Optional[List[str]] = None,
|
439 |
+
verbose: Optional[bool] = False,
|
440 |
+
) -> None:
|
441 |
+
super().__init__()
|
442 |
+
self.multiplier = multiplier
|
443 |
+
|
444 |
+
self.lora_dim = lora_dim
|
445 |
+
self.alpha = alpha
|
446 |
+
self.conv_lora_dim = conv_lora_dim
|
447 |
+
self.conv_alpha = conv_alpha
|
448 |
+
self.dropout = dropout
|
449 |
+
self.rank_dropout = rank_dropout
|
450 |
+
self.module_dropout = module_dropout
|
451 |
+
self.target_replace_modules = target_replace_modules
|
452 |
+
self.prefix = prefix
|
453 |
+
|
454 |
+
self.loraplus_lr_ratio = None
|
455 |
+
# self.loraplus_unet_lr_ratio = None
|
456 |
+
# self.loraplus_text_encoder_lr_ratio = None
|
457 |
+
|
458 |
+
if modules_dim is not None:
|
459 |
+
logger.info(f"create LoRA network from weights")
|
460 |
+
else:
|
461 |
+
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
462 |
+
logger.info(
|
463 |
+
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
464 |
+
)
|
465 |
+
# if self.conv_lora_dim is not None:
|
466 |
+
# logger.info(
|
467 |
+
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
468 |
+
# )
|
469 |
+
# if train_t5xxl:
|
470 |
+
# logger.info(f"train T5XXL as well")
|
471 |
+
|
472 |
+
# compile regular expression if specified
|
473 |
+
exclude_re_patterns = []
|
474 |
+
if exclude_patterns is not None:
|
475 |
+
for pattern in exclude_patterns:
|
476 |
+
try:
|
477 |
+
re_pattern = re.compile(pattern)
|
478 |
+
except re.error as e:
|
479 |
+
logger.error(f"Invalid exclude pattern '{pattern}': {e}")
|
480 |
+
continue
|
481 |
+
exclude_re_patterns.append(re_pattern)
|
482 |
+
|
483 |
+
include_re_patterns = []
|
484 |
+
if include_patterns is not None:
|
485 |
+
for pattern in include_patterns:
|
486 |
+
try:
|
487 |
+
re_pattern = re.compile(pattern)
|
488 |
+
except re.error as e:
|
489 |
+
logger.error(f"Invalid include pattern '{pattern}': {e}")
|
490 |
+
continue
|
491 |
+
include_re_patterns.append(re_pattern)
|
492 |
+
|
493 |
+
# create module instances
|
494 |
+
def create_modules(
|
495 |
+
is_unet: bool,
|
496 |
+
pfx: str,
|
497 |
+
root_module: torch.nn.Module,
|
498 |
+
target_replace_mods: Optional[List[str]] = None,
|
499 |
+
filter: Optional[str] = None,
|
500 |
+
default_dim: Optional[int] = None,
|
501 |
+
) -> List[LoRAModule]:
|
502 |
+
loras = []
|
503 |
+
skipped = []
|
504 |
+
for name, module in root_module.named_modules():
|
505 |
+
if target_replace_mods is None or module.__class__.__name__ in target_replace_mods:
|
506 |
+
if target_replace_mods is None: # dirty hack for all modules
|
507 |
+
module = root_module # search all modules
|
508 |
+
|
509 |
+
for child_name, child_module in module.named_modules():
|
510 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
511 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
512 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
513 |
+
|
514 |
+
if is_linear or is_conv2d:
|
515 |
+
original_name = (name + "." if name else "") + child_name
|
516 |
+
lora_name = f"{pfx}.{original_name}".replace(".", "_")
|
517 |
+
|
518 |
+
# exclude/include filter
|
519 |
+
excluded = False
|
520 |
+
for pattern in exclude_re_patterns:
|
521 |
+
if pattern.match(original_name):
|
522 |
+
excluded = True
|
523 |
+
break
|
524 |
+
included = False
|
525 |
+
for pattern in include_re_patterns:
|
526 |
+
if pattern.match(original_name):
|
527 |
+
included = True
|
528 |
+
break
|
529 |
+
if excluded and not included:
|
530 |
+
if verbose:
|
531 |
+
logger.info(f"exclude: {original_name}")
|
532 |
+
continue
|
533 |
+
|
534 |
+
# filter by name (not used in the current implementation)
|
535 |
+
if filter is not None and not filter in lora_name:
|
536 |
+
continue
|
537 |
+
|
538 |
+
dim = None
|
539 |
+
alpha = None
|
540 |
+
|
541 |
+
if modules_dim is not None:
|
542 |
+
# モジュール指定あり
|
543 |
+
if lora_name in modules_dim:
|
544 |
+
dim = modules_dim[lora_name]
|
545 |
+
alpha = modules_alpha[lora_name]
|
546 |
+
else:
|
547 |
+
# 通常、すべて対象とする
|
548 |
+
if is_linear or is_conv2d_1x1:
|
549 |
+
dim = default_dim if default_dim is not None else self.lora_dim
|
550 |
+
alpha = self.alpha
|
551 |
+
elif self.conv_lora_dim is not None:
|
552 |
+
dim = self.conv_lora_dim
|
553 |
+
alpha = self.conv_alpha
|
554 |
+
|
555 |
+
if dim is None or dim == 0:
|
556 |
+
# skipした情報を出力
|
557 |
+
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
|
558 |
+
skipped.append(lora_name)
|
559 |
+
continue
|
560 |
+
|
561 |
+
lora = module_class(
|
562 |
+
lora_name,
|
563 |
+
child_module,
|
564 |
+
self.multiplier,
|
565 |
+
dim,
|
566 |
+
alpha,
|
567 |
+
dropout=dropout,
|
568 |
+
rank_dropout=rank_dropout,
|
569 |
+
module_dropout=module_dropout,
|
570 |
+
)
|
571 |
+
loras.append(lora)
|
572 |
+
|
573 |
+
if target_replace_mods is None:
|
574 |
+
break # all modules are searched
|
575 |
+
return loras, skipped
|
576 |
+
|
577 |
+
# # create LoRA for text encoder
|
578 |
+
# # it is redundant to create LoRA modules even if they are not used
|
579 |
+
|
580 |
+
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
|
581 |
+
# skipped_te = []
|
582 |
+
# for i, text_encoder in enumerate(text_encoders):
|
583 |
+
# index = i
|
584 |
+
# if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
|
585 |
+
# break
|
586 |
+
# logger.info(f"create LoRA for Text Encoder {index+1}:")
|
587 |
+
# text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
588 |
+
# logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
|
589 |
+
# self.text_encoder_loras.extend(text_encoder_loras)
|
590 |
+
# skipped_te += skipped
|
591 |
+
|
592 |
+
# create LoRA for U-Net
|
593 |
+
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
594 |
+
self.unet_loras, skipped_un = create_modules(True, prefix, unet, target_replace_modules)
|
595 |
+
|
596 |
+
logger.info(f"create LoRA for U-Net/DiT: {len(self.unet_loras)} modules.")
|
597 |
+
if verbose:
|
598 |
+
for lora in self.unet_loras:
|
599 |
+
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
|
600 |
+
|
601 |
+
skipped = skipped_un
|
602 |
+
if verbose and len(skipped) > 0:
|
603 |
+
logger.warning(
|
604 |
+
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
605 |
+
)
|
606 |
+
for name in skipped:
|
607 |
+
logger.info(f"\t{name}")
|
608 |
+
|
609 |
+
# assertion
|
610 |
+
names = set()
|
611 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
612 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
613 |
+
names.add(lora.lora_name)
|
614 |
+
|
615 |
+
def prepare_network(self, args):
|
616 |
+
"""
|
617 |
+
called after the network is created
|
618 |
+
"""
|
619 |
+
pass
|
620 |
+
|
621 |
+
def set_multiplier(self, multiplier):
|
622 |
+
self.multiplier = multiplier
|
623 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
624 |
+
lora.multiplier = self.multiplier
|
625 |
+
|
626 |
+
def set_enabled(self, is_enabled):
|
627 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
628 |
+
lora.enabled = is_enabled
|
629 |
+
|
630 |
+
def load_weights(self, file):
|
631 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
632 |
+
from safetensors.torch import load_file
|
633 |
+
|
634 |
+
weights_sd = load_file(file)
|
635 |
+
else:
|
636 |
+
weights_sd = torch.load(file, map_location="cpu")
|
637 |
+
|
638 |
+
info = self.load_state_dict(weights_sd, False)
|
639 |
+
return info
|
640 |
+
|
641 |
+
def apply_to(
|
642 |
+
self,
|
643 |
+
text_encoders: Optional[nn.Module],
|
644 |
+
unet: Optional[nn.Module],
|
645 |
+
apply_text_encoder: bool = True,
|
646 |
+
apply_unet: bool = True,
|
647 |
+
):
|
648 |
+
if apply_text_encoder:
|
649 |
+
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
650 |
+
else:
|
651 |
+
self.text_encoder_loras = []
|
652 |
+
|
653 |
+
if apply_unet:
|
654 |
+
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
|
655 |
+
else:
|
656 |
+
self.unet_loras = []
|
657 |
+
|
658 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
659 |
+
lora.apply_to()
|
660 |
+
self.add_module(lora.lora_name, lora)
|
661 |
+
|
662 |
+
# マージできるかどうかを返す
|
663 |
+
def is_mergeable(self):
|
664 |
+
return True
|
665 |
+
|
666 |
+
# TODO refactor to common function with apply_to
|
667 |
+
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None, non_blocking=False):
|
668 |
+
from concurrent.futures import ThreadPoolExecutor
|
669 |
+
|
670 |
+
with ThreadPoolExecutor(max_workers=2) as executor: # 2 workers is enough
|
671 |
+
futures = []
|
672 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
673 |
+
sd_for_lora = {}
|
674 |
+
for key in weights_sd.keys():
|
675 |
+
if key.startswith(lora.lora_name):
|
676 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
677 |
+
if len(sd_for_lora) == 0:
|
678 |
+
logger.info(f"no weight for {lora.lora_name}")
|
679 |
+
continue
|
680 |
+
|
681 |
+
# lora.merge_to(sd_for_lora, dtype, device)
|
682 |
+
futures.append(executor.submit(lora.merge_to, sd_for_lora, dtype, device, non_blocking))
|
683 |
+
|
684 |
+
for future in futures:
|
685 |
+
future.result()
|
686 |
+
|
687 |
+
logger.info(f"weights are merged")
|
688 |
+
|
689 |
+
def set_loraplus_lr_ratio(self, loraplus_lr_ratio): # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
690 |
+
self.loraplus_lr_ratio = loraplus_lr_ratio
|
691 |
+
|
692 |
+
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_lr_ratio}")
|
693 |
+
# logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
694 |
+
|
695 |
+
def prepare_optimizer_params(self, unet_lr: float = 1e-4, **kwargs):
|
696 |
+
self.requires_grad_(True)
|
697 |
+
|
698 |
+
all_params = []
|
699 |
+
lr_descriptions = []
|
700 |
+
|
701 |
+
def assemble_params(loras, lr, loraplus_ratio):
|
702 |
+
param_groups = {"lora": {}, "plus": {}}
|
703 |
+
for lora in loras:
|
704 |
+
for name, param in lora.named_parameters():
|
705 |
+
if loraplus_ratio is not None and "lora_up" in name:
|
706 |
+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
707 |
+
else:
|
708 |
+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
709 |
+
|
710 |
+
params = []
|
711 |
+
descriptions = []
|
712 |
+
for key in param_groups.keys():
|
713 |
+
param_data = {"params": param_groups[key].values()}
|
714 |
+
|
715 |
+
if len(param_data["params"]) == 0:
|
716 |
+
continue
|
717 |
+
|
718 |
+
if lr is not None:
|
719 |
+
if key == "plus":
|
720 |
+
param_data["lr"] = lr * loraplus_ratio
|
721 |
+
else:
|
722 |
+
param_data["lr"] = lr
|
723 |
+
|
724 |
+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
725 |
+
logger.info("NO LR skipping!")
|
726 |
+
continue
|
727 |
+
|
728 |
+
params.append(param_data)
|
729 |
+
descriptions.append("plus" if key == "plus" else "")
|
730 |
+
|
731 |
+
return params, descriptions
|
732 |
+
|
733 |
+
if self.unet_loras:
|
734 |
+
params, descriptions = assemble_params(self.unet_loras, unet_lr, self.loraplus_lr_ratio)
|
735 |
+
all_params.extend(params)
|
736 |
+
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
737 |
+
|
738 |
+
return all_params, lr_descriptions
|
739 |
+
|
740 |
+
def enable_gradient_checkpointing(self):
|
741 |
+
# not supported
|
742 |
+
pass
|
743 |
+
|
744 |
+
def prepare_grad_etc(self, unet):
|
745 |
+
self.requires_grad_(True)
|
746 |
+
|
747 |
+
def on_epoch_start(self, unet):
|
748 |
+
self.train()
|
749 |
+
|
750 |
+
def on_step_start(self):
|
751 |
+
pass
|
752 |
+
|
753 |
+
def get_trainable_params(self):
|
754 |
+
return self.parameters()
|
755 |
+
|
756 |
+
def save_weights(self, file, dtype, metadata):
|
757 |
+
if metadata is not None and len(metadata) == 0:
|
758 |
+
metadata = None
|
759 |
+
|
760 |
+
state_dict = self.state_dict()
|
761 |
+
|
762 |
+
if dtype is not None:
|
763 |
+
for key in list(state_dict.keys()):
|
764 |
+
v = state_dict[key]
|
765 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
766 |
+
state_dict[key] = v
|
767 |
+
|
768 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
769 |
+
from safetensors.torch import save_file
|
770 |
+
from utils import model_utils
|
771 |
+
|
772 |
+
# Precalculate model hashes to save time on indexing
|
773 |
+
if metadata is None:
|
774 |
+
metadata = {}
|
775 |
+
model_hash, legacy_hash = model_utils.precalculate_safetensors_hashes(state_dict, metadata)
|
776 |
+
metadata["sshs_model_hash"] = model_hash
|
777 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
778 |
+
|
779 |
+
save_file(state_dict, file, metadata)
|
780 |
+
else:
|
781 |
+
torch.save(state_dict, file)
|
782 |
+
|
783 |
+
def backup_weights(self):
|
784 |
+
# 重みのバックアップを行う
|
785 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
786 |
+
for lora in loras:
|
787 |
+
org_module = lora.org_module_ref[0]
|
788 |
+
if not hasattr(org_module, "_lora_org_weight"):
|
789 |
+
sd = org_module.state_dict()
|
790 |
+
org_module._lora_org_weight = sd["weight"].detach().clone()
|
791 |
+
org_module._lora_restored = True
|
792 |
+
|
793 |
+
def restore_weights(self):
|
794 |
+
# 重みのリストアを行う
|
795 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
796 |
+
for lora in loras:
|
797 |
+
org_module = lora.org_module_ref[0]
|
798 |
+
if not org_module._lora_restored:
|
799 |
+
sd = org_module.state_dict()
|
800 |
+
sd["weight"] = org_module._lora_org_weight
|
801 |
+
org_module.load_state_dict(sd)
|
802 |
+
org_module._lora_restored = True
|
803 |
+
|
804 |
+
def pre_calculation(self):
|
805 |
+
# 事前計算を行う
|
806 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
807 |
+
for lora in loras:
|
808 |
+
org_module = lora.org_module_ref[0]
|
809 |
+
sd = org_module.state_dict()
|
810 |
+
|
811 |
+
org_weight = sd["weight"]
|
812 |
+
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
813 |
+
sd["weight"] = org_weight + lora_weight
|
814 |
+
assert sd["weight"].shape == org_weight.shape
|
815 |
+
org_module.load_state_dict(sd)
|
816 |
+
|
817 |
+
org_module._lora_restored = False
|
818 |
+
lora.enabled = False
|
819 |
+
|
820 |
+
def apply_max_norm_regularization(self, max_norm_value, device):
|
821 |
+
downkeys = []
|
822 |
+
upkeys = []
|
823 |
+
alphakeys = []
|
824 |
+
norms = []
|
825 |
+
keys_scaled = 0
|
826 |
+
|
827 |
+
state_dict = self.state_dict()
|
828 |
+
for key in state_dict.keys():
|
829 |
+
if "lora_down" in key and "weight" in key:
|
830 |
+
downkeys.append(key)
|
831 |
+
upkeys.append(key.replace("lora_down", "lora_up"))
|
832 |
+
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
833 |
+
|
834 |
+
for i in range(len(downkeys)):
|
835 |
+
down = state_dict[downkeys[i]].to(device)
|
836 |
+
up = state_dict[upkeys[i]].to(device)
|
837 |
+
alpha = state_dict[alphakeys[i]].to(device)
|
838 |
+
dim = down.shape[0]
|
839 |
+
scale = alpha / dim
|
840 |
+
|
841 |
+
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
842 |
+
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
843 |
+
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
844 |
+
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
845 |
+
else:
|
846 |
+
updown = up @ down
|
847 |
+
|
848 |
+
updown *= scale
|
849 |
+
|
850 |
+
norm = updown.norm().clamp(min=max_norm_value / 2)
|
851 |
+
desired = torch.clamp(norm, max=max_norm_value)
|
852 |
+
ratio = desired.cpu() / norm.cpu()
|
853 |
+
sqrt_ratio = ratio**0.5
|
854 |
+
if ratio != 1:
|
855 |
+
keys_scaled += 1
|
856 |
+
state_dict[upkeys[i]] *= sqrt_ratio
|
857 |
+
state_dict[downkeys[i]] *= sqrt_ratio
|
858 |
+
scalednorm = updown.norm() * ratio
|
859 |
+
norms.append(scalednorm.item())
|
860 |
+
|
861 |
+
return keys_scaled, sum(norms) / len(norms), max(norms)
|
862 |
+
|
863 |
+
|
864 |
+
def create_arch_network_from_weights(
|
865 |
+
multiplier: float,
|
866 |
+
weights_sd: Dict[str, torch.Tensor],
|
867 |
+
text_encoders: Optional[List[nn.Module]] = None,
|
868 |
+
unet: Optional[nn.Module] = None,
|
869 |
+
for_inference: bool = False,
|
870 |
+
**kwargs,
|
871 |
+
) -> LoRANetwork:
|
872 |
+
return create_network_from_weights(
|
873 |
+
HUNYUAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
|
874 |
+
)
|
875 |
+
|
876 |
+
|
877 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
878 |
+
def create_network_from_weights(
|
879 |
+
target_replace_modules: List[str],
|
880 |
+
multiplier: float,
|
881 |
+
weights_sd: Dict[str, torch.Tensor],
|
882 |
+
text_encoders: Optional[List[nn.Module]] = None,
|
883 |
+
unet: Optional[nn.Module] = None,
|
884 |
+
for_inference: bool = False,
|
885 |
+
**kwargs,
|
886 |
+
) -> LoRANetwork:
|
887 |
+
# get dim/alpha mapping
|
888 |
+
modules_dim = {}
|
889 |
+
modules_alpha = {}
|
890 |
+
for key, value in weights_sd.items():
|
891 |
+
if "." not in key:
|
892 |
+
continue
|
893 |
+
|
894 |
+
lora_name = key.split(".")[0]
|
895 |
+
if "alpha" in key:
|
896 |
+
modules_alpha[lora_name] = value
|
897 |
+
elif "lora_down" in key:
|
898 |
+
dim = value.shape[0]
|
899 |
+
modules_dim[lora_name] = dim
|
900 |
+
# logger.info(lora_name, value.size(), dim)
|
901 |
+
|
902 |
+
module_class = LoRAInfModule if for_inference else LoRAModule
|
903 |
+
|
904 |
+
network = LoRANetwork(
|
905 |
+
target_replace_modules,
|
906 |
+
"lora_unet",
|
907 |
+
text_encoders,
|
908 |
+
unet,
|
909 |
+
multiplier=multiplier,
|
910 |
+
modules_dim=modules_dim,
|
911 |
+
modules_alpha=modules_alpha,
|
912 |
+
module_class=module_class,
|
913 |
+
)
|
914 |
+
return network
|
networks/lora_wan.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA module for Wan2.1
|
2 |
+
|
3 |
+
import ast
|
4 |
+
from typing import Dict, List, Optional
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
|
13 |
+
import networks.lora as lora
|
14 |
+
|
15 |
+
|
16 |
+
WAN_TARGET_REPLACE_MODULES = ["WanAttentionBlock"]
|
17 |
+
|
18 |
+
|
19 |
+
def create_arch_network(
|
20 |
+
multiplier: float,
|
21 |
+
network_dim: Optional[int],
|
22 |
+
network_alpha: Optional[float],
|
23 |
+
vae: nn.Module,
|
24 |
+
text_encoders: List[nn.Module],
|
25 |
+
unet: nn.Module,
|
26 |
+
neuron_dropout: Optional[float] = None,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
# add default exclude patterns
|
30 |
+
exclude_patterns = kwargs.get("exclude_patterns", None)
|
31 |
+
if exclude_patterns is None:
|
32 |
+
exclude_patterns = []
|
33 |
+
else:
|
34 |
+
exclude_patterns = ast.literal_eval(exclude_patterns)
|
35 |
+
|
36 |
+
# exclude if 'img_mod', 'txt_mod' or 'modulation' in the name
|
37 |
+
exclude_patterns.append(r".*(patch_embedding|text_embedding|time_embedding|time_projection|norm|head).*")
|
38 |
+
|
39 |
+
kwargs["exclude_patterns"] = exclude_patterns
|
40 |
+
|
41 |
+
return lora.create_network(
|
42 |
+
WAN_TARGET_REPLACE_MODULES,
|
43 |
+
"lora_unet",
|
44 |
+
multiplier,
|
45 |
+
network_dim,
|
46 |
+
network_alpha,
|
47 |
+
vae,
|
48 |
+
text_encoders,
|
49 |
+
unet,
|
50 |
+
neuron_dropout=neuron_dropout,
|
51 |
+
**kwargs,
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def create_arch_network_from_weights(
|
56 |
+
multiplier: float,
|
57 |
+
weights_sd: Dict[str, torch.Tensor],
|
58 |
+
text_encoders: Optional[List[nn.Module]] = None,
|
59 |
+
unet: Optional[nn.Module] = None,
|
60 |
+
for_inference: bool = False,
|
61 |
+
**kwargs,
|
62 |
+
) -> lora.LoRANetwork:
|
63 |
+
return lora.create_network_from_weights(
|
64 |
+
WAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
|
65 |
+
)
|
pixel_outputs/pixel_w1_3_lora-000001.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0861037ca1ce8517c187e11e392195df38423d670a73f1cf8622d9213fb2ad5c
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000002.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2da2daa5dea835024a559f06eea97b38a3b809b6724915cd6715fa702c74b059
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:47f65b1ed7b77d27cf59274d3bf6f6285ada28ee0b453ac98efdcb86a429af7a
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c09f83bb67f7dd5566ae6ac5fd11f32877f964f5f416c786a5e85382c96df73
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6edb55a1a4115fefddfc63957c0c140b21e5aef5a4b0b19924305b59af796824
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000006.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:325d67e06add9069b3aa7f5761a778f233566f79a0aa3f51f1ea4a535f4d7211
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f01e0a354111250caaf0bc609db677552a8e7fc363472ca680ece252ffbb086
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75cb5269d989e6467384147285186be202ef5a77192430cb95a74425ad574d3b
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000009.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eea26f66e5bde80ad506b6012883e89f5249acf066160a8529b2c2bbfd9271c3
|
3 |
+
size 87594616
|
pixel_outputs/pixel_w1_3_lora-000010.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8386003f664721e14168cf6c870052474a26fdccfd73176244770075305e0ee
|
3 |
+
size 87594624
|