KingNish commited on
Commit
9481949
·
verified ·
1 Parent(s): 38fb4d8

Upload 14 files

Browse files
data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
data/configs/example.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ t2i_pretrain:
2
+ dataset_names:
3
+ - t2i
4
+ image_transform_args:
5
+ image_stride: 16
6
+ max_image_size: 1024
7
+ min_image_size: 512
8
+ is_mandatory: true
9
+ num_used_data: # The sum should be larger that NUM_GPUS x NUM_WORKERS
10
+ - 10
11
+ weight: 1
12
+
13
+ unified_edit:
14
+ dataset_names:
15
+ - seedxedit_multi
16
+ image_transform_args:
17
+ image_stride: 16
18
+ max_image_size: 1024
19
+ min_image_size: 512
20
+ vit_image_transform_args:
21
+ image_stride: 14
22
+ max_image_size: 518
23
+ min_image_size: 224
24
+ is_mandatory: false
25
+ num_used_data:
26
+ - 10
27
+ weight: 1
28
+
29
+ vlm_sft:
30
+ dataset_names:
31
+ - llava_ov
32
+ image_transform_args:
33
+ image_stride: 14
34
+ max_image_size: 980
35
+ min_image_size: 378
36
+ max_pixels: 2_007_040
37
+ frame_sampler_args:
38
+ max_num_frames: 12
39
+ min_num_frames: 8
40
+ is_mandatory: true
41
+ shuffle_lines: True
42
+ shuffle_seed: 0
43
+ num_used_data:
44
+ - 1000
45
+ weight: 1
data/data_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import math
6
+ import random
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torch.nn.attention.flex_attention import or_masks, and_masks
11
+
12
+
13
+ def create_sparse_mask(document_lens, split_lens, attn_modes, device):
14
+ def causal_mask(b, h, q_idx, kv_idx):
15
+ return q_idx >= kv_idx
16
+
17
+ def full_and_noise_mask(b, h, q_idx, kv_idx):
18
+ return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
19
+
20
+ def remove_noise_mask(b, h, q_idx, kv_idx):
21
+ return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
22
+
23
+ def sample_mask(b, h, q_idx, kv_idx):
24
+ return document_id[q_idx] == document_id[kv_idx]
25
+
26
+ full_and_noise_tmp = []
27
+ noise_tmp = []
28
+
29
+ for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
30
+ value = i if model in ['full', 'noise'] else -1
31
+ full_and_noise_tmp.extend([value] * length)
32
+ value_noise = i if model == 'noise' else -1
33
+ noise_tmp.extend([value_noise] * length)
34
+
35
+ full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
36
+ noise_seq_id = torch.Tensor(noise_tmp).to(device)
37
+
38
+ document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
39
+
40
+ return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
41
+
42
+
43
+ def patchify(image, patch_size):
44
+ p = patch_size
45
+ c, h, w = image.shape
46
+ assert h % p == 0 and w % p == 0
47
+ image = image.reshape(c, h // p, p, w // p, p)
48
+ image = torch.einsum("chpwq->hwpqc", image)
49
+ image = image.reshape(-1, p**2 * c)
50
+ return image
51
+
52
+
53
+ def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
54
+ num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
55
+ coords_h = torch.arange(0, num_patches_h)
56
+ coords_w = torch.arange(0, num_patches_w)
57
+ pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
58
+ return pos_ids
59
+
60
+
61
+ def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
62
+ num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
63
+ boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
64
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
65
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
66
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
67
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
68
+ pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
69
+ return pos_ids
70
+
71
+
72
+ def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
73
+ """
74
+ nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
75
+ a sample, where each sample contains multiple splits with different attn modes.
76
+ nested_attn_modes: whether to use full attn in each split.
77
+ """
78
+ sample_len = sum(split_lens)
79
+ attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
80
+
81
+ csum = 0
82
+ for s, attn_mode in zip(split_lens, attn_modes):
83
+ assert attn_mode in ['causal', 'full', 'noise']
84
+ if attn_mode == "causal":
85
+ attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
86
+ attention_mask[csum:csum + s, :csum] = 1
87
+ else:
88
+ attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
89
+ attention_mask[csum:csum + s, :csum] = 1
90
+ csum += s
91
+
92
+ csum = 0
93
+ for s, attn_mode in zip(split_lens, attn_modes):
94
+ if attn_mode == "noise":
95
+ attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
96
+ attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
97
+ csum += s
98
+
99
+ attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
100
+ ~attention_mask, float("-inf")
101
+ )
102
+
103
+ return attention_mask
104
+
105
+
106
+ def split_integer_exp_decay(S, ng_sample_decay=1.0):
107
+ if ng_sample_decay == 1.0:
108
+ N = random.randint(1, S)
109
+ else:
110
+ base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
111
+ p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
112
+ N = random.choices(list(range(1, S + 1)), p, k=1)[0]
113
+ cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
114
+ result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
115
+ return result, cumsum
116
+
117
+
118
+ def pil_img2rgb(image):
119
+ if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
120
+ image = image.convert("RGBA")
121
+ white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
122
+ white.paste(image, mask=image.split()[3])
123
+ image = white
124
+ else:
125
+ image = image.convert("RGB")
126
+
127
+ return image
128
+
129
+
130
+ def add_special_tokens(tokenizer):
131
+ all_special_tokens = []
132
+ for k, v in tokenizer.special_tokens_map.items():
133
+ if isinstance(v, str):
134
+ all_special_tokens.append(v)
135
+ elif isinstance(v, list):
136
+ all_special_tokens += v
137
+
138
+ new_tokens = []
139
+
140
+ if '<|im_start|>' not in all_special_tokens:
141
+ new_tokens.append('<|im_start|>')
142
+
143
+ if '<|im_end|>' not in all_special_tokens:
144
+ new_tokens.append('<|im_end|>')
145
+
146
+ if '<|vision_start|>' not in all_special_tokens:
147
+ new_tokens.append('<|vision_start|>')
148
+
149
+ if '<|vision_end|>' not in all_special_tokens:
150
+ new_tokens.append('<|vision_end|>')
151
+
152
+ num_new_tokens = tokenizer.add_tokens(new_tokens)
153
+ bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
154
+ eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
155
+ start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
156
+ end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
157
+
158
+ new_token_ids = dict(
159
+ bos_token_id=bos_token_id,
160
+ eos_token_id=eos_token_id,
161
+ start_of_image=start_of_image,
162
+ end_of_image=end_of_image,
163
+ )
164
+
165
+ return tokenizer, new_token_ids, num_new_tokens
166
+
167
+
168
+ def len2weight(x, loss_reduction='square'):
169
+ if x == 0:
170
+ return x
171
+ if loss_reduction == 'token':
172
+ return 1
173
+ if loss_reduction == 'sample':
174
+ return 1 / x
175
+ if loss_reduction == 'square':
176
+ return 1 / (x ** 0.5)
177
+ raise NotImplementedError(loss_reduction)
data/dataset_base.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import random
6
+ import json
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from .data_utils import (
12
+ get_flattened_position_ids_interpolate,
13
+ get_flattened_position_ids_extrapolate,
14
+ len2weight,
15
+ patchify,
16
+ prepare_attention_mask_per_sample,
17
+ )
18
+ from .dataset_info import DATASET_INFO, DATASET_REGISTRY
19
+ from .transforms import ImageTransform
20
+ from .video_utils import FrameSampler
21
+
22
+
23
+ class DataConfig:
24
+ def __init__(
25
+ self,
26
+ grouped_datasets,
27
+ text_cond_dropout_prob=0.1,
28
+ vit_cond_dropout_prob=0.4,
29
+ vae_cond_dropout_prob=0.1,
30
+ vae_image_downsample=16,
31
+ max_latent_size=32,
32
+ vit_patch_size=14,
33
+ max_num_patch_per_side=70,
34
+ ):
35
+ self.grouped_datasets = grouped_datasets
36
+ self.text_cond_dropout_prob = text_cond_dropout_prob
37
+ self.vit_cond_dropout_prob = vit_cond_dropout_prob
38
+ self.vit_patch_size = vit_patch_size
39
+ self.max_num_patch_per_side = max_num_patch_per_side
40
+ self.vae_cond_dropout_prob = vae_cond_dropout_prob
41
+ self.vae_image_downsample = vae_image_downsample
42
+ self.max_latent_size = max_latent_size
43
+
44
+
45
+ class PackedDataset(torch.utils.data.IterableDataset):
46
+ def __init__(
47
+ self,
48
+ data_config,
49
+ tokenizer,
50
+ special_tokens,
51
+ local_rank,
52
+ world_size,
53
+ num_workers,
54
+ expected_num_tokens=32768,
55
+ max_num_tokens_per_sample=16384,
56
+ max_num_tokens=36864,
57
+ prefer_buffer_before=16384,
58
+ max_buffer_size=50,
59
+ interpolate_pos=False,
60
+ use_flex=False,
61
+ data_status=None,
62
+ ):
63
+ super().__init__()
64
+ self.expected_num_tokens = expected_num_tokens
65
+ self.max_num_tokens_per_sample = max_num_tokens_per_sample
66
+ self.prefer_buffer_before = prefer_buffer_before
67
+ self.max_num_tokens = max_num_tokens
68
+ self.max_buffer_size = max_buffer_size
69
+ self.tokenizer = tokenizer
70
+ self.local_rank = local_rank
71
+ self.world_size = world_size
72
+ self.num_workers = num_workers
73
+ self.use_flex = use_flex
74
+ for k, v in special_tokens.items():
75
+ setattr(self, k, v)
76
+
77
+ grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
78
+ data_config.grouped_datasets, data_status
79
+ )
80
+ self.grouped_datasets = grouped_datasets
81
+ self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
82
+ self.is_mandatory = is_mandatory
83
+ self.grouped_weights = grouped_weights
84
+ self.data_config = data_config
85
+ self.interpolate_pos = interpolate_pos
86
+ if self.interpolate_pos:
87
+ self.get_flattened_position_ids = get_flattened_position_ids_interpolate
88
+ else:
89
+ self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
90
+
91
+ def build_datasets(self, datasets_metainfo, data_status):
92
+ datasets = []
93
+ is_mandatory = []
94
+ grouped_weights = []
95
+ for grouped_dataset_name, dataset_args in datasets_metainfo.items():
96
+ is_mandatory.append(dataset_args.pop('is_mandatory', False))
97
+ grouped_weights.append(dataset_args.pop('weight', 0.0))
98
+
99
+ if 'frame_sampler_args' in dataset_args.keys():
100
+ frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args'))
101
+ dataset_args['frame_sampler'] = frame_sampler
102
+ if 'image_transform_args' in dataset_args.keys():
103
+ transform = ImageTransform(**dataset_args.pop('image_transform_args'))
104
+ dataset_args['transform'] = transform
105
+ if 'vit_image_transform_args' in dataset_args.keys():
106
+ vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args'))
107
+ dataset_args['vit_transform'] = vit_transform
108
+
109
+ assert 'dataset_names' in dataset_args.keys()
110
+ dataset_names = dataset_args.pop('dataset_names')
111
+ dataset_args['data_dir_list'] = []
112
+ for item in dataset_names:
113
+ if self.local_rank == 0:
114
+ print(f'Preparing Dataset {grouped_dataset_name}/{item}')
115
+ meta_info = DATASET_INFO[grouped_dataset_name][item]
116
+ dataset_args['data_dir_list'].append(meta_info['data_dir'])
117
+
118
+ if "parquet_info_path" in meta_info.keys():
119
+ if 'parquet_info' not in dataset_args.keys():
120
+ dataset_args['parquet_info'] = {}
121
+ with open(meta_info['parquet_info_path'], 'r') as f:
122
+ parquet_info = json.load(f)
123
+ dataset_args['parquet_info'].update(parquet_info)
124
+
125
+ if 'json_dir' in meta_info.keys():
126
+ # parquet/tar with json
127
+ if 'json_dir_list' not in dataset_args.keys():
128
+ dataset_args['json_dir_list'] = [meta_info['json_dir']]
129
+ else:
130
+ dataset_args['json_dir_list'].append(meta_info['json_dir'])
131
+
132
+ if 'jsonl_path' in meta_info.keys():
133
+ # jsonl with jpeg
134
+ if 'jsonl_path_list' not in dataset_args.keys():
135
+ dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']]
136
+ else:
137
+ dataset_args['jsonl_path_list'].append(meta_info['jsonl_path'])
138
+
139
+ resume_data_status = dataset_args.pop('resume_data_status', True)
140
+ if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status:
141
+ data_status_per_group = data_status[grouped_dataset_name]
142
+ else:
143
+ data_status_per_group = None
144
+ dataset = DATASET_REGISTRY[grouped_dataset_name](
145
+ dataset_name=grouped_dataset_name,
146
+ tokenizer=self.tokenizer,
147
+ local_rank=self.local_rank,
148
+ world_size=self.world_size,
149
+ num_workers=self.num_workers,
150
+ data_status=data_status_per_group,
151
+ **dataset_args
152
+ )
153
+ datasets.append(dataset)
154
+
155
+ return datasets, is_mandatory, grouped_weights
156
+
157
+ def set_epoch(self, seed):
158
+ for dataset in self.grouped_datasets:
159
+ dataset.set_epoch(seed)
160
+
161
+ def set_sequence_status(self):
162
+ sequence_status = dict(
163
+ curr = 0,
164
+ sample_lens = list(),
165
+ packed_position_ids = list(),
166
+ nested_attention_masks = list(),
167
+ split_lens = list(),
168
+ attn_modes = list(),
169
+ packed_text_ids = list(),
170
+ packed_text_indexes = list(),
171
+ packed_label_ids = list(),
172
+ ce_loss_indexes = list(),
173
+ ce_loss_weights = list(),
174
+ vae_image_tensors = list(),
175
+ packed_latent_position_ids = list(),
176
+ vae_latent_shapes = list(),
177
+ packed_vae_token_indexes = list(),
178
+ packed_timesteps = list(),
179
+ mse_loss_indexes = list(),
180
+ packed_vit_tokens = list(),
181
+ vit_token_seqlens = list(),
182
+ packed_vit_position_ids = list(),
183
+ packed_vit_token_indexes = list(),
184
+ )
185
+ return sequence_status
186
+
187
+ def to_tensor(self, sequence_status):
188
+ data = dict(
189
+ sequence_length=sum(sequence_status['sample_lens']),
190
+ sample_lens=sequence_status['sample_lens'],
191
+ packed_text_ids=torch.tensor(sequence_status['packed_text_ids']),
192
+ packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']),
193
+ packed_position_ids=torch.tensor(sequence_status['packed_position_ids']),
194
+ )
195
+ if not self.use_flex:
196
+ data['nested_attention_masks'] = sequence_status['nested_attention_masks']
197
+ else:
198
+ sequence_len = data['sequence_length']
199
+ pad_len = self.max_num_tokens - sequence_len
200
+ data['split_lens'] = sequence_status['split_lens'] + [pad_len]
201
+ data['attn_modes'] = sequence_status['attn_modes'] + ['causal']
202
+ data['sample_lens'] += [pad_len]
203
+
204
+ # if the model has a convnet vae (e.g., as visual tokenizer)
205
+ if len(sequence_status['vae_image_tensors']) > 0:
206
+ image_tensors = sequence_status.pop('vae_image_tensors')
207
+ image_sizes = [item.shape for item in image_tensors]
208
+ max_image_size = [max(item) for item in list(zip(*image_sizes))]
209
+ padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
210
+ for i, image_tensor in enumerate(image_tensors):
211
+ padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
212
+
213
+ data['padded_images'] = padded_images
214
+ data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes']
215
+ data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0)
216
+ data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes'])
217
+
218
+ # if the model has a vit (e.g., as visual tokenizer)
219
+ if len(sequence_status['packed_vit_tokens']) > 0:
220
+ data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0)
221
+ data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0)
222
+ data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes'])
223
+ data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens'])
224
+
225
+ # if the model is required to perform visual generation
226
+ if len(sequence_status['packed_timesteps']) > 0:
227
+ data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps'])
228
+ data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes'])
229
+
230
+ # if the model is required to perform text generation
231
+ if len(sequence_status['packed_label_ids']) > 0:
232
+ data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids'])
233
+ data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes'])
234
+ data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights'])
235
+
236
+ return data
237
+
238
+ def __iter__(self):
239
+ total_weights = sum(self.grouped_weights)
240
+ assert total_weights > 0.0
241
+ group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights
242
+ for i in range(len(self.grouped_weights))]
243
+ sequence_status = self.set_sequence_status()
244
+ batch_data_indexes = []
245
+
246
+ buffer = []
247
+ while True:
248
+ # Ensure at least one sample from each group
249
+ if sequence_status['curr'] == 0:
250
+ for group_index, group_iter in enumerate(self.dataset_iters):
251
+ if self.is_mandatory[group_index]:
252
+ while True:
253
+ sample = next(group_iter)
254
+ # if a sample is too long, skip it
255
+ num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
256
+ if num_tokens < self.max_num_tokens_per_sample:
257
+ sequence_status = self.pack_sequence(sample, sequence_status)
258
+ batch_data_indexes.append(sample['data_indexes'])
259
+ break
260
+ else:
261
+ print(f"skip a sample with length {num_tokens}")
262
+ continue
263
+
264
+ if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0:
265
+ sample = buffer.pop(0)
266
+ sample_from_buffer = True
267
+ else:
268
+ # sample normally across all groups
269
+ n = random.random()
270
+ group_index = 0
271
+ for i, cumprob in enumerate(group_cumprobs):
272
+ if n < cumprob:
273
+ group_index = i
274
+ break
275
+ sample = next(self.dataset_iters[group_index])
276
+ sample_from_buffer = False
277
+
278
+ # if a sample is too long, skip it
279
+ num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
280
+ if num_tokens > self.max_num_tokens_per_sample:
281
+ print(f"skip a sample with length {num_tokens}")
282
+ continue
283
+
284
+ if sequence_status['curr'] + num_tokens > self.max_num_tokens:
285
+ if len(buffer) < self.max_buffer_size and not sample_from_buffer:
286
+ buffer.append(sample)
287
+ else:
288
+ print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
289
+ data = self.to_tensor(sequence_status)
290
+ data['batch_data_indexes'] = batch_data_indexes
291
+ yield data
292
+ sequence_status = self.set_sequence_status()
293
+ batch_data_indexes = []
294
+ continue
295
+
296
+ sequence_status = self.pack_sequence(sample, sequence_status)
297
+ batch_data_indexes.append(sample['data_indexes'])
298
+
299
+ if sequence_status['curr'] >= self.expected_num_tokens:
300
+ data = self.to_tensor(sequence_status)
301
+ data['batch_data_indexes'] = batch_data_indexes
302
+ yield data
303
+ sequence_status = self.set_sequence_status()
304
+ batch_data_indexes = []
305
+
306
+ def pack_sequence(self, sample, sequence_status):
307
+ image_tensor_list = sample['image_tensor_list']
308
+ text_ids_list = sample['text_ids_list']
309
+ sequence_plan = sample['sequence_plan']
310
+
311
+ split_lens, attn_modes = list(), list()
312
+ curr = sequence_status['curr']
313
+ curr_rope_id = 0
314
+ sample_lens = 0
315
+
316
+ for item in sequence_plan:
317
+ split_start = item.get('split_start', True)
318
+ if split_start:
319
+ curr_split_len = 0
320
+
321
+ if item['type'] == 'text':
322
+ text_ids = text_ids_list.pop(0)
323
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob:
324
+ continue
325
+
326
+ shifted_text_ids = [self.bos_token_id] + text_ids
327
+ sequence_status['packed_text_ids'].extend(shifted_text_ids)
328
+ sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
329
+ if item['loss'] == 1:
330
+ sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
331
+ sequence_status['ce_loss_weights'].extend(
332
+ [len2weight(len(shifted_text_ids))] * len(shifted_text_ids)
333
+ )
334
+ sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id])
335
+ curr += len(shifted_text_ids)
336
+ curr_split_len += len(shifted_text_ids)
337
+
338
+ # add a <|im_end|> token
339
+ sequence_status['packed_text_ids'].append(self.eos_token_id)
340
+ sequence_status['packed_text_indexes'].append(curr)
341
+ if item['special_token_loss'] == 1: # <|im_end|> may have loss
342
+ sequence_status['ce_loss_indexes'].append(curr)
343
+ sequence_status['ce_loss_weights'].append(1.0)
344
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
345
+ curr += 1
346
+ curr_split_len += 1
347
+
348
+ # update sequence status
349
+ attn_modes.append("causal")
350
+ sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
351
+ curr_rope_id += curr_split_len
352
+
353
+ elif item['type'] == 'vit_image':
354
+ image_tensor = image_tensor_list.pop(0)
355
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob:
356
+ curr_rope_id += 1
357
+ continue
358
+
359
+ # add a <|startofimage|> token
360
+ sequence_status['packed_text_ids'].append(self.start_of_image)
361
+ sequence_status['packed_text_indexes'].append(curr)
362
+ curr += 1
363
+ curr_split_len += 1
364
+
365
+ # preprocess image
366
+ vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size)
367
+ num_img_tokens = vit_tokens.shape[0]
368
+ sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens))
369
+ curr += num_img_tokens
370
+ curr_split_len += num_img_tokens
371
+
372
+ sequence_status['packed_vit_tokens'].append(vit_tokens)
373
+ sequence_status['vit_token_seqlens'].append(num_img_tokens)
374
+ sequence_status['packed_vit_position_ids'].append(
375
+ self.get_flattened_position_ids(
376
+ image_tensor.size(1), image_tensor.size(2),
377
+ self.data_config.vit_patch_size,
378
+ max_num_patches_per_side=self.data_config.max_num_patch_per_side
379
+ )
380
+ )
381
+
382
+ # add a <|endofimage|> token
383
+ sequence_status['packed_text_ids'].append(self.end_of_image)
384
+ sequence_status['packed_text_indexes'].append(curr)
385
+ if item['special_token_loss'] == 1: # <|endofimage|> may have loss
386
+ sequence_status['ce_loss_indexes'].append(curr)
387
+ sequence_status['ce_loss_weights'].append(1.0)
388
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
389
+ curr += 1
390
+ curr_split_len += 1
391
+
392
+ # update sequence status
393
+ attn_modes.append("full")
394
+ sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len)
395
+ curr_rope_id += 1
396
+
397
+ elif item['type'] == 'vae_image':
398
+ image_tensor = image_tensor_list.pop(0)
399
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob:
400
+ # FIXME fix vae dropout in video2video setting.
401
+ curr_rope_id += 1
402
+ continue
403
+
404
+ # add a <|startofimage|> token
405
+ sequence_status['packed_text_ids'].append(self.start_of_image)
406
+ sequence_status['packed_text_indexes'].append(curr)
407
+ curr += 1
408
+ curr_split_len += 1
409
+
410
+ # preprocess image
411
+ sequence_status['vae_image_tensors'].append(image_tensor)
412
+ sequence_status['packed_latent_position_ids'].append(
413
+ self.get_flattened_position_ids(
414
+ image_tensor.size(1), image_tensor.size(2),
415
+ self.data_config.vae_image_downsample,
416
+ max_num_patches_per_side=self.data_config.max_latent_size
417
+ )
418
+ )
419
+ H, W = image_tensor.shape[1:]
420
+ h = H // self.data_config.vae_image_downsample
421
+ w = W // self.data_config.vae_image_downsample
422
+ sequence_status['vae_latent_shapes'].append((h, w))
423
+
424
+ num_img_tokens = w * h
425
+ sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens))
426
+ if item['loss'] == 1:
427
+ sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens))
428
+ if split_start:
429
+ timestep = np.random.randn()
430
+ else:
431
+ timestep = float('-inf')
432
+
433
+ sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens)
434
+ curr += num_img_tokens
435
+ curr_split_len += num_img_tokens
436
+
437
+ # add a <|endofimage|> token
438
+ sequence_status['packed_text_ids'].append(self.end_of_image)
439
+ sequence_status['packed_text_indexes'].append(curr)
440
+ # <|endofimage|> may have loss
441
+ if item['special_token_loss'] == 1:
442
+ sequence_status['ce_loss_indexes'].append(curr)
443
+ sequence_status['ce_loss_weights'].append(1.0)
444
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
445
+ curr += 1
446
+ curr_split_len += 1
447
+
448
+ # update sequence status
449
+ if split_start:
450
+ if item['loss'] == 1 and 'frame_delta' not in item.keys():
451
+ attn_modes.append("noise")
452
+ else:
453
+ attn_modes.append("full")
454
+ sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2))
455
+ if 'frame_delta' in item.keys():
456
+ curr_rope_id += item['frame_delta']
457
+ elif item['loss'] == 0:
458
+ curr_rope_id += 1
459
+
460
+ if item.get('split_end', True):
461
+ split_lens.append(curr_split_len)
462
+ sample_lens += curr_split_len
463
+
464
+ sequence_status['curr'] = curr
465
+ sequence_status['sample_lens'].append(sample_lens)
466
+ # prepare attention mask
467
+ if not self.use_flex:
468
+ sequence_status['nested_attention_masks'].append(
469
+ prepare_attention_mask_per_sample(split_lens, attn_modes)
470
+ )
471
+ else:
472
+ sequence_status['split_lens'].extend(split_lens)
473
+ sequence_status['attn_modes'].extend(attn_modes)
474
+
475
+ return sequence_status
476
+
477
+
478
+ class SimpleCustomBatch:
479
+ def __init__(self, batch):
480
+ data = batch[0]
481
+ self.batch_data_indexes = data['batch_data_indexes']
482
+ self.sequence_length = data["sequence_length"]
483
+ self.sample_lens = data["sample_lens"]
484
+ self.packed_text_ids = data["packed_text_ids"]
485
+ self.packed_text_indexes = data["packed_text_indexes"]
486
+ self.packed_position_ids = data["packed_position_ids"]
487
+
488
+ self.use_flex = "nested_attention_masks" not in data.keys()
489
+
490
+ if self.use_flex:
491
+ self.split_lens = data["split_lens"]
492
+ self.attn_modes = data["attn_modes"]
493
+ else:
494
+ self.nested_attention_masks = data["nested_attention_masks"]
495
+
496
+ if "padded_images" in data.keys():
497
+ self.padded_images = data["padded_images"]
498
+ self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"]
499
+ self.packed_latent_position_ids = data["packed_latent_position_ids"]
500
+ self.packed_vae_token_indexes = data["packed_vae_token_indexes"]
501
+
502
+ if "packed_vit_tokens" in data.keys():
503
+ self.packed_vit_tokens = data["packed_vit_tokens"]
504
+ self.packed_vit_position_ids = data["packed_vit_position_ids"]
505
+ self.packed_vit_token_indexes = data["packed_vit_token_indexes"]
506
+ self.vit_token_seqlens = data["vit_token_seqlens"]
507
+
508
+ if "packed_timesteps" in data.keys():
509
+ self.packed_timesteps = data["packed_timesteps"]
510
+ self.mse_loss_indexes = data["mse_loss_indexes"]
511
+
512
+ if "packed_label_ids" in data.keys():
513
+ self.packed_label_ids = data["packed_label_ids"]
514
+ self.ce_loss_indexes = data["ce_loss_indexes"]
515
+ self.ce_loss_weights = data["ce_loss_weights"]
516
+
517
+ def pin_memory(self):
518
+ self.packed_text_ids = self.packed_text_ids.pin_memory()
519
+ self.packed_text_indexes = self.packed_text_indexes.pin_memory()
520
+ self.packed_position_ids = self.packed_position_ids.pin_memory()
521
+
522
+ if not self.use_flex:
523
+ self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks]
524
+
525
+ if hasattr(self, 'padded_images'):
526
+ self.padded_images = self.padded_images.pin_memory()
527
+ self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory()
528
+ self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory()
529
+
530
+ if hasattr(self, 'packed_timesteps'):
531
+ self.packed_timesteps = self.packed_timesteps.pin_memory()
532
+ self.mse_loss_indexes = self.mse_loss_indexes.pin_memory()
533
+
534
+ if hasattr(self, 'packed_vit_tokens'):
535
+ self.packed_vit_tokens = self.packed_vit_tokens.pin_memory()
536
+ self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory()
537
+ self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory()
538
+ self.vit_token_seqlens = self.vit_token_seqlens.pin_memory()
539
+
540
+ if hasattr(self, 'packed_label_ids'):
541
+ self.packed_label_ids = self.packed_label_ids.pin_memory()
542
+ self.ce_loss_indexes = self.ce_loss_indexes.pin_memory()
543
+ self.ce_loss_weights = self.ce_loss_weights.pin_memory()
544
+
545
+ return self
546
+
547
+ def cuda(self, device):
548
+ self.packed_text_ids = self.packed_text_ids.to(device)
549
+ self.packed_text_indexes = self.packed_text_indexes.to(device)
550
+ self.packed_position_ids = self.packed_position_ids.to(device)
551
+
552
+ if not self.use_flex:
553
+ self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks]
554
+
555
+ if hasattr(self, 'padded_images'):
556
+ self.padded_images = self.padded_images.to(device)
557
+ self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device)
558
+ self.packed_latent_position_ids = self.packed_latent_position_ids.to(device)
559
+
560
+ if hasattr(self, 'packed_timesteps'):
561
+ self.packed_timesteps = self.packed_timesteps.to(device)
562
+ self.mse_loss_indexes = self.mse_loss_indexes.to(device)
563
+
564
+ if hasattr(self, 'packed_vit_tokens'):
565
+ self.packed_vit_tokens = self.packed_vit_tokens.to(device)
566
+ self.packed_vit_position_ids = self.packed_vit_position_ids.to(device)
567
+ self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device)
568
+ self.vit_token_seqlens = self.vit_token_seqlens.to(device)
569
+
570
+ if hasattr(self, 'packed_label_ids'):
571
+ self.packed_label_ids = self.packed_label_ids.to(device)
572
+ self.ce_loss_indexes = self.ce_loss_indexes.to(device)
573
+ self.ce_loss_weights = self.ce_loss_weights.to(device)
574
+
575
+ return self
576
+
577
+ def to_dict(self):
578
+ data = dict(
579
+ sequence_length = self.sequence_length,
580
+ sample_lens = self.sample_lens,
581
+ packed_text_ids = self.packed_text_ids,
582
+ packed_text_indexes = self.packed_text_indexes,
583
+ packed_position_ids = self.packed_position_ids,
584
+ batch_data_indexes = self.batch_data_indexes,
585
+ )
586
+
587
+ if not self.use_flex:
588
+ data['nested_attention_masks'] = self.nested_attention_masks
589
+ else:
590
+ data['split_lens'] = self.split_lens
591
+ data['attn_modes'] = self.attn_modes
592
+
593
+ if hasattr(self, 'padded_images'):
594
+ data['padded_images'] = self.padded_images
595
+ data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes
596
+ data['packed_latent_position_ids'] = self.packed_latent_position_ids
597
+ data['packed_vae_token_indexes'] = self.packed_vae_token_indexes
598
+
599
+ if hasattr(self, 'packed_vit_tokens'):
600
+ data['packed_vit_tokens'] = self.packed_vit_tokens
601
+ data['packed_vit_position_ids'] = self.packed_vit_position_ids
602
+ data['packed_vit_token_indexes'] = self.packed_vit_token_indexes
603
+ data['vit_token_seqlens'] = self.vit_token_seqlens
604
+
605
+ if hasattr(self, 'packed_timesteps'):
606
+ data['packed_timesteps'] = self.packed_timesteps
607
+ data['mse_loss_indexes'] = self.mse_loss_indexes
608
+
609
+ if hasattr(self, 'packed_label_ids'):
610
+ data['packed_label_ids'] = self.packed_label_ids
611
+ data['ce_loss_indexes'] = self.ce_loss_indexes
612
+ data['ce_loss_weights'] = self.ce_loss_weights
613
+
614
+ return data
615
+
616
+
617
+ def collate_wrapper():
618
+ def collate_fn(batch):
619
+ return SimpleCustomBatch(batch)
620
+ return collate_fn
data/dataset_info.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from .interleave_datasets import UnifiedEditIterableDataset
5
+ from .t2i_dataset import T2IIterableDataset
6
+ from .vlm_dataset import SftJSONLIterableDataset
7
+
8
+
9
+ DATASET_REGISTRY = {
10
+ 't2i_pretrain': T2IIterableDataset,
11
+ 'vlm_sft': SftJSONLIterableDataset,
12
+ 'unified_edit': UnifiedEditIterableDataset,
13
+ }
14
+
15
+
16
+ DATASET_INFO = {
17
+ 't2i_pretrain': {
18
+ 't2i': {
19
+ 'data_dir': 'your_data_path/bagel_example/t2i', # path of the parquet files
20
+ 'num_files': 10, # number of data units to be sharded across all ranks and workers
21
+ 'num_total_samples': 1000, # number of total samples in the dataset
22
+ },
23
+ },
24
+ 'unified_edit':{
25
+ 'seedxedit_multi': {
26
+ 'data_dir': 'your_data_path/bagel_example/editing/seedxedit_multi',
27
+ 'num_files': 10,
28
+ 'num_total_samples': 1000,
29
+ "parquet_info_path": 'your_data_path/bagel_example/editing/parquet_info/seedxedit_multi_nas.json', # information of the parquet files
30
+ },
31
+ },
32
+ 'vlm_sft': {
33
+ 'llava_ov': {
34
+ 'data_dir': 'your_data_path/bagel_example/vlm/images',
35
+ 'jsonl_path': 'your_data_path/bagel_example/vlm/llava_ov_si.jsonl',
36
+ 'num_total_samples': 1000
37
+ },
38
+ },
39
+ }
data/distributed_iterable_dataset.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import random
5
+ import torch
6
+
7
+
8
+ class DistributedIterableDataset(torch.utils.data.IterableDataset):
9
+ def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
10
+ self.dataset_name = dataset_name
11
+ self.local_rank = local_rank
12
+ self.world_size = world_size
13
+ self.num_workers = num_workers
14
+ self.rng = random.Random()
15
+ self.data_paths = None
16
+
17
+ def get_data_paths(self, *args, **kwargs):
18
+ raise NotImplementedError
19
+
20
+ def set_epoch(self, seed=42):
21
+ if self.data_paths is None:
22
+ return
23
+
24
+ if isinstance(self.data_paths[0], tuple):
25
+ data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
26
+ elif isinstance(self.data_paths[0], str):
27
+ data_paths = sorted(self.data_paths)
28
+ else:
29
+ raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
30
+
31
+ self.rng.seed(seed)
32
+ self.rng.shuffle(data_paths)
33
+
34
+ num_files_per_rank = len(data_paths) // self.world_size
35
+ local_start = self.local_rank * num_files_per_rank
36
+ local_end = (self.local_rank + 1) * num_files_per_rank
37
+ self.num_files_per_rank = num_files_per_rank
38
+ self.data_paths_per_rank = data_paths[local_start:local_end]
39
+
40
+ def get_data_paths_per_worker(self):
41
+ if self.data_paths is None:
42
+ return None
43
+
44
+ info = torch.utils.data.get_worker_info()
45
+ if info is None:
46
+ # Single worker: Use all files assigned to the rank
47
+ return self.data_paths_per_rank, 0
48
+
49
+ worker_id = info.id
50
+ num_files_per_worker = self.num_files_per_rank // info.num_workers
51
+ start = num_files_per_worker * worker_id
52
+ end = num_files_per_worker * (worker_id + 1)
53
+ data_paths_per_worker = self.data_paths_per_rank[start:end]
54
+
55
+ return data_paths_per_worker[::-1], worker_id
56
+
57
+ def __iter__(self):
58
+ raise NotImplementedError
data/interleave_datasets/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from .edit_dataset import UnifiedEditIterableDataset
5
+
data/interleave_datasets/edit_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import io
5
+ import random
6
+ from PIL import Image, ImageFile, PngImagePlugin
7
+
8
+ from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
9
+ from ..data_utils import pil_img2rgb
10
+
11
+
12
+ Image.MAX_IMAGE_PIXELS = 200000000
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ MaximumDecompressedSize = 1024
15
+ MegaByte = 2 ** 20
16
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
17
+
18
+
19
+ class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
20
+
21
+ def parse_row(self, row):
22
+ image_num = len(row["image_list"])
23
+ # randomly choose start and end, return [0, 1] when only two images
24
+ start_idx = random.choice(range(image_num - 1))
25
+ max_end = min(start_idx + 3, image_num)
26
+ end_idx = random.choice(range(start_idx + 1, max_end))
27
+
28
+ data = self._init_data()
29
+ data = self._add_image(
30
+ data,
31
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
32
+ need_loss=False,
33
+ need_vae=True,
34
+ need_vit=True,
35
+ )
36
+
37
+ if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
38
+ if end_idx == image_num - 1:
39
+ end_idx -= 1
40
+
41
+ instruction = ""
42
+ for idx in range(start_idx + 1, end_idx + 1):
43
+ instruction += random.choice(row["instruction_list"][idx-1]) + ". "
44
+ data = self._add_text(data, instruction.rstrip(), need_loss=False)
45
+ data = self._add_image(
46
+ data,
47
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
48
+ need_loss=True,
49
+ need_vae=False,
50
+ need_vit=False,
51
+ )
52
+ else:
53
+ for idx in range(start_idx + 1, end_idx + 1):
54
+ instruction = random.choice(row["instruction_list"][idx-1])
55
+ data = self._add_text(data, instruction, need_loss=False)
56
+ if idx != end_idx:
57
+ data = self._add_image(
58
+ data,
59
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
60
+ need_loss=True,
61
+ need_vae=True,
62
+ need_vit=True,
63
+ )
64
+ else:
65
+ data = self._add_image(
66
+ data,
67
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
68
+ need_loss=True,
69
+ need_vae=False,
70
+ need_vit=False,
71
+ )
72
+ return data
data/interleave_datasets/interleave_t2i_dataset.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pyarrow.parquet as pq
5
+
6
+ from ..distributed_iterable_dataset import DistributedIterableDataset
7
+ from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
8
+
9
+
10
+ class InterleavedBaseIterableDataset(DistributedIterableDataset):
11
+
12
+ def _init_data(self):
13
+ data = {
14
+ 'sequence_plan': [],
15
+ 'text_ids_list': [],
16
+ 'image_tensor_list': [],
17
+ 'num_tokens': 0,
18
+ }
19
+ return data
20
+
21
+ def _add_text(self, data, text, need_loss, enable_cfg=True):
22
+ text_ids = self.tokenizer.encode(text)
23
+ data['num_tokens'] += len(text_ids)
24
+ data['text_ids_list'].append(text_ids)
25
+ data['sequence_plan'].append(
26
+ {
27
+ 'type': 'text',
28
+ 'enable_cfg': int(enable_cfg),
29
+ 'loss': int(need_loss),
30
+ 'special_token_loss': 0,
31
+ 'special_token_label': None,
32
+ }
33
+ )
34
+ return data
35
+
36
+ def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True):
37
+ assert need_loss or need_vae or need_vit
38
+
39
+ if need_loss:
40
+ data['sequence_plan'].append(
41
+ {
42
+ 'type': 'vae_image',
43
+ 'enable_cfg': 0,
44
+ 'loss': 1,
45
+ 'special_token_loss': 0,
46
+ 'special_token_label': None,
47
+ }
48
+ )
49
+
50
+ image_tensor = self.transform(image)
51
+ height, width = image_tensor.shape[1:]
52
+ data['num_tokens'] += width * height // self.transform.stride ** 2
53
+ data['image_tensor_list'].append(image_tensor)
54
+
55
+ if need_vae:
56
+ data['sequence_plan'].append(
57
+ {
58
+ 'type': 'vae_image',
59
+ 'enable_cfg': int(enable_cfg),
60
+ 'loss': 0,
61
+ 'special_token_loss': 0,
62
+ 'special_token_label': None,
63
+ }
64
+ )
65
+
66
+ image_tensor = self.transform(image)
67
+ height, width = image_tensor.shape[1:]
68
+ data['num_tokens'] += width * height // self.transform.stride ** 2
69
+ data['image_tensor_list'].append(image_tensor.clone())
70
+
71
+ if need_vit:
72
+ data['sequence_plan'].append(
73
+ {
74
+ 'type': 'vit_image',
75
+ 'enable_cfg': int(enable_cfg),
76
+ 'loss': 0,
77
+ 'special_token_loss': 0,
78
+ 'special_token_label': None,
79
+ },
80
+ )
81
+ vit_image_tensor = self.vit_transform(image)
82
+ height, width = vit_image_tensor.shape[1:]
83
+ data['num_tokens'] += width * height // self.vit_transform.stride ** 2
84
+ data['image_tensor_list'].append(vit_image_tensor)
85
+
86
+ return data
87
+
88
+ def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
89
+ assert int(need_loss) + int(need_vae) == 1
90
+
91
+ if need_loss:
92
+ for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
93
+ current_sequence_plan = {
94
+ 'type': 'vae_image',
95
+ 'enable_cfg': 0,
96
+ 'loss': 1,
97
+ 'special_token_loss': 0,
98
+ 'special_token_label': None,
99
+ 'split_start': idx == 0,
100
+ 'split_end': idx == len(frames) - 1,
101
+ }
102
+ if idx < len(frame_indexes) - 1:
103
+ current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
104
+ data['sequence_plan'].append(current_sequence_plan)
105
+ image_tensor = self.transform(image)
106
+ height, width = image_tensor.shape[1:]
107
+ data['image_tensor_list'].append(image_tensor)
108
+ data['num_tokens'] += width * height // self.transform.stride ** 2
109
+
110
+ elif need_vae:
111
+ for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
112
+ current_sequence_plan = {
113
+ 'type': 'vae_image',
114
+ 'enable_cfg': int(enable_cfg),
115
+ 'loss': 0,
116
+ 'special_token_loss': 0,
117
+ 'special_token_label': None,
118
+ 'split_start': idx == 0,
119
+ 'split_end': idx == len(frames) - 1,
120
+ }
121
+ if idx < len(frame_indexes) - 1:
122
+ current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
123
+ data['sequence_plan'].append(current_sequence_plan)
124
+ image_tensor = self.transform(image)
125
+ height, width = image_tensor.shape[1:]
126
+ data['image_tensor_list'].append(image_tensor)
127
+ data['num_tokens'] += width * height // self.transform.stride ** 2
128
+
129
+ return data
130
+
131
+
132
+ class ParquetStandardIterableDataset(DistributedIterableDataset):
133
+
134
+ def __init__(
135
+ self, dataset_name, transform, tokenizer, vit_transform,
136
+ data_dir_list, num_used_data, parquet_info,
137
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
138
+ ):
139
+ """
140
+ data_dir_list: list of data directories contains parquet files
141
+ num_used_data: list of number of sampled data paths for each data directory
142
+ vit_transform: input transform for vit model.
143
+ """
144
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
145
+ self.transform = transform
146
+ self.vit_transform = vit_transform
147
+ self.tokenizer = tokenizer
148
+ self.data_status = data_status
149
+ self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
150
+ self.set_epoch()
151
+
152
+ def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
153
+ row_groups = []
154
+ for data_dir, num_data_path in zip(data_dir_list, num_used_data):
155
+ data_paths = get_parquet_data_paths([data_dir], [num_data_path])
156
+ for data_path in data_paths:
157
+ if data_path in parquet_info.keys():
158
+ num_row_groups = parquet_info[data_path]['num_row_groups']
159
+ for rg_idx in range(num_row_groups):
160
+ row_groups.append((data_path, rg_idx))
161
+ return row_groups
162
+
163
+ def parse_row(self, row):
164
+ raise NotImplementedError
165
+
166
+ def __iter__(self):
167
+ file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
168
+ if self.data_status is not None:
169
+ global_row_group_start_id = self.data_status[worker_id][0]
170
+ row_start_id = self.data_status[worker_id][1] + 1
171
+ else:
172
+ global_row_group_start_id = 0
173
+ row_start_id = 0
174
+
175
+ print(
176
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
177
+ f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
178
+ )
179
+
180
+ while True:
181
+ file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
182
+ for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
183
+ file_paths_per_worker_, start=global_row_group_start_id
184
+ ):
185
+ fs = init_arrow_pf_fs(parquet_file_path)
186
+ with fs.open_input_file(parquet_file_path) as f:
187
+ try:
188
+ fr = pq.ParquetFile(f)
189
+ df = fr.read_row_group(row_group_id).to_pandas()
190
+ df = df.iloc[row_start_id:]
191
+ except Exception as e:
192
+ print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
193
+ continue
194
+
195
+ for row_idx, row in df.iterrows():
196
+ try:
197
+ data = self.parse_row(row)
198
+ if len(data) == 0:
199
+ continue
200
+ data['data_indexes'] = {
201
+ "data_indexes": [global_row_group_idx, row_idx],
202
+ "worker_id": worker_id,
203
+ "dataset_name": self.dataset_name,
204
+ }
205
+ except Exception as e:
206
+ print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
207
+ continue
208
+ yield data
209
+
210
+ row_start_id = 0
211
+ global_row_group_start_id = 0
212
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
data/parquet_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import os
6
+ import xml.etree.ElementTree as ET
7
+ import subprocess
8
+ import logging
9
+
10
+ import pyarrow.fs as pf
11
+ import torch.distributed as dist
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
17
+ num_data_dirs = len(data_dir_list)
18
+ if world_size > 1:
19
+ chunk_size = (num_data_dirs + world_size - 1) // world_size
20
+ start_idx = rank * chunk_size
21
+ end_idx = min(start_idx + chunk_size, num_data_dirs)
22
+ local_data_dir_list = data_dir_list[start_idx:end_idx]
23
+ local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
24
+ else:
25
+ local_data_dir_list = data_dir_list
26
+ local_num_sampled_data_paths = num_sampled_data_paths
27
+
28
+ local_data_paths = []
29
+ for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths):
30
+ if data_dir.startswith("hdfs://"):
31
+ files = hdfs_ls_cmd(data_dir)
32
+ data_paths_per_dir = [
33
+ file for file in files if file.endswith(".parquet")
34
+ ]
35
+ else:
36
+ files = os.listdir(data_dir)
37
+ data_paths_per_dir = [
38
+ os.path.join(data_dir, name)
39
+ for name in files
40
+ if name.endswith(".parquet")
41
+ ]
42
+ repeat = num_data_path // len(data_paths_per_dir)
43
+ data_paths_per_dir = data_paths_per_dir * (repeat + 1)
44
+ local_data_paths.extend(data_paths_per_dir[:num_data_path])
45
+
46
+ if world_size > 1:
47
+ gather_list = [None] * world_size
48
+ dist.all_gather_object(gather_list, local_data_paths)
49
+
50
+ combined_chunks = []
51
+ for chunk_list in gather_list:
52
+ if chunk_list is not None:
53
+ combined_chunks.extend(chunk_list)
54
+ else:
55
+ combined_chunks = local_data_paths
56
+
57
+ return combined_chunks
58
+
59
+
60
+ # NOTE: cumtomize this function for your cluster
61
+ def get_hdfs_host():
62
+ return "hdfs://xxx"
63
+
64
+
65
+ # NOTE: cumtomize this function for your cluster
66
+ def get_hdfs_block_size():
67
+ return 134217728
68
+
69
+
70
+ # NOTE: cumtomize this function for your cluster
71
+ def get_hdfs_extra_conf():
72
+ return None
73
+
74
+
75
+ def init_arrow_pf_fs(parquet_file_path):
76
+ if parquet_file_path.startswith("hdfs://"):
77
+ fs = pf.HadoopFileSystem(
78
+ host=get_hdfs_host(),
79
+ port=0,
80
+ buffer_size=get_hdfs_block_size(),
81
+ extra_conf=get_hdfs_extra_conf(),
82
+ )
83
+ else:
84
+ fs = pf.LocalFileSystem()
85
+ return fs
86
+
87
+
88
+ def hdfs_ls_cmd(dir):
89
+ result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout
90
+ return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i]
data/t2i_dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import io
5
+ import json
6
+ import pyarrow.parquet as pq
7
+ import random
8
+ from PIL import Image
9
+
10
+ from .data_utils import pil_img2rgb
11
+ from .distributed_iterable_dataset import DistributedIterableDataset
12
+ from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
13
+
14
+ Image.MAX_IMAGE_PIXELS = 20_000_000
15
+
16
+
17
+ class T2IIterableDataset(DistributedIterableDataset):
18
+ def __init__(
19
+ self, dataset_name, transform, tokenizer, data_dir_list, num_used_data,
20
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
21
+ ):
22
+ """
23
+ data_dir_list: list of data directories contains parquet files
24
+ num_used_data: list of number of sampled data paths for each data directory
25
+ """
26
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
27
+ self.transform = transform
28
+ self.tokenizer = tokenizer
29
+ self.data_status = data_status
30
+ self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
31
+ self.set_epoch()
32
+
33
+ def get_data_paths(self, data_dir_list, num_used_data):
34
+ return get_parquet_data_paths(data_dir_list, num_used_data)
35
+
36
+ def __iter__(self):
37
+ data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
38
+ if self.data_status is not None:
39
+ parquet_start_id = self.data_status[worker_id][0]
40
+ row_group_start_id = self.data_status[worker_id][1]
41
+ row_start_id = self.data_status[worker_id][2] + 1
42
+ else:
43
+ parquet_start_id = 0
44
+ row_group_start_id = 0
45
+ row_start_id = 0
46
+ transform_stride = self.transform.stride
47
+
48
+ print(
49
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
50
+ f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
51
+ )
52
+
53
+ while True:
54
+ data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
55
+ for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id):
56
+ fs = init_arrow_pf_fs(parquet_file_path)
57
+ with fs.open_input_file(parquet_file_path) as f:
58
+ fr = pq.ParquetFile(f)
59
+ row_group_ids = list(range(fr.num_row_groups))
60
+ row_group_ids_ = row_group_ids[row_group_start_id:]
61
+
62
+ for row_group_id in row_group_ids_:
63
+ df = fr.read_row_group(row_group_id).to_pandas()
64
+ df = df.iloc[row_start_id:]
65
+
66
+ for row_idx, row in df.iterrows():
67
+ num_tokens = 0
68
+ try:
69
+ image_byte = row['image']
70
+ image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
71
+ except Exception as e:
72
+ print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
73
+ continue
74
+ image_tensor = self.transform(image)
75
+ height, width = image_tensor.shape[1:]
76
+ num_tokens += width * height // transform_stride ** 2
77
+
78
+ try:
79
+ caption_dict = row['captions']
80
+ caption_dict = json.loads(caption_dict)
81
+ except Exception as e:
82
+ print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
83
+ continue
84
+
85
+ caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
86
+ if len(caps_token) == 0:
87
+ print(f'no caption in rg#{row_group_id}, {parquet_file_path}')
88
+ caption_token = self.tokenizer.encode(' ')
89
+ else:
90
+ caption_token = random.choice(caps_token)
91
+
92
+ sequence_plan, text_ids_list = [], []
93
+ text_ids = caption_token
94
+ num_tokens += len(caption_token)
95
+ text_ids_list.append(text_ids)
96
+ sequence_plan.append({
97
+ 'type': 'text',
98
+ 'enable_cfg': 1,
99
+ 'loss': 0,
100
+ 'special_token_loss': 0,
101
+ 'special_token_label': None,
102
+ })
103
+
104
+ sequence_plan.append({
105
+ 'type': 'vae_image',
106
+ 'enable_cfg': 0,
107
+ 'loss': 1,
108
+ 'special_token_loss': 0,
109
+ 'special_token_label': None,
110
+ })
111
+
112
+ sample = dict(
113
+ image_tensor_list=[image_tensor],
114
+ text_ids_list=text_ids_list,
115
+ num_tokens=num_tokens,
116
+ sequence_plan=sequence_plan,
117
+ data_indexes={
118
+ "data_indexes": [parquet_idx, row_group_id, row_idx],
119
+ "worker_id": worker_id,
120
+ "dataset_name": self.dataset_name,
121
+ }
122
+ )
123
+ yield sample
124
+
125
+ row_start_id = 0
126
+ row_group_start_id = 0
127
+ parquet_start_id = 0
128
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
data/transforms.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import random
5
+ from PIL import Image
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from torchvision import transforms
11
+ from torchvision.transforms import functional as F
12
+ from torchvision.transforms import InterpolationMode
13
+
14
+
15
+ class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
16
+ """Resize the input image so that its longest side and shortest side are within a specified range,
17
+ ensuring that both sides are divisible by a specified stride.
18
+
19
+ Args:
20
+ max_size (int): Maximum size for the longest edge of the image.
21
+ min_size (int): Minimum size for the shortest edge of the image.
22
+ stride (int): Value by which the height and width of the image must be divisible.
23
+ max_pixels (int): Maximum pixels for the full image.
24
+ interpolation (InterpolationMode): Desired interpolation enum defined by
25
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
26
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
27
+ ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
28
+ The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
29
+ antialias (bool, optional): Whether to apply antialiasing (default is True).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ max_size: int,
35
+ min_size: int,
36
+ stride: int,
37
+ max_pixels: int,
38
+ interpolation=InterpolationMode.BICUBIC,
39
+ antialias=True
40
+ ):
41
+ super().__init__()
42
+ self.max_size = max_size
43
+ self.min_size = min_size
44
+ self.stride = stride
45
+ self.max_pixels = max_pixels
46
+ self.interpolation = interpolation
47
+ self.antialias = antialias
48
+
49
+ def _make_divisible(self, value, stride):
50
+ """Ensure the value is divisible by the stride."""
51
+ return max(stride, int(round(value / stride) * stride))
52
+
53
+ def _apply_scale(self, width, height, scale):
54
+ new_width = round(width * scale)
55
+ new_height = round(height * scale)
56
+ new_width = self._make_divisible(new_width, self.stride)
57
+ new_height = self._make_divisible(new_height, self.stride)
58
+ return new_width, new_height
59
+
60
+ def forward(self, img, img_num=1):
61
+ """
62
+ Args:
63
+ img (PIL Image): Image to be resized.
64
+ img_num (int): Number of images, used to change max_tokens.
65
+ Returns:
66
+ PIL Image or Tensor: Rescaled image with divisible dimensions.
67
+ """
68
+ if isinstance(img, torch.Tensor):
69
+ height, width = img.shape[-2:]
70
+ else:
71
+ width, height = img.size
72
+
73
+ scale = min(self.max_size / max(width, height), 1.0)
74
+ scale = max(scale, self.min_size / min(width, height))
75
+ new_width, new_height = self._apply_scale(width, height, scale)
76
+
77
+ # Ensure the number of pixels does not exceed max_pixels
78
+ if new_width * new_height > self.max_pixels / img_num:
79
+ scale = self.max_pixels / img_num / (new_width * new_height)
80
+ new_width, new_height = self._apply_scale(new_width, new_height, scale)
81
+
82
+ # Ensure longest edge does not exceed max_size
83
+ if max(new_width, new_height) > self.max_size:
84
+ scale = self.max_size / max(new_width, new_height)
85
+ new_width, new_height = self._apply_scale(new_width, new_height, scale)
86
+
87
+ return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
88
+
89
+
90
+ class ImageTransform:
91
+ def __init__(
92
+ self,
93
+ max_image_size,
94
+ min_image_size,
95
+ image_stride,
96
+ max_pixels=14*14*9*1024,
97
+ image_mean=[0.5, 0.5, 0.5],
98
+ image_std=[0.5, 0.5, 0.5]
99
+ ):
100
+ self.stride = image_stride
101
+
102
+ self.resize_transform = MaxLongEdgeMinShortEdgeResize(
103
+ max_size=max_image_size,
104
+ min_size=min_image_size,
105
+ stride=image_stride,
106
+ max_pixels=max_pixels,
107
+ )
108
+ self.to_tensor_transform = transforms.ToTensor()
109
+ self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
110
+
111
+ def __call__(self, img, img_num=1):
112
+ img = self.resize_transform(img, img_num=img_num)
113
+ img = self.to_tensor_transform(img)
114
+ img = self.normalize_transform(img)
115
+ return img
116
+
117
+
118
+ def decolorization(image):
119
+ gray_image = image.convert('L')
120
+ return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
121
+
122
+
123
+ def downscale(image, scale_factor):
124
+ new_width = int(round(image.width * scale_factor))
125
+ new_height = int(round(image.height * scale_factor))
126
+ new_width = max(1, new_width)
127
+ new_height = max(1, new_height)
128
+ return image.resize((new_width, new_height), resample=Image.BICUBIC)
129
+
130
+
131
+ def crop(image, crop_factors):
132
+ target_h, target_w = crop_factors
133
+ img_w, img_h = image.size
134
+
135
+ if target_h > img_h or target_w > img_w:
136
+ raise ValueError("Crop size exceeds image dimensions")
137
+
138
+ x = random.randint(0, img_w - target_w)
139
+ y = random.randint(0, img_h - target_h)
140
+
141
+ return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
142
+
143
+
144
+ def motion_blur_opencv(image, kernel_size=15, angle=0):
145
+ # 线性核
146
+ kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
147
+ kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
148
+
149
+ # 旋转核
150
+ center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
151
+ M = cv2.getRotationMatrix2D(center, angle, 1)
152
+ rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
153
+
154
+ # 归一化核
155
+ rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
156
+
157
+ img = np.array(image)
158
+ if img.ndim == 2:
159
+ blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
160
+ else:
161
+ # 对于彩色图像,各通道独立卷积
162
+ blurred = np.zeros_like(img)
163
+ for c in range(img.shape[2]):
164
+ blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
165
+
166
+ return Image.fromarray(blurred.astype(np.uint8))
167
+
168
+
169
+ def shuffle_patch(image, num_splits, gap_size=2):
170
+ """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
171
+ h_splits, w_splits = num_splits
172
+ img_w, img_h = image.size
173
+
174
+ base_patch_h = img_h // h_splits
175
+ patch_heights = [base_patch_h] * (h_splits - 1)
176
+ patch_heights.append(img_h - sum(patch_heights))
177
+
178
+ base_patch_w = img_w // w_splits
179
+ patch_widths = [base_patch_w] * (w_splits - 1)
180
+ patch_widths.append(img_w - sum(patch_widths))
181
+
182
+ patches = []
183
+ current_y = 0
184
+ for i in range(h_splits):
185
+ current_x = 0
186
+ patch_h = patch_heights[i]
187
+ for j in range(w_splits):
188
+ patch_w = patch_widths[j]
189
+ patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
190
+ patches.append(patch)
191
+ current_x += patch_w
192
+ current_y += patch_h
193
+
194
+ random.shuffle(patches)
195
+
196
+ total_width = sum(patch_widths) + (w_splits - 1) * gap_size
197
+ total_height = sum(patch_heights) + (h_splits - 1) * gap_size
198
+ new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
199
+
200
+ current_y = 0 # 当前行的起始 Y 坐标
201
+ patch_idx = 0 # 当前处理的块索引
202
+ for i in range(h_splits):
203
+ current_x = 0 # 当前列的起始 X 坐标
204
+ patch_h = patch_heights[i] # 当前行块的高度
205
+ for j in range(w_splits):
206
+ # 取出打乱后的块
207
+ patch = patches[patch_idx]
208
+ patch_w = patch_widths[j] # 当前列块的宽度
209
+ # 粘贴块(左上角坐标为 (current_x, current_y))
210
+ new_image.paste(patch, (current_x, current_y))
211
+ # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
212
+ current_x += patch_w + gap_size
213
+ patch_idx += 1
214
+ # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
215
+ current_y += patch_h + gap_size
216
+
217
+ return new_image
218
+
219
+
220
+ def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
221
+ """
222
+ 图像分割后随机空白部分patch,用于inpainting任务
223
+
224
+ 参数:
225
+ image: PIL.Image 输入图像(RGB模式)
226
+ h_splits: int 行分割数(垂直方向分割块数)
227
+ w_splits: int 列分割数(水平方向分割块数)
228
+ blank_ratio: float 空白patch的比例(0~1)
229
+ blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
230
+
231
+ 返回:
232
+ PIL.Image 处理后拼接的图像
233
+ """
234
+ h_splits, w_splits = num_splits
235
+ img_w, img_h = image.size
236
+
237
+ base_patch_h = img_h // h_splits
238
+ patch_heights = [base_patch_h] * (h_splits - 1)
239
+ patch_heights.append(img_h - sum(patch_heights))
240
+
241
+ base_patch_w = img_w // w_splits
242
+ patch_widths = [base_patch_w] * (w_splits - 1)
243
+ patch_widths.append(img_w - sum(patch_widths))
244
+
245
+ patches = []
246
+ current_y = 0
247
+ for i in range(h_splits):
248
+ current_x = 0
249
+ patch_h = patch_heights[i]
250
+ for j in range(w_splits):
251
+ patch_w = patch_widths[j]
252
+ patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
253
+ patches.append(patch)
254
+ current_x += patch_w
255
+ current_y += patch_h
256
+
257
+ total_patches = h_splits * w_splits
258
+ num_blank = int(total_patches * blank_ratio)
259
+ num_blank = max(0, min(num_blank, total_patches))
260
+ blank_indices = random.sample(range(total_patches), num_blank)
261
+
262
+ processed_patches = []
263
+ for idx, patch in enumerate(patches):
264
+ if idx in blank_indices:
265
+ blank_patch = Image.new("RGB", patch.size, color=blank_color)
266
+ processed_patches.append(blank_patch)
267
+ else:
268
+ processed_patches.append(patch)
269
+
270
+ # 创建结果图像(尺寸与原图一致)
271
+ result_image = Image.new("RGB", (img_w, img_h))
272
+ current_y = 0
273
+ patch_idx = 0
274
+ for i in range(h_splits):
275
+ current_x = 0
276
+ patch_h = patch_heights[i]
277
+ for j in range(w_splits):
278
+ # 取出处理后的patch
279
+ patch = processed_patches[patch_idx]
280
+ patch_w = patch_widths[j]
281
+ # 粘贴到原位置
282
+ result_image.paste(patch, (current_x, current_y))
283
+ current_x += patch_w
284
+ patch_idx += 1
285
+ current_y += patch_h
286
+
287
+ return result_image
data/video_utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 OpenGVLab
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under MIT, with the full license text
8
+ # available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+
13
+ import io
14
+ import os
15
+ import random
16
+ import re
17
+
18
+ import numpy as np
19
+ import decord
20
+ from PIL import Image
21
+
22
+
23
+ def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
24
+ if sample in ['rand', 'middle']: # uniform sampling
25
+ acc_samples = min(num_frames, vlen)
26
+ # split the video into `acc_samples` intervals, and sample from each interval.
27
+ intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
28
+ ranges = []
29
+ for idx, interv in enumerate(intervals[:-1]):
30
+ ranges.append((interv, intervals[idx + 1] - 1))
31
+ if sample == 'rand':
32
+ try:
33
+ frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
34
+ except:
35
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
36
+ frame_indices.sort()
37
+ frame_indices = list(frame_indices)
38
+ elif fix_start is not None:
39
+ frame_indices = [x[0] + fix_start for x in ranges]
40
+ elif sample == 'middle':
41
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
42
+ else:
43
+ raise NotImplementedError
44
+
45
+ if len(frame_indices) < num_frames: # padded with last frame
46
+ padded_frame_indices = [frame_indices[-1]] * num_frames
47
+ padded_frame_indices[:len(frame_indices)] = frame_indices
48
+ frame_indices = padded_frame_indices
49
+ elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
50
+ output_fps = float(sample[3:])
51
+ duration = float(vlen) / input_fps
52
+ delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
53
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
54
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
55
+ frame_indices = [e for e in frame_indices if e < vlen]
56
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
57
+ frame_indices = frame_indices[:max_num_frames]
58
+ else:
59
+ raise ValueError
60
+ return frame_indices
61
+
62
+
63
+ def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
64
+ video_reader = decord.VideoReader(video_path, num_threads=1)
65
+ vlen = len(video_reader)
66
+ fps = video_reader.get_avg_fps()
67
+ duration = vlen / float(fps)
68
+ if clip:
69
+ start, end = clip
70
+ duration = end - start
71
+ vlen = int(duration * fps)
72
+ start_index = int(start * fps)
73
+
74
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
75
+
76
+ frame_indices = get_frame_indices(
77
+ t_num_frames, vlen, sample=sample, fix_start=fix_start,
78
+ input_fps=fps
79
+ )
80
+ if clip:
81
+ frame_indices = [f + start_index for f in frame_indices]
82
+ frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
83
+ frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
84
+ return frames
85
+
86
+
87
+ def extract_frame_number(filename):
88
+ # Extract the numeric part from the filename using regular expressions
89
+ match = re.search(r'_(\d+).jpg$', filename)
90
+ return int(match.group(1)) if match else -1
91
+
92
+
93
+ def sort_frames(frame_paths):
94
+ # Extract filenames from each path and sort by their numeric part
95
+ return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
96
+
97
+
98
+ def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
99
+ image_list = sort_frames(list(os.listdir(video_path)))
100
+ frames = []
101
+ for image in image_list:
102
+ fp = os.path.join(video_path, image)
103
+ frame = Image.open(fp).convert('RGB')
104
+ frames.append(frame)
105
+ vlen = len(frames)
106
+
107
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
108
+
109
+ if vlen > t_num_frames:
110
+ frame_indices = get_frame_indices(
111
+ t_num_frames, vlen, sample=sample, fix_start=fix_start
112
+ )
113
+ frames = [frames[i] for i in frame_indices]
114
+ return frames
115
+
116
+
117
+ class FrameSampler:
118
+ def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
119
+ self.max_num_frames = max_num_frames
120
+ self.min_num_frames = min_num_frames
121
+ self.sample = sample
122
+
123
+ def __call__(self, file_name):
124
+ fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
125
+ frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
126
+ return frames
127
+
128
+
129
+ def decode_video_byte(video_bytes):
130
+ video_stream = io.BytesIO(video_bytes)
131
+ vr = decord.VideoReader(video_stream)
132
+ return vr
133
+
134
+
135
+ def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
136
+ if isinstance(mp4_p, str):
137
+ vr = decord.VideoReader(mp4_p, num_threads=1)
138
+ elif isinstance(mp4_p, decord.video_reader.VideoReader):
139
+ vr = mp4_p
140
+ video_fps = vr.get_avg_fps() # 获取视频的帧率
141
+ video_duration = len(vr) / video_fps
142
+ if n_frames is not None:
143
+ if random_sample:
144
+ frame_indices = sorted(random.sample(range(len(vr)), n_frames))
145
+ else:
146
+ frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
147
+ else:
148
+ frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
149
+ frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
150
+ frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
151
+ if not return_frame_indices:
152
+ return frames, video_duration
153
+ else:
154
+ return frames, video_duration, frame_indices
155
+
156
+
157
+ def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
158
+ if isinstance(mp4_p, str):
159
+ vr = decord.VideoReader(mp4_p, num_threads=1)
160
+ elif isinstance(mp4_p, decord.video_reader.VideoReader):
161
+ vr = mp4_p
162
+ # sample the frames in frame_indices
163
+ frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
164
+ frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
165
+ return frames
data/vlm_dataset.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+ import os
6
+ import traceback
7
+ from PIL import Image, ImageFile, PngImagePlugin
8
+
9
+ from .data_utils import pil_img2rgb
10
+ from .distributed_iterable_dataset import DistributedIterableDataset
11
+
12
+
13
+ Image.MAX_IMAGE_PIXELS = 200000000
14
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
15
+ MaximumDecompressedSize = 1024
16
+ MegaByte = 2 ** 20
17
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
18
+
19
+
20
+ class SftJSONLIterableDataset(DistributedIterableDataset):
21
+ def __init__(
22
+ self, dataset_name, transform, tokenizer, frame_sampler,
23
+ jsonl_path_list, data_dir_list, num_used_data,
24
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
25
+ shuffle_lines=False, shuffle_seed=0,
26
+ ):
27
+ """
28
+ jsonl_path_list: list of jsonl file paths
29
+ data_dir_list: list of image directories containing the images of each jsonl file
30
+ num_used_data: list of number of sampled data points for each jsonl
31
+ """
32
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
33
+ self.transform = transform
34
+ self.tokenizer = tokenizer
35
+ self.frame_sampler = frame_sampler
36
+ self.data_status = data_status
37
+ self.data_paths = self.get_data_paths(
38
+ jsonl_path_list,
39
+ data_dir_list,
40
+ num_used_data,
41
+ shuffle_lines,
42
+ shuffle_seed,
43
+ )
44
+ self.set_epoch()
45
+
46
+ def get_data_paths(
47
+ self,
48
+ jsonl_path_list,
49
+ data_dir_list,
50
+ num_used_data,
51
+ shuffle_lines,
52
+ shuffle_seed,
53
+ ):
54
+ data_paths = []
55
+ for jsonl_path, image_dir, num_data_point in zip(
56
+ jsonl_path_list, data_dir_list, num_used_data
57
+ ):
58
+ with open(jsonl_path, 'r') as f:
59
+ raw_data = f.readlines()
60
+ if shuffle_lines:
61
+ self.rng.seed(shuffle_seed)
62
+ self.rng.shuffle(raw_data)
63
+ raw_data = raw_data[:num_data_point]
64
+ data_paths.extend([(json_data, image_dir) for json_data in raw_data])
65
+ return data_paths
66
+
67
+ def change_format(self, data, num_images):
68
+ elements = []
69
+ for conversation in data['conversations']:
70
+ if conversation['from'] == 'human':
71
+ if '<image>' not in conversation['value']:
72
+ elements.append({
73
+ 'type': 'text',
74
+ 'has_loss': 0,
75
+ 'text': conversation['value'],
76
+ })
77
+ else:
78
+ text_list = conversation['value'].split('<image>')
79
+ for idx, text in enumerate(text_list):
80
+ if text.strip() != '':
81
+ elements.append({
82
+ 'type': 'text',
83
+ 'has_loss': 0,
84
+ 'text': text.strip(),
85
+ })
86
+ if (idx != len(text_list) - 1) and (idx < num_images):
87
+ elements.append({'type': 'image',})
88
+ elif conversation['from'] == 'gpt':
89
+ elements.append({
90
+ 'type': 'text',
91
+ 'has_loss': 1,
92
+ 'text': conversation['value'],
93
+ })
94
+ return elements
95
+
96
+ def __iter__(self):
97
+ data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
98
+ if self.data_status is not None:
99
+ row_start_id = self.data_status[worker_id] + 1
100
+ else:
101
+ row_start_id = 0
102
+ transform_stride = self.transform.stride
103
+
104
+ print(
105
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
106
+ f"resuming data at row#{row_start_id}"
107
+ )
108
+
109
+ while True:
110
+ data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
111
+ for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id):
112
+ num_tokens = 0
113
+ image_tensor_list = []
114
+ text_ids_list = []
115
+ sequence_plan = []
116
+
117
+ try:
118
+ data_item = json.loads(data)
119
+ raw_images = None
120
+ if 'image' in data_item:
121
+ if type(data_item['image']) == list:
122
+ raw_images = [
123
+ pil_img2rgb(Image.open(os.path.join(image_dir, image)))
124
+ for image in data_item['image']
125
+ ]
126
+ else:
127
+ raw_images = [
128
+ pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'])))
129
+ ]
130
+ elif 'video' in data_item:
131
+ raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video']))
132
+ special_tokens = '<image>' * len(raw_images)
133
+ for item in data_item['conversations']:
134
+ if '<video>' in item['value']:
135
+ item['value'] = item['value'].replace('<video>', special_tokens)
136
+ break
137
+ else:
138
+ raise ValueError("Cannot find <video> in the conversation!")
139
+ except:
140
+ traceback.print_exc()
141
+ continue
142
+
143
+ if raw_images:
144
+ for raw_image in raw_images:
145
+ image_tensor = self.transform(raw_image, img_num=len(raw_images))
146
+ image_tensor_list.append(image_tensor)
147
+ height, width = image_tensor.shape[1:]
148
+ num_tokens += width * height // transform_stride ** 2
149
+
150
+ elements = self.change_format(data_item, len(image_tensor_list))
151
+
152
+ for item in elements:
153
+ if item['type'] == 'text':
154
+ text_data = item['text']
155
+ text_ids = self.tokenizer.encode(text_data)
156
+ if len(text_ids) > 0:
157
+ text_ids_list.append(text_ids)
158
+ num_tokens += len(text_ids)
159
+ current_plan = {
160
+ 'type': 'text',
161
+ 'enable_cfg': 0,
162
+ 'loss': item['has_loss'],
163
+ 'special_token_loss': 0,
164
+ 'special_token_label': None,
165
+ }
166
+ sequence_plan.append(current_plan)
167
+ elif item['type'] == 'image':
168
+ current_plan = {
169
+ 'type': 'vit_image',
170
+ 'enable_cfg': 0,
171
+ 'loss': 0,
172
+ 'special_token_loss': 0,
173
+ 'special_token_label': None,
174
+ }
175
+ sequence_plan.append(current_plan)
176
+
177
+ has_loss = [item['loss'] for item in sequence_plan]
178
+ if sum(has_loss) == 0:
179
+ print(f'No loss defined, skipped.')
180
+ continue
181
+
182
+ yield dict(
183
+ image_tensor_list=image_tensor_list,
184
+ text_ids_list=text_ids_list,
185
+ sequence_plan=sequence_plan,
186
+ num_tokens=num_tokens,
187
+ data_indexes={
188
+ "data_indexes": row_idx,
189
+ "worker_id": worker_id,
190
+ "dataset_name": self.dataset_name,
191
+ }
192
+ )
193
+
194
+ row_start_id = 0
195
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")