rahul7star commited on
Commit
357c94c
·
verified ·
1 Parent(s): 1a1e7bf

Upload 99 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +15 -0
  2. assets/audio/2.WAV +3 -0
  3. assets/audio/3.WAV +3 -0
  4. assets/audio/4.WAV +3 -0
  5. assets/image/1.png +3 -0
  6. assets/image/2.png +3 -0
  7. assets/image/3.png +3 -0
  8. assets/image/4.png +3 -0
  9. assets/image/src1.png +3 -0
  10. assets/image/src2.png +3 -0
  11. assets/image/src3.png +3 -0
  12. assets/image/src4.png +3 -0
  13. assets/material/demo.png +3 -0
  14. assets/material/logo.png +3 -0
  15. assets/material/method.png +3 -0
  16. assets/material/teaser.png +3 -0
  17. assets/test.csv +25 -0
  18. hymm_gradio/flask_audio.py +268 -0
  19. hymm_gradio/gradio_audio.py +122 -0
  20. hymm_gradio/tool_for_end2end.py +325 -0
  21. hymm_sp/__init__.py +0 -0
  22. hymm_sp/__pycache__/__init__.cpython-310.pyc +0 -0
  23. hymm_sp/__pycache__/config.cpython-310.pyc +0 -0
  24. hymm_sp/__pycache__/constants.cpython-310.pyc +0 -0
  25. hymm_sp/__pycache__/helpers.cpython-310.pyc +0 -0
  26. hymm_sp/__pycache__/inference.cpython-310.pyc +0 -0
  27. hymm_sp/__pycache__/sample_inference_audio.cpython-310.pyc +0 -0
  28. hymm_sp/config.py +142 -0
  29. hymm_sp/constants.py +59 -0
  30. hymm_sp/data_kits/__pycache__/audio_dataset.cpython-310.pyc +0 -0
  31. hymm_sp/data_kits/__pycache__/audio_preprocessor.cpython-310.pyc +0 -0
  32. hymm_sp/data_kits/__pycache__/data_tools.cpython-310.pyc +0 -0
  33. hymm_sp/data_kits/__pycache__/ffmpeg_utils.cpython-310.pyc +0 -0
  34. hymm_sp/data_kits/audio_dataset.py +170 -0
  35. hymm_sp/data_kits/audio_preprocessor.py +72 -0
  36. hymm_sp/data_kits/data_tools.py +41 -0
  37. hymm_sp/data_kits/face_align/__init__.py +1 -0
  38. hymm_sp/data_kits/face_align/__pycache__/__init__.cpython-310.pyc +0 -0
  39. hymm_sp/data_kits/face_align/__pycache__/align.cpython-310.pyc +0 -0
  40. hymm_sp/data_kits/face_align/__pycache__/detface.cpython-310.pyc +0 -0
  41. hymm_sp/data_kits/face_align/align.py +34 -0
  42. hymm_sp/data_kits/face_align/detface.py +283 -0
  43. hymm_sp/data_kits/ffmpeg_utils.py +184 -0
  44. hymm_sp/diffusion/__init__.py +30 -0
  45. hymm_sp/diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  46. hymm_sp/diffusion/pipelines/__init__.py +1 -0
  47. hymm_sp/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
  48. hymm_sp/diffusion/pipelines/__pycache__/pipeline_hunyuan_video_audio.cpython-310.pyc +0 -0
  49. hymm_sp/diffusion/pipelines/pipeline_hunyuan_video_audio.py +1363 -0
  50. hymm_sp/diffusion/schedulers/__init__.py +1 -0
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/audio/2.WAV filter=lfs diff=lfs merge=lfs -text
37
+ assets/audio/3.WAV filter=lfs diff=lfs merge=lfs -text
38
+ assets/audio/4.WAV filter=lfs diff=lfs merge=lfs -text
39
+ assets/image/1.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/image/2.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/image/3.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/image/4.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/image/src1.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/image/src2.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/image/src3.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/image/src4.png filter=lfs diff=lfs merge=lfs -text
47
+ assets/material/demo.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/material/logo.png filter=lfs diff=lfs merge=lfs -text
49
+ assets/material/method.png filter=lfs diff=lfs merge=lfs -text
50
+ assets/material/teaser.png filter=lfs diff=lfs merge=lfs -text
assets/audio/2.WAV ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:388d768354bb20f3aaa8327ef3391737c8150d3351fdaa04aa98b57caddc5dfb
3
+ size 3862572
assets/audio/3.WAV ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:854f2df163f4fc7bd69f09e8ca31758dedf8a56b191083d5fd4a5b25259b5fe2
3
+ size 1921068
assets/audio/4.WAV ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81d992d96a96829f27aef2d99fa63614b692cc9ed2cbac61e23e9a1dcc402b3a
3
+ size 3809324
assets/image/1.png ADDED

Git LFS Details

  • SHA256: d9bb2f06d84af5ca9347d684554f9bd8edb118749457b37bff6026a1b429843d
  • Pointer size: 132 Bytes
  • Size of remote file: 7.36 MB
assets/image/2.png ADDED

Git LFS Details

  • SHA256: 97fdb135022e2d82366b5891b69d27b14fb183941c52cc6903acb5b6e74e89eb
  • Pointer size: 132 Bytes
  • Size of remote file: 5.93 MB
assets/image/3.png ADDED

Git LFS Details

  • SHA256: 6616f4dde7435a504f11af1b8919124ae48a592858fdc1ab0ccacaa265f2271d
  • Pointer size: 132 Bytes
  • Size of remote file: 7.34 MB
assets/image/4.png ADDED

Git LFS Details

  • SHA256: 931921def1c1c9336d1330554c1fa39b6090431dd720ccaa279e684b6253a800
  • Pointer size: 132 Bytes
  • Size of remote file: 5.28 MB
assets/image/src1.png ADDED

Git LFS Details

  • SHA256: 2c1ff64938aaa0950e33132e43cdc1e3e3a1cedc1b6ecdb33a067cb1971c4189
  • Pointer size: 133 Bytes
  • Size of remote file: 10.2 MB
assets/image/src2.png ADDED

Git LFS Details

  • SHA256: a4b4bb9edc4d01cec463d85d97cd1c9319e34d6f5d10e8d11d6c33a89f418b66
  • Pointer size: 132 Bytes
  • Size of remote file: 9.55 MB
assets/image/src3.png ADDED

Git LFS Details

  • SHA256: b51bed13f3fe69f5b98b95faa1e570a3600999bcc2479ef6080ac80f53255855
  • Pointer size: 132 Bytes
  • Size of remote file: 9.99 MB
assets/image/src4.png ADDED

Git LFS Details

  • SHA256: 25b8bbf44d24bd25cb37495816cf1aa9cfcc9e32935dad074823bfdc590d0420
  • Pointer size: 132 Bytes
  • Size of remote file: 9.6 MB
assets/material/demo.png ADDED

Git LFS Details

  • SHA256: 32a0b751c12f40babc96f208cdfa3932aca74044c7689770b849e461e0b6aa1a
  • Pointer size: 132 Bytes
  • Size of remote file: 7.68 MB
assets/material/logo.png ADDED

Git LFS Details

  • SHA256: 2318f6ce82b18f91586050672e45cc637db61a33885de672b1b0d26e31b5de94
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
assets/material/method.png ADDED

Git LFS Details

  • SHA256: 40ebe93eec368255ffaf58ad88f7aa5cdb605fd538f9272a3257803f9e86c338
  • Pointer size: 132 Bytes
  • Size of remote file: 4.94 MB
assets/material/teaser.png ADDED

Git LFS Details

  • SHA256: 8bbd88f126ad3a80ed613d90bb93e7a6bda346245df4aecd75b884ae00607ac4
  • Pointer size: 132 Bytes
  • Size of remote file: 5.43 MB
