xizaoqu
commited on
Commit
·
c7542a3
1
Parent(s):
a7ea928
update
Browse files- algorithms/worldmem/df_video.py +2 -5
- app.py +142 -59
algorithms/worldmem/df_video.py
CHANGED
@@ -615,8 +615,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
615 |
for _ in range(condition_similar_length):
|
616 |
overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
|
617 |
|
618 |
-
# if curr_frame == 54:
|
619 |
-
# import pdb;pdb.set_trace()
|
620 |
confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
|
621 |
|
622 |
if len(random_idx) > 0:
|
@@ -624,10 +622,11 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
624 |
_, r_idx = torch.topk(confidence, k=1, dim=0)
|
625 |
random_idx.append(r_idx[0])
|
626 |
|
|
|
627 |
occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
|
628 |
-
|
629 |
in_fov1 = in_fov1 & ~occupied_mask
|
630 |
|
|
|
631 |
# cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
|
632 |
# range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
|
633 |
# cos_sim = cos_sim.mean((-2,-1))
|
@@ -637,8 +636,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
637 |
|
638 |
random_idx = torch.stack(random_idx).cpu()
|
639 |
|
640 |
-
print(random_idx)
|
641 |
-
|
642 |
return random_idx
|
643 |
|
644 |
def _prepare_conditions(self,
|
|
|
615 |
for _ in range(condition_similar_length):
|
616 |
overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
|
617 |
|
|
|
|
|
618 |
confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
|
619 |
|
620 |
if len(random_idx) > 0:
|
|
|
622 |
_, r_idx = torch.topk(confidence, k=1, dim=0)
|
623 |
random_idx.append(r_idx[0])
|
624 |
|
625 |
+
# choice 1: directly remove overlapping region
|
626 |
occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
|
|
|
627 |
in_fov1 = in_fov1 & ~occupied_mask
|
628 |
|
629 |
+
# choice 2: apply similarity filter
|
630 |
# cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
|
631 |
# range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
|
632 |
# cos_sim = cos_sim.mean((-2,-1))
|
|
|
636 |
|
637 |
random_idx = torch.stack(random_idx).cpu()
|
638 |
|
|
|
|
|
639 |
return random_idx
|
640 |
|
641 |
def _prepare_conditions(self,
|
app.py
CHANGED
@@ -70,6 +70,13 @@ KEY_TO_ACTION = {
|
|
70 |
"1": ("hotbar.1", 1),
|
71 |
}
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def load_custom_checkpoint(algo, checkpoint_path):
|
74 |
hf_ckpt = str(checkpoint_path).split('/')
|
75 |
repo_id = '/'.join(hf_ckpt[:2])
|
@@ -156,7 +163,6 @@ def enable_amp(model, precision="16-mixed"):
|
|
156 |
return model
|
157 |
|
158 |
memory_frames = []
|
159 |
-
memory_curr_frame = 0
|
160 |
input_history = ""
|
161 |
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
162 |
DESERT_IMAGE = "assets/desert.png"
|
@@ -166,7 +172,6 @@ PLACE_IMAGE = "assets/place.png"
|
|
166 |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
167 |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
168 |
|
169 |
-
DEFAULT_IMAGE = ICE_PLAINS_IMAGE
|
170 |
device = torch.device('cuda')
|
171 |
|
172 |
def save_video(frames, path="output.mp4", fps=10):
|
@@ -193,13 +198,6 @@ worldmem = enable_amp(worldmem, precision="16-mixed")
|
|
193 |
actions = np.zeros((1, 25), dtype=np.float32)
|
194 |
poses = np.zeros((1, 5), dtype=np.float32)
|
195 |
|
196 |
-
memory_frames = load_image_as_tensor(DEFAULT_IMAGE)[None].numpy()
|
197 |
-
|
198 |
-
self_frames = None
|
199 |
-
self_actions = None
|
200 |
-
self_poses = None
|
201 |
-
self_memory_c2w = None
|
202 |
-
self_frame_idx = None
|
203 |
|
204 |
|
205 |
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
|
@@ -240,17 +238,8 @@ def set_memory_length(memory_length, sampling_memory_length_state):
|
|
240 |
print("set memory length to", worldmem.condition_similar_length)
|
241 |
return sampling_memory_length_state
|
242 |
|
243 |
-
def generate(keys):
|
244 |
-
# print("algo frame:", len(worldmem.frames))
|
245 |
input_actions = parse_input_to_tensor(keys)
|
246 |
-
global input_history
|
247 |
-
global memory_frames
|
248 |
-
global memory_curr_frame
|
249 |
-
global self_frames
|
250 |
-
global self_actions
|
251 |
-
global self_poses
|
252 |
-
global self_memory_c2w
|
253 |
-
global self_frame_idx
|
254 |
|
255 |
if self_frames is None:
|
256 |
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
@@ -282,25 +271,34 @@ def generate(keys):
|
|
282 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
283 |
save_video(out_video, temporal_video_path)
|
284 |
|
|
|
|
|
|
|
|
|
|
|
285 |
input_history += keys
|
286 |
-
return out_video[-1], temporal_video_path, input_history
|
287 |
-
|
288 |
-
def reset():
|
289 |
-
global memory_curr_frame
|
290 |
-
global input_history
|
291 |
-
global memory_frames
|
292 |
-
global self_frames
|
293 |
-
global self_actions
|
294 |
-
global self_poses
|
295 |
-
global self_memory_c2w
|
296 |
-
global self_frame_idx
|
297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
self_frames = None
|
299 |
self_poses = None
|
300 |
self_actions = None
|
301 |
self_memory_c2w = None
|
302 |
self_frame_idx = None
|
303 |
-
memory_frames = load_image_as_tensor(
|
304 |
input_history = ""
|
305 |
|
306 |
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
@@ -313,14 +311,58 @@ def reset():
|
|
313 |
self_memory_c2w=self_memory_c2w,
|
314 |
self_frame_idx=self_frame_idx)
|
315 |
|
316 |
-
return input_history,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
reset()
|
322 |
-
return SELECTED_IMAGE
|
323 |
|
|
|
|
|
|
|
|
|
324 |
|
325 |
css = """
|
326 |
h1 {
|
@@ -329,6 +371,10 @@ h1 {
|
|
329 |
}
|
330 |
"""
|
331 |
|
|
|
|
|
|
|
|
|
332 |
with gr.Blocks(css=css) as demo:
|
333 |
gr.Markdown(
|
334 |
"""
|
@@ -358,13 +404,18 @@ with gr.Blocks(css=css) as demo:
|
|
358 |
# </a>
|
359 |
# </div>
|
360 |
|
361 |
-
example_actions =
|
362 |
-
|
|
|
|
|
|
|
|
|
363 |
|
|
|
364 |
|
365 |
with gr.Row(variant="panel"):
|
366 |
video_display = gr.Video(autoplay=True, loop=True)
|
367 |
-
image_display = gr.Image(value=
|
368 |
|
369 |
|
370 |
with gr.Row(variant="panel"):
|
@@ -374,17 +425,17 @@ with gr.Blocks(css=css) as demo:
|
|
374 |
gr.Markdown("### Action sequence examples.")
|
375 |
with gr.Row():
|
376 |
buttons = []
|
377 |
-
for
|
378 |
-
with gr.Column(scale=len(
|
379 |
-
buttons.append(gr.Button(
|
380 |
with gr.Row():
|
381 |
-
for
|
382 |
-
with gr.Column(scale=len(
|
383 |
-
buttons.append(gr.Button(
|
384 |
with gr.Row():
|
385 |
-
for
|
386 |
-
with gr.Column(scale=len(
|
387 |
-
buttons.append(gr.Button(
|
388 |
|
389 |
with gr.Column(scale=1):
|
390 |
slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps")
|
@@ -397,6 +448,12 @@ with gr.Blocks(css=css) as demo:
|
|
397 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
398 |
sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
|
401 |
def set_action(action):
|
402 |
return action
|
@@ -404,8 +461,8 @@ with gr.Blocks(css=css) as demo:
|
|
404 |
# gr.Markdown("### Action sequence examples.")
|
405 |
|
406 |
|
407 |
-
for button,
|
408 |
-
button.click(set_action, inputs=[gr.State(value=
|
409 |
|
410 |
|
411 |
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
|
@@ -418,6 +475,32 @@ with gr.Blocks(css=css) as demo:
|
|
418 |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
419 |
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
gr.Markdown(
|
422 |
"""
|
423 |
## Instructions & Notes:
|
@@ -441,14 +524,14 @@ with gr.Blocks(css=css) as demo:
|
|
441 |
"""
|
442 |
)
|
443 |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
444 |
-
submit_button.click(generate, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
445 |
-
reset_btn.click(reset, outputs=[log_output,
|
446 |
-
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display)
|
447 |
-
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display)
|
448 |
-
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display)
|
449 |
-
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display)
|
450 |
-
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display)
|
451 |
-
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display)
|
452 |
|
453 |
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
454 |
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|
|
|
70 |
"1": ("hotbar.1", 1),
|
71 |
}
|
72 |
|
73 |
+
example_images = [
|
74 |
+
["1", "assets/ice_plains.png", "turn right+go backward+look up+turn left+look down+turn right+go forward+turn left", 20, 3, 8],
|
75 |
+
["2", "assets/place.png", "put item+go backward+put item+go backward+go around", 20, 3, 8],
|
76 |
+
["3", "assets/rain_sunflower_plains.png", "turn right+look up+turn right+look down+turn left+go backward+turn left", 20, 3, 8],
|
77 |
+
["4", "assets/desert.png", "turn 360 degree+turn right+go forward+turn left", 20, 3, 8],
|
78 |
+
]
|
79 |
+
|
80 |
def load_custom_checkpoint(algo, checkpoint_path):
|
81 |
hf_ckpt = str(checkpoint_path).split('/')
|
82 |
repo_id = '/'.join(hf_ckpt[:2])
|
|
|
163 |
return model
|
164 |
|
165 |
memory_frames = []
|
|
|
166 |
input_history = ""
|
167 |
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
168 |
DESERT_IMAGE = "assets/desert.png"
|
|
|
172 |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
173 |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
174 |
|
|
|
175 |
device = torch.device('cuda')
|
176 |
|
177 |
def save_video(frames, path="output.mp4", fps=10):
|
|
|
198 |
actions = np.zeros((1, 25), dtype=np.float32)
|
199 |
poses = np.zeros((1, 5), dtype=np.float32)
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
|
203 |
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
|
|
|
238 |
print("set memory length to", worldmem.condition_similar_length)
|
239 |
return sampling_memory_length_state
|
240 |
|
241 |
+
def generate(keys, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
|
|
|
242 |
input_actions = parse_input_to_tensor(keys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
if self_frames is None:
|
245 |
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
|
|
271 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
272 |
save_video(out_video, temporal_video_path)
|
273 |
|
274 |
+
|
275 |
+
now = datetime.now()
|
276 |
+
folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
|
277 |
+
folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
|
278 |
+
os.makedirs(folder_path, exist_ok=True)
|
279 |
input_history += keys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
+
data_dict = {
|
282 |
+
"input_history": input_history,
|
283 |
+
"memory_frames": memory_frames,
|
284 |
+
"self_frames": self_frames,
|
285 |
+
"self_actions": self_actions,
|
286 |
+
"self_poses": self_poses,
|
287 |
+
"self_memory_c2w": self_memory_c2w,
|
288 |
+
"self_frame_idx": self_frame_idx,
|
289 |
+
}
|
290 |
+
|
291 |
+
np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
292 |
+
|
293 |
+
return out_video[-1], temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
294 |
+
|
295 |
+
def reset(selected_image):
|
296 |
self_frames = None
|
297 |
self_poses = None
|
298 |
self_actions = None
|
299 |
self_memory_c2w = None
|
300 |
self_frame_idx = None
|
301 |
+
memory_frames = load_image_as_tensor(selected_image).numpy()[None]
|
302 |
input_history = ""
|
303 |
|
304 |
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
|
|
311 |
self_memory_c2w=self_memory_c2w,
|
312 |
self_frame_idx=self_frame_idx)
|
313 |
|
314 |
+
return input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
315 |
+
|
316 |
+
def on_image_click(selected_image):
|
317 |
+
input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = reset(selected_image)
|
318 |
+
return input_history, selected_image, selected_image, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
319 |
+
|
320 |
+
def set_memory(examples_case, image_display, log_output, slider_denoising_step, slider_context_length, slider_memory_length):
|
321 |
+
if examples_case == '1':
|
322 |
+
data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-11_16-01-49/data_bundle.npz")
|
323 |
+
input_history = data_bundle['input_history'].item()
|
324 |
+
memory_frames = data_bundle['memory_frames']
|
325 |
+
self_frames = data_bundle['self_frames']
|
326 |
+
self_actions = data_bundle['self_actions']
|
327 |
+
self_poses = data_bundle['self_poses']
|
328 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
329 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
330 |
+
elif examples_case == '2':
|
331 |
+
data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-12_10-42-04/data_bundle.npz")
|
332 |
+
input_history = data_bundle['input_history'].item()
|
333 |
+
memory_frames = data_bundle['memory_frames']
|
334 |
+
self_frames = data_bundle['self_frames']
|
335 |
+
self_actions = data_bundle['self_actions']
|
336 |
+
self_poses = data_bundle['self_poses']
|
337 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
338 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
339 |
+
elif examples_case == '3':
|
340 |
+
data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-12_10-56-57/data_bundle.npz")
|
341 |
+
input_history = data_bundle['input_history'].item()
|
342 |
+
memory_frames = data_bundle['memory_frames']
|
343 |
+
self_frames = data_bundle['self_frames']
|
344 |
+
self_actions = data_bundle['self_actions']
|
345 |
+
self_poses = data_bundle['self_poses']
|
346 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
347 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
348 |
+
elif examples_case == '4':
|
349 |
+
data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-11_16-07-19/data_bundle.npz")
|
350 |
+
input_history = data_bundle['input_history'].item()
|
351 |
+
memory_frames = data_bundle['memory_frames']
|
352 |
+
self_frames = data_bundle['self_frames']
|
353 |
+
self_actions = data_bundle['self_actions']
|
354 |
+
self_poses = data_bundle['self_poses']
|
355 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
356 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
357 |
|
358 |
+
out_video = memory_frames.transpose(0,2,3,1)
|
359 |
+
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
360 |
+
out_video = (out_video * 255).astype(np.uint8)
|
|
|
|
|
361 |
|
362 |
+
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
363 |
+
save_video(out_video, temporal_video_path)
|
364 |
+
|
365 |
+
return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
366 |
|
367 |
css = """
|
368 |
h1 {
|
|
|
371 |
}
|
372 |
"""
|
373 |
|
374 |
+
def on_select(evt: gr.SelectData):
|
375 |
+
selected_index = evt.index
|
376 |
+
return examples[selected_index]
|
377 |
+
|
378 |
with gr.Blocks(css=css) as demo:
|
379 |
gr.Markdown(
|
380 |
"""
|
|
|
404 |
# </a>
|
405 |
# </div>
|
406 |
|
407 |
+
example_actions = {"turn left + turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
408 |
+
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
409 |
+
"turn right+go backward+look up+turn left+look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
410 |
+
"turn right+go forward+turn left": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
411 |
+
"turn right+look up+turn right+look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
|
412 |
+
"put item+go backward+put item+go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
|
413 |
|
414 |
+
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
415 |
|
416 |
with gr.Row(variant="panel"):
|
417 |
video_display = gr.Video(autoplay=True, loop=True)
|
418 |
+
image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
|
419 |
|
420 |
|
421 |
with gr.Row(variant="panel"):
|
|
|
425 |
gr.Markdown("### Action sequence examples.")
|
426 |
with gr.Row():
|
427 |
buttons = []
|
428 |
+
for action_key in list(example_actions.keys())[:2]:
|
429 |
+
with gr.Column(scale=len(action_key)):
|
430 |
+
buttons.append(gr.Button(action_key))
|
431 |
with gr.Row():
|
432 |
+
for action_key in list(example_actions.keys())[2:4]:
|
433 |
+
with gr.Column(scale=len(action_key)):
|
434 |
+
buttons.append(gr.Button(action_key))
|
435 |
with gr.Row():
|
436 |
+
for action_key in list(example_actions.keys())[4:6]:
|
437 |
+
with gr.Column(scale=len(action_key)):
|
438 |
+
buttons.append(gr.Button(action_key))
|
439 |
|
440 |
with gr.Column(scale=1):
|
441 |
slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps")
|
|
|
448 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
449 |
sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
|
450 |
|
451 |
+
memory_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
|
452 |
+
self_frames = gr.State()
|
453 |
+
self_actions = gr.State()
|
454 |
+
self_poses = gr.State()
|
455 |
+
self_memory_c2w = gr.State()
|
456 |
+
self_frame_idx = gr.State()
|
457 |
|
458 |
def set_action(action):
|
459 |
return action
|
|
|
461 |
# gr.Markdown("### Action sequence examples.")
|
462 |
|
463 |
|
464 |
+
for button, action_key in zip(buttons, list(example_actions.keys())):
|
465 |
+
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
466 |
|
467 |
|
468 |
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
|
|
|
475 |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
476 |
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
477 |
|
478 |
+
gr.Markdown("### Click the examples below for a quick review, and continue generating based on them.")
|
479 |
+
|
480 |
+
example_case = gr.Textbox(label="Case", visible=False)
|
481 |
+
image_output = gr.Image(visible=False)
|
482 |
+
|
483 |
+
# gr.Examples(examples=example_images,
|
484 |
+
# inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
485 |
+
# fn=set_memory,
|
486 |
+
# outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx],
|
487 |
+
# cache_examples=True
|
488 |
+
# )
|
489 |
+
|
490 |
+
examples = gr.Examples(
|
491 |
+
examples=example_images,
|
492 |
+
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
493 |
+
cache_examples=False
|
494 |
+
)
|
495 |
+
|
496 |
+
example_case.change(
|
497 |
+
fn=set_memory,
|
498 |
+
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
499 |
+
outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
|
500 |
+
)
|
501 |
+
|
502 |
+
|
503 |
+
|
504 |
gr.Markdown(
|
505 |
"""
|
506 |
## Instructions & Notes:
|
|
|
524 |
"""
|
525 |
)
|
526 |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
527 |
+
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
528 |
+
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
529 |
+
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
530 |
+
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
531 |
+
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
532 |
+
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
533 |
+
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
534 |
+
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
535 |
|
536 |
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
537 |
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|