aleafy commited on
Commit
5e5f393
·
2 Parent(s): 2be5fe1 8134846
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +0 -2
  2. __pycache__/db_examples.cpython-310.pyc +0 -0
  3. __pycache__/demo_utils1.cpython-310.pyc +0 -0
  4. app.py +1 -1
  5. app1_rg3.py +503 -0
  6. app_bf1.py +498 -0
  7. db_examples.py +80 -28
  8. demo/clean_bg_extracted/0/cropped_video.mp4 +0 -0
  9. demo/clean_bg_extracted/0/frames/0000.png +0 -0
  10. demo/clean_bg_extracted/1/cropped_video.mp4 +0 -0
  11. demo/clean_bg_extracted/1/frames/0000.png +0 -0
  12. demo/clean_bg_extracted/2/cropped_video.mp4 +0 -0
  13. demo/clean_bg_extracted/2/frames/0000.png +0 -0
  14. demo/clean_fg_extracted/1/cropped_video.mp4 +0 -0
  15. demo/clean_fg_extracted/1/frames/0000.png +0 -0
  16. demo/clean_fg_extracted/10/cropped_video.mp4 +0 -0
  17. demo/clean_fg_extracted/10/frames/0000.png +0 -0
  18. demo/clean_fg_extracted/11/cropped_video.mp4 +0 -0
  19. demo/clean_fg_extracted/11/frames/0000.png +0 -0
  20. demo/clean_fg_extracted/12/cropped_video.mp4 +0 -0
  21. demo/clean_fg_extracted/12/frames/0000.png +0 -0
  22. demo/clean_fg_extracted/13/cropped_video.mp4 +0 -0
  23. demo/clean_fg_extracted/13/frames/0000.png +0 -0
  24. demo/clean_fg_extracted/16/cropped_video.mp4 +0 -0
  25. demo/clean_fg_extracted/16/frames/0000.png +0 -0
  26. demo/clean_fg_extracted/17/cropped_video.mp4 +0 -0
  27. demo/clean_fg_extracted/17/frames/0000.png +0 -0
  28. demo/clean_fg_extracted/2/cropped_video.mp4 +0 -0
  29. demo/clean_fg_extracted/2/frames/0000.png +0 -0
  30. demo/clean_fg_extracted/3/cropped_video.mp4 +0 -0
  31. demo/clean_fg_extracted/3/frames/0000.png +0 -0
  32. demo/clean_fg_extracted/4/cropped_video.mp4 +0 -0
  33. demo/clean_fg_extracted/4/frames/0000.png +0 -0
  34. demo/clean_fg_extracted/5/cropped_video.mp4 +0 -0
  35. demo/clean_fg_extracted/5/frames/0000.png +0 -0
  36. demo/clean_fg_extracted/6/3.mp4 +0 -0
  37. demo/clean_fg_extracted/6/frames/0000.png +0 -0
  38. demo/clean_fg_extracted/7/cropped_video.mp4 +0 -0
  39. demo/clean_fg_extracted/7/frames/0000.png +0 -0
  40. demo/clean_fg_extracted/8/cropped_video.mp4 +0 -0
  41. demo/clean_fg_extracted/8/frames/0000.png +0 -0
  42. misc_utils/__pycache__/flow_utils.cpython-310.pyc +0 -0
  43. misc_utils/__pycache__/image_utils.cpython-310.pyc +0 -0
  44. misc_utils/__pycache__/model_utils.cpython-310.pyc +0 -0
  45. misc_utils/__pycache__/train_utils.cpython-310.pyc +0 -0
  46. modules/openclip/__pycache__/modules.cpython-310.pyc +0 -0
  47. modules/video_unet_temporal/__pycache__/attention.cpython-310.pyc +0 -0
  48. modules/video_unet_temporal/__pycache__/motion_module.cpython-310.pyc +0 -0
  49. modules/video_unet_temporal/__pycache__/resnet.cpython-310.pyc +0 -0
  50. modules/video_unet_temporal/__pycache__/unet.cpython-310.pyc +0 -0
.gitignore CHANGED
@@ -1,5 +1,3 @@
1
- app1.py
2
- app2.py
3
  demo_utils1.py
4
  tmp
5
  models
 
 
 
1
  demo_utils1.py
2
  tmp
3
  models
__pycache__/db_examples.cpython-310.pyc DELETED
Binary file (1.94 kB)
 
__pycache__/demo_utils1.cpython-310.pyc DELETED
Binary file (470 Bytes)
 
app.py CHANGED
@@ -495,4 +495,4 @@ with block:
495
 
496
  # 启动 Gradio 应用
497
  # block.launch(server_name='0.0.0.0', server_port=10002, share=True)
498
- block.launch()
 
495
 
496
  # 启动 Gradio 应用
497
  # block.launch(server_name='0.0.0.0', server_port=10002, share=True)
