|
import os |
|
import argparse |
|
import torch |
|
import json |
|
import glob |
|
import numpy as np |
|
import re |
|
import logging |
|
import random |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Optional, Any, Tuple |
|
from functools import partial |
|
|
|
from datasets import Dataset as HFDataset |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
Trainer, |
|
TrainingArguments, |
|
HfArgumentParser, |
|
set_seed, |
|
TrainerCallback, |
|
DataCollatorForLanguageModeling |
|
) |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
import bitsandbytes as bnb |
|
from trl import GRPOConfig, GRPOTrainer |
|
from accelerate import Accelerator |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def extract_answer(solution_text: str): |
|
"""Extract the answer from the model's response using regex patterns.""" |
|
boxed_pattern = r'\\boxed\{([^}]*)\}' |
|
matches = re.findall(boxed_pattern, solution_text) |
|
if matches: |
|
return matches[-1].strip() |
|
|
|
|
|
if "index of -1" in solution_text.lower() or "index: -1" in solution_text.lower(): |
|
return "-1" |
|
|
|
|
|
paragraph_pattern = r'paragraph[\s_]*(\d+)' |
|
paragraph_matches = re.findall(paragraph_pattern, solution_text.lower()) |
|
if paragraph_matches: |
|
return paragraph_matches[0] |
|
|
|
|
|
index_pattern = r'index[\s:]*(is|of)?[\s:]*(-?\d+)' |
|
index_matches = re.findall(index_pattern, solution_text.lower()) |
|
if index_matches: |
|
for match in index_matches: |
|
return match[1] |
|
|
|
return None |
|
|
|
def load_mistake_data(file_path): |
|
"""Load data from a JSONL file.""" |
|
data = [] |
|
with open(file_path, 'r') as f: |
|
for line in f: |
|
try: |
|
item = json.loads(line) |
|
|
|
if item.get('mistake_index') is None: |
|
item['mistake_index'] = -1 |
|
data.append(item) |
|
except json.JSONDecodeError: |
|
logger.warning(f"Skipping malformed JSON in {file_path}") |
|
continue |
|
return data |
|
|
|
def prepare_input_mistake(template, input_d): |
|
"""Prepare input for the mistake detection task.""" |
|
problem = input_d['input'] |
|
steps = input_d['steps'] |
|
|
|
|
|
tagged_steps = '' |
|
for sdx, step in enumerate(steps): |
|
tagged_steps += f'<paragraph_{sdx}>\n{step}\n</paragraph_{sdx}>\n\n' |
|
tagged_steps = tagged_steps.strip() |
|
|
|
|
|
prompt = template.format(problem=problem, tagged_response=tagged_steps) |
|
return prompt |
|
|
|
def compute_reward(prediction, target): |
|
""" |
|
Compute the reward for a prediction compared to the target. |
|
|
|
Returns: |
|
- 1.0 for exact match |
|
- 0.5 for partial match (e.g., correctly identifying presence of mistake but wrong index) |
|
- 0.0 for complete mismatch |
|
""" |
|
if prediction is None: |
|
return 0.0 |
|
|
|
try: |
|
pred = int(prediction) |
|
targ = int(target) |
|
|
|
if pred == targ: |
|
return 1.0 |
|
|
|
elif (pred == -1 and targ == -1) or (pred != -1 and targ != -1): |
|
return 0.5 |
|
else: |
|
return 0.0 |
|
except (ValueError, TypeError): |
|
return 0.0 |
|
|
|
def preprocess_function(examples, tokenizer, template, max_length=2048): |
|
"""Process examples for model training.""" |
|
|
|
|
|
|
|
|
|
|
|
prompt_list = [] |
|
groundtruth_list = [] |
|
|
|
for example in examples["data"]: |
|
|
|
prompt = prepare_input_mistake(template, example) |
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
prompt_text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
prompt_list.append(prompt_text) |
|
groundtruth_list.append(example["mistake_index"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = { |
|
"prompt": prompt_list, |
|
"ground_truth": groundtruth_list, |
|
"original_example": examples["data"] |
|
} |
|
|
|
return result |
|
|
|
class SaveBestModelCallback(TrainerCallback): |
|
"""Callback to save best model based on average reward.""" |
|
def __init__(self): |
|
self.best_reward = -float('inf') |
|
|
|
def on_evaluate(self, args, state, control, metrics, **kwargs): |
|
current_reward = metrics.get("eval_reward", 0) |
|
if current_reward > self.best_reward: |
|
self.best_reward = current_reward |
|
|
|
output_dir = os.path.join(args.output_dir, "best_model") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
trainer = kwargs.get("trainer") |
|
if trainer: |
|
trainer.save_model(output_dir) |
|
logger.info(f"Saved best model with reward {current_reward}") |
|
|
|
def reward_func(completions, ground_truth, **kwargs): |
|
""" |
|
Compute rewards by comparing model completions to ground truth. |
|
|
|
Args: |
|
completions: List of model completion strings |
|
ground_truth: List of ground truth values |
|
**kwargs: Additional arguments |
|
|
|
Returns: |
|
torch.Tensor: Tensor of rewards |
|
""" |
|
rewards = [] |
|
|
|
for completion, target in zip(completions, ground_truth): |
|
|
|
prediction = extract_answer(completion) |
|
|
|
|
|
if isinstance(target, torch.Tensor): |
|
target = target.item() |
|
|
|
|
|
reward = compute_reward(prediction, target) |
|
rewards.append(torch.tensor(reward)) |
|
|
|
return torch.stack(rewards) |
|
|
|
@dataclass |
|
class ScriptArguments: |
|
"""Arguments for the GRPO training script.""" |
|
model_name_or_path: str = field( |
|
default="deepseek-ai/deepseek-math-7b-instruct", |
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} |
|
) |
|
train_data_dir: str = field( |
|
default="BIG-Bench-Mistake-Train", |
|
metadata={"help": "Directory containing training data files"} |
|
) |
|
val_data_dir: str = field( |
|
default="BIG-Bench-Mistake-Test", |
|
metadata={"help": "Directory containing validation data files"} |
|
) |
|
template_path: str = field( |
|
default="templates/critique_template.txt", |
|
metadata={"help": "Path to prompt template file"} |
|
) |
|
output_dir: str = field( |
|
default="results/grpo_finetune", |
|
metadata={"help": "Output directory for model checkpoints"} |
|
) |
|
seed: int = field( |
|
default=42, |
|
metadata={"help": "Random seed for initialization"} |
|
) |
|
max_length: int = field( |
|
default=2048, |
|
metadata={"help": "Maximum sequence length for tokenizer"} |
|
) |
|
per_device_train_batch_size: int = field( |
|
default=1, |
|
metadata={"help": "Batch size per GPU for training"} |
|
) |
|
per_device_eval_batch_size: int = field( |
|
default=1, |
|
metadata={"help": "Batch size per GPU for evaluation"} |
|
) |
|
gradient_accumulation_steps: int = field( |
|
default=8, |
|
metadata={"help": "Number of updates steps to accumulate before backward pass"} |
|
) |
|
max_train_samples: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Max number of training samples to use (for debugging)"} |
|
) |
|
max_eval_samples: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "Max number of evaluation samples to use (for debugging)"} |
|
) |
|
learning_rate: float = field( |
|
default=5e-5, |
|
metadata={"help": "Learning rate for training"} |
|
) |
|
num_train_epochs: int = field( |
|
default=3, |
|
metadata={"help": "Number of training epochs"} |
|
) |
|
logging_steps: int = field( |
|
default=10, |
|
metadata={"help": "Log every X updates steps"} |
|
) |
|
eval_steps: int = field( |
|
default=100, |
|
metadata={"help": "Run evaluation every X steps"} |
|
) |
|
save_steps: int = field( |
|
default=500, |
|
metadata={"help": "Save checkpoint every X steps"} |
|
) |
|
warmup_steps: int = field( |
|
default=100, |
|
metadata={"help": "Linear warmup over this many steps"} |
|
) |
|
use_lora: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to use LoRA for parameter-efficient fine-tuning"} |
|
) |
|
lora_r: int = field( |
|
default=16, |
|
metadata={"help": "LoRA attention dimension"} |
|
) |
|
lora_alpha: int = field( |
|
default=32, |
|
metadata={"help": "LoRA alpha parameter"} |
|
) |
|
lora_dropout: float = field( |
|
default=0.05, |
|
metadata={"help": "LoRA dropout probability"} |
|
) |
|
load_in_8bit: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to load model in 8-bit precision"} |
|
) |
|
load_in_4bit: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to load model in 4-bit precision"} |
|
) |
|
use_group_rewards: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to use group rewards in GRPO"} |
|
) |
|
gumbel_samples: int = field( |
|
default=10, |
|
metadata={"help": "Number of Gumbel samples for GRPO"} |
|
) |
|
critic_multiple: float = field( |
|
default=0.5, |
|
metadata={"help": "Critic loss multiplier"} |
|
) |
|
deepspeed: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Path to deepspeed config file for using deepspeed"} |
|
) |
|
|
|
def main(): |
|
parser = HfArgumentParser(ScriptArguments) |
|
args = parser.parse_args_into_dataclasses()[0] |
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "left" |
|
|
|
logger.info(f"Loading model {args.model_name_or_path}...") |
|
|
|
|
|
if args.load_in_8bit: |
|
quantization_config = {"load_in_8bit": True} |
|
elif args.load_in_4bit: |
|
quantization_config = {"load_in_4bit": True, |
|
"bnb_4bit_compute_dtype": torch.float16, |
|
"bnb_4bit_quant_type": "nf4"} |
|
else: |
|
quantization_config = None |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.model_name_or_path, |
|
torch_dtype=None, |
|
device_map=None, |
|
quantization_config=quantization_config |
|
) |
|
|
|
|
|
if args.use_lora: |
|
logger.info("Applying LoRA...") |
|
if args.load_in_8bit or args.load_in_4bit: |
|
model = prepare_model_for_kbit_training(model) |
|
|
|
peft_config = LoraConfig( |
|
r=args.lora_r, |
|
lora_alpha=args.lora_alpha, |
|
lora_dropout=args.lora_dropout, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] |
|
) |
|
model = get_peft_model(model, peft_config) |
|
model.print_trainable_parameters() |
|
|
|
|
|
with open(args.template_path, 'r') as f: |
|
template = f.read().strip() |
|
|
|
|
|
train_files = glob.glob(os.path.join(args.train_data_dir, '*.jsonl')) |
|
val_files = glob.glob(os.path.join(args.val_data_dir, '*.jsonl')) |
|
|
|
|
|
if os.path.exists(os.path.join(args.train_data_dir, 'combined_train.jsonl')): |
|
train_files = [os.path.join(args.train_data_dir, 'combined_train.jsonl')] |
|
|
|
if os.path.exists(os.path.join(args.val_data_dir, 'combined_test.jsonl')): |
|
val_files = [os.path.join(args.val_data_dir, 'combined_test.jsonl')] |
|
|
|
logger.info(f"Loading training data from {len(train_files)} files...") |
|
train_data = [] |
|
for file in train_files: |
|
train_data.extend(load_mistake_data(file)) |
|
|
|
logger.info(f"Loading validation data from {len(val_files)} files...") |
|
val_data = [] |
|
for file in val_files: |
|
val_data.extend(load_mistake_data(file)) |
|
|
|
|
|
if args.max_train_samples and len(train_data) > args.max_train_samples: |
|
train_data = random.sample(train_data, args.max_train_samples) |
|
|
|
if args.max_eval_samples and len(val_data) > args.max_eval_samples: |
|
val_data = random.sample(val_data, args.max_eval_samples) |
|
|
|
logger.info(f"Loaded {len(train_data)} training examples and {len(val_data)} validation examples") |
|
|
|
|
|
train_hf_dataset = HFDataset.from_dict({"data": train_data}) |
|
val_hf_dataset = HFDataset.from_dict({"data": val_data}) |
|
|
|
|
|
train_tokenize_func = partial(preprocess_function, tokenizer=tokenizer, template=template, max_length=args.max_length) |
|
val_tokenize_func = partial(preprocess_function, tokenizer=tokenizer, template=template, max_length=args.max_length) |
|
|
|
|
|
train_dataset = train_hf_dataset.map( |
|
train_tokenize_func, |
|
batched=True, |
|
remove_columns=["data"], |
|
desc="Processing training dataset" |
|
) |
|
|
|
val_dataset = val_hf_dataset.map( |
|
val_tokenize_func, |
|
batched=True, |
|
remove_columns=["data"], |
|
desc="Processing validation dataset" |
|
) |
|
|
|
|
|
reward_fn = reward_func |
|
|
|
|
|
training_args = GRPOConfig( |
|
output_dir=args.output_dir, |
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
per_device_eval_batch_size=args.per_device_train_batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
learning_rate=args.learning_rate, |
|
num_train_epochs=args.num_train_epochs, |
|
logging_steps=args.logging_steps, |
|
evaluation_strategy="no", |
|
save_strategy="steps", |
|
save_steps=args.save_steps, |
|
warmup_steps=args.warmup_steps, |
|
save_total_limit=3, |
|
load_best_model_at_end=False, |
|
weight_decay=0.01, |
|
|
|
bf16=True, |
|
report_to="none", |
|
max_grad_norm=1.0, |
|
remove_unused_columns=False, |
|
use_vllm=True, |
|
|
|
temperature=0.6, |
|
top_p=0.95, |
|
num_generations=14, |
|
|
|
max_prompt_length=1024, |
|
max_completion_length=1024, |
|
log_completions=True, |
|
do_eval=False, |
|
) |
|
|
|
|
|
trainer = GRPOTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
|
|
reward_funcs=reward_fn, |
|
|
|
) |
|
|
|
|
|
logger.info("Starting training with DeepSpeed...") |
|
trainer.train() |
|
|
|
|
|
trainer.save_model(os.path.join(args.output_dir, "final_model")) |
|
logger.info(f"Training completed. Final model saved to {os.path.join(args.output_dir, 'final_model')}") |
|
|
|
if __name__ == "__main__": |
|
main() |