xizaoqu commited on
Commit
ae8fd03
·
1 Parent(s): faeb2a7
algorithms/common/base_algo.py CHANGED
@@ -12,7 +12,6 @@ class BaseAlgo(ABC):
12
  def __init__(self, cfg: DictConfig):
13
  super().__init__()
14
  self.cfg = cfg
15
- self.debug = self.cfg.debug
16
 
17
  @abstractmethod
18
  def run(*args: Any, **kwargs: Any) -> Any:
 
12
  def __init__(self, cfg: DictConfig):
13
  super().__init__()
14
  self.cfg = cfg
 
15
 
16
  @abstractmethod
17
  def run(*args: Any, **kwargs: Any) -> Any:
algorithms/common/base_pytorch_algo.py CHANGED
@@ -21,7 +21,6 @@ class BasePytorchAlgo(pl.LightningModule, ABC):
21
  def __init__(self, cfg: DictConfig):
22
  super().__init__()
23
  self.cfg = cfg
24
- self.debug = self.cfg.debug
25
  self._build_model()
26
 
27
  @abstractmethod
 
21
  def __init__(self, cfg: DictConfig):
22
  super().__init__()
23
  self.cfg = cfg
 
24
  self._build_model()
25
 
26
  @abstractmethod
algorithms/worldmem/df_video.py CHANGED
@@ -379,7 +379,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
379
  ref_mode=self.ref_mode
380
  )
381
 
382
- self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
383
  self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
384
 
385
  vae = VAE_models["vit-l-20-shallow-encoder"]()
 
379
  ref_mode=self.ref_mode
380
  )
381
 
382
+ # self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
383
  self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
384
 
385
  vae = VAE_models["vit-l-20-shallow-encoder"]()
app.py CHANGED
@@ -23,6 +23,8 @@ import subprocess
23
  from PIL import Image
24
  from datetime import datetime
25
  import spaces
 
 
26
 