498
+ block.launch()
app1_rg3.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ from enum import Enum
5
+ import db_examples
6
+ import cv2
7
+
8
+
9
+ from demo_utils1 import *
10
+
11
+ from misc_utils.train_utils import unit_test_create_model
12
+ from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
13
+ import os
14
+ from PIL import Image
15
+ import torch
16
+ import torchvision
17
+ from torchvision import transforms
18
+ from einops import rearrange
19
+ import imageio
20
+ import time
21
+
22
+ from torchvision.transforms import functional as F
23
+ from torch.hub import download_url_to_file
24
+
25
+ import os
26
+
27
+ # 推理设置
28
+ from pl_trainer.inference.inference import InferenceIP2PVideo
29
+ from tqdm import tqdm
30
+
31
+
32
+ # if not os.path.exists(filename):
33
+ # original_path = os.getcwd()
34
+ # base_path = './models'
35
+ # os.makedirs(base_path, exist_ok=True)
36
+
37
+ # # 直接在代码中写入 Token(注意安全风险)
38
+ # GIT_TOKEN = "955b8ea91095840b76fe38b90a088c200d4c813c"
39
+ # repo_url = f"https://YeFang:{GIT_TOKEN}@code.openxlab.org.cn/YeFang/RIV_models.git"
40
+
41
+ # try:
42
+ # if os.system(f'git clone {repo_url} {base_path}') != 0:
43
+ # raise RuntimeError("Git 克隆失败")
44
+ # os.chdir(base_path)
45
+ # if os.system('git lfs pull') != 0:
46
+ # raise RuntimeError("Git LFS 拉取失败")
47
+ # finally:
48
+ # os.chdir(original_path)
49
+
50
+ def tensor_to_pil_image(x):
51
+ """
52
+ 将 4D PyTorch 张量转换为 PIL 图像。
53
+ """
54
+ x = x.float() # 确保张量类型为 float
55
+ grid_img = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0).detach().cpu().numpy()
56
+ grid_img = (grid_img * 255).clip(0, 255).astype("uint8") # 将 [0, 1] 范围转换为 [0, 255]
57
+ return Image.fromarray(grid_img)
58
+
59
+ def frame_to_batch(x):
60
+ """
61
+ 将帧维度转换为批次维度。
62
+ """
63
+ return rearrange(x, 'b f c h w -> (b f) c h w')
64
+
65
+ def clip_image(x, min=0., max=1.):
66
+ """
67
+ 将图像张量裁剪到指定的最小和最大值。
68
+ """
69
+ return torch.clamp(x, min=min, max=max)
70
+
71
+ def unnormalize(x):
72
+ """
73
+ 将张量范围从 [-1, 1] 转换到 [0, 1]。
74
+ """
75
+ return (x + 1) / 2
76
+
77
+
78
+ # 读取图像文件
79
+ def read_images_from_directory(directory, num_frames=16):
80
+ images = []
81
+ for i in range(num_frames):
82
+ img_path = os.path.join(directory, f'{i:04d}.png')
83
+ img = imageio.imread(img_path)
84
+ images.append(torch.tensor(img).permute(2, 0, 1)) # Convert to Tensor (C, H, W)
85
+ return images
86
+
87
+ def load_and_process_images(folder_path):
88
+ """
89
+ 读取文件夹中的所有图片,将它们转换为 [-1, 1] 范围的张量并返回一个 4D 张量。
90
+ """
91
+ processed_images = []
92
+ transform = transforms.Compose([
93
+ transforms.ToTensor(),
94
+ transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
95
+ ])
96
+ for filename in sorted(os.listdir(folder_path)):
97
+ if filename.endswith(".png"):
98
+ img_path = os.path.join(folder_path, filename)
99
+ image = Image.open(img_path).convert("RGB")
100
+ processed_image = transform(image)
101
+ processed_images.append(processed_image)
102
+ return torch.stack(processed_images) # 返回 4D 张量
103
+
104
+ def load_and_process_video(video_path, num_frames=16, crop_size=512):
105
+ """
106
+ 读取视频文件中的前 num_frames 帧,将每一帧转换为 [-1, 1] 范围的张量,
107
+ 并进行中心裁剪至 crop_size x crop_size,返回一个 4D 张量。
108
+ """
109
+ processed_frames = []
110
+ transform = transforms.Compose([
111
+ transforms.CenterCrop(crop_size), # 中心裁剪
112
+ transforms.ToTensor(),
113
+ transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
114
+ ])
115
+
116
+ # 使用 OpenCV 读取视频
117
+ cap = cv2.VideoCapture(video_path)
118
+
119
+ if not cap.isOpened():
120
+ raise ValueError(f"无法打开视频文件: {video_path}")
121
+
122
+ frame_count = 0
123
+
124
+ while frame_count < num_frames:
125
+ ret, frame = cap.read()
126
+ if not ret:
127
+ break # 视频帧读取完毕或视频帧不足
128
+
129
+ # 转换为 RGB 格式
130
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
131
+ image = Image.fromarray(frame)
132
+
133
+ # 应用转换
134
+ processed_frame = transform(image)
135
+ processed_frames.append(processed_frame)
136
+
137
+ frame_count += 1
138
+
139
+ cap.release() # 释放视频资源
140
+
141
+ if len(processed_frames) < num_frames:
142
+ raise ValueError(f"视频帧不足 {num_frames} 帧,仅找到 {len(processed_frames)} 帧。")
143
+
144
+ return torch.stack(processed_frames) # 返回 4D 张量 (帧数, 通道数, 高度, 宽度)
145
+
146
+
147
+ def clear_cache(output_path):
148
+ if os.path.exists(output_path):
149
+ os.remove(output_path)
150
+ return None
151
+
152
+
153
+ #! 加载模型
154
+ # 配置路径和加载模型
155
+ config_path = 'configs/instruct_v2v_ic_gradio.yaml'
156
+ diffusion_model = unit_test_create_model(config_path)
157
+ diffusion_model = diffusion_model.to('cuda')
158
+
159
+ # 加载模型检查点
160
+ # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
161
+ # ckpt_path = 'tmp/pytorch_model.bin'
162
+ # 下载文件
163
+
164
+ os.makedirs('models', exist_ok=True)
165
+ model_path = "models/relvid_mm_sd15_fbc_unet.pth"
166
+
167
+ if not os.path.exists(model_path):
168
+ download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
169
+
170
+
171
+ ckpt = torch.load(model_path, map_location='cpu')
172
+ diffusion_model.load_state_dict(ckpt, strict=False)
173
+
174
+
175
+ # import pdb; pdb.set_trace()
176
+
177
+ # 更改全局临时目录
178
+ new_tmp_dir = "./demo/gradio_bg"
179
+ os.makedirs(new_tmp_dir, exist_ok=True)
180
+
181
+ # import pdb; pdb.set_trace()
182
+
183
+ def save_video_from_frames(image_pred, save_pth, fps=8):
184
+ """
185
+ 将 image_pred 中的帧保存为视频文件。
186
+
187
+ 参数:
188
+ - image_pred: Tensor,形状为 (1, 16, 3, 512, 512)
189
+ - save_pth: 保存视频的路径,例如 "output_video.mp4"
190
+ - fps: 视频的帧率
191
+ """
192
+ # 视频参数
193
+ num_frames = image_pred.shape[1]
194
+ frame_height, frame_width = 512, 512 # 目标尺寸
195
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码格式
196
+
197
+ # 创建 VideoWriter 对象
198
+ out = cv2.VideoWriter(save_pth, fourcc, fps, (frame_width, frame_height))
199
+
200
+ for i in range(num_frames):
201
+ # 反归一化 + 转换为 0-255 范围
202
+ pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
203
+ pred_frame_resized = pred_frame.squeeze(0).detach().cpu() # (3, 512, 512)
204
+ pred_frame_resized = pred_frame_resized.permute(1, 2, 0).numpy().astype("uint8") # (512, 512, 3)
205
+
206
+ # Resize 到 256x256
207
+ pred_frame_resized = cv2.resize(pred_frame_resized, (frame_width, frame_height))
208
+
209
+ # 将 RGB 转为 BGR(因为 OpenCV 使用 BGR 格式)
210
+ pred_frame_bgr = cv2.cvtColor(pred_frame_resized, cv2.COLOR_RGB2BGR)
211
+
212
+ # 写入帧到视频
213
+ out.write(pred_frame_bgr)
214
+
215
+ # 释放 VideoWriter 资源
216
+ out.release()
217
+ print(f"视频已保存至 {save_pth}")
218
+
219
+
220
+ inf_pipe = InferenceIP2PVideo(
221
+ diffusion_model.unet,
222
+ scheduler='ddpm',
223
+ num_ddim_steps=20
224
+ )
225
+
226
+
227
+ def process_example(*args):
228
+ v_index = args[0]
229
+ select_e = db_examples.background_conditioned_examples[int(v_index)-1]
230
+ input_fg_path = select_e[1]
231
+ input_bg_path = select_e[2]
232
+ result_video_path = select_e[-1]
233
+ # input_fg_img = args[1] # 第 0 个参数
234
+ # input_bg_img = args[2] # 第 1 个参数
235
+ # result_video_img = args[-1] # 最后一个参数
236
+
237
+ input_fg = input_fg_path.replace("frames/0000.png", "cropped_video.mp4")
238
+ input_bg = input_bg_path.replace("frames/0000.png", "cropped_video.mp4")
239
+ result_video = result_video_path.replace(".png", ".mp4")
240
+
241
+ return input_fg, input_bg, result_video
242
+
243
+
244
+
245
+ # 伪函数占位(生成空白视频)
246
+ def dummy_process(input_fg, input_bg, prompt):
247
+ # import pdb; pdb.set_trace()
248
+
249
+ diffusion_model.to(torch.float16)
250
+ fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0).to(dtype=torch.float16)
251
+ bg_tensor = load_and_process_video(input_bg).cuda().unsqueeze(0).to(dtype=torch.float16) # (1, 16, 4, 64, 64)
252
+
253
+ cond_fg_tensor = diffusion_model.encode_image_to_latent(fg_tensor) # (1, 16, 4, 64, 64)
254
+ cond_bg_tensor = diffusion_model.encode_image_to_latent(bg_tensor)
255
+ cond_tensor = torch.cat((cond_fg_tensor, cond_bg_tensor), dim=2)
256
+
257
+ # 初始化潜变量
258
+ init_latent = torch.randn_like(cond_fg_tensor)
259
+
260
+ # EDIT_PROMPT = 'change the background'
261
+ EDIT_PROMPT = prompt
262
+ VIDEO_CFG = 1.2
263
+ TEXT_CFG = 7.5
264
+ text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
265
+ text_uncond = diffusion_model.encode_text([''])
266
+ # to float16
267
+ print('------------to float 16----------------')
268
+ init_latent, text_cond, text_uncond, cond_tensor = (
269
+ init_latent.to(dtype=torch.float16),
270
+ text_cond.to(dtype=torch.float16),
271
+ text_uncond.to(dtype=torch.float16),
272
+ cond_tensor.to(dtype=torch.float16)
273
+ )
274
+ inf_pipe.unet.to(torch.float16)
275
+ latent_pred = inf_pipe(
276
+ latent=init_latent,
277
+ text_cond=text_cond,
278
+ text_uncond=text_uncond,
279
+ img_cond=cond_tensor,
280
+ text_cfg=TEXT_CFG,
281
+ img_cfg=VIDEO_CFG,
282
+ )['latent']
283
+
284
+
285
+ image_pred = diffusion_model.decode_latent_to_image(latent_pred) # (1,16,3,512,512)
286
+ output_path = os.path.join(new_tmp_dir, f"output_{int(time.time())}.mp4")
287
+ # clear_cache(output_path)
288
+
289
+ save_video_from_frames(image_pred, output_path)
290
+ # import pdb; pdb.set_trace()
291
+ # fps = 8
292
+ # frames = []
293
+ # for i in range(16):
294
+ # pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
295
+ # pred_frame_resized = pred_frame.squeeze(0).detach().cpu() #(3,512,512)
296
+ # pred_frame_resized = pred_frame_resized.permute(1, 2, 0).detach().cpu().numpy().astype("uint8") #(512,512,3) np
297
+ # Image.fromarray(pred_frame_resized).save(save_pth)
298
+
299
+ # # 生成一个简单的黑色视频作为示例
300
+ # output_path = os.path.join(new_tmp_dir, "output.mp4")
301
+ # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
302
+ # out = cv2.VideoWriter(output_path, fourcc, 20.0, (512, 512))
303
+
304
+ # for _ in range(60): # 生成 3 秒的视频(20fps)
305
+ # frame = np.zeros((512, 512, 3), dtype=np.uint8)
306
+ # out.write(frame)
307
+ # out.release()
308
+ torch.cuda.empty_cache()
309
+
310
+ return output_path
311
+
312
+ # 枚举类用于背景选择
313
+ class BGSource(Enum):
314
+ UPLOAD = "Use Background Video"
315
+ UPLOAD_FLIP = "Use Flipped Background Video"
316
+ UPLOAD_REVERSE = "Use Reversed Background Video"
317
+
318
+
319
+ # Quick prompts 示例
320
+ # quick_prompts = [
321
+ # 'beautiful woman, fantasy setting',
322
+ # 'beautiful woman, neon dynamic lighting',
323
+ # 'man in suit, tunel lighting',
324
+ # 'animated mouse, aesthetic lighting',
325
+ # 'robot warrior, a sunset background',
326
+ # 'yellow cat, reflective wet beach',
327
+ # 'camera, dock, calm sunset',
328
+ # 'astronaut, dim lighting',
329
+ # 'astronaut, colorful balloons',
330
+ # 'astronaut, desert landscape'
331
+ # ]
332
+
333
+ # quick_prompts = [
334
+ # 'beautiful woman',
335
+ # 'handsome man',
336
+ # 'beautiful woman, cinematic lighting',
337
+ # 'handsome man, cinematic lighting',
338
+ # 'beautiful woman, natural lighting',
339
+ # 'handsome man, natural lighting',
340
+ # 'beautiful woman, neo punk lighting, cyberpunk',
341
+ # 'handsome man, neo punk lighting, cyberpunk',
342
+ # ]
343
+
344
+
345
+ quick_prompts = [
346
+ 'beautiful woman',
347
+ 'handsome man',
348
+ 'beautiful woman, cinematic lighting',
349
+ 'handsome man, cinematic lighting',
350
+ 'beautiful woman, natural lighting',
351
+ 'handsome man, natural lighting',
352
+ 'beautiful woman, warm lighting',
353
+ 'handsome man, soft lighting',
354
+ 'change the background lighting',
355
+ ]
356
+
357
+
358
+ quick_prompts = [[x] for x in quick_prompts]
359
+
360
+ # css = """
361
+ # #foreground-gallery {
362
+ # width: 700 !important; /* 限制最大宽度 */
363
+ # max-width: 700px !important; /* 避免它自动变宽 */
364
+ # flex: none !important; /* 让它不自动扩展 */
365
+ # }
366
+ # """
367
+
368
+ css = """
369
+ #prompt-box, #bg-source, #quick-list, #relight-btn {
370
+ width: 750px !important;
371
+
372
+ }
373
+ """
374
+
375
+ # Gradio UI 结构
376
+ block = gr.Blocks(css=css).queue()
377
+ with block:
378
+ with gr.Row():
379
+ # gr.Markdown("## RelightVid (Relighting with Foreground and Background Video Condition)")
380
+ gr.Markdown("# 💡RelightVid \n### Relighting with Foreground and Background Video Condition")
381
+
382
+ with gr.Row():
383
+ with gr.Column():
384
+ with gr.Row():
385
+ input_fg = gr.Video(label="Foreground Video", height=380, width=420, visible=True)
386
+ input_bg = gr.Video(label="Background Video", height=380, width=420, visible=True)
387
+
388
+ segment_button = gr.Button(value="Video Segmentation")
389
+ with gr.Accordion("Segmentation Options", open=False):
390
+ # 如果用户不使用 point_prompt,而是直接提供坐标,则使用 x, y
391
+ with gr.Row():
392
+ x_coord = gr.Slider(label="X Coordinate (Point Prompt Ratio)", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
393
+ y_coord = gr.Slider(label="Y Coordinate (Point Prompt Ratio)", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
394
+
395
+
396
+ fg_gallery = gr.Gallery(height=150, object_fit='contain', label='Foreground Quick List', value=db_examples.fg_samples, columns=5, allow_preview=False)
397
+ bg_gallery = gr.Gallery(height=450, object_fit='contain', label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False)
398
+
399
+
400
+ with gr.Group():
401
+ # with gr.Row():
402
+ # num_samples = gr.Slider(label="Videos", minimum=1, maximum=12, value=1, step=1)
403
+ # seed = gr.Number(label="Seed", value=12345, precision=0)
404
+ with gr.Row():
405
+ video_width = gr.Slider(label="Video Width", minimum=256, maximum=1024, value=512, step=64, visible=False)
406
+ video_height = gr.Slider(label="Video Height", minimum=256, maximum=1024, value=512, step=64, visible=False)
407
+
408
+ # with gr.Accordion("Advanced options", open=False):
409
+ # steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
410
+ # cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
411
+ # highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
412
+ # highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
413
+ # a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
414
+ # n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
415
+ # normal_button = gr.Button(value="Compute Normal (4x Slower)")
416
+
417
+ with gr.Column():
418
+ result_video = gr.Video(label='Output Video', height=750, width=750, visible=True)
419
+
420
+ prompt = gr.Textbox(label="Prompt", elem_id="prompt-box")
421
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
422
+ value=BGSource.UPLOAD.value,
423
+ label="Background Source",
424
+ type='value',
425
+ elem_id="bg-source")
426
+
427
+ example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt], elem_id="quick-list")
428
+ relight_button = gr.Button(value="Relight", elem_id="relight-btn")
429
+
430
+
431
+ # prompt = gr.Textbox(label="Prompt")
432
+ # bg_source = gr.Radio(choices=[e.value for e in BGSource],
433
+ # value=BGSource.UPLOAD.value,
434
+ # label="Background Source", type='value')
435
+
436
+ # example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
437
+ # relight_button = gr.Button(value="Relight")
438
+
439
+
440
+ # fg_gallery = gr.Gallery(witdth=400, object_fit='contain', label='Foreground Quick List', value=db_examples.bg_samples, columns=4, allow_preview=False)
441
+ # fg_gallery = gr.Gallery(
442
+ # height=380,
443
+ # object_fit='contain',
444
+ # label='Foreground Quick List',
445
+ # value=db_examples.fg_samples,
446
+ # columns=4,
447
+ # allow_preview=False,
448
+ # elem_id="foreground-gallery" # 👈 添加 elem_id
449
+ # )
450
+
451
+
452
+ # 输入列表
453
+ # ips = [input_fg, input_bg, prompt, video_width, video_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
454
+ ips = [input_fg, input_bg, prompt]
455
+
456
+ # 按钮绑定处理函数
457
+ # relight_button.click(fn=lambda: None, inputs=[], outputs=[result_video])
458
+
459
+ relight_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
460
+
461
+ # normal_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
462
+
463
+ # 背景库选择
464
+ def bg_gallery_selected(gal, evt: gr.SelectData):
465
+ # import pdb; pdb.set_trace()
466
+ # img_path = gal[evt.index][0]
467
+ img_path = db_examples.bg_samples[evt.index]
468
+ video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4')
469
+ return video_path
470
+
471
+ bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=input_bg)
472
+
473
+ def fg_gallery_selected(gal, evt: gr.SelectData):
474
+ # import pdb; pdb.set_trace()
475
+ # img_path = gal[evt.index][0]
476
+ img_path = db_examples.fg_samples[evt.index]
477
+ video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4')
478
+ return video_path
479
+
480
+ fg_gallery.select(fg_gallery_selected, inputs=fg_gallery, outputs=input_fg)
481
+
482
+ input_fg_img = gr.Image(label="Foreground Video", visible=False)
483
+ input_bg_img = gr.Image(label="Background Video", visible=False)
484
+ result_video_img = gr.Image(label="Output Video", visible=False)
485
+
486
+ v_index = gr.Textbox(label="ID", visible=False)
487
+ example_prompts.click(lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False)
488
+
489
+ # 示例
490
+ # dummy_video_for_outputs = gr.Video(visible=False, label='Result')
491
+ gr.Examples(
492
+ # fn=lambda *args: args[-1],
493
+ fn=process_example,
494
+ examples=db_examples.background_conditioned_examples,
495
+ # inputs=[v_index, input_fg_img, input_bg_img, prompt, bg_source, video_width, video_height, result_video_img],
496
+ inputs=[v_index, input_fg_img, input_bg_img, prompt, bg_source, result_video_img],
497
+ outputs=[input_fg, input_bg, result_video],
498
+ run_on_click=True, examples_per_page=1024
499
+ )
500
+
501
+ # 启动 Gradio 应用
502
+ # block.launch(server_name='0.0.0.0', server_port=10002, share=True)
503
+ block.launch(share=True)
app_bf1.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ from enum import Enum
5
+ import db_examples
6
+ import cv2
7
+
8
+ from demo_utils1 import *
9
+
10
+ from misc_utils.train_utils import unit_test_create_model
11
+ from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
12
+ import os
13
+ from PIL import Image
14
+ import torch
15
+ import torchvision
16
+ from torchvision import transforms
17
+ from einops import rearrange
18
+ import imageio
19
+ import time
20
+
21
+ from torchvision.transforms import functional as F
22
+ from torch.hub import download_url_to_file
23
+
24
+ import os
25
+ import spaces
26
+
27
+ # 推理设置
28
+ from pl_trainer.inference.inference import InferenceIP2PVideo
29
+ from tqdm import tqdm
30
+
31
+
32
+ # if not os.path.exists(filename):
33
+ # original_path = os.getcwd()
34
+ # base_path = './models'
35
+ # os.makedirs(base_path, exist_ok=True)
36
+
37
+ # # 直接在代码中写入 Token(注意安全风险)
38
+ # GIT_TOKEN = "955b8ea91095840b76fe38b90a088c200d4c813c"
39
+ # repo_url = f"https://YeFang:{GIT_TOKEN}@code.openxlab.org.cn/YeFang/RIV_models.git"
40
+
41
+ # try:
42
+ # if os.system(f'git clone {repo_url} {base_path}') != 0:
43
+ # raise RuntimeError("Git 克隆失败")
44
+ # os.chdir(base_path)
45
+ # if os.system('git lfs pull') != 0:
46
+ # raise RuntimeError("Git LFS 拉取失败")
47
+ # finally:
48
+ # os.chdir(original_path)
49
+
50
+ def tensor_to_pil_image(x):
51
+ """
52
+ 将 4D PyTorch 张量转换为 PIL 图像。
53
+ """
54
+ x = x.float() # 确保张量类型为 float
55
+ grid_img = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0).detach().cpu().numpy()
56
+ grid_img = (grid_img * 255).clip(0, 255).astype("uint8") # 将 [0, 1] 范围转换为 [0, 255]
57
+ return Image.fromarray(grid_img)
58
+
59
+ def frame_to_batch(x):
60
+ """
61
+ 将帧维度转换为批次维度。
62
+ """
63
+ return rearrange(x, 'b f c h w -> (b f) c h w')
64
+
65
+ def clip_image(x, min=0., max=1.):
66
+ """
67
+ 将图像张量裁剪到指定的最小和最大值。
68
+ """
69
+ return torch.clamp(x, min=min, max=max)
70
+
71
+ def unnormalize(x):
72
+ """
73
+ 将张量范围从 [-1, 1] 转换到 [0, 1]。
74
+ """
75
+ return (x + 1) / 2
76
+
77
+
78
+ # 读取图像文件
79
+ def read_images_from_directory(directory, num_frames=16):
80
+ images = []
81
+ for i in range(num_frames):
82
+ img_path = os.path.join(directory, f'{i:04d}.png')
83
+ img = imageio.imread(img_path)
84
+ images.append(torch.tensor(img).permute(2, 0, 1)) # Convert to Tensor (C, H, W)
85
+ return images
86
+
87
+ def load_and_process_images(folder_path):
88
+ """
89
+ 读取文件夹中的所有图片,将它们转换为 [-1, 1] 范围的张量并返回一个 4D 张量。
90
+ """
91
+ processed_images = []
92
+ transform = transforms.Compose([
93
+ transforms.ToTensor(),
94
+ transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
95
+ ])
96
+ for filename in sorted(os.listdir(folder_path)):
97
+ if filename.endswith(".png"):
98
+ img_path = os.path.join(folder_path, filename)
99
+ image = Image.open(img_path).convert("RGB")
100
+ processed_image = transform(image)
101
+ processed_images.append(processed_image)
102
+ return torch.stack(processed_images) # 返回 4D 张量
103
+
104
+ def load_and_process_video(video_path, num_frames=16, crop_size=512):
105
+ """
106
+ 读取视频文件中的前 num_frames 帧,将每一帧转换为 [-1, 1] 范围的张量,
107
+ 并进行中心裁剪至 crop_size x crop_size,返回一个 4D 张量。
108
+ """
109
+ processed_frames = []
110
+ transform = transforms.Compose([
111
+ transforms.CenterCrop(crop_size), # 中心裁剪
112
+ transforms.ToTensor(),
113
+ transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1]
114
+ ])
115
+
116
+ # 使用 OpenCV 读取视频
117
+ cap = cv2.VideoCapture(video_path)
118
+
119
+ if not cap.isOpened():
120
+ raise ValueError(f"无法打开视频文件: {video_path}")
121
+
122
+ frame_count = 0
123
+
124
+ while frame_count < num_frames:
125
+ ret, frame = cap.read()
126
+ if not ret:
127
+ break # 视频帧读取完毕或视频帧不足
128
+
129
+ # 转换为 RGB 格式
130
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
131
+ image = Image.fromarray(frame)
132
+
133
+ # 应用转换
134
+ processed_frame = transform(image)
135
+ processed_frames.append(processed_frame)
136
+
137
+ frame_count += 1
138
+
139
+ cap.release() # 释放视频资源
140
+
141
+ if len(processed_frames) < num_frames:
142
+ raise ValueError(f"视频帧不足 {num_frames} 帧,仅找到 {len(processed_frames)} 帧。")
143
+
144
+ return torch.stack(processed_frames) # 返回 4D 张量 (帧数, 通道数, 高度, 宽度)
145
+
146
+
147
+ def clear_cache(output_path):
148
+ if os.path.exists(output_path):
149
+ os.remove(output_path)
150
+ return None
151
+
152
+
153
+ #! 加载模型
154
+ # 配置路径和加载模型
155
+ config_path = 'configs/instruct_v2v_ic_gradio.yaml'
156
+ diffusion_model = unit_test_create_model(config_path)
157
+ diffusion_model = diffusion_model.to('cuda')
158
+
159
+ # 加载模型检查点
160
+ # ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
161
+ # ckpt_path = 'tmp/pytorch_model.bin'
162
+ # 下载文件
163
+
164
+ os.makedirs('models', exist_ok=True)
165
+ model_path = "models/relvid_mm_sd15_fbc_unet.pth"
166
+
167
+ if not os.path.exists(model_path):
168
+ download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
169
+
170
+
171
+ ckpt = torch.load(model_path, map_location='cpu')
172
+ diffusion_model.load_state_dict(ckpt, strict=False)
173
+
174
+
175
+ # import pdb; pdb.set_trace()
176
+
177
+ # 更改全局临时目录
178
+ new_tmp_dir = "./demo/gradio_bg"
179
+ os.makedirs(new_tmp_dir, exist_ok=True)
180
+
181
+ # import pdb; pdb.set_trace()
182
+
183
+ def save_video_from_frames(image_pred, save_pth, fps=8):
184
+ """
185
+ 将 image_pred 中的帧保存为视频文件。
186
+ 参数:
187
+ - image_pred: Tensor,形状为 (1, 16, 3, 512, 512)
188
+ - save_pth: 保存视频的路径,例如 "output_video.mp4"
189
+ - fps: 视频的帧率
190
+ """
191
+ # 视频参数
192
+ num_frames = image_pred.shape[1]
193
+ frame_height, frame_width = 512, 512 # 目标尺寸
194
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码格式
195
+
196
+ # 创建 VideoWriter 对象
197
+ out = cv2.VideoWriter(save_pth, fourcc, fps, (frame_width, frame_height))
198
+
199
+ for i in range(num_frames):
200
+ # 反归一化 + 转换为 0-255 范围
201
+ pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
202
+ pred_frame_resized = pred_frame.squeeze(0).detach().cpu() # (3, 512, 512)
203
+ pred_frame_resized = pred_frame_resized.permute(1, 2, 0).numpy().astype("uint8") # (512, 512, 3)
204
+
205
+ # Resize 到 256x256
206
+ pred_frame_resized = cv2.resize(pred_frame_resized, (frame_width, frame_height))
207
+
208
+ # 将 RGB 转为 BGR(因为 OpenCV 使用 BGR 格式)
209
+ pred_frame_bgr = cv2.cvtColor(pred_frame_resized, cv2.COLOR_RGB2BGR)
210
+
211
+ # 写入帧到视频
212
+ out.write(pred_frame_bgr)
213
+
214
+ # 释放 VideoWriter 资源
215
+ out.release()
216
+ print(f"视频已保存至 {save_pth}")
217
+
218
+
219
+ inf_pipe = InferenceIP2PVideo(
220
+ diffusion_model.unet,
221
+ scheduler='ddpm',
222
+ num_ddim_steps=20
223
+ )
224
+
225
+
226
+ def process_example(*args):
227
+ v_index = args[0]
228
+ select_e = db_examples.background_conditioned_examples[int(v_index)-1]
229
+ input_fg_path = select_e[1]
230
+ input_bg_path = select_e[2]
231
+ result_video_path = select_e[-1]
232
+ # input_fg_img = args[1] # 第 0 个参数
233
+ # input_bg_img = args[2] # 第 1 个参数
234
+ # result_video_img = args[-1] # 最后一个参数
235
+
236
+ input_fg = input_fg_path.replace("frames/0000.png", "cropped_video.mp4")
237
+ input_bg = input_bg_path.replace("frames/0000.png", "cropped_video.mp4")
238
+ result_video = result_video_path.replace(".png", ".mp4")
239
+
240
+ return input_fg, input_bg, result_video
241
+
242
+
243
+
244
+ # 伪函数占位(生成空白视频)
245
+ @spaces.GPU
246
+ def dummy_process(input_fg, input_bg, prompt):
247
+ # import pdb; pdb.set_trace()
248
+
249
+ diffusion_model.to(torch.float16)
250
+ fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0).to(dtype=torch.float16)
251
+ bg_tensor = load_and_process_video(input_bg).cuda().unsqueeze(0).to(dtype=torch.float16) # (1, 16, 4, 64, 64)
252
+
253
+ cond_fg_tensor = diffusion_model.encode_image_to_latent(fg_tensor) # (1, 16, 4, 64, 64)
254
+ cond_bg_tensor = diffusion_model.encode_image_to_latent(bg_tensor)
255
+ cond_tensor = torch.cat((cond_fg_tensor, cond_bg_tensor), dim=2)
256
+
257
+ # 初始化潜变量
258
+ init_latent = torch.randn_like(cond_fg_tensor)
259
+
260
+ # EDIT_PROMPT = 'change the background'
261
+ EDIT_PROMPT = prompt
262
+ VIDEO_CFG = 1.2
263
+ TEXT_CFG = 7.5
264
+ text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768)
265
+ text_uncond = diffusion_model.encode_text([''])
266
+ # to float16
267
+ print('------------to float 16----------------')
268
+ init_latent, text_cond, text_uncond, cond_tensor = (
269
+ init_latent.to(dtype=torch.float16),
270
+ text_cond.to(dtype=torch.float16),
271
+ text_uncond.to(dtype=torch.float16),
272
+ cond_tensor.to(dtype=torch.float16)
273
+ )
274
+ inf_pipe.unet.to(torch.float16)
275
+ latent_pred = inf_pipe(
276
+ latent=init_latent,
277
+ text_cond=text_cond,
278
+ text_uncond=text_uncond,
279
+ img_cond=cond_tensor,
280
+ text_cfg=TEXT_CFG,
281
+ img_cfg=VIDEO_CFG,
282
+ )['latent']
283
+
284
+
285
+ image_pred = diffusion_model.decode_latent_to_image(latent_pred) # (1,16,3,512,512)
286
+ output_path = os.path.join(new_tmp_dir, f"output_{int(time.time())}.mp4")
287
+ # clear_cache(output_path)
288
+
289
+ save_video_from_frames(image_pred, output_path)
290
+ # import pdb; pdb.set_trace()
291
+ # fps = 8
292
+ # frames = []
293
+ # for i in range(16):
294
+ # pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255
295
+ # pred_frame_resized = pred_frame.squeeze(0).detach().cpu() #(3,512,512)
296
+ # pred_frame_resized = pred_frame_resized.permute(1, 2, 0).detach().cpu().numpy().astype("uint8") #(512,512,3) np
297
+ # Image.fromarray(pred_frame_resized).save(save_pth)
298
+
299
+ # # 生成一个简单的黑色视频作为���例
300
+ # output_path = os.path.join(new_tmp_dir, "output.mp4")
301
+ # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
302
+ # out = cv2.VideoWriter(output_path, fourcc, 20.0, (512, 512))
303
+
304
+ # for _ in range(60): # 生成 3 秒的视频(20fps)
305
+ # frame = np.zeros((512, 512, 3), dtype=np.uint8)
306
+ # out.write(frame)
307
+ # out.release()
308
+ torch.cuda.empty_cache()
309
+
310
+ return output_path
311
+
312
+ # 枚举类用于背景选择
313
+ class BGSource(Enum):
314
+ UPLOAD = "Use Background Video"
315
+ UPLOAD_FLIP = "Use Flipped Background Video"
316
+ UPLOAD_REVERSE = "Use Reversed Background Video"
317
+
318
+
319
+ # Quick prompts 示例
320
+ # quick_prompts = [
321
+ # 'beautiful woman, fantasy setting',
322
+ # 'beautiful woman, neon dynamic lighting',
323
+ # 'man in suit, tunel lighting',
324
+ # 'animated mouse, aesthetic lighting',
325
+ # 'robot warrior, a sunset background',
326
+ # 'yellow cat, reflective wet beach',
327
+ # 'camera, dock, calm sunset',
328
+ # 'astronaut, dim lighting',
329
+ # 'astronaut, colorful balloons',
330
+ # 'astronaut, desert landscape'
331
+ # ]
332
+
333
+ # quick_prompts = [
334
+ # 'beautiful woman',
335
+ # 'handsome man',
336
+ # 'beautiful woman, cinematic lighting',
337
+ # 'handsome man, cinematic lighting',
338
+ # 'beautiful woman, natural lighting',
339
+ # 'handsome man, natural lighting',
340
+ # 'beautiful woman, neo punk lighting, cyberpunk',
341
+ # 'handsome man, neo punk lighting, cyberpunk',
342
+ # ]
343
+
344
+
345
+ quick_prompts = [
346
+ 'beautiful woman',
347
+ 'handsome man',
348
+ # 'beautiful woman, cinematic lighting',
349
+ 'handsome man, cinematic lighting',
350
+ 'beautiful woman, natural lighting',
351
+ 'handsome man, natural lighting',
352
+ 'beautiful woman, warm lighting',
353
+ 'handsome man, soft lighting',
354
+ 'change the background lighting',
355
+ ]
356
+
357
+
358
+ quick_prompts = [[x] for x in quick_prompts]
359
+
360
+ # css = """
361
+ # #foreground-gallery {
362
+ # width: 700 !important; /* 限制最大宽度 */
363
+ # max-width: 700px !important; /* 避免它自动变宽 */
364
+ # flex: none !important; /* 让它不自动扩展 */
365
+ # }
366
+ # """
367
+
368
+ # css = """
369
+ # #prompt-box, #bg-source, #quick-list, #relight-btn {
370
+ # width: 750px !important;
371
+ # }
372
+ # """
373
+
374
+ # Gradio UI 结构
375
+ block = gr.Blocks().queue()
376
+ with block:
377
+ with gr.Row():
378
+ # gr.Markdown("## RelightVid (Relighting with Foreground and Background Video Condition)")
379
+ gr.Markdown("# 💡RelightVid \n### Relighting with Foreground and Background Video Condition")
380
+
381
+ with gr.Row():
382
+ with gr.Column():
383
+ with gr.Row():
384
+ input_fg = gr.Video(label="Foreground Video", height=380, width=420, visible=True)
385
+ input_bg = gr.Video(label="Background Video", height=380, width=420, visible=True)
386
+
387
+ segment_button = gr.Button(value="Video Segmentation")
388
+ with gr.Accordion("Segmentation Options", open=False):
389
+ # 如果用户不使用 point_prompt,而是直接提供坐标,则使用 x, y
390
+ with gr.Row():
391
+ x_coord = gr.Slider(label="X Coordinate (Point Prompt Ratio)", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
392
+ y_coord = gr.Slider(label="Y Coordinate (Point Prompt Ratio)", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
393
+
394
+
395
+ fg_gallery = gr.Gallery(height=150, object_fit='contain', label='Foreground Quick List', value=db_examples.fg_samples, columns=5, allow_preview=False)
396
+ bg_gallery = gr.Gallery(height=450, object_fit='contain', label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False)
397
+
398
+
399
+ with gr.Group():
400
+ # with gr.Row():
401
+ # num_samples = gr.Slider(label="Videos", minimum=1, maximum=12, value=1, step=1)
402
+ # seed = gr.Number(label="Seed", value=12345, precision=0)
403
+ with gr.Row():
404
+ video_width = gr.Slider(label="Video Width", minimum=256, maximum=1024, value=512, step=64, visible=False)
405
+ video_height = gr.Slider(label="Video Height", minimum=256, maximum=1024, value=512, step=64, visible=False)
406
+
407
+ # with gr.Accordion("Advanced options", open=False):
408
+ # steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
409
+ # cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
410
+ # highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
411
+ # highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
412
+ # a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
413
+ # n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
414
+ # normal_button = gr.Button(value="Compute Normal (4x Slower)")
415
+
416
+ with gr.Column():
417
+ result_video = gr.Video(label='Output Video', height=750, visible=True)
418
+
419
+ prompt = gr.Textbox(label="Prompt")
420
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
421
+ value=BGSource.UPLOAD.value,
422
+ label="Background Source",
423
+ type='value')
424
+
425
+ example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
426
+ relight_button = gr.Button(value="Relight")
427
+
428
+ # prompt = gr.Textbox(label="Prompt")
429
+ # bg_source = gr.Radio(choices=[e.value for e in BGSource],
430
+ # value=BGSource.UPLOAD.value,
431
+ # label="Background Source", type='value')
432
+
433
+ # example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt])
434
+ # relight_button = gr.Button(value="Relight")
435
+ # fg_gallery = gr.Gallery(witdth=400, object_fit='contain', label='Foreground Quick List', value=db_examples.bg_samples, columns=4, allow_preview=False)
436
+ # fg_gallery = gr.Gallery(
437
+ # height=380,
438
+ # object_fit='contain',
439
+ # label='Foreground Quick List',
440
+ # value=db_examples.fg_samples,
441
+ # columns=4,
442
+ # allow_preview=False,
443
+ # elem_id="foreground-gallery" # 👈 添加 elem_id
444
+ # )
445
+
446
+
447
+ # 输入列表
448
+ # ips = [input_fg, input_bg, prompt, video_width, video_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
449
+ ips = [input_fg, input_bg, prompt]
450
+
451
+ # 按钮绑定处理函数
452
+ # relight_button.click(fn=lambda: None, inputs=[], outputs=[result_video])
453
+
454
+ relight_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
455
+
456
+ # normal_button.click(fn=dummy_process, inputs=ips, outputs=[result_video])
457
+
458
+ # 背景库选择
459
+ def bg_gallery_selected(gal, evt: gr.SelectData):
460
+ # import pdb; pdb.set_trace()
461
+ # img_path = gal[evt.index][0]
462
+ img_path = db_examples.bg_samples[evt.index]
463
+ video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4')
464
+ return video_path
465
+
466
+ bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=input_bg)
467
+
468
+ def fg_gallery_selected(gal, evt: gr.SelectData):
469
+ # import pdb; pdb.set_trace()
470
+ # img_path = gal[evt.index][0]
471
+ img_path = db_examples.fg_samples[evt.index]
472
+ video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4')
473
+ return video_path
474
+
475
+ fg_gallery.select(fg_gallery_selected, inputs=fg_gallery, outputs=input_fg)
476
+
477
+ input_fg_img = gr.Image(label="Foreground Video", visible=False)
478
+ input_bg_img = gr.Image(label="Background Video", visible=False)
479
+ result_video_img = gr.Image(label="Output Video", visible=False)
480
+
481
+ v_index = gr.Textbox(label="ID", visible=False)
482
+ example_prompts.click(lambda x: x[0], inputs=example_prompts, outputs=prompt, show_progress=False, queue=False)
483
+
484
+ # 示例
485
+ # dummy_video_for_outputs = gr.Video(visible=False, label='Result')
486
+ gr.Examples(
487
+ # fn=lambda *args: args[-1],
488
+ fn=process_example,
489
+ examples=db_examples.background_conditioned_examples,
490
+ # inputs=[v_index, input_fg_img, input_bg_img, prompt, bg_source, video_width, video_height, result_video_img],
491
+ inputs=[v_index, input_fg_img, input_bg_img, prompt, bg_source, result_video_img],
492
+ outputs=[input_fg, input_bg, result_video],
493
+ run_on_click=True, examples_per_page=1024
494
+ )
495
+
496
+ # 启动 Gradio 应用
497
+ # block.launch(server_name='0.0.0.0', server_port=10002, share=True)
498
+ block.launch()
db_examples.py CHANGED
@@ -21,8 +21,25 @@ fg_samples = [
21
  'demo/clean_fg_extracted/14/frames/0000.png',
22
  'demo/clean_fg_extracted/15/frames/0000.png',
23
  'demo/clean_fg_extracted/18/frames/0000.png',
24
- 'demo/clean_fg_extracted/22/frames/0000.png',
25
- 'demo/clean_fg_extracted/9/frames/0000.png',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # 'demo/clean_bg_extracted/39/frames/0000.png',
27
  # 'demo/clean_bg_extracted/59/frames/0000.png',
28
  # 'demo/clean_bg_extracted/55/frames/0000.png',
@@ -41,40 +58,40 @@ background_conditioned_examples = [
41
  1,
42
  "demo/clean_fg_extracted/14/frames/0000.png",
43
  "demo/clean_bg_extracted/22/frames/0000.png",
44
- "beautiful woman, cinematic lighting",
45
  "Use Background Video",
46
- 512,
47
- 512,
48
  "static_fg_sync_bg_visualization_fy/14_22_100fps.png",
49
  ],
50
  [
51
  2,
52
  "demo/clean_fg_extracted/14/frames/0000.png",
53
  "demo/clean_bg_extracted/55/frames/0000.png",
54
- "beautiful woman, cinematic lighting",
55
  "Use Background Video",
56
- 512,
57
- 512,
58
  "static_fg_sync_bg_visualization_fy/14_55_100fps.png",
59
  ],
60
  [
61
  3,
62
  "demo/clean_fg_extracted/15/frames/0000.png",
63
  "demo/clean_bg_extracted/27/frames/0000.png",
64
- "beautiful woman, cinematic lighting",
65
  "Use Background Video",
66
- 512,
67
- 512,
68
  "static_fg_sync_bg_visualization_fy/15_27_100fps.png",
69
  ],
70
  [
71
  4,
72
  "demo/clean_fg_extracted/18/frames/0000.png",
73
  "demo/clean_bg_extracted/33/frames/0000.png", # 23->33
74
- "beautiful woman, cinematic lighting",
75
  "Use Background Video",
76
- 512,
77
- 512,
78
  "static_fg_sync_bg_visualization_fy/18_33_100fps.png",
79
  ],
80
  # [
@@ -91,10 +108,10 @@ background_conditioned_examples = [
91
  5,
92
  "demo/clean_fg_extracted/22/frames/0000.png",
93
  "demo/clean_bg_extracted/59/frames/0000.png", # 39 -> 59
94
- "beautiful woman, cinematic lighting",
95
  "Use Background Video",
96
- 512,
97
- 512,
98
  "static_fg_sync_bg_visualization_fy/22_59_100fps.png",
99
  ],
100
  # [
@@ -107,38 +124,73 @@ background_conditioned_examples = [
107
  #
108
  # "static_fg_sync_bg_visualization_fy/22_59_100fps.png",
109
  # ],
 
110
  [
111
  6,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  "demo/clean_fg_extracted/9/frames/0000.png",
113
  "demo/clean_bg_extracted/8/frames/0000.png",
114
- "beautiful woman, cinematic lighting",
115
  "Use Background Video",
116
- 512,
117
- 512,
118
 
119
  "static_fg_sync_bg_visualization_fy/9_8_100fps.png",
120
  ],
121
  [
122
- 7,
123
  "demo/clean_fg_extracted/9/frames/0000.png",
124
  "demo/clean_bg_extracted/9/frames/0000.png",
125
- "beautiful woman, cinematic lighting",
126
  "Use Background Video",
127
- 512,
128
- 512,
129
  "static_fg_sync_bg_visualization_fy/9_9_100fps.png",
130
  ],
131
  [
132
- 8,
133
  "demo/clean_fg_extracted/9/frames/0000.png",
134
  "demo/clean_bg_extracted/10/frames/0000.png",
135
- "beautiful woman, cinematic lighting",
136
  "Use Background Video",
137
- 512,
138
- 512,
139
 
140
  "static_fg_sync_bg_visualization_fy/9_10_100fps.png",
141
  ],
 
 
 
 
 
 
 
 
 
 
 
 
142
  # [
143
  # "demo/clean_fg_extracted/9/frames/0000.png",
144
  # "demo/clean_bg_extracted/14/frames/0000.png",
 
21
  'demo/clean_fg_extracted/14/frames/0000.png',
22
  'demo/clean_fg_extracted/15/frames/0000.png',
23
  'demo/clean_fg_extracted/18/frames/0000.png',
24
+ 'demo/clean_fg_extracted/8/frames/0000.png',
25
+ 'demo/clean_fg_extracted/1/frames/0000.png',
26
+ # 'demo/clean_fg_extracted/22/frames/0000.png',
27
+ # 'demo/clean_fg_extracted/1/frames/0000.png',
28
+ # 'demo/clean_fg_extracted/2/frames/0000.png',
29
+ # 'demo/clean_fg_extracted/3/frames/0000.png',
30
+ # 'demo/clean_fg_extracted/4/frames/0000.png',
31
+ # 'demo/clean_fg_extracted/5/frames/0000.png',
32
+ # 'demo/clean_fg_extracted/6/frames/0000.png',
33
+ # 'demo/clean_fg_extracted/7/frames/0000.png',
34
+ # 'demo/clean_fg_extracted/8/frames/0000.png',
35
+ # 'demo/clean_fg_extracted/9/frames/0000.png',
36
+ # 'demo/clean_fg_extracted/10/frames/0000.png',
37
+ # 'demo/clean_fg_extracted/11/frames/0000.png',
38
+ # 'demo/clean_fg_extracted/12/frames/0000.png',
39
+ # 'demo/clean_fg_extracted/13/frames/0000.png',
40
+ # 'demo/clean_fg_extracted/16/frames/0000.png',
41
+ # 'demo/clean_fg_extracted/17/frames/0000.png',
42
+ # 'demo/clean_fg_extracted/9/frames/0000.png',
43
  # 'demo/clean_bg_extracted/39/frames/0000.png',
44
  # 'demo/clean_bg_extracted/59/frames/0000.png',
45
  # 'demo/clean_bg_extracted/55/frames/0000.png',
 
58
  1,
59
  "demo/clean_fg_extracted/14/frames/0000.png",
60
  "demo/clean_bg_extracted/22/frames/0000.png",
61
+ "beautiful woman, natural lighting",
62
  "Use Background Video",
63
+ # 512,
64
+ # 512,
65
  "static_fg_sync_bg_visualization_fy/14_22_100fps.png",
66
  ],
67
  [
68
  2,
69
  "demo/clean_fg_extracted/14/frames/0000.png",
70
  "demo/clean_bg_extracted/55/frames/0000.png",
71
+ "beautiful woman, neon dynamic lighting",
72
  "Use Background Video",
73
+ # 512,
74
+ # 512,
75
  "static_fg_sync_bg_visualization_fy/14_55_100fps.png",
76
  ],
77
  [
78
  3,
79
  "demo/clean_fg_extracted/15/frames/0000.png",
80
  "demo/clean_bg_extracted/27/frames/0000.png",
81
+ "man in suit, tunel lighting",
82
  "Use Background Video",
83
+ # 512,
84
+ # 512,
85
  "static_fg_sync_bg_visualization_fy/15_27_100fps.png",
86
  ],
87
  [
88
  4,
89
  "demo/clean_fg_extracted/18/frames/0000.png",
90
  "demo/clean_bg_extracted/33/frames/0000.png", # 23->33
91
+ "animated mouse, aesthetic lighting",
92
  "Use Background Video",
93
+ # 512,
94
+ # 512,
95
  "static_fg_sync_bg_visualization_fy/18_33_100fps.png",
96
  ],
97
  # [
 
108
  5,
109
  "demo/clean_fg_extracted/22/frames/0000.png",
110
  "demo/clean_bg_extracted/59/frames/0000.png", # 39 -> 59
111
+ "robot warrior, a sunset background",
112
  "Use Background Video",
113
+ # 512,
114
+ # 512,
115
  "static_fg_sync_bg_visualization_fy/22_59_100fps.png",
116
  ],
