|
--- |
|
base_model: google/gemma-2-9b-it |
|
library_name: peft |
|
license: gemma |
|
tags: |
|
- trl |
|
- sft |
|
- generated_from_trainer |
|
model-index: |
|
- name: gemma-2-9b-it-lora-yt-titles |
|
results: [] |
|
datasets: |
|
- AdamLucek/youtube-titles |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
# LoRA Adapters for Gemma-2-9B-IT on YouTube Titles |
|
|
|
These are LoRA adapters for [google/gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it) trained on [AdamLucek/youtube-titles](https://huggingface.co/datasets/AdamLucek/youtube-titles). |
|
|
|
Intended task is to tune Gemma 2 9B to generate YouTube title's more similar to popular YouTubers, data was prepped in the instruction tuned token format. |
|
|
|
## Intended uses & limitations |
|
|
|
See original model page [google/gemma-2-9b-it intended usage](https://huggingface.co/google/gemma-2-9b-it#intended-usage) for details about Gemma 2 9B usage, limitations, and ethical considerations. |
|
|
|
## Usage |
|
|
|
The below code will show you how to load and interface with the LoRA model. |
|
|
|
**Loading the Model & LoRA Adapters** |
|
```python |
|
from peft import PeftConfig, PeftModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
import torch |
|
|
|
# Load the Pre Trained Model |
|
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", |
|
quantization_config=BitsAndBytesConfig(load_in_8bit=True), |
|
device_map="auto" |
|
).eval() |
|
# Load the Tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it") |
|
|
|
# Attach LoRA Adapters to Pre Trained Model |
|
model = PeftModel.from_pretrained(model, "AdamLucek/gemma-2-9b-it-lora-yt-titles", adapter_name="youtube_titles") |
|
``` |
|
**Inference** |
|
|
|
```python |
|
topic = "huggingface AI models" |
|
messages = [ |
|
{"role": "user", "content": f"Create a YouTube title about {topic}"} |
|
] |
|
|
|
# Apply chat template and prepare inputs |
|
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
inputs = tokenizer(text, return_tensors="pt") |
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
# Generate outputs |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=0.1, |
|
repetition_penalty=1.2, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
# Decode outputs |
|
decoded = tokenizer.decode(outputs[0]) |
|
|
|
print(decoded) |
|
``` |
|
|
|
## Training procedure |
|
|
|
### Training hyperparameters |
|
|
|
Trained on a single a6000 using the following script |
|
|
|
``` |
|
python \ |
|
examples/scripts/sft.py \ |
|
--model_name_or_path="google/gemma-2-9b-it" \ |
|
--dataset_name="AdamLucek/youtube-titles" \ |
|
--dataset_text_field="gemma2_9b_it_format" \ |
|
--per_device_train_batch_size=4 \ |
|
--per_device_eval_batch_size=4 \ |
|
--gradient_accumulation_steps=4 \ |
|
--max_grad_norm=1.0 \ |
|
--learning_rate=5e-5 \ |
|
--weight_decay=0.01 \ |
|
--lr_scheduler_type="cosine" \ |
|
--warmup_ratio=0.1 \ |
|
--report_to="wandb" \ |
|
--bf16 \ |
|
--max_seq_length=2048 \ |
|
--lora_r=16 \ |
|
--lora_alpha=32 \ |
|
--lora_target_modules q_proj k_proj v_proj o_proj \ |
|
--load_in_8bit \ |
|
--use_peft \ |
|
--attn_implementation="eager" \ |
|
--logging_steps=1 \ |
|
--eval_strategy="steps" \ |
|
--eval_steps=200 \ |
|
--save_strategy="steps" \ |
|
--save_steps=250 \ |
|
--output_dir="models/gemma2" \ |
|
--hub_model_id="gemma-2-9b-it-lora-yt-titles" \ |
|
--push_to_hub \ |
|
--num_train_epochs=3 |
|
``` |
|
|
|
### Training results |
|
|
|
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="200" height="32"/>](https://wandb.ai/adam-lucek/huggingface/runs/vhp5k2tx) |
|
|
|
| Training Loss | Epoch | Step | Validation Loss | |
|
|:-------------:|:------:|:----:|:---------------:| |
|
| 2.2556 | 0.7619 | 200 | 2.0945 | |
|
| 2.1866 | 1.5238 | 400 | 2.0988 | |
|
| 2.3421 | 2.2857 | 600 | 2.2142 | |
|
|
|
It achieves the following results on the evaluation set: |
|
- Loss: 2.2142 |
|
|
|
### Framework versions |
|
|
|
- PEFT 0.11.1 |
|
- Transformers 4.42.3 |
|
- Pytorch 2.0.1 |
|
- Datasets 2.20.0 |
|
- Tokenizers 0.19.1 |