27
  ACTION_KEYS = [
28
  "inventory",
@@ -65,6 +67,16 @@ KEY_TO_ACTION = {
65
  "1": ("hotbar.1", 1),
66
  }
67
 
 
 
 
 
 
 
 
 
 
 
68
  def parse_input_to_tensor(input_str):
69
  """
70
  Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
@@ -157,265 +169,223 @@ def save_video(frames, path="output.mp4", fps=10):
157
  subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
158
  return path
159
 
160
- class InteractiveRunner:
161
- def __init__(self, algo):
162
- self.algo = algo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- @spaces.GPU()
165
- @torch.autocast("cuda")
166
- def run(self, first_frame, action, first_pose, curr_frame, device):
167
- return self.algo.interactive(first_frame, action, first_pose, curr_frame, device=device)
168
 
 
 
 
 
169
 
170
- @hydra.main(
171
- version_base=None,
172
- config_path="configurations",
173
- config_name="huggingface",
174
- )
175
- def run(cfg: DictConfig):
176
 
177
- algo = run_local(cfg)
178
- algo.to(device)
179
 
180
- algodevice = next(algo.parameters()).device
181
- print("algo:", algodevice)
 
 
182
 
183
- actions = torch.zeros((1, 25))
184
- poses = torch.zeros((1, 5))
185
-
186
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
 
 
 
 
 
 
 
 
 
187
 
188
- runner = InteractiveRunner(algo)
189
-
190
- algodevice = next(runner.algo.parameters()).device
191
- print("runner.algo:", algodevice)
192
-
193
- # @spaces.GPU()
194
- # def run_interactive(first_frame, action, first_pose, curr_frame, device):
195
- # global algo
196
- # new_frame = algo.interactive(first_frame,
197
- # action,
198
- # first_pose,
199
- # curr_frame,
200
- # device=device)
201
- # return new_frame
202
-
203
- def set_denoising_steps(denoising_steps, sampling_timesteps_state):
204
- runner.algo.sampling_timesteps = denoising_steps
205
- runner.algo.diffusion_model.sampling_timesteps = denoising_steps
206
- sampling_timesteps_state = denoising_steps
207
- print("set denoising steps to", runner.algo.sampling_timesteps)
208
- return sampling_timesteps_state
209
-
210
- def update_image_and_log(keys):
211
- actions = parse_input_to_tensor(keys)
212
- global input_history
213
- global memory_curr_frame
214
-
215
- print("algo frame:", len(runner.algo.frames))
216
-
217
- for i in range(len(actions)):
218
- memory_curr_frame += 1
219
-
220
- # new_frame = run_interactive(memory_frames[0],
221
- # actions[i],
222
- # None,
223
- # memory_curr_frame,
224
- # device=device)
225
-
226
- new_frame = runner.run(
227
- memory_frames[0],
228
- actions[i],
229
- None,
230
- memory_curr_frame,
231
- device
232
- )
233
-
234
- print("algo frame:", len(runner.algo.frames))
235
-
236
- memory_frames.append(new_frame)
237
-
238
- out_video = torch.stack(memory_frames)
239
- out_video = out_video.permute(0,2,3,1).numpy()
240
- out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
241
- out_video = (out_video * 255).astype(np.uint8)
242
-
243
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
244
- os.makedirs("outputs_gradio", exist_ok=True)
245
- filename = f"outputs_gradio/{timestamp}.mp4"
246
- save_video(out_video, filename)
247
-
248
- input_history += keys
249
- return out_video[-1], filename, input_history
250
-
251
- def reset():
252
- global memory_curr_frame
253
- global input_history
254
- global memory_frames
255
-
256
- # runner.algo.to(device)
257
- algodevice = next(runner.algo.parameters()).device
258
- print(algodevice)
259
- runner.algo.reset()
260
- memory_frames = []
261
- memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
262
- memory_curr_frame = 0
263
- input_history = ""
264
-
265
- # _ = run_interactive(memory_frames[0],
266
- # actions[0],
267
- # poses[0],
268
- # memory_curr_frame,
269
- # device=device)
270
- #
271
-
272
- new_frame = runner.run(
273
- memory_frames[0],
274
- actions[0],
275
- poses[0],
276
- memory_curr_frame,
277
- device
278
- )
279
-
280
- return input_history, DEFAULT_IMAGE
281
-
282
- def on_image_click(SELECTED_IMAGE):
283
  global DEFAULT_IMAGE
284
  DEFAULT_IMAGE = SELECTED_IMAGE
285
  reset()
286
  return SELECTED_IMAGE
287
 
288
- # new_frame = runner.run(
289
- # memory_frames[0],
290
- # actions[0],
291
- # poses[0],
292
- # memory_curr_frame,
293
- # device
294
- # )
295
-
296
- # print("first algo frame:", len(algo.frames))
297
 
298
- css = """
299
- h1 {
300
- text-align: center;
301
- display:block;
302
- }
303
- """
304
 
305
- with gr.Blocks(css=css) as demo:
306
- gr.Markdown(
307
- """
308
- # WORLDMEM: Long-term Consistent World Generation with Memory
309
- """
310
- )
311
-
312
- # <div style="text-align: center;">
313
- # <!-- Public Website -->
314
- # <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
315
- # <img src="https://img.shields.io/badge/public_website-8A2BE2">
316
- # </a>
317
-
318
- # <!-- GitHub Stars -->
319
- # <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
320
- # <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social">
321
- # </a>
322
-
323
- # <!-- Project Page -->
324
- # <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
325
- # <img src="https://img.shields.io/badge/project_page-blue">
326
- # </a>
327
-
328
- # <!-- arXiv Paper -->
329
- # <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
330
- # <img src="https://img.shields.io/badge/arXiv-paper-red">
331
- # </a>
332
- # </div>
333
-
334
- with gr.Row(variant="panel"):
335
- video_display = gr.Video(autoplay=True, loop=True)
336
- image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
337
-
338
- with gr.Row(variant="panel"):
339
- with gr.Column(scale=2):
340
- input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1)
341
- log_output = gr.Textbox(label="History Log", interactive=False)
342
- with gr.Column(scale=1):
343
- slider = gr.Slider(minimum=10, maximum=50, value=runner.algo.sampling_timesteps, step=1, label="Denoising Steps")
344
- submit_button = gr.Button("Generate")
345
- reset_btn = gr.Button("Reset")
346
-
347
- sampling_timesteps_state = gr.State(runner.algo.sampling_timesteps)
348
-
349
- example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
350
- "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAA1NNNNNNNNNSSUNNWW"]
351
-
352
- def set_action(action):
353
- return action
354
-
355
- gr.Markdown("### Action sequence examples.")
356
- with gr.Row():
357
- buttons = []
358
- for action in example_actions[:2]:
359
- with gr.Column(scale=len(action)):
360
- buttons.append(gr.Button(action))
361
- with gr.Row():
362
- for action in example_actions[2:4]:
363
- with gr.Column(scale=len(action)):
364
- buttons.append(gr.Button(action))
365
- with gr.Row():
366
- for action in example_actions[4:5]:
367
- with gr.Column(scale=len(action)):
368
- buttons.append(gr.Button(action))
369
-
370
- for button, action in zip(buttons, example_actions):
371
- button.click(set_action, inputs=[gr.State(value=action)], outputs=input_box)
372
-
373
-
374
- gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
375
-
376
- with gr.Row():
377
- image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
378
- image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
379
- image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
380
- image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
381
- image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
382
- image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
383
-
384
- gr.Markdown(
385
- """
386
- ## Instructions & Notes:
387
-
388
- 1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin.
389
- 2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel.
390
- 3. Click **"Reset"** to clear the current sequence and start fresh.
391
- 4. Action sequences can be composed using the following keys:
392
- - W: turn up
393
- - S: turn down
394
- - A: turn left
395
- - D: turn right
396
- - Q: move forward
397
- - E: move backward
398
- - N: no-op (do nothing)
399
- - 1: switch to hotbar 1
400
- - U: use item
401
- 5. Higher denoising steps produce more detailed results but take longer. **20 steps** is a good balance between quality and speed.
402
- 6. If you find this project interesting or useful, please consider giving it a ⭐️ on [GitHub]()!
403
- 7. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **[email protected]**.
404
- """
405
  )
406
- # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
407
- submit_button.click(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
408
- reset_btn.click(reset, outputs=[log_output, image_display])
409
- image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display)
410
- image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display)
411
- image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display)
412
- image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display)
413
- image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display)
414
- image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display)
415
-
416
- slider.change(fn=set_denoising_steps, inputs=[slider, sampling_timesteps_state], outputs=sampling_timesteps_state)
417
-
418
- demo.launch()
419
-
420
- if __name__ == "__main__":
421
- run() # pylint: disable=no-value-for-parameter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  from PIL import Image
24
  from datetime import datetime
25
  import spaces
26
+ from algorithms.worldmem import WorldMemMinecraft
27
+ from huggingface_hub import hf_hub_download
28
 
29
  ACTION_KEYS = [
30
  "inventory",
 
67
  "1": ("hotbar.1", 1),
68
  }
69
 
70
+ def load_custom_checkpoint(algo, checkpoint_path):
71
+ hf_ckpt = str(checkpoint_path).split('/')
72
+ repo_id = '/'.join(hf_ckpt[:2])
73
+ file_name = '/'.join(hf_ckpt[2:])
74
+ model_path = hf_hub_download(repo_id=repo_id,
75
+ filename=file_name)
76
+ ckpt = torch.load(model_path, map_location=torch.device('cpu'))
77
+ algo.load_state_dict(ckpt['state_dict'], strict=False)
78
+
79
+
80
  def parse_input_to_tensor(input_str):
81
  """
82
  Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
 
169
  subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
170
  return path
171
 
172
+ cfg = OmegaConf.load("configurations/huggingface.yaml")
173
+ worldmem = WorldMemMinecraft(cfg)
174
+ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffusion_path)
175
+ load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
176
+ load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
177
+ worldmem.to("cuda").eval()
178
+
179
+
180
+ actions = torch.zeros((1, 25))
181
+ poses = torch.zeros((1, 5))
182
+
183
+ memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
184
+
185
+ @spaces.GPU()
186
+ def run_interactive(first_frame, action, first_pose, curr_frame, device):
187
+ new_frame = worldmem.interactive(first_frame,
188
+ action,
189
+ first_pose,
190
+ curr_frame,
191
+ device=device)
192
+ return new_frame
193
+
194
+ def set_denoising_steps(denoising_steps, sampling_timesteps_state):
195
+ worldmem.sampling_timesteps = denoising_steps
196
+ worldmem.diffusion_model.sampling_timesteps = denoising_steps
197
+ sampling_timesteps_state = denoising_steps
198
+ print("set denoising steps to", worldmem.sampling_timesteps)
199
+ return sampling_timesteps_state
200
+
201
+ def update_image_and_log(keys):
202
+ actions = parse_input_to_tensor(keys)
203
+ global input_history
204
+ global memory_curr_frame
205
+
206
+ print("algo frame:", len(worldmem.frames))
207
+
208
+ for i in range(len(actions)):
209
+ memory_curr_frame += 1
210
+
211
+ new_frame = run_interactive(memory_frames[0],
212
+ actions[i],
213
+ None,
214
+ memory_curr_frame,
215
+ device=device)
216
+
217
+ # print("algo frame:", len(runner.algo.frames))
218
 