117
  # [
 
124
  #
125
  # "static_fg_sync_bg_visualization_fy/22_59_100fps.png",
126
  # ],
127
+
128
  [
129
  6,
130
+ "demo/clean_fg_extracted/17/frames/0000.png",
131
+ "demo/clean_bg_extracted/0/frames/0000.png",
132
+ "yellow cat, reflective wet beach",
133
+ "Use Background Video",
134
+ # 512,
135
+ # 512,
136
+
137
+ "static_fg_sync_bg_visualization_fy/17_0_100fps.png",
138
+ ],
139
+ [
140
+ 7,
141
+ "demo/clean_fg_extracted/16/frames/0000.png",
142
+ "demo/clean_bg_extracted/1/frames/0000.png",
143
+ "camera, dock, calm sunset",
144
+ "Use Background Video",
145
+ # 512,
146
+ # 512,
147
+
148
+ "static_fg_sync_bg_visualization_fy/16_1_100fps.png",
149
+ ],
150
+ [
151
+ 8,
152
  "demo/clean_fg_extracted/9/frames/0000.png",
153
  "demo/clean_bg_extracted/8/frames/0000.png",
154
+ "astronaut, dim lighting",
155
  "Use Background Video",
156
+ # 512,
157
+ # 512,
158
 
159
  "static_fg_sync_bg_visualization_fy/9_8_100fps.png",
160
  ],
