svjack commited on
Commit
f7cfd9d
·
verified ·
1 Parent(s): c872c1e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +5 -0
  2. .ipynb_checkpoints/README-checkpoint.md +173 -0
  3. .python-version +1 -0
  4. README.md +173 -0
  5. cache_latents.py +281 -0
  6. cache_text_encoder_outputs.py +214 -0
  7. convert_lora.py +135 -0
  8. dataset/__init__.py +0 -0
  9. dataset/config_utils.py +372 -0
  10. dataset/dataset_config.md +387 -0
  11. dataset/image_video_dataset.py +1400 -0
  12. docs/advanced_config.md +151 -0
  13. docs/sampling_during_training.md +108 -0
  14. docs/wan.md +241 -0
  15. hunyuan_model/__init__.py +0 -0
  16. hunyuan_model/activation_layers.py +23 -0
  17. hunyuan_model/attention.py +295 -0
  18. hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
  19. hunyuan_model/embed_layers.py +132 -0
  20. hunyuan_model/helpers.py +40 -0
  21. hunyuan_model/mlp_layers.py +118 -0
  22. hunyuan_model/models.py +1044 -0
  23. hunyuan_model/modulate_layers.py +76 -0
  24. hunyuan_model/norm_layers.py +79 -0
  25. hunyuan_model/pipeline_hunyuan_video.py +1100 -0
  26. hunyuan_model/posemb_layers.py +310 -0
  27. hunyuan_model/text_encoder.py +710 -0
  28. hunyuan_model/token_refiner.py +245 -0
  29. hunyuan_model/vae.py +446 -0
  30. hv_generate_video.py +911 -0
  31. hv_train.py +1721 -0
  32. hv_train_network.py +0 -0
  33. merge_lora.py +63 -0
  34. modules/__init__.py +0 -0
  35. modules/custom_offloading_utils.py +266 -0
  36. modules/scheduling_flow_match_discrete.py +257 -0
  37. modules/unet_causal_3d_blocks.py +818 -0
  38. networks/__init__.py +0 -0
  39. networks/lora.py +914 -0
  40. networks/lora_wan.py +65 -0
  41. pixel_outputs/pixel_w1_3_lora-000001.safetensors +3 -0
  42. pixel_outputs/pixel_w1_3_lora-000002.safetensors +3 -0
  43. pixel_outputs/pixel_w1_3_lora-000003.safetensors +3 -0
  44. pixel_outputs/pixel_w1_3_lora-000004.safetensors +3 -0
  45. pixel_outputs/pixel_w1_3_lora-000005.safetensors +3 -0
  46. pixel_outputs/pixel_w1_3_lora-000006.safetensors +3 -0
  47. pixel_outputs/pixel_w1_3_lora-000007.safetensors +3 -0
  48. pixel_outputs/pixel_w1_3_lora-000008.safetensors +3 -0
  49. pixel_outputs/pixel_w1_3_lora-000009.safetensors +3 -0
  50. pixel_outputs/pixel_w1_3_lora-000010.safetensors +3 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ .venv
3
+ venv/
4
+ logs/
5
+ uv.lock
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pixel Text-to-Video Generation
2
+
3
+ This repository contains the necessary steps and scripts to generate videos using the Pixel text-to-video model. The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create high-quality anime-style videos based on textual prompts.
4
+
5
+ ## Prerequisites
6
+
7
+ Before proceeding, ensure that you have the following installed on your system:
8
+
9
+ • **Ubuntu** (or a compatible Linux distribution)
10
+ • **Python 3.x**
11
+ • **pip** (Python package manager)
12
+ • **Git**
13
+ • **Git LFS** (Git Large File Storage)
14
+ • **FFmpeg**
15
+
16
+ ## Installation
17
+
18
+ 1. **Update and Install Dependencies**
19
+
20
+ ```bash
21
+ sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
22
+ ```
23
+
24
+ 2. **Clone the Repository**
25
+
26
+ ```bash
27
+ git clone https://huggingface.co/svjack/Pixel_wan_2_1_1_3_B_text2video_lora
28
+ cd Pixel_wan_2_1_1_3_B_text2video_lora
29
+ ```
30
+
31
+ 3. **Install Python Dependencies**
32
+
33
+ ```bash
34
+ pip install torch torchvision
35
+ pip install -r requirements.txt
36
+ pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
37
+ pip install moviepy==1.0.3
38
+ pip install sageattention==1.0.6
39
+ ```
40
+
41
+ 4. **Download Model Weights**
42
+
43
+ ```bash
44
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
45
+ wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
46
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
47
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
48
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
49
+ ```
50
+
51
+ ## Usage
52
+
53
+ To generate a video, use the `wan_generate_video.py` script with the appropriate parameters. Below are examples of how to generate videos using the Pixel model.
54
+
55
+ #### Woods
56
+
57
+ ```bash
58
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
59
+ --save_path save --output_type both \
60
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
61
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
62
+ --attn_mode torch \
63
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
64
+ --lora_multiplier 1.0 \
65
+ --prompt "The video showcases a pixel art scene from a video game. Golden light filters through the canopy, illuminating soft moss and fallen leaves. Wildflowers bloom nearby, and glowing fireflies hover in the air. A gentle stream flows in the background, its murmur blending with birdsong. The scene radiates tranquility and natural charm."
66
+
67
+ ```
68
+
69
+
70
+ #### Castle
71
+
72
+ ```bash
73
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
74
+ --save_path save --output_type both \
75
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
76
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
77
+ --attn_mode torch \
78
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
79
+ --lora_multiplier 1.0 \
80
+ --prompt "The video showcases a pixel art scene from a video game. the video shifts to a majestic castle under a starry sky. Silvery moonlight bathes the ancient stone walls, casting soft shadows on the surrounding landscape. Towering spires rise into the night, their peaks adorned with glowing orbs that mimic the stars above. A tranquil moat reflects the shimmering heavens, its surface rippling gently in the cool breeze. Fireflies dance around the castle’s ivy-covered arches, adding a touch of magic to the scene. In the distance, a faint aurora paints the horizon with hues of green and purple, blending seamlessly with the celestial tapestry. The scene exudes an aura of timeless wonder and serene beauty."
81
+
82
+ ```
83
+
84
+
85
+ #### City
86
+
87
+ ```bash
88
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
89
+ --save_path save --output_type both \
90
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
91
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
92
+ --attn_mode torch \
93
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
94
+ --lora_multiplier 1.0 \
95
+ --prompt "The video showcases a pixel art scene from a video game. the video showcases a vibrant urban landscape. The city skyline is dominated by towering skyscrapers, their glass facades reflecting the sunlight. The streets are bustling with activity, filled with cars, buses, and pedestrians. Parks and green spaces are scattered throughout, offering a refreshing contrast to the concrete jungle. The architecture is a mix of modern and historic buildings, each telling a story of the city's evolution. The overall scene is alive with energy, capturing the essence of urban life."
96
+
97
+ ```
98
+
99
+
100
+ #### Girl
101
+
102
+ ```bash
103
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
104
+ --save_path save --output_type both \
105
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
106
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
107
+ --attn_mode torch \
108
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
109
+ --lora_multiplier 1.0 \
110
+ --prompt "The video showcases a pixel art scene from a video game. .The video showcases a animation featuring charming anime-style scene featuring a pink-haired girl with angel wings. She's seated at a desk, enjoying a donut while working on a laptop. The setting is a cozy, pastel-colored room with a pink chair, a milk carton, and a coffee cup. The girl's expression is one of delight as she savors her treat."
111
+
112
+ ```
113
+
114
+ #### Snow
115
+
116
+ ```bash
117
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
118
+ --save_path save --output_type both \
119
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
120
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
121
+ --attn_mode torch \
122
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
123
+ --lora_multiplier 1.0 \
124
+ --prompt "The video showcases a pixel art scene from a video game. .The video showcases a animation featuring a serene and majestic snowy mountain landscape. The scene is dominated by towering peaks covered in pristine white snow, with a soft gradient of blue and purple hues in the sky. A small cabin with a smoking chimney sits at the base of the mountain, surrounded by pine trees dusted with snow. A winding path leads up the mountain, with footprints visible in the snow. The atmosphere is calm and peaceful, evoking a sense of solitude and wonder."
125
+
126
+ ```
127
+
128
+
129
+ ## Parameters
130
+
131
+ * `--fp8`: Enable FP8 precision (optional).
132
+ * `--task`: Specify the task (e.g., `t2v-1.3B`).
133
+ * `--video_size`: Set the resolution of the generated video (e.g., `1024 1024`).
134
+ * `--video_length`: Define the length of the video in frames.
135
+ * `--infer_steps`: Number of inference steps.
136
+ * `--save_path`: Directory to save the generated video.
137
+ * `--output_type`: Output type (e.g., `both` for video and frames).
138
+ * `--dit`: Path to the diffusion model weights.
139
+ * `--vae`: Path to the VAE model weights.
140
+ * `--t5`: Path to the T5 model weights.
141
+ * `--attn_mode`: Attention mode (e.g., `torch`).
142
+ * `--lora_weight`: Path to the LoRA weights.
143
+ * `--lora_multiplier`: Multiplier for LoRA weights.
144
+ * `--prompt`: Textual prompt for video generation.
145
+
146
+
147
+
148
+ ## Output
149
+
150
+ The generated video and frames will be saved in the specified `save_path` directory.
151
+
152
+ ## Troubleshooting
153
+
154
+ • Ensure all dependencies are correctly installed.
155
+ • Verify that the model weights are downloaded and placed in the correct locations.
156
+ • Check for any missing Python packages and install them using `pip`.
157
+
158
+ ## License
159
+
160
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
161
+
162
+ ## Acknowledgments
163
+
164
+ • **Hugging Face** for hosting the model weights.
165
+ • **Wan-AI** for providing the pre-trained models.
166
+ • **DeepBeepMeep** for contributing to the model weights.
167
+
168
+ ## Contact
169
+
170
+ For any questions or issues, please open an issue on the repository or contact the maintainer.
171
+
172
+ ---
173
+
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
README.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pixel Text-to-Video Generation
2
+
3
+ This repository contains the necessary steps and scripts to generate videos using the Pixel text-to-video model. The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create high-quality anime-style videos based on textual prompts.
4
+
5
+ ## Prerequisites
6
+
7
+ Before proceeding, ensure that you have the following installed on your system:
8
+
9
+ • **Ubuntu** (or a compatible Linux distribution)
10
+ • **Python 3.x**
11
+ • **pip** (Python package manager)
12
+ • **Git**
13
+ • **Git LFS** (Git Large File Storage)
14
+ • **FFmpeg**
15
+
16
+ ## Installation
17
+
18
+ 1. **Update and Install Dependencies**
19
+
20
+ ```bash
21
+ sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
22
+ ```
23
+
24
+ 2. **Clone the Repository**
25
+
26
+ ```bash
27
+ git clone https://huggingface.co/svjack/Pixel_wan_2_1_1_3_B_text2video_lora
28
+ cd Pixel_wan_2_1_1_3_B_text2video_lora
29
+ ```
30
+
31
+ 3. **Install Python Dependencies**
32
+
33
+ ```bash
34
+ pip install torch torchvision
35
+ pip install -r requirements.txt
36
+ pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
37
+ pip install moviepy==1.0.3
38
+ pip install sageattention==1.0.6
39
+ ```
40
+
41
+ 4. **Download Model Weights**
42
+
43
+ ```bash
44
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
45
+ wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
46
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
47
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
48
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
49
+ ```
50
+
51
+ ## Usage
52
+
53
+ To generate a video, use the `wan_generate_video.py` script with the appropriate parameters. Below are examples of how to generate videos using the Pixel model.
54
+
55
+ #### Woods
56
+
57
+ ```bash
58
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
59
+ --save_path save --output_type both \
60
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
61
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
62
+ --attn_mode torch \
63
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
64
+ --lora_multiplier 1.0 \
65
+ --prompt "The video showcases a pixel art scene from a video game. Golden light filters through the canopy, illuminating soft moss and fallen leaves. Wildflowers bloom nearby, and glowing fireflies hover in the air. A gentle stream flows in the background, its murmur blending with birdsong. The scene radiates tranquility and natural charm."
66
+
67
+ ```
68
+
69
+
70
+ #### Castle
71
+
72
+ ```bash
73
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
74
+ --save_path save --output_type both \
75
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
76
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
77
+ --attn_mode torch \
78
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
79
+ --lora_multiplier 1.0 \
80
+ --prompt "The video showcases a pixel art scene from a video game. the video shifts to a majestic castle under a starry sky. Silvery moonlight bathes the ancient stone walls, casting soft shadows on the surrounding landscape. Towering spires rise into the night, their peaks adorned with glowing orbs that mimic the stars above. A tranquil moat reflects the shimmering heavens, its surface rippling gently in the cool breeze. Fireflies dance around the castle’s ivy-covered arches, adding a touch of magic to the scene. In the distance, a faint aurora paints the horizon with hues of green and purple, blending seamlessly with the celestial tapestry. The scene exudes an aura of timeless wonder and serene beauty."
81
+
82
+ ```
83
+
84
+
85
+ #### City
86
+
87
+ ```bash
88
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
89
+ --save_path save --output_type both \
90
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
91
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
92
+ --attn_mode torch \
93
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
94
+ --lora_multiplier 1.0 \
95
+ --prompt "The video showcases a pixel art scene from a video game. the video showcases a vibrant urban landscape. The city skyline is dominated by towering skyscrapers, their glass facades reflecting the sunlight. The streets are bustling with activity, filled with cars, buses, and pedestrians. Parks and green spaces are scattered throughout, offering a refreshing contrast to the concrete jungle. The architecture is a mix of modern and historic buildings, each telling a story of the city's evolution. The overall scene is alive with energy, capturing the essence of urban life."
96
+
97
+ ```
98
+
99
+
100
+ #### Girl
101
+
102
+ ```bash
103
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
104
+ --save_path save --output_type both \
105
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
106
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
107
+ --attn_mode torch \
108
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
109
+ --lora_multiplier 1.0 \
110
+ --prompt "The video showcases a pixel art scene from a video game. .The video showcases a animation featuring charming anime-style scene featuring a pink-haired girl with angel wings. She's seated at a desk, enjoying a donut while working on a laptop. The setting is a cozy, pastel-colored room with a pink chair, a milk carton, and a coffee cup. The girl's expression is one of delight as she savors her treat."
111
+
112
+ ```
113
+
114
+ #### Snow
115
+
116
+ ```bash
117
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 768 1024 --video_length 81 --infer_steps 20 \
118
+ --save_path save --output_type both \
119
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
120
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
121
+ --attn_mode torch \
122
+ --lora_weight pixel_outputs/pixel_w1_3_lora-000010.safetensors \
123
+ --lora_multiplier 1.0 \
124
+ --prompt "The video showcases a pixel art scene from a video game. .The video showcases a animation featuring a serene and majestic snowy mountain landscape. The scene is dominated by towering peaks covered in pristine white snow, with a soft gradient of blue and purple hues in the sky. A small cabin with a smoking chimney sits at the base of the mountain, surrounded by pine trees dusted with snow. A winding path leads up the mountain, with footprints visible in the snow. The atmosphere is calm and peaceful, evoking a sense of solitude and wonder."
125
+
126
+ ```
127
+
128
+
129
+ ## Parameters
130
+
131
+ * `--fp8`: Enable FP8 precision (optional).
132
+ * `--task`: Specify the task (e.g., `t2v-1.3B`).
133
+ * `--video_size`: Set the resolution of the generated video (e.g., `1024 1024`).
134
+ * `--video_length`: Define the length of the video in frames.
135
+ * `--infer_steps`: Number of inference steps.
136
+ * `--save_path`: Directory to save the generated video.
137
+ * `--output_type`: Output type (e.g., `both` for video and frames).
138
+ * `--dit`: Path to the diffusion model weights.
139
+ * `--vae`: Path to the VAE model weights.
140
+ * `--t5`: Path to the T5 model weights.
141
+ * `--attn_mode`: Attention mode (e.g., `torch`).
142
+ * `--lora_weight`: Path to the LoRA weights.
143
+ * `--lora_multiplier`: Multiplier for LoRA weights.
144
+ * `--prompt`: Textual prompt for video generation.
145
+
146
+
147
+
148
+ ## Output
149
+
150
+ The generated video and frames will be saved in the specified `save_path` directory.
151
+
152
+ ## Troubleshooting
153
+
154
+ • Ensure all dependencies are correctly installed.
155
+ • Verify that the model weights are downloaded and placed in the correct locations.
156
+ • Check for any missing Python packages and install them using `pip`.
157
+
158
+ ## License
159
+
160
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
161
+
162
+ ## Acknowledgments
163
+
164
+ • **Hugging Face** for hosting the model weights.
165
+ • **Wan-AI** for providing the pre-trained models.
166
+ • **DeepBeepMeep** for contributing to the model weights.
167
+
168
+ ## Contact
169
+
170
+ For any questions or issues, please open an issue on the repository or contact the maintainer.
171
+
172
+ ---
173
+
cache_latents.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import glob
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from dataset import config_utils
11
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
12
+ from PIL import Image
13
+
14
+ import logging
15
+
16
+ from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache, ARCHITECTURE_HUNYUAN_VIDEO
17
+ from hunyuan_model.vae import load_vae
18
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
26
+ import cv2
27
+
28
+ imgs = (
29
+ [image]
30
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
31
+ else [image[0], image[-1]]
32
+ )
33
+ if len(imgs) > 1:
34
+ print(f"Number of images: {len(image)}")
35
+ for i, img in enumerate(imgs):
36
+ if len(imgs) > 1:
37
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
38
+ else:
39
+ print(f"Image: {img.shape}")
40
+ cv2_img = np.array(img) if isinstance(img, Image.Image) else img
41
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
42
+ cv2.imshow("image", cv2_img)
43
+ k = cv2.waitKey(0)
44
+ cv2.destroyAllWindows()
45
+ if k == ord("q") or k == ord("d"):
46
+ return k
47
+ return k
48
+
49
+
50
+ def show_console(
51
+ image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
52
+ width: int,
53
+ back: str,
54
+ interactive: bool = False,
55
+ ) -> int:
56
+ from ascii_magic import from_pillow_image, Back
57
+
58
+ back = None
59
+ if back is not None:
60
+ back = getattr(Back, back.upper())
61
+
62
+ k = None
63
+ imgs = (
64
+ [image]
65
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
66
+ else [image[0], image[-1]]
67
+ )
68
+ if len(imgs) > 1:
69
+ print(f"Number of images: {len(image)}")
70
+ for i, img in enumerate(imgs):
71
+ if len(imgs) > 1:
72
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
73
+ else:
74
+ print(f"Image: {img.shape}")
75
+ pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
76
+ ascii_img = from_pillow_image(pil_img)
77
+ ascii_img.to_terminal(columns=width, back=back)
78
+
79
+ if interactive:
80
+ k = input("Press q to quit, d to next dataset, other key to next: ")
81
+ if k == "q" or k == "d":
82
+ return ord(k)
83
+
84
+ if not interactive:
85
+ return ord(" ")
86
+ return ord(k) if k else ord(" ")
87
+
88
+
89
+ def show_datasets(
90
+ datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
91
+ ):
92
+ print(f"d: next dataset, q: quit")
93
+
94
+ num_workers = max(1, os.cpu_count() - 1)
95
+ for i, dataset in enumerate(datasets):
96
+ print(f"Dataset [{i}]")
97
+ batch_index = 0
98
+ num_images_to_show = console_num_images
99
+ k = None
100
+ for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
101
+ print(f"bucket resolution: {key}, count: {len(batch)}")
102
+ for j, item_info in enumerate(batch):
103
+ item_info: ItemInfo
104
+ print(f"{batch_index}-{j}: {item_info}")
105
+ if debug_mode == "image":
106
+ k = show_image(item_info.content)
107
+ elif debug_mode == "console":
108
+ k = show_console(item_info.content, console_width, console_back, console_num_images is None)
109
+ if num_images_to_show is not None:
110
+ num_images_to_show -= 1
111
+ if num_images_to_show == 0:
112
+ k = ord("d") # next dataset
113
+
114
+ if k == ord("q"):
115
+ return
116
+ elif k == ord("d"):
117
+ break
118
+ if k == ord("d"):
119
+ break
120
+ batch_index += 1
121
+
122
+
123
+ def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
124
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
125
+ if len(contents.shape) == 4:
126
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
127
+
128
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
129
+ contents = contents.to(vae.device, dtype=vae.dtype)
130
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
131
+
132
+ h, w = contents.shape[3], contents.shape[4]
133
+ if h < 8 or w < 8:
134
+ item = batch[0] # other items should have the same size
135
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
136
+
137
+ # print(f"encode batch: {contents.shape}")
138
+ with torch.no_grad():
139
+ latent = vae.encode(contents).latent_dist.sample()
140
+ # latent = latent * vae.config.scaling_factor
141
+
142
+ # # debug: decode and save
143
+ # with torch.no_grad():
144
+ # latent_to_decode = latent / vae.config.scaling_factor
145
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
146
+ # images = (images / 2 + 0.5).clamp(0, 1)
147
+ # images = images.cpu().float().numpy()
148
+ # images = (images * 255).astype(np.uint8)
149
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
150
+ # for b in range(images.shape[0]):
151
+ # for f in range(images.shape[1]):
152
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
153
+ # img = Image.fromarray(images[b, f])
154
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
155
+
156
+ for item, l in zip(batch, latent):
157
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
158
+ save_latent_cache(item, l)
159
+
160
+
161
+ def encode_datasets(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
162
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
163
+ for i, dataset in enumerate(datasets):
164
+ logger.info(f"Encoding dataset [{i}]")
165
+ all_latent_cache_paths = []
166
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
167
+ all_latent_cache_paths.extend([item.latent_cache_path for item in batch])
168
+
169
+ if args.skip_existing:
170
+ filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
171
+ if len(filtered_batch) == 0:
172
+ continue
173
+ batch = filtered_batch
174
+
175
+ bs = args.batch_size if args.batch_size is not None else len(batch)
176
+ for i in range(0, len(batch), bs):
177
+ encode(batch[i : i + bs])
178
+
179
+ # normalize paths
180
+ all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
181
+ all_latent_cache_paths = set(all_latent_cache_paths)
182
+
183
+ # remove old cache files not in the dataset
184
+ all_cache_files = dataset.get_all_latent_cache_files()
185
+ for cache_file in all_cache_files:
186
+ if os.path.normpath(cache_file) not in all_latent_cache_paths:
187
+ if args.keep_cache:
188
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
189
+ else:
190
+ os.remove(cache_file)
191
+ logger.info(f"Removed old cache file: {cache_file}")
192
+
193
+
194
+ def main(args):
195
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
196
+ device = torch.device(device)
197
+
198
+ # Load dataset config
199
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
200
+ logger.info(f"Load dataset config from {args.dataset_config}")
201
+ user_config = config_utils.load_user_config(args.dataset_config)
202
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
203
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
204
+
205
+ datasets = train_dataset_group.datasets
206
+
207
+ if args.debug_mode is not None:
208
+ show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
209
+ return
210
+
211
+ assert args.vae is not None, "vae checkpoint is required"
212
+
213
+ # Load VAE model: HunyuanVideo VAE model is float16
214
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
215
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
216
+ vae.eval()
217
+ logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
218
+
219
+ if args.vae_chunk_size is not None:
220
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
221
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
222
+ if args.vae_spatial_tile_sample_min_size is not None:
223
+ vae.enable_spatial_tiling(True)
224
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
225
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
226
+ elif args.vae_tiling:
227
+ vae.enable_spatial_tiling(True)
228
+
229
+ # Encode images
230
+ def encode(one_batch: list[ItemInfo]):
231
+ encode_and_save_batch(vae, one_batch)
232
+
233
+ encode_datasets(datasets, encode, args)
234
+
235
+
236
+ def setup_parser_common() -> argparse.ArgumentParser:
237
+ parser = argparse.ArgumentParser()
238
+
239
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
240
+ parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
241
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
242
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
243
+ parser.add_argument(
244
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
245
+ )
246
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
247
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
248
+ parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
249
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
250
+ parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
251
+ parser.add_argument(
252
+ "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
253
+ )
254
+ parser.add_argument(
255
+ "--console_num_images",
256
+ type=int,
257
+ default=None,
258
+ help="debug mode: not interactive, number of images to show for each dataset",
259
+ )
260
+ return parser
261
+
262
+
263
+ def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
264
+ parser.add_argument(
265
+ "--vae_tiling",
266
+ action="store_true",
267
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
268
+ )
269
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
270
+ parser.add_argument(
271
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
272
+ )
273
+ return parser
274
+
275
+
276
+ if __name__ == "__main__":
277
+ parser = setup_parser_common()
278
+ parser = hv_setup_parser(parser)
279
+
280
+ args = parser.parse_args()
281
+ main(args)
cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ import accelerate
12
+
13
+ from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, BaseDataset, ItemInfo, save_text_encoder_output_cache
14
+ from hunyuan_model import text_encoder as text_encoder_module
15
+ from hunyuan_model.text_encoder import TextEncoder
16
+
17
+ import logging
18
+
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
26
+ data_type = "video" # video only, image is not supported
27
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
28
+
29
+ with torch.no_grad():
30
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
31
+
32
+ return prompt_outputs.hidden_state, prompt_outputs.attention_mask
33
+
34
+
35
+ def encode_and_save_batch(
36
+ text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
37
+ ):
38
+ prompts = [item.caption for item in batch]
39
+ # print(prompts)
40
+
41
+ # encode prompt
42
+ if accelerator is not None:
43
+ with accelerator.autocast():
44
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
45
+ else:
46
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
47
+
48
+ # # convert to fp16 if needed
49
+ # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
50
+ # prompt_embeds = prompt_embeds.to(text_encoder.dtype)
51
+
52
+ # save prompt cache
53
+ for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
54
+ save_text_encoder_output_cache(item, embed, mask, is_llm)
55
+
56
+
57
+ def prepare_cache_files_and_paths(datasets: list[BaseDataset]):
58
+ all_cache_files_for_dataset = [] # exisiting cache files
59
+ all_cache_paths_for_dataset = [] # all cache paths in the dataset
60
+ for dataset in datasets:
61
+ all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()]
62
+ all_cache_files = set(all_cache_files)
63
+ all_cache_files_for_dataset.append(all_cache_files)
64
+
65
+ all_cache_paths_for_dataset.append(set())
66
+ return all_cache_files_for_dataset, all_cache_paths_for_dataset
67
+
68
+
69
+ def process_text_encoder_batches(
70
+ num_workers: Optional[int],
71
+ skip_existing: bool,
72
+ batch_size: int,
73
+ datasets: list[BaseDataset],
74
+ all_cache_files_for_dataset: list[set],
75
+ all_cache_paths_for_dataset: list[set],
76
+ encode: callable,
77
+ ):
78
+ num_workers = num_workers if num_workers is not None else max(1, os.cpu_count() - 1)
79
+ for i, dataset in enumerate(datasets):
80
+ logger.info(f"Encoding dataset [{i}]")
81
+ all_cache_files = all_cache_files_for_dataset[i]
82
+ all_cache_paths = all_cache_paths_for_dataset[i]
83
+ for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
84
+ # update cache files (it's ok if we update it multiple times)
85
+ all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch])
86
+
87
+ # skip existing cache files
88
+ if skip_existing:
89
+ filtered_batch = [
90
+ item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files
91
+ ]
92
+ # print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files")
93
+ if len(filtered_batch) == 0:
94
+ continue
95
+ batch = filtered_batch
96
+
97
+ bs = batch_size if batch_size is not None else len(batch)
98
+ for i in range(0, len(batch), bs):
99
+ encode(batch[i : i + bs])
100
+
101
+
102
+ def post_process_cache_files(
103
+ datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set]
104
+ ):
105
+ for i, dataset in enumerate(datasets):
106
+ all_cache_files = all_cache_files_for_dataset[i]
107
+ all_cache_paths = all_cache_paths_for_dataset[i]
108
+ for cache_file in all_cache_files:
109
+ if cache_file not in all_cache_paths:
110
+ if args.keep_cache:
111
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
112
+ else:
113
+ os.remove(cache_file)
114
+ logger.info(f"Removed old cache file: {cache_file}")
115
+
116
+
117
+ def main(args):
118
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
119
+ device = torch.device(device)
120
+
121
+ # Load dataset config
122
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
123
+ logger.info(f"Load dataset config from {args.dataset_config}")
124
+ user_config = config_utils.load_user_config(args.dataset_config)
125
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
126
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
127
+
128
+ datasets = train_dataset_group.datasets
129
+
130
+ # define accelerator for fp8 inference
131
+ accelerator = None
132
+ if args.fp8_llm:
133
+ accelerator = accelerate.Accelerator(mixed_precision="fp16")
134
+
135
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
136
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = prepare_cache_files_and_paths(datasets)
137
+
138
+ # Load Text Encoder 1
139
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
140
+ logger.info(f"loading text encoder 1: {args.text_encoder1}")
141
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
142
+ text_encoder_1.to(device=device)
143
+
144
+ # Encode with Text Encoder 1 (LLM)
145
+ logger.info("Encoding with Text Encoder 1")
146
+
147
+ def encode_for_text_encoder_1(batch: list[ItemInfo]):
148
+ encode_and_save_batch(text_encoder_1, batch, is_llm=True, accelerator=accelerator)
149
+
150
+ process_text_encoder_batches(
151
+ args.num_workers,
152
+ args.skip_existing,
153
+ args.batch_size,
154
+ datasets,
155
+ all_cache_files_for_dataset,
156
+ all_cache_paths_for_dataset,
157
+ encode_for_text_encoder_1,
158
+ )
159
+ del text_encoder_1
160
+
161
+ # Load Text Encoder 2
162
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
163
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
164
+ text_encoder_2.to(device=device)
165
+
166
+ # Encode with Text Encoder 2
167
+ logger.info("Encoding with Text Encoder 2")
168
+
169
+ def encode_for_text_encoder_2(batch: list[ItemInfo]):
170
+ encode_and_save_batch(text_encoder_2, batch, is_llm=False, accelerator=None)
171
+
172
+ process_text_encoder_batches(
173
+ args.num_workers,
174
+ args.skip_existing,
175
+ args.batch_size,
176
+ datasets,
177
+ all_cache_files_for_dataset,
178
+ all_cache_paths_for_dataset,
179
+ encode_for_text_encoder_2,
180
+ )
181
+ del text_encoder_2
182
+
183
+ # remove cache files not in dataset
184
+ post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset)
185
+
186
+
187
+ def setup_parser_common():
188
+ parser = argparse.ArgumentParser()
189
+
190
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
191
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
192
+ parser.add_argument(
193
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
194
+ )
195
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
196
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
197
+ parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
198
+ return parser
199
+
200
+
201
+ def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
202
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
203
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
204
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
205
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
206
+ return parser
207
+
208
+
209
+ if __name__ == "__main__":
210
+ parser = setup_parser_common()
211
+ parser = hv_setup_parser(parser)
212
+
213
+ args = parser.parse_args()
214
+ main(args)
convert_lora.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from safetensors.torch import load_file, save_file
5
+ from safetensors import safe_open
6
+ from utils import model_utils
7
+
8
+ import logging
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ def convert_from_diffusers(prefix, weights_sd):
16
+ # convert from diffusers(?) to default LoRA
17
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
18
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
19
+
20
+ # note: Diffusers has no alpha, so alpha is set to rank
21
+ new_weights_sd = {}
22
+ lora_dims = {}
23
+ for key, weight in weights_sd.items():
24
+ diffusers_prefix, key_body = key.split(".", 1)
25
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
26
+ logger.warning(f"unexpected key: {key} in diffusers format")
27
+ continue
28
+
29
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
30
+ new_weights_sd[new_key] = weight
31
+
32
+ lora_name = new_key.split(".")[0] # before first dot
33
+ if lora_name not in lora_dims and "lora_down" in new_key:
34
+ lora_dims[lora_name] = weight.shape[0]
35
+
36
+ # add alpha with rank
37
+ for lora_name, dim in lora_dims.items():
38
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
39
+
40
+ return new_weights_sd
41
+
42
+
43
+ def convert_to_diffusers(prefix, weights_sd):
44
+ # convert from default LoRA to diffusers
45
+
46
+ # get alphas
47
+ lora_alphas = {}
48
+ for key, weight in weights_sd.items():
49
+ if key.startswith(prefix):
50
+ lora_name = key.split(".", 1)[0] # before first dot
51
+ if lora_name not in lora_alphas and "alpha" in key:
52
+ lora_alphas[lora_name] = weight
53
+
54
+ new_weights_sd = {}
55
+ for key, weight in weights_sd.items():
56
+ if key.startswith(prefix):
57
+ if "alpha" in key:
58
+ continue
59
+
60
+ lora_name = key.split(".", 1)[0] # before first dot
61
+
62
+ module_name = lora_name[len(prefix) :] # remove "lora_unet_"
63
+ module_name = module_name.replace("_", ".") # replace "_" with "."
64
+ if ".cross.attn." in module_name or ".self.attn." in module_name:
65
+ # Wan2.1 lora name to module name: ugly but works
66
+ module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
67
+ module_name = module_name.replace("self.attn", "self_attn") # fix self attn
68
+ else:
69
+ # HunyuanVideo lora name to module name: ugly but works
70
+ module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
71
+ module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
72
+ module_name = module_name.replace("img.", "img_") # fix img
73
+ module_name = module_name.replace("txt.", "txt_") # fix txt
74
+ module_name = module_name.replace("attn.", "attn_") # fix attn
75
+
76
+ diffusers_prefix = "diffusion_model"
77
+ if "lora_down" in key:
78
+ new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
79
+ dim = weight.shape[0]
80
+ elif "lora_up" in key:
81
+ new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
82
+ dim = weight.shape[1]
83
+ else:
84
+ logger.warning(f"unexpected key: {key} in default LoRA format")
85
+ continue
86
+
87
+ # scale weight by alpha
88
+ if lora_name in lora_alphas:
89
+ # we scale both down and up, so scale is sqrt
90
+ scale = lora_alphas[lora_name] / dim
91
+ scale = scale.sqrt()
92
+ weight = weight * scale
93
+ else:
94
+ logger.warning(f"missing alpha for {lora_name}")
95
+
96
+ new_weights_sd[new_key] = weight
97
+
98
+ return new_weights_sd
99
+
100
+
101
+ def convert(input_file, output_file, target_format):
102
+ logger.info(f"loading {input_file}")
103
+ weights_sd = load_file(input_file)
104
+ with safe_open(input_file, framework="pt") as f:
105
+ metadata = f.metadata()
106
+
107
+ logger.info(f"converting to {target_format}")
108
+ prefix = "lora_unet_"
109
+ if target_format == "default":
110
+ new_weights_sd = convert_from_diffusers(prefix, weights_sd)
111
+ metadata = metadata or {}
112
+ model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
113
+ elif target_format == "other":
114
+ new_weights_sd = convert_to_diffusers(prefix, weights_sd)
115
+ else:
116
+ raise ValueError(f"unknown target format: {target_format}")
117
+
118
+ logger.info(f"saving to {output_file}")
119
+ save_file(new_weights_sd, output_file, metadata=metadata)
120
+
121
+ logger.info("done")
122
+
123
+
124
+ def parse_args():
125
+ parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
126
+ parser.add_argument("--input", type=str, required=True, help="input model file")
127
+ parser.add_argument("--output", type=str, required=True, help="output model file")
128
+ parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
129
+ args = parser.parse_args()
130
+ return args
131
+
132
+
133
+ if __name__ == "__main__":
134
+ args = parse_args()
135
+ convert(args.input, args.output, args.target)
dataset/__init__.py ADDED
File without changes
dataset/config_utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
18
+
19
+ from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ @dataclass
28
+ class BaseDatasetParams:
29
+ resolution: Tuple[int, int] = (960, 544)
30
+ enable_bucket: bool = False
31
+ bucket_no_upscale: bool = False
32
+ caption_extension: Optional[str] = None
33
+ batch_size: int = 1
34
+ num_repeats: int = 1
35
+ cache_directory: Optional[str] = None
36
+ debug_dataset: bool = False
37
+ architecture: str = "no_default" # short style like "hv" or "wan"
38
+
39
+
40
+ @dataclass
41
+ class ImageDatasetParams(BaseDatasetParams):
42
+ image_directory: Optional[str] = None
43
+ image_jsonl_file: Optional[str] = None
44
+
45
+
46
+ @dataclass
47
+ class VideoDatasetParams(BaseDatasetParams):
48
+ video_directory: Optional[str] = None
49
+ video_jsonl_file: Optional[str] = None
50
+ target_frames: Sequence[int] = (1,)
51
+ frame_extraction: Optional[str] = "head"
52
+ frame_stride: Optional[int] = 1
53
+ frame_sample: Optional[int] = 1
54
+
55
+
56
+ @dataclass
57
+ class DatasetBlueprint:
58
+ is_image_dataset: bool
59
+ params: Union[ImageDatasetParams, VideoDatasetParams]
60
+
61
+
62
+ @dataclass
63
+ class DatasetGroupBlueprint:
64
+ datasets: Sequence[DatasetBlueprint]
65
+
66
+
67
+ @dataclass
68
+ class Blueprint:
69
+ dataset_group: DatasetGroupBlueprint
70
+
71
+
72
+ class ConfigSanitizer:
73
+ # @curry
74
+ @staticmethod
75
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
76
+ Schema(ExactSequence([klass, klass]))(value)
77
+ return tuple(value)
78
+
79
+ # @curry
80
+ @staticmethod
81
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
82
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
83
+ try:
84
+ Schema(klass)(value)
85
+ return (value, value)
86
+ except:
87
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
88
+
89
+ # datasets schema
90
+ DATASET_ASCENDABLE_SCHEMA = {
91
+ "caption_extension": str,
92
+ "batch_size": int,
93
+ "num_repeats": int,
94
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
95
+ "enable_bucket": bool,
96
+ "bucket_no_upscale": bool,
97
+ }
98
+ IMAGE_DATASET_DISTINCT_SCHEMA = {
99
+ "image_directory": str,
100
+ "image_jsonl_file": str,
101
+ "cache_directory": str,
102
+ }
103
+ VIDEO_DATASET_DISTINCT_SCHEMA = {
104
+ "video_directory": str,
105
+ "video_jsonl_file": str,
106
+ "target_frames": [int],
107
+ "frame_extraction": str,
108
+ "frame_stride": int,
109
+ "frame_sample": int,
110
+ "cache_directory": str,
111
+ }
112
+
113
+ # options handled by argparse but not handled by user config
114
+ ARGPARSE_SPECIFIC_SCHEMA = {
115
+ "debug_dataset": bool,
116
+ }
117
+
118
+ def __init__(self) -> None:
119
+ self.image_dataset_schema = self.__merge_dict(
120
+ self.DATASET_ASCENDABLE_SCHEMA,
121
+ self.IMAGE_DATASET_DISTINCT_SCHEMA,
122
+ )
123
+ self.video_dataset_schema = self.__merge_dict(
124
+ self.DATASET_ASCENDABLE_SCHEMA,
125
+ self.VIDEO_DATASET_DISTINCT_SCHEMA,
126
+ )
127
+
128
+ def validate_flex_dataset(dataset_config: dict):
129
+ if "target_frames" in dataset_config:
130
+ return Schema(self.video_dataset_schema)(dataset_config)
131
+ else:
132
+ return Schema(self.image_dataset_schema)(dataset_config)
133
+
134
+ self.dataset_schema = validate_flex_dataset
135
+
136
+ self.general_schema = self.__merge_dict(
137
+ self.DATASET_ASCENDABLE_SCHEMA,
138
+ )
139
+ self.user_config_validator = Schema(
140
+ {
141
+ "general": self.general_schema,
142
+ "datasets": [self.dataset_schema],
143
+ }
144
+ )
145
+ self.argparse_schema = self.__merge_dict(
146
+ self.ARGPARSE_SPECIFIC_SCHEMA,
147
+ )
148
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
149
+
150
+ def sanitize_user_config(self, user_config: dict) -> dict:
151
+ try:
152
+ return self.user_config_validator(user_config)
153
+ except MultipleInvalid:
154
+ # TODO: clarify the error message
155
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
156
+ raise
157
+
158
+ # NOTE: In nature, argument parser result is not needed to be sanitize
159
+ # However this will help us to detect program bug
160
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
161
+ try:
162
+ return self.argparse_config_validator(argparse_namespace)
163
+ except MultipleInvalid:
164
+ # XXX: this should be a bug
165
+ logger.error(
166
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
167
+ )
168
+ raise
169
+
170
+ # NOTE: value would be overwritten by latter dict if there is already the same key
171
+ @staticmethod
172
+ def __merge_dict(*dict_list: dict) -> dict:
173
+ merged = {}
174
+ for schema in dict_list:
175
+ # merged |= schema
176
+ for k, v in schema.items():
177
+ merged[k] = v
178
+ return merged
179
+
180
+
181
+ class BlueprintGenerator:
182
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
183
+
184
+ def __init__(self, sanitizer: ConfigSanitizer):
185
+ self.sanitizer = sanitizer
186
+
187
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
188
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
189
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
190
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
191
+
192
+ argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
193
+ general_config = sanitized_user_config.get("general", {})
194
+
195
+ dataset_blueprints = []
196
+ for dataset_config in sanitized_user_config.get("datasets", []):
197
+ is_image_dataset = "target_frames" not in dataset_config
198
+ if is_image_dataset:
199
+ dataset_params_klass = ImageDatasetParams
200
+ else:
201
+ dataset_params_klass = VideoDatasetParams
202
+
203
+ params = self.generate_params_by_fallbacks(
204
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
205
+ )
206
+ dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
207
+
208
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
209
+
210
+ return Blueprint(dataset_group_blueprint)
211
+
212
+ @staticmethod
213
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
214
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
215
+ search_value = BlueprintGenerator.search_value
216
+ default_params = asdict(param_klass())
217
+ param_names = default_params.keys()
218
+
219
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
220
+
221
+ return param_klass(**params)
222
+
223
+ @staticmethod
224
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
225
+ for cand in fallbacks:
226
+ value = cand.get(key)
227
+ if value is not None:
228
+ return value
229
+
230
+ return default_value
231
+
232
+
233
+ # if training is True, it will return a dataset group for training, otherwise for caching
234
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
235
+ datasets: List[Union[ImageDataset, VideoDataset]] = []
236
+
237
+ for dataset_blueprint in dataset_group_blueprint.datasets:
238
+ if dataset_blueprint.is_image_dataset:
239
+ dataset_klass = ImageDataset
240
+ else:
241
+ dataset_klass = VideoDataset
242
+
243
+ dataset = dataset_klass(**asdict(dataset_blueprint.params))
244
+ datasets.append(dataset)
245
+
246
+ # assertion
247
+ cache_directories = [dataset.cache_directory for dataset in datasets]
248
+ num_of_unique_cache_directories = len(set(cache_directories))
249
+ if num_of_unique_cache_directories != len(cache_directories):
250
+ raise ValueError(
251
+ "cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)"
252
+ + " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)"
253
+ )
254
+
255
+ # print info
256
+ info = ""
257
+ for i, dataset in enumerate(datasets):
258
+ is_image_dataset = isinstance(dataset, ImageDataset)
259
+ info += dedent(
260
+ f"""\
261
+ [Dataset {i}]
262
+ is_image_dataset: {is_image_dataset}
263
+ resolution: {dataset.resolution}
264
+ batch_size: {dataset.batch_size}
265
+ num_repeats: {dataset.num_repeats}
266
+ caption_extension: "{dataset.caption_extension}"
267
+ enable_bucket: {dataset.enable_bucket}
268
+ bucket_no_upscale: {dataset.bucket_no_upscale}
269
+ cache_directory: "{dataset.cache_directory}"
270
+ debug_dataset: {dataset.debug_dataset}
271
+ """
272
+ )
273
+
274
+ if is_image_dataset:
275
+ info += indent(
276
+ dedent(
277
+ f"""\
278
+ image_directory: "{dataset.image_directory}"
279
+ image_jsonl_file: "{dataset.image_jsonl_file}"
280
+ \n"""
281
+ ),
282
+ " ",
283
+ )
284
+ else:
285
+ info += indent(
286
+ dedent(
287
+ f"""\
288
+ video_directory: "{dataset.video_directory}"
289
+ video_jsonl_file: "{dataset.video_jsonl_file}"
290
+ target_frames: {dataset.target_frames}
291
+ frame_extraction: {dataset.frame_extraction}
292
+ frame_stride: {dataset.frame_stride}
293
+ frame_sample: {dataset.frame_sample}
294
+ \n"""
295
+ ),
296
+ " ",
297
+ )
298
+ logger.info(f"{info}")
299
+
300
+ # make buckets first because it determines the length of dataset
301
+ # and set the same seed for all datasets
302
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
303
+ for i, dataset in enumerate(datasets):
304
+ # logger.info(f"[Dataset {i}]")
305
+ dataset.set_seed(seed)
306
+ if training:
307
+ dataset.prepare_for_training()
308
+
309
+ return DatasetGroup(datasets)
310
+
311
+
312
+ def load_user_config(file: str) -> dict:
313
+ file: Path = Path(file)
314
+ if not file.is_file():
315
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
316
+
317
+ if file.name.lower().endswith(".json"):
318
+ try:
319
+ with open(file, "r", encoding="utf-8") as f:
320
+ config = json.load(f)
321
+ except Exception:
322
+ logger.error(
323
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
324
+ )
325
+ raise
326
+ elif file.name.lower().endswith(".toml"):
327
+ try:
328
+ config = toml.load(file)
329
+ except Exception:
330
+ logger.error(
331
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
332
+ )
333
+ raise
334
+ else:
335
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
336
+
337
+ return config
338
+
339
+
340
+ # for config test
341
+ if __name__ == "__main__":
342
+ parser = argparse.ArgumentParser()
343
+ parser.add_argument("dataset_config")
344
+ config_args, remain = parser.parse_known_args()
345
+
346
+ parser = argparse.ArgumentParser()
347
+ parser.add_argument("--debug_dataset", action="store_true")
348
+ argparse_namespace = parser.parse_args(remain)
349
+
350
+ logger.info("[argparse_namespace]")
351
+ logger.info(f"{vars(argparse_namespace)}")
352
+
353
+ user_config = load_user_config(config_args.dataset_config)
354
+
355
+ logger.info("")
356
+ logger.info("[user_config]")
357
+ logger.info(f"{user_config}")
358
+
359
+ sanitizer = ConfigSanitizer()
360
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
361
+
362
+ logger.info("")
363
+ logger.info("[sanitized_user_config]")
364
+ logger.info(f"{sanitized_user_config}")
365
+
366
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
367
+
368
+ logger.info("")
369
+ logger.info("[blueprint]")
370
+ logger.info(f"{blueprint}")
371
+
372
+ dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
dataset/dataset_config.md ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ ## Dataset Configuration
4
+
5
+ <details>
6
+ <summary>English</summary>
7
+
8
+ Please create a TOML file for dataset configuration.
9
+
10
+ Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
11
+
12
+ The cache directory must be different for each dataset.
13
+ </details>
14
+
15
+ <details>
16
+ <summary>日本語</summary>
17
+
18
+ データセットの設定を行うためのTOMLファイルを作成してください。
19
+
20
+ 画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
21
+
22
+ キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
23
+ </details>
24
+
25
+ ### Sample for Image Dataset with Caption Text Files
26
+
27
+ ```toml
28
+ # resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
29
+ # otherwise, the default values will be used for each item
30
+
31
+ # general configurations
32
+ [general]
33
+ resolution = [960, 544]
34
+ caption_extension = ".txt"
35
+ batch_size = 1
36
+ enable_bucket = true
37
+ bucket_no_upscale = false
38
+
39
+ [[datasets]]
40
+ image_directory = "/path/to/image_dir"
41
+ cache_directory = "/path/to/cache_directory"
42
+ num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
43
+
44
+ # other datasets can be added here. each dataset can have different configurations
45
+ ```
46
+
47
+ <details>
48
+ <summary>English</summary>
49
+
50
+ `cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
51
+
52
+ `num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
53
+
54
+ </details>
55
+
56
+ <details>
57
+ <summary>日本語</summary>
58
+
59
+ `cache_directory` はオプションです。デフォルトは画像ディレクトリと同じディレクトリに設定されます。ただし、異なるデータセット間でキャッシュファイルが共有されるのを防ぐために、明示的に別のキャッシュディレクトリを設定することをお勧めします。
60
+
61
+ `num_repeats` はオプションで、デフォルトは 1 です(繰り返しなし)。画像(や動画)を、その回数だけ単純に繰り返してデータセットを拡張します。たとえば`num_repeats = 2`としたとき、画像20枚のデータセットなら、各画像が2枚ずつ(同一のキャプションで)計40枚存在した場合と同じになります。異なるデータ数のデータセット間でバランスを取るために使用可能です。
62
+
63
+ resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
64
+
65
+ `[[datasets]]`以下を追加することで、他のデータセットを追加できます。各データセットには異なる設定を持てます。
66
+ </details>
67
+
68
+ ### Sample for Image Dataset with Metadata JSONL File
69
+
70
+ ```toml
71
+ # resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
72
+ # caption_extension is not required for metadata jsonl file
73
+ # cache_directory is required for each dataset with metadata jsonl file
74
+
75
+ # general configurations
76
+ [general]
77
+ resolution = [960, 544]
78
+ batch_size = 1
79
+ enable_bucket = true
80
+ bucket_no_upscale = false
81
+
82
+ [[datasets]]
83
+ image_jsonl_file = "/path/to/metadata.jsonl"
84
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
85
+ num_repeats = 1 # optional, default is 1. Same as above.
86
+
87
+ # other datasets can be added here. each dataset can have different configurations
88
+ ```
89
+
90
+ JSONL file format for metadata:
91
+
92
+ ```json
93
+ {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
94
+ {"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
95
+ ```
96
+
97
+ <details>
98
+ <summary>日本語</summary>
99
+
100
+ resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
101
+
102
+ metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須��す。
103
+
104
+ キャプションによるデータセットと同様に、複数のデータセットを追加できます。各データセットには異なる設定を持てます。
105
+ </details>
106
+
107
+
108
+ ### Sample for Video Dataset with Caption Text Files
109
+
110
+ ```toml
111
+ # resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample,
112
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
113
+ # num_repeats is also available for video dataset, example is not shown here
114
+
115
+ # general configurations
116
+ [general]
117
+ resolution = [960, 544]
118
+ caption_extension = ".txt"
119
+ batch_size = 1
120
+ enable_bucket = true
121
+ bucket_no_upscale = false
122
+
123
+ [[datasets]]
124
+ video_directory = "/path/to/video_dir"
125
+ cache_directory = "/path/to/cache_directory" # recommended to set cache directory
126
+ target_frames = [1, 25, 45]
127
+ frame_extraction = "head"
128
+
129
+ # other datasets can be added here. each dataset can have different configurations
130
+ ```
131
+
132
+ <details>
133
+ <summary>日本語</summary>
134
+
135
+ resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
136
+
137
+ 他の注意事項は画像データセットと同様です。
138
+ </details>
139
+
140
+ ### Sample for Video Dataset with Metadata JSONL File
141
+
142
+ ```toml
143
+ # resolution, target_frames, frame_extraction, frame_stride, frame_sample,
144
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
145
+ # caption_extension is not required for metadata jsonl file
146
+ # cache_directory is required for each dataset with metadata jsonl file
147
+
148
+ # general configurations
149
+ [general]
150
+ resolution = [960, 544]
151
+ batch_size = 1
152
+ enable_bucket = true
153
+ bucket_no_upscale = false
154
+
155
+ [[datasets]]
156
+ video_jsonl_file = "/path/to/metadata.jsonl"
157
+ target_frames = [1, 25, 45]
158
+ frame_extraction = "head"
159
+ cache_directory = "/path/to/cache_directory_head"
160
+
161
+ # same metadata jsonl file can be used for multiple datasets
162
+ [[datasets]]
163
+ video_jsonl_file = "/path/to/metadata.jsonl"
164
+ target_frames = [1]
165
+ frame_stride = 10
166
+ cache_directory = "/path/to/cache_directory_stride"
167
+
168
+ # other datasets can be added here. each dataset can have different configurations
169
+ ```
170
+
171
+ JSONL file format for metadata:
172
+
173
+ ```json
174
+ {"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
175
+ {"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
176
+ ```
177
+
178
+ <details>
179
+ <summary>日本語</summary>
180
+
181
+ resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
182
+
183
+ metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
184
+
185
+ 他の注意事項は今までのデータセットと同様です。
186
+ </details>
187
+
188
+ ### frame_extraction Options
189
+
190
+ <details>
191
+ <summary>English</summary>
192
+
193
+ - `head`: Extract the first N frames from the video.
194
+ - `chunk`: Extract frames by splitting the video into chunks of N frames.
195
+ - `slide`: Extract frames from the video with a stride of `frame_stride`.
196
+ - `uniform`: Extract `frame_sample` samples uniformly from the video.
197
+
198
+ For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
199
+ </details>
200
+
201
+ <details>
202
+ <summary>日本語</summary>
203
+
204
+ - `head`: 動画から最初のNフレームを抽出します。
205
+ - `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
206
+ - `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
207
+ - `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
208
+
209
+ 例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
210
+ </details>
211
+
212
+ ```
213
+ Original Video, 40 frames: x = frame, o = no frame
214
+ oooooooooooooooooooooooooooooooooooooooo
215
+
216
+ head, target_frames = [1, 13, 25] -> extract head frames:
217
+ xooooooooooooooooooooooooooooooooooooooo
218
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
219
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
220
+
221
+ chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
222
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
223
+ oooooooooooooxxxxxxxxxxxxxoooooooooooooo
224
+ ooooooooooooooooooooooooooxxxxxxxxxxxxxo
225
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
226
+
227
+ NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
228
+ 注: frame_extraction "chunk" を使用する場合、target_frames に 1 を含めないでください。全てのフレームが抽出されてしまいます。
229
+
230
+ slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
231
+ xooooooooooooooooooooooooooooooooooooooo
232
+ ooooooooooxooooooooooooooooooooooooooooo
233
+ ooooooooooooooooooooxooooooooooooooooooo
234
+ ooooooooooooooooooooooooooooooxooooooooo
235
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
236
+ ooooooooooxxxxxxxxxxxxxooooooooooooooooo
237
+ ooooooooooooooooooooxxxxxxxxxxxxxooooooo
238
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
239
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
240
+
241
+ uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
242
+ xooooooooooooooooooooooooooooooooooooooo
243
+ oooooooooooooxoooooooooooooooooooooooooo
244
+ oooooooooooooooooooooooooxoooooooooooooo
245
+ ooooooooooooooooooooooooooooooooooooooox
246
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
247
+ oooooooooxxxxxxxxxxxxxoooooooooooooooooo
248
+ ooooooooooooooooooxxxxxxxxxxxxxooooooooo
249
+ oooooooooooooooooooooooooooxxxxxxxxxxxxx
250
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
251
+ oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
252
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
253
+ oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
254
+ ```
255
+
256
+ ## Specifications
257
+
258
+ ```toml
259
+ # general configurations
260
+ [general]
261
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
262
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
263
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
264
+ num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
265
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
266
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
267
+
268
+ ### Image Dataset
269
+
270
+ # sample image dataset with caption text files
271
+ [[datasets]]
272
+ image_directory = "/path/to/image_dir"
273
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
274
+ resolution = [960, 544] # required if general resolution is not set
275
+ batch_size = 4 # optional, overwrite the default batch size
276
+ num_repeats = 1 # optional, overwrite the default num_repeats
277
+ enable_bucket = false # optional, overwrite the default bucketing setting
278
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
279
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
280
+
281
+ # sample image dataset with metadata **jsonl** file
282
+ [[datasets]]
283
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
284
+ resolution = [960, 544] # required if general resolution is not set
285
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
286
+ # caption_extension is not required for metadata jsonl file
287
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
288
+
289
+ ### Video Dataset
290
+
291
+ # sample video dataset with caption text files
292
+ [[datasets]]
293
+ video_directory = "/path/to/video_dir"
294
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
295
+ resolution = [960, 544] # required if general resolution is not set
296
+
297
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
298
+
299
+ # NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
300
+
301
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
302
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
303
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
304
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
305
+
306
+ # sample video dataset with metadata jsonl file
307
+ [[datasets]]
308
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
309
+
310
+ target_frames = [1, 79]
311
+
312
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
313
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
314
+ ```
315
+
316
+ <!--
317
+ # sample image dataset with lance
318
+ [[datasets]]
319
+ image_lance_dataset = "/path/to/lance_dataset"
320
+ resolution = [960, 544] # required if general resolution is not set
321
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
322
+ -->
323
+
324
+ The metadata with .json file will be supported in the near future.
325
+
326
+
327
+
328
+ <!--
329
+
330
+ ```toml
331
+ # general configurations
332
+ [general]
333
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
334
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
335
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
336
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
337
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
338
+
339
+ # sample image dataset with caption text files
340
+ [[datasets]]
341
+ image_directory = "/path/to/image_dir"
342
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
343
+ resolution = [960, 544] # required if general resolution is not set
344
+ batch_size = 4 # optional, overwrite the default batch size
345
+ enable_bucket = false # optional, overwrite the default bucketing setting
346
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
347
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
348
+
349
+ # sample image dataset with metadata **jsonl** file
350
+ [[datasets]]
351
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
352
+ resolution = [960, 544] # required if general resolution is not set
353
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
354
+ # caption_extension is not required for metadata jsonl file
355
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
356
+
357
+ # sample video dataset with caption text files
358
+ [[datasets]]
359
+ video_directory = "/path/to/video_dir"
360
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
361
+ resolution = [960, 544] # required if general resolution is not set
362
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
363
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
364
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
365
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
366
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
367
+
368
+ # sample video dataset with metadata jsonl file
369
+ [[datasets]]
370
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
371
+ target_frames = [1, 79]
372
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
373
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
374
+ ```
375
+
376
+ # sample image dataset with lance
377
+ [[datasets]]
378
+ image_lance_dataset = "/path/to/lance_dataset"
379
+ resolution = [960, 544] # required if general resolution is not set
380
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
381
+
382
+ The metadata with .json file will be supported in the near future.
383
+
384
+
385
+
386
+
387
+ -->
dataset/image_video_dataset.py ADDED
@@ -0,0 +1,1400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ import time
8
+ from typing import Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from safetensors.torch import save_file, load_file
13
+ from safetensors import safe_open
14
+ from PIL import Image
15
+ import cv2
16
+ import av
17
+
18
+ from utils import safetensors_utils
19
+ from utils.model_utils import dtype_to_str
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
28
+
29
+ try:
30
+ import pillow_avif
31
+
32
+ IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
33
+ except:
34
+ pass
35
+
36
+ # JPEG-XL on Linux
37
+ try:
38
+ from jxlpy import JXLImagePlugin
39
+
40
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
41
+ except:
42
+ pass
43
+
44
+ # JPEG-XL on Windows
45
+ try:
46
+ import pillow_jxl
47
+
48
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
49
+ except:
50
+ pass
51
+
52
+ VIDEO_EXTENSIONS = [
53
+ ".mp4",
54
+ ".webm",
55
+ ".avi",
56
+ ".mkv",
57
+ ".mov",
58
+ ".flv",
59
+ ".wmv",
60
+ ".m4v",
61
+ ".mpg",
62
+ ".mpeg",
63
+ ".MP4",
64
+ ".WEBM",
65
+ ".AVI",
66
+ ".MKV",
67
+ ".MOV",
68
+ ".FLV",
69
+ ".WMV",
70
+ ".M4V",
71
+ ".MPG",
72
+ ".MPEG",
73
+ ] # some of them are not tested
74
+
75
+ ARCHITECTURE_HUNYUAN_VIDEO = "hv"
76
+ ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
77
+ ARCHITECTURE_WAN = "wan"
78
+ ARCHITECTURE_WAN_FULL = "wan"
79
+
80
+
81
+ def glob_images(directory, base="*"):
82
+ img_paths = []
83
+ for ext in IMAGE_EXTENSIONS:
84
+ if base == "*":
85
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
86
+ else:
87
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
88
+ img_paths = list(set(img_paths)) # remove duplicates
89
+ img_paths.sort()
90
+ return img_paths
91
+
92
+
93
+ def glob_videos(directory, base="*"):
94
+ video_paths = []
95
+ for ext in VIDEO_EXTENSIONS:
96
+ if base == "*":
97
+ video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
98
+ else:
99
+ video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
100
+ video_paths = list(set(video_paths)) # remove duplicates
101
+ video_paths.sort()
102
+ return video_paths
103
+
104
+
105
+ def divisible_by(num: int, divisor: int) -> int:
106
+ return num - num % divisor
107
+
108
+
109
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
110
+ """
111
+ Resize the image to the bucket resolution.
112
+ """
113
+ is_pil_image = isinstance(image, Image.Image)
114
+ if is_pil_image:
115
+ image_width, image_height = image.size
116
+ else:
117
+ image_height, image_width = image.shape[:2]
118
+
119
+ if bucket_reso == (image_width, image_height):
120
+ return np.array(image) if is_pil_image else image
121
+
122
+ bucket_width, bucket_height = bucket_reso
123
+ if bucket_width == image_width or bucket_height == image_height:
124
+ image = np.array(image) if is_pil_image else image
125
+ else:
126
+ # resize the image to the bucket resolution to match the short side
127
+ scale_width = bucket_width / image_width
128
+ scale_height = bucket_height / image_height
129
+ scale = max(scale_width, scale_height)
130
+ image_width = int(image_width * scale + 0.5)
131
+ image_height = int(image_height * scale + 0.5)
132
+
133
+ if scale > 1:
134
+ image = Image.fromarray(image) if not is_pil_image else image
135
+ image = image.resize((image_width, image_height), Image.LANCZOS)
136
+ image = np.array(image)
137
+ else:
138
+ image = np.array(image) if is_pil_image else image
139
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
140
+
141
+ # crop the image to the bucket resolution
142
+ crop_left = (image_width - bucket_width) // 2
143
+ crop_top = (image_height - bucket_height) // 2
144
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
145
+ return image
146
+
147
+
148
+ class ItemInfo:
149
+ def __init__(
150
+ self,
151
+ item_key: str,
152
+ caption: str,
153
+ original_size: tuple[int, int],
154
+ bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
155
+ frame_count: Optional[int] = None,
156
+ content: Optional[np.ndarray] = None,
157
+ latent_cache_path: Optional[str] = None,
158
+ ) -> None:
159
+ self.item_key = item_key
160
+ self.caption = caption
161
+ self.original_size = original_size
162
+ self.bucket_size = bucket_size
163
+ self.frame_count = frame_count
164
+ self.content = content
165
+ self.latent_cache_path = latent_cache_path
166
+ self.text_encoder_output_cache_path: Optional[str] = None
167
+
168
+ def __str__(self) -> str:
169
+ return (
170
+ f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
171
+ + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
172
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
173
+ )
174
+
175
+
176
+ # We use simple if-else approach to support multiple architectures.
177
+ # Maybe we can use a plugin system in the future.
178
+
179
+ # the keys of the dict are `<content_type>_FxHxW_<dtype>` for latents
180
+ # and `<content_type>_<dtype|mask>` for other tensors
181
+
182
+
183
+ def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
184
+ """HunyuanVideo architecture only"""
185
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
186
+
187
+ _, F, H, W = latent.shape
188
+ dtype_str = dtype_to_str(latent.dtype)
189
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
190
+
191
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
192
+
193
+
194
+ def save_latent_cache_wan(
195
+ item_info: ItemInfo, latent: torch.Tensor, clip_embed: Optional[torch.Tensor], image_latent: Optional[torch.Tensor]
196
+ ):
197
+ """Wan architecture only"""
198
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
199
+
200
+ _, F, H, W = latent.shape
201
+ dtype_str = dtype_to_str(latent.dtype)
202
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
203
+
204
+ if clip_embed is not None:
205
+ sd[f"clip_{dtype_str}"] = clip_embed.detach().cpu()
206
+
207
+ if image_latent is not None:
208
+ sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
209
+
210
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
211
+
212
+
213
+ def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
214
+ metadata = {
215
+ "architecture": arch_fullname,
216
+ "width": f"{item_info.original_size[0]}",
217
+ "height": f"{item_info.original_size[1]}",
218
+ "format_version": "1.0.1",
219
+ }
220
+ if item_info.frame_count is not None:
221
+ metadata["frame_count"] = f"{item_info.frame_count}"
222
+
223
+ for key, value in sd.items():
224
+ # NaN check and show warning, replace NaN with 0
225
+ if torch.isnan(value).any():
226
+ logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
227
+ value[torch.isnan(value)] = 0
228
+
229
+ latent_dir = os.path.dirname(item_info.latent_cache_path)
230
+ os.makedirs(latent_dir, exist_ok=True)
231
+
232
+ save_file(sd, item_info.latent_cache_path, metadata=metadata)
233
+
234
+
235
+ def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
236
+ """HunyuanVideo architecture only"""
237
+ assert (
238
+ embed.dim() == 1 or embed.dim() == 2
239
+ ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
240
+ assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
241
+
242
+ sd = {}
243
+ dtype_str = dtype_to_str(embed.dtype)
244
+ text_encoder_type = "llm" if is_llm else "clipL"
245
+ sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
246
+ if mask is not None:
247
+ sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
248
+
249
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
250
+
251
+
252
+ def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor):
253
+ """Wan architecture only. Wan2.1 only has a single text encoder"""
254
+
255
+ sd = {}
256
+ dtype_str = dtype_to_str(embed.dtype)
257
+ text_encoder_type = "t5"
258
+ sd[f"varlen_{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
259
+
260
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
261
+
262
+
263
+ def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
264
+ for key, value in sd.items():
265
+ # NaN check and show warning, replace NaN with 0
266
+ if torch.isnan(value).any():
267
+ logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
268
+ value[torch.isnan(value)] = 0
269
+
270
+ metadata = {
271
+ "architecture": arch_fullname,
272
+ "caption1": item_info.caption,
273
+ "format_version": "1.0.1",
274
+ }
275
+
276
+ if os.path.exists(item_info.text_encoder_output_cache_path):
277
+ # load existing cache and update metadata
278
+ with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
279
+ existing_metadata = f.metadata()
280
+ for key in f.keys():
281
+ if key not in sd: # avoid overwriting by existing cache, we keep the new one
282
+ sd[key] = f.get_tensor(key)
283
+
284
+ assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
285
+ if existing_metadata["caption1"] != metadata["caption1"]:
286
+ logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
287
+ # TODO verify format_version
288
+
289
+ existing_metadata.pop("caption1", None)
290
+ existing_metadata.pop("format_version", None)
291
+ metadata.update(existing_metadata) # copy existing metadata except caption and format_version
292
+ else:
293
+ text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
294
+ os.makedirs(text_encoder_output_dir, exist_ok=True)
295
+
296
+ safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
297
+
298
+
299
+ class BucketSelector:
300
+ RESOLUTION_STEPS_HUNYUAN = 16
301
+ RESOLUTION_STEPS_WAN = 16
302
+
303
+ def __init__(
304
+ self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
305
+ ):
306
+ self.resolution = resolution
307
+ self.bucket_area = resolution[0] * resolution[1]
308
+ self.architecture = architecture
309
+
310
+ if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
311
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
312
+ elif self.architecture == ARCHITECTURE_WAN:
313
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
314
+ else:
315
+ raise ValueError(f"Invalid architecture: {self.architecture}")
316
+
317
+ if not enable_bucket:
318
+ # only define one bucket
319
+ self.bucket_resolutions = [resolution]
320
+ self.no_upscale = False
321
+ else:
322
+ # prepare bucket resolution
323
+ self.no_upscale = no_upscale
324
+ sqrt_size = int(math.sqrt(self.bucket_area))
325
+ min_size = divisible_by(sqrt_size // 2, self.reso_steps)
326
+ self.bucket_resolutions = []
327
+ for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
328
+ h = divisible_by(self.bucket_area // w, self.reso_steps)
329
+ self.bucket_resolutions.append((w, h))
330
+ self.bucket_resolutions.append((h, w))
331
+
332
+ self.bucket_resolutions = list(set(self.bucket_resolutions))
333
+ self.bucket_resolutions.sort()
334
+
335
+ # calculate aspect ratio to find the nearest resolution
336
+ self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
337
+
338
+ def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
339
+ """
340
+ return the bucket resolution for the given image size, (width, height)
341
+ """
342
+ area = image_size[0] * image_size[1]
343
+ if self.no_upscale and area <= self.bucket_area:
344
+ w, h = image_size
345
+ w = divisible_by(w, self.reso_steps)
346
+ h = divisible_by(h, self.reso_steps)
347
+ return w, h
348
+
349
+ aspect_ratio = image_size[0] / image_size[1]
350
+ ar_errors = self.aspect_ratios - aspect_ratio
351
+ bucket_id = np.abs(ar_errors).argmin()
352
+ return self.bucket_resolutions[bucket_id]
353
+
354
+
355
+ def load_video(
356
+ video_path: str,
357
+ start_frame: Optional[int] = None,
358
+ end_frame: Optional[int] = None,
359
+ bucket_selector: Optional[BucketSelector] = None,
360
+ bucket_reso: Optional[tuple[int, int]] = None,
361
+ ) -> list[np.ndarray]:
362
+ """
363
+ bucket_reso: if given, resize the video to the bucket resolution, (width, height)
364
+ """
365
+ container = av.open(video_path)
366
+ video = []
367
+ for i, frame in enumerate(container.decode(video=0)):
368
+ if start_frame is not None and i < start_frame:
369
+ continue
370
+ if end_frame is not None and i >= end_frame:
371
+ break
372
+ frame = frame.to_image()
373
+
374
+ if bucket_selector is not None and bucket_reso is None:
375
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
376
+
377
+ if bucket_reso is not None:
378
+ frame = resize_image_to_bucket(frame, bucket_reso)
379
+ else:
380
+ frame = np.array(frame)
381
+
382
+ video.append(frame)
383
+ container.close()
384
+ return video
385
+
386
+
387
+ class BucketBatchManager:
388
+
389
+ def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
390
+ self.batch_size = batch_size
391
+ self.buckets = bucketed_item_info
392
+ self.bucket_resos = list(self.buckets.keys())
393
+ self.bucket_resos.sort()
394
+
395
+ self.bucket_batch_indices = []
396
+ for bucket_reso in self.bucket_resos:
397
+ bucket = self.buckets[bucket_reso]
398
+ num_batches = math.ceil(len(bucket) / self.batch_size)
399
+ for i in range(num_batches):
400
+ self.bucket_batch_indices.append((bucket_reso, i))
401
+
402
+ self.shuffle()
403
+
404
+ def show_bucket_info(self):
405
+ for bucket_reso in self.bucket_resos:
406
+ bucket = self.buckets[bucket_reso]
407
+ logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
408
+
409
+ logger.info(f"total batches: {len(self)}")
410
+
411
+ def shuffle(self):
412
+ for bucket in self.buckets.values():
413
+ random.shuffle(bucket)
414
+ random.shuffle(self.bucket_batch_indices)
415
+
416
+ def __len__(self):
417
+ return len(self.bucket_batch_indices)
418
+
419
+ def __getitem__(self, idx):
420
+ bucket_reso, batch_idx = self.bucket_batch_indices[idx]
421
+ bucket = self.buckets[bucket_reso]
422
+ start = batch_idx * self.batch_size
423
+ end = min(start + self.batch_size, len(bucket))
424
+
425
+ batch_tensor_data = {}
426
+ varlen_keys = set()
427
+ for item_info in bucket[start:end]:
428
+ sd_latent = load_file(item_info.latent_cache_path)
429
+ sd_te = load_file(item_info.text_encoder_output_cache_path)
430
+ sd = {**sd_latent, **sd_te}
431
+
432
+ # TODO refactor this
433
+ for key in sd.keys():
434
+ is_varlen_key = key.startswith("varlen_") # varlen keys are not stacked
435
+ content_key = key
436
+
437
+ if is_varlen_key:
438
+ content_key = content_key.replace("varlen_", "")
439
+
440
+ if content_key.endswith("_mask"):
441
+ pass
442
+ else:
443
+ content_key = content_key.rsplit("_", 1)[0] # remove dtype
444
+ if content_key.startswith("latents_"):
445
+ content_key = content_key.rsplit("_", 1)[0] # remove FxHxW
446
+
447
+ if content_key not in batch_tensor_data:
448
+ batch_tensor_data[content_key] = []
449
+ batch_tensor_data[content_key].append(sd[key])
450
+
451
+ if is_varlen_key:
452
+ varlen_keys.add(content_key)
453
+
454
+ for key in batch_tensor_data.keys():
455
+ if key not in varlen_keys:
456
+ batch_tensor_data[key] = torch.stack(batch_tensor_data[key])
457
+
458
+ return batch_tensor_data
459
+
460
+
461
+ class ContentDatasource:
462
+ def __init__(self):
463
+ self.caption_only = False
464
+
465
+ def set_caption_only(self, caption_only: bool):
466
+ self.caption_only = caption_only
467
+
468
+ def is_indexable(self):
469
+ return False
470
+
471
+ def get_caption(self, idx: int) -> tuple[str, str]:
472
+ """
473
+ Returns caption. May not be called if is_indexable() returns False.
474
+ """
475
+ raise NotImplementedError
476
+
477
+ def __len__(self):
478
+ raise NotImplementedError
479
+
480
+ def __iter__(self):
481
+ raise NotImplementedError
482
+
483
+ def __next__(self):
484
+ raise NotImplementedError
485
+
486
+
487
+ class ImageDatasource(ContentDatasource):
488
+ def __init__(self):
489
+ super().__init__()
490
+
491
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
492
+ """
493
+ Returns image data as a tuple of image path, image, and caption for the given index.
494
+ Key must be unique and valid as a file name.
495
+ May not be called if is_indexable() returns False.
496
+ """
497
+ raise NotImplementedError
498
+
499
+
500
+ class ImageDirectoryDatasource(ImageDatasource):
501
+ def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
502
+ super().__init__()
503
+ self.image_directory = image_directory
504
+ self.caption_extension = caption_extension
505
+ self.current_idx = 0
506
+
507
+ # glob images
508
+ logger.info(f"glob images in {self.image_directory}")
509
+ self.image_paths = glob_images(self.image_directory)
510
+ logger.info(f"found {len(self.image_paths)} images")
511
+
512
+ def is_indexable(self):
513
+ return True
514
+
515
+ def __len__(self):
516
+ return len(self.image_paths)
517
+
518
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
519
+ image_path = self.image_paths[idx]
520
+ image = Image.open(image_path).convert("RGB")
521
+
522
+ _, caption = self.get_caption(idx)
523
+
524
+ return image_path, image, caption
525
+
526
+ def get_caption(self, idx: int) -> tuple[str, str]:
527
+ image_path = self.image_paths[idx]
528
+ caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
529
+ with open(caption_path, "r", encoding="utf-8") as f:
530
+ caption = f.read().strip()
531
+ return image_path, caption
532
+
533
+ def __iter__(self):
534
+ self.current_idx = 0
535
+ return self
536
+
537
+ def __next__(self) -> callable:
538
+ """
539
+ Returns a fetcher function that returns image data.
540
+ """
541
+ if self.current_idx >= len(self.image_paths):
542
+ raise StopIteration
543
+
544
+ if self.caption_only:
545
+
546
+ def create_caption_fetcher(index):
547
+ return lambda: self.get_caption(index)
548
+
549
+ fetcher = create_caption_fetcher(self.current_idx)
550
+ else:
551
+
552
+ def create_image_fetcher(index):
553
+ return lambda: self.get_image_data(index)
554
+
555
+ fetcher = create_image_fetcher(self.current_idx)
556
+
557
+ self.current_idx += 1
558
+ return fetcher
559
+
560
+
561
+ class ImageJsonlDatasource(ImageDatasource):
562
+ def __init__(self, image_jsonl_file: str):
563
+ super().__init__()
564
+ self.image_jsonl_file = image_jsonl_file
565
+ self.current_idx = 0
566
+
567
+ # load jsonl
568
+ logger.info(f"load image jsonl from {self.image_jsonl_file}")
569
+ self.data = []
570
+ with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
571
+ for line in f:
572
+ try:
573
+ data = json.loads(line)
574
+ except json.JSONDecodeError:
575
+ logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}")
576
+ raise
577
+ self.data.append(data)
578
+ logger.info(f"loaded {len(self.data)} images")
579
+
580
+ def is_indexable(self):
581
+ return True
582
+
583
+ def __len__(self):
584
+ return len(self.data)
585
+
586
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
587
+ data = self.data[idx]
588
+ image_path = data["image_path"]
589
+ image = Image.open(image_path).convert("RGB")
590
+
591
+ caption = data["caption"]
592
+
593
+ return image_path, image, caption
594
+
595
+ def get_caption(self, idx: int) -> tuple[str, str]:
596
+ data = self.data[idx]
597
+ image_path = data["image_path"]
598
+ caption = data["caption"]
599
+ return image_path, caption
600
+
601
+ def __iter__(self):
602
+ self.current_idx = 0
603
+ return self
604
+
605
+ def __next__(self) -> callable:
606
+ if self.current_idx >= len(self.data):
607
+ raise StopIteration
608
+
609
+ if self.caption_only:
610
+
611
+ def create_caption_fetcher(index):
612
+ return lambda: self.get_caption(index)
613
+
614
+ fetcher = create_caption_fetcher(self.current_idx)
615
+
616
+ else:
617
+
618
+ def create_fetcher(index):
619
+ return lambda: self.get_image_data(index)
620
+
621
+ fetcher = create_fetcher(self.current_idx)
622
+
623
+ self.current_idx += 1
624
+ return fetcher
625
+
626
+
627
+ class VideoDatasource(ContentDatasource):
628
+ def __init__(self):
629
+ super().__init__()
630
+
631
+ # None means all frames
632
+ self.start_frame = None
633
+ self.end_frame = None
634
+
635
+ self.bucket_selector = None
636
+
637
+ def __len__(self):
638
+ raise NotImplementedError
639
+
640
+ def get_video_data_from_path(
641
+ self,
642
+ video_path: str,
643
+ start_frame: Optional[int] = None,
644
+ end_frame: Optional[int] = None,
645
+ bucket_selector: Optional[BucketSelector] = None,
646
+ ) -> tuple[str, list[Image.Image], str]:
647
+ # this method can resize the video if bucket_selector is given to reduce the memory usage
648
+
649
+ start_frame = start_frame if start_frame is not None else self.start_frame
650
+ end_frame = end_frame if end_frame is not None else self.end_frame
651
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
652
+
653
+ video = load_video(video_path, start_frame, end_frame, bucket_selector)
654
+ return video
655
+
656
+ def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
657
+ self.start_frame = start_frame
658
+ self.end_frame = end_frame
659
+
660
+ def set_bucket_selector(self, bucket_selector: BucketSelector):
661
+ self.bucket_selector = bucket_selector
662
+
663
+ def __iter__(self):
664
+ raise NotImplementedError
665
+
666
+ def __next__(self):
667
+ raise NotImplementedError
668
+
669
+
670
+ class VideoDirectoryDatasource(VideoDatasource):
671
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
672
+ super().__init__()
673
+ self.video_directory = video_directory
674
+ self.caption_extension = caption_extension
675
+ self.current_idx = 0
676
+
677
+ # glob images
678
+ logger.info(f"glob images in {self.video_directory}")
679
+ self.video_paths = glob_videos(self.video_directory)
680
+ logger.info(f"found {len(self.video_paths)} videos")
681
+
682
+ def is_indexable(self):
683
+ return True
684
+
685
+ def __len__(self):
686
+ return len(self.video_paths)
687
+
688
+ def get_video_data(
689
+ self,
690
+ idx: int,
691
+ start_frame: Optional[int] = None,
692
+ end_frame: Optional[int] = None,
693
+ bucket_selector: Optional[BucketSelector] = None,
694
+ ) -> tuple[str, list[Image.Image], str]:
695
+ video_path = self.video_paths[idx]
696
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
697
+
698
+ _, caption = self.get_caption(idx)
699
+
700
+ return video_path, video, caption
701
+
702
+ def get_caption(self, idx: int) -> tuple[str, str]:
703
+ video_path = self.video_paths[idx]
704
+ caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
705
+ with open(caption_path, "r", encoding="utf-8") as f:
706
+ caption = f.read().strip()
707
+ return video_path, caption
708
+
709
+ def __iter__(self):
710
+ self.current_idx = 0
711
+ return self
712
+
713
+ def __next__(self):
714
+ if self.current_idx >= len(self.video_paths):
715
+ raise StopIteration
716
+
717
+ if self.caption_only:
718
+
719
+ def create_caption_fetcher(index):
720
+ return lambda: self.get_caption(index)
721
+
722
+ fetcher = create_caption_fetcher(self.current_idx)
723
+
724
+ else:
725
+
726
+ def create_fetcher(index):
727
+ return lambda: self.get_video_data(index)
728
+
729
+ fetcher = create_fetcher(self.current_idx)
730
+
731
+ self.current_idx += 1
732
+ return fetcher
733
+
734
+
735
+ class VideoJsonlDatasource(VideoDatasource):
736
+ def __init__(self, video_jsonl_file: str):
737
+ super().__init__()
738
+ self.video_jsonl_file = video_jsonl_file
739
+ self.current_idx = 0
740
+
741
+ # load jsonl
742
+ logger.info(f"load video jsonl from {self.video_jsonl_file}")
743
+ self.data = []
744
+ with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
745
+ for line in f:
746
+ data = json.loads(line)
747
+ self.data.append(data)
748
+ logger.info(f"loaded {len(self.data)} videos")
749
+
750
+ def is_indexable(self):
751
+ return True
752
+
753
+ def __len__(self):
754
+ return len(self.data)
755
+
756
+ def get_video_data(
757
+ self,
758
+ idx: int,
759
+ start_frame: Optional[int] = None,
760
+ end_frame: Optional[int] = None,
761
+ bucket_selector: Optional[BucketSelector] = None,
762
+ ) -> tuple[str, list[Image.Image], str]:
763
+ data = self.data[idx]
764
+ video_path = data["video_path"]
765
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
766
+
767
+ caption = data["caption"]
768
+
769
+ return video_path, video, caption
770
+
771
+ def get_caption(self, idx: int) -> tuple[str, str]:
772
+ data = self.data[idx]
773
+ video_path = data["video_path"]
774
+ caption = data["caption"]
775
+ return video_path, caption
776
+
777
+ def __iter__(self):
778
+ self.current_idx = 0
779
+ return self
780
+
781
+ def __next__(self):
782
+ if self.current_idx >= len(self.data):
783
+ raise StopIteration
784
+
785
+ if self.caption_only:
786
+
787
+ def create_caption_fetcher(index):
788
+ return lambda: self.get_caption(index)
789
+
790
+ fetcher = create_caption_fetcher(self.current_idx)
791
+
792
+ else:
793
+
794
+ def create_fetcher(index):
795
+ return lambda: self.get_video_data(index)
796
+
797
+ fetcher = create_fetcher(self.current_idx)
798
+
799
+ self.current_idx += 1
800
+ return fetcher
801
+
802
+
803
+ class BaseDataset(torch.utils.data.Dataset):
804
+ def __init__(
805
+ self,
806
+ resolution: Tuple[int, int] = (960, 544),
807
+ caption_extension: Optional[str] = None,
808
+ batch_size: int = 1,
809
+ num_repeats: int = 1,
810
+ enable_bucket: bool = False,
811
+ bucket_no_upscale: bool = False,
812
+ cache_directory: Optional[str] = None,
813
+ debug_dataset: bool = False,
814
+ architecture: str = "no_default",
815
+ ):
816
+ self.resolution = resolution
817
+ self.caption_extension = caption_extension
818
+ self.batch_size = batch_size
819
+ self.num_repeats = num_repeats
820
+ self.enable_bucket = enable_bucket
821
+ self.bucket_no_upscale = bucket_no_upscale
822
+ self.cache_directory = cache_directory
823
+ self.debug_dataset = debug_dataset
824
+ self.architecture = architecture
825
+ self.seed = None
826
+ self.current_epoch = 0
827
+
828
+ if not self.enable_bucket:
829
+ self.bucket_no_upscale = False
830
+
831
+ def get_metadata(self) -> dict:
832
+ metadata = {
833
+ "resolution": self.resolution,
834
+ "caption_extension": self.caption_extension,
835
+ "batch_size_per_device": self.batch_size,
836
+ "num_repeats": self.num_repeats,
837
+ "enable_bucket": bool(self.enable_bucket),
838
+ "bucket_no_upscale": bool(self.bucket_no_upscale),
839
+ }
840
+ return metadata
841
+
842
+ def get_all_latent_cache_files(self):
843
+ return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
844
+
845
+ def get_all_text_encoder_output_cache_files(self):
846
+ return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}_te.safetensors"))
847
+
848
+ def get_latent_cache_path(self, item_info: ItemInfo) -> str:
849
+ """
850
+ Returns the cache path for the latent tensor.
851
+
852
+ item_info: ItemInfo object
853
+
854
+ Returns:
855
+ str: cache path
856
+
857
+ cache_path is based on the item_key and the resolution.
858
+ """
859
+ w, h = item_info.original_size
860
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
861
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
862
+ return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors")
863
+
864
+ def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
865
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
866
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
867
+ return os.path.join(self.cache_directory, f"{basename}_{self.architecture}_te.safetensors")
868
+
869
+ def retrieve_latent_cache_batches(self, num_workers: int):
870
+ raise NotImplementedError
871
+
872
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
873
+ raise NotImplementedError
874
+
875
+ def prepare_for_training(self):
876
+ pass
877
+
878
+ def set_seed(self, seed: int):
879
+ self.seed = seed
880
+
881
+ def set_current_epoch(self, epoch):
882
+ if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
883
+ if epoch > self.current_epoch:
884
+ logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
885
+ num_epochs = epoch - self.current_epoch
886
+ for _ in range(num_epochs):
887
+ self.current_epoch += 1
888
+ self.shuffle_buckets()
889
+ # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
890
+ else:
891
+ logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
892
+ self.current_epoch = epoch
893
+
894
+ def set_current_step(self, step):
895
+ self.current_step = step
896
+
897
+ def set_max_train_steps(self, max_train_steps):
898
+ self.max_train_steps = max_train_steps
899
+
900
+ def shuffle_buckets(self):
901
+ raise NotImplementedError
902
+
903
+ def __len__(self):
904
+ return NotImplementedError
905
+
906
+ def __getitem__(self, idx):
907
+ raise NotImplementedError
908
+
909
+ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
910
+ datasource.set_caption_only(True)
911
+ executor = ThreadPoolExecutor(max_workers=num_workers)
912
+
913
+ data: list[ItemInfo] = []
914
+ futures = []
915
+
916
+ def aggregate_future(consume_all: bool = False):
917
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
918
+ completed_futures = [future for future in futures if future.done()]
919
+ if len(completed_futures) == 0:
920
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
921
+ time.sleep(0.1)
922
+ continue
923
+ else:
924
+ break # submit batch if possible
925
+
926
+ for future in completed_futures:
927
+ item_key, caption = future.result()
928
+ item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
929
+ item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
930
+ data.append(item_info)
931
+
932
+ futures.remove(future)
933
+
934
+ def submit_batch(flush: bool = False):
935
+ nonlocal data
936
+ if len(data) >= batch_size or (len(data) > 0 and flush):
937
+ batch = data[0:batch_size]
938
+ if len(data) > batch_size:
939
+ data = data[batch_size:]
940
+ else:
941
+ data = []
942
+ return batch
943
+ return None
944
+
945
+ for fetch_op in datasource:
946
+ future = executor.submit(fetch_op)
947
+ futures.append(future)
948
+ aggregate_future()
949
+ while True:
950
+ batch = submit_batch()
951
+ if batch is None:
952
+ break
953
+ yield batch
954
+
955
+ aggregate_future(consume_all=True)
956
+ while True:
957
+ batch = submit_batch(flush=True)
958
+ if batch is None:
959
+ break
960
+ yield batch
961
+
962
+ executor.shutdown()
963
+
964
+
965
+ class ImageDataset(BaseDataset):
966
+ def __init__(
967
+ self,
968
+ resolution: Tuple[int, int],
969
+ caption_extension: Optional[str],
970
+ batch_size: int,
971
+ num_repeats: int,
972
+ enable_bucket: bool,
973
+ bucket_no_upscale: bool,
974
+ image_directory: Optional[str] = None,
975
+ image_jsonl_file: Optional[str] = None,
976
+ cache_directory: Optional[str] = None,
977
+ debug_dataset: bool = False,
978
+ architecture: str = "no_default",
979
+ ):
980
+ super(ImageDataset, self).__init__(
981
+ resolution,
982
+ caption_extension,
983
+ batch_size,
984
+ num_repeats,
985
+ enable_bucket,
986
+ bucket_no_upscale,
987
+ cache_directory,
988
+ debug_dataset,
989
+ architecture,
990
+ )
991
+ self.image_directory = image_directory
992
+ self.image_jsonl_file = image_jsonl_file
993
+ if image_directory is not None:
994
+ self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
995
+ elif image_jsonl_file is not None:
996
+ self.datasource = ImageJsonlDatasource(image_jsonl_file)
997
+ else:
998
+ raise ValueError("image_directory or image_jsonl_file must be specified")
999
+
1000
+ if self.cache_directory is None:
1001
+ self.cache_directory = self.image_directory
1002
+
1003
+ self.batch_manager = None
1004
+ self.num_train_items = 0
1005
+
1006
+ def get_metadata(self):
1007
+ metadata = super().get_metadata()
1008
+ if self.image_directory is not None:
1009
+ metadata["image_directory"] = os.path.basename(self.image_directory)
1010
+ if self.image_jsonl_file is not None:
1011
+ metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
1012
+ return metadata
1013
+
1014
+ def get_total_image_count(self):
1015
+ return len(self.datasource) if self.datasource.is_indexable() else None
1016
+
1017
+ def retrieve_latent_cache_batches(self, num_workers: int):
1018
+ buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1019
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1020
+
1021
+ batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
1022
+ futures = []
1023
+
1024
+ # aggregate futures and sort by bucket resolution
1025
+ def aggregate_future(consume_all: bool = False):
1026
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1027
+ completed_futures = [future for future in futures if future.done()]
1028
+ if len(completed_futures) == 0:
1029
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1030
+ time.sleep(0.1)
1031
+ continue
1032
+ else:
1033
+ break # submit batch if possible
1034
+
1035
+ for future in completed_futures:
1036
+ original_size, item_key, image, caption = future.result()
1037
+ bucket_height, bucket_width = image.shape[:2]
1038
+ bucket_reso = (bucket_width, bucket_height)
1039
+
1040
+ item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
1041
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1042
+
1043
+ if bucket_reso not in batches:
1044
+ batches[bucket_reso] = []
1045
+ batches[bucket_reso].append(item_info)
1046
+
1047
+ futures.remove(future)
1048
+
1049
+ # submit batch if some bucket has enough items
1050
+ def submit_batch(flush: bool = False):
1051
+ for key in batches:
1052
+ if len(batches[key]) >= self.batch_size or flush:
1053
+ batch = batches[key][0 : self.batch_size]
1054
+ if len(batches[key]) > self.batch_size:
1055
+ batches[key] = batches[key][self.batch_size :]
1056
+ else:
1057
+ del batches[key]
1058
+ return key, batch
1059
+ return None, None
1060
+
1061
+ for fetch_op in self.datasource:
1062
+
1063
+ # fetch and resize image in a separate thread
1064
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
1065
+ image_key, image, caption = op()
1066
+ image: Image.Image
1067
+ image_size = image.size
1068
+
1069
+ bucket_reso = buckset_selector.get_bucket_resolution(image_size)
1070
+ image = resize_image_to_bucket(image, bucket_reso)
1071
+ return image_size, image_key, image, caption
1072
+
1073
+ future = executor.submit(fetch_and_resize, fetch_op)
1074
+ futures.append(future)
1075
+ aggregate_future()
1076
+ while True:
1077
+ key, batch = submit_batch()
1078
+ if key is None:
1079
+ break
1080
+ yield key, batch
1081
+
1082
+ aggregate_future(consume_all=True)
1083
+ while True:
1084
+ key, batch = submit_batch(flush=True)
1085
+ if key is None:
1086
+ break
1087
+ yield key, batch
1088
+
1089
+ executor.shutdown()
1090
+
1091
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1092
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1093
+
1094
+ def prepare_for_training(self):
1095
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1096
+
1097
+ # glob cache files
1098
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
1099
+
1100
+ # assign cache files to item info
1101
+ bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
1102
+ for cache_file in latent_cache_files:
1103
+ tokens = os.path.basename(cache_file).split("_")
1104
+
1105
+ image_size = tokens[-2] # 0000x0000
1106
+ image_width, image_height = map(int, image_size.split("x"))
1107
+ image_size = (image_width, image_height)
1108
+
1109
+ item_key = "_".join(tokens[:-2])
1110
+ text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
1111
+ if not os.path.exists(text_encoder_output_cache_file):
1112
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1113
+ continue
1114
+
1115
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1116
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
1117
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1118
+
1119
+ bucket = bucketed_item_info.get(bucket_reso, [])
1120
+ for _ in range(self.num_repeats):
1121
+ bucket.append(item_info)
1122
+ bucketed_item_info[bucket_reso] = bucket
1123
+
1124
+ # prepare batch manager
1125
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1126
+ self.batch_manager.show_bucket_info()
1127
+
1128
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1129
+
1130
+ def shuffle_buckets(self):
1131
+ # set random seed for this epoch
1132
+ random.seed(self.seed + self.current_epoch)
1133
+ self.batch_manager.shuffle()
1134
+
1135
+ def __len__(self):
1136
+ if self.batch_manager is None:
1137
+ return 100 # dummy value
1138
+ return len(self.batch_manager)
1139
+
1140
+ def __getitem__(self, idx):
1141
+ return self.batch_manager[idx]
1142
+
1143
+
1144
+ class VideoDataset(BaseDataset):
1145
+ def __init__(
1146
+ self,
1147
+ resolution: Tuple[int, int],
1148
+ caption_extension: Optional[str],
1149
+ batch_size: int,
1150
+ num_repeats: int,
1151
+ enable_bucket: bool,
1152
+ bucket_no_upscale: bool,
1153
+ frame_extraction: Optional[str] = "head",
1154
+ frame_stride: Optional[int] = 1,
1155
+ frame_sample: Optional[int] = 1,
1156
+ target_frames: Optional[list[int]] = None,
1157
+ video_directory: Optional[str] = None,
1158
+ video_jsonl_file: Optional[str] = None,
1159
+ cache_directory: Optional[str] = None,
1160
+ debug_dataset: bool = False,
1161
+ architecture: str = "no_default",
1162
+ ):
1163
+ super(VideoDataset, self).__init__(
1164
+ resolution,
1165
+ caption_extension,
1166
+ batch_size,
1167
+ num_repeats,
1168
+ enable_bucket,
1169
+ bucket_no_upscale,
1170
+ cache_directory,
1171
+ debug_dataset,
1172
+ architecture,
1173
+ )
1174
+ self.video_directory = video_directory
1175
+ self.video_jsonl_file = video_jsonl_file
1176
+ self.target_frames = target_frames
1177
+ self.frame_extraction = frame_extraction
1178
+ self.frame_stride = frame_stride
1179
+ self.frame_sample = frame_sample
1180
+
1181
+ if video_directory is not None:
1182
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
1183
+ elif video_jsonl_file is not None:
1184
+ self.datasource = VideoJsonlDatasource(video_jsonl_file)
1185
+
1186
+ if self.frame_extraction == "uniform" and self.frame_sample == 1:
1187
+ self.frame_extraction = "head"
1188
+ logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
1189
+ if self.frame_extraction == "head":
1190
+ # head extraction. we can limit the number of frames to be extracted
1191
+ self.datasource.set_start_and_end_frame(0, max(self.target_frames))
1192
+
1193
+ if self.cache_directory is None:
1194
+ self.cache_directory = self.video_directory
1195
+
1196
+ self.batch_manager = None
1197
+ self.num_train_items = 0
1198
+
1199
+ def get_metadata(self):
1200
+ metadata = super().get_metadata()
1201
+ if self.video_directory is not None:
1202
+ metadata["video_directory"] = os.path.basename(self.video_directory)
1203
+ if self.video_jsonl_file is not None:
1204
+ metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
1205
+ metadata["frame_extraction"] = self.frame_extraction
1206
+ metadata["frame_stride"] = self.frame_stride
1207
+ metadata["frame_sample"] = self.frame_sample
1208
+ metadata["target_frames"] = self.target_frames
1209
+ return metadata
1210
+
1211
+ def retrieve_latent_cache_batches(self, num_workers: int):
1212
+ buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
1213
+ self.datasource.set_bucket_selector(buckset_selector)
1214
+
1215
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1216
+
1217
+ # key: (width, height, frame_count), value: [ItemInfo]
1218
+ batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
1219
+ futures = []
1220
+
1221
+ def aggregate_future(consume_all: bool = False):
1222
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1223
+ completed_futures = [future for future in futures if future.done()]
1224
+ if len(completed_futures) == 0:
1225
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1226
+ time.sleep(0.1)
1227
+ continue
1228
+ else:
1229
+ break # submit batch if possible
1230
+
1231
+ for future in completed_futures:
1232
+ original_frame_size, video_key, video, caption = future.result()
1233
+
1234
+ frame_count = len(video)
1235
+ video = np.stack(video, axis=0)
1236
+ height, width = video.shape[1:3]
1237
+ bucket_reso = (width, height) # already resized
1238
+
1239
+ crop_pos_and_frames = []
1240
+ if self.frame_extraction == "head":
1241
+ for target_frame in self.target_frames:
1242
+ if frame_count >= target_frame:
1243
+ crop_pos_and_frames.append((0, target_frame))
1244
+ elif self.frame_extraction == "chunk":
1245
+ # split by target_frames
1246
+ for target_frame in self.target_frames:
1247
+ for i in range(0, frame_count, target_frame):
1248
+ if i + target_frame <= frame_count:
1249
+ crop_pos_and_frames.append((i, target_frame))
1250
+ elif self.frame_extraction == "slide":
1251
+ # slide window
1252
+ for target_frame in self.target_frames:
1253
+ if frame_count >= target_frame:
1254
+ for i in range(0, frame_count - target_frame + 1, self.frame_stride):
1255
+ crop_pos_and_frames.append((i, target_frame))
1256
+ elif self.frame_extraction == "uniform":
1257
+ # select N frames uniformly
1258
+ for target_frame in self.target_frames:
1259
+ if frame_count >= target_frame:
1260
+ frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
1261
+ for i in frame_indices:
1262
+ crop_pos_and_frames.append((i, target_frame))
1263
+ else:
1264
+ raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
1265
+
1266
+ for crop_pos, target_frame in crop_pos_and_frames:
1267
+ cropped_video = video[crop_pos : crop_pos + target_frame]
1268
+ body, ext = os.path.splitext(video_key)
1269
+ item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
1270
+ batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
1271
+
1272
+ item_info = ItemInfo(
1273
+ item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
1274
+ )
1275
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1276
+
1277
+ batch = batches.get(batch_key, [])
1278
+ batch.append(item_info)
1279
+ batches[batch_key] = batch
1280
+
1281
+ futures.remove(future)
1282
+
1283
+ def submit_batch(flush: bool = False):
1284
+ for key in batches:
1285
+ if len(batches[key]) >= self.batch_size or flush:
1286
+ batch = batches[key][0 : self.batch_size]
1287
+ if len(batches[key]) > self.batch_size:
1288
+ batches[key] = batches[key][self.batch_size :]
1289
+ else:
1290
+ del batches[key]
1291
+ return key, batch
1292
+ return None, None
1293
+
1294
+ for operator in self.datasource:
1295
+
1296
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
1297
+ video_key, video, caption = op()
1298
+ video: list[np.ndarray]
1299
+ frame_size = (video[0].shape[1], video[0].shape[0])
1300
+
1301
+ # resize if necessary
1302
+ bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
1303
+ video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
1304
+
1305
+ return frame_size, video_key, video, caption
1306
+
1307
+ future = executor.submit(fetch_and_resize, operator)
1308
+ futures.append(future)
1309
+ aggregate_future()
1310
+ while True:
1311
+ key, batch = submit_batch()
1312
+ if key is None:
1313
+ break
1314
+ yield key, batch
1315
+
1316
+ aggregate_future(consume_all=True)
1317
+ while True:
1318
+ key, batch = submit_batch(flush=True)
1319
+ if key is None:
1320
+ break
1321
+ yield key, batch
1322
+
1323
+ executor.shutdown()
1324
+
1325
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1326
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1327
+
1328
+ def prepare_for_training(self):
1329
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1330
+
1331
+ # glob cache files
1332
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
1333
+
1334
+ # assign cache files to item info
1335
+ bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
1336
+ for cache_file in latent_cache_files:
1337
+ tokens = os.path.basename(cache_file).split("_")
1338
+
1339
+ image_size = tokens[-2] # 0000x0000
1340
+ image_width, image_height = map(int, image_size.split("x"))
1341
+ image_size = (image_width, image_height)
1342
+
1343
+ frame_pos, frame_count = tokens[-3].split("-")
1344
+ frame_pos, frame_count = int(frame_pos), int(frame_count)
1345
+
1346
+ item_key = "_".join(tokens[:-3])
1347
+ text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
1348
+ if not os.path.exists(text_encoder_output_cache_file):
1349
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1350
+ continue
1351
+
1352
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1353
+ bucket_reso = (*bucket_reso, frame_count)
1354
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
1355
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1356
+
1357
+ bucket = bucketed_item_info.get(bucket_reso, [])
1358
+ for _ in range(self.num_repeats):
1359
+ bucket.append(item_info)
1360
+ bucketed_item_info[bucket_reso] = bucket
1361
+
1362
+ # prepare batch manager
1363
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1364
+ self.batch_manager.show_bucket_info()
1365
+
1366
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1367
+
1368
+ def shuffle_buckets(self):
1369
+ # set random seed for this epoch
1370
+ random.seed(self.seed + self.current_epoch)
1371
+ self.batch_manager.shuffle()
1372
+
1373
+ def __len__(self):
1374
+ if self.batch_manager is None:
1375
+ return 100 # dummy value
1376
+ return len(self.batch_manager)
1377
+
1378
+ def __getitem__(self, idx):
1379
+ return self.batch_manager[idx]
1380
+
1381
+
1382
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1383
+ def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
1384
+ super().__init__(datasets)
1385
+ self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
1386
+ self.num_train_items = 0
1387
+ for dataset in self.datasets:
1388
+ self.num_train_items += dataset.num_train_items
1389
+
1390
+ def set_current_epoch(self, epoch):
1391
+ for dataset in self.datasets:
1392
+ dataset.set_current_epoch(epoch)
1393
+
1394
+ def set_current_step(self, step):
1395
+ for dataset in self.datasets:
1396
+ dataset.set_current_step(step)
1397
+
1398
+ def set_max_train_steps(self, max_train_steps):
1399
+ for dataset in self.datasets:
1400
+ dataset.set_max_train_steps(max_train_steps)
docs/advanced_config.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ # Advanced configuration / 高度な設定
4
+
5
+ ## How to specify `network_args` / `network_args`の指定方法
6
+
7
+ The `--network_args` option is an option for specifying detailed arguments to LoRA. Specify the arguments in the form of `key=value` in `--network_args`.
8
+
9
+ <details>
10
+ <summary>日本語</summary>
11
+ `--network_args`オプションは、LoRAへの詳細な引数を指定するためのオプションです。`--network_args`には、`key=value`の形式で引数を指定します。
12
+ </details>
13
+
14
+ ### Example / 記述例
15
+
16
+ If you specify it on the command line, write as follows. / コマンドラインで指定する場合は以下のように記述します。
17
+
18
+ ```bash
19
+ accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
20
+ --network_module networks.lora --network_dim 32
21
+ --network_args "key1=value1" "key2=value2" ...
22
+ ```
23
+
24
+ If you specify it in the configuration file, write as follows. / 設定ファイルで指定する場合は以下のように記述します。
25
+
26
+ ```toml
27
+ network_args = ["key1=value1", "key2=value2", ...]
28
+ ```
29
+
30
+ If you specify `"verbose=True"`, detailed information of LoRA will be displayed. / `"verbose=True"`を指定するとLoRAの詳細な情報が表示されます。
31
+
32
+ ```bash
33
+ --network_args "verbose=True" "key1=value1" "key2=value2" ...
34
+ ```
35
+
36
+ ## LoRA+
37
+
38
+ LoRA+ is a method to improve the training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiplier for the learning rate. The original paper recommends 16, but adjust as needed. It seems to be good to start from around 4. For details, please refer to the [related PR of sd-scripts](https://github.com/kohya-ss/sd-scripts/pull/1233).
39
+
40
+ Specify `loraplus_lr_ratio` with `--network_args`.
41
+
42
+ <details>
43
+ <summary>日本語</summary>
44
+
45
+ LoRA+は、LoRAのUP側(LoRA-B)の学習率を上げることで学習速度を向上させる手法です。学習率に対する倍率を指定します。元論文では16を推奨していますが、必要に応じて調整してください。4程度から始めるとよいようです。詳細は[sd-scriptsの関連PR]https://github.com/kohya-ss/sd-scripts/pull/1233)を参照してください。
46
+
47
+ `--network_args`で`loraplus_lr_ratio`を指定します。
48
+ </details>
49
+
50
+ ### Example / 記述例
51
+
52
+ ```bash
53
+ accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
54
+ --network_module networks.lora --network_dim 32 --network_args "loraplus_lr_ratio=4" ...
55
+ ```
56
+
57
+ ## Select the target modules of LoRA / LoRAの対象モジュールを選択する
58
+
59
+ *This feature is highly experimental and the specification may change. / この機能は特に実験的なもので、仕様は変更される可能性があります。*
60
+
61
+ By specifying `exclude_patterns` and `include_patterns` with `--network_args`, you can select the target modules of LoRA.
62
+
63
+ `exclude_patterns` excludes modules that match the specified pattern. `include_patterns` targets only modules that match the specified pattern.
64
+
65
+ Specify the values as a list. For example, `"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`.
66
+
67
+ The pattern is a regular expression for the module name. The module name is in the form of `double_blocks.0.img_mod.linear` or `single_blocks.39.modulation.linear`. The regular expression is not a partial match but a complete match.
68
+
69
+ The patterns are applied in the order of `exclude_patterns`→`include_patterns`. By default, the Linear layers of `img_mod`, `txt_mod`, and `modulation` of double blocks and single blocks are excluded.
70
+
71
+ (`.*(img_mod|txt_mod|modulation).*` is specified.)
72
+
73
+ <details>
74
+ <summary>日本語</summary>
75
+
76
+ `--network_args`で`exclude_patterns`と`include_patterns`を指定することで、LoRAの対象モジュールを選択することができます。
77
+
78
+ `exclude_patterns`は、指定したパターンに一致するモジュールを除外します。`include_patterns`は、指定したパターンに一致するモジュールのみを対象とします。
79
+
80
+ 値は、リストで指定します。`"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`のようになります。
81
+
82
+ パターンは、モジュール名に対する正規表現です。モジュール名は、たとえば`double_blocks.0.img_mod.linear`や`single_blocks.39.modulation.linear`のような形式です。正規表現は部分一致ではなく完全一致です。
83
+
84
+ パターンは、`exclude_patterns`→`include_patterns`の順で適用されます。デフォルトは、double blocksとsingle blocksのLinear層のうち、`img_mod`、`txt_mod`、`modulation`が除外されています。
85
+
86
+ (`.*(img_mod|txt_mod|modulation).*`が指定されています。)
87
+ </details>
88
+
89
+ ### Example / 記述例
90
+
91
+ Only the modules of double blocks / double blocksのモジュールのみを対象とする場合:
92
+
93
+ ```bash
94
+ --network_args "exclude_patterns=[r'.*single_blocks.*']"
95
+ ```
96
+
97
+ Only the modules of single blocks from the 10th / single blocksの10番目以降のLinearモジュールのみを対象とする場合:
98
+
99
+ ```bash
100
+ --network_args "exclude_patterns=[r'.*']" "include_patterns=[r'.*single_blocks\.\d{2}\.linear.*']"
101
+ ```
102
+
103
+ ## Save and view logs in TensorBoard format / TensorBoard形式のログの保存と参照
104
+
105
+ Specify the folder to save the logs with the `--logging_dir` option. Logs in TensorBoard format will be saved.
106
+
107
+ For example, if you specify `--logging_dir=logs`, a `logs` folder will be created in the working folder, and logs will be saved in the date folder inside it.
108
+
109
+ Also, if you specify the `--log_prefix` option, the specified string will be added before the date. For example, use `--logging_dir=logs --log_prefix=lora_setting1_` for identification.
110
+
111
+ To view logs in TensorBoard, open another command prompt and activate the virtual environment. Then enter the following in the working folder.
112
+
113
+ ```powershell
114
+ tensorboard --logdir=logs
115
+ ```
116
+
117
+ (tensorboard installation is required.)
118
+
119
+ Then open a browser and access http://localhost:6006/ to display it.
120
+
121
+ <details>
122
+ <summary>日本語</summary>
123
+ `--logging_dir`オプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
124
+
125
+ たとえば`--logging_dir=logs`と指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。
126
+
127
+ また`--log_prefix`オプションを指定すると、日時の前に指定した文字列が追加されます。`--logging_dir=logs --log_prefix=lora_setting1_`などとして識別用にお使いください。
128
+
129
+ TensorBoardでログを確認するには、別のコマンドプロンプトを開き、仮想環境を有効にしてから、作業フォルダで以下のように入力します。
130
+
131
+ ```powershell
132
+ tensorboard --logdir=logs
133
+ ```
134
+
135
+ (tensorboardのインストールが必要です。)
136
+
137
+ その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
138
+ </details>
139
+
140
+ ## Save and view logs in wandb / wandbでログの保存と参照
141
+
142
+ `--log_with wandb` option is available to save logs in wandb format. `tensorboard` or `all` is also available. The default is `tensorboard`.
143
+
144
+ Specify the project name with `--log_tracker_name` when using wandb.
145
+
146
+ <details>
147
+ <summary>日本語</summary>
148
+ `--log_with wandb`オプションを指定するとwandb形式でログを保存することができます。`tensorboard`や`all`も指定可能です。デフォルトは`tensorboard`です。
149
+
150
+ wandbを使用する場合は、`--log_tracker_name`でプロジェクト名を指定してください。
151
+ </details>
docs/sampling_during_training.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ # Sampling during training / 学習中のサンプル画像生成
4
+
5
+ By preparing a prompt file, you can generate sample images during training.
6
+
7
+ Please be aware that it consumes a considerable amount of VRAM, so be careful when generating sample images for videos with a large number of frames. Also, since it takes time to generate, adjust the frequency of sample image generation as needed.
8
+
9
+ <details>
10
+ <summary>日本語</summary>
11
+
12
+ プロンプトファイルを用意することで、学習中にサンプル画像を生成することができます。
13
+
14
+ VRAMをそれなりに消費しますので、特にフレーム数が多い動画を生成する場合は注意してください。また生成には時間がかかりますので、サンプル画像生成の頻度は適宜調整してください。
15
+ </details>
16
+
17
+ ## How to use / 使い方
18
+
19
+ ### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
20
+
21
+ Example of command line options for training with sampling / 記述例:
22
+
23
+ ```bash
24
+ --vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt
25
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
26
+ --text_encoder1 path/to/ckpts/text_encoder
27
+ --text_encoder2 path/to/ckpts/text_encoder_2
28
+ --sample_prompts /path/to/prompt_file.txt
29
+ --sample_every_n_epochs 1 --sample_every_n_steps 1000 --sample_at_first
30
+ ```
31
+
32
+ `--vae`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size`, `--text_encoder1`, `--text_encoder2` are the same as when generating images, so please refer to [here](/README.md#inference) for details. `--fp8_llm` can also be specified.
33
+
34
+ `--sample_prompts` specifies the path to the prompt file used for sample image generation. Details are described below.
35
+
36
+ `--sample_every_n_epochs` specifies how often to generate sample images in epochs, and `--sample_every_n_steps` specifies how often to generate sample images in steps.
37
+
38
+ `--sample_at_first` is specified when generating sample images at the beginning of training.
39
+
40
+ Sample images and videos are saved in the `sample` directory in the directory specified by `--output_dir`. They are saved as `.png` for still images and `.mp4` for videos.
41
+
42
+ <details>
43
+ <summary>日本語</summary>
44
+
45
+ `--vae`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`、`--text_encoder1`、`--text_encoder2`は、画像生成時と同様ですので、詳細は[こちら](/README.ja.md#推論)を参照してください。`--fp8_llm`も指定可能です。
46
+
47
+ `--sample_prompts`は、サンプル画像生成に使用するプロンプトファイルのパスを指定します。詳細は後述します。
48
+
49
+ `--sample_every_n_epochs`は、何エポックごとにサンプル画像を生成するかを、`--sample_every_n_steps`は、何ステップごとにサンプル画像を生成するかを指定します。
50
+
51
+ `--sample_at_first`は、学習開始時にサンプル画像を生成する場合に指定します。
52
+
53
+ サンプル画像、動画は、`--output_dir`で指定したディレクトリ内の、`sample`ディレクトリに保存されます。静止画の場合は`.png`、動画の場合は`.mp4`で保存されます。
54
+ </details>
55
+
56
+ ### Prompt file / プロンプトファイル
57
+
58
+ The prompt file is a text file that contains the prompts for generating sample images. The example is as follows. / プロンプトファイルは、サンプル画像生成のためのプロンプトを記述したテキストファイルです。例は以下の通りです。
59
+
60
+ ```
61
+ # prompt 1: for generating a cat video
62
+ A cat walks on the grass, realistic style. --w 640 --h 480 --f 25 --d 1 --s 20
63
+
64
+ # prompt 2: for generating a dog image
65
+ A dog runs on the beach, realistic style. --w 960 --h 544 --f 1 --d 2 --s 20
66
+ ```
67
+
68
+ A line starting with `#` is a comment.
69
+
70
+ * `--w` specifies the width of the generated image or video. The default is 256.
71
+ * `--h` specifies the height. The default is 256.
72
+ * `--f` specifies the number of frames. The default is 1, which generates a still image.
73
+ * `--d` specifies the seed. The default is random.
74
+ * `--s` specifies the number of steps in generation. The default is 20.
75
+ * `--g` specifies the guidance scale. The default is 6.0, which is the default value during inference of HunyuanVideo. Specify 1.0 for SkyReels V1 models. Ignore this option for Wan2.1 models.
76
+ * `--fs` specifies the discrete flow shift. The default is 14.5, which corresponds to the number of steps 20. In the HunyuanVideo paper, 7.0 is recommended for 50 steps, and 17.0 is recommended for less than 20 steps (e.g. 10).
77
+
78
+ If you train I2V models, you can use the additional options below.
79
+
80
+ * `--i path/to/image.png`: the image path for image2video inference.
81
+
82
+ If you train the model with classifier free guidance, you can use the additional options below.
83
+
84
+ *`--n negative prompt...`: the negative prompt for the classifier free guidance.
85
+ *`--l 6.0`: the classifier free guidance scale. Should be set to 6.0 for SkyReels V1 models. 5.0 is the default value for Wan2.1 (if omitted).
86
+
87
+ <details>
88
+ <summary>日本語</summary>
89
+
90
+ `#` で始まる行はコメントです。
91
+
92
+ * `--w` 生成画像、動画の幅を指定します。省略時は256です。
93
+ * `--h` 高さを指定します。省略時は256です。
94
+ * `--f` フレーム数を指定します。省略時は1で、静止画を生成します。
95
+ * `--d` シードを指定します。省略時はランダムです。
96
+ * `--s` 生成におけるステップ数を指定します。省略時は20です。
97
+ * `--g` guidance scaleを指定します。省略時は6.0で、HunyuanVideoの推論時のデフォルト値です。
98
+ * `--fs` discrete flow shiftを指定します。省略時は14.5で、ステップ数20の場合に対応した値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。
99
+
100
+ I2Vモデルを学習する場合、以下の追加オプションを使用できます。
101
+
102
+ * `--i path/to/image.png`: image2video推論用の画像パス。
103
+
104
+ classifier free guidance(ネガティブプロンプト)を必要とするモデルを学習する場合、以下の追加オプションを使用できます。
105
+
106
+ *`--n negative prompt...`: classifier free guidance用のネガティブプロンプト。
107
+ *`--l 6.0`: classifier free guidance scale。SkyReels V1モデルの場合は6.0に設定してください。Wan2.1の場合はデフォルト値が5.0です(省略時)。
108
+ </details>
docs/wan.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ # Wan 2.1
4
+
5
+ ## Overview / 概要
6
+
7
+ This is an unofficial training and inference script for [Wan2.1](https://github.com/Wan-Video/Wan2.1). The features are as follows.
8
+
9
+ - fp8 support and memory reduction by block swap: Inference of a 720x1280x81frames videos with 24GB VRAM, training with 720x1280 images with 24GB VRAM
10
+ - Inference without installing Flash attention (using PyTorch's scaled dot product attention)
11
+ - Supports xformers and Sage attention
12
+
13
+ This feature is experimental.
14
+
15
+ <details>
16
+ <summary>日本語</summary>
17
+ [Wan2.1](https://github.com/Wan-Video/Wan2.1) の非公式の学習および推論スクリプトです。
18
+
19
+ 以下の特徴があります。
20
+
21
+ - fp8対応およびblock swapによる省メモリ化:720x1280x81framesの動画を24GB VRAMで推論可能、720x1280の画像での学習が24GB VRAMで可能
22
+ - Flash attentionのインストールなしでの実行(PyTorchのscaled dot product attentionを使用)
23
+ - xformersおよびSage attention対応
24
+
25
+ この機能は実験的なものです。
26
+ </details>
27
+
28
+ ## Download the model / モデルのダウンロード
29
+
30
+ Download the T5 `models_t5_umt5-xxl-enc-bf16.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
31
+
32
+ Download the VAE from the above page `Wan2.1_VAE.pth` or download `split_files/vae/wan_2.1_vae.safetensors` from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
33
+
34
+ Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
35
+
36
+ Please select the appropriate weights according to T2V, I2V, resolution, model size, etc. fp8 models can be used if `--fp8` is specified.
37
+
38
+ (Thanks to Comfy-Org for providing the repackaged weights.)
39
+ <details>
40
+ <summary>日本語</summary>
41
+ T5 `models_t5_umt5-xxl-enc-bf16.pth` およびCLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を、次のページからダウンロードしてください:https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
42
+
43
+ VAEは上のページから `Wan2.1_VAE.pth` をダウンロードするか、次のページから `split_files/vae/wan_2.1_vae.safetensors` をダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
44
+
45
+ DiTの重みを次のページからダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
46
+
47
+ T2VやI2V、解像度、モデルサイズなどにより適切な重みを選択してください。`--fp8`指定時はfp8モデルも使用できます。
48
+
49
+ (repackaged版の重みを提供してくださっているComfy-Orgに感謝いたします。)
50
+ </details>
51
+
52
+ ## Pre-caching / 事前キャッシュ
53
+
54
+ ### Latent Pre-caching
55
+
56
+ Latent pre-caching is almost the same as in HunyuanVideo. Create the cache using the following command:
57
+
58
+ ```bash
59
+ python wan_cache_latents.py --dataset_config path/to/toml --vae path/to/wan_2.1_vae.safetensors
60
+ ```
61
+
62
+ If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model. If not specified, the training will raise an error.
63
+
64
+ If you're running low on VRAM, specify `--vae_cache_cpu` to use the CPU for the VAE internal cache, which will reduce VRAM usage somewhat.
65
+
66
+ <details>
67
+ <summary>日本語</summary>
68
+ latentの事前キャッシングはHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
69
+
70
+ I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。指定しないと学習時にエラーが発生します。
71
+
72
+ VRAMが不足している場合は、`--vae_cache_cpu` を指定するとVAEの内部キャッシュにCPUを使うことで、使用VRAMを多少削減できます。
73
+ </details>
74
+
75
+ ### Text Encoder Output Pre-caching
76
+
77
+ Text encoder output pre-caching is also almost the same as in HunyuanVideo. Create the cache using the following command:
78
+
79
+ ```bash
80
+ python wan_cache_text_encoder_outputs.py --dataset_config path/to/toml --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --batch_size 16
81
+ ```
82
+
83
+ Adjust `--batch_size` according to your available VRAM.
84
+
85
+ For systems with limited VRAM (less than ~16GB), use `--fp8_t5` to run the T5 in fp8 mode.
86
+
87
+ <details>
88
+ <summary>日本語</summary>
89
+ テキストエンコーダ出力の事前キャッシングもHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
90
+
91
+ 使用可能なVRAMに合わせて `--batch_size` を調整してください。
92
+
93
+ VRAMが限られているシステム(約16GB未満)の場合は、T5をfp8モードで実行するために `--fp8_t5` を使用してください。
94
+ </details>
95
+
96
+ ## Training / 学習
97
+
98
+ ### Training
99
+
100
+ Start training using the following command (input as a single line):
101
+
102
+ ```bash
103
+ accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 wan_train_network.py
104
+ --task t2v-1.3B
105
+ --dit path/to/wan2.1_xxx_bf16.safetensors
106
+ --dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base
107
+ --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing
108
+ --max_data_loader_n_workers 2 --persistent_data_loader_workers
109
+ --network_module networks.lora_wan --network_dim 32
110
+ --timestep_sampling shift --discrete_flow_shift 3.0
111
+ --max_train_epochs 16 --save_every_n_epochs 1 --seed 42
112
+ --output_dir path/to/output_dir --output_name name-of-lora
113
+ ```
114
+ The above is an example. The appropriate values for `timestep_sampling` and `discrete_flow_shift` need to be determined by experimentation.
115
+
116
+ For additional options, use `python wan_train_network.py --help` (note that many options are unverified).
117
+
118
+ `--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`. Specify the DiT weights for the task with `--dit`.
119
+
120
+ Don't forget to specify `--network_module networks.lora_wan`.
121
+
122
+ Other options are mostly the same as `hv_train_network.py`.
123
+
124
+ Use `convert_lora.py` for converting the LoRA weights after training, as in HunyuanVideo.
125
+
126
+ <details>
127
+ <summary>日本語</summary>
128
+ `timestep_sampling`や`discrete_flow_shift`は一例です。どのような値が適切かは実験が必要です。
129
+
130
+ その他のオプションについては `python wan_train_network.py --help` を使用してください(多くのオプションは未検証です)。
131
+
132
+ `--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。`--dit`に、taskに応じたDiTの重みを指定してください。
133
+
134
+ `--network_module` に `networks.lora_wan` を指定することを忘れないでください。
135
+
136
+ その他のオプションは、ほぼ`hv_train_network.py`と同様です。
137
+
138
+ 学習後のLoRAの重みの変換は、HunyuanVideoと同様に`convert_lora.py`を使用してください。
139
+ </details>
140
+
141
+ ### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
142
+
143
+ Example of command line options for training with sampling / 記述例:
144
+
145
+ ```bash
146
+ --vae path/to/wan_2.1_vae.safetensors
147
+ --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
148
+ --sample_prompts /path/to/prompt_file.txt
149
+ --sample_every_n_epochs 1 --sample_every_n_steps 1000 -- sample_at_first
150
+ ```
151
+ Each option is the same as when generating images or as HunyuanVideo. Please refer to [here](/docs/sampling_during_training.md) for details.
152
+
153
+ If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model.
154
+
155
+ You can specify the initial image and negative prompts in the prompt file. Please refer to [here](/docs/sampling_during_training.md#prompt-file--プロンプトファイル).
156
+
157
+ <details>
158
+ <summary>日本語</summary>
159
+ 各オプションは推論時、およびHunyuanVideoの場合と同様です。[こちら](/docs/sampling_during_training.md)を参照してください。
160
+
161
+ I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。
162
+
163
+ プロンプトファイルで、初期画像やネガティブプロンプト等を指定できます。[こちら](/docs/sampling_during_training.md#prompt-file--プロンプトファイル)を参照してください。
164
+ </details>
165
+
166
+
167
+ ## Inference / 推論
168
+
169
+ ### T2V Inference / T2V推論
170
+
171
+ The following is an example of T2V inference (input as a single line):
172
+
173
+ ```bash
174
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 832 480 --video_length 81 --infer_steps 20
175
+ --prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
176
+ --dit path/to/wan2.1_t2v_1.3B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
177
+ --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
178
+ --attn_mode torch
179
+ ```
180
+
181
+ `--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`.
182
+
183
+ `--attn_mode` is `torch`, `sdpa` (same as `torch`), `xformers`, `sageattn`,`flash2`, `flash` (same as `flash2`) or `flash3`. `torch` is the default. Other options require the corresponding library to be installed. `flash3` (Flash attention 3) is not tested.
184
+
185
+ `--fp8_t5` can be used to specify the T5 model in fp8 format. This option reduces memory usage for the T5 model.
186
+
187
+ `--negative_prompt` can be used to specify a negative prompt. If omitted, the default negative prompt is used.
188
+
189
+ ` --flow_shift` can be used to specify the flow shift (default 3.0 for I2V with 480p, 5.0 for others).
190
+
191
+ `--guidance_scale` can be used to specify the guidance scale for classifier free guiance (default 5.0).
192
+
193
+ `--blocks_to_swap` is the number of blocks to swap during inference. The default value is None (no block swap). The maximum value is 39 for 14B model and 29 for 1.3B model.
194
+
195
+ `--vae_cache_cpu` enables VAE cache in main memory. This reduces VRAM usage slightly but processing is slower.
196
+
197
+ Other options are same as `hv_generate_video.py` (some options are not supported, please check the help).
198
+
199
+ <details>
200
+ <summary>日本語</summary>
201
+ `--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。
202
+
203
+ `--attn_mode` には `torch`, `sdpa`(`torch`と同じ)、`xformers`, `sageattn`, `flash2`, `flash`(`flash2`と同じ), `flash3` のいずれかを指定します。デフォルトは `torch` です。その他のオプションを使用する場合は、対応するライブラリをインストールする必要があります。`flash3`(Flash attention 3)は未テストです。
204
+
205
+ `--fp8_t5` を指定するとT5モデルをfp8形式で実行します。T5モデル呼び出し時のメモリ使用量を削減します。
206
+
207
+ `--negative_prompt` でネガティブプロンプトを指定できます。省略した場合はデフォルトのネガティブプロンプトが使用されます。
208
+
209
+ `--flow_shift` でflow shiftを指定できます(480pのI2Vの場合はデフォルト3.0、それ以外は5.0)。
210
+
211
+ `--guidance_scale` でclassifier free guianceのガイダンススケールを指定できます(デフォルト5.0)。
212
+
213
+ `--blocks_to_swap` は推論時のblock swapの数です。デフォルト値はNone(block swapなし)です。最大値は14Bモデルの場合39、1.3Bモデルの場合29です。
214
+
215
+ `--vae_cache_cpu` を有効にすると、VAEのキャッシュをメインメモリに保持します。VRAM使用量が多少減りますが、処理は遅くなります。
216
+
217
+ その他のオプションは `hv_generate_video.py` と同じです(一部のオプションはサポートされていないため、ヘルプを確認してください)。
218
+ </details>
219
+
220
+ ### I2V Inference / I2V推論
221
+
222
+ The following is an example of I2V inference (input as a single line):
223
+
224
+ ```bash
225
+ python wan_generate_video.py --fp8 --task i2v-14B --video_size 832 480 --video_length 81 --infer_steps 20
226
+ --prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
227
+ --dit path/to/wan2.1_i2v_480p_14B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
228
+ --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
229
+ --attn_mode torch --image_path path/to/image.jpg
230
+ ```
231
+
232
+ Add `--clip` to specify the CLIP model. `--image_path` is the path to the image to be used as the initial frame.
233
+
234
+ Other options are same as T2V inference.
235
+
236
+ <details>
237
+ <summary>日本語</summary>
238
+ `--clip` を追加してCLIPモデルを指定します。`--image_path` は初期フレームとして使用する画像のパスです。
239
+
240
+ その他のオプションはT2V推論と同じです。
241
+ </details>
hunyuan_model/__init__.py ADDED
File without changes
hunyuan_model/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hunyuan_model/attention.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ from flash_attn.flash_attn_interface import flash_attn_func
13
+ except ImportError:
14
+ flash_attn = None
15
+ flash_attn_varlen_func = None
16
+ _flash_attn_forward = None
17
+ flash_attn_func = None
18
+
19
+ try:
20
+ print(f"Trying to import sageattention")
21
+ from sageattention import sageattn_varlen, sageattn
22
+
23
+ print("Successfully imported sageattention")
24
+ except ImportError:
25
+ print(f"Failed to import sageattention")
26
+ sageattn_varlen = None
27
+ sageattn = None
28
+
29
+ try:
30
+ import xformers.ops as xops
31
+ except ImportError:
32
+ xops = None
33
+
34
+ MEMORY_LAYOUT = {
35
+ "flash": (
36
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
37
+ lambda x: x,
38
+ ),
39
+ "flash_fixlen": (
40
+ lambda x: x,
41
+ lambda x: x,
42
+ ),
43
+ "sageattn": (
44
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
45
+ lambda x: x,
46
+ ),
47
+ "sageattn_fixlen": (
48
+ lambda x: x.transpose(1, 2),
49
+ lambda x: x.transpose(1, 2),
50
+ ),
51
+ "torch": (
52
+ lambda x: x.transpose(1, 2),
53
+ lambda x: x.transpose(1, 2),
54
+ ),
55
+ "xformers": (
56
+ lambda x: x,
57
+ lambda x: x,
58
+ ),
59
+ "vanilla": (
60
+ lambda x: x.transpose(1, 2),
61
+ lambda x: x.transpose(1, 2),
62
+ ),
63
+ }
64
+
65
+
66
+ def get_cu_seqlens(text_mask, img_len):
67
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
68
+
69
+ Args:
70
+ text_mask (torch.Tensor): the mask of text
71
+ img_len (int): the length of image
72
+
73
+ Returns:
74
+ torch.Tensor: the calculated cu_seqlens for flash attention
75
+ """
76
+ batch_size = text_mask.shape[0]
77
+ text_len = text_mask.sum(dim=1)
78
+ max_len = text_mask.shape[1] + img_len
79
+
80
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
81
+
82
+ for i in range(batch_size):
83
+ s = text_len[i] + img_len
84
+ s1 = i * max_len + s
85
+ s2 = (i + 1) * max_len
86
+ cu_seqlens[2 * i + 1] = s1
87
+ cu_seqlens[2 * i + 2] = s2
88
+
89
+ return cu_seqlens
90
+
91
+
92
+ def attention(
93
+ q_or_qkv_list,
94
+ k=None,
95
+ v=None,
96
+ mode="flash",
97
+ drop_rate=0,
98
+ attn_mask=None,
99
+ total_len=None,
100
+ causal=False,
101
+ cu_seqlens_q=None,
102
+ cu_seqlens_kv=None,
103
+ max_seqlen_q=None,
104
+ max_seqlen_kv=None,
105
+ batch_size=1,
106
+ ):
107
+ """
108
+ Perform QKV self attention.
109
+
110
+ Args:
111
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
112
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
113
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
114
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
115
+ drop_rate (float): Dropout rate in attention map. (default: 0)
116
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
117
+ (default: None)
118
+ causal (bool): Whether to use causal attention. (default: False)
119
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
120
+ used to index into q.
121
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
122
+ used to index into kv.
123
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
124
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
125
+
126
+ Returns:
127
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
128
+ """
129
+ q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
130
+ if type(q_or_qkv_list) == list:
131
+ q_or_qkv_list.clear()
132
+ split_attn = total_len is not None
133
+ if split_attn and mode == "sageattn":
134
+ mode = "sageattn_fixlen"
135
+ elif split_attn and mode == "flash":
136
+ mode = "flash_fixlen"
137
+ # print(f"Attention mode: {mode}, split_attn: {split_attn}")
138
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
139
+
140
+ # trim the sequence length to the actual length instead of attn_mask
141
+ if split_attn:
142
+ trimmed_len = q.shape[1] - total_len
143
+ q = [q[i : i + 1, : total_len[i]] for i in range(len(q))]
144
+ k = [k[i : i + 1, : total_len[i]] for i in range(len(k))]
145
+ v = [v[i : i + 1, : total_len[i]] for i in range(len(v))]
146
+ q = [pre_attn_layout(q_i) for q_i in q]
147
+ k = [pre_attn_layout(k_i) for k_i in k]
148
+ v = [pre_attn_layout(v_i) for v_i in v]
149
+ # print(
150
+ # f"Trimming the sequence length to {total_len},trimmed_len: {trimmed_len}, q.shape: {[q_i.shape for q_i in q]}, mode: {mode}"
151
+ # )
152
+ else:
153
+ q = pre_attn_layout(q)
154
+ k = pre_attn_layout(k)
155
+ v = pre_attn_layout(v)
156
+
157
+ if mode == "torch":
158
+ if split_attn:
159
+ x = []
160
+ for i in range(len(q)):
161
+ x_i = F.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate, is_causal=causal)
162
+ q[i], k[i], v[i] = None, None, None
163
+ x.append(x_i)
164
+ del q, k, v
165
+ else:
166
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
167
+ attn_mask = attn_mask.to(q.dtype)
168
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
169
+ del q, k, v
170
+ del attn_mask
171
+
172
+ elif mode == "xformers":
173
+ # B, M, H, K: M is the sequence length, H is the number of heads, K is the dimension of the heads -> it is same as input dimension
174
+ # currently only support batch_size = 1
175
+ assert split_attn, "Xformers only supports splitting"
176
+ x = []
177
+ for i in range(len(q)):
178
+ x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) # , causal=causal)
179
+ q[i], k[i], v[i] = None, None, None
180
+ x.append(x_i)
181
+ del q, k, v
182
+
183
+ elif mode == "flash":
184
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
185
+ del q, k, v
186
+ # x with shape [(bxs), a, d]
187
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
188
+ elif mode == "flash_fixlen":
189
+ x = []
190
+ for i in range(len(q)):
191
+ # q: (batch_size, seqlen, nheads, headdim), k: (batch_size, seqlen, nheads_k, headdim), v: (batch_size, seqlen, nheads_k, headdim)
192
+ x_i = flash_attn_func(q[i], k[i], v[i], dropout_p=drop_rate, causal=causal)
193
+ q[i], k[i], v[i] = None, None, None
194
+ x.append(x_i)
195
+ del q, k, v
196
+ elif mode == "sageattn":
197
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
198
+ del q, k, v
199
+ # x with shape [(bxs), a, d]
200
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
201
+ elif mode == "sageattn_fixlen":
202
+ x = []
203
+ for i in range(len(q)):
204
+ # HND seems to cause an error
205
+ x_i = sageattn(q[i], k[i], v[i]) # (batch_size, seq_len, head_num, head_dim)
206
+ q[i], k[i], v[i] = None, None, None
207
+ x.append(x_i)
208
+ del q, k, v
209
+ elif mode == "vanilla":
210
+ assert not split_attn, "Vanilla attention does not support trimming"
211
+ scale_factor = 1 / math.sqrt(q.size(-1))
212
+
213
+ b, a, s, _ = q.shape
214
+ s1 = k.size(2)
215
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
216
+ if causal:
217
+ # Only applied to self attention
218
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
219
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
220
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
221
+ attn_bias.to(q.dtype)
222
+
223
+ if attn_mask is not None:
224
+ if attn_mask.dtype == torch.bool:
225
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
226
+ else:
227
+ attn_bias += attn_mask
228
+
229
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
230
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
231
+ attn += attn_bias
232
+ attn = attn.softmax(dim=-1)
233
+ attn = torch.dropout(attn, p=drop_rate, train=True)
234
+ x = attn @ v
235
+ else:
236
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
237
+
238
+ if split_attn:
239
+ x = [post_attn_layout(x_i) for x_i in x]
240
+ for i in range(len(x)):
241
+ x[i] = F.pad(x[i], (0, 0, 0, 0, 0, trimmed_len[i]))
242
+ x = torch.cat(x, dim=0)
243
+ else:
244
+ x = post_attn_layout(x)
245
+
246
+ b, s, a, d = x.shape
247
+ out = x.reshape(b, s, -1)
248
+ return out
249
+
250
+
251
+ def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
252
+ attn1 = hybrid_seq_parallel_attn(
253
+ None,
254
+ q[:, :img_q_len, :, :],
255
+ k[:, :img_kv_len, :, :],
256
+ v[:, :img_kv_len, :, :],
257
+ dropout_p=0.0,
258
+ causal=False,
259
+ joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
260
+ joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
261
+ joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
262
+ joint_strategy="rear",
263
+ )
264
+ if flash_attn.__version__ >= "2.7.0":
265
+ attn2, *_ = _flash_attn_forward(
266
+ q[:, cu_seqlens_q[1] :],
267
+ k[:, cu_seqlens_kv[1] :],
268
+ v[:, cu_seqlens_kv[1] :],
269
+ dropout_p=0.0,
270
+ softmax_scale=q.shape[-1] ** (-0.5),
271
+ causal=False,
272
+ window_size_left=-1,
273
+ window_size_right=-1,
274
+ softcap=0.0,
275
+ alibi_slopes=None,
276
+ return_softmax=False,
277
+ )
278
+ else:
279
+ attn2, *_ = _flash_attn_forward(
280
+ q[:, cu_seqlens_q[1] :],
281
+ k[:, cu_seqlens_kv[1] :],
282
+ v[:, cu_seqlens_kv[1] :],
283
+ dropout_p=0.0,
284
+ softmax_scale=q.shape[-1] ** (-0.5),
285
+ causal=False,
286
+ window_size=(-1, -1),
287
+ softcap=0.0,
288
+ alibi_slopes=None,
289
+ return_softmax=False,
290
+ )
291
+ attn = torch.cat([attn1, attn2], dim=1)
292
+ b, s, a, d = attn.shape
293
+ attn = attn.reshape(b, s, -1)
294
+
295
+ return attn
hunyuan_model/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ from typing import Dict, Optional, Tuple, Union
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+
27
+ # try:
28
+ # # This diffusers is modified and packed in the mirror.
29
+ # from diffusers.loaders import FromOriginalVAEMixin
30
+ # except ImportError:
31
+ # # Use this to be compatible with the original diffusers.
32
+ # from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+ from diffusers.models.attention_processor import (
35
+ ADDED_KV_ATTENTION_PROCESSORS,
36
+ CROSS_ATTENTION_PROCESSORS,
37
+ Attention,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
+
46
+
47
+ @dataclass
48
+ class DecoderOutput2(BaseOutput):
49
+ sample: torch.FloatTensor
50
+ posterior: Optional[DiagonalGaussianDistribution] = None
51
+
52
+
53
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin):
54
+ r"""
55
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
+
57
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
+ for all models (such as downloading or saving).
59
+ """
60
+
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
+ block_out_channels: Tuple[int] = (64,),
71
+ layers_per_block: int = 1,
72
+ act_fn: str = "silu",
73
+ latent_channels: int = 4,
74
+ norm_num_groups: int = 32,
75
+ sample_size: int = 32,
76
+ sample_tsize: int = 64,
77
+ scaling_factor: float = 0.18215,
78
+ force_upcast: float = True,
79
+ spatial_compression_ratio: int = 8,
80
+ time_compression_ratio: int = 4,
81
+ mid_block_add_attention: bool = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.time_compression_ratio = time_compression_ratio
86
+
87
+ self.encoder = EncoderCausal3D(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ time_compression_ratio=time_compression_ratio,
97
+ spatial_compression_ratio=spatial_compression_ratio,
98
+ mid_block_add_attention=mid_block_add_attention,
99
+ )
100
+
101
+ self.decoder = DecoderCausal3D(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ time_compression_ratio=time_compression_ratio,
110
+ spatial_compression_ratio=spatial_compression_ratio,
111
+ mid_block_add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
+
117
+ self.use_slicing = False
118
+ self.use_spatial_tiling = False
119
+ self.use_temporal_tiling = False
120
+
121
+ # only relevant if vae tiling is enabled
122
+ self.tile_sample_min_tsize = sample_tsize
123
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
+
125
+ self.tile_sample_min_size = self.config.sample_size
126
+ sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
127
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
128
+ self.tile_overlap_factor = 0.25
129
+
130
+ def _set_gradient_checkpointing(self, module, value=False):
131
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
132
+ module.gradient_checkpointing = value
133
+
134
+ def enable_temporal_tiling(self, use_tiling: bool = True):
135
+ self.use_temporal_tiling = use_tiling
136
+
137
+ def disable_temporal_tiling(self):
138
+ self.enable_temporal_tiling(False)
139
+
140
+ def enable_spatial_tiling(self, use_tiling: bool = True):
141
+ self.use_spatial_tiling = use_tiling
142
+
143
+ def disable_spatial_tiling(self):
144
+ self.enable_spatial_tiling(False)
145
+
146
+ def enable_tiling(self, use_tiling: bool = True):
147
+ r"""
148
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
149
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
150
+ processing larger videos.
151
+ """
152
+ self.enable_spatial_tiling(use_tiling)
153
+ self.enable_temporal_tiling(use_tiling)
154
+
155
+ def disable_tiling(self):
156
+ r"""
157
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
158
+ decoding in one step.
159
+ """
160
+ self.disable_spatial_tiling()
161
+ self.disable_temporal_tiling()
162
+
163
+ def enable_slicing(self):
164
+ r"""
165
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
166
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
167
+ """
168
+ self.use_slicing = True
169
+
170
+ def disable_slicing(self):
171
+ r"""
172
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
173
+ decoding in one step.
174
+ """
175
+ self.use_slicing = False
176
+
177
+ def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
178
+ # set chunk_size to CausalConv3d recursively
179
+ def set_chunk_size(module):
180
+ if hasattr(module, "chunk_size"):
181
+ module.chunk_size = chunk_size
182
+
183
+ self.apply(set_chunk_size)
184
+
185
+ @property
186
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
187
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
188
+ r"""
189
+ Returns:
190
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
191
+ indexed by its weight name.
192
+ """
193
+ # set recursively
194
+ processors = {}
195
+
196
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
197
+ if hasattr(module, "get_processor"):
198
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
199
+
200
+ for sub_name, child in module.named_children():
201
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
202
+
203
+ return processors
204
+
205
+ for name, module in self.named_children():
206
+ fn_recursive_add_processors(name, module, processors)
207
+
208
+ return processors
209
+
210
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
211
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
212
+ r"""
213
+ Sets the attention processor to use to compute attention.
214
+
215
+ Parameters:
216
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
217
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
218
+ for **all** `Attention` layers.
219
+
220
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
221
+ processor. This is strongly recommended when setting trainable attention processors.
222
+
223
+ """
224
+ count = len(self.attn_processors.keys())
225
+
226
+ if isinstance(processor, dict) and len(processor) != count:
227
+ raise ValueError(
228
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
229
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
230
+ )
231
+
232
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
233
+ if hasattr(module, "set_processor"):
234
+ if not isinstance(processor, dict):
235
+ module.set_processor(processor, _remove_lora=_remove_lora)
236
+ else:
237
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
238
+
239
+ for sub_name, child in module.named_children():
240
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
241
+
242
+ for name, module in self.named_children():
243
+ fn_recursive_attn_processor(name, module, processor)
244
+
245
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
246
+ def set_default_attn_processor(self):
247
+ """
248
+ Disables custom attention processors and sets the default attention implementation.
249
+ """
250
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
+ processor = AttnAddedKVProcessor()
252
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
253
+ processor = AttnProcessor()
254
+ else:
255
+ raise ValueError(
256
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
257
+ )
258
+
259
+ self.set_attn_processor(processor, _remove_lora=True)
260
+
261
+ @apply_forward_hook
262
+ def encode(
263
+ self, x: torch.FloatTensor, return_dict: bool = True
264
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
265
+ """
266
+ Encode a batch of images/videos into latents.
267
+
268
+ Args:
269
+ x (`torch.FloatTensor`): Input batch of images/videos.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
272
+
273
+ Returns:
274
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
275
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
276
+ """
277
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
278
+
279
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
280
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
281
+
282
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
283
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
284
+
285
+ if self.use_slicing and x.shape[0] > 1:
286
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
287
+ h = torch.cat(encoded_slices)
288
+ else:
289
+ h = self.encoder(x)
290
+
291
+ moments = self.quant_conv(h)
292
+ posterior = DiagonalGaussianDistribution(moments)
293
+
294
+ if not return_dict:
295
+ return (posterior,)
296
+
297
+ return AutoencoderKLOutput(latent_dist=posterior)
298
+
299
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
300
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
301
+
302
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
303
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
304
+
305
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
306
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
307
+
308
+ z = self.post_quant_conv(z)
309
+ dec = self.decoder(z)
310
+
311
+ if not return_dict:
312
+ return (dec,)
313
+
314
+ return DecoderOutput(sample=dec)
315
+
316
+ @apply_forward_hook
317
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
318
+ """
319
+ Decode a batch of images/videos.
320
+
321
+ Args:
322
+ z (`torch.FloatTensor`): Input batch of latent vectors.
323
+ return_dict (`bool`, *optional*, defaults to `True`):
324
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
+
326
+ Returns:
327
+ [`~models.vae.DecoderOutput`] or `tuple`:
328
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
+ returned.
330
+
331
+ """
332
+ if self.use_slicing and z.shape[0] > 1:
333
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
+ decoded = torch.cat(decoded_slices)
335
+ else:
336
+ decoded = self._decode(z).sample
337
+
338
+ if not return_dict:
339
+ return (decoded,)
340
+
341
+ return DecoderOutput(sample=decoded)
342
+
343
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
+ for y in range(blend_extent):
346
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
+ return b
348
+
349
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
+ for x in range(blend_extent):
352
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
+ return b
354
+
355
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
+ return b
360
+
361
+ def spatial_tiled_encode(
362
+ self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
363
+ ) -> AutoencoderKLOutput:
364
+ r"""Encode a batch of images/videos using a tiled encoder.
365
+
366
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
367
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
368
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
369
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
370
+ output, but they should be much less noticeable.
371
+
372
+ Args:
373
+ x (`torch.FloatTensor`): Input batch of images/videos.
374
+ return_dict (`bool`, *optional*, defaults to `True`):
375
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
376
+
377
+ Returns:
378
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
379
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
380
+ `tuple` is returned.
381
+ """
382
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
383
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
384
+ row_limit = self.tile_latent_min_size - blend_extent
385
+
386
+ # Split video into tiles and encode them separately.
387
+ rows = []
388
+ for i in range(0, x.shape[-2], overlap_size):
389
+ row = []
390
+ for j in range(0, x.shape[-1], overlap_size):
391
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
392
+ tile = self.encoder(tile)
393
+ tile = self.quant_conv(tile)
394
+ row.append(tile)
395
+ rows.append(row)
396
+ result_rows = []
397
+ for i, row in enumerate(rows):
398
+ result_row = []
399
+ for j, tile in enumerate(row):
400
+ # blend the above tile and the left tile
401
+ # to the current tile and add the current tile to the result row
402
+ if i > 0:
403
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
404
+ if j > 0:
405
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
406
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
407
+ result_rows.append(torch.cat(result_row, dim=-1))
408
+
409
+ moments = torch.cat(result_rows, dim=-2)
410
+ if return_moments:
411
+ return moments
412
+
413
+ posterior = DiagonalGaussianDistribution(moments)
414
+ if not return_dict:
415
+ return (posterior,)
416
+
417
+ return AutoencoderKLOutput(latent_dist=posterior)
418
+
419
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
420
+ r"""
421
+ Decode a batch of images/videos using a tiled decoder.
422
+
423
+ Args:
424
+ z (`torch.FloatTensor`): Input batch of latent vectors.
425
+ return_dict (`bool`, *optional*, defaults to `True`):
426
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
427
+
428
+ Returns:
429
+ [`~models.vae.DecoderOutput`] or `tuple`:
430
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
431
+ returned.
432
+ """
433
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
434
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
435
+ row_limit = self.tile_sample_min_size - blend_extent
436
+
437
+ # Split z into overlapping tiles and decode them separately.
438
+ # The tiles have an overlap to avoid seams between tiles.
439
+ rows = []
440
+ for i in range(0, z.shape[-2], overlap_size):
441
+ row = []
442
+ for j in range(0, z.shape[-1], overlap_size):
443
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
444
+ tile = self.post_quant_conv(tile)
445
+ decoded = self.decoder(tile)
446
+ row.append(decoded)
447
+ rows.append(row)
448
+ result_rows = []
449
+ for i, row in enumerate(rows):
450
+ result_row = []
451
+ for j, tile in enumerate(row):
452
+ # blend the above tile and the left tile
453
+ # to the current tile and add the current tile to the result row
454
+ if i > 0:
455
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
456
+ if j > 0:
457
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
458
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
459
+ result_rows.append(torch.cat(result_row, dim=-1))
460
+
461
+ dec = torch.cat(result_rows, dim=-2)
462
+ if not return_dict:
463
+ return (dec,)
464
+
465
+ return DecoderOutput(sample=dec)
466
+
467
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
468
+
469
+ B, C, T, H, W = x.shape
470
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
471
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
472
+ t_limit = self.tile_latent_min_tsize - blend_extent
473
+
474
+ # Split the video into tiles and encode them separately.
475
+ row = []
476
+ for i in range(0, T, overlap_size):
477
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
478
+ if self.use_spatial_tiling and (
479
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
480
+ ):
481
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
482
+ else:
483
+ tile = self.encoder(tile)
484
+ tile = self.quant_conv(tile)
485
+ if i > 0:
486
+ tile = tile[:, :, 1:, :, :]
487
+ row.append(tile)
488
+ result_row = []
489
+ for i, tile in enumerate(row):
490
+ if i > 0:
491
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
492
+ result_row.append(tile[:, :, :t_limit, :, :])
493
+ else:
494
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
495
+
496
+ moments = torch.cat(result_row, dim=2)
497
+ posterior = DiagonalGaussianDistribution(moments)
498
+
499
+ if not return_dict:
500
+ return (posterior,)
501
+
502
+ return AutoencoderKLOutput(latent_dist=posterior)
503
+
504
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
505
+ # Split z into overlapping tiles and decode them separately.
506
+
507
+ B, C, T, H, W = z.shape
508
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
509
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
510
+ t_limit = self.tile_sample_min_tsize - blend_extent
511
+
512
+ row = []
513
+ for i in range(0, T, overlap_size):
514
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
515
+ if self.use_spatial_tiling and (
516
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
517
+ ):
518
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
519
+ else:
520
+ tile = self.post_quant_conv(tile)
521
+ decoded = self.decoder(tile)
522
+ if i > 0:
523
+ decoded = decoded[:, :, 1:, :, :]
524
+ row.append(decoded)
525
+ result_row = []
526
+ for i, tile in enumerate(row):
527
+ if i > 0:
528
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
529
+ result_row.append(tile[:, :, :t_limit, :, :])
530
+ else:
531
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
532
+
533
+ dec = torch.cat(result_row, dim=2)
534
+ if not return_dict:
535
+ return (dec,)
536
+
537
+ return DecoderOutput(sample=dec)
538
+
539
+ def forward(
540
+ self,
541
+ sample: torch.FloatTensor,
542
+ sample_posterior: bool = False,
543
+ return_dict: bool = True,
544
+ return_posterior: bool = False,
545
+ generator: Optional[torch.Generator] = None,
546
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
547
+ r"""
548
+ Args:
549
+ sample (`torch.FloatTensor`): Input sample.
550
+ sample_posterior (`bool`, *optional*, defaults to `False`):
551
+ Whether to sample from the posterior.
552
+ return_dict (`bool`, *optional*, defaults to `True`):
553
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
554
+ """
555
+ x = sample
556
+ posterior = self.encode(x).latent_dist
557
+ if sample_posterior:
558
+ z = posterior.sample(generator=generator)
559
+ else:
560
+ z = posterior.mode()
561
+ dec = self.decode(z).sample
562
+
563
+ if not return_dict:
564
+ if return_posterior:
565
+ return (dec, posterior)
566
+ else:
567
+ return (dec,)
568
+ if return_posterior:
569
+ return DecoderOutput2(sample=dec, posterior=posterior)
570
+ else:
571
+ return DecoderOutput2(sample=dec)
572
+
573
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
574
+ def fuse_qkv_projections(self):
575
+ """
576
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
577
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
578
+
579
+ <Tip warning={true}>
580
+
581
+ This API is 🧪 experimental.
582
+
583
+ </Tip>
584
+ """
585
+ self.original_attn_processors = None
586
+
587
+ for _, attn_processor in self.attn_processors.items():
588
+ if "Added" in str(attn_processor.__class__.__name__):
589
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
590
+
591
+ self.original_attn_processors = self.attn_processors
592
+
593
+ for module in self.modules():
594
+ if isinstance(module, Attention):
595
+ module.fuse_projections(fuse=True)
596
+
597
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
598
+ def unfuse_qkv_projections(self):
599
+ """Disables the fused QKV projection if enabled.
600
+
601
+ <Tip warning={true}>
602
+
603
+ This API is 🧪 experimental.
604
+
605
+ </Tip>
606
+
607
+ """
608
+ if self.original_attn_processors is not None:
609
+ self.set_attn_processor(self.original_attn_processors)
hunyuan_model/embed_layers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange, repeat
6
+
7
+ from .helpers import to_2tuple
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ dtype=None,
32
+ device=None,
33
+ ):
34
+ factory_kwargs = {"dtype": dtype, "device": device}
35
+ super().__init__()
36
+ patch_size = to_2tuple(patch_size)
37
+ self.patch_size = patch_size
38
+ self.flatten = flatten
39
+
40
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
41
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
42
+ if bias:
43
+ nn.init.zeros_(self.proj.bias)
44
+
45
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
46
+
47
+ def forward(self, x):
48
+ x = self.proj(x)
49
+ if self.flatten:
50
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
51
+ x = self.norm(x)
52
+ return x
53
+
54
+
55
+ class TextProjection(nn.Module):
56
+ """
57
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
58
+
59
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60
+ """
61
+
62
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63
+ factory_kwargs = {"dtype": dtype, "device": device}
64
+ super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66
+ self.act_1 = act_layer()
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68
+
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
74
+
75
+
76
+ def timestep_embedding(t, dim, max_period=10000):
77
+ """
78
+ Create sinusoidal timestep embeddings.
79
+
80
+ Args:
81
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82
+ dim (int): the dimension of the output.
83
+ max_period (int): controls the minimum frequency of the embeddings.
84
+
85
+ Returns:
86
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87
+
88
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89
+ """
90
+ half = dim // 2
91
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
92
+ args = t[:, None].float() * freqs[None]
93
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
94
+ if dim % 2:
95
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
96
+ return embedding
97
+
98
+
99
+ class TimestepEmbedder(nn.Module):
100
+ """
101
+ Embeds scalar timesteps into vector representations.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ act_layer,
108
+ frequency_embedding_size=256,
109
+ max_period=10000,
110
+ out_size=None,
111
+ dtype=None,
112
+ device=None,
113
+ ):
114
+ factory_kwargs = {"dtype": dtype, "device": device}
115
+ super().__init__()
116
+ self.frequency_embedding_size = frequency_embedding_size
117
+ self.max_period = max_period
118
+ if out_size is None:
119
+ out_size = hidden_size
120
+
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
123
+ act_layer(),
124
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
125
+ )
126
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
127
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
128
+
129
+ def forward(self, t):
130
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
131
+ t_emb = self.mlp(t_freq)
132
+ return t_emb
hunyuan_model/helpers.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+
3
+ from itertools import repeat
4
+
5
+
6
+ def _ntuple(n):
7
+ def parse(x):
8
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9
+ x = tuple(x)
10
+ if len(x) == 1:
11
+ x = tuple(repeat(x[0], n))
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+
22
+
23
+ def as_tuple(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ return tuple(x)
26
+ if x is None or isinstance(x, (int, float, str)):
27
+ return (x,)
28
+ else:
29
+ raise ValueError(f"Unknown type {type(x)}")
30
+
31
+
32
+ def as_list_of_2tuple(x):
33
+ x = as_tuple(x)
34
+ if len(x) == 1:
35
+ x = (x[0], x[0])
36
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37
+ lst = []
38
+ for i in range(0, len(x), 2):
39
+ lst.append((x[i], x[i + 1]))
40
+ return lst
hunyuan_model/mlp_layers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modulate_layers import modulate
10
+ from .helpers import to_2tuple
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(
38
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39
+ )
40
+ self.act = act_layer()
41
+ self.drop1 = nn.Dropout(drop_probs[0])
42
+ self.norm = (
43
+ norm_layer(hidden_channels, **factory_kwargs)
44
+ if norm_layer is not None
45
+ else nn.Identity()
46
+ )
47
+ self.fc2 = linear_layer(
48
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
49
+ )
50
+ self.drop2 = nn.Dropout(drop_probs[1])
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop1(x)
56
+ x = self.norm(x)
57
+ x = self.fc2(x)
58
+ x = self.drop2(x)
59
+ return x
60
+
61
+
62
+ #
63
+ class MLPEmbedder(nn.Module):
64
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66
+ factory_kwargs = {"device": device, "dtype": dtype}
67
+ super().__init__()
68
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69
+ self.silu = nn.SiLU()
70
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.out_layer(self.silu(self.in_layer(x)))
74
+
75
+
76
+ class FinalLayer(nn.Module):
77
+ """The final layer of DiT."""
78
+
79
+ def __init__(
80
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81
+ ):
82
+ factory_kwargs = {"device": device, "dtype": dtype}
83
+ super().__init__()
84
+
85
+ # Just use LayerNorm for the final layer
86
+ self.norm_final = nn.LayerNorm(
87
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88
+ )
89
+ if isinstance(patch_size, int):
90
+ self.linear = nn.Linear(
91
+ hidden_size,
92
+ patch_size * patch_size * out_channels,
93
+ bias=True,
94
+ **factory_kwargs
95
+ )
96
+ else:
97
+ self.linear = nn.Linear(
98
+ hidden_size,
99
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100
+ bias=True,
101
+ )
102
+ nn.init.zeros_(self.linear.weight)
103
+ nn.init.zeros_(self.linear.bias)
104
+
105
+ # Here we don't distinguish between the modulate types. Just use the simple one.
106
+ self.adaLN_modulation = nn.Sequential(
107
+ act_layer(),
108
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109
+ )
110
+ # Zero-initialize the modulation
111
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
112
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
113
+
114
+ def forward(self, x, c):
115
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
117
+ x = self.linear(x)
118
+ return x
hunyuan_model/models.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, List, Tuple, Optional, Union, Dict
3
+ import accelerate
4
+ from einops import rearrange
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from .activation_layers import get_activation_layer
11
+ from .norm_layers import get_norm_layer
12
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
13
+ from .attention import attention, parallel_attention, get_cu_seqlens
14
+ from .posemb_layers import apply_rotary_emb
15
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
16
+ from .modulate_layers import ModulateDiT, modulate, apply_gate
17
+ from .token_refiner import SingleTokenRefiner
18
+ from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
19
+ from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
20
+
21
+ from utils.safetensors_utils import MemoryEfficientSafeOpen
22
+
23
+
24
+ class MMDoubleStreamBlock(nn.Module):
25
+ """
26
+ A multimodal dit block with seperate modulation for
27
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
28
+ (Flux.1): https://github.com/black-forest-labs/flux
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ hidden_size: int,
34
+ heads_num: int,
35
+ mlp_width_ratio: float,
36
+ mlp_act_type: str = "gelu_tanh",
37
+ qk_norm: bool = True,
38
+ qk_norm_type: str = "rms",
39
+ qkv_bias: bool = False,
40
+ dtype: Optional[torch.dtype] = None,
41
+ device: Optional[torch.device] = None,
42
+ attn_mode: str = "flash",
43
+ split_attn: bool = False,
44
+ ):
45
+ factory_kwargs = {"device": device, "dtype": dtype}
46
+ super().__init__()
47
+ self.attn_mode = attn_mode
48
+ self.split_attn = split_attn
49
+
50
+ self.deterministic = False
51
+ self.heads_num = heads_num
52
+ head_dim = hidden_size // heads_num
53
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
54
+
55
+ self.img_mod = ModulateDiT(
56
+ hidden_size,
57
+ factor=6,
58
+ act_layer=get_activation_layer("silu"),
59
+ **factory_kwargs,
60
+ )
61
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
62
+
63
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
64
+ qk_norm_layer = get_norm_layer(qk_norm_type)
65
+ self.img_attn_q_norm = (
66
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
67
+ )
68
+ self.img_attn_k_norm = (
69
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
70
+ )
71
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
72
+
73
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
74
+ self.img_mlp = MLP(
75
+ hidden_size,
76
+ mlp_hidden_dim,
77
+ act_layer=get_activation_layer(mlp_act_type),
78
+ bias=True,
79
+ **factory_kwargs,
80
+ )
81
+
82
+ self.txt_mod = ModulateDiT(
83
+ hidden_size,
84
+ factor=6,
85
+ act_layer=get_activation_layer("silu"),
86
+ **factory_kwargs,
87
+ )
88
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
89
+
90
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
91
+ self.txt_attn_q_norm = (
92
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
93
+ )
94
+ self.txt_attn_k_norm = (
95
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
96
+ )
97
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
98
+
99
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
100
+ self.txt_mlp = MLP(
101
+ hidden_size,
102
+ mlp_hidden_dim,
103
+ act_layer=get_activation_layer(mlp_act_type),
104
+ bias=True,
105
+ **factory_kwargs,
106
+ )
107
+ self.hybrid_seq_parallel_attn = None
108
+
109
+ self.gradient_checkpointing = False
110
+
111
+ def enable_deterministic(self):
112
+ self.deterministic = True
113
+
114
+ def disable_deterministic(self):
115
+ self.deterministic = False
116
+
117
+ def enable_gradient_checkpointing(self):
118
+ self.gradient_checkpointing = True
119
+
120
+ def disable_gradient_checkpointing(self):
121
+ self.gradient_checkpointing = False
122
+
123
+ def _forward(
124
+ self,
125
+ img: torch.Tensor,
126
+ txt: torch.Tensor,
127
+ vec: torch.Tensor,
128
+ attn_mask: Optional[torch.Tensor] = None,
129
+ total_len: Optional[torch.Tensor] = None,
130
+ cu_seqlens_q: Optional[torch.Tensor] = None,
131
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
132
+ max_seqlen_q: Optional[int] = None,
133
+ max_seqlen_kv: Optional[int] = None,
134
+ freqs_cis: tuple = None,
135
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
136
+ (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
137
+ 6, dim=-1
138
+ )
139
+ (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
140
+ 6, dim=-1
141
+ )
142
+
143
+ # Prepare image for attention.
144
+ img_modulated = self.img_norm1(img)
145
+ img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
146
+ img_qkv = self.img_attn_qkv(img_modulated)
147
+ img_modulated = None
148
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
149
+ img_qkv = None
150
+ # Apply QK-Norm if needed
151
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
152
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
153
+
154
+ # Apply RoPE if needed.
155
+ if freqs_cis is not None:
156
+ img_q_shape = img_q.shape
157
+ img_k_shape = img_k.shape
158
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
159
+ assert (
160
+ img_q.shape == img_q_shape and img_k.shape == img_k_shape
161
+ ), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
162
+ # img_q, img_k = img_qq, img_kk
163
+
164
+ # Prepare txt for attention.
165
+ txt_modulated = self.txt_norm1(txt)
166
+ txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
167
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
168
+ txt_modulated = None
169
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
170
+ txt_qkv = None
171
+ # Apply QK-Norm if needed.
172
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
173
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
174
+
175
+ # Run actual attention.
176
+ img_q_len = img_q.shape[1]
177
+ img_kv_len = img_k.shape[1]
178
+ batch_size = img_k.shape[0]
179
+ q = torch.cat((img_q, txt_q), dim=1)
180
+ img_q = txt_q = None
181
+ k = torch.cat((img_k, txt_k), dim=1)
182
+ img_k = txt_k = None
183
+ v = torch.cat((img_v, txt_v), dim=1)
184
+ img_v = txt_v = None
185
+
186
+ assert (
187
+ cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
188
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
189
+
190
+ # attention computation start
191
+ if not self.hybrid_seq_parallel_attn:
192
+ l = [q, k, v]
193
+ q = k = v = None
194
+ attn = attention(
195
+ l,
196
+ mode=self.attn_mode,
197
+ attn_mask=attn_mask,
198
+ total_len=total_len,
199
+ cu_seqlens_q=cu_seqlens_q,
200
+ cu_seqlens_kv=cu_seqlens_kv,
201
+ max_seqlen_q=max_seqlen_q,
202
+ max_seqlen_kv=max_seqlen_kv,
203
+ batch_size=batch_size,
204
+ )
205
+ else:
206
+ attn = parallel_attention(
207
+ self.hybrid_seq_parallel_attn,
208
+ q,
209
+ k,
210
+ v,
211
+ img_q_len=img_q_len,
212
+ img_kv_len=img_kv_len,
213
+ cu_seqlens_q=cu_seqlens_q,
214
+ cu_seqlens_kv=cu_seqlens_kv,
215
+ )
216
+
217
+ # attention computation end
218
+
219
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
220
+ attn = None
221
+
222
+ # Calculate the img bloks.
223
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
224
+ img_attn = None
225
+ img = img + apply_gate(
226
+ self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
227
+ gate=img_mod2_gate,
228
+ )
229
+
230
+ # Calculate the txt bloks.
231
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
232
+ txt_attn = None
233
+ txt = txt + apply_gate(
234
+ self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
235
+ gate=txt_mod2_gate,
236
+ )
237
+
238
+ return img, txt
239
+
240
+ # def forward(
241
+ # self,
242
+ # img: torch.Tensor,
243
+ # txt: torch.Tensor,
244
+ # vec: torch.Tensor,
245
+ # attn_mask: Optional[torch.Tensor] = None,
246
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
247
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
248
+ # max_seqlen_q: Optional[int] = None,
249
+ # max_seqlen_kv: Optional[int] = None,
250
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
251
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ def forward(self, *args, **kwargs):
253
+ if self.training and self.gradient_checkpointing:
254
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
255
+ else:
256
+ return self._forward(*args, **kwargs)
257
+
258
+
259
+ class MMSingleStreamBlock(nn.Module):
260
+ """
261
+ A DiT block with parallel linear layers as described in
262
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
263
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
264
+ (Flux.1): https://github.com/black-forest-labs/flux
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ hidden_size: int,
270
+ heads_num: int,
271
+ mlp_width_ratio: float = 4.0,
272
+ mlp_act_type: str = "gelu_tanh",
273
+ qk_norm: bool = True,
274
+ qk_norm_type: str = "rms",
275
+ qk_scale: float = None,
276
+ dtype: Optional[torch.dtype] = None,
277
+ device: Optional[torch.device] = None,
278
+ attn_mode: str = "flash",
279
+ split_attn: bool = False,
280
+ ):
281
+ factory_kwargs = {"device": device, "dtype": dtype}
282
+ super().__init__()
283
+ self.attn_mode = attn_mode
284
+ self.split_attn = split_attn
285
+
286
+ self.deterministic = False
287
+ self.hidden_size = hidden_size
288
+ self.heads_num = heads_num
289
+ head_dim = hidden_size // heads_num
290
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
291
+ self.mlp_hidden_dim = mlp_hidden_dim
292
+ self.scale = qk_scale or head_dim**-0.5
293
+
294
+ # qkv and mlp_in
295
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
296
+ # proj and mlp_out
297
+ self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
298
+
299
+ qk_norm_layer = get_norm_layer(qk_norm_type)
300
+ self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
301
+ self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
302
+
303
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
304
+
305
+ self.mlp_act = get_activation_layer(mlp_act_type)()
306
+ self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
307
+ self.hybrid_seq_parallel_attn = None
308
+
309
+ self.gradient_checkpointing = False
310
+
311
+ def enable_deterministic(self):
312
+ self.deterministic = True
313
+
314
+ def disable_deterministic(self):
315
+ self.deterministic = False
316
+
317
+ def enable_gradient_checkpointing(self):
318
+ self.gradient_checkpointing = True
319
+
320
+ def disable_gradient_checkpointing(self):
321
+ self.gradient_checkpointing = False
322
+
323
+ def _forward(
324
+ self,
325
+ x: torch.Tensor,
326
+ vec: torch.Tensor,
327
+ txt_len: int,
328
+ attn_mask: Optional[torch.Tensor] = None,
329
+ total_len: Optional[torch.Tensor] = None,
330
+ cu_seqlens_q: Optional[torch.Tensor] = None,
331
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
332
+ max_seqlen_q: Optional[int] = None,
333
+ max_seqlen_kv: Optional[int] = None,
334
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
335
+ ) -> torch.Tensor:
336
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
337
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
338
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
339
+ x_mod = None
340
+ # mlp = mlp.to("cpu", non_blocking=True)
341
+ # clean_memory_on_device(x.device)
342
+
343
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
344
+ qkv = None
345
+
346
+ # Apply QK-Norm if needed.
347
+ q = self.q_norm(q).to(v)
348
+ k = self.k_norm(k).to(v)
349
+
350
+ # Apply RoPE if needed.
351
+ if freqs_cis is not None:
352
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
353
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
354
+ q = k = None
355
+ img_q_shape = img_q.shape
356
+ img_k_shape = img_k.shape
357
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
358
+ assert (
359
+ img_q.shape == img_q_shape and img_k_shape == img_k.shape
360
+ ), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
361
+ # img_q, img_k = img_qq, img_kk
362
+ # del img_qq, img_kk
363
+ q = torch.cat((img_q, txt_q), dim=1)
364
+ k = torch.cat((img_k, txt_k), dim=1)
365
+ del img_q, txt_q, img_k, txt_k
366
+
367
+ # Compute attention.
368
+ assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
369
+
370
+ # attention computation start
371
+ if not self.hybrid_seq_parallel_attn:
372
+ l = [q, k, v]
373
+ q = k = v = None
374
+ attn = attention(
375
+ l,
376
+ mode=self.attn_mode,
377
+ attn_mask=attn_mask,
378
+ total_len=total_len,
379
+ cu_seqlens_q=cu_seqlens_q,
380
+ cu_seqlens_kv=cu_seqlens_kv,
381
+ max_seqlen_q=max_seqlen_q,
382
+ max_seqlen_kv=max_seqlen_kv,
383
+ batch_size=x.shape[0],
384
+ )
385
+ else:
386
+ attn = parallel_attention(
387
+ self.hybrid_seq_parallel_attn,
388
+ q,
389
+ k,
390
+ v,
391
+ img_q_len=img_q.shape[1],
392
+ img_kv_len=img_k.shape[1],
393
+ cu_seqlens_q=cu_seqlens_q,
394
+ cu_seqlens_kv=cu_seqlens_kv,
395
+ )
396
+ # attention computation end
397
+
398
+ # Compute activation in mlp stream, cat again and run second linear layer.
399
+ # mlp = mlp.to(x.device)
400
+ mlp = self.mlp_act(mlp)
401
+ attn_mlp = torch.cat((attn, mlp), 2)
402
+ attn = None
403
+ mlp = None
404
+ output = self.linear2(attn_mlp)
405
+ attn_mlp = None
406
+ return x + apply_gate(output, gate=mod_gate)
407
+
408
+ # def forward(
409
+ # self,
410
+ # x: torch.Tensor,
411
+ # vec: torch.Tensor,
412
+ # txt_len: int,
413
+ # attn_mask: Optional[torch.Tensor] = None,
414
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
415
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
416
+ # max_seqlen_q: Optional[int] = None,
417
+ # max_seqlen_kv: Optional[int] = None,
418
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
419
+ # ) -> torch.Tensor:
420
+ def forward(self, *args, **kwargs):
421
+ if self.training and self.gradient_checkpointing:
422
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
423
+ else:
424
+ return self._forward(*args, **kwargs)
425
+
426
+
427
+ class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
428
+ """
429
+ HunyuanVideo Transformer backbone
430
+
431
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
432
+
433
+ Reference:
434
+ [1] Flux.1: https://github.com/black-forest-labs/flux
435
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
436
+
437
+ Parameters
438
+ ----------
439
+ args: argparse.Namespace
440
+ The arguments parsed by argparse.
441
+ patch_size: list
442
+ The size of the patch.
443
+ in_channels: int
444
+ The number of input channels.
445
+ out_channels: int
446
+ The number of output channels.
447
+ hidden_size: int
448
+ The hidden size of the transformer backbone.
449
+ heads_num: int
450
+ The number of attention heads.
451
+ mlp_width_ratio: float
452
+ The ratio of the hidden size of the MLP in the transformer block.
453
+ mlp_act_type: str
454
+ The activation function of the MLP in the transformer block.
455
+ depth_double_blocks: int
456
+ The number of transformer blocks in the double blocks.
457
+ depth_single_blocks: int
458
+ The number of transformer blocks in the single blocks.
459
+ rope_dim_list: list
460
+ The dimension of the rotary embedding for t, h, w.
461
+ qkv_bias: bool
462
+ Whether to use bias in the qkv linear layer.
463
+ qk_norm: bool
464
+ Whether to use qk norm.
465
+ qk_norm_type: str
466
+ The type of qk norm.
467
+ guidance_embed: bool
468
+ Whether to use guidance embedding for distillation.
469
+ text_projection: str
470
+ The type of the text projection, default is single_refiner.
471
+ use_attention_mask: bool
472
+ Whether to use attention mask for text encoder.
473
+ dtype: torch.dtype
474
+ The dtype of the model.
475
+ device: torch.device
476
+ The device of the model.
477
+ attn_mode: str
478
+ The mode of the attention, default is flash.
479
+ split_attn: bool
480
+ Whether to use split attention (make attention as batch size 1).
481
+ """
482
+
483
+ # @register_to_config
484
+ def __init__(
485
+ self,
486
+ text_states_dim: int,
487
+ text_states_dim_2: int,
488
+ patch_size: list = [1, 2, 2],
489
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
490
+ out_channels: int = None,
491
+ hidden_size: int = 3072,
492
+ heads_num: int = 24,
493
+ mlp_width_ratio: float = 4.0,
494
+ mlp_act_type: str = "gelu_tanh",
495
+ mm_double_blocks_depth: int = 20,
496
+ mm_single_blocks_depth: int = 40,
497
+ rope_dim_list: List[int] = [16, 56, 56],
498
+ qkv_bias: bool = True,
499
+ qk_norm: bool = True,
500
+ qk_norm_type: str = "rms",
501
+ guidance_embed: bool = False, # For modulation.
502
+ text_projection: str = "single_refiner",
503
+ use_attention_mask: bool = True,
504
+ dtype: Optional[torch.dtype] = None,
505
+ device: Optional[torch.device] = None,
506
+ attn_mode: str = "flash",
507
+ split_attn: bool = False,
508
+ ):
509
+ factory_kwargs = {"device": device, "dtype": dtype}
510
+ super().__init__()
511
+
512
+ self.patch_size = patch_size
513
+ self.in_channels = in_channels
514
+ self.out_channels = in_channels if out_channels is None else out_channels
515
+ self.unpatchify_channels = self.out_channels
516
+ self.guidance_embed = guidance_embed
517
+ self.rope_dim_list = rope_dim_list
518
+
519
+ # Text projection. Default to linear projection.
520
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
521
+ self.use_attention_mask = use_attention_mask
522
+ self.text_projection = text_projection
523
+
524
+ self.text_states_dim = text_states_dim
525
+ self.text_states_dim_2 = text_states_dim_2
526
+
527
+ if hidden_size % heads_num != 0:
528
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
529
+ pe_dim = hidden_size // heads_num
530
+ if sum(rope_dim_list) != pe_dim:
531
+ raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
532
+ self.hidden_size = hidden_size
533
+ self.heads_num = heads_num
534
+
535
+ self.attn_mode = attn_mode
536
+ self.split_attn = split_attn
537
+ print(f"Using {self.attn_mode} attention mode, split_attn: {self.split_attn}")
538
+
539
+ # image projection
540
+ self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
541
+
542
+ # text projection
543
+ if self.text_projection == "linear":
544
+ self.txt_in = TextProjection(
545
+ self.text_states_dim,
546
+ self.hidden_size,
547
+ get_activation_layer("silu"),
548
+ **factory_kwargs,
549
+ )
550
+ elif self.text_projection == "single_refiner":
551
+ self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
552
+ else:
553
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
554
+
555
+ # time modulation
556
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
557
+
558
+ # text modulation
559
+ self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
560
+
561
+ # guidance modulation
562
+ self.guidance_in = (
563
+ TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
564
+ )
565
+
566
+ # double blocks
567
+ self.double_blocks = nn.ModuleList(
568
+ [
569
+ MMDoubleStreamBlock(
570
+ self.hidden_size,
571
+ self.heads_num,
572
+ mlp_width_ratio=mlp_width_ratio,
573
+ mlp_act_type=mlp_act_type,
574
+ qk_norm=qk_norm,
575
+ qk_norm_type=qk_norm_type,
576
+ qkv_bias=qkv_bias,
577
+ attn_mode=attn_mode,
578
+ split_attn=split_attn,
579
+ **factory_kwargs,
580
+ )
581
+ for _ in range(mm_double_blocks_depth)
582
+ ]
583
+ )
584
+
585
+ # single blocks
586
+ self.single_blocks = nn.ModuleList(
587
+ [
588
+ MMSingleStreamBlock(
589
+ self.hidden_size,
590
+ self.heads_num,
591
+ mlp_width_ratio=mlp_width_ratio,
592
+ mlp_act_type=mlp_act_type,
593
+ qk_norm=qk_norm,
594
+ qk_norm_type=qk_norm_type,
595
+ attn_mode=attn_mode,
596
+ split_attn=split_attn,
597
+ **factory_kwargs,
598
+ )
599
+ for _ in range(mm_single_blocks_depth)
600
+ ]
601
+ )
602
+
603
+ self.final_layer = FinalLayer(
604
+ self.hidden_size,
605
+ self.patch_size,
606
+ self.out_channels,
607
+ get_activation_layer("silu"),
608
+ **factory_kwargs,
609
+ )
610
+
611
+ self.gradient_checkpointing = False
612
+ self.blocks_to_swap = None
613
+ self.offloader_double = None
614
+ self.offloader_single = None
615
+ self._enable_img_in_txt_in_offloading = False
616
+
617
+ @property
618
+ def device(self):
619
+ return next(self.parameters()).device
620
+
621
+ @property
622
+ def dtype(self):
623
+ return next(self.parameters()).dtype
624
+
625
+ def enable_gradient_checkpointing(self):
626
+ self.gradient_checkpointing = True
627
+
628
+ self.txt_in.enable_gradient_checkpointing()
629
+
630
+ for block in self.double_blocks + self.single_blocks:
631
+ block.enable_gradient_checkpointing()
632
+
633
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
634
+
635
+ def disable_gradient_checkpointing(self):
636
+ self.gradient_checkpointing = False
637
+
638
+ self.txt_in.disable_gradient_checkpointing()
639
+
640
+ for block in self.double_blocks + self.single_blocks:
641
+ block.disable_gradient_checkpointing()
642
+
643
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing disabled.")
644
+
645
+ def enable_img_in_txt_in_offloading(self):
646
+ self._enable_img_in_txt_in_offloading = True
647
+
648
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
649
+ self.blocks_to_swap = num_blocks
650
+ self.num_double_blocks = len(self.double_blocks)
651
+ self.num_single_blocks = len(self.single_blocks)
652
+ double_blocks_to_swap = num_blocks // 2
653
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
654
+
655
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
656
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
657
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
658
+ )
659
+
660
+ self.offloader_double = ModelOffloader(
661
+ "double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
662
+ )
663
+ self.offloader_single = ModelOffloader(
664
+ "single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
665
+ )
666
+ print(
667
+ f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
668
+ )
669
+
670
+ def switch_block_swap_for_inference(self):
671
+ if self.blocks_to_swap:
672
+ self.offloader_double.set_forward_only(True)
673
+ self.offloader_single.set_forward_only(True)
674
+ self.prepare_block_swap_before_forward()
675
+ print(f"HYVideoDiffusionTransformer: Block swap set to forward only.")
676
+
677
+ def switch_block_swap_for_training(self):
678
+ if self.blocks_to_swap:
679
+ self.offloader_double.set_forward_only(False)
680
+ self.offloader_single.set_forward_only(False)
681
+ self.prepare_block_swap_before_forward()
682
+ print(f"HYVideoDiffusionTransformer: Block swap set to forward and backward.")
683
+
684
+ def move_to_device_except_swap_blocks(self, device: torch.device):
685
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
686
+ if self.blocks_to_swap:
687
+ save_double_blocks = self.double_blocks
688
+ save_single_blocks = self.single_blocks
689
+ self.double_blocks = None
690
+ self.single_blocks = None
691
+
692
+ self.to(device)
693
+
694
+ if self.blocks_to_swap:
695
+ self.double_blocks = save_double_blocks
696
+ self.single_blocks = save_single_blocks
697
+
698
+ def prepare_block_swap_before_forward(self):
699
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
700
+ return
701
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
702
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
703
+
704
+ def enable_deterministic(self):
705
+ for block in self.double_blocks:
706
+ block.enable_deterministic()
707
+ for block in self.single_blocks:
708
+ block.enable_deterministic()
709
+
710
+ def disable_deterministic(self):
711
+ for block in self.double_blocks:
712
+ block.disable_deterministic()
713
+ for block in self.single_blocks:
714
+ block.disable_deterministic()
715
+
716
+ def forward(
717
+ self,
718
+ x: torch.Tensor,
719
+ t: torch.Tensor, # Should be in range(0, 1000).
720
+ text_states: torch.Tensor = None,
721
+ text_mask: torch.Tensor = None, # Now we don't use it.
722
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
723
+ freqs_cos: Optional[torch.Tensor] = None,
724
+ freqs_sin: Optional[torch.Tensor] = None,
725
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
726
+ return_dict: bool = True,
727
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
728
+ out = {}
729
+ img = x
730
+ txt = text_states
731
+ _, _, ot, oh, ow = x.shape
732
+ tt, th, tw = (
733
+ ot // self.patch_size[0],
734
+ oh // self.patch_size[1],
735
+ ow // self.patch_size[2],
736
+ )
737
+
738
+ # Prepare modulation vectors.
739
+ vec = self.time_in(t)
740
+
741
+ # text modulation
742
+ vec = vec + self.vector_in(text_states_2)
743
+
744
+ # guidance modulation
745
+ if self.guidance_embed:
746
+ if guidance is None:
747
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
748
+
749
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
750
+ vec = vec + self.guidance_in(guidance)
751
+
752
+ # Embed image and text.
753
+ if self._enable_img_in_txt_in_offloading:
754
+ self.img_in.to(x.device, non_blocking=True)
755
+ self.txt_in.to(x.device, non_blocking=True)
756
+ synchronize_device(x.device)
757
+
758
+ img = self.img_in(img)
759
+ if self.text_projection == "linear":
760
+ txt = self.txt_in(txt)
761
+ elif self.text_projection == "single_refiner":
762
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
763
+ else:
764
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
765
+
766
+ if self._enable_img_in_txt_in_offloading:
767
+ self.img_in.to(torch.device("cpu"), non_blocking=True)
768
+ self.txt_in.to(torch.device("cpu"), non_blocking=True)
769
+ synchronize_device(x.device)
770
+ clean_memory_on_device(x.device)
771
+
772
+ txt_seq_len = txt.shape[1]
773
+ img_seq_len = img.shape[1]
774
+
775
+ # Compute cu_squlens and max_seqlen for flash attention
776
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
777
+ cu_seqlens_kv = cu_seqlens_q
778
+ max_seqlen_q = img_seq_len + txt_seq_len
779
+ max_seqlen_kv = max_seqlen_q
780
+
781
+ attn_mask = total_len = None
782
+ if self.split_attn or self.attn_mode == "torch":
783
+ # calculate text length and total length
784
+ text_len = text_mask.sum(dim=1) # (bs, )
785
+ total_len = img_seq_len + text_len # (bs, )
786
+ if self.attn_mode == "torch" and not self.split_attn:
787
+ # initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
788
+ bs = img.shape[0]
789
+ attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
790
+
791
+ # set attention mask with total_len
792
+ for i in range(bs):
793
+ attn_mask[i, :, : total_len[i], : total_len[i]] = True
794
+ total_len = None # means we don't use split_attn
795
+
796
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
797
+ # --------------------- Pass through DiT blocks ------------------------
798
+ for block_idx, block in enumerate(self.double_blocks):
799
+ double_block_args = [
800
+ img,
801
+ txt,
802
+ vec,
803
+ attn_mask,
804
+ total_len,
805
+ cu_seqlens_q,
806
+ cu_seqlens_kv,
807
+ max_seqlen_q,
808
+ max_seqlen_kv,
809
+ freqs_cis,
810
+ ]
811
+
812
+ if self.blocks_to_swap:
813
+ self.offloader_double.wait_for_block(block_idx)
814
+
815
+ img, txt = block(*double_block_args)
816
+
817
+ if self.blocks_to_swap:
818
+ self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
819
+
820
+ # Merge txt and img to pass through single stream blocks.
821
+ x = torch.cat((img, txt), 1)
822
+ if self.blocks_to_swap:
823
+ # delete img, txt to reduce memory usage
824
+ del img, txt
825
+ clean_memory_on_device(x.device)
826
+
827
+ if len(self.single_blocks) > 0:
828
+ for block_idx, block in enumerate(self.single_blocks):
829
+ single_block_args = [
830
+ x,
831
+ vec,
832
+ txt_seq_len,
833
+ attn_mask,
834
+ total_len,
835
+ cu_seqlens_q,
836
+ cu_seqlens_kv,
837
+ max_seqlen_q,
838
+ max_seqlen_kv,
839
+ freqs_cis,
840
+ ]
841
+ if self.blocks_to_swap:
842
+ self.offloader_single.wait_for_block(block_idx)
843
+
844
+ x = block(*single_block_args)
845
+
846
+ if self.blocks_to_swap:
847
+ self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
848
+
849
+ img = x[:, :img_seq_len, ...]
850
+ x = None
851
+
852
+ # ---------------------------- Final layer ------------------------------
853
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
854
+
855
+ img = self.unpatchify(img, tt, th, tw)
856
+ if return_dict:
857
+ out["x"] = img
858
+ return out
859
+ return img
860
+
861
+ def unpatchify(self, x, t, h, w):
862
+ """
863
+ x: (N, T, patch_size**2 * C)
864
+ imgs: (N, H, W, C)
865
+ """
866
+ c = self.unpatchify_channels
867
+ pt, ph, pw = self.patch_size
868
+ assert t * h * w == x.shape[1]
869
+
870
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
871
+ x = torch.einsum("nthwcopq->nctohpwq", x)
872
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
873
+
874
+ return imgs
875
+
876
+ def params_count(self):
877
+ counts = {
878
+ "double": sum(
879
+ [
880
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
881
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
882
+ + sum(p.numel() for p in block.img_mlp.parameters())
883
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
884
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
885
+ + sum(p.numel() for p in block.txt_mlp.parameters())
886
+ for block in self.double_blocks
887
+ ]
888
+ ),
889
+ "single": sum(
890
+ [
891
+ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
892
+ for block in self.single_blocks
893
+ ]
894
+ ),
895
+ "total": sum(p.numel() for p in self.parameters()),
896
+ }
897
+ counts["attn+mlp"] = counts["double"] + counts["single"]
898
+ return counts
899
+
900
+
901
+ #################################################################################
902
+ # HunyuanVideo Configs #
903
+ #################################################################################
904
+
905
+ HUNYUAN_VIDEO_CONFIG = {
906
+ "HYVideo-T/2": {
907
+ "mm_double_blocks_depth": 20,
908
+ "mm_single_blocks_depth": 40,
909
+ "rope_dim_list": [16, 56, 56],
910
+ "hidden_size": 3072,
911
+ "heads_num": 24,
912
+ "mlp_width_ratio": 4,
913
+ },
914
+ "HYVideo-T/2-cfgdistill": {
915
+ "mm_double_blocks_depth": 20,
916
+ "mm_single_blocks_depth": 40,
917
+ "rope_dim_list": [16, 56, 56],
918
+ "hidden_size": 3072,
919
+ "heads_num": 24,
920
+ "mlp_width_ratio": 4,
921
+ "guidance_embed": True,
922
+ },
923
+ }
924
+
925
+
926
+ def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
927
+ """load hunyuan video model
928
+
929
+ NOTE: Only support HYVideo-T/2-cfgdistill now.
930
+
931
+ Args:
932
+ text_state_dim (int): text state dimension
933
+ text_state_dim_2 (int): text state dimension 2
934
+ in_channels (int): input channels number
935
+ out_channels (int): output channels number
936
+ factor_kwargs (dict): factor kwargs
937
+
938
+ Returns:
939
+ model (nn.Module): The hunyuan video model
940
+ """
941
+ # if args.model in HUNYUAN_VIDEO_CONFIG.keys():
942
+ model = HYVideoDiffusionTransformer(
943
+ text_states_dim=text_states_dim,
944
+ text_states_dim_2=text_states_dim_2,
945
+ in_channels=in_channels,
946
+ out_channels=out_channels,
947
+ **HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
948
+ **factor_kwargs,
949
+ )
950
+ return model
951
+ # else:
952
+ # raise NotImplementedError()
953
+
954
+
955
+ def load_state_dict(model, model_path):
956
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
957
+
958
+ load_key = "module"
959
+ if load_key in state_dict:
960
+ state_dict = state_dict[load_key]
961
+ else:
962
+ raise KeyError(
963
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
964
+ f"are: {list(state_dict.keys())}."
965
+ )
966
+ model.load_state_dict(state_dict, strict=True, assign=True)
967
+ return model
968
+
969
+
970
+ def load_transformer(dit_path, attn_mode, split_attn, device, dtype, in_channels=16) -> HYVideoDiffusionTransformer:
971
+ # =========================== Build main model ===========================
972
+ factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode, "split_attn": split_attn}
973
+ latent_channels = 16
974
+ out_channels = latent_channels
975
+
976
+ with accelerate.init_empty_weights():
977
+ transformer = load_dit_model(
978
+ text_states_dim=4096,
979
+ text_states_dim_2=768,
980
+ in_channels=in_channels,
981
+ out_channels=out_channels,
982
+ factor_kwargs=factor_kwargs,
983
+ )
984
+
985
+ if os.path.splitext(dit_path)[-1] == ".safetensors":
986
+ # loading safetensors: may be already fp8
987
+ with MemoryEfficientSafeOpen(dit_path) as f:
988
+ state_dict = {}
989
+ for k in f.keys():
990
+ tensor = f.get_tensor(k)
991
+ tensor = tensor.to(device=device, dtype=dtype)
992
+ # TODO support comfy model
993
+ # if k.startswith("model.model."):
994
+ # k = convert_comfy_model_key(k)
995
+ state_dict[k] = tensor
996
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
997
+ else:
998
+ transformer = load_state_dict(transformer, dit_path)
999
+
1000
+ return transformer
1001
+
1002
+
1003
+ def get_rotary_pos_embed_by_shape(model, latents_size):
1004
+ target_ndim = 3
1005
+ ndim = 5 - 2
1006
+
1007
+ if isinstance(model.patch_size, int):
1008
+ assert all(s % model.patch_size == 0 for s in latents_size), (
1009
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
1010
+ f"but got {latents_size}."
1011
+ )
1012
+ rope_sizes = [s // model.patch_size for s in latents_size]
1013
+ elif isinstance(model.patch_size, list):
1014
+ assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
1015
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
1016
+ f"but got {latents_size}."
1017
+ )
1018
+ rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
1019
+
1020
+ if len(rope_sizes) != target_ndim:
1021
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
1022
+ head_dim = model.hidden_size // model.heads_num
1023
+ rope_dim_list = model.rope_dim_list
1024
+ if rope_dim_list is None:
1025
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
1026
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
1027
+
1028
+ rope_theta = 256
1029
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
1030
+ rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
1031
+ )
1032
+ return freqs_cos, freqs_sin
1033
+
1034
+
1035
+ def get_rotary_pos_embed(vae_name, model, video_length, height, width):
1036
+ # 884
1037
+ if "884" in vae_name:
1038
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
1039
+ elif "888" in vae_name:
1040
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
1041
+ else:
1042
+ latents_size = [video_length, height // 8, width // 8]
1043
+
1044
+ return get_rotary_pos_embed_by_shape(model, latents_size)
hunyuan_model/modulate_layers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ModulateDiT(nn.Module):
8
+ """Modulation layer for DiT."""
9
+ def __init__(
10
+ self,
11
+ hidden_size: int,
12
+ factor: int,
13
+ act_layer: Callable,
14
+ dtype=None,
15
+ device=None,
16
+ ):
17
+ factory_kwargs = {"dtype": dtype, "device": device}
18
+ super().__init__()
19
+ self.act = act_layer()
20
+ self.linear = nn.Linear(
21
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
+ )
23
+ # Zero-initialize the modulation
24
+ nn.init.zeros_(self.linear.weight)
25
+ nn.init.zeros_(self.linear.bias)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ return self.linear(self.act(x))
29
+
30
+
31
+ def modulate(x, shift=None, scale=None):
32
+ """modulate by shift and scale
33
+
34
+ Args:
35
+ x (torch.Tensor): input tensor.
36
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
37
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
38
+
39
+ Returns:
40
+ torch.Tensor: the output tensor after modulate.
41
+ """
42
+ if scale is None and shift is None:
43
+ return x
44
+ elif shift is None:
45
+ return x * (1 + scale.unsqueeze(1))
46
+ elif scale is None:
47
+ return x + shift.unsqueeze(1)
48
+ else:
49
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
+
51
+
52
+ def apply_gate(x, gate=None, tanh=False):
53
+ """AI is creating summary for apply_gate
54
+
55
+ Args:
56
+ x (torch.Tensor): input tensor.
57
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
58
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
59
+
60
+ Returns:
61
+ torch.Tensor: the output tensor after apply gate.
62
+ """
63
+ if gate is None:
64
+ return x
65
+ if tanh:
66
+ return x * gate.unsqueeze(1).tanh()
67
+ else:
68
+ return x * gate.unsqueeze(1)
69
+
70
+
71
+ def ckpt_wrapper(module):
72
+ def ckpt_forward(*inputs):
73
+ outputs = module(*inputs)
74
+ return outputs
75
+
76
+ return ckpt_forward
hunyuan_model/norm_layers.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ elementwise_affine=True,
10
+ eps: float = 1e-6,
11
+ device=None,
12
+ dtype=None,
13
+ ):
14
+ """
15
+ Initialize the RMSNorm normalization layer.
16
+
17
+ Args:
18
+ dim (int): The dimension of the input tensor.
19
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20
+
21
+ Attributes:
22
+ eps (float): A small value added to the denominator for numerical stability.
23
+ weight (nn.Parameter): Learnable scaling parameter.
24
+
25
+ """
26
+ factory_kwargs = {"device": device, "dtype": dtype}
27
+ super().__init__()
28
+ self.eps = eps
29
+ if elementwise_affine:
30
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31
+
32
+ def _norm(self, x):
33
+ """
34
+ Apply the RMSNorm normalization to the input tensor.
35
+
36
+ Args:
37
+ x (torch.Tensor): The input tensor.
38
+
39
+ Returns:
40
+ torch.Tensor: The normalized tensor.
41
+
42
+ """
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ """
47
+ Forward pass through the RMSNorm layer.
48
+
49
+ Args:
50
+ x (torch.Tensor): The input tensor.
51
+
52
+ Returns:
53
+ torch.Tensor: The output tensor after applying RMSNorm.
54
+
55
+ """
56
+ output = self._norm(x.float()).type_as(x)
57
+ if hasattr(self, "weight"):
58
+ # output = output * self.weight
59
+ # support fp8
60
+ output = output * self.weight.to(output.dtype)
61
+ return output
62
+
63
+
64
+ def get_norm_layer(norm_layer):
65
+ """
66
+ Get the normalization layer.
67
+
68
+ Args:
69
+ norm_layer (str): The type of normalization layer.
70
+
71
+ Returns:
72
+ norm_layer (nn.Module): The normalization layer.
73
+ """
74
+ if norm_layer == "layer":
75
+ return nn.LayerNorm
76
+ elif norm_layer == "rms":
77
+ return RMSNorm
78
+ else:
79
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
hunyuan_model/pipeline_hunyuan_video.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import torch
22
+ import torch.distributed as dist
23
+ import numpy as np
24
+ from dataclasses import dataclass
25
+ from packaging import version
26
+
27
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
28
+ from diffusers.configuration_utils import FrozenDict
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL
32
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.utils import BaseOutput
45
+
46
+ from ...constants import PRECISION_TO_TYPE
47
+ from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
48
+ from ...text_encoder import TextEncoder
49
+ from ...modules import HYVideoDiffusionTransformer
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """"""
54
+
55
+
56
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
57
+ """
58
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
59
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
60
+ """
61
+ std_text = noise_pred_text.std(
62
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
63
+ )
64
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65
+ # rescale the results from guidance (fixes overexposure)
66
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68
+ noise_cfg = (
69
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
70
+ )
71
+ return noise_cfg
72
+
73
+
74
+ def retrieve_timesteps(
75
+ scheduler,
76
+ num_inference_steps: Optional[int] = None,
77
+ device: Optional[Union[str, torch.device]] = None,
78
+ timesteps: Optional[List[int]] = None,
79
+ sigmas: Optional[List[float]] = None,
80
+ **kwargs,
81
+ ):
82
+ """
83
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
+
86
+ Args:
87
+ scheduler (`SchedulerMixin`):
88
+ The scheduler to get timesteps from.
89
+ num_inference_steps (`int`):
90
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
+ must be `None`.
92
+ device (`str` or `torch.device`, *optional*):
93
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ timesteps (`List[int]`, *optional*):
95
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
+ `num_inference_steps` and `sigmas` must be `None`.
97
+ sigmas (`List[float]`, *optional*):
98
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
+ `num_inference_steps` and `timesteps` must be `None`.
100
+
101
+ Returns:
102
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
+ second element is the number of inference steps.
104
+ """
105
+ if timesteps is not None and sigmas is not None:
106
+ raise ValueError(
107
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
108
+ )
109
+ if timesteps is not None:
110
+ accepts_timesteps = "timesteps" in set(
111
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
112
+ )
113
+ if not accepts_timesteps:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" timestep schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ elif sigmas is not None:
122
+ accept_sigmas = "sigmas" in set(
123
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
124
+ )
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ @dataclass
140
+ class HunyuanVideoPipelineOutput(BaseOutput):
141
+ videos: Union[torch.Tensor, np.ndarray]
142
+
143
+
144
+ class HunyuanVideoPipeline(DiffusionPipeline):
145
+ r"""
146
+ Pipeline for text-to-video generation using HunyuanVideo.
147
+
148
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
149
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
150
+
151
+ Args:
152
+ vae ([`AutoencoderKL`]):
153
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
154
+ text_encoder ([`TextEncoder`]):
155
+ Frozen text-encoder.
156
+ text_encoder_2 ([`TextEncoder`]):
157
+ Frozen text-encoder_2.
158
+ transformer ([`HYVideoDiffusionTransformer`]):
159
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
160
+ scheduler ([`SchedulerMixin`]):
161
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
162
+ """
163
+
164
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
165
+ _optional_components = ["text_encoder_2"]
166
+ _exclude_from_cpu_offload = ["transformer"]
167
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
168
+
169
+ def __init__(
170
+ self,
171
+ vae: AutoencoderKL,
172
+ text_encoder: TextEncoder,
173
+ transformer: HYVideoDiffusionTransformer,
174
+ scheduler: KarrasDiffusionSchedulers,
175
+ text_encoder_2: Optional[TextEncoder] = None,
176
+ progress_bar_config: Dict[str, Any] = None,
177
+ args=None,
178
+ ):
179
+ super().__init__()
180
+
181
+ # ==========================================================================================
182
+ if progress_bar_config is None:
183
+ progress_bar_config = {}
184
+ if not hasattr(self, "_progress_bar_config"):
185
+ self._progress_bar_config = {}
186
+ self._progress_bar_config.update(progress_bar_config)
187
+
188
+ self.args = args
189
+ # ==========================================================================================
190
+
191
+ if (
192
+ hasattr(scheduler.config, "steps_offset")
193
+ and scheduler.config.steps_offset != 1
194
+ ):
195
+ deprecation_message = (
196
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
197
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
198
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
199
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
200
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
201
+ " file"
202
+ )
203
+ deprecate(
204
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
205
+ )
206
+ new_config = dict(scheduler.config)
207
+ new_config["steps_offset"] = 1
208
+ scheduler._internal_dict = FrozenDict(new_config)
209
+
210
+ if (
211
+ hasattr(scheduler.config, "clip_sample")
212
+ and scheduler.config.clip_sample is True
213
+ ):
214
+ deprecation_message = (
215
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
216
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
217
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
218
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
219
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
220
+ )
221
+ deprecate(
222
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
223
+ )
224
+ new_config = dict(scheduler.config)
225
+ new_config["clip_sample"] = False
226
+ scheduler._internal_dict = FrozenDict(new_config)
227
+
228
+ self.register_modules(
229
+ vae=vae,
230
+ text_encoder=text_encoder,
231
+ transformer=transformer,
232
+ scheduler=scheduler,
233
+ text_encoder_2=text_encoder_2,
234
+ )
235
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+
238
+ def encode_prompt(
239
+ self,
240
+ prompt,
241
+ device,
242
+ num_videos_per_prompt,
243
+ do_classifier_free_guidance,
244
+ negative_prompt=None,
245
+ prompt_embeds: Optional[torch.Tensor] = None,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
248
+ negative_attention_mask: Optional[torch.Tensor] = None,
249
+ lora_scale: Optional[float] = None,
250
+ clip_skip: Optional[int] = None,
251
+ text_encoder: Optional[TextEncoder] = None,
252
+ data_type: Optional[str] = "image",
253
+ ):
254
+ r"""
255
+ Encodes the prompt into text encoder hidden states.
256
+
257
+ Args:
258
+ prompt (`str` or `List[str]`, *optional*):
259
+ prompt to be encoded
260
+ device: (`torch.device`):
261
+ torch device
262
+ num_videos_per_prompt (`int`):
263
+ number of videos that should be generated per prompt
264
+ do_classifier_free_guidance (`bool`):
265
+ whether to use classifier free guidance or not
266
+ negative_prompt (`str` or `List[str]`, *optional*):
267
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
268
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
269
+ less than `1`).
270
+ prompt_embeds (`torch.Tensor`, *optional*):
271
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
272
+ provided, text embeddings will be generated from `prompt` input argument.
273
+ attention_mask (`torch.Tensor`, *optional*):
274
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
275
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
276
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
277
+ argument.
278
+ negative_attention_mask (`torch.Tensor`, *optional*):
279
+ lora_scale (`float`, *optional*):
280
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
281
+ clip_skip (`int`, *optional*):
282
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
283
+ the output of the pre-final layer will be used for computing the prompt embeddings.
284
+ text_encoder (TextEncoder, *optional*):
285
+ data_type (`str`, *optional*):
286
+ """
287
+ if text_encoder is None:
288
+ text_encoder = self.text_encoder
289
+
290
+ # set lora scale so that monkey patched LoRA
291
+ # function of text encoder can correctly access it
292
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
293
+ self._lora_scale = lora_scale
294
+
295
+ # dynamically adjust the LoRA scale
296
+ if not USE_PEFT_BACKEND:
297
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
298
+ else:
299
+ scale_lora_layers(text_encoder.model, lora_scale)
300
+
301
+ if prompt is not None and isinstance(prompt, str):
302
+ batch_size = 1
303
+ elif prompt is not None and isinstance(prompt, list):
304
+ batch_size = len(prompt)
305
+ else:
306
+ batch_size = prompt_embeds.shape[0]
307
+
308
+ if prompt_embeds is None:
309
+ # textual inversion: process multi-vector tokens if necessary
310
+ if isinstance(self, TextualInversionLoaderMixin):
311
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
312
+
313
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
314
+
315
+ if clip_skip is None:
316
+ prompt_outputs = text_encoder.encode(
317
+ text_inputs, data_type=data_type, device=device
318
+ )
319
+ prompt_embeds = prompt_outputs.hidden_state
320
+ else:
321
+ prompt_outputs = text_encoder.encode(
322
+ text_inputs,
323
+ output_hidden_states=True,
324
+ data_type=data_type,
325
+ device=device,
326
+ )
327
+ # Access the `hidden_states` first, that contains a tuple of
328
+ # all the hidden states from the encoder layers. Then index into
329
+ # the tuple to access the hidden states from the desired layer.
330
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
331
+ # We also need to apply the final LayerNorm here to not mess with the
332
+ # representations. The `last_hidden_states` that we typically use for
333
+ # obtaining the final prompt representations passes through the LayerNorm
334
+ # layer.
335
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(
336
+ prompt_embeds
337
+ )
338
+
339
+ attention_mask = prompt_outputs.attention_mask
340
+ if attention_mask is not None:
341
+ attention_mask = attention_mask.to(device)
342
+ bs_embed, seq_len = attention_mask.shape
343
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
344
+ attention_mask = attention_mask.view(
345
+ bs_embed * num_videos_per_prompt, seq_len
346
+ )
347
+
348
+ if text_encoder is not None:
349
+ prompt_embeds_dtype = text_encoder.dtype
350
+ elif self.transformer is not None:
351
+ prompt_embeds_dtype = self.transformer.dtype
352
+ else:
353
+ prompt_embeds_dtype = prompt_embeds.dtype
354
+
355
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
356
+
357
+ if prompt_embeds.ndim == 2:
358
+ bs_embed, _ = prompt_embeds.shape
359
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
360
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
361
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
362
+ else:
363
+ bs_embed, seq_len, _ = prompt_embeds.shape
364
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
365
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
366
+ prompt_embeds = prompt_embeds.view(
367
+ bs_embed * num_videos_per_prompt, seq_len, -1
368
+ )
369
+
370
+ # get unconditional embeddings for classifier free guidance
371
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
372
+ uncond_tokens: List[str]
373
+ if negative_prompt is None:
374
+ uncond_tokens = [""] * batch_size
375
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
376
+ raise TypeError(
377
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
378
+ f" {type(prompt)}."
379
+ )
380
+ elif isinstance(negative_prompt, str):
381
+ uncond_tokens = [negative_prompt]
382
+ elif batch_size != len(negative_prompt):
383
+ raise ValueError(
384
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
385
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
386
+ " the batch size of `prompt`."
387
+ )
388
+ else:
389
+ uncond_tokens = negative_prompt
390
+
391
+ # textual inversion: process multi-vector tokens if necessary
392
+ if isinstance(self, TextualInversionLoaderMixin):
393
+ uncond_tokens = self.maybe_convert_prompt(
394
+ uncond_tokens, text_encoder.tokenizer
395
+ )
396
+
397
+ # max_length = prompt_embeds.shape[1]
398
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
399
+
400
+ negative_prompt_outputs = text_encoder.encode(
401
+ uncond_input, data_type=data_type, device=device
402
+ )
403
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
404
+
405
+ negative_attention_mask = negative_prompt_outputs.attention_mask
406
+ if negative_attention_mask is not None:
407
+ negative_attention_mask = negative_attention_mask.to(device)
408
+ _, seq_len = negative_attention_mask.shape
409
+ negative_attention_mask = negative_attention_mask.repeat(
410
+ 1, num_videos_per_prompt
411
+ )
412
+ negative_attention_mask = negative_attention_mask.view(
413
+ batch_size * num_videos_per_prompt, seq_len
414
+ )
415
+
416
+ if do_classifier_free_guidance:
417
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
418
+ seq_len = negative_prompt_embeds.shape[1]
419
+
420
+ negative_prompt_embeds = negative_prompt_embeds.to(
421
+ dtype=prompt_embeds_dtype, device=device
422
+ )
423
+
424
+ if negative_prompt_embeds.ndim == 2:
425
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
426
+ 1, num_videos_per_prompt
427
+ )
428
+ negative_prompt_embeds = negative_prompt_embeds.view(
429
+ batch_size * num_videos_per_prompt, -1
430
+ )
431
+ else:
432
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
433
+ 1, num_videos_per_prompt, 1
434
+ )
435
+ negative_prompt_embeds = negative_prompt_embeds.view(
436
+ batch_size * num_videos_per_prompt, seq_len, -1
437
+ )
438
+
439
+ if text_encoder is not None:
440
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
441
+ # Retrieve the original scale by scaling back the LoRA layers
442
+ unscale_lora_layers(text_encoder.model, lora_scale)
443
+
444
+ return (
445
+ prompt_embeds,
446
+ negative_prompt_embeds,
447
+ attention_mask,
448
+ negative_attention_mask,
449
+ )
450
+
451
+ def decode_latents(self, latents, enable_tiling=True):
452
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
453
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
454
+
455
+ latents = 1 / self.vae.config.scaling_factor * latents
456
+ if enable_tiling:
457
+ self.vae.enable_tiling()
458
+ image = self.vae.decode(latents, return_dict=False)[0]
459
+ else:
460
+ image = self.vae.decode(latents, return_dict=False)[0]
461
+ image = (image / 2 + 0.5).clamp(0, 1)
462
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
463
+ if image.ndim == 4:
464
+ image = image.cpu().permute(0, 2, 3, 1).float()
465
+ else:
466
+ image = image.cpu().float()
467
+ return image
468
+
469
+ def prepare_extra_func_kwargs(self, func, kwargs):
470
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
471
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
472
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
473
+ # and should be between [0, 1]
474
+ extra_step_kwargs = {}
475
+
476
+ for k, v in kwargs.items():
477
+ accepts = k in set(inspect.signature(func).parameters.keys())
478
+ if accepts:
479
+ extra_step_kwargs[k] = v
480
+ return extra_step_kwargs
481
+
482
+ def check_inputs(
483
+ self,
484
+ prompt,
485
+ height,
486
+ width,
487
+ video_length,
488
+ callback_steps,
489
+ negative_prompt=None,
490
+ prompt_embeds=None,
491
+ negative_prompt_embeds=None,
492
+ callback_on_step_end_tensor_inputs=None,
493
+ vae_ver="88-4c-sd",
494
+ ):
495
+ if height % 8 != 0 or width % 8 != 0:
496
+ raise ValueError(
497
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
498
+ )
499
+
500
+ if video_length is not None:
501
+ if "884" in vae_ver:
502
+ if video_length != 1 and (video_length - 1) % 4 != 0:
503
+ raise ValueError(
504
+ f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
505
+ )
506
+ elif "888" in vae_ver:
507
+ if video_length != 1 and (video_length - 1) % 8 != 0:
508
+ raise ValueError(
509
+ f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
510
+ )
511
+
512
+ if callback_steps is not None and (
513
+ not isinstance(callback_steps, int) or callback_steps <= 0
514
+ ):
515
+ raise ValueError(
516
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
517
+ f" {type(callback_steps)}."
518
+ )
519
+ if callback_on_step_end_tensor_inputs is not None and not all(
520
+ k in self._callback_tensor_inputs
521
+ for k in callback_on_step_end_tensor_inputs
522
+ ):
523
+ raise ValueError(
524
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
525
+ )
526
+
527
+ if prompt is not None and prompt_embeds is not None:
528
+ raise ValueError(
529
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
530
+ " only forward one of the two."
531
+ )
532
+ elif prompt is None and prompt_embeds is None:
533
+ raise ValueError(
534
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
535
+ )
536
+ elif prompt is not None and (
537
+ not isinstance(prompt, str) and not isinstance(prompt, list)
538
+ ):
539
+ raise ValueError(
540
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
541
+ )
542
+
543
+ if negative_prompt is not None and negative_prompt_embeds is not None:
544
+ raise ValueError(
545
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
546
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
547
+ )
548
+
549
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
550
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
551
+ raise ValueError(
552
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
553
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
554
+ f" {negative_prompt_embeds.shape}."
555
+ )
556
+
557
+
558
+ def prepare_latents(
559
+ self,
560
+ batch_size,
561
+ num_channels_latents,
562
+ height,
563
+ width,
564
+ video_length,
565
+ dtype,
566
+ device,
567
+ generator,
568
+ latents=None,
569
+ ):
570
+ shape = (
571
+ batch_size,
572
+ num_channels_latents,
573
+ video_length,
574
+ int(height) // self.vae_scale_factor,
575
+ int(width) // self.vae_scale_factor,
576
+ )
577
+ if isinstance(generator, list) and len(generator) != batch_size:
578
+ raise ValueError(
579
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
580
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
581
+ )
582
+
583
+ if latents is None:
584
+ latents = randn_tensor(
585
+ shape, generator=generator, device=device, dtype=dtype
586
+ )
587
+ else:
588
+ latents = latents.to(device)
589
+
590
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
591
+ if hasattr(self.scheduler, "init_noise_sigma"):
592
+ # scale the initial noise by the standard deviation required by the scheduler
593
+ latents = latents * self.scheduler.init_noise_sigma
594
+ return latents
595
+
596
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
597
+ def get_guidance_scale_embedding(
598
+ self,
599
+ w: torch.Tensor,
600
+ embedding_dim: int = 512,
601
+ dtype: torch.dtype = torch.float32,
602
+ ) -> torch.Tensor:
603
+ """
604
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
605
+
606
+ Args:
607
+ w (`torch.Tensor`):
608
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
609
+ embedding_dim (`int`, *optional*, defaults to 512):
610
+ Dimension of the embeddings to generate.
611
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
612
+ Data type of the generated embeddings.
613
+
614
+ Returns:
615
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
616
+ """
617
+ assert len(w.shape) == 1
618
+ w = w * 1000.0
619
+
620
+ half_dim = embedding_dim // 2
621
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
622
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
623
+ emb = w.to(dtype)[:, None] * emb[None, :]
624
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
625
+ if embedding_dim % 2 == 1: # zero pad
626
+ emb = torch.nn.functional.pad(emb, (0, 1))
627
+ assert emb.shape == (w.shape[0], embedding_dim)
628
+ return emb
629
+
630
+ @property
631
+ def guidance_scale(self):
632
+ return self._guidance_scale
633
+
634
+ @property
635
+ def guidance_rescale(self):
636
+ return self._guidance_rescale
637
+
638
+ @property
639
+ def clip_skip(self):
640
+ return self._clip_skip
641
+
642
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
643
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
644
+ # corresponds to doing no classifier free guidance.
645
+ @property
646
+ def do_classifier_free_guidance(self):
647
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
648
+ return self._guidance_scale > 1
649
+
650
+ @property
651
+ def cross_attention_kwargs(self):
652
+ return self._cross_attention_kwargs
653
+
654
+ @property
655
+ def num_timesteps(self):
656
+ return self._num_timesteps
657
+
658
+ @property
659
+ def interrupt(self):
660
+ return self._interrupt
661
+
662
+ @torch.no_grad()
663
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
664
+ def __call__(
665
+ self,
666
+ prompt: Union[str, List[str]],
667
+ height: int,
668
+ width: int,
669
+ video_length: int,
670
+ data_type: str = "video",
671
+ num_inference_steps: int = 50,
672
+ timesteps: List[int] = None,
673
+ sigmas: List[float] = None,
674
+ guidance_scale: float = 7.5,
675
+ negative_prompt: Optional[Union[str, List[str]]] = None,
676
+ num_videos_per_prompt: Optional[int] = 1,
677
+ eta: float = 0.0,
678
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
679
+ latents: Optional[torch.Tensor] = None,
680
+ prompt_embeds: Optional[torch.Tensor] = None,
681
+ attention_mask: Optional[torch.Tensor] = None,
682
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
683
+ negative_attention_mask: Optional[torch.Tensor] = None,
684
+ output_type: Optional[str] = "pil",
685
+ return_dict: bool = True,
686
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
687
+ guidance_rescale: float = 0.0,
688
+ clip_skip: Optional[int] = None,
689
+ callback_on_step_end: Optional[
690
+ Union[
691
+ Callable[[int, int, Dict], None],
692
+ PipelineCallback,
693
+ MultiPipelineCallbacks,
694
+ ]
695
+ ] = None,
696
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
697
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
698
+ vae_ver: str = "88-4c-sd",
699
+ enable_tiling: bool = False,
700
+ n_tokens: Optional[int] = None,
701
+ embedded_guidance_scale: Optional[float] = None,
702
+ **kwargs,
703
+ ):
704
+ r"""
705
+ The call function to the pipeline for generation.
706
+
707
+ Args:
708
+ prompt (`str` or `List[str]`):
709
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
710
+ height (`int`):
711
+ The height in pixels of the generated image.
712
+ width (`int`):
713
+ The width in pixels of the generated image.
714
+ video_length (`int`):
715
+ The number of frames in the generated video.
716
+ num_inference_steps (`int`, *optional*, defaults to 50):
717
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
718
+ expense of slower inference.
719
+ timesteps (`List[int]`, *optional*):
720
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
721
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
722
+ passed will be used. Must be in descending order.
723
+ sigmas (`List[float]`, *optional*):
724
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
725
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
726
+ will be used.
727
+ guidance_scale (`float`, *optional*, defaults to 7.5):
728
+ A higher guidance scale value encourages the model to generate images closely linked to the text
729
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
730
+ negative_prompt (`str` or `List[str]`, *optional*):
731
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
732
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
733
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
734
+ The number of images to generate per prompt.
735
+ eta (`float`, *optional*, defaults to 0.0):
736
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
737
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
738
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
739
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
740
+ generation deterministic.
741
+ latents (`torch.Tensor`, *optional*):
742
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
743
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
744
+ tensor is generated by sampling using the supplied random `generator`.
745
+ prompt_embeds (`torch.Tensor`, *optional*):
746
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
747
+ provided, text embeddings are generated from the `prompt` input argument.
748
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
749
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
750
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
751
+
752
+ output_type (`str`, *optional*, defaults to `"pil"`):
753
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
754
+ return_dict (`bool`, *optional*, defaults to `True`):
755
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
756
+ plain tuple.
757
+ cross_attention_kwargs (`dict`, *optional*):
758
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
759
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
760
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
761
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
762
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
763
+ using zero terminal SNR.
764
+ clip_skip (`int`, *optional*):
765
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
766
+ the output of the pre-final layer will be used for computing the prompt embeddings.
767
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
768
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
769
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
770
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
771
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
772
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
773
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
774
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
775
+ `._callback_tensor_inputs` attribute of your pipeline class.
776
+
777
+ Examples:
778
+
779
+ Returns:
780
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
781
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
782
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
783
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
784
+ "not-safe-for-work" (nsfw) content.
785
+ """
786
+ callback = kwargs.pop("callback", None)
787
+ callback_steps = kwargs.pop("callback_steps", None)
788
+
789
+ if callback is not None:
790
+ deprecate(
791
+ "callback",
792
+ "1.0.0",
793
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
794
+ )
795
+ if callback_steps is not None:
796
+ deprecate(
797
+ "callback_steps",
798
+ "1.0.0",
799
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
800
+ )
801
+
802
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
803
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
804
+
805
+ # 0. Default height and width to unet
806
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
807
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
808
+ # to deal with lora scaling and other possible forward hooks
809
+
810
+ # 1. Check inputs. Raise error if not correct
811
+ self.check_inputs(
812
+ prompt,
813
+ height,
814
+ width,
815
+ video_length,
816
+ callback_steps,
817
+ negative_prompt,
818
+ prompt_embeds,
819
+ negative_prompt_embeds,
820
+ callback_on_step_end_tensor_inputs,
821
+ vae_ver=vae_ver,
822
+ )
823
+
824
+ self._guidance_scale = guidance_scale
825
+ self._guidance_rescale = guidance_rescale
826
+ self._clip_skip = clip_skip
827
+ self._cross_attention_kwargs = cross_attention_kwargs
828
+ self._interrupt = False
829
+
830
+ # 2. Define call parameters
831
+ if prompt is not None and isinstance(prompt, str):
832
+ batch_size = 1
833
+ elif prompt is not None and isinstance(prompt, list):
834
+ batch_size = len(prompt)
835
+ else:
836
+ batch_size = prompt_embeds.shape[0]
837
+
838
+ device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
839
+
840
+ # 3. Encode input prompt
841
+ lora_scale = (
842
+ self.cross_attention_kwargs.get("scale", None)
843
+ if self.cross_attention_kwargs is not None
844
+ else None
845
+ )
846
+
847
+ (
848
+ prompt_embeds,
849
+ negative_prompt_embeds,
850
+ prompt_mask,
851
+ negative_prompt_mask,
852
+ ) = self.encode_prompt(
853
+ prompt,
854
+ device,
855
+ num_videos_per_prompt,
856
+ self.do_classifier_free_guidance,
857
+ negative_prompt,
858
+ prompt_embeds=prompt_embeds,
859
+ attention_mask=attention_mask,
860
+ negative_prompt_embeds=negative_prompt_embeds,
861
+ negative_attention_mask=negative_attention_mask,
862
+ lora_scale=lora_scale,
863
+ clip_skip=self.clip_skip,
864
+ data_type=data_type,
865
+ )
866
+ if self.text_encoder_2 is not None:
867
+ (
868
+ prompt_embeds_2,
869
+ negative_prompt_embeds_2,
870
+ prompt_mask_2,
871
+ negative_prompt_mask_2,
872
+ ) = self.encode_prompt(
873
+ prompt,
874
+ device,
875
+ num_videos_per_prompt,
876
+ self.do_classifier_free_guidance,
877
+ negative_prompt,
878
+ prompt_embeds=None,
879
+ attention_mask=None,
880
+ negative_prompt_embeds=None,
881
+ negative_attention_mask=None,
882
+ lora_scale=lora_scale,
883
+ clip_skip=self.clip_skip,
884
+ text_encoder=self.text_encoder_2,
885
+ data_type=data_type,
886
+ )
887
+ else:
888
+ prompt_embeds_2 = None
889
+ negative_prompt_embeds_2 = None
890
+ prompt_mask_2 = None
891
+ negative_prompt_mask_2 = None
892
+
893
+ # For classifier free guidance, we need to do two forward passes.
894
+ # Here we concatenate the unconditional and text embeddings into a single batch
895
+ # to avoid doing two forward passes
896
+ if self.do_classifier_free_guidance:
897
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
898
+ if prompt_mask is not None:
899
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
900
+ if prompt_embeds_2 is not None:
901
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
902
+ if prompt_mask_2 is not None:
903
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
904
+
905
+
906
+ # 4. Prepare timesteps
907
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
908
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
909
+ )
910
+ timesteps, num_inference_steps = retrieve_timesteps(
911
+ self.scheduler,
912
+ num_inference_steps,
913
+ device,
914
+ timesteps,
915
+ sigmas,
916
+ **extra_set_timesteps_kwargs,
917
+ )
918
+
919
+ if "884" in vae_ver:
920
+ video_length = (video_length - 1) // 4 + 1
921
+ elif "888" in vae_ver:
922
+ video_length = (video_length - 1) // 8 + 1
923
+ else:
924
+ video_length = video_length
925
+
926
+ # 5. Prepare latent variables
927
+ num_channels_latents = self.transformer.config.in_channels
928
+ latents = self.prepare_latents(
929
+ batch_size * num_videos_per_prompt,
930
+ num_channels_latents,
931
+ height,
932
+ width,
933
+ video_length,
934
+ prompt_embeds.dtype,
935
+ device,
936
+ generator,
937
+ latents,
938
+ )
939
+
940
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
941
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
942
+ self.scheduler.step,
943
+ {"generator": generator, "eta": eta},
944
+ )
945
+
946
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
947
+ autocast_enabled = (
948
+ target_dtype != torch.float32
949
+ ) and not self.args.disable_autocast
950
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
951
+ vae_autocast_enabled = (
952
+ vae_dtype != torch.float32
953
+ ) and not self.args.disable_autocast
954
+
955
+ # 7. Denoising loop
956
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
957
+ self._num_timesteps = len(timesteps)
958
+
959
+ # if is_progress_bar:
960
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
961
+ for i, t in enumerate(timesteps):
962
+ if self.interrupt:
963
+ continue
964
+
965
+ # expand the latents if we are doing classifier free guidance
966
+ latent_model_input = (
967
+ torch.cat([latents] * 2)
968
+ if self.do_classifier_free_guidance
969
+ else latents
970
+ )
971
+ latent_model_input = self.scheduler.scale_model_input(
972
+ latent_model_input, t
973
+ )
974
+
975
+ t_expand = t.repeat(latent_model_input.shape[0])
976
+ guidance_expand = (
977
+ torch.tensor(
978
+ [embedded_guidance_scale] * latent_model_input.shape[0],
979
+ dtype=torch.float32,
980
+ device=device,
981
+ ).to(target_dtype)
982
+ * 1000.0
983
+ if embedded_guidance_scale is not None
984
+ else None
985
+ )
986
+
987
+ # predict the noise residual
988
+ with torch.autocast(
989
+ device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
990
+ ):
991
+ noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
992
+ latent_model_input, # [2, 16, 33, 24, 42]
993
+ t_expand, # [2]
994
+ text_states=prompt_embeds, # [2, 256, 4096]
995
+ text_mask=prompt_mask, # [2, 256]
996
+ text_states_2=prompt_embeds_2, # [2, 768]
997
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
998
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
999
+ guidance=guidance_expand,
1000
+ return_dict=True,
1001
+ )[
1002
+ "x"
1003
+ ]
1004
+
1005
+ # perform guidance
1006
+ if self.do_classifier_free_guidance:
1007
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1008
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
1009
+ noise_pred_text - noise_pred_uncond
1010
+ )
1011
+
1012
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1013
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1014
+ noise_pred = rescale_noise_cfg(
1015
+ noise_pred,
1016
+ noise_pred_text,
1017
+ guidance_rescale=self.guidance_rescale,
1018
+ )
1019
+
1020
+ # compute the previous noisy sample x_t -> x_t-1
1021
+ latents = self.scheduler.step(
1022
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1023
+ )[0]
1024
+
1025
+ if callback_on_step_end is not None:
1026
+ callback_kwargs = {}
1027
+ for k in callback_on_step_end_tensor_inputs:
1028
+ callback_kwargs[k] = locals()[k]
1029
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1030
+
1031
+ latents = callback_outputs.pop("latents", latents)
1032
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1033
+ negative_prompt_embeds = callback_outputs.pop(
1034
+ "negative_prompt_embeds", negative_prompt_embeds
1035
+ )
1036
+
1037
+ # call the callback, if provided
1038
+ if i == len(timesteps) - 1 or (
1039
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1040
+ ):
1041
+ if progress_bar is not None:
1042
+ progress_bar.update()
1043
+ if callback is not None and i % callback_steps == 0:
1044
+ step_idx = i // getattr(self.scheduler, "order", 1)
1045
+ callback(step_idx, t, latents)
1046
+
1047
+ if not output_type == "latent":
1048
+ expand_temporal_dim = False
1049
+ if len(latents.shape) == 4:
1050
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1051
+ latents = latents.unsqueeze(2)
1052
+ expand_temporal_dim = True
1053
+ elif len(latents.shape) == 5:
1054
+ pass
1055
+ else:
1056
+ raise ValueError(
1057
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
1058
+ )
1059
+
1060
+ if (
1061
+ hasattr(self.vae.config, "shift_factor")
1062
+ and self.vae.config.shift_factor
1063
+ ):
1064
+ latents = (
1065
+ latents / self.vae.config.scaling_factor
1066
+ + self.vae.config.shift_factor
1067
+ )
1068
+ else:
1069
+ latents = latents / self.vae.config.scaling_factor
1070
+
1071
+ with torch.autocast(
1072
+ device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
1073
+ ):
1074
+ if enable_tiling:
1075
+ self.vae.enable_tiling()
1076
+ image = self.vae.decode(
1077
+ latents, return_dict=False, generator=generator
1078
+ )[0]
1079
+ else:
1080
+ image = self.vae.decode(
1081
+ latents, return_dict=False, generator=generator
1082
+ )[0]
1083
+
1084
+ if expand_temporal_dim or image.shape[2] == 1:
1085
+ image = image.squeeze(2)
1086
+
1087
+ else:
1088
+ image = latents
1089
+
1090
+ image = (image / 2 + 0.5).clamp(0, 1)
1091
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1092
+ image = image.cpu().float()
1093
+
1094
+ # Offload all models
1095
+ self.maybe_free_model_hooks()
1096
+
1097
+ if not return_dict:
1098
+ return image
1099
+
1100
+ return HunyuanVideoPipelineOutput(videos=image)
hunyuan_model/posemb_layers.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List
3
+
4
+
5
+ def _to_tuple(x, dim=2):
6
+ if isinstance(x, int):
7
+ return (x,) * dim
8
+ elif len(x) == dim:
9
+ return x
10
+ else:
11
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
12
+
13
+
14
+ def get_meshgrid_nd(start, *args, dim=2):
15
+ """
16
+ Get n-D meshgrid with start, stop and num.
17
+
18
+ Args:
19
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
+ n-tuples.
23
+ *args: See above.
24
+ dim (int): Dimension of the meshgrid. Defaults to 2.
25
+
26
+ Returns:
27
+ grid (np.ndarray): [dim, ...]
28
+ """
29
+ if len(args) == 0:
30
+ # start is grid_size
31
+ num = _to_tuple(start, dim=dim)
32
+ start = (0,) * dim
33
+ stop = num
34
+ elif len(args) == 1:
35
+ # start is start, args[0] is stop, step is 1
36
+ start = _to_tuple(start, dim=dim)
37
+ stop = _to_tuple(args[0], dim=dim)
38
+ num = [stop[i] - start[i] for i in range(dim)]
39
+ elif len(args) == 2:
40
+ # start is start, args[0] is stop, args[1] is num
41
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
+ else:
45
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
+
47
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
+ axis_grid = []
49
+ for i in range(dim):
50
+ a, b, n = start[i], stop[i], num[i]
51
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52
+ axis_grid.append(g)
53
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
+
56
+ return grid
57
+
58
+
59
+ #################################################################################
60
+ # Rotary Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
+
64
+
65
+ def reshape_for_broadcast(
66
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
67
+ x: torch.Tensor,
68
+ head_first=False,
69
+ ):
70
+ """
71
+ Reshape frequency tensor for broadcasting it with another tensor.
72
+
73
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
74
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
75
+
76
+ Notes:
77
+ When using FlashMHAModified, head_first should be False.
78
+ When using Attention, head_first should be True.
79
+
80
+ Args:
81
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
82
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
83
+ head_first (bool): head dimension first (except batch dim) or not.
84
+
85
+ Returns:
86
+ torch.Tensor: Reshaped frequency tensor.
87
+
88
+ Raises:
89
+ AssertionError: If the frequency tensor doesn't match the expected shape.
90
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
91
+ """
92
+ ndim = x.ndim
93
+ assert 0 <= 1 < ndim
94
+
95
+ if isinstance(freqs_cis, tuple):
96
+ # freqs_cis: (cos, sin) in real space
97
+ if head_first:
98
+ assert freqs_cis[0].shape == (
99
+ x.shape[-2],
100
+ x.shape[-1],
101
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
102
+ shape = [
103
+ d if i == ndim - 2 or i == ndim - 1 else 1
104
+ for i, d in enumerate(x.shape)
105
+ ]
106
+ else:
107
+ assert freqs_cis[0].shape == (
108
+ x.shape[1],
109
+ x.shape[-1],
110
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
111
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
112
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
113
+ else:
114
+ # freqs_cis: values in complex space
115
+ if head_first:
116
+ assert freqs_cis.shape == (
117
+ x.shape[-2],
118
+ x.shape[-1],
119
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
120
+ shape = [
121
+ d if i == ndim - 2 or i == ndim - 1 else 1
122
+ for i, d in enumerate(x.shape)
123
+ ]
124
+ else:
125
+ assert freqs_cis.shape == (
126
+ x.shape[1],
127
+ x.shape[-1],
128
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
129
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130
+ return freqs_cis.view(*shape)
131
+
132
+
133
+ def rotate_half(x):
134
+ x_real, x_imag = (
135
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
136
+ ) # [B, S, H, D//2]
137
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
138
+
139
+
140
+ def apply_rotary_emb(
141
+ xq: torch.Tensor,
142
+ xk: torch.Tensor,
143
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
144
+ head_first: bool = False,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Apply rotary embeddings to input tensors using the given frequency tensor.
148
+
149
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
150
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
151
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
152
+ returned as real tensors.
153
+
154
+ Args:
155
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
156
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
157
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
158
+ head_first (bool): head dimension first (except batch dim) or not.
159
+
160
+ Returns:
161
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
162
+
163
+ """
164
+ xk_out = None
165
+ if isinstance(freqs_cis, tuple):
166
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
167
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
168
+ # real * cos - imag * sin
169
+ # imag * cos + real * sin
170
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
171
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
172
+ else:
173
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
174
+ xq_ = torch.view_as_complex(
175
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
176
+ ) # [B, S, H, D//2]
177
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
178
+ xq.device
179
+ ) # [S, D//2] --> [1, S, 1, D//2]
180
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
181
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
182
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
183
+ xk_ = torch.view_as_complex(
184
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
185
+ ) # [B, S, H, D//2]
186
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
187
+
188
+ return xq_out, xk_out
189
+
190
+
191
+ def get_nd_rotary_pos_embed(
192
+ rope_dim_list,
193
+ start,
194
+ *args,
195
+ theta=10000.0,
196
+ use_real=False,
197
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
198
+ interpolation_factor: Union[float, List[float]] = 1.0,
199
+ ):
200
+ """
201
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
202
+
203
+ Args:
204
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
205
+ sum(rope_dim_list) should equal to head_dim of attention layer.
206
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
207
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
208
+ *args: See above.
209
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
210
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
211
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
212
+ part and an imaginary part separately.
213
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
214
+
215
+ Returns:
216
+ pos_embed (torch.Tensor): [HW, D/2]
217
+ """
218
+
219
+ grid = get_meshgrid_nd(
220
+ start, *args, dim=len(rope_dim_list)
221
+ ) # [3, W, H, D] / [2, W, H]
222
+
223
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
224
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
225
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
226
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
227
+ assert len(theta_rescale_factor) == len(
228
+ rope_dim_list
229
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
230
+
231
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
232
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
233
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
234
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
235
+ assert len(interpolation_factor) == len(
236
+ rope_dim_list
237
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
238
+
239
+ # use 1/ndim of dimensions to encode grid_axis
240
+ embs = []
241
+ for i in range(len(rope_dim_list)):
242
+ emb = get_1d_rotary_pos_embed(
243
+ rope_dim_list[i],
244
+ grid[i].reshape(-1),
245
+ theta,
246
+ use_real=use_real,
247
+ theta_rescale_factor=theta_rescale_factor[i],
248
+ interpolation_factor=interpolation_factor[i],
249
+ ) # 2 x [WHD, rope_dim_list[i]]
250
+ embs.append(emb)
251
+
252
+ if use_real:
253
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
254
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
255
+ return cos, sin
256
+ else:
257
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
258
+ return emb
259
+
260
+
261
+ def get_1d_rotary_pos_embed(
262
+ dim: int,
263
+ pos: Union[torch.FloatTensor, int],
264
+ theta: float = 10000.0,
265
+ use_real: bool = False,
266
+ theta_rescale_factor: float = 1.0,
267
+ interpolation_factor: float = 1.0,
268
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
269
+ """
270
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
271
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
272
+
273
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
274
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
275
+ The returned tensor contains complex values in complex64 data type.
276
+
277
+ Args:
278
+ dim (int): Dimension of the frequency tensor.
279
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
280
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
281
+ use_real (bool, optional): If True, return real part and imaginary part separately.
282
+ Otherwise, return complex numbers.
283
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
284
+
285
+ Returns:
286
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
287
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
288
+ """
289
+ if isinstance(pos, int):
290
+ pos = torch.arange(pos).float()
291
+
292
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
293
+ # has some connection to NTK literature
294
+ if theta_rescale_factor != 1.0:
295
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
296
+
297
+ freqs = 1.0 / (
298
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
299
+ ) # [D/2]
300
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
301
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
302
+ if use_real:
303
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
304
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
305
+ return freqs_cos, freqs_sin
306
+ else:
307
+ freqs_cis = torch.polar(
308
+ torch.ones_like(freqs), freqs
309
+ ) # complex64 # [S, D/2]
310
+ return freqs_cis
hunyuan_model/text_encoder.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+ import os
4
+ from typing import Optional, Tuple, Union
5
+ from copy import deepcopy
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import (
10
+ CLIPTextModel,
11
+ CLIPTokenizer,
12
+ AutoTokenizer,
13
+ AutoModel,
14
+ CLIPConfig,
15
+ LlamaForCausalLM,
16
+ LlamaConfig,
17
+ )
18
+ from transformers.utils import ModelOutput
19
+ from transformers.models.llama import LlamaModel
20
+ from safetensors.torch import load_file
21
+ from accelerate import init_empty_weights
22
+
23
+ import logging
24
+
25
+ logger = logging.getLogger(__name__)
26
+ logging.basicConfig(level=logging.INFO)
27
+
28
+
29
+ CLIP_L_HUGGINGFACE_MODEL_ID = "openai/clip-vit-large-patch14"
30
+ LLAVA_HUGGINGFACE_MODEL_ID = "xtuner/llava-llama-3-8b-v1_1-transformers"
31
+
32
+ CLIP_CONFIG = {
33
+ "_name_or_path": "clip-vit-large-patch14/",
34
+ "architectures": ["CLIPModel"],
35
+ "initializer_factor": 1.0,
36
+ "logit_scale_init_value": 2.6592,
37
+ "model_type": "clip",
38
+ "projection_dim": 768,
39
+ # "text_config": {
40
+ "_name_or_path": "",
41
+ "add_cross_attention": False,
42
+ "architectures": None,
43
+ "attention_dropout": 0.0,
44
+ "bad_words_ids": None,
45
+ "bos_token_id": 0,
46
+ "chunk_size_feed_forward": 0,
47
+ "cross_attention_hidden_size": None,
48
+ "decoder_start_token_id": None,
49
+ "diversity_penalty": 0.0,
50
+ "do_sample": False,
51
+ "dropout": 0.0,
52
+ "early_stopping": False,
53
+ "encoder_no_repeat_ngram_size": 0,
54
+ "eos_token_id": 2,
55
+ "finetuning_task": None,
56
+ "forced_bos_token_id": None,
57
+ "forced_eos_token_id": None,
58
+ "hidden_act": "quick_gelu",
59
+ "hidden_size": 768,
60
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
61
+ "initializer_factor": 1.0,
62
+ "initializer_range": 0.02,
63
+ "intermediate_size": 3072,
64
+ "is_decoder": False,
65
+ "is_encoder_decoder": False,
66
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
67
+ "layer_norm_eps": 1e-05,
68
+ "length_penalty": 1.0,
69
+ "max_length": 20,
70
+ "max_position_embeddings": 77,
71
+ "min_length": 0,
72
+ "model_type": "clip_text_model",
73
+ "no_repeat_ngram_size": 0,
74
+ "num_attention_heads": 12,
75
+ "num_beam_groups": 1,
76
+ "num_beams": 1,
77
+ "num_hidden_layers": 12,
78
+ "num_return_sequences": 1,
79
+ "output_attentions": False,
80
+ "output_hidden_states": False,
81
+ "output_scores": False,
82
+ "pad_token_id": 1,
83
+ "prefix": None,
84
+ "problem_type": None,
85
+ "projection_dim": 768,
86
+ "pruned_heads": {},
87
+ "remove_invalid_values": False,
88
+ "repetition_penalty": 1.0,
89
+ "return_dict": True,
90
+ "return_dict_in_generate": False,
91
+ "sep_token_id": None,
92
+ "task_specific_params": None,
93
+ "temperature": 1.0,
94
+ "tie_encoder_decoder": False,
95
+ "tie_word_embeddings": True,
96
+ "tokenizer_class": None,
97
+ "top_k": 50,
98
+ "top_p": 1.0,
99
+ "torch_dtype": None,
100
+ "torchscript": False,
101
+ "transformers_version": "4.16.0.dev0",
102
+ "use_bfloat16": False,
103
+ "vocab_size": 49408,
104
+ # },
105
+ # "text_config_dict": {
106
+ "hidden_size": 768,
107
+ "intermediate_size": 3072,
108
+ "num_attention_heads": 12,
109
+ "num_hidden_layers": 12,
110
+ "projection_dim": 768,
111
+ # },
112
+ # "torch_dtype": "float32",
113
+ # "transformers_version": null
114
+ }
115
+
116
+ LLAMA_CONFIG = {
117
+ "architectures": ["LlamaForCausalLM"],
118
+ "attention_bias": False,
119
+ "attention_dropout": 0.0,
120
+ "bos_token_id": 128000,
121
+ "eos_token_id": 128001,
122
+ "head_dim": 128,
123
+ "hidden_act": "silu",
124
+ "hidden_size": 4096,
125
+ "initializer_range": 0.02,
126
+ "intermediate_size": 14336,
127
+ "max_position_embeddings": 8192,
128
+ "mlp_bias": False,
129
+ "model_type": "llama",
130
+ "num_attention_heads": 32,
131
+ "num_hidden_layers": 32,
132
+ "num_key_value_heads": 8,
133
+ "pretraining_tp": 1,
134
+ "rms_norm_eps": 1e-05,
135
+ "rope_scaling": None,
136
+ "rope_theta": 500000.0,
137
+ "tie_word_embeddings": False,
138
+ "torch_dtype": "float16",
139
+ "transformers_version": "4.46.3",
140
+ "use_cache": True,
141
+ "vocab_size": 128320,
142
+ }
143
+
144
+ # When using decoder-only models, we must provide a prompt template to instruct the text encoder
145
+ # on how to generate the text.
146
+ # --------------------------------------------------------------------
147
+ PROMPT_TEMPLATE_ENCODE = (
148
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
149
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
150
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
151
+ )
152
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
153
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
154
+ "1. The main content and theme of the video."
155
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
156
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
157
+ "4. background environment, light, style and atmosphere."
158
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
159
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
160
+ )
161
+
162
+ NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
163
+
164
+ PROMPT_TEMPLATE = {
165
+ "dit-llm-encode": {
166
+ "template": PROMPT_TEMPLATE_ENCODE,
167
+ "crop_start": 36,
168
+ },
169
+ "dit-llm-encode-video": {
170
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
171
+ "crop_start": 95,
172
+ },
173
+ }
174
+
175
+
176
+ def use_default(value, default):
177
+ return value if value is not None else default
178
+
179
+
180
+ def load_clip_l(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
181
+ if os.path.isdir(text_encoder_path):
182
+ # load from directory, configs are in the directory
183
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
184
+ else:
185
+ # load from file, we create the model with the appropriate config
186
+ config = CLIPConfig(**CLIP_CONFIG)
187
+ with init_empty_weights():
188
+ text_encoder = CLIPTextModel._from_config(config, torch_dtype=dtype)
189
+
190
+ state_dict = load_file(text_encoder_path)
191
+
192
+ text_encoder.load_state_dict(state_dict, strict=True, assign=True)
193
+ # if dtype is not None:
194
+ # text_encoder.to(dtype=dtype)
195
+
196
+ return text_encoder
197
+
198
+
199
+ def load_clip_l_tokenizer(tokenizer_path: str):
200
+ if os.path.isdir(tokenizer_path):
201
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
202
+ else:
203
+ # load from Hugging Face
204
+ logger.info(f"Loading tokenizer from Hugging Face: {CLIP_L_HUGGINGFACE_MODEL_ID}")
205
+ tokenizer = CLIPTokenizer.from_pretrained(CLIP_L_HUGGINGFACE_MODEL_ID, max_length=77)
206
+
207
+ return tokenizer
208
+
209
+
210
+ def load_llm(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
211
+ if os.path.isdir(text_encoder_path):
212
+ # load from directory, configs are in the directory
213
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
214
+ else:
215
+ # load from file, we create the model with the appropriate config
216
+ config = LlamaConfig(**LLAMA_CONFIG)
217
+ with init_empty_weights():
218
+ text_encoder = LlamaForCausalLM._from_config(config, torch_dtype=dtype)
219
+
220
+ state_dict = load_file(text_encoder_path)
221
+
222
+ # support weights from ComfyUI
223
+ if "tokenizer" in state_dict:
224
+ state_dict.pop("tokenizer")
225
+
226
+ text_encoder.load_state_dict(state_dict, strict=True, assign=True)
227
+
228
+ return text_encoder
229
+
230
+
231
+ def load_llm_tokenizer(tokenizer_path: str, padding_side="right"):
232
+ if os.path.isdir(tokenizer_path):
233
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
234
+ else:
235
+ # load from Hugging Face
236
+ logger.info(f"Loading tokenizer from Hugging Face: {LLAVA_HUGGINGFACE_MODEL_ID}")
237
+ tokenizer = AutoTokenizer.from_pretrained(LLAVA_HUGGINGFACE_MODEL_ID, padding_side=padding_side)
238
+
239
+ return tokenizer
240
+
241
+
242
+ def load_text_encoder(
243
+ text_encoder_type: str,
244
+ text_encoder_path: str,
245
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
246
+ ):
247
+ logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
248
+
249
+ # reduce peak memory usage by specifying the dtype of the model
250
+ dtype = text_encoder_dtype
251
+ if text_encoder_type == "clipL":
252
+ text_encoder = load_clip_l(text_encoder_path, dtype=dtype)
253
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
254
+ elif text_encoder_type == "llm":
255
+ text_encoder = load_llm(text_encoder_path, dtype=dtype)
256
+ if hasattr(text_encoder, "norm"):
257
+ text_encoder.final_layer_norm = text_encoder.norm # by from_pretrained
258
+ else:
259
+ text_encoder.final_layer_norm = text_encoder.model.norm # by _from_config
260
+ else:
261
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
262
+ # from_pretrained will ensure that the model is in eval mode.
263
+
264
+ if dtype is not None:
265
+ text_encoder = text_encoder.to(dtype=dtype)
266
+
267
+ text_encoder.requires_grad_(False)
268
+
269
+ logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
270
+ return text_encoder, text_encoder_path
271
+
272
+
273
+ def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
274
+ logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
275
+
276
+ if tokenizer_type == "clipL":
277
+ tokenizer = load_clip_l_tokenizer(tokenizer_path)
278
+ elif tokenizer_type == "llm":
279
+ tokenizer = load_llm_tokenizer(tokenizer_path, padding_side=padding_side)
280
+ else:
281
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
282
+
283
+ return tokenizer, tokenizer_path
284
+
285
+
286
+ @dataclass
287
+ class TextEncoderModelOutput(ModelOutput):
288
+ """
289
+ Base class for model's outputs that also contains a pooling of the last hidden states.
290
+
291
+ Args:
292
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
293
+ Sequence of hidden-states at the output of the last layer of the model.
294
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
295
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
296
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
297
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
298
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
299
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
300
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
301
+ List of decoded texts.
302
+ """
303
+
304
+ hidden_state: torch.FloatTensor = None
305
+ attention_mask: Optional[torch.LongTensor] = None
306
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
307
+ text_outputs: Optional[list] = None
308
+
309
+
310
+ class TextEncoder(nn.Module):
311
+ def __init__(
312
+ self,
313
+ text_encoder_type: str,
314
+ max_length: int,
315
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
316
+ text_encoder_path: Optional[str] = None,
317
+ tokenizer_type: Optional[str] = None,
318
+ tokenizer_path: Optional[str] = None,
319
+ output_key: Optional[str] = None,
320
+ use_attention_mask: bool = True,
321
+ input_max_length: Optional[int] = None,
322
+ prompt_template: Optional[dict] = None,
323
+ prompt_template_video: Optional[dict] = None,
324
+ hidden_state_skip_layer: Optional[int] = None,
325
+ apply_final_norm: bool = False,
326
+ reproduce: bool = False,
327
+ ):
328
+ super().__init__()
329
+ self.text_encoder_type = text_encoder_type
330
+ self.max_length = max_length
331
+ # self.precision = text_encoder_precision
332
+ self.model_path = text_encoder_path
333
+ self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
334
+ self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
335
+ self.use_attention_mask = use_attention_mask
336
+ if prompt_template_video is not None:
337
+ assert use_attention_mask is True, "Attention mask is True required when training videos."
338
+ self.input_max_length = input_max_length if input_max_length is not None else max_length
339
+ self.prompt_template = prompt_template
340
+ self.prompt_template_video = prompt_template_video
341
+ self.hidden_state_skip_layer = hidden_state_skip_layer
342
+ self.apply_final_norm = apply_final_norm
343
+ self.reproduce = reproduce
344
+
345
+ self.use_template = self.prompt_template is not None
346
+ if self.use_template:
347
+ assert (
348
+ isinstance(self.prompt_template, dict) and "template" in self.prompt_template
349
+ ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
350
+ assert "{}" in str(self.prompt_template["template"]), (
351
+ "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
352
+ f"got {self.prompt_template['template']}"
353
+ )
354
+
355
+ self.use_video_template = self.prompt_template_video is not None
356
+ if self.use_video_template:
357
+ if self.prompt_template_video is not None:
358
+ assert (
359
+ isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
360
+ ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
361
+ assert "{}" in str(self.prompt_template_video["template"]), (
362
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
363
+ f"got {self.prompt_template_video['template']}"
364
+ )
365
+
366
+ if "t5" in text_encoder_type:
367
+ self.output_key = output_key or "last_hidden_state"
368
+ elif "clip" in text_encoder_type:
369
+ self.output_key = output_key or "pooler_output"
370
+ elif "llm" in text_encoder_type or "glm" in text_encoder_type:
371
+ self.output_key = output_key or "last_hidden_state"
372
+ else:
373
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
374
+
375
+ self.model, self.model_path = load_text_encoder(
376
+ text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
377
+ )
378
+ self.dtype = self.model.dtype
379
+
380
+ self.tokenizer, self.tokenizer_path = load_tokenizer(
381
+ tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
382
+ )
383
+
384
+ def __repr__(self):
385
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
386
+
387
+ @property
388
+ def device(self):
389
+ return self.model.device
390
+
391
+ @staticmethod
392
+ def apply_text_to_template(text, template, prevent_empty_text=True):
393
+ """
394
+ Apply text to template.
395
+
396
+ Args:
397
+ text (str): Input text.
398
+ template (str or list): Template string or list of chat conversation.
399
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
400
+ by adding a space. Defaults to True.
401
+ """
402
+ if isinstance(template, str):
403
+ # Will send string to tokenizer. Used for llm
404
+ return template.format(text)
405
+ else:
406
+ raise TypeError(f"Unsupported template type: {type(template)}")
407
+
408
+ def text2tokens(self, text, data_type="image"):
409
+ """
410
+ Tokenize the input text.
411
+
412
+ Args:
413
+ text (str or list): Input text.
414
+ """
415
+ tokenize_input_type = "str"
416
+ if self.use_template:
417
+ if data_type == "image":
418
+ prompt_template = self.prompt_template["template"]
419
+ elif data_type == "video":
420
+ prompt_template = self.prompt_template_video["template"]
421
+ else:
422
+ raise ValueError(f"Unsupported data type: {data_type}")
423
+ if isinstance(text, (list, tuple)):
424
+ text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
425
+ if isinstance(text[0], list):
426
+ tokenize_input_type = "list"
427
+ elif isinstance(text, str):
428
+ text = self.apply_text_to_template(text, prompt_template)
429
+ if isinstance(text, list):
430
+ tokenize_input_type = "list"
431
+ else:
432
+ raise TypeError(f"Unsupported text type: {type(text)}")
433
+
434
+ kwargs = dict(
435
+ truncation=True,
436
+ max_length=self.max_length,
437
+ padding="max_length",
438
+ return_tensors="pt",
439
+ )
440
+ if tokenize_input_type == "str":
441
+ return self.tokenizer(
442
+ text,
443
+ return_length=False,
444
+ return_overflowing_tokens=False,
445
+ return_attention_mask=True,
446
+ **kwargs,
447
+ )
448
+ elif tokenize_input_type == "list":
449
+ return self.tokenizer.apply_chat_template(
450
+ text,
451
+ add_generation_prompt=True,
452
+ tokenize=True,
453
+ return_dict=True,
454
+ **kwargs,
455
+ )
456
+ else:
457
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
458
+
459
+ def encode(
460
+ self,
461
+ batch_encoding,
462
+ use_attention_mask=None,
463
+ output_hidden_states=False,
464
+ do_sample=None,
465
+ hidden_state_skip_layer=None,
466
+ return_texts=False,
467
+ data_type="image",
468
+ device=None,
469
+ ):
470
+ """
471
+ Args:
472
+ batch_encoding (dict): Batch encoding from tokenizer.
473
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
474
+ Defaults to None.
475
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
476
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
477
+ output_hidden_states will be set True. Defaults to False.
478
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
479
+ When self.produce is False, do_sample is set to True by default.
480
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
481
+ If None, self.output_key will be used. Defaults to None.
482
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
483
+ """
484
+ device = self.model.device if device is None else device
485
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
486
+ hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
487
+ do_sample = use_default(do_sample, not self.reproduce)
488
+ attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
489
+ outputs = self.model(
490
+ input_ids=batch_encoding["input_ids"].to(device),
491
+ attention_mask=attention_mask,
492
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
493
+ )
494
+ if hidden_state_skip_layer is not None:
495
+ last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
496
+ # Real last hidden state already has layer norm applied. So here we only apply it
497
+ # for intermediate layers.
498
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
499
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
500
+ else:
501
+ last_hidden_state = outputs[self.output_key]
502
+
503
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
504
+ if self.use_template:
505
+ if data_type == "image":
506
+ crop_start = self.prompt_template.get("crop_start", -1)
507
+ elif data_type == "video":
508
+ crop_start = self.prompt_template_video.get("crop_start", -1)
509
+ else:
510
+ raise ValueError(f"Unsupported data type: {data_type}")
511
+ if crop_start > 0:
512
+ last_hidden_state = last_hidden_state[:, crop_start:]
513
+ attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
514
+
515
+ if output_hidden_states:
516
+ return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
517
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
518
+
519
+ def forward(
520
+ self,
521
+ text,
522
+ use_attention_mask=None,
523
+ output_hidden_states=False,
524
+ do_sample=False,
525
+ hidden_state_skip_layer=None,
526
+ return_texts=False,
527
+ ):
528
+ batch_encoding = self.text2tokens(text)
529
+ return self.encode(
530
+ batch_encoding,
531
+ use_attention_mask=use_attention_mask,
532
+ output_hidden_states=output_hidden_states,
533
+ do_sample=do_sample,
534
+ hidden_state_skip_layer=hidden_state_skip_layer,
535
+ return_texts=return_texts,
536
+ )
537
+
538
+
539
+ # region HunyanVideo architecture
540
+
541
+
542
+ def load_text_encoder_1(
543
+ text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
544
+ ) -> TextEncoder:
545
+ text_encoder_dtype = dtype or torch.float16
546
+ text_encoder_type = "llm"
547
+ text_len = 256
548
+ hidden_state_skip_layer = 2
549
+ apply_final_norm = False
550
+ reproduce = False
551
+
552
+ prompt_template = "dit-llm-encode"
553
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
554
+ prompt_template_video = "dit-llm-encode-video"
555
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
556
+
557
+ crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
558
+ max_length = text_len + crop_start
559
+
560
+ text_encoder_1 = TextEncoder(
561
+ text_encoder_type=text_encoder_type,
562
+ max_length=max_length,
563
+ text_encoder_dtype=text_encoder_dtype,
564
+ text_encoder_path=text_encoder_dir,
565
+ tokenizer_type=text_encoder_type,
566
+ prompt_template=prompt_template,
567
+ prompt_template_video=prompt_template_video,
568
+ hidden_state_skip_layer=hidden_state_skip_layer,
569
+ apply_final_norm=apply_final_norm,
570
+ reproduce=reproduce,
571
+ )
572
+ text_encoder_1.eval()
573
+
574
+ if fp8_llm:
575
+ org_dtype = text_encoder_1.dtype
576
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
577
+ text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
578
+
579
+ # prepare LLM for fp8
580
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
581
+ def forward_hook(module):
582
+ def forward(hidden_states):
583
+ input_dtype = hidden_states.dtype
584
+ hidden_states = hidden_states.to(torch.float32)
585
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
586
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
587
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
588
+
589
+ return forward
590
+
591
+ for module in llama_model.modules():
592
+ if module.__class__.__name__ in ["Embedding"]:
593
+ # print("set", module.__class__.__name__, "to", target_dtype)
594
+ module.to(target_dtype)
595
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
596
+ # print("set", module.__class__.__name__, "hooks")
597
+ module.forward = forward_hook(module)
598
+
599
+ prepare_fp8(text_encoder_1.model, org_dtype)
600
+ else:
601
+ text_encoder_1.to(device=device)
602
+
603
+ return text_encoder_1
604
+
605
+
606
+ def load_text_encoder_2(
607
+ text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
608
+ ) -> TextEncoder:
609
+ text_encoder_dtype = dtype or torch.float16
610
+ reproduce = False
611
+
612
+ text_encoder_2_type = "clipL"
613
+ text_len_2 = 77
614
+
615
+ text_encoder_2 = TextEncoder(
616
+ text_encoder_type=text_encoder_2_type,
617
+ max_length=text_len_2,
618
+ text_encoder_dtype=text_encoder_dtype,
619
+ text_encoder_path=text_encoder_dir,
620
+ tokenizer_type=text_encoder_2_type,
621
+ reproduce=reproduce,
622
+ )
623
+ text_encoder_2.eval()
624
+
625
+ text_encoder_2.to(device=device)
626
+
627
+ return text_encoder_2
628
+
629
+
630
+ # endregion
631
+
632
+
633
+ if __name__ == "__main__":
634
+ import argparse
635
+ from utils.model_utils import str_to_dtype
636
+
637
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
638
+
639
+ parser = argparse.ArgumentParser()
640
+ parser.add_argument("type", type=str, help="Text Encoder type")
641
+ parser.add_argument("path1", type=str, help="Text Encoder directory or file 1")
642
+ parser.add_argument("path2", type=str, help="Text Encoder directory or file 2")
643
+ parser.add_argument("--dtype", type=str, default=None, help="Data type for Text Encoder")
644
+ args = parser.parse_args()
645
+
646
+ dtype = str_to_dtype(args.dtype) if args.dtype is not None else torch.float16
647
+
648
+ """
649
+ if args.type == "clipL":
650
+ text_encoder_1st = load_clip_l(args.path1, dtype=dtype)
651
+ tokenizer_1st = load_clip_l_tokenizer(args.path1)
652
+ text_encoder_2nd = load_clip_l(args.path2, dtype=dtype)
653
+ tokenizer_2nd = load_clip_l_tokenizer(args.path2)
654
+ elif args.type == "llm":
655
+ text_encoder_1st = load_llm(args.path1, dtype=dtype)
656
+ tokenizer_1st = load_llm_tokenizer(args.path1)
657
+ text_encoder_2nd = load_llm(args.path2, dtype=dtype)
658
+ tokenizer_2nd = load_llm_tokenizer(args.path2)
659
+
660
+ print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
661
+ print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
662
+
663
+ text_encoder_1st.to(device=device)
664
+ text_encoder_2nd.to(device=device)
665
+
666
+ test_text = "A cat sitting on a table"
667
+ token_ids_1st = tokenizer_1st(test_text, return_tensors="pt")["input_ids"]
668
+ token_ids_2nd = tokenizer_2nd(test_text, return_tensors="pt")["input_ids"]
669
+ assert torch.allclose(token_ids_1st, token_ids_2nd)
670
+ print(f"Token IDs are the same: {token_ids_1st}")
671
+
672
+ with torch.no_grad():
673
+ text_encoder_1st_output = text_encoder_1st(token_ids_1st.to(device), output_hidden_states=True)
674
+ text_encoder_2nd_output = text_encoder_2nd(token_ids_2nd.to(device), output_hidden_states=True)
675
+ print(f"1st Text Encoder output keys: {text_encoder_1st_output.keys()}")
676
+ print(f"2nd Text Encoder output keys: {text_encoder_2nd_output.keys()}")
677
+ for key in text_encoder_1st_output:
678
+ print(f"Checking output: {key}")
679
+ assert key in text_encoder_2nd_output, f"Key {key} not in 2nd Text Encoder output"
680
+ assert torch.allclose(text_encoder_1st_output[key], text_encoder_2nd_output[key])
681
+ print(f"Outputs are the same: {key}")
682
+ print("All outputs are the same.")
683
+ """
684
+
685
+ if args.type == "clipL":
686
+ text_encoder_1st = load_text_encoder_2(args.path1, device, dtype)
687
+ text_encoder_2nd = load_text_encoder_2(args.path2, device, dtype)
688
+ elif args.type == "llm":
689
+ text_encoder_1st = load_text_encoder_1(args.path1, device, False, dtype)
690
+ text_encoder_2nd = load_text_encoder_1(args.path2, device, False, dtype)
691
+ print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
692
+ print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
693
+
694
+ prompt = "A cat sitting on a table"
695
+ data_type = "video" # video only, image is not supported
696
+ text_inputs_1st = text_encoder_1st.text2tokens(prompt, data_type=data_type)
697
+ text_inputs_2nd = text_encoder_2nd.text2tokens(prompt, data_type=data_type)
698
+ print(text_inputs_1st)
699
+ assert torch.allclose(text_inputs_1st["input_ids"], text_inputs_2nd["input_ids"])
700
+
701
+ with torch.no_grad():
702
+ prompt_outputs_1st = text_encoder_1st.encode(text_inputs_1st, data_type=data_type)
703
+ prompt_outputs_2nd = text_encoder_2nd.encode(text_inputs_1st, data_type=data_type)
704
+
705
+ # prompt_outputs.hidden_state, prompt_outputs.attention_mask
706
+ assert torch.allclose(prompt_outputs_1st.hidden_state, prompt_outputs_2nd.hidden_state)
707
+ print("Hidden states are the same.")
708
+ assert torch.allclose(prompt_outputs_1st.attention_mask, prompt_outputs_2nd.attention_mask)
709
+ print("Attention masks are the same.")
710
+ print("All outputs are the same.")
hunyuan_model/token_refiner.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+ from .activation_layers import get_activation_layer
9
+ from .attention import attention
10
+ from .norm_layers import get_norm_layer
11
+ from .embed_layers import TimestepEmbedder, TextProjection
12
+ from .mlp_layers import MLP
13
+ from .modulate_layers import modulate, apply_gate
14
+
15
+
16
+ class IndividualTokenRefinerBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ hidden_size,
20
+ heads_num,
21
+ mlp_width_ratio: str = 4.0,
22
+ mlp_drop_rate: float = 0.0,
23
+ act_type: str = "silu",
24
+ qk_norm: bool = False,
25
+ qk_norm_type: str = "layer",
26
+ qkv_bias: bool = True,
27
+ dtype: Optional[torch.dtype] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ factory_kwargs = {"device": device, "dtype": dtype}
31
+ super().__init__()
32
+ self.heads_num = heads_num
33
+ head_dim = hidden_size // heads_num
34
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35
+
36
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
37
+ self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
38
+ qk_norm_layer = get_norm_layer(qk_norm_type)
39
+ self.self_attn_q_norm = (
40
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
41
+ )
42
+ self.self_attn_k_norm = (
43
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
44
+ )
45
+ self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
46
+
47
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
48
+ act_layer = get_activation_layer(act_type)
49
+ self.mlp = MLP(
50
+ in_channels=hidden_size,
51
+ hidden_channels=mlp_hidden_dim,
52
+ act_layer=act_layer,
53
+ drop=mlp_drop_rate,
54
+ **factory_kwargs,
55
+ )
56
+
57
+ self.adaLN_modulation = nn.Sequential(
58
+ act_layer(),
59
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
60
+ )
61
+ # Zero-initialize the modulation
62
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
63
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
64
+
65
+ self.gradient_checkpointing = False
66
+
67
+ def enable_gradient_checkpointing(self):
68
+ self.gradient_checkpointing = True
69
+
70
+ def disable_gradient_checkpointing(self):
71
+ self.gradient_checkpointing = False
72
+
73
+ def _forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
77
+ attn_mask: torch.Tensor = None,
78
+ ):
79
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
80
+
81
+ norm_x = self.norm1(x)
82
+ qkv = self.self_attn_qkv(norm_x)
83
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
84
+ # Apply QK-Norm if needed
85
+ q = self.self_attn_q_norm(q).to(v)
86
+ k = self.self_attn_k_norm(k).to(v)
87
+
88
+ # Self-Attention
89
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
90
+
91
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
92
+
93
+ # FFN Layer
94
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
95
+
96
+ return x
97
+
98
+ def forward(self, *args, **kwargs):
99
+ if self.training and self.gradient_checkpointing:
100
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
101
+ else:
102
+ return self._forward(*args, **kwargs)
103
+
104
+
105
+ class IndividualTokenRefiner(nn.Module):
106
+ def __init__(
107
+ self,
108
+ hidden_size,
109
+ heads_num,
110
+ depth,
111
+ mlp_width_ratio: float = 4.0,
112
+ mlp_drop_rate: float = 0.0,
113
+ act_type: str = "silu",
114
+ qk_norm: bool = False,
115
+ qk_norm_type: str = "layer",
116
+ qkv_bias: bool = True,
117
+ dtype: Optional[torch.dtype] = None,
118
+ device: Optional[torch.device] = None,
119
+ ):
120
+ factory_kwargs = {"device": device, "dtype": dtype}
121
+ super().__init__()
122
+ self.blocks = nn.ModuleList(
123
+ [
124
+ IndividualTokenRefinerBlock(
125
+ hidden_size=hidden_size,
126
+ heads_num=heads_num,
127
+ mlp_width_ratio=mlp_width_ratio,
128
+ mlp_drop_rate=mlp_drop_rate,
129
+ act_type=act_type,
130
+ qk_norm=qk_norm,
131
+ qk_norm_type=qk_norm_type,
132
+ qkv_bias=qkv_bias,
133
+ **factory_kwargs,
134
+ )
135
+ for _ in range(depth)
136
+ ]
137
+ )
138
+
139
+ def enable_gradient_checkpointing(self):
140
+ for block in self.blocks:
141
+ block.enable_gradient_checkpointing()
142
+
143
+ def disable_gradient_checkpointing(self):
144
+ for block in self.blocks:
145
+ block.disable_gradient_checkpointing()
146
+
147
+ def forward(
148
+ self,
149
+ x: torch.Tensor,
150
+ c: torch.LongTensor,
151
+ mask: Optional[torch.Tensor] = None,
152
+ ):
153
+ self_attn_mask = None
154
+ if mask is not None:
155
+ batch_size = mask.shape[0]
156
+ seq_len = mask.shape[1]
157
+ mask = mask.to(x.device)
158
+ # batch_size x 1 x seq_len x seq_len
159
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
160
+ # batch_size x 1 x seq_len x seq_len
161
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
162
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
163
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
164
+ # avoids self-attention weight being NaN for padding tokens
165
+ self_attn_mask[:, :, :, 0] = True
166
+
167
+ for block in self.blocks:
168
+ x = block(x, c, self_attn_mask)
169
+ return x
170
+
171
+
172
+ class SingleTokenRefiner(nn.Module):
173
+ """
174
+ A single token refiner block for llm text embedding refine.
175
+ """
176
+
177
+ def __init__(
178
+ self,
179
+ in_channels,
180
+ hidden_size,
181
+ heads_num,
182
+ depth,
183
+ mlp_width_ratio: float = 4.0,
184
+ mlp_drop_rate: float = 0.0,
185
+ act_type: str = "silu",
186
+ qk_norm: bool = False,
187
+ qk_norm_type: str = "layer",
188
+ qkv_bias: bool = True,
189
+ attn_mode: str = "torch",
190
+ dtype: Optional[torch.dtype] = None,
191
+ device: Optional[torch.device] = None,
192
+ ):
193
+ factory_kwargs = {"device": device, "dtype": dtype}
194
+ super().__init__()
195
+ self.attn_mode = attn_mode
196
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
197
+
198
+ self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
199
+
200
+ act_layer = get_activation_layer(act_type)
201
+ # Build timestep embedding layer
202
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
203
+ # Build context embedding layer
204
+ self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
205
+
206
+ self.individual_token_refiner = IndividualTokenRefiner(
207
+ hidden_size=hidden_size,
208
+ heads_num=heads_num,
209
+ depth=depth,
210
+ mlp_width_ratio=mlp_width_ratio,
211
+ mlp_drop_rate=mlp_drop_rate,
212
+ act_type=act_type,
213
+ qk_norm=qk_norm,
214
+ qk_norm_type=qk_norm_type,
215
+ qkv_bias=qkv_bias,
216
+ **factory_kwargs,
217
+ )
218
+
219
+ def enable_gradient_checkpointing(self):
220
+ self.individual_token_refiner.enable_gradient_checkpointing()
221
+
222
+ def disable_gradient_checkpointing(self):
223
+ self.individual_token_refiner.disable_gradient_checkpointing()
224
+
225
+ def forward(
226
+ self,
227
+ x: torch.Tensor,
228
+ t: torch.LongTensor,
229
+ mask: Optional[torch.LongTensor] = None,
230
+ ):
231
+ timestep_aware_representations = self.t_embedder(t)
232
+
233
+ if mask is None:
234
+ context_aware_representations = x.mean(dim=1)
235
+ else:
236
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
237
+ context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
238
+ context_aware_representations = self.c_embedder(context_aware_representations)
239
+ c = timestep_aware_representations + context_aware_representations
240
+
241
+ x = self.input_embedder(x)
242
+
243
+ x = self.individual_token_refiner(x, c, mask)
244
+
245
+ return x
hunyuan_model/vae.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+ from typing import Optional, Tuple, Union
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from diffusers.utils import BaseOutput, is_torch_version
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from diffusers.models.attention_processor import SpatialNorm
13
+ from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
14
+
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+
21
+ SCALING_FACTOR = 0.476986
22
+ VAE_VER = "884-16c-hy" # We don't support other versions currently
23
+
24
+
25
+ def load_vae(
26
+ vae_type: str = "884-16c-hy",
27
+ vae_dtype: Optional[Union[str, torch.dtype]] = None,
28
+ sample_size: tuple = None,
29
+ vae_path: str = None,
30
+ device=None,
31
+ ):
32
+ """the fucntion to load the 3D VAE model
33
+
34
+ Args:
35
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
36
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
37
+ sample_size (tuple, optional): the tiling size. Defaults to None.
38
+ vae_path (str, optional): the path to vae. Defaults to None.
39
+ logger (_type_, optional): logger. Defaults to None.
40
+ device (_type_, optional): device to load vae. Defaults to None.
41
+ """
42
+ if vae_path is None:
43
+ vae_path = VAE_PATH[vae_type]
44
+
45
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
46
+
47
+ # use fixed config for Hunyuan's VAE
48
+ CONFIG_JSON = """{
49
+ "_class_name": "AutoencoderKLCausal3D",
50
+ "_diffusers_version": "0.4.2",
51
+ "act_fn": "silu",
52
+ "block_out_channels": [
53
+ 128,
54
+ 256,
55
+ 512,
56
+ 512
57
+ ],
58
+ "down_block_types": [
59
+ "DownEncoderBlockCausal3D",
60
+ "DownEncoderBlockCausal3D",
61
+ "DownEncoderBlockCausal3D",
62
+ "DownEncoderBlockCausal3D"
63
+ ],
64
+ "in_channels": 3,
65
+ "latent_channels": 16,
66
+ "layers_per_block": 2,
67
+ "norm_num_groups": 32,
68
+ "out_channels": 3,
69
+ "sample_size": 256,
70
+ "sample_tsize": 64,
71
+ "up_block_types": [
72
+ "UpDecoderBlockCausal3D",
73
+ "UpDecoderBlockCausal3D",
74
+ "UpDecoderBlockCausal3D",
75
+ "UpDecoderBlockCausal3D"
76
+ ],
77
+ "scaling_factor": 0.476986,
78
+ "time_compression_ratio": 4,
79
+ "mid_block_add_attention": true
80
+ }"""
81
+
82
+ # config = AutoencoderKLCausal3D.load_config(vae_path)
83
+ config = json.loads(CONFIG_JSON)
84
+
85
+ # import here to avoid circular import
86
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
87
+
88
+ if sample_size:
89
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
90
+ else:
91
+ vae = AutoencoderKLCausal3D.from_config(config)
92
+
93
+ # vae_ckpt = Path(vae_path) / "pytorch_model.pt"
94
+ # assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
95
+
96
+ if vae_path.endswith(".safetensors"):
97
+ from safetensors.torch import load_file
98
+ ckpt = load_file(vae_path)
99
+ else:
100
+ ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
101
+ if "state_dict" in ckpt:
102
+ ckpt = ckpt["state_dict"]
103
+ if any(k.startswith("vae.") for k in ckpt.keys()):
104
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
105
+ vae.load_state_dict(ckpt)
106
+
107
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
108
+ time_compression_ratio = vae.config.time_compression_ratio
109
+
110
+ if vae_dtype is not None:
111
+ vae = vae.to(vae_dtype)
112
+
113
+ vae.requires_grad_(False)
114
+
115
+ logger.info(f"VAE to dtype: {vae.dtype}")
116
+
117
+ if device is not None:
118
+ vae = vae.to(device)
119
+
120
+ vae.eval()
121
+
122
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio
123
+
124
+
125
+ @dataclass
126
+ class DecoderOutput(BaseOutput):
127
+ r"""
128
+ Output of decoding method.
129
+
130
+ Args:
131
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
132
+ The decoded output sample from the last layer of the model.
133
+ """
134
+
135
+ sample: torch.FloatTensor
136
+
137
+
138
+ class EncoderCausal3D(nn.Module):
139
+ r"""
140
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
141
+ """
142
+
143
+ def __init__(
144
+ self,
145
+ in_channels: int = 3,
146
+ out_channels: int = 3,
147
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
148
+ block_out_channels: Tuple[int, ...] = (64,),
149
+ layers_per_block: int = 2,
150
+ norm_num_groups: int = 32,
151
+ act_fn: str = "silu",
152
+ double_z: bool = True,
153
+ mid_block_add_attention=True,
154
+ time_compression_ratio: int = 4,
155
+ spatial_compression_ratio: int = 8,
156
+ ):
157
+ super().__init__()
158
+ self.layers_per_block = layers_per_block
159
+
160
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
161
+ self.mid_block = None
162
+ self.down_blocks = nn.ModuleList([])
163
+
164
+ # down
165
+ output_channel = block_out_channels[0]
166
+ for i, down_block_type in enumerate(down_block_types):
167
+ input_channel = output_channel
168
+ output_channel = block_out_channels[i]
169
+ is_final_block = i == len(block_out_channels) - 1
170
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
171
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
172
+
173
+ if time_compression_ratio == 4:
174
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
175
+ add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
176
+ else:
177
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
178
+
179
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
180
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
181
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
182
+ down_block = get_down_block3d(
183
+ down_block_type,
184
+ num_layers=self.layers_per_block,
185
+ in_channels=input_channel,
186
+ out_channels=output_channel,
187
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
188
+ downsample_stride=downsample_stride,
189
+ resnet_eps=1e-6,
190
+ downsample_padding=0,
191
+ resnet_act_fn=act_fn,
192
+ resnet_groups=norm_num_groups,
193
+ attention_head_dim=output_channel,
194
+ temb_channels=None,
195
+ )
196
+ self.down_blocks.append(down_block)
197
+
198
+ # mid
199
+ self.mid_block = UNetMidBlockCausal3D(
200
+ in_channels=block_out_channels[-1],
201
+ resnet_eps=1e-6,
202
+ resnet_act_fn=act_fn,
203
+ output_scale_factor=1,
204
+ resnet_time_scale_shift="default",
205
+ attention_head_dim=block_out_channels[-1],
206
+ resnet_groups=norm_num_groups,
207
+ temb_channels=None,
208
+ add_attention=mid_block_add_attention,
209
+ )
210
+
211
+ # out
212
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
213
+ self.conv_act = nn.SiLU()
214
+
215
+ conv_out_channels = 2 * out_channels if double_z else out_channels
216
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
217
+
218
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
219
+ r"""The forward method of the `EncoderCausal3D` class."""
220
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
221
+
222
+ sample = self.conv_in(sample)
223
+
224
+ # down
225
+ for down_block in self.down_blocks:
226
+ sample = down_block(sample)
227
+
228
+ # middle
229
+ sample = self.mid_block(sample)
230
+
231
+ # post-process
232
+ sample = self.conv_norm_out(sample)
233
+ sample = self.conv_act(sample)
234
+ sample = self.conv_out(sample)
235
+
236
+ return sample
237
+
238
+
239
+ class DecoderCausal3D(nn.Module):
240
+ r"""
241
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ in_channels: int = 3,
247
+ out_channels: int = 3,
248
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
249
+ block_out_channels: Tuple[int, ...] = (64,),
250
+ layers_per_block: int = 2,
251
+ norm_num_groups: int = 32,
252
+ act_fn: str = "silu",
253
+ norm_type: str = "group", # group, spatial
254
+ mid_block_add_attention=True,
255
+ time_compression_ratio: int = 4,
256
+ spatial_compression_ratio: int = 8,
257
+ ):
258
+ super().__init__()
259
+ self.layers_per_block = layers_per_block
260
+
261
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
262
+ self.mid_block = None
263
+ self.up_blocks = nn.ModuleList([])
264
+
265
+ temb_channels = in_channels if norm_type == "spatial" else None
266
+
267
+ # mid
268
+ self.mid_block = UNetMidBlockCausal3D(
269
+ in_channels=block_out_channels[-1],
270
+ resnet_eps=1e-6,
271
+ resnet_act_fn=act_fn,
272
+ output_scale_factor=1,
273
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
274
+ attention_head_dim=block_out_channels[-1],
275
+ resnet_groups=norm_num_groups,
276
+ temb_channels=temb_channels,
277
+ add_attention=mid_block_add_attention,
278
+ )
279
+
280
+ # up
281
+ reversed_block_out_channels = list(reversed(block_out_channels))
282
+ output_channel = reversed_block_out_channels[0]
283
+ for i, up_block_type in enumerate(up_block_types):
284
+ prev_output_channel = output_channel
285
+ output_channel = reversed_block_out_channels[i]
286
+ is_final_block = i == len(block_out_channels) - 1
287
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
288
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
289
+
290
+ if time_compression_ratio == 4:
291
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
292
+ add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
293
+ else:
294
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
295
+
296
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
297
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
298
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
299
+ up_block = get_up_block3d(
300
+ up_block_type,
301
+ num_layers=self.layers_per_block + 1,
302
+ in_channels=prev_output_channel,
303
+ out_channels=output_channel,
304
+ prev_output_channel=None,
305
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
306
+ upsample_scale_factor=upsample_scale_factor,
307
+ resnet_eps=1e-6,
308
+ resnet_act_fn=act_fn,
309
+ resnet_groups=norm_num_groups,
310
+ attention_head_dim=output_channel,
311
+ temb_channels=temb_channels,
312
+ resnet_time_scale_shift=norm_type,
313
+ )
314
+ self.up_blocks.append(up_block)
315
+ prev_output_channel = output_channel
316
+
317
+ # out
318
+ if norm_type == "spatial":
319
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
320
+ else:
321
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
322
+ self.conv_act = nn.SiLU()
323
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
324
+
325
+ self.gradient_checkpointing = False
326
+
327
+ def forward(
328
+ self,
329
+ sample: torch.FloatTensor,
330
+ latent_embeds: Optional[torch.FloatTensor] = None,
331
+ ) -> torch.FloatTensor:
332
+ r"""The forward method of the `DecoderCausal3D` class."""
333
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
334
+
335
+ sample = self.conv_in(sample)
336
+
337
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
338
+ if self.training and self.gradient_checkpointing:
339
+
340
+ def create_custom_forward(module):
341
+ def custom_forward(*inputs):
342
+ return module(*inputs)
343
+
344
+ return custom_forward
345
+
346
+ if is_torch_version(">=", "1.11.0"):
347
+ # middle
348
+ sample = torch.utils.checkpoint.checkpoint(
349
+ create_custom_forward(self.mid_block),
350
+ sample,
351
+ latent_embeds,
352
+ use_reentrant=False,
353
+ )
354
+ sample = sample.to(upscale_dtype)
355
+
356
+ # up
357
+ for up_block in self.up_blocks:
358
+ sample = torch.utils.checkpoint.checkpoint(
359
+ create_custom_forward(up_block),
360
+ sample,
361
+ latent_embeds,
362
+ use_reentrant=False,
363
+ )
364
+ else:
365
+ # middle
366
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
367
+ sample = sample.to(upscale_dtype)
368
+
369
+ # up
370
+ for up_block in self.up_blocks:
371
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
372
+ else:
373
+ # middle
374
+ sample = self.mid_block(sample, latent_embeds)
375
+ sample = sample.to(upscale_dtype)
376
+
377
+ # up
378
+ for up_block in self.up_blocks:
379
+ sample = up_block(sample, latent_embeds)
380
+
381
+ # post-process
382
+ if latent_embeds is None:
383
+ sample = self.conv_norm_out(sample)
384
+ else:
385
+ sample = self.conv_norm_out(sample, latent_embeds)
386
+ sample = self.conv_act(sample)
387
+ sample = self.conv_out(sample)
388
+
389
+ return sample
390
+
391
+
392
+ class DiagonalGaussianDistribution(object):
393
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
394
+ if parameters.ndim == 3:
395
+ dim = 2 # (B, L, C)
396
+ elif parameters.ndim == 5 or parameters.ndim == 4:
397
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
398
+ else:
399
+ raise NotImplementedError
400
+ self.parameters = parameters
401
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
402
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
403
+ self.deterministic = deterministic
404
+ self.std = torch.exp(0.5 * self.logvar)
405
+ self.var = torch.exp(self.logvar)
406
+ if self.deterministic:
407
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
408
+
409
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
410
+ # make sure sample is on the same device as the parameters and has same dtype
411
+ sample = randn_tensor(
412
+ self.mean.shape,
413
+ generator=generator,
414
+ device=self.parameters.device,
415
+ dtype=self.parameters.dtype,
416
+ )
417
+ x = self.mean + self.std * sample
418
+ return x
419
+
420
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
421
+ if self.deterministic:
422
+ return torch.Tensor([0.0])
423
+ else:
424
+ reduce_dim = list(range(1, self.mean.ndim))
425
+ if other is None:
426
+ return 0.5 * torch.sum(
427
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
428
+ dim=reduce_dim,
429
+ )
430
+ else:
431
+ return 0.5 * torch.sum(
432
+ torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
433
+ dim=reduce_dim,
434
+ )
435
+
436
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
437
+ if self.deterministic:
438
+ return torch.Tensor([0.0])
439
+ logtwopi = np.log(2.0 * np.pi)
440
+ return 0.5 * torch.sum(
441
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
442
+ dim=dims,
443
+ )
444
+
445
+ def mode(self) -> torch.Tensor:
446
+ return self.mean
hv_generate_video.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import random
5
+ import sys
6
+ import os
7
+ import time
8
+ from typing import Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ import accelerate
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from transformers.models.llama import LlamaModel
16
+ from tqdm import tqdm
17
+ import av
18
+ from einops import rearrange
19
+ from safetensors.torch import load_file, save_file
20
+ from safetensors import safe_open
21
+ from PIL import Image
22
+
23
+ from hunyuan_model import vae
24
+ from hunyuan_model.text_encoder import TextEncoder
25
+ from hunyuan_model.text_encoder import PROMPT_TEMPLATE
26
+ from hunyuan_model.vae import load_vae
27
+ from hunyuan_model.models import load_transformer, get_rotary_pos_embed
28
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
29
+ from networks import lora
30
+
31
+ try:
32
+ from lycoris.kohya import create_network_from_weights
33
+ except:
34
+ pass
35
+
36
+ from utils.model_utils import str_to_dtype
37
+ from utils.safetensors_utils import mem_eff_save_file
38
+ from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket
39
+
40
+ import logging
41
+
42
+ logger = logging.getLogger(__name__)
43
+ logging.basicConfig(level=logging.INFO)
44
+
45
+
46
+ def clean_memory_on_device(device):
47
+ if device.type == "cuda":
48
+ torch.cuda.empty_cache()
49
+ elif device.type == "cpu":
50
+ pass
51
+ elif device.type == "mps": # not tested
52
+ torch.mps.empty_cache()
53
+
54
+
55
+ def synchronize_device(device: torch.device):
56
+ if device.type == "cuda":
57
+ torch.cuda.synchronize()
58
+ elif device.type == "xpu":
59
+ torch.xpu.synchronize()
60
+ elif device.type == "mps":
61
+ torch.mps.synchronize()
62
+
63
+
64
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
65
+ """save videos by video tensor
66
+ copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
67
+
68
+ Args:
69
+ videos (torch.Tensor): video tensor predicted by the model
70
+ path (str): path to save video
71
+ rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
72
+ n_rows (int, optional): Defaults to 1.
73
+ fps (int, optional): video save fps. Defaults to 8.
74
+ """
75
+ videos = rearrange(videos, "b c t h w -> t b c h w")
76
+ outputs = []
77
+ for x in videos:
78
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
79
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
80
+ if rescale:
81
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
82
+ x = torch.clamp(x, 0, 1)
83
+ x = (x * 255).numpy().astype(np.uint8)
84
+ outputs.append(x)
85
+
86
+ os.makedirs(os.path.dirname(path), exist_ok=True)
87
+
88
+ # # save video with av
89
+ # container = av.open(path, "w")
90
+ # stream = container.add_stream("libx264", rate=fps)
91
+ # for x in outputs:
92
+ # frame = av.VideoFrame.from_ndarray(x, format="rgb24")
93
+ # packet = stream.encode(frame)
94
+ # container.mux(packet)
95
+ # packet = stream.encode(None)
96
+ # container.mux(packet)
97
+ # container.close()
98
+
99
+ height, width, _ = outputs[0].shape
100
+
101
+ # create output container
102
+ container = av.open(path, mode="w")
103
+
104
+ # create video stream
105
+ codec = "libx264"
106
+ pixel_format = "yuv420p"
107
+ stream = container.add_stream(codec, rate=fps)
108
+ stream.width = width
109
+ stream.height = height
110
+ stream.pix_fmt = pixel_format
111
+ stream.bit_rate = 4000000 # 4Mbit/s
112
+
113
+ for frame_array in outputs:
114
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
115
+ packets = stream.encode(frame)
116
+ for packet in packets:
117
+ container.mux(packet)
118
+
119
+ for packet in stream.encode():
120
+ container.mux(packet)
121
+
122
+ container.close()
123
+
124
+
125
+ def save_images_grid(
126
+ videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True
127
+ ):
128
+ videos = rearrange(videos, "b c t h w -> t b c h w")
129
+ outputs = []
130
+ for x in videos:
131
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
132
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
133
+ if rescale:
134
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
135
+ x = torch.clamp(x, 0, 1)
136
+ x = (x * 255).numpy().astype(np.uint8)
137
+ outputs.append(x)
138
+
139
+ if create_subdir:
140
+ output_dir = os.path.join(parent_dir, image_name)
141
+ else:
142
+ output_dir = parent_dir
143
+
144
+ os.makedirs(output_dir, exist_ok=True)
145
+ for i, x in enumerate(outputs):
146
+ image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png")
147
+ image = Image.fromarray(x)
148
+ image.save(image_path)
149
+
150
+
151
+ # region Encoding prompt
152
+
153
+
154
+ def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
155
+ r"""
156
+ Encodes the prompt into text encoder hidden states.
157
+
158
+ Args:
159
+ prompt (`str` or `List[str]`):
160
+ prompt to be encoded
161
+ device: (`torch.device`):
162
+ torch device
163
+ num_videos_per_prompt (`int`):
164
+ number of videos that should be generated per prompt
165
+ text_encoder (TextEncoder):
166
+ text encoder to be used for encoding the prompt
167
+ """
168
+ # LoRA and Textual Inversion are not supported in this script
169
+ # negative prompt and prompt embedding are not supported in this script
170
+ # clip_skip is not supported in this script because it is not used in the original script
171
+ data_type = "video" # video only, image is not supported
172
+
173
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
174
+
175
+ with torch.no_grad():
176
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
177
+ prompt_embeds = prompt_outputs.hidden_state
178
+
179
+ attention_mask = prompt_outputs.attention_mask
180
+ if attention_mask is not None:
181
+ attention_mask = attention_mask.to(device)
182
+ bs_embed, seq_len = attention_mask.shape
183
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
184
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
185
+
186
+ prompt_embeds_dtype = text_encoder.dtype
187
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
188
+
189
+ if prompt_embeds.ndim == 2:
190
+ bs_embed, _ = prompt_embeds.shape
191
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
192
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
193
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
194
+ else:
195
+ bs_embed, seq_len, _ = prompt_embeds.shape
196
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
197
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
198
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
199
+
200
+ return prompt_embeds, attention_mask
201
+
202
+
203
+ def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None):
204
+ # constants
205
+ prompt_template_video = "dit-llm-encode-video"
206
+ prompt_template = "dit-llm-encode"
207
+ text_encoder_dtype = torch.float16
208
+ text_encoder_type = "llm"
209
+ text_len = 256
210
+ hidden_state_skip_layer = 2
211
+ apply_final_norm = False
212
+ reproduce = False
213
+
214
+ text_encoder_2_type = "clipL"
215
+ text_len_2 = 77
216
+
217
+ num_videos = 1
218
+
219
+ # if args.prompt_template_video is not None:
220
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
221
+ # elif args.prompt_template is not None:
222
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
223
+ # else:
224
+ # crop_start = 0
225
+ crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
226
+ max_length = text_len + crop_start
227
+
228
+ # prompt_template
229
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
230
+
231
+ # prompt_template_video
232
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
233
+
234
+ # load text encoders
235
+ logger.info(f"loading text encoder: {args.text_encoder1}")
236
+ text_encoder = TextEncoder(
237
+ text_encoder_type=text_encoder_type,
238
+ max_length=max_length,
239
+ text_encoder_dtype=text_encoder_dtype,
240
+ text_encoder_path=args.text_encoder1,
241
+ tokenizer_type=text_encoder_type,
242
+ prompt_template=prompt_template,
243
+ prompt_template_video=prompt_template_video,
244
+ hidden_state_skip_layer=hidden_state_skip_layer,
245
+ apply_final_norm=apply_final_norm,
246
+ reproduce=reproduce,
247
+ )
248
+ text_encoder.eval()
249
+ if fp8_llm:
250
+ org_dtype = text_encoder.dtype
251
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
252
+ text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
253
+
254
+ # prepare LLM for fp8
255
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
256
+ def forward_hook(module):
257
+ def forward(hidden_states):
258
+ input_dtype = hidden_states.dtype
259
+ hidden_states = hidden_states.to(torch.float32)
260
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
261
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
262
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
263
+
264
+ return forward
265
+
266
+ for module in llama_model.modules():
267
+ if module.__class__.__name__ in ["Embedding"]:
268
+ # print("set", module.__class__.__name__, "to", target_dtype)
269
+ module.to(target_dtype)
270
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
271
+ # print("set", module.__class__.__name__, "hooks")
272
+ module.forward = forward_hook(module)
273
+
274
+ prepare_fp8(text_encoder.model, org_dtype)
275
+
276
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
277
+ text_encoder_2 = TextEncoder(
278
+ text_encoder_type=text_encoder_2_type,
279
+ max_length=text_len_2,
280
+ text_encoder_dtype=text_encoder_dtype,
281
+ text_encoder_path=args.text_encoder2,
282
+ tokenizer_type=text_encoder_2_type,
283
+ reproduce=reproduce,
284
+ )
285
+ text_encoder_2.eval()
286
+
287
+ # encode prompt
288
+ logger.info(f"Encoding prompt with text encoder 1")
289
+ text_encoder.to(device=device)
290
+ if fp8_llm:
291
+ with accelerator.autocast():
292
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
293
+ else:
294
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
295
+ text_encoder = None
296
+ clean_memory_on_device(device)
297
+
298
+ logger.info(f"Encoding prompt with text encoder 2")
299
+ text_encoder_2.to(device=device)
300
+ prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
301
+
302
+ prompt_embeds = prompt_embeds.to("cpu")
303
+ prompt_mask = prompt_mask.to("cpu")
304
+ prompt_embeds_2 = prompt_embeds_2.to("cpu")
305
+ prompt_mask_2 = prompt_mask_2.to("cpu")
306
+
307
+ text_encoder_2 = None
308
+ clean_memory_on_device(device)
309
+
310
+ return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
311
+
312
+
313
+ # endregion
314
+
315
+
316
+ def load_images(image_dir, video_length, bucket_reso):
317
+ image_files = glob_images(image_dir)
318
+ if len(image_files) == 0:
319
+ raise ValueError(f"No image files found in {image_dir}")
320
+ if len(image_files) < video_length:
321
+ raise ValueError(f"Number of images in {image_dir} is less than {video_length}")
322
+
323
+ image_files.sort()
324
+ images = []
325
+ for image_file in image_files[:video_length]:
326
+ image = Image.open(image_file)
327
+ image = resize_image_to_bucket(image, bucket_reso) # returns a numpy array
328
+ images.append(image)
329
+
330
+ return images
331
+
332
+
333
+ def prepare_vae(args, device):
334
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
335
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
336
+ vae.eval()
337
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
338
+
339
+ # set chunk_size to CausalConv3d recursively
340
+ chunk_size = args.vae_chunk_size
341
+ if chunk_size is not None:
342
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
343
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
344
+
345
+ if args.vae_spatial_tile_sample_min_size is not None:
346
+ vae.enable_spatial_tiling(True)
347
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
348
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
349
+ # elif args.vae_tiling:
350
+ else:
351
+ vae.enable_spatial_tiling(True)
352
+
353
+ return vae, vae_dtype
354
+
355
+
356
+ def encode_to_latents(args, video, device):
357
+ vae, vae_dtype = prepare_vae(args, device)
358
+
359
+ video = video.to(device=device, dtype=vae_dtype)
360
+ video = video * 2 - 1 # 0, 1 -> -1, 1
361
+ with torch.no_grad():
362
+ latents = vae.encode(video).latent_dist.sample()
363
+
364
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
365
+ latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
366
+ else:
367
+ latents = latents * vae.config.scaling_factor
368
+
369
+ return latents
370
+
371
+
372
+ def decode_latents(args, latents, device):
373
+ vae, vae_dtype = prepare_vae(args, device)
374
+
375
+ expand_temporal_dim = False
376
+ if len(latents.shape) == 4:
377
+ latents = latents.unsqueeze(2)
378
+ expand_temporal_dim = True
379
+ elif len(latents.shape) == 5:
380
+ pass
381
+ else:
382
+ raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
383
+
384
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
385
+ latents = latents / vae.config.scaling_factor + vae.config.shift_factor
386
+ else:
387
+ latents = latents / vae.config.scaling_factor
388
+
389
+ latents = latents.to(device=device, dtype=vae_dtype)
390
+ with torch.no_grad():
391
+ image = vae.decode(latents, return_dict=False)[0]
392
+
393
+ if expand_temporal_dim:
394
+ image = image.squeeze(2)
395
+
396
+ image = (image / 2 + 0.5).clamp(0, 1)
397
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
398
+ image = image.cpu().float()
399
+
400
+ return image
401
+
402
+
403
+ def parse_args():
404
+ parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
405
+
406
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
407
+ parser.add_argument(
408
+ "--dit_in_channels",
409
+ type=int,
410
+ default=None,
411
+ help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others",
412
+ )
413
+ parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
414
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
415
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
416
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
417
+
418
+ # LoRA
419
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
420
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
421
+ parser.add_argument(
422
+ "--save_merged_model",
423
+ type=str,
424
+ default=None,
425
+ help="Save merged model to path. If specified, no inference will be performed.",
426
+ )
427
+ parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights")
428
+
429
+ # inference
430
+ parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
431
+ parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation")
432
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
433
+ parser.add_argument("--video_length", type=int, default=129, help="video length")
434
+ parser.add_argument("--fps", type=int, default=24, help="video fps")
435
+ parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
436
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
437
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
438
+ parser.add_argument(
439
+ "--guidance_scale",
440
+ type=float,
441
+ default=1.0,
442
+ help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)",
443
+ )
444
+ parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
445
+ parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
446
+ parser.add_argument(
447
+ "--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model"
448
+ )
449
+ parser.add_argument(
450
+ "--split_uncond",
451
+ action="store_true",
452
+ help="split unconditional call for classifier free guidance, slower but less memory usage",
453
+ )
454
+ parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference")
455
+
456
+ # Flow Matching
457
+ parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
458
+
459
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
460
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
461
+ parser.add_argument(
462
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
463
+ )
464
+ parser.add_argument(
465
+ "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode"
466
+ )
467
+ parser.add_argument(
468
+ "--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True"
469
+ )
470
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
471
+ parser.add_argument(
472
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
473
+ )
474
+ parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
475
+ parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
476
+ parser.add_argument(
477
+ "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
478
+ )
479
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
480
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
481
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
482
+
483
+ args = parser.parse_args()
484
+
485
+ assert (args.latent_path is None or len(args.latent_path) == 0) or (
486
+ args.output_type == "images" or args.output_type == "video"
487
+ ), "latent_path is only supported for images or video output"
488
+
489
+ # update dit_weight based on model_base if not exists
490
+
491
+ return args
492
+
493
+
494
+ def check_inputs(args):
495
+ height = args.video_size[0]
496
+ width = args.video_size[1]
497
+ video_length = args.video_length
498
+
499
+ if height % 8 != 0 or width % 8 != 0:
500
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
501
+ return height, width, video_length
502
+
503
+
504
+ def main():
505
+ args = parse_args()
506
+
507
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
508
+ device = torch.device(device)
509
+ dit_dtype = torch.bfloat16
510
+ dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
511
+ logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
512
+
513
+ original_base_names = None
514
+ if args.latent_path is not None and len(args.latent_path) > 0:
515
+ original_base_names = []
516
+ latents_list = []
517
+ seeds = []
518
+ for latent_path in args.latent_path:
519
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
520
+ seed = 0
521
+
522
+ if os.path.splitext(latent_path)[1] != ".safetensors":
523
+ latents = torch.load(latent_path, map_location="cpu")
524
+ else:
525
+ latents = load_file(latent_path)["latent"]
526
+ with safe_open(latent_path, framework="pt") as f:
527
+ metadata = f.metadata()
528
+ if metadata is None:
529
+ metadata = {}
530
+ logger.info(f"Loaded metadata: {metadata}")
531
+
532
+ if "seeds" in metadata:
533
+ seed = int(metadata["seeds"])
534
+
535
+ seeds.append(seed)
536
+ latents_list.append(latents)
537
+
538
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
539
+ latents = torch.stack(latents_list, dim=0)
540
+ else:
541
+ # prepare accelerator
542
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
543
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
544
+
545
+ # load prompt
546
+ prompt = args.prompt # TODO load prompts from file
547
+ assert prompt is not None, "prompt is required"
548
+
549
+ # check inputs: may be height, width, video_length etc will be changed for each generation in future
550
+ height, width, video_length = check_inputs(args)
551
+
552
+ # encode prompt with LLM and Text Encoder
553
+ logger.info(f"Encoding prompt: {prompt}")
554
+
555
+ do_classifier_free_guidance = args.guidance_scale != 1.0
556
+ if do_classifier_free_guidance:
557
+ negative_prompt = args.negative_prompt
558
+ if negative_prompt is None:
559
+ logger.info("Negative prompt is not provided, using empty prompt")
560
+ negative_prompt = ""
561
+ logger.info(f"Encoding negative prompt: {negative_prompt}")
562
+ prompt = [negative_prompt, prompt]
563
+ else:
564
+ if args.negative_prompt is not None:
565
+ logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.")
566
+
567
+ prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
568
+ prompt, args, device, args.fp8_llm, accelerator
569
+ )
570
+
571
+ # encode latents for video2video inference
572
+ video_latents = None
573
+ if args.video_path is not None:
574
+ # v2v inference
575
+ logger.info(f"Video2Video inference: {args.video_path}")
576
+
577
+ if os.path.isfile(args.video_path):
578
+ video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
579
+ else:
580
+ video = load_images(args.video_path, video_length, bucket_reso=(width, height)) # list of frames
581
+
582
+ if len(video) < video_length:
583
+ raise ValueError(f"Video length is less than {video_length}")
584
+ video = np.stack(video, axis=0) # F, H, W, C
585
+ video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W
586
+ video = video / 255.0
587
+
588
+ logger.info(f"Encoding video to latents")
589
+ video_latents = encode_to_latents(args, video, device)
590
+ video_latents = video_latents.to(device=device, dtype=dit_dtype)
591
+
592
+ clean_memory_on_device(device)
593
+
594
+ # encode latents for image2video inference
595
+ image_latents = None
596
+ if args.image_path is not None:
597
+ # i2v inference
598
+ logger.info(f"Image2Video inference: {args.image_path}")
599
+
600
+ image = Image.open(args.image_path)
601
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
602
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W
603
+ image = image / 255.0
604
+
605
+ logger.info(f"Encoding image to latents")
606
+ image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W
607
+ image_latents = image_latents.to(device=device, dtype=dit_dtype)
608
+
609
+ clean_memory_on_device(device)
610
+
611
+ # load DiT model
612
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
613
+ loading_device = "cpu" # if blocks_to_swap > 0 else device
614
+
615
+ logger.info(f"Loading DiT model from {args.dit}")
616
+ if args.attn_mode == "sdpa":
617
+ args.attn_mode = "torch"
618
+
619
+ # if image_latents is given, the model should be I2V model, so the in_channels should be 32
620
+ dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16)
621
+
622
+ # if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16
623
+ # the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway
624
+ # on the fly merging will be a solution for this issue for .safetenors files (not implemented yet)
625
+ transformer = load_transformer(
626
+ args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels
627
+ )
628
+ transformer.eval()
629
+
630
+ # load LoRA weights
631
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
632
+ for i, lora_weight in enumerate(args.lora_weight):
633
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
634
+ lora_multiplier = args.lora_multiplier[i]
635
+ else:
636
+ lora_multiplier = 1.0
637
+
638
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
639
+ weights_sd = load_file(lora_weight)
640
+
641
+ # Filter to exclude keys that are part of single_blocks
642
+ if args.exclude_single_blocks:
643
+ filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k}
644
+ weights_sd = filtered_weights
645
+
646
+ if args.lycoris:
647
+ lycoris_net, _ = create_network_from_weights(
648
+ multiplier=lora_multiplier,
649
+ file=None,
650
+ weights_sd=weights_sd,
651
+ unet=transformer,
652
+ text_encoder=None,
653
+ vae=None,
654
+ for_inference=True,
655
+ )
656
+ else:
657
+ network = lora.create_arch_network_from_weights(
658
+ lora_multiplier, weights_sd, unet=transformer, for_inference=True
659
+ )
660
+ logger.info("Merging LoRA weights to DiT model")
661
+
662
+ # try:
663
+ # network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
664
+ # info = network.load_state_dict(weights_sd, strict=True)
665
+ # logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
666
+ # network.eval()
667
+ # network.to(device)
668
+ # except Exception as e:
669
+ if args.lycoris:
670
+ lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device)
671
+ else:
672
+ network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
673
+
674
+ synchronize_device(device)
675
+
676
+ logger.info("LoRA weights loaded")
677
+
678
+ # save model here before casting to dit_weight_dtype
679
+ if args.save_merged_model:
680
+ logger.info(f"Saving merged model to {args.save_merged_model}")
681
+ mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory
682
+ logger.info("Merged model saved")
683
+ return
684
+
685
+ if blocks_to_swap > 0:
686
+ logger.info(f"Casting model to {dit_weight_dtype}")
687
+ transformer.to(dtype=dit_weight_dtype)
688
+ logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
689
+ transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
690
+ transformer.move_to_device_except_swap_blocks(device)
691
+ transformer.prepare_block_swap_before_forward()
692
+ else:
693
+ logger.info(f"Moving and casting model to {device} and {dit_weight_dtype}")
694
+ transformer.to(device=device, dtype=dit_weight_dtype)
695
+ if args.img_in_txt_in_offloading:
696
+ logger.info("Enable offloading img_in and txt_in to CPU")
697
+ transformer.enable_img_in_txt_in_offloading()
698
+
699
+ # load scheduler
700
+ logger.info(f"Loading scheduler")
701
+ scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
702
+
703
+ # Prepare timesteps
704
+ num_inference_steps = args.infer_steps
705
+ scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
706
+ timesteps = scheduler.timesteps
707
+
708
+ # Prepare generator
709
+ num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size
710
+ seed = args.seed
711
+ if seed is None:
712
+ seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)]
713
+ elif isinstance(seed, int):
714
+ seeds = [seed + i for i in range(num_videos_per_prompt)]
715
+ else:
716
+ raise ValueError(f"Seed must be an integer or None, got {seed}.")
717
+ generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
718
+
719
+ # Prepare noisy latents
720
+ num_channels_latents = 16 # transformer.config.in_channels
721
+ vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
722
+
723
+ vae_ver = vae.VAE_VER
724
+ if "884" in vae_ver:
725
+ latent_video_length = (video_length - 1) // 4 + 1
726
+ elif "888" in vae_ver:
727
+ latent_video_length = (video_length - 1) // 8 + 1
728
+ else:
729
+ latent_video_length = video_length
730
+
731
+ # shape = (
732
+ # num_videos_per_prompt,
733
+ # num_channels_latents,
734
+ # latent_video_length,
735
+ # height // vae_scale_factor,
736
+ # width // vae_scale_factor,
737
+ # )
738
+ # latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
739
+
740
+ # make first N frames to be the same if the given seed is same
741
+ shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor)
742
+ latents = []
743
+ for i in range(latent_video_length):
744
+ latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype))
745
+ latents = torch.cat(latents, dim=2)
746
+
747
+ # pad image_latents to match the length of video_latents
748
+ if image_latents is not None:
749
+ zero_latents = torch.zeros_like(latents)
750
+ zero_latents[:, :, :1, :, :] = image_latents
751
+ image_latents = zero_latents
752
+
753
+ if args.video_path is not None:
754
+ # v2v inference
755
+ noise = latents
756
+ assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}"
757
+
758
+ num_inference_steps = int(num_inference_steps * args.strength)
759
+ timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time
760
+ t = timestep_start / 1000.0
761
+ latents = noise * t + video_latents * (1 - t)
762
+
763
+ timesteps = timesteps[-num_inference_steps:]
764
+
765
+ logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}")
766
+
767
+ # FlowMatchDiscreteScheduler does not have init_noise_sigma
768
+
769
+ # Denoising loop
770
+ embedded_guidance_scale = args.embedded_cfg_scale
771
+ if embedded_guidance_scale is not None:
772
+ guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
773
+ guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
774
+ if do_classifier_free_guidance:
775
+ guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0)
776
+ else:
777
+ guidance_expand = None
778
+ freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width)
779
+ # n_tokens = freqs_cos.shape[0]
780
+
781
+ # move and cast all inputs to the correct device and dtype
782
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
783
+ prompt_mask = prompt_mask.to(device=device)
784
+ prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
785
+ prompt_mask_2 = prompt_mask_2.to(device=device)
786
+
787
+ freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
788
+ freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
789
+
790
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference
791
+
792
+ # assert split_uncond and split_attn
793
+ if args.split_attn and do_classifier_free_guidance and not args.split_uncond:
794
+ logger.warning("split_attn is enabled, split_uncond will be enabled as well.")
795
+ args.split_uncond = True
796
+
797
+ # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
798
+ with tqdm(total=num_inference_steps) as progress_bar:
799
+ for i, t in enumerate(timesteps):
800
+ latents = scheduler.scale_model_input(latents, t)
801
+
802
+ # predict the noise residual
803
+ with torch.no_grad(), accelerator.autocast():
804
+ latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0)
805
+ if image_latents is not None:
806
+ latents_image_input = (
807
+ image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0)
808
+ )
809
+ latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W
810
+
811
+ batch_size = 1 if args.split_uncond else latents_input.shape[0]
812
+
813
+ noise_pred_list = []
814
+ for j in range(0, latents_input.shape[0], batch_size):
815
+ noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
816
+ latents_input[j : j + batch_size], # [1, 16, 33, 24, 42]
817
+ t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1]
818
+ text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096]
819
+ text_mask=prompt_mask[j : j + batch_size], # [1, 256]
820
+ text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768]
821
+ freqs_cos=freqs_cos, # [seqlen, head_dim]
822
+ freqs_sin=freqs_sin, # [seqlen, head_dim]
823
+ guidance=guidance_expand[j : j + batch_size], # [1]
824
+ return_dict=True,
825
+ )["x"]
826
+ noise_pred_list.append(noise_pred)
827
+ noise_pred = torch.cat(noise_pred_list, dim=0)
828
+
829
+ # perform classifier free guidance
830
+ if do_classifier_free_guidance:
831
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
832
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
833
+
834
+ # # SkyReels' rescale noise config is omitted for now
835
+ # if guidance_rescale > 0.0:
836
+ # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
837
+ # noise_pred = rescale_noise_cfg(
838
+ # noise_pred,
839
+ # noise_pred_cond,
840
+ # guidance_rescale=self.guidance_rescale,
841
+ # )
842
+
843
+ # compute the previous noisy sample x_t -> x_t-1
844
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
845
+
846
+ # update progress bar
847
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
848
+ if progress_bar is not None:
849
+ progress_bar.update()
850
+
851
+ # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
852
+ # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
853
+
854
+ latents = latents.detach().cpu()
855
+ transformer = None
856
+ clean_memory_on_device(device)
857
+
858
+ # Save samples
859
+ output_type = args.output_type
860
+ save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
861
+ os.makedirs(save_path, exist_ok=True)
862
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
863
+
864
+ if output_type == "latent" or output_type == "both":
865
+ # save latent
866
+ for i, latent in enumerate(latents):
867
+ latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors"
868
+
869
+ if args.no_metadata:
870
+ metadata = None
871
+ else:
872
+ metadata = {
873
+ "seeds": f"{seeds[i]}",
874
+ "prompt": f"{args.prompt}",
875
+ "height": f"{height}",
876
+ "width": f"{width}",
877
+ "video_length": f"{video_length}",
878
+ "infer_steps": f"{num_inference_steps}",
879
+ "guidance_scale": f"{args.guidance_scale}",
880
+ "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
881
+ }
882
+ if args.negative_prompt is not None:
883
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
884
+ sd = {"latent": latent}
885
+ save_file(sd, latent_path, metadata=metadata)
886
+
887
+ logger.info(f"Latent save to: {latent_path}")
888
+ if output_type == "video" or output_type == "both":
889
+ # save video
890
+ videos = decode_latents(args, latents, device)
891
+ for i, sample in enumerate(videos):
892
+ original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
893
+ sample = sample.unsqueeze(0)
894
+ video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4"
895
+ save_videos_grid(sample, video_path, fps=args.fps)
896
+ logger.info(f"Sample save to: {video_path}")
897
+ elif output_type == "images":
898
+ # save images
899
+ videos = decode_latents(args, latents, device)
900
+ for i, sample in enumerate(videos):
901
+ original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
902
+ sample = sample.unsqueeze(0)
903
+ image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}"
904
+ save_images_grid(sample, save_path, image_name)
905
+ logger.info(f"Sample images save to: {save_path}/{image_name}")
906
+
907
+ logger.info("Done!")
908
+
909
+
910
+ if __name__ == "__main__":
911
+ main()
hv_train.py ADDED
@@ -0,0 +1,1721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import asyncio
3
+ from datetime import timedelta
4
+ import gc
5
+ import importlib
6
+ import argparse
7
+ import math
8
+ import os
9
+ import pathlib
10
+ import re
11
+ import sys
12
+ import random
13
+ import time
14
+ import json
15
+ from multiprocessing import Value
16
+ from typing import Any, Dict, List, Optional
17
+ import accelerate
18
+ import numpy as np
19
+ from packaging.version import Version
20
+
21
+ import huggingface_hub
22
+ import toml
23
+
24
+ import torch
25
+ from tqdm import tqdm
26
+ from accelerate.utils import set_seed
27
+ from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
28
+ from safetensors.torch import load_file, save_file
29
+ import transformers
30
+ from diffusers.optimization import (
31
+ SchedulerType as DiffusersSchedulerType,
32
+ TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
33
+ )
34
+ from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
35
+
36
+ from dataset import config_utils
37
+ from hunyuan_model.models import load_transformer, get_rotary_pos_embed_by_shape
38
+ import hunyuan_model.text_encoder as text_encoder_module
39
+ from hunyuan_model.vae import load_vae
40
+ import hunyuan_model.vae as vae_module
41
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
42
+ import networks.lora as lora_module
43
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
44
+ from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO
45
+
46
+ import logging
47
+
48
+ from utils import huggingface_utils, model_utils, train_utils, sai_model_spec
49
+
50
+ logger = logging.getLogger(__name__)
51
+ logging.basicConfig(level=logging.INFO)
52
+
53
+
54
+ BASE_MODEL_VERSION_HUNYUAN_VIDEO = "hunyuan_video"
55
+
56
+ # TODO make separate file for some functions to commonize with other scripts
57
+
58
+
59
+ def clean_memory_on_device(device: torch.device):
60
+ r"""
61
+ Clean memory on the specified device, will be called from training scripts.
62
+ """
63
+ gc.collect()
64
+
65
+ # device may "cuda" or "cuda:0", so we need to check the type of device
66
+ if device.type == "cuda":
67
+ torch.cuda.empty_cache()
68
+ if device.type == "xpu":
69
+ torch.xpu.empty_cache()
70
+ if device.type == "mps":
71
+ torch.mps.empty_cache()
72
+
73
+
74
+ # for collate_fn: epoch and step is multiprocessing.Value
75
+ class collator_class:
76
+ def __init__(self, epoch, step, dataset):
77
+ self.current_epoch = epoch
78
+ self.current_step = step
79
+ self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
80
+
81
+ def __call__(self, examples):
82
+ worker_info = torch.utils.data.get_worker_info()
83
+ # worker_info is None in the main process
84
+ if worker_info is not None:
85
+ dataset = worker_info.dataset
86
+ else:
87
+ dataset = self.dataset
88
+
89
+ # set epoch and step
90
+ dataset.set_current_epoch(self.current_epoch.value)
91
+ dataset.set_current_step(self.current_step.value)
92
+ return examples[0]
93
+
94
+
95
+ def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
96
+ """
97
+ DeepSpeed is not supported in this script currently.
98
+ """
99
+ if args.logging_dir is None:
100
+ logging_dir = None
101
+ else:
102
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
103
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
104
+
105
+ if args.log_with is None:
106
+ if logging_dir is not None:
107
+ log_with = "tensorboard"
108
+ else:
109
+ log_with = None
110
+ else:
111
+ log_with = args.log_with
112
+ if log_with in ["tensorboard", "all"]:
113
+ if logging_dir is None:
114
+ raise ValueError(
115
+ "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
116
+ )
117
+ if log_with in ["wandb", "all"]:
118
+ try:
119
+ import wandb
120
+ except ImportError:
121
+ raise ImportError("No wandb / wandb がインストールされていないようです")
122
+ if logging_dir is not None:
123
+ os.makedirs(logging_dir, exist_ok=True)
124
+ os.environ["WANDB_DIR"] = logging_dir
125
+ if args.wandb_api_key is not None:
126
+ wandb.login(key=args.wandb_api_key)
127
+
128
+ kwargs_handlers = [
129
+ (
130
+ InitProcessGroupKwargs(
131
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
132
+ init_method=(
133
+ "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
134
+ ),
135
+ timeout=timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
136
+ )
137
+ if torch.cuda.device_count() > 1
138
+ else None
139
+ ),
140
+ (
141
+ DistributedDataParallelKwargs(
142
+ gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
143
+ )
144
+ if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
145
+ else None
146
+ ),
147
+ ]
148
+ kwargs_handlers = [i for i in kwargs_handlers if i is not None]
149
+
150
+ accelerator = Accelerator(
151
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
152
+ mixed_precision=args.mixed_precision,
153
+ log_with=log_with,
154
+ project_dir=logging_dir,
155
+ kwargs_handlers=kwargs_handlers,
156
+ )
157
+ print("accelerator device:", accelerator.device)
158
+ return accelerator
159
+
160
+
161
+ def line_to_prompt_dict(line: str) -> dict:
162
+ # subset of gen_img_diffusers
163
+ prompt_args = line.split(" --")
164
+ prompt_dict = {}
165
+ prompt_dict["prompt"] = prompt_args[0]
166
+
167
+ for parg in prompt_args:
168
+ try:
169
+ m = re.match(r"w (\d+)", parg, re.IGNORECASE)
170
+ if m:
171
+ prompt_dict["width"] = int(m.group(1))
172
+ continue
173
+
174
+ m = re.match(r"h (\d+)", parg, re.IGNORECASE)
175
+ if m:
176
+ prompt_dict["height"] = int(m.group(1))
177
+ continue
178
+
179
+ m = re.match(r"f (\d+)", parg, re.IGNORECASE)
180
+ if m:
181
+ prompt_dict["frame_count"] = int(m.group(1))
182
+ continue
183
+
184
+ m = re.match(r"d (\d+)", parg, re.IGNORECASE)
185
+ if m:
186
+ prompt_dict["seed"] = int(m.group(1))
187
+ continue
188
+
189
+ m = re.match(r"s (\d+)", parg, re.IGNORECASE)
190
+ if m: # steps
191
+ prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
192
+ continue
193
+
194
+ # m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
195
+ # if m: # scale
196
+ # prompt_dict["scale"] = float(m.group(1))
197
+ # continue
198
+ # m = re.match(r"n (.+)", parg, re.IGNORECASE)
199
+ # if m: # negative prompt
200
+ # prompt_dict["negative_prompt"] = m.group(1)
201
+ # continue
202
+
203
+ except ValueError as ex:
204
+ logger.error(f"Exception in parsing / 解析エラー: {parg}")
205
+ logger.error(ex)
206
+
207
+ return prompt_dict
208
+
209
+
210
+ def load_prompts(prompt_file: str) -> list[Dict]:
211
+ # read prompts
212
+ if prompt_file.endswith(".txt"):
213
+ with open(prompt_file, "r", encoding="utf-8") as f:
214
+ lines = f.readlines()
215
+ prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
216
+ elif prompt_file.endswith(".toml"):
217
+ with open(prompt_file, "r", encoding="utf-8") as f:
218
+ data = toml.load(f)
219
+ prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
220
+ elif prompt_file.endswith(".json"):
221
+ with open(prompt_file, "r", encoding="utf-8") as f:
222
+ prompts = json.load(f)
223
+
224
+ # preprocess prompts
225
+ for i in range(len(prompts)):
226
+ prompt_dict = prompts[i]
227
+ if isinstance(prompt_dict, str):
228
+ prompt_dict = line_to_prompt_dict(prompt_dict)
229
+ prompts[i] = prompt_dict
230
+ assert isinstance(prompt_dict, dict)
231
+
232
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
233
+ prompt_dict["enum"] = i
234
+ prompt_dict.pop("subset", None)
235
+
236
+ return prompts
237
+
238
+
239
+ def compute_density_for_timestep_sampling(
240
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
241
+ ):
242
+ """Compute the density for sampling the timesteps when doing SD3 training.
243
+
244
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
245
+
246
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
247
+ """
248
+ if weighting_scheme == "logit_normal":
249
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
250
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
251
+ u = torch.nn.functional.sigmoid(u)
252
+ elif weighting_scheme == "mode":
253
+ u = torch.rand(size=(batch_size,), device="cpu")
254
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
255
+ else:
256
+ u = torch.rand(size=(batch_size,), device="cpu")
257
+ return u
258
+
259
+
260
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
261
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
262
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
263
+ timesteps = timesteps.to(device)
264
+
265
+ # if sum([(schedule_timesteps == t) for t in timesteps]) < len(timesteps):
266
+ if any([(schedule_timesteps == t).sum() == 0 for t in timesteps]):
267
+ # raise ValueError("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
268
+ # round to nearest timestep
269
+ logger.warning("Some timesteps are not in the schedule / 一部のtimestepsがスケジュールに含まれていません")
270
+ step_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]
271
+ else:
272
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
273
+
274
+ sigma = sigmas[step_indices].flatten()
275
+ while len(sigma.shape) < n_dim:
276
+ sigma = sigma.unsqueeze(-1)
277
+ return sigma
278
+
279
+
280
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, noise_scheduler, timesteps, device, dtype):
281
+ """Computes loss weighting scheme for SD3 training.
282
+
283
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
284
+
285
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
286
+ """
287
+ if weighting_scheme == "sigma_sqrt" or weighting_scheme == "cosmap":
288
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=5, dtype=dtype)
289
+ if weighting_scheme == "sigma_sqrt":
290
+ weighting = (sigmas**-2.0).float()
291
+ else:
292
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
293
+ weighting = 2 / (math.pi * bot)
294
+ else:
295
+ weighting = None # torch.ones_like(sigmas)
296
+ return weighting
297
+
298
+
299
+ class FineTuningTrainer:
300
+ def __init__(self):
301
+ pass
302
+
303
+ def process_sample_prompts(
304
+ self,
305
+ args: argparse.Namespace,
306
+ accelerator: Accelerator,
307
+ sample_prompts: str,
308
+ text_encoder1: str,
309
+ text_encoder2: str,
310
+ fp8_llm: bool,
311
+ ):
312
+ logger.info(f"cache Text Encoder outputs for sample prompt: {sample_prompts}")
313
+ prompts = load_prompts(sample_prompts)
314
+
315
+ def encode_for_text_encoder(text_encoder, is_llm=True):
316
+ sample_prompts_te_outputs = {} # (prompt) -> (embeds, mask)
317
+ with accelerator.autocast(), torch.no_grad():
318
+ for prompt_dict in prompts:
319
+ for p in [prompt_dict.get("prompt", "")]:
320
+ if p not in sample_prompts_te_outputs:
321
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
322
+
323
+ data_type = "video"
324
+ text_inputs = text_encoder.text2tokens(p, data_type=data_type)
325
+
326
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
327
+ sample_prompts_te_outputs[p] = (prompt_outputs.hidden_state, prompt_outputs.attention_mask)
328
+
329
+ return sample_prompts_te_outputs
330
+
331
+ # Load Text Encoder 1 and encode
332
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else model_utils.str_to_dtype(args.text_encoder_dtype)
333
+ logger.info(f"loading text encoder 1: {text_encoder1}")
334
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(text_encoder1, accelerator.device, fp8_llm, text_encoder_dtype)
335
+
336
+ logger.info("encoding with Text Encoder 1")
337
+ te_outputs_1 = encode_for_text_encoder(text_encoder_1)
338
+ del text_encoder_1
339
+
340
+ # Load Text Encoder 2 and encode
341
+ logger.info(f"loading text encoder 2: {text_encoder2}")
342
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(text_encoder2, accelerator.device, text_encoder_dtype)
343
+
344
+ logger.info("encoding with Text Encoder 2")
345
+ te_outputs_2 = encode_for_text_encoder(text_encoder_2, is_llm=False)
346
+ del text_encoder_2
347
+
348
+ # prepare sample parameters
349
+ sample_parameters = []
350
+ for prompt_dict in prompts:
351
+ prompt_dict_copy = prompt_dict.copy()
352
+ p = prompt_dict.get("prompt", "")
353
+ prompt_dict_copy["llm_embeds"] = te_outputs_1[p][0]
354
+ prompt_dict_copy["llm_mask"] = te_outputs_1[p][1]
355
+ prompt_dict_copy["clipL_embeds"] = te_outputs_2[p][0]
356
+ prompt_dict_copy["clipL_mask"] = te_outputs_2[p][1]
357
+ sample_parameters.append(prompt_dict_copy)
358
+
359
+ clean_memory_on_device(accelerator.device)
360
+
361
+ return sample_parameters
362
+
363
+ def get_optimizer(self, args, trainable_params: list[torch.nn.Parameter]) -> tuple[str, str, torch.optim.Optimizer]:
364
+ # adamw, adamw8bit, adafactor
365
+
366
+ optimizer_type = args.optimizer_type.lower()
367
+
368
+ # split optimizer_type and optimizer_args
369
+ optimizer_kwargs = {}
370
+ if args.optimizer_args is not None and len(args.optimizer_args) > 0:
371
+ for arg in args.optimizer_args:
372
+ key, value = arg.split("=")
373
+ value = ast.literal_eval(value)
374
+ optimizer_kwargs[key] = value
375
+
376
+ lr = args.learning_rate
377
+ optimizer = None
378
+ optimizer_class = None
379
+
380
+ if optimizer_type.endswith("8bit".lower()):
381
+ try:
382
+ import bitsandbytes as bnb
383
+ except ImportError:
384
+ raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
385
+
386
+ if optimizer_type == "AdamW8bit".lower():
387
+ logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
388
+ optimizer_class = bnb.optim.AdamW8bit
389
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
390
+
391
+ elif optimizer_type == "Adafactor".lower():
392
+ # Adafactor: check relative_step and warmup_init
393
+ if "relative_step" not in optimizer_kwargs:
394
+ optimizer_kwargs["relative_step"] = True # default
395
+ if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
396
+ logger.info(
397
+ f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
398
+ )
399
+ optimizer_kwargs["relative_step"] = True
400
+ logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
401
+
402
+ if optimizer_kwargs["relative_step"]:
403
+ logger.info(f"relative_step is true / relative_stepがtrueです")
404
+ if lr != 0.0:
405
+ logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
406
+ args.learning_rate = None
407
+
408
+ if args.lr_scheduler != "adafactor":
409
+ logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
410
+ args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
411
+
412
+ lr = None
413
+ else:
414
+ if args.max_grad_norm != 0.0:
415
+ logger.warning(
416
+ f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません"
417
+ )
418
+ if args.lr_scheduler != "constant_with_warmup":
419
+ logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
420
+ if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
421
+ logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
422
+
423
+ optimizer_class = transformers.optimization.Adafactor
424
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
425
+
426
+ elif optimizer_type == "AdamW".lower():
427
+ logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
428
+ optimizer_class = torch.optim.AdamW
429
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
430
+
431
+ if optimizer is None:
432
+ # 任意のoptimizerを使う
433
+ case_sensitive_optimizer_type = args.optimizer_type # not lower
434
+ logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
435
+
436
+ if "." not in case_sensitive_optimizer_type: # from torch.optim
437
+ optimizer_module = torch.optim
438
+ else: # from other library
439
+ values = case_sensitive_optimizer_type.split(".")
440
+ optimizer_module = importlib.import_module(".".join(values[:-1]))
441
+ case_sensitive_optimizer_type = values[-1]
442
+
443
+ optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
444
+ optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
445
+
446
+ # for logging
447
+ optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
448
+ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
449
+
450
+ # get train and eval functions
451
+ if hasattr(optimizer, "train") and callable(optimizer.train):
452
+ train_fn = optimizer.train
453
+ eval_fn = optimizer.eval
454
+ else:
455
+ train_fn = lambda: None
456
+ eval_fn = lambda: None
457
+
458
+ return optimizer_name, optimizer_args, optimizer, train_fn, eval_fn
459
+
460
+ def is_schedulefree_optimizer(self, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> bool:
461
+ return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
462
+
463
+ def get_dummy_scheduler(optimizer: torch.optim.Optimizer) -> Any:
464
+ # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
465
+ # this scheduler is used for logging only.
466
+ # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
467
+ class DummyScheduler:
468
+ def __init__(self, optimizer: torch.optim.Optimizer):
469
+ self.optimizer = optimizer
470
+
471
+ def step(self):
472
+ pass
473
+
474
+ def get_last_lr(self):
475
+ return [group["lr"] for group in self.optimizer.param_groups]
476
+
477
+ return DummyScheduler(optimizer)
478
+
479
+ def get_scheduler(self, args, optimizer: torch.optim.Optimizer, num_processes: int):
480
+ """
481
+ Unified API to get any scheduler from its name.
482
+ """
483
+ # if schedulefree optimizer, return dummy scheduler
484
+ if self.is_schedulefree_optimizer(optimizer, args):
485
+ return self.get_dummy_scheduler(optimizer)
486
+
487
+ name = args.lr_scheduler
488
+ num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
489
+ num_warmup_steps: Optional[int] = (
490
+ int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
491
+ )
492
+ num_decay_steps: Optional[int] = (
493
+ int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
494
+ )
495
+ num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
496
+ num_cycles = args.lr_scheduler_num_cycles
497
+ power = args.lr_scheduler_power
498
+ timescale = args.lr_scheduler_timescale
499
+ min_lr_ratio = args.lr_scheduler_min_lr_ratio
500
+
501
+ lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
502
+ if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
503
+ for arg in args.lr_scheduler_args:
504
+ key, value = arg.split("=")
505
+ value = ast.literal_eval(value)
506
+ lr_scheduler_kwargs[key] = value
507
+
508
+ def wrap_check_needless_num_warmup_steps(return_vals):
509
+ if num_warmup_steps is not None and num_warmup_steps != 0:
510
+ raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
511
+ return return_vals
512
+
513
+ # using any lr_scheduler from other library
514
+ if args.lr_scheduler_type:
515
+ lr_scheduler_type = args.lr_scheduler_type
516
+ logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
517
+ if "." not in lr_scheduler_type: # default to use torch.optim
518
+ lr_scheduler_module = torch.optim.lr_scheduler
519
+ else:
520
+ values = lr_scheduler_type.split(".")
521
+ lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
522
+ lr_scheduler_type = values[-1]
523
+ lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
524
+ lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
525
+ return lr_scheduler
526
+
527
+ if name.startswith("adafactor"):
528
+ assert (
529
+ type(optimizer) == transformers.optimization.Adafactor
530
+ ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
531
+ initial_lr = float(name.split(":")[1])
532
+ # logger.info(f"adafactor scheduler init lr {initial_lr}")
533
+ return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
534
+
535
+ if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
536
+ name = DiffusersSchedulerType(name)
537
+ schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
538
+ return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
539
+
540
+ name = SchedulerType(name)
541
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
542
+
543
+ if name == SchedulerType.CONSTANT:
544
+ return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
545
+
546
+ # All other schedulers require `num_warmup_steps`
547
+ if num_warmup_steps is None:
548
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
549
+
550
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
551
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
552
+
553
+ if name == SchedulerType.INVERSE_SQRT:
554
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
555
+
556
+ # All other schedulers require `num_training_steps`
557
+ if num_training_steps is None:
558
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
559
+
560
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
561
+ return schedule_func(
562
+ optimizer,
563
+ num_warmup_steps=num_warmup_steps,
564
+ num_training_steps=num_training_steps,
565
+ num_cycles=num_cycles,
566
+ **lr_scheduler_kwargs,
567
+ )
568
+
569
+ if name == SchedulerType.POLYNOMIAL:
570
+ return schedule_func(
571
+ optimizer,
572
+ num_warmup_steps=num_warmup_steps,
573
+ num_training_steps=num_training_steps,
574
+ power=power,
575
+ **lr_scheduler_kwargs,
576
+ )
577
+
578
+ if name == SchedulerType.COSINE_WITH_MIN_LR:
579
+ return schedule_func(
580
+ optimizer,
581
+ num_warmup_steps=num_warmup_steps,
582
+ num_training_steps=num_training_steps,
583
+ num_cycles=num_cycles / 2,
584
+ min_lr_rate=min_lr_ratio,
585
+ **lr_scheduler_kwargs,
586
+ )
587
+
588
+ # these schedulers do not require `num_decay_steps`
589
+ if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
590
+ return schedule_func(
591
+ optimizer,
592
+ num_warmup_steps=num_warmup_steps,
593
+ num_training_steps=num_training_steps,
594
+ **lr_scheduler_kwargs,
595
+ )
596
+
597
+ # All other schedulers require `num_decay_steps`
598
+ if num_decay_steps is None:
599
+ raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
600
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
601
+ return schedule_func(
602
+ optimizer,
603
+ num_warmup_steps=num_warmup_steps,
604
+ num_stable_steps=num_stable_steps,
605
+ num_decay_steps=num_decay_steps,
606
+ num_cycles=num_cycles / 2,
607
+ min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
608
+ **lr_scheduler_kwargs,
609
+ )
610
+
611
+ return schedule_func(
612
+ optimizer,
613
+ num_warmup_steps=num_warmup_steps,
614
+ num_training_steps=num_training_steps,
615
+ num_decay_steps=num_decay_steps,
616
+ **lr_scheduler_kwargs,
617
+ )
618
+
619
+ def resume_from_local_or_hf_if_specified(self, accelerator: Accelerator, args: argparse.Namespace) -> bool:
620
+ if not args.resume:
621
+ return False
622
+
623
+ if not args.resume_from_huggingface:
624
+ logger.info(f"resume training from local state: {args.resume}")
625
+ accelerator.load_state(args.resume)
626
+ return True
627
+
628
+ logger.info(f"resume training from huggingface state: {args.resume}")
629
+ repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
630
+ path_in_repo = "/".join(args.resume.split("/")[2:])
631
+ revision = None
632
+ repo_type = None
633
+ if ":" in path_in_repo:
634
+ divided = path_in_repo.split(":")
635
+ if len(divided) == 2:
636
+ path_in_repo, revision = divided
637
+ repo_type = "model"
638
+ else:
639
+ path_in_repo, revision, repo_type = divided
640
+ logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
641
+
642
+ list_files = huggingface_utils.list_dir(
643
+ repo_id=repo_id,
644
+ subfolder=path_in_repo,
645
+ revision=revision,
646
+ token=args.huggingface_token,
647
+ repo_type=repo_type,
648
+ )
649
+
650
+ async def download(filename) -> str:
651
+ def task():
652
+ return huggingface_hub.hf_hub_download(
653
+ repo_id=repo_id,
654
+ filename=filename,
655
+ revision=revision,
656
+ repo_type=repo_type,
657
+ token=args.huggingface_token,
658
+ )
659
+
660
+ return await asyncio.get_event_loop().run_in_executor(None, task)
661
+
662
+ loop = asyncio.get_event_loop()
663
+ results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
664
+ if len(results) == 0:
665
+ raise ValueError(
666
+ "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
667
+ )
668
+ dirname = os.path.dirname(results[0])
669
+ accelerator.load_state(dirname)
670
+
671
+ return True
672
+
673
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, transformer, sample_parameters):
674
+ pass
675
+
676
+ def get_noisy_model_input_and_timesteps(
677
+ self,
678
+ args: argparse.Namespace,
679
+ noise: torch.Tensor,
680
+ latents: torch.Tensor,
681
+ noise_scheduler: FlowMatchDiscreteScheduler,
682
+ device: torch.device,
683
+ dtype: torch.dtype,
684
+ ):
685
+ batch_size = noise.shape[0]
686
+
687
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid" or args.timestep_sampling == "shift":
688
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
689
+ # Simple random t-based noise sampling
690
+ if args.timestep_sampling == "sigmoid":
691
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((batch_size,), device=device))
692
+ else:
693
+ t = torch.rand((batch_size,), device=device)
694
+
695
+ elif args.timestep_sampling == "shift":
696
+ shift = args.discrete_flow_shift
697
+ logits_norm = torch.randn(batch_size, device=device)
698
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
699
+ t = logits_norm.sigmoid()
700
+ t = (t * shift) / (1 + (shift - 1) * t)
701
+
702
+ t_min = args.min_timestep if args.min_timestep is not None else 0
703
+ t_max = args.max_timestep if args.max_timestep is not None else 1000.0
704
+ t_min /= 1000.0
705
+ t_max /= 1000.0
706
+ t = t * (t_max - t_min) + t_min # scale to [t_min, t_max], default [0, 1]
707
+
708
+ timesteps = t * 1000.0
709
+ t = t.view(-1, 1, 1, 1, 1)
710
+ noisy_model_input = (1 - t) * latents + t * noise
711
+
712
+ timesteps += 1 # 1 to 1000
713
+ else:
714
+ # Sample a random timestep for each image
715
+ # for weighting schemes where we sample timesteps non-uniformly
716
+ u = compute_density_for_timestep_sampling(
717
+ weighting_scheme=args.weighting_scheme,
718
+ batch_size=batch_size,
719
+ logit_mean=args.logit_mean,
720
+ logit_std=args.logit_std,
721
+ mode_scale=args.mode_scale,
722
+ )
723
+ # indices = (u * noise_scheduler.config.num_train_timesteps).long()
724
+ t_min = args.min_timestep if args.min_timestep is not None else 0
725
+ t_max = args.max_timestep if args.max_timestep is not None else 1000
726
+ indices = (u * (t_max - t_min) + t_min).long()
727
+
728
+ timesteps = noise_scheduler.timesteps[indices].to(device=device) # 1 to 1000
729
+
730
+ # Add noise according to flow matching.
731
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
732
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
733
+
734
+ return noisy_model_input, timesteps
735
+
736
+ def train(self, args):
737
+ if args.seed is None:
738
+ args.seed = random.randint(0, 2**32)
739
+ set_seed(args.seed)
740
+
741
+ # Load dataset config
742
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
743
+ logger.info(f"Load dataset config from {args.dataset_config}")
744
+ user_config = config_utils.load_user_config(args.dataset_config)
745
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
746
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group, training=True)
747
+
748
+ current_epoch = Value("i", 0)
749
+ current_step = Value("i", 0)
750
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
751
+ collator = collator_class(current_epoch, current_step, ds_for_collator)
752
+
753
+ # prepare accelerator
754
+ logger.info("preparing accelerator")
755
+ accelerator = prepare_accelerator(args)
756
+ is_main_process = accelerator.is_main_process
757
+
758
+ # prepare dtype
759
+ weight_dtype = torch.float32
760
+ if args.mixed_precision == "fp16":
761
+ weight_dtype = torch.float16
762
+ elif args.mixed_precision == "bf16":
763
+ weight_dtype = torch.bfloat16
764
+
765
+ # HunyuanVideo specific
766
+ vae_dtype = torch.float16 if args.vae_dtype is None else model_utils.str_to_dtype(args.vae_dtype)
767
+
768
+ # get embedding for sampling images
769
+ sample_parameters = vae = None
770
+ if args.sample_prompts:
771
+ sample_parameters = self.process_sample_prompts(
772
+ args, accelerator, args.sample_prompts, args.text_encoder1, args.text_encoder2, args.fp8_llm
773
+ )
774
+
775
+ # Load VAE model for sampling images: VAE is loaded to cpu to save gpu memory
776
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device="cpu", vae_path=args.vae)
777
+ vae.requires_grad_(False)
778
+ vae.eval()
779
+
780
+ if args.vae_chunk_size is not None:
781
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
782
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
783
+ if args.vae_spatial_tile_sample_min_size is not None:
784
+ vae.enable_spatial_tiling(True)
785
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
786
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
787
+ elif args.vae_tiling:
788
+ vae.enable_spatial_tiling(True)
789
+
790
+ # load DiT model
791
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
792
+ loading_device = "cpu" if blocks_to_swap > 0 else accelerator.device
793
+
794
+ logger.info(f"Loading DiT model from {args.dit}")
795
+ if args.sdpa:
796
+ attn_mode = "torch"
797
+ elif args.flash_attn:
798
+ attn_mode = "flash"
799
+ elif args.sage_attn:
800
+ attn_mode = "sageattn"
801
+ elif args.xformers:
802
+ attn_mode = "xformers"
803
+ else:
804
+ raise ValueError(
805
+ f"either --sdpa, --flash-attn, --sage-attn or --xformers must be specified / --sdpa, --flash-attn, --sage-attn, --xformersのいずれかを指定してください"
806
+ )
807
+ transformer = load_transformer(
808
+ args.dit, attn_mode, args.split_attn, loading_device, None, in_channels=args.dit_in_channels
809
+ ) # load as is
810
+
811
+ if blocks_to_swap > 0:
812
+ logger.info(f"enable swap {blocks_to_swap} blocks to CPU from device: {accelerator.device}")
813
+ transformer.enable_block_swap(blocks_to_swap, accelerator.device, supports_backward=True)
814
+ transformer.move_to_device_except_swap_blocks(accelerator.device)
815
+ if args.img_in_txt_in_offloading:
816
+ logger.info("Enable offloading img_in and txt_in to CPU")
817
+ transformer.enable_img_in_txt_in_offloading()
818
+
819
+ if args.gradient_checkpointing:
820
+ transformer.enable_gradient_checkpointing()
821
+
822
+ # prepare optimizer, data loader etc.
823
+ accelerator.print("prepare optimizer, data loader etc.")
824
+
825
+ transformer.requires_grad_(False)
826
+ if accelerator.is_main_process:
827
+ accelerator.print(f"Trainable modules '{args.trainable_modules}'.")
828
+ for name, param in transformer.named_parameters():
829
+ for trainable_module_name in args.trainable_modules:
830
+ if trainable_module_name in name:
831
+ param.requires_grad = True
832
+ break
833
+
834
+ total_params = list(transformer.parameters())
835
+ trainable_params = list(filter(lambda p: p.requires_grad, transformer.parameters()))
836
+ logger.info(
837
+ f"number of trainable parameters: {sum(p.numel() for p in trainable_params) / 1e6} M, total paramters: {sum(p.numel() for p in total_params) / 1e6} M"
838
+ )
839
+ optimizer_name, optimizer_args, optimizer, optimizer_train_fn, optimizer_eval_fn = self.get_optimizer(
840
+ args, trainable_params
841
+ )
842
+
843
+ # prepare dataloader
844
+
845
+ # num workers for data loader: if 0, persistent_workers is not available
846
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
847
+
848
+ train_dataloader = torch.utils.data.DataLoader(
849
+ train_dataset_group,
850
+ batch_size=1,
851
+ shuffle=True,
852
+ collate_fn=collator,
853
+ num_workers=n_workers,
854
+ persistent_workers=args.persistent_data_loader_workers,
855
+ )
856
+
857
+ # calculate max_train_steps
858
+ if args.max_train_epochs is not None:
859
+ args.max_train_steps = args.max_train_epochs * math.ceil(
860
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
861
+ )
862
+ accelerator.print(
863
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
864
+ )
865
+
866
+ # send max_train_steps to train_dataset_group
867
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
868
+
869
+ # prepare lr_scheduler
870
+ lr_scheduler = self.get_scheduler(args, optimizer, accelerator.num_processes)
871
+
872
+ # prepare training model. accelerator does some magic here
873
+
874
+ # experimental feature: train the model with gradients in fp16/bf16
875
+ dit_dtype = torch.float32
876
+ if args.full_fp16:
877
+ assert (
878
+ args.mixed_precision == "fp16"
879
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
880
+ accelerator.print("enable full fp16 training.")
881
+ dit_weight_dtype = torch.float16
882
+ elif args.full_bf16:
883
+ assert (
884
+ args.mixed_precision == "bf16"
885
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
886
+ accelerator.print("enable full bf16 training.")
887
+ dit_weight_dtype = torch.bfloat16
888
+ else:
889
+ dit_weight_dtype = torch.float32
890
+
891
+ # TODO add fused optimizer and stochastic rounding
892
+
893
+ # cast model to dit_weight_dtype
894
+ # if dit_dtype != dit_weight_dtype:
895
+ logger.info(f"casting model to {dit_weight_dtype}")
896
+ transformer.to(dit_weight_dtype)
897
+
898
+ if blocks_to_swap > 0:
899
+ transformer = accelerator.prepare(transformer, device_placement=[not blocks_to_swap > 0])
900
+ accelerator.unwrap_model(transformer).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
901
+ accelerator.unwrap_model(transformer).prepare_block_swap_before_forward()
902
+ else:
903
+ transformer = accelerator.prepare(transformer)
904
+
905
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
906
+
907
+ transformer.train()
908
+
909
+ if args.full_fp16:
910
+ # patch accelerator for fp16 training
911
+ # def patch_accelerator_for_fp16_training(accelerator):
912
+ org_unscale_grads = accelerator.scaler._unscale_grads_
913
+
914
+ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
915
+ return org_unscale_grads(optimizer, inv_scale, found_inf, True)
916
+
917
+ accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
918
+
919
+ # resume from local or huggingface. accelerator.step is set
920
+ self.resume_from_local_or_hf_if_specified(accelerator, args) # accelerator.load_state(args.resume)
921
+
922
+ # epoch数を計算する
923
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
924
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
925
+
926
+ # 学習���る
927
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
928
+
929
+ accelerator.print("running training / 学習開始")
930
+ accelerator.print(f" num train items / 学習画像、動画数: {train_dataset_group.num_train_items}")
931
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
932
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
933
+ accelerator.print(
934
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
935
+ )
936
+ # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
937
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
938
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
939
+
940
+ if accelerator.is_main_process:
941
+ init_kwargs = {}
942
+ if args.wandb_run_name:
943
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
944
+ if args.log_tracker_config is not None:
945
+ init_kwargs = toml.load(args.log_tracker_config)
946
+ accelerator.init_trackers(
947
+ "hunyuan_video_ft" if args.log_tracker_name is None else args.log_tracker_name,
948
+ config=train_utils.get_sanitized_config_or_none(args),
949
+ init_kwargs=init_kwargs,
950
+ )
951
+
952
+ # TODO skip until initial step
953
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
954
+
955
+ epoch_to_start = 0
956
+ global_step = 0
957
+ noise_scheduler = FlowMatchDiscreteScheduler(shift=args.discrete_flow_shift, reverse=True, solver="euler")
958
+
959
+ loss_recorder = train_utils.LossRecorder()
960
+ del train_dataset_group
961
+
962
+ # function for saving/removing
963
+ def save_model(ckpt_name: str, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
964
+ os.makedirs(args.output_dir, exist_ok=True)
965
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
966
+
967
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
968
+
969
+ title = args.metadata_title if args.metadata_title is not None else args.output_name
970
+ if args.min_timestep is not None or args.max_timestep is not None:
971
+ min_time_step = args.min_timestep if args.min_timestep is not None else 0
972
+ max_time_step = args.max_timestep if args.max_timestep is not None else 1000
973
+ md_timesteps = (min_time_step, max_time_step)
974
+ else:
975
+ md_timesteps = None
976
+
977
+ sai_metadata = sai_model_spec.build_metadata(
978
+ None,
979
+ ARCHITECTURE_HUNYUAN_VIDEO,
980
+ time.time(),
981
+ title,
982
+ None,
983
+ args.metadata_author,
984
+ args.metadata_description,
985
+ args.metadata_license,
986
+ args.metadata_tags,
987
+ timesteps=md_timesteps,
988
+ is_lora=False,
989
+ )
990
+
991
+ save_file(unwrapped_nw.state_dict(), ckpt_file, sai_metadata)
992
+ if args.huggingface_repo_id is not None:
993
+ huggingface_utils.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
994
+
995
+ def remove_model(old_ckpt_name):
996
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
997
+ if os.path.exists(old_ckpt_file):
998
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
999
+ os.remove(old_ckpt_file)
1000
+
1001
+ # For --sample_at_first
1002
+ optimizer_eval_fn()
1003
+ self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, transformer, sample_parameters)
1004
+ optimizer_train_fn()
1005
+ if len(accelerator.trackers) > 0:
1006
+ # log empty object to commit the sample images to wandb
1007
+ accelerator.log({}, step=0)
1008
+
1009
+ # training loop
1010
+
1011
+ # log device and dtype for each model
1012
+ logger.info(f"DiT dtype: {transformer.dtype}, device: {transformer.device}")
1013
+
1014
+ clean_memory_on_device(accelerator.device)
1015
+
1016
+ pos_embed_cache = {}
1017
+
1018
+ for epoch in range(epoch_to_start, num_train_epochs):
1019
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
1020
+ current_epoch.value = epoch + 1
1021
+
1022
+ for step, batch in enumerate(train_dataloader):
1023
+ latents, llm_embeds, llm_mask, clip_embeds = batch
1024
+ bsz = latents.shape[0]
1025
+ current_step.value = global_step
1026
+
1027
+ with accelerator.accumulate(transformer):
1028
+ latents = latents * vae_module.SCALING_FACTOR
1029
+
1030
+ # Sample noise that we'll add to the latents
1031
+ noise = torch.randn_like(latents)
1032
+
1033
+ # calculate model input and timesteps
1034
+ noisy_model_input, timesteps = self.get_noisy_model_input_and_timesteps(
1035
+ args, noise, latents, noise_scheduler, accelerator.device, dit_dtype
1036
+ )
1037
+
1038
+ weighting = compute_loss_weighting_for_sd3(
1039
+ args.weighting_scheme, noise_scheduler, timesteps, accelerator.device, dit_dtype
1040
+ )
1041
+
1042
+ # ensure guidance_scale in args is float
1043
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # , dtype=dit_dtype)
1044
+
1045
+ # ensure the hidden state will require grad
1046
+ if args.gradient_checkpointing:
1047
+ noisy_model_input.requires_grad_(True)
1048
+ guidance_vec.requires_grad_(True)
1049
+
1050
+ pos_emb_shape = latents.shape[1:]
1051
+ if pos_emb_shape not in pos_embed_cache:
1052
+ freqs_cos, freqs_sin = get_rotary_pos_embed_by_shape(
1053
+ accelerator.unwrap_model(transformer), latents.shape[2:]
1054
+ )
1055
+ # freqs_cos = freqs_cos.to(device=accelerator.device, dtype=dit_dtype)
1056
+ # freqs_sin = freqs_sin.to(device=accelerator.device, dtype=dit_dtype)
1057
+ pos_embed_cache[pos_emb_shape] = (freqs_cos, freqs_sin)
1058
+ else:
1059
+ freqs_cos, freqs_sin = pos_embed_cache[pos_emb_shape]
1060
+
1061
+ # call DiT
1062
+ latents = latents.to(device=accelerator.device, dtype=dit_dtype)
1063
+ noisy_model_input = noisy_model_input.to(device=accelerator.device, dtype=dit_dtype)
1064
+ # timesteps = timesteps.to(device=accelerator.device, dtype=dit_dtype)
1065
+ # llm_embeds = llm_embeds.to(device=accelerator.device, dtype=dit_dtype)
1066
+ # llm_mask = llm_mask.to(device=accelerator.device)
1067
+ # clip_embeds = clip_embeds.to(device=accelerator.device, dtype=dit_dtype)
1068
+ with accelerator.autocast():
1069
+ model_pred = transformer(
1070
+ noisy_model_input,
1071
+ timesteps,
1072
+ text_states=llm_embeds,
1073
+ text_mask=llm_mask,
1074
+ text_states_2=clip_embeds,
1075
+ freqs_cos=freqs_cos,
1076
+ freqs_sin=freqs_sin,
1077
+ guidance=guidance_vec,
1078
+ return_dict=False,
1079
+ )
1080
+
1081
+ # flow matching loss
1082
+ target = noise - latents
1083
+
1084
+ loss = torch.nn.functional.mse_loss(model_pred.to(dit_dtype), target, reduction="none")
1085
+
1086
+ if weighting is not None:
1087
+ loss = loss * weighting
1088
+ # loss = loss.mean([1, 2, 3])
1089
+ # # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
1090
+ # loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
1091
+
1092
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
1093
+
1094
+ accelerator.backward(loss)
1095
+ if accelerator.sync_gradients:
1096
+ # self.all_reduce_network(accelerator, network) # sync DDP grad manually
1097
+ state = accelerate.PartialState()
1098
+ if state.distributed_type != accelerate.DistributedType.NO:
1099
+ for param in transformer.parameters():
1100
+ if param.grad is not None:
1101
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
1102
+
1103
+ if args.max_grad_norm != 0.0:
1104
+ params_to_clip = accelerator.unwrap_model(transformer).parameters()
1105
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1106
+
1107
+ optimizer.step()
1108
+ lr_scheduler.step()
1109
+ optimizer.zero_grad(set_to_none=True)
1110
+
1111
+ # Checks if the accelerator has performed an optimization step behind the scenes
1112
+ if accelerator.sync_gradients:
1113
+ progress_bar.update(1)
1114
+ global_step += 1
1115
+
1116
+ optimizer_eval_fn()
1117
+ self.sample_images(
1118
+ accelerator, args, None, global_step, accelerator.device, vae, transformer, sample_parameters
1119
+ )
1120
+
1121
+ # 指定ステップごとにモデルを保存
1122
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
1123
+ accelerator.wait_for_everyone()
1124
+ if accelerator.is_main_process:
1125
+ ckpt_name = train_utils.get_step_ckpt_name(args.output_name, global_step)
1126
+ save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch)
1127
+
1128
+ if args.save_state:
1129
+ train_utils.save_and_remove_state_stepwise(args, accelerator, global_step)
1130
+
1131
+ remove_step_no = train_utils.get_remove_step_no(args, global_step)
1132
+ if remove_step_no is not None:
1133
+ remove_ckpt_name = train_utils.get_step_ckpt_name(args.output_name, remove_step_no)
1134
+ remove_model(remove_ckpt_name)
1135
+ optimizer_train_fn()
1136
+
1137
+ current_loss = loss.detach().item()
1138
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
1139
+ avr_loss: float = loss_recorder.moving_average
1140
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
1141
+ progress_bar.set_postfix(**logs)
1142
+
1143
+ if len(accelerator.trackers) > 0:
1144
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
1145
+ accelerator.log(logs, step=global_step)
1146
+
1147
+ if global_step >= args.max_train_steps:
1148
+ break
1149
+
1150
+ if len(accelerator.trackers) > 0:
1151
+ logs = {"loss/epoch": loss_recorder.moving_average}
1152
+ accelerator.log(logs, step=epoch + 1)
1153
+
1154
+ accelerator.wait_for_everyone()
1155
+
1156
+ # 指定エポックごとにモデルを保存
1157
+ optimizer_eval_fn()
1158
+ if args.save_every_n_epochs is not None:
1159
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
1160
+ if is_main_process and saving:
1161
+ ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, epoch + 1)
1162
+ save_model(ckpt_name, accelerator.unwrap_model(transformer), global_step, epoch + 1)
1163
+
1164
+ remove_epoch_no = train_utils.get_remove_epoch_no(args, epoch + 1)
1165
+ if remove_epoch_no is not None:
1166
+ remove_ckpt_name = train_utils.get_epoch_ckpt_name(args.output_name, remove_epoch_no)
1167
+ remove_model(remove_ckpt_name)
1168
+
1169
+ if args.save_state:
1170
+ train_utils.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
1171
+
1172
+ self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, transformer, sample_parameters)
1173
+ optimizer_train_fn()
1174
+
1175
+ # end of epoch
1176
+
1177
+ if is_main_process:
1178
+ transformer = accelerator.unwrap_model(transformer)
1179
+
1180
+ accelerator.end_training()
1181
+ optimizer_eval_fn()
1182
+
1183
+ if args.save_state or args.save_state_on_train_end:
1184
+ train_utils.save_state_on_train_end(args, accelerator)
1185
+
1186
+ if is_main_process:
1187
+ ckpt_name = train_utils.get_last_ckpt_name(args.output_name)
1188
+ save_model(ckpt_name, transformer, global_step, num_train_epochs, force_sync_upload=True)
1189
+
1190
+ logger.info("model saved.")
1191
+
1192
+
1193
+ def setup_parser() -> argparse.ArgumentParser:
1194
+ def int_or_float(value):
1195
+ if value.endswith("%"):
1196
+ try:
1197
+ return float(value[:-1]) / 100.0
1198
+ except ValueError:
1199
+ raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
1200
+ try:
1201
+ float_value = float(value)
1202
+ if float_value >= 1 and float_value.is_integer():
1203
+ return int(value)
1204
+ return float(value)
1205
+ except ValueError:
1206
+ raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
1207
+
1208
+ parser = argparse.ArgumentParser()
1209
+
1210
+ # general settings
1211
+ parser.add_argument(
1212
+ "--config_file",
1213
+ type=str,
1214
+ default=None,
1215
+ help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す",
1216
+ )
1217
+ parser.add_argument(
1218
+ "--dataset_config",
1219
+ type=pathlib.Path,
1220
+ default=None,
1221
+ required=True,
1222
+ help="config file for dataset / データセットの設定ファイル",
1223
+ )
1224
+
1225
+ # training settings
1226
+ parser.add_argument(
1227
+ "--sdpa",
1228
+ action="store_true",
1229
+ help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)",
1230
+ )
1231
+ parser.add_argument(
1232
+ "--flash_attn",
1233
+ action="store_true",
1234
+ help="use FlashAttention for CrossAttention, requires FlashAttention / CrossAttentionにFlashAttentionを使う、FlashAttentionが必要",
1235
+ )
1236
+ parser.add_argument(
1237
+ "--sage_attn",
1238
+ action="store_true",
1239
+ help="use SageAttention. requires SageAttention / SageAttentionを使う。SageAttentionが必要",
1240
+ )
1241
+ parser.add_argument(
1242
+ "--xformers",
1243
+ action="store_true",
1244
+ help="use xformers for CrossAttention, requires xformers / CrossAttentionにxformersを使う、xformersが必要",
1245
+ )
1246
+ parser.add_argument(
1247
+ "--split_attn",
1248
+ action="store_true",
1249
+ help="use split attention for attention calculation (split batch size=1, affects memory usage and speed)"
1250
+ " / attentionを分割して計算する(バッチサイズ=1に分割、メモリ使用量と速度に影響)",
1251
+ )
1252
+
1253
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
1254
+ parser.add_argument(
1255
+ "--max_train_epochs",
1256
+ type=int,
1257
+ default=None,
1258
+ help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)",
1259
+ )
1260
+ parser.add_argument(
1261
+ "--max_data_loader_n_workers",
1262
+ type=int,
1263
+ default=8,
1264
+ help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)",
1265
+ )
1266
+ parser.add_argument(
1267
+ "--persistent_data_loader_workers",
1268
+ action="store_true",
1269
+ help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
1270
+ )
1271
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
1272
+ parser.add_argument(
1273
+ "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
1274
+ )
1275
+ parser.add_argument(
1276
+ "--gradient_accumulation_steps",
1277
+ type=int,
1278
+ default=1,
1279
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
1280
+ )
1281
+ parser.add_argument(
1282
+ "--mixed_precision",
1283
+ type=str,
1284
+ default="no",
1285
+ choices=["no", "fp16", "bf16"],
1286
+ help="use mixed precision / 混合精度を使う場合、その精度",
1287
+ )
1288
+ parser.add_argument("--trainable_modules", nargs="+", default=".", help="Enter a list of trainable modules")
1289
+
1290
+ parser.add_argument(
1291
+ "--logging_dir",
1292
+ type=str,
1293
+ default=None,
1294
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
1295
+ )
1296
+ parser.add_argument(
1297
+ "--log_with",
1298
+ type=str,
1299
+ default=None,
1300
+ choices=["tensorboard", "wandb", "all"],
1301
+ help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
1302
+ )
1303
+ parser.add_argument(
1304
+ "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
1305
+ )
1306
+ parser.add_argument(
1307
+ "--log_tracker_name",
1308
+ type=str,
1309
+ default=None,
1310
+ help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
1311
+ )
1312
+ parser.add_argument(
1313
+ "--wandb_run_name",
1314
+ type=str,
1315
+ default=None,
1316
+ help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
1317
+ )
1318
+ parser.add_argument(
1319
+ "--log_tracker_config",
1320
+ type=str,
1321
+ default=None,
1322
+ help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス",
1323
+ )
1324
+ parser.add_argument(
1325
+ "--wandb_api_key",
1326
+ type=str,
1327
+ default=None,
1328
+ help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
1329
+ )
1330
+ parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
1331
+
1332
+ parser.add_argument(
1333
+ "--ddp_timeout",
1334
+ type=int,
1335
+ default=None,
1336
+ help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
1337
+ )
1338
+ parser.add_argument(
1339
+ "--ddp_gradient_as_bucket_view",
1340
+ action="store_true",
1341
+ help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
1342
+ )
1343
+ parser.add_argument(
1344
+ "--ddp_static_graph",
1345
+ action="store_true",
1346
+ help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
1347
+ )
1348
+
1349
+ parser.add_argument(
1350
+ "--sample_every_n_steps",
1351
+ type=int,
1352
+ default=None,
1353
+ help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
1354
+ )
1355
+ parser.add_argument(
1356
+ "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
1357
+ )
1358
+ parser.add_argument(
1359
+ "--sample_every_n_epochs",
1360
+ type=int,
1361
+ default=None,
1362
+ help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
1363
+ )
1364
+ parser.add_argument(
1365
+ "--sample_prompts",
1366
+ type=str,
1367
+ default=None,
1368
+ help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
1369
+ )
1370
+
1371
+ # optimizer and lr scheduler settings
1372
+ parser.add_argument(
1373
+ "--optimizer_type",
1374
+ type=str,
1375
+ default="",
1376
+ help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, AdaFactor. "
1377
+ "Also, you can use any optimizer by specifying the full path to the class, like 'torch.optim.AdamW', 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit' etc. / ",
1378
+ )
1379
+ parser.add_argument(
1380
+ "--optimizer_args",
1381
+ type=str,
1382
+ default=None,
1383
+ nargs="*",
1384
+ help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
1385
+ )
1386
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
1387
+ parser.add_argument(
1388
+ "--max_grad_norm",
1389
+ default=1.0,
1390
+ type=float,
1391
+ help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
1392
+ )
1393
+
1394
+ parser.add_argument(
1395
+ "--lr_scheduler",
1396
+ type=str,
1397
+ default="constant",
1398
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor",
1399
+ )
1400
+ parser.add_argument(
1401
+ "--lr_warmup_steps",
1402
+ type=int_or_float,
1403
+ default=0,
1404
+ help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
1405
+ " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
1406
+ )
1407
+ parser.add_argument(
1408
+ "--lr_decay_steps",
1409
+ type=int_or_float,
1410
+ default=0,
1411
+ help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
1412
+ " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
1413
+ )
1414
+ parser.add_argument(
1415
+ "--lr_scheduler_num_cycles",
1416
+ type=int,
1417
+ default=1,
1418
+ help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数",
1419
+ )
1420
+ parser.add_argument(
1421
+ "--lr_scheduler_power",
1422
+ type=float,
1423
+ default=1,
1424
+ help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
1425
+ )
1426
+ parser.add_argument(
1427
+ "--lr_scheduler_timescale",
1428
+ type=int,
1429
+ default=None,
1430
+ help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
1431
+ + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
1432
+ )
1433
+ parser.add_argument(
1434
+ "--lr_scheduler_min_lr_ratio",
1435
+ type=float,
1436
+ default=None,
1437
+ help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
1438
+ + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
1439
+ )
1440
+ parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
1441
+ parser.add_argument(
1442
+ "--lr_scheduler_args",
1443
+ type=str,
1444
+ default=None,
1445
+ nargs="*",
1446
+ help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
1447
+ )
1448
+
1449
+ # model settings
1450
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path / DiTのチェックポイントのパス")
1451
+ parser.add_argument("--dit_dtype", type=str, default=None, help="data type for DiT, default is bfloat16")
1452
+ parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
1453
+ parser.add_argument("--vae", type=str, help="VAE checkpoint path / VAEのチェックポイントのパス")
1454
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
1455
+ parser.add_argument(
1456
+ "--vae_tiling",
1457
+ action="store_true",
1458
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled."
1459
+ " / VAEの空間タイリングを有効にする、デフォルトはFalse。vae_spatial_tile_sample_min_sizeが設定されている場合、自動的に有効になります。",
1460
+ )
1461
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
1462
+ parser.add_argument(
1463
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
1464
+ )
1465
+ parser.add_argument("--text_encoder1", type=str, help="Text Encoder 1 directory / テキストエンコーダ1のディレクトリ")
1466
+ parser.add_argument("--text_encoder2", type=str, help="Text Encoder 2 directory / テキストエンコーダ2のディレクトリ")
1467
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
1468
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for LLM / LLMにfp8を使う")
1469
+ parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
1470
+ parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
1471
+
1472
+ parser.add_argument(
1473
+ "--blocks_to_swap",
1474
+ type=int,
1475
+ default=None,
1476
+ help="number of blocks to swap in the model, max XXX / モデル内のブロックの数、最大XXX",
1477
+ )
1478
+ parser.add_argument(
1479
+ "--img_in_txt_in_offloading",
1480
+ action="store_true",
1481
+ help="offload img_in and txt_in to cpu / img_inとtxt_inをCPUにオフロードする",
1482
+ )
1483
+
1484
+ # parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers")
1485
+ parser.add_argument("--guidance_scale", type=float, default=1.0, help="Embeded classifier free guidance scale.")
1486
+ parser.add_argument(
1487
+ "--timestep_sampling",
1488
+ choices=["sigma", "uniform", "sigmoid", "shift"],
1489
+ default="sigma",
1490
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
1491
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
1492
+ )
1493
+ parser.add_argument(
1494
+ "--discrete_flow_shift",
1495
+ type=float,
1496
+ default=1.0,
1497
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 1.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは1.0。",
1498
+ )
1499
+ parser.add_argument(
1500
+ "--sigmoid_scale",
1501
+ type=float,
1502
+ default=1.0,
1503
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid" or "shift"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"または"shift"の場合のみ有効)。',
1504
+ )
1505
+ parser.add_argument(
1506
+ "--weighting_scheme",
1507
+ type=str,
1508
+ default="none",
1509
+ choices=["logit_normal", "mode", "cosmap", "sigma_sqrt", "none"],
1510
+ help="weighting scheme for timestep distribution. Default is none"
1511
+ " / タイムステップ分布の重み付けスキーム、デフォルトはnone",
1512
+ )
1513
+ parser.add_argument(
1514
+ "--logit_mean",
1515
+ type=float,
1516
+ default=0.0,
1517
+ help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
1518
+ )
1519
+ parser.add_argument(
1520
+ "--logit_std",
1521
+ type=float,
1522
+ default=1.0,
1523
+ help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
1524
+ )
1525
+ parser.add_argument(
1526
+ "--mode_scale",
1527
+ type=float,
1528
+ default=1.29,
1529
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
1530
+ )
1531
+ parser.add_argument(
1532
+ "--min_timestep",
1533
+ type=int,
1534
+ default=None,
1535
+ help="set minimum time step for training (0~999, default is 0) / 学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ",
1536
+ )
1537
+ parser.add_argument(
1538
+ "--max_timestep",
1539
+ type=int,
1540
+ default=None,
1541
+ help="set maximum time step for training (1~1000, default is 1000) / 学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
1542
+ )
1543
+
1544
+ # save and load settings
1545
+ parser.add_argument(
1546
+ "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
1547
+ )
1548
+ parser.add_argument(
1549
+ "--output_name",
1550
+ type=str,
1551
+ default=None,
1552
+ required=True,
1553
+ help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名",
1554
+ )
1555
+ parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
1556
+
1557
+ parser.add_argument(
1558
+ "--save_every_n_epochs",
1559
+ type=int,
1560
+ default=None,
1561
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
1562
+ )
1563
+ parser.add_argument(
1564
+ "--save_every_n_steps",
1565
+ type=int,
1566
+ default=None,
1567
+ help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
1568
+ )
1569
+ parser.add_argument(
1570
+ "--save_last_n_epochs",
1571
+ type=int,
1572
+ default=None,
1573
+ help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)",
1574
+ )
1575
+ parser.add_argument(
1576
+ "--save_last_n_epochs_state",
1577
+ type=int,
1578
+ default=None,
1579
+ help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)",
1580
+ )
1581
+ parser.add_argument(
1582
+ "--save_last_n_steps",
1583
+ type=int,
1584
+ default=None,
1585
+ help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
1586
+ )
1587
+ parser.add_argument(
1588
+ "--save_last_n_steps_state",
1589
+ type=int,
1590
+ default=None,
1591
+ help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)",
1592
+ )
1593
+ parser.add_argument(
1594
+ "--save_state",
1595
+ action="store_true",
1596
+ help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
1597
+ )
1598
+ parser.add_argument(
1599
+ "--save_state_on_train_end",
1600
+ action="store_true",
1601
+ help="save training state (including optimizer states etc.) on train end even if --save_state is not specified"
1602
+ " / --save_stateが未指定時にもoptimizerなど学習状態も含めたstateを学習終了時に保存する",
1603
+ )
1604
+
1605
+ # SAI Model spec
1606
+ parser.add_argument(
1607
+ "--metadata_title",
1608
+ type=str,
1609
+ default=None,
1610
+ help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
1611
+ )
1612
+ parser.add_argument(
1613
+ "--metadata_author",
1614
+ type=str,
1615
+ default=None,
1616
+ help="author name for model metadata / メタデータに書き込まれるモデル作者名",
1617
+ )
1618
+ parser.add_argument(
1619
+ "--metadata_description",
1620
+ type=str,
1621
+ default=None,
1622
+ help="description for model metadata / メタデータに書き込まれるモデル説明",
1623
+ )
1624
+ parser.add_argument(
1625
+ "--metadata_license",
1626
+ type=str,
1627
+ default=None,
1628
+ help="license for model metadata / メタデータに書き込まれるモデルライセンス",
1629
+ )
1630
+ parser.add_argument(
1631
+ "--metadata_tags",
1632
+ type=str,
1633
+ default=None,
1634
+ help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
1635
+ )
1636
+
1637
+ # huggingface settings
1638
+ parser.add_argument(
1639
+ "--huggingface_repo_id",
1640
+ type=str,
1641
+ default=None,
1642
+ help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
1643
+ )
1644
+ parser.add_argument(
1645
+ "--huggingface_repo_type",
1646
+ type=str,
1647
+ default=None,
1648
+ help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
1649
+ )
1650
+ parser.add_argument(
1651
+ "--huggingface_path_in_repo",
1652
+ type=str,
1653
+ default=None,
1654
+ help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
1655
+ )
1656
+ parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
1657
+ parser.add_argument(
1658
+ "--huggingface_repo_visibility",
1659
+ type=str,
1660
+ default=None,
1661
+ help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)",
1662
+ )
1663
+ parser.add_argument(
1664
+ "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
1665
+ )
1666
+ parser.add_argument(
1667
+ "--resume_from_huggingface",
1668
+ action="store_true",
1669
+ help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
1670
+ )
1671
+ parser.add_argument(
1672
+ "--async_upload",
1673
+ action="store_true",
1674
+ help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
1675
+ )
1676
+
1677
+ return parser
1678
+
1679
+
1680
+ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
1681
+ if not args.config_file:
1682
+ return args
1683
+
1684
+ config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
1685
+
1686
+ if not os.path.exists(config_path):
1687
+ logger.info(f"{config_path} not found.")
1688
+ exit(1)
1689
+
1690
+ logger.info(f"Loading settings from {config_path}...")
1691
+ with open(config_path, "r", encoding="utf-8") as f:
1692
+ config_dict = toml.load(f)
1693
+
1694
+ # combine all sections into one
1695
+ ignore_nesting_dict = {}
1696
+ for section_name, section_dict in config_dict.items():
1697
+ # if value is not dict, save key and value as is
1698
+ if not isinstance(section_dict, dict):
1699
+ ignore_nesting_dict[section_name] = section_dict
1700
+ continue
1701
+
1702
+ # if value is dict, save all key and value into one dict
1703
+ for key, value in section_dict.items():
1704
+ ignore_nesting_dict[key] = value
1705
+
1706
+ config_args = argparse.Namespace(**ignore_nesting_dict)
1707
+ args = parser.parse_args(namespace=config_args)
1708
+ args.config_file = os.path.splitext(args.config_file)[0]
1709
+ logger.info(args.config_file)
1710
+
1711
+ return args
1712
+
1713
+
1714
+ if __name__ == "__main__":
1715
+ parser = setup_parser()
1716
+
1717
+ args = parser.parse_args()
1718
+ args = read_config_from_file(args, parser)
1719
+
1720
+ trainer = FineTuningTrainer()
1721
+ trainer.train(args)
hv_train_network.py ADDED
The diff for this file is too large to render. See raw diff
 
merge_lora.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import torch
4
+ from safetensors.torch import load_file
5
+ from networks import lora
6
+ from utils.safetensors_utils import mem_eff_save_file
7
+ from hunyuan_model.models import load_transformer
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description="HunyuanVideo model merger script")
15
+
16
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
17
+ parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
18
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
19
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier (can specify multiple values)")
20
+ parser.add_argument("--save_merged_model", type=str, required=True, help="Path to save the merged model")
21
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for merging")
22
+
23
+ return parser.parse_args()
24
+
25
+
26
+ def main():
27
+ args = parse_args()
28
+
29
+ device = torch.device(args.device)
30
+ logger.info(f"Using device: {device}")
31
+
32
+ # Load DiT model
33
+ logger.info(f"Loading DiT model from {args.dit}")
34
+ transformer = load_transformer(args.dit, "torch", False, "cpu", torch.bfloat16, in_channels=args.dit_in_channels)
35
+ transformer.eval()
36
+
37
+ # Load LoRA weights and merge
38
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
39
+ for i, lora_weight in enumerate(args.lora_weight):
40
+ # Use the corresponding lora_multiplier or default to 1.0
41
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
42
+ lora_multiplier = args.lora_multiplier[i]
43
+ else:
44
+ lora_multiplier = 1.0
45
+
46
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
47
+ weights_sd = load_file(lora_weight)
48
+ network = lora.create_network_from_weights_hunyuan_video(
49
+ lora_multiplier, weights_sd, unet=transformer, for_inference=True
50
+ )
51
+ logger.info("Merging LoRA weights to DiT model")
52
+ network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
53
+
54
+ logger.info("LoRA weights loaded")
55
+
56
+ # Save the merged model
57
+ logger.info(f"Saving merged model to {args.save_merged_model}")
58
+ mem_eff_save_file(transformer.state_dict(), args.save_merged_model)
59
+ logger.info("Merged model saved")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
modules/__init__.py ADDED
File without changes
modules/custom_offloading_utils.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import gc
3
+ import time
4
+ from typing import Optional
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ def clean_memory_on_device(device: torch.device):
10
+ r"""
11
+ Clean memory on the specified device, will be called from training scripts.
12
+ """
13
+ gc.collect()
14
+
15
+ # device may "cuda" or "cuda:0", so we need to check the type of device
16
+ if device.type == "cuda":
17
+ torch.cuda.empty_cache()
18
+ if device.type == "xpu":
19
+ torch.xpu.empty_cache()
20
+ if device.type == "mps":
21
+ torch.mps.empty_cache()
22
+
23
+
24
+ def synchronize_device(device: torch.device):
25
+ if device.type == "cuda":
26
+ torch.cuda.synchronize()
27
+ elif device.type == "xpu":
28
+ torch.xpu.synchronize()
29
+ elif device.type == "mps":
30
+ torch.mps.synchronize()
31
+
32
+
33
+ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
34
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
35
+
36
+ weight_swap_jobs = []
37
+
38
+ # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
39
+ # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
40
+ # print(module_to_cpu.__class__, module_to_cuda.__class__)
41
+ # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
42
+ # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
43
+
44
+ modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
45
+ for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
46
+ if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
47
+ module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
48
+ if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
49
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
50
+ else:
51
+ if module_to_cuda.weight.data.device.type != device.type:
52
+ # print(
53
+ # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
54
+ # )
55
+ module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
56
+
57
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
58
+
59
+ stream = torch.cuda.Stream()
60
+ with torch.cuda.stream(stream):
61
+ # cuda to cpu
62
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
63
+ cuda_data_view.record_stream(stream)
64
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
65
+
66
+ stream.synchronize()
67
+
68
+ # cpu to cuda
69
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
70
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
71
+ module_to_cuda.weight.data = cuda_data_view
72
+
73
+ stream.synchronize()
74
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
75
+
76
+
77
+ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
78
+ """
79
+ not tested
80
+ """
81
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
82
+
83
+ weight_swap_jobs = []
84
+ for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
85
+ if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
86
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
87
+
88
+ # device to cpu
89
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
90
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
91
+
92
+ synchronize_device()
93
+
94
+ # cpu to device
95
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
96
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
97
+ module_to_cuda.weight.data = cuda_data_view
98
+
99
+ synchronize_device()
100
+
101
+
102
+ def weighs_to_device(layer: nn.Module, device: torch.device):
103
+ for module in layer.modules():
104
+ if hasattr(module, "weight") and module.weight is not None:
105
+ module.weight.data = module.weight.data.to(device, non_blocking=True)
106
+
107
+
108
+ class Offloader:
109
+ """
110
+ common offloading class
111
+ """
112
+
113
+ def __init__(self, block_type: str, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
114
+ self.block_type = block_type
115
+ self.num_blocks = num_blocks
116
+ self.blocks_to_swap = blocks_to_swap
117
+ self.device = device
118
+ self.debug = debug
119
+
120
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
121
+ self.futures = {}
122
+ self.cuda_available = device.type == "cuda"
123
+
124
+ def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
125
+ if self.cuda_available:
126
+ swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
127
+ else:
128
+ swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
129
+
130
+ def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
131
+ def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
132
+ if self.debug:
133
+ start_time = time.perf_counter()
134
+ print(
135
+ f"[{self.block_type}] Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}"
136
+ )
137
+
138
+ self.swap_weight_devices(block_to_cpu, block_to_cuda)
139
+
140
+ if self.debug:
141
+ print(f"[{self.block_type}] Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
142
+ return bidx_to_cpu, bidx_to_cuda # , event
143
+
144
+ block_to_cpu = blocks[block_idx_to_cpu]
145
+ block_to_cuda = blocks[block_idx_to_cuda]
146
+
147
+ self.futures[block_idx_to_cuda] = self.thread_pool.submit(
148
+ move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
149
+ )
150
+
151
+ def _wait_blocks_move(self, block_idx):
152
+ if block_idx not in self.futures:
153
+ return
154
+
155
+ if self.debug:
156
+ print(f"[{self.block_type}] Wait for block {block_idx}")
157
+ start_time = time.perf_counter()
158
+
159
+ future = self.futures.pop(block_idx)
160
+ _, bidx_to_cuda = future.result()
161
+
162
+ assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
163
+
164
+ if self.debug:
165
+ print(f"[{self.block_type}] Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
166
+
167
+
168
+ class ModelOffloader(Offloader):
169
+ """
170
+ supports forward offloading
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ block_type: str,
176
+ blocks: list[nn.Module],
177
+ num_blocks: int,
178
+ blocks_to_swap: int,
179
+ supports_backward: bool,
180
+ device: torch.device,
181
+ debug: bool = False,
182
+ ):
183
+ super().__init__(block_type, num_blocks, blocks_to_swap, device, debug)
184
+
185
+ self.supports_backward = supports_backward
186
+ self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
187
+
188
+ if self.supports_backward:
189
+ # register backward hooks
190
+ self.remove_handles = []
191
+ for i, block in enumerate(blocks):
192
+ hook = self.create_backward_hook(blocks, i)
193
+ if hook is not None:
194
+ handle = block.register_full_backward_hook(hook)
195
+ self.remove_handles.append(handle)
196
+
197
+ def set_forward_only(self, forward_only: bool):
198
+ self.forward_only = forward_only
199
+
200
+ def __del__(self):
201
+ if self.supports_backward:
202
+ for handle in self.remove_handles:
203
+ handle.remove()
204
+
205
+ def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
206
+ # -1 for 0-based index
207
+ num_blocks_propagated = self.num_blocks - block_index - 1
208
+ swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
209
+ waiting = block_index > 0 and block_index <= self.blocks_to_swap
210
+
211
+ if not swapping and not waiting:
212
+ return None
213
+
214
+ # create hook
215
+ block_idx_to_cpu = self.num_blocks - num_blocks_propagated
216
+ block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
217
+ block_idx_to_wait = block_index - 1
218
+
219
+ def backward_hook(module, grad_input, grad_output):
220
+ if self.debug:
221
+ print(f"Backward hook for block {block_index}")
222
+
223
+ if swapping:
224
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
225
+ if waiting:
226
+ self._wait_blocks_move(block_idx_to_wait)
227
+ return None
228
+
229
+ return backward_hook
230
+
231
+ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
232
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
233
+ return
234
+
235
+ if self.debug:
236
+ print(f"[{self.block_type}] Prepare block devices before forward")
237
+
238
+ for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
239
+ b.to(self.device)
240
+ weighs_to_device(b, self.device) # make sure weights are on device
241
+
242
+ for b in blocks[self.num_blocks - self.blocks_to_swap :]:
243
+ b.to(self.device) # move block to device first
244
+ weighs_to_device(b, "cpu") # make sure weights are on cpu
245
+
246
+ synchronize_device(self.device)
247
+ clean_memory_on_device(self.device)
248
+
249
+ def wait_for_block(self, block_idx: int):
250
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
251
+ return
252
+ self._wait_blocks_move(block_idx)
253
+
254
+ def submit_move_blocks_forward(self, blocks: list[nn.Module], block_idx: int):
255
+ # check if blocks_to_swap is enabled
256
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
257
+ return
258
+
259
+ # if supports_backward and backward is enabled, we swap blocks more than blocks_to_swap in backward pass
260
+ if not self.forward_only and block_idx >= self.blocks_to_swap:
261
+ return
262
+
263
+ block_idx_to_cpu = block_idx
264
+ block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
265
+ block_idx_to_cuda = block_idx_to_cuda % self.num_blocks # this works for forward-only offloading
266
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
modules/scheduling_flow_match_discrete.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import BaseOutput, logging
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+
47
+
48
+ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49
+ """
50
+ Euler scheduler.
51
+
52
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
+ methods the library implements for all schedulers such as loading and saving.
54
+
55
+ Args:
56
+ num_train_timesteps (`int`, defaults to 1000):
57
+ The number of diffusion steps to train the model.
58
+ timestep_spacing (`str`, defaults to `"linspace"`):
59
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61
+ shift (`float`, defaults to 1.0):
62
+ The shift value for the timestep schedule.
63
+ reverse (`bool`, defaults to `True`):
64
+ Whether to reverse the timestep schedule.
65
+ """
66
+
67
+ _compatibles = []
68
+ order = 1
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_train_timesteps: int = 1000,
74
+ shift: float = 1.0,
75
+ reverse: bool = True,
76
+ solver: str = "euler",
77
+ n_tokens: Optional[int] = None,
78
+ ):
79
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80
+
81
+ if not reverse:
82
+ sigmas = sigmas.flip(0)
83
+
84
+ self.sigmas = sigmas
85
+ # the value fed to model
86
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87
+
88
+ self._step_index = None
89
+ self._begin_index = None
90
+
91
+ self.supported_solver = ["euler"]
92
+ if solver not in self.supported_solver:
93
+ raise ValueError(
94
+ f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95
+ )
96
+
97
+ @property
98
+ def step_index(self):
99
+ """
100
+ The index counter for current timestep. It will increase 1 after each scheduler step.
101
+ """
102
+ return self._step_index
103
+
104
+ @property
105
+ def begin_index(self):
106
+ """
107
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108
+ """
109
+ return self._begin_index
110
+
111
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112
+ def set_begin_index(self, begin_index: int = 0):
113
+ """
114
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115
+
116
+ Args:
117
+ begin_index (`int`):
118
+ The begin index for the scheduler.
119
+ """
120
+ self._begin_index = begin_index
121
+
122
+ def _sigma_to_t(self, sigma):
123
+ return sigma * self.config.num_train_timesteps
124
+
125
+ def set_timesteps(
126
+ self,
127
+ num_inference_steps: int,
128
+ device: Union[str, torch.device] = None,
129
+ n_tokens: int = None,
130
+ ):
131
+ """
132
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133
+
134
+ Args:
135
+ num_inference_steps (`int`):
136
+ The number of diffusion steps used when generating samples with a pre-trained model.
137
+ device (`str` or `torch.device`, *optional*):
138
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
+ n_tokens (`int`, *optional*):
140
+ Number of tokens in the input sequence.
141
+ """
142
+ self.num_inference_steps = num_inference_steps
143
+
144
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145
+ sigmas = self.sd3_time_shift(sigmas)
146
+
147
+ if not self.config.reverse:
148
+ sigmas = 1 - sigmas
149
+
150
+ self.sigmas = sigmas
151
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152
+ dtype=torch.float32, device=device
153
+ )
154
+
155
+ # Reset step index
156
+ self._step_index = None
157
+
158
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
159
+ if schedule_timesteps is None:
160
+ schedule_timesteps = self.timesteps
161
+
162
+ indices = (schedule_timesteps == timestep).nonzero()
163
+
164
+ # The sigma index that is taken for the **very** first `step`
165
+ # is always the second index (or the last index if there is only 1)
166
+ # This way we can ensure we don't accidentally skip a sigma in
167
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168
+ pos = 1 if len(indices) > 1 else 0
169
+
170
+ return indices[pos].item()
171
+
172
+ def _init_step_index(self, timestep):
173
+ if self.begin_index is None:
174
+ if isinstance(timestep, torch.Tensor):
175
+ timestep = timestep.to(self.timesteps.device)
176
+ self._step_index = self.index_for_timestep(timestep)
177
+ else:
178
+ self._step_index = self._begin_index
179
+
180
+ def scale_model_input(
181
+ self, sample: torch.Tensor, timestep: Optional[int] = None
182
+ ) -> torch.Tensor:
183
+ return sample
184
+
185
+ def sd3_time_shift(self, t: torch.Tensor):
186
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187
+
188
+ def step(
189
+ self,
190
+ model_output: torch.FloatTensor,
191
+ timestep: Union[float, torch.FloatTensor],
192
+ sample: torch.FloatTensor,
193
+ return_dict: bool = True,
194
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195
+ """
196
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197
+ process from the learned model outputs (most often the predicted noise).
198
+
199
+ Args:
200
+ model_output (`torch.FloatTensor`):
201
+ The direct output from learned diffusion model.
202
+ timestep (`float`):
203
+ The current discrete timestep in the diffusion chain.
204
+ sample (`torch.FloatTensor`):
205
+ A current instance of a sample created by the diffusion process.
206
+ generator (`torch.Generator`, *optional*):
207
+ A random number generator.
208
+ n_tokens (`int`, *optional*):
209
+ Number of tokens in the input sequence.
210
+ return_dict (`bool`):
211
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212
+ tuple.
213
+
214
+ Returns:
215
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
218
+ """
219
+
220
+ if (
221
+ isinstance(timestep, int)
222
+ or isinstance(timestep, torch.IntTensor)
223
+ or isinstance(timestep, torch.LongTensor)
224
+ ):
225
+ raise ValueError(
226
+ (
227
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229
+ " one of the `scheduler.timesteps` as a timestep."
230
+ ),
231
+ )
232
+
233
+ if self.step_index is None:
234
+ self._init_step_index(timestep)
235
+
236
+ # Upcast to avoid precision issues when computing prev_sample
237
+ sample = sample.to(torch.float32)
238
+
239
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240
+
241
+ if self.config.solver == "euler":
242
+ prev_sample = sample + model_output.to(torch.float32) * dt
243
+ else:
244
+ raise ValueError(
245
+ f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246
+ )
247
+
248
+ # upon completion increase step index by one
249
+ self._step_index += 1
250
+
251
+ if not return_dict:
252
+ return (prev_sample,)
253
+
254
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255
+
256
+ def __len__(self):
257
+ return self.config.num_train_timesteps
modules/unet_causal_3d_blocks.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+ from einops import rearrange
26
+
27
+ from diffusers.utils import logging
28
+ from diffusers.models.activations import get_activation
29
+ from diffusers.models.attention_processor import SpatialNorm
30
+ from diffusers.models.attention_processor import Attention
31
+ from diffusers.models.normalization import AdaGroupNorm
32
+ from diffusers.models.normalization import RMSNorm
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
38
+ seq_len = n_frame * n_hw
39
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
40
+ for i in range(seq_len):
41
+ i_frame = i // n_hw
42
+ mask[i, : (i_frame + 1) * n_hw] = 0
43
+ if batch_size is not None:
44
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
45
+ return mask
46
+
47
+
48
+ class CausalConv3d(nn.Module):
49
+ """
50
+ Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
51
+ This maintains temporal causality in video generation tasks.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ chan_in,
57
+ chan_out,
58
+ kernel_size: Union[int, Tuple[int, int, int]],
59
+ stride: Union[int, Tuple[int, int, int]] = 1,
60
+ dilation: Union[int, Tuple[int, int, int]] = 1,
61
+ pad_mode="replicate",
62
+ chunk_size=0,
63
+ **kwargs,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.pad_mode = pad_mode
68
+ padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
69
+ self.time_causal_padding = padding
70
+ self.chunk_size = chunk_size
71
+
72
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
73
+
74
+ def original_forward(self, x):
75
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
76
+ return self.conv(x)
77
+
78
+ def forward(self, x):
79
+ if self.chunk_size == 0:
80
+ return self.original_forward(x)
81
+
82
+ # if not large, call original forward
83
+ if x.shape[4] < self.chunk_size * 1.5:
84
+ return self.original_forward(x)
85
+
86
+ # # debug: verify the original forward is the same as chunked forward
87
+ # orig_forwarded_value = None
88
+ # if x.shape[4] < self.chunk_size * 4:
89
+ # orig_forwarded_value = self.original_forward(x)
90
+
91
+ # get the kernel size
92
+ kernel_size = self.conv.kernel_size[0] # assume cubic kernel
93
+ assert kernel_size == self.conv.kernel_size[1] == self.conv.kernel_size[2], "Only cubic kernels are supported"
94
+ padding_size = kernel_size // 2 # 1 for kernel_size=3, 0 for kernel_size=1
95
+
96
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
97
+
98
+ B, C, D, H, W = orig_shape = x.shape
99
+ chunk_size = self.chunk_size
100
+ chunk_size -= chunk_size % self.conv.stride[2] # make sure the chunk size is divisible by stride
101
+ # print(f"chunked forward: {x.shape}, chunk_size: {chunk_size}")
102
+
103
+ # calculate the indices for chunking with overlap and padding by kernel size and stride
104
+ indices = []
105
+ i = 0
106
+ while i < W - padding_size:
107
+ start_idx = i - padding_size
108
+ end_idx = min(i + chunk_size + padding_size, W)
109
+ if i == 0:
110
+ start_idx = 0
111
+ end_idx += padding_size # to make sure the first chunk is divisible by stride
112
+ if W - end_idx < chunk_size // 2: # small chunk at the end
113
+ end_idx = W
114
+ indices.append((start_idx, end_idx))
115
+ i = end_idx - padding_size
116
+ # print(f"chunked forward: {x.shape}, chunked indices: {indices}")
117
+
118
+ chunks = []
119
+ for start_idx, end_idx in indices:
120
+ chunk = x[:, :, :, :, start_idx:end_idx]
121
+ chunk_output = self.conv(chunk)
122
+ # print(chunk.shape, chunk_output.shape)
123
+ chunks.append(chunk_output)
124
+
125
+ # concatenate the chunks
126
+ x = torch.cat(chunks, dim=4)
127
+
128
+ assert (
129
+ x.shape[2] == ((D - padding_size * 2) + self.conv.stride[0] - 1) // self.conv.stride[0]
130
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
131
+ assert (
132
+ x.shape[3] == ((H - padding_size * 2) + self.conv.stride[1] - 1) // self.conv.stride[1]
133
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
134
+ assert (
135
+ x.shape[4] == ((W - padding_size * 2) + self.conv.stride[2] - 1) // self.conv.stride[2]
136
+ ), f"Invalid shape: {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}"
137
+
138
+ # # debug: verify the original forward is the same as chunked forward
139
+ # if orig_forwarded_value is not None:
140
+ # assert torch.allclose(
141
+ # orig_forwarded_value, x, rtol=1e-4, atol=1e-2
142
+ # ), f"Chunked forward is different from original forward. {x.shape}, {orig_shape}, {padding_size}, {self.conv.stride}, {self.conv.kernel_size}"
143
+
144
+ return x
145
+
146
+
147
+ class UpsampleCausal3D(nn.Module):
148
+ """
149
+ A 3D upsampling layer with an optional convolution.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ channels: int,
155
+ use_conv: bool = False,
156
+ use_conv_transpose: bool = False,
157
+ out_channels: Optional[int] = None,
158
+ name: str = "conv",
159
+ kernel_size: Optional[int] = None,
160
+ padding=1,
161
+ norm_type=None,
162
+ eps=None,
163
+ elementwise_affine=None,
164
+ bias=True,
165
+ interpolate=True,
166
+ upsample_factor=(2, 2, 2),
167
+ ):
168
+ super().__init__()
169
+ self.channels = channels
170
+ self.out_channels = out_channels or channels
171
+ self.use_conv = use_conv
172
+ self.use_conv_transpose = use_conv_transpose
173
+ self.name = name
174
+ self.interpolate = interpolate
175
+ self.upsample_factor = upsample_factor
176
+
177
+ if norm_type == "ln_norm":
178
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
179
+ elif norm_type == "rms_norm":
180
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
181
+ elif norm_type is None:
182
+ self.norm = None
183
+ else:
184
+ raise ValueError(f"unknown norm_type: {norm_type}")
185
+
186
+ conv = None
187
+ if use_conv_transpose:
188
+ raise NotImplementedError
189
+ elif use_conv:
190
+ if kernel_size is None:
191
+ kernel_size = 3
192
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
193
+
194
+ if name == "conv":
195
+ self.conv = conv
196
+ else:
197
+ self.Conv2d_0 = conv
198
+
199
+ def forward(
200
+ self,
201
+ hidden_states: torch.FloatTensor,
202
+ output_size: Optional[int] = None,
203
+ scale: float = 1.0,
204
+ ) -> torch.FloatTensor:
205
+ assert hidden_states.shape[1] == self.channels
206
+
207
+ if self.norm is not None:
208
+ raise NotImplementedError
209
+
210
+ if self.use_conv_transpose:
211
+ return self.conv(hidden_states)
212
+
213
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
214
+ dtype = hidden_states.dtype
215
+ if dtype == torch.bfloat16:
216
+ hidden_states = hidden_states.to(torch.float32)
217
+
218
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
219
+ if hidden_states.shape[0] >= 64:
220
+ hidden_states = hidden_states.contiguous()
221
+
222
+ # if `output_size` is passed we force the interpolation output
223
+ # size and do not make use of `scale_factor=2`
224
+ if self.interpolate:
225
+ B, C, T, H, W = hidden_states.shape
226
+ first_h, other_h = hidden_states.split((1, T - 1), dim=2)
227
+ if output_size is None:
228
+ if T > 1:
229
+ other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
230
+
231
+ first_h = first_h.squeeze(2)
232
+ first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
233
+ first_h = first_h.unsqueeze(2)
234
+ else:
235
+ raise NotImplementedError
236
+
237
+ if T > 1:
238
+ hidden_states = torch.cat((first_h, other_h), dim=2)
239
+ else:
240
+ hidden_states = first_h
241
+
242
+ # If the input is bfloat16, we cast back to bfloat16
243
+ if dtype == torch.bfloat16:
244
+ hidden_states = hidden_states.to(dtype)
245
+
246
+ if self.use_conv:
247
+ if self.name == "conv":
248
+ hidden_states = self.conv(hidden_states)
249
+ else:
250
+ hidden_states = self.Conv2d_0(hidden_states)
251
+
252
+ return hidden_states
253
+
254
+
255
+ class DownsampleCausal3D(nn.Module):
256
+ """
257
+ A 3D downsampling layer with an optional convolution.
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ channels: int,
263
+ use_conv: bool = False,
264
+ out_channels: Optional[int] = None,
265
+ padding: int = 1,
266
+ name: str = "conv",
267
+ kernel_size=3,
268
+ norm_type=None,
269
+ eps=None,
270
+ elementwise_affine=None,
271
+ bias=True,
272
+ stride=2,
273
+ ):
274
+ super().__init__()
275
+ self.channels = channels
276
+ self.out_channels = out_channels or channels
277
+ self.use_conv = use_conv
278
+ self.padding = padding
279
+ stride = stride
280
+ self.name = name
281
+
282
+ if norm_type == "ln_norm":
283
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
284
+ elif norm_type == "rms_norm":
285
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
286
+ elif norm_type is None:
287
+ self.norm = None
288
+ else:
289
+ raise ValueError(f"unknown norm_type: {norm_type}")
290
+
291
+ if use_conv:
292
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
293
+ else:
294
+ raise NotImplementedError
295
+
296
+ if name == "conv":
297
+ self.Conv2d_0 = conv
298
+ self.conv = conv
299
+ elif name == "Conv2d_0":
300
+ self.conv = conv
301
+ else:
302
+ self.conv = conv
303
+
304
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
305
+ assert hidden_states.shape[1] == self.channels
306
+
307
+ if self.norm is not None:
308
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
309
+
310
+ assert hidden_states.shape[1] == self.channels
311
+
312
+ hidden_states = self.conv(hidden_states)
313
+
314
+ return hidden_states
315
+
316
+
317
+ class ResnetBlockCausal3D(nn.Module):
318
+ r"""
319
+ A Resnet block.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ *,
325
+ in_channels: int,
326
+ out_channels: Optional[int] = None,
327
+ conv_shortcut: bool = False,
328
+ dropout: float = 0.0,
329
+ temb_channels: int = 512,
330
+ groups: int = 32,
331
+ groups_out: Optional[int] = None,
332
+ pre_norm: bool = True,
333
+ eps: float = 1e-6,
334
+ non_linearity: str = "swish",
335
+ skip_time_act: bool = False,
336
+ # default, scale_shift, ada_group, spatial
337
+ time_embedding_norm: str = "default",
338
+ kernel: Optional[torch.FloatTensor] = None,
339
+ output_scale_factor: float = 1.0,
340
+ use_in_shortcut: Optional[bool] = None,
341
+ up: bool = False,
342
+ down: bool = False,
343
+ conv_shortcut_bias: bool = True,
344
+ conv_3d_out_channels: Optional[int] = None,
345
+ ):
346
+ super().__init__()
347
+ self.pre_norm = pre_norm
348
+ self.pre_norm = True
349
+ self.in_channels = in_channels
350
+ out_channels = in_channels if out_channels is None else out_channels
351
+ self.out_channels = out_channels
352
+ self.use_conv_shortcut = conv_shortcut
353
+ self.up = up
354
+ self.down = down
355
+ self.output_scale_factor = output_scale_factor
356
+ self.time_embedding_norm = time_embedding_norm
357
+ self.skip_time_act = skip_time_act
358
+
359
+ linear_cls = nn.Linear
360
+
361
+ if groups_out is None:
362
+ groups_out = groups
363
+
364
+ if self.time_embedding_norm == "ada_group":
365
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
366
+ elif self.time_embedding_norm == "spatial":
367
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
368
+ else:
369
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
370
+
371
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
372
+
373
+ if temb_channels is not None:
374
+ if self.time_embedding_norm == "default":
375
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
376
+ elif self.time_embedding_norm == "scale_shift":
377
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
378
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
379
+ self.time_emb_proj = None
380
+ else:
381
+ raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
382
+ else:
383
+ self.time_emb_proj = None
384
+
385
+ if self.time_embedding_norm == "ada_group":
386
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
387
+ elif self.time_embedding_norm == "spatial":
388
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
389
+ else:
390
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
391
+
392
+ self.dropout = torch.nn.Dropout(dropout)
393
+ conv_3d_out_channels = conv_3d_out_channels or out_channels
394
+ self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
395
+
396
+ self.nonlinearity = get_activation(non_linearity)
397
+
398
+ self.upsample = self.downsample = None
399
+ if self.up:
400
+ self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
401
+ elif self.down:
402
+ self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
403
+
404
+ self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
405
+
406
+ self.conv_shortcut = None
407
+ if self.use_in_shortcut:
408
+ self.conv_shortcut = CausalConv3d(
409
+ in_channels,
410
+ conv_3d_out_channels,
411
+ kernel_size=1,
412
+ stride=1,
413
+ bias=conv_shortcut_bias,
414
+ )
415
+
416
+ def forward(
417
+ self,
418
+ input_tensor: torch.FloatTensor,
419
+ temb: torch.FloatTensor,
420
+ scale: float = 1.0,
421
+ ) -> torch.FloatTensor:
422
+ hidden_states = input_tensor
423
+
424
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
425
+ hidden_states = self.norm1(hidden_states, temb)
426
+ else:
427
+ hidden_states = self.norm1(hidden_states)
428
+
429
+ hidden_states = self.nonlinearity(hidden_states)
430
+
431
+ if self.upsample is not None:
432
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
433
+ if hidden_states.shape[0] >= 64:
434
+ input_tensor = input_tensor.contiguous()
435
+ hidden_states = hidden_states.contiguous()
436
+ input_tensor = self.upsample(input_tensor, scale=scale)
437
+ hidden_states = self.upsample(hidden_states, scale=scale)
438
+ elif self.downsample is not None:
439
+ input_tensor = self.downsample(input_tensor, scale=scale)
440
+ hidden_states = self.downsample(hidden_states, scale=scale)
441
+
442
+ hidden_states = self.conv1(hidden_states)
443
+
444
+ if self.time_emb_proj is not None:
445
+ if not self.skip_time_act:
446
+ temb = self.nonlinearity(temb)
447
+ temb = self.time_emb_proj(temb, scale)[:, :, None, None]
448
+
449
+ if temb is not None and self.time_embedding_norm == "default":
450
+ hidden_states = hidden_states + temb
451
+
452
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
453
+ hidden_states = self.norm2(hidden_states, temb)
454
+ else:
455
+ hidden_states = self.norm2(hidden_states)
456
+
457
+ if temb is not None and self.time_embedding_norm == "scale_shift":
458
+ scale, shift = torch.chunk(temb, 2, dim=1)
459
+ hidden_states = hidden_states * (1 + scale) + shift
460
+
461
+ hidden_states = self.nonlinearity(hidden_states)
462
+
463
+ hidden_states = self.dropout(hidden_states)
464
+ hidden_states = self.conv2(hidden_states)
465
+
466
+ if self.conv_shortcut is not None:
467
+ input_tensor = self.conv_shortcut(input_tensor)
468
+
469
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
470
+
471
+ return output_tensor
472
+
473
+
474
+ def get_down_block3d(
475
+ down_block_type: str,
476
+ num_layers: int,
477
+ in_channels: int,
478
+ out_channels: int,
479
+ temb_channels: int,
480
+ add_downsample: bool,
481
+ downsample_stride: int,
482
+ resnet_eps: float,
483
+ resnet_act_fn: str,
484
+ transformer_layers_per_block: int = 1,
485
+ num_attention_heads: Optional[int] = None,
486
+ resnet_groups: Optional[int] = None,
487
+ cross_attention_dim: Optional[int] = None,
488
+ downsample_padding: Optional[int] = None,
489
+ dual_cross_attention: bool = False,
490
+ use_linear_projection: bool = False,
491
+ only_cross_attention: bool = False,
492
+ upcast_attention: bool = False,
493
+ resnet_time_scale_shift: str = "default",
494
+ attention_type: str = "default",
495
+ resnet_skip_time_act: bool = False,
496
+ resnet_out_scale_factor: float = 1.0,
497
+ cross_attention_norm: Optional[str] = None,
498
+ attention_head_dim: Optional[int] = None,
499
+ downsample_type: Optional[str] = None,
500
+ dropout: float = 0.0,
501
+ ):
502
+ # If attn head dim is not defined, we default it to the number of heads
503
+ if attention_head_dim is None:
504
+ logger.warn(
505
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
506
+ )
507
+ attention_head_dim = num_attention_heads
508
+
509
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
510
+ if down_block_type == "DownEncoderBlockCausal3D":
511
+ return DownEncoderBlockCausal3D(
512
+ num_layers=num_layers,
513
+ in_channels=in_channels,
514
+ out_channels=out_channels,
515
+ dropout=dropout,
516
+ add_downsample=add_downsample,
517
+ downsample_stride=downsample_stride,
518
+ resnet_eps=resnet_eps,
519
+ resnet_act_fn=resnet_act_fn,
520
+ resnet_groups=resnet_groups,
521
+ downsample_padding=downsample_padding,
522
+ resnet_time_scale_shift=resnet_time_scale_shift,
523
+ )
524
+ raise ValueError(f"{down_block_type} does not exist.")
525
+
526
+
527
+ def get_up_block3d(
528
+ up_block_type: str,
529
+ num_layers: int,
530
+ in_channels: int,
531
+ out_channels: int,
532
+ prev_output_channel: int,
533
+ temb_channels: int,
534
+ add_upsample: bool,
535
+ upsample_scale_factor: Tuple,
536
+ resnet_eps: float,
537
+ resnet_act_fn: str,
538
+ resolution_idx: Optional[int] = None,
539
+ transformer_layers_per_block: int = 1,
540
+ num_attention_heads: Optional[int] = None,
541
+ resnet_groups: Optional[int] = None,
542
+ cross_attention_dim: Optional[int] = None,
543
+ dual_cross_attention: bool = False,
544
+ use_linear_projection: bool = False,
545
+ only_cross_attention: bool = False,
546
+ upcast_attention: bool = False,
547
+ resnet_time_scale_shift: str = "default",
548
+ attention_type: str = "default",
549
+ resnet_skip_time_act: bool = False,
550
+ resnet_out_scale_factor: float = 1.0,
551
+ cross_attention_norm: Optional[str] = None,
552
+ attention_head_dim: Optional[int] = None,
553
+ upsample_type: Optional[str] = None,
554
+ dropout: float = 0.0,
555
+ ) -> nn.Module:
556
+ # If attn head dim is not defined, we default it to the number of heads
557
+ if attention_head_dim is None:
558
+ logger.warn(
559
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
560
+ )
561
+ attention_head_dim = num_attention_heads
562
+
563
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
564
+ if up_block_type == "UpDecoderBlockCausal3D":
565
+ return UpDecoderBlockCausal3D(
566
+ num_layers=num_layers,
567
+ in_channels=in_channels,
568
+ out_channels=out_channels,
569
+ resolution_idx=resolution_idx,
570
+ dropout=dropout,
571
+ add_upsample=add_upsample,
572
+ upsample_scale_factor=upsample_scale_factor,
573
+ resnet_eps=resnet_eps,
574
+ resnet_act_fn=resnet_act_fn,
575
+ resnet_groups=resnet_groups,
576
+ resnet_time_scale_shift=resnet_time_scale_shift,
577
+ temb_channels=temb_channels,
578
+ )
579
+ raise ValueError(f"{up_block_type} does not exist.")
580
+
581
+
582
+ class UNetMidBlockCausal3D(nn.Module):
583
+ """
584
+ A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
585
+ """
586
+
587
+ def __init__(
588
+ self,
589
+ in_channels: int,
590
+ temb_channels: int,
591
+ dropout: float = 0.0,
592
+ num_layers: int = 1,
593
+ resnet_eps: float = 1e-6,
594
+ resnet_time_scale_shift: str = "default", # default, spatial
595
+ resnet_act_fn: str = "swish",
596
+ resnet_groups: int = 32,
597
+ attn_groups: Optional[int] = None,
598
+ resnet_pre_norm: bool = True,
599
+ add_attention: bool = True,
600
+ attention_head_dim: int = 1,
601
+ output_scale_factor: float = 1.0,
602
+ ):
603
+ super().__init__()
604
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
605
+ self.add_attention = add_attention
606
+
607
+ if attn_groups is None:
608
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
609
+
610
+ # there is always at least one resnet
611
+ resnets = [
612
+ ResnetBlockCausal3D(
613
+ in_channels=in_channels,
614
+ out_channels=in_channels,
615
+ temb_channels=temb_channels,
616
+ eps=resnet_eps,
617
+ groups=resnet_groups,
618
+ dropout=dropout,
619
+ time_embedding_norm=resnet_time_scale_shift,
620
+ non_linearity=resnet_act_fn,
621
+ output_scale_factor=output_scale_factor,
622
+ pre_norm=resnet_pre_norm,
623
+ )
624
+ ]
625
+ attentions = []
626
+
627
+ if attention_head_dim is None:
628
+ logger.warn(
629
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
630
+ )
631
+ attention_head_dim = in_channels
632
+
633
+ for _ in range(num_layers):
634
+ if self.add_attention:
635
+ attentions.append(
636
+ Attention(
637
+ in_channels,
638
+ heads=in_channels // attention_head_dim,
639
+ dim_head=attention_head_dim,
640
+ rescale_output_factor=output_scale_factor,
641
+ eps=resnet_eps,
642
+ norm_num_groups=attn_groups,
643
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
644
+ residual_connection=True,
645
+ bias=True,
646
+ upcast_softmax=True,
647
+ _from_deprecated_attn_block=True,
648
+ )
649
+ )
650
+ else:
651
+ attentions.append(None)
652
+
653
+ resnets.append(
654
+ ResnetBlockCausal3D(
655
+ in_channels=in_channels,
656
+ out_channels=in_channels,
657
+ temb_channels=temb_channels,
658
+ eps=resnet_eps,
659
+ groups=resnet_groups,
660
+ dropout=dropout,
661
+ time_embedding_norm=resnet_time_scale_shift,
662
+ non_linearity=resnet_act_fn,
663
+ output_scale_factor=output_scale_factor,
664
+ pre_norm=resnet_pre_norm,
665
+ )
666
+ )
667
+
668
+ self.attentions = nn.ModuleList(attentions)
669
+ self.resnets = nn.ModuleList(resnets)
670
+
671
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
672
+ hidden_states = self.resnets[0](hidden_states, temb)
673
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
674
+ if attn is not None:
675
+ B, C, T, H, W = hidden_states.shape
676
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
677
+ attention_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
678
+ hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
679
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
680
+ hidden_states = resnet(hidden_states, temb)
681
+
682
+ return hidden_states
683
+
684
+
685
+ class DownEncoderBlockCausal3D(nn.Module):
686
+ def __init__(
687
+ self,
688
+ in_channels: int,
689
+ out_channels: int,
690
+ dropout: float = 0.0,
691
+ num_layers: int = 1,
692
+ resnet_eps: float = 1e-6,
693
+ resnet_time_scale_shift: str = "default",
694
+ resnet_act_fn: str = "swish",
695
+ resnet_groups: int = 32,
696
+ resnet_pre_norm: bool = True,
697
+ output_scale_factor: float = 1.0,
698
+ add_downsample: bool = True,
699
+ downsample_stride: int = 2,
700
+ downsample_padding: int = 1,
701
+ ):
702
+ super().__init__()
703
+ resnets = []
704
+
705
+ for i in range(num_layers):
706
+ in_channels = in_channels if i == 0 else out_channels
707
+ resnets.append(
708
+ ResnetBlockCausal3D(
709
+ in_channels=in_channels,
710
+ out_channels=out_channels,
711
+ temb_channels=None,
712
+ eps=resnet_eps,
713
+ groups=resnet_groups,
714
+ dropout=dropout,
715
+ time_embedding_norm=resnet_time_scale_shift,
716
+ non_linearity=resnet_act_fn,
717
+ output_scale_factor=output_scale_factor,
718
+ pre_norm=resnet_pre_norm,
719
+ )
720
+ )
721
+
722
+ self.resnets = nn.ModuleList(resnets)
723
+
724
+ if add_downsample:
725
+ self.downsamplers = nn.ModuleList(
726
+ [
727
+ DownsampleCausal3D(
728
+ out_channels,
729
+ use_conv=True,
730
+ out_channels=out_channels,
731
+ padding=downsample_padding,
732
+ name="op",
733
+ stride=downsample_stride,
734
+ )
735
+ ]
736
+ )
737
+ else:
738
+ self.downsamplers = None
739
+
740
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
741
+ for resnet in self.resnets:
742
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
743
+
744
+ if self.downsamplers is not None:
745
+ for downsampler in self.downsamplers:
746
+ hidden_states = downsampler(hidden_states, scale)
747
+
748
+ return hidden_states
749
+
750
+
751
+ class UpDecoderBlockCausal3D(nn.Module):
752
+ def __init__(
753
+ self,
754
+ in_channels: int,
755
+ out_channels: int,
756
+ resolution_idx: Optional[int] = None,
757
+ dropout: float = 0.0,
758
+ num_layers: int = 1,
759
+ resnet_eps: float = 1e-6,
760
+ resnet_time_scale_shift: str = "default", # default, spatial
761
+ resnet_act_fn: str = "swish",
762
+ resnet_groups: int = 32,
763
+ resnet_pre_norm: bool = True,
764
+ output_scale_factor: float = 1.0,
765
+ add_upsample: bool = True,
766
+ upsample_scale_factor=(2, 2, 2),
767
+ temb_channels: Optional[int] = None,
768
+ ):
769
+ super().__init__()
770
+ resnets = []
771
+
772
+ for i in range(num_layers):
773
+ input_channels = in_channels if i == 0 else out_channels
774
+
775
+ resnets.append(
776
+ ResnetBlockCausal3D(
777
+ in_channels=input_channels,
778
+ out_channels=out_channels,
779
+ temb_channels=temb_channels,
780
+ eps=resnet_eps,
781
+ groups=resnet_groups,
782
+ dropout=dropout,
783
+ time_embedding_norm=resnet_time_scale_shift,
784
+ non_linearity=resnet_act_fn,
785
+ output_scale_factor=output_scale_factor,
786
+ pre_norm=resnet_pre_norm,
787
+ )
788
+ )
789
+
790
+ self.resnets = nn.ModuleList(resnets)
791
+
792
+ if add_upsample:
793
+ self.upsamplers = nn.ModuleList(
794
+ [
795
+ UpsampleCausal3D(
796
+ out_channels,
797
+ use_conv=True,
798
+ out_channels=out_channels,
799
+ upsample_factor=upsample_scale_factor,
800
+ )
801
+ ]
802
+ )
803
+ else:
804
+ self.upsamplers = None
805
+
806
+ self.resolution_idx = resolution_idx
807
+
808
+ def forward(
809
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
810
+ ) -> torch.FloatTensor:
811
+ for resnet in self.resnets:
812
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
813
+
814
+ if self.upsamplers is not None:
815
+ for upsampler in self.upsamplers:
816
+ hidden_states = upsampler(hidden_states)
817
+
818
+ return hidden_states
networks/__init__.py ADDED
File without changes
networks/lora.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module: currently conv2d is not fully supported
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import ast
7
+ import math
8
+ import os
9
+ import re
10
+ from typing import Dict, List, Optional, Type, Union
11
+ from diffusers import AutoencoderKL
12
+ from transformers import CLIPTextModel
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ import logging
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logging.basicConfig(level=logging.INFO)
21
+
22
+ HUNYUAN_TARGET_REPLACE_MODULES = ["MMDoubleStreamBlock", "MMSingleStreamBlock"]
23
+
24
+
25
+ class LoRAModule(torch.nn.Module):
26
+ """
27
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ lora_name,
33
+ org_module: torch.nn.Module,
34
+ multiplier=1.0,
35
+ lora_dim=4,
36
+ alpha=1,
37
+ dropout=None,
38
+ rank_dropout=None,
39
+ module_dropout=None,
40
+ split_dims: Optional[List[int]] = None,
41
+ ):
42
+ """
43
+ if alpha == 0 or None, alpha is rank (no scaling).
44
+
45
+ split_dims is used to mimic the split qkv of multi-head attention.
46
+ """
47
+ super().__init__()
48
+ self.lora_name = lora_name
49
+
50
+ if org_module.__class__.__name__ == "Conv2d":
51
+ in_dim = org_module.in_channels
52
+ out_dim = org_module.out_channels
53
+ else:
54
+ in_dim = org_module.in_features
55
+ out_dim = org_module.out_features
56
+
57
+ self.lora_dim = lora_dim
58
+ self.split_dims = split_dims
59
+
60
+ if split_dims is None:
61
+ if org_module.__class__.__name__ == "Conv2d":
62
+ kernel_size = org_module.kernel_size
63
+ stride = org_module.stride
64
+ padding = org_module.padding
65
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
66
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
67
+ else:
68
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
69
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
70
+
71
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
72
+ torch.nn.init.zeros_(self.lora_up.weight)
73
+ else:
74
+ # conv2d not supported
75
+ assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
76
+ assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
77
+ # print(f"split_dims: {split_dims}")
78
+ self.lora_down = torch.nn.ModuleList(
79
+ [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
80
+ )
81
+ self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
82
+ for lora_down in self.lora_down:
83
+ torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
84
+ for lora_up in self.lora_up:
85
+ torch.nn.init.zeros_(lora_up.weight)
86
+
87
+ if type(alpha) == torch.Tensor:
88
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
89
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
90
+ self.scale = alpha / self.lora_dim
91
+ self.register_buffer("alpha", torch.tensor(alpha)) # for save/load
92
+
93
+ # same as microsoft's
94
+ self.multiplier = multiplier
95
+ self.org_module = org_module # remove in applying
96
+ self.dropout = dropout
97
+ self.rank_dropout = rank_dropout
98
+ self.module_dropout = module_dropout
99
+
100
+ def apply_to(self):
101
+ self.org_forward = self.org_module.forward
102
+ self.org_module.forward = self.forward
103
+ del self.org_module
104
+
105
+ def forward(self, x):
106
+ org_forwarded = self.org_forward(x)
107
+
108
+ # module dropout
109
+ if self.module_dropout is not None and self.training:
110
+ if torch.rand(1) < self.module_dropout:
111
+ return org_forwarded
112
+
113
+ if self.split_dims is None:
114
+ lx = self.lora_down(x)
115
+
116
+ # normal dropout
117
+ if self.dropout is not None and self.training:
118
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
119
+
120
+ # rank dropout
121
+ if self.rank_dropout is not None and self.training:
122
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
123
+ if len(lx.size()) == 3:
124
+ mask = mask.unsqueeze(1) # for Text Encoder
125
+ elif len(lx.size()) == 4:
126
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
127
+ lx = lx * mask
128
+
129
+ # scaling for rank dropout: treat as if the rank is changed
130
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
131
+ else:
132
+ scale = self.scale
133
+
134
+ lx = self.lora_up(lx)
135
+
136
+ return org_forwarded + lx * self.multiplier * scale
137
+ else:
138
+ lxs = [lora_down(x) for lora_down in self.lora_down]
139
+
140
+ # normal dropout
141
+ if self.dropout is not None and self.training:
142
+ lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
143
+
144
+ # rank dropout
145
+ if self.rank_dropout is not None and self.training:
146
+ masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
147
+ for i in range(len(lxs)):
148
+ if len(lx.size()) == 3:
149
+ masks[i] = masks[i].unsqueeze(1)
150
+ elif len(lx.size()) == 4:
151
+ masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
152
+ lxs[i] = lxs[i] * masks[i]
153
+
154
+ # scaling for rank dropout: treat as if the rank is changed
155
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
156
+ else:
157
+ scale = self.scale
158
+
159
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
160
+
161
+ return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
162
+
163
+
164
+ class LoRAInfModule(LoRAModule):
165
+ def __init__(
166
+ self,
167
+ lora_name,
168
+ org_module: torch.nn.Module,
169
+ multiplier=1.0,
170
+ lora_dim=4,
171
+ alpha=1,
172
+ **kwargs,
173
+ ):
174
+ # no dropout for inference
175
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
176
+
177
+ self.org_module_ref = [org_module] # for reference
178
+ self.enabled = True
179
+ self.network: LoRANetwork = None
180
+
181
+ def set_network(self, network):
182
+ self.network = network
183
+
184
+ # merge weight to org_module
185
+ # def merge_to(self, sd, dtype, device, non_blocking=False):
186
+ # if torch.cuda.is_available():
187
+ # stream = torch.cuda.Stream(device=device)
188
+ # with torch.cuda.stream(stream):
189
+ # print(f"merge_to {self.lora_name}")
190
+ # self._merge_to(sd, dtype, device, non_blocking)
191
+ # torch.cuda.synchronize(device=device)
192
+ # print(f"merge_to {self.lora_name} done")
193
+ # torch.cuda.empty_cache()
194
+ # else:
195
+ # self._merge_to(sd, dtype, device, non_blocking)
196
+
197
+ def merge_to(self, sd, dtype, device, non_blocking=False):
198
+ # extract weight from org_module
199
+ org_sd = self.org_module.state_dict()
200
+ weight = org_sd["weight"]
201
+ org_dtype = weight.dtype
202
+ org_device = weight.device
203
+ weight = weight.to(device, dtype=torch.float, non_blocking=non_blocking) # for calculation
204
+
205
+ if dtype is None:
206
+ dtype = org_dtype
207
+ if device is None:
208
+ device = org_device
209
+
210
+ if self.split_dims is None:
211
+ # get up/down weight
212
+ down_weight = sd["lora_down.weight"].to(device, dtype=torch.float, non_blocking=non_blocking)
213
+ up_weight = sd["lora_up.weight"].to(device, dtype=torch.float, non_blocking=non_blocking)
214
+
215
+ # merge weight
216
+ if len(weight.size()) == 2:
217
+ # linear
218
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
219
+ elif down_weight.size()[2:4] == (1, 1):
220
+ # conv2d 1x1
221
+ weight = (
222
+ weight
223
+ + self.multiplier
224
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
225
+ * self.scale
226
+ )
227
+ else:
228
+ # conv2d 3x3
229
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
230
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
231
+ weight = weight + self.multiplier * conved * self.scale
232
+
233
+ # set weight to org_module
234
+ org_sd["weight"] = weight.to(org_device, dtype=dtype) # back to CPU without non_blocking
235
+ self.org_module.load_state_dict(org_sd)
236
+ else:
237
+ # split_dims
238
+ total_dims = sum(self.split_dims)
239
+ for i in range(len(self.split_dims)):
240
+ # get up/down weight
241
+ down_weight = sd[f"lora_down.{i}.weight"].to(device, torch.float, non_blocking=non_blocking) # (rank, in_dim)
242
+ up_weight = sd[f"lora_up.{i}.weight"].to(device, torch.float, non_blocking=non_blocking) # (split dim, rank)
243
+
244
+ # pad up_weight -> (total_dims, rank)
245
+ padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
246
+ padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight
247
+
248
+ # merge weight
249
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
250
+
251
+ # set weight to org_module
252
+ org_sd["weight"] = weight.to(org_device, dtype) # back to CPU without non_blocking
253
+ self.org_module.load_state_dict(org_sd)
254
+
255
+ # return weight for merge
256
+ def get_weight(self, multiplier=None):
257
+ if multiplier is None:
258
+ multiplier = self.multiplier
259
+
260
+ # get up/down weight from module
261
+ up_weight = self.lora_up.weight.to(torch.float)
262
+ down_weight = self.lora_down.weight.to(torch.float)
263
+
264
+ # pre-calculated weight
265
+ if len(down_weight.size()) == 2:
266
+ # linear
267
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
268
+ elif down_weight.size()[2:4] == (1, 1):
269
+ # conv2d 1x1
270
+ weight = (
271
+ self.multiplier
272
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
273
+ * self.scale
274
+ )
275
+ else:
276
+ # conv2d 3x3
277
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
278
+ weight = self.multiplier * conved * self.scale
279
+
280
+ return weight
281
+
282
+ def default_forward(self, x):
283
+ # logger.info(f"default_forward {self.lora_name} {x.size()}")
284
+ if self.split_dims is None:
285
+ lx = self.lora_down(x)
286
+ lx = self.lora_up(lx)
287
+ return self.org_forward(x) + lx * self.multiplier * self.scale
288
+ else:
289
+ lxs = [lora_down(x) for lora_down in self.lora_down]
290
+ lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
291
+ return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale
292
+
293
+ def forward(self, x):
294
+ if not self.enabled:
295
+ return self.org_forward(x)
296
+ return self.default_forward(x)
297
+
298
+
299
+ def create_arch_network(
300
+ multiplier: float,
301
+ network_dim: Optional[int],
302
+ network_alpha: Optional[float],
303
+ vae: nn.Module,
304
+ text_encoders: List[nn.Module],
305
+ unet: nn.Module,
306
+ neuron_dropout: Optional[float] = None,
307
+ **kwargs,
308
+ ):
309
+ # add default exclude patterns
310
+ exclude_patterns = kwargs.get("exclude_patterns", None)
311
+ if exclude_patterns is None:
312
+ exclude_patterns = []
313
+ else:
314
+ exclude_patterns = ast.literal_eval(exclude_patterns)
315
+
316
+ # exclude if 'img_mod', 'txt_mod' or 'modulation' in the name
317
+ exclude_patterns.append(r".*(img_mod|txt_mod|modulation).*")
318
+
319
+ kwargs["exclude_patterns"] = exclude_patterns
320
+
321
+ return create_network(
322
+ HUNYUAN_TARGET_REPLACE_MODULES,
323
+ "lora_unet",
324
+ multiplier,
325
+ network_dim,
326
+ network_alpha,
327
+ vae,
328
+ text_encoders,
329
+ unet,
330
+ neuron_dropout=neuron_dropout,
331
+ **kwargs,
332
+ )
333
+
334
+
335
+ def create_network(
336
+ target_replace_modules: List[str],
337
+ prefix: str,
338
+ multiplier: float,
339
+ network_dim: Optional[int],
340
+ network_alpha: Optional[float],
341
+ vae: nn.Module,
342
+ text_encoders: List[nn.Module],
343
+ unet: nn.Module,
344
+ neuron_dropout: Optional[float] = None,
345
+ **kwargs,
346
+ ):
347
+ """ architecture independent network creation """
348
+ if network_dim is None:
349
+ network_dim = 4 # default
350
+ if network_alpha is None:
351
+ network_alpha = 1.0
352
+
353
+ # extract dim/alpha for conv2d, and block dim
354
+ conv_dim = kwargs.get("conv_dim", None)
355
+ conv_alpha = kwargs.get("conv_alpha", None)
356
+ if conv_dim is not None:
357
+ conv_dim = int(conv_dim)
358
+ if conv_alpha is None:
359
+ conv_alpha = 1.0
360
+ else:
361
+ conv_alpha = float(conv_alpha)
362
+
363
+ # TODO generic rank/dim setting with regular expression
364
+
365
+ # rank/module dropout
366
+ rank_dropout = kwargs.get("rank_dropout", None)
367
+ if rank_dropout is not None:
368
+ rank_dropout = float(rank_dropout)
369
+ module_dropout = kwargs.get("module_dropout", None)
370
+ if module_dropout is not None:
371
+ module_dropout = float(module_dropout)
372
+
373
+ # verbose
374
+ verbose = kwargs.get("verbose", False)
375
+ if verbose is not None:
376
+ verbose = True if verbose == "True" else False
377
+
378
+ # regular expression for module selection: exclude and include
379
+ exclude_patterns = kwargs.get("exclude_patterns", None)
380
+ if exclude_patterns is not None and isinstance(exclude_patterns, str):
381
+ exclude_patterns = ast.literal_eval(exclude_patterns)
382
+ include_patterns = kwargs.get("include_patterns", None)
383
+ if include_patterns is not None and isinstance(include_patterns, str):
384
+ include_patterns = ast.literal_eval(include_patterns)
385
+
386
+ # too many arguments ( ^ω^)・・・
387
+ network = LoRANetwork(
388
+ target_replace_modules,
389
+ prefix,
390
+ text_encoders,
391
+ unet,
392
+ multiplier=multiplier,
393
+ lora_dim=network_dim,
394
+ alpha=network_alpha,
395
+ dropout=neuron_dropout,
396
+ rank_dropout=rank_dropout,
397
+ module_dropout=module_dropout,
398
+ conv_lora_dim=conv_dim,
399
+ conv_alpha=conv_alpha,
400
+ exclude_patterns=exclude_patterns,
401
+ include_patterns=include_patterns,
402
+ verbose=verbose,
403
+ )
404
+
405
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
406
+ # loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
407
+ # loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
408
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
409
+ # loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
410
+ # loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
411
+ if loraplus_lr_ratio is not None: # or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
412
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio) # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
413
+
414
+ return network
415
+
416
+
417
+ class LoRANetwork(torch.nn.Module):
418
+ # only supports U-Net (DiT), Text Encoders are not supported
419
+
420
+ def __init__(
421
+ self,
422
+ target_replace_modules: List[str],
423
+ prefix: str,
424
+ text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
425
+ unet: nn.Module,
426
+ multiplier: float = 1.0,
427
+ lora_dim: int = 4,
428
+ alpha: float = 1,
429
+ dropout: Optional[float] = None,
430
+ rank_dropout: Optional[float] = None,
431
+ module_dropout: Optional[float] = None,
432
+ conv_lora_dim: Optional[int] = None,
433
+ conv_alpha: Optional[float] = None,
434
+ module_class: Type[object] = LoRAModule,
435
+ modules_dim: Optional[Dict[str, int]] = None,
436
+ modules_alpha: Optional[Dict[str, int]] = None,
437
+ exclude_patterns: Optional[List[str]] = None,
438
+ include_patterns: Optional[List[str]] = None,
439
+ verbose: Optional[bool] = False,
440
+ ) -> None:
441
+ super().__init__()
442
+ self.multiplier = multiplier
443
+
444
+ self.lora_dim = lora_dim
445
+ self.alpha = alpha
446
+ self.conv_lora_dim = conv_lora_dim
447
+ self.conv_alpha = conv_alpha
448
+ self.dropout = dropout
449
+ self.rank_dropout = rank_dropout
450
+ self.module_dropout = module_dropout
451
+ self.target_replace_modules = target_replace_modules
452
+ self.prefix = prefix
453
+
454
+ self.loraplus_lr_ratio = None
455
+ # self.loraplus_unet_lr_ratio = None
456
+ # self.loraplus_text_encoder_lr_ratio = None
457
+
458
+ if modules_dim is not None:
459
+ logger.info(f"create LoRA network from weights")
460
+ else:
461
+ logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
462
+ logger.info(
463
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
464
+ )
465
+ # if self.conv_lora_dim is not None:
466
+ # logger.info(
467
+ # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
468
+ # )
469
+ # if train_t5xxl:
470
+ # logger.info(f"train T5XXL as well")
471
+
472
+ # compile regular expression if specified
473
+ exclude_re_patterns = []
474
+ if exclude_patterns is not None:
475
+ for pattern in exclude_patterns:
476
+ try:
477
+ re_pattern = re.compile(pattern)
478
+ except re.error as e:
479
+ logger.error(f"Invalid exclude pattern '{pattern}': {e}")
480
+ continue
481
+ exclude_re_patterns.append(re_pattern)
482
+
483
+ include_re_patterns = []
484
+ if include_patterns is not None:
485
+ for pattern in include_patterns:
486
+ try:
487
+ re_pattern = re.compile(pattern)
488
+ except re.error as e:
489
+ logger.error(f"Invalid include pattern '{pattern}': {e}")
490
+ continue
491
+ include_re_patterns.append(re_pattern)
492
+
493
+ # create module instances
494
+ def create_modules(
495
+ is_unet: bool,
496
+ pfx: str,
497
+ root_module: torch.nn.Module,
498
+ target_replace_mods: Optional[List[str]] = None,
499
+ filter: Optional[str] = None,
500
+ default_dim: Optional[int] = None,
501
+ ) -> List[LoRAModule]:
502
+ loras = []
503
+ skipped = []
504
+ for name, module in root_module.named_modules():
505
+ if target_replace_mods is None or module.__class__.__name__ in target_replace_mods:
506
+ if target_replace_mods is None: # dirty hack for all modules
507
+ module = root_module # search all modules
508
+
509
+ for child_name, child_module in module.named_modules():
510
+ is_linear = child_module.__class__.__name__ == "Linear"
511
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
512
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
513
+
514
+ if is_linear or is_conv2d:
515
+ original_name = (name + "." if name else "") + child_name
516
+ lora_name = f"{pfx}.{original_name}".replace(".", "_")
517
+
518
+ # exclude/include filter
519
+ excluded = False
520
+ for pattern in exclude_re_patterns:
521
+ if pattern.match(original_name):
522
+ excluded = True
523
+ break
524
+ included = False
525
+ for pattern in include_re_patterns:
526
+ if pattern.match(original_name):
527
+ included = True
528
+ break
529
+ if excluded and not included:
530
+ if verbose:
531
+ logger.info(f"exclude: {original_name}")
532
+ continue
533
+
534
+ # filter by name (not used in the current implementation)
535
+ if filter is not None and not filter in lora_name:
536
+ continue
537
+
538
+ dim = None
539
+ alpha = None
540
+
541
+ if modules_dim is not None:
542
+ # モジュール指定あり
543
+ if lora_name in modules_dim:
544
+ dim = modules_dim[lora_name]
545
+ alpha = modules_alpha[lora_name]
546
+ else:
547
+ # 通常、すべて対象とする
548
+ if is_linear or is_conv2d_1x1:
549
+ dim = default_dim if default_dim is not None else self.lora_dim
550
+ alpha = self.alpha
551
+ elif self.conv_lora_dim is not None:
552
+ dim = self.conv_lora_dim
553
+ alpha = self.conv_alpha
554
+
555
+ if dim is None or dim == 0:
556
+ # skipした情報を出力
557
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
558
+ skipped.append(lora_name)
559
+ continue
560
+
561
+ lora = module_class(
562
+ lora_name,
563
+ child_module,
564
+ self.multiplier,
565
+ dim,
566
+ alpha,
567
+ dropout=dropout,
568
+ rank_dropout=rank_dropout,
569
+ module_dropout=module_dropout,
570
+ )
571
+ loras.append(lora)
572
+
573
+ if target_replace_mods is None:
574
+ break # all modules are searched
575
+ return loras, skipped
576
+
577
+ # # create LoRA for text encoder
578
+ # # it is redundant to create LoRA modules even if they are not used
579
+
580
+ self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
581
+ # skipped_te = []
582
+ # for i, text_encoder in enumerate(text_encoders):
583
+ # index = i
584
+ # if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
585
+ # break
586
+ # logger.info(f"create LoRA for Text Encoder {index+1}:")
587
+ # text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
588
+ # logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
589
+ # self.text_encoder_loras.extend(text_encoder_loras)
590
+ # skipped_te += skipped
591
+
592
+ # create LoRA for U-Net
593
+ self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
594
+ self.unet_loras, skipped_un = create_modules(True, prefix, unet, target_replace_modules)
595
+
596
+ logger.info(f"create LoRA for U-Net/DiT: {len(self.unet_loras)} modules.")
597
+ if verbose:
598
+ for lora in self.unet_loras:
599
+ logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
600
+
601
+ skipped = skipped_un
602
+ if verbose and len(skipped) > 0:
603
+ logger.warning(
604
+ f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
605
+ )
606
+ for name in skipped:
607
+ logger.info(f"\t{name}")
608
+
609
+ # assertion
610
+ names = set()
611
+ for lora in self.text_encoder_loras + self.unet_loras:
612
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
613
+ names.add(lora.lora_name)
614
+
615
+ def prepare_network(self, args):
616
+ """
617
+ called after the network is created
618
+ """
619
+ pass
620
+
621
+ def set_multiplier(self, multiplier):
622
+ self.multiplier = multiplier
623
+ for lora in self.text_encoder_loras + self.unet_loras:
624
+ lora.multiplier = self.multiplier
625
+
626
+ def set_enabled(self, is_enabled):
627
+ for lora in self.text_encoder_loras + self.unet_loras:
628
+ lora.enabled = is_enabled
629
+
630
+ def load_weights(self, file):
631
+ if os.path.splitext(file)[1] == ".safetensors":
632
+ from safetensors.torch import load_file
633
+
634
+ weights_sd = load_file(file)
635
+ else:
636
+ weights_sd = torch.load(file, map_location="cpu")
637
+
638
+ info = self.load_state_dict(weights_sd, False)
639
+ return info
640
+
641
+ def apply_to(
642
+ self,
643
+ text_encoders: Optional[nn.Module],
644
+ unet: Optional[nn.Module],
645
+ apply_text_encoder: bool = True,
646
+ apply_unet: bool = True,
647
+ ):
648
+ if apply_text_encoder:
649
+ logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
650
+ else:
651
+ self.text_encoder_loras = []
652
+
653
+ if apply_unet:
654
+ logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
655
+ else:
656
+ self.unet_loras = []
657
+
658
+ for lora in self.text_encoder_loras + self.unet_loras:
659
+ lora.apply_to()
660
+ self.add_module(lora.lora_name, lora)
661
+
662
+ # マージできるかどうかを返す
663
+ def is_mergeable(self):
664
+ return True
665
+
666
+ # TODO refactor to common function with apply_to
667
+ def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None, non_blocking=False):
668
+ from concurrent.futures import ThreadPoolExecutor
669
+
670
+ with ThreadPoolExecutor(max_workers=2) as executor: # 2 workers is enough
671
+ futures = []
672
+ for lora in self.text_encoder_loras + self.unet_loras:
673
+ sd_for_lora = {}
674
+ for key in weights_sd.keys():
675
+ if key.startswith(lora.lora_name):
676
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
677
+ if len(sd_for_lora) == 0:
678
+ logger.info(f"no weight for {lora.lora_name}")
679
+ continue
680
+
681
+ # lora.merge_to(sd_for_lora, dtype, device)
682
+ futures.append(executor.submit(lora.merge_to, sd_for_lora, dtype, device, non_blocking))
683
+
684
+ for future in futures:
685
+ future.result()
686
+
687
+ logger.info(f"weights are merged")
688
+
689
+ def set_loraplus_lr_ratio(self, loraplus_lr_ratio): # , loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
690
+ self.loraplus_lr_ratio = loraplus_lr_ratio
691
+
692
+ logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_lr_ratio}")
693
+ # logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
694
+
695
+ def prepare_optimizer_params(self, unet_lr: float = 1e-4, **kwargs):
696
+ self.requires_grad_(True)
697
+
698
+ all_params = []
699
+ lr_descriptions = []
700
+
701
+ def assemble_params(loras, lr, loraplus_ratio):
702
+ param_groups = {"lora": {}, "plus": {}}
703
+ for lora in loras:
704
+ for name, param in lora.named_parameters():
705
+ if loraplus_ratio is not None and "lora_up" in name:
706
+ param_groups["plus"][f"{lora.lora_name}.{name}"] = param
707
+ else:
708
+ param_groups["lora"][f"{lora.lora_name}.{name}"] = param
709
+
710
+ params = []
711
+ descriptions = []
712
+ for key in param_groups.keys():
713
+ param_data = {"params": param_groups[key].values()}
714
+
715
+ if len(param_data["params"]) == 0:
716
+ continue
717
+
718
+ if lr is not None:
719
+ if key == "plus":
720
+ param_data["lr"] = lr * loraplus_ratio
721
+ else:
722
+ param_data["lr"] = lr
723
+
724
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
725
+ logger.info("NO LR skipping!")
726
+ continue
727
+
728
+ params.append(param_data)
729
+ descriptions.append("plus" if key == "plus" else "")
730
+
731
+ return params, descriptions
732
+
733
+ if self.unet_loras:
734
+ params, descriptions = assemble_params(self.unet_loras, unet_lr, self.loraplus_lr_ratio)
735
+ all_params.extend(params)
736
+ lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
737
+
738
+ return all_params, lr_descriptions
739
+
740
+ def enable_gradient_checkpointing(self):
741
+ # not supported
742
+ pass
743
+
744
+ def prepare_grad_etc(self, unet):
745
+ self.requires_grad_(True)
746
+
747
+ def on_epoch_start(self, unet):
748
+ self.train()
749
+
750
+ def on_step_start(self):
751
+ pass
752
+
753
+ def get_trainable_params(self):
754
+ return self.parameters()
755
+
756
+ def save_weights(self, file, dtype, metadata):
757
+ if metadata is not None and len(metadata) == 0:
758
+ metadata = None
759
+
760
+ state_dict = self.state_dict()
761
+
762
+ if dtype is not None:
763
+ for key in list(state_dict.keys()):
764
+ v = state_dict[key]
765
+ v = v.detach().clone().to("cpu").to(dtype)
766
+ state_dict[key] = v
767
+
768
+ if os.path.splitext(file)[1] == ".safetensors":
769
+ from safetensors.torch import save_file
770
+ from utils import model_utils
771
+
772
+ # Precalculate model hashes to save time on indexing
773
+ if metadata is None:
774
+ metadata = {}
775
+ model_hash, legacy_hash = model_utils.precalculate_safetensors_hashes(state_dict, metadata)
776
+ metadata["sshs_model_hash"] = model_hash
777
+ metadata["sshs_legacy_hash"] = legacy_hash
778
+
779
+ save_file(state_dict, file, metadata)
780
+ else:
781
+ torch.save(state_dict, file)
782
+
783
+ def backup_weights(self):
784
+ # 重みのバックアップを行う
785
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
786
+ for lora in loras:
787
+ org_module = lora.org_module_ref[0]
788
+ if not hasattr(org_module, "_lora_org_weight"):
789
+ sd = org_module.state_dict()
790
+ org_module._lora_org_weight = sd["weight"].detach().clone()
791
+ org_module._lora_restored = True
792
+
793
+ def restore_weights(self):
794
+ # 重みのリストアを行う
795
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
796
+ for lora in loras:
797
+ org_module = lora.org_module_ref[0]
798
+ if not org_module._lora_restored:
799
+ sd = org_module.state_dict()
800
+ sd["weight"] = org_module._lora_org_weight
801
+ org_module.load_state_dict(sd)
802
+ org_module._lora_restored = True
803
+
804
+ def pre_calculation(self):
805
+ # 事前計算を行う
806
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
807
+ for lora in loras:
808
+ org_module = lora.org_module_ref[0]
809
+ sd = org_module.state_dict()
810
+
811
+ org_weight = sd["weight"]
812
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
813
+ sd["weight"] = org_weight + lora_weight
814
+ assert sd["weight"].shape == org_weight.shape
815
+ org_module.load_state_dict(sd)
816
+
817
+ org_module._lora_restored = False
818
+ lora.enabled = False
819
+
820
+ def apply_max_norm_regularization(self, max_norm_value, device):
821
+ downkeys = []
822
+ upkeys = []
823
+ alphakeys = []
824
+ norms = []
825
+ keys_scaled = 0
826
+
827
+ state_dict = self.state_dict()
828
+ for key in state_dict.keys():
829
+ if "lora_down" in key and "weight" in key:
830
+ downkeys.append(key)
831
+ upkeys.append(key.replace("lora_down", "lora_up"))
832
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
833
+
834
+ for i in range(len(downkeys)):
835
+ down = state_dict[downkeys[i]].to(device)
836
+ up = state_dict[upkeys[i]].to(device)
837
+ alpha = state_dict[alphakeys[i]].to(device)
838
+ dim = down.shape[0]
839
+ scale = alpha / dim
840
+
841
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
842
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
843
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
844
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
845
+ else:
846
+ updown = up @ down
847
+
848
+ updown *= scale
849
+
850
+ norm = updown.norm().clamp(min=max_norm_value / 2)
851
+ desired = torch.clamp(norm, max=max_norm_value)
852
+ ratio = desired.cpu() / norm.cpu()
853
+ sqrt_ratio = ratio**0.5
854
+ if ratio != 1:
855
+ keys_scaled += 1
856
+ state_dict[upkeys[i]] *= sqrt_ratio
857
+ state_dict[downkeys[i]] *= sqrt_ratio
858
+ scalednorm = updown.norm() * ratio
859
+ norms.append(scalednorm.item())
860
+
861
+ return keys_scaled, sum(norms) / len(norms), max(norms)
862
+
863
+
864
+ def create_arch_network_from_weights(
865
+ multiplier: float,
866
+ weights_sd: Dict[str, torch.Tensor],
867
+ text_encoders: Optional[List[nn.Module]] = None,
868
+ unet: Optional[nn.Module] = None,
869
+ for_inference: bool = False,
870
+ **kwargs,
871
+ ) -> LoRANetwork:
872
+ return create_network_from_weights(
873
+ HUNYUAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
874
+ )
875
+
876
+
877
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
878
+ def create_network_from_weights(
879
+ target_replace_modules: List[str],
880
+ multiplier: float,
881
+ weights_sd: Dict[str, torch.Tensor],
882
+ text_encoders: Optional[List[nn.Module]] = None,
883
+ unet: Optional[nn.Module] = None,
884
+ for_inference: bool = False,
885
+ **kwargs,
886
+ ) -> LoRANetwork:
887
+ # get dim/alpha mapping
888
+ modules_dim = {}
889
+ modules_alpha = {}
890
+ for key, value in weights_sd.items():
891
+ if "." not in key:
892
+ continue
893
+
894
+ lora_name = key.split(".")[0]
895
+ if "alpha" in key:
896
+ modules_alpha[lora_name] = value
897
+ elif "lora_down" in key:
898
+ dim = value.shape[0]
899
+ modules_dim[lora_name] = dim
900
+ # logger.info(lora_name, value.size(), dim)
901
+
902
+ module_class = LoRAInfModule if for_inference else LoRAModule
903
+
904
+ network = LoRANetwork(
905
+ target_replace_modules,
906
+ "lora_unet",
907
+ text_encoders,
908
+ unet,
909
+ multiplier=multiplier,
910
+ modules_dim=modules_dim,
911
+ modules_alpha=modules_alpha,
912
+ module_class=module_class,
913
+ )
914
+ return network
networks/lora_wan.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA module for Wan2.1
2
+
3
+ import ast
4
+ from typing import Dict, List, Optional
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ import networks.lora as lora
14
+
15
+
16
+ WAN_TARGET_REPLACE_MODULES = ["WanAttentionBlock"]
17
+
18
+
19
+ def create_arch_network(
20
+ multiplier: float,
21
+ network_dim: Optional[int],
22
+ network_alpha: Optional[float],
23
+ vae: nn.Module,
24
+ text_encoders: List[nn.Module],
25
+ unet: nn.Module,
26
+ neuron_dropout: Optional[float] = None,
27
+ **kwargs,
28
+ ):
29
+ # add default exclude patterns
30
+ exclude_patterns = kwargs.get("exclude_patterns", None)
31
+ if exclude_patterns is None:
32
+ exclude_patterns = []
33
+ else:
34
+ exclude_patterns = ast.literal_eval(exclude_patterns)
35
+
36
+ # exclude if 'img_mod', 'txt_mod' or 'modulation' in the name
37
+ exclude_patterns.append(r".*(patch_embedding|text_embedding|time_embedding|time_projection|norm|head).*")
38
+
39
+ kwargs["exclude_patterns"] = exclude_patterns
40
+
41
+ return lora.create_network(
42
+ WAN_TARGET_REPLACE_MODULES,
43
+ "lora_unet",
44
+ multiplier,
45
+ network_dim,
46
+ network_alpha,
47
+ vae,
48
+ text_encoders,
49
+ unet,
50
+ neuron_dropout=neuron_dropout,
51
+ **kwargs,
52
+ )
53
+
54
+
55
+ def create_arch_network_from_weights(
56
+ multiplier: float,
57
+ weights_sd: Dict[str, torch.Tensor],
58
+ text_encoders: Optional[List[nn.Module]] = None,
59
+ unet: Optional[nn.Module] = None,
60
+ for_inference: bool = False,
61
+ **kwargs,
62
+ ) -> lora.LoRANetwork:
63
+ return lora.create_network_from_weights(
64
+ WAN_TARGET_REPLACE_MODULES, multiplier, weights_sd, text_encoders, unet, for_inference, **kwargs
65
+ )
pixel_outputs/pixel_w1_3_lora-000001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0861037ca1ce8517c187e11e392195df38423d670a73f1cf8622d9213fb2ad5c
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2da2daa5dea835024a559f06eea97b38a3b809b6724915cd6715fa702c74b059
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47f65b1ed7b77d27cf59274d3bf6f6285ada28ee0b453ac98efdcb86a429af7a
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c09f83bb67f7dd5566ae6ac5fd11f32877f964f5f416c786a5e85382c96df73
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6edb55a1a4115fefddfc63957c0c140b21e5aef5a4b0b19924305b59af796824
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:325d67e06add9069b3aa7f5761a778f233566f79a0aa3f51f1ea4a535f4d7211
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f01e0a354111250caaf0bc609db677552a8e7fc363472ca680ece252ffbb086
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75cb5269d989e6467384147285186be202ef5a77192430cb95a74425ad574d3b
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000009.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eea26f66e5bde80ad506b6012883e89f5249acf066160a8529b2c2bbfd9271c3
3
+ size 87594616
pixel_outputs/pixel_w1_3_lora-000010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8386003f664721e14168cf6c870052474a26fdccfd73176244770075305e0ee
3
+ size 87594624