Model Card for gemma-3-1b-thinking

This model is a fine-tuned version of google/gemma-3-1b-it. It has been trained using TRL.

Quick start

!pip install -qqq git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3 git+https://github.com/huggingface/trl.git@main bitsandbytes

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import os
from huggingface_hub import login

hf_token = "hf_...." # Put your huggingface token here
login(token=hf_token)

# Base model checkpoint used during training
base_model_ckpt = "google/gemma-3-1b-it"
# Your fine-tuned adapter checkpoint on Hugging Face Hub
adapter_ckpt = "TanishkB/gemma-3-1b-thinking"

# --- Load Model and Tokenizer ---

print(f"Loading tokenizer for base model: {base_model_ckpt}")
# Load the tokenizer associated with the BASE model
tokenizer = AutoTokenizer.from_pretrained(base_model_ckpt)

# Important: Set pad token if not already set (though Gemma's might be)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Loading base model: {base_model_ckpt}")
# Load the base model
# Using bfloat16 for efficiency, matching your training setup
# device_map='auto' lets accelerate handle device placement (GPU if available)
# attn_implementation="eager" matches your training setting.
# Consider "flash_attention_2" if your hardware/environment supports it for better speed.
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_ckpt,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager" # Or "flash_attention_2" if supported
)

print(f"Loading LoRA adapter: {adapter_ckpt}")
# Load the PEFT model (adapter) on top of the base model
model = PeftModel.from_pretrained(base_model, adapter_ckpt)
# Merge adapter into the model if you want a standalone model
# model = model.merge_and_unload()
print("Model loaded successfully.")
model.eval() # Set the model to evaluation mode


system_prompt = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

#user_question = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
user_question = "Capital of Paris?" # Try another type of question

# Create the chat message list, matching the training data structure
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_question},
]

print("\n--- Input Messages ---")
print(messages)

# add_generation_prompt=True adds the tokens needed to signal the model to start generating
# return_tensors='pt' prepares it for PyTorch
input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
)

# Move input tensors to the same device as the model
input_ids = input_ids.to(model.device)

print("\n--- Generating Response ---")

# --- Generate Response ---
with torch.no_grad():
    # Generate output tokens
    # max_new_tokens controls the maximum length of the generated response
    # You might adjust parameters like temperature, top_p, do_sample based on desired output style
    outputs = model.generate(
        input_ids=input_ids,
        max_new_tokens=512, # Adjust as needed, enough for reasoning and answer
        do_sample=True,     # Sample for potentially more diverse answers
        temperature=0.7,    # Controls randomness (lower = more deterministic)
        top_p=0.9,          # Nucleus sampling
        pad_token_id=tokenizer.eos_token_id 
    )

# --- Decode and Print Output ---

# Find the length of the input prompt in tokens
prompt_length = input_ids.shape[-1]
# Slice the output tensor to get only the generated tokens
generated_ids = outputs[0, prompt_length:]

# Decode the generated tokens into text
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

# print("\n--- Model Response ---")
# print(response_text)

print("\n--- Full Output (including prompt formatting) ---")
# Decode the entire output sequence to see the full context if needed
full_output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Note: Gemma's template might add specific role tokens like <start_of_turn> / <end_of_turn>
# which might be visible here if skip_special_tokens=False is used, but usually aren't needed
# for the final response text.
print(full_output_text)

Training procedure

This model was trained with GRPO, a method introduced in DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models.

Framework versions

  • TRL: 0.17.0.dev0
  • Transformers: 4.50.0.dev0
  • Pytorch: 2.5.1+cu121
  • Datasets: 3.3.1
  • Tokenizers: 0.21.0

Citations

Cite GRPO as:

@article{zhihong2024deepseekmath,
    title        = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
    author       = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
    year         = 2024,
    eprint       = {arXiv:2402.03300},
}

Cite TRL as:

@misc{vonwerra2022trl,
    title        = {{TRL: Transformer Reinforcement Learning}},
    author       = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
    year         = 2020,
    journal      = {GitHub repository},
    publisher    = {GitHub},
    howpublished = {\url{https://github.com/huggingface/trl}}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for TanishkB/gemma-3-1b-thinking

Finetuned
(95)
this model