161
  [
162
+ 9,
163
  "demo/clean_fg_extracted/9/frames/0000.png",
164
  "demo/clean_bg_extracted/9/frames/0000.png",
165
+ "astronaut, colorful balloons",
166
  "Use Background Video",
167
+ # 512,
168
+ # 512,
169
  "static_fg_sync_bg_visualization_fy/9_9_100fps.png",
170
  ],
171
  [
172
+ 10,
173
  "demo/clean_fg_extracted/9/frames/0000.png",
174
  "demo/clean_bg_extracted/10/frames/0000.png",
175
+ "astronaut, desert landscape",
176
  "Use Background Video",
177
+ # 512,
178
+ # 512,
179
 
180
  "static_fg_sync_bg_visualization_fy/9_10_100fps.png",
181
  ],
182
+
183
+ # [
184
+ # 11,
185
+ # "demo/clean_fg_extracted/7/frames/0000.png",
186
+ # "demo/clean_bg_extracted/2/frames/0000.png",
187
+ # "beautiful woman, cinematic lighting",
188
+ # "Use Background Video",
189
+ # 512,
190
+ # 512,
191
+
192
+ # "static_fg_sync_bg_visualization_fy/16_1_100fps.png",
193
+ # ],
194
  # [
195
  # "demo/clean_fg_extracted/9/frames/0000.png",
