diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..ccca00fe6938ae24dd37c2a8b58bcc17d9059051 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/arch.png filter=lfs diff=lfs merge=lfs -text
+assets/emerging_curves.png filter=lfs diff=lfs merge=lfs -text
+assets/teaser.webp filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..3abe8b44b1cde09d2ab9915cf506d6928439fa33
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+wandb
+__pycache__
+.vscode
+notebooks
+results
+*.ipynb_checkpoints
+eval_results
+tests
+.DS_Store
+gradio.sh
\ No newline at end of file
diff --git a/EVAL.md b/EVAL.md
new file mode 100644
index 0000000000000000000000000000000000000000..a2c268b0702881c2d3bd2158ec506ca0f1581106
--- /dev/null
+++ b/EVAL.md
@@ -0,0 +1,78 @@
+# VLM
+We follow [InternVL2](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html) to evaluate the performance on MME, MMBench, MMMU, MMVet, MathVista and MMVP.
+
+## Data prepration
+Please follow the [InternVL2](https://internvl.readthedocs.io/en/latest/get_started/eval_data_preparation.html) to prepare the corresponding data. And the link the data under `vlm`.
+
+The final directory structure is:
+```shell
+data
+├── MathVista
+├── mmbench
+├── mme
+├── MMMU
+├── mm-vet
+└── MMVP
+```
+
+## Evaluation
+
+Directly run `scripts/eval/run_eval_vlm.sh` to evaluate different benchmarks. The output will be saved in `$output_path`.
+- Set `$model_path` and `$output_path` for the path for checkpoint and log.
+- Increase `GPUS` if you want to run faster.
+- For MMBench, please use the official [evaluation server](https://mmbench.opencompass.org.cn/mmbench-submission).
+- For MMVet, please use the official [evaluation server](https://huggingface.co/spaces/whyu/MM-Vet_Evaluator).
+- For MathVista, please set `$openai_api_key` in `scripts/eval/run_eval_vlm.sh` and `your_api_url` in `eval/vlm/eval/mathvista/utilities.py`. The default GPT version is `gpt-4o-2024-11-20`.
+- For MMMU, we use CoT in the report, which improve the accuracy by about 2%. For evaluation of the oprn-ended answer, we use GPT-4o for judgement.
+
+
+# GenEval
+We modify the code in [GenEval](https://github.com/djghosh13/geneval/tree/main) for faster evaluation.
+
+## Setup
+Install the following dependencies:
+```shell
+pip install open-clip-torch
+pip install clip-benchmark
+pip install --upgrade setuptools
+
+sudo pip install -U openmim
+sudo mim install mmengine mmcv-full==1.7.2
+
+git clone https://github.com/open-mmlab/mmdetection.git
+cd mmdetection; git checkout 2.x
+pip install -v -e .
+```
+
+Download Detector:
+```shell
+cd ./eval/gen/geneval
+mkdir model
+
+bash ./evaluation/download_models.sh ./model
+```
+
+## Evaluation
+Directly run `scripts/eval/run_geneval.sh` to evaluate GenEVAL. The output will be saved in `$output_path`.
+- Set `$model_path` and `$output_path` for the path for checkpoint and log.
+- Set `metadata_file` to `./eval/gen/geneval/prompts/evaluation_metadata.jsonl` for original GenEval prompts.
+
+
+# WISE
+We modify the code in [WISE](https://github.com/PKU-YuanGroup/WISE/tree/main) for faster evaluation.
+
+
+## Evaluation
+Directly run `scripts/eval/run_wise.sh` to evaluate WISE. The output will be saved in `$output_path`.
+- Set `$model_path` and `$output_path` for the path for checkpoint and log.
+- Set `$openai_api_key` in `scripts/eval/run_wise.sh` and `your_api_url` in `eval/gen/wise/gpt_eval_mp.py`. The default GPT version is `gpt-4o-2024-11-20`.
+- Use `think` for thinking mode.
+
+
+
+# GEdit-Bench
+Please follow [GEdit-Bench](https://github.com/stepfun-ai/Step1X-Edit/blob/main/GEdit-Bench/EVAL.md) for evaluation.
+
+
+# IntelligentBench
+TBD
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/TRAIN.md b/TRAIN.md
new file mode 100644
index 0000000000000000000000000000000000000000..145a16685344a4f6e0feaf924ec5bd01f5b4b4e3
--- /dev/null
+++ b/TRAIN.md
@@ -0,0 +1,133 @@
+# Data prepration
+
+We provide data examples for **T2I**, **Editing**, and **VLM** tasks. The T2I dataset is generated using [FLUX.1‑dev](https://huggingface.co/black-forest-labs/FLUX.1-dev); the editing examples are randomly sampled from [SEED‑Data‑Edit‑Part3](https://huggingface.co/datasets/AILab-CVC/SEED-Data-Edit-Part2-3); and the VLM set is sourced from [LLaVA‑OneVision‑Data](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Data).
+
+We offer examples in both raw-image folder and parquet shard formats. For other data formats, you can use our dataset code as a template and extend it as needed.
+
+
+1. **Download the sample dataset**
+
+ ```bash
+ wget -O bagel_example.zip \
+ https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/bagel_example.zip
+ unzip bagel_example.zip -d /data
+ ```
+2. **Expected hierarchy**
+
+ ```text
+ bagel_example
+ ├── t2i/ # text-to-image (parquet)
+ ├── editing/ # image editing (parquet)
+ │ ├── seedxedit_multi/
+ │ └── parquet_info/
+ └── vlm/
+ ├── images/ # JPEG / PNG frames
+ └── llava_ov_si.jsonl # vision‑language SFT conversations
+ ```
+3. Edit every `your_data_path` placeholder in **`data/dataset_info.py`**.
+4. *(Optional)* Extend `DATASET_INFO` with your own parquet shards or JSONL files to mix extra data.
+
+---
+
+# Training
+
+The baseline full‑feature recipe looks like this (replace environment variables with real paths or values):
+
+```shell
+torchrun \
+ --nnodes=$num_nodes \
+ --node_rank=$node_rank \
+ --nproc_per_node=8 \
+ --master_addr=$master_addr \
+ --master_port=$master_port \
+ train/pretrain_unified_navit.py \
+ --dataset_config_file ./data/configs/example.yaml \
+ --llm_path $llm_path \
+ --vae_path $vae_path \
+ --vit_path $vit_path \
+ --use_flex True \
+ --resume_from $resume_from \
+ --results_dir $output_path \
+ --checkpoint_dir $ckpt_path \
+ --max_latent_size 64 # 32 for low-resolution pre-training
+```
+
+- **When fine-tuning BAGEL, please set `max_latent_size=64` to ensure the correct pretrained weights are loaded.**
+- The sum of num_used_data should be larger than NUM_GPUS x NUM_WORKERS.
+- For T2I-only fine-tuning, set `visual_und=False`; for VLM-only, set `visual_gen=False`.
+
+ You are encouraged to adjust any of these hyperparameters to fit your GPU budget and the scale of your dataset. If you encounter any issues, please open an issue for assistance. 🎉
+
+
+## Model config
+
+
+| Argument | Default | Description |
+| ---------------------------- | ------------------------------------------- | --------------------------------------------------------------- |
+| `llm_path` | `hf/Qwen2.5-0.5B-Instruct` | Language‑model backbone (HuggingFace repo or local folder). |
+| `vae_path` | `flux/vae/ae.safetensors` | Pre‑trained VAE checkpoint for latent diffusion. |
+| `vit_path` | `hf/siglip-so400m-14-980-flash-attn2-navit` | SigLIP ViT used for image understanding. |
+| `max_latent_size` | `32` | Maximum latent grid side; defines highest generable resolution. |
+| `latent_patch_size` | `2` | VAE pixels represented by one latent patch. |
+| `vit_max_num_patch_per_side` | `70` | Max ViT patches per image side after resizing. |
+| `text_cond_dropout_prob` | `0.1` | Probability to drop text conditioning while training. |
+| `vae_cond_dropout_prob` | `0.3` | Dropout on VAE latent inputs. |
+| `vit_cond_dropout_prob` | `0.3` | Dropout on visual features. |
+
+*(See `ModelArguments` for many more options.)*
+
+
+## Data config
+
+
+| Argument | Default | Description |
+| --------------------------- | --------------------------- | --------------------------------------------------------- |
+| `dataset_config_file` | `data/configs/example.yaml` | YAML that groups datasets and assigns sampling weights. |
+| `num_workers` | `4` | Background workers per rank for the PyTorch `DataLoader`. |
+| `prefetch_factor` | `2` | Batches pre‑fetched by each worker. |
+| `max_num_tokens_per_sample` | `16384` | Skip raw samples longer than this. |
+| `max_num_tokens` | `36864` | Hard cap for a packed batch (prevents OOM). |
+| `max_buffer_size` | `50` | Overflow buffer length for oversized samples. |
+| `data_seed` | `42` | Seed for reproducible shuffling and sampling. |
+
+
+## Training config
+
+| Argument | Default | Description |
+| -------------------------------------- | ---------------------- | ------------------------------------------------------ |
+| `total_steps` | `500_000` | Optimiser steps to run. |
+| `lr` | `1e-4` | Peak learning rate after warm‑up. |
+| `lr_scheduler` | `constant` | Learning‑rate schedule (`constant` or `cosine`). |
+| `warmup_steps` | `2000` | Linear warm‑up duration. |
+| `ema` | `0.9999` | Exponential moving‑average decay for model weights. |
+| `max_grad_norm` | `1.0` | Gradient‑clipping threshold. |
+| `save_every` | `2000` | Checkpoint frequency (steps). |
+| `visual_gen / visual_und` | `True` | Enable image generation / understanding branches. |
+| `freeze_llm / freeze_vit / freeze_vae` | `False / False / True` | Freeze selected modules to save VRAM or for ablations. |
+| `use_flex` | `True` (in example) | Enable FLEX packing for higher GPU utilisation. |
+| `sharding_strategy` | `HYBRID_SHARD` | FSDP sharding mode. |
+| `num_shard` | `8` | Parameter shards per rank in HYBRID mode. |
+
+**Distributed‑launch environment variables**
+
+| Var | Meaning |
+| ----------------------------- | --------------------------------- |
+| `num_nodes` / `node_rank` | Multi‑node orchestration indices. |
+| `nproc_per_node` | Number of GPUs per node. |
+| `master_addr` / `master_port` | NCCL rendezvous endpoint. |
+
+
+## Logging config
+
+
+| Argument | Default | Description |
+| ---------------- | --------------------- | ---------------------------------------------------- |
+| `results_dir` | `results` | Root directory for logs and metrics. |
+| `checkpoint_dir` | `results/checkpoints` | Checkpoints are saved here. |
+| `log_every` | `10` | Steps between console / W\&B logs. |
+| `wandb_project` | `bagel` | Weights & Biases project name. |
+| `wandb_name` | `run` | Run name inside the project. |
+| `wandb_offline` | `False` | Switch to offline mode (logs locally, sync later). |
+| `wandb_resume` | `allow` | Resumption policy if an existing run ID is detected. |
+
+> **Tip** Export `WANDB_API_KEY` before launching if you want online dashboards.
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd95c0517c810b6ab3648d1daf8d4d68b5c3dfde
--- /dev/null
+++ b/app.py
@@ -0,0 +1,505 @@
+import gradio as gr
+import numpy as np
+import os
+import torch
+import random
+
+from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
+from PIL import Image
+
+from data.data_utils import add_special_tokens, pil_img2rgb
+from data.transforms import ImageTransform
+from inferencer import InterleaveInferencer
+from modeling.autoencoder import load_ae
+from modeling.bagel.qwen2_navit import NaiveCache
+from modeling.bagel import (
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
+ SiglipVisionConfig, SiglipVisionModel
+)
+from modeling.qwen2 import Qwen2Tokenizer
+
+
+# Model Initialization
+model_path = "/path/to/BAGEL-7B-MoT/weights" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
+
+llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
+llm_config.qk_norm = True
+llm_config.tie_word_embeddings = False
+llm_config.layer_module = "Qwen2MoTDecoderLayer"
+
+vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
+vit_config.rope = False
+vit_config.num_hidden_layers -= 1
+
+vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
+
+config = BagelConfig(
+ visual_gen=True,
+ visual_und=True,
+ llm_config=llm_config,
+ vit_config=vit_config,
+ vae_config=vae_config,
+ vit_max_num_patch_per_side=70,
+ connector_act='gelu_pytorch_tanh',
+ latent_patch_size=2,
+ max_latent_size=64,
+)
+
+with init_empty_weights():
+ language_model = Qwen2ForCausalLM(llm_config)
+ vit_model = SiglipVisionModel(vit_config)
+ model = Bagel(language_model, vit_model, config)
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
+
+tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
+tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
+
+vae_transform = ImageTransform(1024, 512, 16)
+vit_transform = ImageTransform(980, 224, 14)
+
+# Model Loading and Multi GPU Infernece Preparing
+device_map = infer_auto_device_map(
+ model,
+ max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
+)
+
+same_device_modules = [
+ 'language_model.model.embed_tokens',
+ 'time_embedder',
+ 'latent_pos_embed',
+ 'vae2llm',
+ 'llm2vae',
+ 'connector',
+ 'vit_pos_embed'
+]
+
+if torch.cuda.device_count() == 1:
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
+ for k in same_device_modules:
+ if k in device_map:
+ device_map[k] = first_device
+ else:
+ device_map[k] = "cuda:0"
+else:
+ first_device = device_map.get(same_device_modules[0])
+ for k in same_device_modules:
+ if k in device_map:
+ device_map[k] = first_device
+
+model = load_checkpoint_and_dispatch(
+ model,
+ checkpoint=os.path.join(model_path, "ema.safetensors"),
+ device_map=device_map,
+ offload_buffers=True,
+ dtype=torch.bfloat16,
+ force_hooks=True,
+).eval()
+
+
+# Inferencer Preparing
+inferencer = InterleaveInferencer(
+ model=model,
+ vae_model=vae_model,
+ tokenizer=tokenizer,
+ vae_transform=vae_transform,
+ vit_transform=vit_transform,
+ new_token_ids=new_token_ids,
+)
+
+def set_seed(seed):
+ """Set random seeds for reproducibility"""
+ if seed > 0:
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ return seed
+
+# Text to Image function with thinking option and hyperparameters
+def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
+ timestep_shift=3.0, num_timesteps=50,
+ cfg_renorm_min=1.0, cfg_renorm_type="global",
+ max_think_token_n=1024, do_sample=False, text_temperature=0.3,
+ seed=0, image_ratio="1:1"):
+ # Set seed for reproducibility
+ set_seed(seed)
+
+ if image_ratio == "1:1":
+ image_shapes = (1024, 1024)
+ elif image_ratio == "4:3":
+ image_shapes = (768, 1024)
+ elif image_ratio == "3:4":
+ image_shapes = (1024, 768)
+ elif image_ratio == "16:9":
+ image_shapes = (576, 1024)
+ elif image_ratio == "9:16":
+ image_shapes = (1024, 576)
+
+ # Set hyperparameters
+ inference_hyper = dict(
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
+ do_sample=do_sample if show_thinking else False,
+ text_temperature=text_temperature if show_thinking else 0.3,
+ cfg_text_scale=cfg_text_scale,
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
+ timestep_shift=timestep_shift,
+ num_timesteps=num_timesteps,
+ cfg_renorm_min=cfg_renorm_min,
+ cfg_renorm_type=cfg_renorm_type,
+ image_shapes=image_shapes,
+ )
+
+ # Call inferencer with or without think parameter based on user choice
+ result = inferencer(text=prompt, think=show_thinking, **inference_hyper)
+ return result["image"], result.get("text", None)
+
+
+# Image Understanding function with thinking option and hyperparameters
+def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
+ do_sample=False, text_temperature=0.3, max_new_tokens=512):
+ if image is None:
+ return "Please upload an image."
+
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ image = pil_img2rgb(image)
+
+ # Set hyperparameters
+ inference_hyper = dict(
+ do_sample=do_sample,
+ text_temperature=text_temperature,
+ max_think_token_n=max_new_tokens, # Set max_length
+ )
+
+ # Use show_thinking parameter to control thinking process
+ result = inferencer(image=image, text=prompt, think=show_thinking,
+ understanding_output=True, **inference_hyper)
+ return result["text"]
+
+
+# Image Editing function with thinking option and hyperparameters
+def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
+ cfg_img_scale=2.0, cfg_interval=0.0,
+ timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
+ cfg_renorm_type="text_channel", max_think_token_n=1024,
+ do_sample=False, text_temperature=0.3, seed=0):
+ # Set seed for reproducibility
+ set_seed(seed)
+
+ if image is None:
+ return "Please upload an image.", ""
+
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+
+ image = pil_img2rgb(image)
+
+ # Set hyperparameters
+ inference_hyper = dict(
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
+ do_sample=do_sample if show_thinking else False,
+ text_temperature=text_temperature if show_thinking else 0.3,
+ cfg_text_scale=cfg_text_scale,
+ cfg_img_scale=cfg_img_scale,
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
+ timestep_shift=timestep_shift,
+ num_timesteps=num_timesteps,
+ cfg_renorm_min=cfg_renorm_min,
+ cfg_renorm_type=cfg_renorm_type,
+ )
+
+ # Include thinking parameter based on user choice
+ result = inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper)
+ return result["image"], result.get("text", "")
+
+
+# Helper function to load example images
+def load_example_image(image_path):
+ try:
+ return Image.open(image_path)
+ except Exception as e:
+ print(f"Error loading example image: {e}")
+ return None
+
+
+# Gradio UI
+with gr.Blocks() as demo:
+ gr.Markdown("""
+
+

+
+""")
+
+ with gr.Tab("📝 Text to Image"):
+ txt_input = gr.Textbox(
+ label="Prompt",
+ value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
+ )
+
+ with gr.Row():
+ show_thinking = gr.Checkbox(label="Thinking", value=False)
+
+ # Add hyperparameter controls in an accordion
+ with gr.Accordion("Inference Hyperparameters", open=False):
+ # 参数一排两个布局
+ with gr.Group():
+ with gr.Row():
+ seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
+ label="Seed", info="0 for random seed, positive for reproducible results")
+ image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
+ value="1:1", label="Image Ratio",
+ info="The longer size is fixed to 1024")
+
+ with gr.Row():
+ cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
+ cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
+
+ with gr.Row():
+ cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
+ value="global", label="CFG Renorm Type",
+ info="If the genrated image is blurry, use 'global'")
+ cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
+
+ with gr.Row():
+ num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
+ label="Timesteps", info="Total denoising steps")
+ timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
+ label="Timestep Shift", info="Higher values for layout, lower for details")
+
+ # Thinking parameters in a single row
+ thinking_params = gr.Group(visible=False)
+ with thinking_params:
+ with gr.Row():
+ do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
+ max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
+ text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
+ label="Temperature", info="Controls randomness in text generation")
+
+ thinking_output = gr.Textbox(label="Thinking Process", visible=False)
+ img_output = gr.Image(label="Generated Image")
+ gen_btn = gr.Button("Generate")
+
+ # Dynamically show/hide thinking process box and parameters
+ def update_thinking_visibility(show):
+ return gr.update(visible=show), gr.update(visible=show)
+
+ show_thinking.change(
+ fn=update_thinking_visibility,
+ inputs=[show_thinking],
+ outputs=[thinking_output, thinking_params]
+ )
+
+ # Process function based on thinking option and hyperparameters
+ def process_text_to_image(prompt, show_thinking, cfg_text_scale,
+ cfg_interval, timestep_shift,
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio):
+ image, thinking = text_to_image(
+ prompt, show_thinking, cfg_text_scale, cfg_interval,
+ timestep_shift, num_timesteps,
+ cfg_renorm_min, cfg_renorm_type,
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
+ )
+ return image, thinking if thinking else ""
+
+ gen_btn.click(
+ fn=process_text_to_image,
+ inputs=[
+ txt_input, show_thinking, cfg_text_scale,
+ cfg_interval, timestep_shift,
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
+ ],
+ outputs=[img_output, thinking_output]
+ )
+
+ with gr.Tab("🖌️ Image Edit"):
+ with gr.Row():
+ with gr.Column(scale=1):
+ edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
+ edit_prompt = gr.Textbox(
+ label="Prompt",
+ value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
+ )
+
+ with gr.Column(scale=1):
+ edit_image_output = gr.Image(label="Result")
+ edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
+
+ with gr.Row():
+ edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
+
+ # Add hyperparameter controls in an accordion
+ with gr.Accordion("Inference Hyperparameters", open=False):
+ with gr.Group():
+ with gr.Row():
+ edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
+ label="Seed", info="0 for random seed, positive for reproducible results")
+ edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
+
+ with gr.Row():
+ edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
+ label="CFG Image Scale", info="Controls how much the model preserves input image details")
+ edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
+
+ with gr.Row():
+ edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
+ value="text_channel", label="CFG Renorm Type",
+ info="If the genrated image is blurry, use 'global")
+ edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
+
+ with gr.Row():
+ edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
+ label="Timesteps", info="Total denoising steps")
+ edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
+ label="Timestep Shift", info="Higher values for layout, lower for details")
+
+
+ # Thinking parameters in a single row
+ edit_thinking_params = gr.Group(visible=False)
+ with edit_thinking_params:
+ with gr.Row():
+ edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
+ edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
+ edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
+ label="Temperature", info="Controls randomness in text generation")
+
+ edit_btn = gr.Button("Submit")
+
+ # Dynamically show/hide thinking process box for editing
+ def update_edit_thinking_visibility(show):
+ return gr.update(visible=show), gr.update(visible=show)
+
+ edit_show_thinking.change(
+ fn=update_edit_thinking_visibility,
+ inputs=[edit_show_thinking],
+ outputs=[edit_thinking_output, edit_thinking_params]
+ )
+
+ # Process editing with thinking option and hyperparameters
+ def process_edit_image(image, prompt, show_thinking, cfg_text_scale,
+ cfg_img_scale, cfg_interval,
+ timestep_shift, num_timesteps, cfg_renorm_min,
+ cfg_renorm_type, max_think_token_n, do_sample,
+ text_temperature, seed):
+ edited_image, thinking = edit_image(
+ image, prompt, show_thinking, cfg_text_scale, cfg_img_scale,
+ cfg_interval, timestep_shift,
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
+ max_think_token_n, do_sample, text_temperature, seed
+ )
+
+ return edited_image, thinking if thinking else ""
+
+ edit_btn.click(
+ fn=process_edit_image,
+ inputs=[
+ edit_image_input, edit_prompt, edit_show_thinking,
+ edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
+ edit_timestep_shift, edit_num_timesteps,
+ edit_cfg_renorm_min, edit_cfg_renorm_type,
+ edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
+ ],
+ outputs=[edit_image_output, edit_thinking_output]
+ )
+
+ with gr.Tab("🖼️ Image Understanding"):
+ with gr.Row():
+ with gr.Column(scale=1):
+ img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
+ understand_prompt = gr.Textbox(
+ label="Prompt",
+ value="Can someone explain what's funny about this meme??"
+ )
+
+ with gr.Column(scale=1):
+ txt_output = gr.Textbox(label="Result", lines=20)
+
+ with gr.Row():
+ understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
+
+ # Add hyperparameter controls in an accordion
+ with gr.Accordion("Inference Hyperparameters", open=False):
+ with gr.Row():
+ understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
+ understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
+ label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
+ understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
+ label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
+
+ img_understand_btn = gr.Button("Submit")
+
+ # Process understanding with thinking option and hyperparameters
+ def process_understanding(image, prompt, show_thinking, do_sample,
+ text_temperature, max_new_tokens):
+ result = image_understanding(
+ image, prompt, show_thinking, do_sample,
+ text_temperature, max_new_tokens
+ )
+ return result
+
+ img_understand_btn.click(
+ fn=process_understanding,
+ inputs=[
+ img_input, understand_prompt, understand_show_thinking,
+ understand_do_sample, understand_text_temperature, understand_max_new_tokens
+ ],
+ outputs=txt_output
+ )
+
+ gr.Markdown("""
+
+""")
+
+demo.launch(share=True)
\ No newline at end of file
diff --git a/assets/arch.png b/assets/arch.png
new file mode 100644
index 0000000000000000000000000000000000000000..b79cfbf9174926d4cdb20c89ab8944e49f051ed7
--- /dev/null
+++ b/assets/arch.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28affbbfede911a75884bae4e8e1d5b897b8b450fa4c7d9b68818d05492b0967
+size 168377
diff --git a/assets/emerging_curves.png b/assets/emerging_curves.png
new file mode 100644
index 0000000000000000000000000000000000000000..eccc45b992ce0afde1923790f5179c95c4fa8f42
--- /dev/null
+++ b/assets/emerging_curves.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c1ddd355742cddb52045ee59098305cc5de8174cb09afa019bb9afefd868733
+size 372807
diff --git a/assets/teaser.webp b/assets/teaser.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e72964e965eed00f0e979d7e47c2dcdb742e65b9
--- /dev/null
+++ b/assets/teaser.webp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d679e69a1fbdb7f9abceb59d9bc3d29ab65b7e871ba48b59aec0a7f35defa558
+size 1104072
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..943b7d3a7a676d9a500bf048d4307f4f803221f5
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,2 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
\ No newline at end of file
diff --git a/data/configs/example.yaml b/data/configs/example.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a2e2d79565185f3115810b4e4d69492506b10c3a
--- /dev/null
+++ b/data/configs/example.yaml
@@ -0,0 +1,45 @@
+t2i_pretrain:
+ dataset_names:
+ - t2i
+ image_transform_args:
+ image_stride: 16
+ max_image_size: 1024
+ min_image_size: 512
+ is_mandatory: true
+ num_used_data: # The sum should be larger that NUM_GPUS x NUM_WORKERS
+ - 10
+ weight: 1
+
+unified_edit:
+ dataset_names:
+ - seedxedit_multi
+ image_transform_args:
+ image_stride: 16
+ max_image_size: 1024
+ min_image_size: 512
+ vit_image_transform_args:
+ image_stride: 14
+ max_image_size: 518
+ min_image_size: 224
+ is_mandatory: false
+ num_used_data:
+ - 10
+ weight: 1
+
+vlm_sft:
+ dataset_names:
+ - llava_ov
+ image_transform_args:
+ image_stride: 14
+ max_image_size: 980
+ min_image_size: 378
+ max_pixels: 2_007_040
+ frame_sampler_args:
+ max_num_frames: 12
+ min_num_frames: 8
+ is_mandatory: true
+ shuffle_lines: True
+ shuffle_seed: 0
+ num_used_data:
+ - 1000
+ weight: 1
diff --git a/data/data_utils.py b/data/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bab578bd69eaa1095ecdfd7d300e10f3c550e59c
--- /dev/null
+++ b/data/data_utils.py
@@ -0,0 +1,177 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import math
+import random
+from PIL import Image
+
+import torch
+from torch.nn.attention.flex_attention import or_masks, and_masks
+
+
+def create_sparse_mask(document_lens, split_lens, attn_modes, device):
+ def causal_mask(b, h, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+ def full_and_noise_mask(b, h, q_idx, kv_idx):
+ return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
+
+ def remove_noise_mask(b, h, q_idx, kv_idx):
+ return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
+
+ def sample_mask(b, h, q_idx, kv_idx):
+ return document_id[q_idx] == document_id[kv_idx]
+
+ full_and_noise_tmp = []
+ noise_tmp = []
+
+ for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
+ value = i if model in ['full', 'noise'] else -1
+ full_and_noise_tmp.extend([value] * length)
+ value_noise = i if model == 'noise' else -1
+ noise_tmp.extend([value_noise] * length)
+
+ full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
+ noise_seq_id = torch.Tensor(noise_tmp).to(device)
+
+ document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
+
+ return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
+
+
+def patchify(image, patch_size):
+ p = patch_size
+ c, h, w = image.shape
+ assert h % p == 0 and w % p == 0
+ image = image.reshape(c, h // p, p, w // p, p)
+ image = torch.einsum("chpwq->hwpqc", image)
+ image = image.reshape(-1, p**2 * c)
+ return image
+
+
+def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
+ num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
+ coords_h = torch.arange(0, num_patches_h)
+ coords_w = torch.arange(0, num_patches_w)
+ pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
+ return pos_ids
+
+
+def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
+ num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
+ boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
+ pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
+ return pos_ids
+
+
+def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
+ """
+ nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
+ a sample, where each sample contains multiple splits with different attn modes.
+ nested_attn_modes: whether to use full attn in each split.
+ """
+ sample_len = sum(split_lens)
+ attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
+
+ csum = 0
+ for s, attn_mode in zip(split_lens, attn_modes):
+ assert attn_mode in ['causal', 'full', 'noise']
+ if attn_mode == "causal":
+ attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
+ attention_mask[csum:csum + s, :csum] = 1
+ else:
+ attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
+ attention_mask[csum:csum + s, :csum] = 1
+ csum += s
+
+ csum = 0
+ for s, attn_mode in zip(split_lens, attn_modes):
+ if attn_mode == "noise":
+ attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
+ attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
+ csum += s
+
+ attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
+ ~attention_mask, float("-inf")
+ )
+
+ return attention_mask
+
+
+def split_integer_exp_decay(S, ng_sample_decay=1.0):
+ if ng_sample_decay == 1.0:
+ N = random.randint(1, S)
+ else:
+ base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
+ p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
+ N = random.choices(list(range(1, S + 1)), p, k=1)[0]
+ cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
+ result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
+ return result, cumsum
+
+
+def pil_img2rgb(image):
+ if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
+ image = image.convert("RGBA")
+ white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
+ white.paste(image, mask=image.split()[3])
+ image = white
+ else:
+ image = image.convert("RGB")
+
+ return image
+
+
+def add_special_tokens(tokenizer):
+ all_special_tokens = []
+ for k, v in tokenizer.special_tokens_map.items():
+ if isinstance(v, str):
+ all_special_tokens.append(v)
+ elif isinstance(v, list):
+ all_special_tokens += v
+
+ new_tokens = []
+
+ if '<|im_start|>' not in all_special_tokens:
+ new_tokens.append('<|im_start|>')
+
+ if '<|im_end|>' not in all_special_tokens:
+ new_tokens.append('<|im_end|>')
+
+ if '<|vision_start|>' not in all_special_tokens:
+ new_tokens.append('<|vision_start|>')
+
+ if '<|vision_end|>' not in all_special_tokens:
+ new_tokens.append('<|vision_end|>')
+
+ num_new_tokens = tokenizer.add_tokens(new_tokens)
+ bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
+ eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
+ start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
+ end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
+
+ new_token_ids = dict(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ start_of_image=start_of_image,
+ end_of_image=end_of_image,
+ )
+
+ return tokenizer, new_token_ids, num_new_tokens
+
+
+def len2weight(x, loss_reduction='square'):
+ if x == 0:
+ return x
+ if loss_reduction == 'token':
+ return 1
+ if loss_reduction == 'sample':
+ return 1 / x
+ if loss_reduction == 'square':
+ return 1 / (x ** 0.5)
+ raise NotImplementedError(loss_reduction)
diff --git a/data/dataset_base.py b/data/dataset_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..263eece94ec4afcd39b7baba28c40332dc79aadd
--- /dev/null
+++ b/data/dataset_base.py
@@ -0,0 +1,620 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import random
+import json
+
+import numpy as np
+import torch
+
+from .data_utils import (
+ get_flattened_position_ids_interpolate,
+ get_flattened_position_ids_extrapolate,
+ len2weight,
+ patchify,
+ prepare_attention_mask_per_sample,
+)
+from .dataset_info import DATASET_INFO, DATASET_REGISTRY
+from .transforms import ImageTransform
+from .video_utils import FrameSampler
+
+
+class DataConfig:
+ def __init__(
+ self,
+ grouped_datasets,
+ text_cond_dropout_prob=0.1,
+ vit_cond_dropout_prob=0.4,
+ vae_cond_dropout_prob=0.1,
+ vae_image_downsample=16,
+ max_latent_size=32,
+ vit_patch_size=14,
+ max_num_patch_per_side=70,
+ ):
+ self.grouped_datasets = grouped_datasets
+ self.text_cond_dropout_prob = text_cond_dropout_prob
+ self.vit_cond_dropout_prob = vit_cond_dropout_prob
+ self.vit_patch_size = vit_patch_size
+ self.max_num_patch_per_side = max_num_patch_per_side
+ self.vae_cond_dropout_prob = vae_cond_dropout_prob
+ self.vae_image_downsample = vae_image_downsample
+ self.max_latent_size = max_latent_size
+
+
+class PackedDataset(torch.utils.data.IterableDataset):
+ def __init__(
+ self,
+ data_config,
+ tokenizer,
+ special_tokens,
+ local_rank,
+ world_size,
+ num_workers,
+ expected_num_tokens=32768,
+ max_num_tokens_per_sample=16384,
+ max_num_tokens=36864,
+ prefer_buffer_before=16384,
+ max_buffer_size=50,
+ interpolate_pos=False,
+ use_flex=False,
+ data_status=None,
+ ):
+ super().__init__()
+ self.expected_num_tokens = expected_num_tokens
+ self.max_num_tokens_per_sample = max_num_tokens_per_sample
+ self.prefer_buffer_before = prefer_buffer_before
+ self.max_num_tokens = max_num_tokens
+ self.max_buffer_size = max_buffer_size
+ self.tokenizer = tokenizer
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.num_workers = num_workers
+ self.use_flex = use_flex
+ for k, v in special_tokens.items():
+ setattr(self, k, v)
+
+ grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
+ data_config.grouped_datasets, data_status
+ )
+ self.grouped_datasets = grouped_datasets
+ self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
+ self.is_mandatory = is_mandatory
+ self.grouped_weights = grouped_weights
+ self.data_config = data_config
+ self.interpolate_pos = interpolate_pos
+ if self.interpolate_pos:
+ self.get_flattened_position_ids = get_flattened_position_ids_interpolate
+ else:
+ self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
+
+ def build_datasets(self, datasets_metainfo, data_status):
+ datasets = []
+ is_mandatory = []
+ grouped_weights = []
+ for grouped_dataset_name, dataset_args in datasets_metainfo.items():
+ is_mandatory.append(dataset_args.pop('is_mandatory', False))
+ grouped_weights.append(dataset_args.pop('weight', 0.0))
+
+ if 'frame_sampler_args' in dataset_args.keys():
+ frame_sampler = FrameSampler(**dataset_args.pop('frame_sampler_args'))
+ dataset_args['frame_sampler'] = frame_sampler
+ if 'image_transform_args' in dataset_args.keys():
+ transform = ImageTransform(**dataset_args.pop('image_transform_args'))
+ dataset_args['transform'] = transform
+ if 'vit_image_transform_args' in dataset_args.keys():
+ vit_transform = ImageTransform(**dataset_args.pop('vit_image_transform_args'))
+ dataset_args['vit_transform'] = vit_transform
+
+ assert 'dataset_names' in dataset_args.keys()
+ dataset_names = dataset_args.pop('dataset_names')
+ dataset_args['data_dir_list'] = []
+ for item in dataset_names:
+ if self.local_rank == 0:
+ print(f'Preparing Dataset {grouped_dataset_name}/{item}')
+ meta_info = DATASET_INFO[grouped_dataset_name][item]
+ dataset_args['data_dir_list'].append(meta_info['data_dir'])
+
+ if "parquet_info_path" in meta_info.keys():
+ if 'parquet_info' not in dataset_args.keys():
+ dataset_args['parquet_info'] = {}
+ with open(meta_info['parquet_info_path'], 'r') as f:
+ parquet_info = json.load(f)
+ dataset_args['parquet_info'].update(parquet_info)
+
+ if 'json_dir' in meta_info.keys():
+ # parquet/tar with json
+ if 'json_dir_list' not in dataset_args.keys():
+ dataset_args['json_dir_list'] = [meta_info['json_dir']]
+ else:
+ dataset_args['json_dir_list'].append(meta_info['json_dir'])
+
+ if 'jsonl_path' in meta_info.keys():
+ # jsonl with jpeg
+ if 'jsonl_path_list' not in dataset_args.keys():
+ dataset_args['jsonl_path_list'] = [meta_info['jsonl_path']]
+ else:
+ dataset_args['jsonl_path_list'].append(meta_info['jsonl_path'])
+
+ resume_data_status = dataset_args.pop('resume_data_status', True)
+ if data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status:
+ data_status_per_group = data_status[grouped_dataset_name]
+ else:
+ data_status_per_group = None
+ dataset = DATASET_REGISTRY[grouped_dataset_name](
+ dataset_name=grouped_dataset_name,
+ tokenizer=self.tokenizer,
+ local_rank=self.local_rank,
+ world_size=self.world_size,
+ num_workers=self.num_workers,
+ data_status=data_status_per_group,
+ **dataset_args
+ )
+ datasets.append(dataset)
+
+ return datasets, is_mandatory, grouped_weights
+
+ def set_epoch(self, seed):
+ for dataset in self.grouped_datasets:
+ dataset.set_epoch(seed)
+
+ def set_sequence_status(self):
+ sequence_status = dict(
+ curr = 0,
+ sample_lens = list(),
+ packed_position_ids = list(),
+ nested_attention_masks = list(),
+ split_lens = list(),
+ attn_modes = list(),
+ packed_text_ids = list(),
+ packed_text_indexes = list(),
+ packed_label_ids = list(),
+ ce_loss_indexes = list(),
+ ce_loss_weights = list(),
+ vae_image_tensors = list(),
+ packed_latent_position_ids = list(),
+ vae_latent_shapes = list(),
+ packed_vae_token_indexes = list(),
+ packed_timesteps = list(),
+ mse_loss_indexes = list(),
+ packed_vit_tokens = list(),
+ vit_token_seqlens = list(),
+ packed_vit_position_ids = list(),
+ packed_vit_token_indexes = list(),
+ )
+ return sequence_status
+
+ def to_tensor(self, sequence_status):
+ data = dict(
+ sequence_length=sum(sequence_status['sample_lens']),
+ sample_lens=sequence_status['sample_lens'],
+ packed_text_ids=torch.tensor(sequence_status['packed_text_ids']),
+ packed_text_indexes=torch.tensor(sequence_status['packed_text_indexes']),
+ packed_position_ids=torch.tensor(sequence_status['packed_position_ids']),
+ )
+ if not self.use_flex:
+ data['nested_attention_masks'] = sequence_status['nested_attention_masks']
+ else:
+ sequence_len = data['sequence_length']
+ pad_len = self.max_num_tokens - sequence_len
+ data['split_lens'] = sequence_status['split_lens'] + [pad_len]
+ data['attn_modes'] = sequence_status['attn_modes'] + ['causal']
+ data['sample_lens'] += [pad_len]
+
+ # if the model has a convnet vae (e.g., as visual tokenizer)
+ if len(sequence_status['vae_image_tensors']) > 0:
+ image_tensors = sequence_status.pop('vae_image_tensors')
+ image_sizes = [item.shape for item in image_tensors]
+ max_image_size = [max(item) for item in list(zip(*image_sizes))]
+ padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
+ for i, image_tensor in enumerate(image_tensors):
+ padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
+
+ data['padded_images'] = padded_images
+ data['patchified_vae_latent_shapes'] = sequence_status['vae_latent_shapes']
+ data['packed_latent_position_ids'] = torch.cat(sequence_status['packed_latent_position_ids'], dim=0)
+ data['packed_vae_token_indexes'] = torch.tensor(sequence_status['packed_vae_token_indexes'])
+
+ # if the model has a vit (e.g., as visual tokenizer)
+ if len(sequence_status['packed_vit_tokens']) > 0:
+ data['packed_vit_tokens'] = torch.cat(sequence_status['packed_vit_tokens'], dim=0)
+ data['packed_vit_position_ids'] = torch.cat(sequence_status['packed_vit_position_ids'], dim=0)
+ data['packed_vit_token_indexes'] = torch.tensor(sequence_status['packed_vit_token_indexes'])
+ data['vit_token_seqlens'] = torch.tensor(sequence_status['vit_token_seqlens'])
+
+ # if the model is required to perform visual generation
+ if len(sequence_status['packed_timesteps']) > 0:
+ data['packed_timesteps'] = torch.tensor(sequence_status['packed_timesteps'])
+ data['mse_loss_indexes'] = torch.tensor(sequence_status['mse_loss_indexes'])
+
+ # if the model is required to perform text generation
+ if len(sequence_status['packed_label_ids']) > 0:
+ data['packed_label_ids'] = torch.tensor(sequence_status['packed_label_ids'])
+ data['ce_loss_indexes'] = torch.tensor(sequence_status['ce_loss_indexes'])
+ data['ce_loss_weights'] = torch.tensor(sequence_status['ce_loss_weights'])
+
+ return data
+
+ def __iter__(self):
+ total_weights = sum(self.grouped_weights)
+ assert total_weights > 0.0
+ group_cumprobs = [sum(self.grouped_weights[:i + 1]) / total_weights
+ for i in range(len(self.grouped_weights))]
+ sequence_status = self.set_sequence_status()
+ batch_data_indexes = []
+
+ buffer = []
+ while True:
+ # Ensure at least one sample from each group
+ if sequence_status['curr'] == 0:
+ for group_index, group_iter in enumerate(self.dataset_iters):
+ if self.is_mandatory[group_index]:
+ while True:
+ sample = next(group_iter)
+ # if a sample is too long, skip it
+ num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
+ if num_tokens < self.max_num_tokens_per_sample:
+ sequence_status = self.pack_sequence(sample, sequence_status)
+ batch_data_indexes.append(sample['data_indexes'])
+ break
+ else:
+ print(f"skip a sample with length {num_tokens}")
+ continue
+
+ if sequence_status['curr'] < self.prefer_buffer_before and len(buffer) > 0:
+ sample = buffer.pop(0)
+ sample_from_buffer = True
+ else:
+ # sample normally across all groups
+ n = random.random()
+ group_index = 0
+ for i, cumprob in enumerate(group_cumprobs):
+ if n < cumprob:
+ group_index = i
+ break
+ sample = next(self.dataset_iters[group_index])
+ sample_from_buffer = False
+
+ # if a sample is too long, skip it
+ num_tokens = sample['num_tokens'] + 2 * len(sample['sequence_plan'])
+ if num_tokens > self.max_num_tokens_per_sample:
+ print(f"skip a sample with length {num_tokens}")
+ continue
+
+ if sequence_status['curr'] + num_tokens > self.max_num_tokens:
+ if len(buffer) < self.max_buffer_size and not sample_from_buffer:
+ buffer.append(sample)
+ else:
+ print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
+ data = self.to_tensor(sequence_status)
+ data['batch_data_indexes'] = batch_data_indexes
+ yield data
+ sequence_status = self.set_sequence_status()
+ batch_data_indexes = []
+ continue
+
+ sequence_status = self.pack_sequence(sample, sequence_status)
+ batch_data_indexes.append(sample['data_indexes'])
+
+ if sequence_status['curr'] >= self.expected_num_tokens:
+ data = self.to_tensor(sequence_status)
+ data['batch_data_indexes'] = batch_data_indexes
+ yield data
+ sequence_status = self.set_sequence_status()
+ batch_data_indexes = []
+
+ def pack_sequence(self, sample, sequence_status):
+ image_tensor_list = sample['image_tensor_list']
+ text_ids_list = sample['text_ids_list']
+ sequence_plan = sample['sequence_plan']
+
+ split_lens, attn_modes = list(), list()
+ curr = sequence_status['curr']
+ curr_rope_id = 0
+ sample_lens = 0
+
+ for item in sequence_plan:
+ split_start = item.get('split_start', True)
+ if split_start:
+ curr_split_len = 0
+
+ if item['type'] == 'text':
+ text_ids = text_ids_list.pop(0)
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.text_cond_dropout_prob:
+ continue
+
+ shifted_text_ids = [self.bos_token_id] + text_ids
+ sequence_status['packed_text_ids'].extend(shifted_text_ids)
+ sequence_status['packed_text_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
+ if item['loss'] == 1:
+ sequence_status['ce_loss_indexes'].extend(range(curr, curr + len(shifted_text_ids)))
+ sequence_status['ce_loss_weights'].extend(
+ [len2weight(len(shifted_text_ids))] * len(shifted_text_ids)
+ )
+ sequence_status['packed_label_ids'].extend(text_ids + [self.eos_token_id])
+ curr += len(shifted_text_ids)
+ curr_split_len += len(shifted_text_ids)
+
+ # add a <|im_end|> token
+ sequence_status['packed_text_ids'].append(self.eos_token_id)
+ sequence_status['packed_text_indexes'].append(curr)
+ if item['special_token_loss'] == 1: # <|im_end|> may have loss
+ sequence_status['ce_loss_indexes'].append(curr)
+ sequence_status['ce_loss_weights'].append(1.0)
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
+ curr += 1
+ curr_split_len += 1
+
+ # update sequence status
+ attn_modes.append("causal")
+ sequence_status['packed_position_ids'].extend(range(curr_rope_id, curr_rope_id + curr_split_len))
+ curr_rope_id += curr_split_len
+
+ elif item['type'] == 'vit_image':
+ image_tensor = image_tensor_list.pop(0)
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.vit_cond_dropout_prob:
+ curr_rope_id += 1
+ continue
+
+ # add a <|startofimage|> token
+ sequence_status['packed_text_ids'].append(self.start_of_image)
+ sequence_status['packed_text_indexes'].append(curr)
+ curr += 1
+ curr_split_len += 1
+
+ # preprocess image
+ vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size)
+ num_img_tokens = vit_tokens.shape[0]
+ sequence_status['packed_vit_token_indexes'].extend(range(curr, curr + num_img_tokens))
+ curr += num_img_tokens
+ curr_split_len += num_img_tokens
+
+ sequence_status['packed_vit_tokens'].append(vit_tokens)
+ sequence_status['vit_token_seqlens'].append(num_img_tokens)
+ sequence_status['packed_vit_position_ids'].append(
+ self.get_flattened_position_ids(
+ image_tensor.size(1), image_tensor.size(2),
+ self.data_config.vit_patch_size,
+ max_num_patches_per_side=self.data_config.max_num_patch_per_side
+ )
+ )
+
+ # add a <|endofimage|> token
+ sequence_status['packed_text_ids'].append(self.end_of_image)
+ sequence_status['packed_text_indexes'].append(curr)
+ if item['special_token_loss'] == 1: # <|endofimage|> may have loss
+ sequence_status['ce_loss_indexes'].append(curr)
+ sequence_status['ce_loss_weights'].append(1.0)
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
+ curr += 1
+ curr_split_len += 1
+
+ # update sequence status
+ attn_modes.append("full")
+ sequence_status['packed_position_ids'].extend([curr_rope_id] * curr_split_len)
+ curr_rope_id += 1
+
+ elif item['type'] == 'vae_image':
+ image_tensor = image_tensor_list.pop(0)
+ if item['enable_cfg'] == 1 and random.random() < self.data_config.vae_cond_dropout_prob:
+ # FIXME fix vae dropout in video2video setting.
+ curr_rope_id += 1
+ continue
+
+ # add a <|startofimage|> token
+ sequence_status['packed_text_ids'].append(self.start_of_image)
+ sequence_status['packed_text_indexes'].append(curr)
+ curr += 1
+ curr_split_len += 1
+
+ # preprocess image
+ sequence_status['vae_image_tensors'].append(image_tensor)
+ sequence_status['packed_latent_position_ids'].append(
+ self.get_flattened_position_ids(
+ image_tensor.size(1), image_tensor.size(2),
+ self.data_config.vae_image_downsample,
+ max_num_patches_per_side=self.data_config.max_latent_size
+ )
+ )
+ H, W = image_tensor.shape[1:]
+ h = H // self.data_config.vae_image_downsample
+ w = W // self.data_config.vae_image_downsample
+ sequence_status['vae_latent_shapes'].append((h, w))
+
+ num_img_tokens = w * h
+ sequence_status['packed_vae_token_indexes'].extend(range(curr, curr + num_img_tokens))
+ if item['loss'] == 1:
+ sequence_status['mse_loss_indexes'].extend(range(curr, curr + num_img_tokens))
+ if split_start:
+ timestep = np.random.randn()
+ else:
+ timestep = float('-inf')
+
+ sequence_status['packed_timesteps'].extend([timestep] * num_img_tokens)
+ curr += num_img_tokens
+ curr_split_len += num_img_tokens
+
+ # add a <|endofimage|> token
+ sequence_status['packed_text_ids'].append(self.end_of_image)
+ sequence_status['packed_text_indexes'].append(curr)
+ # <|endofimage|> may have loss
+ if item['special_token_loss'] == 1:
+ sequence_status['ce_loss_indexes'].append(curr)
+ sequence_status['ce_loss_weights'].append(1.0)
+ sequence_status['packed_label_ids'].append(item['special_token_label'])
+ curr += 1
+ curr_split_len += 1
+
+ # update sequence status
+ if split_start:
+ if item['loss'] == 1 and 'frame_delta' not in item.keys():
+ attn_modes.append("noise")
+ else:
+ attn_modes.append("full")
+ sequence_status['packed_position_ids'].extend([curr_rope_id] * (num_img_tokens + 2))
+ if 'frame_delta' in item.keys():
+ curr_rope_id += item['frame_delta']
+ elif item['loss'] == 0:
+ curr_rope_id += 1
+
+ if item.get('split_end', True):
+ split_lens.append(curr_split_len)
+ sample_lens += curr_split_len
+
+ sequence_status['curr'] = curr
+ sequence_status['sample_lens'].append(sample_lens)
+ # prepare attention mask
+ if not self.use_flex:
+ sequence_status['nested_attention_masks'].append(
+ prepare_attention_mask_per_sample(split_lens, attn_modes)
+ )
+ else:
+ sequence_status['split_lens'].extend(split_lens)
+ sequence_status['attn_modes'].extend(attn_modes)
+
+ return sequence_status
+
+
+class SimpleCustomBatch:
+ def __init__(self, batch):
+ data = batch[0]
+ self.batch_data_indexes = data['batch_data_indexes']
+ self.sequence_length = data["sequence_length"]
+ self.sample_lens = data["sample_lens"]
+ self.packed_text_ids = data["packed_text_ids"]
+ self.packed_text_indexes = data["packed_text_indexes"]
+ self.packed_position_ids = data["packed_position_ids"]
+
+ self.use_flex = "nested_attention_masks" not in data.keys()
+
+ if self.use_flex:
+ self.split_lens = data["split_lens"]
+ self.attn_modes = data["attn_modes"]
+ else:
+ self.nested_attention_masks = data["nested_attention_masks"]
+
+ if "padded_images" in data.keys():
+ self.padded_images = data["padded_images"]
+ self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"]
+ self.packed_latent_position_ids = data["packed_latent_position_ids"]
+ self.packed_vae_token_indexes = data["packed_vae_token_indexes"]
+
+ if "packed_vit_tokens" in data.keys():
+ self.packed_vit_tokens = data["packed_vit_tokens"]
+ self.packed_vit_position_ids = data["packed_vit_position_ids"]
+ self.packed_vit_token_indexes = data["packed_vit_token_indexes"]
+ self.vit_token_seqlens = data["vit_token_seqlens"]
+
+ if "packed_timesteps" in data.keys():
+ self.packed_timesteps = data["packed_timesteps"]
+ self.mse_loss_indexes = data["mse_loss_indexes"]
+
+ if "packed_label_ids" in data.keys():
+ self.packed_label_ids = data["packed_label_ids"]
+ self.ce_loss_indexes = data["ce_loss_indexes"]
+ self.ce_loss_weights = data["ce_loss_weights"]
+
+ def pin_memory(self):
+ self.packed_text_ids = self.packed_text_ids.pin_memory()
+ self.packed_text_indexes = self.packed_text_indexes.pin_memory()
+ self.packed_position_ids = self.packed_position_ids.pin_memory()
+
+ if not self.use_flex:
+ self.nested_attention_masks = [item.pin_memory() for item in self.nested_attention_masks]
+
+ if hasattr(self, 'padded_images'):
+ self.padded_images = self.padded_images.pin_memory()
+ self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory()
+ self.packed_latent_position_ids = self.packed_latent_position_ids.pin_memory()
+
+ if hasattr(self, 'packed_timesteps'):
+ self.packed_timesteps = self.packed_timesteps.pin_memory()
+ self.mse_loss_indexes = self.mse_loss_indexes.pin_memory()
+
+ if hasattr(self, 'packed_vit_tokens'):
+ self.packed_vit_tokens = self.packed_vit_tokens.pin_memory()
+ self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory()
+ self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory()
+ self.vit_token_seqlens = self.vit_token_seqlens.pin_memory()
+
+ if hasattr(self, 'packed_label_ids'):
+ self.packed_label_ids = self.packed_label_ids.pin_memory()
+ self.ce_loss_indexes = self.ce_loss_indexes.pin_memory()
+ self.ce_loss_weights = self.ce_loss_weights.pin_memory()
+
+ return self
+
+ def cuda(self, device):
+ self.packed_text_ids = self.packed_text_ids.to(device)
+ self.packed_text_indexes = self.packed_text_indexes.to(device)
+ self.packed_position_ids = self.packed_position_ids.to(device)
+
+ if not self.use_flex:
+ self.nested_attention_masks = [item.to(device) for item in self.nested_attention_masks]
+
+ if hasattr(self, 'padded_images'):
+ self.padded_images = self.padded_images.to(device)
+ self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device)
+ self.packed_latent_position_ids = self.packed_latent_position_ids.to(device)
+
+ if hasattr(self, 'packed_timesteps'):
+ self.packed_timesteps = self.packed_timesteps.to(device)
+ self.mse_loss_indexes = self.mse_loss_indexes.to(device)
+
+ if hasattr(self, 'packed_vit_tokens'):
+ self.packed_vit_tokens = self.packed_vit_tokens.to(device)
+ self.packed_vit_position_ids = self.packed_vit_position_ids.to(device)
+ self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device)
+ self.vit_token_seqlens = self.vit_token_seqlens.to(device)
+
+ if hasattr(self, 'packed_label_ids'):
+ self.packed_label_ids = self.packed_label_ids.to(device)
+ self.ce_loss_indexes = self.ce_loss_indexes.to(device)
+ self.ce_loss_weights = self.ce_loss_weights.to(device)
+
+ return self
+
+ def to_dict(self):
+ data = dict(
+ sequence_length = self.sequence_length,
+ sample_lens = self.sample_lens,
+ packed_text_ids = self.packed_text_ids,
+ packed_text_indexes = self.packed_text_indexes,
+ packed_position_ids = self.packed_position_ids,
+ batch_data_indexes = self.batch_data_indexes,
+ )
+
+ if not self.use_flex:
+ data['nested_attention_masks'] = self.nested_attention_masks
+ else:
+ data['split_lens'] = self.split_lens
+ data['attn_modes'] = self.attn_modes
+
+ if hasattr(self, 'padded_images'):
+ data['padded_images'] = self.padded_images
+ data['patchified_vae_latent_shapes'] = self.patchified_vae_latent_shapes
+ data['packed_latent_position_ids'] = self.packed_latent_position_ids
+ data['packed_vae_token_indexes'] = self.packed_vae_token_indexes
+
+ if hasattr(self, 'packed_vit_tokens'):
+ data['packed_vit_tokens'] = self.packed_vit_tokens
+ data['packed_vit_position_ids'] = self.packed_vit_position_ids
+ data['packed_vit_token_indexes'] = self.packed_vit_token_indexes
+ data['vit_token_seqlens'] = self.vit_token_seqlens
+
+ if hasattr(self, 'packed_timesteps'):
+ data['packed_timesteps'] = self.packed_timesteps
+ data['mse_loss_indexes'] = self.mse_loss_indexes
+
+ if hasattr(self, 'packed_label_ids'):
+ data['packed_label_ids'] = self.packed_label_ids
+ data['ce_loss_indexes'] = self.ce_loss_indexes
+ data['ce_loss_weights'] = self.ce_loss_weights
+
+ return data
+
+
+def collate_wrapper():
+ def collate_fn(batch):
+ return SimpleCustomBatch(batch)
+ return collate_fn
diff --git a/data/dataset_info.py b/data/dataset_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..055e301279c1a57d5b037cfc809f72f6a0011d81
--- /dev/null
+++ b/data/dataset_info.py
@@ -0,0 +1,39 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+from .interleave_datasets import UnifiedEditIterableDataset
+from .t2i_dataset import T2IIterableDataset
+from .vlm_dataset import SftJSONLIterableDataset
+
+
+DATASET_REGISTRY = {
+ 't2i_pretrain': T2IIterableDataset,
+ 'vlm_sft': SftJSONLIterableDataset,
+ 'unified_edit': UnifiedEditIterableDataset,
+}
+
+
+DATASET_INFO = {
+ 't2i_pretrain': {
+ 't2i': {
+ 'data_dir': 'your_data_path/bagel_example/t2i', # path of the parquet files
+ 'num_files': 10, # number of data units to be sharded across all ranks and workers
+ 'num_total_samples': 1000, # number of total samples in the dataset
+ },
+ },
+ 'unified_edit':{
+ 'seedxedit_multi': {
+ 'data_dir': 'your_data_path/bagel_example/editing/seedxedit_multi',
+ 'num_files': 10,
+ 'num_total_samples': 1000,
+ "parquet_info_path": 'your_data_path/bagel_example/editing/parquet_info/seedxedit_multi_nas.json', # information of the parquet files
+ },
+ },
+ 'vlm_sft': {
+ 'llava_ov': {
+ 'data_dir': 'your_data_path/bagel_example/vlm/images',
+ 'jsonl_path': 'your_data_path/bagel_example/vlm/llava_ov_si.jsonl',
+ 'num_total_samples': 1000
+ },
+ },
+}
\ No newline at end of file
diff --git a/data/distributed_iterable_dataset.py b/data/distributed_iterable_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..65668a4b97b7824caee2a6f02ce1e0999238b39c
--- /dev/null
+++ b/data/distributed_iterable_dataset.py
@@ -0,0 +1,58 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+import random
+import torch
+
+
+class DistributedIterableDataset(torch.utils.data.IterableDataset):
+ def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
+ self.dataset_name = dataset_name
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.num_workers = num_workers
+ self.rng = random.Random()
+ self.data_paths = None
+
+ def get_data_paths(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def set_epoch(self, seed=42):
+ if self.data_paths is None:
+ return
+
+ if isinstance(self.data_paths[0], tuple):
+ data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
+ elif isinstance(self.data_paths[0], str):
+ data_paths = sorted(self.data_paths)
+ else:
+ raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
+
+ self.rng.seed(seed)
+ self.rng.shuffle(data_paths)
+
+ num_files_per_rank = len(data_paths) // self.world_size
+ local_start = self.local_rank * num_files_per_rank
+ local_end = (self.local_rank + 1) * num_files_per_rank
+ self.num_files_per_rank = num_files_per_rank
+ self.data_paths_per_rank = data_paths[local_start:local_end]
+
+ def get_data_paths_per_worker(self):
+ if self.data_paths is None:
+ return None
+
+ info = torch.utils.data.get_worker_info()
+ if info is None:
+ # Single worker: Use all files assigned to the rank
+ return self.data_paths_per_rank, 0
+
+ worker_id = info.id
+ num_files_per_worker = self.num_files_per_rank // info.num_workers
+ start = num_files_per_worker * worker_id
+ end = num_files_per_worker * (worker_id + 1)
+ data_paths_per_worker = self.data_paths_per_rank[start:end]
+
+ return data_paths_per_worker[::-1], worker_id
+
+ def __iter__(self):
+ raise NotImplementedError
diff --git a/data/interleave_datasets/__init__.py b/data/interleave_datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..528e72ab18a4d86179409d5834d19f34466f8c66
--- /dev/null
+++ b/data/interleave_datasets/__init__.py
@@ -0,0 +1,5 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+from .edit_dataset import UnifiedEditIterableDataset
+
diff --git a/data/interleave_datasets/edit_dataset.py b/data/interleave_datasets/edit_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec72e2e4490855b3b20a099fb340c3b107358c91
--- /dev/null
+++ b/data/interleave_datasets/edit_dataset.py
@@ -0,0 +1,72 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+import io
+import random
+from PIL import Image, ImageFile, PngImagePlugin
+
+from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
+from ..data_utils import pil_img2rgb
+
+
+Image.MAX_IMAGE_PIXELS = 200000000
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+MaximumDecompressedSize = 1024
+MegaByte = 2 ** 20
+PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
+
+
+class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
+
+ def parse_row(self, row):
+ image_num = len(row["image_list"])
+ # randomly choose start and end, return [0, 1] when only two images
+ start_idx = random.choice(range(image_num - 1))
+ max_end = min(start_idx + 3, image_num)
+ end_idx = random.choice(range(start_idx + 1, max_end))
+
+ data = self._init_data()
+ data = self._add_image(
+ data,
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
+ need_loss=False,
+ need_vae=True,
+ need_vit=True,
+ )
+
+ if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
+ if end_idx == image_num - 1:
+ end_idx -= 1
+
+ instruction = ""
+ for idx in range(start_idx + 1, end_idx + 1):
+ instruction += random.choice(row["instruction_list"][idx-1]) + ". "
+ data = self._add_text(data, instruction.rstrip(), need_loss=False)
+ data = self._add_image(
+ data,
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
+ need_loss=True,
+ need_vae=False,
+ need_vit=False,
+ )
+ else:
+ for idx in range(start_idx + 1, end_idx + 1):
+ instruction = random.choice(row["instruction_list"][idx-1])
+ data = self._add_text(data, instruction, need_loss=False)
+ if idx != end_idx:
+ data = self._add_image(
+ data,
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
+ need_loss=True,
+ need_vae=True,
+ need_vit=True,
+ )
+ else:
+ data = self._add_image(
+ data,
+ pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
+ need_loss=True,
+ need_vae=False,
+ need_vit=False,
+ )
+ return data
diff --git a/data/interleave_datasets/interleave_t2i_dataset.py b/data/interleave_datasets/interleave_t2i_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..98c18436a2d41c0e84ad864b0b999cb12530695c
--- /dev/null
+++ b/data/interleave_datasets/interleave_t2i_dataset.py
@@ -0,0 +1,212 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+import pyarrow.parquet as pq
+
+from ..distributed_iterable_dataset import DistributedIterableDataset
+from ..parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
+
+
+class InterleavedBaseIterableDataset(DistributedIterableDataset):
+
+ def _init_data(self):
+ data = {
+ 'sequence_plan': [],
+ 'text_ids_list': [],
+ 'image_tensor_list': [],
+ 'num_tokens': 0,
+ }
+ return data
+
+ def _add_text(self, data, text, need_loss, enable_cfg=True):
+ text_ids = self.tokenizer.encode(text)
+ data['num_tokens'] += len(text_ids)
+ data['text_ids_list'].append(text_ids)
+ data['sequence_plan'].append(
+ {
+ 'type': 'text',
+ 'enable_cfg': int(enable_cfg),
+ 'loss': int(need_loss),
+ 'special_token_loss': 0,
+ 'special_token_label': None,
+ }
+ )
+ return data
+
+ def _add_image(self, data, image, need_loss, need_vae, need_vit, enable_cfg=True):
+ assert need_loss or need_vae or need_vit
+
+ if need_loss:
+ data['sequence_plan'].append(
+ {
+ 'type': 'vae_image',
+ 'enable_cfg': 0,
+ 'loss': 1,
+ 'special_token_loss': 0,
+ 'special_token_label': None,
+ }
+ )
+
+ image_tensor = self.transform(image)
+ height, width = image_tensor.shape[1:]
+ data['num_tokens'] += width * height // self.transform.stride ** 2
+ data['image_tensor_list'].append(image_tensor)
+
+ if need_vae:
+ data['sequence_plan'].append(
+ {
+ 'type': 'vae_image',
+ 'enable_cfg': int(enable_cfg),
+ 'loss': 0,
+ 'special_token_loss': 0,
+ 'special_token_label': None,
+ }
+ )
+
+ image_tensor = self.transform(image)
+ height, width = image_tensor.shape[1:]
+ data['num_tokens'] += width * height // self.transform.stride ** 2
+ data['image_tensor_list'].append(image_tensor.clone())
+
+ if need_vit:
+ data['sequence_plan'].append(
+ {
+ 'type': 'vit_image',
+ 'enable_cfg': int(enable_cfg),
+ 'loss': 0,
+ 'special_token_loss': 0,
+ 'special_token_label': None,
+ },
+ )
+ vit_image_tensor = self.vit_transform(image)
+ height, width = vit_image_tensor.shape[1:]
+ data['num_tokens'] += width * height // self.vit_transform.stride ** 2
+ data['image_tensor_list'].append(vit_image_tensor)
+
+ return data
+
+ def _add_video(self, data, frames, frame_indexes, need_loss, need_vae, enable_cfg=True):
+ assert int(need_loss) + int(need_vae) == 1
+
+ if need_loss:
+ for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
+ current_sequence_plan = {
+ 'type': 'vae_image',
+ 'enable_cfg': 0,
+ 'loss': 1,
+ 'special_token_loss': 0,
+ 'special_token_label': None,
+ 'split_start': idx == 0,
+ 'split_end': idx == len(frames) - 1,
+ }
+ if idx < len(frame_indexes) - 1:
+ current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
+ data['sequence_plan'].append(current_sequence_plan)
+ image_tensor = self.transform(image)
+ height, width = image_tensor.shape[1:]
+ data['image_tensor_list'].append(image_tensor)
+ data['num_tokens'] += width * height // self.transform.stride ** 2
+
+ elif need_vae:
+ for idx, (image, frame_idx) in enumerate(zip(frames, frame_indexes)):
+ current_sequence_plan = {
+ 'type': 'vae_image',
+ 'enable_cfg': int(enable_cfg),
+ 'loss': 0,
+ 'special_token_loss': 0,
+ 'special_token_label': None,
+ 'split_start': idx == 0,
+ 'split_end': idx == len(frames) - 1,
+ }
+ if idx < len(frame_indexes) - 1:
+ current_sequence_plan['frame_delta'] = frame_indexes[idx + 1] - frame_idx
+ data['sequence_plan'].append(current_sequence_plan)
+ image_tensor = self.transform(image)
+ height, width = image_tensor.shape[1:]
+ data['image_tensor_list'].append(image_tensor)
+ data['num_tokens'] += width * height // self.transform.stride ** 2
+
+ return data
+
+
+class ParquetStandardIterableDataset(DistributedIterableDataset):
+
+ def __init__(
+ self, dataset_name, transform, tokenizer, vit_transform,
+ data_dir_list, num_used_data, parquet_info,
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
+ ):
+ """
+ data_dir_list: list of data directories contains parquet files
+ num_used_data: list of number of sampled data paths for each data directory
+ vit_transform: input transform for vit model.
+ """
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
+ self.transform = transform
+ self.vit_transform = vit_transform
+ self.tokenizer = tokenizer
+ self.data_status = data_status
+ self.data_paths = self.get_data_paths(data_dir_list, num_used_data, parquet_info)
+ self.set_epoch()
+
+ def get_data_paths(self, data_dir_list, num_used_data, parquet_info):
+ row_groups = []
+ for data_dir, num_data_path in zip(data_dir_list, num_used_data):
+ data_paths = get_parquet_data_paths([data_dir], [num_data_path])
+ for data_path in data_paths:
+ if data_path in parquet_info.keys():
+ num_row_groups = parquet_info[data_path]['num_row_groups']
+ for rg_idx in range(num_row_groups):
+ row_groups.append((data_path, rg_idx))
+ return row_groups
+
+ def parse_row(self, row):
+ raise NotImplementedError
+
+ def __iter__(self):
+ file_paths_per_worker, worker_id = self.get_data_paths_per_worker()
+ if self.data_status is not None:
+ global_row_group_start_id = self.data_status[worker_id][0]
+ row_start_id = self.data_status[worker_id][1] + 1
+ else:
+ global_row_group_start_id = 0
+ row_start_id = 0
+
+ print(
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
+ f"resuming data at global_rg#{global_row_group_start_id}, row#{row_start_id}"
+ )
+
+ while True:
+ file_paths_per_worker_ = file_paths_per_worker[global_row_group_start_id:]
+ for global_row_group_idx, (parquet_file_path, row_group_id) in enumerate(
+ file_paths_per_worker_, start=global_row_group_start_id
+ ):
+ fs = init_arrow_pf_fs(parquet_file_path)
+ with fs.open_input_file(parquet_file_path) as f:
+ try:
+ fr = pq.ParquetFile(f)
+ df = fr.read_row_group(row_group_id).to_pandas()
+ df = df.iloc[row_start_id:]
+ except Exception as e:
+ print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
+ continue
+
+ for row_idx, row in df.iterrows():
+ try:
+ data = self.parse_row(row)
+ if len(data) == 0:
+ continue
+ data['data_indexes'] = {
+ "data_indexes": [global_row_group_idx, row_idx],
+ "worker_id": worker_id,
+ "dataset_name": self.dataset_name,
+ }
+ except Exception as e:
+ print(f'Error {e} in rg#{row_group_id}, {parquet_file_path}')
+ continue
+ yield data
+
+ row_start_id = 0
+ global_row_group_start_id = 0
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
diff --git a/data/parquet_utils.py b/data/parquet_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..12f23019c1821506f11c473122c30a0f05fed797
--- /dev/null
+++ b/data/parquet_utils.py
@@ -0,0 +1,90 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import os
+import xml.etree.ElementTree as ET
+import subprocess
+import logging
+
+import pyarrow.fs as pf
+import torch.distributed as dist
+
+logger = logging.getLogger(__name__)
+
+
+def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
+ num_data_dirs = len(data_dir_list)
+ if world_size > 1:
+ chunk_size = (num_data_dirs + world_size - 1) // world_size
+ start_idx = rank * chunk_size
+ end_idx = min(start_idx + chunk_size, num_data_dirs)
+ local_data_dir_list = data_dir_list[start_idx:end_idx]
+ local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
+ else:
+ local_data_dir_list = data_dir_list
+ local_num_sampled_data_paths = num_sampled_data_paths
+
+ local_data_paths = []
+ for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths):
+ if data_dir.startswith("hdfs://"):
+ files = hdfs_ls_cmd(data_dir)
+ data_paths_per_dir = [
+ file for file in files if file.endswith(".parquet")
+ ]
+ else:
+ files = os.listdir(data_dir)
+ data_paths_per_dir = [
+ os.path.join(data_dir, name)
+ for name in files
+ if name.endswith(".parquet")
+ ]
+ repeat = num_data_path // len(data_paths_per_dir)
+ data_paths_per_dir = data_paths_per_dir * (repeat + 1)
+ local_data_paths.extend(data_paths_per_dir[:num_data_path])
+
+ if world_size > 1:
+ gather_list = [None] * world_size
+ dist.all_gather_object(gather_list, local_data_paths)
+
+ combined_chunks = []
+ for chunk_list in gather_list:
+ if chunk_list is not None:
+ combined_chunks.extend(chunk_list)
+ else:
+ combined_chunks = local_data_paths
+
+ return combined_chunks
+
+
+# NOTE: cumtomize this function for your cluster
+def get_hdfs_host():
+ return "hdfs://xxx"
+
+
+# NOTE: cumtomize this function for your cluster
+def get_hdfs_block_size():
+ return 134217728
+
+
+# NOTE: cumtomize this function for your cluster
+def get_hdfs_extra_conf():
+ return None
+
+
+def init_arrow_pf_fs(parquet_file_path):
+ if parquet_file_path.startswith("hdfs://"):
+ fs = pf.HadoopFileSystem(
+ host=get_hdfs_host(),
+ port=0,
+ buffer_size=get_hdfs_block_size(),
+ extra_conf=get_hdfs_extra_conf(),
+ )
+ else:
+ fs = pf.LocalFileSystem()
+ return fs
+
+
+def hdfs_ls_cmd(dir):
+ result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout
+ return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i]
diff --git a/data/t2i_dataset.py b/data/t2i_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..81e1398400c8493eb6af74c38ae087e443f1060c
--- /dev/null
+++ b/data/t2i_dataset.py
@@ -0,0 +1,128 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+import io
+import json
+import pyarrow.parquet as pq
+import random
+from PIL import Image
+
+from .data_utils import pil_img2rgb
+from .distributed_iterable_dataset import DistributedIterableDataset
+from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
+
+Image.MAX_IMAGE_PIXELS = 20_000_000
+
+
+class T2IIterableDataset(DistributedIterableDataset):
+ def __init__(
+ self, dataset_name, transform, tokenizer, data_dir_list, num_used_data,
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
+ ):
+ """
+ data_dir_list: list of data directories contains parquet files
+ num_used_data: list of number of sampled data paths for each data directory
+ """
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
+ self.transform = transform
+ self.tokenizer = tokenizer
+ self.data_status = data_status
+ self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
+ self.set_epoch()
+
+ def get_data_paths(self, data_dir_list, num_used_data):
+ return get_parquet_data_paths(data_dir_list, num_used_data)
+
+ def __iter__(self):
+ data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
+ if self.data_status is not None:
+ parquet_start_id = self.data_status[worker_id][0]
+ row_group_start_id = self.data_status[worker_id][1]
+ row_start_id = self.data_status[worker_id][2] + 1
+ else:
+ parquet_start_id = 0
+ row_group_start_id = 0
+ row_start_id = 0
+ transform_stride = self.transform.stride
+
+ print(
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
+ f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
+ )
+
+ while True:
+ data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
+ for parquet_idx, parquet_file_path in enumerate(data_paths_per_worker_, start=parquet_start_id):
+ fs = init_arrow_pf_fs(parquet_file_path)
+ with fs.open_input_file(parquet_file_path) as f:
+ fr = pq.ParquetFile(f)
+ row_group_ids = list(range(fr.num_row_groups))
+ row_group_ids_ = row_group_ids[row_group_start_id:]
+
+ for row_group_id in row_group_ids_:
+ df = fr.read_row_group(row_group_id).to_pandas()
+ df = df.iloc[row_start_id:]
+
+ for row_idx, row in df.iterrows():
+ num_tokens = 0
+ try:
+ image_byte = row['image']
+ image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
+ except Exception as e:
+ print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
+ continue
+ image_tensor = self.transform(image)
+ height, width = image_tensor.shape[1:]
+ num_tokens += width * height // transform_stride ** 2
+
+ try:
+ caption_dict = row['captions']
+ caption_dict = json.loads(caption_dict)
+ except Exception as e:
+ print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
+ continue
+
+ caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
+ if len(caps_token) == 0:
+ print(f'no caption in rg#{row_group_id}, {parquet_file_path}')
+ caption_token = self.tokenizer.encode(' ')
+ else:
+ caption_token = random.choice(caps_token)
+
+ sequence_plan, text_ids_list = [], []
+ text_ids = caption_token
+ num_tokens += len(caption_token)
+ text_ids_list.append(text_ids)
+ sequence_plan.append({
+ 'type': 'text',
+ 'enable_cfg': 1,
+ 'loss': 0,
+ 'special_token_loss': 0,
+ 'special_token_label': None,
+ })
+
+ sequence_plan.append({
+ 'type': 'vae_image',
+ 'enable_cfg': 0,
+ 'loss': 1,
+ 'special_token_loss': 0,
+ 'special_token_label': None,
+ })
+
+ sample = dict(
+ image_tensor_list=[image_tensor],
+ text_ids_list=text_ids_list,
+ num_tokens=num_tokens,
+ sequence_plan=sequence_plan,
+ data_indexes={
+ "data_indexes": [parquet_idx, row_group_id, row_idx],
+ "worker_id": worker_id,
+ "dataset_name": self.dataset_name,
+ }
+ )
+ yield sample
+
+ row_start_id = 0
+ row_group_start_id = 0
+ parquet_start_id = 0
+ print(f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}")
diff --git a/data/transforms.py b/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b0d415a019ac027a5deeeddf8647439c9a6830f
--- /dev/null
+++ b/data/transforms.py
@@ -0,0 +1,287 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+import random
+from PIL import Image
+
+import cv2
+import numpy as np
+import torch
+from torchvision import transforms
+from torchvision.transforms import functional as F
+from torchvision.transforms import InterpolationMode
+
+
+class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
+ """Resize the input image so that its longest side and shortest side are within a specified range,
+ ensuring that both sides are divisible by a specified stride.
+
+ Args:
+ max_size (int): Maximum size for the longest edge of the image.
+ min_size (int): Minimum size for the shortest edge of the image.
+ stride (int): Value by which the height and width of the image must be divisible.
+ max_pixels (int): Maximum pixels for the full image.
+ interpolation (InterpolationMode): Desired interpolation enum defined by
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
+ ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
+ The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
+ antialias (bool, optional): Whether to apply antialiasing (default is True).
+ """
+
+ def __init__(
+ self,
+ max_size: int,
+ min_size: int,
+ stride: int,
+ max_pixels: int,
+ interpolation=InterpolationMode.BICUBIC,
+ antialias=True
+ ):
+ super().__init__()
+ self.max_size = max_size
+ self.min_size = min_size
+ self.stride = stride
+ self.max_pixels = max_pixels
+ self.interpolation = interpolation
+ self.antialias = antialias
+
+ def _make_divisible(self, value, stride):
+ """Ensure the value is divisible by the stride."""
+ return max(stride, int(round(value / stride) * stride))
+
+ def _apply_scale(self, width, height, scale):
+ new_width = round(width * scale)
+ new_height = round(height * scale)
+ new_width = self._make_divisible(new_width, self.stride)
+ new_height = self._make_divisible(new_height, self.stride)
+ return new_width, new_height
+
+ def forward(self, img, img_num=1):
+ """
+ Args:
+ img (PIL Image): Image to be resized.
+ img_num (int): Number of images, used to change max_tokens.
+ Returns:
+ PIL Image or Tensor: Rescaled image with divisible dimensions.
+ """
+ if isinstance(img, torch.Tensor):
+ height, width = img.shape[-2:]
+ else:
+ width, height = img.size
+
+ scale = min(self.max_size / max(width, height), 1.0)
+ scale = max(scale, self.min_size / min(width, height))
+ new_width, new_height = self._apply_scale(width, height, scale)
+
+ # Ensure the number of pixels does not exceed max_pixels
+ if new_width * new_height > self.max_pixels / img_num:
+ scale = self.max_pixels / img_num / (new_width * new_height)
+ new_width, new_height = self._apply_scale(new_width, new_height, scale)
+
+ # Ensure longest edge does not exceed max_size
+ if max(new_width, new_height) > self.max_size:
+ scale = self.max_size / max(new_width, new_height)
+ new_width, new_height = self._apply_scale(new_width, new_height, scale)
+
+ return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
+
+
+class ImageTransform:
+ def __init__(
+ self,
+ max_image_size,
+ min_image_size,
+ image_stride,
+ max_pixels=14*14*9*1024,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5]
+ ):
+ self.stride = image_stride
+
+ self.resize_transform = MaxLongEdgeMinShortEdgeResize(
+ max_size=max_image_size,
+ min_size=min_image_size,
+ stride=image_stride,
+ max_pixels=max_pixels,
+ )
+ self.to_tensor_transform = transforms.ToTensor()
+ self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
+
+ def __call__(self, img, img_num=1):
+ img = self.resize_transform(img, img_num=img_num)
+ img = self.to_tensor_transform(img)
+ img = self.normalize_transform(img)
+ return img
+
+
+def decolorization(image):
+ gray_image = image.convert('L')
+ return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
+
+
+def downscale(image, scale_factor):
+ new_width = int(round(image.width * scale_factor))
+ new_height = int(round(image.height * scale_factor))
+ new_width = max(1, new_width)
+ new_height = max(1, new_height)
+ return image.resize((new_width, new_height), resample=Image.BICUBIC)
+
+
+def crop(image, crop_factors):
+ target_h, target_w = crop_factors
+ img_w, img_h = image.size
+
+ if target_h > img_h or target_w > img_w:
+ raise ValueError("Crop size exceeds image dimensions")
+
+ x = random.randint(0, img_w - target_w)
+ y = random.randint(0, img_h - target_h)
+
+ return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
+
+
+def motion_blur_opencv(image, kernel_size=15, angle=0):
+ # 线性核
+ kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
+ kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
+
+ # 旋转核
+ center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
+ M = cv2.getRotationMatrix2D(center, angle, 1)
+ rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
+
+ # 归一化核
+ rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
+
+ img = np.array(image)
+ if img.ndim == 2:
+ blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
+ else:
+ # 对于彩色图像,各通道独立卷积
+ blurred = np.zeros_like(img)
+ for c in range(img.shape[2]):
+ blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
+
+ return Image.fromarray(blurred.astype(np.uint8))
+
+
+def shuffle_patch(image, num_splits, gap_size=2):
+ """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
+ h_splits, w_splits = num_splits
+ img_w, img_h = image.size
+
+ base_patch_h = img_h // h_splits
+ patch_heights = [base_patch_h] * (h_splits - 1)
+ patch_heights.append(img_h - sum(patch_heights))
+
+ base_patch_w = img_w // w_splits
+ patch_widths = [base_patch_w] * (w_splits - 1)
+ patch_widths.append(img_w - sum(patch_widths))
+
+ patches = []
+ current_y = 0
+ for i in range(h_splits):
+ current_x = 0
+ patch_h = patch_heights[i]
+ for j in range(w_splits):
+ patch_w = patch_widths[j]
+ patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
+ patches.append(patch)
+ current_x += patch_w
+ current_y += patch_h
+
+ random.shuffle(patches)
+
+ total_width = sum(patch_widths) + (w_splits - 1) * gap_size
+ total_height = sum(patch_heights) + (h_splits - 1) * gap_size
+ new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
+
+ current_y = 0 # 当前行的起始 Y 坐标
+ patch_idx = 0 # 当前处理的块索引
+ for i in range(h_splits):
+ current_x = 0 # 当前列的起始 X 坐标
+ patch_h = patch_heights[i] # 当前行块的高度
+ for j in range(w_splits):
+ # 取出打乱后的块
+ patch = patches[patch_idx]
+ patch_w = patch_widths[j] # 当前列块的宽度
+ # 粘贴块(左上角坐标为 (current_x, current_y))
+ new_image.paste(patch, (current_x, current_y))
+ # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
+ current_x += patch_w + gap_size
+ patch_idx += 1
+ # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
+ current_y += patch_h + gap_size
+
+ return new_image
+
+
+def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
+ """
+ 图像分割后随机空白部分patch,用于inpainting任务
+
+ 参数:
+ image: PIL.Image 输入图像(RGB模式)
+ h_splits: int 行分割数(垂直方向分割块数)
+ w_splits: int 列分割数(水平方向分割块数)
+ blank_ratio: float 空白patch的比例(0~1)
+ blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
+
+ 返回:
+ PIL.Image 处理后拼接的图像
+ """
+ h_splits, w_splits = num_splits
+ img_w, img_h = image.size
+
+ base_patch_h = img_h // h_splits
+ patch_heights = [base_patch_h] * (h_splits - 1)
+ patch_heights.append(img_h - sum(patch_heights))
+
+ base_patch_w = img_w // w_splits
+ patch_widths = [base_patch_w] * (w_splits - 1)
+ patch_widths.append(img_w - sum(patch_widths))
+
+ patches = []
+ current_y = 0
+ for i in range(h_splits):
+ current_x = 0
+ patch_h = patch_heights[i]
+ for j in range(w_splits):
+ patch_w = patch_widths[j]
+ patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
+ patches.append(patch)
+ current_x += patch_w
+ current_y += patch_h
+
+ total_patches = h_splits * w_splits
+ num_blank = int(total_patches * blank_ratio)
+ num_blank = max(0, min(num_blank, total_patches))
+ blank_indices = random.sample(range(total_patches), num_blank)
+
+ processed_patches = []
+ for idx, patch in enumerate(patches):
+ if idx in blank_indices:
+ blank_patch = Image.new("RGB", patch.size, color=blank_color)
+ processed_patches.append(blank_patch)
+ else:
+ processed_patches.append(patch)
+
+ # 创建结果图像(尺寸与原图一致)
+ result_image = Image.new("RGB", (img_w, img_h))
+ current_y = 0
+ patch_idx = 0
+ for i in range(h_splits):
+ current_x = 0
+ patch_h = patch_heights[i]
+ for j in range(w_splits):
+ # 取出处理后的patch
+ patch = processed_patches[patch_idx]
+ patch_w = patch_widths[j]
+ # 粘贴到原位置
+ result_image.paste(patch, (current_x, current_y))
+ current_x += patch_w
+ patch_idx += 1
+ current_y += patch_h
+
+ return result_image
diff --git a/data/video_utils.py b/data/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b72fc5096207e69333a81388c82609c24e7f48a
--- /dev/null
+++ b/data/video_utils.py
@@ -0,0 +1,165 @@
+# Copyright (c) 2023 OpenGVLab
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: MIT
+#
+# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
+#
+# Original file was released under MIT, with the full license text
+# available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
+#
+# This modified file is released under the same license.
+
+
+import io
+import os
+import random
+import re
+
+import numpy as np
+import decord
+from PIL import Image
+
+
+def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
+ if sample in ['rand', 'middle']: # uniform sampling
+ acc_samples = min(num_frames, vlen)
+ # split the video into `acc_samples` intervals, and sample from each interval.
+ intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
+ ranges = []
+ for idx, interv in enumerate(intervals[:-1]):
+ ranges.append((interv, intervals[idx + 1] - 1))
+ if sample == 'rand':
+ try:
+ frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
+ except:
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
+ frame_indices.sort()
+ frame_indices = list(frame_indices)
+ elif fix_start is not None:
+ frame_indices = [x[0] + fix_start for x in ranges]
+ elif sample == 'middle':
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
+ else:
+ raise NotImplementedError
+
+ if len(frame_indices) < num_frames: # padded with last frame
+ padded_frame_indices = [frame_indices[-1]] * num_frames
+ padded_frame_indices[:len(frame_indices)] = frame_indices
+ frame_indices = padded_frame_indices
+ elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
+ output_fps = float(sample[3:])
+ duration = float(vlen) / input_fps
+ delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
+ frame_indices = [e for e in frame_indices if e < vlen]
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
+ frame_indices = frame_indices[:max_num_frames]
+ else:
+ raise ValueError
+ return frame_indices
+
+
+def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
+ video_reader = decord.VideoReader(video_path, num_threads=1)
+ vlen = len(video_reader)
+ fps = video_reader.get_avg_fps()
+ duration = vlen / float(fps)
+ if clip:
+ start, end = clip
+ duration = end - start
+ vlen = int(duration * fps)
+ start_index = int(start * fps)
+
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
+
+ frame_indices = get_frame_indices(
+ t_num_frames, vlen, sample=sample, fix_start=fix_start,
+ input_fps=fps
+ )
+ if clip:
+ frame_indices = [f + start_index for f in frame_indices]
+ frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
+ frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
+ return frames
+
+
+def extract_frame_number(filename):
+ # Extract the numeric part from the filename using regular expressions
+ match = re.search(r'_(\d+).jpg$', filename)
+ return int(match.group(1)) if match else -1
+
+
+def sort_frames(frame_paths):
+ # Extract filenames from each path and sort by their numeric part
+ return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
+
+
+def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
+ image_list = sort_frames(list(os.listdir(video_path)))
+ frames = []
+ for image in image_list:
+ fp = os.path.join(video_path, image)
+ frame = Image.open(fp).convert('RGB')
+ frames.append(frame)
+ vlen = len(frames)
+
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
+
+ if vlen > t_num_frames:
+ frame_indices = get_frame_indices(
+ t_num_frames, vlen, sample=sample, fix_start=fix_start
+ )
+ frames = [frames[i] for i in frame_indices]
+ return frames
+
+
+class FrameSampler:
+ def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
+ self.max_num_frames = max_num_frames
+ self.min_num_frames = min_num_frames
+ self.sample = sample
+
+ def __call__(self, file_name):
+ fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
+ frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
+ return frames
+
+
+def decode_video_byte(video_bytes):
+ video_stream = io.BytesIO(video_bytes)
+ vr = decord.VideoReader(video_stream)
+ return vr
+
+
+def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
+ if isinstance(mp4_p, str):
+ vr = decord.VideoReader(mp4_p, num_threads=1)
+ elif isinstance(mp4_p, decord.video_reader.VideoReader):
+ vr = mp4_p
+ video_fps = vr.get_avg_fps() # 获取视频的帧率
+ video_duration = len(vr) / video_fps
+ if n_frames is not None:
+ if random_sample:
+ frame_indices = sorted(random.sample(range(len(vr)), n_frames))
+ else:
+ frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
+ else:
+ frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
+ frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
+ frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
+ if not return_frame_indices:
+ return frames, video_duration
+ else:
+ return frames, video_duration, frame_indices
+
+
+def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
+ if isinstance(mp4_p, str):
+ vr = decord.VideoReader(mp4_p, num_threads=1)
+ elif isinstance(mp4_p, decord.video_reader.VideoReader):
+ vr = mp4_p
+ # sample the frames in frame_indices
+ frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
+ frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
+ return frames
\ No newline at end of file
diff --git a/data/vlm_dataset.py b/data/vlm_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b5e10377494253750100af720ad9112f7771d3c
--- /dev/null
+++ b/data/vlm_dataset.py
@@ -0,0 +1,195 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+import os
+import traceback
+from PIL import Image, ImageFile, PngImagePlugin
+
+from .data_utils import pil_img2rgb
+from .distributed_iterable_dataset import DistributedIterableDataset
+
+
+Image.MAX_IMAGE_PIXELS = 200000000
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+MaximumDecompressedSize = 1024
+MegaByte = 2 ** 20
+PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
+
+
+class SftJSONLIterableDataset(DistributedIterableDataset):
+ def __init__(
+ self, dataset_name, transform, tokenizer, frame_sampler,
+ jsonl_path_list, data_dir_list, num_used_data,
+ local_rank=0, world_size=1, num_workers=8, data_status=None,
+ shuffle_lines=False, shuffle_seed=0,
+ ):
+ """
+ jsonl_path_list: list of jsonl file paths
+ data_dir_list: list of image directories containing the images of each jsonl file
+ num_used_data: list of number of sampled data points for each jsonl
+ """
+ super().__init__(dataset_name, local_rank, world_size, num_workers)
+ self.transform = transform
+ self.tokenizer = tokenizer
+ self.frame_sampler = frame_sampler
+ self.data_status = data_status
+ self.data_paths = self.get_data_paths(
+ jsonl_path_list,
+ data_dir_list,
+ num_used_data,
+ shuffle_lines,
+ shuffle_seed,
+ )
+ self.set_epoch()
+
+ def get_data_paths(
+ self,
+ jsonl_path_list,
+ data_dir_list,
+ num_used_data,
+ shuffle_lines,
+ shuffle_seed,
+ ):
+ data_paths = []
+ for jsonl_path, image_dir, num_data_point in zip(
+ jsonl_path_list, data_dir_list, num_used_data
+ ):
+ with open(jsonl_path, 'r') as f:
+ raw_data = f.readlines()
+ if shuffle_lines:
+ self.rng.seed(shuffle_seed)
+ self.rng.shuffle(raw_data)
+ raw_data = raw_data[:num_data_point]
+ data_paths.extend([(json_data, image_dir) for json_data in raw_data])
+ return data_paths
+
+ def change_format(self, data, num_images):
+ elements = []
+ for conversation in data['conversations']:
+ if conversation['from'] == 'human':
+ if '' not in conversation['value']:
+ elements.append({
+ 'type': 'text',
+ 'has_loss': 0,
+ 'text': conversation['value'],
+ })
+ else:
+ text_list = conversation['value'].split('')
+ for idx, text in enumerate(text_list):
+ if text.strip() != '':
+ elements.append({
+ 'type': 'text',
+ 'has_loss': 0,
+ 'text': text.strip(),
+ })
+ if (idx != len(text_list) - 1) and (idx < num_images):
+ elements.append({'type': 'image',})
+ elif conversation['from'] == 'gpt':
+ elements.append({
+ 'type': 'text',
+ 'has_loss': 1,
+ 'text': conversation['value'],
+ })
+ return elements
+
+ def __iter__(self):
+ data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
+ if self.data_status is not None:
+ row_start_id = self.data_status[worker_id] + 1
+ else:
+ row_start_id = 0
+ transform_stride = self.transform.stride
+
+ print(
+ f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
+ f"resuming data at row#{row_start_id}"
+ )
+
+ while True:
+ data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
+ for row_idx, (data, image_dir) in enumerate(data_paths_per_worker_, start=row_start_id):
+ num_tokens = 0
+ image_tensor_list = []
+ text_ids_list = []
+ sequence_plan = []
+
+ try:
+ data_item = json.loads(data)
+ raw_images = None
+ if 'image' in data_item:
+ if type(data_item['image']) == list:
+ raw_images = [
+ pil_img2rgb(Image.open(os.path.join(image_dir, image)))
+ for image in data_item['image']
+ ]
+ else:
+ raw_images = [
+ pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'])))
+ ]
+ elif 'video' in data_item:
+ raw_images = self.frame_sampler(os.path.join(image_dir, data_item['video']))
+ special_tokens = '' * len(raw_images)
+ for item in data_item['conversations']:
+ if '