219
+ memory_frames.append(new_frame)
 
 
 
220
 
221
+ out_video = torch.stack(memory_frames)
222
+ out_video = out_video.permute(0,2,3,1).numpy()
223
+ out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
224
+ out_video = (out_video * 255).astype(np.uint8)
225
 
226
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
227
+ os.makedirs("outputs_gradio", exist_ok=True)
228
+ filename = f"outputs_gradio/{timestamp}.mp4"
229
+ save_video(out_video, filename)
 
 
230
 
231
+ input_history += keys
232
+ return out_video[-1], filename, input_history
233
 
234
+ def reset():
235
+ global memory_curr_frame
236
+ global input_history
237
+ global memory_frames
238
 
239
+ worldmem.reset()
240
+ memory_frames = []
 
241
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
242
+ memory_curr_frame = 0
243
+ input_history = ""
244
+
245
+ _ = run_interactive(memory_frames[0],
246
+ actions[0],
247
+ poses[0],
248
+ memory_curr_frame,
249
+ device=device)
250
+
251
 
252
+
253
+ return input_history, DEFAULT_IMAGE
254
+
255
+ def on_image_click(SELECTED_IMAGE):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  global DEFAULT_IMAGE
257
  DEFAULT_IMAGE = SELECTED_IMAGE
