Model Card: Gemma 3 1B Mental Health Fine-Tuned
Model Overview
Model Name: Skshackster/gemma3-1b-mental-health-fine-tuned
Base Model: google/gemma-3-1b-it
Model Type: Transformer-based Causal Language Model
License: Gemma License (Check base model license for specifics)
Developed by: Skshackster
Hosted on: Hugging Face Hub
Intended Use: Conversational mental health support, research in mental health dialogue systems
This model is a fine-tuned version of Google's Gemma 3 1B Instruct model, specifically adapted for mental health therapy conversations. It has been trained on a dataset of 500 carefully curated mental health dialogues, totaling approximately 10 million tokens, to provide empathetic, safe, and supportive responses aligned with the role of a joyous mental therapy assistant.
Model Description
- Architecture: Gemma 3 1B is a transformer-based causal language model with 1 billion parameters, optimized for instruction-following and conversational tasks.
- Fine-Tuning Objective: Enhance the model's ability to engage in empathetic, supportive, and safe mental health conversations, adhering to strict ethical guidelines.
- Training Data: 500 conversational examples in JSONL format, each containing a sequence of messages with roles (
system
,user
,assistant
). The dataset focuses on mental health topics such as stress, sadness, relationship challenges, and emotional well-being. - Token Count: Approximately 10 million tokens, derived from the conversational dataset.
- Training Framework: Hugging Face Transformers, PyTorch, and Datasets libraries.
- Training Duration: 3 epochs, with checkpoints saved periodically.
Dataset
The fine-tuning dataset consists of 500 mental health therapy conversations, formatted as JSONL. Each conversation includes:
- System Prompt: Defines the assistantโs role as a โhelpful and joyous mental therapy assistant,โ emphasizing safety, positivity, and ethical responses.
- User Messages: Describe mental health challenges, such as work-related stress, sadness, or relationship issues.
- Assistant Responses: Provide empathetic, supportive, and actionable advice, adhering to the system promptโs guidelines.
Dataset Statistics
- Total Examples: 500
- Estimated Tokens: ~10 million (based on tokenizer processing)
- Format: JSONL, with each line containing a
"messages"
list of role-content pairs. - Roles:
system
,user
,assistant
- Average Conversation Length: Variable, with some conversations exceeding 10 turns.
- Content Focus: Mental health support, covering topics like stress management, emotional resilience, and interpersonal relationships.
Data Preparation
- Loading: Dataset loaded using Hugging Face
datasets
library. - Splitting: The dataset was split into training and validation sets. Explicit splitting is recommended to avoid data leakage.
- Tokenization: Conversations serialized into a string format (
<|role|>content<|eos|>
) and tokenized using the Gemma 3 tokenizer with a maximum length of 1024 tokens, padded to ensure uniform sequence lengths.
Technical Details
Tokenization
- Tokenizer:
AutoTokenizer
fromgoogle/gemma-3-1b-it
, using the fast tokenizer implementation. - Serialization: Messages are concatenated with role markers (e.g.,
<|system|>
,<|user|>
,<|assistant|>
) and terminated with the EOS token. - Padding: Fixed-length padding to 1024 tokens per sequence.
- Labels: Input IDs copied as labels for causal language modeling.
Model Configuration
- Base Model:
google/gemma-3-1b-it
- Precision:
bfloat16
for efficient training. - Attention Implementation:
eager
(Note: Considerflash_attention_2
for improved performance if available). - Device Mapping:
device_map="auto"
for automatic sharding across available devices. - Memory Optimization:
- Gradient checkpointing enabled to reduce memory usage.
- KV cache disabled during training to avoid conflicts with checkpointing.
low_cpu_mem_usage=True
for faster model initialization.
Training Hyperparameters
- Framework: Hugging Face
Trainer
API - Batch Size:
- Per-device batch size: 4 (training and evaluation)
- Gradient accumulation steps: 16
- Effective batch size: 64
- Epochs: 3
- Learning Rate: 1e-4
- Warmup Steps: 200
- Optimizer: Default (AdamW with Hugging Face defaults)
- Precision:
bf16=True
for mixed-precision training - Evaluation: Performed periodically (Note: More frequent evaluation, e.g., every 100 steps, is recommended for small datasets)
- Checkpointing: Saved periodically, with a maximum of 3 checkpoints retained
- Hub Integration: Model checkpoints pushed to
Skshackster/gemma3-1b-mental-health-fine-tuned
on Hugging Face Hub
Usage
Prerequisites
- Python 3.8+
- Required libraries:
transformers
,datasets
,torch
,huggingface_hub
Notes
- Ensure the input format matches the training data (role markers and EOS tokens).
- For optimal performance, use a GPU with
bfloat16
support. - The model is fine-tuned for mental health support and may not generalize to other domains without further training.
Ethical Considerations
- Intended Use: This model is designed for supportive mental health conversations, not as a replacement for professional therapy. Users should consult licensed mental health professionals for clinical needs.
- Safety: The system prompt enforces safe, positive, and unbiased responses, but users should monitor outputs for unintended behavior.
- Bias: The dataset is curated to avoid harmful content, but biases in the training data may persist. Users are encouraged to report any problematic outputs.
- Privacy: The model does not store or process personal data beyond the training dataset, which should be anonymized to protect user privacy.
- Limitations: The model may not handle complex mental health scenarios accurately and should be used as a supplementary tool.
Evaluation
- Metrics: Training metrics are available in TensorBoard-compatible format. Evaluation was performed periodically, but more frequent evaluation is recommended for small datasets.
- Performance: The model is expected to generate empathetic and contextually appropriate responses for mental health queries, but quantitative metrics (e.g., perplexity) are not provided.
- Validation: Ensure the validation set is distinct from the training set to obtain reliable performance metrics.
Limitations
- Dataset Size: With only 500 examples, the model may not capture the full diversity of mental health scenarios.
- Data Leakage: Using the same data for training and validation risks overfitting. Explicit splitting is recommended.
- Truncation: Conversations longer than 1024 tokens are truncated, potentially losing context.
- Domain Specificity: The model is optimized for mental health dialogues and may underperform in other domains.
- Compute Requirements: Fine-tuning and inference require significant computational resources.
Future Improvements
- Dataset Expansion: Include more diverse mental health conversations to improve robustness.
- Dynamic Padding: Replace fixed-length padding with dynamic batch padding to optimize memory usage.
- Flash Attention: Use
flash_attention_2
for faster training if supported. - Frequent Evaluation: Evaluate more frequently for better monitoring on small datasets.
- Bias Mitigation: Conduct bias audits and include adversarial testing to ensure fairness.
Contact
For questions, issues, or contributions, please contact the model developer via the Hugging Face Hub or open an issue in the model repository.
Acknowledgments
- Built on the
google/gemma-3-1b-it
model by Google. - Powered by Hugging Face Transformers, Datasets, and PyTorch.
- Downloads last month
- 28