|
from datasets import load_dataset, Dataset |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from trl import GRPOConfig, GRPOTrainer |
|
from peft import LoraConfig, get_peft_model |
|
import torch |
|
import os |
|
from collections import defaultdict |
|
|
|
|
|
os.environ["WANDB_PROJECT"] = "phi2-grpo-finetuning" |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
raw_data = load_dataset("OpenAssistant/oasst1", split="train") |
|
|
|
|
|
|
|
conversations = defaultdict(list) |
|
for item in raw_data: |
|
conversations[item["message_tree_id"]].append(item) |
|
|
|
|
|
pairs = [] |
|
for tree_id, msgs in conversations.items(): |
|
prompt = next((m for m in msgs if m["role"] == "prompter" and m["parent_id"] is None), None) |
|
if not prompt: |
|
continue |
|
|
|
|
|
replies = [m for m in msgs if m["parent_id"] == prompt["message_id"]] |
|
|
|
|
|
if len([r for r in replies if r.get("ranking")]) < 2: |
|
|
|
if len(replies) >= 2: |
|
|
|
if all("like_count" in r for r in replies): |
|
ranked = sorted(replies, key=lambda x: x.get("like_count", 0), reverse=True) |
|
else: |
|
ranked = replies[:2] |
|
|
|
chosen = ranked[0]["text"] |
|
rejected = ranked[-1]["text"] |
|
|
|
pairs.append({ |
|
"prompt": prompt["text"], |
|
"chosen": chosen, |
|
"rejected": rejected |
|
}) |
|
continue |
|
|
|
|
|
ranked = sorted(replies, key=lambda x: x["ranking"]) |
|
chosen = ranked[0]["text"] |
|
rejected = ranked[-1]["text"] |
|
|
|
pairs.append({ |
|
"prompt": prompt["text"], |
|
"chosen": chosen, |
|
"rejected": rejected |
|
}) |
|
|
|
|
|
preference_dataset = Dataset.from_list(pairs) |
|
|
|
|
|
if len(preference_dataset) > 1000: |
|
preference_dataset = preference_dataset.select(range(1000)) |
|
|
|
print(f"Created {len(preference_dataset)} preference pairs for GRPO") |
|
|
|
|
|
if len(preference_dataset) > 0: |
|
print("\nSample preference pair:") |
|
print(f"Prompt: {preference_dataset[0]['prompt'][:100]}...") |
|
print(f"Chosen: {preference_dataset[0]['chosen'][:100]}...") |
|
print(f"Rejected: {preference_dataset[0]['rejected'][:100]}...") |
|
else: |
|
print("WARNING: No preference pairs were created. Check the dataset structure.") |
|
|
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
) |
|
|
|
|
|
model_name = "microsoft/phi-2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
quantization_config=quantization_config, |
|
device_map="auto" |
|
) |
|
|
|
|
|
peft_config = LoraConfig( |
|
r=16, |
|
lora_alpha=32, |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
) |
|
|
|
|
|
model = get_peft_model(model, peft_config) |
|
model.print_trainable_parameters() |
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "left" |
|
|
|
|
|
def reward_func(completions, **kwargs): |
|
return [len(c.split()) for c in completions] |
|
|
|
|
|
training_args = GRPOConfig( |
|
output_dir="phi2-grpo-qlora", |
|
num_train_epochs=1, |
|
per_device_train_batch_size=2, |
|
gradient_accumulation_steps=4, |
|
gradient_checkpointing=True, |
|
learning_rate=5e-6, |
|
logging_steps=10, |
|
save_steps=10, |
|
save_total_limit=1, |
|
fp16=True, |
|
remove_unused_columns=False, |
|
report_to="none", |
|
optim="adamw_torch", |
|
lr_scheduler_type="cosine", |
|
warmup_ratio=0.1, |
|
num_generations=2, |
|
) |
|
|
|
|
|
trainer = GRPOTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=preference_dataset, |
|
reward_funcs=reward_func, |
|
) |
|
|
|
|
|
trainer.tokenizer = tokenizer |
|
|
|
|
|
trainer.train(resume_from_checkpoint=True) |
|
|
|
|
|
trainer.save_model("phi2-grpo-qlora-final") |
|
|