KingNish commited on
Commit
6ceb788
·
verified ·
1 Parent(s): 0d897bc

Delete data

Browse files
data/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
 
 
 
data/configs/example.yaml DELETED
@@ -1,45 +0,0 @@
1
- t2i_pretrain:
2
- dataset_names:
3
- - t2i
4
- image_transform_args:
5
- image_stride: 16
6
- max_image_size: 1024
7
- min_image_size: 512
8
- is_mandatory: true
9
- num_used_data: # The sum should be larger that NUM_GPUS x NUM_WORKERS
10
- - 10
11
- weight: 1
12
-
13
- unified_edit:
14
- dataset_names:
15
- - seedxedit_multi
16
- image_transform_args:
17
- image_stride: 16
18
- max_image_size: 1024
19
- min_image_size: 512
20
- vit_image_transform_args:
21
- image_stride: 14
22
- max_image_size: 518
23
- min_image_size: 224
24
- is_mandatory: false
25
- num_used_data:
26
- - 10
27
- weight: 1
28
-
29
- vlm_sft:
30
- dataset_names:
31
- - llava_ov
32
- image_transform_args:
33
- image_stride: 14
34
- max_image_size: 980
35
- min_image_size: 378
36
- max_pixels: 2_007_040
37
- frame_sampler_args:
38
- max_num_frames: 12
39
- min_num_frames: 8
40
- is_mandatory: true
41
- shuffle_lines: True
42
- shuffle_seed: 0
43
- num_used_data:
44
- - 1000
45
- weight: 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/data_utils.py DELETED
@@ -1,177 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
-
5
- import math
6
- import random
7
- from PIL import Image
8
-
9
- import torch
10
- from torch.nn.attention.flex_attention import or_masks, and_masks
11
-
12
-
13
- def create_sparse_mask(document_lens, split_lens, attn_modes, device):
14
- def causal_mask(b, h, q_idx, kv_idx):
15
- return q_idx >= kv_idx
16
-
17
- def full_and_noise_mask(b, h, q_idx, kv_idx):
18
- return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
19
-
20
- def remove_noise_mask(b, h, q_idx, kv_idx):
21
- return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
22
-
23
- def sample_mask(b, h, q_idx, kv_idx):
24
- return document_id[q_idx] == document_id[kv_idx]
25
-
26
- full_and_noise_tmp = []
27
- noise_tmp = []
28
-
29
- for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
30
- value = i if model in ['full', 'noise'] else -1
31
- full_and_noise_tmp.extend([value] * length)
32
- value_noise = i if model == 'noise' else -1
33
- noise_tmp.extend([value_noise] * length)
34
-
35
- full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
36
- noise_seq_id = torch.Tensor(noise_tmp).to(device)
37
-
38
- document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
39
-
40
- return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
41
-
42
-
43
- def patchify(image, patch_size):
44
- p = patch_size
45
- c, h, w = image.shape
46
- assert h % p == 0 and w % p == 0
47
- image = image.reshape(c, h // p, p, w // p, p)
48
- image = torch.einsum("chpwq->hwpqc", image)
49
- image = image.reshape(-1, p**2 * c)
50
- return image
51
-
52
-
53
- def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
54
- num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
55
- coords_h = torch.arange(0, num_patches_h)
56
- coords_w = torch.arange(0, num_patches_w)
57
- pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
58
- return pos_ids
59
-
60
-
61
- def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
62
- num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
63
- boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
64
- fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
65
- fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
66
- bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
67
- bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
68
- pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
69
- return pos_ids
70
-
71
-
72
- def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
73
- """
74
- nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
75
- a sample, where each sample contains multiple splits with different attn modes.
76
- nested_attn_modes: whether to use full attn in each split.
77
- """
78
- sample_len = sum(split_lens)
79
- attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
80
-
81
- csum = 0
82
- for s, attn_mode in zip(split_lens, attn_modes):
83
- assert attn_mode in ['causal', 'full', 'noise']
84
- if attn_mode == "causal":
85
- attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
86
- attention_mask[csum:csum + s, :csum] = 1
87
- else:
88
- attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
89
- attention_mask[csum:csum + s, :csum] = 1
90
- csum += s
91
-
92
- csum = 0
93
- for s, attn_mode in zip(split_lens, attn_modes):
94
- if attn_mode == "noise":
95
- attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
96
- attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
97
- csum += s
98
-
99
- attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
100
- ~attention_mask, float("-inf")
101
- )
102
-
103
- return attention_mask
104
-
105
-
106
- def split_integer_exp_decay(S, ng_sample_decay=1.0):
107
- if ng_sample_decay == 1.0:
108
- N = random.randint(1, S)
109
- else:
110
- base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
111
- p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
112
- N = random.choices(list(range(1, S + 1)), p, k=1)[0]
113
- cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
114
- result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
115
- return result, cumsum
116
-
117
-
118
- def pil_img2rgb(image):
119
- if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
120
- image = image.convert("RGBA")
121
- white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
122
- white.paste(image, mask=image.split()[3])
123
- image = white
124
- else:
125
- image = image.convert("RGB")
126
-
127
- return image
128
-
129
-
130
- def add_special_tokens(tokenizer):
131
- all_special_tokens = []
132
- for k, v in tokenizer.special_tokens_map.items():
133
- if isinstance(v, str):
134
- all_special_tokens.append(v)
135
- elif isinstance(v, list):
136
- all_special_tokens += v
137
-
138
- new_tokens = []
139
-
140
- if '<|im_start|>' not in all_special_tokens:
141
- new_tokens.append('<|im_start|>')
142
-
143
- if '<|im_end|>' not in all_special_tokens:
144
- new_tokens.append('<|im_end|>')
145
-
146
- if '<|vision_start|>' not in all_special_tokens:
147
- new_tokens.append('<|vision_start|>')
148
-
149
- if '<|vision_end|>' not in all_special_tokens:
150
- new_tokens.append('<|vision_end|>')
151
-
152
- num_new_tokens = tokenizer.add_tokens(new_tokens)
153
- bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
154
- eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
155
- start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
156
- end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
157
-
158
- new_token_ids = dict(
159
- bos_token_id=bos_token_id,
160
- eos_token_id=eos_token_id,
161
- start_of_image=start_of_image,
162
- end_of_image=end_of_image,
163
- )
164
-
165
- return tokenizer, new_token_ids, num_new_tokens
166
-
167
-
168
- def len2weight(x, loss_reduction='square'):
169
- if x == 0:
170
- return x
171
- if loss_reduction == 'token':
172
- return 1
173
- if loss_reduction == 'sample':
174
- return 1 / x
175
- if loss_reduction == 'square':
176
- return 1 / (x ** 0.5)
177
- raise NotImplementedError(loss_reduction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/dataset_base.py DELETED
@@ -1,620 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
-
5
- import random
6
- import json
7
-
8
- import numpy as np
9
- import torch
10
-
11
- from .data_utils import (
12
- get_flattened_position_ids_interpolate,
13
- get_flattened_position_ids_extrapolate,
14
- len2weight,
15
- patchify,
16
- prepare_attention_mask_per_sample,
17
- )
18
- from .dataset_info import DATASET_INFO, DATASET_REGISTRY
19
- from .transforms import ImageTransform
20
- from .video_utils import FrameSampler
21
-
22
-
23
- class DataConfig:
24
- def __init__(
25
- self,
26
- grouped_datasets,
27
- text_cond_dropout_prob=0.1,
28
- vit_cond_dropout_prob=0.4,
29
- vae_cond_dropout_prob=0.1,
30
- vae_image_downsample=16,
31
- max_latent_size=32,
32
- vit_patch_size=14,
33
- max_num_patch_per_side=70,
34
- ):
35
- self.grouped_datasets = grouped_datasets
36
- self.text_cond_dropout_prob = text_cond_dropout_prob
37
- self.vit_cond_dropout_prob = vit_cond_dropout_prob
38
- self.vit_patch_size = vit_patch_size
39
- self.max_num_patch_per_side = max_num_patch_per_side
40
- self.vae_cond_dropout_prob = vae_cond_dropout_prob
41
- self.vae_image_downsample = vae_image_downsample
42
- self.max_latent_size = max_latent_size
43
-
44
-
45
- class PackedDataset(torch.utils.data.IterableDataset):
46
- def __init__(
47
- self,
48
- data_config,
49
- tokenizer,
50
- special_tokens,
51
- local_rank,
52
- world_size,
53
- num_workers,
54
- expected_num_tokens=32768,
55
- max_num_tokens_per_sample=16384,
56
- max_num_tokens=36864,
57
- prefer_buffer_before=16384,
58
- max_buffer_size=50,
59
- interpolate_pos=False,
60
- use_flex=False,
61
- data_status=None,
62
- ):
63
- super().__init__()
64
- self.expected_num_tokens = expected_num_tokens
65
- self.max_num_tokens_per_sample = max_num_tokens_per_sample
66
- self.prefer_buffer_before = prefer_buffer_before
67
- self.max_num_tokens = max_num_tokens
68
- self.max_buffer_size = max_buffer_size
69
- self.tokenizer = tokenizer
70
- self.local_rank = local_rank
71
- self.world_size = world_size
72
- self.num_workers = num_workers
73
- self.use_flex = use_flex
74
- for k, v in special_tokens.items():
75
- setattr(self, k, v)
76
-
77
- grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
78
- data_config.grouped_datasets, data_status
79
- )
80
- self.grouped_datasets = grouped_datasets
81
- self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
82
- self.is_mandatory = is_mandatory
83
- self.grouped_weights = grouped_weights
84
- self.data_config = data_config
85
- self.interpolate_pos = interpolate_pos
86
- if self.interpolate_pos:
87
- self.get_flattened_position_ids = get_flattened_position_ids_interpolate
88
- else:
89
- self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
90
-
91
- def build_datasets(self, datasets_metainfo, data_status):
92
- datasets = []
93
- is_mandatory = []
94
- grouped_weights = []
95
- for grouped_dataset_name, dataset_args in datasets_metainfo.items():
96
- is_mandatory.append(dataset_args.pop('is_mandatory', False))
97
- grouped_weights.append(dataset_args.pop('weight', 0.0))
98
-
99
- if 'frame_sampler_args' in dataset_args.keys():
100
- frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args'))
101
- dataset_args['frame_sampler'] = frame_sampler
102
- if 'image_transform_args' in dataset_args.keys():
103
- transform = ImageTransform(**dataset_args.pop('image_transform_args'))
104
- dataset_args['transform'] = transform
105
- if 'vit_image_transform_args' in dataset_args.keys():
106
- vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args'))
107
- dataset_args['vit_transform'] = vit_transform
108
-
109
- assert 'dataset_names' in dataset_args.keys()
110
- dataset_names = dataset_args.pop('dataset_names')
111
- dataset_args['data_dir_list'] = []
112
- for item in dataset_names:
113
- if self.local_rank == 0:
114
- print(f'Preparing Dataset {grouped_dataset_name}/{item}')
115
- meta_info = DATASET_INFO[grouped_dataset_name][item]
116
- dataset_args['data_dir_list'].append(meta_info['data_dir'])
117
-
118
- if "parquet_info_path" in meta_info.keys():
119
- if 'parquet_info' not in dataset_args.keys():
120
- dataset_args['parquet_info'] = {}
121
- with open(meta_info['parquet_info_path'], 'r') as f:
122
- parquet_info = json.load(f)
123
- dataset_args['parquet_info'].update(parquet_info)
124
-
125
- if 'json_dir' in meta_info.keys():
126
- # parquet/tar with json
127
- if 'json_dir_list' not in dataset_args.keys():
128
- dataset_args['json_dir_list'] = [meta_info['json_dir']]
129
- else:
130
- dataset_args['json_dir_list'].append(meta_info['json_dir'])
131
-
132
- if 'jsonl_path' in meta_info.keys():
133
- # jsonl with jpeg
134
- if 'jsonl_path_list' not in dataset_args.keys():
135
- dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']]
136
- else:
137
- dataset_args['jsonl_path_list'].append(meta_info['jsonl_path'])
138
-
139
- resume_data_status = dataset_args.pop('resume_data_status', True)
140
- if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status:
141
- data_status_per_group = data_status[grouped_dataset_name]
142
- else:
143
- data_status_per_group = None
144
- dataset = DATASET_REGISTRY[grouped_dataset_name](
145
- dataset_name=grouped_dataset_name,
146
- tokenizer=self.tokenizer,
147
- local_rank=self.local_rank,
148
- world_size=self.world_size,
149
- num_workers=self.num_workers,
150
- data_status=data_status_per_group,
151
- **dataset_args
152
- )
153
- datasets.append(dataset)
154
-
155
- return datasets, is_mandatory, grouped_weights
156
-
157
- def set_epoch(self, seed):
158
- for dataset in self.grouped_datasets:
159
- dataset.set_epoch(seed)
160
-
161
- def set_sequence_status(self):
162
- sequence_status = dict(
163
- curr = 0,
164
- sample_lens = list(),
165
- packed_position_ids = list(),
166
- nested_attention_masks = list(),
167
- split_lens = list(),
168
- attn_modes = list(),
169
- packed_text_ids = list(),
170
- packed_text_indexes = list(),
171
- packed_label_ids = list(),
172
- ce_loss_indexes = list(),
173
- ce_loss_weights = list(),
174
- vae_image_tensors = list(),
175
- packed_latent_position_ids = list(),
176
- vae_latent_shapes = list(),
177
- packed_vae_token_indexes = list(),
178
- packed_timesteps = list(),
179
- mse_loss_indexes = list(),
180
- packed_vit_tokens = list(),
181
- vit_token_seqlens = list(),
182
- packed_vit_position_ids = list(),
183
- packed_vit_token_indexes = list(),
184
- )
185
- return sequence_status
186
-
187
- def to_tensor(self, sequence_status):
188
- data = dict(
189
- sequence_length=sum(sequence_status['sample_lens']),
190
- sample_lens=sequence_status['sample_lens'],
191
- packed_text_ids=torch.tensor(sequence_status['packed_text_ids']),
192
- packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']),
193
- packed_position_ids=torch.tensor(sequence_status['packed_position_ids']),
194
- )
195
- if not self.use_flex:
196
- data['nested_attention_masks'] = sequence_status['nested_attention_masks']
197
- else:
198
- sequence_len = data['sequence_length']
199
- pad_len = self.max_num_tokens - sequence_len
200
- data['split_lens'] = sequence_status['split_lens'] + [pad_len]
201
- data['attn_modes'] = sequence_status['attn_modes'] + ['causal']
202
- data['sample_lens'] += [pad_len]
203
-
204
- # if the model has a convnet vae (e.g., as visual tokenizer)
205
- if len(sequence_status['vae_image_tensors']) > 0:
206
- image_tensors = sequence_status.pop('vae_image_tensors')
207
- image_sizes = [item.shape for item in image_tensors]
208
- max_image_size = [max(item) for item in list(zip(*image_sizes))]
209
- padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
210
- for i, image_tensor in enumerate(image_tensors):
211
- padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
212
-
213
- data['padded_images'] = padded_images
214
- data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes']
215
- data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0)
216
- data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes'])
217
-
218
- # if the model has a vit (e.g., as visual tokenizer)
219
- if len(sequence_status['packed_vit_tokens']) > 0:
220
- data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0)
221
- data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0)
222
- data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes'])
223
- data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens'])
224
-
225
- # if the model is required to perform visual generation
226
- if len(sequence_status['packed_timesteps']) > 0:
227
- data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps'])
228
- data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes'])
229
-
230
- # if the model is required to perform text generation
231
- if len(sequence_status['packed_label_ids']) > 0:
232
- data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids'])
233
- data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes'])
234
- data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights'])
235
-
236
- return data
237
-
238
- def __iter__(self):
239
- total_weights = sum(self.grouped_weights)
240
- assert total_weights > 0.0
241
- group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights
242
- for i in range(len(self.grouped_weights))]
243
- sequence_status = self.set_sequence_status()
244
- batch_data_indexes = []
245
-
246
- buffer = []
247
- while True:
248
- # Ensure at least one sample from each group
249
- if sequence_status['curr'] == 0:
250
- for group_index, group_iter in enumerate(self.dataset_iters):
251
- if self.is_mandatory[group_index]:
252
- while True:
253
- sample = next(group_iter)
254
- # if a sample is too long, skip it
255
- num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
256
- if num_tokens < self.max_num_tokens_per_sample:
257
- sequence_status = self.pack_sequence(sample, sequence_status)
258
- batch_data_indexes.append(sample['data_indexes'])
259
- break
260
- else:
261
- print(f"skip a sample with length {num_tokens}")
262
- continue
263
-
264
- if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0:
265
- sample = buffer.pop(0)
266
- sample_from_buffer = True
267
- else:
268
- # sample normally across all groups
269
- n = random.random()
270
- group_index = 0
271
- for i, cumprob in enumerate(group_cumprobs):
272
- if n < cumprob:
273
- group_index = i
274
- break
275
- sample = next(self.dataset_iters[group_index])
276
- sample_from_buffer = False
277
-
278
- # if a sample is too long, skip it
279
- num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
280
- if num_tokens > self.max_num_tokens_per_sample:
281
- print(f"skip a sample with length {num_tokens}")
282
- continue
283
-
284
- if sequence_status['curr'] + num_tokens > self.max_num_tokens:
285
- if len(buffer) < self.max_buffer_size and not sample_from_buffer:
286
- buffer.append(sample)
287
- else:
288
- print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
289
- data = self.to_tensor(sequence_status)
290
- data['batch_data_indexes'] = batch_data_indexes
291
- yield data
292
- sequence_status = self.set_sequence_status()
293
- batch_data_indexes = []
294
- continue
295
-
296
- sequence_status = self.pack_sequence(sample, sequence_status)
297
- batch_data_indexes.append(sample['data_indexes'])
298
-
299
- if sequence_status['curr'] >= self.expected_num_tokens:
300
- data = self.to_tensor(sequence_status)
301
- data['batch_data_indexes'] = batch_data_indexes
302
- yield data
303
- sequence_status = self.set_sequence_status()
304
- batch_data_indexes = []
305
-
306
- def pack_sequence(self, sample, sequence_status):
307
- image_tensor_list = sample['image_tensor_list']
308
- text_ids_list = sample['text_ids_list']
309
- sequence_plan = sample['sequence_plan']
310
-
311
- split_lens, attn_modes = list(), list()
312
- curr = sequence_status['curr']
313
- curr_rope_id = 0
314
- sample_lens = 0
315
-
316
- for item in sequence_plan:
317
- split_start = item.get('split_start', True)
318
- if split_start:
319
- curr_split_len = 0
320
-
321
- if item['type'] == 'text':
322
- text_ids = text_ids_list.pop(0)
323
- if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob:
324
- continue
325
-
326
- shifted_text_ids = [self.bos_token_id] + text_ids
327
- sequence_status['packed_text_ids'].extend(shifted_text_ids)
328
- sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
329
- if item['loss'] == 1:
330
- sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
331
- sequence_status['ce_loss_weights'].extend(
332
- [len2weight(len(shifted_text_ids))] * len(shifted_text_ids)
333
- )
334
- sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id])
335
- curr += len(shifted_text_ids)
336
- curr_split_len += len(shifted_text_ids)
337
-
338
- # add a <|im_end|> token
339
- sequence_status['packed_text_ids'].append(self.eos_token_id)
340
- sequence_status['packed_text_indexes'].append(curr)
341
- if item['special_token_loss'] == 1: # <|im_end|> may have loss
342
- sequence_status['ce_loss_indexes'].append(curr)
343
- sequence_status['ce_loss_weights'].append(1.0)
344
- sequence_status['packed_label_ids'].append(item['special_token_label'])
345
- curr += 1
346
- curr_split_len += 1
347
-
348
- # update sequence status
349
- attn_modes.append("causal")
350
- sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
351
- curr_rope_id += curr_split_len
352
-
353
- elif item['type'] == 'vit_image':
354
- image_tensor = image_tensor_list.pop(0)
355
- if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob:
356
- curr_rope_id += 1
357
- continue
358
-
359
- # add a <|startofimage|> token
360
- sequence_status['packed_text_ids'].append(self.start_of_image)
361
- sequence_status['packed_text_indexes'].append(curr)
362
- curr += 1
363
- curr_split_len += 1
364
-
365
- # preprocess image
366
- vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size)
367
- num_img_tokens = vit_tokens.shape[0]
368
- sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens))
369
- curr += num_img_tokens
370
- curr_split_len += num_img_tokens
371
-
372
- sequence_status['packed_vit_tokens'].append(vit_tokens)
373
- sequence_status['vit_token_seqlens'].append(num_img_tokens)
374
- sequence_status['packed_vit_position_ids'].append(
375
- self.get_flattened_position_ids(
376
- image_tensor.size(1), image_tensor.size(2),
377
- self.data_config.vit_patch_size,
378
- max_num_patches_per_side=self.data_config.max_num_patch_per_side
379
- )
380
- )
381
-
382
- # add a <|endofimage|> token
383
- sequence_status['packed_text_ids'].append(self.end_of_image)
384
- sequence_status['packed_text_indexes'].append(curr)
385
- if item['special_token_loss'] == 1: # <|endofimage|> may have loss
386
- sequence_status['ce_loss_indexes'].append(curr)
387
- sequence_status['ce_loss_weights'].append(1.0)
388
- sequence_status['packed_label_ids'].append(item['special_token_label'])
389
- curr += 1
390
- curr_split_len += 1
391
-
392
- # update sequence status
393
- attn_modes.append("full")
394
- sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len)
395
- curr_rope_id += 1
396
-
397
- elif item['type'] == 'vae_image':
398
- image_tensor = image_tensor_list.pop(0)
399
- if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob:
400
- # FIXME fix vae dropout in video2video setting.
401
- curr_rope_id += 1
402
- continue
403
-
404
- # add a <|startofimage|> token
405
- sequence_status['packed_text_ids'].append(self.start_of_image)
406
- sequence_status['packed_text_indexes'].append(curr)
407
- curr += 1
408
- curr_split_len += 1
409
-
410
- # preprocess image
411
- sequence_status['vae_image_tensors'].append(image_tensor)
412
- sequence_status['packed_latent_position_ids'].append(
413
- self.get_flattened_position_ids(
414
- image_tensor.size(1), image_tensor.size(2),
415
- self.data_config.vae_image_downsample,
416
- max_num_patches_per_side=self.data_config.max_latent_size
417
- )
418
- )
419
- H, W = image_tensor.shape[1:]
420
- h = H // self.data_config.vae_image_downsample
421
- w = W // self.data_config.vae_image_downsample
422
- sequence_status['vae_latent_shapes'].append((h, w))
423
-
424
- num_img_tokens = w * h
425
- sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens))
426
- if item['loss'] == 1:
427
- sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens))
428
- if split_start:
429
- timestep = np.random.randn()
430
- else:
431
- timestep = float('-inf')
432
-
433
- sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens)
434
- curr += num_img_tokens
435
- curr_split_len += num_img_tokens
436
-
437
- # add a <|endofimage|> token
438
- sequence_status['packed_text_ids'].append(self.end_of_image)
439
- sequence_status['packed_text_indexes'].append(curr)
440
- # <|endofimage|> may have loss
441
- if item['special_token_loss'] == 1:
442
- sequence_status['ce_loss_indexes'].append(curr)
443
- sequence_status['ce_loss_weights'].append(1.0)
444
- sequence_status['packed_label_ids'].append(item['special_token_label'])
445
- curr += 1
446
- curr_split_len += 1
447
-
448
- # update sequence status
449
- if split_start:
450
- if item['loss'] == 1 and 'frame_delta' not in item.keys():
451
- attn_modes.append("noise")
452
- else:
453
- attn_modes.append("full")
454
- sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2))
455
- if 'frame_delta' in item.keys():
456
- curr_rope_id += item['frame_delta']
457
- elif item['loss'] == 0:
458
- curr_rope_id += 1
459
-
460
- if item.get('split_end', True):
461
- split_lens.append(curr_split_len)
462
- sample_lens += curr_split_len
463
-
464
- sequence_status['curr'] = curr
465
- sequence_status['sample_lens'].append(sample_lens)
466
- # prepare attention mask
467
- if not self.use_flex:
468
- sequence_status['nested_attention_masks'].append(
469
- prepare_attention_mask_per_sample(split_lens, attn_modes)
470
- )
471
- else:
472
- sequence_status['split_lens'].extend(split_lens)
473
- sequence_status['attn_modes'].extend(attn_modes)
474
-
475
- return sequence_status
476
-
477
-
478
- class SimpleCustomBatch:
479
- def __init__(self, batch):
480
- data = batch[0]
481
- self.batch_data_indexes = data['batch_data_indexes']
482
- self.sequence_length = data["sequence_length"]
483
- self.sample_lens = data["sample_lens"]
484
- self.packed_text_ids = data["packed_text_ids"]
485
- self.packed_text_indexes = data["packed_text_indexes"]
486
- self.packed_position_ids = data["packed_position_ids"]
487
-
488
- self.use_flex = "nested_attention_masks" not in data.keys()
489
-
490
- if self.use_flex:
491
- self.split_lens = data["split_lens"]
492
- self.attn_modes = data["attn_modes"]
493
- else:
494
- self.nested_attention_masks = data["nested_attention_masks"]
495
-
496
- if "padded_images" in data.keys():
497
- self.padded_images = data["padded_images"]
498
- self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"]
499
- self.packed_latent_position_ids = data["packed_latent_position_ids"]
500
- self.packed_vae_token_indexes = data["packed_vae_token_indexes"]
501
-
502
- if "packed_vit_tokens" in data.keys():
503
- self.packed_vit_tokens = data["packed_vit_tokens"]
504
- self.packed_vit_position_ids = data["packed_vit_position_ids"]
505
- self.packed_vit_token_indexes = data["packed_vit_token_indexes"]
506
- self.vit_token_seqlens = data["vit_token_seqlens"]
507
-
508
- if "packed_timesteps" in data.keys():
509
- self.packed_timesteps = data["packed_timesteps"]
510
- self.mse_loss_indexes = data["mse_loss_indexes"]
511
-
512
- if "packed_label_ids" in data.keys():
513
- self.packed_label_ids = data["packed_label_ids"]
514
- self.ce_loss_indexes = data["ce_loss_indexes"]
515
- self.ce_loss_weights = data["ce_loss_weights"]
516
-
517
- def pin_memory(self):
518
- self.packed_text_ids = self.packed_text_ids.pin_memory()
519
- self.packed_text_indexes = self.packed_text_indexes.pin_memory()
520
- self.packed_position_ids = self.packed_position_ids.pin_memory()
521
-
522
- if not self.use_flex:
523
- self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks]
524
-
525
- if hasattr(self, 'padded_images'):
526
- self.padded_images = self.padded_images.pin_memory()
527
- self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory()
528
- self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory()
529
-
530
- if hasattr(self, 'packed_timesteps'):
531
- self.packed_timesteps = self.packed_timesteps.pin_memory()
532
- self.mse_loss_indexes = self.mse_loss_indexes.pin_memory()
533
-
534
- if hasattr(self, 'packed_vit_tokens'):
535
- self.packed_vit_tokens = self.packed_vit_tokens.pin_memory()
536
- self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory()
537
- self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory()
538
- self.vit_token_seqlens = self.vit_token_seqlens.pin_memory()
539
-
540
- if hasattr(self, 'packed_label_ids'):
541
- self.packed_label_ids = self.packed_label_ids.pin_memory()
542
- self.ce_loss_indexes = self.ce_loss_indexes.pin_memory()
543
- self.ce_loss_weights = self.ce_loss_weights.pin_memory()
544
-
545
- return self
546
-
547
- def cuda(self, device):
548
- self.packed_text_ids = self.packed_text_ids.to(device)
549
- self.packed_text_indexes = self.packed_text_indexes.to(device)
550
- self.packed_position_ids = self.packed_position_ids.to(device)
551
-
552
- if not self.use_flex:
553
- self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks]
554
-
555
- if hasattr(self, 'padded_images'):
556
- self.padded_images = self.padded_images.to(device)
557
- self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device)
558
- self.packed_latent_position_ids = self.packed_latent_position_ids.to(device)
559
-
560
- if hasattr(self, 'packed_timesteps'):
561
- self.packed_timesteps = self.packed_timesteps.to(device)
562
- self.mse_loss_indexes = self.mse_loss_indexes.to(device)
563
-
564
- if hasattr(self, 'packed_vit_tokens'):
565
- self.packed_vit_tokens = self.packed_vit_tokens.to(device)
566
- self.packed_vit_position_ids = self.packed_vit_position_ids.to(device)
567
- self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device)
568
- self.vit_token_seqlens = self.vit_token_seqlens.to(device)
569
-
570
- if hasattr(self, 'packed_label_ids'):
571
- self.packed_label_ids = self.packed_label_ids.to(device)
572
- self.ce_loss_indexes = self.ce_loss_indexes.to(device)
573
- self.ce_loss_weights = self.ce_loss_weights.to(device)
574
-
575
- return self
576
-
577
- def to_dict(self):
578
- data = dict(
579
- sequence_length = self.sequence_length,
580
- sample_lens = self.sample_lens,
581
- packed_text_ids = self.packed_text_ids,
582
- packed_text_indexes = self.packed_text_indexes,
583
- packed_position_ids = self.packed_position_ids,
584
- batch_data_indexes = self.batch_data_indexes,
585
- )
586
-
587
- if not self.use_flex:
588
- data['nested_attention_masks'] = self.nested_attention_masks
589
- else:
590
- data['split_lens'] = self.split_lens
591
- data['attn_modes'] = self.attn_modes
592
-
593
- if hasattr(self, 'padded_images'):
594
- data['padded_images'] = self.padded_images
595
- data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes
596
- data['packed_latent_position_ids'] = self.packed_latent_position_ids
597
- data['packed_vae_token_indexes'] = self.packed_vae_token_indexes
598
-
599
- if hasattr(self, 'packed_vit_tokens'):
600
- data['packed_vit_tokens'] = self.packed_vit_tokens
601
- data['packed_vit_position_ids'] = self.packed_vit_position_ids
602
- data['packed_vit_token_indexes'] = self.packed_vit_token_indexes
603
- data['vit_token_seqlens'] = self.vit_token_seqlens
604
-
605
- if hasattr(self, 'packed_timesteps'):
606
- data['packed_timesteps'] = self.packed_timesteps
607
- data['mse_loss_indexes'] = self.mse_loss_indexes
608
-
609
- if hasattr(self, 'packed_label_ids'):
610
- data['packed_label_ids'] = self.packed_label_ids
611
- data['ce_loss_indexes'] = self.ce_loss_indexes
612
- data['ce_loss_weights'] = self.ce_loss_weights
613
-
614
- return data
615
-
616
-
617
- def collate_wrapper():
618
- def collate_fn(batch):
619
- return SimpleCustomBatch(batch)
620
- return collate_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/dataset_info.py DELETED
@@ -1,39 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from .interleave_datasets import UnifiedEditIterableDataset
5
- from .t2i_dataset import T2IIterableDataset
6
- from .vlm_dataset import SftJSONLIterableDataset
7
-
8
-
9
- DATASET_REGISTRY = {
10
- 't2i_pretrain': T2IIterableDataset,
11
- 'vlm_sft': SftJSONLIterableDataset,
12
- 'unified_edit': UnifiedEditIterableDataset,
13
- }
14
-
15
-
16
- DATASET_INFO = {
17
- 't2i_pretrain': {
18
- 't2i': {
19
- 'data_dir': 'your_data_path/bagel_example/t2i', # path of the parquet files
20
- 'num_files': 10, # number of data units to be sharded across all ranks and workers
21
- 'num_total_samples': 1000, # number of total samples in the dataset
22
- },
23
- },
24
- 'unified_edit':{
25
- 'seedxedit_multi': {
26
- 'data_dir': 'your_data_path/bagel_example/editing/seedxedit_multi',
27
- 'num_files': 10,
28
- 'num_total_samples': 1000,
29
- "parquet_info_path": 'your_data_path/bagel_example/editing/parquet_info/seedxedit_multi_nas.json', # information of the parquet files
30
- },
31
- },
32
- 'vlm_sft': {
33
- 'llava_ov': {
34
- 'data_dir': 'your_data_path/bagel_example/vlm/images',
35
- 'jsonl_path': 'your_data_path/bagel_example/vlm/llava_ov_si.jsonl',
36
- 'num_total_samples': 1000
37
- },
38
- },
39
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/distributed_iterable_dataset.py DELETED
@@ -1,58 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import random
5
- import torch
6
-
7
-
8
- class DistributedIterableDataset(torch.utils.data.IterableDataset):
9
- def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
10
- self.dataset_name = dataset_name
11
- self.local_rank = local_rank
12
- self.world_size = world_size
13
- self.num_workers = num_workers
14
- self.rng = random.Random()
15
- self.data_paths = None
16
-
17
- def get_data_paths(self, *args, **kwargs):
18
- raise NotImplementedError
19
-
20
- def set_epoch(self, seed=42):
21
- if self.data_paths is None:
22
- return
23
-
24
- if isinstance(self.data_paths[0], tuple):
25
- data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
26
- elif isinstance(self.data_paths[0], str):
27
- data_paths = sorted(self.data_paths)
28
- else:
29
- raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
30
-
31
- self.rng.seed(seed)
32
- self.rng.shuffle(data_paths)
33
-
34
- num_files_per_rank = len(data_paths) // self.world_size
35
- local_start = self.local_rank * num_files_per_rank
36
- local_end = (self.local_rank + 1) * num_files_per_rank
37
- self.num_files_per_rank = num_files_per_rank
38
- self.data_paths_per_rank = data_paths[local_start:local_end]
39
-
40
- def get_data_paths_per_worker(self):
41
- if self.data_paths is None:
42
- return None
43
-
44
- info = torch.utils.data.get_worker_info()
45
- if info is None:
46
- # Single worker: Use all files assigned to the rank
47
- return self.data_paths_per_rank, 0
48
-
49
- worker_id = info.id
50
- num_files_per_worker = self.num_files_per_rank // info.num_workers
51
- start = num_files_per_worker * worker_id
52
- end = num_files_per_worker * (worker_id + 1)
53
- data_paths_per_worker = self.data_paths_per_rank[start:end]
54
-
55
- return data_paths_per_worker[::-1], worker_id
56
-
57
- def __iter__(self):
58
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/interleave_datasets/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from .edit_dataset import UnifiedEditIterableDataset
5
-
 
 
 
 
 
 
data/interleave_datasets/edit_dataset.py DELETED
@@ -1,72 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import io
5
- import random
6
- from PIL import Image, ImageFile, PngImagePlugin
7
-
8
- from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
9
- from ..data_utils import pil_img2rgb
10
-
11
-
12
- Image.MAX_IMAGE_PIXELS = 200000000
13
- ImageFile.LOAD_TRUNCATED_IMAGES = True
14
- MaximumDecompressedSize = 1024
15
- MegaByte = 2 ** 20
16
- PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
17
-
18
-
19
- class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
20
-
21
- def parse_row(self, row):
22
- image_num = len(row["image_list"])
23
- # randomly choose start and end, return [0, 1] when only two images
24
- start_idx = random.choice(range(image_num - 1))
25
- max_end = min(start_idx + 3, image_num)
26
- end_idx = random.choice(range(start_idx + 1, max_end))
27
-
28
- data = self._init_data()
29
- data = self._add_image(
30
- data,
31
- pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
32
- need_loss=False,
33
- need_vae=True,
34
- need_vit=True,
35
- )
36
-
37
- if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
38
- if end_idx == image_num - 1:
39
- end_idx -= 1
40
-
41
- instruction = ""
42
- for idx in range(start_idx + 1, end_idx + 1):
43
- instruction += random.choice(row["instruction_list"][idx-1]) + ". "
44
- data = self._add_text(data, instruction.rstrip(), need_loss=False)
45
- data = self._add_image(
46
- data,
47
- pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
48
- need_loss=True,
49
- need_vae=False,
50
- need_vit=False,
51
- )
52
- else:
53
- for idx in range(start_idx + 1, end_idx + 1):
54
- instruction = random.choice(row["instruction_list"][idx-1])
55
- data = self._add_text(data, instruction, need_loss=False)
56
- if idx != end_idx:
57
- data = self._add_image(
58
- data,
59
- pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
60
- need_loss=True,
61
- need_vae=True,
62
- need_vit=True,
63
- )
64
- else:
65
- data = self._add_image(
66
- data,
67
- pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
68
- need_loss=True,
69
- need_vae=False,
70
- need_vit=False,
71
- )
72
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/interleave_datasets/interleave_t2i_dataset.py DELETED
@@ -1,212 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import pyarrow.parquet as pq
5
-
6
- from ..distributed_iterable_dataset import DistributedIterableDataset
7
- from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
8
-
9
-
10
- class InterleavedBaseIterableDataset(DistributedIterableDataset):
11
-
12
- def _init_data(self):
13
- data = {
14
- 'sequence_plan': [],
15
- 'text_ids_list': [],
16
- 'image_tensor_list': [],
17
- 'num_tokens': 0,
18
- }
19
- return data
20
-
21
- def _add_text(self, data, text, need_loss, enable_cfg=True):
22
- text_ids = self.tokenizer.encode(text)
23
- data['num_tokens'] += len(text_ids)
24
- data['text_ids_list'].append(text_ids)
25
- data['sequence_plan'].append(
26
- {
27
- 'type': 'text',
28
- 'enable_cfg': int(enable_cfg),
29
- 'loss': int(need_loss),
30
- 'special_token_loss': 0,
31
- 'special_token_label': None,
32
- }
33
- )
34
- return data
35
-
36
- def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True):
37
- assert need_loss or need_vae or need_vit
38
-
39
- if need_loss:
40
- data['sequence_plan'].append(
41
- {
42
- 'type': 'vae_image',
43
- 'enable_cfg': 0,
44
- 'loss': 1,
45
- 'special_token_loss': 0,
46
- 'special_token_label': None,
47
- }
48
- )
49
-
50
- image_tensor = self.transform(image)
51
- height, width = image_tensor.shape[1:]
52
- data['num_tokens'] += width * height // self.transform.stride ** 2
53
- data['image_tensor_list'].append(image_tensor)
54
-
55
- if need_vae:
56
- data['sequence_plan'].append(
57
- {
58
- 'type': 'vae_image',
59
- 'enable_cfg': int(enable_cfg),
60
- 'loss': 0,
61
- 'special_token_loss': 0,
62
- 'special_token_label': None,
63
- }
64
- )
65
-
66
- image_tensor = self.transform(image)
67
- height, width = image_tensor.shape[1:]
68
- data['num_tokens'] += width * height // self.transform.stride ** 2
69
- data['image_tensor_list'].append(image_tensor.clone())
70
-
71
- if need_vit:
72
- data['sequence_plan'].append(
73
- {
74
- 'type': 'vit_image',
75
- 'enable_cfg': int(enable_cfg),
76
- 'loss': 0,
77
- 'special_token_loss': 0,
78
- 'special_token_label': None,
79
- },
80
- )
81
- vit_image_tensor = self.vit_transform(image)
82
- height, width = vit_image_tensor.shape[1:]
83
- data['num_tokens'] += width * height // self.vit_transform.stride ** 2
84
- data['image_tensor_list'].append(vit_image_tensor)
85
-
86
- return data
87
-
88
- def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
89
- assert int(need_loss) + int(need_vae) == 1
90
-
91
- if need_loss:
92
- for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
93
- current_sequence_plan = {
94
- 'type': 'vae_image',
95
- 'enable_cfg': 0,
96
- 'loss': 1,
97
- 'special_token_loss': 0,
98
- 'special_token_label': None,
99
- 'split_start': idx == 0,
100
- 'split_end': idx == len(frames) - 1,
101
- }
102
- if idx < len(frame_indexes) - 1:
103
- current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
104
- data['sequence_plan'].append(current_sequence_plan)
105
- image_tensor = self.transform(image)
106
- height, width = image_tensor.shape[1:]
107
- data['image_tensor_list'].append(image_tensor)
108
- data['num_tokens'] += width * height // self.transform.stride ** 2
109
-
110
- elif need_vae:
111
- for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
112
- current_sequence_plan = {
113
- 'type': 'vae_image',
114
- 'enable_cfg': int(enable_cfg),
115
- 'loss': 0,
116
- 'special_token_loss': 0,
117
- 'special_token_label': None,
118
- 'split_start': idx == 0,
119
- 'split_end': idx == len(frames) - 1,
120
- }
121
- if idx < len(frame_indexes) - 1:
122
- current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
123
- data['sequence_plan'].append(current_sequence_plan)
124
- image_tensor = self.transform(image)
125
- height, width = image_tensor.shape[1:]
126
- data['image_tensor_list'].append(image_tensor)
127
- data['num_tokens'] += width * height // self.transform.stride ** 2
128
-
129
- return data
130
-
131
-
132
- class ParquetStandardIterableDataset(DistributedIterableDataset):
133
-
134
- def __init__(
135
- self, dataset_name, transform, tokenizer, vit_transform,
136
- data_dir_list, num_used_data, parquet_info,
137
- local_rank=0, world_size=1, num_workers=8, data_status=None,
138
- ):
139
- """
140
- data_dir_list: list of data directories contains parquet files
141
- num_used_data: list of number of sampled data paths for each data directory
142
- vit_transform: input transform for vit model.
143
- """
144
- super().__init__(dataset_name, local_rank, world_size, num_workers)
145
- self.transform = transform
146
- self.vit_transform = vit_transform
147
- self.tokenizer = tokenizer
148
- self.data_status = data_status
149
- self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
150
- self.set_epoch()
151
-
152
- def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
153
- row_groups = []
154
- for data_dir, num_data_path in zip(data_dir_list, num_used_data):
155
- data_paths = get_parquet_data_paths([data_dir], [num_data_path])
156
- for data_path in data_paths:
157
- if data_path in parquet_info.keys():
158
- num_row_groups = parquet_info[data_path]['num_row_groups']
159
- for rg_idx in range(num_row_groups):
160
- row_groups.append((data_path, rg_idx))
161
- return row_groups
162
-
163
- def parse_row(self, row):
164
- raise NotImplementedError
165
-
166
- def __iter__(self):
167
- file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
168
- if self.data_status is not None:
169
- global_row_group_start_id = self.data_status[worker_id][0]
170
- row_start_id = self.data_status[worker_id][1] + 1
171
- else:
172
- global_row_group_start_id = 0
173
- row_start_id = 0
174
-
175
- print(
176
- f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
177
- f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
178
- )
179
-
180
- while True:
181
- file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
182
- for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
183
- file_paths_per_worker_, start=global_row_group_start_id
184
- ):
185
- fs = init_arrow_pf_fs(parquet_file_path)
186
- with fs.open_input_file(parquet_file_path) as f:
187
- try:
188
- fr = pq.ParquetFile(f)
189
- df = fr.read_row_group(row_group_id).to_pandas()
190
- df = df.iloc[row_start_id:]
191
- except Exception as e:
192
- print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
193
- continue
194
-
195
- for row_idx, row in df.iterrows():
196
- try:
197
- data = self.parse_row(row)
198
- if len(data) == 0:
199
- continue
200
- data['data_indexes'] = {
201
- "data_indexes": [global_row_group_idx, row_idx],
202
- "worker_id": worker_id,
203
- "dataset_name": self.dataset_name,
204
- }
205
- except Exception as e:
206
- print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
207
- continue
208
- yield data
209
-
210
- row_start_id = 0
211
- global_row_group_start_id = 0
212
- print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/parquet_utils.py DELETED
@@ -1,90 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
-
5
- import os
6
- import xml.etree.ElementTree as ET
7
- import subprocess
8
- import logging
9
-
10
- import pyarrow.fs as pf
11
- import torch.distributed as dist
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
17
- num_data_dirs = len(data_dir_list)
18
- if world_size > 1:
19
- chunk_size = (num_data_dirs + world_size - 1) // world_size
20
- start_idx = rank * chunk_size
21
- end_idx = min(start_idx + chunk_size, num_data_dirs)
22
- local_data_dir_list = data_dir_list[start_idx:end_idx]
23
- local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
24
- else:
25
- local_data_dir_list = data_dir_list
26
- local_num_sampled_data_paths = num_sampled_data_paths
27
-
28
- local_data_paths = []
29
- for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths):
30
- if data_dir.startswith("hdfs://"):
31
- files = hdfs_ls_cmd(data_dir)
32
- data_paths_per_dir = [
33
- file for file in files if file.endswith(".parquet")
34
- ]
35
- else:
36
- files = os.listdir(data_dir)
37
- data_paths_per_dir = [
38
- os.path.join(data_dir, name)
39
- for name in files
40
- if name.endswith(".parquet")
41
- ]
42
- repeat = num_data_path // len(data_paths_per_dir)
43
- data_paths_per_dir = data_paths_per_dir * (repeat + 1)
44
- local_data_paths.extend(data_paths_per_dir[:num_data_path])
45
-
46
- if world_size > 1:
47
- gather_list = [None] * world_size
48
- dist.all_gather_object(gather_list, local_data_paths)
49
-
50
- combined_chunks = []
51
- for chunk_list in gather_list:
52
- if chunk_list is not None:
53
- combined_chunks.extend(chunk_list)
54
- else:
55
- combined_chunks = local_data_paths
56
-
57
- return combined_chunks
58
-
59
-
60
- # NOTE: cumtomize this function for your cluster
61
- def get_hdfs_host():
62
- return "hdfs://xxx"
63
-
64
-
65
- # NOTE: cumtomize this function for your cluster
66
- def get_hdfs_block_size():
67
- return 134217728
68
-
69
-
70
- # NOTE: cumtomize this function for your cluster
71
- def get_hdfs_extra_conf():
72
- return None
73
-
74
-
75
- def init_arrow_pf_fs(parquet_file_path):
76
- if parquet_file_path.startswith("hdfs://"):
77
- fs = pf.HadoopFileSystem(
78
- host=get_hdfs_host(),
79
- port=0,
80
- buffer_size=get_hdfs_block_size(),
81
- extra_conf=get_hdfs_extra_conf(),
82
- )
83
- else:
84
- fs = pf.LocalFileSystem()
85
- return fs
86
-
87
-
88
- def hdfs_ls_cmd(dir):
89
- result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout
90
- return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/t2i_dataset.py DELETED
@@ -1,128 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import io
5
- import json
6
- import pyarrow.parquet as pq
7
- import random
8
- from PIL import Image
9
-
10
- from .data_utils import pil_img2rgb
11
- from .distributed_iterable_dataset import DistributedIterableDataset
12
- from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
13
-
14
- Image.MAX_IMAGE_PIXELS = 20_000_000
15
-
16
-
17
- class T2IIterableDataset(DistributedIterableDataset):
18
- def __init__(
19
- self, dataset_name, transform, tokenizer, data_dir_list, num_used_data,
20
- local_rank=0, world_size=1, num_workers=8, data_status=None,
21
- ):
22
- """
23
- data_dir_list: list of data directories contains parquet files
24
- num_used_data: list of number of sampled data paths for each data directory
25
- """
26
- super().__init__(dataset_name, local_rank, world_size, num_workers)
27
- self.transform = transform
28
- self.tokenizer = tokenizer
29
- self.data_status = data_status
30
- self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
31
- self.set_epoch()
32
-
33
- def get_data_paths(self, data_dir_list, num_used_data):
34
- return get_parquet_data_paths(data_dir_list, num_used_data)
35
-
36
- def __iter__(self):
37
- data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
38
- if self.data_status is not None:
39
- parquet_start_id = self.data_status[worker_id][0]
40
- row_group_start_id = self.data_status[worker_id][1]
41
- row_start_id = self.data_status[worker_id][2] + 1
42
- else:
43
- parquet_start_id = 0
44
- row_group_start_id = 0
45
- row_start_id = 0
46
- transform_stride = self.transform.stride
47
-
48
- print(
49
- f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
50
- f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
51
- )
52
-
53
- while True:
54
- data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
55
- for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id):
56
- fs = init_arrow_pf_fs(parquet_file_path)
57
- with fs.open_input_file(parquet_file_path) as f:
58
- fr = pq.ParquetFile(f)
59
- row_group_ids = list(range(fr.num_row_groups))
60
- row_group_ids_ = row_group_ids[row_group_start_id:]
61
-
62
- for row_group_id in row_group_ids_:
63
- df = fr.read_row_group(row_group_id).to_pandas()
64
- df = df.iloc[row_start_id:]
65
-
66
- for row_idx, row in df.iterrows():
67
- num_tokens = 0
68
- try:
69
- image_byte = row['image']
70
- image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
71
- except Exception as e:
72
- print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
73
- continue
74
- image_tensor = self.transform(image)
75
- height, width = image_tensor.shape[1:]
76
- num_tokens += width * height // transform_stride ** 2
77
-
78
- try:
79
- caption_dict = row['captions']
80
- caption_dict = json.loads(caption_dict)
81
- except Exception as e:
82
- print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
83
- continue
84
-
85
- caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
86
- if len(caps_token) == 0:
87
- print(f'no caption in rg#{row_group_id}, {parquet_file_path}')
88
- caption_token = self.tokenizer.encode(' ')
89
- else:
90
- caption_token = random.choice(caps_token)
91
-
92
- sequence_plan, text_ids_list = [], []
93
- text_ids = caption_token
94
- num_tokens += len(caption_token)
95
- text_ids_list.append(text_ids)
96
- sequence_plan.append({
97
- 'type': 'text',
98
- 'enable_cfg': 1,
99
- 'loss': 0,
100
- 'special_token_loss': 0,
101
- 'special_token_label': None,
102
- })
103
-
104
- sequence_plan.append({
105
- 'type': 'vae_image',
106
- 'enable_cfg': 0,
107
- 'loss': 1,
108
- 'special_token_loss': 0,
109
- 'special_token_label': None,
110
- })
111
-
112
- sample = dict(
113
- image_tensor_list=[image_tensor],
114
- text_ids_list=text_ids_list,
115
- num_tokens=num_tokens,
116
- sequence_plan=sequence_plan,
117
- data_indexes={
118
- "data_indexes": [parquet_idx, row_group_id, row_idx],
119
- "worker_id": worker_id,
120
- "dataset_name": self.dataset_name,
121
- }
122
- )
123
- yield sample
124
-
125
- row_start_id = 0
126
- row_group_start_id = 0
127
- parquet_start_id = 0
128
- print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/transforms.py DELETED
@@ -1,287 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import random
5
- from PIL import Image
6
-
7
- import cv2
8
- import numpy as np
9
- import torch
10
- from torchvision import transforms
11
- from torchvision.transforms import functional as F
12
- from torchvision.transforms import InterpolationMode
13
-
14
-
15
- class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
16
- """Resize the input image so that its longest side and shortest side are within a specified range,
17
- ensuring that both sides are divisible by a specified stride.
18
-
19
- Args:
20
- max_size (int): Maximum size for the longest edge of the image.
21
- min_size (int): Minimum size for the shortest edge of the image.
22
- stride (int): Value by which the height and width of the image must be divisible.
23
- max_pixels (int): Maximum pixels for the full image.
24
- interpolation (InterpolationMode): Desired interpolation enum defined by
25
- :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
26
- If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
27
- ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
28
- The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
29
- antialias (bool, optional): Whether to apply antialiasing (default is True).
30
- """
31
-
32
- def __init__(
33
- self,
34
- max_size: int,
35
- min_size: int,
36
- stride: int,
37
- max_pixels: int,
38
- interpolation=InterpolationMode.BICUBIC,
39
- antialias=True
40
- ):
41
- super().__init__()
42
- self.max_size = max_size
43
- self.min_size = min_size
44
- self.stride = stride
45
- self.max_pixels = max_pixels
46
- self.interpolation = interpolation
47
- self.antialias = antialias
48
-
49
- def _make_divisible(self, value, stride):
50
- """Ensure the value is divisible by the stride."""
51
- return max(stride, int(round(value / stride) * stride))
52
-
53
- def _apply_scale(self, width, height, scale):
54
- new_width = round(width * scale)
55
- new_height = round(height * scale)
56
- new_width = self._make_divisible(new_width, self.stride)
57
- new_height = self._make_divisible(new_height, self.stride)
58
- return new_width, new_height
59
-
60
- def forward(self, img, img_num=1):
61
- """
62
- Args:
63
- img (PIL Image): Image to be resized.
64
- img_num (int): Number of images, used to change max_tokens.
65
- Returns:
66
- PIL Image or Tensor: Rescaled image with divisible dimensions.
67
- """
68
- if isinstance(img, torch.Tensor):
69
- height, width = img.shape[-2:]
70
- else:
71
- width, height = img.size
72
-
73
- scale = min(self.max_size / max(width, height), 1.0)
74
- scale = max(scale, self.min_size / min(width, height))
75
- new_width, new_height = self._apply_scale(width, height, scale)
76
-
77
- # Ensure the number of pixels does not exceed max_pixels
78
- if new_width * new_height > self.max_pixels / img_num:
79
- scale = self.max_pixels / img_num / (new_width * new_height)
80
- new_width, new_height = self._apply_scale(new_width, new_height, scale)
81
-
82
- # Ensure longest edge does not exceed max_size
83
- if max(new_width, new_height) > self.max_size:
84
- scale = self.max_size / max(new_width, new_height)
85
- new_width, new_height = self._apply_scale(new_width, new_height, scale)
86
-
87
- return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
88
-
89
-
90
- class ImageTransform:
91
- def __init__(
92
- self,
93
- max_image_size,
94
- min_image_size,
95
- image_stride,
96
- max_pixels=14*14*9*1024,
97
- image_mean=[0.5, 0.5, 0.5],
98
- image_std=[0.5, 0.5, 0.5]
99
- ):
100
- self.stride = image_stride
101
-
102
- self.resize_transform = MaxLongEdgeMinShortEdgeResize(
103
- max_size=max_image_size,
104
- min_size=min_image_size,
105
- stride=image_stride,
106
- max_pixels=max_pixels,
107
- )
108
- self.to_tensor_transform = transforms.ToTensor()
109
- self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
110
-
111
- def __call__(self, img, img_num=1):
112
- img = self.resize_transform(img, img_num=img_num)
113
- img = self.to_tensor_transform(img)
114
- img = self.normalize_transform(img)
115
- return img
116
-
117
-
118
- def decolorization(image):
119
- gray_image = image.convert('L')
120
- return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
121
-
122
-
123
- def downscale(image, scale_factor):
124
- new_width = int(round(image.width * scale_factor))
125
- new_height = int(round(image.height * scale_factor))
126
- new_width = max(1, new_width)
127
- new_height = max(1, new_height)
128
- return image.resize((new_width, new_height), resample=Image.BICUBIC)
129
-
130
-
131
- def crop(image, crop_factors):
132
- target_h, target_w = crop_factors
133
- img_w, img_h = image.size
134
-
135
- if target_h > img_h or target_w > img_w:
136
- raise ValueError("Crop size exceeds image dimensions")
137
-
138
- x = random.randint(0, img_w - target_w)
139
- y = random.randint(0, img_h - target_h)
140
-
141
- return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
142
-
143
-
144
- def motion_blur_opencv(image, kernel_size=15, angle=0):
145
- # 线性核
146
- kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
147
- kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
148
-
149
- # 旋转核
150
- center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
151
- M = cv2.getRotationMatrix2D(center, angle, 1)
152
- rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
153
-
154
- # 归一化核
155
- rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
156
-
157
- img = np.array(image)
158
- if img.ndim == 2:
159
- blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
160
- else:
161
- # 对于彩色图像,各通道独立卷积
162
- blurred = np.zeros_like(img)
163
- for c in range(img.shape[2]):
164
- blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
165
-
166
- return Image.fromarray(blurred.astype(np.uint8))
167
-
168
-
169
- def shuffle_patch(image, num_splits, gap_size=2):
170
- """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
171
- h_splits, w_splits = num_splits
172
- img_w, img_h = image.size
173
-
174
- base_patch_h = img_h // h_splits
175
- patch_heights = [base_patch_h] * (h_splits - 1)
176
- patch_heights.append(img_h - sum(patch_heights))
177
-
178
- base_patch_w = img_w // w_splits
179
- patch_widths = [base_patch_w] * (w_splits - 1)
180
- patch_widths.append(img_w - sum(patch_widths))
181
-
182
- patches = []
183
- current_y = 0
184
- for i in range(h_splits):
185
- current_x = 0
186
- patch_h = patch_heights[i]
187
- for j in range(w_splits):
188
- patch_w = patch_widths[j]
189
- patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
190
- patches.append(patch)
191
- current_x += patch_w
192
- current_y += patch_h
193
-
194
- random.shuffle(patches)
195
-
196
- total_width = sum(patch_widths) + (w_splits - 1) * gap_size
197
- total_height = sum(patch_heights) + (h_splits - 1) * gap_size
198
- new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
199
-
200
- current_y = 0 # 当前行的起始 Y 坐标
201
- patch_idx = 0 # 当前处理的块索引
202
- for i in range(h_splits):
203
- current_x = 0 # 当前列的起始 X 坐标
204
- patch_h = patch_heights[i] # 当前行块的高度
205
- for j in range(w_splits):
206
- # 取出打乱后的块
207
- patch = patches[patch_idx]
208
- patch_w = patch_widths[j] # 当前列块的宽度
209
- # 粘贴块(左上角坐标为 (current_x, current_y))
210
- new_image.paste(patch, (current_x, current_y))
211
- # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
212
- current_x += patch_w + gap_size
213
- patch_idx += 1
214
- # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
215
- current_y += patch_h + gap_size
216
-
217
- return new_image
218
-
219
-
220
- def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
221
- """
222
- 图像分割后随机空白部分patch,用于inpainting任务
223
-
224
- 参数:
225
- image: PIL.Image 输入图像(RGB模式)
226
- h_splits: int 行分割数(垂直方向分割块数)
227
- w_splits: int 列分割数(水平方向分割块数)
228
- blank_ratio: float 空白patch的比例(0~1)
229
- blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
230
-
231
- 返回:
232
- PIL.Image 处理后拼接的图像
233
- """
234
- h_splits, w_splits = num_splits
235
- img_w, img_h = image.size
236
-
237
- base_patch_h = img_h // h_splits
238
- patch_heights = [base_patch_h] * (h_splits - 1)
239
- patch_heights.append(img_h - sum(patch_heights))
240
-
241
- base_patch_w = img_w // w_splits
242
- patch_widths = [base_patch_w] * (w_splits - 1)
243
- patch_widths.append(img_w - sum(patch_widths))
244
-
245
- patches = []
246
- current_y = 0
247
- for i in range(h_splits):
248
- current_x = 0
249
- patch_h = patch_heights[i]
250
- for j in range(w_splits):
251
- patch_w = patch_widths[j]
252
- patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
253
- patches.append(patch)
254
- current_x += patch_w
255
- current_y += patch_h
256
-
257
- total_patches = h_splits * w_splits
258
- num_blank = int(total_patches * blank_ratio)
259
- num_blank = max(0, min(num_blank, total_patches))
260
- blank_indices = random.sample(range(total_patches), num_blank)
261
-
262
- processed_patches = []
263
- for idx, patch in enumerate(patches):
264
- if idx in blank_indices:
265
- blank_patch = Image.new("RGB", patch.size, color=blank_color)
266
- processed_patches.append(blank_patch)
267
- else:
268
- processed_patches.append(patch)
269
-
270
- # 创建结果图像(尺寸与原图一致)
271
- result_image = Image.new("RGB", (img_w, img_h))
272
- current_y = 0
273
- patch_idx = 0
274
- for i in range(h_splits):
275
- current_x = 0
276
- patch_h = patch_heights[i]
277
- for j in range(w_splits):
278
- # 取出处理后的patch
279
- patch = processed_patches[patch_idx]
280
- patch_w = patch_widths[j]
281
- # 粘贴到原位置
282
- result_image.paste(patch, (current_x, current_y))
283
- current_x += patch_w
284
- patch_idx += 1
285
- current_y += patch_h
286
-
287
- return result_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/video_utils.py DELETED
@@ -1,165 +0,0 @@
1
- # Copyright (c) 2023 OpenGVLab
2
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
- # SPDX-License-Identifier: MIT
4
- #
5
- # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
- #
7
- # Original file was released under MIT, with the full license text
8
- # available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
9
- #
10
- # This modified file is released under the same license.
11
-
12
-
13
- import io
14
- import os
15
- import random
16
- import re
17
-
18
- import numpy as np
19
- import decord
20
- from PIL import Image
21
-
22
-
23
- def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
24
- if sample in ['rand', 'middle']: # uniform sampling
25
- acc_samples = min(num_frames, vlen)
26
- # split the video into `acc_samples` intervals, and sample from each interval.
27
- intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
28
- ranges = []
29
- for idx, interv in enumerate(intervals[:-1]):
30
- ranges.append((interv, intervals[idx + 1] - 1))
31
- if sample == 'rand':
32
- try:
33
- frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
34
- except:
35
- frame_indices = np.random.permutation(vlen)[:acc_samples]
36
- frame_indices.sort()
37
- frame_indices = list(frame_indices)
38
- elif fix_start is not None:
39
- frame_indices = [x[0] + fix_start for x in ranges]
40
- elif sample == 'middle':
41
- frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
42
- else:
43
- raise NotImplementedError
44
-
45
- if len(frame_indices) < num_frames: # padded with last frame
46
- padded_frame_indices = [frame_indices[-1]] * num_frames
47
- padded_frame_indices[:len(frame_indices)] = frame_indices
48
- frame_indices = padded_frame_indices
49
- elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
50
- output_fps = float(sample[3:])
51
- duration = float(vlen) / input_fps
52
- delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
53
- frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
54
- frame_indices = np.around(frame_seconds * input_fps).astype(int)
55
- frame_indices = [e for e in frame_indices if e < vlen]
56
- if max_num_frames > 0 and len(frame_indices) > max_num_frames:
57
- frame_indices = frame_indices[:max_num_frames]
58
- else:
59
- raise ValueError
60
- return frame_indices
61
-
62
-
63
- def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
64
- video_reader = decord.VideoReader(video_path, num_threads=1)
65
- vlen = len(video_reader)
66
- fps = video_reader.get_avg_fps()
67
- duration = vlen / float(fps)
68
- if clip:
69
- start, end = clip
70
- duration = end - start
71
- vlen = int(duration * fps)
72
- start_index = int(start * fps)
73
-
74
- t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
75
-
76
- frame_indices = get_frame_indices(
77
- t_num_frames, vlen, sample=sample, fix_start=fix_start,
78
- input_fps=fps
79
- )
80
- if clip:
81
- frame_indices = [f + start_index for f in frame_indices]
82
- frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
83
- frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
84
- return frames
85
-
86
-
87
- def extract_frame_number(filename):
88
- # Extract the numeric part from the filename using regular expressions
89
- match = re.search(r'_(\d+).jpg$', filename)
90
- return int(match.group(1)) if match else -1
91
-
92
-
93
- def sort_frames(frame_paths):
94
- # Extract filenames from each path and sort by their numeric part
95
- return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
96
-
97
-
98
- def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
99
- image_list = sort_frames(list(os.listdir(video_path)))
100
- frames = []
101
- for image in image_list:
102
- fp = os.path.join(video_path, image)
103
- frame = Image.open(fp).convert('RGB')
104
- frames.append(frame)
105
- vlen = len(frames)
106
-
107
- t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
108
-
109
- if vlen > t_num_frames:
110
- frame_indices = get_frame_indices(
111
- t_num_frames, vlen, sample=sample, fix_start=fix_start
112
- )
113
- frames = [frames[i] for i in frame_indices]
114
- return frames
115
-
116
-
117
- class FrameSampler:
118
- def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
119
- self.max_num_frames = max_num_frames
120
- self.min_num_frames = min_num_frames
121
- self.sample = sample
122
-
123
- def __call__(self, file_name):
124
- fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
125
- frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
126
- return frames
127
-
128
-
129
- def decode_video_byte(video_bytes):
130
- video_stream = io.BytesIO(video_bytes)
131
- vr = decord.VideoReader(video_stream)
132
- return vr
133
-
134
-
135
- def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
136
- if isinstance(mp4_p, str):
137
- vr = decord.VideoReader(mp4_p, num_threads=1)
138
- elif isinstance(mp4_p, decord.video_reader.VideoReader):
139
- vr = mp4_p
140
- video_fps = vr.get_avg_fps() # 获取视频的帧率
141
- video_duration = len(vr) / video_fps
142
- if n_frames is not None:
143
- if random_sample:
144
- frame_indices = sorted(random.sample(range(len(vr)), n_frames))
145
- else:
146
- frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
147
- else:
148
- frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
149
- frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
150
- frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
151
- if not return_frame_indices:
152
- return frames, video_duration
153
- else:
154
- return frames, video_duration, frame_indices
155
-
156
-
157
- def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
158
- if isinstance(mp4_p, str):
159
- vr = decord.VideoReader(mp4_p, num_threads=1)
160
- elif isinstance(mp4_p, decord.video_reader.VideoReader):
161
- vr = mp4_p
162
- # sample the frames in frame_indices
163
- frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
164
- frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
165
- return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/vlm_dataset.py DELETED
@@ -1,195 +0,0 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import json
5
- import os
6
- import traceback
7
- from PIL import Image, ImageFile, PngImagePlugin
8
-
9
- from .data_utils import pil_img2rgb
10
- from .distributed_iterable_dataset import DistributedIterableDataset
11
-
12
-
13
- Image.MAX_IMAGE_PIXELS = 200000000
14
- ImageFile.LOAD_TRUNCATED_IMAGES = True
15
- MaximumDecompressedSize = 1024
16
- MegaByte = 2 ** 20
17
- PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
18
-
19
-
20
- class SftJSONLIterableDataset(DistributedIterableDataset):
21
- def __init__(
22
- self, dataset_name, transform, tokenizer, frame_sampler,
23
- jsonl_path_list, data_dir_list, num_used_data,
24
- local_rank=0, world_size=1, num_workers=8, data_status=None,
25
- shuffle_lines=False, shuffle_seed=0,
26
- ):
27
- """
28
- jsonl_path_list: list of jsonl file paths
29
- data_dir_list: list of image directories containing the images of each jsonl file
30
- num_used_data: list of number of sampled data points for each jsonl
31
- """
32
- super().__init__(dataset_name, local_rank, world_size, num_workers)
33
- self.transform = transform
34
- self.tokenizer = tokenizer
35
- self.frame_sampler = frame_sampler
36
- self.data_status = data_status
37
- self.data_paths = self.get_data_paths(
38
- jsonl_path_list,
39
- data_dir_list,
40
- num_used_data,
41
- shuffle_lines,
42
- shuffle_seed,
43
- )
44
- self.set_epoch()
45
-
46
- def get_data_paths(
47
- self,
48
- jsonl_path_list,
49
- data_dir_list,
50
- num_used_data,
51
- shuffle_lines,
52
- shuffle_seed,
53
- ):
54
- data_paths = []
55
- for jsonl_path, image_dir, num_data_point in zip(
56
- jsonl_path_list, data_dir_list, num_used_data
57
- ):
58
- with open(jsonl_path, 'r') as f:
59
- raw_data = f.readlines()
60
- if shuffle_lines:
61
- self.rng.seed(shuffle_seed)
62
- self.rng.shuffle(raw_data)
63
- raw_data = raw_data[:num_data_point]
64
- data_paths.extend([(json_data, image_dir) for json_data in raw_data])
65
- return data_paths
66
-
67
- def change_format(self, data, num_images):
68
- elements = []
69
- for conversation in data['conversations']:
70
- if conversation['from'] == 'human':
71
- if '<image>' not in conversation['value']:
72
- elements.append({
73
- 'type': 'text',
74
- 'has_loss': 0,
75
- 'text': conversation['value'],
76
- })
77
- else:
78
- text_list = conversation['value'].split('<image>')
79
- for idx, text in enumerate(text_list):
80
- if text.strip() != '':
81
- elements.append({
82
- 'type': 'text',
83
- 'has_loss': 0,
84
- 'text': text.strip(),
85
- })
86
- if (idx != len(text_list) - 1) and (idx < num_images):
87
- elements.append({'type': 'image',})
88
- elif conversation['from'] == 'gpt':
89
- elements.append({
90
- 'type': 'text',
91
- 'has_loss': 1,
92
- 'text': conversation['value'],
93
- })
94
- return elements
95
-
96
- def __iter__(self):
97
- data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
98
- if self.data_status is not None:
99
- row_start_id = self.data_status[worker_id] + 1
100
- else:
101
- row_start_id = 0
102
- transform_stride = self.transform.stride
103
-
104
- print(
105
- f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
106
- f"resuming data at row#{row_start_id}"
107
- )
108
-
109
- while True:
110
- data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
111
- for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id):
112
- num_tokens = 0
113
- image_tensor_list = []
114
- text_ids_list = []
115
- sequence_plan = []
116
-
117
- try:
118
- data_item = json.loads(data)
119
- raw_images = None
120
- if 'image' in data_item:
121
- if type(data_item['image']) == list:
122
- raw_images = [
123
- pil_img2rgb(Image.open(os.path.join(image_dir, image)))
124
- for image in data_item['image']
125
- ]
126
- else:
127
- raw_images = [
128
- pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'])))
129
- ]
130
- elif 'video' in data_item:
131
- raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video']))
132
- special_tokens = '<image>' * len(raw_images)
133
- for item in data_item['conversations']:
134
- if '<video>' in item['value']:
135
- item['value'] = item['value'].replace('<video>', special_tokens)
136
- break
137
- else:
138
- raise ValueError("Cannot find <video> in the conversation!")
139
- except:
140
- traceback.print_exc()
141
- continue
142
-
143
- if raw_images:
144
- for raw_image in raw_images:
145
- image_tensor = self.transform(raw_image, img_num=len(raw_images))
146
- image_tensor_list.append(image_tensor)
147
- height, width = image_tensor.shape[1:]
148
- num_tokens += width * height // transform_stride ** 2
149
-
150
- elements = self.change_format(data_item, len(image_tensor_list))
151
-
152
- for item in elements:
153
- if item['type'] == 'text':
154
- text_data = item['text']
155
- text_ids = self.tokenizer.encode(text_data)
156
- if len(text_ids) > 0:
157
- text_ids_list.append(text_ids)
158
- num_tokens += len(text_ids)
159
- current_plan = {
160
- 'type': 'text',
161
- 'enable_cfg': 0,
162
- 'loss': item['has_loss'],
163
- 'special_token_loss': 0,
164
- 'special_token_label': None,
165
- }
166
- sequence_plan.append(current_plan)
167
- elif item['type'] == 'image':
168
- current_plan = {
169
- 'type': 'vit_image',
170
- 'enable_cfg': 0,
171
- 'loss': 0,
172
- 'special_token_loss': 0,
173
- 'special_token_label': None,
174
- }
175
- sequence_plan.append(current_plan)
176
-
177
- has_loss = [item['loss'] for item in sequence_plan]
178
- if sum(has_loss) == 0:
179
- print(f'No loss defined, skipped.')
180
- continue
181
-
182
- yield dict(
183
- image_tensor_list=image_tensor_list,
184
- text_ids_list=text_ids_list,
185
- sequence_plan=sequence_plan,
186
- num_tokens=num_tokens,
187
- data_indexes={
188
- "data_indexes": row_idx,
189
- "worker_id": worker_id,
190
- "dataset_name": self.dataset_name,
191
- }
192
- )
193
-
194
- row_start_id = 0
195
- print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")