Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
# SPDX-License-Identifier: Apache-2.0 | |
import functools | |
import os | |
import wandb | |
import yaml | |
from copy import deepcopy | |
from dataclasses import dataclass, field | |
from time import time | |
import torch | |
import torch.distributed as dist | |
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
CheckpointImpl, | |
apply_activation_checkpointing, | |
checkpoint_wrapper, | |
) | |
from torch.utils.data import DataLoader | |
from transformers import HfArgumentParser, set_seed | |
from transformers.optimization import ( | |
get_constant_schedule_with_warmup, | |
get_cosine_with_min_lr_schedule_with_warmup, | |
) | |
from data.dataset_base import DataConfig, PackedDataset, collate_wrapper | |
from data.data_utils import add_special_tokens | |
from modeling.autoencoder import load_ae | |
from modeling.bagel import ( | |
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel | |
) | |
from modeling.qwen2 import Qwen2Tokenizer | |
from train.train_utils import create_logger, get_latest_ckpt | |
from train.fsdp_utils import ( | |
FSDPCheckpoint, FSDPConfig, grad_checkpoint_check_fn, fsdp_wrapper, | |
fsdp_ema_setup, fsdp_ema_update, | |
) | |
class ModelArguments: | |
llm_path: str = field( | |
default="hf/Qwen2.5-0.5B-Instruct/", | |
metadata={"help": "Path or HuggingFace repo ID of the pretrained Qwen2-style language model."} | |
) | |
llm_qk_norm: bool = field( | |
default=True, | |
metadata={"help": "Enable QK LayerNorm (qk_norm) inside the attention blocks."} | |
) | |
tie_word_embeddings: bool = field( | |
default=False, | |
metadata={"help": "Share input and output word embeddings (tied embeddings)."} | |
) | |
layer_module: str = field( | |
default="Qwen2DecoderLayer", | |
metadata={"help": "Python class name of the decoder layer to instantiate."} | |
) | |
vae_path: str = field( | |
default="flux/vae/ae.safetensors", | |
metadata={"help": "Path to the pretrained VAE checkpoint for latent-space image generation."} | |
) | |
vit_path: str = field( | |
default="hf/siglip-so400m-14-980-flash-attn2-navit/", | |
metadata={"help": "Path or repo ID of the SigLIP Vision Transformer used for image understanding."} | |
) | |
max_latent_size: int = field( | |
default=32, | |
metadata={"help": "Maximum latent grid size (patches per side) for the VAE latent tensor."} | |
) | |
latent_patch_size: int = field( | |
default=2, | |
metadata={"help": "Spatial size (in VAE pixels) covered by each latent patch."} | |
) | |
vit_patch_size: int = field( | |
default=14, | |
metadata={"help": "Patch size (pixels) for the Vision Transformer encoder."} | |
) | |
vit_max_num_patch_per_side: int = field( | |
default=70, | |
metadata={"help": "Maximum number of ViT patches along one image side after cropping / resize."} | |
) | |
connector_act: str = field( | |
default="gelu_pytorch_tanh", | |
metadata={"help": "Activation function used in the latent-to-text connector MLP."} | |
) | |
interpolate_pos: bool = field( | |
default=False, | |
metadata={"help": "Interpolate positional embeddings when image resolution differs from pre-training."} | |
) | |
vit_select_layer: int = field( | |
default=-2, | |
metadata={"help": "Which hidden layer of the ViT to take as the visual feature (negative = from the end)."} | |
) | |
vit_rope: bool = field( | |
default=False, | |
metadata={"help": "Replace ViT positional encodings with RoPE."} | |
) | |
text_cond_dropout_prob: float = field( | |
default=0.1, | |
metadata={"help": "Probability of dropping text embeddings during training."} | |
) | |
vae_cond_dropout_prob: float = field( | |
default=0.3, | |
metadata={"help": "Probability of dropping VAE latent inputs during training."} | |
) | |
vit_cond_dropout_prob: float = field( | |
default=0.3, | |
metadata={"help": "Probability of dropping ViT visual features during training."} | |
) | |
class DataArguments: | |
dataset_config_file: str = field( | |
default="data/configs/example.yaml", | |
metadata={"help": "YAML file specifying dataset groups, weights, and preprocessing rules."} | |
) | |
prefetch_factor: int = field( | |
default=2, | |
metadata={"help": "How many batches each DataLoader worker pre-loads in advance."} | |
) | |
num_workers: int = field( | |
default=4, | |
metadata={"help": "Number of background workers for the PyTorch DataLoader."} | |
) | |
max_num_tokens_per_sample: int = field( | |
default=16384, | |
metadata={"help": "Maximum tokens allowed in one raw sample; longer samples are skipped."} | |
) | |
max_num_tokens: int = field( | |
default=36864, | |
metadata={"help": "Hard limit on tokens in a packed batch; flush if adding a sample would exceed it."} | |
) | |
prefer_buffer_before: int = field( | |
default=16384, | |
metadata={"help": "While batch length is below this, pop from the overflow buffer before new sampling."} | |
) | |
max_buffer_size: int = field( | |
default=50, | |
metadata={"help": "Maximum number of oversized samples kept in the overflow buffer."} | |
) | |
data_seed: int = field( | |
default=42, | |
metadata={"help": "Seed used when shuffling / sampling data shards to ensure reproducibility."} | |
) | |
class TrainingArguments: | |
# --- modality switches --- | |
visual_gen: bool = field( | |
default=True, | |
metadata={"help": "Train image generation branch."} | |
) | |
visual_und: bool = field( | |
default=True, | |
metadata={"help": "Train image understanding branch."} | |
) | |
# --- bookkeeping & logging --- | |
results_dir: str = field( | |
default="results", | |
metadata={"help": "Root directory for logs."} | |
) | |
checkpoint_dir: str = field( | |
default="results/checkpoints", | |
metadata={"help": "Root directory for model checkpoints."} | |
) | |
wandb_project: str = field( | |
default="bagel", | |
metadata={"help": "Weights & Biases project name."} | |
) | |
wandb_name: str = field( | |
default="run", | |
metadata={"help": "Name shown in the Weights & Biases UI for this run."} | |
) | |
wandb_runid: str = field( | |
default="0", | |
metadata={"help": "Unique identifier to resume a previous W&B run, if desired."} | |
) | |
wandb_resume: str = field( | |
default="allow", | |
metadata={"help": "W&B resume mode: 'allow', 'must', or 'never'."} | |
) | |
wandb_offline: bool = field( | |
default=False, | |
metadata={"help": "Run W&B in offline mode (logs locally, sync later)."} | |
) | |
# --- reproducibility & resume --- | |
global_seed: int = field( | |
default=4396, | |
metadata={"help": "Base random seed; actual seed is offset by rank for DDP."} | |
) | |
auto_resume: bool = field( | |
default=False, | |
metadata={"help": "Automatically pick up the latest checkpoint found in checkpoint_dir."} | |
) | |
resume_from: str = field( | |
default=None, | |
metadata={"help": "Explicit checkpoint path to resume from (overrides auto_resume)." } | |
) | |
resume_model_only: bool = field( | |
default=False, | |
metadata={"help": "Load only model weights, ignoring optimizer/scheduler states."} | |
) | |
finetune_from_ema: bool = field( | |
default=False, | |
metadata={"help": "When resume_model_only=True, load the EMA (exponential moving average) weights instead of raw weights."} | |
) | |
# --- reporting frequency --- | |
log_every: int = field( | |
default=10, | |
metadata={"help": "Print / log every N training steps."} | |
) | |
save_every: int = field( | |
default=2000, | |
metadata={"help": "Save a checkpoint every N training steps."} | |
) | |
total_steps: int = field( | |
default=500_000, | |
metadata={"help": "Total number of optimizer steps to train for."} | |
) | |
# --- optimization & scheduler --- | |
warmup_steps: int = field( | |
default=2000, | |
metadata={"help": "Linear warm-up steps before applying the main LR schedule."} | |
) | |
lr_scheduler: str = field( | |
default="constant", | |
metadata={"help": "Type of LR schedule: 'constant' or 'cosine'."} | |
) | |
lr: float = field( | |
default=1e-4, | |
metadata={"help": "Peak learning rate after warm-up."} | |
) | |
min_lr: float = field( | |
default=1e-7, | |
metadata={"help": "Minimum learning rate for cosine schedule (ignored for constant)."} | |
) | |
beta1: float = field( | |
default=0.9, | |
metadata={"help": "AdamW β₁ coefficient."} | |
) | |
beta2: float = field( | |
default=0.95, | |
metadata={"help": "AdamW β₂ coefficient."} | |
) | |
eps: float = field( | |
default=1e-15, | |
metadata={"help": "AdamW ε for numerical stability."} | |
) | |
ema: float = field( | |
default=0.9999, | |
metadata={"help": "Decay rate for the exponential moving average of model weights."} | |
) | |
max_grad_norm: int = field( | |
default=1.0, | |
metadata={"help": "Gradient clipping threshold (L2 norm)."} | |
) | |
timestep_shift: float = field( | |
default=1.0, | |
metadata={"help": "Shift applied to diffusion timestep indices (for latent prediction)."} | |
) | |
mse_weight: float = field( | |
default=1.0, | |
metadata={"help": "Scaling factor for the image-reconstruction MSE loss term."} | |
) | |
ce_weight: float = field( | |
default=1.0, | |
metadata={"help": "Scaling factor for the language cross-entropy loss term."} | |
) | |
ce_loss_reweighting: bool = field( | |
default=False, | |
metadata={"help": "Reweight CE loss by token importance (provided via ce_loss_weights)."} | |
) | |
expected_num_tokens: int = field( | |
default=32768, | |
metadata={"help": "Soft target token count; yield the batch once it reaches or exceeds this size."} | |
) | |
# --- distributed training / FSDP --- | |
num_replicate: int = field( | |
default=1, | |
metadata={"help": "Number of model replicas per GPU rank for tensor parallelism."} | |
) | |
num_shard: int = field( | |
default=8, | |
metadata={"help": "Number of parameter shards when using FSDP HYBRID_SHARD."} | |
) | |
sharding_strategy: str = field( | |
default="HYBRID_SHARD", | |
metadata={"help": "FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, HYBRID_SHARD, etc."} | |
) | |
backward_prefetch: str = field( | |
default="BACKWARD_PRE", | |
metadata={"help": "FSDP backward prefetch strategy (BACKWARD_PRE or NO_PREFETCH)."} | |
) | |
cpu_offload: bool = field( | |
default=False, | |
metadata={"help": "Enable FSDP parameter offload to CPU."} | |
) | |
# --- module freezing --- | |
freeze_llm: bool = field( | |
default=False, | |
metadata={"help": "Keep language-model weights fixed (no gradient updates)."} | |
) | |
freeze_vit: bool = field( | |
default=False, | |
metadata={"help": "Keep ViT weights fixed during training."} | |
) | |
freeze_vae: bool = field( | |
default=True, | |
metadata={"help": "Keep VAE weights fixed; only predict latents, don’t fine-tune encoder/decoder."} | |
) | |
freeze_und: bool = field( | |
default=False, | |
metadata={"help": "Freeze the visual understanding connector layers."} | |
) | |
copy_init_moe: bool = field( | |
default=True, | |
metadata={"help": "Duplicate initial MoE experts so each has identical initialisation."} | |
) | |
use_flex: bool = field( | |
default=False, | |
metadata={"help": "Enable FLEX (flash-ext friendly) packing algorithm for sequence data."} | |
) | |
def main(): | |
assert torch.cuda.is_available() | |
dist.init_process_group("nccl") | |
device = dist.get_rank() % torch.cuda.device_count() | |
torch.cuda.set_device(device) | |
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
# Setup logging: | |
if dist.get_rank() == 0: | |
os.makedirs(training_args.results_dir, exist_ok=True) | |
os.makedirs(training_args.checkpoint_dir, exist_ok=True) | |
logger = create_logger(training_args.results_dir, dist.get_rank()) | |
wandb.init( | |
project=training_args.wandb_project, | |
id=f"{training_args.wandb_name}-run{training_args.wandb_runid}", | |
name=training_args.wandb_name, | |
resume=training_args.wandb_resume, | |
mode="offline" if training_args.wandb_offline else "online" | |
) | |
wandb.config.update(training_args) | |
wandb.config.update(model_args) | |
wandb.config.update(data_args) | |
else: | |
logger = create_logger(None, dist.get_rank()) | |
dist.barrier() | |
logger.info(f'Training arguments {training_args}') | |
logger.info(f'Model arguments {model_args}') | |
logger.info(f'Data arguments {data_args}') | |
# prepare auto resume logic: | |
if training_args.auto_resume: | |
resume_from = get_latest_ckpt(training_args.checkpoint_dir) | |
if resume_from is None: | |
resume_from = training_args.resume_from | |
resume_model_only = training_args.resume_model_only | |
if resume_model_only: | |
finetune_from_ema = training_args.finetune_from_ema | |
else: | |
finetune_from_ema = False | |
else: | |
resume_model_only = False | |
finetune_from_ema = False | |
else: | |
resume_from = training_args.resume_from | |
resume_model_only = training_args.resume_model_only | |
if resume_model_only: | |
finetune_from_ema = training_args.finetune_from_ema | |
else: | |
finetune_from_ema = False | |
# Set seed: | |
seed = training_args.global_seed * dist.get_world_size() + dist.get_rank() | |
set_seed(seed) | |
# Setup model: | |
llm_config = Qwen2Config.from_pretrained(model_args.llm_path) | |
llm_config.layer_module = model_args.layer_module | |
llm_config.qk_norm = model_args.llm_qk_norm | |
llm_config.tie_word_embeddings = model_args.tie_word_embeddings | |
llm_config.freeze_und = training_args.freeze_und | |
language_model = Qwen2ForCausalLM.from_pretrained(model_args.llm_path, config=llm_config) | |
if training_args.copy_init_moe: | |
language_model.init_moe() | |
if training_args.visual_und: | |
vit_config = SiglipVisionConfig.from_pretrained(model_args.vit_path) | |
vit_config.num_hidden_layers = vit_config.num_hidden_layers + 1 + model_args.vit_select_layer | |
vit_config.rope = model_args.vit_rope | |
vit_model = SiglipVisionModel.from_pretrained(model_args.vit_path, config=vit_config) | |
if training_args.visual_gen: | |
vae_model, vae_config = load_ae(local_path=model_args.vae_path) | |
config = BagelConfig( | |
visual_gen=training_args.visual_gen, | |
visual_und=training_args.visual_und, | |
llm_config=llm_config, | |
vit_config=vit_config if training_args.visual_und else None, | |
vae_config=vae_config if training_args.visual_gen else None, | |
latent_patch_size=model_args.latent_patch_size, | |
max_latent_size=model_args.max_latent_size, | |
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, | |
connector_act=model_args.connector_act, | |
interpolate_pos=model_args.interpolate_pos, | |
timestep_shift=training_args.timestep_shift, | |
) | |
model = Bagel( | |
language_model, | |
vit_model if training_args.visual_und else None, | |
config | |
) | |
if training_args.visual_und: | |
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config) | |
# Setup tokenizer for model: | |
tokenizer = Qwen2Tokenizer.from_pretrained(model_args.llm_path) | |
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer) | |
if num_new_tokens > 0: | |
model.language_model.resize_token_embeddings(len(tokenizer)) | |
model.config.llm_config.vocab_size = len(tokenizer) | |
model.language_model.config.vocab_size = len(tokenizer) | |
# maybe freeze something: | |
if training_args.freeze_vae and training_args.visual_gen: | |
for param in vae_model.parameters(): | |
param.requires_grad = False | |
if training_args.freeze_llm: | |
model.language_model.eval() | |
for param in model.language_model.parameters(): | |
param.requires_grad = False | |
if training_args.freeze_vit and training_args.visual_und: | |
model.vit_model.eval() | |
for param in model.vit_model.parameters(): | |
param.requires_grad = False | |
# Setup FSDP and load pretrained model: | |
fsdp_config = FSDPConfig( | |
sharding_strategy=training_args.sharding_strategy, | |
backward_prefetch=training_args.backward_prefetch, | |
cpu_offload=training_args.cpu_offload, | |
num_replicate=training_args.num_replicate, | |
num_shard=training_args.num_shard, | |
) | |
ema_model = deepcopy(model) | |
model, ema_model = FSDPCheckpoint.try_load_ckpt( | |
resume_from, logger, model, ema_model, resume_from_ema=finetune_from_ema | |
) | |
ema_model = fsdp_ema_setup(ema_model, fsdp_config) | |
fsdp_model = fsdp_wrapper(model, fsdp_config) | |
apply_activation_checkpointing( | |
fsdp_model, | |
checkpoint_wrapper_fn=functools.partial( | |
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT | |
), | |
check_fn=grad_checkpoint_check_fn | |
) | |
if dist.get_rank() == 0: | |
print(fsdp_model) | |
for name, param in model.named_parameters(): | |
print(name, param.requires_grad) | |
# Setup optimizer and scheduler | |
optimizer = torch.optim.AdamW( | |
fsdp_model.parameters(), | |
lr=training_args.lr, | |
betas=(training_args.beta1, training_args.beta2), | |
eps=training_args.eps, | |
weight_decay=0 | |
) | |
if training_args.lr_scheduler == 'cosine': | |
scheduler = get_cosine_with_min_lr_schedule_with_warmup( | |
optimizer=optimizer, | |
num_warmup_steps=training_args.warmup_steps, | |
num_training_steps=training_args.total_steps, | |
min_lr=training_args.min_lr, | |
) | |
elif training_args.lr_scheduler == 'constant': | |
scheduler = get_constant_schedule_with_warmup( | |
optimizer=optimizer, num_warmup_steps=training_args.warmup_steps | |
) | |
else: | |
raise ValueError | |
# maybe resume optimizer, scheduler, and train_steps | |
if resume_model_only: | |
train_step = 0 | |
data_status = None | |
else: | |
optimizer, scheduler, train_step, data_status = FSDPCheckpoint.try_load_train_state( | |
resume_from, optimizer, scheduler, fsdp_config, | |
) | |
# Setup packed dataloader | |
with open(data_args.dataset_config_file, "r") as stream: | |
dataset_meta = yaml.safe_load(stream) | |
dataset_config = DataConfig(grouped_datasets=dataset_meta) | |
if training_args.visual_und: | |
dataset_config.vit_patch_size = model_args.vit_patch_size | |
dataset_config.max_num_patch_per_side = model_args.vit_max_num_patch_per_side | |
if training_args.visual_gen: | |
vae_image_downsample = model_args.latent_patch_size * vae_config.downsample | |
dataset_config.vae_image_downsample = vae_image_downsample | |
dataset_config.max_latent_size = model_args.max_latent_size | |
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob | |
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob | |
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob | |
train_dataset = PackedDataset( | |
dataset_config, | |
tokenizer=tokenizer, | |
special_tokens=new_token_ids, | |
local_rank=dist.get_rank(), | |
world_size=dist.get_world_size(), | |
num_workers=data_args.num_workers, | |
expected_num_tokens=training_args.expected_num_tokens, | |
max_num_tokens_per_sample=data_args.max_num_tokens_per_sample, | |
max_num_tokens=data_args.max_num_tokens, | |
max_buffer_size=data_args.max_buffer_size, | |
prefer_buffer_before=data_args.prefer_buffer_before, | |
interpolate_pos=model_args.interpolate_pos, | |
use_flex=training_args.use_flex, | |
data_status=data_status, | |
) | |
train_dataset.set_epoch(data_args.data_seed) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=1, # batch size is 1 packed dataset | |
num_workers=data_args.num_workers, | |
pin_memory=True, | |
collate_fn=collate_wrapper(), | |
drop_last=True, | |
prefetch_factor=data_args.prefetch_factor, | |
) | |
# Prepare models for training: | |
if training_args.visual_gen: | |
vae_model.to(device).eval() | |
fsdp_model.train() | |
ema_model.eval() | |
# train loop | |
start_time = time() | |
logger.info(f"Training for {training_args.total_steps} steps, starting at {train_step}...") | |
for curr_step, data in enumerate(train_loader, start=train_step): | |
data = data.cuda(device).to_dict() | |
data_indexes = data.pop('batch_data_indexes', None) | |
ce_loss_weights = data.pop('ce_loss_weights', None) | |
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): | |
if training_args.visual_gen: | |
with torch.no_grad(): | |
data['padded_latent'] = vae_model.encode(data.pop('padded_images')) | |
loss_dict = fsdp_model(**data) | |
loss = 0 | |
ce = loss_dict["ce"] | |
if ce is not None: | |
total_ce_tokens = torch.tensor(len(data['ce_loss_indexes']), device=device) | |
dist.all_reduce(total_ce_tokens, op=dist.ReduceOp.SUM) | |
if training_args.ce_loss_reweighting: | |
ce = ce * ce_loss_weights | |
total_ce_loss_weights = ce_loss_weights.sum() | |
dist.all_reduce(total_ce_loss_weights, op=dist.ReduceOp.SUM) | |
ce = ce.sum() * dist.get_world_size() / total_ce_loss_weights | |
else: | |
ce = ce.sum() * dist.get_world_size() / total_ce_tokens | |
loss_dict["ce"] = ce.detach() | |
loss = loss + ce * training_args.ce_weight | |
else: | |
assert not training_args.visual_und | |
loss_dict["ce"] = torch.tensor(0, device=device) | |
total_ce_tokens = torch.tensor(0, device=device) | |
if training_args.visual_gen: | |
mse = loss_dict["mse"] | |
total_mse_tokens = torch.tensor(len(data['mse_loss_indexes']), device=device) | |
dist.all_reduce(total_mse_tokens, op=dist.ReduceOp.SUM) | |
mse = mse.mean(dim=-1).sum() * dist.get_world_size() / total_mse_tokens | |
loss_dict["mse"] = mse.detach() | |
loss = loss + mse * training_args.mse_weight | |
else: | |
assert not training_args.visual_gen | |
loss_dict["mse"] = torch.tensor(0, device=device) | |
total_mse_tokens = torch.tensor(0, device=device) | |
optimizer.zero_grad() | |
loss.backward() | |
total_norm = fsdp_model.clip_grad_norm_(training_args.max_grad_norm) | |
optimizer.step() | |
scheduler.step() | |
fsdp_ema_update(ema_model, fsdp_model, decay=training_args.ema) | |
# Log loss values: | |
if curr_step % training_args.log_every == 0: | |
total_samples = torch.tensor(len(data['sample_lens']), device=device) | |
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) | |
# Measure training speed: | |
torch.cuda.synchronize() | |
end_time = time() | |
steps_per_sec = training_args.log_every / (end_time - start_time) | |
message = f"(step={curr_step:07d}) " | |
wandb_log = {} | |
for key, value in loss_dict.items(): | |
# Reduce loss history over all processes: | |
avg_loss = torch.tensor(value.item(), device=device) | |
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) | |
avg_loss = avg_loss.item() / dist.get_world_size() | |
message += f"Train Loss {key}: {avg_loss:.4f}, " | |
wandb_log[key] = avg_loss | |
message += f"Train Steps/Sec: {steps_per_sec:.2f}, " | |
logger.info(message) | |
wandb_log['lr'] = optimizer.param_groups[0]['lr'] | |
wandb_log['total_mse_tokens'] = total_mse_tokens.item() | |
wandb_log['total_ce_tokens'] = total_ce_tokens.item() | |
wandb_log['total_norm'] = total_norm.item() | |
wandb_log['total_samples'] = total_samples.item() | |
mem_allocated = torch.tensor(torch.cuda.max_memory_allocated() / 1024**2, device=device) | |
dist.all_reduce(mem_allocated, op=dist.ReduceOp.MAX) | |
wandb_log['mem_allocated'] = mem_allocated | |
mem_cache = torch.tensor(torch.cuda.max_memory_reserved() / 1024**2, device=device) | |
dist.all_reduce(mem_cache, op=dist.ReduceOp.MAX) | |
wandb_log['mem_cache'] = mem_cache | |
if dist.get_rank() == 0: | |
wandb.log(wandb_log, step=curr_step) | |
start_time = time() | |
if data_status is None: | |
data_status = {} | |
for item in data_indexes: | |
if item['dataset_name'] not in data_status.keys(): | |
data_status[item['dataset_name']] = {} | |
data_status[item['dataset_name']][item['worker_id']] = item['data_indexes'] | |
if curr_step > 0 and curr_step % training_args.save_every == 0: | |
if dist.get_rank() == 0: | |
gather_list = [None] * dist.get_world_size() | |
else: | |
gather_list = None | |
dist.gather_object(data_status, gather_list, dst=0) | |
FSDPCheckpoint.fsdp_save_ckpt( | |
ckpt_dir=training_args.checkpoint_dir, | |
train_steps=curr_step, | |
model=fsdp_model, | |
ema_model=ema_model, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
logger=logger, | |
fsdp_config=fsdp_config, | |
data_status=gather_list | |
) | |
logger.info("Done!") | |
if dist.get_rank() == 0: | |
wandb.finish() | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
main() | |