258
  reset()
259
  return SELECTED_IMAGE
260
 
261
+ # new_frame = runner.run(
262
+ # memory_frames[0],
263
+ # actions[0],
264
+ # poses[0],
265
+ # memory_curr_frame,
266
+ # device
267
+ # )
 
 
268
 
269
+ # print("first algo frame:", len(algo.frames))
 
 
 
 
 
270
 
271
+ css = """
272
+ h1 {
273
+ text-align: center;
274
+ display:block;
275
+ }
276
+ """
277
+
278
+ with gr.Blocks(css=css) as demo:
279
+ gr.Markdown(
280
+ """
281
+ # WORLDMEM: Long-term Consistent World Generation with Memory
282
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  )
284
+
285
+ # <div style="text-align: center;">
286
+ # <!-- Public Website -->
287
+ # <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
288
+ # <img src="https://img.shields.io/badge/public_website-8A2BE2">
289
+ # </a>
290
+
291
+ # <!-- GitHub Stars -->
292
+ # <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
293
+ # <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social">
294
+ # </a>
295
+
296
+ # <!-- Project Page -->
297
+ # <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
298
+ # <img src="https://img.shields.io/badge/project_page-blue">
299
+ # </a>
300
+
301
+ # <!-- arXiv Paper -->
302
+ # <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
303
+ # <img src="https://img.shields.io/badge/arXiv-paper-red">
304
+ # </a>
305
+ # </div>
306
+
307
+ with gr.Row(variant="panel"):
308
+ video_display = gr.Video(autoplay=True, loop=True)
309
+ image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
310
+
311
+ with gr.Row(variant="panel"):
312
+ with gr.Column(scale=2):
313
+ input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1)
314
+ log_output = gr.Textbox(label="History Log", interactive=False)
315
+ with gr.Column(scale=1):
316
+ slider = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps")
317
+ submit_button = gr.Button("Generate")
318
+ reset_btn = gr.Button("Reset")
319
+
320
+ sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
321
+
322
+ example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
323
+ "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAA1NNNNNNNNNSSUNNWW"]
324
+
325
+ def set_action(action):
326
+ return action
327
+
328
+ gr.Markdown("### Action sequence examples.")
329
+ with gr.Row():
330
+ buttons = []
331
+ for action in example_actions[:2]:
332
+ with gr.Column(scale=len(action)):
333
+ buttons.append(gr.Button(action))
334
+ with gr.Row():
335
+ for action in example_actions[2:4]:
336
+ with gr.Column(scale=len(action)):
337
+ buttons.append(gr.Button(action))
338
+ with gr.Row():
339
+ for action in example_actions[4:5]:
340
+ with gr.Column(scale=len(action)):
341
+ buttons.append(gr.Button(action))
342
+
343
+ for button, action in zip(buttons, example_actions):
344
+ button.click(set_action, inputs=[gr.State(value=action)], outputs=input_box)
345
+
346
+
347
+ gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
348
+
349
+ with gr.Row():
350
+ image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
351
+ image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
352
+ image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
353
+ image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
354
+ image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
355
+ image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
356
+
357
+ gr.Markdown(
358
+ """
359
+ ## Instructions & Notes:
360
+
361
+ 1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin.
362
+ 2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel.
363
+ 3. Click **"Reset"** to clear the current sequence and start fresh.
364
+ 4. Action sequences can be composed using the following keys:
365
+ - W: turn up
366
+ - S: turn down
367
+ - A: turn left
368
+ - D: turn right
369
+ - Q: move forward
370
+ - E: move backward
371
+ - N: no-op (do nothing)
372
+ - 1: switch to hotbar 1
373
+ - U: use item
374
+ 5. Higher denoising steps produce more detailed results but take longer. **20 steps** is a good balance between quality and speed.
375
+ 6. If you find this project interesting or useful, please consider giving it a ⭐️ on [GitHub]()!
376
+ 7. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **[email protected]**.
377
+ """
378
+ )
379
+ # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
380
+ submit_button.click(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
381
+ reset_btn.click(reset, outputs=[log_output, image_display])
382
+ image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display)
383
+ image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display)
384
+ image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display)
385
+ image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display)
386
+ image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display)
387
+ image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display)
388
+
389
+ slider.change(fn=set_denoising_steps, inputs=[slider, sampling_timesteps_state], outputs=sampling_timesteps_state)
390
+
391
+ demo.launch()
configurations/huggingface.yaml CHANGED
@@ -1,57 +1,58 @@
1
- defaults:
2
- - algorithm: df_video_worldmemminecraft
3
- - experiment: exp_video
4
- - dataset: video_minecraft
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- dataset:
7
- n_frames_valid: 100
8
- validation_multiplier: 1
9
- use_plucker: true
10
- customized_validation: true
11
- condition_similar_length: 8
12
- padding_pool: 10
13
- focal_length: 0.35
14
- save_dir: data/test_pumpkin
15
- add_frame_timestep_embedder: true
16
- pos_range: 0.5
17
- angle_range: 30
18
-
19
- experiment:
20
- tasks: [interactive]
21
- training:
22
- data:
23
- num_workers: 4
24
- validation:
25
- batch_size: 1
26
- limit_batch: 1
27
- data:
28
- num_workers: 4
29
- load_vae: false
30
- load_t_to_r: false
31
- zero_init_gate: false
32
- only_tune_refer: false
33
- diffusion_path: yslan/worldmem_checkpoints/diffusion_only.ckpt
34
- vae_path: yslan/worldmem_checkpoints/vae_only.ckpt
35
- pose_predictor_path: yslan/worldmem_checkpoints/pose_prediction_model_only.ckpt
36
- customized_load: true
37
-
38
- algorithm:
39
- n_tokens: 8
40
- context_frames: 90
41
- pose_cond_dim: 5
42
- use_plucker: true
43
- focal_length: 0.35
44
- customized_validation: true
45
- condition_similar_length: 8
46
- log_video: true
47
- relative_embedding: true
48
- cond_only_on_qk: true
49
- add_pose_embed: false
50
- use_domain_adapter: false
51
- use_reference_attention: true
52
- add_frame_timestep_embedder: true
53
- is_interactive: true
54
- diffusion:
55
- sampling_timesteps: 20
56
-
57
- debug: false
 
1
+ n_tokens: 8
2
+ pose_cond_dim: 5
3
+ use_plucker: true
4
+ focal_length: 0.35
5
+ customized_validation: true
6
+ condition_similar_length: 8
7
+ log_video: true
8
+ relative_embedding: true
9
+ cond_only_on_qk: true
10
+ add_pose_embed: false
11
+ use_domain_adapter: false
12
+ use_reference_attention: true
13
+ add_frame_timestep_embedder: true
14
+ is_interactive: true
15
+ diffusion:
16
+ sampling_timesteps: 20
17
+ beta_schedule: sigmoid
18
+ objective: pred_v
19
+ use_fused_snr: True
20
+ cum_snr_decay: 0.96
21
+ clip_noise: 20.
22
+ ddim_sampling_eta: 0.0
23
+ stabilization_level: 15
24
+ schedule_fn_kwargs: {}
25
+ use_snr: False
26
+ use_cum_snr: False
27
+ snr_clip: 5.0
28
+ timesteps: 1000
29
+ # architecture
30
+ architecture:
31
+ network_size: 64
32
+ attn_heads: 4
33
+ attn_dim_head: 64
34
+ dim_mults: [1, 2, 4, 8]
35
+ resolution: ${dataset.resolution}
36
+ attn_resolutions: [16, 32, 64, 128]
37
+ use_init_temporal_attn: True
38
+ use_linear_attn: True
39
+ time_emb_type: rotary
40
 
41
+ weight_decay: 2e-3
42
+ warmup_steps: 10000
43
+ optimizer_beta: [0.9, 0.99]
44
+ action_cond_dim: 25
45
+ n_frames: 8
46
+ frame_skip: 1
47
+ frame_stack: 1
48
+ uncertainty_scale: 1
49
+ guidance_scale: 0.0
50
+ chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
51
+ scheduling_matrix: autoregressive
52
+ noise_level: random_all
53
+ causal: True
54
+ x_shape: [3, 360, 640]
55
+ context_frames: 1
56
+ diffusion_path: yslan/worldmem_checkpoints/diffusion_only.ckpt
57
+ vae_path: yslan/worldmem_checkpoints/vae_only.ckpt
58
+ pose_predictor_path: yslan/worldmem_checkpoints/pose_prediction_model_only.ckpt