xizaoqu
commited on
Commit
·
100414d
1
Parent(s):
f07d258
rm
Browse files- app.py +1 -25
- configurations/README.md +0 -7
- configurations/algorithm/base_algo.yaml +0 -3
- configurations/algorithm/base_pytorch_algo.yaml +0 -4
- configurations/algorithm/df_base.yaml +0 -42
- configurations/algorithm/df_video_worldmemminecraft.yaml +0 -42
- configurations/algorithm/pose_prediction.yaml +0 -19
- configurations/config.yaml +0 -16
- configurations/dataset/base_dataset.yaml +0 -3
- configurations/dataset/base_video.yaml +0 -14
- configurations/dataset/video_minecraft.yaml +0 -14
- configurations/dataset/video_minecraft_pose.yaml +0 -14
- configurations/experiment/base_experiment.yaml +0 -2
- configurations/experiment/base_pytorch.yaml +0 -50
- configurations/experiment/exp_pose.yaml +0 -31
- configurations/experiment/exp_video.yaml +0 -31
- datasets/README.md +0 -11
- datasets/__init__.py +0 -1
- datasets/video/__init__.py +0 -2
- datasets/video/base_video_dataset.py +0 -158
- datasets/video/minecraft_video_dataset.py +0 -262
- datasets/video/minecraft_video_dataset_oasis_filter.py +0 -99
- datasets/video/minecraft_video_dataset_pose.py +0 -421
- experiments/README.md +0 -19
- experiments/__init__.py +0 -35
- experiments/exp_base.py +0 -473
- experiments/exp_pose.py +0 -310
- experiments/exp_video.py +0 -25
- main.py +0 -219
- scripts/README.md +0 -10
- scripts/dummy_script.sh +0 -1
- split_checkpoint.py +0 -9
app.py
CHANGED
@@ -10,13 +10,8 @@ import hydra
|
|
10 |
from omegaconf import DictConfig, OmegaConf
|
11 |
from omegaconf.omegaconf import open_dict
|
12 |
|
13 |
-
from utils.print_utils import cyan
|
14 |
-
from utils.ckpt_utils import download_latest_checkpoint, is_run_id
|
15 |
-
from utils.cluster_utils import submit_slurm_job
|
16 |
-
from utils.distributed_utils import is_rank_zero
|
17 |
import numpy as np
|
18 |
import torch
|
19 |
-
from datasets.video.minecraft_video_dataset import *
|
20 |
import torchvision.transforms as transforms
|
21 |
import cv2
|
22 |
import subprocess
|
@@ -351,18 +346,7 @@ def set_memory(examples_case, image_display, log_output, slider_denoising_step,
|
|
351 |
|
352 |
return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
353 |
|
354 |
-
|
355 |
-
h1 {
|
356 |
-
text-align: center;
|
357 |
-
display:block;
|
358 |
-
}
|
359 |
-
"""
|
360 |
-
|
361 |
-
def on_select(evt: gr.SelectData):
|
362 |
-
selected_index = evt.index
|
363 |
-
return examples[selected_index]
|
364 |
-
|
365 |
-
with gr.Blocks(css=css) as demo:
|
366 |
gr.Markdown(
|
367 |
"""
|
368 |
# WORLDMEM: Long-term Consistent World Generation with Memory
|
@@ -515,13 +499,6 @@ with gr.Blocks(css=css) as demo:
|
|
515 |
example_case = gr.Textbox(label="Case", visible=False)
|
516 |
image_output = gr.Image(visible=False)
|
517 |
|
518 |
-
# gr.Examples(examples=example_images,
|
519 |
-
# inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
520 |
-
# fn=set_memory,
|
521 |
-
# outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx],
|
522 |
-
# cache_examples=True
|
523 |
-
# )
|
524 |
-
|
525 |
examples = gr.Examples(
|
526 |
examples=example_images,
|
527 |
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
@@ -534,7 +511,6 @@ with gr.Blocks(css=css) as demo:
|
|
534 |
outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
|
535 |
)
|
536 |
|
537 |
-
|
538 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
539 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
540 |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
|
|
10 |
from omegaconf import DictConfig, OmegaConf
|
11 |
from omegaconf.omegaconf import open_dict
|
12 |
|
|
|
|
|
|
|
|
|
13 |
import numpy as np
|
14 |
import torch
|
|
|
15 |
import torchvision.transforms as transforms
|
16 |
import cv2
|
17 |
import subprocess
|
|
|
346 |
|
347 |
return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
348 |
|
349 |
+
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
gr.Markdown(
|
351 |
"""
|
352 |
# WORLDMEM: Long-term Consistent World Generation with Memory
|
|
|
499 |
example_case = gr.Textbox(label="Case", visible=False)
|
500 |
image_output = gr.Image(visible=False)
|
501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
examples = gr.Examples(
|
503 |
examples=example_images,
|
504 |
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
|
|
511 |
outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
|
512 |
)
|
513 |
|
|
|
514 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
515 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
516 |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
configurations/README.md
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
# configurations
|
2 |
-
|
3 |
-
We use [Hydra](https://hydra.cc/docs/intro/) to manage configurations. Change/Add the yaml files in this folder
|
4 |
-
to change the default configurations. You can also override the default configurations by
|
5 |
-
passing command line arguments.
|
6 |
-
|
7 |
-
All configurations are automatically saved in wandb run.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/base_algo.yaml
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class
|
2 |
-
|
3 |
-
debug: ${debug} # inherited from configurations/config.yaml
|
|
|
|
|
|
|
|
configurations/algorithm/base_pytorch_algo.yaml
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_algo # inherits from configurations/algorithm/base_algo.yaml
|
3 |
-
|
4 |
-
lr: ${experiment.training.lr}
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/df_base.yaml
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_pytorch_algo
|
3 |
-
|
4 |
-
# dataset-dependent configurations
|
5 |
-
x_shape: ${dataset.observation_shape}
|
6 |
-
frame_stack: 1
|
7 |
-
frame_skip: 1
|
8 |
-
data_mean: ${dataset.data_mean}
|
9 |
-
data_std: ${dataset.data_std}
|
10 |
-
external_cond_dim: 0 #${dataset.action_dim}
|
11 |
-
context_frames: ${dataset.context_length}
|
12 |
-
# training hyperparameters
|
13 |
-
weight_decay: 1e-4
|
14 |
-
warmup_steps: 10000
|
15 |
-
optimizer_beta: [0.9, 0.999]
|
16 |
-
# diffusion-related
|
17 |
-
uncertainty_scale: 1
|
18 |
-
guidance_scale: 0.0
|
19 |
-
chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
|
20 |
-
scheduling_matrix: autoregressive
|
21 |
-
noise_level: random_all
|
22 |
-
causal: True
|
23 |
-
|
24 |
-
diffusion:
|
25 |
-
# training
|
26 |
-
objective: pred_x0
|
27 |
-
beta_schedule: cosine
|
28 |
-
schedule_fn_kwargs: {}
|
29 |
-
clip_noise: 20.0
|
30 |
-
use_snr: False
|
31 |
-
use_cum_snr: False
|
32 |
-
use_fused_snr: False
|
33 |
-
snr_clip: 5.0
|
34 |
-
cum_snr_decay: 0.98
|
35 |
-
timesteps: 1000
|
36 |
-
# sampling
|
37 |
-
sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased
|
38 |
-
ddim_sampling_eta: 1.0
|
39 |
-
stabilization_level: 10
|
40 |
-
# architecture
|
41 |
-
architecture:
|
42 |
-
network_size: 64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/df_video_worldmemminecraft.yaml
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- df_base
|
3 |
-
|
4 |
-
n_frames: ${dataset.n_frames}
|
5 |
-
frame_skip: ${dataset.frame_skip}
|
6 |
-
metadata: ${dataset.metadata}
|
7 |
-
|
8 |
-
# training hyperparameters
|
9 |
-
weight_decay: 2e-3
|
10 |
-
warmup_steps: 10000
|
11 |
-
optimizer_beta: [0.9, 0.99]
|
12 |
-
action_cond_dim: 25
|
13 |
-
|
14 |
-
diffusion:
|
15 |
-
# training
|
16 |
-
beta_schedule: sigmoid
|
17 |
-
objective: pred_v
|
18 |
-
use_fused_snr: True
|
19 |
-
cum_snr_decay: 0.96
|
20 |
-
clip_noise: 20.
|
21 |
-
# sampling
|
22 |
-
sampling_timesteps: 20
|
23 |
-
ddim_sampling_eta: 0.0
|
24 |
-
stabilization_level: 15
|
25 |
-
# architecture
|
26 |
-
architecture:
|
27 |
-
network_size: 64
|
28 |
-
attn_heads: 4
|
29 |
-
attn_dim_head: 64
|
30 |
-
dim_mults: [1, 2, 4, 8]
|
31 |
-
resolution: ${dataset.resolution}
|
32 |
-
attn_resolutions: [16, 32, 64, 128]
|
33 |
-
use_init_temporal_attn: True
|
34 |
-
use_linear_attn: True
|
35 |
-
time_emb_type: rotary
|
36 |
-
|
37 |
-
metrics:
|
38 |
-
# - fvd
|
39 |
-
# - fid
|
40 |
-
# - lpips
|
41 |
-
|
42 |
-
_name: df_video_worldmemminecraft
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/algorithm/pose_prediction.yaml
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- df_base
|
3 |
-
|
4 |
-
n_frames: ${dataset.n_frames}
|
5 |
-
frame_skip: ${dataset.frame_skip}
|
6 |
-
metadata: ${dataset.metadata}
|
7 |
-
|
8 |
-
# training hyperparameters
|
9 |
-
weight_decay: 2e-3
|
10 |
-
warmup_steps: 10000
|
11 |
-
optimizer_beta: [0.9, 0.99]
|
12 |
-
|
13 |
-
|
14 |
-
metrics:
|
15 |
-
# - fvd
|
16 |
-
# - fid
|
17 |
-
# - lpips
|
18 |
-
|
19 |
-
_name: pose_prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/config.yaml
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
# configuration parsing starts here
|
2 |
-
defaults:
|
3 |
-
- experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme]
|
4 |
-
- dataset: video_minecraft_oasis # dataset yaml file name in configurations/dataset folder [fixme]
|
5 |
-
- algorithm: df_video # algorithm yaml file name in configurations/algorithm folder [fixme]
|
6 |
-
- cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute
|
7 |
-
|
8 |
-
debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
|
9 |
-
|
10 |
-
wandb:
|
11 |
-
entity: xizaoqu # wandb account name / organization name [fixme]
|
12 |
-
project: diffusion-forcing # wandb project name; if not provided, defaults to root folder name [fixme]
|
13 |
-
mode: online # set wandb logging to online, offline or dryrun
|
14 |
-
|
15 |
-
resume: null # wandb run id to resume logging and loading checkpoint from
|
16 |
-
load: null # wandb run id containing checkpoint or a path to a checkpoint file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/dataset/base_dataset.yaml
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class
|
2 |
-
|
3 |
-
debug: ${debug} # inherited from configurations/config.yaml
|
|
|
|
|
|
|
|
configurations/dataset/base_video.yaml
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_dataset
|
3 |
-
|
4 |
-
metadata: "data/${dataset.name}/metadata.json"
|
5 |
-
data_mean: "data/${dataset.name}/data_mean.npy"
|
6 |
-
data_std: "data/${dataset.name}/data_std.npy"
|
7 |
-
save_dir: ???
|
8 |
-
n_frames: 32
|
9 |
-
context_length: 4
|
10 |
-
resolution: 128
|
11 |
-
observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"]
|
12 |
-
external_cond_dim: 0
|
13 |
-
validation_multiplier: 1
|
14 |
-
frame_skip: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/dataset/video_minecraft.yaml
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_video
|
3 |
-
|
4 |
-
save_dir: data/minecraft_simple_backforward
|
5 |
-
n_frames: 16 # TODO: increase later
|
6 |
-
resolution: 128
|
7 |
-
data_mean: 0.5
|
8 |
-
data_std: 0.5
|
9 |
-
action_cond_dim: 25
|
10 |
-
context_length: 1
|
11 |
-
frame_skip: 1
|
12 |
-
validation_multiplier: 1
|
13 |
-
|
14 |
-
_name: video_minecraft_oasis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/dataset/video_minecraft_pose.yaml
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_video
|
3 |
-
|
4 |
-
save_dir: data/minecraft_simple_backforward
|
5 |
-
n_frames: 16 # TODO: increase later
|
6 |
-
resolution: 128
|
7 |
-
data_mean: 0.5
|
8 |
-
data_std: 0.5
|
9 |
-
external_cond_dim: 25
|
10 |
-
context_length: 1
|
11 |
-
frame_skip: 1
|
12 |
-
validation_multiplier: 1
|
13 |
-
|
14 |
-
_name: video_minecraft_pose
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/experiment/base_experiment.yaml
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
debug: ${debug} # inherited from configurations/config.yaml
|
2 |
-
tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them.
|
|
|
|
|
|
configurations/experiment/base_pytorch.yaml
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
# inherites from base_experiment.yaml
|
2 |
-
# most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html
|
3 |
-
|
4 |
-
defaults:
|
5 |
-
- base_experiment
|
6 |
-
|
7 |
-
tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them.
|
8 |
-
num_nodes: 1 # number of gpu servers used in large scale distributed training
|
9 |
-
|
10 |
-
training:
|
11 |
-
precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable
|
12 |
-
compile: False # whether to compile the model with torch.compile
|
13 |
-
lr: 0.001 # learning rate
|
14 |
-
batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training
|
15 |
-
max_epochs: 1000 # set to -1 to train forever
|
16 |
-
max_steps: -1 # set to -1 to train forever, will override max_epochs
|
17 |
-
max_time: null # set to something like "00:12:00:00" to enable
|
18 |
-
data:
|
19 |
-
num_workers: 4 # number of CPU threads for data preprocessing.
|
20 |
-
shuffle: True # whether training data will be shuffled
|
21 |
-
optim:
|
22 |
-
accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop
|
23 |
-
gradient_clip_val: 0 # clip gradients with norm above this value, set to 0 to disable
|
24 |
-
checkpointing:
|
25 |
-
# these are arguments to pytorch lightning's callback, `ModelCheckpoint` class
|
26 |
-
every_n_train_steps: 5000 # save a checkpoint every n train steps
|
27 |
-
every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``
|
28 |
-
train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
29 |
-
enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones.
|
30 |
-
|
31 |
-
validation:
|
32 |
-
precision: 16-mixed
|
33 |
-
compile: False # whether to compile the model with torch.compile
|
34 |
-
batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
35 |
-
val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
|
36 |
-
val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
|
37 |
-
limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
|
38 |
-
inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
|
39 |
-
data:
|
40 |
-
num_workers: 4 # number of CPU threads for data preprocessing, for validation.
|
41 |
-
shuffle: False # whether validation data will be shuffled
|
42 |
-
|
43 |
-
test:
|
44 |
-
precision: 16-mixed
|
45 |
-
compile: False # whether to compile the model with torch.compile
|
46 |
-
batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
47 |
-
limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
|
48 |
-
data:
|
49 |
-
num_workers: 4 # number of CPU threads for data preprocessing, for test.
|
50 |
-
shuffle: False # whether test data will be shuffled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/experiment/exp_pose.yaml
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_pytorch
|
3 |
-
|
4 |
-
tasks: [training]
|
5 |
-
|
6 |
-
training:
|
7 |
-
lr: 8e-5
|
8 |
-
precision: 16-mixed
|
9 |
-
batch_size: 4
|
10 |
-
max_epochs: -1
|
11 |
-
max_steps: 2000005
|
12 |
-
checkpointing:
|
13 |
-
every_n_train_steps: 2500
|
14 |
-
optim:
|
15 |
-
gradient_clip_val: 1.0
|
16 |
-
|
17 |
-
validation:
|
18 |
-
val_every_n_step: 300
|
19 |
-
val_every_n_epoch: null
|
20 |
-
batch_size: 4
|
21 |
-
limit_batch: 1
|
22 |
-
|
23 |
-
test:
|
24 |
-
limit_batch: 1
|
25 |
-
batch_size: 1
|
26 |
-
|
27 |
-
logging:
|
28 |
-
metrics:
|
29 |
-
# - fvd
|
30 |
-
# - fid
|
31 |
-
# - lpips
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configurations/experiment/exp_video.yaml
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_pytorch
|
3 |
-
|
4 |
-
tasks: [training]
|
5 |
-
|
6 |
-
training:
|
7 |
-
lr: 8e-5
|
8 |
-
precision: 16-mixed
|
9 |
-
batch_size: 4
|
10 |
-
max_epochs: -1
|
11 |
-
max_steps: 2000005
|
12 |
-
checkpointing:
|
13 |
-
every_n_train_steps: 2500
|
14 |
-
optim:
|
15 |
-
gradient_clip_val: 1.0
|
16 |
-
|
17 |
-
validation:
|
18 |
-
val_every_n_step: 300
|
19 |
-
val_every_n_epoch: null
|
20 |
-
batch_size: 4
|
21 |
-
limit_batch: 1
|
22 |
-
|
23 |
-
test:
|
24 |
-
limit_batch: 1
|
25 |
-
batch_size: 1
|
26 |
-
|
27 |
-
logging:
|
28 |
-
metrics:
|
29 |
-
# - fvd
|
30 |
-
# - fid
|
31 |
-
# - lpips
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/README.md
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
The `datasets` folder is used to contain dataset code or environment code.
|
2 |
-
Don't store actual data like images here! For those, please use the `data` folder instead of `datasets`.
|
3 |
-
|
4 |
-
Create a folder to create your own pytorch dataset definition. Then, update the `__init__.py`
|
5 |
-
at every level to register all datasets.
|
6 |
-
|
7 |
-
Each dataset class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/dataset` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/).
|
8 |
-
|
9 |
-
---
|
10 |
-
|
11 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .video import MinecraftVideoDataset
|
|
|
|
datasets/video/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .minecraft_video_dataset import MinecraftVideoDataset
|
2 |
-
from .minecraft_video_dataset_pose import MinecraftVideoPoseDataset
|
|
|
|
|
|
datasets/video/base_video_dataset.py
DELETED
@@ -1,158 +0,0 @@
|
|
1 |
-
from typing import Sequence
|
2 |
-
import torch
|
3 |
-
import random
|
4 |
-
import os
|
5 |
-
import numpy as np
|
6 |
-
import cv2
|
7 |
-
from omegaconf import DictConfig
|
8 |
-
from torchvision import transforms
|
9 |
-
from pathlib import Path
|
10 |
-
from abc import abstractmethod, ABC
|
11 |
-
import json
|
12 |
-
|
13 |
-
|
14 |
-
class BaseVideoDataset(torch.utils.data.Dataset, ABC):
|
15 |
-
"""
|
16 |
-
Base class for video datasets. Videos may be of variable length.
|
17 |
-
|
18 |
-
Folder structure of each dataset:
|
19 |
-
- [save_dir] (specified in config, e.g., data/phys101)
|
20 |
-
- /[split] (one per split)
|
21 |
-
- /data_folder_name (e.g., videos)
|
22 |
-
metadata.json
|
23 |
-
"""
|
24 |
-
|
25 |
-
def __init__(self, cfg: DictConfig, split: str = "training"):
|
26 |
-
super().__init__()
|
27 |
-
self.cfg = cfg
|
28 |
-
self.split = split
|
29 |
-
self.resolution = cfg.resolution
|
30 |
-
self.external_cond_dim = cfg.external_cond_dim
|
31 |
-
self.n_frames = (
|
32 |
-
cfg.n_frames * cfg.frame_skip
|
33 |
-
if split == "training"
|
34 |
-
else cfg.n_frames * cfg.frame_skip * cfg.validation_multiplier
|
35 |
-
)
|
36 |
-
self.frame_skip = cfg.frame_skip
|
37 |
-
self.save_dir = Path(cfg.save_dir)
|
38 |
-
self.save_dir.mkdir(exist_ok=True, parents=True)
|
39 |
-
self.split_dir = self.save_dir / f"{split}"
|
40 |
-
|
41 |
-
self.metadata_path = self.save_dir / "metadata.json"
|
42 |
-
|
43 |
-
self.data_paths = self.get_data_paths(self.split)
|
44 |
-
|
45 |
-
if self.split == 'training':
|
46 |
-
self.metadata = [1200] * len(self.data_paths) # total 1500 f
|
47 |
-
else:
|
48 |
-
self.metadata = [1] * len(self.data_paths) # total 1500 f
|
49 |
-
# self.clips_per_video = np.clip(np.array(self.metadata[split]) - self.n_frames + 1, a_min=1, a_max=None).astype(
|
50 |
-
# np.int32
|
51 |
-
# )
|
52 |
-
self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype(
|
53 |
-
np.int32
|
54 |
-
)
|
55 |
-
self.cum_clips_per_video = np.cumsum(self.clips_per_video)
|
56 |
-
self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True)
|
57 |
-
|
58 |
-
# shuffle but keep the same order for each epoch, so validation sample is diverse yet deterministic
|
59 |
-
random.seed(0)
|
60 |
-
self.idx_remap = list(range(self.__len__()))
|
61 |
-
random.shuffle(self.idx_remap)
|
62 |
-
|
63 |
-
@abstractmethod
|
64 |
-
def download_dataset(self) -> Sequence[int]:
|
65 |
-
"""
|
66 |
-
Download dataset from the internet and build it in save_dir
|
67 |
-
|
68 |
-
Returns a list of video lengths
|
69 |
-
"""
|
70 |
-
raise NotImplementedError
|
71 |
-
|
72 |
-
@abstractmethod
|
73 |
-
def get_data_paths(self, split):
|
74 |
-
"""Return a list of data paths (e.g. xxx.mp4) for a given split"""
|
75 |
-
raise NotImplementedError
|
76 |
-
|
77 |
-
def get_data_lengths(self, split):
|
78 |
-
"""Return a list of num_frames for each data path (e.g. xxx.mp4) for a given split"""
|
79 |
-
lengths = []
|
80 |
-
for path in self.get_data_paths(split):
|
81 |
-
length = cv2.VideoCapture(str(path)).get(cv2.CAP_PROP_FRAME_COUNT)
|
82 |
-
lengths.append(length)
|
83 |
-
return lengths
|
84 |
-
|
85 |
-
def split_idx(self, idx):
|
86 |
-
video_idx = np.argmax(self.cum_clips_per_video > idx)
|
87 |
-
frame_idx = idx - np.pad(self.cum_clips_per_video, (1, 0))[video_idx]
|
88 |
-
return video_idx, frame_idx
|
89 |
-
|
90 |
-
@staticmethod
|
91 |
-
def load_video(path: Path):
|
92 |
-
"""
|
93 |
-
Load video from a path
|
94 |
-
:param filename: path to the video
|
95 |
-
:return: video as a numpy array
|
96 |
-
"""
|
97 |
-
|
98 |
-
cap = cv2.VideoCapture(str(path))
|
99 |
-
|
100 |
-
frames = []
|
101 |
-
while cap.isOpened():
|
102 |
-
ret, frame = cap.read()
|
103 |
-
if ret:
|
104 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
105 |
-
frames.append(frame)
|
106 |
-
else:
|
107 |
-
break
|
108 |
-
|
109 |
-
cap.release()
|
110 |
-
frames = np.stack(frames, dtype=np.uint8)
|
111 |
-
return np.transpose(frames, (0, 3, 1, 2)) # (T, C, H, W)
|
112 |
-
|
113 |
-
@staticmethod
|
114 |
-
def load_image(filename: Path):
|
115 |
-
"""
|
116 |
-
Load image from a path
|
117 |
-
:param filename: path to the image
|
118 |
-
:return: image as a numpy array
|
119 |
-
"""
|
120 |
-
image = cv2.imread(str(filename))
|
121 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
122 |
-
return np.transpose(image, (2, 0, 1))
|
123 |
-
|
124 |
-
def __len__(self):
|
125 |
-
return self.clips_per_video.sum()
|
126 |
-
|
127 |
-
def __getitem__(self, idx):
|
128 |
-
idx = self.idx_remap[idx]
|
129 |
-
video_idx, frame_idx = self.split_idx(idx)
|
130 |
-
video_path = self.data_paths[video_idx]
|
131 |
-
video = self.load_video(video_path)[frame_idx : frame_idx + self.n_frames]
|
132 |
-
|
133 |
-
pad_len = self.n_frames - len(video)
|
134 |
-
|
135 |
-
nonterminal = np.ones(self.n_frames)
|
136 |
-
if len(video) < self.n_frames:
|
137 |
-
video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
|
138 |
-
nonterminal[-pad_len:] = 0
|
139 |
-
|
140 |
-
video = torch.from_numpy(video / 256.0).float()
|
141 |
-
video = self.transform(video)
|
142 |
-
|
143 |
-
if self.external_cond_dim:
|
144 |
-
external_cond = np.load(
|
145 |
-
# pylint: disable=no-member
|
146 |
-
self.condition_dir
|
147 |
-
/ f"{video_path.name.replace('.mp4', '.npy')}"
|
148 |
-
)
|
149 |
-
if len(external_cond) < self.n_frames:
|
150 |
-
external_cond = np.pad(external_cond, ((0, pad_len),))
|
151 |
-
external_cond = torch.from_numpy(external_cond).float()
|
152 |
-
return (
|
153 |
-
video[:: self.frame_skip],
|
154 |
-
external_cond[:: self.frame_skip],
|
155 |
-
nonterminal[:: self.frame_skip],
|
156 |
-
)
|
157 |
-
else:
|
158 |
-
return video[:: self.frame_skip], nonterminal[:: self.frame_skip]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/video/minecraft_video_dataset.py
DELETED
@@ -1,262 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import io
|
3 |
-
import tarfile
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from typing import Sequence, Mapping
|
7 |
-
from omegaconf import DictConfig
|
8 |
-
from pytorchvideo.data.encoded_video import EncodedVideo
|
9 |
-
import random
|
10 |
-
|
11 |
-
from .base_video_dataset import BaseVideoDataset
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
ACTION_KEYS = [
|
17 |
-
"inventory",
|
18 |
-
"ESC",
|
19 |
-
"hotbar.1",
|
20 |
-
"hotbar.2",
|
21 |
-
"hotbar.3",
|
22 |
-
"hotbar.4",
|
23 |
-
"hotbar.5",
|
24 |
-
"hotbar.6",
|
25 |
-
"hotbar.7",
|
26 |
-
"hotbar.8",
|
27 |
-
"hotbar.9",
|
28 |
-
"forward",
|
29 |
-
"back",
|
30 |
-
"left",
|
31 |
-
"right",
|
32 |
-
"cameraY",
|
33 |
-
"cameraX",
|
34 |
-
"jump",
|
35 |
-
"sneak",
|
36 |
-
"sprint",
|
37 |
-
"swapHands",
|
38 |
-
"attack",
|
39 |
-
"use",
|
40 |
-
"pickItem",
|
41 |
-
"drop",
|
42 |
-
]
|
43 |
-
|
44 |
-
def convert_action_space(actions):
|
45 |
-
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
|
46 |
-
vec_25[actions[:,0]==1, 11] = 1
|
47 |
-
vec_25[actions[:,0]==2, 12] = 1
|
48 |
-
vec_25[actions[:,4]==11, 16] = -1
|
49 |
-
vec_25[actions[:,4]==13, 16] = 1
|
50 |
-
vec_25[actions[:,3]==11, 15] = -1
|
51 |
-
vec_25[actions[:,3]==13, 15] = 1
|
52 |
-
vec_25[actions[:,5]==6, 24] = 1
|
53 |
-
vec_25[actions[:,5]==1, 24] = 1
|
54 |
-
vec_25[actions[:,1]==1, 13] = 1
|
55 |
-
vec_25[actions[:,1]==2, 14] = 1
|
56 |
-
vec_25[actions[:,7]==1, 2] = 1
|
57 |
-
return vec_25
|
58 |
-
|
59 |
-
# Dataset class
|
60 |
-
class MinecraftVideoDataset(BaseVideoDataset):
|
61 |
-
"""
|
62 |
-
Minecraft video dataset for training and validation.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
cfg (DictConfig): Configuration object.
|
66 |
-
split (str): Dataset split ("training" or "validation").
|
67 |
-
"""
|
68 |
-
def __init__(self, cfg: DictConfig, split: str = "training"):
|
69 |
-
if split == "test":
|
70 |
-
split = "validation"
|
71 |
-
super().__init__(cfg, split)
|
72 |
-
self.n_frames = cfg.n_frames_valid if split == "validation" and hasattr(cfg, "n_frames_valid") else cfg.n_frames
|
73 |
-
self.use_plucker = cfg.use_plucker
|
74 |
-
self.condition_similar_length = cfg.condition_similar_length
|
75 |
-
self.customized_validation = cfg.customized_validation
|
76 |
-
self.angle_range = cfg.angle_range
|
77 |
-
self.pos_range = cfg.pos_range
|
78 |
-
self.add_frame_timestep_embedder = cfg.add_frame_timestep_embedder
|
79 |
-
self.training_dropout = 0.1
|
80 |
-
self.sample_more_place = getattr(cfg, "sample_more_place", False)
|
81 |
-
self.within_context = getattr(cfg, "within_context", False)
|
82 |
-
self.sample_more_event = getattr(cfg, "sample_more_event", False)
|
83 |
-
self.causal_frame = getattr(cfg, "causal_frame", False)
|
84 |
-
|
85 |
-
def get_data_paths(self, split: str):
|
86 |
-
"""
|
87 |
-
Retrieve all video file paths for the given split.
|
88 |
-
|
89 |
-
Args:
|
90 |
-
split (str): Dataset split ("training" or "validation").
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
List[Path]: List of video file paths.
|
94 |
-
"""
|
95 |
-
data_dir = self.save_dir / split
|
96 |
-
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
97 |
-
if not paths:
|
98 |
-
sub_dirs = os.listdir(data_dir)
|
99 |
-
for sub_dir in sub_dirs:
|
100 |
-
sub_path = data_dir / sub_dir
|
101 |
-
paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name)
|
102 |
-
return paths
|
103 |
-
|
104 |
-
def download_dataset(self):
|
105 |
-
pass
|
106 |
-
|
107 |
-
def __getitem__(self, idx: int):
|
108 |
-
"""
|
109 |
-
Retrieve a single data sample by index.
|
110 |
-
|
111 |
-
Args:
|
112 |
-
idx (int): Index of the data sample.
|
113 |
-
|
114 |
-
Returns:
|
115 |
-
Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timesteps.
|
116 |
-
"""
|
117 |
-
max_retries = 1000
|
118 |
-
for _ in range(max_retries):
|
119 |
-
try:
|
120 |
-
return self.load_data(idx)
|
121 |
-
except Exception as e:
|
122 |
-
print(f"Retrying due to error: {e}")
|
123 |
-
idx = (idx + 1) % len(self)
|
124 |
-
|
125 |
-
def load_data(self, idx):
|
126 |
-
idx = self.idx_remap[idx]
|
127 |
-
file_idx, frame_idx = self.split_idx(idx)
|
128 |
-
action_path = self.data_paths[file_idx]
|
129 |
-
video_path = self.data_paths[file_idx]
|
130 |
-
|
131 |
-
action_path = video_path.with_suffix(".npz")
|
132 |
-
actions_pool = np.load(action_path)['actions']
|
133 |
-
poses_pool = np.load(action_path)['poses']
|
134 |
-
|
135 |
-
|
136 |
-
poses_pool[0,1] = poses_pool[1,1] # wrong first in place
|
137 |
-
|
138 |
-
assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
|
139 |
-
|
140 |
-
|
141 |
-
if len(poses_pool) < len(actions_pool):
|
142 |
-
poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
|
143 |
-
|
144 |
-
actions_pool = convert_action_space(actions_pool)
|
145 |
-
video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
|
146 |
-
|
147 |
-
frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
|
148 |
-
|
149 |
-
if self.split == "validation":
|
150 |
-
frame_idx = 240
|
151 |
-
|
152 |
-
if self.sample_more_place and self.split == "training":
|
153 |
-
if random.uniform(0, 1) > 0.5:
|
154 |
-
place_mask = (actions_pool[:,24]==1)
|
155 |
-
place_mask[:100] = 0
|
156 |
-
valid_indices = np.where(place_mask)[0]
|
157 |
-
random_index = np.random.choice(valid_indices)
|
158 |
-
frame_idx = random_index - random.randint(1, self.n_frames-1)
|
159 |
-
|
160 |
-
total_frame = video_raw.duration.numerator
|
161 |
-
fps = 10 # video_raw.duration.denominator
|
162 |
-
total_frame = total_frame * fps / video_raw.duration.denominator
|
163 |
-
video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
|
164 |
-
video = video.permute(1, 2, 3, 0).numpy()
|
165 |
-
|
166 |
-
if self.split != "validation" and 'degrees' in np.load(action_path).keys():
|
167 |
-
degrees = np.load(action_path)['degrees']
|
168 |
-
actions_pool[:,16] *= degrees
|
169 |
-
|
170 |
-
actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames])
|
171 |
-
|
172 |
-
poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
|
173 |
-
pad_len = self.n_frames - len(video)
|
174 |
-
poses_pool[:,:3] -= poses[:1,:3]
|
175 |
-
poses_pool[:,-1] = -poses_pool[:,-1]
|
176 |
-
poses_pool[:,3:] %= 360
|
177 |
-
|
178 |
-
poses[:,:3] -= poses[:1,:3] # do not normalize angle
|
179 |
-
poses[:,-1] = -poses[:,-1]
|
180 |
-
poses[:,3:] %= 360
|
181 |
-
|
182 |
-
assert len(video) >= self.n_frames, f"{video_path}"
|
183 |
-
|
184 |
-
if self.split == "training" and self.condition_similar_length>0:
|
185 |
-
if random.uniform(0, 1) > self.training_dropout:
|
186 |
-
refer_frame_dis = poses[:,None] - poses_pool[None,:]
|
187 |
-
refer_frame_dis = np.abs(refer_frame_dis)
|
188 |
-
refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180] = 360 - refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180]
|
189 |
-
valid_index = ((((refer_frame_dis[..., :3] <= self.pos_range).sum(-1))>=3) & (((refer_frame_dis[..., 3:] <= self.angle_range).sum(-1))>=2) & \
|
190 |
-
((((refer_frame_dis[..., :3] > 0).sum(-1))>=1) | (((refer_frame_dis[..., 3:] > 0).sum(-1))>=1))
|
191 |
-
).sum(0)
|
192 |
-
valid_index[:100] = 0 # mute bad initial scene
|
193 |
-
|
194 |
-
if self.add_frame_timestep_embedder and self.causal_frame and (actions_pool[:frame_idx,24]==1).sum() > 0:
|
195 |
-
valid_index[frame_idx:] = 0
|
196 |
-
|
197 |
-
mask = valid_index >= 1
|
198 |
-
mask[0] = False
|
199 |
-
candidate_indices = np.argwhere(mask)
|
200 |
-
|
201 |
-
mask2 = valid_index >= 0
|
202 |
-
mask2[0] = False
|
203 |
-
|
204 |
-
count = min(self.condition_similar_length, candidate_indices.shape[0])
|
205 |
-
selected_indices = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:,0]
|
206 |
-
|
207 |
-
if count < self.condition_similar_length:
|
208 |
-
candidate_indices2 = np.argwhere(mask2)
|
209 |
-
selected_indices2 = candidate_indices2[np.random.choice(candidate_indices2.shape[0], self.condition_similar_length-count, replace=True)][:,0]
|
210 |
-
selected_indices = np.concatenate([selected_indices, selected_indices2])
|
211 |
-
|
212 |
-
if self.sample_more_event:
|
213 |
-
if random.uniform(0, 1) > 0.3:
|
214 |
-
valid_idx = torch.nonzero(actions_pool[:frame_idx,24]==1)[:,0]
|
215 |
-
if len(valid_idx) > self.condition_similar_length //2:
|
216 |
-
valid_idx = valid_idx[-self.condition_similar_length //2:]
|
217 |
-
|
218 |
-
if len(valid_idx) > 0:
|
219 |
-
selected_indices[-len(valid_idx):] = valid_idx + 4
|
220 |
-
|
221 |
-
else:
|
222 |
-
selected_indices = np.array(list(range(self.condition_similar_length))) * 0 + random.randint(0, frame_idx)
|
223 |
-
|
224 |
-
video_pool = []
|
225 |
-
for si in selected_indices:
|
226 |
-
video_pool.append(video_raw.get_clip(start_sec=si/fps, end_sec=(si+1)/fps)["video"][:,0].permute(1,2,0))
|
227 |
-
|
228 |
-
video_pool = np.stack(video_pool)
|
229 |
-
video = np.concatenate([video, video_pool])
|
230 |
-
actions = np.concatenate([actions, actions_pool[selected_indices]])
|
231 |
-
poses = np.concatenate([poses, poses_pool[selected_indices]])
|
232 |
-
|
233 |
-
timestep = np.concatenate([np.array(list(range(frame_idx, frame_idx + self.n_frames))), selected_indices])
|
234 |
-
|
235 |
-
else:
|
236 |
-
timestep = np.array(list(range(self.n_frames)))
|
237 |
-
|
238 |
-
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
|
239 |
-
|
240 |
-
if self.split == "validation" and not self.customized_validation:
|
241 |
-
num_frame = actions.shape[0]
|
242 |
-
|
243 |
-
actions[:] = 0
|
244 |
-
actions[:,16] = 1
|
245 |
-
poses[:] = 0
|
246 |
-
for ff in range(1, num_frame):
|
247 |
-
poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15
|
248 |
-
|
249 |
-
if self.within_context:
|
250 |
-
actions[:] = 0
|
251 |
-
actions[:self.n_frames//2+1,16] = 1
|
252 |
-
actions[self.n_frames//2+1:,16] = -1
|
253 |
-
poses[:] = 0
|
254 |
-
for ff in range(1, num_frame):
|
255 |
-
poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15
|
256 |
-
|
257 |
-
return (
|
258 |
-
video[:: self.frame_skip],
|
259 |
-
actions[:: self.frame_skip],
|
260 |
-
poses[:: self.frame_skip],
|
261 |
-
timestep
|
262 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/video/minecraft_video_dataset_oasis_filter.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from typing import Sequence
|
3 |
-
import numpy as np
|
4 |
-
import io
|
5 |
-
from omegaconf import DictConfig
|
6 |
-
from tqdm import tqdm
|
7 |
-
|
8 |
-
from typing import Mapping, Sequence
|
9 |
-
import os
|
10 |
-
import math
|
11 |
-
from packaging import version as pver
|
12 |
-
from PIL import Image
|
13 |
-
import random
|
14 |
-
import shutil
|
15 |
-
import os
|
16 |
-
from pathlib import Path
|
17 |
-
import traceback
|
18 |
-
|
19 |
-
class OASISMinecraftVideoFilterDataset(torch.utils.data.Dataset):
|
20 |
-
"""
|
21 |
-
Minecraft dataset
|
22 |
-
"""
|
23 |
-
|
24 |
-
def __init__(self, source_dir, target_dir, split):
|
25 |
-
self.source_dir = Path(source_dir)
|
26 |
-
self.split_dir = self.source_dir / f"{split}"
|
27 |
-
self.data_paths = self.get_data_paths(split)
|
28 |
-
self.target_dir = Path(target_dir) / f"{split}"
|
29 |
-
self.target_dir.mkdir(exist_ok=True, parents=True)
|
30 |
-
|
31 |
-
def get_data_paths(self, split):
|
32 |
-
data_dir = self.source_dir / split
|
33 |
-
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
34 |
-
|
35 |
-
if len(paths) == 0:
|
36 |
-
sub_path = os.listdir(data_dir)
|
37 |
-
for sp in sub_path:
|
38 |
-
data_dir = self.source_dir / split / sp
|
39 |
-
paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
40 |
-
return paths
|
41 |
-
|
42 |
-
def __len__(self):
|
43 |
-
return len(self.data_paths)
|
44 |
-
|
45 |
-
def __getitem__(self, idx):
|
46 |
-
|
47 |
-
return self.sub_get(idx)
|
48 |
-
# try:
|
49 |
-
# return self.sub_get(idx)
|
50 |
-
# except Exception as e:
|
51 |
-
# traceback.print_exc()
|
52 |
-
# # return self.sub_get(0)
|
53 |
-
|
54 |
-
|
55 |
-
def sub_get(self, idx):
|
56 |
-
action_path = self.data_paths[idx]
|
57 |
-
video_path = self.data_paths[idx]
|
58 |
-
|
59 |
-
action_path = video_path.with_suffix(".npz")
|
60 |
-
actions_pool = np.load(action_path)['actions']
|
61 |
-
poses_pool = np.load(action_path)['poses']
|
62 |
-
|
63 |
-
poses_pool[0,1] = poses_pool[1,1] # wrong first in place
|
64 |
-
|
65 |
-
print(poses_pool.shape)
|
66 |
-
|
67 |
-
if poses_pool[:,1].max() - poses_pool[:,1].min() < 2:
|
68 |
-
target_action_path = self.target_dir / action_path.parent.name / action_path.name
|
69 |
-
target_video_path = self.target_dir / video_path.parent.name / video_path.name
|
70 |
-
target_action_path.parent.mkdir(exist_ok=True, parents=True)
|
71 |
-
target_video_path.parent.mkdir(exist_ok=True, parents=True)
|
72 |
-
|
73 |
-
try:
|
74 |
-
shutil.copy2(action_path, target_action_path)
|
75 |
-
shutil.copy2(video_path, target_video_path)
|
76 |
-
except:
|
77 |
-
import pdb;pdb.set_trace()
|
78 |
-
|
79 |
-
return poses_pool[:10]
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
if __name__ == "__main__":
|
84 |
-
import torch
|
85 |
-
from unittest.mock import MagicMock
|
86 |
-
import tqdm
|
87 |
-
|
88 |
-
cfg = MagicMock()
|
89 |
-
cfg.resolution = 64
|
90 |
-
cfg.external_cond_dim = 0
|
91 |
-
cfg.n_frames = 64
|
92 |
-
cfg.save_dir = "data/minecraft"
|
93 |
-
cfg.validation_multiplier = 1
|
94 |
-
|
95 |
-
dataset = MinecraftVideoDataset(cfg, "training")
|
96 |
-
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
|
97 |
-
|
98 |
-
for batch in tqdm.tqdm(dataloader):
|
99 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/video/minecraft_video_dataset_pose.py
DELETED
@@ -1,421 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from typing import Sequence
|
3 |
-
import numpy as np
|
4 |
-
import io
|
5 |
-
import tarfile
|
6 |
-
from pytorchvideo.data.encoded_video import EncodedVideo
|
7 |
-
from omegaconf import DictConfig
|
8 |
-
from tqdm import tqdm
|
9 |
-
|
10 |
-
from .base_video_dataset import BaseVideoDataset
|
11 |
-
from typing import Mapping, Sequence
|
12 |
-
import os
|
13 |
-
import math
|
14 |
-
from packaging import version as pver
|
15 |
-
from PIL import Image
|
16 |
-
import random
|
17 |
-
|
18 |
-
def euler_to_rotation_matrix(pitch, yaw):
|
19 |
-
"""
|
20 |
-
Convert euler angles (pitch, yaw) to a 3x3 rotation matrix.
|
21 |
-
pitch: rotation around x-axis (in radians)
|
22 |
-
yaw: rotation around y-axis (in radians)
|
23 |
-
"""
|
24 |
-
# Rotation matrix around x-axis (pitch)
|
25 |
-
R_x = np.array([
|
26 |
-
[1, 0, 0],
|
27 |
-
[0, math.cos(pitch), -math.sin(pitch)],
|
28 |
-
[0, math.sin(pitch), math.cos(pitch)]
|
29 |
-
])
|
30 |
-
|
31 |
-
# Rotation matrix around y-axis (yaw)
|
32 |
-
R_y = np.array([
|
33 |
-
[math.cos(yaw), 0, math.sin(yaw)],
|
34 |
-
[0, 1, 0],
|
35 |
-
[-math.sin(yaw), 0, math.cos(yaw)]
|
36 |
-
])
|
37 |
-
|
38 |
-
# Combined rotation matrix
|
39 |
-
R = np.dot(R_y, R_x)
|
40 |
-
return R
|
41 |
-
|
42 |
-
def custom_meshgrid(*args):
|
43 |
-
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
44 |
-
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
45 |
-
return torch.meshgrid(*args)
|
46 |
-
else:
|
47 |
-
return torch.meshgrid(*args, indexing='ij')
|
48 |
-
|
49 |
-
def camera_to_world_to_world_to_camera(camera_to_world):
|
50 |
-
"""
|
51 |
-
Convert Camera-to-World matrix to World-to-Camera matrix by inverting the transformation.
|
52 |
-
"""
|
53 |
-
# Extract rotation (R) and translation (T)
|
54 |
-
R = camera_to_world[:3, :3]
|
55 |
-
T = camera_to_world[:3, 3]
|
56 |
-
|
57 |
-
# Calculate World-to-Camera (inverse) matrix
|
58 |
-
world_to_camera = np.eye(4)
|
59 |
-
|
60 |
-
# The rotation part of World-to-Camera is the transpose of Camera-to-World's rotation
|
61 |
-
world_to_camera[:3, :3] = R.T
|
62 |
-
|
63 |
-
# The translation part is the negative of the rotated translation
|
64 |
-
world_to_camera[:3, 3] = -np.dot(R.T, T)
|
65 |
-
|
66 |
-
return world_to_camera
|
67 |
-
|
68 |
-
def euler_to_camera_to_world_matrix(pose):
|
69 |
-
|
70 |
-
x, y, z, pitch, yaw = pose
|
71 |
-
# Convert pitch and yaw to radians
|
72 |
-
pitch = math.radians(pitch)
|
73 |
-
yaw = math.radians(yaw)
|
74 |
-
|
75 |
-
# Get the rotation matrix from Euler angles
|
76 |
-
R = euler_to_rotation_matrix(pitch, yaw)
|
77 |
-
|
78 |
-
# Create the 4x4 transformation matrix (rotation + translation)
|
79 |
-
camera_to_world = np.eye(4)
|
80 |
-
|
81 |
-
# Set the rotation part (upper 3x3)
|
82 |
-
camera_to_world[:3, :3] = R
|
83 |
-
|
84 |
-
# Set the translation part (last column)
|
85 |
-
camera_to_world[:3, 3] = [x, y, z]
|
86 |
-
|
87 |
-
return camera_to_world
|
88 |
-
|
89 |
-
def tensor_to_gif(tensor, output_path, fps=10):
|
90 |
-
"""
|
91 |
-
Converts a PyTorch tensor of shape (F, 3, H, W) to a GIF.
|
92 |
-
|
93 |
-
Args:
|
94 |
-
tensor (torch.Tensor): Input tensor of shape (F, 3, H, W) with values in range [0, 1] or [0, 255].
|
95 |
-
output_path (str): Path to save the output GIF.
|
96 |
-
fps (int): Frames per second for the GIF.
|
97 |
-
"""
|
98 |
-
# Ensure the tensor is in [0, 255] range
|
99 |
-
if tensor.max() <= 1.0:
|
100 |
-
tensor = (tensor * 255).byte()
|
101 |
-
else:
|
102 |
-
tensor = tensor.byte()
|
103 |
-
|
104 |
-
# Convert tensor to numpy array and rearrange to (F, H, W, 3)
|
105 |
-
frames = tensor.permute(0, 2, 3, 1).cpu().numpy()
|
106 |
-
|
107 |
-
# Convert frames to PIL Images
|
108 |
-
pil_frames = [Image.fromarray(frame) for frame in frames]
|
109 |
-
|
110 |
-
# Save as GIF
|
111 |
-
pil_frames[0].save(
|
112 |
-
output_path,
|
113 |
-
save_all=True,
|
114 |
-
append_images=pil_frames[1:],
|
115 |
-
duration=int(1000 / fps),
|
116 |
-
loop=0
|
117 |
-
)
|
118 |
-
|
119 |
-
def get_relative_pose(cam_params, zero_first_frame_scale):
|
120 |
-
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
121 |
-
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
122 |
-
source_cam_c2w = abs_c2ws[0]
|
123 |
-
if zero_first_frame_scale:
|
124 |
-
cam_to_origin = 0
|
125 |
-
else:
|
126 |
-
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
|
127 |
-
target_cam_c2w = np.array([
|
128 |
-
[1, 0, 0, 0],
|
129 |
-
[0, 1, 0, -cam_to_origin],
|
130 |
-
[0, 0, 1, 0],
|
131 |
-
[0, 0, 0, 1]
|
132 |
-
])
|
133 |
-
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
134 |
-
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
135 |
-
ret_poses = np.array(ret_poses, dtype=np.float32)
|
136 |
-
return ret_poses
|
137 |
-
|
138 |
-
def ray_condition(K, c2w, H, W, device):
|
139 |
-
# c2w: B, V, 4, 4
|
140 |
-
# K: B, V, 4
|
141 |
-
|
142 |
-
B = K.shape[0]
|
143 |
-
|
144 |
-
j, i = custom_meshgrid(
|
145 |
-
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
146 |
-
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
147 |
-
)
|
148 |
-
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
149 |
-
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
150 |
-
|
151 |
-
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
152 |
-
|
153 |
-
zs = torch.ones_like(i) # [B, HxW]
|
154 |
-
xs = (i - cx) / fx * zs
|
155 |
-
ys = (j - cy) / fy * zs
|
156 |
-
zs = zs.expand_as(ys)
|
157 |
-
|
158 |
-
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
159 |
-
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
160 |
-
|
161 |
-
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
162 |
-
rays_o = c2w[..., :3, 3] # B, V, 3
|
163 |
-
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
164 |
-
# c2w @ dirctions
|
165 |
-
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
166 |
-
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
167 |
-
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
168 |
-
|
169 |
-
return plucker
|
170 |
-
|
171 |
-
class Camera(object):
|
172 |
-
def __init__(self, entry, focal_length=0.35):
|
173 |
-
self.fx = focal_length # 0.35 correspond to 110 fov
|
174 |
-
self.fy = focal_length*640/360
|
175 |
-
self.cx = 0.5
|
176 |
-
self.cy = 0.5
|
177 |
-
self.c2w_mat = euler_to_camera_to_world_matrix(entry)
|
178 |
-
self.w2c_mat = camera_to_world_to_world_to_camera(np.copy(self.c2w_mat))
|
179 |
-
|
180 |
-
|
181 |
-
ACTION_KEYS = [
|
182 |
-
"inventory",
|
183 |
-
"ESC",
|
184 |
-
"hotbar.1",
|
185 |
-
"hotbar.2",
|
186 |
-
"hotbar.3",
|
187 |
-
"hotbar.4",
|
188 |
-
"hotbar.5",
|
189 |
-
"hotbar.6",
|
190 |
-
"hotbar.7",
|
191 |
-
"hotbar.8",
|
192 |
-
"hotbar.9",
|
193 |
-
"forward",
|
194 |
-
"back",
|
195 |
-
"left",
|
196 |
-
"right",
|
197 |
-
"cameraY",
|
198 |
-
"cameraX",
|
199 |
-
"jump",
|
200 |
-
"sneak",
|
201 |
-
"sprint",
|
202 |
-
"swapHands",
|
203 |
-
"attack",
|
204 |
-
"use",
|
205 |
-
"pickItem",
|
206 |
-
"drop",
|
207 |
-
]
|
208 |
-
|
209 |
-
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
|
210 |
-
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
|
211 |
-
for i, current_actions in enumerate(actions):
|
212 |
-
for j, action_key in enumerate(ACTION_KEYS):
|
213 |
-
if action_key.startswith("camera"):
|
214 |
-
if action_key == "cameraX":
|
215 |
-
value = current_actions["camera"][0]
|
216 |
-
elif action_key == "cameraY":
|
217 |
-
value = current_actions["camera"][1]
|
218 |
-
else:
|
219 |
-
raise ValueError(f"Unknown camera action key: {action_key}")
|
220 |
-
max_val = 20
|
221 |
-
bin_size = 0.5
|
222 |
-
num_buckets = int(max_val / bin_size)
|
223 |
-
value = (value - num_buckets) / num_buckets
|
224 |
-
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
|
225 |
-
else:
|
226 |
-
value = current_actions[action_key]
|
227 |
-
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
|
228 |
-
actions_one_hot[i, j] = value
|
229 |
-
|
230 |
-
return actions_one_hot
|
231 |
-
|
232 |
-
def simpletomulti(actions):
|
233 |
-
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
|
234 |
-
vec_25[actions==1, 11] = 1
|
235 |
-
vec_25[actions==2, 16] = -1
|
236 |
-
vec_25[actions==3, 16] = 1
|
237 |
-
vec_25[actions==4, 15] = -1
|
238 |
-
vec_25[actions==5, 15] = 1
|
239 |
-
return vec_25
|
240 |
-
|
241 |
-
def simpletomulti2(actions):
|
242 |
-
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
|
243 |
-
vec_25[actions[:,0]==1, 11] = 1
|
244 |
-
vec_25[actions[:,0]==2, 12] = 1
|
245 |
-
vec_25[actions[:,4]==11, 16] = -1
|
246 |
-
vec_25[actions[:,4]==13, 16] = 1
|
247 |
-
vec_25[actions[:,3]==11, 15] = -1
|
248 |
-
vec_25[actions[:,3]==13, 15] = 1
|
249 |
-
vec_25[actions[:,5]==6, 24] = 1
|
250 |
-
vec_25[actions[:,5]==1, 24] = 1
|
251 |
-
vec_25[actions[:,1]==1, 13] = 1
|
252 |
-
vec_25[actions[:,1]==2, 14] = 1
|
253 |
-
vec_25[actions[:,7]==1, 2] = 1
|
254 |
-
return vec_25
|
255 |
-
|
256 |
-
class MinecraftVideoPoseDataset(BaseVideoDataset):
|
257 |
-
"""
|
258 |
-
Minecraft dataset
|
259 |
-
"""
|
260 |
-
|
261 |
-
def __init__(self, cfg: DictConfig, split: str = "training"):
|
262 |
-
if split == "test":
|
263 |
-
split = "validation"
|
264 |
-
super().__init__(cfg, split)
|
265 |
-
|
266 |
-
if hasattr(cfg, "n_frames_valid") and split == "validation":
|
267 |
-
self.n_frames = cfg.n_frames_valid
|
268 |
-
|
269 |
-
def get_data_paths(self, split):
|
270 |
-
data_dir = self.save_dir / split
|
271 |
-
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
272 |
-
|
273 |
-
if len(paths) == 0:
|
274 |
-
sub_path = os.listdir(data_dir)
|
275 |
-
for sp in sub_path:
|
276 |
-
data_dir = self.save_dir / split / sp
|
277 |
-
paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
|
278 |
-
return paths
|
279 |
-
|
280 |
-
def get_data_lengths(self, split):
|
281 |
-
lengths = [300] * len(self.get_data_paths(split))
|
282 |
-
return lengths
|
283 |
-
|
284 |
-
def download_dataset(self) -> Sequence[int]:
|
285 |
-
from internetarchive import download
|
286 |
-
|
287 |
-
part_suffixes = [
|
288 |
-
"aa",
|
289 |
-
"ab",
|
290 |
-
"ac",
|
291 |
-
"ad",
|
292 |
-
"ae",
|
293 |
-
"af",
|
294 |
-
"ag",
|
295 |
-
"ah",
|
296 |
-
"ai",
|
297 |
-
"aj",
|
298 |
-
"ak",
|
299 |
-
]
|
300 |
-
for part_suffix in part_suffixes:
|
301 |
-
identifier = f"minecraft_marsh_dataset_{part_suffix}"
|
302 |
-
file_name = f"minecraft.tar.part{part_suffix}"
|
303 |
-
download(identifier, file_name, destdir=self.save_dir, verbose=True)
|
304 |
-
|
305 |
-
combined_bytes = io.BytesIO()
|
306 |
-
for part_suffix in part_suffixes:
|
307 |
-
identifier = f"minecraft_marsh_dataset_{part_suffix}"
|
308 |
-
file_name = f"minecraft.tar.part{part_suffix}"
|
309 |
-
part_file = self.save_dir / identifier / file_name
|
310 |
-
with open(part_file, "rb") as part:
|
311 |
-
combined_bytes.write(part.read())
|
312 |
-
combined_bytes.seek(0)
|
313 |
-
with tarfile.open(fileobj=combined_bytes, mode="r") as combined_archive:
|
314 |
-
combined_archive.extractall(self.save_dir)
|
315 |
-
(self.save_dir / "minecraft/test").rename(self.save_dir / "validation")
|
316 |
-
(self.save_dir / "minecraft/train").rename(self.save_dir / "training")
|
317 |
-
(self.save_dir / "minecraft").rmdir()
|
318 |
-
for part_suffix in part_suffixes:
|
319 |
-
identifier = f"minecraft_marsh_dataset_{part_suffix}"
|
320 |
-
file_name = f"minecraft.tar.part{part_suffix}"
|
321 |
-
part_file = self.save_dir / identifier / file_name
|
322 |
-
part_file.rmdir()
|
323 |
-
|
324 |
-
def __getitem__(self, idx):
|
325 |
-
# return self.load_data(idx)
|
326 |
-
|
327 |
-
max_retries = 1000
|
328 |
-
for mr in range(max_retries):
|
329 |
-
try:
|
330 |
-
return self.load_data(idx)
|
331 |
-
except Exception as e:
|
332 |
-
print(f"{mr} Error: {e}")
|
333 |
-
# idx = self.idx_remap[idx]
|
334 |
-
# file_idx, frame_idx = self.split_idx(idx)
|
335 |
-
# video_path = self.data_paths[file_idx]
|
336 |
-
# os.remove(video_path)
|
337 |
-
idx = (idx + 1) % self.__len__()
|
338 |
-
|
339 |
-
def load_data(self, idx):
|
340 |
-
idx = self.idx_remap[idx]
|
341 |
-
file_idx, frame_idx = self.split_idx(idx)
|
342 |
-
action_path = self.data_paths[file_idx]
|
343 |
-
video_path = self.data_paths[file_idx]
|
344 |
-
|
345 |
-
action_path = video_path.with_suffix(".npz")
|
346 |
-
actions_pool = np.load(action_path)['actions']
|
347 |
-
poses_pool = np.load(action_path)['poses']
|
348 |
-
|
349 |
-
poses_pool[0,1] = poses_pool[1,1] # wrong first in place
|
350 |
-
|
351 |
-
assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
|
352 |
-
|
353 |
-
if len(poses_pool) < len(actions_pool):
|
354 |
-
poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
|
355 |
-
|
356 |
-
actions_pool = simpletomulti2(actions_pool)
|
357 |
-
video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
|
358 |
-
|
359 |
-
frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
|
360 |
-
|
361 |
-
if self.split == "validation":
|
362 |
-
frame_idx = 240
|
363 |
-
|
364 |
-
total_frame = video_raw.duration.numerator
|
365 |
-
fps = 10 # video_raw.duration.denominator
|
366 |
-
total_frame = total_frame * fps / video_raw.duration.denominator
|
367 |
-
video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
|
368 |
-
|
369 |
-
video = video.permute(1, 2, 3, 0).numpy()
|
370 |
-
|
371 |
-
if self.split != "validation" and 'degrees' in np.load(action_path).keys():
|
372 |
-
degrees = np.load(action_path)['degrees']
|
373 |
-
actions_pool[:,16] *= degrees
|
374 |
-
|
375 |
-
actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) # (t, )
|
376 |
-
|
377 |
-
poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
|
378 |
-
pad_len = self.n_frames - len(video)
|
379 |
-
poses_pool[:,:3] -= poses[:1,:3]
|
380 |
-
# poses_pool[:,3:] = -poses_pool[:,3:]
|
381 |
-
poses_pool[:,-1] = -poses_pool[:,-1]
|
382 |
-
poses_pool[:,3:] %= 360
|
383 |
-
|
384 |
-
poses[:,:3] -= poses[:1,:3] # do not normalize angle
|
385 |
-
# poses[:,3:] = -poses[:,3:]
|
386 |
-
poses[:,-1] = -poses[:,-1]
|
387 |
-
poses[:,3:] %= 360
|
388 |
-
|
389 |
-
nonterminal = np.ones(self.n_frames)
|
390 |
-
if len(video) < self.n_frames:
|
391 |
-
video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
|
392 |
-
actions = np.pad(actions, ((0, pad_len),))
|
393 |
-
poses = np.pad(actions, ((0, pad_len),))
|
394 |
-
nonterminal[-pad_len:] = 0
|
395 |
-
|
396 |
-
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
|
397 |
-
|
398 |
-
return (
|
399 |
-
video[:: self.frame_skip],
|
400 |
-
actions[:: self.frame_skip],
|
401 |
-
poses[:: self.frame_skip]
|
402 |
-
)
|
403 |
-
|
404 |
-
|
405 |
-
if __name__ == "__main__":
|
406 |
-
import torch
|
407 |
-
from unittest.mock import MagicMock
|
408 |
-
import tqdm
|
409 |
-
|
410 |
-
cfg = MagicMock()
|
411 |
-
cfg.resolution = 64
|
412 |
-
cfg.external_cond_dim = 0
|
413 |
-
cfg.n_frames = 64
|
414 |
-
cfg.save_dir = "data/minecraft"
|
415 |
-
cfg.validation_multiplier = 1
|
416 |
-
|
417 |
-
dataset = MinecraftVideoDataset(cfg, "training")
|
418 |
-
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
|
419 |
-
|
420 |
-
for batch in tqdm.tqdm(dataloader):
|
421 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/README.md
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
# experiments
|
2 |
-
|
3 |
-
`experiments` folder contains code of experiments. Each file in the experiment folder represents a certain type of
|
4 |
-
benchmark specific to a project. Such experiment can be instantiated with a certain dataset and a certain algorithm.
|
5 |
-
|
6 |
-
You should create a new `.py` file for your experiment,
|
7 |
-
inherent from any suitable base classes in `experiments/exp_base.py`,
|
8 |
-
and then register your new experiment in `experiments/__init__.py`.
|
9 |
-
|
10 |
-
You run an experiment by running `python -m main [options]` in the root directory of the
|
11 |
-
project. You should not log any data in this folder, but storing them under `outputs` under root project
|
12 |
-
directory.
|
13 |
-
|
14 |
-
This folder is only intend to contain formal experiments. For debug code and unit tests, put them under `debug` folder.
|
15 |
-
For scripts that's not meant to be an experiment please use `scripts` folder.
|
16 |
-
|
17 |
-
---
|
18 |
-
|
19 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/__init__.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
from typing import Optional, Union
|
2 |
-
from omegaconf import DictConfig
|
3 |
-
import pathlib
|
4 |
-
from lightning.pytorch.loggers.wandb import WandbLogger
|
5 |
-
|
6 |
-
from .exp_base import BaseExperiment
|
7 |
-
from .exp_video import VideoPredictionExperiment
|
8 |
-
from .exp_pose import PoseExperiment
|
9 |
-
|
10 |
-
# each key has to be a yaml file under '[project_root]/configurations/experiment' without .yaml suffix
|
11 |
-
exp_registry = dict(
|
12 |
-
exp_video=VideoPredictionExperiment,
|
13 |
-
exp_pose=PoseExperiment
|
14 |
-
)
|
15 |
-
|
16 |
-
|
17 |
-
def build_experiment(
|
18 |
-
cfg: DictConfig,
|
19 |
-
logger: Optional[WandbLogger] = None,
|
20 |
-
ckpt_path: Optional[Union[str, pathlib.Path]] = None,
|
21 |
-
) -> BaseExperiment:
|
22 |
-
"""
|
23 |
-
Build an experiment instance based on registry
|
24 |
-
:param cfg: configuration file
|
25 |
-
:param logger: optional logger for the experiment
|
26 |
-
:param ckpt_path: optional checkpoint path for saving and loading
|
27 |
-
:return:
|
28 |
-
"""
|
29 |
-
if cfg.experiment._name not in exp_registry:
|
30 |
-
raise ValueError(
|
31 |
-
f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. "
|
32 |
-
"Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file."
|
33 |
-
)
|
34 |
-
|
35 |
-
return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/exp_base.py
DELETED
@@ -1,473 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
|
3 |
-
template [repo](https://github.com/buoyancy99/research-template).
|
4 |
-
By its MIT license, you must keep the above sentence in `README.md`
|
5 |
-
and the `LICENSE` file to credit the author.
|
6 |
-
"""
|
7 |
-
|
8 |
-
from abc import ABC, abstractmethod
|
9 |
-
from typing import Optional, Union, Literal, List, Dict
|
10 |
-
import pathlib
|
11 |
-
import os
|
12 |
-
|
13 |
-
import hydra
|
14 |
-
import torch
|
15 |
-
from lightning.pytorch.strategies.ddp import DDPStrategy
|
16 |
-
|
17 |
-
import lightning.pytorch as pl
|
18 |
-
from lightning.pytorch.loggers.wandb import WandbLogger
|
19 |
-
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
|
20 |
-
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
21 |
-
from pytorch_lightning.utilities import rank_zero_info
|
22 |
-
|
23 |
-
from omegaconf import DictConfig
|
24 |
-
|
25 |
-
from utils.print_utils import cyan
|
26 |
-
from utils.distributed_utils import is_rank_zero
|
27 |
-
from safetensors.torch import load_model
|
28 |
-
from pathlib import Path
|
29 |
-
from huggingface_hub import hf_hub_download
|
30 |
-
|
31 |
-
torch.set_float32_matmul_precision("high")
|
32 |
-
|
33 |
-
def load_custom_checkpoint(algo, optimizer, checkpoint_path):
|
34 |
-
if not checkpoint_path:
|
35 |
-
rank_zero_info("No checkpoint path provided, skipping checkpoint loading.")
|
36 |
-
return None
|
37 |
-
|
38 |
-
if not isinstance(checkpoint_path, Path):
|
39 |
-
checkpoint_path = Path(checkpoint_path)
|
40 |
-
|
41 |
-
if "yslan" in str(checkpoint_path):
|
42 |
-
hf_ckpt = str(checkpoint_path).split('/')
|
43 |
-
repo_id = '/'.join(hf_ckpt[:2])
|
44 |
-
file_name = '/'.join(hf_ckpt[2:])
|
45 |
-
model_path = hf_hub_download(repo_id=repo_id,
|
46 |
-
filename=file_name)
|
47 |
-
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
|
48 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
49 |
-
|
50 |
-
elif checkpoint_path.suffix == ".pt":
|
51 |
-
ckpt = torch.load(checkpoint_path, weights_only=True)
|
52 |
-
algo.load_state_dict(ckpt, strict=False)
|
53 |
-
elif checkpoint_path.suffix == ".ckpt":
|
54 |
-
ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
55 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
56 |
-
elif checkpoint_path.suffix == ".safetensors":
|
57 |
-
load_model(algo, checkpoint_path, strict=False)
|
58 |
-
elif os.path.isdir(checkpoint_path):
|
59 |
-
ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt')]
|
60 |
-
if not ckpt_files:
|
61 |
-
raise FileNotFoundError("在指定文件夹中未找到任何 .ckpt 文件!")
|
62 |
-
selected_ckpt = max(ckpt_files)
|
63 |
-
selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt)
|
64 |
-
print(f"加载的 checkpoint 文件为: {selected_ckpt_path}")
|
65 |
-
|
66 |
-
ckpt = torch.load(selected_ckpt_path, map_location=torch.device('cpu'))
|
67 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
68 |
-
|
69 |
-
rank_zero_info("Model weights loaded.")
|
70 |
-
|
71 |
-
class BaseExperiment(ABC):
|
72 |
-
"""
|
73 |
-
Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more
|
74 |
-
flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks.
|
75 |
-
"""
|
76 |
-
|
77 |
-
# each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
|
78 |
-
compatible_algorithms: Dict = NotImplementedError
|
79 |
-
|
80 |
-
def __init__(
|
81 |
-
self,
|
82 |
-
root_cfg: DictConfig,
|
83 |
-
logger: Optional[WandbLogger] = None,
|
84 |
-
ckpt_path: Optional[Union[str, pathlib.Path]] = None,
|
85 |
-
) -> None:
|
86 |
-
"""
|
87 |
-
Constructor
|
88 |
-
|
89 |
-
Args:
|
90 |
-
cfg: configuration file that contains everything about the experiment
|
91 |
-
logger: a pytorch-lightning WandbLogger instance
|
92 |
-
ckpt_path: an optional path to saved checkpoint
|
93 |
-
"""
|
94 |
-
super().__init__()
|
95 |
-
self.root_cfg = root_cfg
|
96 |
-
self.cfg = root_cfg.experiment
|
97 |
-
self.debug = root_cfg.debug
|
98 |
-
self.logger = logger
|
99 |
-
self.ckpt_path = ckpt_path
|
100 |
-
self.algo = None
|
101 |
-
self.customized_load = self.cfg.customized_load
|
102 |
-
self.load_vae = self.cfg.load_vae
|
103 |
-
self.load_t_to_r = self.cfg.load_t_to_r
|
104 |
-
self.zero_init_gate=self.cfg.zero_init_gate
|
105 |
-
self.only_tune_refer = self.cfg.only_tune_refer
|
106 |
-
self.diffusion_path = self.cfg.diffusion_path
|
107 |
-
self.vae_path = self.cfg.vae_path # "/mnt/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
|
108 |
-
self.pose_predictor_path = self.cfg.pose_predictor_path # "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
|
109 |
-
|
110 |
-
def _build_algo(self):
|
111 |
-
"""
|
112 |
-
Build the lightning module
|
113 |
-
:return: a pytorch-lightning module to be launched
|
114 |
-
"""
|
115 |
-
algo_name = self.root_cfg.algorithm._name
|
116 |
-
if algo_name not in self.compatible_algorithms:
|
117 |
-
raise ValueError(
|
118 |
-
f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. "
|
119 |
-
"Make sure you define compatible_algorithms correctly and make sure that each key has "
|
120 |
-
"same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix"
|
121 |
-
)
|
122 |
-
return self.compatible_algorithms[algo_name](self.root_cfg.algorithm)
|
123 |
-
|
124 |
-
def exec_task(self, task: str) -> None:
|
125 |
-
"""
|
126 |
-
Executing a certain task specified by string. Each task should be a stage of experiment.
|
127 |
-
In most computer vision / nlp applications, tasks should be just train and test.
|
128 |
-
In reinforcement learning, you might have more stages such as collecting dataset etc
|
129 |
-
|
130 |
-
Args:
|
131 |
-
task: a string specifying a task implemented for this experiment
|
132 |
-
"""
|
133 |
-
if hasattr(self, task) and callable(getattr(self, task)):
|
134 |
-
if is_rank_zero:
|
135 |
-
print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
|
136 |
-
getattr(self, task)()
|
137 |
-
else:
|
138 |
-
raise ValueError(
|
139 |
-
f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
|
140 |
-
)
|
141 |
-
|
142 |
-
def exec_interactive(self, task: str) -> None:
|
143 |
-
"""
|
144 |
-
Executing a certain task specified by string. Each task should be a stage of experiment.
|
145 |
-
In most computer vision / nlp applications, tasks should be just train and test.
|
146 |
-
In reinforcement learning, you might have more stages such as collecting dataset etc
|
147 |
-
|
148 |
-
Args:
|
149 |
-
task: a string specifying a task implemented for this experiment
|
150 |
-
"""
|
151 |
-
if hasattr(self, task) and callable(getattr(self, task)):
|
152 |
-
if is_rank_zero:
|
153 |
-
print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
|
154 |
-
return getattr(self, task)()
|
155 |
-
else:
|
156 |
-
raise ValueError(
|
157 |
-
f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
|
158 |
-
)
|
159 |
-
|
160 |
-
class BaseLightningExperiment(BaseExperiment):
|
161 |
-
"""
|
162 |
-
Abstract class for pytorch lightning experiments. Useful for computer vision & nlp where main components are
|
163 |
-
simply models, datasets and train loop.
|
164 |
-
"""
|
165 |
-
|
166 |
-
# each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
|
167 |
-
compatible_algorithms: Dict = NotImplementedError
|
168 |
-
|
169 |
-
# each key has to be a yaml file under '[project_root]/configurations/dataset' without .yaml suffix
|
170 |
-
compatible_datasets: Dict = NotImplementedError
|
171 |
-
|
172 |
-
def _build_trainer_callbacks(self):
|
173 |
-
callbacks = []
|
174 |
-
if self.logger:
|
175 |
-
callbacks.append(LearningRateMonitor("step", True))
|
176 |
-
|
177 |
-
def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
178 |
-
train_dataset = self._build_dataset("training")
|
179 |
-
shuffle = (
|
180 |
-
False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle
|
181 |
-
)
|
182 |
-
if train_dataset:
|
183 |
-
return torch.utils.data.DataLoader(
|
184 |
-
train_dataset,
|
185 |
-
batch_size=self.cfg.training.batch_size,
|
186 |
-
num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers),
|
187 |
-
shuffle=shuffle,
|
188 |
-
persistent_workers=True,
|
189 |
-
)
|
190 |
-
else:
|
191 |
-
return None
|
192 |
-
|
193 |
-
def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
194 |
-
validation_dataset = self._build_dataset("validation")
|
195 |
-
shuffle = (
|
196 |
-
False
|
197 |
-
if isinstance(validation_dataset, torch.utils.data.IterableDataset)
|
198 |
-
else self.cfg.validation.data.shuffle
|
199 |
-
)
|
200 |
-
if validation_dataset:
|
201 |
-
return torch.utils.data.DataLoader(
|
202 |
-
validation_dataset,
|
203 |
-
batch_size=self.cfg.validation.batch_size,
|
204 |
-
num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers),
|
205 |
-
shuffle=shuffle,
|
206 |
-
persistent_workers=True,
|
207 |
-
)
|
208 |
-
else:
|
209 |
-
return None
|
210 |
-
|
211 |
-
def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
212 |
-
test_dataset = self._build_dataset("test")
|
213 |
-
shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle
|
214 |
-
if test_dataset:
|
215 |
-
return torch.utils.data.DataLoader(
|
216 |
-
test_dataset,
|
217 |
-
batch_size=self.cfg.test.batch_size,
|
218 |
-
num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers),
|
219 |
-
shuffle=shuffle,
|
220 |
-
persistent_workers=True,
|
221 |
-
)
|
222 |
-
else:
|
223 |
-
return None
|
224 |
-
|
225 |
-
def training(self) -> None:
|
226 |
-
"""
|
227 |
-
All training happens here
|
228 |
-
"""
|
229 |
-
if not self.algo:
|
230 |
-
self.algo = self._build_algo()
|
231 |
-
if self.cfg.training.compile:
|
232 |
-
self.algo = torch.compile(self.algo)
|
233 |
-
|
234 |
-
callbacks = []
|
235 |
-
if self.logger:
|
236 |
-
callbacks.append(LearningRateMonitor("step", True))
|
237 |
-
if "checkpointing" in self.cfg.training:
|
238 |
-
callbacks.append(
|
239 |
-
ModelCheckpoint(
|
240 |
-
pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints",
|
241 |
-
**self.cfg.training.checkpointing,
|
242 |
-
)
|
243 |
-
)
|
244 |
-
|
245 |
-
# TODO do not upload checkpoint to wandb
|
246 |
-
|
247 |
-
# trainer = pl.Trainer(
|
248 |
-
# accelerator="auto",
|
249 |
-
# logger=self.logger if self.logger else False,
|
250 |
-
# devices=torch.cuda.device_count(),
|
251 |
-
# num_nodes=self.cfg.num_nodes,
|
252 |
-
# strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
|
253 |
-
# callbacks=callbacks,
|
254 |
-
# gradient_clip_val=self.cfg.training.optim.gradient_clip_val,
|
255 |
-
# val_check_interval=self.cfg.validation.val_every_n_step,
|
256 |
-
# limit_val_batches=self.cfg.validation.limit_batch,
|
257 |
-
# check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch,
|
258 |
-
# accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches,
|
259 |
-
# precision=self.cfg.training.precision,
|
260 |
-
# detect_anomaly=False, # self.cfg.debug,
|
261 |
-
# num_sanity_val_steps=int(self.cfg.debug),
|
262 |
-
# max_epochs=self.cfg.training.max_epochs,
|
263 |
-
# max_steps=self.cfg.training.max_steps,
|
264 |
-
# max_time=self.cfg.training.max_time,
|
265 |
-
# )
|
266 |
-
|
267 |
-
trainer = pl.Trainer(
|
268 |
-
accelerator="auto",
|
269 |
-
devices="auto", # 自动选择设备
|
270 |
-
strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
|
271 |
-
logger=self.logger or False, # 简化写法
|
272 |
-
callbacks=callbacks,
|
273 |
-
gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, # 确保默认值
|
274 |
-
val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None,
|
275 |
-
limit_val_batches=self.cfg.validation.limit_batch,
|
276 |
-
check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None,
|
277 |
-
accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, # 默认累积为1
|
278 |
-
precision=self.cfg.training.precision or 32, # 默认32位精度
|
279 |
-
detect_anomaly=False, # 默认关闭异常检测
|
280 |
-
num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0,
|
281 |
-
max_epochs=self.cfg.training.max_epochs,
|
282 |
-
max_steps=self.cfg.training.max_steps,
|
283 |
-
max_time=self.cfg.training.max_time
|
284 |
-
)
|
285 |
-
|
286 |
-
|
287 |
-
if self.customized_load:
|
288 |
-
if self.load_vae:
|
289 |
-
load_custom_checkpoint(algo=self.algo.diffusion_model.model,optimizer=None,checkpoint_path=self.ckpt_path)
|
290 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
291 |
-
else:
|
292 |
-
load_custom_checkpoint(algo=self.algo,optimizer=None,checkpoint_path=self.ckpt_path)
|
293 |
-
|
294 |
-
if self.load_t_to_r:
|
295 |
-
param_list = []
|
296 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
297 |
-
if 't_' in name and 't_embedder' not in name:
|
298 |
-
print(name)
|
299 |
-
param_list.append(para)
|
300 |
-
|
301 |
-
it = 0
|
302 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
303 |
-
if 'r_' in name:
|
304 |
-
para.requires_grad_(False)
|
305 |
-
try:
|
306 |
-
para.copy_(param_list[it].detach().cpu())
|
307 |
-
except:
|
308 |
-
import pdb;pdb.set_trace()
|
309 |
-
para.requires_grad_(True)
|
310 |
-
it += 1
|
311 |
-
|
312 |
-
if self.zero_init_gate:
|
313 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
314 |
-
if 'r_adaLN_modulation' in name:
|
315 |
-
para.requires_grad_(False)
|
316 |
-
para[2*1024:3*1024] = 0
|
317 |
-
para[5*1024:6*1024] = 0
|
318 |
-
para.requires_grad_(True)
|
319 |
-
|
320 |
-
if self.only_tune_refer:
|
321 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
322 |
-
para.requires_grad_(False)
|
323 |
-
if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name:
|
324 |
-
para.requires_grad_(True)
|
325 |
-
|
326 |
-
trainer.fit(
|
327 |
-
self.algo,
|
328 |
-
train_dataloaders=self._build_training_loader(),
|
329 |
-
val_dataloaders=self._build_validation_loader(),
|
330 |
-
ckpt_path=None,
|
331 |
-
)
|
332 |
-
else:
|
333 |
-
|
334 |
-
if self.only_tune_refer:
|
335 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
336 |
-
para.requires_grad_(False)
|
337 |
-
if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name:
|
338 |
-
para.requires_grad_(True)
|
339 |
-
|
340 |
-
trainer.fit(
|
341 |
-
self.algo,
|
342 |
-
train_dataloaders=self._build_training_loader(),
|
343 |
-
val_dataloaders=self._build_validation_loader(),
|
344 |
-
ckpt_path=self.ckpt_path,
|
345 |
-
)
|
346 |
-
|
347 |
-
def validation(self) -> None:
|
348 |
-
"""
|
349 |
-
All validation happens here
|
350 |
-
"""
|
351 |
-
if not self.algo:
|
352 |
-
self.algo = self._build_algo()
|
353 |
-
if self.cfg.validation.compile:
|
354 |
-
self.algo = torch.compile(self.algo)
|
355 |
-
|
356 |
-
callbacks = []
|
357 |
-
|
358 |
-
trainer = pl.Trainer(
|
359 |
-
accelerator="auto",
|
360 |
-
logger=self.logger,
|
361 |
-
devices="auto",
|
362 |
-
num_nodes=self.cfg.num_nodes,
|
363 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
364 |
-
callbacks=callbacks,
|
365 |
-
# limit_val_batches=self.cfg.validation.limit_batch,
|
366 |
-
limit_val_batches=self.cfg.validation.limit_batch,
|
367 |
-
precision=self.cfg.validation.precision,
|
368 |
-
detect_anomaly=False, # self.cfg.debug,
|
369 |
-
inference_mode=self.cfg.validation.inference_mode,
|
370 |
-
)
|
371 |
-
|
372 |
-
if self.customized_load:
|
373 |
-
|
374 |
-
if self.load_vae:
|
375 |
-
load_custom_checkpoint(algo=self.algo.diffusion_model.model,optimizer=None,checkpoint_path=self.ckpt_path)
|
376 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
377 |
-
else:
|
378 |
-
load_custom_checkpoint(algo=self.algo,optimizer=None,checkpoint_path=self.ckpt_path)
|
379 |
-
|
380 |
-
if self.load_t_to_r:
|
381 |
-
param_list = []
|
382 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
383 |
-
if 't_' in name and 't_embedder' not in name:
|
384 |
-
print(name)
|
385 |
-
param_list.append(para)
|
386 |
-
|
387 |
-
it = 0
|
388 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
389 |
-
if 'r_' in name:
|
390 |
-
para.requires_grad_(False)
|
391 |
-
try:
|
392 |
-
para.copy_(param_list[it].detach().cpu())
|
393 |
-
except:
|
394 |
-
import pdb;pdb.set_trace()
|
395 |
-
para.requires_grad_(True)
|
396 |
-
it += 1
|
397 |
-
|
398 |
-
if self.zero_init_gate:
|
399 |
-
for name, para in self.algo.diffusion_model.named_parameters():
|
400 |
-
if 'r_adaLN_modulation' in name:
|
401 |
-
para.requires_grad_(False)
|
402 |
-
para[2*1024:3*1024] = 0
|
403 |
-
para[5*1024:6*1024] = 0
|
404 |
-
para.requires_grad_(True)
|
405 |
-
|
406 |
-
trainer.validate(
|
407 |
-
self.algo,
|
408 |
-
dataloaders=self._build_validation_loader(),
|
409 |
-
ckpt_path=None,
|
410 |
-
)
|
411 |
-
else:
|
412 |
-
trainer.validate(
|
413 |
-
self.algo,
|
414 |
-
dataloaders=self._build_validation_loader(),
|
415 |
-
ckpt_path=self.ckpt_path,
|
416 |
-
)
|
417 |
-
|
418 |
-
def test(self) -> None:
|
419 |
-
"""
|
420 |
-
All testing happens here
|
421 |
-
"""
|
422 |
-
if not self.algo:
|
423 |
-
self.algo = self._build_algo()
|
424 |
-
if self.cfg.test.compile:
|
425 |
-
self.algo = torch.compile(self.algo)
|
426 |
-
|
427 |
-
callbacks = []
|
428 |
-
|
429 |
-
trainer = pl.Trainer(
|
430 |
-
accelerator="auto",
|
431 |
-
logger=self.logger,
|
432 |
-
devices="auto",
|
433 |
-
num_nodes=self.cfg.num_nodes,
|
434 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
435 |
-
callbacks=callbacks,
|
436 |
-
limit_test_batches=self.cfg.test.limit_batch,
|
437 |
-
precision=self.cfg.test.precision,
|
438 |
-
detect_anomaly=False, # self.cfg.debug,
|
439 |
-
)
|
440 |
-
|
441 |
-
# Only load the checkpoint if only testing. Otherwise, it will have been loaded
|
442 |
-
# and further trained during train.
|
443 |
-
trainer.test(
|
444 |
-
self.algo,
|
445 |
-
dataloaders=self._build_test_loader(),
|
446 |
-
ckpt_path=self.ckpt_path,
|
447 |
-
)
|
448 |
-
if not self.algo:
|
449 |
-
self.algo = self._build_algo()
|
450 |
-
if self.cfg.validation.compile:
|
451 |
-
self.algo = torch.compile(self.algo)
|
452 |
-
|
453 |
-
|
454 |
-
def interactive(self):
|
455 |
-
|
456 |
-
if not self.algo:
|
457 |
-
self.algo = self._build_algo()
|
458 |
-
if self.cfg.validation.compile:
|
459 |
-
self.algo = torch.compile(self.algo)
|
460 |
-
|
461 |
-
if self.customized_load:
|
462 |
-
load_custom_checkpoint(algo=self.algo.diffusion_model,optimizer=None,checkpoint_path=self.diffusion_path)
|
463 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
464 |
-
load_custom_checkpoint(algo=self.algo.pose_prediction_model,optimizer=None,checkpoint_path=self.pose_predictor_path)
|
465 |
-
return self.algo
|
466 |
-
else:
|
467 |
-
raise NotImplementedError
|
468 |
-
|
469 |
-
def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]:
|
470 |
-
if split in ["training", "test", "validation"]:
|
471 |
-
return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split)
|
472 |
-
else:
|
473 |
-
raise NotImplementedError(f"split '{split}' is not implemented")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/exp_pose.py
DELETED
@@ -1,310 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
|
3 |
-
template [repo](https://github.com/buoyancy99/research-template).
|
4 |
-
By its MIT license, you must keep the above sentence in `README.md`
|
5 |
-
and the `LICENSE` file to credit the author.
|
6 |
-
"""
|
7 |
-
|
8 |
-
from abc import ABC, abstractmethod
|
9 |
-
from typing import Optional, Union, Literal, List, Dict
|
10 |
-
import pathlib
|
11 |
-
import os
|
12 |
-
|
13 |
-
import hydra
|
14 |
-
import torch
|
15 |
-
from lightning.pytorch.strategies.ddp import DDPStrategy
|
16 |
-
|
17 |
-
import lightning.pytorch as pl
|
18 |
-
from lightning.pytorch.loggers.wandb import WandbLogger
|
19 |
-
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
|
20 |
-
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
21 |
-
from pytorch_lightning.utilities import rank_zero_info
|
22 |
-
|
23 |
-
from omegaconf import DictConfig
|
24 |
-
|
25 |
-
from utils.print_utils import cyan
|
26 |
-
from utils.distributed_utils import is_rank_zero
|
27 |
-
from safetensors.torch import load_model
|
28 |
-
from pathlib import Path
|
29 |
-
from algorithms.worldmem import PosePrediction
|
30 |
-
from datasets.video import MinecraftVideoPoseDataset
|
31 |
-
|
32 |
-
|
33 |
-
torch.set_float32_matmul_precision("high")
|
34 |
-
|
35 |
-
def load_custom_checkpoint(algo, optimizer, checkpoint_path):
|
36 |
-
if not checkpoint_path:
|
37 |
-
rank_zero_info("No checkpoint path provided, skipping checkpoint loading.")
|
38 |
-
return None
|
39 |
-
|
40 |
-
if not isinstance(checkpoint_path, Path):
|
41 |
-
checkpoint_path = Path(checkpoint_path)
|
42 |
-
|
43 |
-
if checkpoint_path.suffix == ".pt":
|
44 |
-
ckpt = torch.load(checkpoint_path, weights_only=True)
|
45 |
-
algo.load_state_dict(ckpt, strict=False)
|
46 |
-
elif checkpoint_path.suffix == ".ckpt":
|
47 |
-
ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
48 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
49 |
-
elif checkpoint_path.suffix == ".safetensors":
|
50 |
-
load_model(algo, checkpoint_path, strict=False)
|
51 |
-
elif os.path.isdir(checkpoint_path):
|
52 |
-
ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt')]
|
53 |
-
if not ckpt_files:
|
54 |
-
raise FileNotFoundError("在指定文件夹中未找到任何 .ckpt 文件!")
|
55 |
-
selected_ckpt = max(ckpt_files)
|
56 |
-
selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt)
|
57 |
-
print(f"加载的 checkpoint 文件为: {selected_ckpt_path}")
|
58 |
-
|
59 |
-
ckpt = torch.load(selected_ckpt_path, map_location=torch.device('cpu'))
|
60 |
-
algo.load_state_dict(ckpt['state_dict'], strict=False)
|
61 |
-
|
62 |
-
rank_zero_info("Model weights loaded.")
|
63 |
-
|
64 |
-
class PoseExperiment(ABC):
|
65 |
-
"""
|
66 |
-
Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more
|
67 |
-
flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks.
|
68 |
-
"""
|
69 |
-
|
70 |
-
# each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
|
71 |
-
compatible_algorithms = dict(
|
72 |
-
pose_prediction=PosePrediction
|
73 |
-
)
|
74 |
-
|
75 |
-
compatible_datasets = dict(
|
76 |
-
video_minecraft_pose=MinecraftVideoPoseDataset
|
77 |
-
)
|
78 |
-
|
79 |
-
def __init__(
|
80 |
-
self,
|
81 |
-
root_cfg: DictConfig,
|
82 |
-
logger: Optional[WandbLogger] = None,
|
83 |
-
ckpt_path: Optional[Union[str, pathlib.Path]] = None,
|
84 |
-
) -> None:
|
85 |
-
"""
|
86 |
-
Constructor
|
87 |
-
|
88 |
-
Args:
|
89 |
-
cfg: configuration file that contains everything about the experiment
|
90 |
-
logger: a pytorch-lightning WandbLogger instance
|
91 |
-
ckpt_path: an optional path to saved checkpoint
|
92 |
-
"""
|
93 |
-
super().__init__()
|
94 |
-
self.root_cfg = root_cfg
|
95 |
-
self.cfg = root_cfg.experiment
|
96 |
-
self.debug = root_cfg.debug
|
97 |
-
self.logger = logger
|
98 |
-
self.ckpt_path = ckpt_path
|
99 |
-
self.algo = None
|
100 |
-
self.vae_path = "/cpfs01/user/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
|
101 |
-
|
102 |
-
def _build_algo(self):
|
103 |
-
"""
|
104 |
-
Build the lightning module
|
105 |
-
:return: a pytorch-lightning module to be launched
|
106 |
-
"""
|
107 |
-
algo_name = self.root_cfg.algorithm._name
|
108 |
-
if algo_name not in self.compatible_algorithms:
|
109 |
-
raise ValueError(
|
110 |
-
f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. "
|
111 |
-
"Make sure you define compatible_algorithms correctly and make sure that each key has "
|
112 |
-
"same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix"
|
113 |
-
)
|
114 |
-
return self.compatible_algorithms[algo_name](self.root_cfg.algorithm)
|
115 |
-
|
116 |
-
def exec_task(self, task: str) -> None:
|
117 |
-
"""
|
118 |
-
Executing a certain task specified by string. Each task should be a stage of experiment.
|
119 |
-
In most computer vision / nlp applications, tasks should be just train and test.
|
120 |
-
In reinforcement learning, you might have more stages such as collecting dataset etc
|
121 |
-
|
122 |
-
Args:
|
123 |
-
task: a string specifying a task implemented for this experiment
|
124 |
-
"""
|
125 |
-
if hasattr(self, task) and callable(getattr(self, task)):
|
126 |
-
if is_rank_zero:
|
127 |
-
print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
|
128 |
-
getattr(self, task)()
|
129 |
-
else:
|
130 |
-
raise ValueError(
|
131 |
-
f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
|
132 |
-
)
|
133 |
-
|
134 |
-
|
135 |
-
def _build_trainer_callbacks(self):
|
136 |
-
callbacks = []
|
137 |
-
if self.logger:
|
138 |
-
callbacks.append(LearningRateMonitor("step", True))
|
139 |
-
|
140 |
-
def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
141 |
-
train_dataset = self._build_dataset("training")
|
142 |
-
shuffle = (
|
143 |
-
False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle
|
144 |
-
)
|
145 |
-
if train_dataset:
|
146 |
-
return torch.utils.data.DataLoader(
|
147 |
-
train_dataset,
|
148 |
-
batch_size=self.cfg.training.batch_size,
|
149 |
-
num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers),
|
150 |
-
shuffle=shuffle,
|
151 |
-
persistent_workers=True,
|
152 |
-
)
|
153 |
-
else:
|
154 |
-
return None
|
155 |
-
|
156 |
-
def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
157 |
-
validation_dataset = self._build_dataset("validation")
|
158 |
-
shuffle = (
|
159 |
-
False
|
160 |
-
if isinstance(validation_dataset, torch.utils.data.IterableDataset)
|
161 |
-
else self.cfg.validation.data.shuffle
|
162 |
-
)
|
163 |
-
if validation_dataset:
|
164 |
-
return torch.utils.data.DataLoader(
|
165 |
-
validation_dataset,
|
166 |
-
batch_size=self.cfg.validation.batch_size,
|
167 |
-
num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers),
|
168 |
-
shuffle=shuffle,
|
169 |
-
persistent_workers=True,
|
170 |
-
)
|
171 |
-
else:
|
172 |
-
return None
|
173 |
-
|
174 |
-
def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
|
175 |
-
test_dataset = self._build_dataset("test")
|
176 |
-
shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle
|
177 |
-
if test_dataset:
|
178 |
-
return torch.utils.data.DataLoader(
|
179 |
-
test_dataset,
|
180 |
-
batch_size=self.cfg.test.batch_size,
|
181 |
-
num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers),
|
182 |
-
shuffle=shuffle,
|
183 |
-
persistent_workers=True,
|
184 |
-
)
|
185 |
-
else:
|
186 |
-
return None
|
187 |
-
|
188 |
-
def training(self) -> None:
|
189 |
-
"""
|
190 |
-
All training happens here
|
191 |
-
"""
|
192 |
-
if not self.algo:
|
193 |
-
self.algo = self._build_algo()
|
194 |
-
if self.cfg.training.compile:
|
195 |
-
self.algo = torch.compile(self.algo)
|
196 |
-
|
197 |
-
callbacks = []
|
198 |
-
if self.logger:
|
199 |
-
callbacks.append(LearningRateMonitor("step", True))
|
200 |
-
if "checkpointing" in self.cfg.training:
|
201 |
-
callbacks.append(
|
202 |
-
ModelCheckpoint(
|
203 |
-
pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints",
|
204 |
-
**self.cfg.training.checkpointing,
|
205 |
-
)
|
206 |
-
)
|
207 |
-
|
208 |
-
trainer = pl.Trainer(
|
209 |
-
accelerator="auto",
|
210 |
-
devices="auto", # 自动选择设备
|
211 |
-
strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
|
212 |
-
logger=self.logger or False, # 简化写法
|
213 |
-
callbacks=callbacks,
|
214 |
-
gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, # 确保默认值
|
215 |
-
val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None,
|
216 |
-
limit_val_batches=self.cfg.validation.limit_batch,
|
217 |
-
check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None,
|
218 |
-
accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, # 默认累积为1
|
219 |
-
precision=self.cfg.training.precision or 32, # 默认32位精度
|
220 |
-
detect_anomaly=False, # 默认关闭异常检测
|
221 |
-
num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0,
|
222 |
-
max_epochs=self.cfg.training.max_epochs,
|
223 |
-
max_steps=self.cfg.training.max_steps,
|
224 |
-
max_time=self.cfg.training.max_time
|
225 |
-
)
|
226 |
-
|
227 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
228 |
-
|
229 |
-
trainer.fit(
|
230 |
-
self.algo,
|
231 |
-
train_dataloaders=self._build_training_loader(),
|
232 |
-
val_dataloaders=self._build_validation_loader(),
|
233 |
-
ckpt_path=self.ckpt_path,
|
234 |
-
)
|
235 |
-
|
236 |
-
def validation(self) -> None:
|
237 |
-
"""
|
238 |
-
All validation happens here
|
239 |
-
"""
|
240 |
-
if not self.algo:
|
241 |
-
self.algo = self._build_algo()
|
242 |
-
if self.cfg.validation.compile:
|
243 |
-
self.algo = torch.compile(self.algo)
|
244 |
-
|
245 |
-
callbacks = []
|
246 |
-
|
247 |
-
trainer = pl.Trainer(
|
248 |
-
accelerator="auto",
|
249 |
-
logger=self.logger,
|
250 |
-
devices="auto",
|
251 |
-
num_nodes=self.cfg.num_nodes,
|
252 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
253 |
-
callbacks=callbacks,
|
254 |
-
# limit_val_batches=self.cfg.validation.limit_batch,
|
255 |
-
limit_val_batches=self.cfg.validation.limit_batch,
|
256 |
-
precision=self.cfg.validation.precision,
|
257 |
-
detect_anomaly=False, # self.cfg.debug,
|
258 |
-
inference_mode=self.cfg.validation.inference_mode,
|
259 |
-
)
|
260 |
-
|
261 |
-
load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
|
262 |
-
|
263 |
-
trainer.validate(
|
264 |
-
self.algo,
|
265 |
-
dataloaders=self._build_validation_loader(),
|
266 |
-
ckpt_path=self.ckpt_path,
|
267 |
-
)
|
268 |
-
|
269 |
-
def test(self) -> None:
|
270 |
-
"""
|
271 |
-
All testing happens here
|
272 |
-
"""
|
273 |
-
if not self.algo:
|
274 |
-
self.algo = self._build_algo()
|
275 |
-
if self.cfg.test.compile:
|
276 |
-
self.algo = torch.compile(self.algo)
|
277 |
-
|
278 |
-
callbacks = []
|
279 |
-
|
280 |
-
trainer = pl.Trainer(
|
281 |
-
accelerator="auto",
|
282 |
-
logger=self.logger,
|
283 |
-
devices="auto",
|
284 |
-
num_nodes=self.cfg.num_nodes,
|
285 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
286 |
-
callbacks=callbacks,
|
287 |
-
limit_test_batches=self.cfg.test.limit_batch,
|
288 |
-
precision=self.cfg.test.precision,
|
289 |
-
detect_anomaly=False, # self.cfg.debug,
|
290 |
-
)
|
291 |
-
|
292 |
-
# Only load the checkpoint if only testing. Otherwise, it will have been loaded
|
293 |
-
# and further trained during train.
|
294 |
-
trainer.test(
|
295 |
-
self.algo,
|
296 |
-
dataloaders=self._build_test_loader(),
|
297 |
-
ckpt_path=self.ckpt_path,
|
298 |
-
)
|
299 |
-
if not self.algo:
|
300 |
-
self.algo = self._build_algo()
|
301 |
-
if self.cfg.validation.compile:
|
302 |
-
self.algo = torch.compile(self.algo)
|
303 |
-
|
304 |
-
def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]:
|
305 |
-
if split in ["training", "test", "validation"]:
|
306 |
-
return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split)
|
307 |
-
else:
|
308 |
-
raise NotImplementedError(f"split '{split}' is not implemented")
|
309 |
-
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/exp_video.py
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
from datasets.video import (
|
2 |
-
MinecraftVideoDataset,
|
3 |
-
MinecraftVideoPoseDataset
|
4 |
-
)
|
5 |
-
|
6 |
-
from algorithms.worldmem import WorldMemMinecraft
|
7 |
-
from algorithms.worldmem import PosePrediction
|
8 |
-
from .exp_base import BaseLightningExperiment
|
9 |
-
|
10 |
-
|
11 |
-
class VideoPredictionExperiment(BaseLightningExperiment):
|
12 |
-
"""
|
13 |
-
A video prediction experiment
|
14 |
-
"""
|
15 |
-
|
16 |
-
compatible_algorithms = dict(
|
17 |
-
df_video_worldmemminecraft=WorldMemMinecraft,
|
18 |
-
pose_prediction=PosePrediction
|
19 |
-
)
|
20 |
-
|
21 |
-
compatible_datasets = dict(
|
22 |
-
# video datasets
|
23 |
-
video_minecraft=MinecraftVideoDataset,
|
24 |
-
video_minecraft_pose=MinecraftVideoPoseDataset
|
25 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
DELETED
@@ -1,219 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
|
3 |
-
template [repo](https://github.com/buoyancy99/research-template).
|
4 |
-
By its MIT license, you must keep the above sentence in `README.md`
|
5 |
-
and the `LICENSE` file to credit the author.
|
6 |
-
|
7 |
-
Main file for the project. This will create and run new experiments and load checkpoints from wandb.
|
8 |
-
Borrowed part of the code from David Charatan and wandb.
|
9 |
-
"""
|
10 |
-
|
11 |
-
import sys
|
12 |
-
import subprocess
|
13 |
-
import time
|
14 |
-
from pathlib import Path
|
15 |
-
|
16 |
-
import hydra
|
17 |
-
from omegaconf import DictConfig, OmegaConf
|
18 |
-
from omegaconf.omegaconf import open_dict
|
19 |
-
|
20 |
-
from utils.print_utils import cyan
|
21 |
-
from utils.ckpt_utils import download_latest_checkpoint, is_run_id
|
22 |
-
from utils.cluster_utils import submit_slurm_job
|
23 |
-
from utils.distributed_utils import is_rank_zero
|
24 |
-
|
25 |
-
def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = '*.ckpt'):
|
26 |
-
# 获取文件夹中所有符合 pattern 的文件
|
27 |
-
checkpoint_files = list(checkpoint_folder.glob(pattern))
|
28 |
-
if not checkpoint_files:
|
29 |
-
return None # 如果没有找到 checkpoint 文件,返回 None
|
30 |
-
# 根据文件修改时间(st_mtime)选取最新的文件
|
31 |
-
latest_checkpoint = max(checkpoint_files, key=lambda f: f.stat().st_mtime)
|
32 |
-
return latest_checkpoint
|
33 |
-
|
34 |
-
def run_local(cfg: DictConfig):
|
35 |
-
# delay some imports in case they are not needed in non-local envs for submission
|
36 |
-
from experiments import build_experiment
|
37 |
-
from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger
|
38 |
-
|
39 |
-
# Get yaml names
|
40 |
-
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
|
41 |
-
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
|
42 |
-
|
43 |
-
with open_dict(cfg):
|
44 |
-
if cfg_choice["experiment"] is not None:
|
45 |
-
cfg.experiment._name = cfg_choice["experiment"]
|
46 |
-
if cfg_choice["dataset"] is not None:
|
47 |
-
cfg.dataset._name = cfg_choice["dataset"]
|
48 |
-
if cfg_choice["algorithm"] is not None:
|
49 |
-
cfg.algorithm._name = cfg_choice["algorithm"]
|
50 |
-
|
51 |
-
# import pdb;pdb.set_trace()
|
52 |
-
# Set up the output directory.
|
53 |
-
output_dir = getattr(cfg, "output_dir", None)
|
54 |
-
if output_dir is not None:
|
55 |
-
OmegaConf.set_readonly(hydra_cfg, False)
|
56 |
-
hydra_cfg.runtime.output_dir = output_dir
|
57 |
-
OmegaConf.set_readonly(hydra_cfg, True)
|
58 |
-
|
59 |
-
output_dir = Path(hydra_cfg.runtime.output_dir)
|
60 |
-
|
61 |
-
if is_rank_zero:
|
62 |
-
print(cyan(f"Outputs will be saved to:"), output_dir)
|
63 |
-
(output_dir.parents[1] / "latest-run").unlink(missing_ok=True)
|
64 |
-
(output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True)
|
65 |
-
|
66 |
-
# Set up logging with wandb.
|
67 |
-
if cfg.wandb.mode != "disabled":
|
68 |
-
# If resuming, merge into the existing run on wandb.
|
69 |
-
resume = cfg.get("resume", None)
|
70 |
-
name = f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})" if resume is None else None
|
71 |
-
|
72 |
-
if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
|
73 |
-
logger_cls = OfflineWandbLogger
|
74 |
-
else:
|
75 |
-
logger_cls = SpaceEfficientWandbLogger
|
76 |
-
|
77 |
-
offline = cfg.wandb.mode != "online"
|
78 |
-
logger = logger_cls(
|
79 |
-
name=name,
|
80 |
-
save_dir=str(output_dir),
|
81 |
-
offline=offline,
|
82 |
-
entity=cfg.wandb.entity,
|
83 |
-
project=cfg.wandb.project,
|
84 |
-
log_model=False,
|
85 |
-
config=OmegaConf.to_container(cfg),
|
86 |
-
id=resume,
|
87 |
-
resume="auto"
|
88 |
-
)
|
89 |
-
|
90 |
-
else:
|
91 |
-
logger = None
|
92 |
-
|
93 |
-
# Load ckpt
|
94 |
-
resume = cfg.get("resume", None)
|
95 |
-
load = cfg.get("load", None)
|
96 |
-
checkpoint_path = None
|
97 |
-
load_id = None
|
98 |
-
if load and not is_run_id(load):
|
99 |
-
checkpoint_path = load
|
100 |
-
if resume:
|
101 |
-
load_id = resume
|
102 |
-
elif load and is_run_id(load):
|
103 |
-
load_id = load
|
104 |
-
else:
|
105 |
-
load_id = None
|
106 |
-
|
107 |
-
if load_id:
|
108 |
-
run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
|
109 |
-
checkpoint_path = Path("outputs/downloaded") / run_path / "model.ckpt"
|
110 |
-
checkpoint_path = output_dir / get_latest_checkpoint(output_dir / "checkpoints")
|
111 |
-
|
112 |
-
if checkpoint_path and is_rank_zero:
|
113 |
-
print(f"Will load checkpoint from {checkpoint_path}")
|
114 |
-
|
115 |
-
# launch experiment
|
116 |
-
experiment = build_experiment(cfg, logger, checkpoint_path)
|
117 |
-
for task in cfg.experiment.tasks:
|
118 |
-
experiment.exec_task(task)
|
119 |
-
|
120 |
-
|
121 |
-
def run_slurm(cfg: DictConfig):
|
122 |
-
python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True"
|
123 |
-
project_root = Path.cwd()
|
124 |
-
while not (project_root / ".git").exists():
|
125 |
-
project_root = project_root.parent
|
126 |
-
if project_root == Path("/"):
|
127 |
-
raise Exception("Could not find repo directory!")
|
128 |
-
|
129 |
-
slurm_log_dir = submit_slurm_job(
|
130 |
-
cfg,
|
131 |
-
python_args,
|
132 |
-
project_root,
|
133 |
-
)
|
134 |
-
|
135 |
-
if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
|
136 |
-
print("Job submitted to a compute node without internet. This requires manual syncing on login node.")
|
137 |
-
osh_command_dir = project_root / ".wandb_osh_command_dir"
|
138 |
-
|
139 |
-
osh_proc = None
|
140 |
-
# if click.confirm("Do you want us to run the sync loop for you?", default=True):
|
141 |
-
osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir])
|
142 |
-
print(f"Running wandb-osh in background... PID: {osh_proc.pid}")
|
143 |
-
print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.")
|
144 |
-
print(
|
145 |
-
f"You can manually start a sync loop later by running the following:",
|
146 |
-
cyan(f"wandb-osh --command-dir {osh_command_dir}"),
|
147 |
-
)
|
148 |
-
|
149 |
-
print(
|
150 |
-
"Once the job gets allocated and starts running, we will print a command below "
|
151 |
-
"for you to trace the errors and outputs: (Ctrl + C to exit without waiting)"
|
152 |
-
)
|
153 |
-
msg = f"tail -f {slurm_log_dir}/* \n"
|
154 |
-
try:
|
155 |
-
while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")):
|
156 |
-
time.sleep(1)
|
157 |
-
print(cyan("To trace the outputs and errors, run the following command:"), msg)
|
158 |
-
except KeyboardInterrupt:
|
159 |
-
print("Keyboard interrupt detected. Exiting...")
|
160 |
-
print(
|
161 |
-
cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"),
|
162 |
-
msg,
|
163 |
-
)
|
164 |
-
|
165 |
-
|
166 |
-
@hydra.main(
|
167 |
-
version_base=None,
|
168 |
-
config_path="configurations",
|
169 |
-
config_name="config",
|
170 |
-
)
|
171 |
-
def run(cfg: DictConfig):
|
172 |
-
if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
|
173 |
-
with open_dict(cfg):
|
174 |
-
if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
|
175 |
-
cfg.wandb.mode = "offline"
|
176 |
-
|
177 |
-
if "name" not in cfg:
|
178 |
-
raise ValueError("must specify a name for the run with command line argument '+name=[name]'")
|
179 |
-
|
180 |
-
if not cfg.wandb.get("entity", None):
|
181 |
-
raise ValueError(
|
182 |
-
"must specify wandb entity in 'configurations/config.yaml' or with command line"
|
183 |
-
" argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group"
|
184 |
-
" name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/"
|
185 |
-
)
|
186 |
-
|
187 |
-
if cfg.wandb.project is None:
|
188 |
-
cfg.wandb.project = str(Path(__file__).parent.name)
|
189 |
-
|
190 |
-
# If resuming or loading a wandb ckpt and not on a compute node, download the checkpoint.
|
191 |
-
resume = cfg.get("resume", None)
|
192 |
-
load = cfg.get("load", None)
|
193 |
-
|
194 |
-
if resume and load:
|
195 |
-
raise ValueError(
|
196 |
-
"When resuming a wandb run with `resume=[wandb id]`, checkpoint will be loaded from the cloud"
|
197 |
-
"and `load` should not be specified."
|
198 |
-
)
|
199 |
-
|
200 |
-
if resume:
|
201 |
-
load_id = resume
|
202 |
-
elif load and is_run_id(load):
|
203 |
-
load_id = load
|
204 |
-
else:
|
205 |
-
load_id = None
|
206 |
-
|
207 |
-
# if load_id and "_on_compute_node" not in cfg:
|
208 |
-
# run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
|
209 |
-
# download_latest_checkpoint(run_path, Path("outputs/downloaded"))
|
210 |
-
|
211 |
-
if "cluster" in cfg and not "_on_compute_node" in cfg:
|
212 |
-
print(cyan("Slurm detected, submitting to compute node instead of running locally..."))
|
213 |
-
run_slurm(cfg)
|
214 |
-
else:
|
215 |
-
run_local(cfg)
|
216 |
-
|
217 |
-
|
218 |
-
if __name__ == "__main__":
|
219 |
-
run() # pylint: disable=no-value-for-parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/README.md
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
# scirpts
|
2 |
-
|
3 |
-
`scripts` folder contains bash scripts for you to scale up your project on cloud.
|
4 |
-
Don't put your jupyter notebooks here! They belongs to `debug` folder.
|
5 |
-
|
6 |
-
General scripts that are useful for all projects can be put in the `script` folder directly.
|
7 |
-
|
8 |
-
---
|
9 |
-
|
10 |
-
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/dummy_script.sh
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
echo 'hello world'
|
|
|
|
split_checkpoint.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
ckpt_path = "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
|
4 |
-
checkpoint = torch.load(ckpt_path, map_location="cpu") # map_location 可根据需要更换
|
5 |
-
|
6 |
-
state_dict = checkpoint['state_dict']
|
7 |
-
pose_prediction_model_dict = {k.replace('pose_prediction_model.', ''): v for k, v in state_dict.items() if k.startswith('pose_prediction_model.')}
|
8 |
-
|
9 |
-
torch.save({'state_dict': pose_prediction_model_dict}, "pose_prediction_model_only.ckpt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|