Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,547 Bytes
e6af450 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import functools
import os
import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
CPUOffload,
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from safetensors.torch import load_file, save_file
from modeling.bagel.modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
from modeling.bagel.qwen2_navit import (
Qwen2DecoderLayer,
Qwen2MoEDecoderLayer,
Qwen2MoTDecoderLayer,
)
from modeling.bagel.siglip_navit import SiglipEncoderLayer, SiglipVisionTransformer
class FSDPConfig:
def __init__(
self,
sharding_strategy,
backward_prefetch,
cpu_offload,
num_replicate,
num_shard=8,
):
self.sharding_strategy = sharding_strategy
self.backward_prefetch = backward_prefetch
self.cpu_offload = cpu_offload
self.num_replicate = num_replicate
self.num_shard = num_shard
def fsdp_wrapper(original_model, fsdp_config, ignored_modules=[]):
if fsdp_config.sharding_strategy == 'HYBRID_SHARD':
device_mesh = init_device_mesh(
"cuda",
mesh_shape=(fsdp_config.num_replicate, fsdp_config.num_shard),
mesh_dim_names=("replicate", "shard")
)
else:
device_mesh = None
return FSDP(
original_model,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
Qwen2DecoderLayer,
Qwen2MoEDecoderLayer,
Qwen2MoTDecoderLayer,
SiglipEncoderLayer,
SiglipVisionTransformer,
MLPconnector,
TimestepEmbedder,
PositionEmbedding,
},
),
ignored_modules=ignored_modules,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
device_id=dist.get_rank() % torch.cuda.device_count(),
sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy],
backward_prefetch=BackwardPrefetch[fsdp_config.backward_prefetch],
cpu_offload=CPUOffload(offload_params=fsdp_config.cpu_offload),
device_mesh=device_mesh,
)
class FSDPCheckpoint:
@staticmethod
def fsdp_save_ckpt(
ckpt_dir,
train_steps,
model,
ema_model,
optimizer,
scheduler,
data_status,
logger,
fsdp_config,
):
save_path = os.path.join(ckpt_dir, f"{train_steps:07d}")
os.makedirs(save_path, exist_ok=True)
logger.info(f"Saving checkpoint to {save_path}.")
if ema_model is not None:
with FSDP.state_dict_type(
ema_model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
):
ema_state_dict = ema_model.state_dict()
if dist.get_rank() == 0:
save_file(ema_state_dict, os.path.join(save_path, "ema.safetensors"))
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
):
model_state_dict = model.state_dict()
if dist.get_rank() == 0:
save_file(model_state_dict, os.path.join(save_path, "model.safetensors"))
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
if fsdp_config.sharding_strategy == "FULL_SHARD":
shard_index = dist.get_rank()
total_shards = dist.get_world_size()
elif fsdp_config.sharding_strategy == "HYBRID_SHARD":
shard_index = dist.get_rank() % fsdp_config.num_shard
total_shards = fsdp_config.num_shard
else:
raise NotImplementedError
optimizer_save_path = os.path.join(
save_path, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt"
)
if fsdp_config.sharding_strategy == "FULL_SHARD":
torch.save(optimizer.state_dict(), optimizer_save_path)
elif fsdp_config.sharding_strategy == "HYBRID_SHARD":
if dist.get_rank() < fsdp_config.num_shard:
torch.save(optimizer.state_dict(), optimizer_save_path)
else:
raise NotImplementedError
if dist.get_rank() == 0 and scheduler is not None:
torch.save(scheduler.state_dict(), os.path.join(save_path, "scheduler.pt"))
if dist.get_rank() == 0 and data_status is not None:
torch.save(data_status, os.path.join(save_path, "data_status.pt"))
dist.barrier()
return
@staticmethod
def try_load_ckpt(resume_from, logger, model, ema_model=None, resume_from_ema=False):
if resume_from is not None and os.path.exists(resume_from):
logger.info(f"Loading checkpoint from {resume_from}.")
if resume_from_ema:
model_state_dict_path = os.path.join(resume_from, f"ema.safetensors")
else:
model_state_dict_path = os.path.join(resume_from, f"model.safetensors")
model_state_dict = load_file(model_state_dict_path, device="cpu")
# NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off,
# which makes it easier to adapt to different resolutions.
model_state_dict.pop('latent_pos_embed.pos_embed')
model_state_dict.pop('vit_pos_embed.pos_embed')
msg = model.load_state_dict(model_state_dict, strict=False)
logger.info(msg)
del model_state_dict
if ema_model is not None:
ema_state_dict_path = os.path.join(resume_from, f"ema.safetensors")
if not os.path.exists(ema_state_dict_path):
logger.info(f"replicaing ema model from {model_state_dict_path}.")
ema_state_dict_path = model_state_dict_path
ema_state_dict = load_file(ema_state_dict_path, device="cpu")
# NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off,
# which makes it easier to adapt to different resolutions.
ema_state_dict.pop('latent_pos_embed.pos_embed')
ema_state_dict.pop('vit_pos_embed.pos_embed')
msg = ema_model.load_state_dict(ema_state_dict, strict=False)
logger.info(msg)
del ema_state_dict
else:
logger.info(f"Training from scratch.")
return model, ema_model
@staticmethod
def try_load_train_state(resume_from, optimizer, scheduler, fsdp_config):
if resume_from is not None and os.path.exists(resume_from):
if fsdp_config.sharding_strategy == "FULL_SHARD":
shard_index = dist.get_rank()
total_shards = dist.get_world_size()
elif fsdp_config.sharding_strategy == "HYBRID_SHARD":
shard_index = dist.get_rank() % fsdp_config.num_shard
total_shards = fsdp_config.num_shard
else:
raise NotImplementedError
optimizer_state_dict_path = os.path.join(
resume_from, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt"
)
optimizer_state_dict = torch.load(optimizer_state_dict_path, map_location="cpu", weights_only=True)
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
scheduler_state_dict_path = os.path.join(resume_from, "scheduler.pt")
scheduler_state_dict = torch.load(scheduler_state_dict_path, weights_only=True, map_location="cpu")
scheduler.load_state_dict(scheduler_state_dict)
del scheduler_state_dict
train_steps = int(os.path.basename(os.path.normpath(resume_from))) + 1
"""
data_status = [
{
dataset_name: {
worker_id: [parquet_idx, row_group_id, row_idx],
},
},
]
"""
data_status_path = os.path.join(resume_from, "data_status.pt")
if os.path.exists(data_status_path):
data_status = torch.load(data_status_path, weights_only=True, map_location="cpu")
local_rank = dist.get_rank()
if local_rank < len(data_status):
data_status = data_status[local_rank]
else:
data_status = None
else:
data_status = None
else:
train_steps = 0
data_status = None
return optimizer, scheduler, train_steps, data_status
def grad_checkpoint_check_fn(module):
module_options = (
Qwen2DecoderLayer,
SiglipEncoderLayer,
MLPconnector,
Qwen2MoEDecoderLayer,
Qwen2MoTDecoderLayer
)
return isinstance(module, module_options)
def fsdp_ema_setup(ema_model, fsdp_config, ignored_modules=[]):
for param in ema_model.parameters():
param.requires_grad = False
ema_model = fsdp_wrapper(ema_model, fsdp_config, ignored_modules=ignored_modules)
return ema_model
@torch.no_grad()
def fsdp_ema_update(ema_model, model, decay=0.9999):
ema_handles = traversal_utils._get_fsdp_handles(ema_model)
new_handles = traversal_utils._get_fsdp_handles(model)
assert len(ema_handles) == len(new_handles)
ema_params = []
new_params = []
for ema_handle, new_handle in zip(ema_handles, new_handles):
if ema_handle.flat_param is not None and new_handle.flat_param.requires_grad:
ema_params.append(ema_handle.flat_param.data)
new_params.append(new_handle.flat_param.data.to(dtype=ema_handle.flat_param.dtype))
torch._foreach_mul_(ema_params, decay)
torch._foreach_add_(ema_params, new_params, alpha=1 - decay)
|