assets/test.csv ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ videoid,image,audio,prompt,fps
2
+ 8,assets/image/1.png,assets/audio/2.WAV,A person sits cross-legged by a campfire in a forested area.,25
3
+ 9,assets/image/2.png,assets/audio/2.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25
4
+ 10,assets/image/3.png,assets/audio/2.WAV,A person playing guitar by a campfire in a forest.,25
5
+ 11,assets/image/4.png,assets/audio/2.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25
6
+ 12,assets/image/src1.png,assets/audio/2.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25
7
+ 13,assets/image/src2.png,assets/audio/2.WAV,A person in a green jacket stands in a forest at dusk.,25
8
+ 14,assets/image/src3.png,assets/audio/2.WAV,A person playing guitar by a campfire in a forest.,25
9
+ 15,assets/image/src4.png,assets/audio/2.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25
10
+ 16,assets/image/1.png,assets/audio/3.WAV,A person sits cross-legged by a campfire in a forested area.,25
11
+ 17,assets/image/2.png,assets/audio/3.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25
12
+ 18,assets/image/3.png,assets/audio/3.WAV,A person playing guitar by a campfire in a forest.,25
13
+ 19,assets/image/4.png,assets/audio/3.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25
14
+ 20,assets/image/src1.png,assets/audio/3.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25
15
+ 21,assets/image/src2.png,assets/audio/3.WAV,A person in a green jacket stands in a forest at dusk.,25
16
+ 22,assets/image/src3.png,assets/audio/3.WAV,A person playing guitar by a campfire in a forest.,25
17
+ 23,assets/image/src4.png,assets/audio/3.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25
18
+ 24,assets/image/1.png,assets/audio/4.WAV,A person sits cross-legged by a campfire in a forested area.,25
19
+ 25,assets/image/2.png,assets/audio/4.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25
20
+ 26,assets/image/3.png,assets/audio/4.WAV,A person playing guitar by a campfire in a forest.,25
21
+ 27,assets/image/4.png,assets/audio/4.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25
22
+ 28,assets/image/src1.png,assets/audio/4.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25
23
+ 29,assets/image/src2.png,assets/audio/4.WAV,A person in a green jacket stands in a forest at dusk.,25
24
+ 30,assets/image/src3.png,assets/audio/4.WAV,A person playing guitar by a campfire in a forest.,25
25
+ 31,assets/image/src4.png,assets/audio/4.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25
hymm_gradio/flask_audio.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import warnings
5
+ import threading
6
+ import traceback
7
+ import uvicorn
8
+ from fastapi import FastAPI, Body
9
+ from pathlib import Path
10
+ from datetime import datetime
11
+ import torch.distributed as dist
12
+ from hymm_gradio.tool_for_end2end import *
13
+ from hymm_sp.config import parse_args
14
+ from hymm_sp.sample_inference_audio import HunyuanVideoSampler
15
+
16
+ from hymm_sp.modules.parallel_states import (
17
+ initialize_distributed,
18
+ nccl_info,
19
+ )
20
+
21
+ from transformers import WhisperModel
22
+ from transformers import AutoFeatureExtractor
23
+ from hymm_sp.data_kits.face_align import AlignImage
24
+
25
+
26
+ warnings.filterwarnings("ignore")
27
+ MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE')
28
+ app = FastAPI()
29
+ rlock = threading.RLock()
30
+
31
+
32
+
33
+ @app.api_route('/predict2', methods=['GET', 'POST'])
34
+ def predict(data=Body(...)):
35
+ is_acquire = False
36
+ error_info = ""
37
+ try:
38
+ is_acquire = rlock.acquire(blocking=False)
39
+ if is_acquire:
40
+ res = predict_wrap(data)
41
+ return res
42
+ except Exception as e:
43
+ error_info = traceback.format_exc()
44
+ print(error_info)
45
+ finally:
46
+ if is_acquire:
47
+ rlock.release()
48
+ return {"errCode": -1, "info": "broken"}
49
+
50
+ def predict_wrap(input_dict={}):
51
+ if nccl_info.sp_size > 1:
52
+ device = torch.device(f"cuda:{torch.distributed.get_rank()}")
53
+ rank = local_rank = torch.distributed.get_rank()
54
+ print(f"sp_size={nccl_info.sp_size}, rank {rank} local_rank {local_rank}")
55
+ try:
56
+ print(f"----- rank = {rank}")
57
+ if rank == 0:
58
+ input_dict = process_input_dict(input_dict)
59
+
60
+ print('------- start to predict -------')
61
+ # Parse input arguments
62
+ image_path = input_dict["image_path"]
63
+ driving_audio_path = input_dict["audio_path"]
64
+
65
+ prompt = input_dict["prompt"]
66
+
67
+ save_fps = input_dict.get("save_fps", 25)
68
+
69
+
70
+ ret_dict = None
71
+ if image_path is None or driving_audio_path is None:
72
+ ret_dict = {
73
+ "errCode": -3,
74
+ "content": [
75
+ {
76
+ "buffer": None
77
+ },
78
+ ],
79
+ "info": "input content is not valid",
80
+ }
81
+
82
+ print(f"errCode: -3, input content is not valid!")
83
+ return ret_dict
84
+
85
+ # Preprocess input batch
86
+ torch.cuda.synchronize()
87
+
88
+ a = datetime.now()
89
+
90
+ try:
91
+ model_kwargs_tmp = data_preprocess_server(
92
+ args, image_path, driving_audio_path, prompt, feature_extractor
93
+ )
94
+ except:
95
+ ret_dict = {
96
+ "errCode": -2,
97
+ "content": [
98
+ {
99
+ "buffer": None
100
+ },
101
+ ],
102
+ "info": "failed to preprocess input data"
103
+ }
104
+ print(f"errCode: -2, preprocess failed!")
105
+ return ret_dict
106
+
107
+ text_prompt = model_kwargs_tmp["text_prompt"]
108
+ audio_path = model_kwargs_tmp["audio_path"]
109
+ image_path = model_kwargs_tmp["image_path"]
110
+ fps = model_kwargs_tmp["fps"]
111
+ audio_prompts = model_kwargs_tmp["audio_prompts"]
112
+ audio_len = model_kwargs_tmp["audio_len"]
113
+ motion_bucket_id_exps = model_kwargs_tmp["motion_bucket_id_exps"]
114
+ motion_bucket_id_heads = model_kwargs_tmp["motion_bucket_id_heads"]
115
+ pixel_value_ref = model_kwargs_tmp["pixel_value_ref"]
116
+ pixel_value_ref_llava = model_kwargs_tmp["pixel_value_ref_llava"]
117
+
118
+
119
+
120
+ torch.cuda.synchronize()
121
+ b = datetime.now()
122
+ preprocess_time = (b - a).total_seconds()
123
+ print("="*100)
124
+ print("preprocess time :", preprocess_time)
125
+ print("="*100)
126
+
127
+ else:
128
+ text_prompt = None
129
+ audio_path = None
130
+ image_path = None
131
+ fps = None
132
+ audio_prompts = None
133
+ audio_len = None
134
+ motion_bucket_id_exps = None
135
+ motion_bucket_id_heads = None
136
+ pixel_value_ref = None
137
+ pixel_value_ref_llava = None
138
+
139
+ except:
140
+ traceback.print_exc()
141
+ if rank == 0:
142
+ ret_dict = {
143
+ "errCode": -1, # Failed to generate video
144
+ "content":[
145
+ {
146
+ "buffer": None
147
+ }
148
+ ],
149
+ "info": "failed to preprocess",
150
+ }
151
+ return ret_dict
152
+
153
+ try:
154
+ broadcast_params = [
155
+ text_prompt,
156
+ audio_path,
157
+ image_path,
158
+ fps,
159
+ audio_prompts,
160
+ audio_len,
161
+ motion_bucket_id_exps,
162
+ motion_bucket_id_heads,
163
+ pixel_value_ref,
164
+ pixel_value_ref_llava,
165
+ ]
166
+ dist.broadcast_object_list(broadcast_params, src=0)
167
+ outputs = generate_image_parallel(*broadcast_params)
168
+
169
+ if rank == 0:
170
+ samples = outputs["samples"]
171
+ sample = samples[0].unsqueeze(0)
172
+
173
+ sample = sample[:, :, :audio_len[0]]
174
+
175
+ video = sample[0].permute(1, 2, 3, 0).clamp(0, 1).numpy()
176
+ video = (video * 255.).astype(np.uint8)
177
+
178
+ output_dict = {
179
+ "err_code": 0,
180
+ "err_msg": "succeed",
181
+ "video": video,
182
+ "audio": input_dict.get("audio_path", None),
183
+ "save_fps": save_fps,
184
+ }
185
+
186
+ ret_dict = process_output_dict(output_dict)
187
+ return ret_dict
188
+
189
+ except:
190
+ traceback.print_exc()
191
+ if rank == 0:
192
+ ret_dict = {
193
+ "errCode": -1, # Failed to generate video
194
+ "content":[
195
+ {
196
+ "buffer": None
197
+ }
198
+ ],
199
+ "info": "failed to generate video",
200
+ }
201
+ return ret_dict
202
+
203
+ return None
204
+
205
+ def generate_image_parallel(text_prompt,
206
+ audio_path,
207
+ image_path,
208
+ fps,
209
+ audio_prompts,
210
+ audio_len,
211
+ motion_bucket_id_exps,
212
+ motion_bucket_id_heads,
213
+ pixel_value_ref,
214
+ pixel_value_ref_llava
215
+ ):
216
+ if nccl_info.sp_size > 1:
217
+ device = torch.device(f"cuda:{torch.distributed.get_rank()}")
218
+
219
+ batch = {
220
+ "text_prompt": text_prompt,
221
+ "audio_path": audio_path,
222
+ "image_path": image_path,
223
+ "fps": fps,
224
+ "audio_prompts": audio_prompts,
225
+ "audio_len": audio_len,
226
+ "motion_bucket_id_exps": motion_bucket_id_exps,
227
+ "motion_bucket_id_heads": motion_bucket_id_heads,
228
+ "pixel_value_ref": pixel_value_ref,
229
+ "pixel_value_ref_llava": pixel_value_ref_llava
230
+ }
231
+
232
+ samples = hunyuan_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance)
233
+ return samples
234
+
235
+ def worker_loop():
236
+ while True:
237
+ predict_wrap()
238
+
239
+
240
+ if __name__ == "__main__":
241
+ audio_args = parse_args()
242
+ initialize_distributed(audio_args.seed)
243
+ hunyuan_sampler = HunyuanVideoSampler.from_pretrained(
244
+ audio_args.ckpt, args=audio_args)
245
+ args = hunyuan_sampler.args
246
+
247
+ rank = local_rank = 0
248
+ device = torch.device("cuda")
249
+ if nccl_info.sp_size > 1:
250
+ device = torch.device(f"cuda:{torch.distributed.get_rank()}")
251
+ rank = local_rank = torch.distributed.get_rank()
252
+
253
+ feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/")
254
+ wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32)
255
+ wav2vec.requires_grad_(False)
256
+
257
+
258
+ BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/'
259
+ det_path = os.path.join(BASE_DIR, 'detface.pt')
260
+ align_instance = AlignImage("cuda", det_path=det_path)
261
+
262
+
263
+
264
+ if rank == 0:
265
+ uvicorn.run(app, host="0.0.0.0", port=80)
266
+ else:
267
+ worker_loop()
268
+
hymm_gradio/gradio_audio.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import json
5
+ import datetime
6
+ import requests
7
+ import gradio as gr
8
+ from tool_for_end2end import *
9
+
10
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
11
+ DATADIR = './temp'
12
+ _HEADER_ = '''
13
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
14
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">Tencent HunyuanVideo-Avatar Demo</h1>
15
+ </div>
16
+
17
+ '''
18
+ # flask url
19
+ URL = "http://127.0.0.1:80/predict2"
20
+
21
+ def post_and_get(audio_input, id_image, prompt):
22
+ now = datetime.datetime.now().isoformat()
23
+ imgdir = os.path.join(DATADIR, 'reference')
24
+ videodir = os.path.join(DATADIR, 'video')
25
+ imgfile = os.path.join(imgdir, now + '.png')
26
+ output_video_path = os.path.join(videodir, now + '.mp4')
27
+
28
+
29
+ os.makedirs(imgdir, exist_ok=True)
30
+ os.makedirs(videodir, exist_ok=True)
31
+ cv2.imwrite(imgfile, id_image[:,:,::-1])
32
+
33
+ proxies = {
34
+ "http": None,
35
+ "https": None,
36
+ }
37
+
38
+ files = {
39
+ "image_buffer": encode_image_to_base64(imgfile),
40
+ "audio_buffer": encode_wav_to_base64(audio_input),
41
+ "text": prompt,
42
+ "save_fps": 25,
43
+ }
44
+ r = requests.get(URL, data = json.dumps(files), proxies=proxies)
45
+ ret_dict = json.loads(r.text)
46
+ print(ret_dict["info"])
47
+ save_video_base64_to_local(
48
+ video_path=None,
49
+ base64_buffer=ret_dict["content"][0]["buffer"],
50
+ output_video_path=output_video_path)
51
+
52
+
53
+ return output_video_path
54
+
55
+ def create_demo():
56
+
57
+ with gr.Blocks() as demo:
58
+ gr.Markdown(_HEADER_)
59
+ with gr.Tab('语音数字人驱动'):
60
+ with gr.Row():
61
+ with gr.Column(scale=1):
62
+ with gr.Group():
63
+ prompt = gr.Textbox(label="Prompt", value="a man is speaking.")
64
+
65
+ audio_input = gr.Audio(sources=["upload"],
66
+ type="filepath",
67
+ label="Upload Audio",
68
+ elem_classes="media-upload",
69
+ scale=1)
70
+ id_image = gr.Image(label="Input reference image", height=480)
71
+
72
+ with gr.Column(scale=2):
73
+ with gr.Group():
74
+ output_image = gr.Video(label="Generated Video")
75
+
76
+
77
+ with gr.Column(scale=1):
78
+ generate_btn = gr.Button("Generate")
79
+
80
+ generate_btn.click(fn=post_and_get,
81
+ inputs=[audio_input, id_image, prompt],
82
+ outputs=[output_image],
83
+ )
84
+
85
+ # quick_prompts = [[x] for x in glob.glob('./assets/images/*.png')]
86
+ # example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Other object', samples_per_page=1000, components=[id_image])
87
+ # example_quick_prompts.click(lambda x: x[0], inputs=example_quick_prompts, outputs=id_image, show_progress=False, queue=False)
88
+ # with gr.Row(), gr.Column():
89
+ # gr.Markdown("## Examples")
90
+ # example_inps = [
91
+ # [
92
+ # 'A woman is drinking coffee at a café.',
93
+ # './assets/images/seg_woman_01.png',
94
+ # 1280, 720, 30, 129, 7.5, 13, 1024,
95
+ # "assets/videos/seg_woman_01.mp4"
96
+ # ],
97
+ # [
98
+ # 'In a cubicle of an office building, a woman focuses intently on the computer screen, typing rapidly on the keyboard, surrounded by piles of documents.',
99
+ # './assets/images/seg_woman_03.png',
100
+ # 1280, 720, 30, 129, 7.5, 13, 1025,
101
+ # "./assets/videos/seg_woman_03.mp4"
102
+ # ],
103
+ # [
104
+ # 'A man walks across an ancient stone bridge holding an umbrella, raindrops tapping against it.',
105
+ # './assets/images/seg_man_01.png',
106
+ # 1280, 720, 30, 129, 7.5, 13, 1025,
107
+ # "./assets/videos/seg_man_01.mp4"
108
+ # ],
109
+ # [
110
+ # 'During a train journey, a man admires the changing scenery through the window.',
111
+ # './assets/images/seg_man_02.png',
112
+ # 1280, 720, 30, 129, 7.5, 13, 1026,
113
+ # "./assets/videos/seg_man_02.mp4"
114
+ # ]
115
+ # ]
116
+ # gr.Examples(examples=example_inps, inputs=[prompt, id_image, width, height, num_steps, num_frames, guidance, flow_shift, seed, output_image],)
117
+ return demo
118
+
119
+ if __name__ == "__main__":
120
+ allowed_paths = ['/']
121
+ demo = create_demo()
122
+ demo.launch(server_name='0.0.0.0', server_port=8080, share=True, allowed_paths=allowed_paths)
hymm_gradio/tool_for_end2end.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import math
4
+ import uuid
5
+ import base64
6
+ import imageio
7
+ import torch
8
+ import torchvision
9
+ from PIL import Image
10
+ import numpy as np
11
+ from copy import deepcopy
12
+ from einops import rearrange
13
+ import torchvision.transforms as transforms
14
+ from torchvision.transforms import ToPILImage
15
+ from hymm_sp.data_kits.audio_dataset import get_audio_feature
16
+ from hymm_sp.data_kits.ffmpeg_utils import save_video
17
+
18
+ TEMP_DIR = "./temp"
19
+ if not os.path.exists(TEMP_DIR):
20
+ os.makedirs(TEMP_DIR, exist_ok=True)
21
+
22
+
23
+ def data_preprocess_server(args, image_path, audio_path, prompts, feature_extractor):
24
+ llava_transform = transforms.Compose(
25
+ [
26
+ transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
29
+ ]
30
+ )
31
+
32
+ """ 生成prompt """
33
+ if prompts is None:
34
+ prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed."
35
+ else:
36
+ prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + prompts
37
+
38
+ fps = 25
39
+
40
+ img_size = args.image_size
41
+ ref_image = Image.open(image_path).convert('RGB')
42
+
43
+ # Resize reference image
44
+ w, h = ref_image.size
45
+ scale = img_size / min(w, h)
46
+ new_w = round(w * scale / 64) * 64
47
+ new_h = round(h * scale / 64) * 64
48
+
49
+ if img_size == 704:
50
+ img_size_long = 1216
51
+ if new_w * new_h > img_size * img_size_long:
52
+ scale = math.sqrt(img_size * img_size_long / w / h)
53
+ new_w = round(w * scale / 64) * 64
54
+ new_h = round(h * scale / 64) * 64
55
+
56
+ ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
57
+
58
+ ref_image = np.array(ref_image)
59
+ ref_image = torch.from_numpy(ref_image)
60
+
61
+ audio_input, audio_len = get_audio_feature(feature_extractor, audio_path)
62
+ audio_prompts = audio_input[0]
63
+
64
+ motion_bucket_id_heads = np.array([25] * 4)
65
+ motion_bucket_id_exps = np.array([30] * 4)
66
+ motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
67
+ motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
68
+ fps = torch.from_numpy(np.array(fps))
69
+
70
+ to_pil = ToPILImage()
71
+ pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
72
+
73
+ pixel_value_ref_llava = [llava_transform(to_pil(image)) for image in pixel_value_ref]
74
+ pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
75
+
76
+ batch = {
77
+ "text_prompt": [prompts],
78
+ "audio_path": [audio_path],
79
+ "image_path": [image_path],
80
+ "fps": fps.unsqueeze(0).to(dtype=torch.float16),
81
+ "audio_prompts": audio_prompts.unsqueeze(0).to(dtype=torch.float16),
82
+ "audio_len": [audio_len],
83
+ "motion_bucket_id_exps": motion_bucket_id_exps.unsqueeze(0),
84
+ "motion_bucket_id_heads": motion_bucket_id_heads.unsqueeze(0),
85
+ "pixel_value_ref": pixel_value_ref.unsqueeze(0).to(dtype=torch.float16),
86
+ "pixel_value_ref_llava": pixel_value_ref_llava.unsqueeze(0).to(dtype=torch.float16)
87
+ }
88
+
89
+ return batch
90
+
91
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
92
+ videos = rearrange(videos, "b c t h w -> t b c h w")
93
+ outputs = []
94
+ for x in videos:
95
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
96
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
97
+ if rescale:
98
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
99
+ x = torch.clamp(x,0,1)
100
+ x = (x * 255).numpy().astype(np.uint8)
101
+ outputs.append(x)
102
+
103
+ os.makedirs(os.path.dirname(path), exist_ok=True)
104
+ imageio.mimsave(path, outputs, fps=fps, quality=quality)
105
+
106
+ def encode_image_to_base64(image_path):
107
+ try:
108
+ with open(image_path, 'rb') as image_file:
109
+ image_data = image_file.read()
110
+ encoded_data = base64.b64encode(image_data).decode('utf-8')
111
+ print(f"Image file '{image_path}' has been successfully encoded to Base64.")
112
+ return encoded_data
113
+
114
+ except Exception as e:
115
+ print(f"Error encoding image: {e}")
116
+ return None
117
+
118
+ def encode_video_to_base64(video_path):
119
+ try:
120
+ with open(video_path, 'rb') as video_file:
121
+ video_data = video_file.read()
122
+ encoded_data = base64.b64encode(video_data).decode('utf-8')
123
+ print(f"Video file '{video_path}' has been successfully encoded to Base64.")
124
+ return encoded_data
125
+
126
+ except Exception as e:
127
+ print(f"Error encoding video: {e}")
128
+ return None
129
+
130
+ def encode_wav_to_base64(wav_path):
131
+ try:
132
+ with open(wav_path, 'rb') as audio_file:
133
+ audio_data = audio_file.read()
134
+ encoded_data = base64.b64encode(audio_data).decode('utf-8')
135
+ print(f"Audio file '{wav_path}' has been successfully encoded to Base64.")
136
+ return encoded_data
137
+
138
+ except Exception as e:
139
+ print(f"Error encoding audio: {e}")
140
+ return None
141
+
142
+ def encode_pkl_to_base64(pkl_path):
143
+ try:
144
+ with open(pkl_path, 'rb') as pkl_file:
145
+ pkl_data = pkl_file.read()
146
+
147
+ encoded_data = base64.b64encode(pkl_data).decode('utf-8')
148
+
149
+ print(f"Pickle file '{pkl_path}' has been successfully encoded to Base64.")
150
+ return encoded_data
151
+
152
+ except Exception as e:
153
+ print(f"Error encoding pickle: {e}")
154
+ return None
155
+
156
+ def decode_base64_to_image(base64_buffer_str):
157
+ try:
158
+ image_data = base64.b64decode(base64_buffer_str)
159
+ image = Image.open(io.BytesIO(image_data))
160
+ image_array = np.array(image)
161
+ print(f"Image Base64 string has beed succesfully decoded to image.")
162
+ return image_array
163
+ except Exception as e:
164
+ print(f"Error encdecodingoding image: {e}")
165
+ return None
166
+
167
+ def decode_base64_to_video(base64_buffer_str):
168
+ try:
169
+ video_data = base64.b64decode(base64_buffer_str)
170
+ video_bytes = io.BytesIO(video_data)
171
+ video_bytes.seek(0)
172
+ video_reader = imageio.get_reader(video_bytes, 'ffmpeg')
173
+ video_frames = [frame for frame in video_reader]
174
+ return video_frames
175
+ except Exception as e:
176
+ print(f"Error decoding video: {e}")
177
+ return None
178
+
179
+
180
+ def save_video_base64_to_local(video_path=None, base64_buffer=None, output_video_path=None):
181
+ if video_path is not None and base64_buffer is None:
182
+ video_buffer_base64 = encode_video_to_base64(video_path)
183
+ elif video_path is None and base64_buffer is not None:
184
+ video_buffer_base64 = deepcopy(base64_buffer)
185
+ else:
186
+ print("Please pass either 'video_path' or 'base64_buffer'")
187
+ return None
188
+
189
+ if video_buffer_base64 is not None:
190
+ video_data = base64.b64decode(video_buffer_base64)
191
+ if output_video_path is None:
192
+ uuid_string = str(uuid.uuid4())
193
+ temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
194
+ else:
195
+ temp_video_path = output_video_path
196
+ with open(temp_video_path, 'wb') as video_file:
197
+ video_file.write(video_data)
198
+ return temp_video_path
199
+ else:
200
+ return None
201
+
202
+ def save_audio_base64_to_local(audio_path=None, base64_buffer=None):
203
+ if audio_path is not None and base64_buffer is None:
204
+ audio_buffer_base64 = encode_wav_to_base64(audio_path)
205
+ elif audio_path is None and base64_buffer is not None:
206
+ audio_buffer_base64 = deepcopy(base64_buffer)
207
+ else:
208
+ print("Please pass either 'audio_path' or 'base64_buffer'")
209
+ return None
210
+
211
+ if audio_buffer_base64 is not None:
212
+ audio_data = base64.b64decode(audio_buffer_base64)
213
+ uuid_string = str(uuid.uuid4())
214
+ temp_audio_path = f'{TEMP_DIR}/{uuid_string}.wav'
215
+ with open(temp_audio_path, 'wb') as audio_file:
216
+ audio_file.write(audio_data)
217
+ return temp_audio_path
218
+ else:
219
+ return None
220
+
221
+ def save_pkl_base64_to_local(pkl_path=None, base64_buffer=None):
222
+ if pkl_path is not None and base64_buffer is None:
223
+ pkl_buffer_base64 = encode_pkl_to_base64(pkl_path)
224
+ elif pkl_path is None and base64_buffer is not None:
225
+ pkl_buffer_base64 = deepcopy(base64_buffer)
226
+ else:
227
+ print("Please pass either 'pkl_path' or 'base64_buffer'")
228
+ return None
229
+
230
+ if pkl_buffer_base64 is not None:
231
+ pkl_data = base64.b64decode(pkl_buffer_base64)
232
+ uuid_string = str(uuid.uuid4())
233
+ temp_pkl_path = f'{TEMP_DIR}/{uuid_string}.pkl'
234
+ with open(temp_pkl_path, 'wb') as pkl_file:
235
+ pkl_file.write(pkl_data)
236
+ return temp_pkl_path
237
+ else:
238
+ return None
239
+
240
+ def remove_temp_fles(input_dict):
241
+ for key, val in input_dict.items():
242
+ if "_path" in key and val is not None and os.path.exists(val):
243
+ os.remove(val)
244
+ print(f"Remove temporary {key} from {val}")
245
+
246
+ def process_output_dict(output_dict):
247
+
248
+ uuid_string = str(uuid.uuid4())
249
+ temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
250
+ save_video(output_dict["video"], temp_video_path, fps=output_dict.get("save_fps", 25))
251
+
252
+ # Add audio
253
+ if output_dict["audio"] is not None and os.path.exists(output_dict["audio"]):
254
+ output_path = temp_video_path
255
+ audio_path = output_dict["audio"]
256
+ save_path = temp_video_path.replace(".mp4", "_audio.mp4")
257
+ print('='*100)
258
+ print(f"output_path = {output_path}\n audio_path = {audio_path}\n save_path = {save_path}")
259
+ os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{save_path}' -y -loglevel quiet; rm '{output_path}'")
260
+ else:
261
+ save_path = temp_video_path
262
+
263
+ video_base64_buffer = encode_video_to_base64(save_path)
264
+
265
+ encoded_output_dict = {
266
+ "errCode": output_dict["err_code"],
267
+ "content": [
268
+ {
269
+ "buffer": video_base64_buffer
270
+ },
271
+ ],
272
+ "info":output_dict["err_msg"],
273
+ }
274
+
275
+
276
+
277
+ return encoded_output_dict
278
+
279
+
280
+ def save_image_base64_to_local(image_path=None, base64_buffer=None):
281
+ # Encode image to base64 buffer
282
+ if image_path is not None and base64_buffer is None:
283
+ image_buffer_base64 = encode_image_to_base64(image_path)
284
+ elif image_path is None and base64_buffer is not None:
285
+ image_buffer_base64 = deepcopy(base64_buffer)
286
+ else:
287
+ print("Please pass either 'image_path' or 'base64_buffer'")
288
+ return None
289
+
290
+ # Decode base64 buffer and save to local disk
291
+ if image_buffer_base64 is not None:
292
+ image_data = base64.b64decode(image_buffer_base64)
293
+ uuid_string = str(uuid.uuid4())
294
+ temp_image_path = f'{TEMP_DIR}/{uuid_string}.png'
295
+ with open(temp_image_path, 'wb') as image_file:
296
+ image_file.write(image_data)
297
+ return temp_image_path
298
+ else:
299
+ return None
300
+
301
+ def process_input_dict(input_dict):
302
+
303
+ decoded_input_dict = {}
304
+
305
+ decoded_input_dict["save_fps"] = input_dict.get("save_fps", 25)
306
+
307
+ image_base64_buffer = input_dict.get("image_buffer", None)
308
+ if image_base64_buffer is not None:
309
+ decoded_input_dict["image_path"] = save_image_base64_to_local(
310
+ image_path=None,
311
+ base64_buffer=image_base64_buffer)
312
+ else:
313
+ decoded_input_dict["image_path"] = None
314
+
315
+ audio_base64_buffer = input_dict.get("audio_buffer", None)
316
+ if audio_base64_buffer is not None:
317
+ decoded_input_dict["audio_path"] = save_audio_base64_to_local(
318
+ audio_path=None,
319
+ base64_buffer=audio_base64_buffer)
320
+ else:
321
+ decoded_input_dict["audio_path"] = None
322
+
323
+ decoded_input_dict["prompt"] = input_dict.get("text", None)
324
+
325
+ return decoded_input_dict
hymm_sp/__init__.py ADDED
File without changes
hymm_sp/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (221 Bytes). View file
 