196
  # "demo/clean_bg_extracted/14/frames/0000.png",
demo/clean_bg_extracted/0/cropped_video.mp4 ADDED
Binary file (116 kB). View file
 
demo/clean_bg_extracted/0/frames/0000.png ADDED
demo/clean_bg_extracted/1/cropped_video.mp4 ADDED
Binary file (215 kB). View file
 
demo/clean_bg_extracted/1/frames/0000.png ADDED
demo/clean_bg_extracted/2/cropped_video.mp4 ADDED
Binary file (293 kB). View file
 
demo/clean_bg_extracted/2/frames/0000.png ADDED
demo/clean_fg_extracted/1/cropped_video.mp4 ADDED
Binary file (78.7 kB). View file
 
demo/clean_fg_extracted/1/frames/0000.png ADDED
demo/clean_fg_extracted/10/cropped_video.mp4 ADDED
Binary file (29.4 kB). View file
 
demo/clean_fg_extracted/10/frames/0000.png ADDED
demo/clean_fg_extracted/11/cropped_video.mp4 ADDED
Binary file (30.4 kB). View file
 
demo/clean_fg_extracted/11/frames/0000.png ADDED
demo/clean_fg_extracted/12/cropped_video.mp4 ADDED
Binary file (18.2 kB). View file
 
