xizaoqu commited on
Commit
100414d
·
1 Parent(s): f07d258
app.py CHANGED
@@ -10,13 +10,8 @@ import hydra
10
  from omegaconf import DictConfig, OmegaConf
11
  from omegaconf.omegaconf import open_dict
12
 
13
- from utils.print_utils import cyan
14
- from utils.ckpt_utils import download_latest_checkpoint, is_run_id
15
- from utils.cluster_utils import submit_slurm_job
16
- from utils.distributed_utils import is_rank_zero
17
  import numpy as np
18
  import torch
19
- from datasets.video.minecraft_video_dataset import *
20
  import torchvision.transforms as transforms
21
  import cv2
22
  import subprocess
@@ -351,18 +346,7 @@ def set_memory(examples_case, image_display, log_output, slider_denoising_step,
351
 
352
  return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
353
 
354
- css = """
355
- h1 {
356
- text-align: center;
357
- display:block;
358
- }
359
- """
360
-
361
- def on_select(evt: gr.SelectData):
362
- selected_index = evt.index
363
- return examples[selected_index]
364
-
365
- with gr.Blocks(css=css) as demo:
366
  gr.Markdown(
367
  """
368
  # WORLDMEM: Long-term Consistent World Generation with Memory
@@ -515,13 +499,6 @@ with gr.Blocks(css=css) as demo:
515
  example_case = gr.Textbox(label="Case", visible=False)
516
  image_output = gr.Image(visible=False)
517
 
518
- # gr.Examples(examples=example_images,
519
- # inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
520
- # fn=set_memory,
521
- # outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx],
522
- # cache_examples=True
523
- # )
524
-
525
  examples = gr.Examples(
526
  examples=example_images,
527
  inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
@@ -534,7 +511,6 @@ with gr.Blocks(css=css) as demo:
534
  outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
535
  )
536
 
537
-
538
  submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
539
  reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
540
  image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
 
10
  from omegaconf import DictConfig, OmegaConf
11
  from omegaconf.omegaconf import open_dict
12
 
 
 
 
 
13
  import numpy as np
14
  import torch
 
15
  import torchvision.transforms as transforms
16
  import cv2
17
  import subprocess
 
346
 
347
  return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
348
 
349
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
350
  gr.Markdown(
351
  """
352
  # WORLDMEM: Long-term Consistent World Generation with Memory
 
499
  example_case = gr.Textbox(label="Case", visible=False)
500
  image_output = gr.Image(visible=False)
501
 
 
 
 
 
 
 
 
502
  examples = gr.Examples(
503
  examples=example_images,
504
  inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
 
511
  outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
512
  )
513
 
 
514
  submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
515
  reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
516
  image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
configurations/README.md DELETED
@@ -1,7 +0,0 @@
1
- # configurations
2
-
3
- We use [Hydra](https://hydra.cc/docs/intro/) to manage configurations. Change/Add the yaml files in this folder
4
- to change the default configurations. You can also override the default configurations by
5
- passing command line arguments.
6
-
7
- All configurations are automatically saved in wandb run.
 
 
 
 
 
 
 
 
configurations/algorithm/base_algo.yaml DELETED
@@ -1,3 +0,0 @@
1
- # This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class
2
-
3
- debug: ${debug} # inherited from configurations/config.yaml
 
 
 
 
configurations/algorithm/base_pytorch_algo.yaml DELETED
@@ -1,4 +0,0 @@
1
- defaults:
2
- - base_algo # inherits from configurations/algorithm/base_algo.yaml
3
-
4
- lr: ${experiment.training.lr}
 
 
 
 
 
configurations/algorithm/df_base.yaml DELETED
@@ -1,42 +0,0 @@
1
- defaults:
2
- - base_pytorch_algo
3
-
4
- # dataset-dependent configurations
5
- x_shape: ${dataset.observation_shape}
6
- frame_stack: 1
7
- frame_skip: 1
8
- data_mean: ${dataset.data_mean}
9
- data_std: ${dataset.data_std}
10
- external_cond_dim: 0 #${dataset.action_dim}
11
- context_frames: ${dataset.context_length}
12
- # training hyperparameters
13
- weight_decay: 1e-4
14
- warmup_steps: 10000
15
- optimizer_beta: [0.9, 0.999]
16
- # diffusion-related
17
- uncertainty_scale: 1
18
- guidance_scale: 0.0
19
- chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
20
- scheduling_matrix: autoregressive
21
- noise_level: random_all
22
- causal: True
23
-
24
- diffusion:
25
- # training
26
- objective: pred_x0
27
- beta_schedule: cosine
28
- schedule_fn_kwargs: {}
29
- clip_noise: 20.0
30
- use_snr: False
31
- use_cum_snr: False
32
- use_fused_snr: False
33
- snr_clip: 5.0
34
- cum_snr_decay: 0.98
35
- timesteps: 1000
36
- # sampling
37
- sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased
38
- ddim_sampling_eta: 1.0
39
- stabilization_level: 10
40
- # architecture
41
- architecture:
42
- network_size: 64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/algorithm/df_video_worldmemminecraft.yaml DELETED
@@ -1,42 +0,0 @@
1
- defaults:
2
- - df_base
3
-
4
- n_frames: ${dataset.n_frames}
5
- frame_skip: ${dataset.frame_skip}
6
- metadata: ${dataset.metadata}
7
-
8
- # training hyperparameters
9
- weight_decay: 2e-3
10
- warmup_steps: 10000
11
- optimizer_beta: [0.9, 0.99]
12
- action_cond_dim: 25
13
-
14
- diffusion:
15
- # training
16
- beta_schedule: sigmoid
17
- objective: pred_v
18
- use_fused_snr: True
19
- cum_snr_decay: 0.96
20
- clip_noise: 20.
21
- # sampling
22
- sampling_timesteps: 20
23
- ddim_sampling_eta: 0.0
24
- stabilization_level: 15
25
- # architecture
26
- architecture:
27
- network_size: 64
28
- attn_heads: 4
29
- attn_dim_head: 64
30
- dim_mults: [1, 2, 4, 8]
31
- resolution: ${dataset.resolution}
32
- attn_resolutions: [16, 32, 64, 128]
33
- use_init_temporal_attn: True
34
- use_linear_attn: True
35
- time_emb_type: rotary
36
-
37
- metrics:
38
- # - fvd
39
- # - fid
40
- # - lpips
41
-
42
- _name: df_video_worldmemminecraft
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/algorithm/pose_prediction.yaml DELETED
@@ -1,19 +0,0 @@
1
- defaults:
2
- - df_base
3
-
4
- n_frames: ${dataset.n_frames}
5
- frame_skip: ${dataset.frame_skip}
6
- metadata: ${dataset.metadata}
7
-
8
- # training hyperparameters
9
- weight_decay: 2e-3
10
- warmup_steps: 10000
11
- optimizer_beta: [0.9, 0.99]
12
-
13
-
14
- metrics:
15
- # - fvd
16
- # - fid
17
- # - lpips
18
-
19
- _name: pose_prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/config.yaml DELETED
@@ -1,16 +0,0 @@
1
- # configuration parsing starts here
2
- defaults:
3
- - experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme]
4
- - dataset: video_minecraft_oasis # dataset yaml file name in configurations/dataset folder [fixme]
5
- - algorithm: df_video # algorithm yaml file name in configurations/algorithm folder [fixme]
6
- - cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute
7
-
8
- debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
9
-
10
- wandb:
11
- entity: xizaoqu # wandb account name / organization name [fixme]
12
- project: diffusion-forcing # wandb project name; if not provided, defaults to root folder name [fixme]
13
- mode: online # set wandb logging to online, offline or dryrun
14
-
15
- resume: null # wandb run id to resume logging and loading checkpoint from
16
- load: null # wandb run id containing checkpoint or a path to a checkpoint file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/dataset/base_dataset.yaml DELETED
@@ -1,3 +0,0 @@
1
- # This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class
2
-
3
- debug: ${debug} # inherited from configurations/config.yaml
 
 
 
 
configurations/dataset/base_video.yaml DELETED
@@ -1,14 +0,0 @@
1
- defaults:
2
- - base_dataset
3
-
4
- metadata: "data/${dataset.name}/metadata.json"
5
- data_mean: "data/${dataset.name}/data_mean.npy"
6
- data_std: "data/${dataset.name}/data_std.npy"
7
- save_dir: ???
8
- n_frames: 32
9
- context_length: 4
10
- resolution: 128
11
- observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"]
12
- external_cond_dim: 0
13
- validation_multiplier: 1
14
- frame_skip: 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/dataset/video_minecraft.yaml DELETED
@@ -1,14 +0,0 @@
1
- defaults:
2
- - base_video
3
-
4
- save_dir: data/minecraft_simple_backforward
5
- n_frames: 16 # TODO: increase later
6
- resolution: 128
7
- data_mean: 0.5
8
- data_std: 0.5
9
- action_cond_dim: 25
10
- context_length: 1
11
- frame_skip: 1
12
- validation_multiplier: 1
13
-
14
- _name: video_minecraft_oasis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/dataset/video_minecraft_pose.yaml DELETED
@@ -1,14 +0,0 @@
1
- defaults:
2
- - base_video
3
-
4
- save_dir: data/minecraft_simple_backforward
5
- n_frames: 16 # TODO: increase later
6
- resolution: 128
7
- data_mean: 0.5
8
- data_std: 0.5
9
- external_cond_dim: 25
10
- context_length: 1
11
- frame_skip: 1
12
- validation_multiplier: 1
13
-
14
- _name: video_minecraft_pose
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/experiment/base_experiment.yaml DELETED
@@ -1,2 +0,0 @@
1
- debug: ${debug} # inherited from configurations/config.yaml
2
- tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them.
 
 
 
configurations/experiment/base_pytorch.yaml DELETED
@@ -1,50 +0,0 @@
1
- # inherites from base_experiment.yaml
2
- # most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html
3
-
4
- defaults:
5
- - base_experiment
6
-
7
- tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them.
8
- num_nodes: 1 # number of gpu servers used in large scale distributed training
9
-
10
- training:
11
- precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable
12
- compile: False # whether to compile the model with torch.compile
13
- lr: 0.001 # learning rate
14
- batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training
15
- max_epochs: 1000 # set to -1 to train forever
16
- max_steps: -1 # set to -1 to train forever, will override max_epochs
17
- max_time: null # set to something like "00:12:00:00" to enable
18
- data:
19
- num_workers: 4 # number of CPU threads for data preprocessing.
20
- shuffle: True # whether training data will be shuffled
21
- optim:
22
- accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop
23
- gradient_clip_val: 0 # clip gradients with norm above this value, set to 0 to disable
24
- checkpointing:
25
- # these are arguments to pytorch lightning's callback, `ModelCheckpoint` class
26
- every_n_train_steps: 5000 # save a checkpoint every n train steps
27
- every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``
28
- train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
29
- enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones.
30
-
31
- validation:
32
- precision: 16-mixed
33
- compile: False # whether to compile the model with torch.compile
34
- batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
35
- val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
36
- val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
37
- limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
38
- inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
39
- data:
40
- num_workers: 4 # number of CPU threads for data preprocessing, for validation.
41
- shuffle: False # whether validation data will be shuffled
42
-
43
- test:
44
- precision: 16-mixed
45
- compile: False # whether to compile the model with torch.compile
46
- batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
47
- limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
48
- data:
49
- num_workers: 4 # number of CPU threads for data preprocessing, for test.
50
- shuffle: False # whether test data will be shuffled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/experiment/exp_pose.yaml DELETED
@@ -1,31 +0,0 @@
1
- defaults:
2
- - base_pytorch
3
-
4
- tasks: [training]
5
-
6
- training:
7
- lr: 8e-5
8
- precision: 16-mixed
9
- batch_size: 4
10
- max_epochs: -1
11
- max_steps: 2000005
12
- checkpointing:
13
- every_n_train_steps: 2500
14
- optim:
15
- gradient_clip_val: 1.0
16
-
17
- validation:
18
- val_every_n_step: 300
19
- val_every_n_epoch: null
20
- batch_size: 4
21
- limit_batch: 1
22
-
23
- test:
24
- limit_batch: 1
25
- batch_size: 1
26
-
27
- logging:
28
- metrics:
29
- # - fvd
30
- # - fid
31
- # - lpips
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configurations/experiment/exp_video.yaml DELETED
@@ -1,31 +0,0 @@
1
- defaults:
2
- - base_pytorch
3
-
4
- tasks: [training]
5
-
6
- training:
7
- lr: 8e-5
8
- precision: 16-mixed
9
- batch_size: 4
10
- max_epochs: -1
11
- max_steps: 2000005
12
- checkpointing:
13
- every_n_train_steps: 2500
14
- optim:
15
- gradient_clip_val: 1.0
16
-
17
- validation:
18
- val_every_n_step: 300
19
- val_every_n_epoch: null
20
- batch_size: 4
21
- limit_batch: 1
22
-
23
- test:
24
- limit_batch: 1
25
- batch_size: 1
26
-
27
- logging:
28
- metrics:
29
- # - fvd
30
- # - fid
31
- # - lpips
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets/README.md DELETED
@@ -1,11 +0,0 @@
1
- The `datasets` folder is used to contain dataset code or environment code.
2
- Don't store actual data like images here! For those, please use the `data` folder instead of `datasets`.
3
-
4
- Create a folder to create your own pytorch dataset definition. Then, update the `__init__.py`
5
- at every level to register all datasets.
6
-
7
- Each dataset class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/dataset` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/).
8
-
9
- ---
10
-
11
- This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
 
 
 
 
 
 
 
 
 
 
 
 
datasets/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .video import MinecraftVideoDataset
 
 
datasets/video/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .minecraft_video_dataset import MinecraftVideoDataset
2
- from .minecraft_video_dataset_pose import MinecraftVideoPoseDataset
 
 
 
datasets/video/base_video_dataset.py DELETED
@@ -1,158 +0,0 @@
1
- from typing import Sequence
2
- import torch
3
- import random
4
- import os
5
- import numpy as np
6
- import cv2
7
- from omegaconf import DictConfig
8
- from torchvision import transforms
9
- from pathlib import Path
10
- from abc import abstractmethod, ABC
11
- import json
12
-
13
-
14
- class BaseVideoDataset(torch.utils.data.Dataset, ABC):
15
- """
16
- Base class for video datasets. Videos may be of variable length.
17
-
18
- Folder structure of each dataset:
19
- - [save_dir] (specified in config, e.g., data/phys101)
20
- - /[split] (one per split)
21
- - /data_folder_name (e.g., videos)
22
- metadata.json
23
- """
24
-
25
- def __init__(self, cfg: DictConfig, split: str = "training"):
26
- super().__init__()
27
- self.cfg = cfg
28
- self.split = split
29
- self.resolution = cfg.resolution
30
- self.external_cond_dim = cfg.external_cond_dim
31
- self.n_frames = (
32
- cfg.n_frames * cfg.frame_skip
33
- if split == "training"
34
- else cfg.n_frames * cfg.frame_skip * cfg.validation_multiplier
35
- )
36
- self.frame_skip = cfg.frame_skip
37
- self.save_dir = Path(cfg.save_dir)
38
- self.save_dir.mkdir(exist_ok=True, parents=True)
39
- self.split_dir = self.save_dir / f"{split}"
40
-
41
- self.metadata_path = self.save_dir / "metadata.json"
42
-
43
- self.data_paths = self.get_data_paths(self.split)
44
-
45
- if self.split == 'training':
46
- self.metadata = [1200] * len(self.data_paths) # total 1500 f
47
- else:
48
- self.metadata = [1] * len(self.data_paths) # total 1500 f
49
- # self.clips_per_video = np.clip(np.array(self.metadata[split]) - self.n_frames + 1, a_min=1, a_max=None).astype(
50
- # np.int32
51
- # )
52
- self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype(
53
- np.int32
54
- )
55
- self.cum_clips_per_video = np.cumsum(self.clips_per_video)
56
- self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True)
57
-
58
- # shuffle but keep the same order for each epoch, so validation sample is diverse yet deterministic
59
- random.seed(0)
60
- self.idx_remap = list(range(self.__len__()))
61
- random.shuffle(self.idx_remap)
62
-
63
- @abstractmethod
64
- def download_dataset(self) -> Sequence[int]:
65
- """
66
- Download dataset from the internet and build it in save_dir
67
-
68
- Returns a list of video lengths
69
- """
70
- raise NotImplementedError
71
-
72
- @abstractmethod
73
- def get_data_paths(self, split):
74
- """Return a list of data paths (e.g. xxx.mp4) for a given split"""
75
- raise NotImplementedError
76
-
77
- def get_data_lengths(self, split):
78
- """Return a list of num_frames for each data path (e.g. xxx.mp4) for a given split"""
79
- lengths = []
80
- for path in self.get_data_paths(split):
81
- length = cv2.VideoCapture(str(path)).get(cv2.CAP_PROP_FRAME_COUNT)
82
- lengths.append(length)
83
- return lengths
84
-
85
- def split_idx(self, idx):
86
- video_idx = np.argmax(self.cum_clips_per_video > idx)
87
- frame_idx = idx - np.pad(self.cum_clips_per_video, (1, 0))[video_idx]
88
- return video_idx, frame_idx
89
-
90
- @staticmethod
91
- def load_video(path: Path):
92
- """
93
- Load video from a path
94
- :param filename: path to the video
95
- :return: video as a numpy array
96
- """
97
-
98
- cap = cv2.VideoCapture(str(path))
99
-
100
- frames = []
101
- while cap.isOpened():
102
- ret, frame = cap.read()
103
- if ret:
104
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
105
- frames.append(frame)
106
- else:
107
- break
108
-
109
- cap.release()
110
- frames = np.stack(frames, dtype=np.uint8)
111
- return np.transpose(frames, (0, 3, 1, 2)) # (T, C, H, W)
112
-
113
- @staticmethod
114
- def load_image(filename: Path):
115
- """
116
- Load image from a path
117
- :param filename: path to the image
118
- :return: image as a numpy array
119
- """
120
- image = cv2.imread(str(filename))
121
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
122
- return np.transpose(image, (2, 0, 1))
123
-
124
- def __len__(self):
125
- return self.clips_per_video.sum()
126
-
127
- def __getitem__(self, idx):
128
- idx = self.idx_remap[idx]
129
- video_idx, frame_idx = self.split_idx(idx)
130
- video_path = self.data_paths[video_idx]
131
- video = self.load_video(video_path)[frame_idx : frame_idx + self.n_frames]
132
-
133
- pad_len = self.n_frames - len(video)
134
-
135
- nonterminal = np.ones(self.n_frames)
136
- if len(video) < self.n_frames:
137
- video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
138
- nonterminal[-pad_len:] = 0
139
-
140
- video = torch.from_numpy(video / 256.0).float()
141
- video = self.transform(video)
142
-
143
- if self.external_cond_dim:
144
- external_cond = np.load(
145
- # pylint: disable=no-member
146
- self.condition_dir
147
- / f"{video_path.name.replace('.mp4', '.npy')}"
148
- )
149
- if len(external_cond) < self.n_frames:
150
- external_cond = np.pad(external_cond, ((0, pad_len),))
151
- external_cond = torch.from_numpy(external_cond).float()
152
- return (
153
- video[:: self.frame_skip],
154
- external_cond[:: self.frame_skip],
155
- nonterminal[:: self.frame_skip],
156
- )
157
- else:
158
- return video[:: self.frame_skip], nonterminal[:: self.frame_skip]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets/video/minecraft_video_dataset.py DELETED
@@ -1,262 +0,0 @@
1
- import os
2
- import io
3
- import tarfile
4
- import numpy as np
5
- import torch
6
- from typing import Sequence, Mapping
7
- from omegaconf import DictConfig
8
- from pytorchvideo.data.encoded_video import EncodedVideo
9
- import random
10
-
11
- from .base_video_dataset import BaseVideoDataset
12
-
13
-
14
-
15
-
16
- ACTION_KEYS = [
17
- "inventory",
18
- "ESC",
19
- "hotbar.1",
20
- "hotbar.2",
21
- "hotbar.3",
22
- "hotbar.4",
23
- "hotbar.5",
24
- "hotbar.6",
25
- "hotbar.7",
26
- "hotbar.8",
27
- "hotbar.9",
28
- "forward",
29
- "back",
30
- "left",
31
- "right",
32
- "cameraY",
33
- "cameraX",
34
- "jump",
35
- "sneak",
36
- "sprint",
37
- "swapHands",
38
- "attack",
39
- "use",
40
- "pickItem",
41
- "drop",
42
- ]
43
-
44
- def convert_action_space(actions):
45
- vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
46
- vec_25[actions[:,0]==1, 11] = 1
47
- vec_25[actions[:,0]==2, 12] = 1
48
- vec_25[actions[:,4]==11, 16] = -1
49
- vec_25[actions[:,4]==13, 16] = 1
50
- vec_25[actions[:,3]==11, 15] = -1
51
- vec_25[actions[:,3]==13, 15] = 1
52
- vec_25[actions[:,5]==6, 24] = 1
53
- vec_25[actions[:,5]==1, 24] = 1
54
- vec_25[actions[:,1]==1, 13] = 1
55
- vec_25[actions[:,1]==2, 14] = 1
56
- vec_25[actions[:,7]==1, 2] = 1
57
- return vec_25
58
-
59
- # Dataset class
60
- class MinecraftVideoDataset(BaseVideoDataset):
61
- """
62
- Minecraft video dataset for training and validation.
63
-
64
- Args:
65
- cfg (DictConfig): Configuration object.
66
- split (str): Dataset split ("training" or "validation").
67
- """
68
- def __init__(self, cfg: DictConfig, split: str = "training"):
69
- if split == "test":
70
- split = "validation"
71
- super().__init__(cfg, split)
72
- self.n_frames = cfg.n_frames_valid if split == "validation" and hasattr(cfg, "n_frames_valid") else cfg.n_frames
73
- self.use_plucker = cfg.use_plucker
74
- self.condition_similar_length = cfg.condition_similar_length
75
- self.customized_validation = cfg.customized_validation
76
- self.angle_range = cfg.angle_range
77
- self.pos_range = cfg.pos_range
78
- self.add_frame_timestep_embedder = cfg.add_frame_timestep_embedder
79
- self.training_dropout = 0.1
80
- self.sample_more_place = getattr(cfg, "sample_more_place", False)
81
- self.within_context = getattr(cfg, "within_context", False)
82
- self.sample_more_event = getattr(cfg, "sample_more_event", False)
83
- self.causal_frame = getattr(cfg, "causal_frame", False)
84
-
85
- def get_data_paths(self, split: str):
86
- """
87
- Retrieve all video file paths for the given split.
88
-
89
- Args:
90
- split (str): Dataset split ("training" or "validation").
91
-
92
- Returns:
93
- List[Path]: List of video file paths.
94
- """
95
- data_dir = self.save_dir / split
96
- paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
97
- if not paths:
98
- sub_dirs = os.listdir(data_dir)
99
- for sub_dir in sub_dirs:
100
- sub_path = data_dir / sub_dir
101
- paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name)
102
- return paths
103
-
104
- def download_dataset(self):
105
- pass
106
-
107
- def __getitem__(self, idx: int):
108
- """
109
- Retrieve a single data sample by index.
110
-
111
- Args:
112
- idx (int): Index of the data sample.
113
-
114
- Returns:
115
- Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timesteps.
116
- """
117
- max_retries = 1000
118
- for _ in range(max_retries):
119
- try:
120
- return self.load_data(idx)
121
- except Exception as e:
122
- print(f"Retrying due to error: {e}")
123
- idx = (idx + 1) % len(self)
124
-
125
- def load_data(self, idx):
126
- idx = self.idx_remap[idx]
127
- file_idx, frame_idx = self.split_idx(idx)
128
- action_path = self.data_paths[file_idx]
129
- video_path = self.data_paths[file_idx]
130
-
131
- action_path = video_path.with_suffix(".npz")
132
- actions_pool = np.load(action_path)['actions']
133
- poses_pool = np.load(action_path)['poses']
134
-
135
-
136
- poses_pool[0,1] = poses_pool[1,1] # wrong first in place
137
-
138
- assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
139
-
140
-
141
- if len(poses_pool) < len(actions_pool):
142
- poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
143
-
144
- actions_pool = convert_action_space(actions_pool)
145
- video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
146
-
147
- frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
148
-
149
- if self.split == "validation":
150
- frame_idx = 240
151
-
152
- if self.sample_more_place and self.split == "training":
153
- if random.uniform(0, 1) > 0.5:
154
- place_mask = (actions_pool[:,24]==1)
155
- place_mask[:100] = 0
156
- valid_indices = np.where(place_mask)[0]
157
- random_index = np.random.choice(valid_indices)
158
- frame_idx = random_index - random.randint(1, self.n_frames-1)
159
-
160
- total_frame = video_raw.duration.numerator
161
- fps = 10 # video_raw.duration.denominator
162
- total_frame = total_frame * fps / video_raw.duration.denominator
163
- video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
164
- video = video.permute(1, 2, 3, 0).numpy()
165
-
166
- if self.split != "validation" and 'degrees' in np.load(action_path).keys():
167
- degrees = np.load(action_path)['degrees']
168
- actions_pool[:,16] *= degrees
169
-
170
- actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames])
171
-
172
- poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
173
- pad_len = self.n_frames - len(video)
174
- poses_pool[:,:3] -= poses[:1,:3]
175
- poses_pool[:,-1] = -poses_pool[:,-1]
176
- poses_pool[:,3:] %= 360
177
-
178
- poses[:,:3] -= poses[:1,:3] # do not normalize angle
179
- poses[:,-1] = -poses[:,-1]
180
- poses[:,3:] %= 360
181
-
182
- assert len(video) >= self.n_frames, f"{video_path}"
183
-
184
- if self.split == "training" and self.condition_similar_length>0:
185
- if random.uniform(0, 1) > self.training_dropout:
186
- refer_frame_dis = poses[:,None] - poses_pool[None,:]
187
- refer_frame_dis = np.abs(refer_frame_dis)
188
- refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180] = 360 - refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180]
189
- valid_index = ((((refer_frame_dis[..., :3] <= self.pos_range).sum(-1))>=3) & (((refer_frame_dis[..., 3:] <= self.angle_range).sum(-1))>=2) & \
190
- ((((refer_frame_dis[..., :3] > 0).sum(-1))>=1) | (((refer_frame_dis[..., 3:] > 0).sum(-1))>=1))
191
- ).sum(0)
192
- valid_index[:100] = 0 # mute bad initial scene
193
-
194
- if self.add_frame_timestep_embedder and self.causal_frame and (actions_pool[:frame_idx,24]==1).sum() > 0:
195
- valid_index[frame_idx:] = 0
196
-
197
- mask = valid_index >= 1
198
- mask[0] = False
199
- candidate_indices = np.argwhere(mask)
200
-
201
- mask2 = valid_index >= 0
202
- mask2[0] = False
203
-
204
- count = min(self.condition_similar_length, candidate_indices.shape[0])
205
- selected_indices = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:,0]
206
-
207
- if count < self.condition_similar_length:
208
- candidate_indices2 = np.argwhere(mask2)
209
- selected_indices2 = candidate_indices2[np.random.choice(candidate_indices2.shape[0], self.condition_similar_length-count, replace=True)][:,0]
210
- selected_indices = np.concatenate([selected_indices, selected_indices2])
211
-
212
- if self.sample_more_event:
213
- if random.uniform(0, 1) > 0.3:
214
- valid_idx = torch.nonzero(actions_pool[:frame_idx,24]==1)[:,0]
215
- if len(valid_idx) > self.condition_similar_length //2:
216
- valid_idx = valid_idx[-self.condition_similar_length //2:]
217
-
218
- if len(valid_idx) > 0:
219
- selected_indices[-len(valid_idx):] = valid_idx + 4
220
-
221
- else:
222
- selected_indices = np.array(list(range(self.condition_similar_length))) * 0 + random.randint(0, frame_idx)
223
-
224
- video_pool = []
225
- for si in selected_indices:
226
- video_pool.append(video_raw.get_clip(start_sec=si/fps, end_sec=(si+1)/fps)["video"][:,0].permute(1,2,0))
227
-
228
- video_pool = np.stack(video_pool)
229
- video = np.concatenate([video, video_pool])
230
- actions = np.concatenate([actions, actions_pool[selected_indices]])
231
- poses = np.concatenate([poses, poses_pool[selected_indices]])
232
-
233
- timestep = np.concatenate([np.array(list(range(frame_idx, frame_idx + self.n_frames))), selected_indices])
234
-
235
- else:
236
- timestep = np.array(list(range(self.n_frames)))
237
-
238
- video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
239
-
240
- if self.split == "validation" and not self.customized_validation:
241
- num_frame = actions.shape[0]
242
-
243
- actions[:] = 0
244
- actions[:,16] = 1
245
- poses[:] = 0
246
- for ff in range(1, num_frame):
247
- poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15
248
-
249
- if self.within_context:
250
- actions[:] = 0
251
- actions[:self.n_frames//2+1,16] = 1
252
- actions[self.n_frames//2+1:,16] = -1
253
- poses[:] = 0
254
- for ff in range(1, num_frame):
255
- poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15
256
-
257
- return (
258
- video[:: self.frame_skip],
259
- actions[:: self.frame_skip],
260
- poses[:: self.frame_skip],
261
- timestep
262
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets/video/minecraft_video_dataset_oasis_filter.py DELETED
@@ -1,99 +0,0 @@
1
- import torch
2
- from typing import Sequence
3
- import numpy as np
4
- import io
5
- from omegaconf import DictConfig
6
- from tqdm import tqdm
7
-
8
- from typing import Mapping, Sequence
9
- import os
10
- import math
11
- from packaging import version as pver
12
- from PIL import Image
13
- import random
14
- import shutil
15
- import os
16
- from pathlib import Path
17
- import traceback
18
-
19
- class OASISMinecraftVideoFilterDataset(torch.utils.data.Dataset):
20
- """
21
- Minecraft dataset
22
- """
23
-
24
- def __init__(self, source_dir, target_dir, split):
25
- self.source_dir = Path(source_dir)
26
- self.split_dir = self.source_dir / f"{split}"
27
- self.data_paths = self.get_data_paths(split)
28
- self.target_dir = Path(target_dir) / f"{split}"
29
- self.target_dir.mkdir(exist_ok=True, parents=True)
30
-
31
- def get_data_paths(self, split):
32
- data_dir = self.source_dir / split
33
- paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
34
-
35
- if len(paths) == 0:
36
- sub_path = os.listdir(data_dir)
37
- for sp in sub_path:
38
- data_dir = self.source_dir / split / sp
39
- paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
40
- return paths
41
-
42
- def __len__(self):
43
- return len(self.data_paths)
44
-
45
- def __getitem__(self, idx):
46
-
47
- return self.sub_get(idx)
48
- # try:
49
- # return self.sub_get(idx)
50
- # except Exception as e:
51
- # traceback.print_exc()
52
- # # return self.sub_get(0)
53
-
54
-
55
- def sub_get(self, idx):
56
- action_path = self.data_paths[idx]
57
- video_path = self.data_paths[idx]
58
-
59
- action_path = video_path.with_suffix(".npz")
60
- actions_pool = np.load(action_path)['actions']
61
- poses_pool = np.load(action_path)['poses']
62
-
63
- poses_pool[0,1] = poses_pool[1,1] # wrong first in place
64
-
65
- print(poses_pool.shape)
66
-
67
- if poses_pool[:,1].max() - poses_pool[:,1].min() < 2:
68
- target_action_path = self.target_dir / action_path.parent.name / action_path.name
69
- target_video_path = self.target_dir / video_path.parent.name / video_path.name
70
- target_action_path.parent.mkdir(exist_ok=True, parents=True)
71
- target_video_path.parent.mkdir(exist_ok=True, parents=True)
72
-
73
- try:
74
- shutil.copy2(action_path, target_action_path)
75
- shutil.copy2(video_path, target_video_path)
76
- except:
77
- import pdb;pdb.set_trace()
78
-
79
- return poses_pool[:10]
80
-
81
-
82
-
83
- if __name__ == "__main__":
84
- import torch
85
- from unittest.mock import MagicMock
86
- import tqdm
87
-
88
- cfg = MagicMock()
89
- cfg.resolution = 64
90
- cfg.external_cond_dim = 0
91
- cfg.n_frames = 64
92
- cfg.save_dir = "data/minecraft"
93
- cfg.validation_multiplier = 1
94
-
95
- dataset = MinecraftVideoDataset(cfg, "training")
96
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
97
-
98
- for batch in tqdm.tqdm(dataloader):
99
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets/video/minecraft_video_dataset_pose.py DELETED
@@ -1,421 +0,0 @@
1
- import torch
2
- from typing import Sequence
3
- import numpy as np
4
- import io
5
- import tarfile
6
- from pytorchvideo.data.encoded_video import EncodedVideo
7
- from omegaconf import DictConfig
8
- from tqdm import tqdm
9
-
10
- from .base_video_dataset import BaseVideoDataset
11
- from typing import Mapping, Sequence
12
- import os
13
- import math
14
- from packaging import version as pver
15
- from PIL import Image
16
- import random
17
-
18
- def euler_to_rotation_matrix(pitch, yaw):
19
- """
20
- Convert euler angles (pitch, yaw) to a 3x3 rotation matrix.
21
- pitch: rotation around x-axis (in radians)
22
- yaw: rotation around y-axis (in radians)
23
- """
24
- # Rotation matrix around x-axis (pitch)
25
- R_x = np.array([
26
- [1, 0, 0],
27
- [0, math.cos(pitch), -math.sin(pitch)],
28
- [0, math.sin(pitch), math.cos(pitch)]
29
- ])
30
-
31
- # Rotation matrix around y-axis (yaw)
32
- R_y = np.array([
33
- [math.cos(yaw), 0, math.sin(yaw)],
34
- [0, 1, 0],
35
- [-math.sin(yaw), 0, math.cos(yaw)]
36
- ])
37
-
38
- # Combined rotation matrix
39
- R = np.dot(R_y, R_x)
40
- return R
41
-
42
- def custom_meshgrid(*args):
43
- # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
44
- if pver.parse(torch.__version__) < pver.parse('1.10'):
45
- return torch.meshgrid(*args)
46
- else:
47
- return torch.meshgrid(*args, indexing='ij')
48
-
49
- def camera_to_world_to_world_to_camera(camera_to_world):
50
- """
51
- Convert Camera-to-World matrix to World-to-Camera matrix by inverting the transformation.
52
- """
53
- # Extract rotation (R) and translation (T)
54
- R = camera_to_world[:3, :3]
55
- T = camera_to_world[:3, 3]
56
-
57
- # Calculate World-to-Camera (inverse) matrix
58
- world_to_camera = np.eye(4)
59
-
60
- # The rotation part of World-to-Camera is the transpose of Camera-to-World's rotation
61
- world_to_camera[:3, :3] = R.T
62
-
63
- # The translation part is the negative of the rotated translation
64
- world_to_camera[:3, 3] = -np.dot(R.T, T)
65
-
66
- return world_to_camera
67
-
68
- def euler_to_camera_to_world_matrix(pose):
69
-
70
- x, y, z, pitch, yaw = pose
71
- # Convert pitch and yaw to radians
72
- pitch = math.radians(pitch)
73
- yaw = math.radians(yaw)
74
-
75
- # Get the rotation matrix from Euler angles
76
- R = euler_to_rotation_matrix(pitch, yaw)
77
-
78
- # Create the 4x4 transformation matrix (rotation + translation)
79
- camera_to_world = np.eye(4)
80
-
81
- # Set the rotation part (upper 3x3)
82
- camera_to_world[:3, :3] = R
83
-
84
- # Set the translation part (last column)
85
- camera_to_world[:3, 3] = [x, y, z]
86
-
87
- return camera_to_world
88
-
89
- def tensor_to_gif(tensor, output_path, fps=10):
90
- """
91
- Converts a PyTorch tensor of shape (F, 3, H, W) to a GIF.
92
-
93
- Args:
94
- tensor (torch.Tensor): Input tensor of shape (F, 3, H, W) with values in range [0, 1] or [0, 255].
95
- output_path (str): Path to save the output GIF.
96
- fps (int): Frames per second for the GIF.
97
- """
98
- # Ensure the tensor is in [0, 255] range
99
- if tensor.max() <= 1.0:
100
- tensor = (tensor * 255).byte()
101
- else:
102
- tensor = tensor.byte()
103
-
104
- # Convert tensor to numpy array and rearrange to (F, H, W, 3)
105
- frames = tensor.permute(0, 2, 3, 1).cpu().numpy()
106
-
107
- # Convert frames to PIL Images
108
- pil_frames = [Image.fromarray(frame) for frame in frames]
109
-
110
- # Save as GIF
111
- pil_frames[0].save(
112
- output_path,
113
- save_all=True,
114
- append_images=pil_frames[1:],
115
- duration=int(1000 / fps),
116
- loop=0
117
- )
118
-
119
- def get_relative_pose(cam_params, zero_first_frame_scale):
120
- abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
121
- abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
122
- source_cam_c2w = abs_c2ws[0]
123
- if zero_first_frame_scale:
124
- cam_to_origin = 0
125
- else:
126
- cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
127
- target_cam_c2w = np.array([
128
- [1, 0, 0, 0],
129
- [0, 1, 0, -cam_to_origin],
130
- [0, 0, 1, 0],
131
- [0, 0, 0, 1]
132
- ])
133
- abs2rel = target_cam_c2w @ abs_w2cs[0]
134
- ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
135
- ret_poses = np.array(ret_poses, dtype=np.float32)
136
- return ret_poses
137
-
138
- def ray_condition(K, c2w, H, W, device):
139
- # c2w: B, V, 4, 4
140
- # K: B, V, 4
141
-
142
- B = K.shape[0]
143
-
144
- j, i = custom_meshgrid(
145
- torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
146
- torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
147
- )
148
- i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
149
- j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
150
-
151
- fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
152
-
153
- zs = torch.ones_like(i) # [B, HxW]
154
- xs = (i - cx) / fx * zs
155
- ys = (j - cy) / fy * zs
156
- zs = zs.expand_as(ys)
157
-
158
- directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
159
- directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
160
-
161
- rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
162
- rays_o = c2w[..., :3, 3] # B, V, 3
163
- rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
164
- # c2w @ dirctions
165
- rays_dxo = torch.linalg.cross(rays_o, rays_d)
166
- plucker = torch.cat([rays_dxo, rays_d], dim=-1)
167
- plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
168
-
169
- return plucker
170
-
171
- class Camera(object):
172
- def __init__(self, entry, focal_length=0.35):
173
- self.fx = focal_length # 0.35 correspond to 110 fov
174
- self.fy = focal_length*640/360
175
- self.cx = 0.5
176
- self.cy = 0.5
177
- self.c2w_mat = euler_to_camera_to_world_matrix(entry)
178
- self.w2c_mat = camera_to_world_to_world_to_camera(np.copy(self.c2w_mat))
179
-
180
-
181
- ACTION_KEYS = [
182
- "inventory",
183
- "ESC",
184
- "hotbar.1",
185
- "hotbar.2",
186
- "hotbar.3",
187
- "hotbar.4",
188
- "hotbar.5",
189
- "hotbar.6",
190
- "hotbar.7",
191
- "hotbar.8",
192
- "hotbar.9",
193
- "forward",
194
- "back",
195
- "left",
196
- "right",
197
- "cameraY",
198
- "cameraX",
199
- "jump",
200
- "sneak",
201
- "sprint",
202
- "swapHands",
203
- "attack",
204
- "use",
205
- "pickItem",
206
- "drop",
207
- ]
208
-
209
- def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
210
- actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
211
- for i, current_actions in enumerate(actions):
212
- for j, action_key in enumerate(ACTION_KEYS):
213
- if action_key.startswith("camera"):
214
- if action_key == "cameraX":
215
- value = current_actions["camera"][0]
216
- elif action_key == "cameraY":
217
- value = current_actions["camera"][1]
218
- else:
219
- raise ValueError(f"Unknown camera action key: {action_key}")
220
- max_val = 20
221
- bin_size = 0.5
222
- num_buckets = int(max_val / bin_size)
223
- value = (value - num_buckets) / num_buckets
224
- assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
225
- else:
226
- value = current_actions[action_key]
227
- assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
228
- actions_one_hot[i, j] = value
229
-
230
- return actions_one_hot
231
-
232
- def simpletomulti(actions):
233
- vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
234
- vec_25[actions==1, 11] = 1
235
- vec_25[actions==2, 16] = -1
236
- vec_25[actions==3, 16] = 1
237
- vec_25[actions==4, 15] = -1
238
- vec_25[actions==5, 15] = 1
239
- return vec_25
240
-
241
- def simpletomulti2(actions):
242
- vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
243
- vec_25[actions[:,0]==1, 11] = 1
244
- vec_25[actions[:,0]==2, 12] = 1
245
- vec_25[actions[:,4]==11, 16] = -1
246
- vec_25[actions[:,4]==13, 16] = 1
247
- vec_25[actions[:,3]==11, 15] = -1
248
- vec_25[actions[:,3]==13, 15] = 1
249
- vec_25[actions[:,5]==6, 24] = 1
250
- vec_25[actions[:,5]==1, 24] = 1
251
- vec_25[actions[:,1]==1, 13] = 1
252
- vec_25[actions[:,1]==2, 14] = 1
253
- vec_25[actions[:,7]==1, 2] = 1
254
- return vec_25
255
-
256
- class MinecraftVideoPoseDataset(BaseVideoDataset):
257
- """
258
- Minecraft dataset
259
- """
260
-
261
- def __init__(self, cfg: DictConfig, split: str = "training"):
262
- if split == "test":
263
- split = "validation"
264
- super().__init__(cfg, split)
265
-
266
- if hasattr(cfg, "n_frames_valid") and split == "validation":
267
- self.n_frames = cfg.n_frames_valid
268
-
269
- def get_data_paths(self, split):
270
- data_dir = self.save_dir / split
271
- paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
272
-
273
- if len(paths) == 0:
274
- sub_path = os.listdir(data_dir)
275
- for sp in sub_path:
276
- data_dir = self.save_dir / split / sp
277
- paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
278
- return paths
279
-
280
- def get_data_lengths(self, split):
281
- lengths = [300] * len(self.get_data_paths(split))
282
- return lengths
283
-
284
- def download_dataset(self) -> Sequence[int]:
285
- from internetarchive import download
286
-
287
- part_suffixes = [
288
- "aa",
289
- "ab",
290
- "ac",
291
- "ad",
292
- "ae",
293
- "af",
294
- "ag",
295
- "ah",
296
- "ai",
297
- "aj",
298
- "ak",
299
- ]
300
- for part_suffix in part_suffixes:
301
- identifier = f"minecraft_marsh_dataset_{part_suffix}"
302
- file_name = f"minecraft.tar.part{part_suffix}"
303
- download(identifier, file_name, destdir=self.save_dir, verbose=True)
304
-
305
- combined_bytes = io.BytesIO()
306
- for part_suffix in part_suffixes:
307
- identifier = f"minecraft_marsh_dataset_{part_suffix}"
308
- file_name = f"minecraft.tar.part{part_suffix}"
309
- part_file = self.save_dir / identifier / file_name
310
- with open(part_file, "rb") as part:
311
- combined_bytes.write(part.read())
312
- combined_bytes.seek(0)
313
- with tarfile.open(fileobj=combined_bytes, mode="r") as combined_archive:
314
- combined_archive.extractall(self.save_dir)
315
- (self.save_dir / "minecraft/test").rename(self.save_dir / "validation")
316
- (self.save_dir / "minecraft/train").rename(self.save_dir / "training")
317
- (self.save_dir / "minecraft").rmdir()
318
- for part_suffix in part_suffixes:
319
- identifier = f"minecraft_marsh_dataset_{part_suffix}"
320
- file_name = f"minecraft.tar.part{part_suffix}"
321
- part_file = self.save_dir / identifier / file_name
322
- part_file.rmdir()
323
-
324
- def __getitem__(self, idx):
325
- # return self.load_data(idx)
326
-
327
- max_retries = 1000
328
- for mr in range(max_retries):
329
- try:
330
- return self.load_data(idx)
331
- except Exception as e:
332
- print(f"{mr} Error: {e}")
333
- # idx = self.idx_remap[idx]
334
- # file_idx, frame_idx = self.split_idx(idx)
335
- # video_path = self.data_paths[file_idx]
336
- # os.remove(video_path)
337
- idx = (idx + 1) % self.__len__()
338
-
339
- def load_data(self, idx):
340
- idx = self.idx_remap[idx]
341
- file_idx, frame_idx = self.split_idx(idx)
342
- action_path = self.data_paths[file_idx]
343
- video_path = self.data_paths[file_idx]
344
-
345
- action_path = video_path.with_suffix(".npz")
346
- actions_pool = np.load(action_path)['actions']
347
- poses_pool = np.load(action_path)['poses']
348
-
349
- poses_pool[0,1] = poses_pool[1,1] # wrong first in place
350
-
351
- assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
352
-
353
- if len(poses_pool) < len(actions_pool):
354
- poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
355
-
356
- actions_pool = simpletomulti2(actions_pool)
357
- video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
358
-
359
- frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
360
-
361
- if self.split == "validation":
362
- frame_idx = 240
363
-
364
- total_frame = video_raw.duration.numerator
365
- fps = 10 # video_raw.duration.denominator
366
- total_frame = total_frame * fps / video_raw.duration.denominator
367
- video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
368
-
369
- video = video.permute(1, 2, 3, 0).numpy()
370
-
371
- if self.split != "validation" and 'degrees' in np.load(action_path).keys():
372
- degrees = np.load(action_path)['degrees']
373
- actions_pool[:,16] *= degrees
374
-
375
- actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) # (t, )
376
-
377
- poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
378
- pad_len = self.n_frames - len(video)
379
- poses_pool[:,:3] -= poses[:1,:3]
380
- # poses_pool[:,3:] = -poses_pool[:,3:]
381
- poses_pool[:,-1] = -poses_pool[:,-1]
382
- poses_pool[:,3:] %= 360
383
-
384
- poses[:,:3] -= poses[:1,:3] # do not normalize angle
385
- # poses[:,3:] = -poses[:,3:]
386
- poses[:,-1] = -poses[:,-1]
387
- poses[:,3:] %= 360
388
-
389
- nonterminal = np.ones(self.n_frames)
390
- if len(video) < self.n_frames:
391
- video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
392
- actions = np.pad(actions, ((0, pad_len),))
393
- poses = np.pad(actions, ((0, pad_len),))
394
- nonterminal[-pad_len:] = 0
395
-
396
- video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
397
-
398
- return (
399
- video[:: self.frame_skip],
400
- actions[:: self.frame_skip],
401
- poses[:: self.frame_skip]
402
- )
403
-
404
-
405
- if __name__ == "__main__":
406
- import torch
407
- from unittest.mock import MagicMock
408
- import tqdm
409
-
410
- cfg = MagicMock()
411
- cfg.resolution = 64
412
- cfg.external_cond_dim = 0
413
- cfg.n_frames = 64
414
- cfg.save_dir = "data/minecraft"
415
- cfg.validation_multiplier = 1
416
-
417
- dataset = MinecraftVideoDataset(cfg, "training")
418
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
419
-
420
- for batch in tqdm.tqdm(dataloader):
421
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/README.md DELETED
@@ -1,19 +0,0 @@
1
- # experiments
2
-
3
- `experiments` folder contains code of experiments. Each file in the experiment folder represents a certain type of
4
- benchmark specific to a project. Such experiment can be instantiated with a certain dataset and a certain algorithm.
5
-
6
- You should create a new `.py` file for your experiment,
7
- inherent from any suitable base classes in `experiments/exp_base.py`,
8
- and then register your new experiment in `experiments/__init__.py`.
9
-
10
- You run an experiment by running `python -m main [options]` in the root directory of the
11
- project. You should not log any data in this folder, but storing them under `outputs` under root project
12
- directory.
13
-
14
- This folder is only intend to contain formal experiments. For debug code and unit tests, put them under `debug` folder.
15
- For scripts that's not meant to be an experiment please use `scripts` folder.
16
-
17
- ---
18
-
19
- This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/__init__.py DELETED
@@ -1,35 +0,0 @@
1
- from typing import Optional, Union
2
- from omegaconf import DictConfig
3
- import pathlib
4
- from lightning.pytorch.loggers.wandb import WandbLogger
5
-
6
- from .exp_base import BaseExperiment
7
- from .exp_video import VideoPredictionExperiment
8
- from .exp_pose import PoseExperiment
9
-
10
- # each key has to be a yaml file under '[project_root]/configurations/experiment' without .yaml suffix
11
- exp_registry = dict(
12
- exp_video=VideoPredictionExperiment,
13
- exp_pose=PoseExperiment
14
- )
15
-
16
-
17
- def build_experiment(
18
- cfg: DictConfig,
19
- logger: Optional[WandbLogger] = None,
20
- ckpt_path: Optional[Union[str, pathlib.Path]] = None,
21
- ) -> BaseExperiment:
22
- """
23
- Build an experiment instance based on registry
24
- :param cfg: configuration file
25
- :param logger: optional logger for the experiment
26
- :param ckpt_path: optional checkpoint path for saving and loading
27
- :return:
28
- """
29
- if cfg.experiment._name not in exp_registry:
30
- raise ValueError(
31
- f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. "
32
- "Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file."
33
- )
34
-
35
- return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/exp_base.py DELETED
@@ -1,473 +0,0 @@
1
- """
2
- This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
3
- template [repo](https://github.com/buoyancy99/research-template).
4
- By its MIT license, you must keep the above sentence in `README.md`
5
- and the `LICENSE` file to credit the author.
6
- """
7
-
8
- from abc import ABC, abstractmethod
9
- from typing import Optional, Union, Literal, List, Dict
10
- import pathlib
11
- import os
12
-
13
- import hydra
14
- import torch
15
- from lightning.pytorch.strategies.ddp import DDPStrategy
16
-
17
- import lightning.pytorch as pl
18
- from lightning.pytorch.loggers.wandb import WandbLogger
19
- from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
20
- from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
21
- from pytorch_lightning.utilities import rank_zero_info
22
-
23
- from omegaconf import DictConfig
24
-
25
- from utils.print_utils import cyan
26
- from utils.distributed_utils import is_rank_zero
27
- from safetensors.torch import load_model
28
- from pathlib import Path
29
- from huggingface_hub import hf_hub_download
30
-
31
- torch.set_float32_matmul_precision("high")
32
-
33
- def load_custom_checkpoint(algo, optimizer, checkpoint_path):
34
- if not checkpoint_path:
35
- rank_zero_info("No checkpoint path provided, skipping checkpoint loading.")
36
- return None
37
-
38
- if not isinstance(checkpoint_path, Path):
39
- checkpoint_path = Path(checkpoint_path)
40
-
41
- if "yslan" in str(checkpoint_path):
42
- hf_ckpt = str(checkpoint_path).split('/')
43
- repo_id = '/'.join(hf_ckpt[:2])
44
- file_name = '/'.join(hf_ckpt[2:])
45
- model_path = hf_hub_download(repo_id=repo_id,
46
- filename=file_name)
47
- ckpt = torch.load(model_path, map_location=torch.device('cpu'))
48
- algo.load_state_dict(ckpt['state_dict'], strict=False)
49
-
50
- elif checkpoint_path.suffix == ".pt":
51
- ckpt = torch.load(checkpoint_path, weights_only=True)
52
- algo.load_state_dict(ckpt, strict=False)
53
- elif checkpoint_path.suffix == ".ckpt":
54
- ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
55
- algo.load_state_dict(ckpt['state_dict'], strict=False)
56
- elif checkpoint_path.suffix == ".safetensors":
57
- load_model(algo, checkpoint_path, strict=False)
58
- elif os.path.isdir(checkpoint_path):
59
- ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt')]
60
- if not ckpt_files:
61
- raise FileNotFoundError("在指定文件夹中未找到任何 .ckpt 文件!")
62
- selected_ckpt = max(ckpt_files)
63
- selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt)
64
- print(f"加载的 checkpoint 文件为: {selected_ckpt_path}")
65
-
66
- ckpt = torch.load(selected_ckpt_path, map_location=torch.device('cpu'))
67
- algo.load_state_dict(ckpt['state_dict'], strict=False)
68
-
69
- rank_zero_info("Model weights loaded.")
70
-
71
- class BaseExperiment(ABC):
72
- """
73
- Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more
74
- flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks.
75
- """
76
-
77
- # each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
78
- compatible_algorithms: Dict = NotImplementedError
79
-
80
- def __init__(
81
- self,
82
- root_cfg: DictConfig,
83
- logger: Optional[WandbLogger] = None,
84
- ckpt_path: Optional[Union[str, pathlib.Path]] = None,
85
- ) -> None:
86
- """
87
- Constructor
88
-
89
- Args:
90
- cfg: configuration file that contains everything about the experiment
91
- logger: a pytorch-lightning WandbLogger instance
92
- ckpt_path: an optional path to saved checkpoint
93
- """
94
- super().__init__()
95
- self.root_cfg = root_cfg
96
- self.cfg = root_cfg.experiment
97
- self.debug = root_cfg.debug
98
- self.logger = logger
99
- self.ckpt_path = ckpt_path
100
- self.algo = None
101
- self.customized_load = self.cfg.customized_load
102
- self.load_vae = self.cfg.load_vae
103
- self.load_t_to_r = self.cfg.load_t_to_r
104
- self.zero_init_gate=self.cfg.zero_init_gate
105
- self.only_tune_refer = self.cfg.only_tune_refer
106
- self.diffusion_path = self.cfg.diffusion_path
107
- self.vae_path = self.cfg.vae_path # "/mnt/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
108
- self.pose_predictor_path = self.cfg.pose_predictor_path # "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
109
-
110
- def _build_algo(self):
111
- """
112
- Build the lightning module
113
- :return: a pytorch-lightning module to be launched
114
- """
115
- algo_name = self.root_cfg.algorithm._name
116
- if algo_name not in self.compatible_algorithms:
117
- raise ValueError(
118
- f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. "
119
- "Make sure you define compatible_algorithms correctly and make sure that each key has "
120
- "same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix"
121
- )
122
- return self.compatible_algorithms[algo_name](self.root_cfg.algorithm)
123
-
124
- def exec_task(self, task: str) -> None:
125
- """
126
- Executing a certain task specified by string. Each task should be a stage of experiment.
127
- In most computer vision / nlp applications, tasks should be just train and test.
128
- In reinforcement learning, you might have more stages such as collecting dataset etc
129
-
130
- Args:
131
- task: a string specifying a task implemented for this experiment
132
- """
133
- if hasattr(self, task) and callable(getattr(self, task)):
134
- if is_rank_zero:
135
- print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
136
- getattr(self, task)()
137
- else:
138
- raise ValueError(
139
- f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
140
- )
141
-
142
- def exec_interactive(self, task: str) -> None:
143
- """
144
- Executing a certain task specified by string. Each task should be a stage of experiment.
145
- In most computer vision / nlp applications, tasks should be just train and test.
146
- In reinforcement learning, you might have more stages such as collecting dataset etc
147
-
148
- Args:
149
- task: a string specifying a task implemented for this experiment
150
- """
151
- if hasattr(self, task) and callable(getattr(self, task)):
152
- if is_rank_zero:
153
- print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
154
- return getattr(self, task)()
155
- else:
156
- raise ValueError(
157
- f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
158
- )
159
-
160
- class BaseLightningExperiment(BaseExperiment):
161
- """
162
- Abstract class for pytorch lightning experiments. Useful for computer vision & nlp where main components are
163
- simply models, datasets and train loop.
164
- """
165
-
166
- # each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
167
- compatible_algorithms: Dict = NotImplementedError
168
-
169
- # each key has to be a yaml file under '[project_root]/configurations/dataset' without .yaml suffix
170
- compatible_datasets: Dict = NotImplementedError
171
-
172
- def _build_trainer_callbacks(self):
173
- callbacks = []
174
- if self.logger:
175
- callbacks.append(LearningRateMonitor("step", True))
176
-
177
- def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
178
- train_dataset = self._build_dataset("training")
179
- shuffle = (
180
- False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle
181
- )
182
- if train_dataset:
183
- return torch.utils.data.DataLoader(
184
- train_dataset,
185
- batch_size=self.cfg.training.batch_size,
186
- num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers),
187
- shuffle=shuffle,
188
- persistent_workers=True,
189
- )
190
- else:
191
- return None
192
-
193
- def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
194
- validation_dataset = self._build_dataset("validation")
195
- shuffle = (
196
- False
197
- if isinstance(validation_dataset, torch.utils.data.IterableDataset)
198
- else self.cfg.validation.data.shuffle
199
- )
200
- if validation_dataset:
201
- return torch.utils.data.DataLoader(
202
- validation_dataset,
203
- batch_size=self.cfg.validation.batch_size,
204
- num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers),
205
- shuffle=shuffle,
206
- persistent_workers=True,
207
- )
208
- else:
209
- return None
210
-
211
- def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
212
- test_dataset = self._build_dataset("test")
213
- shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle
214
- if test_dataset:
215
- return torch.utils.data.DataLoader(
216
- test_dataset,
217
- batch_size=self.cfg.test.batch_size,
218
- num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers),
219
- shuffle=shuffle,
220
- persistent_workers=True,
221
- )
222
- else:
223
- return None
224
-
225
- def training(self) -> None:
226
- """
227
- All training happens here
228
- """
229
- if not self.algo:
230
- self.algo = self._build_algo()
231
- if self.cfg.training.compile:
232
- self.algo = torch.compile(self.algo)
233
-
234
- callbacks = []
235
- if self.logger:
236
- callbacks.append(LearningRateMonitor("step", True))
237
- if "checkpointing" in self.cfg.training:
238
- callbacks.append(
239
- ModelCheckpoint(
240
- pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints",
241
- **self.cfg.training.checkpointing,
242
- )
243
- )
244
-
245
- # TODO do not upload checkpoint to wandb
246
-
247
- # trainer = pl.Trainer(
248
- # accelerator="auto",
249
- # logger=self.logger if self.logger else False,
250
- # devices=torch.cuda.device_count(),
251
- # num_nodes=self.cfg.num_nodes,
252
- # strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
253
- # callbacks=callbacks,
254
- # gradient_clip_val=self.cfg.training.optim.gradient_clip_val,
255
- # val_check_interval=self.cfg.validation.val_every_n_step,
256
- # limit_val_batches=self.cfg.validation.limit_batch,
257
- # check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch,
258
- # accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches,
259
- # precision=self.cfg.training.precision,
260
- # detect_anomaly=False, # self.cfg.debug,
261
- # num_sanity_val_steps=int(self.cfg.debug),
262
- # max_epochs=self.cfg.training.max_epochs,
263
- # max_steps=self.cfg.training.max_steps,
264
- # max_time=self.cfg.training.max_time,
265
- # )
266
-
267
- trainer = pl.Trainer(
268
- accelerator="auto",
269
- devices="auto", # 自动选择设备
270
- strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
271
- logger=self.logger or False, # 简化写法
272
- callbacks=callbacks,
273
- gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, # 确保默认值
274
- val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None,
275
- limit_val_batches=self.cfg.validation.limit_batch,
276
- check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None,
277
- accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, # 默认累积为1
278
- precision=self.cfg.training.precision or 32, # 默认32位精度
279
- detect_anomaly=False, # 默认关闭异常检测
280
- num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0,
281
- max_epochs=self.cfg.training.max_epochs,
282
- max_steps=self.cfg.training.max_steps,
283
- max_time=self.cfg.training.max_time
284
- )
285
-
286
-
287
- if self.customized_load:
288
- if self.load_vae:
289
- load_custom_checkpoint(algo=self.algo.diffusion_model.model,optimizer=None,checkpoint_path=self.ckpt_path)
290
- load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
291
- else:
292
- load_custom_checkpoint(algo=self.algo,optimizer=None,checkpoint_path=self.ckpt_path)
293
-
294
- if self.load_t_to_r:
295
- param_list = []
296
- for name, para in self.algo.diffusion_model.named_parameters():
297
- if 't_' in name and 't_embedder' not in name:
298
- print(name)
299
- param_list.append(para)
300
-
301
- it = 0
302
- for name, para in self.algo.diffusion_model.named_parameters():
303
- if 'r_' in name:
304
- para.requires_grad_(False)
305
- try:
306
- para.copy_(param_list[it].detach().cpu())
307
- except:
308
- import pdb;pdb.set_trace()
309
- para.requires_grad_(True)
310
- it += 1
311
-
312
- if self.zero_init_gate:
313
- for name, para in self.algo.diffusion_model.named_parameters():
314
- if 'r_adaLN_modulation' in name:
315
- para.requires_grad_(False)
316
- para[2*1024:3*1024] = 0
317
- para[5*1024:6*1024] = 0
318
- para.requires_grad_(True)
319
-
320
- if self.only_tune_refer:
321
- for name, para in self.algo.diffusion_model.named_parameters():
322
- para.requires_grad_(False)
323
- if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name:
324
- para.requires_grad_(True)
325
-
326
- trainer.fit(
327
- self.algo,
328
- train_dataloaders=self._build_training_loader(),
329
- val_dataloaders=self._build_validation_loader(),
330
- ckpt_path=None,
331
- )
332
- else:
333
-
334
- if self.only_tune_refer:
335
- for name, para in self.algo.diffusion_model.named_parameters():
336
- para.requires_grad_(False)
337
- if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name:
338
- para.requires_grad_(True)
339
-
340
- trainer.fit(
341
- self.algo,
342
- train_dataloaders=self._build_training_loader(),
343
- val_dataloaders=self._build_validation_loader(),
344
- ckpt_path=self.ckpt_path,
345
- )
346
-
347
- def validation(self) -> None:
348
- """
349
- All validation happens here
350
- """
351
- if not self.algo:
352
- self.algo = self._build_algo()
353
- if self.cfg.validation.compile:
354
- self.algo = torch.compile(self.algo)
355
-
356
- callbacks = []
357
-
358
- trainer = pl.Trainer(
359
- accelerator="auto",
360
- logger=self.logger,
361
- devices="auto",
362
- num_nodes=self.cfg.num_nodes,
363
- strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
364
- callbacks=callbacks,
365
- # limit_val_batches=self.cfg.validation.limit_batch,
366
- limit_val_batches=self.cfg.validation.limit_batch,
367
- precision=self.cfg.validation.precision,
368
- detect_anomaly=False, # self.cfg.debug,
369
- inference_mode=self.cfg.validation.inference_mode,
370
- )
371
-
372
- if self.customized_load:
373
-
374
- if self.load_vae:
375
- load_custom_checkpoint(algo=self.algo.diffusion_model.model,optimizer=None,checkpoint_path=self.ckpt_path)
376
- load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
377
- else:
378
- load_custom_checkpoint(algo=self.algo,optimizer=None,checkpoint_path=self.ckpt_path)
379
-
380
- if self.load_t_to_r:
381
- param_list = []
382
- for name, para in self.algo.diffusion_model.named_parameters():
383
- if 't_' in name and 't_embedder' not in name:
384
- print(name)
385
- param_list.append(para)
386
-
387
- it = 0
388
- for name, para in self.algo.diffusion_model.named_parameters():
389
- if 'r_' in name:
390
- para.requires_grad_(False)
391
- try:
392
- para.copy_(param_list[it].detach().cpu())
393
- except:
394
- import pdb;pdb.set_trace()
395
- para.requires_grad_(True)
396
- it += 1
397
-
398
- if self.zero_init_gate:
399
- for name, para in self.algo.diffusion_model.named_parameters():
400
- if 'r_adaLN_modulation' in name:
401
- para.requires_grad_(False)
402
- para[2*1024:3*1024] = 0
403
- para[5*1024:6*1024] = 0
404
- para.requires_grad_(True)
405
-
406
- trainer.validate(
407
- self.algo,
408
- dataloaders=self._build_validation_loader(),
409
- ckpt_path=None,
410
- )
411
- else:
412
- trainer.validate(
413
- self.algo,
414
- dataloaders=self._build_validation_loader(),
415
- ckpt_path=self.ckpt_path,
416
- )
417
-
418
- def test(self) -> None:
419
- """
420
- All testing happens here
421
- """
422
- if not self.algo:
423
- self.algo = self._build_algo()
424
- if self.cfg.test.compile:
425
- self.algo = torch.compile(self.algo)
426
-
427
- callbacks = []
428
-
429
- trainer = pl.Trainer(
430
- accelerator="auto",
431
- logger=self.logger,
432
- devices="auto",
433
- num_nodes=self.cfg.num_nodes,
434
- strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
435
- callbacks=callbacks,
436
- limit_test_batches=self.cfg.test.limit_batch,
437
- precision=self.cfg.test.precision,
438
- detect_anomaly=False, # self.cfg.debug,
439
- )
440
-
441
- # Only load the checkpoint if only testing. Otherwise, it will have been loaded
442
- # and further trained during train.
443
- trainer.test(
444
- self.algo,
445
- dataloaders=self._build_test_loader(),
446
- ckpt_path=self.ckpt_path,
447
- )
448
- if not self.algo:
449
- self.algo = self._build_algo()
450
- if self.cfg.validation.compile:
451
- self.algo = torch.compile(self.algo)
452
-
453
-
454
- def interactive(self):
455
-
456
- if not self.algo:
457
- self.algo = self._build_algo()
458
- if self.cfg.validation.compile:
459
- self.algo = torch.compile(self.algo)
460
-
461
- if self.customized_load:
462
- load_custom_checkpoint(algo=self.algo.diffusion_model,optimizer=None,checkpoint_path=self.diffusion_path)
463
- load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
464
- load_custom_checkpoint(algo=self.algo.pose_prediction_model,optimizer=None,checkpoint_path=self.pose_predictor_path)
465
- return self.algo
466
- else:
467
- raise NotImplementedError
468
-
469
- def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]:
470
- if split in ["training", "test", "validation"]:
471
- return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split)
472
- else:
473
- raise NotImplementedError(f"split '{split}' is not implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/exp_pose.py DELETED
@@ -1,310 +0,0 @@
1
- """
2
- This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
3
- template [repo](https://github.com/buoyancy99/research-template).
4
- By its MIT license, you must keep the above sentence in `README.md`
5
- and the `LICENSE` file to credit the author.
6
- """
7
-
8
- from abc import ABC, abstractmethod
9
- from typing import Optional, Union, Literal, List, Dict
10
- import pathlib
11
- import os
12
-
13
- import hydra
14
- import torch
15
- from lightning.pytorch.strategies.ddp import DDPStrategy
16
-
17
- import lightning.pytorch as pl
18
- from lightning.pytorch.loggers.wandb import WandbLogger
19
- from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
20
- from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
21
- from pytorch_lightning.utilities import rank_zero_info
22
-
23
- from omegaconf import DictConfig
24
-
25
- from utils.print_utils import cyan
26
- from utils.distributed_utils import is_rank_zero
27
- from safetensors.torch import load_model
28
- from pathlib import Path
29
- from algorithms.worldmem import PosePrediction
30
- from datasets.video import MinecraftVideoPoseDataset
31
-
32
-
33
- torch.set_float32_matmul_precision("high")
34
-
35
- def load_custom_checkpoint(algo, optimizer, checkpoint_path):
36
- if not checkpoint_path:
37
- rank_zero_info("No checkpoint path provided, skipping checkpoint loading.")
38
- return None
39
-
40
- if not isinstance(checkpoint_path, Path):
41
- checkpoint_path = Path(checkpoint_path)
42
-
43
- if checkpoint_path.suffix == ".pt":
44
- ckpt = torch.load(checkpoint_path, weights_only=True)
45
- algo.load_state_dict(ckpt, strict=False)
46
- elif checkpoint_path.suffix == ".ckpt":
47
- ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
48
- algo.load_state_dict(ckpt['state_dict'], strict=False)
49
- elif checkpoint_path.suffix == ".safetensors":
50
- load_model(algo, checkpoint_path, strict=False)
51
- elif os.path.isdir(checkpoint_path):
52
- ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.ckpt')]
53
- if not ckpt_files:
54
- raise FileNotFoundError("在指定文件夹中未找到任何 .ckpt 文件!")
55
- selected_ckpt = max(ckpt_files)
56
- selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt)
57
- print(f"加载的 checkpoint 文件为: {selected_ckpt_path}")
58
-
59
- ckpt = torch.load(selected_ckpt_path, map_location=torch.device('cpu'))
60
- algo.load_state_dict(ckpt['state_dict'], strict=False)
61
-
62
- rank_zero_info("Model weights loaded.")
63
-
64
- class PoseExperiment(ABC):
65
- """
66
- Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more
67
- flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks.
68
- """
69
-
70
- # each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix
71
- compatible_algorithms = dict(
72
- pose_prediction=PosePrediction
73
- )
74
-
75
- compatible_datasets = dict(
76
- video_minecraft_pose=MinecraftVideoPoseDataset
77
- )
78
-
79
- def __init__(
80
- self,
81
- root_cfg: DictConfig,
82
- logger: Optional[WandbLogger] = None,
83
- ckpt_path: Optional[Union[str, pathlib.Path]] = None,
84
- ) -> None:
85
- """
86
- Constructor
87
-
88
- Args:
89
- cfg: configuration file that contains everything about the experiment
90
- logger: a pytorch-lightning WandbLogger instance
91
- ckpt_path: an optional path to saved checkpoint
92
- """
93
- super().__init__()
94
- self.root_cfg = root_cfg
95
- self.cfg = root_cfg.experiment
96
- self.debug = root_cfg.debug
97
- self.logger = logger
98
- self.ckpt_path = ckpt_path
99
- self.algo = None
100
- self.vae_path = "/cpfs01/user/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
101
-
102
- def _build_algo(self):
103
- """
104
- Build the lightning module
105
- :return: a pytorch-lightning module to be launched
106
- """
107
- algo_name = self.root_cfg.algorithm._name
108
- if algo_name not in self.compatible_algorithms:
109
- raise ValueError(
110
- f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. "
111
- "Make sure you define compatible_algorithms correctly and make sure that each key has "
112
- "same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix"
113
- )
114
- return self.compatible_algorithms[algo_name](self.root_cfg.algorithm)
115
-
116
- def exec_task(self, task: str) -> None:
117
- """
118
- Executing a certain task specified by string. Each task should be a stage of experiment.
119
- In most computer vision / nlp applications, tasks should be just train and test.
120
- In reinforcement learning, you might have more stages such as collecting dataset etc
121
-
122
- Args:
123
- task: a string specifying a task implemented for this experiment
124
- """
125
- if hasattr(self, task) and callable(getattr(self, task)):
126
- if is_rank_zero:
127
- print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}")
128
- getattr(self, task)()
129
- else:
130
- raise ValueError(
131
- f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable."
132
- )
133
-
134
-
135
- def _build_trainer_callbacks(self):
136
- callbacks = []
137
- if self.logger:
138
- callbacks.append(LearningRateMonitor("step", True))
139
-
140
- def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
141
- train_dataset = self._build_dataset("training")
142
- shuffle = (
143
- False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle
144
- )
145
- if train_dataset:
146
- return torch.utils.data.DataLoader(
147
- train_dataset,
148
- batch_size=self.cfg.training.batch_size,
149
- num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers),
150
- shuffle=shuffle,
151
- persistent_workers=True,
152
- )
153
- else:
154
- return None
155
-
156
- def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
157
- validation_dataset = self._build_dataset("validation")
158
- shuffle = (
159
- False
160
- if isinstance(validation_dataset, torch.utils.data.IterableDataset)
161
- else self.cfg.validation.data.shuffle
162
- )
163
- if validation_dataset:
164
- return torch.utils.data.DataLoader(
165
- validation_dataset,
166
- batch_size=self.cfg.validation.batch_size,
167
- num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers),
168
- shuffle=shuffle,
169
- persistent_workers=True,
170
- )
171
- else:
172
- return None
173
-
174
- def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]:
175
- test_dataset = self._build_dataset("test")
176
- shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle
177
- if test_dataset:
178
- return torch.utils.data.DataLoader(
179
- test_dataset,
180
- batch_size=self.cfg.test.batch_size,
181
- num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers),
182
- shuffle=shuffle,
183
- persistent_workers=True,
184
- )
185
- else:
186
- return None
187
-
188
- def training(self) -> None:
189
- """
190
- All training happens here
191
- """
192
- if not self.algo:
193
- self.algo = self._build_algo()
194
- if self.cfg.training.compile:
195
- self.algo = torch.compile(self.algo)
196
-
197
- callbacks = []
198
- if self.logger:
199
- callbacks.append(LearningRateMonitor("step", True))
200
- if "checkpointing" in self.cfg.training:
201
- callbacks.append(
202
- ModelCheckpoint(
203
- pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints",
204
- **self.cfg.training.checkpointing,
205
- )
206
- )
207
-
208
- trainer = pl.Trainer(
209
- accelerator="auto",
210
- devices="auto", # 自动选择设备
211
- strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto",
212
- logger=self.logger or False, # 简化写法
213
- callbacks=callbacks,
214
- gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, # 确保默认值
215
- val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None,
216
- limit_val_batches=self.cfg.validation.limit_batch,
217
- check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None,
218
- accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, # 默认累积为1
219
- precision=self.cfg.training.precision or 32, # 默认32位精度
220
- detect_anomaly=False, # 默认关闭异常检测
221
- num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0,
222
- max_epochs=self.cfg.training.max_epochs,
223
- max_steps=self.cfg.training.max_steps,
224
- max_time=self.cfg.training.max_time
225
- )
226
-
227
- load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
228
-
229
- trainer.fit(
230
- self.algo,
231
- train_dataloaders=self._build_training_loader(),
232
- val_dataloaders=self._build_validation_loader(),
233
- ckpt_path=self.ckpt_path,
234
- )
235
-
236
- def validation(self) -> None:
237
- """
238
- All validation happens here
239
- """
240
- if not self.algo:
241
- self.algo = self._build_algo()
242
- if self.cfg.validation.compile:
243
- self.algo = torch.compile(self.algo)
244
-
245
- callbacks = []
246
-
247
- trainer = pl.Trainer(
248
- accelerator="auto",
249
- logger=self.logger,
250
- devices="auto",
251
- num_nodes=self.cfg.num_nodes,
252
- strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
253
- callbacks=callbacks,
254
- # limit_val_batches=self.cfg.validation.limit_batch,
255
- limit_val_batches=self.cfg.validation.limit_batch,
256
- precision=self.cfg.validation.precision,
257
- detect_anomaly=False, # self.cfg.debug,
258
- inference_mode=self.cfg.validation.inference_mode,
259
- )
260
-
261
- load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
262
-
263
- trainer.validate(
264
- self.algo,
265
- dataloaders=self._build_validation_loader(),
266
- ckpt_path=self.ckpt_path,
267
- )
268
-
269
- def test(self) -> None:
270
- """
271
- All testing happens here
272
- """
273
- if not self.algo:
274
- self.algo = self._build_algo()
275
- if self.cfg.test.compile:
276
- self.algo = torch.compile(self.algo)
277
-
278
- callbacks = []
279
-
280
- trainer = pl.Trainer(
281
- accelerator="auto",
282
- logger=self.logger,
283
- devices="auto",
284
- num_nodes=self.cfg.num_nodes,
285
- strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
286
- callbacks=callbacks,
287
- limit_test_batches=self.cfg.test.limit_batch,
288
- precision=self.cfg.test.precision,
289
- detect_anomaly=False, # self.cfg.debug,
290
- )
291
-
292
- # Only load the checkpoint if only testing. Otherwise, it will have been loaded
293
- # and further trained during train.
294
- trainer.test(
295
- self.algo,
296
- dataloaders=self._build_test_loader(),
297
- ckpt_path=self.ckpt_path,
298
- )
299
- if not self.algo:
300
- self.algo = self._build_algo()
301
- if self.cfg.validation.compile:
302
- self.algo = torch.compile(self.algo)
303
-
304
- def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]:
305
- if split in ["training", "test", "validation"]:
306
- return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split)
307
- else:
308
- raise NotImplementedError(f"split '{split}' is not implemented")
309
-
310
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/exp_video.py DELETED
@@ -1,25 +0,0 @@
1
- from datasets.video import (
2
- MinecraftVideoDataset,
3
- MinecraftVideoPoseDataset
4
- )
5
-
6
- from algorithms.worldmem import WorldMemMinecraft
7
- from algorithms.worldmem import PosePrediction
8
- from .exp_base import BaseLightningExperiment
9
-
10
-
11
- class VideoPredictionExperiment(BaseLightningExperiment):
12
- """
13
- A video prediction experiment
14
- """
15
-
16
- compatible_algorithms = dict(
17
- df_video_worldmemminecraft=WorldMemMinecraft,
18
- pose_prediction=PosePrediction
19
- )
20
-
21
- compatible_datasets = dict(
22
- # video datasets
23
- video_minecraft=MinecraftVideoDataset,
24
- video_minecraft_pose=MinecraftVideoPoseDataset
25
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,219 +0,0 @@
1
- """
2
- This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
3
- template [repo](https://github.com/buoyancy99/research-template).
4
- By its MIT license, you must keep the above sentence in `README.md`
5
- and the `LICENSE` file to credit the author.
6
-
7
- Main file for the project. This will create and run new experiments and load checkpoints from wandb.
8
- Borrowed part of the code from David Charatan and wandb.
9
- """
10
-
11
- import sys
12
- import subprocess
13
- import time
14
- from pathlib import Path
15
-
16
- import hydra
17
- from omegaconf import DictConfig, OmegaConf
18
- from omegaconf.omegaconf import open_dict
19
-
20
- from utils.print_utils import cyan
21
- from utils.ckpt_utils import download_latest_checkpoint, is_run_id
22
- from utils.cluster_utils import submit_slurm_job
23
- from utils.distributed_utils import is_rank_zero
24
-
25
- def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = '*.ckpt'):
26
- # 获取文件夹中所有符合 pattern 的文件
27
- checkpoint_files = list(checkpoint_folder.glob(pattern))
28
- if not checkpoint_files:
29
- return None # 如果没有找到 checkpoint 文件,返回 None
30
- # 根据文件修改时间(st_mtime)选取最新的文件
31
- latest_checkpoint = max(checkpoint_files, key=lambda f: f.stat().st_mtime)
32
- return latest_checkpoint
33
-
34
- def run_local(cfg: DictConfig):
35
- # delay some imports in case they are not needed in non-local envs for submission
36
- from experiments import build_experiment
37
- from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger
38
-
39
- # Get yaml names
40
- hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
41
- cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
42
-
43
- with open_dict(cfg):
44
- if cfg_choice["experiment"] is not None:
45
- cfg.experiment._name = cfg_choice["experiment"]
46
- if cfg_choice["dataset"] is not None:
47
- cfg.dataset._name = cfg_choice["dataset"]
48
- if cfg_choice["algorithm"] is not None:
49
- cfg.algorithm._name = cfg_choice["algorithm"]
50
-
51
- # import pdb;pdb.set_trace()
52
- # Set up the output directory.
53
- output_dir = getattr(cfg, "output_dir", None)
54
- if output_dir is not None:
55
- OmegaConf.set_readonly(hydra_cfg, False)
56
- hydra_cfg.runtime.output_dir = output_dir
57
- OmegaConf.set_readonly(hydra_cfg, True)
58
-
59
- output_dir = Path(hydra_cfg.runtime.output_dir)
60
-
61
- if is_rank_zero:
62
- print(cyan(f"Outputs will be saved to:"), output_dir)
63
- (output_dir.parents[1] / "latest-run").unlink(missing_ok=True)
64
- (output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True)
65
-
66
- # Set up logging with wandb.
67
- if cfg.wandb.mode != "disabled":
68
- # If resuming, merge into the existing run on wandb.
69
- resume = cfg.get("resume", None)
70
- name = f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})" if resume is None else None
71
-
72
- if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
73
- logger_cls = OfflineWandbLogger
74
- else:
75
- logger_cls = SpaceEfficientWandbLogger
76
-
77
- offline = cfg.wandb.mode != "online"
78
- logger = logger_cls(
79
- name=name,
80
- save_dir=str(output_dir),
81
- offline=offline,
82
- entity=cfg.wandb.entity,
83
- project=cfg.wandb.project,
84
- log_model=False,
85
- config=OmegaConf.to_container(cfg),
86
- id=resume,
87
- resume="auto"
88
- )
89
-
90
- else:
91
- logger = None
92
-
93
- # Load ckpt
94
- resume = cfg.get("resume", None)
95
- load = cfg.get("load", None)
96
- checkpoint_path = None
97
- load_id = None
98
- if load and not is_run_id(load):
99
- checkpoint_path = load
100
- if resume:
101
- load_id = resume
102
- elif load and is_run_id(load):
103
- load_id = load
104
- else:
105
- load_id = None
106
-
107
- if load_id:
108
- run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
109
- checkpoint_path = Path("outputs/downloaded") / run_path / "model.ckpt"
110
- checkpoint_path = output_dir / get_latest_checkpoint(output_dir / "checkpoints")
111
-
112
- if checkpoint_path and is_rank_zero:
113
- print(f"Will load checkpoint from {checkpoint_path}")
114
-
115
- # launch experiment
116
- experiment = build_experiment(cfg, logger, checkpoint_path)
117
- for task in cfg.experiment.tasks:
118
- experiment.exec_task(task)
119
-
120
-
121
- def run_slurm(cfg: DictConfig):
122
- python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True"
123
- project_root = Path.cwd()
124
- while not (project_root / ".git").exists():
125
- project_root = project_root.parent
126
- if project_root == Path("/"):
127
- raise Exception("Could not find repo directory!")
128
-
129
- slurm_log_dir = submit_slurm_job(
130
- cfg,
131
- python_args,
132
- project_root,
133
- )
134
-
135
- if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
136
- print("Job submitted to a compute node without internet. This requires manual syncing on login node.")
137
- osh_command_dir = project_root / ".wandb_osh_command_dir"
138
-
139
- osh_proc = None
140
- # if click.confirm("Do you want us to run the sync loop for you?", default=True):
141
- osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir])
142
- print(f"Running wandb-osh in background... PID: {osh_proc.pid}")
143
- print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.")
144
- print(
145
- f"You can manually start a sync loop later by running the following:",
146
- cyan(f"wandb-osh --command-dir {osh_command_dir}"),
147
- )
148
-
149
- print(
150
- "Once the job gets allocated and starts running, we will print a command below "
151
- "for you to trace the errors and outputs: (Ctrl + C to exit without waiting)"
152
- )
153
- msg = f"tail -f {slurm_log_dir}/* \n"
154
- try:
155
- while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")):
156
- time.sleep(1)
157
- print(cyan("To trace the outputs and errors, run the following command:"), msg)
158
- except KeyboardInterrupt:
159
- print("Keyboard interrupt detected. Exiting...")
160
- print(
161
- cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"),
162
- msg,
163
- )
164
-
165
-
166
- @hydra.main(
167
- version_base=None,
168
- config_path="configurations",
169
- config_name="config",
170
- )
171
- def run(cfg: DictConfig):
172
- if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
173
- with open_dict(cfg):
174
- if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
175
- cfg.wandb.mode = "offline"
176
-
177
- if "name" not in cfg:
178
- raise ValueError("must specify a name for the run with command line argument '+name=[name]'")
179
-
180
- if not cfg.wandb.get("entity", None):
181
- raise ValueError(
182
- "must specify wandb entity in 'configurations/config.yaml' or with command line"
183
- " argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group"
184
- " name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/"
185
- )
186
-
187
- if cfg.wandb.project is None:
188
- cfg.wandb.project = str(Path(__file__).parent.name)
189
-
190
- # If resuming or loading a wandb ckpt and not on a compute node, download the checkpoint.
191
- resume = cfg.get("resume", None)
192
- load = cfg.get("load", None)
193
-
194
- if resume and load:
195
- raise ValueError(
196
- "When resuming a wandb run with `resume=[wandb id]`, checkpoint will be loaded from the cloud"
197
- "and `load` should not be specified."
198
- )
199
-
200
- if resume:
201
- load_id = resume
202
- elif load and is_run_id(load):
203
- load_id = load
204
- else:
205
- load_id = None
206
-
207
- # if load_id and "_on_compute_node" not in cfg:
208
- # run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}"
209
- # download_latest_checkpoint(run_path, Path("outputs/downloaded"))
210
-
211
- if "cluster" in cfg and not "_on_compute_node" in cfg:
212
- print(cyan("Slurm detected, submitting to compute node instead of running locally..."))
213
- run_slurm(cfg)
214
- else:
215
- run_local(cfg)
216
-
217
-
218
- if __name__ == "__main__":
219
- run() # pylint: disable=no-value-for-parameter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/README.md DELETED
@@ -1,10 +0,0 @@
1
- # scirpts
2
-
3
- `scripts` folder contains bash scripts for you to scale up your project on cloud.
4
- Don't put your jupyter notebooks here! They belongs to `debug` folder.
5
-
6
- General scripts that are useful for all projects can be put in the `script` folder directly.
7
-
8
- ---
9
-
10
- This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
 
 
 
 
 
 
 
 
 
 
 
scripts/dummy_script.sh DELETED
@@ -1 +0,0 @@
1
- echo 'hello world'
 
 
split_checkpoint.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
-
3
- ckpt_path = "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
4
- checkpoint = torch.load(ckpt_path, map_location="cpu") # map_location 可根据需要更换
5
-
6
- state_dict = checkpoint['state_dict']
7
- pose_prediction_model_dict = {k.replace('pose_prediction_model.', ''): v for k, v in state_dict.items() if k.startswith('pose_prediction_model.')}
8
-
9
- torch.save({'state_dict': pose_prediction_model_dict}, "pose_prediction_model_only.ckpt")