Dyve_plus_RL_copy / grpo_train.py
zeju-0727's picture
Upload grpo_train.py with huggingface_hub
5cad8b7 verified
import os
import argparse
import torch
import json
import glob
import numpy as np
import re
import logging
import random
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Tuple
from functools import partial
from datasets import Dataset as HFDataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
HfArgumentParser,
set_seed,
TrainerCallback,
DataCollatorForLanguageModeling
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import bitsandbytes as bnb
from trl import GRPOConfig, GRPOTrainer
from accelerate import Accelerator
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def extract_answer(solution_text: str):
"""Extract the answer from the model's response using regex patterns."""
boxed_pattern = r'\\boxed\{([^}]*)\}'
matches = re.findall(boxed_pattern, solution_text)
if matches:
return matches[-1].strip()
# Try to find a numeric answer if no boxed answer is found
if "index of -1" in solution_text.lower() or "index: -1" in solution_text.lower():
return "-1"
# Look for paragraph indices
paragraph_pattern = r'paragraph[\s_]*(\d+)'
paragraph_matches = re.findall(paragraph_pattern, solution_text.lower())
if paragraph_matches:
return paragraph_matches[0]
# Check for direct indices
index_pattern = r'index[\s:]*(is|of)?[\s:]*(-?\d+)'
index_matches = re.findall(index_pattern, solution_text.lower())
if index_matches:
for match in index_matches:
return match[1]
return None
def load_mistake_data(file_path):
"""Load data from a JSONL file."""
data = []
with open(file_path, 'r') as f:
for line in f:
try:
item = json.loads(line)
# Convert None to -1 for consistency
if item.get('mistake_index') is None:
item['mistake_index'] = -1
data.append(item)
except json.JSONDecodeError:
logger.warning(f"Skipping malformed JSON in {file_path}")
continue
return data
def prepare_input_mistake(template, input_d):
"""Prepare input for the mistake detection task."""
problem = input_d['input']
steps = input_d['steps']
# Format the steps with tags for paragraph identification
tagged_steps = ''
for sdx, step in enumerate(steps):
tagged_steps += f'<paragraph_{sdx}>\n{step}\n</paragraph_{sdx}>\n\n'
tagged_steps = tagged_steps.strip()
# Create the formatted prompt using the template
prompt = template.format(problem=problem, tagged_response=tagged_steps)
return prompt
def compute_reward(prediction, target):
"""
Compute the reward for a prediction compared to the target.
Returns:
- 1.0 for exact match
- 0.5 for partial match (e.g., correctly identifying presence of mistake but wrong index)
- 0.0 for complete mismatch
"""
if prediction is None:
return 0.0
try:
pred = int(prediction)
targ = int(target)
if pred == targ:
return 1.0
# Partial credit for correctly identifying whether there's a mistake at all
elif (pred == -1 and targ == -1) or (pred != -1 and targ != -1):
return 0.5
else:
return 0.0
except (ValueError, TypeError):
return 0.0
def preprocess_function(examples, tokenizer, template, max_length=2048):
"""Process examples for model training."""
# List to store processed inputs
# input_ids_list = []
# attention_mask_list = []
# labels_list = []
prompt_list = []
groundtruth_list = []
for example in examples["data"]:
# Prepare the prompt
prompt = prepare_input_mistake(template, example)
messages = [{"role": "user", "content": prompt}]
# Format using chat template
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
prompt_list.append(prompt_text)
groundtruth_list.append(example["mistake_index"])
# # Tokenize
# encoded = tokenizer(
# prompt_text,
# max_length=max_length,
# padding="max_length",
# truncation=True,
# return_tensors="pt"
# )
# input_ids_list.append(encoded["input_ids"][0])
# attention_mask_list.append(encoded["attention_mask"][0])
# labels_list.append(encoded["input_ids"][0].clone())
# Create processed features
result = {
"prompt": prompt_list,
"ground_truth": groundtruth_list,
"original_example": examples["data"]
}
return result
class SaveBestModelCallback(TrainerCallback):
"""Callback to save best model based on average reward."""
def __init__(self):
self.best_reward = -float('inf')
def on_evaluate(self, args, state, control, metrics, **kwargs):
current_reward = metrics.get("eval_reward", 0)
if current_reward > self.best_reward:
self.best_reward = current_reward
# Save the best model
output_dir = os.path.join(args.output_dir, "best_model")
os.makedirs(output_dir, exist_ok=True)
# Get the model from kwargs
trainer = kwargs.get("trainer")
if trainer:
trainer.save_model(output_dir)
logger.info(f"Saved best model with reward {current_reward}")
def reward_func(completions, ground_truth, **kwargs):
"""
Compute rewards by comparing model completions to ground truth.
Args:
completions: List of model completion strings
ground_truth: List of ground truth values
**kwargs: Additional arguments
Returns:
torch.Tensor: Tensor of rewards
"""
rewards = []
for completion, target in zip(completions, ground_truth):
# Extract model's prediction from the completion
prediction = extract_answer(completion)
# Convert target if it's a tensor
if isinstance(target, torch.Tensor):
target = target.item()
# Compute reward
reward = compute_reward(prediction, target)
rewards.append(torch.tensor(reward))
return torch.stack(rewards)
@dataclass
class ScriptArguments:
"""Arguments for the GRPO training script."""
model_name_or_path: str = field(
default="deepseek-ai/deepseek-math-7b-instruct",
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
train_data_dir: str = field(
default="BIG-Bench-Mistake-Train",
metadata={"help": "Directory containing training data files"}
)
val_data_dir: str = field(
default="BIG-Bench-Mistake-Test",
metadata={"help": "Directory containing validation data files"}
)
template_path: str = field(
default="templates/critique_template.txt",
metadata={"help": "Path to prompt template file"}
)
output_dir: str = field(
default="results/grpo_finetune",
metadata={"help": "Output directory for model checkpoints"}
)
seed: int = field(
default=42,
metadata={"help": "Random seed for initialization"}
)
max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length for tokenizer"}
)
per_device_train_batch_size: int = field(
default=1,
metadata={"help": "Batch size per GPU for training"}
)
per_device_eval_batch_size: int = field(
default=1,
metadata={"help": "Batch size per GPU for evaluation"}
)
gradient_accumulation_steps: int = field(
default=8,
metadata={"help": "Number of updates steps to accumulate before backward pass"}
)
max_train_samples: Optional[int] = field(
default=None,
metadata={"help": "Max number of training samples to use (for debugging)"}
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={"help": "Max number of evaluation samples to use (for debugging)"}
)
learning_rate: float = field(
default=5e-5,
metadata={"help": "Learning rate for training"}
)
num_train_epochs: int = field(
default=3,
metadata={"help": "Number of training epochs"}
)
logging_steps: int = field(
default=10,
metadata={"help": "Log every X updates steps"}
)
eval_steps: int = field(
default=100,
metadata={"help": "Run evaluation every X steps"}
)
save_steps: int = field(
default=500,
metadata={"help": "Save checkpoint every X steps"}
)
warmup_steps: int = field(
default=100,
metadata={"help": "Linear warmup over this many steps"}
)
use_lora: bool = field(
default=True,
metadata={"help": "Whether to use LoRA for parameter-efficient fine-tuning"}
)
lora_r: int = field(
default=16,
metadata={"help": "LoRA attention dimension"}
)
lora_alpha: int = field(
default=32,
metadata={"help": "LoRA alpha parameter"}
)
lora_dropout: float = field(
default=0.05,
metadata={"help": "LoRA dropout probability"}
)
load_in_8bit: bool = field(
default=False,
metadata={"help": "Whether to load model in 8-bit precision"}
)
load_in_4bit: bool = field(
default=True,
metadata={"help": "Whether to load model in 4-bit precision"}
)
use_group_rewards: bool = field(
default=True,
metadata={"help": "Whether to use group rewards in GRPO"}
)
gumbel_samples: int = field(
default=10,
metadata={"help": "Number of Gumbel samples for GRPO"}
)
critic_multiple: float = field(
default=0.5,
metadata={"help": "Critic loss multiplier"}
)
deepspeed: Optional[str] = field(
default=None,
metadata={"help": "Path to deepspeed config file for using deepspeed"}
)
def main():
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
# Set random seed
set_seed(args.seed)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
logger.info(f"Loading model {args.model_name_or_path}...")
# Prepare model with quantization if needed
if args.load_in_8bit:
quantization_config = {"load_in_8bit": True}
elif args.load_in_4bit:
quantization_config = {"load_in_4bit": True,
"bnb_4bit_compute_dtype": torch.float16,
"bnb_4bit_quant_type": "nf4"}
else:
quantization_config = None
# For deepspeed compatibility, use torch_dtype=None for fp16/bf16 handling by deepspeed
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
torch_dtype=None, # Let DeepSpeed handle the precision
device_map=None, # Don't use device_map with DeepSpeed
quantization_config=quantization_config
)
# Apply LoRA if specified
if args.use_lora:
logger.info("Applying LoRA...")
if args.load_in_8bit or args.load_in_4bit:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Load template
with open(args.template_path, 'r') as f:
template = f.read().strip()
# Load training and validation data
train_files = glob.glob(os.path.join(args.train_data_dir, '*.jsonl'))
val_files = glob.glob(os.path.join(args.val_data_dir, '*.jsonl'))
# Use combined files if available
if os.path.exists(os.path.join(args.train_data_dir, 'combined_train.jsonl')):
train_files = [os.path.join(args.train_data_dir, 'combined_train.jsonl')]
if os.path.exists(os.path.join(args.val_data_dir, 'combined_test.jsonl')):
val_files = [os.path.join(args.val_data_dir, 'combined_test.jsonl')]
logger.info(f"Loading training data from {len(train_files)} files...")
train_data = []
for file in train_files:
train_data.extend(load_mistake_data(file))
logger.info(f"Loading validation data from {len(val_files)} files...")
val_data = []
for file in val_files:
val_data.extend(load_mistake_data(file))
# Limit number of samples if specified
if args.max_train_samples and len(train_data) > args.max_train_samples:
train_data = random.sample(train_data, args.max_train_samples)
if args.max_eval_samples and len(val_data) > args.max_eval_samples:
val_data = random.sample(val_data, args.max_eval_samples)
logger.info(f"Loaded {len(train_data)} training examples and {len(val_data)} validation examples")
# Create HF datasets
train_hf_dataset = HFDataset.from_dict({"data": train_data})
val_hf_dataset = HFDataset.from_dict({"data": val_data})
# Apply preprocessing function
train_tokenize_func = partial(preprocess_function, tokenizer=tokenizer, template=template, max_length=args.max_length)
val_tokenize_func = partial(preprocess_function, tokenizer=tokenizer, template=template, max_length=args.max_length)
# Process the datasets
train_dataset = train_hf_dataset.map(
train_tokenize_func,
batched=True,
remove_columns=["data"],
desc="Processing training dataset"
)
val_dataset = val_hf_dataset.map(
val_tokenize_func,
batched=True,
remove_columns=["data"],
desc="Processing validation dataset"
)
# Get reward function
reward_fn = reward_func
# Create training arguments with DeepSpeed compatibility
training_args = GRPOConfig(
output_dir=args.output_dir,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
logging_steps=args.logging_steps,
evaluation_strategy="no", # No evaluation during training
save_strategy="steps",
save_steps=args.save_steps,
warmup_steps=args.warmup_steps,
save_total_limit=3,
load_best_model_at_end=False, # Don't load best model as we're not evaluating
weight_decay=0.01,
# Let DeepSpeed handle mixed precision (set via config file)
bf16=True,
report_to="none",
max_grad_norm=1.0,
remove_unused_columns=False,
use_vllm=True,
# Generation config
temperature=0.6,
top_p=0.95,
num_generations=14,
# data processings
max_prompt_length=1024,
max_completion_length=1024,
log_completions=True,
do_eval=False, # Disable evaluation
)
# Create GRPO trainer without evaluation dataset and callback
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
# Remove eval_dataset
reward_funcs=reward_fn,
# Remove SaveBestModelCallback
)
# Train the model
logger.info("Starting training with DeepSpeed...")
trainer.train()
# Save the final model - ensure this runs regardless of accelerator
trainer.save_model(os.path.join(args.output_dir, "final_model"))
logger.info(f"Training completed. Final model saved to {os.path.join(args.output_dir, 'final_model')}")
if __name__ == "__main__":
main()