aleafy commited on
Commit
35a91e2
·
1 Parent(s): 78ca493
Files changed (5) hide show
  1. .gitignore +0 -2
  2. app1_a.py +0 -386
  3. app1_bf.py +0 -388
  4. app1_bf2.py +0 -388
  5. app_bf.py +0 -391
.gitignore CHANGED
@@ -1,5 +1,3 @@
1
- app1.py
2
- app2.py
3
  demo_utils1.py
4
  tmp
5
  models
 
 
 
1
  demo_utils1.py
2
  tmp
3
  models
app1_a.py DELETED
@@ -1,386 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import numpy as np
4
- from enum import Enum
5
- import db_examples
6
- import cv2
7
-
8
-
9
- from demo_utils1 import *
10
-
11
- from misc_utils.train_utils import unit_test_create_model
12
- from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
13
- import os
14
- from PIL import Image
15
- import torch
16
- import torchvision
17
- from torchvision import transforms
18
- from einops import rearrange
19
- import imageio
20
- import time
21
-
22
- from torchvision.transforms import functional as F
23
- from torch.hub import download_url_to_file
24
-
25
- import os
26
-
27
- # 推理设置
28
- from pl_trainer.inference.inference import InferenceIP2PVideo
29
- from tqdm import tqdm
30
-
31
-
32
- # if not os.path.exists(filename):
33
- # original_path = os.getcwd()
34
- # base_path = './models'
35
- # os.makedirs(base_path, exist_ok=True)
36
-
37
- # # 直接在代码中写入 Token(注意安全风险)
38
- # GIT_TOKEN = "955b8ea91095840b76fe38b90a088c200d4c813c"
39
- # repo_url = f"https://YeFang:{GIT_TOKEN}@code.openxlab.org.cn/YeFang/RIV_models.git"
40
-
41
- # try:
42
- # if os.system(f'git clone {repo_url} {base_path}') != 0:
43
- # raise RuntimeError("Git 克隆失败")
44
- # os.chdir(base_path)
45
- # if os.system('git lfs pull') != 0:
46
- # raise RuntimeError("Git LFS 拉取失败")
47
- # finally:
48
- # os.chdir(original_path)
49
-
50
- def tensor_to_pil_image(x):
51
- """
52
- 将 4D PyTorch 张量转换为 PIL 图像。
53
- """
54
- x = x.float() # 确保张量类型为 float
55
- grid_img = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0).detach().cpu().numpy()
56
- grid_img = (grid_img * 255).clip(0, 255).astype("uint8") # 将 [0, 1] 范围转换为 [0, 255]
57
- return Image.fromarray(grid_img)
58
-
59
- def frame_to_batch(x):
60
- """
61
- 将帧维度转换为批次维度。
62
- """
63
- return rearrange(x, 'b f c h w -> (b f) c h w')
64
-
65
- def clip_image(x, min=0., max=1.):
66
- """
67
- 将图像张量裁剪到指定的最小和最大值。
68
- """
69
- return torch.clamp(x, min=min, max=max)
70
-
71
- def unnormalize(x):
72
- """
73
- 将张量范围从 [-1, 1] 转换到 [0, 1]。
74
- """
75
- return (x + 1) / 2
76
-
77
-
78
- # 读取图像文件
79
- def read_images_from_directory(directory, num_frames=16):
80
- images = []
81
- for i in range(num_frames):
82
- img_path = os.path.join(directory, f'{i:04d}.png')
83
- img = imageio.imread(img_path)
84
- images.append(torch.tensor(img).permute(2, 0, 1)) # Convert to Tensor (C, H, W)
85
- return images
86
-
87
- def load_and_process_images(folder_path):
88
- """
89
- 读取文件夹中的所有图片,将它们转换为 [-1, 1] 范围的张量并返回一个 4D 张量。
90
- """
91
- processed_images = []
92
- transform = transforms.Compose([
93
- transforms.ToTensor(),
94
- transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
95
- ])
96
- for filename in sorted(os.listdir(folder_path)):
97
- if filename.endswith(".png"):
98
- img_path = os.path.join(folder_path, filename)
99
- image = Image.open(img_path).convert("RGB")
100
- processed_image = transform(image)
101
- processed_images.append(processed_image)
102
- return torch.stack(processed_images) # 返回 4D 张量
103
-
104
- def load_and_process_video(video_path, num_frames=16, crop_size=512):
105
- """
106
- 读取视频文件中的前 num_frames 帧,将每一帧转换为 [-1, 1] 范围的张量,
107
- 并进行中心裁剪至 crop_size x crop_size,返回一个 4D 张量。
108
- """
109
- processed_frames = []
110
- transform = transforms.Compose([
111
- transforms.CenterCrop(crop_size), # 中心裁剪
112
- transforms.ToTensor(),
113
- transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
114
- ])
115
-
116
- # 使用 OpenCV 读取视频
117
- cap = cv2.VideoCapture(video_path)
118
-
119
- if not cap.isOpened():
120
- raise ValueError(f"无法打开视频文件: {video_path}")
121
-
122
- frame_count = 0
123
-
124
- while frame_count < num_frames:
125
- ret, frame = cap.read()
126
- if not ret:
127
- break # 视频帧读取完毕或视频帧不足
128
-
129
- # 转换为 RGB 格式
130
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
131
- image = Image.fromarray(frame)
132
-
133
- # 应用转换
134
- processed_frame = transform(image)
135
- processed_frames.append(processed_frame)
136
-
137
- frame_count += 1
138
-
139
- cap.release() # 释放视频资源
140
-
141
- if len(processed_frames) < num_frames:
142
- raise ValueError(f"视频帧不足 {num_frames} 帧,仅找到 {len(processed_frames)} 帧。")
143
-
144
- return torch.stack(processed_frames) # 返回 4D 张量 (帧数, 通道数, 高度, 宽度)
145
-
146
-
147
- def clear_cache(output_path):
148
- if os.path.exists(output_path):
149
- os.remove(output_path)
150
- return None
151
-
152
-
153
- #! 加载模型
154
- # 配置路径和加载模型
155
- config_path = 'configs/instruct_v2v_ic_gradio.yaml'
156
- diffusion_model = unit_test_create_model(config_path)
157
- diffusion_model = diffusion_model.to('cuda')
158
-
159
- # 加载模型检查点
160
- # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
161
- # ckpt_path = 'tmp/pytorch_model.bin'
162
- # 下载文件
163
-
164
- os.makedirs('models', exist_ok=True)
165
- model_path = "models/relvid_mm_sd15_fbc_unet.pth"
166
-
167
- if not os.path.exists(model_path):
168
- download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
169
-
170
-
171
- ckpt = torch.load(model_path, map_location='cpu')
172
- diffusion_model.load_state_dict(ckpt, strict=False)
173
-
174
-
175
- # import pdb; pdb.set_trace()
176
-
177
- # 更改全局临时目录
178
- new_tmp_dir = "./demo/gradio_bg"
179
- os.makedirs(new_tmp_dir, exist_ok=True)
180
-
181
- # import pdb; pdb.set_trace()
182
-
183
- def save_video_from_frames(image_pred, save_pth, fps=8):
184
- """
185
- 将 image_pred 中的帧保存为视频文件。
186
-
187
- 参数:
188
- - image_pred: Tensor,形状为 (1, 16, 3, 512, 512)
189
- - save_pth: 保存视频的路径,例如 "output_video.mp4"
190
- - fps: 视频的帧率
191
- """
192
- # 视频参数
193
- num_frames = image_pred.shape[1]
194
- frame_height, frame_width = 512, 512 # 目标尺寸
195
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码格式
196
-
197
- # 创建 VideoWriter 对象
198
- out = cv2.VideoWriter(save_pth, fourcc, fps, (frame_width, frame_height))
199
-
200
- for i in range(num_frames):
201
- # 反归一化 + 转换为 0-255 范围
202
- pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
203
- pred_frame_resized = pred_frame.squeeze(0).detach().cpu() # (3, 512, 512)
204
- pred_frame_resized = pred_frame_resized.permute(1, 2, 0).numpy().astype("uint8") # (512, 512, 3)
205
-
206
- # Resize 到 256x256
207
- pred_frame_resized = cv2.resize(pred_frame_resized, (frame_width, frame_height))
208
-
209
- # 将 RGB 转为 BGR(因为 OpenCV 使用 BGR 格式)
210
- pred_frame_bgr = cv2.cvtColor(pred_frame_resized, cv2.COLOR_RGB2BGR)
211
-
212
- # 写入帧到视频
213
- out.write(pred_frame_bgr)
214
-
215
- # 释放 VideoWriter 资源
216
- out.release()
217
- print(f"视频已保存至 {save_pth}")
218
-
219
-
220
- inf_pipe = InferenceIP2PVideo(
221
- diffusion_model.unet,
222
- scheduler='ddpm',
223
- num_ddim_steps=20
224
- )
225
-
226
- # 伪函数占位(生成空白视频)
227
- def dummy_process(input_fg, input_bg):
228
- # import pdb; pdb.set_trace()
229
-
230
- diffusion_model.to(torch.float16)
231
- fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0).to(dtype=torch.float16)
232
- bg_tensor = load_and_process_video(input_bg).cuda().unsqueeze(0).to(dtype=torch.float16) # (1, 16, 4, 64, 64)
233
-
234
- cond_fg_tensor = diffusion_model.encode_image_to_latent(fg_tensor) # (1, 16, 4, 64, 64)
235
- cond_bg_tensor = diffusion_model.encode_image_to_latent(bg_tensor)
236
- cond_tensor = torch.cat((cond_fg_tensor, cond_bg_tensor), dim=2)
237
-
238
- # 初始化潜变量
239
- init_latent = torch.randn_like(cond_fg_tensor)
240
-
241
- EDIT_PROMPT = 'change the background'
242
- VIDEO_CFG = 1.2
243
- TEXT_CFG = 7.5
244
- text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
245
- text_uncond = diffusion_model.encode_text([''])
246
- # to float16
247
- print('------------to float 16----------------')
248
- init_latent, text_cond, text_uncond, cond_tensor = (
249
- init_latent.to(dtype=torch.float16),
250
- text_cond.to(dtype=torch.float16),
251
- text_uncond.to(dtype=torch.float16),
252
- cond_tensor.to(dtype=torch.float16)
253
- )
254
- inf_pipe.unet.to(torch.float16)
255
- latent_pred = inf_pipe(
256
- latent=init_latent,
257
- text_cond=text_cond,
258
- text_uncond=text_uncond,
259
- img_cond=cond_tensor,
260
- text_cfg=TEXT_CFG,
261
- img_cfg=VIDEO_CFG,
262
- )['latent']
263
-
264
-
265
- image_pred = diffusion_model.decode_latent_to_image(latent_pred) # (1,16,3,512,512)
266
- output_path = os.path.join(new_tmp_dir, f"output_{int(time.time())}.mp4")
267
- # clear_cache(output_path)
268
-
269
- save_video_from_frames(image_pred, output_path)
270
- # import pdb; pdb.set_trace()
271
- # fps = 8
272
- # frames = []
273
- # for i in range(16):
274
- # pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
275
- # pred_frame_resized = pred_frame.squeeze(0).detach().cpu() #(3,512,512)
276
- # pred_frame_resized = pred_frame_resized.permute(1, 2, 0).detach().cpu().numpy().astype("uint8") #(512,512,3) np
277
- # Image.fromarray(pred_frame_resized).save(save_pth)
278
-
279
- # # 生成一个简单的黑色视频作为示例
280
- # output_path = os.path.join(new_tmp_dir, "output.mp4")
281
- # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
282
- # out = cv2.VideoWriter(output_path, fourcc, 20.0, (512, 512))
283
-
284
- # for _ in range(60): # 生成 3 秒的视频(20fps)
285
- # frame = np.zeros((512, 512, 3), dtype=np.uint8)
286
- # out.write(frame)
287
- # out.release()
288
- torch.cuda.empty_cache()
289
-
290
- return output_path
291
-
292
- # 枚举类用于背景选择
293
- class BGSource(Enum):
294
- UPLOAD = "Use Background Video"
295
- UPLOAD_FLIP = "Use Flipped Background Video"
296
- UPLOAD_REVERSE = "Use Reversed Background Video"
297
-
298
-
299
- # Quick prompts 示例
300
- quick_prompts = [
301
- 'beautiful woman',
302
- 'handsome man',
303
- 'beautiful woman, cinematic lighting',
304
- 'handsome man, cinematic lighting',
305
- 'beautiful woman, natural lighting',
306
- 'handsome man, natural lighting',
307
- 'beautiful woman, neo punk lighting, cyberpunk',
308
- 'handsome man, neo punk lighting, cyberpunk',
309
- ]
310
- quick_prompts = [[x] for x in quick_prompts]
311
-
312
- # Gradio UI 结构
313
- block = gr.Blocks().queue()
314
- with block:
315
- with gr.Row():
316
- gr.Markdown("## IC-Light (Relighting with Foreground and Background Video Condition)")
317
-
318
- with gr.Row():
319
- with gr.Column():
320
- with gr.Row():
321
- input_fg = gr.Video(label="Foreground Video", height=370, width=370, visible=True)
322
- input_bg = gr.Video(label="Background Video", height=370, width=370, visible=True)
323
-
324
- prompt = gr.Textbox(label="Prompt")
325
- bg_source = gr.Radio(choices=[e.value for e in BGSource],
326
- value=BGSource.UPLOAD.value,
327
- label="Background Source", type='value')
328
-
329
- example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
330
- bg_gallery = gr.Gallery(height=450, object_fit='contain', label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False)
331
- relight_button = gr.Button(value="Relight")
332
-
333
- with gr.Group():
334
- with gr.Row():
335
- num_samples = gr.Slider(label="Videos", minimum=1, maximum=12, value=1, step=1)
336
- seed = gr.Number(label="Seed", value=12345, precision=0)
337
- with gr.Row():
338
- video_width = gr.Slider(label="Video Width", minimum=256, maximum=1024, value=512, step=64)
339
- video_height = gr.Slider(label="Video Height", minimum=256, maximum=1024, value=640, step=64)
340
-
341
- with gr.Accordion("Advanced options", open=False):
342
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
343
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
344
- highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
345
- highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
346
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
347
- n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
348
- normal_button = gr.Button(value="Compute Normal (4x Slower)")
349
-
350
- with gr.Column():
351
- result_video = gr.Video(label='Output Video', height=600, width=600, visible=True)
352
- fg_gallery = gr.Gallery(width=600, object_fit='contain', label='Foreground Quick List', value=db_examples.bg_samples, columns=4, allow_preview=False)
353
-
354
- # 输入列表
355
- # ips = [input_fg, input_bg, prompt, video_width, video_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
356
- ips = [input_fg, input_bg]
357
-
358
- # 按钮绑定处理函数
359
- # relight_button.click(fn=lambda: None, inputs=[], outputs=[result_video])
360
-
361
- relight_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
362
-
363
- normal_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
364
-
365
- # 背景库选择
366
- def bg_gallery_selected(gal, evt: gr.SelectData):
367
- # import pdb; pdb.set_trace()
368
- # img_path = gal[evt.index][0]
369
- img_path = db_examples.bg_samples[evt.index]
370
- video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4')
371
- return video_path
372
-
373
- bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=input_bg)
374
-
375
- # 示例
376
- # dummy_video_for_outputs = gr.Video(visible=False, label='Result')
377
- gr.Examples(
378
- fn=lambda *args: args[-1],
379
- examples=db_examples.background_conditioned_examples,
380
- inputs=[input_fg, input_bg, prompt, bg_source, video_width, video_height, seed, result_video],
381
- outputs=[result_video],
382
- run_on_click=True, examples_per_page=1024
383
- )
384
-
385
- # 启动 Gradio 应用
386
- block.launch(server_name='0.0.0.0', server_port=10003, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app1_bf.py DELETED
@@ -1,388 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import numpy as np
4
- from enum import Enum
5
- import db_examples
6
- import cv2
7
-
8
-
9
- from demo_utils1 import *
10
-
11
- from misc_utils.train_utils import unit_test_create_model
12
- from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
13
- import os
14
- from PIL import Image
15
- import torch
16
- import torchvision
17
- from torchvision import transforms
18
- from einops import rearrange
19
- import imageio
20
- import time
21
-
22
- from torchvision.transforms import functional as F
23
- from torch.hub import download_url_to_file
24
-
25
- import os
26
-
27
- # 推理设置
28
- from pl_trainer.inference.inference import InferenceIP2PVideo
29
- from tqdm import tqdm
30
-
31
-
32
- # if not os.path.exists(filename):
33
- # original_path = os.getcwd()
34
- # base_path = './models'
35
- # os.makedirs(base_path, exist_ok=True)
36
-
37
- # # 直接在代码中写入 Token(注意安全风险)
38
- # GIT_TOKEN = "955b8ea91095840b76fe38b90a088c200d4c813c"
39
- # repo_url = f"https://YeFang:{GIT_TOKEN}@code.openxlab.org.cn/YeFang/RIV_models.git"
40
-
41
- # try:
42
- # if os.system(f'git clone {repo_url} {base_path}') != 0:
43
- # raise RuntimeError("Git 克隆失败")
44
- # os.chdir(base_path)
45
- # if os.system('git lfs pull') != 0:
46
- # raise RuntimeError("Git LFS 拉取失败")
47
- # finally:
48
- # os.chdir(original_path)
49
-
50
- def tensor_to_pil_image(x):
51
- """
52
- 将 4D PyTorch 张量转换为 PIL 图像。
53
- """
54
- x = x.float() # 确保张量类型为 float
55
- grid_img = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0).detach().cpu().numpy()
56
- grid_img = (grid_img * 255).clip(0, 255).astype("uint8") # 将 [0, 1] 范围转换为 [0, 255]
57
- return Image.fromarray(grid_img)
58
-
59
- def frame_to_batch(x):
60
- """
61
- 将帧维度转换为批次维度。
62
- """
63
- return rearrange(x, 'b f c h w -> (b f) c h w')
64
-
65
- def clip_image(x, min=0., max=1.):
66
- """
67
- 将图像张量裁剪到指定的最小和最大值。
68
- """
69
- return torch.clamp(x, min=min, max=max)
70
-
71
- def unnormalize(x):
72
- """
73
- 将张量范围从 [-1, 1] 转换到 [0, 1]。
74
- """
75
- return (x + 1) / 2
76
-
77
-
78
- # 读取图像文件
79
- def read_images_from_directory(directory, num_frames=16):
80
- images = []
81
- for i in range(num_frames):
82
- img_path = os.path.join(directory, f'{i:04d}.png')
83
- img = imageio.imread(img_path)
84
- images.append(torch.tensor(img).permute(2, 0, 1)) # Convert to Tensor (C, H, W)
85
- return images
86
-
87
- def load_and_process_images(folder_path):
88
- """
89
- 读取文件夹中的所有图片,将它们转换为 [-1, 1] 范围的张量并返回一个 4D 张量。
90
- """
91
- processed_images = []
92
- transform = transforms.Compose([
93
- transforms.ToTensor(),
94
- transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
95
- ])
96
- for filename in sorted(os.listdir(folder_path)):
97
- if filename.endswith(".png"):
98
- img_path = os.path.join(folder_path, filename)
99
- image = Image.open(img_path).convert("RGB")
100
- processed_image = transform(image)
101
- processed_images.append(processed_image)
102
- return torch.stack(processed_images) # 返回 4D 张量
103
-
104
- def load_and_process_video(video_path, num_frames=16, crop_size=512):
105
- """
106
- 读取视频文件中的前 num_frames 帧,将每一帧转换为 [-1, 1] 范围的张量,
107
- 并进行中心裁剪至 crop_size x crop_size,返回一个 4D 张量。
108
- """
109
- processed_frames = []
110
- transform = transforms.Compose([
111
- transforms.CenterCrop(crop_size), # 中心裁剪
112
- transforms.ToTensor(),
113
- transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
114
- ])
115
-
116
- # 使用 OpenCV 读取视频
117
- cap = cv2.VideoCapture(video_path)
118
-
119
- if not cap.isOpened():
120
- raise ValueError(f"无法打开视频文件: {video_path}")
121
-
122
- frame_count = 0
123
-
124
- while frame_count < num_frames:
125
- ret, frame = cap.read()
126
- if not ret:
127
- break # 视频帧读取完毕或视频帧不足
128
-
129
- # 转换为 RGB 格式
130
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
131
- image = Image.fromarray(frame)
132
-
133
- # 应用转换
134
- processed_frame = transform(image)
135
- processed_frames.append(processed_frame)
136
-
137
- frame_count += 1
138
-
139
- cap.release() # 释放视频资源
140
-
141
- if len(processed_frames) < num_frames:
142
- raise ValueError(f"视频帧不足 {num_frames} 帧,仅找到 {len(processed_frames)} 帧。")
143
-
144
- return torch.stack(processed_frames) # 返回 4D 张量 (帧数, 通道数, 高度, 宽度)
145
-
146
-
147
- def clear_cache(output_path):
148
- if os.path.exists(output_path):
149
- os.remove(output_path)
150
- return None
151
-
152
-
153
- #! 加载模型
154
- # 配置路径和加载模型
155
- config_path = 'configs/instruct_v2v_ic_gradio.yaml'
156
- diffusion_model = unit_test_create_model(config_path)
157
- diffusion_model = diffusion_model.to('cuda')
158
-
159
- # 加载模型检查点
160
- # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
161
- # ckpt_path = 'tmp/pytorch_model.bin'
162
- # 下载文件
163
-
164
- os.makedirs('models', exist_ok=True)
165
- model_path = "models/relvid_mm_sd15_fbc_unet.pth"
166
-
167
- if not os.path.exists(model_path):
168
- download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
169
-
170
-
171
- ckpt = torch.load(model_path, map_location='cpu')
172
- diffusion_model.load_state_dict(ckpt, strict=False)
173
-
174
-
175
- # import pdb; pdb.set_trace()
176
-
177
- # 更改全局临时目录
178
- new_tmp_dir = "./demo/gradio_bg"
179
- os.makedirs(new_tmp_dir, exist_ok=True)
180
-
181
- # import pdb; pdb.set_trace()
182
-
183
- def save_video_from_frames(image_pred, save_pth, fps=8):
184
- """
185
- 将 image_pred 中的帧保存为视频文件。
186
-
187
- 参数:
188
- - image_pred: Tensor,形状为 (1, 16, 3, 512, 512)
189
- - save_pth: 保存视频的路径,例如 "output_video.mp4"
190
- - fps: 视频的帧率
191
- """
192
- # 视频参数
193
- num_frames = image_pred.shape[1]
194
- frame_height, frame_width = 512, 512 # 目标尺寸
195
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码格式
196
-
197
- # 创建 VideoWriter 对象
198
- out = cv2.VideoWriter(save_pth, fourcc, fps, (frame_width, frame_height))
199
-
200
- for i in range(num_frames):
201
- # 反归一化 + 转换为 0-255 范围
202
- pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
203
- pred_frame_resized = pred_frame.squeeze(0).detach().cpu() # (3, 512, 512)
204
- pred_frame_resized = pred_frame_resized.permute(1, 2, 0).numpy().astype("uint8") # (512, 512, 3)
205
-
206
- # Resize 到 256x256
207
- pred_frame_resized = cv2.resize(pred_frame_resized, (frame_width, frame_height))
208
-
209
- # 将 RGB 转为 BGR(因为 OpenCV 使用 BGR 格式)
210
- pred_frame_bgr = cv2.cvtColor(pred_frame_resized, cv2.COLOR_RGB2BGR)
211
-
212
- # 写入帧到视频
213
- out.write(pred_frame_bgr)
214
-
215
- # 释放 VideoWriter 资源
216
- out.release()
217
- print(f"视频已保存至 {save_pth}")
218
-
219
-
220
- inf_pipe = InferenceIP2PVideo(
221
- diffusion_model.unet,
222
- scheduler='ddpm',
223
- num_ddim_steps=20
224
- )
225
-
226
- # 伪函数占位(生成空白视频)
227
- def dummy_process(input_fg, input_bg):
228
- # import pdb; pdb.set_trace()
229
-
230
- diffusion_model.to(torch.float16)
231
- fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0).to(dtype=torch.float16)
232
- bg_tensor = load_and_process_video(input_bg).cuda().unsqueeze(0).to(dtype=torch.float16) # (1, 16, 4, 64, 64)
233
-
234
- cond_fg_tensor = diffusion_model.encode_image_to_latent(fg_tensor) # (1, 16, 4, 64, 64)
235
- cond_bg_tensor = diffusion_model.encode_image_to_latent(bg_tensor)
236
- cond_tensor = torch.cat((cond_fg_tensor, cond_bg_tensor), dim=2)
237
-
238
- # 初始化潜变量
239
- init_latent = torch.randn_like(cond_fg_tensor)
240
-
241
- EDIT_PROMPT = 'change the background'
242
- VIDEO_CFG = 1.2
243
- TEXT_CFG = 7.5
244
- text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
245
- text_uncond = diffusion_model.encode_text([''])
246
- # to float16
247
- print('------------to float 16----------------')
248
- init_latent, text_cond, text_uncond, cond_tensor = (
249
- init_latent.to(dtype=torch.float16),
250
- text_cond.to(dtype=torch.float16),
251
- text_uncond.to(dtype=torch.float16),
252
- cond_tensor.to(dtype=torch.float16)
253
- )
254
- inf_pipe.unet.to(torch.float16)
255
- latent_pred = inf_pipe(
256
- latent=init_latent,
257
- text_cond=text_cond,
258
- text_uncond=text_uncond,
259
- img_cond=cond_tensor,
260
- text_cfg=TEXT_CFG,
261
- img_cfg=VIDEO_CFG,
262
- )['latent']
263
-
264
-
265
- image_pred = diffusion_model.decode_latent_to_image(latent_pred) # (1,16,3,512,512)
266
- output_path = os.path.join(new_tmp_dir, f"output_{int(time.time())}.mp4")
267
- # clear_cache(output_path)
268
-
269
- save_video_from_frames(image_pred, output_path)
270
- # import pdb; pdb.set_trace()
271
- # fps = 8
272
- # frames = []
273
- # for i in range(16):
274
- # pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
275
- # pred_frame_resized = pred_frame.squeeze(0).detach().cpu() #(3,512,512)
276
- # pred_frame_resized = pred_frame_resized.permute(1, 2, 0).detach().cpu().numpy().astype("uint8") #(512,512,3) np
277
- # Image.fromarray(pred_frame_resized).save(save_pth)
278
-
279
- # # 生成一个简单的黑色视频作为示例
280
- # output_path = os.path.join(new_tmp_dir, "output.mp4")
281
- # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
282
- # out = cv2.VideoWriter(output_path, fourcc, 20.0, (512, 512))
283
-
284
- # for _ in range(60): # 生成 3 秒的视频(20fps)
285
- # frame = np.zeros((512, 512, 3), dtype=np.uint8)
286
- # out.write(frame)
287
- # out.release()
288
- torch.cuda.empty_cache()
289
-
290
- return output_path
291
-
292
- # 枚举类用于背景选择
293
- class BGSource(Enum):
294
- UPLOAD = "Use Background Video"
295
- UPLOAD_FLIP = "Use Flipped Background Video"
296
- LEFT = "Left Light"
297
- RIGHT = "Right Light"
298
- TOP = "Top Light"
299
- BOTTOM = "Bottom Light"
300
- GREY = "Ambient"
301
-
302
- # Quick prompts 示例
303
- quick_prompts = [
304
- 'beautiful woman',
305
- 'handsome man',
306
- 'beautiful woman, cinematic lighting',
307
- 'handsome man, cinematic lighting',
308
- 'beautiful woman, natural lighting',
309
- 'handsome man, natural lighting',
310
- 'beautiful woman, neo punk lighting, cyberpunk',
311
- 'handsome man, neo punk lighting, cyberpunk',
312
- ]
313
- quick_prompts = [[x] for x in quick_prompts]
314
-
315
- # Gradio UI 结构
316
- block = gr.Blocks().queue()
317
- with block:
318
- with gr.Row():
319
- gr.Markdown("## IC-Light (Relighting with Foreground and Background Video Condition)")
320
-
321
- with gr.Row():
322
- with gr.Column():
323
- with gr.Row():
324
- input_fg = gr.Video(label="Foreground Video", height=370, width=370, visible=True)
325
- input_bg = gr.Video(label="Background Video", height=370, width=370, visible=True)
326
-
327
- prompt = gr.Textbox(label="Prompt")
328
- bg_source = gr.Radio(choices=[e.value for e in BGSource],
329
- value=BGSource.UPLOAD.value,
330
- label="Background Source", type='value')
331
-
332
- example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
333
- bg_gallery = gr.Gallery(height=450, object_fit='contain', label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False)
334
- relight_button = gr.Button(value="Relight")
335
-
336
- with gr.Group():
337
- with gr.Row():
338
- num_samples = gr.Slider(label="Videos", minimum=1, maximum=12, value=1, step=1)
339
- seed = gr.Number(label="Seed", value=12345, precision=0)
340
- with gr.Row():
341
- video_width = gr.Slider(label="Video Width", minimum=256, maximum=1024, value=512, step=64)
342
- video_height = gr.Slider(label="Video Height", minimum=256, maximum=1024, value=640, step=64)
343
-
344
- with gr.Accordion("Advanced options", open=False):
345
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
346
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
347
- highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
348
- highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
349
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
350
- n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
351
- normal_button = gr.Button(value="Compute Normal (4x Slower)")
352
-
353
- with gr.Column():
354
- result_video = gr.Video(label='Output Video', height=600, width=600, visible=True)
355
-
356
- # 输入列表
357
- # ips = [input_fg, input_bg, prompt, video_width, video_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
358
- ips = [input_fg, input_bg]
359
-
360
- # 按钮绑定处理函数
361
- # relight_button.click(fn=lambda: None, inputs=[], outputs=[result_video])
362
-
363
- relight_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
364
-
365
- normal_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
366
-
367
- # 背景库选择
368
- def bg_gallery_selected(gal, evt: gr.SelectData):
369
- # import pdb; pdb.set_trace()
370
- # img_path = gal[evt.index][0]
371
- img_path = db_examples.bg_samples[evt.index]
372
- video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4')
373
- return video_path
374
-
375
- bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=input_bg)
376
-
377
- # 示例
378
- # dummy_video_for_outputs = gr.Video(visible=False, label='Result')
379
- gr.Examples(
380
- fn=lambda *args: args[-1],
381
- examples=db_examples.background_conditioned_examples,
382
- inputs=[input_fg, input_bg, prompt, bg_source, video_width, video_height, seed, result_video],
383
- outputs=[result_video],
384
- run_on_click=True, examples_per_page=1024
385
- )
386
-
387
- # 启动 Gradio 应用
388
- block.launch(server_name='0.0.0.0', server_port=10002, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app1_bf2.py DELETED
@@ -1,388 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import numpy as np
4
- from enum import Enum
5
- import db_examples
6
- import cv2
7
-
8
-
9
- from demo_utils1 import *
10
-
11
- from misc_utils.train_utils import unit_test_create_model
12
- from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
13
- import os
14
- from PIL import Image
15
- import torch
16
- import torchvision
17
- from torchvision import transforms
18
- from einops import rearrange
19
- import imageio
20
- import time
21
-
22
- from torchvision.transforms import functional as F
23
- from torch.hub import download_url_to_file
24
-
25
- import os
26
-
27
- # 推理设置
28
- from pl_trainer.inference.inference import InferenceIP2PVideo
29
- from tqdm import tqdm
30
-
31
-
32
- # if not os.path.exists(filename):
33
- # original_path = os.getcwd()
34
- # base_path = './models'
35
- # os.makedirs(base_path, exist_ok=True)
36
-
37
- # # 直接在代码中写入 Token(注意安全风险)
38
- # GIT_TOKEN = "955b8ea91095840b76fe38b90a088c200d4c813c"
39
- # repo_url = f"https://YeFang:{GIT_TOKEN}@code.openxlab.org.cn/YeFang/RIV_models.git"
40
-
41
- # try:
42
- # if os.system(f'git clone {repo_url} {base_path}') != 0:
43
- # raise RuntimeError("Git 克隆失败")
44
- # os.chdir(base_path)
45
- # if os.system('git lfs pull') != 0:
46
- # raise RuntimeError("Git LFS 拉取失败")
47
- # finally:
48
- # os.chdir(original_path)
49
-
50
- def tensor_to_pil_image(x):
51
- """
52
- 将 4D PyTorch 张量转换为 PIL 图像。
53
- """
54
- x = x.float() # 确保张量类型为 float
55
- grid_img = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0).detach().cpu().numpy()
56
- grid_img = (grid_img * 255).clip(0, 255).astype("uint8") # 将 [0, 1] 范围转换为 [0, 255]
57
- return Image.fromarray(grid_img)
58
-
59
- def frame_to_batch(x):
60
- """
61
- 将帧维度转换为批次维度。
62
- """
63
- return rearrange(x, 'b f c h w -> (b f) c h w')
64
-
65
- def clip_image(x, min=0., max=1.):
66
- """
67
- 将图像张量裁剪到指定的最小和最大值。
68
- """
69
- return torch.clamp(x, min=min, max=max)
70
-
71
- def unnormalize(x):
72
- """
73
- 将张量范围从 [-1, 1] 转换到 [0, 1]。
74
- """
75
- return (x + 1) / 2
76
-
77
-
78
- # 读取图像文件
79
- def read_images_from_directory(directory, num_frames=16):
80
- images = []
81
- for i in range(num_frames):
82
- img_path = os.path.join(directory, f'{i:04d}.png')
83
- img = imageio.imread(img_path)
84
- images.append(torch.tensor(img).permute(2, 0, 1)) # Convert to Tensor (C, H, W)
85
- return images
86
-
87
- def load_and_process_images(folder_path):
88
- """
89
- 读取文件夹中的所有图片,将它们转换为 [-1, 1] 范围的张量并返回一个 4D 张量。
90
- """
91
- processed_images = []
92
- transform = transforms.Compose([
93
- transforms.ToTensor(),
94
- transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
95
- ])
96
- for filename in sorted(os.listdir(folder_path)):
97
- if filename.endswith(".png"):
98
- img_path = os.path.join(folder_path, filename)
99
- image = Image.open(img_path).convert("RGB")
100
- processed_image = transform(image)
101
- processed_images.append(processed_image)
102
- return torch.stack(processed_images) # 返回 4D 张量
103
-
104
- def load_and_process_video(video_path, num_frames=16, crop_size=512):
105
- """
106
- 读取视频文件中的前 num_frames 帧,将每一帧转换为 [-1, 1] 范围的张量,
107
- 并进行中心裁剪至 crop_size x crop_size,返回一个 4D 张量。
108
- """
109
- processed_frames = []
110
- transform = transforms.Compose([
111
- transforms.CenterCrop(crop_size), # 中心裁剪
112
- transforms.ToTensor(),
113
- transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
114
- ])
115
-
116
- # 使用 OpenCV 读取视频
117
- cap = cv2.VideoCapture(video_path)
118
-
119
- if not cap.isOpened():
120
- raise ValueError(f"无法打开视频文件: {video_path}")
121
-
122
- frame_count = 0
123
-
124
- while frame_count < num_frames:
125
- ret, frame = cap.read()
126
- if not ret:
127
- break # 视频帧读取完毕或视频帧不足
128
-
129
- # 转换为 RGB 格式
130
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
131
- image = Image.fromarray(frame)
132
-
133
- # 应用转换
134
- processed_frame = transform(image)
135
- processed_frames.append(processed_frame)
136
-
137
- frame_count += 1
138
-
139
- cap.release() # 释放视频资源
140
-
141
- if len(processed_frames) < num_frames:
142
- raise ValueError(f"视频帧不足 {num_frames} 帧,仅找到 {len(processed_frames)} 帧。")
143
-
144
- return torch.stack(processed_frames) # 返回 4D 张量 (帧数, 通道数, 高度, 宽度)
145
-
146
-
147
- def clear_cache(output_path):
148
- if os.path.exists(output_path):
149
- os.remove(output_path)
150
- return None
151
-
152
-
153
- #! 加载模型
154
- # 配置路径和加载模型
155
- config_path = 'configs/instruct_v2v_ic_gradio.yaml'
156
- diffusion_model = unit_test_create_model(config_path)
157
- diffusion_model = diffusion_model.to('cuda')
158
-
159
- # 加载模型检查点
160
- # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
161
- # ckpt_path = 'tmp/pytorch_model.bin'
162
- # 下载文件
163
-
164
- os.makedirs('models', exist_ok=True)
165
- model_path = "models/relvid_mm_sd15_fbc_unet.pth"
166
-
167
- if not os.path.exists(model_path):
168
- download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
169
-
170
-
171
- ckpt = torch.load(model_path, map_location='cpu')
172
- diffusion_model.load_state_dict(ckpt, strict=False)
173
-
174
-
175
- # import pdb; pdb.set_trace()
176
-
177
- # 更改全局临时目录
178
- new_tmp_dir = "./demo/gradio_bg"
179
- os.makedirs(new_tmp_dir, exist_ok=True)
180
-
181
- # import pdb; pdb.set_trace()
182
-
183
- def save_video_from_frames(image_pred, save_pth, fps=8):
184
- """
185
- 将 image_pred 中的帧保存为视频文件。
186
-
187
- 参数:
188
- - image_pred: Tensor,形状为 (1, 16, 3, 512, 512)
189
- - save_pth: 保存视频的路径,例如 "output_video.mp4"
190
- - fps: 视频的帧率
191
- """
192
- # 视频参数
193
- num_frames = image_pred.shape[1]
194
- frame_height, frame_width = 512, 512 # 目标尺寸
195
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码格式
196
-
197
- # 创建 VideoWriter 对象
198
- out = cv2.VideoWriter(save_pth, fourcc, fps, (frame_width, frame_height))
199
-
200
- for i in range(num_frames):
201
- # 反归一化 + 转换为 0-255 范围
202
- pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
203
- pred_frame_resized = pred_frame.squeeze(0).detach().cpu() # (3, 512, 512)
204
- pred_frame_resized = pred_frame_resized.permute(1, 2, 0).numpy().astype("uint8") # (512, 512, 3)
205
-
206
- # Resize 到 256x256
207
- pred_frame_resized = cv2.resize(pred_frame_resized, (frame_width, frame_height))
208
-
209
- # 将 RGB 转为 BGR(因为 OpenCV 使用 BGR 格式)
210
- pred_frame_bgr = cv2.cvtColor(pred_frame_resized, cv2.COLOR_RGB2BGR)
211
-
212
- # 写入帧到视频
213
- out.write(pred_frame_bgr)
214
-
215
- # 释放 VideoWriter 资源
216
- out.release()
217
- print(f"视频已保存至 {save_pth}")
218
-
219
-
220
- inf_pipe = InferenceIP2PVideo(
221
- diffusion_model.unet,
222
- scheduler='ddpm',
223
- num_ddim_steps=20
224
- )
225
-
226
- # 伪函数占位(生成空白视频)
227
- def dummy_process(input_fg, input_bg):
228
- # import pdb; pdb.set_trace()
229
-
230
- diffusion_model.to(torch.float16)
231
- fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0).to(dtype=torch.float16)
232
- bg_tensor = load_and_process_video(input_bg).cuda().unsqueeze(0).to(dtype=torch.float16) # (1, 16, 4, 64, 64)
233
-
234
- cond_fg_tensor = diffusion_model.encode_image_to_latent(fg_tensor) # (1, 16, 4, 64, 64)
235
- cond_bg_tensor = diffusion_model.encode_image_to_latent(bg_tensor)
236
- cond_tensor = torch.cat((cond_fg_tensor, cond_bg_tensor), dim=2)
237
-
238
- # 初始化潜变量
239
- init_latent = torch.randn_like(cond_fg_tensor)
240
-
241
- EDIT_PROMPT = 'change the background'
242
- VIDEO_CFG = 1.2
243
- TEXT_CFG = 7.5
244
- text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
245
- text_uncond = diffusion_model.encode_text([''])
246
- # to float16
247
- print('------------to float 16----------------')
248
- init_latent, text_cond, text_uncond, cond_tensor = (
249
- init_latent.to(dtype=torch.float16),
250
- text_cond.to(dtype=torch.float16),
251
- text_uncond.to(dtype=torch.float16),
252
- cond_tensor.to(dtype=torch.float16)
253
- )
254
- inf_pipe.unet.to(torch.float16)
255
- latent_pred = inf_pipe(
256
- latent=init_latent,
257
- text_cond=text_cond,
258
- text_uncond=text_uncond,
259
- img_cond=cond_tensor,
260
- text_cfg=TEXT_CFG,
261
- img_cfg=VIDEO_CFG,
262
- )['latent']
263
-
264
-
265
- image_pred = diffusion_model.decode_latent_to_image(latent_pred) # (1,16,3,512,512)
266
- output_path = os.path.join(new_tmp_dir, f"output_{int(time.time())}.mp4")
267
- # clear_cache(output_path)
268
-
269
- save_video_from_frames(image_pred, output_path)
270
- # import pdb; pdb.set_trace()
271
- # fps = 8
272
- # frames = []
273
- # for i in range(16):
274
- # pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
275
- # pred_frame_resized = pred_frame.squeeze(0).detach().cpu() #(3,512,512)
276
- # pred_frame_resized = pred_frame_resized.permute(1, 2, 0).detach().cpu().numpy().astype("uint8") #(512,512,3) np
277
- # Image.fromarray(pred_frame_resized).save(save_pth)
278
-
279
- # # 生成一个简单的黑色视频作为示例
280
- # output_path = os.path.join(new_tmp_dir, "output.mp4")
281
- # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
282
- # out = cv2.VideoWriter(output_path, fourcc, 20.0, (512, 512))
283
-
284
- # for _ in range(60): # 生成 3 秒的视频(20fps)
285
- # frame = np.zeros((512, 512, 3), dtype=np.uint8)
286
- # out.write(frame)
287
- # out.release()
288
- torch.cuda.empty_cache()
289
-
290
- return output_path
291
-
292
- # 枚举类用于背景选择
293
- class BGSource(Enum):
294
- UPLOAD = "Use Background Video"
295
- UPLOAD_FLIP = "Use Flipped Background Video"
296
- UPLOAD_REVERSE = "Use Reversed Background Video"
297
-
298
- # Quick prompts 示例
299
- quick_prompts = [
300
- 'beautiful woman',
301
- 'handsome man',
302
- 'beautiful woman, cinematic lighting',
303
- 'handsome man, cinematic lighting',
304
- 'beautiful woman, natural lighting',
305
- 'handsome man, natural lighting',
306
- 'beautiful woman, neo punk lighting, cyberpunk',
307
- 'handsome man, neo punk lighting, cyberpunk',
308
- ]
309
- quick_prompts = [[x] for x in quick_prompts]
310
-
311
- # Gradio UI 结构
312
- block = gr.Blocks().queue()
313
- with block:
314
- with gr.Row():
315
- gr.Markdown("## IC-Light (Relighting with Foreground and Background Video Condition)")
316
-
317
- with gr.Row():
318
- with gr.Column():
319
- input_fg = gr.Video(label="Foreground Video", height=450, visible=True)
320
- with gr.Column():
321
- input_bg = gr.Video(label="Background Video", height=450, visible=True)
322
- with gr.Column():
323
- result_video = gr.Video(label='Output Video', height=450, visible=True)
324
-
325
- with gr.Row():
326
- with gr.Column():
327
- prompt = gr.Textbox(label="Prompt")
328
- bg_source = gr.Radio(choices=[e.value for e in BGSource],
329
- value=BGSource.UPLOAD.value,
330
- label="Background Source", type='value')
331
-
332
- example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
333
- bg_gallery = gr.Gallery(height=450, object_fit='contain', label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False)
334
- relight_button = gr.Button(value="Relight")
335
-
336
- with gr.Group():
337
- with gr.Row():
338
- num_samples = gr.Slider(label="Videos", minimum=1, maximum=12, value=1, step=1)
339
- seed = gr.Number(label="Seed", value=12345, precision=0)
340
- with gr.Row():
341
- video_width = gr.Slider(label="Video Width", minimum=256, maximum=1024, value=512, step=64)
342
- video_height = gr.Slider(label="Video Height", minimum=256, maximum=1024, value=640, step=64)
343
-
344
-
345
- with gr.Column():
346
- with gr.Accordion("Advanced options", open=False):
347
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
348
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
349
- highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
350
- highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
351
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
352
- n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
353
- normal_button = gr.Button(value="Compute Normal (4x Slower)")
354
-
355
-
356
- # 输入列表
357
- # ips = [input_fg, input_bg, prompt, video_width, video_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
358
- ips = [input_fg, input_bg]
359
-
360
- # 按钮绑定处理函数
361
- # relight_button.click(fn=lambda: None, inputs=[], outputs=[result_video])
362
-
363
- relight_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
364
-
365
- normal_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
366
-
367
- # 背景库选择
368
- def bg_gallery_selected(gal, evt: gr.SelectData):
369
- # import pdb; pdb.set_trace()
370
- # img_path = gal[evt.index][0]
371
- img_path = db_examples.bg_samples[evt.index]
372
- video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4')
373
- return video_path
374
-
375
- bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=input_bg)
376
-
377
- # 示例
378
- # dummy_video_for_outputs = gr.Video(visible=False, label='Result')
379
- gr.Examples(
380
- fn=lambda *args: args[-1],
381
- examples=db_examples.background_conditioned_examples,
382
- inputs=[input_fg, input_bg, prompt, bg_source, video_width, video_height, seed, result_video],
383
- outputs=[result_video],
384
- run_on_click=True, examples_per_page=1024
385
- )
386
-
387
- # 启动 Gradio 应用
388
- block.launch(server_name='0.0.0.0', server_port=10002, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_bf.py DELETED
@@ -1,391 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import numpy as np
4
- from enum import Enum
5
- import db_examples
6
- import cv2
7
-
8
- import spaces
9
-
10
- from demo_utils1 import *
11
-
12
- from misc_utils.train_utils import unit_test_create_model
13
- from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
14
- import os
15
- from PIL import Image
16
- import torch
17
- import torchvision
18
- from torchvision import transforms
19
- from einops import rearrange
20
- import imageio
21
- import time
22
-
23
- from torchvision.transforms import functional as F
24
- from torch.hub import download_url_to_file
25
-
26
- import os
27
-
28
- # 推理设置
29
- from pl_trainer.inference.inference import InferenceIP2PVideo
30
- from tqdm import tqdm
31
-
32
-
33
- # if not os.path.exists(filename):
34
- # original_path = os.getcwd()
35
- # base_path = './models'
36
- # os.makedirs(base_path, exist_ok=True)
37
-
38
- # # 直接在代码中写入 Token(注意安全风险)
39
- # GIT_TOKEN = "955b8ea91095840b76fe38b90a088c200d4c813c"
40
- # repo_url = f"https://YeFang:{GIT_TOKEN}@code.openxlab.org.cn/YeFang/RIV_models.git"
41
-
42
- # try:
43
- # if os.system(f'git clone {repo_url} {base_path}') != 0:
44
- # raise RuntimeError("Git 克隆失败")
45
- # os.chdir(base_path)
46
- # if os.system('git lfs pull') != 0:
47
- # raise RuntimeError("Git LFS 拉取失败")
48
- # finally:
49
- # os.chdir(original_path)
50
-
51
- def tensor_to_pil_image(x):
52
- """
53
- 将 4D PyTorch 张量转换为 PIL 图像。
54
- """
55
- x = x.float() # 确保张量类型为 float
56
- grid_img = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0).detach().cpu().numpy()
57
- grid_img = (grid_img * 255).clip(0, 255).astype("uint8") # 将 [0, 1] 范围转换为 [0, 255]
58
- return Image.fromarray(grid_img)
59
-
60
- def frame_to_batch(x):
61
- """
62
- 将帧维度转换为批次维度。
63
- """
64
- return rearrange(x, 'b f c h w -> (b f) c h w')
65
-
66
- def clip_image(x, min=0., max=1.):
67
- """
68
- 将图像张量裁剪到指定的最小和最大值。
69
- """
70
- return torch.clamp(x, min=min, max=max)
71
-
72
- def unnormalize(x):
73
- """
74
- 将张量范围从 [-1, 1] 转换到 [0, 1]。
75
- """
76
- return (x + 1) / 2
77
-
78
-
79
- # 读取图像文件
80
- def read_images_from_directory(directory, num_frames=16):
81
- images = []
82
- for i in range(num_frames):
83
- img_path = os.path.join(directory, f'{i:04d}.png')
84
- img = imageio.imread(img_path)
85
- images.append(torch.tensor(img).permute(2, 0, 1)) # Convert to Tensor (C, H, W)
86
- return images
87
-
88
- def load_and_process_images(folder_path):
89
- """
90
- 读取文件夹中的所有图片,将它们转换为 [-1, 1] 范围的张量并返回一个 4D 张量。
91
- """
92
- processed_images = []
93
- transform = transforms.Compose([
94
- transforms.ToTensor(),
95
- transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
96
- ])
97
- for filename in sorted(os.listdir(folder_path)):
98
- if filename.endswith(".png"):
99
- img_path = os.path.join(folder_path, filename)
100
- image = Image.open(img_path).convert("RGB")
101
- processed_image = transform(image)
102
- processed_images.append(processed_image)
103
- return torch.stack(processed_images) # 返回 4D 张量
104
-
105
- def load_and_process_video(video_path, num_frames=16, crop_size=512):
106
- """
107
- 读取视频文件中的前 num_frames 帧,将每一帧转换为 [-1, 1] 范围的张量,
108
- 并进行中心裁剪至 crop_size x crop_size,返回一个 4D 张量。
109
- """
110
- processed_frames = []
111
- transform = transforms.Compose([
112
- transforms.CenterCrop(crop_size), # 中心裁剪
113
- transforms.ToTensor(),
114
- transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
115
- ])
116
-
117
- # 使用 OpenCV 读取视频
118
- cap = cv2.VideoCapture(video_path)
119
-
120
- if not cap.isOpened():
121
- raise ValueError(f"无法打开视频文件: {video_path}")
122
-
123
- frame_count = 0
124
-
125
- while frame_count < num_frames:
126
- ret, frame = cap.read()
127
- if not ret:
128
- break # 视频帧读取完毕或视频帧不足
129
-
130
- # 转换为 RGB 格式
131
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
132
- image = Image.fromarray(frame)
133
-
134
- # 应用转换
135
- processed_frame = transform(image)
136
- processed_frames.append(processed_frame)
137
-
138
- frame_count += 1
139
-
140
- cap.release() # 释放视频资源
141
-
142
- if len(processed_frames) < num_frames:
143
- raise ValueError(f"视频帧不足 {num_frames} 帧,仅找到 {len(processed_frames)} 帧。")
144
-
145
- return torch.stack(processed_frames) # 返回 4D 张量 (帧数, 通道数, 高度, 宽度)
146
-
147
-
148
- def clear_cache(output_path):
149
- if os.path.exists(output_path):
150
- os.remove(output_path)
151
- return None
152
-
153
-
154
- #! 加载模型
155
- # 配置路径和加载模型
156
- config_path = 'configs/instruct_v2v_ic_gradio.yaml'
157
- diffusion_model = unit_test_create_model(config_path)
158
- diffusion_model = diffusion_model.to('cuda')
159
-
160
- # 加载模型检查点
161
- # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
162
- # ckpt_path = 'tmp/pytorch_model.bin'
163
- # 下载文件
164
-
165
- os.makedirs('models', exist_ok=True)
166
- model_path = "models/relvid_mm_sd15_fbc_unet.pth"
167
-
168
- if not os.path.exists(model_path):
169
- download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
170
-
171
-
172
- ckpt = torch.load(model_path, map_location='cpu')
173
- diffusion_model.load_state_dict(ckpt, strict=False)
174
-
175
-
176
- # import pdb; pdb.set_trace()
177
-
178
- # 更改全局临时目录
179
- new_tmp_dir = "./demo/gradio_bg"
180
- os.makedirs(new_tmp_dir, exist_ok=True)
181
-
182
- # import pdb; pdb.set_trace()
183
-
184
- def save_video_from_frames(image_pred, save_pth, fps=8):
185
- """
186
- 将 image_pred 中的帧保存为视频文件。
187
-
188
- 参数:
189
- - image_pred: Tensor,形状为 (1, 16, 3, 512, 512)
190
- - save_pth: 保存视频的路径,例如 "output_video.mp4"
191
- - fps: 视频的帧率
192
- """
193
- # 视频参数
194
- num_frames = image_pred.shape[1]
195
- frame_height, frame_width = 512, 512 # 目标尺寸
196
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码格式
197
-
198
- # 创建 VideoWriter 对象
199
- out = cv2.VideoWriter(save_pth, fourcc, fps, (frame_width, frame_height))
200
-
201
- for i in range(num_frames):
202
- # 反归一化 + 转换为 0-255 范围
203
- pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
204
- pred_frame_resized = pred_frame.squeeze(0).detach().cpu() # (3, 512, 512)
205
- pred_frame_resized = pred_frame_resized.permute(1, 2, 0).numpy().astype("uint8") # (512, 512, 3)
206
-
207
- # Resize 到 256x256
208
- pred_frame_resized = cv2.resize(pred_frame_resized, (frame_width, frame_height))
209
-
210
- # 将 RGB 转为 BGR(因为 OpenCV 使用 BGR 格式)
211
- pred_frame_bgr = cv2.cvtColor(pred_frame_resized, cv2.COLOR_RGB2BGR)
212
-
213
- # 写入帧到视频
214
- out.write(pred_frame_bgr)
215
-
216
- # 释放 VideoWriter 资源
217
- out.release()
218
- print(f"视频已保存至 {save_pth}")
219
-
220
-
221
- inf_pipe = InferenceIP2PVideo(
222
- diffusion_model.unet,
223
- scheduler='ddpm',
224
- num_ddim_steps=20
225
- )
226
-
227
- # 伪函数占位(生成空白视频)
228
- @spaces.GPU
229
- def dummy_process(input_fg, input_bg):
230
- # import pdb; pdb.set_trace()
231
-
232
- diffusion_model.to(torch.float16)
233
- fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0).to(dtype=torch.float16)
234
- bg_tensor = load_and_process_video(input_bg).cuda().unsqueeze(0).to(dtype=torch.float16) # (1, 16, 4, 64, 64)
235
-
236
- cond_fg_tensor = diffusion_model.encode_image_to_latent(fg_tensor) # (1, 16, 4, 64, 64)
237
- cond_bg_tensor = diffusion_model.encode_image_to_latent(bg_tensor)
238
- cond_tensor = torch.cat((cond_fg_tensor, cond_bg_tensor), dim=2)
239
-
240
- # 初始化潜变量
241
- init_latent = torch.randn_like(cond_fg_tensor)
242
-
243
- EDIT_PROMPT = 'change the background'
244
- VIDEO_CFG = 1.2
245
- TEXT_CFG = 7.5
246
- text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
247
- text_uncond = diffusion_model.encode_text([''])
248
- # to float16
249
- print('------------to float 16----------------')
250
- init_latent, text_cond, text_uncond, cond_tensor = (
251
- init_latent.to(dtype=torch.float16),
252
- text_cond.to(dtype=torch.float16),
253
- text_uncond.to(dtype=torch.float16),
254
- cond_tensor.to(dtype=torch.float16)
255
- )
256
- inf_pipe.unet.to(torch.float16)
257
- latent_pred = inf_pipe(
258
- latent=init_latent,
259
- text_cond=text_cond,
260
- text_uncond=text_uncond,
261
- img_cond=cond_tensor,
262
- text_cfg=TEXT_CFG,
263
- img_cfg=VIDEO_CFG,
264
- )['latent']
265
-
266
-
267
- image_pred = diffusion_model.decode_latent_to_image(latent_pred) # (1,16,3,512,512)
268
- output_path = os.path.join(new_tmp_dir, f"output_{int(time.time())}.mp4")
269
- # clear_cache(output_path)
270
-
271
- save_video_from_frames(image_pred, output_path)
272
- # import pdb; pdb.set_trace()
273
- # fps = 8
274
- # frames = []
275
- # for i in range(16):
276
- # pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
277
- # pred_frame_resized = pred_frame.squeeze(0).detach().cpu() #(3,512,512)
278
- # pred_frame_resized = pred_frame_resized.permute(1, 2, 0).detach().cpu().numpy().astype("uint8") #(512,512,3) np
279
- # Image.fromarray(pred_frame_resized).save(save_pth)
280
-
281
- # # 生成一个简单的黑色视频作为示例
282
- # output_path = os.path.join(new_tmp_dir, "output.mp4")
283
- # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
284
- # out = cv2.VideoWriter(output_path, fourcc, 20.0, (512, 512))
285
-
286
- # for _ in range(60): # 生成 3 秒的视频(20fps)
287
- # frame = np.zeros((512, 512, 3), dtype=np.uint8)
288
- # out.write(frame)
289
- # out.release()
290
- torch.cuda.empty_cache()
291
-
292
- return output_path
293
-
294
- # 枚举类用于背景选择
295
- class BGSource(Enum):
296
- UPLOAD = "Use Background Video"
297
- UPLOAD_FLIP = "Use Flipped Background Video"
298
- LEFT = "Left Light"
299
- RIGHT = "Right Light"
300
- TOP = "Top Light"
301
- BOTTOM = "Bottom Light"
302
- GREY = "Ambient"
303
-
304
- # Quick prompts ��例
305
- quick_prompts = [
306
- 'beautiful woman',
307
- 'handsome man',
308
- 'beautiful woman, cinematic lighting',
309
- 'handsome man, cinematic lighting',
310
- 'beautiful woman, natural lighting',
311
- 'handsome man, natural lighting',
312
- 'beautiful woman, neo punk lighting, cyberpunk',
313
- 'handsome man, neo punk lighting, cyberpunk',
314
- ]
315
- quick_prompts = [[x] for x in quick_prompts]
316
-
317
- # Gradio UI 结构
318
- block = gr.Blocks().queue()
319
- with block:
320
- with gr.Row():
321
- gr.Markdown("## IC-Light (Relighting with Foreground and Background Video Condition)")
322
-
323
- with gr.Row():
324
- with gr.Column():
325
- with gr.Row():
326
- input_fg = gr.Video(label="Foreground Video", height=370, width=370, visible=True)
327
- input_bg = gr.Video(label="Background Video", height=370, width=370, visible=True)
328
-
329
- prompt = gr.Textbox(label="Prompt")
330
- bg_source = gr.Radio(choices=[e.value for e in BGSource],
331
- value=BGSource.UPLOAD.value,
332
- label="Background Source", type='value')
333
-
334
- example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
335
- bg_gallery = gr.Gallery(height=450, object_fit='contain', label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False)
336
- relight_button = gr.Button(value="Relight")
337
-
338
- with gr.Group():
339
- with gr.Row():
340
- num_samples = gr.Slider(label="Videos", minimum=1, maximum=12, value=1, step=1)
341
- seed = gr.Number(label="Seed", value=12345, precision=0)
342
- with gr.Row():
343
- video_width = gr.Slider(label="Video Width", minimum=256, maximum=1024, value=512, step=64)
344
- video_height = gr.Slider(label="Video Height", minimum=256, maximum=1024, value=640, step=64)
345
-
346
- with gr.Accordion("Advanced options", open=False):
347
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
348
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
349
- highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
350
- highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
351
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
352
- n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
353
- normal_button = gr.Button(value="Compute Normal (4x Slower)")
354
-
355
- with gr.Column():
356
- result_video = gr.Video(label='Output Video', height=600, width=600, visible=True)
357
-
358
- # 输入列表
359
- # ips = [input_fg, input_bg, prompt, video_width, video_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
360
- ips = [input_fg, input_bg]
361
-
362
- # 按钮绑定处理函数
363
- # relight_button.click(fn=lambda: None, inputs=[], outputs=[result_video])
364
-
365
- relight_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
366
-
367
- normal_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
368
-
369
- # 背景库选择
370
- def bg_gallery_selected(gal, evt: gr.SelectData):
371
- # import pdb; pdb.set_trace()
372
- # img_path = gal[evt.index][0]
373
- img_path = db_examples.bg_samples[evt.index]
374
- video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4')
375
- return video_path
376
-
377
- bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=input_bg)
378
-
379
- # 示例
380
- # dummy_video_for_outputs = gr.Video(visible=False, label='Result')
381
- gr.Examples(
382
- fn=lambda *args: args[-1],
383
- examples=db_examples.background_conditioned_examples,
384
- inputs=[input_fg, input_bg, prompt, bg_source, video_width, video_height, seed, result_video],
385
- outputs=[result_video],
386
- run_on_click=True, examples_per_page=1024
387
- )
388
-
389
- # 启动 Gradio 应用
390
- # block.launch(server_name='0.0.0.0', server_port=10002, share=True)
391
- block.launch()