Bagel-7B-Demo / train /train_utils.py
KingNish's picture
Upload 110 files
e6af450 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
def create_logger(logging_dir, rank, filename="log"):
"""
Create a logger that writes to a log file and stdout.
"""
if rank == 0 and logging_dir is not None: # real logger
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/{filename}.txt")
]
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def get_latest_ckpt(checkpoint_dir):
step_dirs = [d for d in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, d))]
if len(step_dirs) == 0:
return None
step_dirs = sorted(step_dirs, key=lambda x: int(x))
latest_step_dir = os.path.join(checkpoint_dir, step_dirs[-1])
return latest_step_dir