# Copyright 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import json import os import traceback from PIL import Image, ImageFile, PngImagePlugin from .data_utils import pil_img2rgb from .distributed_iterable_dataset import DistributedIterableDataset Image.MAX_IMAGE_PIXELS = 200000000 ImageFile.LOAD_TRUNCATED_IMAGES = True MaximumDecompressedSize = 1024 MegaByte = 2 ** 20 PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte class SftJSONLIterableDataset(DistributedIterableDataset): def __init__( self, dataset_name, transform, tokenizer, frame_sampler, jsonl_path_list, data_dir_list, num_used_data, local_rank=0, world_size=1, num_workers=8, data_status=None, shuffle_lines=False, shuffle_seed=0, ): """ jsonl_path_list: list of jsonl file paths data_dir_list: list of image directories containing the images of each jsonl file num_used_data: list of number of sampled data points for each jsonl """ super().__init__(dataset_name, local_rank, world_size, num_workers) self.transform = transform self.tokenizer = tokenizer self.frame_sampler = frame_sampler self.data_status = data_status self.data_paths = self.get_data_paths( jsonl_path_list, data_dir_list, num_used_data, shuffle_lines, shuffle_seed, ) self.set_epoch() def get_data_paths( self, jsonl_path_list, data_dir_list, num_used_data, shuffle_lines, shuffle_seed, ): data_paths = [] for jsonl_path, image_dir, num_data_point in zip( jsonl_path_list, data_dir_list, num_used_data ): with open(jsonl_path, 'r') as f: raw_data = f.readlines() if shuffle_lines: self.rng.seed(shuffle_seed) self.rng.shuffle(raw_data) raw_data = raw_data[:num_data_point] data_paths.extend([(json_data, image_dir) for json_data in raw_data]) return data_paths def change_format(self, data, num_images): elements = [] for conversation in data['conversations']: if conversation['from'] == 'human': if '' not in conversation['value']: elements.append({ 'type': 'text', 'has_loss': 0, 'text': conversation['value'], }) else: text_list = conversation['value'].split('') for idx, text in enumerate(text_list): if text.strip() != '': elements.append({ 'type': 'text', 'has_loss': 0, 'text': text.strip(), }) if (idx != len(text_list) - 1) and (idx < num_images): elements.append({'type': 'image',}) elif conversation['from'] == 'gpt': elements.append({ 'type': 'text', 'has_loss': 1, 'text': conversation['value'], }) return elements def __iter__(self): data_paths_per_worker, worker_id = self.get_data_paths_per_worker() if self.data_status is not None: row_start_id = self.data_status[worker_id] + 1 else: row_start_id = 0 transform_stride = self.transform.stride print( f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: " f"resuming data at row#{row_start_id}" ) while True: data_paths_per_worker_ = data_paths_per_worker[row_start_id:] for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id): num_tokens = 0 image_tensor_list = [] text_ids_list = [] sequence_plan = [] try: data_item = json.loads(data) raw_images = None if 'image' in data_item: if type(data_item['image']) == list: raw_images = [ pil_img2rgb(Image.open(os.path.join(image_dir, image))) for image in data_item['image'] ] else: raw_images = [ pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image']))) ] elif 'video' in data_item: raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video'])) special_tokens = '' * len(raw_images) for item in data_item['conversations']: if '