hymm_sp/__pycache__/config.cpython-310.pyc ADDED
Binary file (7.46 kB). View file
 
hymm_sp/__pycache__/constants.cpython-310.pyc ADDED
Binary file (1.43 kB). View file
 
hymm_sp/__pycache__/helpers.cpython-310.pyc ADDED
Binary file (4.11 kB). View file
 
hymm_sp/__pycache__/inference.cpython-310.pyc ADDED
Binary file (5.03 kB). View file
 
hymm_sp/__pycache__/sample_inference_audio.cpython-310.pyc ADDED
Binary file (8.71 kB). View file
 
hymm_sp/config.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from hymm_sp.constants import *
3
+ import re
4
+ import collections.abc
5
+
6
+ def as_tuple(x):
7
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
8
+ return tuple(x)
9
+ if x is None or isinstance(x, (int, float, str)):
10
+ return (x,)
11
+ else:
12
+ raise ValueError(f"Unknown type {type(x)}")
13
+
14
+ def parse_args(namespace=None):
15
+ parser = argparse.ArgumentParser(description="Hunyuan Multimodal training/inference script")
16
+ parser = add_extra_args(parser)
17
+ args = parser.parse_args(namespace=namespace)
18
+ args = sanity_check_args(args)
19
+ return args
20
+
21
+ def add_extra_args(parser: argparse.ArgumentParser):
22
+ parser = add_network_args(parser)
23
+ parser = add_extra_models_args(parser)
24
+ parser = add_denoise_schedule_args(parser)
25
+ parser = add_evaluation_args(parser)
26
+ return parser
27
+
28
+ def add_network_args(parser: argparse.ArgumentParser):
29
+ group = parser.add_argument_group(title="Network")
30
+ group.add_argument("--model", type=str, default="HYVideo-T/2",
31
+ help="Model architecture to use. It it also used to determine the experiment directory.")
32
+ group.add_argument("--latent-channels", type=str, default=None,
33
+ help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
34
+ "it still needs to match the latent channels of the VAE model.")
35
+ group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
36
+ return parser
37
+
38
+ def add_extra_models_args(parser: argparse.ArgumentParser):
39
+ group = parser.add_argument_group(title="Extra Models (VAE, Text Encoder, Tokenizer)")
40
+
41
+ # VAE
42
+ group.add_argument("--vae", type=str, default="884-16c-hy0801", help="Name of the VAE model.")
43
+ group.add_argument("--vae-precision", type=str, default="fp16",
44
+ help="Precision mode for the VAE model.")
45
+ group.add_argument("--vae-tiling", action="store_true", default=True, help="Enable tiling for the VAE model.")
46
+ group.add_argument("--text-encoder", type=str, default="llava-llama-3-8b", choices=list(TEXT_ENCODER_PATH),
47
+ help="Name of the text encoder model.")
48
+ group.add_argument("--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS,
49
+ help="Precision mode for the text encoder model.")
50
+ group.add_argument("--text-states-dim", type=int, default=4096, help="Dimension of the text encoder hidden states.")
51
+ group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.")
52
+ group.add_argument("--tokenizer", type=str, default="llava-llama-3-8b", choices=list(TOKENIZER_PATH),
53
+ help="Name of the tokenizer model.")
54
+ group.add_argument("--text-encoder-infer-mode", type=str, default="encoder", choices=["encoder", "decoder"],
55
+ help="Inference mode for the text encoder model. It should match the text encoder type. T5 and "
56
+ "CLIP can only work in 'encoder' mode, while Llava/GLM can work in both modes.")
57
+ group.add_argument("--prompt-template-video", type=str, default='li-dit-encode-video', choices=PROMPT_TEMPLATE,
58
+ help="Video prompt template for the decoder-only text encoder model.")
59
+ group.add_argument("--hidden-state-skip-layer", type=int, default=2,
60
+ help="Skip layer for hidden states.")
61
+ group.add_argument("--apply-final-norm", action="store_true",
62
+ help="Apply final normalization to the used text encoder hidden states.")
63
+
64
+ # - CLIP
65
+ group.add_argument("--text-encoder-2", type=str, default='clipL', choices=list(TEXT_ENCODER_PATH),
66
+ help="Name of the second text encoder model.")
67
+ group.add_argument("--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS,
68
+ help="Precision mode for the second text encoder model.")
69
+ group.add_argument("--text-states-dim-2", type=int, default=768,
70
+ help="Dimension of the second text encoder hidden states.")
71
+ group.add_argument("--tokenizer-2", type=str, default='clipL', choices=list(TOKENIZER_PATH),
72
+ help="Name of the second tokenizer model.")
73
+ group.add_argument("--text-len-2", type=int, default=77, help="Maximum length of the second text input.")
74
+ group.set_defaults(use_attention_mask=True)
75
+ group.add_argument("--text-projection", type=str, default="single_refiner", choices=TEXT_PROJECTION,
76
+ help="A projection layer for bridging the text encoder hidden states and the diffusion model "
77
+ "conditions.")
78
+ return parser
79
+
80
+
81
+ def add_denoise_schedule_args(parser: argparse.ArgumentParser):
82
+ group = parser.add_argument_group(title="Denoise schedule")
83
+ group.add_argument("--flow-shift-eval-video", type=float, default=None, help="Shift factor for flow matching schedulers when using video data.")
84
+ group.add_argument("--flow-reverse", action="store_true", default=True, help="If reverse, learning/sampling from t=1 -> t=0.")
85
+ group.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
86
+ group.add_argument("--use-linear-quadratic-schedule", action="store_true", help="Use linear quadratic schedule for flow matching."
87
+ "Follow MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)")
88
+ group.add_argument("--linear-schedule-end", type=int, default=25, help="End step for linear quadratic schedule for flow matching.")
89
+ return parser
90
+
91
+ def add_evaluation_args(parser: argparse.ArgumentParser):
92
+ group = parser.add_argument_group(title="Validation Loss Evaluation")
93
+ parser.add_argument("--precision", type=str, default="bf16", choices=PRECISIONS,
94
+ help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.")
95
+ parser.add_argument("--reproduce", action="store_true",
96
+ help="Enable reproducibility by setting random seeds and deterministic algorithms.")
97
+ parser.add_argument("--ckpt", type=str, help="Path to the checkpoint to evaluate.")
98
+ parser.add_argument("--load-key", type=str, default="module", choices=["module", "ema"],
99
+ help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.")
100
+ parser.add_argument("--cpu-offload", action="store_true", help="Use CPU offload for the model load.")
101
+ parser.add_argument("--infer-min", action="store_true", help="infer 5s.")
102
+ group.add_argument( "--use-fp8", action="store_true", help="Enable use fp8 for inference acceleration.")
103
+ group.add_argument("--video-size", type=int, nargs='+', default=512,
104
+ help="Video size for training. If a single value is provided, it will be used for both width "
105
+ "and height. If two values are provided, they will be used for width and height "
106
+ "respectively.")
107
+ group.add_argument("--sample-n-frames", type=int, default=1,
108
+ help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1")
109
+ group.add_argument("--infer-steps", type=int, default=100, help="Number of denoising steps for inference.")
110
+ group.add_argument("--val-disable-autocast", action="store_true",
111
+ help="Disable autocast for denoising loop and vae decoding in pipeline sampling.")
112
+ group.add_argument("--num-images", type=int, default=1, help="Number of images to generate for each prompt.")
113
+ group.add_argument("--seed", type=int, default=1024, help="Seed for evaluation.")
114
+ group.add_argument("--save-path-suffix", type=str, default="", help="Suffix for the directory of saved samples.")
115
+ group.add_argument("--pos-prompt", type=str, default='', help="Prompt for sampling during evaluation.")
116
+ group.add_argument("--neg-prompt", type=str, default='', help="Negative prompt for sampling during evaluation.")
117
+ group.add_argument("--image-size", type=int, default=704)
118
+ group.add_argument("--pad-face-size", type=float, default=0.7, help="Pad bbox for face align.")
119
+ group.add_argument("--image-path", type=str, default="", help="")
120
+ group.add_argument("--save-path", type=str, default=None, help="Path to save the generated samples.")
121
+ group.add_argument("--input", type=str, default=None, help="test data.")
122
+ group.add_argument("--item-name", type=str, default=None, help="")
123
+ group.add_argument("--cfg-scale", type=float, default=7.5, help="Classifier free guidance scale.")
124
+ group.add_argument("--ip-cfg-scale", type=float, default=0, help="Classifier free guidance scale.")
125
+ group.add_argument("--use-deepcache", type=int, default=1)
126
+ return parser
127
+
128
+ def sanity_check_args(args):
129
+ # VAE channels
130
+ vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
131
+ if not re.match(vae_pattern, args.vae):
132
+ raise ValueError(
133
+ f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
134
+ )
135
+ vae_channels = int(args.vae.split("-")[1][:-1])
136
+ if args.latent_channels is None:
137
+ args.latent_channels = vae_channels
138
+ if vae_channels != args.latent_channels:
139
+ raise ValueError(
140
+ f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
141
+ )
142
+ return args
hymm_sp/constants.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ __all__ = [
5
+ "PROMPT_TEMPLATE", "MODEL_BASE", "PRECISION_TO_TYPE",
6
+ "PRECISIONS", "VAE_PATH", "TEXT_ENCODER_PATH", "TOKENIZER_PATH",
7
+ "TEXT_PROJECTION",
8
+ ]
9
+
10
+ # =================== Constant Values =====================
11
+
12
+ PRECISION_TO_TYPE = {
13
+ 'fp32': torch.float32,
14
+ 'fp16': torch.float16,
15
+ 'bf16': torch.bfloat16,
16
+ }
17
+
18
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
19
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
20
+ "1. The main content and theme of the video."
21
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
22
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
23
+ "4. background environment, light, style and atmosphere."
24
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
25
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
26
+ )
27
+
28
+ PROMPT_TEMPLATE = {
29
+ "li-dit-encode-video": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95},
30
+ }
31
+
32
+ # ======================= Model ======================
33
+ PRECISIONS = {"fp32", "fp16", "bf16"}
34
+
35
+ # =================== Model Path =====================
36
+ MODEL_BASE = os.getenv("MODEL_BASE")
37
+ MODEL_BASE=f"{MODEL_BASE}/ckpts"
38
+
39
+ # 3D VAE
40
+ VAE_PATH = {
41
+ "884-16c-hy0801": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae",
42
+ }
43
+
44
+ # Text Encoder
45
+ TEXT_ENCODER_PATH = {
46
+ "clipL": f"{MODEL_BASE}/text_encoder_2",
47
+ "llava-llama-3-8b": f"{MODEL_BASE}/llava_llama_image",
48
+ }
49
+
50
+ # Tokenizer
51
+ TOKENIZER_PATH = {
52
+ "clipL": f"{MODEL_BASE}/text_encoder_2",
53
+ "llava-llama-3-8b":f"{MODEL_BASE}/llava_llama_image",
54
+ }
55
+
56
+ TEXT_PROJECTION = {
57
+ "linear", # Default, an nn.Linear() layer
58
+ "single_refiner", # Single TokenRefiner. Refer to LI-DiT
59
+ }
hymm_sp/data_kits/__pycache__/audio_dataset.cpython-310.pyc ADDED
Binary file (5.05 kB). View file
 
hymm_sp/data_kits/__pycache__/audio_preprocessor.cpython-310.pyc ADDED
Binary file (2.32 kB). View file
 
hymm_sp/data_kits/__pycache__/data_tools.cpython-310.pyc ADDED
Binary file (1.62 kB). View file
 
hymm_sp/data_kits/__pycache__/ffmpeg_utils.cpython-310.pyc ADDED
Binary file (4.02 kB). View file
 
hymm_sp/data_kits/audio_dataset.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import json
5
+ import torch
6
+ import random
7
+ import librosa
8
+ import traceback
9
+ import torchvision
10
+ import numpy as np
11
+ import pandas as pd
12
+ from PIL import Image
13
+ from einops import rearrange
14
+ from torch.utils.data import Dataset
15
+ from decord import VideoReader, cpu
16
+ from transformers import CLIPImageProcessor
17
+ import torchvision.transforms as transforms
18
+ from torchvision.transforms import ToPILImage
19
+
20
+
21
+
22
+ def get_audio_feature(feature_extractor, audio_path):
23
+ audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
24
+ assert sampling_rate == 16000
25
+
26
+ audio_features = []
27
+ window = 750*640
28
+ for i in range(0, len(audio_input), window):
29
+ audio_feature = feature_extractor(audio_input[i:i+window],
30
+ sampling_rate=sampling_rate,
31
+ return_tensors="pt",
32
+ ).input_features
33
+ audio_features.append(audio_feature)
34
+
35
+ audio_features = torch.cat(audio_features, dim=-1)
36
+ return audio_features, len(audio_input) // 640
37
+
38
+
39
+ class VideoAudioTextLoaderVal(Dataset):
40
+ def __init__(
41
+ self,
42
+ image_size: int,
43
+ meta_file: str,
44
+ **kwargs,
45
+ ):
46
+ super().__init__()
47
+ self.meta_file = meta_file
48
+ self.image_size = image_size
49
+ self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder
50
+ self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder
51
+ self.feature_extractor = kwargs.get("feature_extractor", None)
52
+ self.meta_files = []
53
+
54
+ csv_data = pd.read_csv(meta_file)
55
+ for idx in range(len(csv_data)):
56
+ self.meta_files.append(
57
+ {
58
+ "videoid": str(csv_data["videoid"][idx]),
59
+ "image_path": str(csv_data["image"][idx]),
60
+ "audio_path": str(csv_data["audio"][idx]),
61
+ "prompt": str(csv_data["prompt"][idx]),
62
+ "fps": float(csv_data["fps"][idx])
63
+ }
64
+ )
65
+
66
+ self.llava_transform = transforms.Compose(
67
+ [
68
+ transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
71
+ ]
72
+ )
73
+ self.clip_image_processor = CLIPImageProcessor()
74
+
75
+ self.device = torch.device("cuda")
76
+ self.weight_dtype = torch.float16
77
+
78
+
79
+ def __len__(self):
80
+ return len(self.meta_files)
81
+
82
+ @staticmethod
83
+ def get_text_tokens(text_encoder, description, dtype_encode="video"):
84
+ text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
85
+ text_ids = text_inputs["input_ids"].squeeze(0)
86
+ text_mask = text_inputs["attention_mask"].squeeze(0)
87
+ return text_ids, text_mask
88
+
89
+ def get_batch_data(self, idx):
90
+ meta_file = self.meta_files[idx]
91
+ videoid = meta_file["videoid"]
92
+ image_path = meta_file["image_path"]
93
+ audio_path = meta_file["audio_path"]
94
+ prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
95
+ fps = meta_file["fps"]
96
+
97
+ img_size = self.image_size
98
+ ref_image = Image.open(image_path).convert('RGB')
99
+
100
+ # Resize reference image
101
+ w, h = ref_image.size
102
+ scale = img_size / min(w, h)
103
+ new_w = round(w * scale / 64) * 64
104
+ new_h = round(h * scale / 64) * 64
105
+
106
+ if img_size == 704:
107
+ img_size_long = 1216
108
+ if new_w * new_h > img_size * img_size_long:
109
+ import math
110
+ scale = math.sqrt(img_size * img_size_long / w / h)
111
+ new_w = round(w * scale / 64) * 64
112
+ new_h = round(h * scale / 64) * 64
113
+
114
+ ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
115
+
116
+ ref_image = np.array(ref_image)
117
+ ref_image = torch.from_numpy(ref_image)
118
+
119
+ audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
120
+ audio_prompts = audio_input[0]
121
+
122
+ motion_bucket_id_heads = np.array([25] * 4)
123
+ motion_bucket_id_exps = np.array([30] * 4)
124
+ motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
125
+ motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
126
+ fps = torch.from_numpy(np.array(fps))
127
+
128
+ to_pil = ToPILImage()
129
+ pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
130
+
131
+ pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
132
+ pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
133
+ pixel_value_ref_clip = self.clip_image_processor(
134
+ images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)),
135
+ return_tensors="pt"
136
+ ).pixel_values[0]
137
+ pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
138
+
139
+ # Encode text prompts
140
+
141
+ text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
142
+ text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
143
+
144
+ # Output batch
145
+ batch = {
146
+ "text_prompt": prompt, #
147
+ "videoid": videoid,
148
+ "pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255)
149
+ "pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围
150
+ "pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围
151
+ "audio_prompts": audio_prompts.to(dtype=torch.float16),
152
+ "motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype),
153
+ "motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype),
154
+ "fps": fps.to(dtype=torch.float16),
155
+ "text_ids": text_ids.clone(), # 对应llava_text_encoder
156
+ "text_mask": text_mask.clone(), # 对应llava_text_encoder
157
+ "text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder
158
+ "text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder
159
+ "audio_len": audio_len,
160
+ "image_path": image_path,
161
+ "audio_path": audio_path,
162
+ }
163
+ return batch
164
+
165
+ def __getitem__(self, idx):
166
+ return self.get_batch_data(idx)
167
+
168
+
169
+
170
+
hymm_sp/data_kits/audio_preprocessor.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import cv2
4
+ import json
5
+ import time
6
+ import decord
7
+ import einops
8
+ import librosa
9
+ import torch
10
+ import random
11
+ import argparse
12
+ import traceback
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from einops import rearrange
17
+
18
+
19
+
20
+ def get_facemask(ref_image, align_instance, area=1.25):
21
+ # ref_image: (b f c h w)
22
+ bsz, f, c, h, w = ref_image.shape
23
+ images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8)
24
+ face_masks = []
25
+ for image in images:
26
+ image_pil = Image.fromarray(image).convert("RGB")
27
+ _, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
28
+ try:
29
+ bboxSrc = bboxes_list[0]
30
+ except:
31
+ bboxSrc = [0, 0, w, h]
32
+ x1, y1, ww, hh = bboxSrc
33
+ x2, y2 = x1 + ww, y1 + hh
34
+ ww, hh = (x2-x1) * area, (y2-y1) * area
35
+ center = [(x2+x1)//2, (y2+y1)//2]
36
+ x1 = max(center[0] - ww//2, 0)
37
+ y1 = max(center[1] - hh//2, 0)
38
+ x2 = min(center[0] + ww//2, w)
39
+ y2 = min(center[1] + hh//2, h)
40
+
41
+ face_mask = np.zeros_like(np.array(image_pil))
42
+ face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0
43
+ face_masks.append(torch.from_numpy(face_mask[...,:1]))
44
+ face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c)
45
+ face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f)
46
+ face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype)
47
+ return face_masks
48
+
49
+
50
+ def encode_audio(wav2vec, audio_feats, fps, num_frames=129):
51
+ if fps == 25:
52
+ start_ts = [0]
53
+ step_ts = [1]
54
+ elif fps == 12.5:
55
+ start_ts = [0]
56
+ step_ts = [2]
57
+ num_frames = min(num_frames, 400)
58
+ audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states
59
+ audio_feats = torch.stack(audio_feats, dim=2)
60
+ audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1)
61
+
62
+ audio_prompts = []
63
+ for bb in range(1):
64
+ audio_feats_list = []
65
+ for f in range(num_frames):
66
+ cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
67
+ audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10]
68
+ audio_feats_list.append(audio_clip)
69
+ audio_feats_list = torch.stack(audio_feats_list, 1)
70
+ audio_prompts.append(audio_feats_list)
71
+ audio_prompts = torch.cat(audio_prompts)
72
+ return audio_prompts
hymm_sp/data_kits/data_tools.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import imageio
6
+ import torchvision
7
+ from einops import rearrange
8
+
9
+
10
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
11
+ videos = rearrange(videos, "b c t h w -> t b c h w")
12
+ outputs = []
13
+ for x in videos:
14
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
15
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
16
+ if rescale:
17
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
18
+ x = torch.clamp(x,0,1)
19
+ x = (x * 255).numpy().astype(np.uint8)
20
+ outputs.append(x)
21
+
22
+ os.makedirs(os.path.dirname(path), exist_ok=True)
23
+ imageio.mimsave(path, outputs, fps=fps, quality=quality)
24
+
25
+ def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
26
+ crop_h, crop_w = crop_img.shape[:2]
27
+ target_w, target_h = size
28
+ scale_h, scale_w = target_h / crop_h, target_w / crop_w
29
+ if scale_w > scale_h:
30
+ resize_h = int(target_h*resize_ratio)
31
+ resize_w = int(crop_w / crop_h * resize_h)
32
+ else:
33
+ resize_w = int(target_w*resize_ratio)
34
+ resize_h = int(crop_h / crop_w * resize_w)
35
+ crop_img = cv2.resize(crop_img, (resize_w, resize_h))
36
+ pad_left = (target_w - resize_w) // 2
37
+ pad_top = (target_h - resize_h) // 2
38
+ pad_right = target_w - resize_w - pad_left
39
+ pad_bottom = target_h - resize_h - pad_top
40
+ crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
41
+ return crop_img
hymm_sp/data_kits/face_align/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .align import AlignImage
hymm_sp/data_kits/face_align/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (282 Bytes). View file
 
