# Copyright 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import functools import os import torch import torch.distributed as dist import torch.distributed.fsdp._traversal_utils as traversal_utils from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import ( CPUOffload, FullyShardedDataParallel as FSDP, MixedPrecision, BackwardPrefetch, ShardingStrategy, FullStateDictConfig, StateDictType, ) from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from safetensors.torch import load_file, save_file from modeling.bagel.modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding from modeling.bagel.qwen2_navit import ( Qwen2DecoderLayer, Qwen2MoEDecoderLayer, Qwen2MoTDecoderLayer, ) from modeling.bagel.siglip_navit import SiglipEncoderLayer, SiglipVisionTransformer class FSDPConfig: def __init__( self, sharding_strategy, backward_prefetch, cpu_offload, num_replicate, num_shard=8, ): self.sharding_strategy = sharding_strategy self.backward_prefetch = backward_prefetch self.cpu_offload = cpu_offload self.num_replicate = num_replicate self.num_shard = num_shard def fsdp_wrapper(original_model, fsdp_config, ignored_modules=[]): if fsdp_config.sharding_strategy == 'HYBRID_SHARD': device_mesh = init_device_mesh( "cuda", mesh_shape=(fsdp_config.num_replicate, fsdp_config.num_shard), mesh_dim_names=("replicate", "shard") ) else: device_mesh = None return FSDP( original_model, auto_wrap_policy=functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ Qwen2DecoderLayer, Qwen2MoEDecoderLayer, Qwen2MoTDecoderLayer, SiglipEncoderLayer, SiglipVisionTransformer, MLPconnector, TimestepEmbedder, PositionEmbedding, }, ), ignored_modules=ignored_modules, mixed_precision=MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ), device_id=dist.get_rank() % torch.cuda.device_count(), sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy], backward_prefetch=BackwardPrefetch[fsdp_config.backward_prefetch], cpu_offload=CPUOffload(offload_params=fsdp_config.cpu_offload), device_mesh=device_mesh, ) class FSDPCheckpoint: @staticmethod def fsdp_save_ckpt( ckpt_dir, train_steps, model, ema_model, optimizer, scheduler, data_status, logger, fsdp_config, ): save_path = os.path.join(ckpt_dir, f"{train_steps:07d}") os.makedirs(save_path, exist_ok=True) logger.info(f"Saving checkpoint to {save_path}.") if ema_model is not None: with FSDP.state_dict_type( ema_model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True, offload_to_cpu=True), ): ema_state_dict = ema_model.state_dict() if dist.get_rank() == 0: save_file(ema_state_dict, os.path.join(save_path, "ema.safetensors")) with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True, offload_to_cpu=True), ): model_state_dict = model.state_dict() if dist.get_rank() == 0: save_file(model_state_dict, os.path.join(save_path, "model.safetensors")) with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): if fsdp_config.sharding_strategy == "FULL_SHARD": shard_index = dist.get_rank() total_shards = dist.get_world_size() elif fsdp_config.sharding_strategy == "HYBRID_SHARD": shard_index = dist.get_rank() % fsdp_config.num_shard total_shards = fsdp_config.num_shard else: raise NotImplementedError optimizer_save_path = os.path.join( save_path, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt" ) if fsdp_config.sharding_strategy == "FULL_SHARD": torch.save(optimizer.state_dict(), optimizer_save_path) elif fsdp_config.sharding_strategy == "HYBRID_SHARD": if dist.get_rank() < fsdp_config.num_shard: torch.save(optimizer.state_dict(), optimizer_save_path) else: raise NotImplementedError if dist.get_rank() == 0 and scheduler is not None: torch.save(scheduler.state_dict(), os.path.join(save_path, "scheduler.pt")) if dist.get_rank() == 0 and data_status is not None: torch.save(data_status, os.path.join(save_path, "data_status.pt")) dist.barrier() return @staticmethod def try_load_ckpt(resume_from, logger, model, ema_model=None, resume_from_ema=False): if resume_from is not None and os.path.exists(resume_from): logger.info(f"Loading checkpoint from {resume_from}.") if resume_from_ema: model_state_dict_path = os.path.join(resume_from, f"ema.safetensors") else: model_state_dict_path = os.path.join(resume_from, f"model.safetensors") model_state_dict = load_file(model_state_dict_path, device="cpu") # NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off, # which makes it easier to adapt to different resolutions. model_state_dict.pop('latent_pos_embed.pos_embed') model_state_dict.pop('vit_pos_embed.pos_embed') msg = model.load_state_dict(model_state_dict, strict=False) logger.info(msg) del model_state_dict if ema_model is not None: ema_state_dict_path = os.path.join(resume_from, f"ema.safetensors") if not os.path.exists(ema_state_dict_path): logger.info(f"replicaing ema model from {model_state_dict_path}.") ema_state_dict_path = model_state_dict_path ema_state_dict = load_file(ema_state_dict_path, device="cpu") # NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off, # which makes it easier to adapt to different resolutions. ema_state_dict.pop('latent_pos_embed.pos_embed') ema_state_dict.pop('vit_pos_embed.pos_embed') msg = ema_model.load_state_dict(ema_state_dict, strict=False) logger.info(msg) del ema_state_dict else: logger.info(f"Training from scratch.") return model, ema_model @staticmethod def try_load_train_state(resume_from, optimizer, scheduler, fsdp_config): if resume_from is not None and os.path.exists(resume_from): if fsdp_config.sharding_strategy == "FULL_SHARD": shard_index = dist.get_rank() total_shards = dist.get_world_size() elif fsdp_config.sharding_strategy == "HYBRID_SHARD": shard_index = dist.get_rank() % fsdp_config.num_shard total_shards = fsdp_config.num_shard else: raise NotImplementedError optimizer_state_dict_path = os.path.join( resume_from, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt" ) optimizer_state_dict = torch.load(optimizer_state_dict_path, map_location="cpu", weights_only=True) optimizer.load_state_dict(optimizer_state_dict) del optimizer_state_dict scheduler_state_dict_path = os.path.join(resume_from, "scheduler.pt") scheduler_state_dict = torch.load(scheduler_state_dict_path, weights_only=True, map_location="cpu") scheduler.load_state_dict(scheduler_state_dict) del scheduler_state_dict train_steps = int(os.path.basename(os.path.normpath(resume_from))) + 1 """ data_status = [ { dataset_name: { worker_id: [parquet_idx, row_group_id, row_idx], }, }, ] """ data_status_path = os.path.join(resume_from, "data_status.pt") if os.path.exists(data_status_path): data_status = torch.load(data_status_path, weights_only=True, map_location="cpu") local_rank = dist.get_rank() if local_rank < len(data_status): data_status = data_status[local_rank] else: data_status = None else: data_status = None else: train_steps = 0 data_status = None return optimizer, scheduler, train_steps, data_status def grad_checkpoint_check_fn(module): module_options = ( Qwen2DecoderLayer, SiglipEncoderLayer, MLPconnector, Qwen2MoEDecoderLayer, Qwen2MoTDecoderLayer ) return isinstance(module, module_options) def fsdp_ema_setup(ema_model, fsdp_config, ignored_modules=[]): for param in ema_model.parameters(): param.requires_grad = False ema_model = fsdp_wrapper(ema_model, fsdp_config, ignored_modules=ignored_modules) return ema_model @torch.no_grad() def fsdp_ema_update(ema_model, model, decay=0.9999): ema_handles = traversal_utils._get_fsdp_handles(ema_model) new_handles = traversal_utils._get_fsdp_handles(model) assert len(ema_handles) == len(new_handles) ema_params = [] new_params = [] for ema_handle, new_handle in zip(ema_handles, new_handles): if ema_handle.flat_param is not None and new_handle.flat_param.requires_grad: ema_params.append(ema_handle.flat_param.data) new_params.append(new_handle.flat_param.data.to(dtype=ema_handle.flat_param.dtype)) torch._foreach_mul_(ema_params, decay) torch._foreach_add_(ema_params, new_params, alpha=1 - decay)