Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
# SPDX-License-Identifier: Apache-2.0 | |
import functools | |
import os | |
import torch | |
import torch.distributed as dist | |
import torch.distributed.fsdp._traversal_utils as traversal_utils | |
from torch.distributed.device_mesh import init_device_mesh | |
from torch.distributed.fsdp import ( | |
CPUOffload, | |
FullyShardedDataParallel as FSDP, | |
MixedPrecision, | |
BackwardPrefetch, | |
ShardingStrategy, | |
FullStateDictConfig, | |
StateDictType, | |
) | |
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | |
from safetensors.torch import load_file, save_file | |
from modeling.bagel.modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding | |
from modeling.bagel.qwen2_navit import ( | |
Qwen2DecoderLayer, | |
Qwen2MoEDecoderLayer, | |
Qwen2MoTDecoderLayer, | |
) | |
from modeling.bagel.siglip_navit import SiglipEncoderLayer, SiglipVisionTransformer | |
class FSDPConfig: | |
def __init__( | |
self, | |
sharding_strategy, | |
backward_prefetch, | |
cpu_offload, | |
num_replicate, | |
num_shard=8, | |
): | |
self.sharding_strategy = sharding_strategy | |
self.backward_prefetch = backward_prefetch | |
self.cpu_offload = cpu_offload | |
self.num_replicate = num_replicate | |
self.num_shard = num_shard | |
def fsdp_wrapper(original_model, fsdp_config, ignored_modules=[]): | |
if fsdp_config.sharding_strategy == 'HYBRID_SHARD': | |
device_mesh = init_device_mesh( | |
"cuda", | |
mesh_shape=(fsdp_config.num_replicate, fsdp_config.num_shard), | |
mesh_dim_names=("replicate", "shard") | |
) | |
else: | |
device_mesh = None | |
return FSDP( | |
original_model, | |
auto_wrap_policy=functools.partial( | |
transformer_auto_wrap_policy, | |
transformer_layer_cls={ | |
Qwen2DecoderLayer, | |
Qwen2MoEDecoderLayer, | |
Qwen2MoTDecoderLayer, | |
SiglipEncoderLayer, | |
SiglipVisionTransformer, | |
MLPconnector, | |
TimestepEmbedder, | |
PositionEmbedding, | |
}, | |
), | |
ignored_modules=ignored_modules, | |
mixed_precision=MixedPrecision( | |
param_dtype=torch.bfloat16, | |
reduce_dtype=torch.bfloat16, | |
buffer_dtype=torch.bfloat16, | |
), | |
device_id=dist.get_rank() % torch.cuda.device_count(), | |
sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy], | |
backward_prefetch=BackwardPrefetch[fsdp_config.backward_prefetch], | |
cpu_offload=CPUOffload(offload_params=fsdp_config.cpu_offload), | |
device_mesh=device_mesh, | |
) | |
class FSDPCheckpoint: | |
def fsdp_save_ckpt( | |
ckpt_dir, | |
train_steps, | |
model, | |
ema_model, | |
optimizer, | |
scheduler, | |
data_status, | |
logger, | |
fsdp_config, | |
): | |
save_path = os.path.join(ckpt_dir, f"{train_steps:07d}") | |
os.makedirs(save_path, exist_ok=True) | |
logger.info(f"Saving checkpoint to {save_path}.") | |
if ema_model is not None: | |
with FSDP.state_dict_type( | |
ema_model, | |
StateDictType.FULL_STATE_DICT, | |
FullStateDictConfig(rank0_only=True, offload_to_cpu=True), | |
): | |
ema_state_dict = ema_model.state_dict() | |
if dist.get_rank() == 0: | |
save_file(ema_state_dict, os.path.join(save_path, "ema.safetensors")) | |
with FSDP.state_dict_type( | |
model, | |
StateDictType.FULL_STATE_DICT, | |
FullStateDictConfig(rank0_only=True, offload_to_cpu=True), | |
): | |
model_state_dict = model.state_dict() | |
if dist.get_rank() == 0: | |
save_file(model_state_dict, os.path.join(save_path, "model.safetensors")) | |
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): | |
if fsdp_config.sharding_strategy == "FULL_SHARD": | |
shard_index = dist.get_rank() | |
total_shards = dist.get_world_size() | |
elif fsdp_config.sharding_strategy == "HYBRID_SHARD": | |
shard_index = dist.get_rank() % fsdp_config.num_shard | |
total_shards = fsdp_config.num_shard | |
else: | |
raise NotImplementedError | |
optimizer_save_path = os.path.join( | |
save_path, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt" | |
) | |
if fsdp_config.sharding_strategy == "FULL_SHARD": | |
torch.save(optimizer.state_dict(), optimizer_save_path) | |
elif fsdp_config.sharding_strategy == "HYBRID_SHARD": | |
if dist.get_rank() < fsdp_config.num_shard: | |
torch.save(optimizer.state_dict(), optimizer_save_path) | |
else: | |
raise NotImplementedError | |
if dist.get_rank() == 0 and scheduler is not None: | |
torch.save(scheduler.state_dict(), os.path.join(save_path, "scheduler.pt")) | |
if dist.get_rank() == 0 and data_status is not None: | |
torch.save(data_status, os.path.join(save_path, "data_status.pt")) | |
dist.barrier() | |
return | |
def try_load_ckpt(resume_from, logger, model, ema_model=None, resume_from_ema=False): | |
if resume_from is not None and os.path.exists(resume_from): | |
logger.info(f"Loading checkpoint from {resume_from}.") | |
if resume_from_ema: | |
model_state_dict_path = os.path.join(resume_from, f"ema.safetensors") | |
else: | |
model_state_dict_path = os.path.join(resume_from, f"model.safetensors") | |
model_state_dict = load_file(model_state_dict_path, device="cpu") | |
# NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off, | |
# which makes it easier to adapt to different resolutions. | |
model_state_dict.pop('latent_pos_embed.pos_embed') | |
model_state_dict.pop('vit_pos_embed.pos_embed') | |
msg = model.load_state_dict(model_state_dict, strict=False) | |
logger.info(msg) | |
del model_state_dict | |
if ema_model is not None: | |
ema_state_dict_path = os.path.join(resume_from, f"ema.safetensors") | |
if not os.path.exists(ema_state_dict_path): | |
logger.info(f"replicaing ema model from {model_state_dict_path}.") | |
ema_state_dict_path = model_state_dict_path | |
ema_state_dict = load_file(ema_state_dict_path, device="cpu") | |
# NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off, | |
# which makes it easier to adapt to different resolutions. | |
ema_state_dict.pop('latent_pos_embed.pos_embed') | |
ema_state_dict.pop('vit_pos_embed.pos_embed') | |
msg = ema_model.load_state_dict(ema_state_dict, strict=False) | |
logger.info(msg) | |
del ema_state_dict | |
else: | |
logger.info(f"Training from scratch.") | |
return model, ema_model | |
def try_load_train_state(resume_from, optimizer, scheduler, fsdp_config): | |
if resume_from is not None and os.path.exists(resume_from): | |
if fsdp_config.sharding_strategy == "FULL_SHARD": | |
shard_index = dist.get_rank() | |
total_shards = dist.get_world_size() | |
elif fsdp_config.sharding_strategy == "HYBRID_SHARD": | |
shard_index = dist.get_rank() % fsdp_config.num_shard | |
total_shards = fsdp_config.num_shard | |
else: | |
raise NotImplementedError | |
optimizer_state_dict_path = os.path.join( | |
resume_from, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt" | |
) | |
optimizer_state_dict = torch.load(optimizer_state_dict_path, map_location="cpu", weights_only=True) | |
optimizer.load_state_dict(optimizer_state_dict) | |
del optimizer_state_dict | |
scheduler_state_dict_path = os.path.join(resume_from, "scheduler.pt") | |
scheduler_state_dict = torch.load(scheduler_state_dict_path, weights_only=True, map_location="cpu") | |
scheduler.load_state_dict(scheduler_state_dict) | |
del scheduler_state_dict | |
train_steps = int(os.path.basename(os.path.normpath(resume_from))) + 1 | |
""" | |
data_status = [ | |
{ | |
dataset_name: { | |
worker_id: [parquet_idx, row_group_id, row_idx], | |
}, | |
}, | |
] | |
""" | |
data_status_path = os.path.join(resume_from, "data_status.pt") | |
if os.path.exists(data_status_path): | |
data_status = torch.load(data_status_path, weights_only=True, map_location="cpu") | |
local_rank = dist.get_rank() | |
if local_rank < len(data_status): | |
data_status = data_status[local_rank] | |
else: | |
data_status = None | |
else: | |
data_status = None | |
else: | |
train_steps = 0 | |
data_status = None | |
return optimizer, scheduler, train_steps, data_status | |
def grad_checkpoint_check_fn(module): | |
module_options = ( | |
Qwen2DecoderLayer, | |
SiglipEncoderLayer, | |
MLPconnector, | |
Qwen2MoEDecoderLayer, | |
Qwen2MoTDecoderLayer | |
) | |
return isinstance(module, module_options) | |
def fsdp_ema_setup(ema_model, fsdp_config, ignored_modules=[]): | |
for param in ema_model.parameters(): | |
param.requires_grad = False | |
ema_model = fsdp_wrapper(ema_model, fsdp_config, ignored_modules=ignored_modules) | |
return ema_model | |
def fsdp_ema_update(ema_model, model, decay=0.9999): | |
ema_handles = traversal_utils._get_fsdp_handles(ema_model) | |
new_handles = traversal_utils._get_fsdp_handles(model) | |
assert len(ema_handles) == len(new_handles) | |
ema_params = [] | |
new_params = [] | |
for ema_handle, new_handle in zip(ema_handles, new_handles): | |
if ema_handle.flat_param is not None and new_handle.flat_param.requires_grad: | |
ema_params.append(ema_handle.flat_param.data) | |
new_params.append(new_handle.flat_param.data.to(dtype=ema_handle.flat_param.dtype)) | |
torch._foreach_mul_(ema_params, decay) | |
torch._foreach_add_(ema_params, new_params, alpha=1 - decay) | |