xizaoqu commited on
Commit
c7542a3
·
1 Parent(s): a7ea928
Files changed (2) hide show
  1. algorithms/worldmem/df_video.py +2 -5
  2. app.py +142 -59
algorithms/worldmem/df_video.py CHANGED
@@ -615,8 +615,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
615
  for _ in range(condition_similar_length):
616
  overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
617
 
618
- # if curr_frame == 54:
619
- # import pdb;pdb.set_trace()
620
  confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
621
 
622
  if len(random_idx) > 0:
@@ -624,10 +622,11 @@ class WorldMemMinecraft(DiffusionForcingBase):
624
  _, r_idx = torch.topk(confidence, k=1, dim=0)
625
  random_idx.append(r_idx[0])
626
 
 
627
  occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
628
-
629
  in_fov1 = in_fov1 & ~occupied_mask
630
 
 
631
  # cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
632
  # range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
633
  # cos_sim = cos_sim.mean((-2,-1))
@@ -637,8 +636,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
637
 
638
  random_idx = torch.stack(random_idx).cpu()
639
 
640
- print(random_idx)
641
-
642
  return random_idx
643
 
644
  def _prepare_conditions(self,
 
615
  for _ in range(condition_similar_length):
616
  overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
617
 
 
 
618
  confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
619
 
620
  if len(random_idx) > 0:
 
622
  _, r_idx = torch.topk(confidence, k=1, dim=0)
623
  random_idx.append(r_idx[0])
624
 
625
+ # choice 1: directly remove overlapping region
626
  occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
 
627
  in_fov1 = in_fov1 & ~occupied_mask
628
 
629
+ # choice 2: apply similarity filter
630
  # cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
631
  # range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
632
  # cos_sim = cos_sim.mean((-2,-1))
 
636
 
637
  random_idx = torch.stack(random_idx).cpu()
638
 
 
 
639
  return random_idx
640
 
641
  def _prepare_conditions(self,
app.py CHANGED
@@ -70,6 +70,13 @@ KEY_TO_ACTION = {
70
  "1": ("hotbar.1", 1),
71
  }
72
 
 
 
 
 
 
 
 
73
  def load_custom_checkpoint(algo, checkpoint_path):
74
  hf_ckpt = str(checkpoint_path).split('/')
75
  repo_id = '/'.join(hf_ckpt[:2])
@@ -156,7 +163,6 @@ def enable_amp(model, precision="16-mixed"):
156
  return model
157
 
158
  memory_frames = []
159
- memory_curr_frame = 0
160
  input_history = ""
161
  ICE_PLAINS_IMAGE = "assets/ice_plains.png"
162
  DESERT_IMAGE = "assets/desert.png"
@@ -166,7 +172,6 @@ PLACE_IMAGE = "assets/place.png"
166
  SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
167
  SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
168
 
169
- DEFAULT_IMAGE = ICE_PLAINS_IMAGE
170
  device = torch.device('cuda')
171
 
172
  def save_video(frames, path="output.mp4", fps=10):
@@ -193,13 +198,6 @@ worldmem = enable_amp(worldmem, precision="16-mixed")
193
  actions = np.zeros((1, 25), dtype=np.float32)
194
  poses = np.zeros((1, 5), dtype=np.float32)
195
 
196
- memory_frames = load_image_as_tensor(DEFAULT_IMAGE)[None].numpy()
197
-
198
- self_frames = None
199
- self_actions = None
200
- self_poses = None
201
- self_memory_c2w = None
202
- self_frame_idx = None
203
 
204
 
205
  def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
@@ -240,17 +238,8 @@ def set_memory_length(memory_length, sampling_memory_length_state):
240
  print("set memory length to", worldmem.condition_similar_length)
241
  return sampling_memory_length_state
242
 
243
- def generate(keys):
244
- # print("algo frame:", len(worldmem.frames))
245
  input_actions = parse_input_to_tensor(keys)
246
- global input_history
247
- global memory_frames
248
- global memory_curr_frame
249
- global self_frames
250
- global self_actions
251
- global self_poses
252
- global self_memory_c2w
253
- global self_frame_idx
254
 
255
  if self_frames is None:
256
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
@@ -282,25 +271,34 @@ def generate(keys):
282
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
283
  save_video(out_video, temporal_video_path)
284
 
 
 
 
 
 
285
  input_history += keys
286
- return out_video[-1], temporal_video_path, input_history
287
-
288
- def reset():
289
- global memory_curr_frame
290
- global input_history
291
- global memory_frames
292
- global self_frames
293
- global self_actions
294
- global self_poses
295
- global self_memory_c2w
296
- global self_frame_idx
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  self_frames = None
299
  self_poses = None
300
  self_actions = None
301
  self_memory_c2w = None
302
  self_frame_idx = None
303
- memory_frames = load_image_as_tensor(DEFAULT_IMAGE).numpy()[None]
304
  input_history = ""
305
 
306
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
@@ -313,14 +311,58 @@ def reset():
313
  self_memory_c2w=self_memory_c2w,
314
  self_frame_idx=self_frame_idx)
315
 
316
- return input_history, DEFAULT_IMAGE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
- def on_image_click(SELECTED_IMAGE):
319
- global DEFAULT_IMAGE
320
- DEFAULT_IMAGE = SELECTED_IMAGE
321
- reset()
322
- return SELECTED_IMAGE
323
 
 
 
 
 
324
 
325
  css = """
326
  h1 {
@@ -329,6 +371,10 @@ h1 {
329
  }
330
  """
331
 
 
 
 
 
332
  with gr.Blocks(css=css) as demo:
333
  gr.Markdown(
334
  """
@@ -358,13 +404,18 @@ with gr.Blocks(css=css) as demo:
358
  # </a>
359
  # </div>
360
 
361
- example_actions = ["AAAAAAAAAAAADDDDDDDDDDDD", "AAAAAAAAAAAAAAAAAAAAAAAA", "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
362
- "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS", "SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"]
 
 
 
 
363
 
 
364
 
365
  with gr.Row(variant="panel"):
366
  video_display = gr.Video(autoplay=True, loop=True)
367
- image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame")
368
 
369
 
370
  with gr.Row(variant="panel"):
@@ -374,17 +425,17 @@ with gr.Blocks(css=css) as demo:
374
  gr.Markdown("### Action sequence examples.")
375
  with gr.Row():
376
  buttons = []
377
- for action in example_actions[:2]:
378
- with gr.Column(scale=len(action)):
379
- buttons.append(gr.Button(action))
380
  with gr.Row():
381
- for action in example_actions[2:4]:
382
- with gr.Column(scale=len(action)):
383
- buttons.append(gr.Button(action))
384
  with gr.Row():
385
- for action in example_actions[4:6]:
386
- with gr.Column(scale=len(action)):
387
- buttons.append(gr.Button(action))
388
 
389
  with gr.Column(scale=1):
390
  slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps")
@@ -397,6 +448,12 @@ with gr.Blocks(css=css) as demo:
397
  sampling_context_length_state = gr.State(worldmem.n_tokens)
398
  sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
399
 
 
 
 
 
 
 
400
 
401
  def set_action(action):
402
  return action
@@ -404,8 +461,8 @@ with gr.Blocks(css=css) as demo:
404
  # gr.Markdown("### Action sequence examples.")
405
 
406
 
407
- for button, action in zip(buttons, example_actions):
408
- button.click(set_action, inputs=[gr.State(value=action)], outputs=input_box)
409
 
410
 
411
  gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
@@ -418,6 +475,32 @@ with gr.Blocks(css=css) as demo:
418
  image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
419
  image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  gr.Markdown(
422
  """
423
  ## Instructions & Notes:
@@ -441,14 +524,14 @@ with gr.Blocks(css=css) as demo:
441
  """
442
  )
443
  # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
444
- submit_button.click(generate, inputs=[input_box], outputs=[image_display, video_display, log_output])
445
- reset_btn.click(reset, outputs=[log_output, image_display])
446
- image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display)
447
- image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display)
448
- image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display)
449
- image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display)
450
- image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display)
451
- image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display)
452
 
453
  slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
454
  slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
 
70
  "1": ("hotbar.1", 1),
71
  }
72
 
73
+ example_images = [
74
+ ["1", "assets/ice_plains.png", "turn right+go backward+look up+turn left+look down+turn right+go forward+turn left", 20, 3, 8],
75
+ ["2", "assets/place.png", "put item+go backward+put item+go backward+go around", 20, 3, 8],
76
+ ["3", "assets/rain_sunflower_plains.png", "turn right+look up+turn right+look down+turn left+go backward+turn left", 20, 3, 8],
77
+ ["4", "assets/desert.png", "turn 360 degree+turn right+go forward+turn left", 20, 3, 8],
78
+ ]
79
+
80
  def load_custom_checkpoint(algo, checkpoint_path):
81
  hf_ckpt = str(checkpoint_path).split('/')
82
  repo_id = '/'.join(hf_ckpt[:2])
 
163
  return model
164
 
165
  memory_frames = []
 
166
  input_history = ""
167
  ICE_PLAINS_IMAGE = "assets/ice_plains.png"
168
  DESERT_IMAGE = "assets/desert.png"
 
172
  SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
173
  SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
174
 
 
175
  device = torch.device('cuda')
176
 
177
  def save_video(frames, path="output.mp4", fps=10):
 
198
  actions = np.zeros((1, 25), dtype=np.float32)
199
  poses = np.zeros((1, 5), dtype=np.float32)
200
 
 
 
 
 
 
 
 
201
 
202
 
203
  def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
 
238
  print("set memory length to", worldmem.condition_similar_length)
239
  return sampling_memory_length_state
240
 
241
+ def generate(keys, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
 
242
  input_actions = parse_input_to_tensor(keys)
 
 
 
 
 
 
 
 
243
 
244
  if self_frames is None:
245
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
 
271
  temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
272
  save_video(out_video, temporal_video_path)
273
 
274
+
275
+ now = datetime.now()
276
+ folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
277
+ folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
278
+ os.makedirs(folder_path, exist_ok=True)
279
  input_history += keys
 
 
 
 
 
 
 
 
 
 
 
280
 
281
+ data_dict = {
282
+ "input_history": input_history,
283
+ "memory_frames": memory_frames,
284
+ "self_frames": self_frames,
285
+ "self_actions": self_actions,
286
+ "self_poses": self_poses,
287
+ "self_memory_c2w": self_memory_c2w,
288
+ "self_frame_idx": self_frame_idx,
289
+ }
290
+
291
+ np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
292
+
293
+ return out_video[-1], temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
294
+
295
+ def reset(selected_image):
296
  self_frames = None
297
  self_poses = None
298
  self_actions = None
299
  self_memory_c2w = None
300
  self_frame_idx = None
301
+ memory_frames = load_image_as_tensor(selected_image).numpy()[None]
302
  input_history = ""
303
 
304
  new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
 
311
  self_memory_c2w=self_memory_c2w,
312
  self_frame_idx=self_frame_idx)
313
 
314
+ return input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
315
+
316
+ def on_image_click(selected_image):
317
+ input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = reset(selected_image)
318
+ return input_history, selected_image, selected_image, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
319
+
320
+ def set_memory(examples_case, image_display, log_output, slider_denoising_step, slider_context_length, slider_memory_length):
321
+ if examples_case == '1':
322
+ data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-11_16-01-49/data_bundle.npz")
323
+ input_history = data_bundle['input_history'].item()
324
+ memory_frames = data_bundle['memory_frames']
325
+ self_frames = data_bundle['self_frames']
326
+ self_actions = data_bundle['self_actions']
327
+ self_poses = data_bundle['self_poses']
328
+ self_memory_c2w = data_bundle['self_memory_c2w']
329
+ self_frame_idx = data_bundle['self_frame_idx']
330
+ elif examples_case == '2':
331
+ data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-12_10-42-04/data_bundle.npz")
332
+ input_history = data_bundle['input_history'].item()
333
+ memory_frames = data_bundle['memory_frames']
334
+ self_frames = data_bundle['self_frames']
335
+ self_actions = data_bundle['self_actions']
336
+ self_poses = data_bundle['self_poses']
337
+ self_memory_c2w = data_bundle['self_memory_c2w']
338
+ self_frame_idx = data_bundle['self_frame_idx']
339
+ elif examples_case == '3':
340
+ data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-12_10-56-57/data_bundle.npz")
341
+ input_history = data_bundle['input_history'].item()
342
+ memory_frames = data_bundle['memory_frames']
343
+ self_frames = data_bundle['self_frames']
344
+ self_actions = data_bundle['self_actions']
345
+ self_poses = data_bundle['self_poses']
346
+ self_memory_c2w = data_bundle['self_memory_c2w']
347
+ self_frame_idx = data_bundle['self_frame_idx']
348
+ elif examples_case == '4':
349
+ data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-11_16-07-19/data_bundle.npz")
350
+ input_history = data_bundle['input_history'].item()
351
+ memory_frames = data_bundle['memory_frames']
352
+ self_frames = data_bundle['self_frames']
353
+ self_actions = data_bundle['self_actions']
354
+ self_poses = data_bundle['self_poses']
355
+ self_memory_c2w = data_bundle['self_memory_c2w']
356
+ self_frame_idx = data_bundle['self_frame_idx']
357
 
358
+ out_video = memory_frames.transpose(0,2,3,1)
359
+ out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
360
+ out_video = (out_video * 255).astype(np.uint8)
 
 
361
 
362
+ temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
363
+ save_video(out_video, temporal_video_path)
364
+
365
+ return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
366
 
367
  css = """
368
  h1 {
 
371
  }
372
  """
373
 
374
+ def on_select(evt: gr.SelectData):
375
+ selected_index = evt.index
376
+ return examples[selected_index]
377
+
378
  with gr.Blocks(css=css) as demo:
379
  gr.Markdown(
380
  """
 
404
  # </a>
405
  # </div>
406
 
407
+ example_actions = {"turn left + turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
408
+ "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
409
+ "turn right+go backward+look up+turn left+look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
410
+ "turn right+go forward+turn left": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
411
+ "turn right+look up+turn right+look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
412
+ "put item+go backward+put item+go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
413
 
414
+ selected_image = gr.State(ICE_PLAINS_IMAGE)
415
 
416
  with gr.Row(variant="panel"):
417
  video_display = gr.Video(autoplay=True, loop=True)
418
+ image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
419
 
420
 
421
  with gr.Row(variant="panel"):
 
425
  gr.Markdown("### Action sequence examples.")
426
  with gr.Row():
427
  buttons = []
428
+ for action_key in list(example_actions.keys())[:2]:
429
+ with gr.Column(scale=len(action_key)):
430
+ buttons.append(gr.Button(action_key))
431
  with gr.Row():
432
+ for action_key in list(example_actions.keys())[2:4]:
433
+ with gr.Column(scale=len(action_key)):
434
+ buttons.append(gr.Button(action_key))
435
  with gr.Row():
436
+ for action_key in list(example_actions.keys())[4:6]:
437
+ with gr.Column(scale=len(action_key)):
438
+ buttons.append(gr.Button(action_key))
439
 
440
  with gr.Column(scale=1):
441
  slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps")
 
448
  sampling_context_length_state = gr.State(worldmem.n_tokens)
449
  sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
450
 
451
+ memory_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
452
+ self_frames = gr.State()
453
+ self_actions = gr.State()
454
+ self_poses = gr.State()
455
+ self_memory_c2w = gr.State()
456
+ self_frame_idx = gr.State()
457
 
458
  def set_action(action):
459
  return action
 
461
  # gr.Markdown("### Action sequence examples.")
462
 
463
 
464
+ for button, action_key in zip(buttons, list(example_actions.keys())):
465
+ button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
466
 
467
 
468
  gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
 
475
  image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
476
  image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
477
 
478
+ gr.Markdown("### Click the examples below for a quick review, and continue generating based on them.")
479
+
480
+ example_case = gr.Textbox(label="Case", visible=False)
481
+ image_output = gr.Image(visible=False)
482
+
483
+ # gr.Examples(examples=example_images,
484
+ # inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
485
+ # fn=set_memory,
486
+ # outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx],
487
+ # cache_examples=True
488
+ # )
489
+
490
+ examples = gr.Examples(
491
+ examples=example_images,
492
+ inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
493
+ cache_examples=False
494
+ )
495
+
496
+ example_case.change(
497
+ fn=set_memory,
498
+ inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
499
+ outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
500
+ )
501
+
502
+
503
+
504
  gr.Markdown(
505
  """
506
  ## Instructions & Notes:
 
524
  """
525
  )
526
  # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
527
+ submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
528
+ reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
529
+ image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
530
+ image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
531
+ image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
532
+ image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
533
+ image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
534
+ image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
535
 
536
  slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
537
  slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)