hymm_sp/data_kits/face_align/__pycache__/align.cpython-310.pyc ADDED
Binary file (1.37 kB). View file
 
hymm_sp/data_kits/face_align/__pycache__/detface.cpython-310.pyc ADDED
Binary file (7.98 kB). View file
 
hymm_sp/data_kits/face_align/align.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ from .detface import DetFace
5
+
6
+ class AlignImage(object):
7
+ def __init__(self, device='cuda', det_path=''):
8
+ self.facedet = DetFace(pt_path=det_path, confThreshold=0.5, nmsThreshold=0.45, device=device)
9
+
10
+ @torch.no_grad()
11
+ def __call__(self, im, maxface=False):
12
+ bboxes, kpss, scores = self.facedet.detect(im)
13
+ face_num = bboxes.shape[0]
14
+
15
+ five_pts_list = []
16
+ scores_list = []
17
+ bboxes_list = []
18
+ for i in range(face_num):
19
+ five_pts_list.append(kpss[i].reshape(5,2))
20
+ scores_list.append(scores[i])
21
+ bboxes_list.append(bboxes[i])
22
+
23
+ if maxface and face_num>1:
24
+ max_idx = 0
25
+ max_area = (bboxes[0, 2])*(bboxes[0, 3])
26
+ for i in range(1, face_num):
27
+ area = (bboxes[i,2])*(bboxes[i,3])
28
+ if area>max_area:
29
+ max_idx = i
30
+ five_pts_list = [five_pts_list[max_idx]]
31
+ scores_list = [scores_list[max_idx]]
32
+ bboxes_list = [bboxes_list[max_idx]]
33
+
34
+ return five_pts_list, scores_list, bboxes_list
hymm_sp/data_kits/face_align/detface.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+
8
+
9
+ def xyxy2xywh(x):
10
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
11
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
12
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
13
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
14
+ y[:, 2] = x[:, 2] - x[:, 0] # width
15
+ y[:, 3] = x[:, 3] - x[:, 1] # height
16
+ return y
17
+
18
+
19
+ def xywh2xyxy(x):
20
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
21
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
22
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
23
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
24
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
25
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
26
+ return y
27
+
28
+
29
+ def box_iou(box1, box2):
30
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
31
+ """
32
+ Return intersection-over-union (Jaccard index) of boxes.
33
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
34
+ Arguments:
35
+ box1 (Tensor[N, 4])
36
+ box2 (Tensor[M, 4])
37
+ Returns:
38
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
39
+ IoU values for every element in boxes1 and boxes2
40
+ """
41
+
42
+ def box_area(box):
43
+ # box = 4xn
44
+ return (box[2] - box[0]) * (box[3] - box[1])
45
+
46
+ area1 = box_area(box1.T)
47
+ area2 = box_area(box2.T)
48
+
49
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
50
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
51
+ torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
52
+ # iou = inter / (area1 + area2 - inter)
53
+ return inter / (area1[:, None] + area2 - inter)
54
+
55
+
56
+ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
57
+ # Rescale coords (xyxy) from img1_shape to img0_shape
58
+ if ratio_pad is None: # calculate from img0_shape
59
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
60
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
61
+ else:
62
+ gain = ratio_pad[0][0]
63
+ pad = ratio_pad[1]
64
+
65
+ coords[:, [0, 2]] -= pad[0] # x padding
66
+ coords[:, [1, 3]] -= pad[1] # y padding
67
+ coords[:, :4] /= gain
68
+ clip_coords(coords, img0_shape)
69
+ return coords
70
+
71
+
72
+ def clip_coords(boxes, img_shape):
73
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
74
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
75
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
76
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
77
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
78
+
79
+
80
+ def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
81
+ # Rescale coords (xyxy) from img1_shape to img0_shape
82
+ if ratio_pad is None: # calculate from img0_shape
83
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
84
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
85
+ else:
86
+ gain = ratio_pad[0][0]
87
+ pad = ratio_pad[1]
88
+
89
+ coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
90
+ coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
91
+ coords[:, :10] /= gain
92
+ #clip_coords(coords, img0_shape)
93
+ coords[:, 0].clamp_(0, img0_shape[1]) # x1
94
+ coords[:, 1].clamp_(0, img0_shape[0]) # y1
95
+ coords[:, 2].clamp_(0, img0_shape[1]) # x2
96
+ coords[:, 3].clamp_(0, img0_shape[0]) # y2
97
+ coords[:, 4].clamp_(0, img0_shape[1]) # x3
98
+ coords[:, 5].clamp_(0, img0_shape[0]) # y3
99
+ coords[:, 6].clamp_(0, img0_shape[1]) # x4
100
+ coords[:, 7].clamp_(0, img0_shape[0]) # y4
101
+ coords[:, 8].clamp_(0, img0_shape[1]) # x5
102
+ coords[:, 9].clamp_(0, img0_shape[0]) # y5
103
+ return coords
104
+
105
+
106
+ def show_results(img, xywh, conf, landmarks, class_num):
107
+ h,w,c = img.shape
108
+ tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
109
+ x1 = int(xywh[0] * w - 0.5 * xywh[2] * w)
110
+ y1 = int(xywh[1] * h - 0.5 * xywh[3] * h)
111
+ x2 = int(xywh[0] * w + 0.5 * xywh[2] * w)
112
+ y2 = int(xywh[1] * h + 0.5 * xywh[3] * h)
113
+ cv2.rectangle(img, (x1,y1), (x2, y2), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
114
+
115
+ clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
116
+
117
+ for i in range(5):
118
+ point_x = int(landmarks[2 * i] * w)
119
+ point_y = int(landmarks[2 * i + 1] * h)
120
+ cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1)
121
+
122
+ tf = max(tl - 1, 1) # font thickness
123
+ label = str(conf)[:5]
124
+ cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
125
+ return img
126
+
127
+
128
+ def make_divisible(x, divisor):
129
+ # Returns x evenly divisible by divisor
130
+ return (x // divisor) * divisor
131
+
132
+
133
+ def non_max_suppression_face(prediction, conf_thres=0.5, iou_thres=0.45, classes=None, agnostic=False, labels=()):
134
+ """Performs Non-Maximum Suppression (NMS) on inference results
135
+ Returns:
136
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
137
+ """
138
+
139
+ nc = prediction.shape[2] - 15 # number of classes
140
+ xc = prediction[..., 4] > conf_thres # candidates
141
+
142
+ # Settings
143
+ min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
144
+ # time_limit = 10.0 # seconds to quit after
145
+ redundant = True # require redundant detections
146
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
147
+ merge = False # use merge-NMS
148
+
149
+ # t = time.time()
150
+ output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
151
+ for xi, x in enumerate(prediction): # image index, image inference
152
+ # Apply constraints
153
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
154
+ x = x[xc[xi]] # confidence
155
+
156
+ # Cat apriori labels if autolabelling
157
+ if labels and len(labels[xi]):
158
+ l = labels[xi]
159
+ v = torch.zeros((len(l), nc + 15), device=x.device)
160
+ v[:, :4] = l[:, 1:5] # box
161
+ v[:, 4] = 1.0 # conf
162
+ v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls
163
+ x = torch.cat((x, v), 0)
164
+
165
+ # If none remain process next image
166
+ if not x.shape[0]:
167
+ continue
168
+
169
+ # Compute conf
170
+ x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
171
+
172
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
173
+ box = xywh2xyxy(x[:, :4])
174
+
175
+ # Detections matrix nx6 (xyxy, conf, landmarks, cls)
176
+ if multi_label:
177
+ i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
178
+ x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1)
179
+ else: # best class only
180
+ conf, j = x[:, 15:].max(1, keepdim=True)
181
+ x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
182
+
183
+ # Filter by class
184
+ if classes is not None:
185
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
186
+
187
+ # If none remain process next image
188
+ n = x.shape[0] # number of boxes
189
+ if not n:
190
+ continue
191
+
192
+ # Batched NMS
193
+ c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
194
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
195
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
196
+ #if i.shape[0] > max_det: # limit detections
197
+ # i = i[:max_det]
198
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
199
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
200
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
201
+ weights = iou * scores[None] # box weights
202
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
203
+ if redundant:
204
+ i = i[iou.sum(1) > 1] # require redundancy
205
+
206
+ output[xi] = x[i]
207
+ # if (time.time() - t) > time_limit:
208
+ # break # time limit exceeded
209
+
210
+ return output
211
+
212
+
213
+ class DetFace():
214
+ def __init__(self, pt_path, confThreshold=0.5, nmsThreshold=0.45, device='cuda'):
215
+ assert os.path.exists(pt_path)
216
+
217
+ self.inpSize = 416
218
+ self.conf_thres = confThreshold
219
+ self.iou_thres = nmsThreshold
220
+ self.test_device = torch.device(device if torch.cuda.is_available() else "cpu")
221
+ self.model = torch.jit.load(pt_path).to(self.test_device)
222
+ self.last_w = 416
223
+ self.last_h = 416
224
+ self.grids = None
225
+
226
+ @torch.no_grad()
227
+ def detect(self, srcimg):
228
+ # t0=time.time()
229
+
230
+ h0, w0 = srcimg.shape[:2] # orig hw
231
+ r = self.inpSize / min(h0, w0) # resize image to img_size
232
+ h1 = int(h0*r+31)//32*32
233
+ w1 = int(w0*r+31)//32*32
234
+
235
+ img = cv2.resize(srcimg, (w1,h1), interpolation=cv2.INTER_LINEAR)
236
+
237
+ # Convert
238
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR to RGB
239
+
240
+ # Run inference
241
+ img = torch.from_numpy(img).to(self.test_device).permute(2,0,1)
242
+ img = img.float()/255 # uint8 to fp16/32 0-1
243
+ if img.ndimension() == 3:
244
+ img = img.unsqueeze(0)
245
+
246
+ # Inference
247
+ if h1 != self.last_h or w1 != self.last_w or self.grids is None:
248
+ grids = []
249
+ for scale in [8,16,32]:
250
+ ny = h1//scale
251
+ nx = w1//scale
252
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
253
+ grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float()
254
+ grids.append(grid.to(self.test_device))
255
+ self.grids = grids
256
+ self.last_w = w1
257
+ self.last_h = h1
258
+
259
+ pred = self.model(img, self.grids).cpu()
260
+
261
+ # Apply NMS
262
+ det = non_max_suppression_face(pred, self.conf_thres, self.iou_thres)[0]
263
+ # Process detections
264
+ # det = pred[0]
265
+ bboxes = np.zeros((det.shape[0], 4))
266
+ kpss = np.zeros((det.shape[0], 5, 2))
267
+ scores = np.zeros((det.shape[0]))
268
+ # gn = torch.tensor([w0, h0, w0, h0]).to(pred) # normalization gain whwh
269
+ # gn_lks = torch.tensor([w0, h0, w0, h0, w0, h0, w0, h0, w0, h0]).to(pred) # normalization gain landmarks
270
+ det = det.cpu().numpy()
271
+
272
+ for j in range(det.shape[0]):
273
+ # xywh = (xyxy2xywh(det[j, :4].view(1, 4)) / gn).view(4).cpu().numpy()
274
+ bboxes[j, 0] = det[j, 0] * w0/w1
275
+ bboxes[j, 1] = det[j, 1] * h0/h1
276
+ bboxes[j, 2] = det[j, 2] * w0/w1 - bboxes[j, 0]
277
+ bboxes[j, 3] = det[j, 3] * h0/h1 - bboxes[j, 1]
278
+ scores[j] = det[j, 4]
279
+ # landmarks = (det[j, 5:15].view(1, 10) / gn_lks).view(5,2).cpu().numpy()
280
+ kpss[j, :, :] = det[j, 5:15].reshape(5, 2) * np.array([[w0/w1,h0/h1]])
281
+ # class_num = det[j, 15].cpu().numpy()
282
+ # orgimg = show_results(orgimg, xywh, conf, landmarks, class_num)
283
+ return bboxes, kpss, scores
hymm_sp/data_kits/ffmpeg_utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import skvideo
2
+ # assert skvideo.__version__ >= "1.1.11"
3
+ import os
4
+
5
+ import skvideo.io
6
+ import cv2
7
+
8
+ # install the following packages: #
9
+ # conda install -c conda-forge scikit-video ffmpeg #
10
+ import os
11
+ import torch
12
+ import torchvision
13
+ from PIL import Image
14
+ import numpy as np
15
+ from einops import rearrange
16
+
17
+
18
+
19
+ class VideoUtils(object):
20
+ def __init__(self, video_path=None, output_video_path=None, bit_rate='origin', fps=25):
21
+ if video_path is not None:
22
+ meta_data = skvideo.io.ffprobe(video_path)
23
+ # avg_frame_rate = meta_data['video']['@r_frame_rate']
24
+ # a, b = avg_frame_rate.split('/')
25
+ # fps = float(a) / float(b)
26
+ # fps = 25
27
+ codec_name = 'libx264'
28
+ # codec_name = meta_data['video'].get('@codec_name')
29
+ # if codec_name=='hevc':
30
+ # codec_name='h264'
31
+ # profile = meta_data['video'].get('@profile')
32
+ color_space = meta_data['video'].get('@color_space')
33
+ color_transfer = meta_data['video'].get('@color_transfer')
34
+ color_primaries = meta_data['video'].get('@color_primaries')
35
+ color_range = meta_data['video'].get('@color_range')
36
+ pix_fmt = meta_data['video'].get('@pix_fmt')
37
+ if bit_rate=='origin':
38
+ bit_rate = meta_data['video'].get('@bit_rate')
39
+ else:
40
+ bit_rate=None
41
+ if pix_fmt is None:
42
+ pix_fmt = 'yuv420p'
43
+
44
+ reader_output_dict = {'-r': str(fps)}
45
+ writer_input_dict = {'-r': str(fps)}
46
+ writer_output_dict = {'-pix_fmt': pix_fmt, '-r': str(fps), '-vcodec':str(codec_name)}
47
+ # if bit_rate is not None:
48
+ # writer_output_dict['-b:v'] = bit_rate
49
+ writer_output_dict['-crf'] = '17'
50
+
51
+ # if video has alpha channel, convert to bgra, uint16 to process
52
+ if pix_fmt.startswith('yuva'):
53
+ writer_input_dict['-pix_fmt'] = 'bgra64le'
54
+ reader_output_dict['-pix_fmt'] = 'bgra64le'
55
+ elif pix_fmt.endswith('le'):
56
+ writer_input_dict['-pix_fmt'] = 'bgr48le'
57
+ reader_output_dict['-pix_fmt'] = 'bgr48le'
58
+ else:
59
+ writer_input_dict['-pix_fmt'] = 'bgr24'
60
+ reader_output_dict['-pix_fmt'] = 'bgr24'
61
+
62
+ if color_range is not None:
63
+ writer_output_dict['-color_range'] = color_range
64
+ writer_input_dict['-color_range'] = color_range
65
+ if color_space is not None:
66
+ writer_output_dict['-colorspace'] = color_space
67
+ writer_input_dict['-colorspace'] = color_space
68
+ if color_primaries is not None:
69
+ writer_output_dict['-color_primaries'] = color_primaries
70
+ writer_input_dict['-color_primaries'] = color_primaries
71
+ if color_transfer is not None:
72
+ writer_output_dict['-color_trc'] = color_transfer
73
+ writer_input_dict['-color_trc'] = color_transfer
74
+
75
+ writer_output_dict['-sws_flags'] = 'full_chroma_int+bitexact+accurate_rnd'
76
+ reader_output_dict['-sws_flags'] = 'full_chroma_int+bitexact+accurate_rnd'
77
+ # writer_input_dict['-pix_fmt'] = 'bgr48le'
78
+ # reader_output_dict = {'-pix_fmt': 'bgr48le'}
79
+
80
+ # -s 1920x1080
81
+ # writer_input_dict['-s'] = '1920x1080'
82
+ # writer_output_dict['-s'] = '1920x1080'
83
+ # writer_input_dict['-s'] = '1080x1920'
84
+ # writer_output_dict['-s'] = '1080x1920'
85
+
86
+ print(writer_input_dict)
87
+ print(writer_output_dict)
88
+
89
+ self.reader = skvideo.io.FFmpegReader(video_path, outputdict=reader_output_dict)
90
+ else:
91
+
92
+ # fps = 25
93
+ codec_name = 'libx264'
94
+ bit_rate=None
95
+ pix_fmt = 'yuv420p'
96
+
97
+ reader_output_dict = {'-r': str(fps)}
98
+ writer_input_dict = {'-r': str(fps)}
99
+ writer_output_dict = {'-pix_fmt': pix_fmt, '-r': str(fps), '-vcodec':str(codec_name)}
100
+ # if bit_rate is not None:
101
+ # writer_output_dict['-b:v'] = bit_rate
102
+ writer_output_dict['-crf'] = '17'
103
+
104
+ # if video has alpha channel, convert to bgra, uint16 to process
105
+ if pix_fmt.startswith('yuva'):
106
+ writer_input_dict['-pix_fmt'] = 'bgra64le'
107
+ reader_output_dict['-pix_fmt'] = 'bgra64le'
108
+ elif pix_fmt.endswith('le'):
109
+ writer_input_dict['-pix_fmt'] = 'bgr48le'
110
+ reader_output_dict['-pix_fmt'] = 'bgr48le'
111
+ else:
112
+ writer_input_dict['-pix_fmt'] = 'bgr24'
113
+ reader_output_dict['-pix_fmt'] = 'bgr24'
114
+
115
+ writer_output_dict['-sws_flags'] = 'full_chroma_int+bitexact+accurate_rnd'
116
+ print(writer_input_dict)
117
+ print(writer_output_dict)
118
+
119
+ if output_video_path is not None:
120
+ self.writer = skvideo.io.FFmpegWriter(output_video_path, inputdict=writer_input_dict, outputdict=writer_output_dict, verbosity=1)
121
+
122
+ def getframes(self):
123
+ return self.reader.nextFrame()
124
+
125
+ def writeframe(self, frame):
126
+ if frame is None:
127
+ self.writer.close()
128
+ else:
129
+ self.writer.writeFrame(frame)
130
+
131
+
132
+ def save_videos_from_pil(pil_images, path, fps=8):
133
+ save_fmt = ".mp4"
134
+ os.makedirs(os.path.dirname(path), exist_ok=True)
135
+ width, height = pil_images[0].size
136
+
137
+ if save_fmt == ".mp4":
138
+ video_cap = VideoUtils(output_video_path=path, fps=fps)
139
+ for pil_image in pil_images:
140
+ image_cv2 = np.array(pil_image)[:,:,[2,1,0]]
141
+ video_cap.writeframe(image_cv2)
142
+ video_cap.writeframe(None)
143
+
144
+ elif save_fmt == ".gif":
145
+ pil_images[0].save(
146
+ fp=path,
147
+ format="GIF",
148
+ append_images=pil_images[1:],
149
+ save_all=True,
150
+ duration=(1 / fps * 1000),
151
+ loop=0,
152
+ optimize=False,
153
+ lossless=True
154
+ )
155
+ else:
156
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
157
+
158
+
159
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
160
+ videos = rearrange(videos, "b c t h w -> t b c h w")
161
+ height, width = videos.shape[-2:]
162
+ outputs = []
163
+
164
+ for x in videos:
165
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
166
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
167
+ if rescale:
168
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
169
+ x = (x * 255).numpy().astype(np.uint8)
170
+ x = Image.fromarray(x)
171
+
172
+ outputs.append(x)
173
+
174
+ os.makedirs(os.path.dirname(path), exist_ok=True)
175
+
176
+ save_videos_from_pil(outputs, path, fps)
177
+
178
+ def save_video(video, path: str, rescale=False, n_rows=6, fps=8):
179
+ outputs = []
180
+ for x in video:
181
+ x = Image.fromarray(x)
182
+ outputs.append(x)
183
+
184
+ save_videos_from_pil(outputs, path, fps)
hymm_sp/diffusion/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pipelines import HunyuanVideoAudioPipeline
2
+ from .schedulers import FlowMatchDiscreteScheduler
3
+
4
+
5
+ def load_diffusion_pipeline(args, rank, vae, text_encoder, text_encoder_2, model, scheduler=None,
6
+ device=None, progress_bar_config=None):
7
+ """ Load the denoising scheduler for inference. """
8
+ if scheduler is None:
9
+ scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift_eval_video, reverse=args.flow_reverse, solver=args.flow_solver, )
10
+
11
+ # Only enable progress bar for rank 0
12
+ progress_bar_config = progress_bar_config or {'leave': True, 'disable': rank != 0}
13
+
14
+ pipeline = HunyuanVideoAudioPipeline(vae=vae,
15
+ text_encoder=text_encoder,
16
+ text_encoder_2=text_encoder_2,
17
+ transformer=model,
18
+ scheduler=scheduler,
19
+ # safety_checker=None,
20
+ # feature_extractor=None,
21
+ # requires_safety_checker=False,
22
+ progress_bar_config=progress_bar_config,
23
+ args=args,
24
+ )
25
+ if args.cpu_offload: # avoid oom
26
+ pass
27
+ else:
28
+ pipeline = pipeline.to(device)
29
+
30
+ return pipeline
hymm_sp/diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (921 Bytes). View file
 