demo/clean_fg_extracted/12/frames/0000.png ADDED
demo/clean_fg_extracted/13/cropped_video.mp4 ADDED
Binary file (109 kB). View file
 
demo/clean_fg_extracted/13/frames/0000.png ADDED
demo/clean_fg_extracted/16/cropped_video.mp4 ADDED
Binary file (35.8 kB). View file
 
demo/clean_fg_extracted/16/frames/0000.png ADDED
demo/clean_fg_extracted/17/cropped_video.mp4 ADDED
Binary file (59.2 kB). View file
 
demo/clean_fg_extracted/17/frames/0000.png ADDED
demo/clean_fg_extracted/2/cropped_video.mp4 ADDED
Binary file (66.1 kB). View file
 
demo/clean_fg_extracted/2/frames/0000.png ADDED
demo/clean_fg_extracted/3/cropped_video.mp4 ADDED
Binary file (105 kB). View file
 
demo/clean_fg_extracted/3/frames/0000.png ADDED
demo/clean_fg_extracted/4/cropped_video.mp4 ADDED
Binary file (26.1 kB). View file
 
demo/clean_fg_extracted/4/frames/0000.png ADDED
demo/clean_fg_extracted/5/cropped_video.mp4 ADDED
Binary file (133 kB). View file
 
