Spaces:
Running
on
Zero
Running
on
Zero
Upload 14 files
Browse files- data/__init__.py +2 -0
- data/configs/example.yaml +45 -0
- data/data_utils.py +177 -0
- data/dataset_base.py +620 -0
- data/dataset_info.py +39 -0
- data/distributed_iterable_dataset.py +58 -0
- data/interleave_datasets/__init__.py +5 -0
- data/interleave_datasets/edit_dataset.py +72 -0
- data/interleave_datasets/interleave_t2i_dataset.py +212 -0
- data/parquet_utils.py +90 -0
- data/t2i_dataset.py +128 -0
- data/transforms.py +287 -0
- data/video_utils.py +165 -0
- data/vlm_dataset.py +195 -0
data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
data/configs/example.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
t2i_pretrain:
|
2 |
+
dataset_names:
|
3 |
+
- t2i
|
4 |
+
image_transform_args:
|
5 |
+
image_stride: 16
|
6 |
+
max_image_size: 1024
|
7 |
+
min_image_size: 512
|
8 |
+
is_mandatory: true
|
9 |
+
num_used_data: # The sum should be larger that NUM_GPUS x NUM_WORKERS
|
10 |
+
- 10
|
11 |
+
weight: 1
|
12 |
+
|
13 |
+
unified_edit:
|
14 |
+
dataset_names:
|
15 |
+
- seedxedit_multi
|
16 |
+
image_transform_args:
|
17 |
+
image_stride: 16
|
18 |
+
max_image_size: 1024
|
19 |
+
min_image_size: 512
|
20 |
+
vit_image_transform_args:
|
21 |
+
image_stride: 14
|
22 |
+
max_image_size: 518
|
23 |
+
min_image_size: 224
|
24 |
+
is_mandatory: false
|
25 |
+
num_used_data:
|
26 |
+
- 10
|
27 |
+
weight: 1
|
28 |
+
|
29 |
+
vlm_sft:
|
30 |
+
dataset_names:
|
31 |
+
- llava_ov
|
32 |
+
image_transform_args:
|
33 |
+
image_stride: 14
|
34 |
+
max_image_size: 980
|
35 |
+
min_image_size: 378
|
36 |
+
max_pixels: 2_007_040
|
37 |
+
frame_sampler_args:
|
38 |
+
max_num_frames: 12
|
39 |
+
min_num_frames: 8
|
40 |
+
is_mandatory: true
|
41 |
+
shuffle_lines: True
|
42 |
+
shuffle_seed: 0
|
43 |
+
num_used_data:
|
44 |
+
- 1000
|
45 |
+
weight: 1
|
data/data_utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.nn.attention.flex_attention import or_masks, and_masks
|
11 |
+
|
12 |
+
|
13 |
+
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
|
14 |
+
def causal_mask(b, h, q_idx, kv_idx):
|
15 |
+
return q_idx >= kv_idx
|
16 |
+
|
17 |
+
def full_and_noise_mask(b, h, q_idx, kv_idx):
|
18 |
+
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
|
19 |
+
|
20 |
+
def remove_noise_mask(b, h, q_idx, kv_idx):
|
21 |
+
return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
|
22 |
+
|
23 |
+
def sample_mask(b, h, q_idx, kv_idx):
|
24 |
+
return document_id[q_idx] == document_id[kv_idx]
|
25 |
+
|
26 |
+
full_and_noise_tmp = []
|
27 |
+
noise_tmp = []
|
28 |
+
|
29 |
+
for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
|
30 |
+
value = i if model in ['full', 'noise'] else -1
|
31 |
+
full_and_noise_tmp.extend([value] * length)
|
32 |
+
value_noise = i if model == 'noise' else -1
|
33 |
+
noise_tmp.extend([value_noise] * length)
|
34 |
+
|
35 |
+
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
|
36 |
+
noise_seq_id = torch.Tensor(noise_tmp).to(device)
|
37 |
+
|
38 |
+
document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
|
39 |
+
|
40 |
+
return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
|
41 |
+
|
42 |
+
|
43 |
+
def patchify(image, patch_size):
|
44 |
+
p = patch_size
|
45 |
+
c, h, w = image.shape
|
46 |
+
assert h % p == 0 and w % p == 0
|
47 |
+
image = image.reshape(c, h // p, p, w // p, p)
|
48 |
+
image = torch.einsum("chpwq->hwpqc", image)
|
49 |
+
image = image.reshape(-1, p**2 * c)
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
|
54 |
+
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
|
55 |
+
coords_h = torch.arange(0, num_patches_h)
|
56 |
+
coords_w = torch.arange(0, num_patches_w)
|
57 |
+
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
|
58 |
+
return pos_ids
|
59 |
+
|
60 |
+
|
61 |
+
def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
|
62 |
+
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
|
63 |
+
boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
|
64 |
+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
|
65 |
+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
|
66 |
+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
67 |
+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
68 |
+
pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
|
69 |
+
return pos_ids
|
70 |
+
|
71 |
+
|
72 |
+
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
|
73 |
+
"""
|
74 |
+
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
|
75 |
+
a sample, where each sample contains multiple splits with different attn modes.
|
76 |
+
nested_attn_modes: whether to use full attn in each split.
|
77 |
+
"""
|
78 |
+
sample_len = sum(split_lens)
|
79 |
+
attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
|
80 |
+
|
81 |
+
csum = 0
|
82 |
+
for s, attn_mode in zip(split_lens, attn_modes):
|
83 |
+
assert attn_mode in ['causal', 'full', 'noise']
|
84 |
+
if attn_mode == "causal":
|
85 |
+
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
|
86 |
+
attention_mask[csum:csum + s, :csum] = 1
|
87 |
+
else:
|
88 |
+
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
|
89 |
+
attention_mask[csum:csum + s, :csum] = 1
|
90 |
+
csum += s
|
91 |
+
|
92 |
+
csum = 0
|
93 |
+
for s, attn_mode in zip(split_lens, attn_modes):
|
94 |
+
if attn_mode == "noise":
|
95 |
+
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
|
96 |
+
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
|
97 |
+
csum += s
|
98 |
+
|
99 |
+
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
|
100 |
+
~attention_mask, float("-inf")
|
101 |
+
)
|
102 |
+
|
103 |
+
return attention_mask
|
104 |
+
|
105 |
+
|
106 |
+
def split_integer_exp_decay(S, ng_sample_decay=1.0):
|
107 |
+
if ng_sample_decay == 1.0:
|
108 |
+
N = random.randint(1, S)
|
109 |
+
else:
|
110 |
+
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
|
111 |
+
p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
|
112 |
+
N = random.choices(list(range(1, S + 1)), p, k=1)[0]
|
113 |
+
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
|
114 |
+
result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
|
115 |
+
return result, cumsum
|
116 |
+
|
117 |
+
|
118 |
+
def pil_img2rgb(image):
|
119 |
+
if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
|
120 |
+
image = image.convert("RGBA")
|
121 |
+
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
|
122 |
+
white.paste(image, mask=image.split()[3])
|
123 |
+
image = white
|
124 |
+
else:
|
125 |
+
image = image.convert("RGB")
|
126 |
+
|
127 |
+
return image
|
128 |
+
|
129 |
+
|
130 |
+
def add_special_tokens(tokenizer):
|
131 |
+
all_special_tokens = []
|
132 |
+
for k, v in tokenizer.special_tokens_map.items():
|
133 |
+
if isinstance(v, str):
|
134 |
+
all_special_tokens.append(v)
|
135 |
+
elif isinstance(v, list):
|
136 |
+
all_special_tokens += v
|
137 |
+
|
138 |
+
new_tokens = []
|
139 |
+
|
140 |
+
if '<|im_start|>' not in all_special_tokens:
|
141 |
+
new_tokens.append('<|im_start|>')
|
142 |
+
|
143 |
+
if '<|im_end|>' not in all_special_tokens:
|
144 |
+
new_tokens.append('<|im_end|>')
|
145 |
+
|
146 |
+
if '<|vision_start|>' not in all_special_tokens:
|
147 |
+
new_tokens.append('<|vision_start|>')
|
148 |
+
|
149 |
+
if '<|vision_end|>' not in all_special_tokens:
|
150 |
+
new_tokens.append('<|vision_end|>')
|
151 |
+
|
152 |
+
num_new_tokens = tokenizer.add_tokens(new_tokens)
|
153 |
+
bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
|
154 |
+
eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
|
155 |
+
start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
|
156 |
+
end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
|
157 |
+
|
158 |
+
new_token_ids = dict(
|
159 |
+
bos_token_id=bos_token_id,
|
160 |
+
eos_token_id=eos_token_id,
|
161 |
+
start_of_image=start_of_image,
|
162 |
+
end_of_image=end_of_image,
|
163 |
+
)
|
164 |
+
|
165 |
+
return tokenizer, new_token_ids, num_new_tokens
|
166 |
+
|
167 |
+
|
168 |
+
def len2weight(x, loss_reduction='square'):
|
169 |
+
if x == 0:
|
170 |
+
return x
|
171 |
+
if loss_reduction == 'token':
|
172 |
+
return 1
|
173 |
+
if loss_reduction == 'sample':
|
174 |
+
return 1 / x
|
175 |
+
if loss_reduction == 'square':
|
176 |
+
return 1 / (x ** 0.5)
|
177 |
+
raise NotImplementedError(loss_reduction)
|
data/dataset_base.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
|
5 |
+
import random
|
6 |
+
import json
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .data_utils import (
|
12 |
+
get_flattened_position_ids_interpolate,
|
13 |
+
get_flattened_position_ids_extrapolate,
|
14 |
+
len2weight,
|
15 |
+
patchify,
|
16 |
+
prepare_attention_mask_per_sample,
|
17 |
+
)
|
18 |
+
from .dataset_info import DATASET_INFO, DATASET_REGISTRY
|
19 |
+
from .transforms import ImageTransform
|
20 |
+
from .video_utils import FrameSampler
|
21 |
+
|
22 |
+
|
23 |
+
class DataConfig:
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
grouped_datasets,
|
27 |
+
text_cond_dropout_prob=0.1,
|
28 |
+
vit_cond_dropout_prob=0.4,
|
29 |
+
vae_cond_dropout_prob=0.1,
|
30 |
+
vae_image_downsample=16,
|
31 |
+
max_latent_size=32,
|
32 |
+
vit_patch_size=14,
|
33 |
+
max_num_patch_per_side=70,
|
34 |
+
):
|
35 |
+
self.grouped_datasets = grouped_datasets
|
36 |
+
self.text_cond_dropout_prob = text_cond_dropout_prob
|
37 |
+
self.vit_cond_dropout_prob = vit_cond_dropout_prob
|
38 |
+
self.vit_patch_size = vit_patch_size
|
39 |
+
self.max_num_patch_per_side = max_num_patch_per_side
|
40 |
+
self.vae_cond_dropout_prob = vae_cond_dropout_prob
|
41 |
+
self.vae_image_downsample = vae_image_downsample
|
42 |
+
self.max_latent_size = max_latent_size
|
43 |
+
|
44 |
+
|
45 |
+
class PackedDataset(torch.utils.data.IterableDataset):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
data_config,
|
49 |
+
tokenizer,
|
50 |
+
special_tokens,
|
51 |
+
local_rank,
|
52 |
+
world_size,
|
53 |
+
num_workers,
|
54 |
+
expected_num_tokens=32768,
|
55 |
+
max_num_tokens_per_sample=16384,
|
56 |
+
max_num_tokens=36864,
|
57 |
+
prefer_buffer_before=16384,
|
58 |
+
max_buffer_size=50,
|
59 |
+
interpolate_pos=False,
|
60 |
+
use_flex=False,
|
61 |
+
data_status=None,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.expected_num_tokens = expected_num_tokens
|
65 |
+
self.max_num_tokens_per_sample = max_num_tokens_per_sample
|
66 |
+
self.prefer_buffer_before = prefer_buffer_before
|
67 |
+
self.max_num_tokens = max_num_tokens
|
68 |
+
self.max_buffer_size = max_buffer_size
|
69 |
+
self.tokenizer = tokenizer
|
70 |
+
self.local_rank = local_rank
|
71 |
+
self.world_size = world_size
|
72 |
+
self.num_workers = num_workers
|
73 |
+
self.use_flex = use_flex
|
74 |
+
for k, v in special_tokens.items():
|
75 |
+
setattr(self, k, v)
|
76 |
+
|
77 |
+
grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
|
78 |
+
data_config.grouped_datasets, data_status
|
79 |
+
)
|
80 |
+
self.grouped_datasets = grouped_datasets
|
81 |
+
self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
|
82 |
+
self.is_mandatory = is_mandatory
|
83 |
+
self.grouped_weights = grouped_weights
|
84 |
+
self.data_config = data_config
|
85 |
+
self.interpolate_pos = interpolate_pos
|
86 |
+
if self.interpolate_pos:
|
87 |
+
self.get_flattened_position_ids = get_flattened_position_ids_interpolate
|
88 |
+
else:
|
89 |
+
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
|
90 |
+
|
91 |
+
def build_datasets(self, datasets_metainfo, data_status):
|
92 |
+
datasets = []
|
93 |
+
is_mandatory = []
|
94 |
+
grouped_weights = []
|
95 |
+
for grouped_dataset_name, dataset_args in datasets_metainfo.items():
|
96 |
+
is_mandatory.append(dataset_args.pop('is_mandatory', False))
|
97 |
+
grouped_weights.append(dataset_args.pop('weight', 0.0))
|
98 |
+
|
99 |
+
if 'frame_sampler_args' in dataset_args.keys():
|
100 |
+
frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args'))
|
101 |
+
dataset_args['frame_sampler'] = frame_sampler
|
102 |
+
if 'image_transform_args' in dataset_args.keys():
|
103 |
+
transform = ImageTransform(**dataset_args.pop('image_transform_args'))
|
104 |
+
dataset_args['transform'] = transform
|
105 |
+
if 'vit_image_transform_args' in dataset_args.keys():
|
106 |
+
vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args'))
|
107 |
+
dataset_args['vit_transform'] = vit_transform
|
108 |
+
|
109 |
+
assert 'dataset_names' in dataset_args.keys()
|
110 |
+
dataset_names = dataset_args.pop('dataset_names')
|
111 |
+
dataset_args['data_dir_list'] = []
|
112 |
+
for item in dataset_names:
|
113 |
+
if self.local_rank == 0:
|
114 |
+
print(f'Preparing Dataset {grouped_dataset_name}/{item}')
|
115 |
+
meta_info = DATASET_INFO[grouped_dataset_name][item]
|
116 |
+
dataset_args['data_dir_list'].append(meta_info['data_dir'])
|
117 |
+
|
118 |
+
if "parquet_info_path" in meta_info.keys():
|
119 |
+
if 'parquet_info' not in dataset_args.keys():
|
120 |
+
dataset_args['parquet_info'] = {}
|
121 |
+
with open(meta_info['parquet_info_path'], 'r') as f:
|
122 |
+
parquet_info = json.load(f)
|
123 |
+
dataset_args['parquet_info'].update(parquet_info)
|
124 |
+
|
125 |
+
if 'json_dir' in meta_info.keys():
|
126 |
+
# parquet/tar with json
|
127 |
+
if 'json_dir_list' not in dataset_args.keys():
|
128 |
+
dataset_args['json_dir_list'] = [meta_info['json_dir']]
|
129 |
+
else:
|
130 |
+
dataset_args['json_dir_list'].append(meta_info['json_dir'])
|
131 |
+
|
132 |
+
if 'jsonl_path' in meta_info.keys():
|
133 |
+
# jsonl with jpeg
|
134 |
+
if 'jsonl_path_list' not in dataset_args.keys():
|
135 |
+
dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']]
|
136 |
+
else:
|
137 |
+
dataset_args['jsonl_path_list'].append(meta_info['jsonl_path'])
|
138 |
+
|
139 |
+
resume_data_status = dataset_args.pop('resume_data_status', True)
|
140 |
+
if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status:
|
141 |
+
data_status_per_group = data_status[grouped_dataset_name]
|
142 |
+
else:
|
143 |
+
data_status_per_group = None
|
144 |
+
dataset = DATASET_REGISTRY[grouped_dataset_name](
|
145 |
+
dataset_name=grouped_dataset_name,
|
146 |
+
tokenizer=self.tokenizer,
|
147 |
+
local_rank=self.local_rank,
|
148 |
+
world_size=self.world_size,
|
149 |
+
num_workers=self.num_workers,
|
150 |
+
data_status=data_status_per_group,
|
151 |
+
**dataset_args
|
152 |
+
)
|
153 |
+
datasets.append(dataset)
|
154 |
+
|
155 |
+
return datasets, is_mandatory, grouped_weights
|
156 |
+
|
157 |
+
def set_epoch(self, seed):
|
158 |
+
for dataset in self.grouped_datasets:
|
159 |
+
dataset.set_epoch(seed)
|
160 |
+
|
161 |
+
def set_sequence_status(self):
|
162 |
+
sequence_status = dict(
|
163 |
+
curr = 0,
|
164 |
+
sample_lens = list(),
|
165 |
+
packed_position_ids = list(),
|
166 |
+
nested_attention_masks = list(),
|
167 |
+
split_lens = list(),
|
168 |
+
attn_modes = list(),
|
169 |
+
packed_text_ids = list(),
|
170 |
+
packed_text_indexes = list(),
|
171 |
+
packed_label_ids = list(),
|
172 |
+
ce_loss_indexes = list(),
|
173 |
+
ce_loss_weights = list(),
|
174 |
+
vae_image_tensors = list(),
|
175 |
+
packed_latent_position_ids = list(),
|
176 |
+
vae_latent_shapes = list(),
|
177 |
+
packed_vae_token_indexes = list(),
|
178 |
+
packed_timesteps = list(),
|
179 |
+
mse_loss_indexes = list(),
|
180 |
+
packed_vit_tokens = list(),
|
181 |
+
vit_token_seqlens = list(),
|
182 |
+
packed_vit_position_ids = list(),
|
183 |
+
packed_vit_token_indexes = list(),
|
184 |
+
)
|
185 |
+
return sequence_status
|
186 |
+
|
187 |
+
def to_tensor(self, sequence_status):
|
188 |
+
data = dict(
|
189 |
+
sequence_length=sum(sequence_status['sample_lens']),
|
190 |
+
sample_lens=sequence_status['sample_lens'],
|
191 |
+
packed_text_ids=torch.tensor(sequence_status['packed_text_ids']),
|
192 |
+
packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']),
|
193 |
+
packed_position_ids=torch.tensor(sequence_status['packed_position_ids']),
|
194 |
+
)
|
195 |
+
if not self.use_flex:
|
196 |
+
data['nested_attention_masks'] = sequence_status['nested_attention_masks']
|
197 |
+
else:
|
198 |
+
sequence_len = data['sequence_length']
|
199 |
+
pad_len = self.max_num_tokens - sequence_len
|
200 |
+
data['split_lens'] = sequence_status['split_lens'] + [pad_len]
|
201 |
+
data['attn_modes'] = sequence_status['attn_modes'] + ['causal']
|
202 |
+
data['sample_lens'] += [pad_len]
|
203 |
+
|
204 |
+
# if the model has a convnet vae (e.g., as visual tokenizer)
|
205 |
+
if len(sequence_status['vae_image_tensors']) > 0:
|
206 |
+
image_tensors = sequence_status.pop('vae_image_tensors')
|
207 |
+
image_sizes = [item.shape for item in image_tensors]
|
208 |
+
max_image_size = [max(item) for item in list(zip(*image_sizes))]
|
209 |
+
padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
|
210 |
+
for i, image_tensor in enumerate(image_tensors):
|
211 |
+
padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
|
212 |
+
|
213 |
+
data['padded_images'] = padded_images
|
214 |
+
data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes']
|
215 |
+
data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0)
|
216 |
+
data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes'])
|
217 |
+
|
218 |
+
# if the model has a vit (e.g., as visual tokenizer)
|
219 |
+
if len(sequence_status['packed_vit_tokens']) > 0:
|
220 |
+
data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0)
|
221 |
+
data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0)
|
222 |
+
data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes'])
|
223 |
+
data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens'])
|
224 |
+
|
225 |
+
# if the model is required to perform visual generation
|
226 |
+
if len(sequence_status['packed_timesteps']) > 0:
|
227 |
+
data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps'])
|
228 |
+
data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes'])
|
229 |
+
|
230 |
+
# if the model is required to perform text generation
|
231 |
+
if len(sequence_status['packed_label_ids']) > 0:
|
232 |
+
data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids'])
|
233 |
+
data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes'])
|
234 |
+
data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights'])
|
235 |
+
|
236 |
+
return data
|
237 |
+
|
238 |
+
def __iter__(self):
|
239 |
+
total_weights = sum(self.grouped_weights)
|
240 |
+
assert total_weights > 0.0
|
241 |
+
group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights
|
242 |
+
for i in range(len(self.grouped_weights))]
|
243 |
+
sequence_status = self.set_sequence_status()
|
244 |
+
batch_data_indexes = []
|
245 |
+
|
246 |
+
buffer = []
|
247 |
+
while True:
|
248 |
+
# Ensure at least one sample from each group
|
249 |
+
if sequence_status['curr'] == 0:
|
250 |
+
for group_index, group_iter in enumerate(self.dataset_iters):
|
251 |
+
if self.is_mandatory[group_index]:
|
252 |
+
while True:
|
253 |
+
sample = next(group_iter)
|
254 |
+
# if a sample is too long, skip it
|
255 |
+
num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
|
256 |
+
if num_tokens < self.max_num_tokens_per_sample:
|
257 |
+
sequence_status = self.pack_sequence(sample, sequence_status)
|
258 |
+
batch_data_indexes.append(sample['data_indexes'])
|
259 |
+
break
|
260 |
+
else:
|
261 |
+
print(f"skip a sample with length {num_tokens}")
|
262 |
+
continue
|
263 |
+
|
264 |
+
if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0:
|
265 |
+
sample = buffer.pop(0)
|
266 |
+
sample_from_buffer = True
|
267 |
+
else:
|
268 |
+
# sample normally across all groups
|
269 |
+
n = random.random()
|
270 |
+
group_index = 0
|
271 |
+
for i, cumprob in enumerate(group_cumprobs):
|
272 |
+
if n < cumprob:
|
273 |
+
group_index = i
|
274 |
+
break
|
275 |
+
sample = next(self.dataset_iters[group_index])
|
276 |
+
sample_from_buffer = False
|
277 |
+
|
278 |
+
# if a sample is too long, skip it
|
279 |
+
num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
|
280 |
+
if num_tokens > self.max_num_tokens_per_sample:
|
281 |
+
print(f"skip a sample with length {num_tokens}")
|
282 |
+
continue
|
283 |
+
|
284 |
+
if sequence_status['curr'] + num_tokens > self.max_num_tokens:
|
285 |
+
if len(buffer) < self.max_buffer_size and not sample_from_buffer:
|
286 |
+
buffer.append(sample)
|
287 |
+
else:
|
288 |
+
print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
|
289 |
+
data = self.to_tensor(sequence_status)
|
290 |
+
data['batch_data_indexes'] = batch_data_indexes
|
291 |
+
yield data
|
292 |
+
sequence_status = self.set_sequence_status()
|
293 |
+
batch_data_indexes = []
|
294 |
+
continue
|
295 |
+
|
296 |
+
sequence_status = self.pack_sequence(sample, sequence_status)
|
297 |
+
batch_data_indexes.append(sample['data_indexes'])
|
298 |
+
|
299 |
+
if sequence_status['curr'] >= self.expected_num_tokens:
|
300 |
+
data = self.to_tensor(sequence_status)
|
301 |
+
data['batch_data_indexes'] = batch_data_indexes
|
302 |
+
yield data
|
303 |
+
sequence_status = self.set_sequence_status()
|
304 |
+
batch_data_indexes = []
|
305 |
+
|
306 |
+
def pack_sequence(self, sample, sequence_status):
|
307 |
+
image_tensor_list = sample['image_tensor_list']
|
308 |
+
text_ids_list = sample['text_ids_list']
|
309 |
+
sequence_plan = sample['sequence_plan']
|
310 |
+
|
311 |
+
split_lens, attn_modes = list(), list()
|
312 |
+
curr = sequence_status['curr']
|
313 |
+
curr_rope_id = 0
|
314 |
+
sample_lens = 0
|
315 |
+
|
316 |
+
for item in sequence_plan:
|
317 |
+
split_start = item.get('split_start', True)
|
318 |
+
if split_start:
|
319 |
+
curr_split_len = 0
|
320 |
+
|
321 |
+
if item['type'] == 'text':
|
322 |
+
text_ids = text_ids_list.pop(0)
|
323 |
+
if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob:
|
324 |
+
continue
|
325 |
+
|
326 |
+
shifted_text_ids = [self.bos_token_id] + text_ids
|
327 |
+
sequence_status['packed_text_ids'].extend(shifted_text_ids)
|
328 |
+
sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
|
329 |
+
if item['loss'] == 1:
|
330 |
+
sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
|
331 |
+
sequence_status['ce_loss_weights'].extend(
|
332 |
+
[len2weight(len(shifted_text_ids))] * len(shifted_text_ids)
|
333 |
+
)
|
334 |
+
sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id])
|
335 |
+
curr += len(shifted_text_ids)
|
336 |
+
curr_split_len += len(shifted_text_ids)
|
337 |
+
|
338 |
+
# add a <|im_end|> token
|
339 |
+
sequence_status['packed_text_ids'].append(self.eos_token_id)
|
340 |
+
sequence_status['packed_text_indexes'].append(curr)
|
341 |
+
if item['special_token_loss'] == 1: # <|im_end|> may have loss
|
342 |
+
sequence_status['ce_loss_indexes'].append(curr)
|
343 |
+
sequence_status['ce_loss_weights'].append(1.0)
|
344 |
+
sequence_status['packed_label_ids'].append(item['special_token_label'])
|
345 |
+
curr += 1
|
346 |
+
curr_split_len += 1
|
347 |
+
|
348 |
+
# update sequence status
|
349 |
+
attn_modes.append("causal")
|
350 |
+
sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
|
351 |
+
curr_rope_id += curr_split_len
|
352 |
+
|
353 |
+
elif item['type'] == 'vit_image':
|
354 |
+
image_tensor = image_tensor_list.pop(0)
|
355 |
+
if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob:
|
356 |
+
curr_rope_id += 1
|
357 |
+
continue
|
358 |
+
|
359 |
+
# add a <|startofimage|> token
|
360 |
+
sequence_status['packed_text_ids'].append(self.start_of_image)
|
361 |
+
sequence_status['packed_text_indexes'].append(curr)
|
362 |
+
curr += 1
|
363 |
+
curr_split_len += 1
|
364 |
+
|
365 |
+
# preprocess image
|
366 |
+
vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size)
|
367 |
+
num_img_tokens = vit_tokens.shape[0]
|
368 |
+
sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens))
|
369 |
+
curr += num_img_tokens
|
370 |
+
curr_split_len += num_img_tokens
|
371 |
+
|
372 |
+
sequence_status['packed_vit_tokens'].append(vit_tokens)
|
373 |
+
sequence_status['vit_token_seqlens'].append(num_img_tokens)
|
374 |
+
sequence_status['packed_vit_position_ids'].append(
|
375 |
+
self.get_flattened_position_ids(
|
376 |
+
image_tensor.size(1), image_tensor.size(2),
|
377 |
+
self.data_config.vit_patch_size,
|
378 |
+
max_num_patches_per_side=self.data_config.max_num_patch_per_side
|
379 |
+
)
|
380 |
+
)
|
381 |
+
|
382 |
+
# add a <|endofimage|> token
|
383 |
+
sequence_status['packed_text_ids'].append(self.end_of_image)
|
384 |
+
sequence_status['packed_text_indexes'].append(curr)
|
385 |
+
if item['special_token_loss'] == 1: # <|endofimage|> may have loss
|
386 |
+
sequence_status['ce_loss_indexes'].append(curr)
|
387 |
+
sequence_status['ce_loss_weights'].append(1.0)
|
388 |
+
sequence_status['packed_label_ids'].append(item['special_token_label'])
|
389 |
+
curr += 1
|
390 |
+
curr_split_len += 1
|
391 |
+
|
392 |
+
# update sequence status
|
393 |
+
attn_modes.append("full")
|
394 |
+
sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len)
|
395 |
+
curr_rope_id += 1
|
396 |
+
|
397 |
+
elif item['type'] == 'vae_image':
|
398 |
+
image_tensor = image_tensor_list.pop(0)
|
399 |
+
if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob:
|
400 |
+
# FIXME fix vae dropout in video2video setting.
|
401 |
+
curr_rope_id += 1
|
402 |
+
continue
|
403 |
+
|
404 |
+
# add a <|startofimage|> token
|
405 |
+
sequence_status['packed_text_ids'].append(self.start_of_image)
|
406 |
+
sequence_status['packed_text_indexes'].append(curr)
|
407 |
+
curr += 1
|
408 |
+
curr_split_len += 1
|
409 |
+
|
410 |
+
# preprocess image
|
411 |
+
sequence_status['vae_image_tensors'].append(image_tensor)
|
412 |
+
sequence_status['packed_latent_position_ids'].append(
|
413 |
+
self.get_flattened_position_ids(
|
414 |
+
image_tensor.size(1), image_tensor.size(2),
|
415 |
+
self.data_config.vae_image_downsample,
|
416 |
+
max_num_patches_per_side=self.data_config.max_latent_size
|
417 |
+
)
|
418 |
+
)
|
419 |
+
H, W = image_tensor.shape[1:]
|
420 |
+
h = H // self.data_config.vae_image_downsample
|
421 |
+
w = W // self.data_config.vae_image_downsample
|
422 |
+
sequence_status['vae_latent_shapes'].append((h, w))
|
423 |
+
|
424 |
+
num_img_tokens = w * h
|
425 |
+
sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens))
|
426 |
+
if item['loss'] == 1:
|
427 |
+
sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens))
|
428 |
+
if split_start:
|
429 |
+
timestep = np.random.randn()
|
430 |
+
else:
|
431 |
+
timestep = float('-inf')
|
432 |
+
|
433 |
+
sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens)
|
434 |
+
curr += num_img_tokens
|
435 |
+
curr_split_len += num_img_tokens
|
436 |
+
|
437 |
+
# add a <|endofimage|> token
|
438 |
+
sequence_status['packed_text_ids'].append(self.end_of_image)
|
439 |
+
sequence_status['packed_text_indexes'].append(curr)
|
440 |
+
# <|endofimage|> may have loss
|
441 |
+
if item['special_token_loss'] == 1:
|
442 |
+
sequence_status['ce_loss_indexes'].append(curr)
|
443 |
+
sequence_status['ce_loss_weights'].append(1.0)
|
444 |
+
sequence_status['packed_label_ids'].append(item['special_token_label'])
|
445 |
+
curr += 1
|
446 |
+
curr_split_len += 1
|
447 |
+
|
448 |
+
# update sequence status
|
449 |
+
if split_start:
|
450 |
+
if item['loss'] == 1 and 'frame_delta' not in item.keys():
|
451 |
+
attn_modes.append("noise")
|
452 |
+
else:
|
453 |
+
attn_modes.append("full")
|
454 |
+
sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2))
|
455 |
+
if 'frame_delta' in item.keys():
|
456 |
+
curr_rope_id += item['frame_delta']
|
457 |
+
elif item['loss'] == 0:
|
458 |
+
curr_rope_id += 1
|
459 |
+
|
460 |
+
if item.get('split_end', True):
|
461 |
+
split_lens.append(curr_split_len)
|
462 |
+
sample_lens += curr_split_len
|
463 |
+
|
464 |
+
sequence_status['curr'] = curr
|
465 |
+
sequence_status['sample_lens'].append(sample_lens)
|
466 |
+
# prepare attention mask
|
467 |
+
if not self.use_flex:
|
468 |
+
sequence_status['nested_attention_masks'].append(
|
469 |
+
prepare_attention_mask_per_sample(split_lens, attn_modes)
|
470 |
+
)
|
471 |
+
else:
|
472 |
+
sequence_status['split_lens'].extend(split_lens)
|
473 |
+
sequence_status['attn_modes'].extend(attn_modes)
|
474 |
+
|
475 |
+
return sequence_status
|
476 |
+
|
477 |
+
|
478 |
+
class SimpleCustomBatch:
|
479 |
+
def __init__(self, batch):
|
480 |
+
data = batch[0]
|
481 |
+
self.batch_data_indexes = data['batch_data_indexes']
|
482 |
+
self.sequence_length = data["sequence_length"]
|
483 |
+
self.sample_lens = data["sample_lens"]
|
484 |
+
self.packed_text_ids = data["packed_text_ids"]
|
485 |
+
self.packed_text_indexes = data["packed_text_indexes"]
|
486 |
+
self.packed_position_ids = data["packed_position_ids"]
|
487 |
+
|
488 |
+
self.use_flex = "nested_attention_masks" not in data.keys()
|
489 |
+
|
490 |
+
if self.use_flex:
|
491 |
+
self.split_lens = data["split_lens"]
|
492 |
+
self.attn_modes = data["attn_modes"]
|
493 |
+
else:
|
494 |
+
self.nested_attention_masks = data["nested_attention_masks"]
|
495 |
+
|
496 |
+
if "padded_images" in data.keys():
|
497 |
+
self.padded_images = data["padded_images"]
|
498 |
+
self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"]
|
499 |
+
self.packed_latent_position_ids = data["packed_latent_position_ids"]
|
500 |
+
self.packed_vae_token_indexes = data["packed_vae_token_indexes"]
|
501 |
+
|
502 |
+
if "packed_vit_tokens" in data.keys():
|
503 |
+
self.packed_vit_tokens = data["packed_vit_tokens"]
|
504 |
+
self.packed_vit_position_ids = data["packed_vit_position_ids"]
|
505 |
+
self.packed_vit_token_indexes = data["packed_vit_token_indexes"]
|
506 |
+
self.vit_token_seqlens = data["vit_token_seqlens"]
|
507 |
+
|
508 |
+
if "packed_timesteps" in data.keys():
|
509 |
+
self.packed_timesteps = data["packed_timesteps"]
|
510 |
+
self.mse_loss_indexes = data["mse_loss_indexes"]
|
511 |
+
|
512 |
+
if "packed_label_ids" in data.keys():
|
513 |
+
self.packed_label_ids = data["packed_label_ids"]
|
514 |
+
self.ce_loss_indexes = data["ce_loss_indexes"]
|
515 |
+
self.ce_loss_weights = data["ce_loss_weights"]
|
516 |
+
|
517 |
+
def pin_memory(self):
|
518 |
+
self.packed_text_ids = self.packed_text_ids.pin_memory()
|
519 |
+
self.packed_text_indexes = self.packed_text_indexes.pin_memory()
|
520 |
+
self.packed_position_ids = self.packed_position_ids.pin_memory()
|
521 |
+
|
522 |
+
if not self.use_flex:
|
523 |
+
self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks]
|
524 |
+
|
525 |
+
if hasattr(self, 'padded_images'):
|
526 |
+
self.padded_images = self.padded_images.pin_memory()
|
527 |
+
self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory()
|
528 |
+
self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory()
|
529 |
+
|
530 |
+
if hasattr(self, 'packed_timesteps'):
|
531 |
+
self.packed_timesteps = self.packed_timesteps.pin_memory()
|
532 |
+
self.mse_loss_indexes = self.mse_loss_indexes.pin_memory()
|
533 |
+
|
534 |
+
if hasattr(self, 'packed_vit_tokens'):
|
535 |
+
self.packed_vit_tokens = self.packed_vit_tokens.pin_memory()
|
536 |
+
self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory()
|
537 |
+
self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory()
|
538 |
+
self.vit_token_seqlens = self.vit_token_seqlens.pin_memory()
|
539 |
+
|
540 |
+
if hasattr(self, 'packed_label_ids'):
|
541 |
+
self.packed_label_ids = self.packed_label_ids.pin_memory()
|
542 |
+
self.ce_loss_indexes = self.ce_loss_indexes.pin_memory()
|
543 |
+
self.ce_loss_weights = self.ce_loss_weights.pin_memory()
|
544 |
+
|
545 |
+
return self
|
546 |
+
|
547 |
+
def cuda(self, device):
|
548 |
+
self.packed_text_ids = self.packed_text_ids.to(device)
|
549 |
+
self.packed_text_indexes = self.packed_text_indexes.to(device)
|
550 |
+
self.packed_position_ids = self.packed_position_ids.to(device)
|
551 |
+
|
552 |
+
if not self.use_flex:
|
553 |
+
self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks]
|
554 |
+
|
555 |
+
if hasattr(self, 'padded_images'):
|
556 |
+
self.padded_images = self.padded_images.to(device)
|
557 |
+
self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device)
|
558 |
+
self.packed_latent_position_ids = self.packed_latent_position_ids.to(device)
|
559 |
+
|
560 |
+
if hasattr(self, 'packed_timesteps'):
|
561 |
+
self.packed_timesteps = self.packed_timesteps.to(device)
|
562 |
+
self.mse_loss_indexes = self.mse_loss_indexes.to(device)
|
563 |
+
|
564 |
+
if hasattr(self, 'packed_vit_tokens'):
|
565 |
+
self.packed_vit_tokens = self.packed_vit_tokens.to(device)
|
566 |
+
self.packed_vit_position_ids = self.packed_vit_position_ids.to(device)
|
567 |
+
self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device)
|
568 |
+
self.vit_token_seqlens = self.vit_token_seqlens.to(device)
|
569 |
+
|
570 |
+
if hasattr(self, 'packed_label_ids'):
|
571 |
+
self.packed_label_ids = self.packed_label_ids.to(device)
|
572 |
+
self.ce_loss_indexes = self.ce_loss_indexes.to(device)
|
573 |
+
self.ce_loss_weights = self.ce_loss_weights.to(device)
|
574 |
+
|
575 |
+
return self
|
576 |
+
|
577 |
+
def to_dict(self):
|
578 |
+
data = dict(
|
579 |
+
sequence_length = self.sequence_length,
|
580 |
+
sample_lens = self.sample_lens,
|
581 |
+
packed_text_ids = self.packed_text_ids,
|
582 |
+
packed_text_indexes = self.packed_text_indexes,
|
583 |
+
packed_position_ids = self.packed_position_ids,
|
584 |
+
batch_data_indexes = self.batch_data_indexes,
|
585 |
+
)
|
586 |
+
|
587 |
+
if not self.use_flex:
|
588 |
+
data['nested_attention_masks'] = self.nested_attention_masks
|
589 |
+
else:
|
590 |
+
data['split_lens'] = self.split_lens
|
591 |
+
data['attn_modes'] = self.attn_modes
|
592 |
+
|
593 |
+
if hasattr(self, 'padded_images'):
|
594 |
+
data['padded_images'] = self.padded_images
|
595 |
+
data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes
|
596 |
+
data['packed_latent_position_ids'] = self.packed_latent_position_ids
|
597 |
+
data['packed_vae_token_indexes'] = self.packed_vae_token_indexes
|
598 |
+
|
599 |
+
if hasattr(self, 'packed_vit_tokens'):
|
600 |
+
data['packed_vit_tokens'] = self.packed_vit_tokens
|
601 |
+
data['packed_vit_position_ids'] = self.packed_vit_position_ids
|
602 |
+
data['packed_vit_token_indexes'] = self.packed_vit_token_indexes
|
603 |
+
data['vit_token_seqlens'] = self.vit_token_seqlens
|
604 |
+
|
605 |
+
if hasattr(self, 'packed_timesteps'):
|
606 |
+
data['packed_timesteps'] = self.packed_timesteps
|
607 |
+
data['mse_loss_indexes'] = self.mse_loss_indexes
|
608 |
+
|
609 |
+
if hasattr(self, 'packed_label_ids'):
|
610 |
+
data['packed_label_ids'] = self.packed_label_ids
|
611 |
+
data['ce_loss_indexes'] = self.ce_loss_indexes
|
612 |
+
data['ce_loss_weights'] = self.ce_loss_weights
|
613 |
+
|
614 |
+
return data
|
615 |
+
|
616 |
+
|
617 |
+
def collate_wrapper():
|
618 |
+
def collate_fn(batch):
|
619 |
+
return SimpleCustomBatch(batch)
|
620 |
+
return collate_fn
|
data/dataset_info.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from .interleave_datasets import UnifiedEditIterableDataset
|
5 |
+
from .t2i_dataset import T2IIterableDataset
|
6 |
+
from .vlm_dataset import SftJSONLIterableDataset
|
7 |
+
|
8 |
+
|
9 |
+
DATASET_REGISTRY = {
|
10 |
+
't2i_pretrain': T2IIterableDataset,
|
11 |
+
'vlm_sft': SftJSONLIterableDataset,
|
12 |
+
'unified_edit': UnifiedEditIterableDataset,
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
DATASET_INFO = {
|
17 |
+
't2i_pretrain': {
|
18 |
+
't2i': {
|
19 |
+
'data_dir': 'your_data_path/bagel_example/t2i', # path of the parquet files
|
20 |
+
'num_files': 10, # number of data units to be sharded across all ranks and workers
|
21 |
+
'num_total_samples': 1000, # number of total samples in the dataset
|
22 |
+
},
|
23 |
+
},
|
24 |
+
'unified_edit':{
|
25 |
+
'seedxedit_multi': {
|
26 |
+
'data_dir': 'your_data_path/bagel_example/editing/seedxedit_multi',
|
27 |
+
'num_files': 10,
|
28 |
+
'num_total_samples': 1000,
|
29 |
+
"parquet_info_path": 'your_data_path/bagel_example/editing/parquet_info/seedxedit_multi_nas.json', # information of the parquet files
|
30 |
+
},
|
31 |
+
},
|
32 |
+
'vlm_sft': {
|
33 |
+
'llava_ov': {
|
34 |
+
'data_dir': 'your_data_path/bagel_example/vlm/images',
|
35 |
+
'jsonl_path': 'your_data_path/bagel_example/vlm/llava_ov_si.jsonl',
|
36 |
+
'num_total_samples': 1000
|
37 |
+
},
|
38 |
+
},
|
39 |
+
}
|
data/distributed_iterable_dataset.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class DistributedIterableDataset(torch.utils.data.IterableDataset):
|
9 |
+
def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
|
10 |
+
self.dataset_name = dataset_name
|
11 |
+
self.local_rank = local_rank
|
12 |
+
self.world_size = world_size
|
13 |
+
self.num_workers = num_workers
|
14 |
+
self.rng = random.Random()
|
15 |
+
self.data_paths = None
|
16 |
+
|
17 |
+
def get_data_paths(self, *args, **kwargs):
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
def set_epoch(self, seed=42):
|
21 |
+
if self.data_paths is None:
|
22 |
+
return
|
23 |
+
|
24 |
+
if isinstance(self.data_paths[0], tuple):
|
25 |
+
data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
|
26 |
+
elif isinstance(self.data_paths[0], str):
|
27 |
+
data_paths = sorted(self.data_paths)
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
|
30 |
+
|
31 |
+
self.rng.seed(seed)
|
32 |
+
self.rng.shuffle(data_paths)
|
33 |
+
|
34 |
+
num_files_per_rank = len(data_paths) // self.world_size
|
35 |
+
local_start = self.local_rank * num_files_per_rank
|
36 |
+
local_end = (self.local_rank + 1) * num_files_per_rank
|
37 |
+
self.num_files_per_rank = num_files_per_rank
|
38 |
+
self.data_paths_per_rank = data_paths[local_start:local_end]
|
39 |
+
|
40 |
+
def get_data_paths_per_worker(self):
|
41 |
+
if self.data_paths is None:
|
42 |
+
return None
|
43 |
+
|
44 |
+
info = torch.utils.data.get_worker_info()
|
45 |
+
if info is None:
|
46 |
+
# Single worker: Use all files assigned to the rank
|
47 |
+
return self.data_paths_per_rank, 0
|
48 |
+
|
49 |
+
worker_id = info.id
|
50 |
+
num_files_per_worker = self.num_files_per_rank // info.num_workers
|
51 |
+
start = num_files_per_worker * worker_id
|
52 |
+
end = num_files_per_worker * (worker_id + 1)
|
53 |
+
data_paths_per_worker = self.data_paths_per_rank[start:end]
|
54 |
+
|
55 |
+
return data_paths_per_worker[::-1], worker_id
|
56 |
+
|
57 |
+
def __iter__(self):
|
58 |
+
raise NotImplementedError
|
data/interleave_datasets/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from .edit_dataset import UnifiedEditIterableDataset
|
5 |
+
|
data/interleave_datasets/edit_dataset.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import io
|
5 |
+
import random
|
6 |
+
from PIL import Image, ImageFile, PngImagePlugin
|
7 |
+
|
8 |
+
from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
|
9 |
+
from ..data_utils import pil_img2rgb
|
10 |
+
|
11 |
+
|
12 |
+
Image.MAX_IMAGE_PIXELS = 200000000
|
13 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
14 |
+
MaximumDecompressedSize = 1024
|
15 |
+
MegaByte = 2 ** 20
|
16 |
+
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
|
17 |
+
|
18 |
+
|
19 |
+
class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
|
20 |
+
|
21 |
+
def parse_row(self, row):
|
22 |
+
image_num = len(row["image_list"])
|
23 |
+
# randomly choose start and end, return [0, 1] when only two images
|
24 |
+
start_idx = random.choice(range(image_num - 1))
|
25 |
+
max_end = min(start_idx + 3, image_num)
|
26 |
+
end_idx = random.choice(range(start_idx + 1, max_end))
|
27 |
+
|
28 |
+
data = self._init_data()
|
29 |
+
data = self._add_image(
|
30 |
+
data,
|
31 |
+
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
|
32 |
+
need_loss=False,
|
33 |
+
need_vae=True,
|
34 |
+
need_vit=True,
|
35 |
+
)
|
36 |
+
|
37 |
+
if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
|
38 |
+
if end_idx == image_num - 1:
|
39 |
+
end_idx -= 1
|
40 |
+
|
41 |
+
instruction = ""
|
42 |
+
for idx in range(start_idx + 1, end_idx + 1):
|
43 |
+
instruction += random.choice(row["instruction_list"][idx-1]) + ". "
|
44 |
+
data = self._add_text(data, instruction.rstrip(), need_loss=False)
|
45 |
+
data = self._add_image(
|
46 |
+
data,
|
47 |
+
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
|
48 |
+
need_loss=True,
|
49 |
+
need_vae=False,
|
50 |
+
need_vit=False,
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
for idx in range(start_idx + 1, end_idx + 1):
|
54 |
+
instruction = random.choice(row["instruction_list"][idx-1])
|
55 |
+
data = self._add_text(data, instruction, need_loss=False)
|
56 |
+
if idx != end_idx:
|
57 |
+
data = self._add_image(
|
58 |
+
data,
|
59 |
+
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
|
60 |
+
need_loss=True,
|
61 |
+
need_vae=True,
|
62 |
+
need_vit=True,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
data = self._add_image(
|
66 |
+
data,
|
67 |
+
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
|
68 |
+
need_loss=True,
|
69 |
+
need_vae=False,
|
70 |
+
need_vit=False,
|
71 |
+
)
|
72 |
+
return data
|
data/interleave_datasets/interleave_t2i_dataset.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import pyarrow.parquet as pq
|
5 |
+
|
6 |
+
from ..distributed_iterable_dataset import DistributedIterableDataset
|
7 |
+
from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
|
8 |
+
|
9 |
+
|
10 |
+
class InterleavedBaseIterableDataset(DistributedIterableDataset):
|
11 |
+
|
12 |
+
def _init_data(self):
|
13 |
+
data = {
|
14 |
+
'sequence_plan': [],
|
15 |
+
'text_ids_list': [],
|
16 |
+
'image_tensor_list': [],
|
17 |
+
'num_tokens': 0,
|
18 |
+
}
|
19 |
+
return data
|
20 |
+
|
21 |
+
def _add_text(self, data, text, need_loss, enable_cfg=True):
|
22 |
+
text_ids = self.tokenizer.encode(text)
|
23 |
+
data['num_tokens'] += len(text_ids)
|
24 |
+
data['text_ids_list'].append(text_ids)
|
25 |
+
data['sequence_plan'].append(
|
26 |
+
{
|
27 |
+
'type': 'text',
|
28 |
+
'enable_cfg': int(enable_cfg),
|
29 |
+
'loss': int(need_loss),
|
30 |
+
'special_token_loss': 0,
|
31 |
+
'special_token_label': None,
|
32 |
+
}
|
33 |
+
)
|
34 |
+
return data
|
35 |
+
|
36 |
+
def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True):
|
37 |
+
assert need_loss or need_vae or need_vit
|
38 |
+
|
39 |
+
if need_loss:
|
40 |
+
data['sequence_plan'].append(
|
41 |
+
{
|
42 |
+
'type': 'vae_image',
|
43 |
+
'enable_cfg': 0,
|
44 |
+
'loss': 1,
|
45 |
+
'special_token_loss': 0,
|
46 |
+
'special_token_label': None,
|
47 |
+
}
|
48 |
+
)
|
49 |
+
|
50 |
+
image_tensor = self.transform(image)
|
51 |
+
height, width = image_tensor.shape[1:]
|
52 |
+
data['num_tokens'] += width * height // self.transform.stride ** 2
|
53 |
+
data['image_tensor_list'].append(image_tensor)
|
54 |
+
|
55 |
+
if need_vae:
|
56 |
+
data['sequence_plan'].append(
|
57 |
+
{
|
58 |
+
'type': 'vae_image',
|
59 |
+
'enable_cfg': int(enable_cfg),
|
60 |
+
'loss': 0,
|
61 |
+
'special_token_loss': 0,
|
62 |
+
'special_token_label': None,
|
63 |
+
}
|
64 |
+
)
|
65 |
+
|
66 |
+
image_tensor = self.transform(image)
|
67 |
+
height, width = image_tensor.shape[1:]
|
68 |
+
data['num_tokens'] += width * height // self.transform.stride ** 2
|
69 |
+
data['image_tensor_list'].append(image_tensor.clone())
|
70 |
+
|
71 |
+
if need_vit:
|
72 |
+
data['sequence_plan'].append(
|
73 |
+
{
|
74 |
+
'type': 'vit_image',
|
75 |
+
'enable_cfg': int(enable_cfg),
|
76 |
+
'loss': 0,
|
77 |
+
'special_token_loss': 0,
|
78 |
+
'special_token_label': None,
|
79 |
+
},
|
80 |
+
)
|
81 |
+
vit_image_tensor = self.vit_transform(image)
|
82 |
+
height, width = vit_image_tensor.shape[1:]
|
83 |
+
data['num_tokens'] += width * height // self.vit_transform.stride ** 2
|
84 |
+
data['image_tensor_list'].append(vit_image_tensor)
|
85 |
+
|
86 |
+
return data
|
87 |
+
|
88 |
+
def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
|
89 |
+
assert int(need_loss) + int(need_vae) == 1
|
90 |
+
|
91 |
+
if need_loss:
|
92 |
+
for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
|
93 |
+
current_sequence_plan = {
|
94 |
+
'type': 'vae_image',
|
95 |
+
'enable_cfg': 0,
|
96 |
+
'loss': 1,
|
97 |
+
'special_token_loss': 0,
|
98 |
+
'special_token_label': None,
|
99 |
+
'split_start': idx == 0,
|
100 |
+
'split_end': idx == len(frames) - 1,
|
101 |
+
}
|
102 |
+
if idx < len(frame_indexes) - 1:
|
103 |
+
current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
|
104 |
+
data['sequence_plan'].append(current_sequence_plan)
|
105 |
+
image_tensor = self.transform(image)
|
106 |
+
height, width = image_tensor.shape[1:]
|
107 |
+
data['image_tensor_list'].append(image_tensor)
|
108 |
+
data['num_tokens'] += width * height // self.transform.stride ** 2
|
109 |
+
|
110 |
+
elif need_vae:
|
111 |
+
for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
|
112 |
+
current_sequence_plan = {
|
113 |
+
'type': 'vae_image',
|
114 |
+
'enable_cfg': int(enable_cfg),
|
115 |
+
'loss': 0,
|
116 |
+
'special_token_loss': 0,
|
117 |
+
'special_token_label': None,
|
118 |
+
'split_start': idx == 0,
|
119 |
+
'split_end': idx == len(frames) - 1,
|
120 |
+
}
|
121 |
+
if idx < len(frame_indexes) - 1:
|
122 |
+
current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
|
123 |
+
data['sequence_plan'].append(current_sequence_plan)
|
124 |
+
image_tensor = self.transform(image)
|
125 |
+
height, width = image_tensor.shape[1:]
|
126 |
+
data['image_tensor_list'].append(image_tensor)
|
127 |
+
data['num_tokens'] += width * height // self.transform.stride ** 2
|
128 |
+
|
129 |
+
return data
|
130 |
+
|
131 |
+
|
132 |
+
class ParquetStandardIterableDataset(DistributedIterableDataset):
|
133 |
+
|
134 |
+
def __init__(
|
135 |
+
self, dataset_name, transform, tokenizer, vit_transform,
|
136 |
+
data_dir_list, num_used_data, parquet_info,
|
137 |
+
local_rank=0, world_size=1, num_workers=8, data_status=None,
|
138 |
+
):
|
139 |
+
"""
|
140 |
+
data_dir_list: list of data directories contains parquet files
|
141 |
+
num_used_data: list of number of sampled data paths for each data directory
|
142 |
+
vit_transform: input transform for vit model.
|
143 |
+
"""
|
144 |
+
super().__init__(dataset_name, local_rank, world_size, num_workers)
|
145 |
+
self.transform = transform
|
146 |
+
self.vit_transform = vit_transform
|
147 |
+
self.tokenizer = tokenizer
|
148 |
+
self.data_status = data_status
|
149 |
+
self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
|
150 |
+
self.set_epoch()
|
151 |
+
|
152 |
+
def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
|
153 |
+
row_groups = []
|
154 |
+
for data_dir, num_data_path in zip(data_dir_list, num_used_data):
|
155 |
+
data_paths = get_parquet_data_paths([data_dir], [num_data_path])
|
156 |
+
for data_path in data_paths:
|
157 |
+
if data_path in parquet_info.keys():
|
158 |
+
num_row_groups = parquet_info[data_path]['num_row_groups']
|
159 |
+
for rg_idx in range(num_row_groups):
|
160 |
+
row_groups.append((data_path, rg_idx))
|
161 |
+
return row_groups
|
162 |
+
|
163 |
+
def parse_row(self, row):
|
164 |
+
raise NotImplementedError
|
165 |
+
|
166 |
+
def __iter__(self):
|
167 |
+
file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
|
168 |
+
if self.data_status is not None:
|
169 |
+
global_row_group_start_id = self.data_status[worker_id][0]
|
170 |
+
row_start_id = self.data_status[worker_id][1] + 1
|
171 |
+
else:
|
172 |
+
global_row_group_start_id = 0
|
173 |
+
row_start_id = 0
|
174 |
+
|
175 |
+
print(
|
176 |
+
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
|
177 |
+
f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
|
178 |
+
)
|
179 |
+
|
180 |
+
while True:
|
181 |
+
file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
|
182 |
+
for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
|
183 |
+
file_paths_per_worker_, start=global_row_group_start_id
|
184 |
+
):
|
185 |
+
fs = init_arrow_pf_fs(parquet_file_path)
|
186 |
+
with fs.open_input_file(parquet_file_path) as f:
|
187 |
+
try:
|
188 |
+
fr = pq.ParquetFile(f)
|
189 |
+
df = fr.read_row_group(row_group_id).to_pandas()
|
190 |
+
df = df.iloc[row_start_id:]
|
191 |
+
except Exception as e:
|
192 |
+
print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
|
193 |
+
continue
|
194 |
+
|
195 |
+
for row_idx, row in df.iterrows():
|
196 |
+
try:
|
197 |
+
data = self.parse_row(row)
|
198 |
+
if len(data) == 0:
|
199 |
+
continue
|
200 |
+
data['data_indexes'] = {
|
201 |
+
"data_indexes": [global_row_group_idx, row_idx],
|
202 |
+
"worker_id": worker_id,
|
203 |
+
"dataset_name": self.dataset_name,
|
204 |
+
}
|
205 |
+
except Exception as e:
|
206 |
+
print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
|
207 |
+
continue
|
208 |
+
yield data
|
209 |
+
|
210 |
+
row_start_id = 0
|
211 |
+
global_row_group_start_id = 0
|
212 |
+
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
|
data/parquet_utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
|
5 |
+
import os
|
6 |
+
import xml.etree.ElementTree as ET
|
7 |
+
import subprocess
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import pyarrow.fs as pf
|
11 |
+
import torch.distributed as dist
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
|
17 |
+
num_data_dirs = len(data_dir_list)
|
18 |
+
if world_size > 1:
|
19 |
+
chunk_size = (num_data_dirs + world_size - 1) // world_size
|
20 |
+
start_idx = rank * chunk_size
|
21 |
+
end_idx = min(start_idx + chunk_size, num_data_dirs)
|
22 |
+
local_data_dir_list = data_dir_list[start_idx:end_idx]
|
23 |
+
local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
|
24 |
+
else:
|
25 |
+
local_data_dir_list = data_dir_list
|
26 |
+
local_num_sampled_data_paths = num_sampled_data_paths
|
27 |
+
|
28 |
+
local_data_paths = []
|
29 |
+
for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths):
|
30 |
+
if data_dir.startswith("hdfs://"):
|
31 |
+
files = hdfs_ls_cmd(data_dir)
|
32 |
+
data_paths_per_dir = [
|
33 |
+
file for file in files if file.endswith(".parquet")
|
34 |
+
]
|
35 |
+
else:
|
36 |
+
files = os.listdir(data_dir)
|
37 |
+
data_paths_per_dir = [
|
38 |
+
os.path.join(data_dir, name)
|
39 |
+
for name in files
|
40 |
+
if name.endswith(".parquet")
|
41 |
+
]
|
42 |
+
repeat = num_data_path // len(data_paths_per_dir)
|
43 |
+
data_paths_per_dir = data_paths_per_dir * (repeat + 1)
|
44 |
+
local_data_paths.extend(data_paths_per_dir[:num_data_path])
|
45 |
+
|
46 |
+
if world_size > 1:
|
47 |
+
gather_list = [None] * world_size
|
48 |
+
dist.all_gather_object(gather_list, local_data_paths)
|
49 |
+
|
50 |
+
combined_chunks = []
|
51 |
+
for chunk_list in gather_list:
|
52 |
+
if chunk_list is not None:
|
53 |
+
combined_chunks.extend(chunk_list)
|
54 |
+
else:
|
55 |
+
combined_chunks = local_data_paths
|
56 |
+
|
57 |
+
return combined_chunks
|
58 |
+
|
59 |
+
|
60 |
+
# NOTE: cumtomize this function for your cluster
|
61 |
+
def get_hdfs_host():
|
62 |
+
return "hdfs://xxx"
|
63 |
+
|
64 |
+
|
65 |
+
# NOTE: cumtomize this function for your cluster
|
66 |
+
def get_hdfs_block_size():
|
67 |
+
return 134217728
|
68 |
+
|
69 |
+
|
70 |
+
# NOTE: cumtomize this function for your cluster
|
71 |
+
def get_hdfs_extra_conf():
|
72 |
+
return None
|
73 |
+
|
74 |
+
|
75 |
+
def init_arrow_pf_fs(parquet_file_path):
|
76 |
+
if parquet_file_path.startswith("hdfs://"):
|
77 |
+
fs = pf.HadoopFileSystem(
|
78 |
+
host=get_hdfs_host(),
|
79 |
+
port=0,
|
80 |
+
buffer_size=get_hdfs_block_size(),
|
81 |
+
extra_conf=get_hdfs_extra_conf(),
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
fs = pf.LocalFileSystem()
|
85 |
+
return fs
|
86 |
+
|
87 |
+
|
88 |
+
def hdfs_ls_cmd(dir):
|
89 |
+
result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout
|
90 |
+
return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i]
|
data/t2i_dataset.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import io
|
5 |
+
import json
|
6 |
+
import pyarrow.parquet as pq
|
7 |
+
import random
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from .data_utils import pil_img2rgb
|
11 |
+
from .distributed_iterable_dataset import DistributedIterableDataset
|
12 |
+
from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
|
13 |
+
|
14 |
+
Image.MAX_IMAGE_PIXELS = 20_000_000
|
15 |
+
|
16 |
+
|
17 |
+
class T2IIterableDataset(DistributedIterableDataset):
|
18 |
+
def __init__(
|
19 |
+
self, dataset_name, transform, tokenizer, data_dir_list, num_used_data,
|
20 |
+
local_rank=0, world_size=1, num_workers=8, data_status=None,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
data_dir_list: list of data directories contains parquet files
|
24 |
+
num_used_data: list of number of sampled data paths for each data directory
|
25 |
+
"""
|
26 |
+
super().__init__(dataset_name, local_rank, world_size, num_workers)
|
27 |
+
self.transform = transform
|
28 |
+
self.tokenizer = tokenizer
|
29 |
+
self.data_status = data_status
|
30 |
+
self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
|
31 |
+
self.set_epoch()
|
32 |
+
|
33 |
+
def get_data_paths(self, data_dir_list, num_used_data):
|
34 |
+
return get_parquet_data_paths(data_dir_list, num_used_data)
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
|
38 |
+
if self.data_status is not None:
|
39 |
+
parquet_start_id = self.data_status[worker_id][0]
|
40 |
+
row_group_start_id = self.data_status[worker_id][1]
|
41 |
+
row_start_id = self.data_status[worker_id][2] + 1
|
42 |
+
else:
|
43 |
+
parquet_start_id = 0
|
44 |
+
row_group_start_id = 0
|
45 |
+
row_start_id = 0
|
46 |
+
transform_stride = self.transform.stride
|
47 |
+
|
48 |
+
print(
|
49 |
+
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
|
50 |
+
f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
|
51 |
+
)
|
52 |
+
|
53 |
+
while True:
|
54 |
+
data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
|
55 |
+
for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id):
|
56 |
+
fs = init_arrow_pf_fs(parquet_file_path)
|
57 |
+
with fs.open_input_file(parquet_file_path) as f:
|
58 |
+
fr = pq.ParquetFile(f)
|
59 |
+
row_group_ids = list(range(fr.num_row_groups))
|
60 |
+
row_group_ids_ = row_group_ids[row_group_start_id:]
|
61 |
+
|
62 |
+
for row_group_id in row_group_ids_:
|
63 |
+
df = fr.read_row_group(row_group_id).to_pandas()
|
64 |
+
df = df.iloc[row_start_id:]
|
65 |
+
|
66 |
+
for row_idx, row in df.iterrows():
|
67 |
+
num_tokens = 0
|
68 |
+
try:
|
69 |
+
image_byte = row['image']
|
70 |
+
image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
|
71 |
+
except Exception as e:
|
72 |
+
print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
|
73 |
+
continue
|
74 |
+
image_tensor = self.transform(image)
|
75 |
+
height, width = image_tensor.shape[1:]
|
76 |
+
num_tokens += width * height // transform_stride ** 2
|
77 |
+
|
78 |
+
try:
|
79 |
+
caption_dict = row['captions']
|
80 |
+
caption_dict = json.loads(caption_dict)
|
81 |
+
except Exception as e:
|
82 |
+
print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
|
83 |
+
continue
|
84 |
+
|
85 |
+
caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
|
86 |
+
if len(caps_token) == 0:
|
87 |
+
print(f'no caption in rg#{row_group_id}, {parquet_file_path}')
|
88 |
+
caption_token = self.tokenizer.encode(' ')
|
89 |
+
else:
|
90 |
+
caption_token = random.choice(caps_token)
|
91 |
+
|
92 |
+
sequence_plan, text_ids_list = [], []
|
93 |
+
text_ids = caption_token
|
94 |
+
num_tokens += len(caption_token)
|
95 |
+
text_ids_list.append(text_ids)
|
96 |
+
sequence_plan.append({
|
97 |
+
'type': 'text',
|
98 |
+
'enable_cfg': 1,
|
99 |
+
'loss': 0,
|
100 |
+
'special_token_loss': 0,
|
101 |
+
'special_token_label': None,
|
102 |
+
})
|
103 |
+
|
104 |
+
sequence_plan.append({
|
105 |
+
'type': 'vae_image',
|
106 |
+
'enable_cfg': 0,
|
107 |
+
'loss': 1,
|
108 |
+
'special_token_loss': 0,
|
109 |
+
'special_token_label': None,
|
110 |
+
})
|
111 |
+
|
112 |
+
sample = dict(
|
113 |
+
image_tensor_list=[image_tensor],
|
114 |
+
text_ids_list=text_ids_list,
|
115 |
+
num_tokens=num_tokens,
|
116 |
+
sequence_plan=sequence_plan,
|
117 |
+
data_indexes={
|
118 |
+
"data_indexes": [parquet_idx, row_group_id, row_idx],
|
119 |
+
"worker_id": worker_id,
|
120 |
+
"dataset_name": self.dataset_name,
|
121 |
+
}
|
122 |
+
)
|
123 |
+
yield sample
|
124 |
+
|
125 |
+
row_start_id = 0
|
126 |
+
row_group_start_id = 0
|
127 |
+
parquet_start_id = 0
|
128 |
+
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
|
data/transforms.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import random
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.transforms import functional as F
|
12 |
+
from torchvision.transforms import InterpolationMode
|
13 |
+
|
14 |
+
|
15 |
+
class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
|
16 |
+
"""Resize the input image so that its longest side and shortest side are within a specified range,
|
17 |
+
ensuring that both sides are divisible by a specified stride.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
max_size (int): Maximum size for the longest edge of the image.
|
21 |
+
min_size (int): Minimum size for the shortest edge of the image.
|
22 |
+
stride (int): Value by which the height and width of the image must be divisible.
|
23 |
+
max_pixels (int): Maximum pixels for the full image.
|
24 |
+
interpolation (InterpolationMode): Desired interpolation enum defined by
|
25 |
+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
26 |
+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
|
27 |
+
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
|
28 |
+
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
|
29 |
+
antialias (bool, optional): Whether to apply antialiasing (default is True).
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
max_size: int,
|
35 |
+
min_size: int,
|
36 |
+
stride: int,
|
37 |
+
max_pixels: int,
|
38 |
+
interpolation=InterpolationMode.BICUBIC,
|
39 |
+
antialias=True
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.max_size = max_size
|
43 |
+
self.min_size = min_size
|
44 |
+
self.stride = stride
|
45 |
+
self.max_pixels = max_pixels
|
46 |
+
self.interpolation = interpolation
|
47 |
+
self.antialias = antialias
|
48 |
+
|
49 |
+
def _make_divisible(self, value, stride):
|
50 |
+
"""Ensure the value is divisible by the stride."""
|
51 |
+
return max(stride, int(round(value / stride) * stride))
|
52 |
+
|
53 |
+
def _apply_scale(self, width, height, scale):
|
54 |
+
new_width = round(width * scale)
|
55 |
+
new_height = round(height * scale)
|
56 |
+
new_width = self._make_divisible(new_width, self.stride)
|
57 |
+
new_height = self._make_divisible(new_height, self.stride)
|
58 |
+
return new_width, new_height
|
59 |
+
|
60 |
+
def forward(self, img, img_num=1):
|
61 |
+
"""
|
62 |
+
Args:
|
63 |
+
img (PIL Image): Image to be resized.
|
64 |
+
img_num (int): Number of images, used to change max_tokens.
|
65 |
+
Returns:
|
66 |
+
PIL Image or Tensor: Rescaled image with divisible dimensions.
|
67 |
+
"""
|
68 |
+
if isinstance(img, torch.Tensor):
|
69 |
+
height, width = img.shape[-2:]
|
70 |
+
else:
|
71 |
+
width, height = img.size
|
72 |
+
|
73 |
+
scale = min(self.max_size / max(width, height), 1.0)
|
74 |
+
scale = max(scale, self.min_size / min(width, height))
|
75 |
+
new_width, new_height = self._apply_scale(width, height, scale)
|
76 |
+
|
77 |
+
# Ensure the number of pixels does not exceed max_pixels
|
78 |
+
if new_width * new_height > self.max_pixels / img_num:
|
79 |
+
scale = self.max_pixels / img_num / (new_width * new_height)
|
80 |
+
new_width, new_height = self._apply_scale(new_width, new_height, scale)
|
81 |
+
|
82 |
+
# Ensure longest edge does not exceed max_size
|
83 |
+
if max(new_width, new_height) > self.max_size:
|
84 |
+
scale = self.max_size / max(new_width, new_height)
|
85 |
+
new_width, new_height = self._apply_scale(new_width, new_height, scale)
|
86 |
+
|
87 |
+
return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
|
88 |
+
|
89 |
+
|
90 |
+
class ImageTransform:
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
max_image_size,
|
94 |
+
min_image_size,
|
95 |
+
image_stride,
|
96 |
+
max_pixels=14*14*9*1024,
|
97 |
+
image_mean=[0.5, 0.5, 0.5],
|
98 |
+
image_std=[0.5, 0.5, 0.5]
|
99 |
+
):
|
100 |
+
self.stride = image_stride
|
101 |
+
|
102 |
+
self.resize_transform = MaxLongEdgeMinShortEdgeResize(
|
103 |
+
max_size=max_image_size,
|
104 |
+
min_size=min_image_size,
|
105 |
+
stride=image_stride,
|
106 |
+
max_pixels=max_pixels,
|
107 |
+
)
|
108 |
+
self.to_tensor_transform = transforms.ToTensor()
|
109 |
+
self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
|
110 |
+
|
111 |
+
def __call__(self, img, img_num=1):
|
112 |
+
img = self.resize_transform(img, img_num=img_num)
|
113 |
+
img = self.to_tensor_transform(img)
|
114 |
+
img = self.normalize_transform(img)
|
115 |
+
return img
|
116 |
+
|
117 |
+
|
118 |
+
def decolorization(image):
|
119 |
+
gray_image = image.convert('L')
|
120 |
+
return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
|
121 |
+
|
122 |
+
|
123 |
+
def downscale(image, scale_factor):
|
124 |
+
new_width = int(round(image.width * scale_factor))
|
125 |
+
new_height = int(round(image.height * scale_factor))
|
126 |
+
new_width = max(1, new_width)
|
127 |
+
new_height = max(1, new_height)
|
128 |
+
return image.resize((new_width, new_height), resample=Image.BICUBIC)
|
129 |
+
|
130 |
+
|
131 |
+
def crop(image, crop_factors):
|
132 |
+
target_h, target_w = crop_factors
|
133 |
+
img_w, img_h = image.size
|
134 |
+
|
135 |
+
if target_h > img_h or target_w > img_w:
|
136 |
+
raise ValueError("Crop size exceeds image dimensions")
|
137 |
+
|
138 |
+
x = random.randint(0, img_w - target_w)
|
139 |
+
y = random.randint(0, img_h - target_h)
|
140 |
+
|
141 |
+
return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
|
142 |
+
|
143 |
+
|
144 |
+
def motion_blur_opencv(image, kernel_size=15, angle=0):
|
145 |
+
# 线性核
|
146 |
+
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
147 |
+
kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
|
148 |
+
|
149 |
+
# 旋转核
|
150 |
+
center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
|
151 |
+
M = cv2.getRotationMatrix2D(center, angle, 1)
|
152 |
+
rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
|
153 |
+
|
154 |
+
# 归一化核
|
155 |
+
rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
|
156 |
+
|
157 |
+
img = np.array(image)
|
158 |
+
if img.ndim == 2:
|
159 |
+
blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
|
160 |
+
else:
|
161 |
+
# 对于彩色图像,各通道独立卷积
|
162 |
+
blurred = np.zeros_like(img)
|
163 |
+
for c in range(img.shape[2]):
|
164 |
+
blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
|
165 |
+
|
166 |
+
return Image.fromarray(blurred.astype(np.uint8))
|
167 |
+
|
168 |
+
|
169 |
+
def shuffle_patch(image, num_splits, gap_size=2):
|
170 |
+
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
|
171 |
+
h_splits, w_splits = num_splits
|
172 |
+
img_w, img_h = image.size
|
173 |
+
|
174 |
+
base_patch_h = img_h // h_splits
|
175 |
+
patch_heights = [base_patch_h] * (h_splits - 1)
|
176 |
+
patch_heights.append(img_h - sum(patch_heights))
|
177 |
+
|
178 |
+
base_patch_w = img_w // w_splits
|
179 |
+
patch_widths = [base_patch_w] * (w_splits - 1)
|
180 |
+
patch_widths.append(img_w - sum(patch_widths))
|
181 |
+
|
182 |
+
patches = []
|
183 |
+
current_y = 0
|
184 |
+
for i in range(h_splits):
|
185 |
+
current_x = 0
|
186 |
+
patch_h = patch_heights[i]
|
187 |
+
for j in range(w_splits):
|
188 |
+
patch_w = patch_widths[j]
|
189 |
+
patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
|
190 |
+
patches.append(patch)
|
191 |
+
current_x += patch_w
|
192 |
+
current_y += patch_h
|
193 |
+
|
194 |
+
random.shuffle(patches)
|
195 |
+
|
196 |
+
total_width = sum(patch_widths) + (w_splits - 1) * gap_size
|
197 |
+
total_height = sum(patch_heights) + (h_splits - 1) * gap_size
|
198 |
+
new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
|
199 |
+
|
200 |
+
current_y = 0 # 当前行的起始 Y 坐标
|
201 |
+
patch_idx = 0 # 当前处理的块索引
|
202 |
+
for i in range(h_splits):
|
203 |
+
current_x = 0 # 当前列的起始 X 坐标
|
204 |
+
patch_h = patch_heights[i] # 当前行块的高度
|
205 |
+
for j in range(w_splits):
|
206 |
+
# 取出打乱后的块
|
207 |
+
patch = patches[patch_idx]
|
208 |
+
patch_w = patch_widths[j] # 当前列块的宽度
|
209 |
+
# 粘贴块(左上角坐标为 (current_x, current_y))
|
210 |
+
new_image.paste(patch, (current_x, current_y))
|
211 |
+
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
|
212 |
+
current_x += patch_w + gap_size
|
213 |
+
patch_idx += 1
|
214 |
+
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
|
215 |
+
current_y += patch_h + gap_size
|
216 |
+
|
217 |
+
return new_image
|
218 |
+
|
219 |
+
|
220 |
+
def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
|
221 |
+
"""
|
222 |
+
图像分割后随机空白部分patch,用于inpainting任务
|
223 |
+
|
224 |
+
参数:
|
225 |
+
image: PIL.Image 输入图像(RGB模式)
|
226 |
+
h_splits: int 行分割数(垂直方向分割块数)
|
227 |
+
w_splits: int 列分割数(水平方向分割块数)
|
228 |
+
blank_ratio: float 空白patch的比例(0~1)
|
229 |
+
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
|
230 |
+
|
231 |
+
返回:
|
232 |
+
PIL.Image 处理后拼接的图像
|
233 |
+
"""
|
234 |
+
h_splits, w_splits = num_splits
|
235 |
+
img_w, img_h = image.size
|
236 |
+
|
237 |
+
base_patch_h = img_h // h_splits
|
238 |
+
patch_heights = [base_patch_h] * (h_splits - 1)
|
239 |
+
patch_heights.append(img_h - sum(patch_heights))
|
240 |
+
|
241 |
+
base_patch_w = img_w // w_splits
|
242 |
+
patch_widths = [base_patch_w] * (w_splits - 1)
|
243 |
+
patch_widths.append(img_w - sum(patch_widths))
|
244 |
+
|
245 |
+
patches = []
|
246 |
+
current_y = 0
|
247 |
+
for i in range(h_splits):
|
248 |
+
current_x = 0
|
249 |
+
patch_h = patch_heights[i]
|
250 |
+
for j in range(w_splits):
|
251 |
+
patch_w = patch_widths[j]
|
252 |
+
patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
|
253 |
+
patches.append(patch)
|
254 |
+
current_x += patch_w
|
255 |
+
current_y += patch_h
|
256 |
+
|
257 |
+
total_patches = h_splits * w_splits
|
258 |
+
num_blank = int(total_patches * blank_ratio)
|
259 |
+
num_blank = max(0, min(num_blank, total_patches))
|
260 |
+
blank_indices = random.sample(range(total_patches), num_blank)
|
261 |
+
|
262 |
+
processed_patches = []
|
263 |
+
for idx, patch in enumerate(patches):
|
264 |
+
if idx in blank_indices:
|
265 |
+
blank_patch = Image.new("RGB", patch.size, color=blank_color)
|
266 |
+
processed_patches.append(blank_patch)
|
267 |
+
else:
|
268 |
+
processed_patches.append(patch)
|
269 |
+
|
270 |
+
# 创建结果图像(尺寸与原图一致)
|
271 |
+
result_image = Image.new("RGB", (img_w, img_h))
|
272 |
+
current_y = 0
|
273 |
+
patch_idx = 0
|
274 |
+
for i in range(h_splits):
|
275 |
+
current_x = 0
|
276 |
+
patch_h = patch_heights[i]
|
277 |
+
for j in range(w_splits):
|
278 |
+
# 取出处理后的patch
|
279 |
+
patch = processed_patches[patch_idx]
|
280 |
+
patch_w = patch_widths[j]
|
281 |
+
# 粘贴到原位置
|
282 |
+
result_image.paste(patch, (current_x, current_y))
|
283 |
+
current_x += patch_w
|
284 |
+
patch_idx += 1
|
285 |
+
current_y += patch_h
|
286 |
+
|
287 |
+
return result_image
|
data/video_utils.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 OpenGVLab
|
2 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
|
3 |
+
# SPDX-License-Identifier: MIT
|
4 |
+
#
|
5 |
+
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
|
6 |
+
#
|
7 |
+
# Original file was released under MIT, with the full license text
|
8 |
+
# available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
|
9 |
+
#
|
10 |
+
# This modified file is released under the same license.
|
11 |
+
|
12 |
+
|
13 |
+
import io
|
14 |
+
import os
|
15 |
+
import random
|
16 |
+
import re
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import decord
|
20 |
+
from PIL import Image
|
21 |
+
|
22 |
+
|
23 |
+
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
|
24 |
+
if sample in ['rand', 'middle']: # uniform sampling
|
25 |
+
acc_samples = min(num_frames, vlen)
|
26 |
+
# split the video into `acc_samples` intervals, and sample from each interval.
|
27 |
+
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
|
28 |
+
ranges = []
|
29 |
+
for idx, interv in enumerate(intervals[:-1]):
|
30 |
+
ranges.append((interv, intervals[idx + 1] - 1))
|
31 |
+
if sample == 'rand':
|
32 |
+
try:
|
33 |
+
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
|
34 |
+
except:
|
35 |
+
frame_indices = np.random.permutation(vlen)[:acc_samples]
|
36 |
+
frame_indices.sort()
|
37 |
+
frame_indices = list(frame_indices)
|
38 |
+
elif fix_start is not None:
|
39 |
+
frame_indices = [x[0] + fix_start for x in ranges]
|
40 |
+
elif sample == 'middle':
|
41 |
+
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
|
42 |
+
else:
|
43 |
+
raise NotImplementedError
|
44 |
+
|
45 |
+
if len(frame_indices) < num_frames: # padded with last frame
|
46 |
+
padded_frame_indices = [frame_indices[-1]] * num_frames
|
47 |
+
padded_frame_indices[:len(frame_indices)] = frame_indices
|
48 |
+
frame_indices = padded_frame_indices
|
49 |
+
elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
|
50 |
+
output_fps = float(sample[3:])
|
51 |
+
duration = float(vlen) / input_fps
|
52 |
+
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
|
53 |
+
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
|
54 |
+
frame_indices = np.around(frame_seconds * input_fps).astype(int)
|
55 |
+
frame_indices = [e for e in frame_indices if e < vlen]
|
56 |
+
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
|
57 |
+
frame_indices = frame_indices[:max_num_frames]
|
58 |
+
else:
|
59 |
+
raise ValueError
|
60 |
+
return frame_indices
|
61 |
+
|
62 |
+
|
63 |
+
def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
|
64 |
+
video_reader = decord.VideoReader(video_path, num_threads=1)
|
65 |
+
vlen = len(video_reader)
|
66 |
+
fps = video_reader.get_avg_fps()
|
67 |
+
duration = vlen / float(fps)
|
68 |
+
if clip:
|
69 |
+
start, end = clip
|
70 |
+
duration = end - start
|
71 |
+
vlen = int(duration * fps)
|
72 |
+
start_index = int(start * fps)
|
73 |
+
|
74 |
+
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
|
75 |
+
|
76 |
+
frame_indices = get_frame_indices(
|
77 |
+
t_num_frames, vlen, sample=sample, fix_start=fix_start,
|
78 |
+
input_fps=fps
|
79 |
+
)
|
80 |
+
if clip:
|
81 |
+
frame_indices = [f + start_index for f in frame_indices]
|
82 |
+
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
|
83 |
+
frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
|
84 |
+
return frames
|
85 |
+
|
86 |
+
|
87 |
+
def extract_frame_number(filename):
|
88 |
+
# Extract the numeric part from the filename using regular expressions
|
89 |
+
match = re.search(r'_(\d+).jpg$', filename)
|
90 |
+
return int(match.group(1)) if match else -1
|
91 |
+
|
92 |
+
|
93 |
+
def sort_frames(frame_paths):
|
94 |
+
# Extract filenames from each path and sort by their numeric part
|
95 |
+
return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
|
96 |
+
|
97 |
+
|
98 |
+
def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
|
99 |
+
image_list = sort_frames(list(os.listdir(video_path)))
|
100 |
+
frames = []
|
101 |
+
for image in image_list:
|
102 |
+
fp = os.path.join(video_path, image)
|
103 |
+
frame = Image.open(fp).convert('RGB')
|
104 |
+
frames.append(frame)
|
105 |
+
vlen = len(frames)
|
106 |
+
|
107 |
+
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
|
108 |
+
|
109 |
+
if vlen > t_num_frames:
|
110 |
+
frame_indices = get_frame_indices(
|
111 |
+
t_num_frames, vlen, sample=sample, fix_start=fix_start
|
112 |
+
)
|
113 |
+
frames = [frames[i] for i in frame_indices]
|
114 |
+
return frames
|
115 |
+
|
116 |
+
|
117 |
+
class FrameSampler:
|
118 |
+
def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
|
119 |
+
self.max_num_frames = max_num_frames
|
120 |
+
self.min_num_frames = min_num_frames
|
121 |
+
self.sample = sample
|
122 |
+
|
123 |
+
def __call__(self, file_name):
|
124 |
+
fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
|
125 |
+
frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
|
126 |
+
return frames
|
127 |
+
|
128 |
+
|
129 |
+
def decode_video_byte(video_bytes):
|
130 |
+
video_stream = io.BytesIO(video_bytes)
|
131 |
+
vr = decord.VideoReader(video_stream)
|
132 |
+
return vr
|
133 |
+
|
134 |
+
|
135 |
+
def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
|
136 |
+
if isinstance(mp4_p, str):
|
137 |
+
vr = decord.VideoReader(mp4_p, num_threads=1)
|
138 |
+
elif isinstance(mp4_p, decord.video_reader.VideoReader):
|
139 |
+
vr = mp4_p
|
140 |
+
video_fps = vr.get_avg_fps() # 获取视频的帧率
|
141 |
+
video_duration = len(vr) / video_fps
|
142 |
+
if n_frames is not None:
|
143 |
+
if random_sample:
|
144 |
+
frame_indices = sorted(random.sample(range(len(vr)), n_frames))
|
145 |
+
else:
|
146 |
+
frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
|
147 |
+
else:
|
148 |
+
frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
|
149 |
+
frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
|
150 |
+
frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
|
151 |
+
if not return_frame_indices:
|
152 |
+
return frames, video_duration
|
153 |
+
else:
|
154 |
+
return frames, video_duration, frame_indices
|
155 |
+
|
156 |
+
|
157 |
+
def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
|
158 |
+
if isinstance(mp4_p, str):
|
159 |
+
vr = decord.VideoReader(mp4_p, num_threads=1)
|
160 |
+
elif isinstance(mp4_p, decord.video_reader.VideoReader):
|
161 |
+
vr = mp4_p
|
162 |
+
# sample the frames in frame_indices
|
163 |
+
frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
|
164 |
+
frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
|
165 |
+
return frames
|
data/vlm_dataset.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import traceback
|
7 |
+
from PIL import Image, ImageFile, PngImagePlugin
|
8 |
+
|
9 |
+
from .data_utils import pil_img2rgb
|
10 |
+
from .distributed_iterable_dataset import DistributedIterableDataset
|
11 |
+
|
12 |
+
|
13 |
+
Image.MAX_IMAGE_PIXELS = 200000000
|
14 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
15 |
+
MaximumDecompressedSize = 1024
|
16 |
+
MegaByte = 2 ** 20
|
17 |
+
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
|
18 |
+
|
19 |
+
|
20 |
+
class SftJSONLIterableDataset(DistributedIterableDataset):
|
21 |
+
def __init__(
|
22 |
+
self, dataset_name, transform, tokenizer, frame_sampler,
|
23 |
+
jsonl_path_list, data_dir_list, num_used_data,
|
24 |
+
local_rank=0, world_size=1, num_workers=8, data_status=None,
|
25 |
+
shuffle_lines=False, shuffle_seed=0,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
jsonl_path_list: list of jsonl file paths
|
29 |
+
data_dir_list: list of image directories containing the images of each jsonl file
|
30 |
+
num_used_data: list of number of sampled data points for each jsonl
|
31 |
+
"""
|
32 |
+
super().__init__(dataset_name, local_rank, world_size, num_workers)
|
33 |
+
self.transform = transform
|
34 |
+
self.tokenizer = tokenizer
|
35 |
+
self.frame_sampler = frame_sampler
|
36 |
+
self.data_status = data_status
|
37 |
+
self.data_paths = self.get_data_paths(
|
38 |
+
jsonl_path_list,
|
39 |
+
data_dir_list,
|
40 |
+
num_used_data,
|
41 |
+
shuffle_lines,
|
42 |
+
shuffle_seed,
|
43 |
+
)
|
44 |
+
self.set_epoch()
|
45 |
+
|
46 |
+
def get_data_paths(
|
47 |
+
self,
|
48 |
+
jsonl_path_list,
|
49 |
+
data_dir_list,
|
50 |
+
num_used_data,
|
51 |
+
shuffle_lines,
|
52 |
+
shuffle_seed,
|
53 |
+
):
|
54 |
+
data_paths = []
|
55 |
+
for jsonl_path, image_dir, num_data_point in zip(
|
56 |
+
jsonl_path_list, data_dir_list, num_used_data
|
57 |
+
):
|
58 |
+
with open(jsonl_path, 'r') as f:
|
59 |
+
raw_data = f.readlines()
|
60 |
+
if shuffle_lines:
|
61 |
+
self.rng.seed(shuffle_seed)
|
62 |
+
self.rng.shuffle(raw_data)
|
63 |
+
raw_data = raw_data[:num_data_point]
|
64 |
+
data_paths.extend([(json_data, image_dir) for json_data in raw_data])
|
65 |
+
return data_paths
|
66 |
+
|
67 |
+
def change_format(self, data, num_images):
|
68 |
+
elements = []
|
69 |
+
for conversation in data['conversations']:
|
70 |
+
if conversation['from'] == 'human':
|
71 |
+
if '<image>' not in conversation['value']:
|
72 |
+
elements.append({
|
73 |
+
'type': 'text',
|
74 |
+
'has_loss': 0,
|
75 |
+
'text': conversation['value'],
|
76 |
+
})
|
77 |
+
else:
|
78 |
+
text_list = conversation['value'].split('<image>')
|
79 |
+
for idx, text in enumerate(text_list):
|
80 |
+
if text.strip() != '':
|
81 |
+
elements.append({
|
82 |
+
'type': 'text',
|
83 |
+
'has_loss': 0,
|
84 |
+
'text': text.strip(),
|
85 |
+
})
|
86 |
+
if (idx != len(text_list) - 1) and (idx < num_images):
|
87 |
+
elements.append({'type': 'image',})
|
88 |
+
elif conversation['from'] == 'gpt':
|
89 |
+
elements.append({
|
90 |
+
'type': 'text',
|
91 |
+
'has_loss': 1,
|
92 |
+
'text': conversation['value'],
|
93 |
+
})
|
94 |
+
return elements
|
95 |
+
|
96 |
+
def __iter__(self):
|
97 |
+
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
|
98 |
+
if self.data_status is not None:
|
99 |
+
row_start_id = self.data_status[worker_id] + 1
|
100 |
+
else:
|
101 |
+
row_start_id = 0
|
102 |
+
transform_stride = self.transform.stride
|
103 |
+
|
104 |
+
print(
|
105 |
+
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
|
106 |
+
f"resuming data at row#{row_start_id}"
|
107 |
+
)
|
108 |
+
|
109 |
+
while True:
|
110 |
+
data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
|
111 |
+
for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id):
|
112 |
+
num_tokens = 0
|
113 |
+
image_tensor_list = []
|
114 |
+
text_ids_list = []
|
115 |
+
sequence_plan = []
|
116 |
+
|
117 |
+
try:
|
118 |
+
data_item = json.loads(data)
|
119 |
+
raw_images = None
|
120 |
+
if 'image' in data_item:
|
121 |
+
if type(data_item['image']) == list:
|
122 |
+
raw_images = [
|
123 |
+
pil_img2rgb(Image.open(os.path.join(image_dir, image)))
|
124 |
+
for image in data_item['image']
|
125 |
+
]
|
126 |
+
else:
|
127 |
+
raw_images = [
|
128 |
+
pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'])))
|
129 |
+
]
|
130 |
+
elif 'video' in data_item:
|
131 |
+
raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video']))
|
132 |
+
special_tokens = '<image>' * len(raw_images)
|
133 |
+
for item in data_item['conversations']:
|
134 |
+
if '<video>' in item['value']:
|
135 |
+
item['value'] = item['value'].replace('<video>', special_tokens)
|
136 |
+
break
|
137 |
+
else:
|
138 |
+
raise ValueError("Cannot find <video> in the conversation!")
|
139 |
+
except:
|
140 |
+
traceback.print_exc()
|
141 |
+
continue
|
142 |
+
|
143 |
+
if raw_images:
|
144 |
+
for raw_image in raw_images:
|
145 |
+
image_tensor = self.transform(raw_image, img_num=len(raw_images))
|
146 |
+
image_tensor_list.append(image_tensor)
|
147 |
+
height, width = image_tensor.shape[1:]
|
148 |
+
num_tokens += width * height // transform_stride ** 2
|
149 |
+
|
150 |
+
elements = self.change_format(data_item, len(image_tensor_list))
|
151 |
+
|
152 |
+
for item in elements:
|
153 |
+
if item['type'] == 'text':
|
154 |
+
text_data = item['text']
|
155 |
+
text_ids = self.tokenizer.encode(text_data)
|
156 |
+
if len(text_ids) > 0:
|
157 |
+
text_ids_list.append(text_ids)
|
158 |
+
num_tokens += len(text_ids)
|
159 |
+
current_plan = {
|
160 |
+
'type': 'text',
|
161 |
+
'enable_cfg': 0,
|
162 |
+
'loss': item['has_loss'],
|
163 |
+
'special_token_loss': 0,
|
164 |
+
'special_token_label': None,
|
165 |
+
}
|
166 |
+
sequence_plan.append(current_plan)
|
167 |
+
elif item['type'] == 'image':
|
168 |
+
current_plan = {
|
169 |
+
'type': 'vit_image',
|
170 |
+
'enable_cfg': 0,
|
171 |
+
'loss': 0,
|
172 |
+
'special_token_loss': 0,
|
173 |
+
'special_token_label': None,
|
174 |
+
}
|
175 |
+
sequence_plan.append(current_plan)
|
176 |
+
|
177 |
+
has_loss = [item['loss'] for item in sequence_plan]
|
178 |
+
if sum(has_loss) == 0:
|
179 |
+
print(f'No loss defined, skipped.')
|
180 |
+
continue
|
181 |
+
|
182 |
+
yield dict(
|
183 |
+
image_tensor_list=image_tensor_list,
|
184 |
+
text_ids_list=text_ids_list,
|
185 |
+
sequence_plan=sequence_plan,
|
186 |
+
num_tokens=num_tokens,
|
187 |
+
data_indexes={
|
188 |
+
"data_indexes": row_idx,
|
189 |
+
"worker_id": worker_id,
|
190 |
+
"dataset_name": self.dataset_name,
|
191 |
+
}
|
192 |
+
)
|
193 |
+
|
194 |
+
row_start_id = 0
|
195 |
+
print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
|