hymm_sp/diffusion/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline_hunyuan_video_audio import HunyuanVideoAudioPipeline
hymm_sp/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (319 Bytes). View file
 
hymm_sp/diffusion/pipelines/__pycache__/pipeline_hunyuan_video_audio.cpython-310.pyc ADDED
Binary file (38 kB). View file
 
hymm_sp/diffusion/pipelines/pipeline_hunyuan_video_audio.py ADDED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import numpy as np
22
+ import torch
23
+ from packaging import version
24
+ from diffusers.utils import BaseOutput
25
+ from dataclasses import dataclass
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.configuration_utils import FrozenDict
28
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
29
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
30
+ from diffusers.models import AutoencoderKL, ImageProjection
31
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
32
+ from diffusers.schedulers import KarrasDiffusionSchedulers
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ deprecate,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+
44
+ from hymm_sp.constants import PRECISION_TO_TYPE
45
+ from hymm_sp.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
46
+ from hymm_sp.text_encoder import TextEncoder
47
+ from einops import rearrange
48
+ from ...modules import HYVideoDiffusionTransformer
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ EXAMPLE_DOC_STRING = """"""
53
+
54
+
55
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
56
+ """
57
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
58
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
59
+ """
60
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
61
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
62
+ # rescale the results from guidance (fixes overexposure)
63
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
64
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
65
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
66
+ return noise_cfg
67
+
68
+
69
+ def retrieve_timesteps(
70
+ scheduler,
71
+ num_inference_steps: Optional[int] = None,
72
+ device: Optional[Union[str, torch.device]] = None,
73
+ timesteps: Optional[List[int]] = None,
74
+ sigmas: Optional[List[float]] = None,
75
+ **kwargs,
76
+ ):
77
+ """
78
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
79
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
80
+
81
+ Args:
82
+ scheduler (`SchedulerMixin`):
83
+ The scheduler to get timesteps from.
84
+ num_inference_steps (`int`):
85
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
86
+ must be `None`.
87
+ device (`str` or `torch.device`, *optional*):
88
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
89
+ timesteps (`List[int]`, *optional*):
90
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
91
+ `num_inference_steps` and `sigmas` must be `None`.
92
+ sigmas (`List[float]`, *optional*):
93
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
94
+ `num_inference_steps` and `timesteps` must be `None`.
95
+
96
+ Returns:
97
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
98
+ second element is the number of inference steps.
99
+ """
100
+ if timesteps is not None and sigmas is not None:
101
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
102
+ if timesteps is not None:
103
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
104
+ if not accepts_timesteps:
105
+ raise ValueError(
106
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
107
+ f" timestep schedules. Please check whether you are using the correct scheduler."
108
+ )
109
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
110
+ timesteps = scheduler.timesteps
111
+ num_inference_steps = len(timesteps)
112
+ elif sigmas is not None:
113
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
114
+ if not accept_sigmas:
115
+ raise ValueError(
116
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
118
+ )
119
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
120
+ timesteps = scheduler.timesteps
121
+ num_inference_steps = len(timesteps)
122
+ else:
123
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
124
+ timesteps = scheduler.timesteps
125
+ return timesteps, num_inference_steps
126
+
127
+ @dataclass
128
+ class HunyuanVideoPipelineOutput(BaseOutput):
129
+ videos: Union[torch.Tensor, np.ndarray]
130
+
131
+
132
+ class HunyuanVideoAudioPipeline(DiffusionPipeline):
133
+ r"""
134
+ Pipeline for text-to-video generation using HunyuanVideo.
135
+
136
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
137
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
138
+
139
+ Args:
140
+ vae ([`AutoencoderKL`]):
141
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
142
+ text_encoder ([`TextEncoder`]):
143
+ Frozen text-encoder.
144
+ text_encoder_2 ([`TextEncoder`]):
145
+ Frozen text-encoder_2.
146
+ transformer ([`HYVideoDiffusionTransformer`]):
147
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
148
+ scheduler ([`SchedulerMixin`]):
149
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
150
+ """
151
+
152
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
153
+ _optional_components = ["text_encoder_2"]
154
+ _exclude_from_cpu_offload = ["transformer"]
155
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
156
+
157
+ def __init__(
158
+ self,
159
+ vae: AutoencoderKL,
160
+ text_encoder: TextEncoder,
161
+ transformer: HYVideoDiffusionTransformer,
162
+ scheduler: KarrasDiffusionSchedulers,
163
+ text_encoder_2: Optional[TextEncoder] = None,
164
+ progress_bar_config: Dict[str, Any] = None,
165
+ args=None,
166
+ ):
167
+ super().__init__()
168
+
169
+ # ==========================================================================================
170
+ if progress_bar_config is None:
171
+ progress_bar_config = {}
172
+ if not hasattr(self, '_progress_bar_config'):
173
+ self._progress_bar_config = {}
174
+ self._progress_bar_config.update(progress_bar_config)
175
+
176
+ self.args = args
177
+ # ==========================================================================================
178
+
179
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
180
+ deprecation_message = (
181
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
182
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
183
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
184
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
185
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
186
+ " file"
187
+ )
188
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
189
+ new_config = dict(scheduler.config)
190
+ new_config["steps_offset"] = 1
191
+ scheduler._internal_dict = FrozenDict(new_config)
192
+
193
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
194
+ deprecation_message = (
195
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
196
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
197
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
198
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
199
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
200
+ )
201
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
202
+ new_config = dict(scheduler.config)
203
+ new_config["clip_sample"] = False
204
+ scheduler._internal_dict = FrozenDict(new_config)
205
+
206
+ self.register_modules(
207
+ vae=vae,
208
+ text_encoder=text_encoder,
209
+ transformer=transformer,
210
+ scheduler=scheduler,
211
+ text_encoder_2=text_encoder_2
212
+ )
213
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
214
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
215
+
216
+ def encode_prompt(
217
+ self,
218
+ prompt,
219
+ name,
220
+ device,
221
+ num_videos_per_prompt,
222
+ do_classifier_free_guidance,
223
+ negative_prompt=None,
224
+ pixel_value_llava: Optional[torch.Tensor] = None,
225
+ uncond_pixel_value_llava: Optional[torch.Tensor] = None,
226
+ prompt_embeds: Optional[torch.Tensor] = None,
227
+ attention_mask: Optional[torch.Tensor] = None,
228
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
229
+ negative_attention_mask: Optional[torch.Tensor] = None,
230
+ lora_scale: Optional[float] = None,
231
+ clip_skip: Optional[int] = None,
232
+ text_encoder: Optional[TextEncoder] = None,
233
+ data_type: Optional[str] = "image",
234
+ ):
235
+ r"""
236
+ Encodes the prompt into text encoder hidden states.
237
+
238
+ Args:
239
+ prompt (`str` or `List[str]`, *optional*):
240
+ prompt to be encoded
241
+ device: (`torch.device`):
242
+ torch device
243
+ num_videos_per_prompt (`int`):
244
+ number of images that should be generated per prompt
245
+ do_classifier_free_guidance (`bool`):
246
+ whether to use classifier free guidance or not
247
+ negative_prompt (`str` or `List[str]`, *optional*):
248
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
249
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
250
+ less than `1`).
251
+ pixel_value_llava (`torch.Tensor`, *optional*):
252
+ The image tensor for llava.
253
+ uncond_pixel_value_llava (`torch.Tensor`, *optional*):
254
+ The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
255
+ less than `1`).
256
+ prompt_embeds (`torch.Tensor`, *optional*):
257
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
258
+ provided, text embeddings will be generated from `prompt` input argument.
259
+ attention_mask (`torch.Tensor`, *optional*):
260
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
261
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
262
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
263
+ argument.
264
+ negative_attention_mask (`torch.Tensor`, *optional*):
265
+ lora_scale (`float`, *optional*):
266
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
267
+ clip_skip (`int`, *optional*):
268
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
269
+ the output of the pre-final layer will be used for computing the prompt embeddings.
270
+ text_encoder (TextEncoder, *optional*):
271
+ """
272
+ if text_encoder is None:
273
+ text_encoder = self.text_encoder
274
+
275
+ # set lora scale so that monkey patched LoRA
276
+ # function of text encoder can correctly access it
277
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
278
+ self._lora_scale = lora_scale
279
+
280
+ # dynamically adjust the LoRA scale
281
+ if not USE_PEFT_BACKEND:
282
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
283
+ else:
284
+ scale_lora_layers(text_encoder.model, lora_scale)
285
+
286
+ if prompt is not None and isinstance(prompt, str):
287
+ batch_size = 1
288
+ elif prompt is not None and isinstance(prompt, list):
289
+ batch_size = len(prompt)
290
+ else:
291
+ batch_size = prompt_embeds.shape[0]
292
+
293
+ if prompt_embeds is None:
294
+ # textual inversion: process multi-vector tokens if necessary
295
+ if isinstance(self, TextualInversionLoaderMixin):
296
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
297
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name=name)
298
+
299
+ if pixel_value_llava is not None:
300
+ text_inputs['pixel_value_llava'] = pixel_value_llava
301
+ text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575 * len(pixel_value_llava))).to(text_inputs['attention_mask'])], dim=1)
302
+
303
+ if clip_skip is None:
304
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
305
+ prompt_embeds = prompt_outputs.hidden_state
306
+ else:
307
+ prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type)
308
+ # Access the `hidden_states` first, that contains a tuple of
309
+ # all the hidden states from the encoder layers. Then index into
310
+ # the tuple to access the hidden states from the desired layer.
311
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
312
+ # We also need to apply the final LayerNorm here to not mess with the
313
+ # representations. The `last_hidden_states` that we typically use for
314
+ # obtaining the final prompt representations passes through the LayerNorm
315
+ # layer.
316
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
317
+
318
+ attention_mask = prompt_outputs.attention_mask
319
+ if attention_mask is not None:
320
+ attention_mask = attention_mask.to(device)
321
+ bs_embed, seq_len = attention_mask.shape
322
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
323
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
324
+
325
+ if text_encoder is not None:
326
+ prompt_embeds_dtype = text_encoder.dtype
327
+ elif self.transformer is not None:
328
+ prompt_embeds_dtype = self.transformer.dtype
329
+ else:
330
+ prompt_embeds_dtype = prompt_embeds.dtype
331
+
332
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
333
+
334
+ if prompt_embeds.ndim == 2:
335
+ bs_embed, _ = prompt_embeds.shape
336
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
337
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
338
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
339
+ else:
340
+ bs_embed, seq_len, _ = prompt_embeds.shape
341
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
342
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
343
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
344
+
345
+ # get unconditional embeddings for classifier free guidance
346
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
347
+ uncond_tokens: List[str]
348
+ if negative_prompt is None:
349
+ uncond_tokens = [""] * batch_size
350
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
351
+ raise TypeError(
352
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
353
+ f" {type(prompt)}."
354
+ )
355
+ elif isinstance(negative_prompt, str):
356
+ uncond_tokens = [negative_prompt]
357
+ elif batch_size != len(negative_prompt):
358
+ raise ValueError(
359
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
360
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
361
+ " the batch size of `prompt`."
362
+ )
363
+ else:
364
+ uncond_tokens = negative_prompt
365
+
366
+ # textual inversion: process multi-vector tokens if necessary
367
+ if isinstance(self, TextualInversionLoaderMixin):
368
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer)
369
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
370
+ if uncond_pixel_value_llava is not None:
371
+ uncond_input['pixel_value_llava'] = uncond_pixel_value_llava
372
+ uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575 * len(uncond_pixel_value_llava))).to(uncond_input['attention_mask'])], dim=1)
373
+
374
+ negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type)
375
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
376
+
377
+ negative_attention_mask = negative_prompt_outputs.attention_mask
378
+ if negative_attention_mask is not None:
379
+ negative_attention_mask = negative_attention_mask.to(device)
380
+ _, seq_len = negative_attention_mask.shape
381
+ negative_attention_mask = negative_attention_mask.repeat(1, num_videos_per_prompt)
382
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
383
+
384
+ if do_classifier_free_guidance:
385
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
386
+ seq_len = negative_prompt_embeds.shape[1]
387
+
388
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
389
+
390
+ if negative_prompt_embeds.ndim == 2:
391
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt)
392
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
393
+ else:
394
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
395
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
396
+
397
+ if text_encoder is not None:
398
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
399
+ # Retrieve the original scale by scaling back the LoRA layers
400
+ unscale_lora_layers(text_encoder.model, lora_scale)
401
+
402
+ return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
403
+
404
+ def encode_prompt_audio_text_base(
405
+ self,
406
+ prompt,
407
+ uncond_prompt,
408
+ pixel_value_llava,
409
+ uncond_pixel_value_llava,
410
+ device,
411
+ num_images_per_prompt,
412
+ do_classifier_free_guidance,
413
+ negative_prompt=None,
414
+ prompt_embeds: Optional[torch.Tensor] = None,
415
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
416
+ lora_scale: Optional[float] = None,
417
+ clip_skip: Optional[int] = None,
418
+ text_encoder: Optional[TextEncoder] = None,
419
+ data_type: Optional[str] = "image",
420
+ ):
421
+ if text_encoder is None:
422
+ text_encoder = self.text_encoder
423
+
424
+ # set lora scale so that monkey patched LoRA
425
+ # function of text encoder can correctly access it
426
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
427
+ self._lora_scale = lora_scale
428
+
429
+ # dynamically adjust the LoRA scale
430
+ if not USE_PEFT_BACKEND:
431
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
432
+ else:
433
+ scale_lora_layers(text_encoder.model, lora_scale)
434
+
435
+ if prompt is not None and isinstance(prompt, str):
436
+ batch_size = 1
437
+ elif prompt is not None and isinstance(prompt, list):
438
+ batch_size = len(prompt)
439
+ else:
440
+ batch_size = prompt_embeds.shape[0]
441
+
442
+ prompt_embeds = None
443
+
444
+ if prompt_embeds is None:
445
+ # textual inversion: process multi-vector tokens if necessary
446
+ if isinstance(self, TextualInversionLoaderMixin):
447
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
448
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) # data_type: video, text_inputs: {'input_ids', 'attention_mask'}
449
+
450
+ text_keys = ['input_ids', 'attention_mask']
451
+
452
+ if pixel_value_llava is not None:
453
+ text_inputs['pixel_value_llava'] = pixel_value_llava
454
+ text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575)).to(text_inputs['attention_mask'])], dim=1)
455
+
456
+
457
+ if clip_skip is None:
458
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
459
+ prompt_embeds = prompt_outputs.hidden_state
460
+ else:
461
+ prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type)
462
+ # Access the `hidden_states` first, that contains a tuple of
463
+ # all the hidden states from the encoder layers. Then index into
464
+ # the tuple to access the hidden states from the desired layer.
465
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
466
+ # We also need to apply the final LayerNorm here to not mess with the
467
+ # representations. The `last_hidden_states` that we typically use for
468
+ # obtaining the final prompt representations passes through the LayerNorm
469
+ # layer.
470
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
471
+
472
+ attention_mask = prompt_outputs.attention_mask
473
+ if attention_mask is not None:
474
+ attention_mask = attention_mask.to(device)
475
+ bs_embed, seq_len = attention_mask.shape
476
+ attention_mask = attention_mask.repeat(1, num_images_per_prompt)
477
+ attention_mask = attention_mask.view(bs_embed * num_images_per_prompt, seq_len)
478
+
479
+ if text_encoder is not None:
480
+ prompt_embeds_dtype = text_encoder.dtype
481
+ elif self.unet is not None:
482
+ prompt_embeds_dtype = self.unet.dtype
483
+ else:
484
+ prompt_embeds_dtype = prompt_embeds.dtype
485
+
486
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
487
+
488
+ if prompt_embeds.ndim == 2:
489
+ bs_embed, _ = prompt_embeds.shape
490
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
491
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
492
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, -1)
493
+ else:
494
+ bs_embed, seq_len, _ = prompt_embeds.shape
495
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
496
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
497
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
498
+
499
+ # get unconditional embeddings for classifier free guidance
500
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
501
+ uncond_tokens: List[str]
502
+ if negative_prompt is None:
503
+ uncond_tokens = [""] * batch_size
504
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
505
+ raise TypeError(
506
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
507
+ f" {type(prompt)}."
508
+ )
509
+ elif isinstance(negative_prompt, str):
510
+ uncond_tokens = [negative_prompt]
511
+ elif batch_size != len(negative_prompt):
512
+ raise ValueError(
513
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
514
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
515
+ " the batch size of `prompt`."
516
+ )
517
+ else:
518
+ uncond_tokens = negative_prompt
519
+
520
+ # textual inversion: process multi-vector tokens if necessary
521
+ if isinstance(self, TextualInversionLoaderMixin):
522
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer)
523
+ # max_length = prompt_embeds.shape[1]
524
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
525
+
526
+ # if hasattr(text_encoder.model.config, "use_attention_mask") and text_encoder.model.config.use_attention_mask:
527
+ # attention_mask = uncond_input.attention_mask.to(device)
528
+ # else:
529
+ # attention_mask = None
530
+ if uncond_pixel_value_llava is not None:
531
+ uncond_input['pixel_value_llava'] = uncond_pixel_value_llava
532
+ uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575)).to(uncond_input['attention_mask'])], dim=1)
533
+
534
+ negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type)
535
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
536
+
537
+ negative_attention_mask = negative_prompt_outputs.attention_mask
538
+ if negative_attention_mask is not None:
539
+ negative_attention_mask = negative_attention_mask.to(device)
540
+ _, seq_len = negative_attention_mask.shape
541
+ negative_attention_mask = negative_attention_mask.repeat(1, num_images_per_prompt)
542
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_images_per_prompt, seq_len)
543
+
544
+ if do_classifier_free_guidance:
545
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
546
+ seq_len = negative_prompt_embeds.shape[1]
547
+
548
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
549
+
550
+ if negative_prompt_embeds.ndim == 2:
551
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
552
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
553
+ else:
554
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
555
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
556
+
557
+ if text_encoder is not None:
558
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
559
+ # Retrieve the original scale by scaling back the LoRA layers
560
+ unscale_lora_layers(text_encoder.model, lora_scale)
561
+
562
+ return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
563
+
564
+ def decode_latents(self, latents, enable_tiling=True):
565
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
566
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
567
+
568
+ latents = 1 / self.vae.config.scaling_factor * latents
569
+ if enable_tiling:
570
+ self.vae.enable_tiling()
571
+ image = self.vae.decode(latents, return_dict=False)[0]
572
+ self.vae.disable_tiling()
573
+ else:
574
+ image = self.vae.decode(latents, return_dict=False)[0]
575
+ image = (image / 2 + 0.5).clamp(0, 1)
576
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
577
+ if image.ndim==4: image = image.cpu().permute(0, 2, 3, 1).float()
578
+ else: image = image.cpu().float()
579
+ return image
580
+
581
+ def prepare_extra_func_kwargs(self, func, kwargs):
582
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
583
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
584
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
585
+ # and should be between [0, 1]
586
+ extra_step_kwargs = {}
587
+
588
+ for k, v in kwargs.items():
589
+ accepts = k in set(inspect.signature(func).parameters.keys())
590
+ if accepts:
591
+ extra_step_kwargs[k] = v
592
+ return extra_step_kwargs
593
+
594
+ def check_inputs(
595
+ self,
596
+ prompt,
597
+ height,
598
+ width,
599
+ frame,
600
+ callback_steps,
601
+ pixel_value_llava=None,
602
+ uncond_pixel_value_llava=None,
603
+ negative_prompt=None,
604
+ prompt_embeds=None,
605
+ negative_prompt_embeds=None,
606
+ callback_on_step_end_tensor_inputs=None,
607
+ vae_ver='88-4c-sd'
608
+ ):
609
+ if height % 8 != 0 or width % 8 != 0:
610
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
611
+
612
+ if frame is not None:
613
+ if '884' in vae_ver:
614
+ if frame!=1 and (frame-1)%4!=0:
615
+ raise ValueError(f'`frame` has to be 1 or a multiple of 4 but is {frame}.')
616
+ elif '888' in vae_ver:
617
+ if frame!=1 and (frame-1)%8!=0:
618
+ raise ValueError(f'`frame` has to be 1 or a multiple of 8 but is {frame}.')
619
+
620
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
621
+ raise ValueError(
622
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
623
+ f" {type(callback_steps)}."
624
+ )
625
+ if callback_on_step_end_tensor_inputs is not None and not all(
626
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
627
+ ):
628
+ raise ValueError(
629
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
630
+ )
631
+
632
+ if prompt is not None and prompt_embeds is not None:
633
+ raise ValueError(
634
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
635
+ " only forward one of the two."
636
+ )
637
+ elif prompt is None and prompt_embeds is None:
638
+ raise ValueError(
639
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
640
+ )
641
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
642
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
643
+
644
+ if negative_prompt is not None and negative_prompt_embeds is not None:
645
+ raise ValueError(
646
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
647
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
648
+ )
649
+
650
+ if pixel_value_llava is not None and uncond_pixel_value_llava is not None:
651
+ if len(pixel_value_llava) != len(uncond_pixel_value_llava):
652
+ raise ValueError(
653
+ "`pixel_value_llava` and `uncond_pixel_value_llava` must have the same length when passed directly, but"
654
+ f" got: `pixel_value_llava` {len(pixel_value_llava)} != `uncond_pixel_value_llava`"
655
+ f" {len(uncond_pixel_value_llava)}."
656
+ )
657
+
658
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
659
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
660
+ raise ValueError(
661
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
662
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
663
+ f" {negative_prompt_embeds.shape}."
664
+ )
665
+
666
+ def get_timesteps(self, num_inference_steps, strength, device):
667
+ # get the original timestep using init_timestep
668
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
669
+
670
+ t_start = max(num_inference_steps - init_timestep, 0)
671
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
672
+ if hasattr(self.scheduler, "set_begin_index"):
673
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
674
+
675
+ return timesteps.to(device), num_inference_steps - t_start
676
+
677
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, frame, dtype, device, generator, latents=None, ref_latents=None, timestep=None):
678
+ shape = (
679
+ batch_size,
680
+ num_channels_latents,
681
+ frame,
682
+ int(height) // self.vae_scale_factor,
683
+ int(width) // self.vae_scale_factor,
684
+ )
685
+ if isinstance(generator, list) and len(generator) != batch_size:
686
+ raise ValueError(
687
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
688
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
689
+ )
690
+
691
+ if latents is None:
692
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
693
+ else:
694
+ latents = latents.to(device)
695
+
696
+
697
+ if timestep is not None:
698
+ init_latents = ref_latents.clone().repeat(1,1,frame,1,1).to(device).to(dtype)
699
+ latents = latents
700
+
701
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
702
+ if hasattr(self.scheduler, "init_noise_sigma"):
703
+ latents = latents * self.scheduler.init_noise_sigma
704
+
705
+ return latents
706
+
707
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
708
+ def get_guidance_scale_embedding(
709
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
710
+ ) -> torch.Tensor:
711
+ """
712
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
713
+
714
+ Args:
715
+ w (`torch.Tensor`):
716
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
717
+ embedding_dim (`int`, *optional*, defaults to 512):
718
+ Dimension of the embeddings to generate.
719
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
720
+ Data type of the generated embeddings.
721
+
722
+ Returns:
723
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
724
+ """
725
+ assert len(w.shape) == 1
726
+ w = w * 1000.0
727
+
728
+ half_dim = embedding_dim // 2
729
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
730
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
731
+ emb = w.to(dtype)[:, None] * emb[None, :]
732
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
733
+ if embedding_dim % 2 == 1: # zero pad
734
+ emb = torch.nn.functional.pad(emb, (0, 1))
735
+ assert emb.shape == (w.shape[0], embedding_dim)
736
+ return emb
737
+
738
+ @property
739
+ def guidance_scale(self):
740
+ return self._guidance_scale
741
+
742
+ @property
743
+ def guidance_rescale(self):
744
+ return self._guidance_rescale
745
+
746
+ @property
747
+ def clip_skip(self):
748
+ return self._clip_skip
749
+
750
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
751
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
752
+ # corresponds to doing no classifier free guidance.
753
+ @property
754
+ def do_classifier_free_guidance(self):
755
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
756
+ return self._guidance_scale > 1
757
+
758
+ @property
759
+ def cross_attention_kwargs(self):
760
+ return self._cross_attention_kwargs
761
+
762
+ @property
763
+ def num_timesteps(self):
764
+ return self._num_timesteps
765
+
766
+ @property
767
+ def interrupt(self):
768
+ return self._interrupt
769
+
770
+ @torch.no_grad()
771
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
772
+ def __call__(
773
+ self,
774
+ prompt: Union[str, List[str]],
775
+
776
+ ref_latents: Union[torch.Tensor], # [1, 16, 1, h//8, w//8]
777
+ uncond_ref_latents: Union[torch.Tensor],
778
+ pixel_value_llava: Union[torch.Tensor], # [1, 3, 336, 336]
779
+ uncond_pixel_value_llava: Union[torch.Tensor],
780
+ face_masks: Union[torch.Tensor], # [b f h w]
781
+ audio_prompts: Union[torch.Tensor],
782
+ uncond_audio_prompts: Union[torch.Tensor],
783
+ motion_exp: Union[torch.Tensor],
784
+ motion_pose: Union[torch.Tensor],
785
+ fps: Union[torch.Tensor],
786
+
787
+ height: int,
788
+ width: int,
789
+ frame: int,
790
+ data_type: str = "video",
791
+ num_inference_steps: int = 50,
792
+ timesteps: List[int] = None,
793
+ sigmas: List[float] = None,
794
+ guidance_scale: float = 7.5,
795
+ negative_prompt: Optional[Union[str, List[str]]] = None,
796
+ num_videos_per_prompt: Optional[int] = 1,
797
+ eta: float = 0.0,
798
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
799
+ latents: Optional[torch.Tensor] = None,
800
+ prompt_embeds: Optional[torch.Tensor] = None,
801
+ attention_mask: Optional[torch.Tensor] = None,
802
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
803
+ negative_attention_mask: Optional[torch.Tensor] = None,
804
+ output_type: Optional[str] = "pil",
805
+ return_dict: bool = True,
806
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
807
+ guidance_rescale: float = 0.0,
808
+ clip_skip: Optional[int] = None,
809
+ callback_on_step_end: Optional[
810
+ Union[
811
+ Callable[[int, int, Dict], None],
812
+ PipelineCallback,
813
+ MultiPipelineCallbacks,
814
+ ]
815
+ ] = None,
816
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
817
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
818
+ vae_ver: str = "88-4c-sd",
819
+ enable_tiling: bool = False,
820
+ n_tokens: Optional[int] = None,
821
+ embedded_guidance_scale: Optional[float] = None,
822
+ **kwargs,
823
+ ):
824
+ r"""
825
+ The call function to the pipeline for generation.
826
+
827
+ Args:
828
+ prompt (`str` or `List[str]`):
829
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
830
+ height (`int`):
831
+ The height in pixels of the generated image.
832
+ width (`int`):
833
+ The width in pixels of the generated image.
834
+ video_length (`int`):
835
+ The number of frames in the generated video.
836
+ num_inference_steps (`int`, *optional*, defaults to 50):
837
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
838
+ expense of slower inference.
839
+ timesteps (`List[int]`, *optional*):
840
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
841
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
842
+ passed will be used. Must be in descending order.
843
+ sigmas (`List[float]`, *optional*):
844
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
845
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
846
+ will be used.
847
+ guidance_scale (`float`, *optional*, defaults to 7.5):
848
+ A higher guidance scale value encourages the model to generate images closely linked to the text
849
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
850
+ negative_prompt (`str` or `List[str]`, *optional*):
851
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
852
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
853
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
854
+ The number of images to generate per prompt.
855
+ eta (`float`, *optional*, defaults to 0.0):
856
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
857
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
858
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
859
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
860
+ generation deterministic.
861
+ latents (`torch.Tensor`, *optional*):
862
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
863
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
864
+ tensor is generated by sampling using the supplied random `generator`.
865
+ prompt_embeds (`torch.Tensor`, *optional*):
866
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
867
+ provided, text embeddings are generated from the `prompt` input argument.
868
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
869
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
870
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
871
+
872
+ output_type (`str`, *optional*, defaults to `"pil"`):
873
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
874
+ return_dict (`bool`, *optional*, defaults to `True`):
875
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
876
+ plain tuple.
877
+ cross_attention_kwargs (`dict`, *optional*):
878
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
879
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
880
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
881
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
882
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
883
+ using zero terminal SNR.
884
+ clip_skip (`int`, *optional*):
885
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
886
+ the output of the pre-final layer will be used for computing the prompt embeddings.
887
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
888
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
889
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
890
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
891
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
892
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
893
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
894
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
895
+ `._callback_tensor_inputs` attribute of your pipeline class.
896
+
897
+ Examples:
898
+
899
+ Returns:
900
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
901
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
902
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
903
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
904
+ "not-safe-for-work" (nsfw) content.
905
+ """
906
+ callback = kwargs.pop("callback", None)
907
+ callback_steps = kwargs.pop("callback_steps", None)
908
+ if callback is not None:
909
+ deprecate(
910
+ "callback",
911
+ "1.0.0",
912
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
913
+ )
914
+ if callback_steps is not None:
915
+ deprecate(
916
+ "callback_steps",
917
+ "1.0.0",
918
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
919
+ )
920
+
921
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
922
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
923
+
924
+ cpu_offload = kwargs.get("cpu_offload", 0)
925
+
926
+ # 0. Default height and width to transformer
927
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
928
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
929
+ # to deal with lora scaling and other possible forward hooks
930
+
931
+ # 1. Check inputs. Raise error if not correct
932
+ self.check_inputs(
933
+ prompt,
934
+ height,
935
+ width,
936
+ frame,
937
+ callback_steps,
938
+ pixel_value_llava,
939
+ uncond_pixel_value_llava,
940
+ negative_prompt,
941
+ prompt_embeds,
942
+ negative_prompt_embeds,
943
+ callback_on_step_end_tensor_inputs,
944
+ vae_ver=vae_ver
945
+ )
946
+
947
+ self._guidance_scale = guidance_scale
948
+ self.start_cfg_scale = guidance_scale
949
+ self._guidance_rescale = guidance_rescale
950
+ self._clip_skip = clip_skip
951
+ self._cross_attention_kwargs = cross_attention_kwargs
952
+ self._interrupt = False
953
+
954
+ # 2. Define call parameters
955
+ if prompt is not None and isinstance(prompt, str):
956
+ batch_size = 1
957
+ elif prompt is not None and isinstance(prompt, list):
958
+ batch_size = len(prompt)
959
+ else:
960
+ batch_size = prompt_embeds.shape[0]
961
+
962
+ device = self._execution_device
963
+
964
+ # 3. Encode input prompt
965
+ lora_scale = (
966
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
967
+ )
968
+
969
+
970
+ # ========== Encode text prompt (image prompt) ==========
971
+ prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask = \
972
+ self.encode_prompt_audio_text_base(
973
+ prompt=prompt,
974
+ uncond_prompt=negative_prompt,
975
+ pixel_value_llava=pixel_value_llava,
976
+ uncond_pixel_value_llava=uncond_pixel_value_llava,
977
+ device=device,
978
+ num_images_per_prompt=num_videos_per_prompt,
979
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
980
+ negative_prompt=negative_prompt,
981
+ prompt_embeds=prompt_embeds,
982
+ negative_prompt_embeds=negative_prompt_embeds,
983
+ lora_scale=lora_scale,
984
+ clip_skip=self.clip_skip,
985
+ text_encoder=self.text_encoder,
986
+ data_type=data_type,
987
+ # **kwargs
988
+ )
989
+ if self.text_encoder_2 is not None:
990
+ prompt_embeds_2, negative_prompt_embeds_2, prompt_mask_2, negative_prompt_mask_2 = \
991
+ self.encode_prompt_audio_text_base(
992
+ prompt=prompt,
993
+ uncond_prompt=negative_prompt,
994
+ pixel_value_llava=None,
995
+ uncond_pixel_value_llava=None,
996
+ device=device,
997
+ num_images_per_prompt=num_videos_per_prompt,
998
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
999
+ negative_prompt=negative_prompt,
1000
+ prompt_embeds=None,
1001
+ negative_prompt_embeds=None,
1002
+ lora_scale=lora_scale,
1003
+ clip_skip=self.clip_skip,
1004
+ text_encoder=self.text_encoder_2,
1005
+ # **kwargs
1006
+ )
1007
+ else:
1008
+ prompt_embeds_2 = None
1009
+ negative_prompt_embeds_2 = None
1010
+ prompt_mask_2 = None
1011
+ negative_prompt_mask_2 = None
1012
+
1013
+
1014
+ # For classifier free guidance, we need to do two forward passes.
1015
+ # Here we concatenate the unconditional and text embeddings into a single batch
1016
+ # to avoid doing two forward passes
1017
+ if self.do_classifier_free_guidance:
1018
+ prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds])
1019
+ if prompt_mask is not None:
1020
+ prompt_mask_input = torch.cat([negative_prompt_mask, prompt_mask])
1021
+ if prompt_embeds_2 is not None:
1022
+ prompt_embeds_2_input = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
1023
+ if prompt_mask_2 is not None:
1024
+ prompt_mask_2_input = torch.cat([negative_prompt_mask_2, prompt_mask_2])
1025
+
1026
+ if self.do_classifier_free_guidance:
1027
+ ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
1028
+
1029
+
1030
+ # 4. Prepare timesteps
1031
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
1032
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
1033
+ )
1034
+ timesteps, num_inference_steps = retrieve_timesteps(
1035
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, **extra_set_timesteps_kwargs,
1036
+ )
1037
+
1038
+ video_length = audio_prompts.shape[1] // 4 * 4 + 1
1039
+ if "884" in vae_ver:
1040
+ video_length = (video_length - 1) // 4 + 1
1041
+ elif "888" in vae_ver:
1042
+ video_length = (video_length - 1) // 8 + 1
1043
+ else:
1044
+ video_length = video_length
1045
+
1046
+
1047
+ # 5. Prepare latent variables
1048
+ num_channels_latents = self.transformer.config.in_channels
1049
+ infer_length = (audio_prompts.shape[1] // 128 + 1) * 32 + 1
1050
+ latents = self.prepare_latents(
1051
+ batch_size * num_videos_per_prompt,
1052
+ num_channels_latents,
1053
+ height,
1054
+ width,
1055
+ infer_length,
1056
+ prompt_embeds.dtype,
1057
+ device,
1058
+ generator,
1059
+ latents,
1060
+ ref_latents[-1:],
1061
+ timesteps[:1]
1062
+ )
1063
+
1064
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1065
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
1066
+ self.scheduler.step, {"generator": generator, "eta": eta},
1067
+ )
1068
+
1069
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
1070
+ autocast_enabled = (target_dtype != torch.float32) and not self.args.val_disable_autocast
1071
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
1072
+ vae_autocast_enabled = (vae_dtype != torch.float32) and not self.args.val_disable_autocast
1073
+
1074
+ # 7. Denoising loop
1075
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1076
+ self._num_timesteps = len(timesteps)
1077
+
1078
+ latents_all = latents.clone()
1079
+ pad_audio_length = (audio_prompts.shape[1] // 128 + 1) * 128 + 4 - audio_prompts.shape[1]
1080
+ audio_prompts_all = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :pad_audio_length])], dim=1)
1081
+
1082
+
1083
+ shift = 0
1084
+ shift_offset = 10
1085
+ frames_per_batch = 33
1086
+ self.cache_tensor = None
1087
+
1088
+ """ If the total length is shorter than 129, shift is not required """
1089
+ if video_length == 33 or infer_length == 33:
1090
+ infer_length = 33
1091
+ shift_offset = 0
1092
+ latents_all = latents_all[:, :, :33]
1093
+ audio_prompts_all = audio_prompts_all[:, :132]
1094
+
1095
+ if cpu_offload: torch.cuda.empty_cache()
1096
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1097
+ for i, t in enumerate(timesteps):
1098
+ if self.interrupt:
1099
+ continue
1100
+
1101
+ # init
1102
+ pred_latents = torch.zeros_like(
1103
+ latents_all,
1104
+ dtype=latents_all.dtype,
1105
+ )
1106
+ counter = torch.zeros(
1107
+ (latents_all.shape[0], latents_all.shape[1], infer_length, 1, 1),
1108
+ dtype=latents_all.dtype,
1109
+ ).to(device=latents_all.device)
1110
+
1111
+ for index_start in range(0, infer_length, frames_per_batch):
1112
+ self.scheduler._step_index = None
1113
+
1114
+ index_start = index_start - shift
1115
+
1116
+ idx_list = [ii % latents_all.shape[2] for ii in range(index_start, index_start + frames_per_batch)]
1117
+ latents = latents_all[:, :, idx_list].clone()
1118
+
1119
+ idx_list_audio = [ii % audio_prompts_all.shape[1] for ii in range(index_start * 4, (index_start + frames_per_batch) * 4 - 3)]
1120
+ audio_prompts = audio_prompts_all[:, idx_list_audio].clone()
1121
+
1122
+ # expand the latents if we are doing classifier free guidance
1123
+ if self.do_classifier_free_guidance:
1124
+ latent_model_input = torch.cat([latents] * 2)
1125
+ else:
1126
+ latent_model_input = latents
1127
+
1128
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1129
+
1130
+ if self.do_classifier_free_guidance:
1131
+ if i < 10:
1132
+ self._guidance_scale = (1 - i / len(timesteps)) * (self.start_cfg_scale - 2) + 2
1133
+ audio_prompts_input = torch.cat([uncond_audio_prompts, audio_prompts], dim=0)
1134
+ face_masks_input = torch.cat([face_masks * 0.6] * 2, dim=0)
1135
+ else:
1136
+ # define 10-50 step cfg
1137
+ self._guidance_scale = (1 - i / len(timesteps)) * (6.5 - 3.5) + 3.5 # 5-2 +2
1138
+
1139
+ prompt_embeds_input = torch.cat([prompt_embeds, prompt_embeds])
1140
+ if prompt_mask is not None:
1141
+ prompt_mask_input = torch.cat([prompt_mask, prompt_mask])
1142
+ if prompt_embeds_2 is not None:
1143
+ prompt_embeds_2_input = torch.cat([prompt_embeds_2, prompt_embeds_2])
1144
+ if prompt_mask_2 is not None:
1145
+ prompt_mask_2_input = torch.cat([prompt_mask_2, prompt_mask_2])
1146
+ audio_prompts_input = torch.cat([uncond_audio_prompts, audio_prompts], dim=0)
1147
+ face_masks_input = torch.cat([face_masks] * 2, dim=0)
1148
+
1149
+ motion_exp_input = torch.cat([motion_exp] * 2, dim=0)
1150
+ motion_pose_input = torch.cat([motion_pose] * 2, dim=0)
1151
+ fps_input = torch.cat([fps] * 2, dim=0)
1152
+
1153
+ else:
1154
+ audio_prompts_input = audio_prompts
1155
+ face_masks_input = face_masks
1156
+ motion_exp_input = motion_exp
1157
+ motion_pose_input = motion_pose
1158
+ fps_input = fps
1159
+
1160
+ t_expand = t.repeat(latent_model_input.shape[0])
1161
+ guidance_expand = None
1162
+
1163
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
1164
+
1165
+ no_cache_steps = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + list(range(15, 42, 5)) + [41, 42, 43, 44, 45, 46, 47, 48, 49]
1166
+ img_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * latent_model_input.shape[-3]
1167
+ img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * (latent_model_input.shape[-3]+1)
1168
+ if i in no_cache_steps:
1169
+ is_cache = False
1170
+
1171
+ if latent_model_input.shape[-1]*latent_model_input.shape[-2]>64*112 and cpu_offload:
1172
+ if i==0:
1173
+ print(f'cpu_offload={cpu_offload} and {latent_model_input.shape[-2:]} is large, split infer noise-pred')
1174
+
1175
+ additional_kwargs = {
1176
+ "motion_exp": motion_exp_input[:1],
1177
+ "motion_pose": motion_pose_input[:1],
1178
+ "fps": fps_input[:1],
1179
+ "audio_prompts": audio_prompts_input[:1],
1180
+ "face_mask": face_masks_input[:1]
1181
+ }
1182
+ noise_pred_uncond = self.transformer(latent_model_input[:1], t_expand[:1], ref_latents=ref_latents[:1], text_states=prompt_embeds_input[:1], text_mask=prompt_mask_input[:1], text_states_2=prompt_embeds_2_input[:1], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, return_dict=True, is_cache=is_cache, **additional_kwargs,)['x']
1183
+ uncond_cache_tensor = self.transformer.cache_out
1184
+ torch.cuda.empty_cache()
1185
+
1186
+ additional_kwargs = {
1187
+ "motion_exp": motion_exp_input[1:],
1188
+ "motion_pose": motion_pose_input[1:],
1189
+ "fps": fps_input[1:],
1190
+ "audio_prompts": audio_prompts_input[1:],
1191
+ "face_mask": face_masks_input[1:]
1192
+ }
1193
+ noise_pred_text = self.transformer(latent_model_input[1:], t_expand[1:], ref_latents=ref_latents[1:], text_states=prompt_embeds_input[1:], text_mask=prompt_mask_input[1:], text_states_2=prompt_embeds_2_input[1:], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, return_dict=True, is_cache=is_cache, **additional_kwargs,)['x']
1194
+ self.transformer.cache_out = torch.cat([uncond_cache_tensor, self.transformer.cache_out], dim=0)
1195
+
1196
+ noise_pred = torch.cat([noise_pred_uncond, noise_pred_text], dim=0)
1197
+ torch.cuda.empty_cache()
1198
+ else:
1199
+ additional_kwargs = {
1200
+ "motion_exp": motion_exp_input,
1201
+ "motion_pose": motion_pose_input,
1202
+ "fps": fps_input,
1203
+ "audio_prompts": audio_prompts_input,
1204
+ "face_mask": face_masks_input
1205
+ }
1206
+ noise_pred = self.transformer(latent_model_input, t_expand, ref_latents=ref_latents, text_states=prompt_embeds_input, text_mask=prompt_mask_input, text_states_2=prompt_embeds_2_input, freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, return_dict=True, is_cache=is_cache, **additional_kwargs,)['x']
1207
+ torch.cuda.empty_cache()
1208
+
1209
+ if self.cache_tensor is None:
1210
+ self.cache_tensor = {
1211
+ "ref": torch.zeros([latent_model_input.shape[0], latents_all.shape[-3], (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2), 3072]).to(self.transformer.cache_out.dtype).to(latent_model_input.device).clone(),
1212
+ "img": torch.zeros([latent_model_input.shape[0], latents_all.shape[-3], (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2), 3072]).to(self.transformer.cache_out.dtype).to(latent_model_input.device).clone(),
1213
+ "txt": torch.zeros([latent_model_input.shape[0], latents_all.shape[-3], prompt_embeds_input.shape[1], 3072]).to(self.transformer.cache_out.dtype).to(latent_model_input.device).clone(),
1214
+ }
1215
+
1216
+ self.cache_tensor["ref"][:, idx_list] = self.transformer.cache_out[:, :img_ref_len-img_len].reshape(latent_model_input.shape[0], 1, -1, 3072).repeat(1, len(idx_list), 1, 1)
1217
+ self.cache_tensor["img"][:, idx_list] = self.transformer.cache_out[:, img_ref_len-img_len:img_ref_len].reshape(latent_model_input.shape[0], len(idx_list), -1, 3072)
1218
+ self.cache_tensor["txt"][:, idx_list] = self.transformer.cache_out[:, img_ref_len:].unsqueeze(1).repeat(1, len(idx_list), 1, 1)
1219
+
1220
+ else:
1221
+ is_cache = True
1222
+ # self.transformer.cache_out[:, :img_ref_len-img_len] = self.cache_tensor["ref"][:, idx_list].mean(1)
1223
+ self.transformer.cache_out[:, :img_ref_len-img_len] = self.cache_tensor["ref"][:, idx_list][:, 0].clone()
1224
+ self.transformer.cache_out[:, img_ref_len-img_len:img_ref_len] = self.cache_tensor["img"][:, idx_list].reshape(-1, img_len, 3072).clone()
1225
+ self.transformer.cache_out[:, img_ref_len:] = self.cache_tensor["txt"][:, idx_list][:, 0].clone()
1226
+
1227
+ if latent_model_input.shape[-1]*latent_model_input.shape[-2]>64*112 and cpu_offload:
1228
+ if i==0:
1229
+ print(f'cpu_offload={cpu_offload} and {latent_model_input.shape[-2:]} is large, split infer noise-pred')
1230
+
1231
+ additional_kwargs = {
1232
+ "motion_exp": motion_exp_input[:1],
1233
+ "motion_pose": motion_pose_input[:1],
1234
+ "fps": fps_input[:1],
1235
+ "audio_prompts": audio_prompts_input[:1],
1236
+ "face_mask": face_masks_input[:1]
1237
+ }
1238
+ tmp = self.transformer.cache_out.clone()
1239
+ self.transformer.cache_out = tmp[:1]
1240
+ noise_pred_uncond = self.transformer(latent_model_input[:1], t_expand[:1], ref_latents=ref_latents[:1], text_states=prompt_embeds_input[:1], text_mask=prompt_mask_input[:1], text_states_2=prompt_embeds_2_input[:1], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, return_dict=True, is_cache=is_cache, **additional_kwargs,)['x']
1241
+
1242
+
1243
+ torch.cuda.empty_cache()
1244
+
1245
+ additional_kwargs = {
1246
+ "motion_exp": motion_exp_input[1:],
1247
+ "motion_pose": motion_pose_input[1:],
1248
+ "fps": fps_input[1:],
1249
+ "audio_prompts": audio_prompts_input[1:],
1250
+ "face_mask": face_masks_input[1:]
1251
+ }
1252
+ self.transformer.cache_out = tmp[1:]
1253
+ noise_pred_text = self.transformer(latent_model_input[1:], t_expand[1:], ref_latents=ref_latents[1:], text_states=prompt_embeds_input[1:], text_mask=prompt_mask_input[1:], text_states_2=prompt_embeds_2_input[1:], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, return_dict=True, is_cache=is_cache, **additional_kwargs,)['x']
1254
+ noise_pred = torch.cat([noise_pred_uncond, noise_pred_text], dim=0)
1255
+
1256
+ self.transformer.cache_out = tmp
1257
+ torch.cuda.empty_cache()
1258
+ else:
1259
+ additional_kwargs = {
1260
+ "motion_exp": motion_exp_input,
1261
+ "motion_pose": motion_pose_input,
1262
+ "fps": fps_input,
1263
+ "audio_prompts": audio_prompts_input,
1264
+ "face_mask": face_masks_input
1265
+ }
1266
+ noise_pred = self.transformer(latent_model_input, t_expand, ref_latents=ref_latents, text_states=prompt_embeds_input, text_mask=prompt_mask_input, text_states_2=prompt_embeds_2_input, freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, return_dict=True, is_cache=is_cache, **additional_kwargs,)['x']
1267
+ torch.cuda.empty_cache()
1268
+ # perform guidance
1269
+ if self.do_classifier_free_guidance:
1270
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1271
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1272
+
1273
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1274
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1275
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1276
+
1277
+ # compute the previous noisy sample x_t -> x_t-1
1278
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1279
+
1280
+ if callback_on_step_end is not None:
1281
+ callback_kwargs = {}
1282
+ for k in callback_on_step_end_tensor_inputs:
1283
+ callback_kwargs[k] = locals()[k]
1284
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1285
+
1286
+ latents = callback_outputs.pop("latents", latents)
1287
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1288
+ negative_prompt_embeds = callback_outputs.pop(
1289
+ "negative_prompt_embeds", negative_prompt_embeds
1290
+ )
1291
+ latents = latents.to(torch.bfloat16)
1292
+ for iii in range(frames_per_batch):
1293
+ p = (index_start + iii) % pred_latents.shape[2]
1294
+ pred_latents[:, :, p] += latents[:, :, iii]
1295
+ counter[:, :, p] += 1
1296
+
1297
+ shift += shift_offset
1298
+ shift = shift % frames_per_batch
1299
+ pred_latents = pred_latents / counter
1300
+ latents_all = pred_latents
1301
+
1302
+ # call the callback, if provided
1303
+ if i == len(timesteps) - 1 or (
1304
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1305
+ ):
1306
+ if progress_bar is not None:
1307
+ progress_bar.update()
1308
+ if callback is not None and i % callback_steps == 0:
1309
+ step_idx = i // getattr(self.scheduler, "order", 1)
1310
+ callback(step_idx, t, latents)
1311
+
1312
+ latents = latents_all.float()[:, :, :video_length]
1313
+ if cpu_offload: torch.cuda.empty_cache()
1314
+
1315
+ if not output_type == "latent":
1316
+ expand_temporal_dim = False
1317
+ if len(latents.shape) == 4:
1318
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1319
+ latents = latents.unsqueeze(2)
1320
+ expand_temporal_dim = True
1321
+ elif len(latents.shape) == 5:
1322
+ pass
1323
+ else:
1324
+ raise ValueError(
1325
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
1326
+
1327
+ if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
1328
+ latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
1329
+ else:
1330
+ latents = latents / self.vae.config.scaling_factor
1331
+
1332
+ with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled):
1333
+ if enable_tiling:
1334
+ self.vae.enable_tiling()
1335
+ if cpu_offload:
1336
+ self.vae.post_quant_conv.to('cuda')
1337
+ self.vae.decoder.to('cuda')
1338
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
1339
+ self.vae.disable_tiling()
1340
+ if cpu_offload:
1341
+ self.vae.post_quant_conv.to('cpu')
1342
+ self.vae.decoder.to('cpu')
1343
+ torch.cuda.empty_cache()
1344
+ else:
1345
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
1346
+ if image is None:
1347
+ return (None, )
1348
+
1349
+ if expand_temporal_dim or image.shape[2] == 1:
1350
+ image = image.squeeze(2)
1351
+
1352
+ image = (image / 2 + 0.5).clamp(0, 1)
1353
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1354
+ image = image.cpu().float()
1355
+
1356
+ # Offload all models
1357
+ self.maybe_free_model_hooks()
1358
+
1359
+ if cpu_offload: torch.cuda.empty_cache()
1360
+ if not return_dict:
1361
+ return image
1362
+
1363
+ return HunyuanVideoPipelineOutput(videos=image)
hymm_sp/diffusion/schedulers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler