hyperclock commited on
Commit
b08430f
·
verified ·
1 Parent(s): e600784

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py CHANGED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+
4
+ import datasets
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ import torch
8
+ import transformers
9
+ from trl import SFTTrainer
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ ###################
16
+ # Hyper-parameters
17
+ ###################
18
+ training_config = {
19
+ "bf16": True,
20
+ "do_eval": False,
21
+ "learning_rate": 5.0e-06,
22
+ "log_level": "info",
23
+ "logging_steps": 20,
24
+ "logging_strategy": "steps",
25
+ "lr_scheduler_type": "cosine",
26
+ "num_train_epochs": 1,
27
+ "max_steps": -1,
28
+ "output_dir": "./checkpoint_dir",
29
+ "overwrite_output_dir": True,
30
+ "per_device_eval_batch_size": 4,
31
+ "per_device_train_batch_size": 4,
32
+ "remove_unused_columns": True,
33
+ "save_steps": 100,
34
+ "save_total_limit": 1,
35
+ "seed": 0,
36
+ "gradient_checkpointing": True,
37
+ "gradient_checkpointing_kwargs":{"use_reentrant": False},
38
+ "gradient_accumulation_steps": 1,
39
+ "warmup_ratio": 0.2,
40
+ }
41
+
42
+ peft_config = {
43
+ "r": 16,
44
+ "lora_alpha": 32,
45
+ "lora_dropout": 0.05,
46
+ "bias": "none",
47
+ "task_type": "CAUSAL_LM",
48
+ "target_modules": "all-linear",
49
+ "modules_to_save": None,
50
+ }
51
+ train_conf = TrainingArguments(**training_config)
52
+ peft_conf = LoraConfig(**peft_config)
53
+
54
+
55
+ ###############
56
+ # Setup logging
57
+ ###############
58
+ logging.basicConfig(
59
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
60
+ datefmt="%Y-%m-%d %H:%M:%S",
61
+ handlers=[logging.StreamHandler(sys.stdout)],
62
+ )
63
+ log_level = train_conf.get_process_log_level()
64
+ logger.setLevel(log_level)
65
+ datasets.utils.logging.set_verbosity(log_level)
66
+ transformers.utils.logging.set_verbosity(log_level)
67
+ transformers.utils.logging.enable_default_handler()
68
+ transformers.utils.logging.enable_explicit_format()
69
+
70
+ # Log on each process a small summary
71
+ logger.warning(
72
+ f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
73
+ + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
74
+ )
75
+ logger.info(f"Training/evaluation parameters {train_conf}")
76
+ logger.info(f"PEFT parameters {peft_conf}")
77
+
78
+
79
+ ################
80
+ # Model Loading
81
+ ################
82
+ checkpoint_path = "microsoft/Phi-4-mini-instruct"
83
+ model_kwargs = dict(
84
+ use_cache=False,
85
+ trust_remote_code=True,
86
+ attn_implementation="flash_attention_2", # loading the model with flash-attention support
87
+ torch_dtype=torch.bfloat16,
88
+ device_map=None
89
+ )
90
+ model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
91
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
92
+ tokenizer.model_max_length = 2048
93
+ tokenizer.pad_token = tokenizer.unk_token # use unk rather than eos token to prevent endless generation
94
+ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
95
+ tokenizer.padding_side = 'right'
96
+
97
+
98
+ ##################
99
+ # Data Processing
100
+ ##################
101
+ def apply_chat_template(
102
+ example,
103
+ tokenizer,
104
+ ):
105
+ messages = example["messages"]
106
+ example["text"] = tokenizer.apply_chat_template(
107
+ messages, tokenize=False, add_generation_prompt=False)
108
+ return example
109
+
110
+
111
+ train_dataset, test_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split=["train_sft", "test_sft"])
112
+ column_names = list(train_dataset.features)
113
+
114
+ processed_train_dataset = train_dataset.map(
115
+ apply_chat_template,
116
+ fn_kwargs={"tokenizer": tokenizer},
117
+ num_proc=10,
118
+ remove_columns=column_names,
119
+ desc="Applying chat template to train_sft",
120
+ )
121
+
122
+ processed_test_dataset = test_dataset.map(
123
+ apply_chat_template,
124
+ fn_kwargs={"tokenizer": tokenizer},
125
+ num_proc=10,
126
+ remove_columns=column_names,
127
+ desc="Applying chat template to test_sft",
128
+ )
129
+
130
+
131
+ ###########
132
+ # Training
133
+ ###########
134
+ trainer = SFTTrainer(
135
+ model=model,
136
+ args=train_conf,
137
+ peft_config=peft_conf,
138
+ train_dataset=processed_train_dataset,
139
+ eval_dataset=processed_test_dataset,
140
+ max_seq_length=2048,
141
+ dataset_text_field="text",
142
+ tokenizer=tokenizer,
143
+ packing=True
144
+ )
145
+ train_result = trainer.train()
146
+ metrics = train_result.metrics
147
+ trainer.log_metrics("train", metrics)
148
+ trainer.save_metrics("train", metrics)
149
+ trainer.save_state()
150
+
151
+
152
+ #############
153
+ # Evaluation
154
+ #############
155
+ tokenizer.padding_side = 'left'
156
+ metrics = trainer.evaluate()
157
+ metrics["eval_samples"] = len(processed_test_dataset)
158
+ trainer.log_metrics("eval", metrics)
159
+ trainer.save_metrics("eval", metrics)
160
+
161
+
162
+ # ############
163
+ # # Save model
164
+ # ############
165
+ trainer.save_model(train_conf.output_dir)
166
+
167
+
168
+
169
+