Update README.md
Browse files
README.md
CHANGED
@@ -12,4 +12,131 @@ tags:
|
|
12 |
- health
|
13 |
- finetune
|
14 |
- gemma
|
15 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
- health
|
13 |
- finetune
|
14 |
- gemma
|
15 |
+
---
|
16 |
+
|
17 |
+
# Model Card: Gemma 3 1B Mental Health Fine-Tuned
|
18 |
+
|
19 |
+
## Model Overview
|
20 |
+
|
21 |
+
**Model Name**: `Skshackster/gemma3-1b-mental-health-fine-tuned`
|
22 |
+
**Base Model**: `google/gemma-3-1b-it`
|
23 |
+
**Model Type**: Transformer-based Causal Language Model
|
24 |
+
**License**: [Gemma License](https://ai.google.dev/gemma/terms) (Check base model license for specifics)
|
25 |
+
**Developed by**: Saurav Kumar Srivastava
|
26 |
+
**Hosted on**: Hugging Face Hub
|
27 |
+
**Intended Use**: Conversational mental health support, research in mental health dialogue systems
|
28 |
+
|
29 |
+
This model is a fine-tuned version of Google's Gemma 3 1B Instruct model, specifically adapted for mental health therapy conversations. It has been trained on a dataset of 500 carefully curated mental health dialogues, totaling approximately 10 million tokens, to provide empathetic, safe, and supportive responses aligned with the role of a joyous mental therapy assistant.
|
30 |
+
|
31 |
+
## Model Description
|
32 |
+
|
33 |
+
- **Architecture**: Gemma 3 1B is a transformer-based causal language model with 1 billion parameters, optimized for instruction-following and conversational tasks.
|
34 |
+
- **Fine-Tuning Objective**: Enhance the model's ability to engage in empathetic, supportive, and safe mental health conversations, adhering to strict ethical guidelines.
|
35 |
+
- **Training Data**: 500 conversational examples in JSONL format, each containing a sequence of messages with roles (`system`, `user`, `assistant`). The dataset focuses on mental health topics such as stress, sadness, relationship challenges, and emotional well-being.
|
36 |
+
- **Token Count**: Approximately 10 million tokens, derived from the conversational dataset.
|
37 |
+
- **Training Framework**: Hugging Face Transformers, PyTorch, and Datasets libraries.
|
38 |
+
- **Training Duration**: 3 epochs, with checkpoints saved periodically.
|
39 |
+
|
40 |
+
## Dataset
|
41 |
+
|
42 |
+
The fine-tuning dataset consists of 500 mental health therapy conversations, formatted as JSONL. Each conversation includes:
|
43 |
+
|
44 |
+
- **System Prompt**: Defines the assistant’s role as a “helpful and joyous mental therapy assistant,” emphasizing safety, positivity, and ethical responses.
|
45 |
+
- **User Messages**: Describe mental health challenges, such as work-related stress, sadness, or relationship issues.
|
46 |
+
- **Assistant Responses**: Provide empathetic, supportive, and actionable advice, adhering to the system prompt’s guidelines.
|
47 |
+
|
48 |
+
### Dataset Statistics
|
49 |
+
- **Total Examples**: 500
|
50 |
+
- **Estimated Tokens**: ~10 million (based on tokenizer processing)
|
51 |
+
- **Format**: JSONL, with each line containing a `"messages"` list of role-content pairs.
|
52 |
+
- **Roles**: `system`, `user`, `assistant`
|
53 |
+
- **Average Conversation Length**: Variable, with some conversations exceeding 10 turns.
|
54 |
+
- **Content Focus**: Mental health support, covering topics like stress management, emotional resilience, and interpersonal relationships.
|
55 |
+
|
56 |
+
### Data Preparation
|
57 |
+
- **Loading**: Dataset loaded using Hugging Face `datasets` library.
|
58 |
+
- **Splitting**: The dataset was split into training and validation sets. Explicit splitting is recommended to avoid data leakage.
|
59 |
+
- **Tokenization**: Conversations serialized into a string format (`<|role|>content<|eos|>`) and tokenized using the Gemma 3 tokenizer with a maximum length of 1024 tokens, padded to ensure uniform sequence lengths.
|
60 |
+
|
61 |
+
## Technical Details
|
62 |
+
|
63 |
+
### Tokenization
|
64 |
+
- **Tokenizer**: `AutoTokenizer` from `google/gemma-3-1b-it`, using the fast tokenizer implementation.
|
65 |
+
- **Serialization**: Messages are concatenated with role markers (e.g., `<|system|>`, `<|user|>`, `<|assistant|>`) and terminated with the EOS token.
|
66 |
+
- **Padding**: Fixed-length padding to 1024 tokens per sequence.
|
67 |
+
- **Labels**: Input IDs copied as labels for causal language modeling.
|
68 |
+
|
69 |
+
### Model Configuration
|
70 |
+
- **Base Model**: `google/gemma-3-1b-it`
|
71 |
+
- **Precision**: `bfloat16` for efficient training.
|
72 |
+
- **Attention Implementation**: `eager` (Note: Consider `flash_attention_2` for improved performance if available).
|
73 |
+
- **Device Mapping**: `device_map="auto"` for automatic sharding across available devices.
|
74 |
+
- **Memory Optimization**:
|
75 |
+
- Gradient checkpointing enabled to reduce memory usage.
|
76 |
+
- KV cache disabled during training to avoid conflicts with checkpointing.
|
77 |
+
- `low_cpu_mem_usage=True` for faster model initialization.
|
78 |
+
|
79 |
+
### Training Hyperparameters
|
80 |
+
- **Framework**: Hugging Face `Trainer` API
|
81 |
+
- **Batch Size**:
|
82 |
+
- Per-device batch size: 4 (training and evaluation)
|
83 |
+
- Gradient accumulation steps: 16
|
84 |
+
- Effective batch size: 64
|
85 |
+
- **Epochs**: 3
|
86 |
+
- **Learning Rate**: 1e-4
|
87 |
+
- **Warmup Steps**: 200
|
88 |
+
- **Optimizer**: Default (AdamW with Hugging Face defaults)
|
89 |
+
- **Precision**: `bf16=True` for mixed-precision training
|
90 |
+
- **Evaluation**: Performed periodically (Note: More frequent evaluation, e.g., every 100 steps, is recommended for small datasets)
|
91 |
+
- **Checkpointing**: Saved periodically, with a maximum of 3 checkpoints retained
|
92 |
+
- **Hub Integration**: Model checkpoints pushed to `Skshackster/gemma3-1b-mental-health-fine-tuned` on Hugging Face Hub
|
93 |
+
|
94 |
+
## Usage
|
95 |
+
|
96 |
+
### Prerequisites
|
97 |
+
- Python 3.8+
|
98 |
+
- Required libraries: `transformers`, `datasets`, `torch`, `huggingface_hub`
|
99 |
+
|
100 |
+
### Notes
|
101 |
+
- Ensure the input format matches the training data (role markers and EOS tokens).
|
102 |
+
- For optimal performance, use a GPU with `bfloat16` support.
|
103 |
+
- The model is fine-tuned for mental health support and may not generalize to other domains without further training.
|
104 |
+
|
105 |
+
## Ethical Considerations
|
106 |
+
|
107 |
+
- **Intended Use**: This model is designed for supportive mental health conversations, not as a replacement for professional therapy. Users should consult licensed mental health professionals for clinical needs.
|
108 |
+
- **Safety**: The system prompt enforces safe, positive, and unbiased responses, but users should monitor outputs for unintended behavior.
|
109 |
+
- **Bias**: The dataset is curated to avoid harmful content, but biases in the training data may persist. Users are encouraged to report any problematic outputs.
|
110 |
+
- **Privacy**: The model does not store or process personal data beyond the training dataset, which should be anonymized to protect user privacy.
|
111 |
+
- **Limitations**: The model may not handle complex mental health scenarios accurately and should be used as a supplementary tool.
|
112 |
+
|
113 |
+
## Evaluation
|
114 |
+
|
115 |
+
- **Metrics**: Training metrics are available in TensorBoard-compatible format. Evaluation was performed periodically, but more frequent evaluation is recommended for small datasets.
|
116 |
+
- **Performance**: The model is expected to generate empathetic and contextually appropriate responses for mental health queries, but quantitative metrics (e.g., perplexity) are not provided.
|
117 |
+
- **Validation**: Ensure the validation set is distinct from the training set to obtain reliable performance metrics.
|
118 |
+
|
119 |
+
## Limitations
|
120 |
+
|
121 |
+
- **Dataset Size**: With only 500 examples, the model may not capture the full diversity of mental health scenarios.
|
122 |
+
- **Data Leakage**: Using the same data for training and validation risks overfitting. Explicit splitting is recommended.
|
123 |
+
- **Truncation**: Conversations longer than 1024 tokens are truncated, potentially losing context.
|
124 |
+
- **Domain Specificity**: The model is optimized for mental health dialogues and may underperform in other domains.
|
125 |
+
- **Compute Requirements**: Fine-tuning and inference require significant computational resources.
|
126 |
+
|
127 |
+
## Future Improvements
|
128 |
+
|
129 |
+
- **Dataset Expansion**: Include more diverse mental health conversations to improve robustness.
|
130 |
+
- **Dynamic Padding**: Replace fixed-length padding with dynamic batch padding to optimize memory usage.
|
131 |
+
- **Flash Attention**: Use `flash_attention_2` for faster training if supported.
|
132 |
+
- **Frequent Evaluation**: Evaluate more frequently for better monitoring on small datasets.
|
133 |
+
- **Bias Mitigation**: Conduct bias audits and include adversarial testing to ensure fairness.
|
134 |
+
|
135 |
+
## Contact
|
136 |
+
|
137 |
+
For questions, issues, or contributions, please contact the model developer via the Hugging Face Hub or open an issue in the model repository.
|
138 |
+
|
139 |
+
## Acknowledgments
|
140 |
+
|
141 |
+
- Built on the `google/gemma-3-1b-it` model by Google.
|
142 |
+
- Powered by Hugging Face Transformers, Datasets, and PyTorch.
|