# Copyright 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import functools import os import wandb import yaml from copy import deepcopy from dataclasses import dataclass, field from time import time import torch import torch.distributed as dist from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper, ) from torch.utils.data import DataLoader from transformers import HfArgumentParser, set_seed from transformers.optimization import ( get_constant_schedule_with_warmup, get_cosine_with_min_lr_schedule_with_warmup, ) from data.dataset_base import DataConfig, PackedDataset, collate_wrapper from data.data_utils import add_special_tokens from modeling.autoencoder import load_ae from modeling.bagel import ( BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel ) from modeling.qwen2 import Qwen2Tokenizer from train.train_utils import create_logger, get_latest_ckpt from train.fsdp_utils import ( FSDPCheckpoint, FSDPConfig, grad_checkpoint_check_fn, fsdp_wrapper, fsdp_ema_setup, fsdp_ema_update, ) @dataclass class ModelArguments: llm_path: str = field( default="hf/Qwen2.5-0.5B-Instruct/", metadata={"help": "Path or HuggingFace repo ID of the pretrained Qwen2-style language model."} ) llm_qk_norm: bool = field( default=True, metadata={"help": "Enable QK LayerNorm (qk_norm) inside the attention blocks."} ) tie_word_embeddings: bool = field( default=False, metadata={"help": "Share input and output word embeddings (tied embeddings)."} ) layer_module: str = field( default="Qwen2DecoderLayer", metadata={"help": "Python class name of the decoder layer to instantiate."} ) vae_path: str = field( default="flux/vae/ae.safetensors", metadata={"help": "Path to the pretrained VAE checkpoint for latent-space image generation."} ) vit_path: str = field( default="hf/siglip-so400m-14-980-flash-attn2-navit/", metadata={"help": "Path or repo ID of the SigLIP Vision Transformer used for image understanding."} ) max_latent_size: int = field( default=32, metadata={"help": "Maximum latent grid size (patches per side) for the VAE latent tensor."} ) latent_patch_size: int = field( default=2, metadata={"help": "Spatial size (in VAE pixels) covered by each latent patch."} ) vit_patch_size: int = field( default=14, metadata={"help": "Patch size (pixels) for the Vision Transformer encoder."} ) vit_max_num_patch_per_side: int = field( default=70, metadata={"help": "Maximum number of ViT patches along one image side after cropping / resize."} ) connector_act: str = field( default="gelu_pytorch_tanh", metadata={"help": "Activation function used in the latent-to-text connector MLP."} ) interpolate_pos: bool = field( default=False, metadata={"help": "Interpolate positional embeddings when image resolution differs from pre-training."} ) vit_select_layer: int = field( default=-2, metadata={"help": "Which hidden layer of the ViT to take as the visual feature (negative = from the end)."} ) vit_rope: bool = field( default=False, metadata={"help": "Replace ViT positional encodings with RoPE."} ) text_cond_dropout_prob: float = field( default=0.1, metadata={"help": "Probability of dropping text embeddings during training."} ) vae_cond_dropout_prob: float = field( default=0.3, metadata={"help": "Probability of dropping VAE latent inputs during training."} ) vit_cond_dropout_prob: float = field( default=0.3, metadata={"help": "Probability of dropping ViT visual features during training."} ) @dataclass class DataArguments: dataset_config_file: str = field( default="data/configs/example.yaml", metadata={"help": "YAML file specifying dataset groups, weights, and preprocessing rules."} ) prefetch_factor: int = field( default=2, metadata={"help": "How many batches each DataLoader worker pre-loads in advance."} ) num_workers: int = field( default=4, metadata={"help": "Number of background workers for the PyTorch DataLoader."} ) max_num_tokens_per_sample: int = field( default=16384, metadata={"help": "Maximum tokens allowed in one raw sample; longer samples are skipped."} ) max_num_tokens: int = field( default=36864, metadata={"help": "Hard limit on tokens in a packed batch; flush if adding a sample would exceed it."} ) prefer_buffer_before: int = field( default=16384, metadata={"help": "While batch length is below this, pop from the overflow buffer before new sampling."} ) max_buffer_size: int = field( default=50, metadata={"help": "Maximum number of oversized samples kept in the overflow buffer."} ) data_seed: int = field( default=42, metadata={"help": "Seed used when shuffling / sampling data shards to ensure reproducibility."} ) @dataclass class TrainingArguments: # --- modality switches --- visual_gen: bool = field( default=True, metadata={"help": "Train image generation branch."} ) visual_und: bool = field( default=True, metadata={"help": "Train image understanding branch."} ) # --- bookkeeping & logging --- results_dir: str = field( default="results", metadata={"help": "Root directory for logs."} ) checkpoint_dir: str = field( default="results/checkpoints", metadata={"help": "Root directory for model checkpoints."} ) wandb_project: str = field( default="bagel", metadata={"help": "Weights & Biases project name."} ) wandb_name: str = field( default="run", metadata={"help": "Name shown in the Weights & Biases UI for this run."} ) wandb_runid: str = field( default="0", metadata={"help": "Unique identifier to resume a previous W&B run, if desired."} ) wandb_resume: str = field( default="allow", metadata={"help": "W&B resume mode: 'allow', 'must', or 'never'."} ) wandb_offline: bool = field( default=False, metadata={"help": "Run W&B in offline mode (logs locally, sync later)."} ) # --- reproducibility & resume --- global_seed: int = field( default=4396, metadata={"help": "Base random seed; actual seed is offset by rank for DDP."} ) auto_resume: bool = field( default=False, metadata={"help": "Automatically pick up the latest checkpoint found in checkpoint_dir."} ) resume_from: str = field( default=None, metadata={"help": "Explicit checkpoint path to resume from (overrides auto_resume)." } ) resume_model_only: bool = field( default=False, metadata={"help": "Load only model weights, ignoring optimizer/scheduler states."} ) finetune_from_ema: bool = field( default=False, metadata={"help": "When resume_model_only=True, load the EMA (exponential moving average) weights instead of raw weights."} ) # --- reporting frequency --- log_every: int = field( default=10, metadata={"help": "Print / log every N training steps."} ) save_every: int = field( default=2000, metadata={"help": "Save a checkpoint every N training steps."} ) total_steps: int = field( default=500_000, metadata={"help": "Total number of optimizer steps to train for."} ) # --- optimization & scheduler --- warmup_steps: int = field( default=2000, metadata={"help": "Linear warm-up steps before applying the main LR schedule."} ) lr_scheduler: str = field( default="constant", metadata={"help": "Type of LR schedule: 'constant' or 'cosine'."} ) lr: float = field( default=1e-4, metadata={"help": "Peak learning rate after warm-up."} ) min_lr: float = field( default=1e-7, metadata={"help": "Minimum learning rate for cosine schedule (ignored for constant)."} ) beta1: float = field( default=0.9, metadata={"help": "AdamW β₁ coefficient."} ) beta2: float = field( default=0.95, metadata={"help": "AdamW β₂ coefficient."} ) eps: float = field( default=1e-15, metadata={"help": "AdamW ε for numerical stability."} ) ema: float = field( default=0.9999, metadata={"help": "Decay rate for the exponential moving average of model weights."} ) max_grad_norm: int = field( default=1.0, metadata={"help": "Gradient clipping threshold (L2 norm)."} ) timestep_shift: float = field( default=1.0, metadata={"help": "Shift applied to diffusion timestep indices (for latent prediction)."} ) mse_weight: float = field( default=1.0, metadata={"help": "Scaling factor for the image-reconstruction MSE loss term."} ) ce_weight: float = field( default=1.0, metadata={"help": "Scaling factor for the language cross-entropy loss term."} ) ce_loss_reweighting: bool = field( default=False, metadata={"help": "Reweight CE loss by token importance (provided via ce_loss_weights)."} ) expected_num_tokens: int = field( default=32768, metadata={"help": "Soft target token count; yield the batch once it reaches or exceeds this size."} ) # --- distributed training / FSDP --- num_replicate: int = field( default=1, metadata={"help": "Number of model replicas per GPU rank for tensor parallelism."} ) num_shard: int = field( default=8, metadata={"help": "Number of parameter shards when using FSDP HYBRID_SHARD."} ) sharding_strategy: str = field( default="HYBRID_SHARD", metadata={"help": "FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, HYBRID_SHARD, etc."} ) backward_prefetch: str = field( default="BACKWARD_PRE", metadata={"help": "FSDP backward prefetch strategy (BACKWARD_PRE or NO_PREFETCH)."} ) cpu_offload: bool = field( default=False, metadata={"help": "Enable FSDP parameter offload to CPU."} ) # --- module freezing --- freeze_llm: bool = field( default=False, metadata={"help": "Keep language-model weights fixed (no gradient updates)."} ) freeze_vit: bool = field( default=False, metadata={"help": "Keep ViT weights fixed during training."} ) freeze_vae: bool = field( default=True, metadata={"help": "Keep VAE weights fixed; only predict latents, don’t fine-tune encoder/decoder."} ) freeze_und: bool = field( default=False, metadata={"help": "Freeze the visual understanding connector layers."} ) copy_init_moe: bool = field( default=True, metadata={"help": "Duplicate initial MoE experts so each has identical initialisation."} ) use_flex: bool = field( default=False, metadata={"help": "Enable FLEX (flash-ext friendly) packing algorithm for sequence data."} ) def main(): assert torch.cuda.is_available() dist.init_process_group("nccl") device = dist.get_rank() % torch.cuda.device_count() torch.cuda.set_device(device) parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Setup logging: if dist.get_rank() == 0: os.makedirs(training_args.results_dir, exist_ok=True) os.makedirs(training_args.checkpoint_dir, exist_ok=True) logger = create_logger(training_args.results_dir, dist.get_rank()) wandb.init( project=training_args.wandb_project, id=f"{training_args.wandb_name}-run{training_args.wandb_runid}", name=training_args.wandb_name, resume=training_args.wandb_resume, mode="offline" if training_args.wandb_offline else "online" ) wandb.config.update(training_args) wandb.config.update(model_args) wandb.config.update(data_args) else: logger = create_logger(None, dist.get_rank()) dist.barrier() logger.info(f'Training arguments {training_args}') logger.info(f'Model arguments {model_args}') logger.info(f'Data arguments {data_args}') # prepare auto resume logic: if training_args.auto_resume: resume_from = get_latest_ckpt(training_args.checkpoint_dir) if resume_from is None: resume_from = training_args.resume_from resume_model_only = training_args.resume_model_only if resume_model_only: finetune_from_ema = training_args.finetune_from_ema else: finetune_from_ema = False else: resume_model_only = False finetune_from_ema = False else: resume_from = training_args.resume_from resume_model_only = training_args.resume_model_only if resume_model_only: finetune_from_ema = training_args.finetune_from_ema else: finetune_from_ema = False # Set seed: seed = training_args.global_seed * dist.get_world_size() + dist.get_rank() set_seed(seed) # Setup model: llm_config = Qwen2Config.from_pretrained(model_args.llm_path) llm_config.layer_module = model_args.layer_module llm_config.qk_norm = model_args.llm_qk_norm llm_config.tie_word_embeddings = model_args.tie_word_embeddings llm_config.freeze_und = training_args.freeze_und language_model = Qwen2ForCausalLM.from_pretrained(model_args.llm_path, config=llm_config) if training_args.copy_init_moe: language_model.init_moe() if training_args.visual_und: vit_config = SiglipVisionConfig.from_pretrained(model_args.vit_path) vit_config.num_hidden_layers = vit_config.num_hidden_layers + 1 + model_args.vit_select_layer vit_config.rope = model_args.vit_rope vit_model = SiglipVisionModel.from_pretrained(model_args.vit_path, config=vit_config) if training_args.visual_gen: vae_model, vae_config = load_ae(local_path=model_args.vae_path) config = BagelConfig( visual_gen=training_args.visual_gen, visual_und=training_args.visual_und, llm_config=llm_config, vit_config=vit_config if training_args.visual_und else None, vae_config=vae_config if training_args.visual_gen else None, latent_patch_size=model_args.latent_patch_size, max_latent_size=model_args.max_latent_size, vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, connector_act=model_args.connector_act, interpolate_pos=model_args.interpolate_pos, timestep_shift=training_args.timestep_shift, ) model = Bagel( language_model, vit_model if training_args.visual_und else None, config ) if training_args.visual_und: model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config) # Setup tokenizer for model: tokenizer = Qwen2Tokenizer.from_pretrained(model_args.llm_path) tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer) if num_new_tokens > 0: model.language_model.resize_token_embeddings(len(tokenizer)) model.config.llm_config.vocab_size = len(tokenizer) model.language_model.config.vocab_size = len(tokenizer) # maybe freeze something: if training_args.freeze_vae and training_args.visual_gen: for param in vae_model.parameters(): param.requires_grad = False if training_args.freeze_llm: model.language_model.eval() for param in model.language_model.parameters(): param.requires_grad = False if training_args.freeze_vit and training_args.visual_und: model.vit_model.eval() for param in model.vit_model.parameters(): param.requires_grad = False # Setup FSDP and load pretrained model: fsdp_config = FSDPConfig( sharding_strategy=training_args.sharding_strategy, backward_prefetch=training_args.backward_prefetch, cpu_offload=training_args.cpu_offload, num_replicate=training_args.num_replicate, num_shard=training_args.num_shard, ) ema_model = deepcopy(model) model, ema_model = FSDPCheckpoint.try_load_ckpt( resume_from, logger, model, ema_model, resume_from_ema=finetune_from_ema ) ema_model = fsdp_ema_setup(ema_model, fsdp_config) fsdp_model = fsdp_wrapper(model, fsdp_config) apply_activation_checkpointing( fsdp_model, checkpoint_wrapper_fn=functools.partial( checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT ), check_fn=grad_checkpoint_check_fn ) if dist.get_rank() == 0: print(fsdp_model) for name, param in model.named_parameters(): print(name, param.requires_grad) # Setup optimizer and scheduler optimizer = torch.optim.AdamW( fsdp_model.parameters(), lr=training_args.lr, betas=(training_args.beta1, training_args.beta2), eps=training_args.eps, weight_decay=0 ) if training_args.lr_scheduler == 'cosine': scheduler = get_cosine_with_min_lr_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps, min_lr=training_args.min_lr, ) elif training_args.lr_scheduler == 'constant': scheduler = get_constant_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=training_args.warmup_steps ) else: raise ValueError # maybe resume optimizer, scheduler, and train_steps if resume_model_only: train_step = 0 data_status = None else: optimizer, scheduler, train_step, data_status = FSDPCheckpoint.try_load_train_state( resume_from, optimizer, scheduler, fsdp_config, ) # Setup packed dataloader with open(data_args.dataset_config_file, "r") as stream: dataset_meta = yaml.safe_load(stream) dataset_config = DataConfig(grouped_datasets=dataset_meta) if training_args.visual_und: dataset_config.vit_patch_size = model_args.vit_patch_size dataset_config.max_num_patch_per_side = model_args.vit_max_num_patch_per_side if training_args.visual_gen: vae_image_downsample = model_args.latent_patch_size * vae_config.downsample dataset_config.vae_image_downsample = vae_image_downsample dataset_config.max_latent_size = model_args.max_latent_size dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob train_dataset = PackedDataset( dataset_config, tokenizer=tokenizer, special_tokens=new_token_ids, local_rank=dist.get_rank(), world_size=dist.get_world_size(), num_workers=data_args.num_workers, expected_num_tokens=training_args.expected_num_tokens, max_num_tokens_per_sample=data_args.max_num_tokens_per_sample, max_num_tokens=data_args.max_num_tokens, max_buffer_size=data_args.max_buffer_size, prefer_buffer_before=data_args.prefer_buffer_before, interpolate_pos=model_args.interpolate_pos, use_flex=training_args.use_flex, data_status=data_status, ) train_dataset.set_epoch(data_args.data_seed) train_loader = DataLoader( train_dataset, batch_size=1, # batch size is 1 packed dataset num_workers=data_args.num_workers, pin_memory=True, collate_fn=collate_wrapper(), drop_last=True, prefetch_factor=data_args.prefetch_factor, ) # Prepare models for training: if training_args.visual_gen: vae_model.to(device).eval() fsdp_model.train() ema_model.eval() # train loop start_time = time() logger.info(f"Training for {training_args.total_steps} steps, starting at {train_step}...") for curr_step, data in enumerate(train_loader, start=train_step): data = data.cuda(device).to_dict() data_indexes = data.pop('batch_data_indexes', None) ce_loss_weights = data.pop('ce_loss_weights', None) with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): if training_args.visual_gen: with torch.no_grad(): data['padded_latent'] = vae_model.encode(data.pop('padded_images')) loss_dict = fsdp_model(**data) loss = 0 ce = loss_dict["ce"] if ce is not None: total_ce_tokens = torch.tensor(len(data['ce_loss_indexes']), device=device) dist.all_reduce(total_ce_tokens, op=dist.ReduceOp.SUM) if training_args.ce_loss_reweighting: ce = ce * ce_loss_weights total_ce_loss_weights = ce_loss_weights.sum() dist.all_reduce(total_ce_loss_weights, op=dist.ReduceOp.SUM) ce = ce.sum() * dist.get_world_size() / total_ce_loss_weights else: ce = ce.sum() * dist.get_world_size() / total_ce_tokens loss_dict["ce"] = ce.detach() loss = loss + ce * training_args.ce_weight else: assert not training_args.visual_und loss_dict["ce"] = torch.tensor(0, device=device) total_ce_tokens = torch.tensor(0, device=device) if training_args.visual_gen: mse = loss_dict["mse"] total_mse_tokens = torch.tensor(len(data['mse_loss_indexes']), device=device) dist.all_reduce(total_mse_tokens, op=dist.ReduceOp.SUM) mse = mse.mean(dim=-1).sum() * dist.get_world_size() / total_mse_tokens loss_dict["mse"] = mse.detach() loss = loss + mse * training_args.mse_weight else: assert not training_args.visual_gen loss_dict["mse"] = torch.tensor(0, device=device) total_mse_tokens = torch.tensor(0, device=device) optimizer.zero_grad() loss.backward() total_norm = fsdp_model.clip_grad_norm_(training_args.max_grad_norm) optimizer.step() scheduler.step() fsdp_ema_update(ema_model, fsdp_model, decay=training_args.ema) # Log loss values: if curr_step % training_args.log_every == 0: total_samples = torch.tensor(len(data['sample_lens']), device=device) dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) # Measure training speed: torch.cuda.synchronize() end_time = time() steps_per_sec = training_args.log_every / (end_time - start_time) message = f"(step={curr_step:07d}) " wandb_log = {} for key, value in loss_dict.items(): # Reduce loss history over all processes: avg_loss = torch.tensor(value.item(), device=device) dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) avg_loss = avg_loss.item() / dist.get_world_size() message += f"Train Loss {key}: {avg_loss:.4f}, " wandb_log[key] = avg_loss message += f"Train Steps/Sec: {steps_per_sec:.2f}, " logger.info(message) wandb_log['lr'] = optimizer.param_groups[0]['lr'] wandb_log['total_mse_tokens'] = total_mse_tokens.item() wandb_log['total_ce_tokens'] = total_ce_tokens.item() wandb_log['total_norm'] = total_norm.item() wandb_log['total_samples'] = total_samples.item() mem_allocated = torch.tensor(torch.cuda.max_memory_allocated() / 1024**2, device=device) dist.all_reduce(mem_allocated, op=dist.ReduceOp.MAX) wandb_log['mem_allocated'] = mem_allocated mem_cache = torch.tensor(torch.cuda.max_memory_reserved() / 1024**2, device=device) dist.all_reduce(mem_cache, op=dist.ReduceOp.MAX) wandb_log['mem_cache'] = mem_cache if dist.get_rank() == 0: wandb.log(wandb_log, step=curr_step) start_time = time() if data_status is None: data_status = {} for item in data_indexes: if item['dataset_name'] not in data_status.keys(): data_status[item['dataset_name']] = {} data_status[item['dataset_name']][item['worker_id']] = item['data_indexes'] if curr_step > 0 and curr_step % training_args.save_every == 0: if dist.get_rank() == 0: gather_list = [None] * dist.get_world_size() else: gather_list = None dist.gather_object(data_status, gather_list, dst=0) FSDPCheckpoint.fsdp_save_ckpt( ckpt_dir=training_args.checkpoint_dir, train_steps=curr_step, model=fsdp_model, ema_model=ema_model, optimizer=optimizer, scheduler=scheduler, logger=logger, fsdp_config=fsdp_config, data_status=gather_list ) logger.info("Done!") if dist.get_rank() == 0: wandb.finish() dist.destroy_process_group() if __name__ == "__main__": main()