zeju-0727 commited on
Commit
5cad8b7
·
verified ·
1 Parent(s): 9c19867

Upload grpo_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. grpo_train.py +492 -0
grpo_train.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import json
5
+ import glob
6
+ import numpy as np
7
+ import re
8
+ import logging
9
+ import random
10
+ from dataclasses import dataclass, field
11
+ from typing import Dict, List, Optional, Any, Tuple
12
+ from functools import partial
13
+
14
+ from datasets import Dataset as HFDataset
15
+ from transformers import (
16
+ AutoTokenizer,
17
+ AutoModelForCausalLM,
18
+ Trainer,
19
+ TrainingArguments,
20
+ HfArgumentParser,
21
+ set_seed,
22
+ TrainerCallback,
23
+ DataCollatorForLanguageModeling
24
+ )
25
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
26
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
27
+ import bitsandbytes as bnb
28
+ from trl import GRPOConfig, GRPOTrainer
29
+ from accelerate import Accelerator
30
+
31
+ # Configure logging
32
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
33
+ logger = logging.getLogger(__name__)
34
+
35
+ def extract_answer(solution_text: str):
36
+ """Extract the answer from the model's response using regex patterns."""
37
+ boxed_pattern = r'\\boxed\{([^}]*)\}'
38
+ matches = re.findall(boxed_pattern, solution_text)
39
+ if matches:
40
+ return matches[-1].strip()
41
+
42
+ # Try to find a numeric answer if no boxed answer is found
43
+ if "index of -1" in solution_text.lower() or "index: -1" in solution_text.lower():
44
+ return "-1"
45
+
46
+ # Look for paragraph indices
47
+ paragraph_pattern = r'paragraph[\s_]*(\d+)'
48
+ paragraph_matches = re.findall(paragraph_pattern, solution_text.lower())
49
+ if paragraph_matches:
50
+ return paragraph_matches[0]
51
+
52
+ # Check for direct indices
53
+ index_pattern = r'index[\s:]*(is|of)?[\s:]*(-?\d+)'
54
+ index_matches = re.findall(index_pattern, solution_text.lower())
55
+ if index_matches:
56
+ for match in index_matches:
57
+ return match[1]
58
+
59
+ return None
60
+
61
+ def load_mistake_data(file_path):
62
+ """Load data from a JSONL file."""
63
+ data = []
64
+ with open(file_path, 'r') as f:
65
+ for line in f:
66
+ try:
67
+ item = json.loads(line)
68
+ # Convert None to -1 for consistency
69
+ if item.get('mistake_index') is None:
70
+ item['mistake_index'] = -1
71
+ data.append(item)
72
+ except json.JSONDecodeError:
73
+ logger.warning(f"Skipping malformed JSON in {file_path}")
74
+ continue
75
+ return data
76
+
77
+ def prepare_input_mistake(template, input_d):
78
+ """Prepare input for the mistake detection task."""
79
+ problem = input_d['input']
80
+ steps = input_d['steps']
81
+
82
+ # Format the steps with tags for paragraph identification
83
+ tagged_steps = ''
84
+ for sdx, step in enumerate(steps):
85
+ tagged_steps += f'<paragraph_{sdx}>\n{step}\n</paragraph_{sdx}>\n\n'
86
+ tagged_steps = tagged_steps.strip()
87
+
88
+ # Create the formatted prompt using the template
89
+ prompt = template.format(problem=problem, tagged_response=tagged_steps)
90
+ return prompt
91
+
92
+ def compute_reward(prediction, target):
93
+ """
94
+ Compute the reward for a prediction compared to the target.
95
+
96
+ Returns:
97
+ - 1.0 for exact match
98
+ - 0.5 for partial match (e.g., correctly identifying presence of mistake but wrong index)
99
+ - 0.0 for complete mismatch
100
+ """
101
+ if prediction is None:
102
+ return 0.0
103
+
104
+ try:
105
+ pred = int(prediction)
106
+ targ = int(target)
107
+
108
+ if pred == targ:
109
+ return 1.0
110
+ # Partial credit for correctly identifying whether there's a mistake at all
111
+ elif (pred == -1 and targ == -1) or (pred != -1 and targ != -1):
112
+ return 0.5
113
+ else:
114
+ return 0.0
115
+ except (ValueError, TypeError):
116
+ return 0.0
117
+
118
+ def preprocess_function(examples, tokenizer, template, max_length=2048):
119
+ """Process examples for model training."""
120
+ # List to store processed inputs
121
+ # input_ids_list = []
122
+ # attention_mask_list = []
123
+ # labels_list = []
124
+
125
+ prompt_list = []
126
+ groundtruth_list = []
127
+
128
+ for example in examples["data"]:
129
+ # Prepare the prompt
130
+ prompt = prepare_input_mistake(template, example)
131
+ messages = [{"role": "user", "content": prompt}]
132
+
133
+ # Format using chat template
134
+ prompt_text = tokenizer.apply_chat_template(
135
+ messages,
136
+ tokenize=False,
137
+ add_generation_prompt=True
138
+ )
139
+
140
+ prompt_list.append(prompt_text)
141
+ groundtruth_list.append(example["mistake_index"])
142
+
143
+ # # Tokenize
144
+ # encoded = tokenizer(
145
+ # prompt_text,
146
+ # max_length=max_length,
147
+ # padding="max_length",
148
+ # truncation=True,
149
+ # return_tensors="pt"
150
+ # )
151
+
152
+ # input_ids_list.append(encoded["input_ids"][0])
153
+ # attention_mask_list.append(encoded["attention_mask"][0])
154
+ # labels_list.append(encoded["input_ids"][0].clone())
155
+
156
+ # Create processed features
157
+ result = {
158
+ "prompt": prompt_list,
159
+ "ground_truth": groundtruth_list,
160
+ "original_example": examples["data"]
161
+ }
162
+
163
+ return result
164
+
165
+ class SaveBestModelCallback(TrainerCallback):
166
+ """Callback to save best model based on average reward."""
167
+ def __init__(self):
168
+ self.best_reward = -float('inf')
169
+
170
+ def on_evaluate(self, args, state, control, metrics, **kwargs):
171
+ current_reward = metrics.get("eval_reward", 0)
172
+ if current_reward > self.best_reward:
173
+ self.best_reward = current_reward
174
+ # Save the best model
175
+ output_dir = os.path.join(args.output_dir, "best_model")
176
+ os.makedirs(output_dir, exist_ok=True)
177
+
178
+ # Get the model from kwargs
179
+ trainer = kwargs.get("trainer")
180
+ if trainer:
181
+ trainer.save_model(output_dir)
182
+ logger.info(f"Saved best model with reward {current_reward}")
183
+
184
+ def reward_func(completions, ground_truth, **kwargs):
185
+ """
186
+ Compute rewards by comparing model completions to ground truth.
187
+
188
+ Args:
189
+ completions: List of model completion strings
190
+ ground_truth: List of ground truth values
191
+ **kwargs: Additional arguments
192
+
193
+ Returns:
194
+ torch.Tensor: Tensor of rewards
195
+ """
196
+ rewards = []
197
+
198
+ for completion, target in zip(completions, ground_truth):
199
+ # Extract model's prediction from the completion
200
+ prediction = extract_answer(completion)
201
+
202
+ # Convert target if it's a tensor
203
+ if isinstance(target, torch.Tensor):
204
+ target = target.item()
205
+
206
+ # Compute reward
207
+ reward = compute_reward(prediction, target)
208
+ rewards.append(torch.tensor(reward))
209
+
210
+ return torch.stack(rewards)
211
+
212
+ @dataclass
213
+ class ScriptArguments:
214
+ """Arguments for the GRPO training script."""
215
+ model_name_or_path: str = field(
216
+ default="deepseek-ai/deepseek-math-7b-instruct",
217
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
218
+ )
219
+ train_data_dir: str = field(
220
+ default="BIG-Bench-Mistake-Train",
221
+ metadata={"help": "Directory containing training data files"}
222
+ )
223
+ val_data_dir: str = field(
224
+ default="BIG-Bench-Mistake-Test",
225
+ metadata={"help": "Directory containing validation data files"}
226
+ )
227
+ template_path: str = field(
228
+ default="templates/critique_template.txt",
229
+ metadata={"help": "Path to prompt template file"}
230
+ )
231
+ output_dir: str = field(
232
+ default="results/grpo_finetune",
233
+ metadata={"help": "Output directory for model checkpoints"}
234
+ )
235
+ seed: int = field(
236
+ default=42,
237
+ metadata={"help": "Random seed for initialization"}
238
+ )
239
+ max_length: int = field(
240
+ default=2048,
241
+ metadata={"help": "Maximum sequence length for tokenizer"}
242
+ )
243
+ per_device_train_batch_size: int = field(
244
+ default=1,
245
+ metadata={"help": "Batch size per GPU for training"}
246
+ )
247
+ per_device_eval_batch_size: int = field(
248
+ default=1,
249
+ metadata={"help": "Batch size per GPU for evaluation"}
250
+ )
251
+ gradient_accumulation_steps: int = field(
252
+ default=8,
253
+ metadata={"help": "Number of updates steps to accumulate before backward pass"}
254
+ )
255
+ max_train_samples: Optional[int] = field(
256
+ default=None,
257
+ metadata={"help": "Max number of training samples to use (for debugging)"}
258
+ )
259
+ max_eval_samples: Optional[int] = field(
260
+ default=None,
261
+ metadata={"help": "Max number of evaluation samples to use (for debugging)"}
262
+ )
263
+ learning_rate: float = field(
264
+ default=5e-5,
265
+ metadata={"help": "Learning rate for training"}
266
+ )
267
+ num_train_epochs: int = field(
268
+ default=3,
269
+ metadata={"help": "Number of training epochs"}
270
+ )
271
+ logging_steps: int = field(
272
+ default=10,
273
+ metadata={"help": "Log every X updates steps"}
274
+ )
275
+ eval_steps: int = field(
276
+ default=100,
277
+ metadata={"help": "Run evaluation every X steps"}
278
+ )
279
+ save_steps: int = field(
280
+ default=500,
281
+ metadata={"help": "Save checkpoint every X steps"}
282
+ )
283
+ warmup_steps: int = field(
284
+ default=100,
285
+ metadata={"help": "Linear warmup over this many steps"}
286
+ )
287
+ use_lora: bool = field(
288
+ default=True,
289
+ metadata={"help": "Whether to use LoRA for parameter-efficient fine-tuning"}
290
+ )
291
+ lora_r: int = field(
292
+ default=16,
293
+ metadata={"help": "LoRA attention dimension"}
294
+ )
295
+ lora_alpha: int = field(
296
+ default=32,
297
+ metadata={"help": "LoRA alpha parameter"}
298
+ )
299
+ lora_dropout: float = field(
300
+ default=0.05,
301
+ metadata={"help": "LoRA dropout probability"}
302
+ )
303
+ load_in_8bit: bool = field(
304
+ default=False,
305
+ metadata={"help": "Whether to load model in 8-bit precision"}
306
+ )
307
+ load_in_4bit: bool = field(
308
+ default=True,
309
+ metadata={"help": "Whether to load model in 4-bit precision"}
310
+ )
311
+ use_group_rewards: bool = field(
312
+ default=True,
313
+ metadata={"help": "Whether to use group rewards in GRPO"}
314
+ )
315
+ gumbel_samples: int = field(
316
+ default=10,
317
+ metadata={"help": "Number of Gumbel samples for GRPO"}
318
+ )
319
+ critic_multiple: float = field(
320
+ default=0.5,
321
+ metadata={"help": "Critic loss multiplier"}
322
+ )
323
+ deepspeed: Optional[str] = field(
324
+ default=None,
325
+ metadata={"help": "Path to deepspeed config file for using deepspeed"}
326
+ )
327
+
328
+ def main():
329
+ parser = HfArgumentParser(ScriptArguments)
330
+ args = parser.parse_args_into_dataclasses()[0]
331
+
332
+ # Set random seed
333
+ set_seed(args.seed)
334
+
335
+ # Create output directory
336
+ os.makedirs(args.output_dir, exist_ok=True)
337
+
338
+ # Load model and tokenizer
339
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
340
+ tokenizer.pad_token = tokenizer.eos_token
341
+ tokenizer.padding_side = "left"
342
+
343
+ logger.info(f"Loading model {args.model_name_or_path}...")
344
+
345
+ # Prepare model with quantization if needed
346
+ if args.load_in_8bit:
347
+ quantization_config = {"load_in_8bit": True}
348
+ elif args.load_in_4bit:
349
+ quantization_config = {"load_in_4bit": True,
350
+ "bnb_4bit_compute_dtype": torch.float16,
351
+ "bnb_4bit_quant_type": "nf4"}
352
+ else:
353
+ quantization_config = None
354
+
355
+ # For deepspeed compatibility, use torch_dtype=None for fp16/bf16 handling by deepspeed
356
+ model = AutoModelForCausalLM.from_pretrained(
357
+ args.model_name_or_path,
358
+ torch_dtype=None, # Let DeepSpeed handle the precision
359
+ device_map=None, # Don't use device_map with DeepSpeed
360
+ quantization_config=quantization_config
361
+ )
362
+
363
+ # Apply LoRA if specified
364
+ if args.use_lora:
365
+ logger.info("Applying LoRA...")
366
+ if args.load_in_8bit or args.load_in_4bit:
367
+ model = prepare_model_for_kbit_training(model)
368
+
369
+ peft_config = LoraConfig(
370
+ r=args.lora_r,
371
+ lora_alpha=args.lora_alpha,
372
+ lora_dropout=args.lora_dropout,
373
+ bias="none",
374
+ task_type="CAUSAL_LM",
375
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
376
+ )
377
+ model = get_peft_model(model, peft_config)
378
+ model.print_trainable_parameters()
379
+
380
+ # Load template
381
+ with open(args.template_path, 'r') as f:
382
+ template = f.read().strip()
383
+
384
+ # Load training and validation data
385
+ train_files = glob.glob(os.path.join(args.train_data_dir, '*.jsonl'))
386
+ val_files = glob.glob(os.path.join(args.val_data_dir, '*.jsonl'))
387
+
388
+ # Use combined files if available
389
+ if os.path.exists(os.path.join(args.train_data_dir, 'combined_train.jsonl')):
390
+ train_files = [os.path.join(args.train_data_dir, 'combined_train.jsonl')]
391
+
392
+ if os.path.exists(os.path.join(args.val_data_dir, 'combined_test.jsonl')):
393
+ val_files = [os.path.join(args.val_data_dir, 'combined_test.jsonl')]
394
+
395
+ logger.info(f"Loading training data from {len(train_files)} files...")
396
+ train_data = []
397
+ for file in train_files:
398
+ train_data.extend(load_mistake_data(file))
399
+
400
+ logger.info(f"Loading validation data from {len(val_files)} files...")
401
+ val_data = []
402
+ for file in val_files:
403
+ val_data.extend(load_mistake_data(file))
404
+
405
+ # Limit number of samples if specified
406
+ if args.max_train_samples and len(train_data) > args.max_train_samples:
407
+ train_data = random.sample(train_data, args.max_train_samples)
408
+
409
+ if args.max_eval_samples and len(val_data) > args.max_eval_samples:
410
+ val_data = random.sample(val_data, args.max_eval_samples)
411
+
412
+ logger.info(f"Loaded {len(train_data)} training examples and {len(val_data)} validation examples")
413
+
414
+ # Create HF datasets
415
+ train_hf_dataset = HFDataset.from_dict({"data": train_data})
416
+ val_hf_dataset = HFDataset.from_dict({"data": val_data})
417
+
418
+ # Apply preprocessing function
419
+ train_tokenize_func = partial(preprocess_function, tokenizer=tokenizer, template=template, max_length=args.max_length)
420
+ val_tokenize_func = partial(preprocess_function, tokenizer=tokenizer, template=template, max_length=args.max_length)
421
+
422
+ # Process the datasets
423
+ train_dataset = train_hf_dataset.map(
424
+ train_tokenize_func,
425
+ batched=True,
426
+ remove_columns=["data"],
427
+ desc="Processing training dataset"
428
+ )
429
+
430
+ val_dataset = val_hf_dataset.map(
431
+ val_tokenize_func,
432
+ batched=True,
433
+ remove_columns=["data"],
434
+ desc="Processing validation dataset"
435
+ )
436
+
437
+ # Get reward function
438
+ reward_fn = reward_func
439
+
440
+ # Create training arguments with DeepSpeed compatibility
441
+ training_args = GRPOConfig(
442
+ output_dir=args.output_dir,
443
+ per_device_train_batch_size=args.per_device_train_batch_size,
444
+ per_device_eval_batch_size=args.per_device_train_batch_size,
445
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
446
+ learning_rate=args.learning_rate,
447
+ num_train_epochs=args.num_train_epochs,
448
+ logging_steps=args.logging_steps,
449
+ evaluation_strategy="no", # No evaluation during training
450
+ save_strategy="steps",
451
+ save_steps=args.save_steps,
452
+ warmup_steps=args.warmup_steps,
453
+ save_total_limit=3,
454
+ load_best_model_at_end=False, # Don't load best model as we're not evaluating
455
+ weight_decay=0.01,
456
+ # Let DeepSpeed handle mixed precision (set via config file)
457
+ bf16=True,
458
+ report_to="none",
459
+ max_grad_norm=1.0,
460
+ remove_unused_columns=False,
461
+ use_vllm=True,
462
+ # Generation config
463
+ temperature=0.6,
464
+ top_p=0.95,
465
+ num_generations=14,
466
+ # data processings
467
+ max_prompt_length=1024,
468
+ max_completion_length=1024,
469
+ log_completions=True,
470
+ do_eval=False, # Disable evaluation
471
+ )
472
+
473
+ # Create GRPO trainer without evaluation dataset and callback
474
+ trainer = GRPOTrainer(
475
+ model=model,
476
+ args=training_args,
477
+ train_dataset=train_dataset,
478
+ # Remove eval_dataset
479
+ reward_funcs=reward_fn,
480
+ # Remove SaveBestModelCallback
481
+ )
482
+
483
+ # Train the model
484
+ logger.info("Starting training with DeepSpeed...")
485
+ trainer.train()
486
+
487
+ # Save the final model - ensure this runs regardless of accelerator
488
+ trainer.save_model(os.path.join(args.output_dir, "final_model"))
489
+ logger.info(f"Training completed. Final model saved to {os.path.join(args.output_dir, 'final_model')}")
490
+
491
+ if __name__ == "__main__":
492
+ main()