Skshackster commited on
Commit
16e8831
·
verified ·
1 Parent(s): 5fe8909

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +128 -1
README.md CHANGED
@@ -12,4 +12,131 @@ tags:
12
  - health
13
  - finetune
14
  - gemma
15
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  - health
13
  - finetune
14
  - gemma
15
+ ---
16
+
17
+ # Model Card: Gemma 3 1B Mental Health Fine-Tuned
18
+
19
+ ## Model Overview
20
+
21
+ **Model Name**: `Skshackster/gemma3-1b-mental-health-fine-tuned`
22
+ **Base Model**: `google/gemma-3-1b-it`
23
+ **Model Type**: Transformer-based Causal Language Model
24
+ **License**: [Gemma License](https://ai.google.dev/gemma/terms) (Check base model license for specifics)
25
+ **Developed by**: Saurav Kumar Srivastava
26
+ **Hosted on**: Hugging Face Hub
27
+ **Intended Use**: Conversational mental health support, research in mental health dialogue systems
28
+
29
+ This model is a fine-tuned version of Google's Gemma 3 1B Instruct model, specifically adapted for mental health therapy conversations. It has been trained on a dataset of 500 carefully curated mental health dialogues, totaling approximately 10 million tokens, to provide empathetic, safe, and supportive responses aligned with the role of a joyous mental therapy assistant.
30
+
31
+ ## Model Description
32
+
33
+ - **Architecture**: Gemma 3 1B is a transformer-based causal language model with 1 billion parameters, optimized for instruction-following and conversational tasks.
34
+ - **Fine-Tuning Objective**: Enhance the model's ability to engage in empathetic, supportive, and safe mental health conversations, adhering to strict ethical guidelines.
35
+ - **Training Data**: 500 conversational examples in JSONL format, each containing a sequence of messages with roles (`system`, `user`, `assistant`). The dataset focuses on mental health topics such as stress, sadness, relationship challenges, and emotional well-being.
36
+ - **Token Count**: Approximately 10 million tokens, derived from the conversational dataset.
37
+ - **Training Framework**: Hugging Face Transformers, PyTorch, and Datasets libraries.
38
+ - **Training Duration**: 3 epochs, with checkpoints saved periodically.
39
+
40
+ ## Dataset
41
+
42
+ The fine-tuning dataset consists of 500 mental health therapy conversations, formatted as JSONL. Each conversation includes:
43
+
44
+ - **System Prompt**: Defines the assistant’s role as a “helpful and joyous mental therapy assistant,” emphasizing safety, positivity, and ethical responses.
45
+ - **User Messages**: Describe mental health challenges, such as work-related stress, sadness, or relationship issues.
46
+ - **Assistant Responses**: Provide empathetic, supportive, and actionable advice, adhering to the system prompt’s guidelines.
47
+
48
+ ### Dataset Statistics
49
+ - **Total Examples**: 500
50
+ - **Estimated Tokens**: ~10 million (based on tokenizer processing)
51
+ - **Format**: JSONL, with each line containing a `"messages"` list of role-content pairs.
52
+ - **Roles**: `system`, `user`, `assistant`
53
+ - **Average Conversation Length**: Variable, with some conversations exceeding 10 turns.
54
+ - **Content Focus**: Mental health support, covering topics like stress management, emotional resilience, and interpersonal relationships.
55
+
56
+ ### Data Preparation
57
+ - **Loading**: Dataset loaded using Hugging Face `datasets` library.
58
+ - **Splitting**: The dataset was split into training and validation sets. Explicit splitting is recommended to avoid data leakage.
59
+ - **Tokenization**: Conversations serialized into a string format (`<|role|>content<|eos|>`) and tokenized using the Gemma 3 tokenizer with a maximum length of 1024 tokens, padded to ensure uniform sequence lengths.
60
+
61
+ ## Technical Details
62
+
63
+ ### Tokenization
64
+ - **Tokenizer**: `AutoTokenizer` from `google/gemma-3-1b-it`, using the fast tokenizer implementation.
65
+ - **Serialization**: Messages are concatenated with role markers (e.g., `<|system|>`, `<|user|>`, `<|assistant|>`) and terminated with the EOS token.
66
+ - **Padding**: Fixed-length padding to 1024 tokens per sequence.
67
+ - **Labels**: Input IDs copied as labels for causal language modeling.
68
+
69
+ ### Model Configuration
70
+ - **Base Model**: `google/gemma-3-1b-it`
71
+ - **Precision**: `bfloat16` for efficient training.
72
+ - **Attention Implementation**: `eager` (Note: Consider `flash_attention_2` for improved performance if available).
73
+ - **Device Mapping**: `device_map="auto"` for automatic sharding across available devices.
74
+ - **Memory Optimization**:
75
+ - Gradient checkpointing enabled to reduce memory usage.
76
+ - KV cache disabled during training to avoid conflicts with checkpointing.
77
+ - `low_cpu_mem_usage=True` for faster model initialization.
78
+
79
+ ### Training Hyperparameters
80
+ - **Framework**: Hugging Face `Trainer` API
81
+ - **Batch Size**:
82
+ - Per-device batch size: 4 (training and evaluation)
83
+ - Gradient accumulation steps: 16
84
+ - Effective batch size: 64
85
+ - **Epochs**: 3
86
+ - **Learning Rate**: 1e-4
87
+ - **Warmup Steps**: 200
88
+ - **Optimizer**: Default (AdamW with Hugging Face defaults)
89
+ - **Precision**: `bf16=True` for mixed-precision training
90
+ - **Evaluation**: Performed periodically (Note: More frequent evaluation, e.g., every 100 steps, is recommended for small datasets)
91
+ - **Checkpointing**: Saved periodically, with a maximum of 3 checkpoints retained
92
+ - **Hub Integration**: Model checkpoints pushed to `Skshackster/gemma3-1b-mental-health-fine-tuned` on Hugging Face Hub
93
+
94
+ ## Usage
95
+
96
+ ### Prerequisites
97
+ - Python 3.8+
98
+ - Required libraries: `transformers`, `datasets`, `torch`, `huggingface_hub`
99
+
100
+ ### Notes
101
+ - Ensure the input format matches the training data (role markers and EOS tokens).
102
+ - For optimal performance, use a GPU with `bfloat16` support.
103
+ - The model is fine-tuned for mental health support and may not generalize to other domains without further training.
104
+
105
+ ## Ethical Considerations
106
+
107
+ - **Intended Use**: This model is designed for supportive mental health conversations, not as a replacement for professional therapy. Users should consult licensed mental health professionals for clinical needs.
108
+ - **Safety**: The system prompt enforces safe, positive, and unbiased responses, but users should monitor outputs for unintended behavior.
109
+ - **Bias**: The dataset is curated to avoid harmful content, but biases in the training data may persist. Users are encouraged to report any problematic outputs.
110
+ - **Privacy**: The model does not store or process personal data beyond the training dataset, which should be anonymized to protect user privacy.
111
+ - **Limitations**: The model may not handle complex mental health scenarios accurately and should be used as a supplementary tool.
112
+
113
+ ## Evaluation
114
+
115
+ - **Metrics**: Training metrics are available in TensorBoard-compatible format. Evaluation was performed periodically, but more frequent evaluation is recommended for small datasets.
116
+ - **Performance**: The model is expected to generate empathetic and contextually appropriate responses for mental health queries, but quantitative metrics (e.g., perplexity) are not provided.
117
+ - **Validation**: Ensure the validation set is distinct from the training set to obtain reliable performance metrics.
118
+
119
+ ## Limitations
120
+
121
+ - **Dataset Size**: With only 500 examples, the model may not capture the full diversity of mental health scenarios.
122
+ - **Data Leakage**: Using the same data for training and validation risks overfitting. Explicit splitting is recommended.
123
+ - **Truncation**: Conversations longer than 1024 tokens are truncated, potentially losing context.
124
+ - **Domain Specificity**: The model is optimized for mental health dialogues and may underperform in other domains.
125
+ - **Compute Requirements**: Fine-tuning and inference require significant computational resources.
126
+
127
+ ## Future Improvements
128
+
129
+ - **Dataset Expansion**: Include more diverse mental health conversations to improve robustness.
130
+ - **Dynamic Padding**: Replace fixed-length padding with dynamic batch padding to optimize memory usage.
131
+ - **Flash Attention**: Use `flash_attention_2` for faster training if supported.
132
+ - **Frequent Evaluation**: Evaluate more frequently for better monitoring on small datasets.
133
+ - **Bias Mitigation**: Conduct bias audits and include adversarial testing to ensure fairness.
134
+
135
+ ## Contact
136
+
137
+ For questions, issues, or contributions, please contact the model developer via the Hugging Face Hub or open an issue in the model repository.
138
+
139
+ ## Acknowledgments
140
+
141
+ - Built on the `google/gemma-3-1b-it` model by Google.
142
+ - Powered by Hugging Face Transformers, Datasets, and PyTorch.