Bagel-7B-Demo / train /pretrain_unified_navit.py
KingNish's picture
Upload 110 files
e6af450 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import functools
import os
import wandb
import yaml
from copy import deepcopy
from dataclasses import dataclass, field
from time import time
import torch
import torch.distributed as dist
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
from torch.utils.data import DataLoader
from transformers import HfArgumentParser, set_seed
from transformers.optimization import (
get_constant_schedule_with_warmup,
get_cosine_with_min_lr_schedule_with_warmup,
)
from data.dataset_base import DataConfig, PackedDataset, collate_wrapper
from data.data_utils import add_special_tokens
from modeling.autoencoder import load_ae
from modeling.bagel import (
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer
from train.train_utils import create_logger, get_latest_ckpt
from train.fsdp_utils import (
FSDPCheckpoint, FSDPConfig, grad_checkpoint_check_fn, fsdp_wrapper,
fsdp_ema_setup, fsdp_ema_update,
)
@dataclass
class ModelArguments:
llm_path: str = field(
default="hf/Qwen2.5-0.5B-Instruct/",
metadata={"help": "Path or HuggingFace repo ID of the pretrained Qwen2-style language model."}
)
llm_qk_norm: bool = field(
default=True,
metadata={"help": "Enable QK LayerNorm (qk_norm) inside the attention blocks."}
)
tie_word_embeddings: bool = field(
default=False,
metadata={"help": "Share input and output word embeddings (tied embeddings)."}
)
layer_module: str = field(
default="Qwen2DecoderLayer",
metadata={"help": "Python class name of the decoder layer to instantiate."}
)
vae_path: str = field(
default="flux/vae/ae.safetensors",
metadata={"help": "Path to the pretrained VAE checkpoint for latent-space image generation."}
)
vit_path: str = field(
default="hf/siglip-so400m-14-980-flash-attn2-navit/",
metadata={"help": "Path or repo ID of the SigLIP Vision Transformer used for image understanding."}
)
max_latent_size: int = field(
default=32,
metadata={"help": "Maximum latent grid size (patches per side) for the VAE latent tensor."}
)
latent_patch_size: int = field(
default=2,
metadata={"help": "Spatial size (in VAE pixels) covered by each latent patch."}
)
vit_patch_size: int = field(
default=14,
metadata={"help": "Patch size (pixels) for the Vision Transformer encoder."}
)
vit_max_num_patch_per_side: int = field(
default=70,
metadata={"help": "Maximum number of ViT patches along one image side after cropping / resize."}
)
connector_act: str = field(
default="gelu_pytorch_tanh",
metadata={"help": "Activation function used in the latent-to-text connector MLP."}
)
interpolate_pos: bool = field(
default=False,
metadata={"help": "Interpolate positional embeddings when image resolution differs from pre-training."}
)
vit_select_layer: int = field(
default=-2,
metadata={"help": "Which hidden layer of the ViT to take as the visual feature (negative = from the end)."}
)
vit_rope: bool = field(
default=False,
metadata={"help": "Replace ViT positional encodings with RoPE."}
)
text_cond_dropout_prob: float = field(
default=0.1,
metadata={"help": "Probability of dropping text embeddings during training."}
)
vae_cond_dropout_prob: float = field(
default=0.3,
metadata={"help": "Probability of dropping VAE latent inputs during training."}
)
vit_cond_dropout_prob: float = field(
default=0.3,
metadata={"help": "Probability of dropping ViT visual features during training."}
)
@dataclass
class DataArguments:
dataset_config_file: str = field(
default="data/configs/example.yaml",
metadata={"help": "YAML file specifying dataset groups, weights, and preprocessing rules."}
)
prefetch_factor: int = field(
default=2,
metadata={"help": "How many batches each DataLoader worker pre-loads in advance."}
)
num_workers: int = field(
default=4,
metadata={"help": "Number of background workers for the PyTorch DataLoader."}
)
max_num_tokens_per_sample: int = field(
default=16384,
metadata={"help": "Maximum tokens allowed in one raw sample; longer samples are skipped."}
)
max_num_tokens: int = field(
default=36864,
metadata={"help": "Hard limit on tokens in a packed batch; flush if adding a sample would exceed it."}
)
prefer_buffer_before: int = field(
default=16384,
metadata={"help": "While batch length is below this, pop from the overflow buffer before new sampling."}
)
max_buffer_size: int = field(
default=50,
metadata={"help": "Maximum number of oversized samples kept in the overflow buffer."}
)
data_seed: int = field(
default=42,
metadata={"help": "Seed used when shuffling / sampling data shards to ensure reproducibility."}
)
@dataclass
class TrainingArguments:
# --- modality switches ---
visual_gen: bool = field(
default=True,
metadata={"help": "Train image generation branch."}
)
visual_und: bool = field(
default=True,
metadata={"help": "Train image understanding branch."}
)
# --- bookkeeping & logging ---
results_dir: str = field(
default="results",
metadata={"help": "Root directory for logs."}
)
checkpoint_dir: str = field(
default="results/checkpoints",
metadata={"help": "Root directory for model checkpoints."}
)
wandb_project: str = field(
default="bagel",
metadata={"help": "Weights & Biases project name."}
)
wandb_name: str = field(
default="run",
metadata={"help": "Name shown in the Weights & Biases UI for this run."}
)
wandb_runid: str = field(
default="0",
metadata={"help": "Unique identifier to resume a previous W&B run, if desired."}
)
wandb_resume: str = field(
default="allow",
metadata={"help": "W&B resume mode: 'allow', 'must', or 'never'."}
)
wandb_offline: bool = field(
default=False,
metadata={"help": "Run W&B in offline mode (logs locally, sync later)."}
)
# --- reproducibility & resume ---
global_seed: int = field(
default=4396,
metadata={"help": "Base random seed; actual seed is offset by rank for DDP."}
)
auto_resume: bool = field(
default=False,
metadata={"help": "Automatically pick up the latest checkpoint found in checkpoint_dir."}
)
resume_from: str = field(
default=None,
metadata={"help": "Explicit checkpoint path to resume from (overrides auto_resume)." }
)
resume_model_only: bool = field(
default=False,
metadata={"help": "Load only model weights, ignoring optimizer/scheduler states."}
)
finetune_from_ema: bool = field(
default=False,
metadata={"help": "When resume_model_only=True, load the EMA (exponential moving average) weights instead of raw weights."}
)
# --- reporting frequency ---
log_every: int = field(
default=10,
metadata={"help": "Print / log every N training steps."}
)
save_every: int = field(
default=2000,
metadata={"help": "Save a checkpoint every N training steps."}
)
total_steps: int = field(
default=500_000,
metadata={"help": "Total number of optimizer steps to train for."}
)
# --- optimization & scheduler ---
warmup_steps: int = field(
default=2000,
metadata={"help": "Linear warm-up steps before applying the main LR schedule."}
)
lr_scheduler: str = field(
default="constant",
metadata={"help": "Type of LR schedule: 'constant' or 'cosine'."}
)
lr: float = field(
default=1e-4,
metadata={"help": "Peak learning rate after warm-up."}
)
min_lr: float = field(
default=1e-7,
metadata={"help": "Minimum learning rate for cosine schedule (ignored for constant)."}
)
beta1: float = field(
default=0.9,
metadata={"help": "AdamW β₁ coefficient."}
)
beta2: float = field(
default=0.95,
metadata={"help": "AdamW β₂ coefficient."}
)
eps: float = field(
default=1e-15,
metadata={"help": "AdamW ε for numerical stability."}
)
ema: float = field(
default=0.9999,
metadata={"help": "Decay rate for the exponential moving average of model weights."}
)
max_grad_norm: int = field(
default=1.0,
metadata={"help": "Gradient clipping threshold (L2 norm)."}
)
timestep_shift: float = field(
default=1.0,
metadata={"help": "Shift applied to diffusion timestep indices (for latent prediction)."}
)
mse_weight: float = field(
default=1.0,
metadata={"help": "Scaling factor for the image-reconstruction MSE loss term."}
)
ce_weight: float = field(
default=1.0,
metadata={"help": "Scaling factor for the language cross-entropy loss term."}
)
ce_loss_reweighting: bool = field(
default=False,
metadata={"help": "Reweight CE loss by token importance (provided via ce_loss_weights)."}
)
expected_num_tokens: int = field(
default=32768,
metadata={"help": "Soft target token count; yield the batch once it reaches or exceeds this size."}
)
# --- distributed training / FSDP ---
num_replicate: int = field(
default=1,
metadata={"help": "Number of model replicas per GPU rank for tensor parallelism."}
)
num_shard: int = field(
default=8,
metadata={"help": "Number of parameter shards when using FSDP HYBRID_SHARD."}
)
sharding_strategy: str = field(
default="HYBRID_SHARD",
metadata={"help": "FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, HYBRID_SHARD, etc."}
)
backward_prefetch: str = field(
default="BACKWARD_PRE",
metadata={"help": "FSDP backward prefetch strategy (BACKWARD_PRE or NO_PREFETCH)."}
)
cpu_offload: bool = field(
default=False,
metadata={"help": "Enable FSDP parameter offload to CPU."}
)
# --- module freezing ---
freeze_llm: bool = field(
default=False,
metadata={"help": "Keep language-model weights fixed (no gradient updates)."}
)
freeze_vit: bool = field(
default=False,
metadata={"help": "Keep ViT weights fixed during training."}
)
freeze_vae: bool = field(
default=True,
metadata={"help": "Keep VAE weights fixed; only predict latents, don’t fine-tune encoder/decoder."}
)
freeze_und: bool = field(
default=False,
metadata={"help": "Freeze the visual understanding connector layers."}
)
copy_init_moe: bool = field(
default=True,
metadata={"help": "Duplicate initial MoE experts so each has identical initialisation."}
)
use_flex: bool = field(
default=False,
metadata={"help": "Enable FLEX (flash-ext friendly) packing algorithm for sequence data."}
)
def main():
assert torch.cuda.is_available()
dist.init_process_group("nccl")
device = dist.get_rank() % torch.cuda.device_count()
torch.cuda.set_device(device)
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Setup logging:
if dist.get_rank() == 0:
os.makedirs(training_args.results_dir, exist_ok=True)
os.makedirs(training_args.checkpoint_dir, exist_ok=True)
logger = create_logger(training_args.results_dir, dist.get_rank())
wandb.init(
project=training_args.wandb_project,
id=f"{training_args.wandb_name}-run{training_args.wandb_runid}",
name=training_args.wandb_name,
resume=training_args.wandb_resume,
mode="offline" if training_args.wandb_offline else "online"
)
wandb.config.update(training_args)
wandb.config.update(model_args)
wandb.config.update(data_args)
else:
logger = create_logger(None, dist.get_rank())
dist.barrier()
logger.info(f'Training arguments {training_args}')
logger.info(f'Model arguments {model_args}')
logger.info(f'Data arguments {data_args}')
# prepare auto resume logic:
if training_args.auto_resume:
resume_from = get_latest_ckpt(training_args.checkpoint_dir)
if resume_from is None:
resume_from = training_args.resume_from
resume_model_only = training_args.resume_model_only
if resume_model_only:
finetune_from_ema = training_args.finetune_from_ema
else:
finetune_from_ema = False
else:
resume_model_only = False
finetune_from_ema = False
else:
resume_from = training_args.resume_from
resume_model_only = training_args.resume_model_only
if resume_model_only:
finetune_from_ema = training_args.finetune_from_ema
else:
finetune_from_ema = False
# Set seed:
seed = training_args.global_seed * dist.get_world_size() + dist.get_rank()
set_seed(seed)
# Setup model:
llm_config = Qwen2Config.from_pretrained(model_args.llm_path)
llm_config.layer_module = model_args.layer_module
llm_config.qk_norm = model_args.llm_qk_norm
llm_config.tie_word_embeddings = model_args.tie_word_embeddings
llm_config.freeze_und = training_args.freeze_und
language_model = Qwen2ForCausalLM.from_pretrained(model_args.llm_path, config=llm_config)
if training_args.copy_init_moe:
language_model.init_moe()
if training_args.visual_und:
vit_config = SiglipVisionConfig.from_pretrained(model_args.vit_path)
vit_config.num_hidden_layers = vit_config.num_hidden_layers + 1 + model_args.vit_select_layer
vit_config.rope = model_args.vit_rope
vit_model = SiglipVisionModel.from_pretrained(model_args.vit_path, config=vit_config)
if training_args.visual_gen:
vae_model, vae_config = load_ae(local_path=model_args.vae_path)
config = BagelConfig(
visual_gen=training_args.visual_gen,
visual_und=training_args.visual_und,
llm_config=llm_config,
vit_config=vit_config if training_args.visual_und else None,
vae_config=vae_config if training_args.visual_gen else None,
latent_patch_size=model_args.latent_patch_size,
max_latent_size=model_args.max_latent_size,
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side,
connector_act=model_args.connector_act,
interpolate_pos=model_args.interpolate_pos,
timestep_shift=training_args.timestep_shift,
)
model = Bagel(
language_model,
vit_model if training_args.visual_und else None,
config
)
if training_args.visual_und:
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config)
# Setup tokenizer for model:
tokenizer = Qwen2Tokenizer.from_pretrained(model_args.llm_path)
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer)
if num_new_tokens > 0:
model.language_model.resize_token_embeddings(len(tokenizer))
model.config.llm_config.vocab_size = len(tokenizer)
model.language_model.config.vocab_size = len(tokenizer)
# maybe freeze something:
if training_args.freeze_vae and training_args.visual_gen:
for param in vae_model.parameters():
param.requires_grad = False
if training_args.freeze_llm:
model.language_model.eval()
for param in model.language_model.parameters():
param.requires_grad = False
if training_args.freeze_vit and training_args.visual_und:
model.vit_model.eval()
for param in model.vit_model.parameters():
param.requires_grad = False
# Setup FSDP and load pretrained model:
fsdp_config = FSDPConfig(
sharding_strategy=training_args.sharding_strategy,
backward_prefetch=training_args.backward_prefetch,
cpu_offload=training_args.cpu_offload,
num_replicate=training_args.num_replicate,
num_shard=training_args.num_shard,
)
ema_model = deepcopy(model)
model, ema_model = FSDPCheckpoint.try_load_ckpt(
resume_from, logger, model, ema_model, resume_from_ema=finetune_from_ema
)
ema_model = fsdp_ema_setup(ema_model, fsdp_config)
fsdp_model = fsdp_wrapper(model, fsdp_config)
apply_activation_checkpointing(
fsdp_model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT
),
check_fn=grad_checkpoint_check_fn
)
if dist.get_rank() == 0:
print(fsdp_model)
for name, param in model.named_parameters():
print(name, param.requires_grad)
# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(
fsdp_model.parameters(),
lr=training_args.lr,
betas=(training_args.beta1, training_args.beta2),
eps=training_args.eps,
weight_decay=0
)
if training_args.lr_scheduler == 'cosine':
scheduler = get_cosine_with_min_lr_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=training_args.total_steps,
min_lr=training_args.min_lr,
)
elif training_args.lr_scheduler == 'constant':
scheduler = get_constant_schedule_with_warmup(
optimizer=optimizer, num_warmup_steps=training_args.warmup_steps
)
else:
raise ValueError
# maybe resume optimizer, scheduler, and train_steps
if resume_model_only:
train_step = 0
data_status = None
else:
optimizer, scheduler, train_step, data_status = FSDPCheckpoint.try_load_train_state(
resume_from, optimizer, scheduler, fsdp_config,
)
# Setup packed dataloader
with open(data_args.dataset_config_file, "r") as stream:
dataset_meta = yaml.safe_load(stream)
dataset_config = DataConfig(grouped_datasets=dataset_meta)
if training_args.visual_und:
dataset_config.vit_patch_size = model_args.vit_patch_size
dataset_config.max_num_patch_per_side = model_args.vit_max_num_patch_per_side
if training_args.visual_gen:
vae_image_downsample = model_args.latent_patch_size * vae_config.downsample
dataset_config.vae_image_downsample = vae_image_downsample
dataset_config.max_latent_size = model_args.max_latent_size
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob
train_dataset = PackedDataset(
dataset_config,
tokenizer=tokenizer,
special_tokens=new_token_ids,
local_rank=dist.get_rank(),
world_size=dist.get_world_size(),
num_workers=data_args.num_workers,
expected_num_tokens=training_args.expected_num_tokens,
max_num_tokens_per_sample=data_args.max_num_tokens_per_sample,
max_num_tokens=data_args.max_num_tokens,
max_buffer_size=data_args.max_buffer_size,
prefer_buffer_before=data_args.prefer_buffer_before,
interpolate_pos=model_args.interpolate_pos,
use_flex=training_args.use_flex,
data_status=data_status,
)
train_dataset.set_epoch(data_args.data_seed)
train_loader = DataLoader(
train_dataset,
batch_size=1, # batch size is 1 packed dataset
num_workers=data_args.num_workers,
pin_memory=True,
collate_fn=collate_wrapper(),
drop_last=True,
prefetch_factor=data_args.prefetch_factor,
)
# Prepare models for training:
if training_args.visual_gen:
vae_model.to(device).eval()
fsdp_model.train()
ema_model.eval()
# train loop
start_time = time()
logger.info(f"Training for {training_args.total_steps} steps, starting at {train_step}...")
for curr_step, data in enumerate(train_loader, start=train_step):
data = data.cuda(device).to_dict()
data_indexes = data.pop('batch_data_indexes', None)
ce_loss_weights = data.pop('ce_loss_weights', None)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
if training_args.visual_gen:
with torch.no_grad():
data['padded_latent'] = vae_model.encode(data.pop('padded_images'))
loss_dict = fsdp_model(**data)
loss = 0
ce = loss_dict["ce"]
if ce is not None:
total_ce_tokens = torch.tensor(len(data['ce_loss_indexes']), device=device)
dist.all_reduce(total_ce_tokens, op=dist.ReduceOp.SUM)
if training_args.ce_loss_reweighting:
ce = ce * ce_loss_weights
total_ce_loss_weights = ce_loss_weights.sum()
dist.all_reduce(total_ce_loss_weights, op=dist.ReduceOp.SUM)
ce = ce.sum() * dist.get_world_size() / total_ce_loss_weights
else:
ce = ce.sum() * dist.get_world_size() / total_ce_tokens
loss_dict["ce"] = ce.detach()
loss = loss + ce * training_args.ce_weight
else:
assert not training_args.visual_und
loss_dict["ce"] = torch.tensor(0, device=device)
total_ce_tokens = torch.tensor(0, device=device)
if training_args.visual_gen:
mse = loss_dict["mse"]
total_mse_tokens = torch.tensor(len(data['mse_loss_indexes']), device=device)
dist.all_reduce(total_mse_tokens, op=dist.ReduceOp.SUM)
mse = mse.mean(dim=-1).sum() * dist.get_world_size() / total_mse_tokens
loss_dict["mse"] = mse.detach()
loss = loss + mse * training_args.mse_weight
else:
assert not training_args.visual_gen
loss_dict["mse"] = torch.tensor(0, device=device)
total_mse_tokens = torch.tensor(0, device=device)
optimizer.zero_grad()
loss.backward()
total_norm = fsdp_model.clip_grad_norm_(training_args.max_grad_norm)
optimizer.step()
scheduler.step()
fsdp_ema_update(ema_model, fsdp_model, decay=training_args.ema)
# Log loss values:
if curr_step % training_args.log_every == 0:
total_samples = torch.tensor(len(data['sample_lens']), device=device)
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM)
# Measure training speed:
torch.cuda.synchronize()
end_time = time()
steps_per_sec = training_args.log_every / (end_time - start_time)
message = f"(step={curr_step:07d}) "
wandb_log = {}
for key, value in loss_dict.items():
# Reduce loss history over all processes:
avg_loss = torch.tensor(value.item(), device=device)
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
avg_loss = avg_loss.item() / dist.get_world_size()
message += f"Train Loss {key}: {avg_loss:.4f}, "
wandb_log[key] = avg_loss
message += f"Train Steps/Sec: {steps_per_sec:.2f}, "
logger.info(message)
wandb_log['lr'] = optimizer.param_groups[0]['lr']
wandb_log['total_mse_tokens'] = total_mse_tokens.item()
wandb_log['total_ce_tokens'] = total_ce_tokens.item()
wandb_log['total_norm'] = total_norm.item()
wandb_log['total_samples'] = total_samples.item()
mem_allocated = torch.tensor(torch.cuda.max_memory_allocated() / 1024**2, device=device)
dist.all_reduce(mem_allocated, op=dist.ReduceOp.MAX)
wandb_log['mem_allocated'] = mem_allocated
mem_cache = torch.tensor(torch.cuda.max_memory_reserved() / 1024**2, device=device)
dist.all_reduce(mem_cache, op=dist.ReduceOp.MAX)
wandb_log['mem_cache'] = mem_cache
if dist.get_rank() == 0:
wandb.log(wandb_log, step=curr_step)
start_time = time()
if data_status is None:
data_status = {}
for item in data_indexes:
if item['dataset_name'] not in data_status.keys():
data_status[item['dataset_name']] = {}
data_status[item['dataset_name']][item['worker_id']] = item['data_indexes']
if curr_step > 0 and curr_step % training_args.save_every == 0:
if dist.get_rank() == 0:
gather_list = [None] * dist.get_world_size()
else:
gather_list = None
dist.gather_object(data_status, gather_list, dst=0)
FSDPCheckpoint.fsdp_save_ckpt(
ckpt_dir=training_args.checkpoint_dir,
train_steps=curr_step,
model=fsdp_model,
ema_model=ema_model,
optimizer=optimizer,
scheduler=scheduler,
logger=logger,
fsdp_config=fsdp_config,
data_status=gather_list
)
logger.info("Done!")
if dist.get_rank() == 0:
wandb.finish()
dist.destroy_process_group()
if __name__ == "__main__":
main()