# Copyright 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import pyarrow.parquet as pq from ..distributed_iterable_dataset import DistributedIterableDataset from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs class InterleavedBaseIterableDataset(DistributedIterableDataset): def _init_data(self): data = { 'sequence_plan': [], 'text_ids_list': [], 'image_tensor_list': [], 'num_tokens': 0, } return data def _add_text(self, data, text, need_loss, enable_cfg=True): text_ids = self.tokenizer.encode(text) data['num_tokens'] += len(text_ids) data['text_ids_list'].append(text_ids) data['sequence_plan'].append( { 'type': 'text', 'enable_cfg': int(enable_cfg), 'loss': int(need_loss), 'special_token_loss': 0, 'special_token_label': None, } ) return data def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True): assert need_loss or need_vae or need_vit if need_loss: data['sequence_plan'].append( { 'type': 'vae_image', 'enable_cfg': 0, 'loss': 1, 'special_token_loss': 0, 'special_token_label': None, } ) image_tensor = self.transform(image) height, width = image_tensor.shape[1:] data['num_tokens'] += width * height // self.transform.stride ** 2 data['image_tensor_list'].append(image_tensor) if need_vae: data['sequence_plan'].append( { 'type': 'vae_image', 'enable_cfg': int(enable_cfg), 'loss': 0, 'special_token_loss': 0, 'special_token_label': None, } ) image_tensor = self.transform(image) height, width = image_tensor.shape[1:] data['num_tokens'] += width * height // self.transform.stride ** 2 data['image_tensor_list'].append(image_tensor.clone()) if need_vit: data['sequence_plan'].append( { 'type': 'vit_image', 'enable_cfg': int(enable_cfg), 'loss': 0, 'special_token_loss': 0, 'special_token_label': None, }, ) vit_image_tensor = self.vit_transform(image) height, width = vit_image_tensor.shape[1:] data['num_tokens'] += width * height // self.vit_transform.stride ** 2 data['image_tensor_list'].append(vit_image_tensor) return data def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True): assert int(need_loss) + int(need_vae) == 1 if need_loss: for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)): current_sequence_plan = { 'type': 'vae_image', 'enable_cfg': 0, 'loss': 1, 'special_token_loss': 0, 'special_token_label': None, 'split_start': idx == 0, 'split_end': idx == len(frames) - 1, } if idx < len(frame_indexes) - 1: current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx data['sequence_plan'].append(current_sequence_plan) image_tensor = self.transform(image) height, width = image_tensor.shape[1:] data['image_tensor_list'].append(image_tensor) data['num_tokens'] += width * height // self.transform.stride ** 2 elif need_vae: for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)): current_sequence_plan = { 'type': 'vae_image', 'enable_cfg': int(enable_cfg), 'loss': 0, 'special_token_loss': 0, 'special_token_label': None, 'split_start': idx == 0, 'split_end': idx == len(frames) - 1, } if idx < len(frame_indexes) - 1: current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx data['sequence_plan'].append(current_sequence_plan) image_tensor = self.transform(image) height, width = image_tensor.shape[1:] data['image_tensor_list'].append(image_tensor) data['num_tokens'] += width * height // self.transform.stride ** 2 return data class ParquetStandardIterableDataset(DistributedIterableDataset): def __init__( self, dataset_name, transform, tokenizer, vit_transform, data_dir_list, num_used_data, parquet_info, local_rank=0, world_size=1, num_workers=8, data_status=None, ): """ data_dir_list: list of data directories contains parquet files num_used_data: list of number of sampled data paths for each data directory vit_transform: input transform for vit model. """ super().__init__(dataset_name, local_rank, world_size, num_workers) self.transform = transform self.vit_transform = vit_transform self.tokenizer = tokenizer self.data_status = data_status self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info) self.set_epoch() def get_data_paths(self, data_dir_list, num_used_data, parquet_info): row_groups = [] for data_dir, num_data_path in zip(data_dir_list, num_used_data): data_paths = get_parquet_data_paths([data_dir], [num_data_path]) for data_path in data_paths: if data_path in parquet_info.keys(): num_row_groups = parquet_info[data_path]['num_row_groups'] for rg_idx in range(num_row_groups): row_groups.append((data_path, rg_idx)) return row_groups def parse_row(self, row): raise NotImplementedError def __iter__(self): file_paths_per_worker, worker_id = self.get_data_paths_per_worker() if self.data_status is not None: global_row_group_start_id = self.data_status[worker_id][0] row_start_id = self.data_status[worker_id][1] + 1 else: global_row_group_start_id = 0 row_start_id = 0 print( f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: " f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}" ) while True: file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:] for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate( file_paths_per_worker_, start=global_row_group_start_id ): fs = init_arrow_pf_fs(parquet_file_path) with fs.open_input_file(parquet_file_path) as f: try: fr = pq.ParquetFile(f) df = fr.read_row_group(row_group_id).to_pandas() df = df.iloc[row_start_id:] except Exception as e: print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}') continue for row_idx, row in df.iterrows(): try: data = self.parse_row(row) if len(data) == 0: continue data['data_indexes'] = { "data_indexes": [global_row_group_idx, row_idx], "worker_id": worker_id, "dataset_name": self.dataset_name, } except Exception as e: print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}') continue yield data row_start_id = 0 global_row_group_start_id = 0 print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")