demo/clean_fg_extracted/5/frames/0000.png ADDED
demo/clean_fg_extracted/6/3.mp4 ADDED
Binary file (104 kB). View file
 
demo/clean_fg_extracted/6/frames/0000.png ADDED
demo/clean_fg_extracted/7/cropped_video.mp4 ADDED
Binary file (74.9 kB). View file
 
demo/clean_fg_extracted/7/frames/0000.png ADDED
demo/clean_fg_extracted/8/cropped_video.mp4 ADDED
Binary file (139 kB). View file
 
demo/clean_fg_extracted/8/frames/0000.png ADDED
misc_utils/__pycache__/flow_utils.cpython-310.pyc CHANGED
Binary files a/misc_utils/__pycache__/flow_utils.cpython-310.pyc and b/misc_utils/__pycache__/flow_utils.cpython-310.pyc differ
 
misc_utils/__pycache__/image_utils.cpython-310.pyc CHANGED
Binary files a/misc_utils/__pycache__/image_utils.cpython-310.pyc and b/misc_utils/__pycache__/image_utils.cpython-310.pyc differ
 
misc_utils/__pycache__/model_utils.cpython-310.pyc CHANGED
Binary files a/misc_utils/__pycache__/model_utils.cpython-310.pyc and b/misc_utils/__pycache__/model_utils.cpython-310.pyc differ
 
misc_utils/__pycache__/train_utils.cpython-310.pyc DELETED
Binary file (4.93 kB)
 
modules/openclip/__pycache__/modules.cpython-310.pyc DELETED
Binary file (8.62 kB)
 
modules/video_unet_temporal/__pycache__/attention.cpython-310.pyc CHANGED
Binary files a/modules/video_unet_temporal/__pycache__/attention.cpython-310.pyc and b/modules/video_unet_temporal/__pycache__/attention.cpython-310.pyc differ
 
modules/video_unet_temporal/__pycache__/motion_module.cpython-310.pyc CHANGED
Binary files a/modules/video_unet_temporal/__pycache__/motion_module.cpython-310.pyc and b/modules/video_unet_temporal/__pycache__/motion_module.cpython-310.pyc differ
 
modules/video_unet_temporal/__pycache__/resnet.cpython-310.pyc CHANGED
Binary files a/modules/video_unet_temporal/__pycache__/resnet.cpython-310.pyc and b/modules/video_unet_temporal/__pycache__/resnet.cpython-310.pyc differ
 
modules/video_unet_temporal/__pycache__/unet.cpython-310.pyc CHANGED
Binary files a/modules/video_unet_temporal/__pycache__/unet.cpython-310.pyc and b/modules/video_unet_temporal/__pycache__/unet.cpython-310.pyc differ