anonamename commited on
Commit
691c3fc
·
verified ·
1 Parent(s): c63a187

Upload turing-motors/Heron-NVILA-Lite-15B

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ llm/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ license_link: https://huggingface.co/Qwen/Qwen2.5-14B-Instruct/blob/main/LICENSE
4
+ language:
5
+ - ja
6
+ - en
7
+ tags:
8
+ - vila
9
+ - nvila
10
+ - conversational
11
+ - multimodal
12
+ base_model:
13
+ - Qwen/Qwen2.5-14B-Instruct
14
+ - Efficient-Large-Model/paligemma-siglip-so400m-patch14-448
15
+ ---
16
+ # Heron NVILA-Lite 15B
17
+
18
+ Heron NVILA-Lite 15B is a vision language model trained for Japanese, based on the [NVILA](https://arxiv.org/abs/2412.04468)-Lite architecture.
19
+
20
+ ## Model Overview
21
+
22
+ * **Developer**: [Turing Inc.](https://www.turing-motors.com/)
23
+ * **Vision Encoder**: [paligemma-siglip-so400m-patch14-448](https://huggingface.co/Efficient-Large-Model/paligemma-siglip-so400m-patch14-448)
24
+ * **Projector**: mlp_downsample_3x3_fix
25
+ * **LLM**: [Qwen2.5-14B-Instruct](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct)
26
+ * **Supported Languages**: Japanese, English
27
+
28
+ ## Setup
29
+
30
+ ```bash
31
+ # I have confirmed that 4.46.0 and 4.49.0 also work. Other versions of Transformer may also work, but I have not tested them.
32
+ pip install transformers==4.45.0 accelerate opencv-python torchvision einops pillow
33
+ pip install git+https://github.com/bfshi/scaling_on_scales.git
34
+ ```
35
+
36
+ ## Usage
37
+
38
+ ```python
39
+ from transformers import AutoConfig, AutoModel
40
+
41
+ model_path = "turing-motors/Heron-NVILA-Lite-15B"
42
+
43
+ # you can use config
44
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
45
+ model = AutoModel.from_config(config, trust_remote_code=True, device_map="auto")
46
+
47
+ # or directly from_pretrained
48
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
49
+
50
+ # show chat_template
51
+ print(model.tokenizer.chat_template)
52
+
53
+ # examples generate with raw text
54
+ response = model.generate_content(["こんにちは"])
55
+ print(response)
56
+ print("---" * 40)
57
+
58
+ # examples generate with text + image
59
+ from PIL import Image
60
+ import requests
61
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
62
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
63
+ response = model.generate_content([image, "画像を説明してください。"])
64
+ print(response)
65
+ print("---" * 40)
66
+
67
+ # examples generate using generation_config
68
+ from PIL import Image
69
+ import requests
70
+ from transformers import GenerationConfig
71
+ generation_config = {
72
+ "max_new_tokens": 512,
73
+ "temperature": 0.5,
74
+ "do_sample": True,
75
+ }
76
+ generation_config = GenerationConfig(**generation_config)
77
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
78
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
79
+ response = model.generate_content(
80
+ [image, "画像を説明してください。"],
81
+ generation_config=generation_config
82
+ )
83
+ print(response)
84
+ print("---" * 40)
85
+
86
+ # examples generate with text + image + text + image + text
87
+ from PIL import Image
88
+ import requests
89
+ url_list = [
90
+ "https://images.unsplash.com/photo-1694831404826-3400c48c188d?q=80&w=2070&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
91
+ "https://images.unsplash.com/photo-1693240876439-473af88b4ed7?q=80&w=1974&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
92
+ ]
93
+ images = [
94
+ Image.open(requests.get(url, stream=True).raw).convert("RGB") for url in url_list
95
+ ]
96
+ response = model.generate_content([
97
+ images[0],
98
+ "これは日本の横断歩道の画像です",
99
+ images[1],
100
+ "これはオーストリアの信号機の画像です",
101
+ "各画像に写っている歩行者用信号機の色は何色ですか?"])
102
+ print(response)
103
+ print("---" * 40)
104
+ ```
105
+
106
+ ## Training Summary
107
+
108
+ | Stage | Training | Data Sources | Samples |
109
+ |--------|-------------------------------|-------------------------------|-------------|
110
+ | Stage1 | Projector | [Japanese image text pairs](https://gitlab.llm-jp.nii.ac.jp/datasets/llm-jp-japanese-image-text-pairs), [LLaVA-Pretrain](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain) | 1.1M |
111
+ | Stage2 | Projector, LLM | Filtered MOMIJI 3 snapshots (CC-MAIN-2024-46, CC-MAIN-2024-51, CC-MAIN-2025-05) | 13M |
112
+ | | | [Japanese image text pairs (subset)](https://gitlab.llm-jp.nii.ac.jp/datasets/llm-jp-japanese-image-text-pairs), [Japanese interleaved data (subset)](https://gitlab.llm-jp.nii.ac.jp/datasets/llm-jp-japanese-interleaved-data), [mmc4-core (subset)](https://github.com/allenai/mmc4), [coyo-700m (subset)](https://huggingface.co/datasets/kakaobrain/coyo-700m), [wikipedia_ja](https://huggingface.co/datasets/turing-motors/Wikipedia-Vision-JA), [llava_pretrain_ja](https://huggingface.co/datasets/turing-motors/LLaVA-Pretrain-JA), [stair_captions](http://captions.stair.center/) | 20M |
113
+ | Stage3 | Vision Encoder, Projector, LLM | [llava-instruct-v1_5-en-subset-358k](https://huggingface.co/datasets/llm-jp/llava-instruct-v1_5-en-subset-358k), [llava-instruct-ja](https://huggingface.co/datasets/llm-jp/llava-instruct-ja), [japanese-photos-conv](https://huggingface.co/datasets/llm-jp/japanese-photos-conversation), [ja-vg-vqa](https://huggingface.co/datasets/llm-jp/ja-vg-vqa-conversation), [synthdog-ja (subset)](https://huggingface.co/datasets/naver-clova-ix/synthdog-ja), [ai2d](https://huggingface.co/datasets/lmms-lab/ai2d), [synthdog-en](https://huggingface.co/datasets/naver-clova-ix/synthdog-en), [sherlock](https://github.com/allenai/sherlock) | 1.4M |
114
+
115
+ ## Evaluation
116
+ I used [llm-jp-eval-mm](https://github.com/llm-jp/llm-jp-eval-mm) for this evaluation. All scores other than our models are taken from [llm-jp-eval-mm leaderboard](https://llm-jp.github.io/llm-jp-eval-mm/) and the [Asagi website](https://uehara-mech.github.io/asagi-vlm?v=1).
117
+
118
+ | Model | LLM Size | Heron-Bench overall LLM (%) | JA-VLM-Bench-In-the-Wild LLM (/5.0) | JA-VG-VQA-500 LLM (/5.0) |
119
+ |--------------------------------|----------|------------------------------|-------------------------------------|--------------------------|
120
+ | **Heron NVILA-Lite 2B** | 1.5B | 52.8 | 3.52 | 3.50 |
121
+ | **Heron NVILA-Lite 15B** | 14B | 59.6 | 4.2 | 3.82 |
122
+ | [LLaVA-CALM2-SigLIP](https://huggingface.co/cyberagent/llava-calm2-siglip) | 7B | 43.3 | 3.15 | 3.21 |
123
+ | [Llama-3-EvoVLM-JP-v2](https://huggingface.co/SakanaAI/Llama-3-EvoVLM-JP-v2) | 8B | 39.3 | 2.92 | 2.96 |
124
+ | [VILA-jp](https://huggingface.co/llm-jp/llm-jp-3-vila-14b) | 13B | 57.2 | 3.69 | 3.62 |
125
+ | [Asagi-14B](https://huggingface.co/MIL-UT/Asagi-14B) | 13B | 55.8 | 3.44 | 3.84 |
126
+ | [Qwen2-VL 7B Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) | 7B | 55.5 | 3.61 | 3.6 |
127
+ | GPT-4o | - | 87.6 | 3.85 | 3.58 |
128
+
129
+
130
+ ## Risks and Limitations
131
+
132
+ This model is experimental and has not been thoroughly calibrated for ethical compliance or legal standards. Caution is advised for sensitive applications.
133
+
134
+ ## License
135
+
136
+ - Model weights are licensed under [Apache License 2.0](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct/blob/main/LICENSE).
137
+ - Users must comply with [OpenAI terms of use](https://openai.com/policies/terms-of-use) due to the inclusion of GPT-4-generated synthetic data.
138
+
139
+ ## How to cite
140
+
141
+ ```bibtex
142
+ @misc{HeronNVILALite15B,
143
+ title = {Heron NVILA-Lite 15B},
144
+ author = {Shingo Yokoi},
145
+ year = {2025},
146
+ url = {https://huggingface.co/turing-motors/Heron-NVILA-Lite-15B},
147
+ }
148
+ ```
149
+
150
+ ## Citations
151
+
152
+ ```bibtex
153
+ @misc{liu2025nvilaefficientfrontiervisual,
154
+ title={NVILA: Efficient Frontier Visual Language Models},
155
+ author={Zhijian Liu and Ligeng Zhu and Baifeng Shi and Zhuoyang Zhang and Yuming Lou and Shang Yang and Haocheng Xi and Shiyi Cao and Yuxian Gu and Dacheng Li and Xiuyu Li and Yunhao Fang and Yukang Chen and Cheng-Yu Hsieh and De-An Huang and An-Chieh Cheng and Vishwesh Nath and Jinyi Hu and Sifei Liu and Ranjay Krishna and Daguang Xu and Xiaolong Wang and Pavlo Molchanov and Jan Kautz and Hongxu Yin and Song Han and Yao Lu},
156
+ year={2025},
157
+ eprint={2412.04468},
158
+ archivePrefix={arXiv},
159
+ primaryClass={cs.CV},
160
+ url={https://arxiv.org/abs/2412.04468},
161
+ }
162
+ ```
__init__.py ADDED
File without changes
auto_processor.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import os.path as osp
4
+ import warnings
5
+ from collections import defaultdict
6
+ from typing import List, Union
7
+
8
+ import torch
9
+ from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
10
+ from transformers.feature_extraction_utils import BatchFeature
11
+ from transformers.image_utils import ImageInput, VideoInput
12
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
13
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
14
+ from transformers.utils import logging
15
+
16
+ from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS
17
+ from .media import Image, Video, extract_media
18
+ from .mm_utils import process_image, process_images
19
+ from .tokenizer_utils import tokenize_conversation
20
+
21
+ def fetch_image_url_or_fpath(url_or_fpath):
22
+ if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
23
+ import tempfile
24
+ import requests
25
+
26
+ # Download the image to a temporary file
27
+ temp_dir = tempfile.mkdtemp()
28
+ temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
29
+
30
+ response = requests.get(url_or_fpath, stream=True)
31
+ response.raise_for_status()
32
+
33
+ with open(temp_file, "wb") as f:
34
+ for chunk in response.iter_content(chunk_size=8192):
35
+ f.write(chunk)
36
+
37
+ return temp_file
38
+ elif url_or_fpath.startswith("file://"):
39
+ fpath = url_or_fpath.replace("file://", "")
40
+ assert osp.exists(fpath), f"File {fpath} does not exist"
41
+ return fpath
42
+ elif osp.exists(url_or_fpath):
43
+ assert osp.isfile(url_or_fpath), f"File {url_or_fpath} is not a file"
44
+ return url_or_fpath
45
+ else:
46
+ raise ValueError(f"Unsupported image path: {url_or_fpath}")
47
+
48
+
49
+ def __pad_fn(input_ids_list, padding_value=0, target_len=None, padding_side="left"):
50
+ # tensor shape is (batch_size, seq_len)
51
+ max_len = max([ids.shape[1] for ids in input_ids_list])
52
+ if target_len is not None:
53
+ assert target_len >= max_len, "target_len must be greater than or equal to max_len"
54
+ max_len = target_len
55
+
56
+ new_input_ids_list = []
57
+ for i, input_ids in enumerate(input_ids_list):
58
+ pad_tensor = torch.ones_like(input_ids) * padding_value
59
+ curr_len = input_ids.shape[1]
60
+ pad_tensor = pad_tensor[:, : max_len - curr_len]
61
+ if padding_side == "right":
62
+ input_ids = torch.cat((input_ids, pad_tensor), dim=1)
63
+ else:
64
+ input_ids = torch.cat((pad_tensor, input_ids), dim=1)
65
+ new_input_ids_list.append(input_ids)
66
+ return torch.cat(new_input_ids_list, dim=0)
67
+
68
+
69
+ class VILAProcessorKwargs(ProcessingKwargs, total=False):
70
+ _defaults = {
71
+ "text_kwargs": {
72
+ "padding": False,
73
+ },
74
+ }
75
+
76
+
77
+
78
+
79
+ class VILAProcessor(ProcessorMixin):
80
+ # attributes = ["image_processor", "tokenizer"]
81
+ attributes = []
82
+ # valid_kwargs = ["chat_template"]
83
+ valid_kwargs = []
84
+ # image_processor_class = "VILAImageProcessor"
85
+ # tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
86
+
87
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs):
88
+ # self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
89
+ # self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
90
+ self.image_token = MEDIA_TOKENS["image"]
91
+ self.video_token = MEDIA_TOKENS["video"]
92
+ self.config = config
93
+ self.image_processor = image_processor
94
+ self.tokenizer = tokenizer
95
+
96
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
97
+
98
+ @classmethod
99
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
100
+ if os.path.isdir(pretrained_model_name_or_path):
101
+ pretrained_model_name_or_path = pretrained_model_name_or_path
102
+ else:
103
+ print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading")
104
+ from huggingface_hub import snapshot_download
105
+
106
+ pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
107
+
108
+ image_processor = AutoImageProcessor.from_pretrained(
109
+ osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True
110
+ )
111
+ tokenizer = AutoTokenizer.from_pretrained(
112
+ osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True
113
+ )
114
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
115
+ return cls(image_processor=image_processor, tokenizer=tokenizer, config=config)
116
+
117
+ def __repr__(self):
118
+ return (
119
+ f"VILAProcessor(image_processor={self.image_processor}, tokenizer={self.tokenizer}, config={self.config})"
120
+ )
121
+
122
+ def __call__(
123
+ self,
124
+ conversation,
125
+ images: ImageInput = None,
126
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
127
+ videos: VideoInput = None,
128
+ **kwargs: Unpack[VILAProcessorKwargs],
129
+ ) -> BatchFeature:
130
+ if images is not None:
131
+ warnings.warn("images is not supported in __call__")
132
+
133
+ input_ids = []
134
+ media = defaultdict(list)
135
+ media_config = defaultdict(dict)
136
+ for conv in conversation:
137
+ feat = self.__single_call__(conv, images, text, videos, **kwargs)
138
+ input_ids.append(feat.input_ids)
139
+ for name in feat.media:
140
+ media[name] += feat.media[name]
141
+ for name in feat.media_config:
142
+ media_config[name].update(feat.media_config[name])
143
+
144
+ return BatchFeature(
145
+ data={
146
+ # "input_ids": torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id),
147
+ "input_ids": __pad_fn(
148
+ input_ids,
149
+ padding_value=self.tokenizer.pad_token_id,
150
+ padding_side="left",
151
+ ),
152
+ "media": media,
153
+ "media_config": media_config,
154
+ }
155
+ )
156
+
157
+ def __single_call__(
158
+ self,
159
+ conversation,
160
+ images: ImageInput = None,
161
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
162
+ videos: VideoInput = None,
163
+ **kwargs: Unpack[VILAProcessorKwargs],
164
+ ) -> BatchFeature:
165
+ # TODO: should be merged with llava_arch.py/generate_content()
166
+ # TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used)
167
+ conversation = copy.deepcopy(conversation)
168
+ media = extract_media(conversation, self.config)
169
+ # Process media
170
+ media_config = defaultdict(dict)
171
+ for name in media:
172
+ if name == "image":
173
+ if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
174
+ self.config.image_processor = self.image_processor
175
+ if self.config.image_aspect_ratio == "dynamic":
176
+ images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
177
+ conversation[0]["value"] = conversation[0]["value"].replace(
178
+ DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
179
+ )
180
+ else:
181
+ if type(self.config.s2_scales) is str:
182
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
183
+ images, block_sizes = process_image(
184
+ media["image"][0], self.config, None, enable_dynamic_s2=True
185
+ )
186
+ images = images.half()
187
+ media_config[name]["block_sizes"] = [block_sizes]
188
+ else:
189
+ images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
190
+ media[name] = [image for image in images]
191
+ elif name == "video":
192
+ media[name] = [
193
+ process_images(images, self.vision_tower.image_processor, self.config).half()
194
+ for images in media[name]
195
+ ]
196
+ else:
197
+ raise ValueError(f"Unsupported media type: {name}")
198
+ input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0)
199
+ # Set up the generation config
200
+ return BatchFeature(data={"input_ids": input_ids, "media": media, "media_config": media_config})
201
+
202
+ def batch_decode(self, *args, **kwargs):
203
+ """
204
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
205
+ refer to the docstring of this method for more information.
206
+ """
207
+ return self.tokenizer.batch_decode(*args, **kwargs)
208
+
209
+ def decode(self, *args, **kwargs):
210
+ """
211
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
212
+ the docstring of this method for more information.
213
+ """
214
+ return self.tokenizer.decode(*args, **kwargs)
215
+
216
+ def post_process_image_text_to_text(self, generated_outputs):
217
+ """
218
+ Post-process the output of the model to decode the text.
219
+
220
+ Args:
221
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
222
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
223
+ or `(sequence_length,)`.
224
+
225
+ Returns:
226
+ `List[str]`: The decoded text.
227
+ """
228
+ return self.tokenizer.batch_decode(
229
+ generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
230
+ )
231
+
232
+ @property
233
+ def model_input_names(self):
234
+ tokenizer_input_names = self.tokenizer.model_input_names
235
+ image_processor_input_names = self.image_processor.model_input_names
236
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
237
+
238
+ # inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt")
239
+ def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
240
+ vila_conv = []
241
+ for chat in conversation:
242
+ vila_chat = {"from": "", "value": []}
243
+ if chat["role"] == "user":
244
+ # user allows to input image and text
245
+ vila_chat["from"] = "human"
246
+ for content in chat["content"]:
247
+ if content["type"] == "image":
248
+ if "path" in content:
249
+ # VILA style
250
+ vila_chat["value"].append(Image(fetch_image_url_or_fpath(content["path"])))
251
+ elif "image" in content:
252
+ # Qwen style
253
+ vila_chat["value"].append(Image(fetch_image_url_or_fpath(content["image"])))
254
+ else:
255
+ raise ValueError(f"Unsupported content type `image`: {content}, `image` and `path` are required")
256
+ elif content["type"] == "text":
257
+ vila_chat["value"].append(content["text"])
258
+ # NOTE(ligeng): video supports are needed here
259
+ else:
260
+ raise ValueError(f"Unsupported content type: {content['type']}")
261
+ elif chat["role"] == "assistant":
262
+ vila_chat["from"] = "gpt"
263
+ for content in chat["content"]:
264
+ assert content["type"] == "text", f"Unsupported content type: {content['type']}"
265
+ vila_chat["value"].append(content["text"])
266
+ vila_conv.append(vila_chat)
267
+
268
+ return vila_conv
269
+
270
+
271
+ if __name__ == "__main__":
272
+ # gpt style: user, assistant
273
+ # vila style: human, gpt
274
+ gpt_conv = [
275
+ {
276
+ "role": "user",
277
+ "content": [
278
+ {"type": "image", "path": "demo_images/demo_img_1.png"},
279
+ {"type": "text", "text": "Describe this image."},
280
+ ],
281
+ }
282
+ ]
283
+
284
+ llavaconv = [
285
+ {
286
+ "from": "human",
287
+ "value": [
288
+ PIL.Image.open("demo_images/demo_img_1.png"),
289
+ "Describe this image.",
290
+ ],
291
+ }
292
+ ]
293
+
294
+ processor = AutoProcessor.from_pretrained(output_dir, trust_remote_code=True)
295
+ inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
296
+ # model = llava.load("Efficient-Large-Model/qwen25_2B_3x3-sft").cuda()
297
+ # print(model)
298
+ model_path = "NVILA-Lite-2B-hf-preview"
299
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
300
+ # res = model.generate_content(["how are you today?"])
301
+ # print(model.config)
302
+ # print(model.tokenizer)
303
+ # print(res)
304
+ # exit(0)
305
+
306
+ processor = VILAProcessor(
307
+ config=model.config,
308
+ image_processor=model.vision_tower.image_processor,
309
+ tokenizer=model.tokenizer,
310
+ )
311
+
312
+ # TODO: add padding, return_tensors,
313
+ inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt")
314
+ print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
315
+ print("vila conv pass")
316
+
317
+ inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
318
+ print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
319
+ print("gpt conv pass")
320
+
321
+ output_ids = model.generate(
322
+ input_ids=inputs.input_ids,
323
+ media={
324
+ "image": inputs.image,
325
+ },
326
+ media_config={"image": {}},
327
+ generation_config=model.generation_config,
328
+ max_new_tokens=100,
329
+ )
330
+ print(output_ids)
base_projector.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import re
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
22
+
23
+
24
+ class IdentityMap(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def forward(self, x, *args, **kwargs):
29
+ return x
30
+
31
+ @property
32
+ def config(self):
33
+ return {"mm_projector_type": "identity"}
34
+
35
+
36
+ class SimpleResBlock(nn.Module):
37
+ def __init__(self, channels):
38
+ super().__init__()
39
+ self.pre_norm = nn.LayerNorm(channels)
40
+
41
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
42
+
43
+ def forward(self, x):
44
+ x = self.pre_norm(x)
45
+ return x + self.proj(x)
46
+
47
+
48
+ class DownSampleBlock(nn.Module):
49
+ def forward(self, x):
50
+ vit_embeds = x
51
+ h = w = int(vit_embeds.shape[1] ** 0.5)
52
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
53
+ vit_embeds = self.flat_square(vit_embeds)
54
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
55
+ return vit_embeds
56
+
57
+ def flat_square(self, x):
58
+ n, w, h, c = x.size()
59
+ if w % 2 == 1:
60
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
61
+ n, w, h, c = x.size()
62
+ if h % 2 == 1:
63
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
64
+ n, w, h, c = x.size()
65
+ x = x.contiguous()
66
+ x = x.view(n, w, int(h / 2), int(c * 2))
67
+ x = x.permute(0, 2, 1, 3).contiguous()
68
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
69
+ x = x.permute(0, 2, 1, 3).contiguous()
70
+ return x
71
+
72
+
73
+ class DownSample2x2BlockFix(nn.Module):
74
+ def forward(self, x):
75
+ vit_embeds = x
76
+ h = w = int(vit_embeds.shape[1] ** 0.5)
77
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
78
+ vit_embeds = flat_square_2x2(vit_embeds)
79
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
80
+ return vit_embeds
81
+
82
+
83
+ def flat_square_2x2(x):
84
+ n, w, h, c = x.size()
85
+ if w % 2 == 1:
86
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
87
+ n, w, h, c = x.size()
88
+ x = x.contiguous()
89
+ if h % 2 == 1:
90
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
91
+ n, w, h, c = x.size()
92
+ x = x.view(n, w, int(h / 2), int(c * 2))
93
+ x = x.permute(0, 2, 1, 3).contiguous()
94
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
95
+ x = x.permute(0, 2, 1, 3).contiguous()
96
+ return x
97
+
98
+
99
+ class DownSample3x3BlockFix(nn.Module):
100
+ def forward(self, x):
101
+ vit_embeds = x
102
+ h = w = int(vit_embeds.shape[1] ** 0.5)
103
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
104
+ vit_embeds = flat_square_3x3(vit_embeds)
105
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
106
+ return vit_embeds
107
+
108
+
109
+ def flat_square_3x3(x):
110
+ n, w, h, c = x.size()
111
+ if w % 3 != 0:
112
+ x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
113
+ n, w, h, c = x.size()
114
+ x = x.contiguous()
115
+ if h % 3 != 0:
116
+ x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
117
+ n, w, h, c = x.size()
118
+ x = x.view(n, w, int(h / 3), int(c * 3))
119
+ x = x.permute(0, 2, 1, 3).contiguous()
120
+ x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
121
+ x = x.permute(0, 2, 1, 3).contiguous()
122
+ return x
123
+
124
+
125
+ class MultimodalProjectorConfig(PretrainedConfig):
126
+ model_type = "v2l_projector"
127
+
128
+ def __init__(self, mm_projector_type: str = None, **kwargs):
129
+ super().__init__()
130
+ self.mm_projector_type = mm_projector_type
131
+
132
+
133
+ class MultimodalProjector(PreTrainedModel):
134
+ config_class = MultimodalProjectorConfig
135
+
136
+ def __init__(self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig):
137
+ super().__init__(mm_projector_cfg)
138
+ mm_projector_type = mm_projector_cfg.mm_projector_type
139
+ self.downsample_rate = 1
140
+ if mm_projector_type == "identity":
141
+ self.layers = IdentityMap()
142
+ elif mm_projector_type == "linear":
143
+ self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
144
+ elif mm_projector_type == "mlp_downsample":
145
+ self.layers = nn.Sequential(
146
+ DownSampleBlock(),
147
+ nn.LayerNorm(config.mm_hidden_size * 4),
148
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
149
+ nn.GELU(),
150
+ nn.Linear(config.hidden_size, config.hidden_size),
151
+ )
152
+ self.downsample_rate = 2
153
+ elif mm_projector_type == "mlp_downsample_2x2_fix":
154
+ self.layers = nn.Sequential(
155
+ DownSample2x2BlockFix(),
156
+ nn.LayerNorm(config.mm_hidden_size * 4),
157
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
158
+ nn.GELU(),
159
+ nn.Linear(config.hidden_size, config.hidden_size),
160
+ )
161
+ self.downsample_rate = 2
162
+ elif mm_projector_type == "mlp_downsample_3x3_fix":
163
+ self.layers = nn.Sequential(
164
+ DownSample3x3BlockFix(),
165
+ nn.LayerNorm(config.mm_hidden_size * 9),
166
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
167
+ nn.GELU(),
168
+ nn.LayerNorm(config.mm_hidden_size * 3),
169
+ nn.Linear(config.mm_hidden_size * 3, config.hidden_size),
170
+ nn.GELU(),
171
+ nn.Linear(config.hidden_size, config.hidden_size),
172
+ )
173
+ self.downsample_rate = 3
174
+ elif mm_projector_type == "mlp_downsample_3x3_s2":
175
+ self.layers = nn.Sequential(
176
+ DownSample3x3BlockFix(),
177
+ nn.LayerNorm(config.mm_hidden_size * 9),
178
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
179
+ nn.GELU(),
180
+ nn.LayerNorm(config.mm_hidden_size * 3),
181
+ nn.Linear(config.mm_hidden_size * 3, config.mm_hidden_size),
182
+ nn.GELU(),
183
+ nn.LayerNorm(config.mm_hidden_size),
184
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
185
+ nn.GELU(),
186
+ nn.LayerNorm(config.mm_hidden_size // 3),
187
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
188
+ nn.GELU(),
189
+ nn.Linear(config.hidden_size, config.hidden_size),
190
+ )
191
+ elif mm_projector_type == "mlp_downsample_3x3_s2_new":
192
+ self.layers = nn.Sequential(
193
+ DownSample3x3BlockFix(),
194
+ nn.LayerNorm(config.mm_hidden_size * 9),
195
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 4),
196
+ nn.GELU(),
197
+ nn.LayerNorm(config.mm_hidden_size * 4),
198
+ nn.Linear(config.mm_hidden_size * 4, config.mm_hidden_size * 2),
199
+ nn.GELU(),
200
+ nn.LayerNorm(config.mm_hidden_size * 2),
201
+ nn.Linear(config.mm_hidden_size * 2, config.mm_hidden_size),
202
+ nn.GELU(),
203
+ nn.LayerNorm(config.mm_hidden_size),
204
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
205
+ nn.GELU(),
206
+ nn.LayerNorm(config.mm_hidden_size // 3),
207
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
208
+ nn.GELU(),
209
+ nn.Linear(config.hidden_size, config.hidden_size),
210
+ )
211
+ else:
212
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
213
+ if mlp_gelu_match:
214
+ mlp_depth = int(mlp_gelu_match.group(1))
215
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
216
+ for _ in range(1, mlp_depth):
217
+ modules.append(nn.GELU())
218
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
219
+ self.layers = nn.Sequential(*modules)
220
+ else:
221
+ raise ValueError(f"Unknown projector type: {mm_projector_type}")
222
+
223
+ def forward(self, x, *args, **kwargs):
224
+ return self.layers(x)
225
+
226
+
227
+ # AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
228
+ # AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
builder.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import math
18
+ import os
19
+ import os.path as osp
20
+ import warnings
21
+ from dataclasses import asdict
22
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
23
+
24
+ import torch
25
+ import transformers
26
+ from huggingface_hub import file_exists, repo_exists
27
+ from huggingface_hub.utils import HFValidationError
28
+ from transformers import (
29
+ AutoConfig,
30
+ AutoModelForCausalLM,
31
+ AutoTokenizer,
32
+ PretrainedConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizer,
35
+ )
36
+
37
+ # from .conversation import *
38
+ from .conversation import SeparatorStyle, default_conversation
39
+
40
+ SENTINEL_TOKEN = "<vila/sentinel>"
41
+ MEDIA_TOKENS = {
42
+ "image": "<image>",
43
+ "video": "<vila/video>",
44
+ }
45
+
46
+ # from llava.model.utils import packing
47
+ # from llava.utils.logging import logger
48
+ # from llava.utils.tokenizer import infer_stop_tokens
49
+
50
+ DUMMY_CONVERSATION = [
51
+ {"from": "human", "value": "question"},
52
+ {"from": "gpt", "value": "answer"},
53
+ ] * 10
54
+
55
+
56
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
57
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
58
+
59
+
60
+ def has_tokenizer(repo_id_or_path: str) -> bool:
61
+ # Check if the tokenizer is in a local directory
62
+ if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
63
+ return True
64
+
65
+ # Check if the tokenizer is in a Hugging Face Hub repo
66
+ try:
67
+ return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
68
+ except HFValidationError:
69
+ return False
70
+
71
+
72
+ def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
73
+ if not hasattr(tokenizer, "sentinel_token"):
74
+ tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
75
+ tokenizer.sentinel_token = SENTINEL_TOKEN
76
+ tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
77
+
78
+
79
+ def tokenize_conversation_legacy(
80
+ messages: Sequence[Dict[str, str]],
81
+ tokenizer: transformers.PreTrainedTokenizer,
82
+ add_generation_prompt: bool = False,
83
+ overrides: Optional[Dict[str, str]] = None,
84
+ no_system_prompt: bool = False,
85
+ ) -> torch.Tensor:
86
+ conv = default_conversation.copy()
87
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
88
+
89
+ if no_system_prompt:
90
+ conv.system = ""
91
+
92
+ # Skip the first message if it is not from human
93
+ if messages[0]["from"] != "human":
94
+ messages = messages[1:]
95
+
96
+ # Add a generation prompt if needed
97
+ if add_generation_prompt:
98
+ messages.append({"from": "gpt", "value": None})
99
+
100
+ conv.messages = []
101
+ for turn, message in enumerate(messages):
102
+ role = roles[message["from"]]
103
+ assert role == conv.roles[turn % 2]
104
+ if overrides is not None and message["from"] in overrides:
105
+ conv.append_message(role, overrides[message["from"]])
106
+ else:
107
+ conv.append_message(role, message["value"])
108
+
109
+ return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
110
+
111
+
112
+ def tokenize_conversation(
113
+ messages: Sequence[Dict[str, str]],
114
+ tokenizer: transformers.PreTrainedTokenizer,
115
+ add_generation_prompt: bool = False,
116
+ overrides: Optional[Dict[str, str]] = None,
117
+ no_system_prompt: bool = False,
118
+ ) -> torch.Tensor:
119
+ # Normalize the conversation before tokenization
120
+ for message in messages:
121
+ message["value"] = message["value"].strip()
122
+
123
+ if default_conversation.sep_style != SeparatorStyle.AUTO:
124
+ return tokenize_conversation_legacy(
125
+ messages,
126
+ tokenizer,
127
+ add_generation_prompt=add_generation_prompt,
128
+ overrides=overrides,
129
+ no_system_prompt=no_system_prompt,
130
+ )
131
+
132
+ conversation = []
133
+ for m in messages:
134
+ message = {}
135
+ if m["from"] == "human":
136
+ message["role"] = "user"
137
+ elif m["from"] == "gpt":
138
+ message["role"] = "assistant"
139
+ else:
140
+ raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
141
+
142
+ message["content"] = m["value"]
143
+ if overrides is not None and m["from"] in overrides:
144
+ message["content"] = overrides[m["from"]]
145
+ conversation.append(message)
146
+
147
+ if no_system_prompt:
148
+ conversation = [{"role": "system", "content": ""}] + conversation
149
+
150
+ text = tokenizer.apply_chat_template(
151
+ conversation,
152
+ add_generation_prompt=add_generation_prompt,
153
+ tokenize=False,
154
+ )
155
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt")
156
+
157
+
158
+ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
159
+ _maybe_add_sentinel_token(tokenizer)
160
+ template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
161
+
162
+ stop_tokens = {tokenizer.eos_token}
163
+ for k in range(template.size(0) - 1):
164
+ if template[k] == tokenizer.sentinel_token_id:
165
+ stop_token = tokenizer.decode(template[k + 1])
166
+ stop_tokens.add(stop_token)
167
+ return list(stop_tokens)
168
+
169
+
170
+ def context_length_extension(config):
171
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
172
+ model_max_length = getattr(config, "model_max_length", None)
173
+ if orig_ctx_len and model_max_length > orig_ctx_len:
174
+ print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
175
+ scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
176
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
177
+ return config
178
+
179
+
180
+ def build_llm_and_tokenizer(
181
+ model_name_or_path: str,
182
+ config: PretrainedConfig,
183
+ attn_implementation=None,
184
+ model_max_length=None,
185
+ *args,
186
+ **kwargs,
187
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
188
+ # print(model_name_or_path)
189
+ llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
190
+ llm_cfg._attn_implementation = attn_implementation
191
+ llm_cfg.model_max_length = model_max_length
192
+ if model_max_length is not None:
193
+ context_length_extension(llm_cfg)
194
+
195
+ # Quantization related
196
+ quantization_restore_from_checkpoint = False
197
+
198
+ if quantization_restore_from_checkpoint:
199
+ fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
200
+
201
+ llm = AutoModelForCausalLM.from_pretrained(
202
+ fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
203
+ )
204
+ else:
205
+ llm = AutoModelForCausalLM.from_pretrained(
206
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
207
+ )
208
+ # NOTE(ligeng): not sure whether it affects the training
209
+ # packing.patch(llm)
210
+
211
+ # Locate the tokenizer.
212
+ llm_path = model_name_or_path
213
+ if not has_tokenizer(llm_path):
214
+ llm_path = osp.join(llm_path, "llm")
215
+ if not has_tokenizer(llm_path):
216
+ raise ValueError(f"Cannot find tokenizer in {llm_path}.")
217
+
218
+ tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
219
+ if model_max_length is not None:
220
+ tokenizer.model_max_length = model_max_length
221
+
222
+ # Load chat template if specified.
223
+ if getattr(config, "chat_template", None) is not None:
224
+ print(f"Using chat template: {config.chat_template}")
225
+ fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
226
+ if not os.path.exists(fpath):
227
+ fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
228
+ with open(fpath) as fd:
229
+ chat_template = fd.read()
230
+ tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
231
+
232
+ # Set stop tokens for the tokenizer
233
+ tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
234
+ tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
235
+
236
+ # Add media tokens to the tokenizer
237
+ tokenizer.media_tokens = MEDIA_TOKENS
238
+ tokenizer.media_token_ids = {}
239
+ for name, token in MEDIA_TOKENS.items():
240
+ tokenizer.add_tokens([token], special_tokens=True)
241
+ tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
242
+
243
+ # TODO(ligeng): is this necessary for llava?
244
+ config.hidden_size = llm.config.hidden_size
245
+ return llm, tokenizer
config.json ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Ubit": 100,
3
+ "_name_or_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/sft_14b_GPT4_v6/model",
4
+ "architectures": [
5
+ "VILAForCasualLM"
6
+ ],
7
+ "babit": "E5M2",
8
+ "bobit": "E5M2",
9
+ "bwbit": "E5M2",
10
+ "chat_template": null,
11
+ "col_blocksize": -1,
12
+ "col_blocksize_optimizer": 128,
13
+ "draw_distribution_backward": false,
14
+ "draw_distribution_forward": false,
15
+ "drop_path_rate": 0.0,
16
+ "dynamic_s2": false,
17
+ "epsilon": 1e-10,
18
+ "epsilon_optimizer": 1e-15,
19
+ "fabit": "E4M3",
20
+ "first_order_bit": null,
21
+ "first_order_quant_type": null,
22
+ "fobit": "E4M3",
23
+ "fps": 0.0,
24
+ "fwbit": "E4M3",
25
+ "group_size": -1,
26
+ "hidden_size": 5120,
27
+ "image_aspect_ratio": "dynamic",
28
+ "image_encoder": {
29
+ "_target_": "llava.model.encoders.BasicImageEncoder"
30
+ },
31
+ "interpolate_mode": "linear",
32
+ "llm_cfg": {
33
+ "_name_or_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/sft_14b_GPT4_v6/model/llm",
34
+ "add_cross_attention": false,
35
+ "architectures": [
36
+ "Qwen2ForCausalLM"
37
+ ],
38
+ "attention_dropout": 0.0,
39
+ "bad_words_ids": null,
40
+ "begin_suppress_tokens": null,
41
+ "bos_token_id": 151643,
42
+ "chunk_size_feed_forward": 0,
43
+ "cross_attention_hidden_size": null,
44
+ "decoder_start_token_id": null,
45
+ "diversity_penalty": 0.0,
46
+ "do_sample": false,
47
+ "early_stopping": false,
48
+ "encoder_no_repeat_ngram_size": 0,
49
+ "eos_token_id": 151645,
50
+ "exponential_decay_length_penalty": null,
51
+ "finetuning_task": null,
52
+ "forced_bos_token_id": null,
53
+ "forced_eos_token_id": null,
54
+ "hidden_act": "silu",
55
+ "hidden_size": 5120,
56
+ "id2label": {
57
+ "0": "LABEL_0",
58
+ "1": "LABEL_1"
59
+ },
60
+ "initializer_range": 0.02,
61
+ "intermediate_size": 13824,
62
+ "is_decoder": false,
63
+ "is_encoder_decoder": false,
64
+ "label2id": {
65
+ "LABEL_0": 0,
66
+ "LABEL_1": 1
67
+ },
68
+ "length_penalty": 1.0,
69
+ "max_length": 20,
70
+ "max_position_embeddings": 32768,
71
+ "max_window_layers": 70,
72
+ "min_length": 0,
73
+ "model_max_length": 4096,
74
+ "model_type": "qwen2",
75
+ "no_repeat_ngram_size": 0,
76
+ "num_attention_heads": 40,
77
+ "num_beam_groups": 1,
78
+ "num_beams": 1,
79
+ "num_hidden_layers": 48,
80
+ "num_key_value_heads": 8,
81
+ "num_return_sequences": 1,
82
+ "output_attentions": false,
83
+ "output_hidden_states": false,
84
+ "output_scores": false,
85
+ "pad_token_id": null,
86
+ "prefix": null,
87
+ "problem_type": null,
88
+ "pruned_heads": {},
89
+ "remove_invalid_values": false,
90
+ "repetition_penalty": 1.0,
91
+ "return_dict": true,
92
+ "return_dict_in_generate": false,
93
+ "rms_norm_eps": 1e-06,
94
+ "rope_scaling": null,
95
+ "rope_theta": 1000000.0,
96
+ "sep_token_id": null,
97
+ "sliding_window": null,
98
+ "suppress_tokens": null,
99
+ "task_specific_params": null,
100
+ "temperature": 1.0,
101
+ "tf_legacy_loss": false,
102
+ "tie_encoder_decoder": false,
103
+ "tie_word_embeddings": false,
104
+ "tokenizer_class": null,
105
+ "top_k": 50,
106
+ "top_p": 1.0,
107
+ "torch_dtype": "bfloat16",
108
+ "torchscript": false,
109
+ "typical_p": 1.0,
110
+ "use_bfloat16": false,
111
+ "use_cache": true,
112
+ "use_sliding_window": false,
113
+ "vocab_size": 151670
114
+ },
115
+ "max_tiles": 12,
116
+ "min_blockunit_col": 4,
117
+ "min_blockunit_row": 4,
118
+ "min_tiles": 1,
119
+ "mlp_path": null,
120
+ "mm_hidden_size": 1152,
121
+ "mm_projector": "mlp_downsample_3x3_fix",
122
+ "mm_projector_cfg": {
123
+ "_name_or_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/sft_14b_GPT4_v6/model/mm_projector",
124
+ "add_cross_attention": false,
125
+ "architectures": [
126
+ "MultimodalProjector"
127
+ ],
128
+ "bad_words_ids": null,
129
+ "begin_suppress_tokens": null,
130
+ "bos_token_id": null,
131
+ "chunk_size_feed_forward": 0,
132
+ "cross_attention_hidden_size": null,
133
+ "decoder_start_token_id": null,
134
+ "diversity_penalty": 0.0,
135
+ "do_sample": false,
136
+ "early_stopping": false,
137
+ "encoder_no_repeat_ngram_size": 0,
138
+ "eos_token_id": null,
139
+ "exponential_decay_length_penalty": null,
140
+ "finetuning_task": null,
141
+ "forced_bos_token_id": null,
142
+ "forced_eos_token_id": null,
143
+ "id2label": {
144
+ "0": "LABEL_0",
145
+ "1": "LABEL_1"
146
+ },
147
+ "is_decoder": false,
148
+ "is_encoder_decoder": false,
149
+ "label2id": {
150
+ "LABEL_0": 0,
151
+ "LABEL_1": 1
152
+ },
153
+ "length_penalty": 1.0,
154
+ "max_length": 20,
155
+ "min_length": 0,
156
+ "mm_projector_type": "mlp_downsample_3x3_fix",
157
+ "model_type": "v2l_projector",
158
+ "no_repeat_ngram_size": 0,
159
+ "num_beam_groups": 1,
160
+ "num_beams": 1,
161
+ "num_return_sequences": 1,
162
+ "output_attentions": false,
163
+ "output_hidden_states": false,
164
+ "output_scores": false,
165
+ "pad_token_id": null,
166
+ "prefix": null,
167
+ "problem_type": null,
168
+ "pruned_heads": {},
169
+ "remove_invalid_values": false,
170
+ "repetition_penalty": 1.0,
171
+ "return_dict": true,
172
+ "return_dict_in_generate": false,
173
+ "sep_token_id": null,
174
+ "suppress_tokens": null,
175
+ "task_specific_params": null,
176
+ "temperature": 1.0,
177
+ "tf_legacy_loss": false,
178
+ "tie_encoder_decoder": false,
179
+ "tie_word_embeddings": true,
180
+ "tokenizer_class": null,
181
+ "top_k": 50,
182
+ "top_p": 1.0,
183
+ "torch_dtype": "bfloat16",
184
+ "torchscript": false,
185
+ "typical_p": 1.0,
186
+ "use_bfloat16": false
187
+ },
188
+ "mm_projector_lr": null,
189
+ "mm_use_im_patch_token": false,
190
+ "mm_use_im_start_end": false,
191
+ "mm_vision_select_feature": "cls_patch",
192
+ "mm_vision_select_layer": -2,
193
+ "model_dtype": "torch.bfloat16",
194
+ "model_name_or_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/pretrain_14b/model",
195
+ "model_type": "vila",
196
+ "num_time_tokens": 0,
197
+ "num_video_frames": 8,
198
+ "pad_block": false,
199
+ "pad_to_multiple_of": 0,
200
+ "qchoice": "none",
201
+ "quantize_model": false,
202
+ "refine_attn_blocksize": false,
203
+ "refine_col_blocksize": 4,
204
+ "refine_ln_blocksize": false,
205
+ "refine_ln_blocksize_but_only_backward": false,
206
+ "refine_ln_blocksize_but_only_forward": false,
207
+ "refine_ln_pertoken": false,
208
+ "refine_mlp_blocksize": false,
209
+ "refine_residual_fp": false,
210
+ "refine_row_blocksize": 4,
211
+ "resume_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/sft_14b_GPT4_v6/model",
212
+ "row_blocksize": -1,
213
+ "row_blocksize_optimizer": 1,
214
+ "s2": false,
215
+ "s2_max_split_size": 336,
216
+ "s2_resize_output_to_scale_idx": 0,
217
+ "s2_scales": "336,672,1008",
218
+ "second_order_bit": null,
219
+ "second_order_quant_type": null,
220
+ "soft_ce_std": 1.0,
221
+ "symm": true,
222
+ "time_token_format": "<t{t}>",
223
+ "time_token_ids": [],
224
+ "transformers_version": "4.45.0",
225
+ "tune_language_model": true,
226
+ "tune_mm_projector": true,
227
+ "tune_vision_tower": true,
228
+ "use_quantize_optimizer": false,
229
+ "version": "2.0",
230
+ "video_encoder": {
231
+ "_target_": "llava.model.encoders.BasicVideoEncoder"
232
+ },
233
+ "vision_resolution": -1,
234
+ "vision_tower": "/data/models/Efficient-Large-Model/paligemma-siglip-so400m-patch14-448",
235
+ "vision_tower_cfg": {
236
+ "_name_or_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/sft_14b_GPT4_v6/model/vision_tower",
237
+ "add_cross_attention": false,
238
+ "architectures": [
239
+ "SiglipVisionModel"
240
+ ],
241
+ "attention_dropout": 0.0,
242
+ "bad_words_ids": null,
243
+ "begin_suppress_tokens": null,
244
+ "bos_token_id": null,
245
+ "chunk_size_feed_forward": 0,
246
+ "cross_attention_hidden_size": null,
247
+ "decoder_start_token_id": null,
248
+ "diversity_penalty": 0.0,
249
+ "do_sample": false,
250
+ "early_stopping": false,
251
+ "encoder_no_repeat_ngram_size": 0,
252
+ "eos_token_id": null,
253
+ "exponential_decay_length_penalty": null,
254
+ "finetuning_task": null,
255
+ "forced_bos_token_id": null,
256
+ "forced_eos_token_id": null,
257
+ "hidden_act": "gelu_pytorch_tanh",
258
+ "hidden_size": 1152,
259
+ "id2label": {
260
+ "0": "LABEL_0",
261
+ "1": "LABEL_1"
262
+ },
263
+ "image_size": 448,
264
+ "intermediate_size": 4304,
265
+ "is_decoder": false,
266
+ "is_encoder_decoder": false,
267
+ "label2id": {
268
+ "LABEL_0": 0,
269
+ "LABEL_1": 1
270
+ },
271
+ "layer_norm_eps": 1e-06,
272
+ "length_penalty": 1.0,
273
+ "max_length": 20,
274
+ "min_length": 0,
275
+ "model_type": "siglip_vision_model",
276
+ "no_repeat_ngram_size": 0,
277
+ "num_attention_heads": 16,
278
+ "num_beam_groups": 1,
279
+ "num_beams": 1,
280
+ "num_channels": 3,
281
+ "num_hidden_layers": 27,
282
+ "num_image_tokens": 256,
283
+ "num_return_sequences": 1,
284
+ "output_attentions": false,
285
+ "output_hidden_states": false,
286
+ "output_scores": false,
287
+ "pad_token_id": null,
288
+ "patch_size": 14,
289
+ "prefix": null,
290
+ "problem_type": null,
291
+ "projection_dim": 2048,
292
+ "projector_hidden_act": "gelu_fast",
293
+ "pruned_heads": {},
294
+ "remove_invalid_values": false,
295
+ "repetition_penalty": 1.0,
296
+ "return_dict": true,
297
+ "return_dict_in_generate": false,
298
+ "sep_token_id": null,
299
+ "suppress_tokens": null,
300
+ "task_specific_params": null,
301
+ "temperature": 1.0,
302
+ "tf_legacy_loss": false,
303
+ "tie_encoder_decoder": false,
304
+ "tie_word_embeddings": true,
305
+ "tokenizer_class": null,
306
+ "top_k": 50,
307
+ "top_p": 1.0,
308
+ "torch_dtype": "bfloat16",
309
+ "torchscript": false,
310
+ "typical_p": 1.0,
311
+ "use_bfloat16": false,
312
+ "vision_use_head": false
313
+ },
314
+ "vision_tower_lr": null,
315
+ "weight_memory_efficient": true,
316
+ "auto_map": {
317
+ "AutoProcessor": "auto_processor.VILAProcessor",
318
+ "AutoConfig": "modeling_vila.VILAConfig",
319
+ "AutoModel": "modeling_vila.VILAForCasualLM",
320
+ "AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
321
+ }
322
+ }
configuration_vila.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import os.path as osp
5
+ from copy import deepcopy
6
+ from threading import Thread
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import torchvision
11
+ from PIL import Image
12
+ from transformers import (
13
+ AutoProcessor,
14
+ PretrainedConfig,
15
+ PreTrainedModel,
16
+ Qwen2Config,
17
+ Qwen2ForCausalLM,
18
+ Qwen2PreTrainedModel,
19
+ TextIteratorStreamer,
20
+ )
21
+
22
+
23
+ class VILAConfig(PretrainedConfig):
24
+ model_type = "vila"
25
+ keys_to_ignore_at_inference = ["past_key_values"]
26
+
27
+ def __init__(
28
+ self,
29
+ llm_cfg=None,
30
+ vision_tower_cfg=None,
31
+ mm_projector_cfg=None,
32
+ architectures=None,
33
+ resume_path=None,
34
+ hidden_size=None,
35
+ mm_hidden_size=None,
36
+ image_aspect_ratio=None,
37
+ num_video_frames=None,
38
+ fps=None,
39
+ mm_vision_select_layer=None,
40
+ mm_vision_select_feature=None,
41
+ mm_use_im_start_end=False,
42
+ mm_use_im_patch_token=False,
43
+ mm_projector_lr=None,
44
+ vision_tower_lr=None,
45
+ vision_resolution=None,
46
+ interpolate_mode=None,
47
+ s2=None,
48
+ dynamic_s2=None,
49
+ s2_scales=None,
50
+ s2_max_split_size=None,
51
+ s2_resize_output_to_scale_idx=0,
52
+ min_tiles: Optional[int] = 1,
53
+ max_tiles: Optional[int] = 12,
54
+ num_time_tokens=None,
55
+ time_token_format=None,
56
+ image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
57
+ video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}',
58
+ **kwargs,
59
+ ):
60
+ super().__init__()
61
+ self.architectures = architectures
62
+ self.llm_cfg = llm_cfg
63
+ self.vision_tower_cfg = vision_tower_cfg
64
+ self.mm_projector_cfg = mm_projector_cfg
65
+ self.resume_path = resume_path
66
+
67
+ self.hidden_size = hidden_size
68
+ self.mm_hidden_size = mm_hidden_size
69
+ self.image_aspect_ratio = image_aspect_ratio
70
+ self.num_video_frames = num_video_frames
71
+ self.fps = fps
72
+ self.mm_vision_select_layer = mm_vision_select_layer
73
+ self.mm_vision_select_feature = mm_vision_select_feature
74
+ self.mm_use_im_start_end = mm_use_im_start_end
75
+ self.mm_use_im_patch_token = mm_use_im_patch_token
76
+ self.mm_projector_lr = mm_projector_lr
77
+ self.vision_tower_lr = vision_tower_lr
78
+ self.vision_resolution = vision_resolution
79
+ self.interpolate_mode = interpolate_mode
80
+ self.s2 = s2
81
+ self.dynamic_s2 = dynamic_s2
82
+ self.s2_scales = s2_scales
83
+ self.s2_max_split_size = s2_max_split_size
84
+ self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
85
+ self.min_tiles = min_tiles
86
+ self.max_tiles = max_tiles
87
+ self.num_time_tokens = num_time_tokens
88
+ self.time_token_format = time_token_format
89
+
90
+ self.image_encoder = image_encoder
91
+ self.video_encoder = video_encoder
92
+
93
+ super().__init__(**kwargs)
constants.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
18
+ WORKER_HEART_BEAT_INTERVAL = 15
19
+
20
+ LOGDIR = "."
21
+
22
+ # Model Constants
23
+ IGNORE_INDEX = -100
24
+ DEFAULT_IMAGE_TOKEN = "<image>"
25
+
26
+ SENTINEL_TOKEN = "<vila/sentinel>"
27
+ MEDIA_TOKENS = {
28
+ "image": "<image>",
29
+ "video": "<vila/video>",
30
+ }
31
+ # <image> <vila/video> <vila/sentinel>
32
+ # TODO(ligeng): need to discuss with Zhijian for the following tokens for different models.
33
+ """
34
+ 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
35
+ 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
36
+ 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
37
+ 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
38
+ 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
39
+ 151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
40
+ 151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
41
+ 151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
42
+ """
43
+ NUM_EXTRA_TOKENS = 8
conversation.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+ import dataclasses
19
+ from enum import Enum, auto
20
+ from typing import List
21
+
22
+ # from llava.utils.logging import logger
23
+
24
+
25
+ class SeparatorStyle(Enum):
26
+ """Different separator style."""
27
+
28
+ AUTO = auto()
29
+ TWO = auto()
30
+ MPT = auto()
31
+ PLAIN = auto()
32
+ LLAMA_3 = auto()
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class Conversation:
37
+ """A class that keeps all conversation history."""
38
+
39
+ system: str
40
+ roles: List[str]
41
+ messages: List[List[str]]
42
+ sep_style: SeparatorStyle = SeparatorStyle.AUTO
43
+ sep: str = "###"
44
+ sep2: str = None
45
+ version: str = "Unknown"
46
+
47
+ def get_prompt(self):
48
+ messages = self.messages
49
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
50
+ messages = self.messages.copy()
51
+ init_role, init_msg = messages[0].copy()
52
+ init_msg = init_msg[0].replace("<image>", "").strip()
53
+ messages[0] = (init_role, "<image>\n" + init_msg)
54
+
55
+ if self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
66
+ ret = self.system + self.sep
67
+ for rid, (role, message) in enumerate(messages):
68
+ if message:
69
+ if type(message) is tuple:
70
+ message = message[0]
71
+ sep = self.sep if rid < len(messages) - 1 else self.sep2
72
+ ret += role + message + sep
73
+ else:
74
+ ret += role
75
+ elif self.sep_style == SeparatorStyle.MPT:
76
+ ret = self.system + self.sep
77
+ for role, message in messages:
78
+ if message:
79
+ if type(message) is tuple:
80
+ message, _, _ = message
81
+ ret += role + message + self.sep
82
+ else:
83
+ ret += role
84
+ elif self.sep_style == SeparatorStyle.PLAIN:
85
+ seps = [self.sep, self.sep2]
86
+ ret = self.system
87
+ for i, (role, message) in enumerate(messages):
88
+ if message:
89
+ if type(message) is tuple:
90
+ message, _, _ = message
91
+ ret += message + seps[i % 2]
92
+ else:
93
+ ret += ""
94
+ else:
95
+ raise ValueError(f"Invalid style: {self.sep_style}")
96
+
97
+ return ret
98
+
99
+ def append_message(self, role, message):
100
+ self.messages.append([role, message])
101
+
102
+ def copy(self):
103
+ return Conversation(
104
+ system=self.system,
105
+ roles=self.roles,
106
+ messages=[[x, y] for x, y in self.messages],
107
+ sep_style=self.sep_style,
108
+ sep=self.sep,
109
+ sep2=self.sep2,
110
+ version=self.version,
111
+ )
112
+
113
+
114
+ conv_auto = Conversation(
115
+ system="",
116
+ roles=("", ""),
117
+ messages=(),
118
+ sep_style=SeparatorStyle.AUTO,
119
+ sep="\n",
120
+ )
121
+
122
+ conv_vicuna_v1 = Conversation(
123
+ system="A chat between a curious user and an artificial intelligence assistant. "
124
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
125
+ roles=("USER", "ASSISTANT"),
126
+ version="v1",
127
+ messages=(),
128
+ sep_style=SeparatorStyle.TWO,
129
+ sep=" ",
130
+ sep2="</s>",
131
+ )
132
+
133
+ conv_llava_plain = Conversation(
134
+ system="",
135
+ roles=("", ""),
136
+ messages=(),
137
+ sep_style=SeparatorStyle.PLAIN,
138
+ sep="\n",
139
+ )
140
+
141
+ hermes_2 = Conversation(
142
+ system="<|im_start|>system\nAnswer the questions.",
143
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
144
+ sep_style=SeparatorStyle.MPT,
145
+ sep="<|im_end|>",
146
+ messages=(),
147
+ version="hermes-2",
148
+ )
149
+
150
+ # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
151
+ llama_3_chat = Conversation(
152
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
153
+ "You are able to understand the visual content that the user provides, "
154
+ "and assist the user with a variety of tasks using natural language.",
155
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
156
+ version="llama_v3",
157
+ messages=(),
158
+ sep_style=SeparatorStyle.LLAMA_3,
159
+ sep="<|eot_id|>",
160
+ sep2="<|end_of_text|>",
161
+ )
162
+
163
+
164
+ default_conversation = conv_auto
165
+ conv_templates = {
166
+ "auto": conv_auto,
167
+ "hermes-2": hermes_2,
168
+ "llama_3": llama_3_chat,
169
+ "v1": conv_vicuna_v1,
170
+ "vicuna_v1": conv_vicuna_v1,
171
+ "plain": conv_llava_plain,
172
+ }
173
+
174
+
175
+ CONVERSATION_MODE_MAPPING = {
176
+ "vila1.5-3b": "vicuna_v1",
177
+ "vila1.5-8b": "llama_3",
178
+ "vila1.5-13b": "vicuna_v1",
179
+ "vila1.5-40b": "hermes-2",
180
+ "llama-3": "llama_3",
181
+ "llama3": "llama_3",
182
+ }
183
+
184
+
185
+ def auto_set_conversation_mode(model_name_or_path: str) -> str:
186
+ global default_conversation
187
+ for k, v in CONVERSATION_MODE_MAPPING.items():
188
+ if k in model_name_or_path.lower():
189
+ print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
190
+ default_conversation = conv_templates[v]
191
+ return
distributed.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from typing import Any, List, Optional
4
+
5
+ from torch import distributed as dist
6
+
7
+ __all__ = [
8
+ "init",
9
+ "is_initialized",
10
+ "size",
11
+ "rank",
12
+ "local_size",
13
+ "local_rank",
14
+ "is_main",
15
+ "barrier",
16
+ "gather",
17
+ "all_gather",
18
+ ]
19
+
20
+
21
+ def init() -> None:
22
+ if "RANK" not in os.environ:
23
+ warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
24
+ return
25
+ dist.init_process_group(backend="nccl", init_method="env://")
26
+
27
+
28
+ def is_initialized() -> bool:
29
+ return dist.is_initialized()
30
+
31
+
32
+ def size() -> int:
33
+ return int(os.environ.get("WORLD_SIZE", 1))
34
+
35
+
36
+ def rank() -> int:
37
+ return int(os.environ.get("RANK", 0))
38
+
39
+
40
+ def local_size() -> int:
41
+ return int(os.environ.get("LOCAL_WORLD_SIZE", 1))
42
+
43
+
44
+ def local_rank() -> int:
45
+ return int(os.environ.get("LOCAL_RANK", 0))
46
+
47
+
48
+ def is_main() -> bool:
49
+ return rank() == 0
50
+
51
+
52
+ def barrier() -> None:
53
+ dist.barrier()
54
+
55
+
56
+ def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
57
+ if not is_initialized():
58
+ return [obj]
59
+ if is_main():
60
+ objs = [None for _ in range(size())]
61
+ dist.gather_object(obj, objs, dst=dst)
62
+ return objs
63
+ else:
64
+ dist.gather_object(obj, dst=dst)
65
+ return None
66
+
67
+
68
+ def all_gather(obj: Any) -> List[Any]:
69
+ if not is_initialized():
70
+ return [obj]
71
+ objs = [None for _ in range(size())]
72
+ dist.all_gather_object(objs, obj)
73
+ return objs
llm/added_tokens.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<image>": 151666,
4
+ "<tool_call>": 151657,
5
+ "<vila/sentinel>": 151665,
6
+ "<vila/video>": 151667,
7
+ "<|box_end|>": 151649,
8
+ "<|box_start|>": 151648,
9
+ "<|endoftext|>": 151643,
10
+ "<|file_sep|>": 151664,
11
+ "<|fim_middle|>": 151660,
12
+ "<|fim_pad|>": 151662,
13
+ "<|fim_prefix|>": 151659,
14
+ "<|fim_suffix|>": 151661,
15
+ "<|im_end|>": 151645,
16
+ "<|im_start|>": 151644,
17
+ "<|image_pad|>": 151655,
18
+ "<|object_ref_end|>": 151647,
19
+ "<|object_ref_start|>": 151646,
20
+ "<|quad_end|>": 151651,
21
+ "<|quad_start|>": 151650,
22
+ "<|repo_name|>": 151663,
23
+ "<|video_pad|>": 151656,
24
+ "<|vision_end|>": 151653,
25
+ "<|vision_pad|>": 151654,
26
+ "<|vision_start|>": 151652,
27
+ "[BOS]": 151668,
28
+ "[PAD]": 151669
29
+ }
llm/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/sft_14b_GPT4_v6/model/llm",
3
+ "architectures": [
4
+ "Qwen2ForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 5120,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 13824,
13
+ "max_position_embeddings": 32768,
14
+ "max_window_layers": 70,
15
+ "model_max_length": 4096,
16
+ "model_type": "qwen2",
17
+ "num_attention_heads": 40,
18
+ "num_hidden_layers": 48,
19
+ "num_key_value_heads": 8,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000.0,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.45.0",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151670
30
+ }
llm/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "repetition_penalty": 1.05,
10
+ "temperature": 0.7,
11
+ "top_k": 20,
12
+ "top_p": 0.8,
13
+ "transformers_version": "4.45.0"
14
+ }
llm/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
llm/model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee395da79dd9faebfc541bbdfe73fb1132db12e346e314f9868b409a2c217e26
3
+ size 4982176720
llm/model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a80b8643a787f40bb0908614f36194ebfb1d8415718f67372a39844781a45630
3
+ size 4954847344
llm/model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70217c6ce2d4f73e162b81efe72b3498b1a315ff1517d6e97913194752ffaf68
3
+ size 4954847392
llm/model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ecab4ac8e07a9e0eb391f72cdc7b6cc895b5fbaf0809917d04669d3d9f8e2e7
3
+ size 4954847392
llm/model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a861359cf100d5879e2c27d4f840444645cb2efe616f449d1ccc43671a5a3365
3
+ size 4954847392
llm/model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cae8fcc5d91e703f05bb4eb6498e49838badc865fc41bba9a5b1650742c049d
3
+ size 4730498600
llm/model.safetensors.index.json ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 29531998208
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00006-of-00006.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00006.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00006.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
13
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00006.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
15
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
16
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00006.safetensors",
17
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
18
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00006.safetensors",
19
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
20
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00006.safetensors",
21
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
22
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
23
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
24
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
25
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00006.safetensors",
26
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
27
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
28
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00006.safetensors",
29
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
30
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00006.safetensors",
31
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
32
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00006.safetensors",
33
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
34
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
35
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
36
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
37
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
38
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
39
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
40
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
41
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
42
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
43
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
44
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00006.safetensors",
45
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
46
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
47
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
48
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
49
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
50
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
51
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
52
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
53
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
54
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
55
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
56
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00006.safetensors",
57
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
58
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
59
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
60
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
61
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
62
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
63
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
64
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
65
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
66
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
67
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
68
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00006.safetensors",
69
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
70
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
71
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
72
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
73
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
74
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
75
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
76
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
77
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
78
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
79
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
80
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00006.safetensors",
81
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
82
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
83
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
84
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
85
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
86
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
87
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
88
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
89
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
90
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
91
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
92
+ "model.layers.15.input_layernorm.weight": "model-00003-of-00006.safetensors",
93
+ "model.layers.15.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
94
+ "model.layers.15.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
95
+ "model.layers.15.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
96
+ "model.layers.15.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
97
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
98
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
99
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
100
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
101
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
102
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
103
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
104
+ "model.layers.16.input_layernorm.weight": "model-00003-of-00006.safetensors",
105
+ "model.layers.16.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
106
+ "model.layers.16.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
107
+ "model.layers.16.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
108
+ "model.layers.16.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
109
+ "model.layers.16.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
110
+ "model.layers.16.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
111
+ "model.layers.16.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
112
+ "model.layers.16.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
113
+ "model.layers.16.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
114
+ "model.layers.16.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
115
+ "model.layers.16.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
116
+ "model.layers.17.input_layernorm.weight": "model-00003-of-00006.safetensors",
117
+ "model.layers.17.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
118
+ "model.layers.17.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
119
+ "model.layers.17.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
120
+ "model.layers.17.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
121
+ "model.layers.17.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
122
+ "model.layers.17.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
123
+ "model.layers.17.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
124
+ "model.layers.17.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
125
+ "model.layers.17.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
126
+ "model.layers.17.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
127
+ "model.layers.17.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
128
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00006.safetensors",
129
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
130
+ "model.layers.18.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
131
+ "model.layers.18.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
132
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
133
+ "model.layers.18.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
134
+ "model.layers.18.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
135
+ "model.layers.18.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
136
+ "model.layers.18.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
137
+ "model.layers.18.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
138
+ "model.layers.18.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
139
+ "model.layers.18.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
140
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00006.safetensors",
141
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
142
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
143
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
144
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
145
+ "model.layers.19.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
146
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
147
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
148
+ "model.layers.19.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
149
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
150
+ "model.layers.19.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
151
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
152
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00006.safetensors",
153
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
154
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
155
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
156
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
157
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00006.safetensors",
158
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
159
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
160
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00006.safetensors",
161
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
162
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00006.safetensors",
163
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
164
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00006.safetensors",
165
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
166
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
167
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
168
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
169
+ "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
170
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
171
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
172
+ "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
173
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
174
+ "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
175
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
176
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00006.safetensors",
177
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
178
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
179
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
180
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
181
+ "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
182
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
183
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
184
+ "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
185
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
186
+ "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
187
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
188
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00006.safetensors",
189
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
190
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
191
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
192
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
193
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
194
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
195
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
196
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
197
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
198
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
199
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
200
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00006.safetensors",
201
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
202
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
203
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
204
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
205
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
206
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
207
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
208
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
209
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
210
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
211
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
212
+ "model.layers.24.input_layernorm.weight": "model-00004-of-00006.safetensors",
213
+ "model.layers.24.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
214
+ "model.layers.24.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
215
+ "model.layers.24.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
216
+ "model.layers.24.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
217
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00006.safetensors",
218
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
219
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
220
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00006.safetensors",
221
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
222
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00006.safetensors",
223
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
224
+ "model.layers.25.input_layernorm.weight": "model-00004-of-00006.safetensors",
225
+ "model.layers.25.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
226
+ "model.layers.25.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
227
+ "model.layers.25.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
228
+ "model.layers.25.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
229
+ "model.layers.25.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
230
+ "model.layers.25.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
231
+ "model.layers.25.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
232
+ "model.layers.25.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
233
+ "model.layers.25.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
234
+ "model.layers.25.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
235
+ "model.layers.25.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
236
+ "model.layers.26.input_layernorm.weight": "model-00004-of-00006.safetensors",
237
+ "model.layers.26.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
238
+ "model.layers.26.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
239
+ "model.layers.26.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
240
+ "model.layers.26.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
241
+ "model.layers.26.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
242
+ "model.layers.26.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
243
+ "model.layers.26.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
244
+ "model.layers.26.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
245
+ "model.layers.26.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
246
+ "model.layers.26.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
247
+ "model.layers.26.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
248
+ "model.layers.27.input_layernorm.weight": "model-00004-of-00006.safetensors",
249
+ "model.layers.27.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
250
+ "model.layers.27.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
251
+ "model.layers.27.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
252
+ "model.layers.27.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
253
+ "model.layers.27.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
254
+ "model.layers.27.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
255
+ "model.layers.27.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
256
+ "model.layers.27.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
257
+ "model.layers.27.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
258
+ "model.layers.27.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
259
+ "model.layers.27.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
260
+ "model.layers.28.input_layernorm.weight": "model-00004-of-00006.safetensors",
261
+ "model.layers.28.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
262
+ "model.layers.28.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
263
+ "model.layers.28.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
264
+ "model.layers.28.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
265
+ "model.layers.28.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
266
+ "model.layers.28.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
267
+ "model.layers.28.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
268
+ "model.layers.28.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
269
+ "model.layers.28.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
270
+ "model.layers.28.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
271
+ "model.layers.28.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
272
+ "model.layers.29.input_layernorm.weight": "model-00004-of-00006.safetensors",
273
+ "model.layers.29.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
274
+ "model.layers.29.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
275
+ "model.layers.29.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
276
+ "model.layers.29.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
277
+ "model.layers.29.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
278
+ "model.layers.29.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
279
+ "model.layers.29.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
280
+ "model.layers.29.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
281
+ "model.layers.29.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
282
+ "model.layers.29.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
283
+ "model.layers.29.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
284
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00006.safetensors",
285
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
286
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
287
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
288
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
289
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00006.safetensors",
290
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
291
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
292
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00006.safetensors",
293
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
294
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00006.safetensors",
295
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
296
+ "model.layers.30.input_layernorm.weight": "model-00004-of-00006.safetensors",
297
+ "model.layers.30.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
298
+ "model.layers.30.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
299
+ "model.layers.30.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
300
+ "model.layers.30.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
301
+ "model.layers.30.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
302
+ "model.layers.30.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
303
+ "model.layers.30.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
304
+ "model.layers.30.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
305
+ "model.layers.30.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
306
+ "model.layers.30.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
307
+ "model.layers.30.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
308
+ "model.layers.31.input_layernorm.weight": "model-00004-of-00006.safetensors",
309
+ "model.layers.31.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
310
+ "model.layers.31.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
311
+ "model.layers.31.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
312
+ "model.layers.31.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
313
+ "model.layers.31.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
314
+ "model.layers.31.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
315
+ "model.layers.31.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
316
+ "model.layers.31.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
317
+ "model.layers.31.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
318
+ "model.layers.31.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
319
+ "model.layers.31.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
320
+ "model.layers.32.input_layernorm.weight": "model-00004-of-00006.safetensors",
321
+ "model.layers.32.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
322
+ "model.layers.32.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
323
+ "model.layers.32.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
324
+ "model.layers.32.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
325
+ "model.layers.32.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
326
+ "model.layers.32.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
327
+ "model.layers.32.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
328
+ "model.layers.32.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
329
+ "model.layers.32.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
330
+ "model.layers.32.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
331
+ "model.layers.32.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
332
+ "model.layers.33.input_layernorm.weight": "model-00005-of-00006.safetensors",
333
+ "model.layers.33.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
334
+ "model.layers.33.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
335
+ "model.layers.33.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
336
+ "model.layers.33.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
337
+ "model.layers.33.self_attn.k_proj.bias": "model-00004-of-00006.safetensors",
338
+ "model.layers.33.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
339
+ "model.layers.33.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
340
+ "model.layers.33.self_attn.q_proj.bias": "model-00004-of-00006.safetensors",
341
+ "model.layers.33.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
342
+ "model.layers.33.self_attn.v_proj.bias": "model-00004-of-00006.safetensors",
343
+ "model.layers.33.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
344
+ "model.layers.34.input_layernorm.weight": "model-00005-of-00006.safetensors",
345
+ "model.layers.34.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
346
+ "model.layers.34.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
347
+ "model.layers.34.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
348
+ "model.layers.34.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
349
+ "model.layers.34.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
350
+ "model.layers.34.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
351
+ "model.layers.34.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
352
+ "model.layers.34.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
353
+ "model.layers.34.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
354
+ "model.layers.34.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
355
+ "model.layers.34.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
356
+ "model.layers.35.input_layernorm.weight": "model-00005-of-00006.safetensors",
357
+ "model.layers.35.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
358
+ "model.layers.35.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
359
+ "model.layers.35.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
360
+ "model.layers.35.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
361
+ "model.layers.35.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
362
+ "model.layers.35.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
363
+ "model.layers.35.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
364
+ "model.layers.35.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
365
+ "model.layers.35.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
366
+ "model.layers.35.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
367
+ "model.layers.35.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
368
+ "model.layers.36.input_layernorm.weight": "model-00005-of-00006.safetensors",
369
+ "model.layers.36.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
370
+ "model.layers.36.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
371
+ "model.layers.36.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
372
+ "model.layers.36.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
373
+ "model.layers.36.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
374
+ "model.layers.36.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
375
+ "model.layers.36.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
376
+ "model.layers.36.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
377
+ "model.layers.36.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
378
+ "model.layers.36.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
379
+ "model.layers.36.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
380
+ "model.layers.37.input_layernorm.weight": "model-00005-of-00006.safetensors",
381
+ "model.layers.37.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
382
+ "model.layers.37.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
383
+ "model.layers.37.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
384
+ "model.layers.37.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
385
+ "model.layers.37.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
386
+ "model.layers.37.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
387
+ "model.layers.37.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
388
+ "model.layers.37.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
389
+ "model.layers.37.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
390
+ "model.layers.37.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
391
+ "model.layers.37.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
392
+ "model.layers.38.input_layernorm.weight": "model-00005-of-00006.safetensors",
393
+ "model.layers.38.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
394
+ "model.layers.38.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
395
+ "model.layers.38.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
396
+ "model.layers.38.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
397
+ "model.layers.38.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
398
+ "model.layers.38.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
399
+ "model.layers.38.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
400
+ "model.layers.38.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
401
+ "model.layers.38.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
402
+ "model.layers.38.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
403
+ "model.layers.38.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
404
+ "model.layers.39.input_layernorm.weight": "model-00005-of-00006.safetensors",
405
+ "model.layers.39.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
406
+ "model.layers.39.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
407
+ "model.layers.39.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
408
+ "model.layers.39.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
409
+ "model.layers.39.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
410
+ "model.layers.39.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
411
+ "model.layers.39.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
412
+ "model.layers.39.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
413
+ "model.layers.39.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
414
+ "model.layers.39.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
415
+ "model.layers.39.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
416
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00006.safetensors",
417
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
418
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
419
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
420
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
421
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00006.safetensors",
422
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
423
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
424
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00006.safetensors",
425
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
426
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00006.safetensors",
427
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
428
+ "model.layers.40.input_layernorm.weight": "model-00005-of-00006.safetensors",
429
+ "model.layers.40.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
430
+ "model.layers.40.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
431
+ "model.layers.40.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
432
+ "model.layers.40.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
433
+ "model.layers.40.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
434
+ "model.layers.40.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
435
+ "model.layers.40.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
436
+ "model.layers.40.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
437
+ "model.layers.40.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
438
+ "model.layers.40.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
439
+ "model.layers.40.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
440
+ "model.layers.41.input_layernorm.weight": "model-00005-of-00006.safetensors",
441
+ "model.layers.41.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
442
+ "model.layers.41.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
443
+ "model.layers.41.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
444
+ "model.layers.41.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
445
+ "model.layers.41.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
446
+ "model.layers.41.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
447
+ "model.layers.41.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
448
+ "model.layers.41.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
449
+ "model.layers.41.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
450
+ "model.layers.41.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
451
+ "model.layers.41.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
452
+ "model.layers.42.input_layernorm.weight": "model-00006-of-00006.safetensors",
453
+ "model.layers.42.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
454
+ "model.layers.42.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
455
+ "model.layers.42.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
456
+ "model.layers.42.post_attention_layernorm.weight": "model-00006-of-00006.safetensors",
457
+ "model.layers.42.self_attn.k_proj.bias": "model-00005-of-00006.safetensors",
458
+ "model.layers.42.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
459
+ "model.layers.42.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
460
+ "model.layers.42.self_attn.q_proj.bias": "model-00005-of-00006.safetensors",
461
+ "model.layers.42.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
462
+ "model.layers.42.self_attn.v_proj.bias": "model-00005-of-00006.safetensors",
463
+ "model.layers.42.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
464
+ "model.layers.43.input_layernorm.weight": "model-00006-of-00006.safetensors",
465
+ "model.layers.43.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
466
+ "model.layers.43.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
467
+ "model.layers.43.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
468
+ "model.layers.43.post_attention_layernorm.weight": "model-00006-of-00006.safetensors",
469
+ "model.layers.43.self_attn.k_proj.bias": "model-00006-of-00006.safetensors",
470
+ "model.layers.43.self_attn.k_proj.weight": "model-00006-of-00006.safetensors",
471
+ "model.layers.43.self_attn.o_proj.weight": "model-00006-of-00006.safetensors",
472
+ "model.layers.43.self_attn.q_proj.bias": "model-00006-of-00006.safetensors",
473
+ "model.layers.43.self_attn.q_proj.weight": "model-00006-of-00006.safetensors",
474
+ "model.layers.43.self_attn.v_proj.bias": "model-00006-of-00006.safetensors",
475
+ "model.layers.43.self_attn.v_proj.weight": "model-00006-of-00006.safetensors",
476
+ "model.layers.44.input_layernorm.weight": "model-00006-of-00006.safetensors",
477
+ "model.layers.44.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
478
+ "model.layers.44.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
479
+ "model.layers.44.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
480
+ "model.layers.44.post_attention_layernorm.weight": "model-00006-of-00006.safetensors",
481
+ "model.layers.44.self_attn.k_proj.bias": "model-00006-of-00006.safetensors",
482
+ "model.layers.44.self_attn.k_proj.weight": "model-00006-of-00006.safetensors",
483
+ "model.layers.44.self_attn.o_proj.weight": "model-00006-of-00006.safetensors",
484
+ "model.layers.44.self_attn.q_proj.bias": "model-00006-of-00006.safetensors",
485
+ "model.layers.44.self_attn.q_proj.weight": "model-00006-of-00006.safetensors",
486
+ "model.layers.44.self_attn.v_proj.bias": "model-00006-of-00006.safetensors",
487
+ "model.layers.44.self_attn.v_proj.weight": "model-00006-of-00006.safetensors",
488
+ "model.layers.45.input_layernorm.weight": "model-00006-of-00006.safetensors",
489
+ "model.layers.45.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
490
+ "model.layers.45.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
491
+ "model.layers.45.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
492
+ "model.layers.45.post_attention_layernorm.weight": "model-00006-of-00006.safetensors",
493
+ "model.layers.45.self_attn.k_proj.bias": "model-00006-of-00006.safetensors",
494
+ "model.layers.45.self_attn.k_proj.weight": "model-00006-of-00006.safetensors",
495
+ "model.layers.45.self_attn.o_proj.weight": "model-00006-of-00006.safetensors",
496
+ "model.layers.45.self_attn.q_proj.bias": "model-00006-of-00006.safetensors",
497
+ "model.layers.45.self_attn.q_proj.weight": "model-00006-of-00006.safetensors",
498
+ "model.layers.45.self_attn.v_proj.bias": "model-00006-of-00006.safetensors",
499
+ "model.layers.45.self_attn.v_proj.weight": "model-00006-of-00006.safetensors",
500
+ "model.layers.46.input_layernorm.weight": "model-00006-of-00006.safetensors",
501
+ "model.layers.46.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
502
+ "model.layers.46.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
503
+ "model.layers.46.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
504
+ "model.layers.46.post_attention_layernorm.weight": "model-00006-of-00006.safetensors",
505
+ "model.layers.46.self_attn.k_proj.bias": "model-00006-of-00006.safetensors",
506
+ "model.layers.46.self_attn.k_proj.weight": "model-00006-of-00006.safetensors",
507
+ "model.layers.46.self_attn.o_proj.weight": "model-00006-of-00006.safetensors",
508
+ "model.layers.46.self_attn.q_proj.bias": "model-00006-of-00006.safetensors",
509
+ "model.layers.46.self_attn.q_proj.weight": "model-00006-of-00006.safetensors",
510
+ "model.layers.46.self_attn.v_proj.bias": "model-00006-of-00006.safetensors",
511
+ "model.layers.46.self_attn.v_proj.weight": "model-00006-of-00006.safetensors",
512
+ "model.layers.47.input_layernorm.weight": "model-00006-of-00006.safetensors",
513
+ "model.layers.47.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
514
+ "model.layers.47.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
515
+ "model.layers.47.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
516
+ "model.layers.47.post_attention_layernorm.weight": "model-00006-of-00006.safetensors",
517
+ "model.layers.47.self_attn.k_proj.bias": "model-00006-of-00006.safetensors",
518
+ "model.layers.47.self_attn.k_proj.weight": "model-00006-of-00006.safetensors",
519
+ "model.layers.47.self_attn.o_proj.weight": "model-00006-of-00006.safetensors",
520
+ "model.layers.47.self_attn.q_proj.bias": "model-00006-of-00006.safetensors",
521
+ "model.layers.47.self_attn.q_proj.weight": "model-00006-of-00006.safetensors",
522
+ "model.layers.47.self_attn.v_proj.bias": "model-00006-of-00006.safetensors",
523
+ "model.layers.47.self_attn.v_proj.weight": "model-00006-of-00006.safetensors",
524
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00006.safetensors",
525
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
526
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
527
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
528
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
529
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00006.safetensors",
530
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
531
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
532
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00006.safetensors",
533
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
534
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00006.safetensors",
535
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
536
+ "model.layers.6.input_layernorm.weight": "model-00002-of-00006.safetensors",
537
+ "model.layers.6.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
538
+ "model.layers.6.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
539
+ "model.layers.6.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
540
+ "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
541
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00006.safetensors",
542
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
543
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
544
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00006.safetensors",
545
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
546
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00006.safetensors",
547
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
548
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00006.safetensors",
549
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
550
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
551
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
552
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
553
+ "model.layers.7.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
554
+ "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
555
+ "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
556
+ "model.layers.7.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
557
+ "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
558
+ "model.layers.7.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
559
+ "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
560
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00006.safetensors",
561
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
562
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
563
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
564
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
565
+ "model.layers.8.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
566
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
567
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
568
+ "model.layers.8.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
569
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
570
+ "model.layers.8.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
571
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
572
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00006.safetensors",
573
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
574
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
575
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
576
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
577
+ "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00006.safetensors",
578
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
579
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
580
+ "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00006.safetensors",
581
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
582
+ "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00006.safetensors",
583
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
584
+ "model.norm.weight": "model-00006-of-00006.safetensors"
585
+ }
586
+ }
llm/special_tokens_map.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ "<vila/sentinel>",
17
+ "<image>",
18
+ "<vila/video>"
19
+ ],
20
+ "bos_token": {
21
+ "content": "[BOS]",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "eos_token": {
28
+ "content": "<|im_end|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ },
34
+ "pad_token": {
35
+ "content": "[PAD]",
36
+ "lstrip": false,
37
+ "normalized": false,
38
+ "rstrip": false,
39
+ "single_word": false
40
+ }
41
+ }
llm/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2adb5255020285bad13f10e6c896570ffe9c35c1b5c0ea587e6ec9662b84f6ea
3
+ size 11422819
llm/tokenizer_config.json ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<vila/sentinel>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "151666": {
190
+ "content": "<image>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "151667": {
198
+ "content": "<vila/video>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": true
204
+ },
205
+ "151668": {
206
+ "content": "[BOS]",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": true
212
+ },
213
+ "151669": {
214
+ "content": "[PAD]",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ }
221
+ },
222
+ "additional_special_tokens": [
223
+ "<|im_start|>",
224
+ "<|im_end|>",
225
+ "<|object_ref_start|>",
226
+ "<|object_ref_end|>",
227
+ "<|box_start|>",
228
+ "<|box_end|>",
229
+ "<|quad_start|>",
230
+ "<|quad_end|>",
231
+ "<|vision_start|>",
232
+ "<|vision_end|>",
233
+ "<|vision_pad|>",
234
+ "<|image_pad|>",
235
+ "<|video_pad|>",
236
+ "<vila/sentinel>",
237
+ "<image>",
238
+ "<vila/video>"
239
+ ],
240
+ "bos_token": "[BOS]",
241
+ "chat_template": "{% if messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。<|im_end|>\\n' }}{% endif %}{% for message in messages if message['content'] is not none %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
242
+ "clean_up_tokenization_spaces": false,
243
+ "eos_token": "<|im_end|>",
244
+ "errors": "replace",
245
+ "legacy": false,
246
+ "model_max_length": 4096,
247
+ "pad_token": "[PAD]",
248
+ "padding_side": "right",
249
+ "split_special_tokens": false,
250
+ "tokenizer_class": "Qwen2Tokenizer",
251
+ "unk_token": null
252
+ }
llm/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
loss.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from torch.nn.functional import cross_entropy
5
+
6
+ from .constants import IGNORE_INDEX
7
+
8
+ __all__ = ["soft_cross_entropy"]
9
+
10
+
11
+ def soft_cross_entropy(
12
+ outputs: torch.Tensor,
13
+ targets: torch.Tensor,
14
+ soft_tokens: Union[torch.Tensor, List[int]],
15
+ std: float = 1,
16
+ ignore_index: int = IGNORE_INDEX,
17
+ ) -> torch.Tensor:
18
+ # Remove last token from outputs and first token from targets
19
+ outputs = outputs[..., :-1, :].contiguous()
20
+ targets = targets[..., 1:].contiguous()
21
+
22
+ # Flatten outputs and targets
23
+ targets = targets.view(-1)
24
+ outputs = outputs.view(targets.size(0), -1)
25
+
26
+ # Remove outputs and targets with ignore_index
27
+ indices = targets != ignore_index
28
+ outputs = outputs[indices]
29
+ targets = targets[indices]
30
+
31
+ # Convert soft token IDs to tensor
32
+ if isinstance(soft_tokens, list):
33
+ soft_tokens = torch.tensor(soft_tokens).to(targets)
34
+
35
+ # Calculate loss for non-soft tokens
36
+ indices = torch.isin(targets, soft_tokens, invert=True)
37
+ loss = cross_entropy(outputs[indices], targets[indices], reduction="sum")
38
+
39
+ # Calculate loss for soft tokens
40
+ indices = torch.isin(targets, soft_tokens)
41
+ targets_indices = torch.zeros_like(outputs[indices])
42
+ for k, target in enumerate(targets[indices]):
43
+ dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2))
44
+ targets_indices[k][soft_tokens] = dist / dist.sum()
45
+ loss += cross_entropy(outputs[indices], targets_indices, reduction="sum")
46
+
47
+ # Return average loss
48
+ return loss / targets.size(0)
main.py ADDED
File without changes
media.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from collections import defaultdict
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import PIL
9
+ import PIL.Image
10
+ import requests
11
+ from transformers import PretrainedConfig
12
+
13
+ # from llava.constants import MEDIA_TOKENS
14
+ # from llava.media import Image, Video
15
+ # from llava.utils import make_list
16
+ # from llava.utils.logging import logger
17
+
18
+ MEDIA_TOKENS = {
19
+ "image": "<image>",
20
+ "video": "<vila/video>",
21
+ }
22
+
23
+
24
+ class Media:
25
+ pass
26
+
27
+
28
+ class File(Media):
29
+ def __init__(self, path: str) -> None:
30
+ self.path = path
31
+
32
+
33
+ class Image(File):
34
+ pass
35
+
36
+
37
+ class Video(File):
38
+ pass
39
+
40
+
41
+ def make_list(obj: Any) -> List:
42
+ return obj if isinstance(obj, list) else [obj]
43
+
44
+
45
+ def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
46
+ if isinstance(image, Image):
47
+ if image.path.startswith("http://") or image.path.startswith("https://"):
48
+ image = PIL.Image.open(requests.get(image.path, stream=True).raw)
49
+ else:
50
+ image = PIL.Image.open(image.path)
51
+ return image
52
+
53
+
54
+ def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
55
+ # Load video frames from a directory
56
+ if os.path.isdir(video_path):
57
+ frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
58
+ indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
59
+ return [PIL.Image.open(frame_paths[index]) for index in indices]
60
+
61
+ # Load video frames from a video file
62
+ vidcap = cv2.VideoCapture(video_path)
63
+
64
+ # Find the last frame as frame count might not be accurate
65
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
66
+ while frame_count > 0:
67
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
68
+ if vidcap.grab():
69
+ break
70
+ frame_count -= 1
71
+ else:
72
+ raise ValueError(f"Video '{video_path}' has no frames.")
73
+
74
+ # Extract frames uniformly
75
+ indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
76
+ frames = {}
77
+ for index in indices:
78
+ if index in frames:
79
+ continue
80
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
81
+ success, frame = vidcap.read()
82
+ if not success:
83
+ print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
84
+ continue
85
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
+ frames[index] = PIL.Image.fromarray(frame)
87
+ return [frames[index] for index in indices if index in frames]
88
+
89
+
90
+ def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
91
+ num_frames = config.num_video_frames
92
+ if getattr(config, "fps") != 0:
93
+ print("Extracting frames from video with specified FPS is not supported yet. Ignored.")
94
+
95
+ frames = _load_video(video.path, num_frames=num_frames)
96
+ return frames
97
+
98
+
99
+ def extract_media(
100
+ messages: List[Dict[str, Any]],
101
+ config: Optional[PretrainedConfig] = None,
102
+ draft: bool = False,
103
+ ) -> Dict[str, List[Any]]:
104
+ media = defaultdict(list)
105
+ for message in messages:
106
+ text = ""
107
+ for part in make_list(message["value"]):
108
+ if isinstance(part, str):
109
+ for token in MEDIA_TOKENS.values():
110
+ if token in part:
111
+ print(f"Media token '{token}' found in text: '{part}'. Removed.")
112
+ part = part.replace(token, "").strip()
113
+ text += part
114
+ elif isinstance(part, (Image, PIL.Image.Image)):
115
+ if draft:
116
+ media["image"].append(part)
117
+ else:
118
+ media["image"].append(_extract_image(part))
119
+ text += MEDIA_TOKENS["image"]
120
+ elif isinstance(part, Video):
121
+ if draft:
122
+ media["video"].append(part)
123
+ else:
124
+ media["video"].append(_extract_video(part, config))
125
+ text += MEDIA_TOKENS["video"]
126
+ else:
127
+ raise ValueError(f"Unsupported prompt part type: {type(part)}")
128
+ message["value"] = text
129
+ return media
media_encoder.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class BaseEncoder(nn.Module):
9
+ def __init__(self, parent: nn.Module) -> None:
10
+ super().__init__()
11
+ self._parent = [parent]
12
+
13
+ @property
14
+ def parent(self) -> nn.Module:
15
+ return self._parent[0]
16
+
17
+
18
+ class BasicImageEncoder(BaseEncoder):
19
+ def __init__(
20
+ self,
21
+ parent: torch.nn.Module,
22
+ start_tokens: Optional[str] = None,
23
+ end_tokens: Optional[str] = "\n",
24
+ ) -> None:
25
+ super().__init__(parent)
26
+ self.start_tokens = start_tokens
27
+ self.end_tokens = end_tokens
28
+
29
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
30
+ if tokens is None:
31
+ return None
32
+ token_ids = self.parent.tokenizer(tokens).input_ids
33
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
34
+ return self.parent.llm.model.embed_tokens(token_ids)
35
+
36
+ def _process_features(
37
+ self,
38
+ features: torch.Tensor,
39
+ start_token_embeds: Optional[torch.Tensor],
40
+ end_token_embeds: Optional[torch.Tensor],
41
+ ) -> torch.Tensor:
42
+ if start_token_embeds is not None:
43
+ features = torch.cat([start_token_embeds, features], dim=0)
44
+ if end_token_embeds is not None:
45
+ features = torch.cat([features, end_token_embeds], dim=0)
46
+ return features
47
+
48
+ def forward(self, images: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
49
+ images = torch.stack(images, dim=0)
50
+ features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
51
+ process_features = partial(
52
+ self._process_features,
53
+ start_token_embeds=self.embed_tokens(self.start_tokens),
54
+ end_token_embeds=self.embed_tokens(self.end_tokens),
55
+ )
56
+ return [process_features(f) for f in features]
57
+
58
+
59
+ class BasicVideoEncoder(BaseEncoder):
60
+ def __init__(
61
+ self,
62
+ parent: torch.nn.Module,
63
+ start_tokens: Optional[str] = None,
64
+ end_tokens: Optional[str] = "\n",
65
+ ) -> None:
66
+ super().__init__(parent)
67
+ self.start_tokens = start_tokens
68
+ self.end_tokens = end_tokens
69
+
70
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
71
+ if tokens is None:
72
+ return None
73
+ token_ids = self.parent.tokenizer(tokens).input_ids
74
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
75
+ return self.parent.llm.model.embed_tokens(token_ids)
76
+
77
+ def _process_features(
78
+ self,
79
+ features: torch.Tensor,
80
+ start_token_embeds: Optional[torch.Tensor],
81
+ end_token_embeds: Optional[torch.Tensor],
82
+ ) -> torch.Tensor:
83
+ if start_token_embeds is not None:
84
+ start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
85
+ features = torch.cat([start_embeds, features], dim=1)
86
+ if end_token_embeds is not None:
87
+ end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
88
+ features = torch.cat([features, end_embeds], dim=1)
89
+ return features.flatten(0, 1)
90
+
91
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
92
+ num_frames = [video.shape[0] for video in videos]
93
+ images = torch.cat(videos, dim=0)
94
+ features = self.parent.encode_images(images)
95
+ features = torch.split(features, num_frames)
96
+ process_features = partial(
97
+ self._process_features,
98
+ start_token_embeds=self.embed_tokens(self.start_tokens),
99
+ end_token_embeds=self.embed_tokens(self.end_tokens),
100
+ )
101
+ return [process_features(f) for f in features]
mm_projector/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/sft_14b_GPT4_v6/model/mm_projector",
3
+ "architectures": [
4
+ "MultimodalProjector"
5
+ ],
6
+ "mm_projector_type": "mlp_downsample_3x3_fix",
7
+ "model_type": "v2l_projector",
8
+ "torch_dtype": "bfloat16",
9
+ "transformers_version": "4.45.0"
10
+ }
mm_projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:117299ba5a53f595969b56b45068c3974b5ce68214bbb12b91518f87448252e1
3
+ size 159565424
mm_utils.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
18
+
19
+ import base64
20
+ import os
21
+ import tempfile
22
+ from io import BytesIO
23
+
24
+ import numpy as np
25
+ import torch
26
+ from PIL import Image
27
+ from transformers import StoppingCriteria
28
+
29
+ from .constants import DEFAULT_IMAGE_TOKEN
30
+
31
+
32
+ def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
33
+ import cv2
34
+
35
+ if fps == None or frame_count == None:
36
+ # if one of fps or frame_count is None, still recompute
37
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
38
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
39
+ if fps == 0 or frame_count == 0:
40
+ print(f"Video file not found. return empty images. {video_file_name}")
41
+ return [
42
+ Image.new("RGB", (720, 720)),
43
+ ] * num_frames, 0
44
+
45
+ duration = frame_count / fps
46
+ frame_interval = frame_count // num_frames
47
+ if frame_interval == 0 and frame_count <= 1:
48
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
49
+ return [
50
+ Image.new("RGB", (720, 720)),
51
+ ] * num_frames, 0
52
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
53
+
54
+ images = []
55
+ count = 0
56
+ success = True
57
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
58
+ while success:
59
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
60
+ if frame_count >= num_frames:
61
+ success, frame = vidcap.read()
62
+ if count in frame_indices:
63
+ try:
64
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
+ im_pil = Image.fromarray(img)
66
+ images.append(im_pil)
67
+ except BaseException:
68
+ continue
69
+ if len(images) >= num_frames:
70
+ return images, num_frames
71
+ count += 1
72
+ else:
73
+ # Left padding frames if the video is not long enough
74
+ success, frame = vidcap.read()
75
+ if success:
76
+ try:
77
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
78
+ im_pil = Image.fromarray(img)
79
+ images.append(im_pil)
80
+ except BaseException:
81
+ continue
82
+ count += 1
83
+ else:
84
+ break
85
+ if len(images) == 0:
86
+ raise ValueError("Did not find enough frames in the video. return empty image.")
87
+
88
+ return images, len(images)
89
+
90
+
91
+ def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
92
+ """
93
+ num_frames is the max number of frames the model can support.
94
+ frame_count is the number of frames in the input video.
95
+ max_fps is the max FPS of the model can support.
96
+ fps is the fps of the input video.
97
+ """
98
+
99
+ import random
100
+
101
+ import cv2
102
+
103
+ if fps == None or frame_count == None:
104
+ # if one of fps or frame_count is None, still recompute
105
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
106
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
107
+
108
+ if fps == 0 or frame_count == 0:
109
+ print(f"Video file not found. return empty images. {video_file_name}")
110
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
111
+ return [
112
+ Image.new("RGB", (720, 720)),
113
+ ] * empty_video_frames, 0
114
+
115
+ duration = frame_count / fps
116
+ # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
117
+ # If the video is too long (longer than max_fps and num_frames can support),
118
+ # we will use lower fps to sample frames.
119
+ if duration >= num_frames / max_fps:
120
+ frame_interval = frame_count // num_frames
121
+
122
+ # If the video is too short, we will skip the video if there is only one frame.
123
+ if frame_interval == 0 and frame_count <= 1:
124
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
125
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
126
+ return [
127
+ Image.new("RGB", (720, 720)),
128
+ ] * empty_video_frames, 0
129
+
130
+ images = []
131
+ count = 0
132
+ success = True
133
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
134
+
135
+ while success:
136
+ if frame_count >= num_frames:
137
+ # success, frame = vidcap.read()
138
+ if count in frame_indices:
139
+ success, frame = vidcap.read()
140
+ try:
141
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
142
+ im_pil = Image.fromarray(img)
143
+ images.append(im_pil)
144
+ except:
145
+ # print("Failed to read frame:", count)
146
+ continue
147
+ if len(images) >= num_frames:
148
+ return images, num_frames
149
+ else:
150
+ success = vidcap.grab()
151
+ count += 1
152
+ else:
153
+ # Left padding frames if the video is not long enough
154
+ success, frame = vidcap.read()
155
+ if success:
156
+ try:
157
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
158
+ im_pil = Image.fromarray(img)
159
+ images.append(im_pil)
160
+ except:
161
+ # print("Failed to read frame:", count)
162
+ continue
163
+ count += 1
164
+ else:
165
+ break
166
+ else:
167
+ frames_required = int(duration * max_fps)
168
+ frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
169
+ if frames_required == 0:
170
+ print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
171
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
172
+ return [
173
+ Image.new("RGB", (720, 720)),
174
+ ] * empty_video_frames, 0
175
+ elif frames_required == 1:
176
+ frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
177
+ images = []
178
+ count = 0
179
+ looked = 0
180
+ success = True
181
+
182
+ while success:
183
+ success, frame = vidcap.read()
184
+ if success and (looked in frame_indices):
185
+ try:
186
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
187
+ im_pil = Image.fromarray(img)
188
+ images.append(im_pil)
189
+ except:
190
+ continue
191
+ count += 1
192
+ looked += 1
193
+
194
+ if len(images) == 0:
195
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
196
+ return [
197
+ Image.new("RGB", (720, 720)),
198
+ ] * empty_video_frames, 0
199
+ else:
200
+ return images, len(images)
201
+
202
+
203
+ def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
204
+ """
205
+ Extract frames from a video using OpenCV.
206
+
207
+ Args:
208
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
209
+ frames (int): Number of frames to extract from the video.
210
+ fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
211
+
212
+ Returns:
213
+ list: List of PIL Images extracted from the video.
214
+
215
+ Raises:
216
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
217
+ """
218
+ import cv2
219
+
220
+ if isinstance(vpath_or_bytesio, str):
221
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
222
+ if max_fps > 0.0:
223
+ return get_frame_from_vcap_with_fps(
224
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
225
+ )
226
+ return get_frame_from_vcap(
227
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
228
+ )
229
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
230
+ # assuming mp4
231
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
232
+ temp_video.write(vpath_or_bytesio.read())
233
+ temp_video_name = temp_video.name
234
+ vidcap = cv2.VideoCapture(temp_video_name)
235
+ if max_fps > 0.0:
236
+ return get_frame_from_vcap_with_fps(
237
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
238
+ )
239
+ return get_frame_from_vcap(
240
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
241
+ )
242
+ else:
243
+ raise NotImplementedError(type(vpath_or_bytesio))
244
+
245
+
246
+ def load_image_from_base64(image):
247
+ return Image.open(BytesIO(base64.b64decode(image)))
248
+
249
+
250
+ def expand2square(pil_img, background_color):
251
+ """
252
+ Expand the given PIL image to a square shape by adding padding.
253
+
254
+ Parameters:
255
+ - pil_img: The PIL image to be expanded.
256
+ - background_color: The color of the padding to be added.
257
+
258
+ Returns:
259
+ - The expanded PIL image.
260
+
261
+ If the image is already square, it is returned as is.
262
+ If the image is wider than it is tall, padding is added to the top and bottom.
263
+ If the image is taller than it is wide, padding is added to the left and right.
264
+ """
265
+ width, height = pil_img.size
266
+ if pil_img.mode == "L":
267
+ background_color = background_color[0]
268
+ if width == height:
269
+ return pil_img
270
+ elif width > height:
271
+ result = Image.new(pil_img.mode, (width, width), background_color)
272
+ result.paste(pil_img, (0, (width - height) // 2))
273
+ return result
274
+ else:
275
+ result = Image.new(pil_img.mode, (height, height), background_color)
276
+ result.paste(pil_img, ((height - width) // 2, 0))
277
+ return result
278
+
279
+
280
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
281
+ best_ratio_diff = float("inf")
282
+ best_ratio = (1, 1)
283
+ area = width * height
284
+ for ratio in target_ratios:
285
+ target_aspect_ratio = ratio[0] / ratio[1]
286
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
287
+ if ratio_diff < best_ratio_diff:
288
+ best_ratio_diff = ratio_diff
289
+ best_ratio = ratio
290
+ elif ratio_diff == best_ratio_diff:
291
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
292
+ best_ratio = ratio
293
+ return best_ratio
294
+
295
+
296
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
297
+ orig_width, orig_height = image.size
298
+ aspect_ratio = orig_width / orig_height
299
+
300
+ # calculate the existing image aspect ratio
301
+ target_ratios = {
302
+ (i, j)
303
+ for n in range(min_num, max_num + 1)
304
+ for i in range(1, n + 1)
305
+ for j in range(1, n + 1)
306
+ if i * j <= max_num and i * j >= min_num
307
+ }
308
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
309
+
310
+ # find the closest aspect ratio to the target
311
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
312
+
313
+ # calculate the target width and height
314
+ target_width = image_size * target_aspect_ratio[0]
315
+ target_height = image_size * target_aspect_ratio[1]
316
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
317
+
318
+ # resize the image
319
+ resized_img = image.resize((target_width, target_height))
320
+ processed_images = []
321
+ for i in range(blocks):
322
+ box = (
323
+ (i % (target_width // image_size)) * image_size,
324
+ (i // (target_width // image_size)) * image_size,
325
+ ((i % (target_width // image_size)) + 1) * image_size,
326
+ ((i // (target_width // image_size)) + 1) * image_size,
327
+ )
328
+ # split the image
329
+ split_img = resized_img.crop(box)
330
+ processed_images.append(split_img)
331
+ assert len(processed_images) == blocks
332
+ if use_thumbnail and len(processed_images) != 1:
333
+ thumbnail_img = image.resize((image_size, image_size))
334
+ processed_images.append(thumbnail_img)
335
+ return processed_images
336
+
337
+
338
+ def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
339
+ orig_width, orig_height = image.size
340
+ aspect_ratio = orig_width / orig_height
341
+ min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale
342
+
343
+ processed_images = []
344
+
345
+ ##########################################################################################
346
+ ############# Add tiles for all but the last scale using fixed squre ratio ###############
347
+ ##########################################################################################
348
+
349
+ for scale in s2_scales[:-1]:
350
+ target_width = image_size * (scale // s2_scales[0])
351
+ target_height = image_size * (scale // s2_scales[0])
352
+ blocks = (scale // s2_scales[0]) ** 2
353
+
354
+ # resize the image
355
+ resized_img = image.resize((target_width, target_height))
356
+ for i in range(blocks):
357
+ box = (
358
+ (i % (target_width // image_size)) * image_size,
359
+ (i // (target_width // image_size)) * image_size,
360
+ ((i % (target_width // image_size)) + 1) * image_size,
361
+ ((i // (target_width // image_size)) + 1) * image_size,
362
+ )
363
+ # split the image
364
+ split_img = resized_img.crop(box)
365
+ processed_images.append(split_img)
366
+
367
+ ##########################################################################################
368
+ ################ Add tiles for the last scale using dynamic aspect ratio #################
369
+ ##########################################################################################
370
+
371
+ # calculate the existing image aspect ratio
372
+ target_ratios = {
373
+ (i, j)
374
+ for n in range(min_num, max_num + 1)
375
+ for i in range(1, n + 1)
376
+ for j in range(1, n + 1)
377
+ if i * j <= max_num and i * j >= min_num
378
+ }
379
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
380
+
381
+ # find the closest aspect ratio to the target
382
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
383
+
384
+ # calculate the target width and height
385
+ target_width = image_size * target_aspect_ratio[0]
386
+ target_height = image_size * target_aspect_ratio[1]
387
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
388
+
389
+ # resize the image
390
+ resized_img = image.resize((target_width, target_height))
391
+ for i in range(blocks):
392
+ box = (
393
+ (i % (target_width // image_size)) * image_size,
394
+ (i // (target_width // image_size)) * image_size,
395
+ ((i % (target_width // image_size)) + 1) * image_size,
396
+ ((i // (target_width // image_size)) + 1) * image_size,
397
+ )
398
+ # split the image
399
+ split_img = resized_img.crop(box)
400
+ processed_images.append(split_img)
401
+
402
+ return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
403
+
404
+
405
+ def dynamic_process_images_and_prompt(images, prompt, data_args, image_folder=None, max_tiles=None):
406
+ prompt = prompt.split(DEFAULT_IMAGE_TOKEN)
407
+ idx = 0
408
+ all_images = []
409
+ for img in images:
410
+ processed_images = process_image(img, data_args, image_folder, enable_dynamic_res=True, max_tiles=max_tiles)
411
+ all_images.append(processed_images)
412
+ prompt.insert(idx + 1, f"{DEFAULT_IMAGE_TOKEN}\n" * processed_images.shape[0])
413
+ idx += 2
414
+ prompt = "".join(prompt)
415
+ if all_images:
416
+ all_images = torch.cat(all_images)
417
+ else:
418
+ all_images = None
419
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, "")
420
+ return all_images, prompt
421
+
422
+
423
+ def dynamic_s2_process_images_and_prompt(images, prompt, data_args, image_folder=None):
424
+ idx = 0
425
+ all_images = []
426
+ all_block_size = []
427
+ for img in images:
428
+ processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
429
+ all_images.append(processed_images)
430
+ all_block_size.append(block_size)
431
+ idx += 2
432
+ if all_images:
433
+ all_images = torch.cat(all_images)
434
+ else:
435
+ all_images = None
436
+ return all_images, all_block_size
437
+
438
+
439
+ def process_image(
440
+ image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
441
+ ):
442
+ processor = data_args.image_processor
443
+ if isinstance(image_file, str):
444
+ if image_folder is not None:
445
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
446
+ else:
447
+ image = Image.open(image_file).convert("RGB")
448
+ else:
449
+ # image is stored in bytearray
450
+ image = image_file
451
+ image = image.convert("RGB")
452
+ if hasattr(data_args.image_processor, "crop_size"):
453
+ # CLIP vision tower
454
+ crop_size = data_args.image_processor.crop_size
455
+ else:
456
+ # SIGLIP vision tower
457
+ assert hasattr(data_args.image_processor, "size")
458
+ crop_size = data_args.image_processor.size
459
+ if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
460
+ assert crop_size["height"] == crop_size["width"]
461
+ images, block_size = dynamic_s2_preprocess(
462
+ image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
463
+ )
464
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
465
+ return torch.stack(images), block_size
466
+ if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
467
+ assert crop_size["height"] == crop_size["width"]
468
+ if max_tiles is not None:
469
+ max_num = max_tiles
470
+ else:
471
+ max_num = data_args.max_tiles
472
+ images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
473
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
474
+ return torch.stack(images)
475
+
476
+ if data_args.image_aspect_ratio == "resize":
477
+ image = image.resize((crop_size["width"], crop_size["height"]))
478
+ if data_args.image_aspect_ratio == "pad":
479
+
480
+ def expand2square(pil_img, background_color):
481
+ width, height = pil_img.size
482
+ if width == height:
483
+ return pil_img
484
+ elif width > height:
485
+ result = Image.new(pil_img.mode, (width, width), background_color)
486
+ result.paste(pil_img, (0, (width - height) // 2))
487
+ return result
488
+ else:
489
+ result = Image.new(pil_img.mode, (height, height), background_color)
490
+ result.paste(pil_img, ((height - width) // 2, 0))
491
+ return result
492
+
493
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
494
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
495
+ else:
496
+ # Using default behavior of the vision encoder
497
+ # For CLIP, default is central crop
498
+ # For Radio, default is central crop
499
+ # For Siglip, default is resize
500
+ # For InternVIT, default is resize
501
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
502
+ return image
503
+
504
+
505
+ def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
506
+ model_cfg.image_processor = image_processor
507
+ new_images = [
508
+ process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
509
+ for image in images
510
+ ]
511
+
512
+ if all(x.shape == new_images[0].shape for x in new_images):
513
+ if len(new_images[0].shape) == 4:
514
+ new_images = torch.cat(new_images, dim=0)
515
+ elif len(new_images[0].shape) == 3:
516
+ new_images = torch.stack(new_images, dim=0)
517
+ else:
518
+ raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
519
+ else:
520
+ raise ValueError("The shape of images in new_images is different!")
521
+ return new_images
522
+
523
+
524
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
525
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
526
+
527
+
528
+ def is_gemma_tokenizer(tokenizer):
529
+ return "gemma" in tokenizer.__class__.__name__.lower()
530
+
531
+
532
+ def get_model_name_from_path(model_path):
533
+ model_path = model_path.strip("/")
534
+ model_paths = model_path.split("/")
535
+ if model_paths[-1].startswith("checkpoint-"):
536
+ return model_paths[-2] + "_" + model_paths[-1]
537
+ else:
538
+ return model_paths[-1]
539
+
540
+
541
+ class KeywordsStoppingCriteria(StoppingCriteria):
542
+ def __init__(self, keywords, tokenizer, input_ids):
543
+ self.keywords = keywords
544
+ self.keyword_ids = []
545
+ self.max_keyword_len = 0
546
+ for keyword in keywords:
547
+ cur_keyword_ids = tokenizer(keyword).input_ids
548
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
549
+ cur_keyword_ids = cur_keyword_ids[1:]
550
+ if len(cur_keyword_ids) > self.max_keyword_len:
551
+ self.max_keyword_len = len(cur_keyword_ids)
552
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
553
+ self.tokenizer = tokenizer
554
+ self.start_len = input_ids.shape[1]
555
+
556
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
557
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
558
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
559
+ for keyword_id in self.keyword_ids:
560
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
561
+ return True
562
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
563
+ for keyword in self.keywords:
564
+ if keyword in outputs:
565
+ return True
566
+ return False
567
+
568
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
569
+ outputs = []
570
+ for i in range(output_ids.shape[0]):
571
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
572
+ return all(outputs)
model_utils_packing.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import transformers
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ __all__ = ["patch"]
10
+
11
+
12
+ def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]:
13
+ if hasattr(_get_unpad_data, "seqlens_in_batch"):
14
+ seqlens_in_batch = _get_unpad_data.seqlens_in_batch
15
+ else:
16
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
17
+
18
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
19
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
20
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
21
+ return indices, cu_seqlens, max_seqlen_in_batch
22
+
23
+
24
+ def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None:
25
+ _get_unpad_data.seqlens_in_batch = seqlens_in_batch
26
+
27
+
28
+ def patch(model: nn.Module) -> None:
29
+ if transformers.__version__ < "4.43.0":
30
+ m = import_module(model.__module__)
31
+ if not hasattr(m, "_get_unpad_data"):
32
+ raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing")
33
+ m._get_unpad_data = _get_unpad_data
34
+ else:
35
+ transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data
modeling_vila.py ADDED
@@ -0,0 +1,1228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ import os.path
7
+ import os.path as osp
8
+ import shutil
9
+ import warnings
10
+ from abc import ABC
11
+ from collections import OrderedDict, defaultdict, deque
12
+ from copy import deepcopy
13
+ from itertools import chain
14
+ from threading import Thread
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torchvision
22
+ from einops import rearrange
23
+ from PIL import Image
24
+ from transformers import (
25
+ AutoConfig,
26
+ AutoModel,
27
+ AutoProcessor,
28
+ AutoTokenizer,
29
+ GenerationConfig,
30
+ LogitsProcessor,
31
+ PretrainedConfig,
32
+ PreTrainedModel,
33
+ Qwen2Config,
34
+ Qwen2ForCausalLM,
35
+ Qwen2PreTrainedModel,
36
+ TextIteratorStreamer,
37
+ )
38
+ from transformers.modeling_outputs import CausalLMOutputWithPast
39
+ from transformers.modeling_utils import ContextManagers, no_init_weights
40
+
41
+ from .auto_processor import VILAProcessor
42
+ from .base_projector import MultimodalProjector, MultimodalProjectorConfig
43
+ from .builder import build_llm_and_tokenizer
44
+ from .configuration_vila import VILAConfig
45
+ from .constants import *
46
+ from .conversation import SeparatorStyle, default_conversation
47
+ from .distributed import all_gather as vila_all_gather
48
+ from .loss import soft_cross_entropy
49
+ from .media import extract_media
50
+ from .media_encoder import BasicImageEncoder, BasicVideoEncoder
51
+ from .mm_utils import process_image, process_images
52
+ from .model_utils_packing import set_seqlens_in_batch
53
+ from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
54
+ from .tokenizer_utils import tokenize_conversation
55
+ from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
56
+
57
+ # from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
58
+
59
+ # ease debugging
60
+ python_input = input
61
+
62
+ # quick hack for remote code
63
+ def get_pg_manager():
64
+ return None
65
+
66
+
67
+ def get_model_weights_dtype(model: nn.Module):
68
+ pass
69
+
70
+
71
+ def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
72
+ if model_type_or_path is None:
73
+ return None
74
+ ## load from pretrained model
75
+ if config.resume_path:
76
+ assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
77
+ return MultimodalProjector.from_pretrained(model_type_or_path, config)
78
+ ## build from scratch
79
+ else:
80
+ mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
81
+ mm_projector = MultimodalProjector(mm_projector_cfg, config)
82
+ return mm_projector
83
+
84
+
85
+ def check_dot_in_model_path(model_path: str):
86
+ """Check if the model path contains dot, which will affect the remote code loading."""
87
+ if osp.isdir(model_path): # local model
88
+ if "." in osp.abspath(model_path):
89
+ return True
90
+ else: # remote model
91
+ if "." in model_path:
92
+ return True
93
+ return False
94
+
95
+
96
+ def get_vila_version(model_path: str) -> str:
97
+ VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
98
+ for version in VERSIONS:
99
+ if version in model_path.lower():
100
+ return version
101
+ return None
102
+
103
+
104
+ def generate_jinja_template(conv_mode: str) -> str:
105
+ if conv_mode == "vicuna_v1":
106
+ return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
107
+ {% set roles = ["user", "assistant"] %}
108
+ {% set sep = " " %}
109
+
110
+ {{ system_prompt }}
111
+
112
+ {% for message in messages %}
113
+ {% if message['role'] == roles[0] %}
114
+ {{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
115
+ {% else %}
116
+ {{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
117
+ {% endif %}
118
+ {% endfor %}
119
+ {% if messages[-1]['role'] == 'user' %}
120
+ {{ "ASSISTANT:" }}
121
+ {% endif %}
122
+ """
123
+ elif conv_mode == "llama_3":
124
+ return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
125
+ {% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
126
+ {% set sep = "<|eot_id|>" %}
127
+
128
+ {{ system_prompt }}
129
+ {% for message in messages %}
130
+ {% if message['role'] == 'user' %}
131
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
132
+ {% else %}
133
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
134
+ {% endif %}
135
+ {% endfor %}
136
+ {% if messages[-1]['role'] == 'user' %}
137
+ {{ roles[1] }}
138
+ {% endif %}
139
+ """
140
+ elif conv_mode == "hermes_2":
141
+ return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
142
+ {% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
143
+ {% set sep = "<|im_end|>" %}
144
+
145
+ {{ system_prompt }}{{ sep }}
146
+
147
+ {% for message in messages %}
148
+ {% if message['role'] == 'user' %}
149
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
150
+ {% else %}
151
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
152
+ {% endif %}
153
+ {% endfor %}"""
154
+ else:
155
+ raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
156
+
157
+
158
+ def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
159
+ ## skip vision tower instantiation
160
+ if model_name_or_path is None:
161
+ return None
162
+
163
+ vision_tower_arch = None
164
+ if config.resume_path and "radio" not in model_name_or_path:
165
+ assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
166
+ vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
167
+ vision_tower_arch = vision_tower_cfg.architectures[0].lower()
168
+ vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
169
+
170
+ use_s2 = getattr(config, "s2", False)
171
+ use_dynamic_s2 = getattr(config, "dynamic_s2", False)
172
+
173
+ if "siglip" in vision_tower_name:
174
+ if use_dynamic_s2:
175
+ vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
176
+ elif use_s2:
177
+ vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
178
+ else:
179
+ vision_tower = SiglipVisionTower(model_name_or_path, config)
180
+ else:
181
+ raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
182
+
183
+ config.mm_hidden_size = (
184
+ vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
185
+ )
186
+ return vision_tower
187
+
188
+
189
+ class VILAPretrainedModel(PreTrainedModel):
190
+ config_class = VILAConfig
191
+ main_input_name = "input_embeds"
192
+ supports_gradient_checkpointing = True
193
+ _supports_flash_attn_2 = True
194
+
195
+ def __init__(self, config: VILAConfig, *args, **kwargs):
196
+ super().__init__(config)
197
+ self.config = config
198
+ cfgs = get_model_config(config)
199
+ if len(cfgs) == 3:
200
+ llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
201
+ else:
202
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
203
+
204
+ # loading on cpu by default
205
+ device_map = kwargs.get("device_map", "cpu")
206
+ self.mm_projector = build_mm_projector(mm_projector_cfg, config)
207
+ self.vision_tower = build_vision_tower(vision_tower_cfg, config)
208
+ if "auto" in device_map or "cuda" in device_map:
209
+ self.mm_projector = self.mm_projector.cuda()
210
+ self.vision_tower = self.vision_tower.cuda()
211
+ # set device_map auto can autoamtically shard llm to different devices
212
+ self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
213
+
214
+ # NOTE(ligeng): need to add other decoders from config
215
+ self.encoders = {"image": BasicImageEncoder(self), "video": BasicVideoEncoder(self)}
216
+
217
+ self.post_config()
218
+ self.is_loaded = True
219
+
220
+ assert (
221
+ self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
222
+ ), "At least one of the components must be instantiated."
223
+
224
+ @classmethod
225
+ def convert_vila_dev_ckpt_to_remote(
226
+ self,
227
+ model_path: str,
228
+ output_dir: str = None,
229
+ vila_version: str | None = None,
230
+ conv_mode: str | None = None,
231
+ copy: bool = False,
232
+ copy_weights: bool = True,
233
+ copy_code: bool = True,
234
+ *model_args,
235
+ **kwargs,
236
+ ):
237
+ # assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
238
+ assert model_path != output_dir, "model_path and output_dir cannot be the same"
239
+ if os.path.isdir(model_path):
240
+ model_path = model_path
241
+ else:
242
+ from huggingface_hub import HfApi, snapshot_download
243
+
244
+ model_path = snapshot_download(model_path)
245
+ print("downloading HF model to", model_path)
246
+
247
+ if check_dot_in_model_path(model_path) and output_dir is None:
248
+ raise ValueError(
249
+ f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
250
+ )
251
+ if output_dir is not None and "." in output_dir:
252
+ raise ValueError(
253
+ f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
254
+ )
255
+
256
+ if copy:
257
+ print("copy is set to True, copying weights and code to output_dir")
258
+ copy_weights = copy_code = True
259
+ # copy weights and code to output_dir
260
+ self.copy_or_symlink_directory(model_path, output_dir, copy=copy_weights)
261
+ self.copy_remote_py_files(output_dir, copy=copy_code)
262
+
263
+ if vila_version is None:
264
+ vila_version = get_vila_version(output_dir)
265
+
266
+ cfg_path = os.path.join(output_dir, "config.json")
267
+ config = json.load(open(cfg_path))
268
+ config["version"] = "2.0" # nvila tag
269
+ config["architectures"] = ["VILAForCasualLM"]
270
+ config["auto_map"] = {
271
+ "AutoProcessor": "auto_processor.VILAProcessor",
272
+ "AutoConfig": "modeling_vila.VILAConfig",
273
+ "AutoModel": "modeling_vila.VILAForCasualLM",
274
+ "AutoModelForCausalLM": "modeling_vila.VILAForCasualLM",
275
+ }
276
+ # vila1.5 legacy support
277
+ config["model_type"] = "vila"
278
+ if vila_version in ["vila1.5", "vila-m3"]:
279
+ if conv_mode is None:
280
+ raise ValueError(f"Please specify the conversation mode for {output_dir}.")
281
+ config["chat_template"] = conv_mode
282
+ jinja_template = generate_jinja_template(conv_mode)
283
+ jinja_path = os.path.join(output_dir, f"{conv_mode}.jinja")
284
+ with open(jinja_path, "w") as f:
285
+ f.write(jinja_template)
286
+ json.dump(config, open(cfg_path, "w"), indent=2)
287
+
288
+ ##########################################################################################
289
+ config = AutoConfig.from_pretrained(output_dir, trust_remote_code=True)
290
+ tokenizer = load_tokenizer_then_handle_media_tokens_and_chat_template(output_dir, config)
291
+ tokenizer.save_pretrained(osp.join(output_dir, "llm"))
292
+ ##########################################################################################
293
+
294
+ @classmethod
295
+ def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
296
+ # Create output directory if it doesn't exist
297
+ os.makedirs(output_dir, exist_ok=True)
298
+ # Create symlinks for all files in model_path to output_dir
299
+ for item in os.listdir(model_path):
300
+ src_path = os.path.join(model_path, item)
301
+ dst_path = os.path.join(output_dir, item)
302
+
303
+ # Remove existing file/directory at destination if it exists
304
+ if os.path.exists(dst_path):
305
+ if os.path.islink(dst_path):
306
+ os.unlink(dst_path)
307
+ elif os.path.isdir(dst_path):
308
+ shutil.rmtree(dst_path)
309
+ else:
310
+ os.remove(dst_path)
311
+
312
+ # Create symlink
313
+ if copy:
314
+ if os.path.isdir(src_path):
315
+ shutil.copytree(src_path, dst_path)
316
+ else:
317
+ shutil.copy2(src_path, dst_path)
318
+ print(f"Copied {src_path} to {dst_path}")
319
+ else:
320
+ os.symlink(src_path, dst_path)
321
+ print(f"Created symlink from {src_path} to {dst_path}")
322
+
323
+ @classmethod
324
+ def copy_remote_py_files(cls, output_dir, copy=True):
325
+ ## copy .py and REAMDE for next loading remote code
326
+ current_file_path = os.path.abspath(__file__)
327
+ current_folder = os.path.dirname(current_file_path)
328
+ for file_name in os.listdir(current_folder):
329
+ if file_name == "INSTRUCTIONS.md":
330
+ src_fname = os.path.join(current_folder, file_name)
331
+ dst_fname = os.path.join(output_dir, "README.md")
332
+ if os.path.exists(dst_fname):
333
+ old_reamde = open(dst_fname).read()
334
+ else:
335
+ old_reamde = ""
336
+ with open(src_fname) as src, open(dst_fname, "w") as dst:
337
+ dst.write(src.read())
338
+ dst.write(old_reamde)
339
+ print("[HF remote code] REAMDE ", src_fname, "to", dst_fname)
340
+ if file_name.endswith(".py") or file_name.endswith(".jinja"):
341
+ full_file_name = os.path.join(current_folder, file_name)
342
+ if os.path.isfile(full_file_name):
343
+ if copy:
344
+ shutil.copy(full_file_name, output_dir)
345
+ print("[HF remote code] copying", full_file_name, "to", output_dir)
346
+ else:
347
+ # symlink to ease development
348
+ if os.path.exists(os.path.join(output_dir, file_name)):
349
+ os.remove(os.path.join(output_dir, file_name))
350
+ os.symlink(full_file_name, os.path.join(output_dir, file_name))
351
+ print("[HF remote code] linking", full_file_name, "to", output_dir)
352
+
353
+ def save_pretrained(self, output_dir, state_dict=None, **kwargs):
354
+ if state_dict is None:
355
+ # other wise fetch from deepspeed
356
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
357
+ state_dict = self.state_dict()
358
+
359
+ if getattr(self, "tokenizer", None):
360
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
361
+
362
+ if self.get_llm():
363
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
364
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
365
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
366
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
367
+ self.config.llm_cfg = self.llm.config
368
+
369
+ if self.get_vision_tower():
370
+ print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
371
+ self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
372
+ vision_tower_state_dict = OrderedDict(
373
+ {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
374
+ )
375
+ self.vision_tower.vision_tower.save_pretrained(
376
+ os.path.join(output_dir, "vision_tower"),
377
+ state_dict=vision_tower_state_dict,
378
+ )
379
+ self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
380
+ self.config.vision_tower_cfg = self.vision_tower.config
381
+ if hasattr(self.config.vision_tower_cfg, "auto_map"):
382
+ if "radio" not in self.get_vision_tower().__class__.__name__.lower():
383
+ delattr(self.config.vision_tower_cfg, "auto_map")
384
+
385
+ if self.get_mm_projector():
386
+ print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
387
+ self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
388
+ mm_projector_state_dict = OrderedDict(
389
+ {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
390
+ )
391
+ self.mm_projector.save_pretrained(
392
+ os.path.join(output_dir, "mm_projector"),
393
+ state_dict=mm_projector_state_dict,
394
+ )
395
+ self.config.mm_projector_cfg = self.mm_projector.config
396
+
397
+ ## update and save top-level config
398
+ self.config._name_or_path = output_dir
399
+ self.config.architectures = [self.__class__.__name__]
400
+ self.config.save_pretrained(output_dir)
401
+
402
+ ## copy .py and REAMDE for next loading remote code
403
+ self.copy_remote_py_files(output_dir)
404
+
405
+ @classmethod
406
+ def from_pretrained(
407
+ cls,
408
+ pretrained_model_name_or_path: Optional[str] = None,
409
+ *model_args,
410
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
411
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
412
+ ignore_mismatched_sizes: bool = False,
413
+ force_download: bool = False,
414
+ local_files_only: bool = False,
415
+ token: Optional[Union[str, bool]] = None,
416
+ revision: str = "main",
417
+ use_safetensors: Optional[bool] = None,
418
+ weights_only: bool = True,
419
+ **kwargs,
420
+ ):
421
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
422
+ return cls._from_config(config, **kwargs)
423
+
424
+ def init_llm(self, llm_config, config, *args, **kwargs):
425
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
426
+ # hard coded for NVILA
427
+ # variables for XGrammar
428
+ # print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
429
+ NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
430
+
431
+ # TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
432
+ self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
433
+ # XGrammar tokenizer and grammar compiler
434
+ # lazy init only when specified json output during inference
435
+ self.grammar_compiler = None
436
+ self.llm.resize_token_embeddings(len(self.tokenizer))
437
+ return self.llm, self.tokenizer
438
+
439
+ def post_config(self):
440
+ ######################################################################
441
+ # TODO: need to check dtype with jason
442
+ self.llm = self.llm.to(torch.float16)
443
+ self.mm_projector = self.mm_projector.to(torch.float16)
444
+ self.vision_tower = self.vision_tower.to(torch.float16)
445
+ ######################################################################
446
+ self.training = self.llm.training
447
+ ## configuration
448
+ if getattr(self.config, "llm_cfg", None) is None:
449
+ self.config.llm_cfg = self.llm.config
450
+ if getattr(self.config, "vision_tower_cfg", None) is None:
451
+ self.config.vision_tower_cfg = self.vision_tower.config
452
+ if getattr(self.config, "mm_projector_cfg", None) is None:
453
+ self.config.mm_projector_cfg = self.mm_projector.config
454
+
455
+ def get_llm(self):
456
+ llm = getattr(self, "llm", None)
457
+ if type(llm) is list:
458
+ llm = llm[0]
459
+ return llm
460
+
461
+ def get_lm_head(self):
462
+ lm_head = getattr(self.get_llm(), "lm_head", None)
463
+ return lm_head
464
+
465
+ def get_vision_tower(self):
466
+ vision_tower = getattr(self, "vision_tower", None)
467
+ if type(vision_tower) is list:
468
+ vision_tower = vision_tower[0]
469
+ return vision_tower
470
+
471
+ def get_mm_projector(self):
472
+ mm_projector = getattr(self, "mm_projector", None)
473
+ if type(mm_projector) is list:
474
+ mm_projector = mm_projector[0]
475
+ return mm_projector
476
+
477
+ def freezed_module_patch(self):
478
+ """
479
+ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
480
+ """
481
+ if self.training:
482
+ if self.get_llm() and not getattr(self.config, "tune_language_model", False):
483
+ pass
484
+ # logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
485
+ if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
486
+ self.get_vision_tower().eval()
487
+ if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
488
+ self.get_mm_projector().eval()
489
+
490
+
491
+ class VILAForCasualLM(VILAPretrainedModel):
492
+ def __init__(self, config: VILAConfig, *args, **kwargs):
493
+ super().__init__(config, *args, **kwargs)
494
+
495
+ def merge_features_for_dynamic_s2(self, image_features, block_sizes):
496
+ scales = self.get_vision_tower().scales
497
+ resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
498
+
499
+ image_features_each_image = []
500
+ new_block_sizes = []
501
+ block_cnt = 0
502
+ for block_size_each_image in block_sizes:
503
+ if block_size_each_image is None:
504
+ cur_features = image_features[block_cnt : block_cnt + 1]
505
+ cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
506
+ cur_features = cur_features.repeat(1, len(scales), 1, 1)
507
+ image_features_each_image.append(cur_features)
508
+ new_block_sizes.append((1, 1))
509
+ block_cnt += 1
510
+ else:
511
+ cur_features_each_scale = []
512
+ for scale in scales[:-1]:
513
+ num_blocks_this_scale = (scale // scales[0]) ** 2
514
+ cur_features_each_scale.append(
515
+ self.merge_chessboard(
516
+ image_features[block_cnt : block_cnt + num_blocks_this_scale],
517
+ num_split_h=scale // scales[0],
518
+ num_split_w=scale // scales[0],
519
+ )
520
+ ) # 1 * C * H * W
521
+ block_cnt += num_blocks_this_scale
522
+ num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
523
+ cur_features_each_scale.append(
524
+ self.merge_chessboard(
525
+ image_features[block_cnt : block_cnt + num_blocks_last_scale],
526
+ num_split_h=block_size_each_image[0],
527
+ num_split_w=block_size_each_image[1],
528
+ )
529
+ ) # 1 * C * H * W
530
+ block_cnt += num_blocks_last_scale
531
+
532
+ # resize and concat features from different scales
533
+ output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
534
+ cur_features = torch.cat(
535
+ [
536
+ F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
537
+ cur_features_each_scale[i].dtype
538
+ )
539
+ for i in range(len(cur_features_each_scale))
540
+ ],
541
+ dim=1,
542
+ )
543
+ # cur_features = rearrange(cur_features, "1 c h w -> (h w) c")
544
+
545
+ image_features_each_image.append(cur_features)
546
+
547
+ if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
548
+ new_block_sizes.append(block_size_each_image)
549
+ else:
550
+ new_block_sizes.append(
551
+ (
552
+ scales[resize_output_to_scale_idx] // scales[0],
553
+ scales[resize_output_to_scale_idx] // scales[0],
554
+ )
555
+ )
556
+
557
+ assert block_cnt == len(image_features)
558
+
559
+ return image_features_each_image, new_block_sizes
560
+
561
+ def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None):
562
+ if block_sizes is None:
563
+ block_sizes = [None] * len(images)
564
+ if getattr(self.config, "dynamic_s2", False):
565
+ image_features = self.get_vision_tower()(images)
566
+ image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
567
+
568
+ image_features = [
569
+ self.split_chessboard(x, block_size[0], block_size[1])
570
+ for x, block_size in zip(image_features, new_block_sizes)
571
+ ] # list of B * C * H * W tensors
572
+ image_features = torch.cat(
573
+ [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
574
+ ) # B * N * C
575
+ image_features = self.get_mm_projector()(image_features)
576
+ image_features = list(
577
+ image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
578
+ )
579
+ image_features = [
580
+ self.merge_chessboard(x, block_size[0], block_size[1])
581
+ for x, block_size in zip(image_features, new_block_sizes)
582
+ ] # list of 1 * C * H * W tensors
583
+ image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
584
+ if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
585
+ image_features = torch.stack(image_features, dim=0)
586
+ else:
587
+ image_features = self.get_vision_tower()(images)
588
+ image_features = self.get_mm_projector()(image_features)
589
+ return image_features
590
+
591
+ def train(self, mode: bool = True):
592
+ if mode:
593
+ self.tokenizer.padding_side = "right"
594
+ else:
595
+ self.tokenizer.padding_side = "left"
596
+ super().train(mode)
597
+ return self
598
+
599
+ def _embed(
600
+ self,
601
+ input_ids: torch.Tensor,
602
+ media: Dict[str, List[torch.Tensor]],
603
+ media_config: Dict[str, Dict[str, Any]],
604
+ labels: Optional[torch.Tensor],
605
+ attention_mask: Optional[torch.Tensor],
606
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
607
+ # NOTE(ligeng): deep copy to avoid modifying the original media and media_config
608
+ media = copy.deepcopy(media)
609
+ media_config = copy.deepcopy(media_config)
610
+
611
+ labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
612
+ attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
613
+
614
+ PROCESS_GROUP_MANAGER = get_pg_manager()
615
+ if PROCESS_GROUP_MANAGER is not None:
616
+ for name in media:
617
+ self.encoders[name].end_tokens = None
618
+
619
+ # Extract text and media embeddings
620
+ text_embeds = self.llm.model.embed_tokens(input_ids)
621
+ if media is not None:
622
+ media_embeds = self.__embed_media_tokens(media, media_config)
623
+ else:
624
+ # no media was provided, so we just return an empty dict
625
+ media_embeds = {}
626
+
627
+ # This is a workaround to make sure the dummy embeddings are consumed
628
+ while media_embeds.get("dummy"):
629
+ dummy_embed = media_embeds["dummy"].popleft()
630
+ text_embeds += torch.sum(dummy_embed) * 0
631
+
632
+ # Remove padding
633
+ batch_size = labels.shape[0]
634
+ text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
635
+ labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
636
+
637
+ # Build inverse mapping from token ID to media name
638
+ media_tokens = {}
639
+ for name, token_id in self.tokenizer.media_token_ids.items():
640
+ media_tokens[token_id] = name
641
+
642
+ # Fuse text and media embeddings
643
+ inputs_m, labels_m = [], []
644
+ for k in range(batch_size):
645
+ inputs_mk, labels_mk = [], []
646
+ pos = 0
647
+ while pos < len(labels[k]):
648
+ if input_ids[k][pos].item() in media_tokens:
649
+ end = pos + 1
650
+ name = media_tokens[input_ids[k][pos].item()]
651
+ input = media_embeds[name].popleft()
652
+ label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
653
+ # print(f"{self.tokenizer.padding_side} [media] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:pos+1])}"); python_input()
654
+ elif input_ids[k][pos].item() in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id):
655
+ end = pos + 1
656
+ pos = end
657
+ # print(f"[skip PAD/EOS] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
658
+ continue
659
+ else:
660
+ end = pos
661
+ while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
662
+ end += 1
663
+ input = text_embeds[k][pos:end]
664
+ label = labels[k][pos:end]
665
+ # print(f"[text] {k=} {pos=}, {self.tokenizer.batch_decode(input_ids[k][pos:end])}"); python_input()
666
+
667
+ inputs_mk.append(input)
668
+ labels_mk.append(label)
669
+ pos = end
670
+ inputs_m.append(torch.cat(inputs_mk, dim=0))
671
+ labels_m.append(torch.cat(labels_mk, dim=0))
672
+ inputs, labels = inputs_m, labels_m
673
+
674
+ # Check if all media embeddings are consumed
675
+ for name in media_embeds:
676
+ if media_embeds[name]:
677
+ raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
678
+
679
+ # Truncate sequences to `model_max_length` as media embeddings are inserted
680
+ inputs, labels = self.__truncate_sequence(inputs, labels)
681
+
682
+ # Pad sequences to the longest one in the batch
683
+ return self.__batchify_sequence(inputs, labels)
684
+
685
+ def __embed_media_tokens(
686
+ self,
687
+ media: Dict[str, List[torch.Tensor]],
688
+ media_config: Dict[str, Dict[str, Any]],
689
+ ) -> Dict[str, List[torch.Tensor]]:
690
+ embeds = defaultdict(deque)
691
+ for name in media:
692
+ if self.training:
693
+ # Gather metainfo of media objects from all ranks
694
+ info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
695
+ infos = list(chain(vila_all_gather(info)))
696
+
697
+ # The entire batch does not contain any media objects of this type.
698
+ if not infos:
699
+ continue
700
+
701
+ # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
702
+ if media.get(name) is None or len(media[name]) == 0:
703
+ dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
704
+ embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
705
+ continue
706
+ embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
707
+ return embeds
708
+
709
+ def __truncate_sequence(
710
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
711
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
712
+ if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
713
+ warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
714
+ inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
715
+ labels = [label[: self.tokenizer.model_max_length] for label in labels]
716
+ return inputs, labels
717
+
718
+ def __batchify_sequence(
719
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
720
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
721
+ batch_size = len(inputs)
722
+ device = inputs[0].device
723
+ hidden_size = inputs[0].shape[1]
724
+ max_length = max(inputs[k].shape[0] for k in range(batch_size))
725
+ attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
726
+
727
+ inputs_p, labels_p = [], []
728
+ for k in range(batch_size):
729
+ size_pk = max_length - inputs[k].shape[0]
730
+ inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
731
+ labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
732
+ if self.tokenizer.padding_side == "right":
733
+ attention_mask[k, inputs[k].shape[0] :] = False
734
+ inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
735
+ labels_pk = torch.cat([labels[k], labels_pk], dim=0)
736
+ else:
737
+ attention_mask[k, : -inputs[k].shape[0]] = False
738
+ inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
739
+ labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
740
+ inputs_p.append(inputs_pk)
741
+ labels_p.append(labels_pk)
742
+
743
+ inputs = torch.stack(inputs_p, dim=0)
744
+ labels = torch.stack(labels_p, dim=0)
745
+ return inputs, labels, attention_mask
746
+
747
+ def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
748
+ # Handle sequence parallelism
749
+ PROCESS_GROUP_MANAGER = get_pg_manager()
750
+
751
+ # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
752
+ if PROCESS_GROUP_MANAGER is not None:
753
+ sp_degree = PROCESS_GROUP_MANAGER.sp_degree
754
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
755
+ sp_group = PROCESS_GROUP_MANAGER.sp_pg
756
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
757
+ ring_rank = PROCESS_GROUP_MANAGER.ring_rank
758
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
759
+ ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
760
+ ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
761
+
762
+ bs, shard_seqlen = position_ids.shape
763
+ sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
764
+ dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
765
+ sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
766
+
767
+ if sp_rank == 0:
768
+ original_start_id = 0
769
+ else:
770
+ original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
771
+ original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
772
+
773
+ # Gather attention_mask, position_ids, labels and input_embeds
774
+ all_inputs_embeds = torch.zeros(
775
+ bs,
776
+ torch.sum(sp_seq_len_cat),
777
+ inputs_embeds.shape[-1],
778
+ dtype=inputs_embeds.dtype,
779
+ device=inputs_embeds.device,
780
+ ).contiguous()
781
+ all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
782
+ dist.barrier(group=sp_group)
783
+ dist.all_reduce(all_inputs_embeds, group=sp_group)
784
+ dist.barrier(group=sp_group)
785
+
786
+ attention_mask_list = [
787
+ torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
788
+ for i in range(sp_degree)
789
+ ]
790
+ position_ids_list = [
791
+ torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
792
+ for i in range(sp_degree)
793
+ ]
794
+ labels_list = [
795
+ torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
796
+ ]
797
+
798
+ dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
799
+ dist.all_gather(position_ids_list, position_ids, group=sp_group)
800
+ dist.all_gather(labels_list, labels, group=sp_group)
801
+
802
+ effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
803
+ effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
804
+ effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
805
+
806
+ global_attention_mask_list = []
807
+ global_position_ids_list = []
808
+ global_labels_list = []
809
+ global_inputs_embeds_list = []
810
+ for i in range(bs):
811
+ global_attention_mask_batch_list = []
812
+ global_position_ids_batch_list = []
813
+ global_labels_batch_list = []
814
+ global_inputs_embeds_batch_list = []
815
+ for j in range(sp_degree):
816
+ eff_len = effective_seqlen_batch_list[i][j]
817
+ prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
818
+
819
+ global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
820
+ global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
821
+ global_labels_batch_list.append(labels_list[j][i, :eff_len])
822
+ global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
823
+ global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
824
+ global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
825
+ global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
826
+ global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
827
+
828
+ global_attention_mask = torch.nn.utils.rnn.pad_sequence(
829
+ global_attention_mask_list, batch_first=True, padding_value=False
830
+ )
831
+ global_position_ids = torch.nn.utils.rnn.pad_sequence(
832
+ global_position_ids_list, batch_first=True, padding_value=-1
833
+ )
834
+ global_labels = torch.nn.utils.rnn.pad_sequence(
835
+ global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
836
+ )
837
+ global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
838
+ global_inputs_embeds_list, batch_first=True, padding_value=0
839
+ )
840
+
841
+ # Re-shard the inputs
842
+ if ring_degree > 1:
843
+ total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
844
+ new_seqlen_per_rank = total_effective_seqlen // sp_degree
845
+ assert torch.all(
846
+ total_effective_seqlen % sp_degree == 0
847
+ ), "total_effective_seqlen must be divisible by sp_degree"
848
+
849
+ max_new_seqlen = torch.max(new_seqlen_per_rank).item()
850
+
851
+ new_attention_mask = torch.zeros(
852
+ (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
853
+ )
854
+ new_position_ids = torch.zeros(
855
+ (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
856
+ )
857
+ new_labels = torch.full(
858
+ (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
859
+ )
860
+ new_inputs_embeds = torch.zeros(
861
+ (bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
862
+ dtype=global_inputs_embeds.dtype,
863
+ device=global_inputs_embeds.device,
864
+ )
865
+
866
+ if ring_type == "ring_varlen":
867
+ for i in range(bs):
868
+ start_idx = new_seqlen_per_rank[i] * sp_rank
869
+ end_idx = start_idx + new_seqlen_per_rank[i]
870
+ new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
871
+ new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
872
+ new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
873
+ new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
874
+ i, start_idx:end_idx, :
875
+ ]
876
+ elif ring_type == "zigzag_ring_varlen":
877
+ chunk_size = total_effective_seqlen // (2 * sp_degree)
878
+ for i in range(bs):
879
+ # Zigzag pattern indices
880
+ if sp_degree == ring_degree:
881
+ forward_rank_idx = sp_rank
882
+ backward_rank_idx = 2 * sp_degree - sp_rank - 1
883
+ else:
884
+ ulysses_offset = ulysses_rank * ring_degree * 2
885
+ forward_rank_idx = ring_rank + ulysses_offset
886
+ backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
887
+
888
+ # Calculate start and end indices for the forward and backward zigzag
889
+ start_idx_fwd = forward_rank_idx * chunk_size[i]
890
+ end_idx_fwd = start_idx_fwd + chunk_size[i]
891
+
892
+ start_idx_bwd = backward_rank_idx * chunk_size[i]
893
+ end_idx_bwd = start_idx_bwd + chunk_size[i]
894
+
895
+ # Fill new tensors with zigzag data
896
+ new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
897
+ new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
898
+ i, start_idx_bwd:end_idx_bwd
899
+ ]
900
+
901
+ new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
902
+ new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
903
+ i, start_idx_bwd:end_idx_bwd
904
+ ]
905
+
906
+ new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
907
+ new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
908
+
909
+ new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
910
+ new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
911
+ i, start_idx_bwd:end_idx_bwd, :
912
+ ]
913
+ else:
914
+ raise ValueError(f"Invalid ring_type: {ring_type}")
915
+ else:
916
+ global_seq_len = global_attention_mask.shape[-1]
917
+ seq_len_sharded = global_seq_len // sp_degree
918
+ start_idx_reshard = seq_len_sharded * sp_rank
919
+ end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
920
+
921
+ new_attention_mask = torch.narrow(
922
+ global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
923
+ )
924
+ new_position_ids = torch.narrow(
925
+ global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
926
+ )
927
+ new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
928
+ new_inputs_embeds = torch.narrow(
929
+ global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
930
+ )
931
+
932
+ return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
933
+
934
+ device = inputs_embeds.device
935
+ batch_size = inputs_embeds.shape[0]
936
+ seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
937
+
938
+ # Pack all sequences together
939
+ inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
940
+ attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
941
+ position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
942
+ labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
943
+
944
+ # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
945
+ inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
946
+ attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
947
+ position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
948
+ labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
949
+
950
+ # Mask the first token of each sequence to avoid contamination
951
+ for label in labels_p:
952
+ label[0] = IGNORE_INDEX
953
+
954
+ # Batch the data
955
+ inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
956
+ attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
957
+ position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
958
+ labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
959
+
960
+ if hasattr(
961
+ self, "pad_to_multiple_of"
962
+ ): # related to quantization, please refer to ModelArguments for more information.
963
+ assert len(labels_p.shape) == 2
964
+ batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
965
+ hidden_size = inputs_embeds_p.shape[-1]
966
+
967
+ if max_length % self.pad_to_multiple_of != 0:
968
+ max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
969
+ difference = max_length - cur_length
970
+
971
+ inputs_embeds_p = torch.cat(
972
+ (
973
+ inputs_embeds_p,
974
+ torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
975
+ ),
976
+ dim=1,
977
+ )
978
+ labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
979
+ attention_mask_p = torch.cat(
980
+ (
981
+ attention_mask_p,
982
+ torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
983
+ ),
984
+ dim=1,
985
+ )
986
+ position_ids_p = torch.cat(
987
+ (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
988
+ )
989
+
990
+ return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
991
+
992
+ def get_xgr_logits_processor(self, response_format) -> List[LogitsProcessor]:
993
+ raise NotImplementedError("This method is not implemented for VILA model.")
994
+ # Convert response format to logits processor
995
+ import xgrammar as xgr
996
+
997
+ logging.info("[XGrammar] Compiling grammar for contrained output")
998
+
999
+ if self.grammar_compiler is None:
1000
+ # logging.info(f"[XGrammar] {self.tokenizer}, {self.tokenizer.vocab_size}, {self.vocab_size}")
1001
+ self.grammar_compiler = xgr.GrammarCompiler(
1002
+ xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size)
1003
+ )
1004
+
1005
+ if response_format.type == "json_schema":
1006
+ compiled_grammar = self.grammar_compiler.compile_json_schema(
1007
+ response_format.json_schema.schema_,
1008
+ indent=2,
1009
+ )
1010
+ else:
1011
+ compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
1012
+
1013
+ return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)]
1014
+
1015
+ def forward(
1016
+ self,
1017
+ input_ids: torch.LongTensor = None,
1018
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
1019
+ images: Optional[torch.FloatTensor] = None,
1020
+ media_config: Optional[List] = None,
1021
+ attention_mask: Optional[torch.Tensor] = None,
1022
+ position_ids: Optional[torch.LongTensor] = None,
1023
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1024
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1025
+ labels: Optional[torch.LongTensor] = None,
1026
+ packing: bool = True,
1027
+ force_packing: bool = False,
1028
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
1029
+ dpo_forward: bool = False,
1030
+ **kwargs,
1031
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1032
+ self.freezed_module_patch()
1033
+
1034
+ if images is not None:
1035
+ if media is not None:
1036
+ raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
1037
+ print("The 'images' argument is deprecated. Please use 'media' instead.")
1038
+ media = {"image": images}
1039
+
1040
+ if media_config is None:
1041
+ media_config = defaultdict(dict)
1042
+
1043
+ if inputs_embeds is None:
1044
+ inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
1045
+
1046
+ if force_packing or (packing and self.training and not dpo_forward):
1047
+ if seqlens_in_batch is None:
1048
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
1049
+ set_seqlens_in_batch(seqlens_in_batch)
1050
+
1051
+ (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
1052
+ inputs_embeds, attention_mask, position_ids, labels
1053
+ )
1054
+
1055
+ outputs = self.llm(
1056
+ inputs_embeds=inputs_embeds,
1057
+ attention_mask=attention_mask,
1058
+ position_ids=position_ids,
1059
+ past_key_values=past_key_values,
1060
+ labels=labels,
1061
+ **kwargs,
1062
+ )
1063
+
1064
+ if self.training and getattr(self.config, "time_token_ids", []):
1065
+ outputs.loss = soft_cross_entropy(
1066
+ outputs.logits,
1067
+ labels,
1068
+ soft_tokens=self.config.time_token_ids,
1069
+ std=self.config.soft_ce_std,
1070
+ )
1071
+
1072
+ if dpo_forward:
1073
+ return outputs.logits, labels
1074
+
1075
+ return outputs
1076
+
1077
+ @torch.inference_mode()
1078
+ def generate(
1079
+ self,
1080
+ input_ids: Optional[torch.FloatTensor] = None,
1081
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
1082
+ media_config: Dict[str, Dict[str, Any]] = None,
1083
+ attention_mask: Optional[torch.LongTensor] = None,
1084
+ **generation_kwargs,
1085
+ ):
1086
+ if self.training:
1087
+ warnings.warn(
1088
+ "Model is in training mode, using default padding strategy to right. This is not recommended for generation."
1089
+ )
1090
+ inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1091
+ return self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
1092
+
1093
+ @torch.inference_mode()
1094
+ def generate_content(
1095
+ self,
1096
+ prompt: Union[str, List],
1097
+ generation_config: Optional[GenerationConfig] = None,
1098
+ response_format=None,
1099
+ ) -> str:
1100
+ # TODO(zhijianl): Support directly taking conversation as input
1101
+ conversation = [{"from": "human", "value": prompt}]
1102
+
1103
+ # Convert response format to logits processor
1104
+ if response_format:
1105
+ xgr_logits_processor = self.get_xgr_logits_processor(response_format)
1106
+ else:
1107
+ xgr_logits_processor = None
1108
+
1109
+ # Extract media from the conversation
1110
+
1111
+ # TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used)
1112
+ media = extract_media(conversation, self.config)
1113
+
1114
+ # Process media
1115
+ media_config = defaultdict(dict)
1116
+ for name in media:
1117
+ if name == "image":
1118
+ if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
1119
+ self.config.image_processor = self.vision_tower.image_processor
1120
+ if self.config.image_aspect_ratio == "dynamic":
1121
+ images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
1122
+ conversation[0]["value"] = conversation[0]["value"].replace(
1123
+ DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
1124
+ )
1125
+ else:
1126
+ if type(self.config.s2_scales) is str:
1127
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1128
+ images, block_sizes = process_image(
1129
+ media["image"][0], self.config, None, enable_dynamic_s2=True
1130
+ )
1131
+ images = images.half()
1132
+ media_config[name]["block_sizes"] = [block_sizes]
1133
+ else:
1134
+ images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
1135
+ media[name] = [image for image in images]
1136
+ elif name == "video":
1137
+ if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
1138
+ media[name] = [
1139
+ process_images(
1140
+ images,
1141
+ self.vision_tower.image_processor,
1142
+ self.config,
1143
+ enable_dynamic_res=True,
1144
+ max_tiles=self.config.video_max_tiles,
1145
+ ).half()
1146
+ for images in media[name]
1147
+ ]
1148
+ elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
1149
+ self.config.image_processor = self.vision_tower.image_processor
1150
+ if type(self.config.s2_scales) is str:
1151
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1152
+ media[name] = [
1153
+ torch.cat(
1154
+ [
1155
+ process_image(
1156
+ image,
1157
+ self.config,
1158
+ None,
1159
+ enable_dynamic_s2=True,
1160
+ max_tiles=self.config.video_max_tiles,
1161
+ )[0].half()
1162
+ for image in images
1163
+ ]
1164
+ )
1165
+ for images in media[name]
1166
+ ]
1167
+ else:
1168
+ media[name] = [
1169
+ process_images(images, self.vision_tower.image_processor, self.config).half()
1170
+ for images in media[name]
1171
+ ]
1172
+ else:
1173
+ raise ValueError(f"Unsupported media type: {name}")
1174
+
1175
+ # Tokenize the conversation
1176
+ input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0)
1177
+
1178
+ # Set up the generation config
1179
+ generation_config = generation_config or self.default_generation_config
1180
+
1181
+ # print("input_ids", input_ids.shape)
1182
+ # print(input_ids)
1183
+ # print(self.tokenizer.batch_decode(input_ids))
1184
+ # print("media", {k: len(v) for k, v in media.items()})
1185
+ # print("media_config", media_config)
1186
+ # print("generation_config", generation_config)
1187
+ # input("wait for debug")
1188
+ # Generate the response
1189
+ try:
1190
+ output_ids = self.generate(
1191
+ input_ids=input_ids,
1192
+ media=media,
1193
+ media_config=media_config,
1194
+ generation_config=generation_config,
1195
+ logits_processor=xgr_logits_processor, # structured generation
1196
+ )
1197
+ except ValueError:
1198
+ if not generation_config.do_sample:
1199
+ raise
1200
+ # FIXME(zhijianl): This is a temporary workaround for the sampling issue
1201
+ logging.warning("Generation failed with sampling, retrying with greedy decoding.")
1202
+ generation_config.do_sample = False
1203
+ output_ids = self.generate(
1204
+ input_ids=input_ids,
1205
+ media=media,
1206
+ media_config=media_config,
1207
+ generation_config=generation_config,
1208
+ logits_processor=xgr_logits_processor,
1209
+ )
1210
+
1211
+ # Decode the response
1212
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
1213
+ return response
1214
+
1215
+ @property
1216
+ def default_generation_config(self) -> GenerationConfig:
1217
+ generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
1218
+ if self.tokenizer.eos_token_id is None:
1219
+ raise ValueError("Tokenizer must have an EOS token")
1220
+ if generation_config.max_length == GenerationConfig().max_length:
1221
+ generation_config.max_length = self.tokenizer.model_max_length
1222
+ if generation_config.pad_token_id is None:
1223
+ generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
1224
+ if generation_config.bos_token_id is None:
1225
+ generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
1226
+ if generation_config.eos_token_id is None:
1227
+ generation_config.eos_token_id = self.tokenizer.eos_token_id
1228
+ return generation_config
qwen2_jp.jinja ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% if messages[0]['role'] != 'system' %}
2
+ {{ '<|im_start|>system\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。<|im_end|>\n' }}
3
+ {% endif %}
4
+
5
+ {% for message in messages if message['content'] is not none %}
6
+ {{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
7
+ {% endfor %}
8
+
9
+ {% if add_generation_prompt %}
10
+ {{ '<|im_start|>assistant\n' }}
11
+ {% endif %}
siglip_encoder.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from accelerate.hooks import add_hook_to_module
21
+ from einops import rearrange
22
+ from s2wrapper import forward as multiscale_forward
23
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
24
+ from transformers.image_processing_utils import BaseImageProcessor
25
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
26
+ from transformers.models.siglip import SiglipVisionModel
27
+
28
+
29
+ class VisionTower(nn.Module):
30
+ def __init__(self, vision_tower, args, delay_load=False):
31
+ super().__init__()
32
+
33
+ self.is_loaded = False
34
+
35
+ self.vision_tower_name = vision_tower
36
+ self.select_layer = getattr(args, "mm_vision_select_layer", -2)
37
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
38
+
39
+ self.cfg_only = None
40
+
41
+ def feature_select(self, image_forward_outs):
42
+ image_features = image_forward_outs.hidden_states[self.select_layer]
43
+ if self.select_feature == "patch":
44
+ image_features = image_features[:, 1:]
45
+ elif self.select_feature == "cls_patch":
46
+ image_features = image_features
47
+ else:
48
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
49
+ return image_features
50
+
51
+ def _maybe_resize_pos_embeds(
52
+ self,
53
+ model: PreTrainedModel,
54
+ image_processor: BaseImageProcessor,
55
+ resolution: int = -1,
56
+ interpolate_mode: str = "linear",
57
+ ):
58
+ if resolution in [model.config.image_size, -1]:
59
+ return
60
+ print(
61
+ f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
62
+ )
63
+ embeddings = model.vision_model.embeddings
64
+ patch_size = embeddings.patch_size
65
+ num_new_tokens = int((resolution // patch_size) ** 2)
66
+
67
+ old_embeddings = embeddings.position_embedding
68
+ match interpolate_mode:
69
+ case "linear":
70
+ ## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
71
+ ## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
72
+ import torch
73
+ import torch.nn as nn
74
+
75
+ if is_deepspeed_zero3_enabled():
76
+ try:
77
+ import deepspeed
78
+ except ImportError:
79
+ raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.")
80
+ with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
81
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
82
+ else:
83
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
84
+ new_embeddings = nn.Embedding(
85
+ num_new_tokens,
86
+ old_embedding_dim,
87
+ dtype=old_embeddings.weight.dtype,
88
+ device=old_embeddings.weight.device,
89
+ )
90
+ mapped_indices = (
91
+ torch.arange(num_new_tokens).to(old_embeddings.weight.device)
92
+ / (num_new_tokens - 1)
93
+ * (old_num_tokens - 1)
94
+ )
95
+ floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
96
+ ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
97
+ if is_deepspeed_zero3_enabled():
98
+ params = [old_embeddings.weight, new_embeddings.weight]
99
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
100
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
101
+ ceil_indices, :
102
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
103
+ else:
104
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
105
+ ceil_indices, :
106
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
107
+ new_embeddings.weight.data = interpolated_embeds
108
+ case _:
109
+ raise NotImplementedError
110
+
111
+ if hasattr(old_embeddings, "_hf_hook"):
112
+ hook = old_embeddings._hf_hook
113
+ add_hook_to_module(new_embeddings, hook)
114
+ new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
115
+ ## update vision encoder's configurations
116
+ model.config.image_size = resolution
117
+ if hasattr(image_processor, "crop_size"):
118
+ # CLIP vision tower
119
+ image_processor.crop_size = resolution
120
+ else:
121
+ # SIGLIP vision tower
122
+ assert hasattr(image_processor, "size")
123
+ image_processor.size = {"height": resolution, "width": resolution}
124
+ ## TODO define a '_reinitialize' method for VisionTower
125
+ embeddings.position_embedding = new_embeddings
126
+ embeddings.image_size = resolution
127
+ embeddings.num_patches = embeddings.num_positions = num_new_tokens
128
+ embeddings.position_ids = (
129
+ torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
130
+ )
131
+
132
+ def forward(self, images):
133
+ if type(images) is list:
134
+ image_features = []
135
+ for image in images:
136
+ image_forward_out = self.vision_tower(
137
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
138
+ output_hidden_states=True,
139
+ )
140
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
141
+ image_features.append(image_feature)
142
+ else:
143
+ image_forward_outs = self.vision_tower(
144
+ images.to(device=self.device, dtype=self.dtype),
145
+ output_hidden_states=True,
146
+ )
147
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
148
+
149
+ return image_features
150
+
151
+ @property
152
+ def dummy_feature(self):
153
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
154
+
155
+ @property
156
+ def dtype(self):
157
+ return self.vision_tower.dtype
158
+
159
+ @property
160
+ def device(self):
161
+ return self.vision_tower.device
162
+
163
+ @property
164
+ def config(self):
165
+ if self.is_loaded:
166
+ return self.vision_tower.config
167
+ else:
168
+ return self.cfg_only
169
+
170
+ @property
171
+ def hidden_size(self):
172
+ return self.config.hidden_size
173
+
174
+ @property
175
+ def num_patches(self):
176
+ return (self.config.image_size // self.config.patch_size) ** 2
177
+
178
+
179
+ class VisionTowerS2(VisionTower):
180
+ def __init__(self, vision_tower, args, delay_load=False):
181
+ super().__init__(vision_tower, args, delay_load)
182
+
183
+ self.scales = list(map(int, args.s2_scales.split(",")))
184
+ self.scales.sort()
185
+ self.max_split_size = args.s2_max_split_size
186
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
187
+
188
+ def forward_feature(self, images):
189
+ image_forward_outs = self.vision_tower(
190
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
191
+ )
192
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
193
+ return image_features
194
+
195
+ def forward(self, images):
196
+ if type(images) is list:
197
+ image_feature = []
198
+ for image in images:
199
+ image_feature = multiscale_forward(
200
+ self.forward_feature,
201
+ image.unsqueeze(0),
202
+ img_sizes=self.scales,
203
+ max_split_size=self.max_split_size,
204
+ resize_output_to_idx=self.resize_output_to_scale_idx,
205
+ )
206
+ image_features.append(image_feature)
207
+ else:
208
+ image_features = multiscale_forward(
209
+ self.forward_feature,
210
+ images,
211
+ img_sizes=self.scales,
212
+ max_split_size=self.max_split_size,
213
+ resize_output_to_idx=self.resize_output_to_scale_idx,
214
+ )
215
+
216
+ return image_features
217
+
218
+ @property
219
+ def hidden_size(self):
220
+ return self.config.hidden_size * len(self.scales)
221
+
222
+
223
+ class VisionTowerDynamicS2(VisionTower):
224
+ def __init__(self, vision_tower, args, delay_load=False):
225
+ super().__init__(vision_tower, args, delay_load)
226
+
227
+ self.scales = list(map(int, args.s2_scales.split(",")))
228
+ self.scales.sort()
229
+ self.max_split_size = args.s2_max_split_size
230
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
231
+
232
+ def forward_feature(self, images):
233
+ image_forward_outs = self.vision_tower(
234
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
235
+ )
236
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
237
+ return image_features
238
+
239
+ def forward(self, images):
240
+ assert type(images) is not list
241
+ image_features = self.forward_feature(images)
242
+
243
+ return image_features
244
+
245
+ @property
246
+ def hidden_size(self):
247
+ return self.config.hidden_size * len(self.scales)
248
+
249
+
250
+ class SiglipVisionTower(VisionTower):
251
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
252
+ super().__init__(model_name_or_path, config)
253
+ # TODO(ligengl): why pass config here leading to errors?
254
+ self.vision_tower = SiglipVisionModel.from_pretrained(
255
+ model_name_or_path,
256
+ attn_implementation=config._attn_implementation,
257
+ torch_dtype=eval(config.model_dtype),
258
+ )
259
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
260
+ self.is_loaded = True
261
+
262
+
263
+ class SiglipVisionTowerS2(VisionTowerS2):
264
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
265
+ super().__init__(model_name_or_path, config)
266
+ self.vision_tower = SiglipVisionModel.from_pretrained(
267
+ model_name_or_path,
268
+ attn_implementation=config._attn_implementation,
269
+ torch_dtype=eval(config.model_dtype),
270
+ )
271
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
272
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
273
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
274
+ self.is_loaded = True
275
+
276
+
277
+ class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
278
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
279
+ super().__init__(model_name_or_path, config)
280
+ self.vision_tower = SiglipVisionModel.from_pretrained(
281
+ model_name_or_path,
282
+ attn_implementation="flash_attention_2",
283
+ torch_dtype=eval(config.model_dtype),
284
+ )
285
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
286
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
287
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
288
+ self.is_loaded = True
tokenizer_utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from typing import Any, Dict, List, Optional, Sequence
18
+
19
+ import torch
20
+ import transformers
21
+
22
+ from .constants import IGNORE_INDEX, SENTINEL_TOKEN
23
+ from .conversation import SeparatorStyle, default_conversation
24
+ from .mm_utils import tokenizer_image_token
25
+
26
+ # __all__ = [
27
+ # "tokenize_conversation",
28
+ # "preprocess_conversation",
29
+ # "infer_stop_tokens",
30
+ # ]
31
+
32
+ DUMMY_CONVERSATION = [
33
+ {"from": "human", "value": "question"},
34
+ {"from": "gpt", "value": "answer"},
35
+ ] * 10
36
+
37
+
38
+ def tokenize_conversation_legacy(
39
+ messages: Sequence[Dict[str, str]],
40
+ tokenizer: transformers.PreTrainedTokenizer,
41
+ add_generation_prompt: bool = False,
42
+ overrides: Optional[Dict[str, str]] = None,
43
+ no_system_prompt: bool = False,
44
+ ) -> torch.Tensor:
45
+ conv = default_conversation.copy()
46
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
47
+
48
+ if no_system_prompt:
49
+ conv.system = ""
50
+
51
+ # Skip the first message if it is not from human
52
+ if messages[0]["from"] != "human":
53
+ messages = messages[1:]
54
+
55
+ # Add a generation prompt if needed
56
+ if add_generation_prompt:
57
+ messages.append({"from": "gpt", "value": None})
58
+
59
+ conv.messages = []
60
+ for turn, message in enumerate(messages):
61
+ role = roles[message["from"]]
62
+ assert role == conv.roles[turn % 2]
63
+ if overrides is not None and message["from"] in overrides:
64
+ conv.append_message(role, overrides[message["from"]])
65
+ else:
66
+ conv.append_message(role, message["value"])
67
+
68
+ return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
69
+
70
+
71
+ def tokenize_conversation(
72
+ messages: Sequence[Dict[str, str]],
73
+ tokenizer: transformers.PreTrainedTokenizer,
74
+ add_generation_prompt: bool = False,
75
+ overrides: Optional[Dict[str, str]] = None,
76
+ no_system_prompt: bool = False,
77
+ ) -> torch.Tensor:
78
+ # Normalize the conversation before tokenization
79
+ for message in messages:
80
+ message["value"] = message["value"].strip()
81
+
82
+ if default_conversation.sep_style != SeparatorStyle.AUTO:
83
+ return tokenize_conversation_legacy(
84
+ messages,
85
+ tokenizer,
86
+ add_generation_prompt=add_generation_prompt,
87
+ overrides=overrides,
88
+ no_system_prompt=no_system_prompt,
89
+ )
90
+
91
+ conversation = []
92
+ for m in messages:
93
+ message = {}
94
+ if m["from"] == "human":
95
+ message["role"] = "user"
96
+ elif m["from"] == "gpt":
97
+ message["role"] = "assistant"
98
+ else:
99
+ raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
100
+
101
+ message["content"] = m["value"]
102
+ if overrides is not None and m["from"] in overrides:
103
+ message["content"] = overrides[m["from"]]
104
+ conversation.append(message)
105
+
106
+ if no_system_prompt:
107
+ conversation = [{"role": "system", "content": ""}] + conversation
108
+
109
+ text = tokenizer.apply_chat_template(
110
+ conversation,
111
+ add_generation_prompt=add_generation_prompt,
112
+ tokenize=False,
113
+ )
114
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt")
115
+
116
+
117
+ def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
118
+ if not hasattr(tokenizer, "sentinel_token"):
119
+ tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
120
+ tokenizer.sentinel_token = SENTINEL_TOKEN
121
+ tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
122
+
123
+
124
+ def preprocess_conversation(
125
+ conversation: Sequence[Dict[str, str]],
126
+ tokenizer: transformers.PreTrainedTokenizer,
127
+ no_system_prompt: bool = False,
128
+ retried: bool = False,
129
+ ) -> Dict[str, Any]:
130
+ inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
131
+ labels = torch.ones_like(inputs) * IGNORE_INDEX
132
+
133
+ # Generate the template by replacing the assistant's response with a sentinel.
134
+ _maybe_add_sentinel_token(tokenizer)
135
+ template = tokenize_conversation(
136
+ conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt
137
+ )
138
+
139
+ # Remove sentinel tokens from the template.
140
+ mask = torch.ones_like(template, dtype=torch.bool)
141
+ for k in range(template.size(0) - 1):
142
+ if template[k] == tokenizer.sentinel_token_id:
143
+ mask[k : k + 2] = False
144
+ # NOTE(zhijianl): This is to handle the corner case where there is an empty token before the sentinel token.
145
+ if k > 0 and retried:
146
+ mask[k - 1] = False
147
+ template = template[mask]
148
+
149
+ # Match the tokenized conversation with the template (with no assistant's response).
150
+ # Every token that is not matched will be included in the label for training.
151
+ p = 0
152
+ for k in range(inputs.size(0)):
153
+ if p < template.size(0) and inputs[k] == template[p]:
154
+ p += 1
155
+ else:
156
+ labels[k] = inputs[k]
157
+
158
+ # Mask all tokens in the label if the template is not fully matched.
159
+ if p < template.size(0):
160
+ if not retried:
161
+ return preprocess_conversation(
162
+ conversation,
163
+ tokenizer,
164
+ no_system_prompt=no_system_prompt,
165
+ retried=True,
166
+ )
167
+ print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.")
168
+ labels[:] = IGNORE_INDEX
169
+
170
+ return {"input_ids": inputs, "labels": labels}
171
+
172
+
173
+ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
174
+ _maybe_add_sentinel_token(tokenizer)
175
+ template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
176
+
177
+ stop_tokens = {tokenizer.eos_token}
178
+ for k in range(template.size(0) - 1):
179
+ if template[k] == tokenizer.sentinel_token_id:
180
+ stop_token = tokenizer.decode(template[k + 1])
181
+ stop_tokens.add(stop_token)
182
+ return list(stop_tokens)
trainer_state.json ADDED
@@ -0,0 +1,3311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 1.0,
5
+ "eval_steps": 500,
6
+ "global_step": 467,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.0,
13
+ "grad_norm": 11.039007186889648,
14
+ "learning_rate": 1.3333333333333334e-06,
15
+ "loss": 1.7243,
16
+ "step": 1
17
+ },
18
+ {
19
+ "epoch": 0.0,
20
+ "grad_norm": 11.325847625732422,
21
+ "learning_rate": 2.666666666666667e-06,
22
+ "loss": 1.7232,
23
+ "step": 2
24
+ },
25
+ {
26
+ "epoch": 0.01,
27
+ "grad_norm": 11.024140357971191,
28
+ "learning_rate": 4.000000000000001e-06,
29
+ "loss": 1.7473,
30
+ "step": 3
31
+ },
32
+ {
33
+ "epoch": 0.01,
34
+ "grad_norm": 8.857377052307129,
35
+ "learning_rate": 5.333333333333334e-06,
36
+ "loss": 1.5677,
37
+ "step": 4
38
+ },
39
+ {
40
+ "epoch": 0.01,
41
+ "grad_norm": 5.129051685333252,
42
+ "learning_rate": 6.666666666666667e-06,
43
+ "loss": 1.3132,
44
+ "step": 5
45
+ },
46
+ {
47
+ "epoch": 0.01,
48
+ "grad_norm": 3.457620143890381,
49
+ "learning_rate": 8.000000000000001e-06,
50
+ "loss": 1.2985,
51
+ "step": 6
52
+ },
53
+ {
54
+ "epoch": 0.01,
55
+ "grad_norm": 2.502241373062134,
56
+ "learning_rate": 9.333333333333334e-06,
57
+ "loss": 1.1922,
58
+ "step": 7
59
+ },
60
+ {
61
+ "epoch": 0.02,
62
+ "grad_norm": 2.6525237560272217,
63
+ "learning_rate": 1.0666666666666667e-05,
64
+ "loss": 1.1783,
65
+ "step": 8
66
+ },
67
+ {
68
+ "epoch": 0.02,
69
+ "grad_norm": 2.580990791320801,
70
+ "learning_rate": 1.2e-05,
71
+ "loss": 1.1252,
72
+ "step": 9
73
+ },
74
+ {
75
+ "epoch": 0.02,
76
+ "grad_norm": 2.4445464611053467,
77
+ "learning_rate": 1.3333333333333333e-05,
78
+ "loss": 1.1204,
79
+ "step": 10
80
+ },
81
+ {
82
+ "epoch": 0.02,
83
+ "grad_norm": 2.5538313388824463,
84
+ "learning_rate": 1.4666666666666666e-05,
85
+ "loss": 1.0808,
86
+ "step": 11
87
+ },
88
+ {
89
+ "epoch": 0.03,
90
+ "grad_norm": 2.922621488571167,
91
+ "learning_rate": 1.6000000000000003e-05,
92
+ "loss": 1.0484,
93
+ "step": 12
94
+ },
95
+ {
96
+ "epoch": 0.03,
97
+ "grad_norm": 1.6075185537338257,
98
+ "learning_rate": 1.7333333333333336e-05,
99
+ "loss": 1.0798,
100
+ "step": 13
101
+ },
102
+ {
103
+ "epoch": 0.03,
104
+ "grad_norm": 2.0998339653015137,
105
+ "learning_rate": 1.866666666666667e-05,
106
+ "loss": 1.023,
107
+ "step": 14
108
+ },
109
+ {
110
+ "epoch": 0.03,
111
+ "grad_norm": 1.311397910118103,
112
+ "learning_rate": 2e-05,
113
+ "loss": 1.0424,
114
+ "step": 15
115
+ },
116
+ {
117
+ "epoch": 0.03,
118
+ "grad_norm": 1.4649641513824463,
119
+ "learning_rate": 1.9999758458848847e-05,
120
+ "loss": 0.9873,
121
+ "step": 16
122
+ },
123
+ {
124
+ "epoch": 0.04,
125
+ "grad_norm": 1.5159320831298828,
126
+ "learning_rate": 1.9999033847063813e-05,
127
+ "loss": 1.0423,
128
+ "step": 17
129
+ },
130
+ {
131
+ "epoch": 0.04,
132
+ "grad_norm": 1.6150208711624146,
133
+ "learning_rate": 1.9997826199649607e-05,
134
+ "loss": 0.9522,
135
+ "step": 18
136
+ },
137
+ {
138
+ "epoch": 0.04,
139
+ "grad_norm": 2.5012216567993164,
140
+ "learning_rate": 1.9996135574945543e-05,
141
+ "loss": 0.9858,
142
+ "step": 19
143
+ },
144
+ {
145
+ "epoch": 0.04,
146
+ "grad_norm": 1.7912406921386719,
147
+ "learning_rate": 1.9993962054622703e-05,
148
+ "loss": 0.966,
149
+ "step": 20
150
+ },
151
+ {
152
+ "epoch": 0.04,
153
+ "grad_norm": 1.5078647136688232,
154
+ "learning_rate": 1.9991305743680013e-05,
155
+ "loss": 0.9418,
156
+ "step": 21
157
+ },
158
+ {
159
+ "epoch": 0.05,
160
+ "grad_norm": 1.0531651973724365,
161
+ "learning_rate": 1.9988166770439156e-05,
162
+ "loss": 0.9789,
163
+ "step": 22
164
+ },
165
+ {
166
+ "epoch": 0.05,
167
+ "grad_norm": 1.525269865989685,
168
+ "learning_rate": 1.9984545286538362e-05,
169
+ "loss": 0.9383,
170
+ "step": 23
171
+ },
172
+ {
173
+ "epoch": 0.05,
174
+ "grad_norm": 1.369185447692871,
175
+ "learning_rate": 1.9980441466925118e-05,
176
+ "loss": 0.9662,
177
+ "step": 24
178
+ },
179
+ {
180
+ "epoch": 0.05,
181
+ "grad_norm": 1.1335804462432861,
182
+ "learning_rate": 1.9975855509847688e-05,
183
+ "loss": 0.9393,
184
+ "step": 25
185
+ },
186
+ {
187
+ "epoch": 0.06,
188
+ "grad_norm": 1.4465155601501465,
189
+ "learning_rate": 1.9970787636845536e-05,
190
+ "loss": 0.933,
191
+ "step": 26
192
+ },
193
+ {
194
+ "epoch": 0.06,
195
+ "grad_norm": 1.7765053510665894,
196
+ "learning_rate": 1.9965238092738643e-05,
197
+ "loss": 0.9219,
198
+ "step": 27
199
+ },
200
+ {
201
+ "epoch": 0.06,
202
+ "grad_norm": 0.8634375333786011,
203
+ "learning_rate": 1.9959207145615663e-05,
204
+ "loss": 0.9462,
205
+ "step": 28
206
+ },
207
+ {
208
+ "epoch": 0.06,
209
+ "grad_norm": 1.3061445951461792,
210
+ "learning_rate": 1.9952695086820975e-05,
211
+ "loss": 0.8913,
212
+ "step": 29
213
+ },
214
+ {
215
+ "epoch": 0.06,
216
+ "grad_norm": 1.3201128244400024,
217
+ "learning_rate": 1.9945702230940616e-05,
218
+ "loss": 0.9069,
219
+ "step": 30
220
+ },
221
+ {
222
+ "epoch": 0.07,
223
+ "grad_norm": 1.1161390542984009,
224
+ "learning_rate": 1.993822891578708e-05,
225
+ "loss": 0.914,
226
+ "step": 31
227
+ },
228
+ {
229
+ "epoch": 0.07,
230
+ "grad_norm": 1.1489887237548828,
231
+ "learning_rate": 1.9930275502382993e-05,
232
+ "loss": 0.8876,
233
+ "step": 32
234
+ },
235
+ {
236
+ "epoch": 0.07,
237
+ "grad_norm": 1.072081446647644,
238
+ "learning_rate": 1.9921842374943682e-05,
239
+ "loss": 0.9394,
240
+ "step": 33
241
+ },
242
+ {
243
+ "epoch": 0.07,
244
+ "grad_norm": 1.204382061958313,
245
+ "learning_rate": 1.9912929940858607e-05,
246
+ "loss": 0.8852,
247
+ "step": 34
248
+ },
249
+ {
250
+ "epoch": 0.07,
251
+ "grad_norm": 1.0732938051223755,
252
+ "learning_rate": 1.9903538630671687e-05,
253
+ "loss": 0.9019,
254
+ "step": 35
255
+ },
256
+ {
257
+ "epoch": 0.08,
258
+ "grad_norm": 1.0138473510742188,
259
+ "learning_rate": 1.9893668898060504e-05,
260
+ "loss": 0.8915,
261
+ "step": 36
262
+ },
263
+ {
264
+ "epoch": 0.08,
265
+ "grad_norm": 1.2495840787887573,
266
+ "learning_rate": 1.988332121981436e-05,
267
+ "loss": 0.8955,
268
+ "step": 37
269
+ },
270
+ {
271
+ "epoch": 0.08,
272
+ "grad_norm": 1.1097376346588135,
273
+ "learning_rate": 1.9872496095811287e-05,
274
+ "loss": 0.8872,
275
+ "step": 38
276
+ },
277
+ {
278
+ "epoch": 0.08,
279
+ "grad_norm": 1.0911654233932495,
280
+ "learning_rate": 1.9861194048993865e-05,
281
+ "loss": 0.9061,
282
+ "step": 39
283
+ },
284
+ {
285
+ "epoch": 0.09,
286
+ "grad_norm": 1.078086018562317,
287
+ "learning_rate": 1.9849415625343972e-05,
288
+ "loss": 0.8869,
289
+ "step": 40
290
+ },
291
+ {
292
+ "epoch": 0.09,
293
+ "grad_norm": 1.57882821559906,
294
+ "learning_rate": 1.9837161393856413e-05,
295
+ "loss": 0.8587,
296
+ "step": 41
297
+ },
298
+ {
299
+ "epoch": 0.09,
300
+ "grad_norm": 1.0213719606399536,
301
+ "learning_rate": 1.982443194651142e-05,
302
+ "loss": 0.9093,
303
+ "step": 42
304
+ },
305
+ {
306
+ "epoch": 0.09,
307
+ "grad_norm": 1.8046919107437134,
308
+ "learning_rate": 1.9811227898246072e-05,
309
+ "loss": 0.8551,
310
+ "step": 43
311
+ },
312
+ {
313
+ "epoch": 0.09,
314
+ "grad_norm": 1.0796761512756348,
315
+ "learning_rate": 1.979754988692457e-05,
316
+ "loss": 0.9138,
317
+ "step": 44
318
+ },
319
+ {
320
+ "epoch": 0.1,
321
+ "grad_norm": 1.517764687538147,
322
+ "learning_rate": 1.978339857330743e-05,
323
+ "loss": 0.8252,
324
+ "step": 45
325
+ },
326
+ {
327
+ "epoch": 0.1,
328
+ "grad_norm": 1.3809912204742432,
329
+ "learning_rate": 1.976877464101957e-05,
330
+ "loss": 0.8894,
331
+ "step": 46
332
+ },
333
+ {
334
+ "epoch": 0.1,
335
+ "grad_norm": 1.5511187314987183,
336
+ "learning_rate": 1.975367879651728e-05,
337
+ "loss": 0.8437,
338
+ "step": 47
339
+ },
340
+ {
341
+ "epoch": 0.1,
342
+ "grad_norm": 1.6329996585845947,
343
+ "learning_rate": 1.9738111769054095e-05,
344
+ "loss": 0.9215,
345
+ "step": 48
346
+ },
347
+ {
348
+ "epoch": 0.1,
349
+ "grad_norm": 1.3756284713745117,
350
+ "learning_rate": 1.9722074310645553e-05,
351
+ "loss": 0.8401,
352
+ "step": 49
353
+ },
354
+ {
355
+ "epoch": 0.11,
356
+ "grad_norm": 1.7832353115081787,
357
+ "learning_rate": 1.9705567196032894e-05,
358
+ "loss": 0.8396,
359
+ "step": 50
360
+ },
361
+ {
362
+ "epoch": 0.11,
363
+ "grad_norm": 1.3009949922561646,
364
+ "learning_rate": 1.9688591222645607e-05,
365
+ "loss": 0.8627,
366
+ "step": 51
367
+ },
368
+ {
369
+ "epoch": 0.11,
370
+ "grad_norm": 1.448974847793579,
371
+ "learning_rate": 1.9671147210562925e-05,
372
+ "loss": 0.858,
373
+ "step": 52
374
+ },
375
+ {
376
+ "epoch": 0.11,
377
+ "grad_norm": 1.298194169998169,
378
+ "learning_rate": 1.9653236002474202e-05,
379
+ "loss": 0.8495,
380
+ "step": 53
381
+ },
382
+ {
383
+ "epoch": 0.12,
384
+ "grad_norm": 1.0985174179077148,
385
+ "learning_rate": 1.96348584636382e-05,
386
+ "loss": 0.8706,
387
+ "step": 54
388
+ },
389
+ {
390
+ "epoch": 0.12,
391
+ "grad_norm": 1.7281138896942139,
392
+ "learning_rate": 1.9616015481841293e-05,
393
+ "loss": 0.8665,
394
+ "step": 55
395
+ },
396
+ {
397
+ "epoch": 0.12,
398
+ "grad_norm": 1.2161897420883179,
399
+ "learning_rate": 1.9596707967354588e-05,
400
+ "loss": 0.8657,
401
+ "step": 56
402
+ },
403
+ {
404
+ "epoch": 0.12,
405
+ "grad_norm": 1.1948484182357788,
406
+ "learning_rate": 1.9576936852889937e-05,
407
+ "loss": 0.8545,
408
+ "step": 57
409
+ },
410
+ {
411
+ "epoch": 0.12,
412
+ "grad_norm": 1.8918001651763916,
413
+ "learning_rate": 1.955670309355489e-05,
414
+ "loss": 0.8358,
415
+ "step": 58
416
+ },
417
+ {
418
+ "epoch": 0.13,
419
+ "grad_norm": 1.1286191940307617,
420
+ "learning_rate": 1.9536007666806555e-05,
421
+ "loss": 0.8407,
422
+ "step": 59
423
+ },
424
+ {
425
+ "epoch": 0.13,
426
+ "grad_norm": 1.198012113571167,
427
+ "learning_rate": 1.951485157240437e-05,
428
+ "loss": 0.8662,
429
+ "step": 60
430
+ },
431
+ {
432
+ "epoch": 0.13,
433
+ "grad_norm": 2.0095624923706055,
434
+ "learning_rate": 1.9493235832361812e-05,
435
+ "loss": 0.8681,
436
+ "step": 61
437
+ },
438
+ {
439
+ "epoch": 0.13,
440
+ "grad_norm": 1.1153709888458252,
441
+ "learning_rate": 1.9471161490897027e-05,
442
+ "loss": 0.8658,
443
+ "step": 62
444
+ },
445
+ {
446
+ "epoch": 0.13,
447
+ "grad_norm": 1.3872712850570679,
448
+ "learning_rate": 1.9448629614382394e-05,
449
+ "loss": 0.822,
450
+ "step": 63
451
+ },
452
+ {
453
+ "epoch": 0.14,
454
+ "grad_norm": 1.0818780660629272,
455
+ "learning_rate": 1.942564129129298e-05,
456
+ "loss": 0.9052,
457
+ "step": 64
458
+ },
459
+ {
460
+ "epoch": 0.14,
461
+ "grad_norm": 1.1288385391235352,
462
+ "learning_rate": 1.940219763215399e-05,
463
+ "loss": 0.8246,
464
+ "step": 65
465
+ },
466
+ {
467
+ "epoch": 0.14,
468
+ "grad_norm": 0.9448270797729492,
469
+ "learning_rate": 1.9378299769487116e-05,
470
+ "loss": 0.856,
471
+ "step": 66
472
+ },
473
+ {
474
+ "epoch": 0.14,
475
+ "grad_norm": 0.8516116142272949,
476
+ "learning_rate": 1.93539488577558e-05,
477
+ "loss": 0.8436,
478
+ "step": 67
479
+ },
480
+ {
481
+ "epoch": 0.15,
482
+ "grad_norm": 0.9422905445098877,
483
+ "learning_rate": 1.9329146073309502e-05,
484
+ "loss": 0.8396,
485
+ "step": 68
486
+ },
487
+ {
488
+ "epoch": 0.15,
489
+ "grad_norm": 0.8786196112632751,
490
+ "learning_rate": 1.9303892614326835e-05,
491
+ "loss": 0.8769,
492
+ "step": 69
493
+ },
494
+ {
495
+ "epoch": 0.15,
496
+ "grad_norm": 1.207822322845459,
497
+ "learning_rate": 1.9278189700757717e-05,
498
+ "loss": 0.8053,
499
+ "step": 70
500
+ },
501
+ {
502
+ "epoch": 0.15,
503
+ "grad_norm": 1.005181074142456,
504
+ "learning_rate": 1.9252038574264403e-05,
505
+ "loss": 0.8608,
506
+ "step": 71
507
+ },
508
+ {
509
+ "epoch": 0.15,
510
+ "grad_norm": 1.247426986694336,
511
+ "learning_rate": 1.9225440498161544e-05,
512
+ "loss": 0.8336,
513
+ "step": 72
514
+ },
515
+ {
516
+ "epoch": 0.16,
517
+ "grad_norm": 0.9933120012283325,
518
+ "learning_rate": 1.9198396757355118e-05,
519
+ "loss": 0.8575,
520
+ "step": 73
521
+ },
522
+ {
523
+ "epoch": 0.16,
524
+ "grad_norm": 0.9208722114562988,
525
+ "learning_rate": 1.9170908658280388e-05,
526
+ "loss": 0.8066,
527
+ "step": 74
528
+ },
529
+ {
530
+ "epoch": 0.16,
531
+ "grad_norm": 0.8881359100341797,
532
+ "learning_rate": 1.9142977528838763e-05,
533
+ "loss": 0.8786,
534
+ "step": 75
535
+ },
536
+ {
537
+ "epoch": 0.16,
538
+ "grad_norm": 1.1525728702545166,
539
+ "learning_rate": 1.911460471833368e-05,
540
+ "loss": 0.8305,
541
+ "step": 76
542
+ },
543
+ {
544
+ "epoch": 0.16,
545
+ "grad_norm": 1.4480865001678467,
546
+ "learning_rate": 1.9085791597405404e-05,
547
+ "loss": 0.8406,
548
+ "step": 77
549
+ },
550
+ {
551
+ "epoch": 0.17,
552
+ "grad_norm": 0.8593180179595947,
553
+ "learning_rate": 1.9056539557964814e-05,
554
+ "loss": 0.8806,
555
+ "step": 78
556
+ },
557
+ {
558
+ "epoch": 0.17,
559
+ "grad_norm": 0.9452027082443237,
560
+ "learning_rate": 1.902685001312616e-05,
561
+ "loss": 0.8047,
562
+ "step": 79
563
+ },
564
+ {
565
+ "epoch": 0.17,
566
+ "grad_norm": 1.3369029760360718,
567
+ "learning_rate": 1.8996724397138813e-05,
568
+ "loss": 0.8317,
569
+ "step": 80
570
+ },
571
+ {
572
+ "epoch": 0.17,
573
+ "grad_norm": 0.8937678337097168,
574
+ "learning_rate": 1.8966164165317968e-05,
575
+ "loss": 0.8348,
576
+ "step": 81
577
+ },
578
+ {
579
+ "epoch": 0.18,
580
+ "grad_norm": 1.0756009817123413,
581
+ "learning_rate": 1.8935170793974335e-05,
582
+ "loss": 0.8271,
583
+ "step": 82
584
+ },
585
+ {
586
+ "epoch": 0.18,
587
+ "grad_norm": 0.8728197813034058,
588
+ "learning_rate": 1.8903745780342838e-05,
589
+ "loss": 0.8578,
590
+ "step": 83
591
+ },
592
+ {
593
+ "epoch": 0.18,
594
+ "grad_norm": 1.3119213581085205,
595
+ "learning_rate": 1.887189064251027e-05,
596
+ "loss": 0.7796,
597
+ "step": 84
598
+ },
599
+ {
600
+ "epoch": 0.18,
601
+ "grad_norm": 1.1723086833953857,
602
+ "learning_rate": 1.883960691934196e-05,
603
+ "loss": 0.8497,
604
+ "step": 85
605
+ },
606
+ {
607
+ "epoch": 0.18,
608
+ "grad_norm": 1.2870450019836426,
609
+ "learning_rate": 1.8806896170407437e-05,
610
+ "loss": 0.8096,
611
+ "step": 86
612
+ },
613
+ {
614
+ "epoch": 0.19,
615
+ "grad_norm": 1.0081167221069336,
616
+ "learning_rate": 1.8773759975905098e-05,
617
+ "loss": 0.878,
618
+ "step": 87
619
+ },
620
+ {
621
+ "epoch": 0.19,
622
+ "grad_norm": 1.154690146446228,
623
+ "learning_rate": 1.8740199936585856e-05,
624
+ "loss": 0.7973,
625
+ "step": 88
626
+ },
627
+ {
628
+ "epoch": 0.19,
629
+ "grad_norm": 1.2520458698272705,
630
+ "learning_rate": 1.8706217673675813e-05,
631
+ "loss": 0.8218,
632
+ "step": 89
633
+ },
634
+ {
635
+ "epoch": 0.19,
636
+ "grad_norm": 1.1887520551681519,
637
+ "learning_rate": 1.867181482879795e-05,
638
+ "loss": 0.7935,
639
+ "step": 90
640
+ },
641
+ {
642
+ "epoch": 0.19,
643
+ "grad_norm": 1.1408494710922241,
644
+ "learning_rate": 1.8636993063892822e-05,
645
+ "loss": 0.874,
646
+ "step": 91
647
+ },
648
+ {
649
+ "epoch": 0.2,
650
+ "grad_norm": 0.9687843322753906,
651
+ "learning_rate": 1.8601754061138258e-05,
652
+ "loss": 0.7991,
653
+ "step": 92
654
+ },
655
+ {
656
+ "epoch": 0.2,
657
+ "grad_norm": 1.1839170455932617,
658
+ "learning_rate": 1.8566099522868118e-05,
659
+ "loss": 0.8639,
660
+ "step": 93
661
+ },
662
+ {
663
+ "epoch": 0.2,
664
+ "grad_norm": 1.6939510107040405,
665
+ "learning_rate": 1.8530031171490055e-05,
666
+ "loss": 0.7854,
667
+ "step": 94
668
+ },
669
+ {
670
+ "epoch": 0.2,
671
+ "grad_norm": 1.2965248823165894,
672
+ "learning_rate": 1.8493550749402278e-05,
673
+ "loss": 0.8231,
674
+ "step": 95
675
+ },
676
+ {
677
+ "epoch": 0.21,
678
+ "grad_norm": 1.025974154472351,
679
+ "learning_rate": 1.8456660018909424e-05,
680
+ "loss": 0.8452,
681
+ "step": 96
682
+ },
683
+ {
684
+ "epoch": 0.21,
685
+ "grad_norm": 1.2490646839141846,
686
+ "learning_rate": 1.8419360762137395e-05,
687
+ "loss": 0.7846,
688
+ "step": 97
689
+ },
690
+ {
691
+ "epoch": 0.21,
692
+ "grad_norm": 0.8493202924728394,
693
+ "learning_rate": 1.8381654780947272e-05,
694
+ "loss": 0.8648,
695
+ "step": 98
696
+ },
697
+ {
698
+ "epoch": 0.21,
699
+ "grad_norm": 1.1620343923568726,
700
+ "learning_rate": 1.8343543896848275e-05,
701
+ "loss": 0.8261,
702
+ "step": 99
703
+ },
704
+ {
705
+ "epoch": 0.21,
706
+ "grad_norm": 0.9533255100250244,
707
+ "learning_rate": 1.830502995090977e-05,
708
+ "loss": 0.847,
709
+ "step": 100
710
+ },
711
+ {
712
+ "epoch": 0.22,
713
+ "grad_norm": 1.254692554473877,
714
+ "learning_rate": 1.826611480367232e-05,
715
+ "loss": 0.8101,
716
+ "step": 101
717
+ },
718
+ {
719
+ "epoch": 0.22,
720
+ "grad_norm": 1.0770541429519653,
721
+ "learning_rate": 1.822680033505782e-05,
722
+ "loss": 0.8249,
723
+ "step": 102
724
+ },
725
+ {
726
+ "epoch": 0.22,
727
+ "grad_norm": 0.9607298374176025,
728
+ "learning_rate": 1.8187088444278675e-05,
729
+ "loss": 0.823,
730
+ "step": 103
731
+ },
732
+ {
733
+ "epoch": 0.22,
734
+ "grad_norm": 0.8450298309326172,
735
+ "learning_rate": 1.814698104974604e-05,
736
+ "loss": 0.789,
737
+ "step": 104
738
+ },
739
+ {
740
+ "epoch": 0.22,
741
+ "grad_norm": 1.1690232753753662,
742
+ "learning_rate": 1.8106480088977174e-05,
743
+ "loss": 0.86,
744
+ "step": 105
745
+ },
746
+ {
747
+ "epoch": 0.23,
748
+ "grad_norm": 0.7981148362159729,
749
+ "learning_rate": 1.8065587518501806e-05,
750
+ "loss": 0.8124,
751
+ "step": 106
752
+ },
753
+ {
754
+ "epoch": 0.23,
755
+ "grad_norm": 4.888617992401123,
756
+ "learning_rate": 1.8024305313767648e-05,
757
+ "loss": 0.8107,
758
+ "step": 107
759
+ },
760
+ {
761
+ "epoch": 0.23,
762
+ "grad_norm": 1.9401100873947144,
763
+ "learning_rate": 1.798263546904495e-05,
764
+ "loss": 0.8515,
765
+ "step": 108
766
+ },
767
+ {
768
+ "epoch": 0.23,
769
+ "grad_norm": 1.1880918741226196,
770
+ "learning_rate": 1.7940579997330167e-05,
771
+ "loss": 0.8038,
772
+ "step": 109
773
+ },
774
+ {
775
+ "epoch": 0.24,
776
+ "grad_norm": 1.1192213296890259,
777
+ "learning_rate": 1.7898140930248703e-05,
778
+ "loss": 0.8347,
779
+ "step": 110
780
+ },
781
+ {
782
+ "epoch": 0.24,
783
+ "grad_norm": 1.640434741973877,
784
+ "learning_rate": 1.7855320317956785e-05,
785
+ "loss": 0.8175,
786
+ "step": 111
787
+ },
788
+ {
789
+ "epoch": 0.24,
790
+ "grad_norm": 0.8676750063896179,
791
+ "learning_rate": 1.7812120229042415e-05,
792
+ "loss": 0.844,
793
+ "step": 112
794
+ },
795
+ {
796
+ "epoch": 0.24,
797
+ "grad_norm": 1.189393401145935,
798
+ "learning_rate": 1.7768542750425427e-05,
799
+ "loss": 0.7812,
800
+ "step": 113
801
+ },
802
+ {
803
+ "epoch": 0.24,
804
+ "grad_norm": 0.8515229821205139,
805
+ "learning_rate": 1.7724589987256697e-05,
806
+ "loss": 0.8528,
807
+ "step": 114
808
+ },
809
+ {
810
+ "epoch": 0.25,
811
+ "grad_norm": 1.701436996459961,
812
+ "learning_rate": 1.768026406281642e-05,
813
+ "loss": 0.8155,
814
+ "step": 115
815
+ },
816
+ {
817
+ "epoch": 0.25,
818
+ "grad_norm": 1.211014986038208,
819
+ "learning_rate": 1.7635567118411568e-05,
820
+ "loss": 0.8411,
821
+ "step": 116
822
+ },
823
+ {
824
+ "epoch": 0.25,
825
+ "grad_norm": 1.2689430713653564,
826
+ "learning_rate": 1.7590501313272415e-05,
827
+ "loss": 0.8213,
828
+ "step": 117
829
+ },
830
+ {
831
+ "epoch": 0.25,
832
+ "grad_norm": 1.6434332132339478,
833
+ "learning_rate": 1.7545068824448255e-05,
834
+ "loss": 0.8233,
835
+ "step": 118
836
+ },
837
+ {
838
+ "epoch": 0.25,
839
+ "grad_norm": 1.337048888206482,
840
+ "learning_rate": 1.7499271846702216e-05,
841
+ "loss": 0.8001,
842
+ "step": 119
843
+ },
844
+ {
845
+ "epoch": 0.26,
846
+ "grad_norm": 1.2411539554595947,
847
+ "learning_rate": 1.7453112592405245e-05,
848
+ "loss": 0.8476,
849
+ "step": 120
850
+ },
851
+ {
852
+ "epoch": 0.26,
853
+ "grad_norm": 1.685616374015808,
854
+ "learning_rate": 1.740659329142922e-05,
855
+ "loss": 0.7684,
856
+ "step": 121
857
+ },
858
+ {
859
+ "epoch": 0.26,
860
+ "grad_norm": 1.5003278255462646,
861
+ "learning_rate": 1.7359716191039248e-05,
862
+ "loss": 0.8474,
863
+ "step": 122
864
+ },
865
+ {
866
+ "epoch": 0.26,
867
+ "grad_norm": 1.7008017301559448,
868
+ "learning_rate": 1.7312483555785087e-05,
869
+ "loss": 0.8115,
870
+ "step": 123
871
+ },
872
+ {
873
+ "epoch": 0.27,
874
+ "grad_norm": 1.7307039499282837,
875
+ "learning_rate": 1.7264897667391757e-05,
876
+ "loss": 0.8066,
877
+ "step": 124
878
+ },
879
+ {
880
+ "epoch": 0.27,
881
+ "grad_norm": 0.9101468324661255,
882
+ "learning_rate": 1.7216960824649304e-05,
883
+ "loss": 0.8238,
884
+ "step": 125
885
+ },
886
+ {
887
+ "epoch": 0.27,
888
+ "grad_norm": 1.103602647781372,
889
+ "learning_rate": 1.7168675343301768e-05,
890
+ "loss": 0.8162,
891
+ "step": 126
892
+ },
893
+ {
894
+ "epoch": 0.27,
895
+ "grad_norm": 1.2054574489593506,
896
+ "learning_rate": 1.71200435559353e-05,
897
+ "loss": 0.8254,
898
+ "step": 127
899
+ },
900
+ {
901
+ "epoch": 0.27,
902
+ "grad_norm": 1.234502911567688,
903
+ "learning_rate": 1.7071067811865477e-05,
904
+ "loss": 0.8034,
905
+ "step": 128
906
+ },
907
+ {
908
+ "epoch": 0.28,
909
+ "grad_norm": 1.1067628860473633,
910
+ "learning_rate": 1.7021750477023823e-05,
911
+ "loss": 0.7755,
912
+ "step": 129
913
+ },
914
+ {
915
+ "epoch": 0.28,
916
+ "grad_norm": 0.7321228384971619,
917
+ "learning_rate": 1.69720939338435e-05,
918
+ "loss": 0.8535,
919
+ "step": 130
920
+ },
921
+ {
922
+ "epoch": 0.28,
923
+ "grad_norm": 0.8689684867858887,
924
+ "learning_rate": 1.6922100581144228e-05,
925
+ "loss": 0.7752,
926
+ "step": 131
927
+ },
928
+ {
929
+ "epoch": 0.28,
930
+ "grad_norm": 0.9249204397201538,
931
+ "learning_rate": 1.6871772834016406e-05,
932
+ "loss": 0.8373,
933
+ "step": 132
934
+ },
935
+ {
936
+ "epoch": 0.28,
937
+ "grad_norm": 0.8648712635040283,
938
+ "learning_rate": 1.6821113123704425e-05,
939
+ "loss": 0.7638,
940
+ "step": 133
941
+ },
942
+ {
943
+ "epoch": 0.29,
944
+ "grad_norm": 0.8067061901092529,
945
+ "learning_rate": 1.677012389748923e-05,
946
+ "loss": 0.8038,
947
+ "step": 134
948
+ },
949
+ {
950
+ "epoch": 0.29,
951
+ "grad_norm": 0.8623146414756775,
952
+ "learning_rate": 1.671880761857011e-05,
953
+ "loss": 0.8298,
954
+ "step": 135
955
+ },
956
+ {
957
+ "epoch": 0.29,
958
+ "grad_norm": 0.8998252153396606,
959
+ "learning_rate": 1.666716676594567e-05,
960
+ "loss": 0.7686,
961
+ "step": 136
962
+ },
963
+ {
964
+ "epoch": 0.29,
965
+ "grad_norm": 0.9564909934997559,
966
+ "learning_rate": 1.661520383429412e-05,
967
+ "loss": 0.8418,
968
+ "step": 137
969
+ },
970
+ {
971
+ "epoch": 0.3,
972
+ "grad_norm": 0.7597609758377075,
973
+ "learning_rate": 1.6562921333852714e-05,
974
+ "loss": 0.7976,
975
+ "step": 138
976
+ },
977
+ {
978
+ "epoch": 0.3,
979
+ "grad_norm": 1.064211130142212,
980
+ "learning_rate": 1.6510321790296527e-05,
981
+ "loss": 0.8479,
982
+ "step": 139
983
+ },
984
+ {
985
+ "epoch": 0.3,
986
+ "grad_norm": 1.1456950902938843,
987
+ "learning_rate": 1.6457407744616417e-05,
988
+ "loss": 0.7806,
989
+ "step": 140
990
+ },
991
+ {
992
+ "epoch": 0.3,
993
+ "grad_norm": 0.8875635862350464,
994
+ "learning_rate": 1.6404181752996287e-05,
995
+ "loss": 0.8191,
996
+ "step": 141
997
+ },
998
+ {
999
+ "epoch": 0.3,
1000
+ "grad_norm": 1.0326021909713745,
1001
+ "learning_rate": 1.6350646386689593e-05,
1002
+ "loss": 0.8086,
1003
+ "step": 142
1004
+ },
1005
+ {
1006
+ "epoch": 0.31,
1007
+ "grad_norm": 0.8035858273506165,
1008
+ "learning_rate": 1.629680423189514e-05,
1009
+ "loss": 0.7771,
1010
+ "step": 143
1011
+ },
1012
+ {
1013
+ "epoch": 0.31,
1014
+ "grad_norm": 0.8190425634384155,
1015
+ "learning_rate": 1.6242657889632133e-05,
1016
+ "loss": 0.8167,
1017
+ "step": 144
1018
+ },
1019
+ {
1020
+ "epoch": 0.31,
1021
+ "grad_norm": 0.8339990377426147,
1022
+ "learning_rate": 1.618820997561454e-05,
1023
+ "loss": 0.8068,
1024
+ "step": 145
1025
+ },
1026
+ {
1027
+ "epoch": 0.31,
1028
+ "grad_norm": 0.827274739742279,
1029
+ "learning_rate": 1.613346312012473e-05,
1030
+ "loss": 0.817,
1031
+ "step": 146
1032
+ },
1033
+ {
1034
+ "epoch": 0.31,
1035
+ "grad_norm": 0.7203758955001831,
1036
+ "learning_rate": 1.6078419967886402e-05,
1037
+ "loss": 0.8122,
1038
+ "step": 147
1039
+ },
1040
+ {
1041
+ "epoch": 0.32,
1042
+ "grad_norm": 0.7495682835578918,
1043
+ "learning_rate": 1.6023083177936824e-05,
1044
+ "loss": 0.7676,
1045
+ "step": 148
1046
+ },
1047
+ {
1048
+ "epoch": 0.32,
1049
+ "grad_norm": 0.6958379745483398,
1050
+ "learning_rate": 1.5967455423498387e-05,
1051
+ "loss": 0.8305,
1052
+ "step": 149
1053
+ },
1054
+ {
1055
+ "epoch": 0.32,
1056
+ "grad_norm": 0.99383944272995,
1057
+ "learning_rate": 1.591153939184946e-05,
1058
+ "loss": 0.7984,
1059
+ "step": 150
1060
+ },
1061
+ {
1062
+ "epoch": 0.32,
1063
+ "grad_norm": 0.829394519329071,
1064
+ "learning_rate": 1.5855337784194576e-05,
1065
+ "loss": 0.8008,
1066
+ "step": 151
1067
+ },
1068
+ {
1069
+ "epoch": 0.33,
1070
+ "grad_norm": 0.7945014834403992,
1071
+ "learning_rate": 1.5798853315533932e-05,
1072
+ "loss": 0.7504,
1073
+ "step": 152
1074
+ },
1075
+ {
1076
+ "epoch": 0.33,
1077
+ "grad_norm": 0.7520758509635925,
1078
+ "learning_rate": 1.5742088714532247e-05,
1079
+ "loss": 0.8346,
1080
+ "step": 153
1081
+ },
1082
+ {
1083
+ "epoch": 0.33,
1084
+ "grad_norm": 0.8301789164543152,
1085
+ "learning_rate": 1.568504672338694e-05,
1086
+ "loss": 0.7719,
1087
+ "step": 154
1088
+ },
1089
+ {
1090
+ "epoch": 0.33,
1091
+ "grad_norm": 1.3911187648773193,
1092
+ "learning_rate": 1.562773009769564e-05,
1093
+ "loss": 0.8335,
1094
+ "step": 155
1095
+ },
1096
+ {
1097
+ "epoch": 0.33,
1098
+ "grad_norm": 1.039931297302246,
1099
+ "learning_rate": 1.5570141606323105e-05,
1100
+ "loss": 0.7892,
1101
+ "step": 156
1102
+ },
1103
+ {
1104
+ "epoch": 0.34,
1105
+ "grad_norm": 0.801042377948761,
1106
+ "learning_rate": 1.551228403126744e-05,
1107
+ "loss": 0.8124,
1108
+ "step": 157
1109
+ },
1110
+ {
1111
+ "epoch": 0.34,
1112
+ "grad_norm": 1.0106760263442993,
1113
+ "learning_rate": 1.5454160167525688e-05,
1114
+ "loss": 0.7651,
1115
+ "step": 158
1116
+ },
1117
+ {
1118
+ "epoch": 0.34,
1119
+ "grad_norm": 0.7811651825904846,
1120
+ "learning_rate": 1.5395772822958844e-05,
1121
+ "loss": 0.8168,
1122
+ "step": 159
1123
+ },
1124
+ {
1125
+ "epoch": 0.34,
1126
+ "grad_norm": 0.8879010081291199,
1127
+ "learning_rate": 1.5337124818156203e-05,
1128
+ "loss": 0.7364,
1129
+ "step": 160
1130
+ },
1131
+ {
1132
+ "epoch": 0.34,
1133
+ "grad_norm": 0.6862936019897461,
1134
+ "learning_rate": 1.5278218986299074e-05,
1135
+ "loss": 0.8275,
1136
+ "step": 161
1137
+ },
1138
+ {
1139
+ "epoch": 0.35,
1140
+ "grad_norm": 0.9153168797492981,
1141
+ "learning_rate": 1.5219058173023948e-05,
1142
+ "loss": 0.7984,
1143
+ "step": 162
1144
+ },
1145
+ {
1146
+ "epoch": 0.35,
1147
+ "grad_norm": 0.8116987943649292,
1148
+ "learning_rate": 1.515964523628501e-05,
1149
+ "loss": 0.7689,
1150
+ "step": 163
1151
+ },
1152
+ {
1153
+ "epoch": 0.35,
1154
+ "grad_norm": 0.7810778617858887,
1155
+ "learning_rate": 1.5099983046216089e-05,
1156
+ "loss": 0.7985,
1157
+ "step": 164
1158
+ },
1159
+ {
1160
+ "epoch": 0.35,
1161
+ "grad_norm": 0.6745201945304871,
1162
+ "learning_rate": 1.5040074484992e-05,
1163
+ "loss": 0.8015,
1164
+ "step": 165
1165
+ },
1166
+ {
1167
+ "epoch": 0.36,
1168
+ "grad_norm": 0.9147999286651611,
1169
+ "learning_rate": 1.4979922446689308e-05,
1170
+ "loss": 0.8264,
1171
+ "step": 166
1172
+ },
1173
+ {
1174
+ "epoch": 0.36,
1175
+ "grad_norm": 0.8092418313026428,
1176
+ "learning_rate": 1.4919529837146529e-05,
1177
+ "loss": 0.743,
1178
+ "step": 167
1179
+ },
1180
+ {
1181
+ "epoch": 0.36,
1182
+ "grad_norm": 0.8291578888893127,
1183
+ "learning_rate": 1.4858899573823752e-05,
1184
+ "loss": 0.786,
1185
+ "step": 168
1186
+ },
1187
+ {
1188
+ "epoch": 0.36,
1189
+ "grad_norm": 0.6807591915130615,
1190
+ "learning_rate": 1.4798034585661696e-05,
1191
+ "loss": 0.8155,
1192
+ "step": 169
1193
+ },
1194
+ {
1195
+ "epoch": 0.36,
1196
+ "grad_norm": 0.8842042088508606,
1197
+ "learning_rate": 1.4736937812940217e-05,
1198
+ "loss": 0.7765,
1199
+ "step": 170
1200
+ },
1201
+ {
1202
+ "epoch": 0.37,
1203
+ "grad_norm": 0.8237358927726746,
1204
+ "learning_rate": 1.4675612207136283e-05,
1205
+ "loss": 0.7783,
1206
+ "step": 171
1207
+ },
1208
+ {
1209
+ "epoch": 0.37,
1210
+ "grad_norm": 0.661469578742981,
1211
+ "learning_rate": 1.4614060730781377e-05,
1212
+ "loss": 0.7716,
1213
+ "step": 172
1214
+ },
1215
+ {
1216
+ "epoch": 0.37,
1217
+ "grad_norm": 0.7561662197113037,
1218
+ "learning_rate": 1.455228635731839e-05,
1219
+ "loss": 0.7934,
1220
+ "step": 173
1221
+ },
1222
+ {
1223
+ "epoch": 0.37,
1224
+ "grad_norm": 0.6873330473899841,
1225
+ "learning_rate": 1.4490292070957978e-05,
1226
+ "loss": 0.7654,
1227
+ "step": 174
1228
+ },
1229
+ {
1230
+ "epoch": 0.37,
1231
+ "grad_norm": 0.7876659631729126,
1232
+ "learning_rate": 1.4428080866534397e-05,
1233
+ "loss": 0.7754,
1234
+ "step": 175
1235
+ },
1236
+ {
1237
+ "epoch": 0.38,
1238
+ "grad_norm": 0.7588985562324524,
1239
+ "learning_rate": 1.4365655749360833e-05,
1240
+ "loss": 0.8073,
1241
+ "step": 176
1242
+ },
1243
+ {
1244
+ "epoch": 0.38,
1245
+ "grad_norm": 1.0146478414535522,
1246
+ "learning_rate": 1.4303019735084225e-05,
1247
+ "loss": 0.8115,
1248
+ "step": 177
1249
+ },
1250
+ {
1251
+ "epoch": 0.38,
1252
+ "grad_norm": 1.0474367141723633,
1253
+ "learning_rate": 1.4240175849539566e-05,
1254
+ "loss": 0.7662,
1255
+ "step": 178
1256
+ },
1257
+ {
1258
+ "epoch": 0.38,
1259
+ "grad_norm": 0.8567104935646057,
1260
+ "learning_rate": 1.4177127128603748e-05,
1261
+ "loss": 0.8192,
1262
+ "step": 179
1263
+ },
1264
+ {
1265
+ "epoch": 0.39,
1266
+ "grad_norm": 0.7369076609611511,
1267
+ "learning_rate": 1.4113876618048896e-05,
1268
+ "loss": 0.7796,
1269
+ "step": 180
1270
+ },
1271
+ {
1272
+ "epoch": 0.39,
1273
+ "grad_norm": 1.021986961364746,
1274
+ "learning_rate": 1.4050427373395241e-05,
1275
+ "loss": 0.743,
1276
+ "step": 181
1277
+ },
1278
+ {
1279
+ "epoch": 0.39,
1280
+ "grad_norm": 0.799660325050354,
1281
+ "learning_rate": 1.3986782459763499e-05,
1282
+ "loss": 0.7985,
1283
+ "step": 182
1284
+ },
1285
+ {
1286
+ "epoch": 0.39,
1287
+ "grad_norm": 0.9272130131721497,
1288
+ "learning_rate": 1.3922944951726811e-05,
1289
+ "loss": 0.7779,
1290
+ "step": 183
1291
+ },
1292
+ {
1293
+ "epoch": 0.39,
1294
+ "grad_norm": 0.7575782537460327,
1295
+ "learning_rate": 1.3858917933162212e-05,
1296
+ "loss": 0.8191,
1297
+ "step": 184
1298
+ },
1299
+ {
1300
+ "epoch": 0.4,
1301
+ "grad_norm": 0.7355371117591858,
1302
+ "learning_rate": 1.3794704497101656e-05,
1303
+ "loss": 0.7801,
1304
+ "step": 185
1305
+ },
1306
+ {
1307
+ "epoch": 0.4,
1308
+ "grad_norm": 1.0130892992019653,
1309
+ "learning_rate": 1.3730307745582594e-05,
1310
+ "loss": 0.8038,
1311
+ "step": 186
1312
+ },
1313
+ {
1314
+ "epoch": 0.4,
1315
+ "grad_norm": 0.7086817622184753,
1316
+ "learning_rate": 1.366573078949813e-05,
1317
+ "loss": 0.7514,
1318
+ "step": 187
1319
+ },
1320
+ {
1321
+ "epoch": 0.4,
1322
+ "grad_norm": 0.8990337252616882,
1323
+ "learning_rate": 1.3600976748446722e-05,
1324
+ "loss": 0.8257,
1325
+ "step": 188
1326
+ },
1327
+ {
1328
+ "epoch": 0.4,
1329
+ "grad_norm": 2.2387804985046387,
1330
+ "learning_rate": 1.3536048750581494e-05,
1331
+ "loss": 0.7783,
1332
+ "step": 189
1333
+ },
1334
+ {
1335
+ "epoch": 0.41,
1336
+ "grad_norm": 0.8208438754081726,
1337
+ "learning_rate": 1.3470949932459116e-05,
1338
+ "loss": 0.7705,
1339
+ "step": 190
1340
+ },
1341
+ {
1342
+ "epoch": 0.41,
1343
+ "grad_norm": 0.7612828612327576,
1344
+ "learning_rate": 1.3405683438888281e-05,
1345
+ "loss": 0.7839,
1346
+ "step": 191
1347
+ },
1348
+ {
1349
+ "epoch": 0.41,
1350
+ "grad_norm": 0.6562499403953552,
1351
+ "learning_rate": 1.3340252422777788e-05,
1352
+ "loss": 0.8068,
1353
+ "step": 192
1354
+ },
1355
+ {
1356
+ "epoch": 0.41,
1357
+ "grad_norm": 0.784289538860321,
1358
+ "learning_rate": 1.3274660044984225e-05,
1359
+ "loss": 0.8028,
1360
+ "step": 193
1361
+ },
1362
+ {
1363
+ "epoch": 0.42,
1364
+ "grad_norm": 0.7543610334396362,
1365
+ "learning_rate": 1.3208909474159279e-05,
1366
+ "loss": 0.7688,
1367
+ "step": 194
1368
+ },
1369
+ {
1370
+ "epoch": 0.42,
1371
+ "grad_norm": 0.6484317779541016,
1372
+ "learning_rate": 1.314300388659667e-05,
1373
+ "loss": 0.8161,
1374
+ "step": 195
1375
+ },
1376
+ {
1377
+ "epoch": 0.42,
1378
+ "grad_norm": 0.9675614833831787,
1379
+ "learning_rate": 1.3076946466078691e-05,
1380
+ "loss": 0.7715,
1381
+ "step": 196
1382
+ },
1383
+ {
1384
+ "epoch": 0.42,
1385
+ "grad_norm": 0.7833143472671509,
1386
+ "learning_rate": 1.301074040372242e-05,
1387
+ "loss": 0.7748,
1388
+ "step": 197
1389
+ },
1390
+ {
1391
+ "epoch": 0.42,
1392
+ "grad_norm": 0.9204770922660828,
1393
+ "learning_rate": 1.2944388897825559e-05,
1394
+ "loss": 0.7725,
1395
+ "step": 198
1396
+ },
1397
+ {
1398
+ "epoch": 0.43,
1399
+ "grad_norm": 0.789046049118042,
1400
+ "learning_rate": 1.2877895153711935e-05,
1401
+ "loss": 0.7526,
1402
+ "step": 199
1403
+ },
1404
+ {
1405
+ "epoch": 0.43,
1406
+ "grad_norm": 0.7482736706733704,
1407
+ "learning_rate": 1.2811262383576646e-05,
1408
+ "loss": 0.8268,
1409
+ "step": 200
1410
+ },
1411
+ {
1412
+ "epoch": 0.43,
1413
+ "grad_norm": 0.9972829222679138,
1414
+ "learning_rate": 1.274449380633089e-05,
1415
+ "loss": 0.7505,
1416
+ "step": 201
1417
+ },
1418
+ {
1419
+ "epoch": 0.43,
1420
+ "grad_norm": 0.727154016494751,
1421
+ "learning_rate": 1.2677592647446472e-05,
1422
+ "loss": 0.7953,
1423
+ "step": 202
1424
+ },
1425
+ {
1426
+ "epoch": 0.43,
1427
+ "grad_norm": 0.7113155126571655,
1428
+ "learning_rate": 1.2610562138799977e-05,
1429
+ "loss": 0.7877,
1430
+ "step": 203
1431
+ },
1432
+ {
1433
+ "epoch": 0.44,
1434
+ "grad_norm": 0.7132176756858826,
1435
+ "learning_rate": 1.2543405518516651e-05,
1436
+ "loss": 0.8088,
1437
+ "step": 204
1438
+ },
1439
+ {
1440
+ "epoch": 0.44,
1441
+ "grad_norm": 2.328761339187622,
1442
+ "learning_rate": 1.2476126030813964e-05,
1443
+ "loss": 0.7521,
1444
+ "step": 205
1445
+ },
1446
+ {
1447
+ "epoch": 0.44,
1448
+ "grad_norm": 0.6288961172103882,
1449
+ "learning_rate": 1.24087269258449e-05,
1450
+ "loss": 0.7779,
1451
+ "step": 206
1452
+ },
1453
+ {
1454
+ "epoch": 0.44,
1455
+ "grad_norm": 0.7735608816146851,
1456
+ "learning_rate": 1.234121145954094e-05,
1457
+ "loss": 0.7624,
1458
+ "step": 207
1459
+ },
1460
+ {
1461
+ "epoch": 0.45,
1462
+ "grad_norm": 0.845016598701477,
1463
+ "learning_rate": 1.2273582893454774e-05,
1464
+ "loss": 0.7804,
1465
+ "step": 208
1466
+ },
1467
+ {
1468
+ "epoch": 0.45,
1469
+ "grad_norm": 0.8225258588790894,
1470
+ "learning_rate": 1.2205844494602741e-05,
1471
+ "loss": 0.7665,
1472
+ "step": 209
1473
+ },
1474
+ {
1475
+ "epoch": 0.45,
1476
+ "grad_norm": 0.9022204875946045,
1477
+ "learning_rate": 1.213799953530701e-05,
1478
+ "loss": 0.7671,
1479
+ "step": 210
1480
+ },
1481
+ {
1482
+ "epoch": 0.45,
1483
+ "grad_norm": 0.7139418721199036,
1484
+ "learning_rate": 1.2070051293037493e-05,
1485
+ "loss": 0.8215,
1486
+ "step": 211
1487
+ },
1488
+ {
1489
+ "epoch": 0.45,
1490
+ "grad_norm": 1.054016351699829,
1491
+ "learning_rate": 1.2002003050253524e-05,
1492
+ "loss": 0.7387,
1493
+ "step": 212
1494
+ },
1495
+ {
1496
+ "epoch": 0.46,
1497
+ "grad_norm": 0.7111931443214417,
1498
+ "learning_rate": 1.1933858094245281e-05,
1499
+ "loss": 0.8172,
1500
+ "step": 213
1501
+ },
1502
+ {
1503
+ "epoch": 0.46,
1504
+ "grad_norm": 0.7568846940994263,
1505
+ "learning_rate": 1.1865619716974986e-05,
1506
+ "loss": 0.745,
1507
+ "step": 214
1508
+ },
1509
+ {
1510
+ "epoch": 0.46,
1511
+ "grad_norm": 0.8243083953857422,
1512
+ "learning_rate": 1.1797291214917882e-05,
1513
+ "loss": 0.8177,
1514
+ "step": 215
1515
+ },
1516
+ {
1517
+ "epoch": 0.46,
1518
+ "grad_norm": 0.8872765898704529,
1519
+ "learning_rate": 1.1728875888902975e-05,
1520
+ "loss": 0.7488,
1521
+ "step": 216
1522
+ },
1523
+ {
1524
+ "epoch": 0.46,
1525
+ "grad_norm": 0.9252672791481018,
1526
+ "learning_rate": 1.1660377043953588e-05,
1527
+ "loss": 0.7788,
1528
+ "step": 217
1529
+ },
1530
+ {
1531
+ "epoch": 0.47,
1532
+ "grad_norm": 0.7096117734909058,
1533
+ "learning_rate": 1.1591797989127691e-05,
1534
+ "loss": 0.7839,
1535
+ "step": 218
1536
+ },
1537
+ {
1538
+ "epoch": 0.47,
1539
+ "grad_norm": 1.0665735006332397,
1540
+ "learning_rate": 1.152314203735805e-05,
1541
+ "loss": 0.7926,
1542
+ "step": 219
1543
+ },
1544
+ {
1545
+ "epoch": 0.47,
1546
+ "grad_norm": 0.9210519790649414,
1547
+ "learning_rate": 1.14544125052922e-05,
1548
+ "loss": 0.7637,
1549
+ "step": 220
1550
+ },
1551
+ {
1552
+ "epoch": 0.47,
1553
+ "grad_norm": 0.7430177927017212,
1554
+ "learning_rate": 1.1385612713132191e-05,
1555
+ "loss": 0.7781,
1556
+ "step": 221
1557
+ },
1558
+ {
1559
+ "epoch": 0.48,
1560
+ "grad_norm": 0.6779014468193054,
1561
+ "learning_rate": 1.1316745984474227e-05,
1562
+ "loss": 0.7843,
1563
+ "step": 222
1564
+ },
1565
+ {
1566
+ "epoch": 0.48,
1567
+ "grad_norm": 0.9180762767791748,
1568
+ "learning_rate": 1.1247815646148088e-05,
1569
+ "loss": 0.7957,
1570
+ "step": 223
1571
+ },
1572
+ {
1573
+ "epoch": 0.48,
1574
+ "grad_norm": 0.9458864331245422,
1575
+ "learning_rate": 1.117882502805643e-05,
1576
+ "loss": 0.8011,
1577
+ "step": 224
1578
+ },
1579
+ {
1580
+ "epoch": 0.48,
1581
+ "grad_norm": 0.9582037925720215,
1582
+ "learning_rate": 1.1109777463013915e-05,
1583
+ "loss": 0.743,
1584
+ "step": 225
1585
+ },
1586
+ {
1587
+ "epoch": 0.48,
1588
+ "grad_norm": 0.9602392911911011,
1589
+ "learning_rate": 1.1040676286586212e-05,
1590
+ "loss": 0.7724,
1591
+ "step": 226
1592
+ },
1593
+ {
1594
+ "epoch": 0.49,
1595
+ "grad_norm": 0.6800273060798645,
1596
+ "learning_rate": 1.097152483692886e-05,
1597
+ "loss": 0.8166,
1598
+ "step": 227
1599
+ },
1600
+ {
1601
+ "epoch": 0.49,
1602
+ "grad_norm": 0.7832956314086914,
1603
+ "learning_rate": 1.0902326454626012e-05,
1604
+ "loss": 0.7304,
1605
+ "step": 228
1606
+ },
1607
+ {
1608
+ "epoch": 0.49,
1609
+ "grad_norm": 0.8246415853500366,
1610
+ "learning_rate": 1.0833084482529048e-05,
1611
+ "loss": 0.8128,
1612
+ "step": 229
1613
+ },
1614
+ {
1615
+ "epoch": 0.49,
1616
+ "grad_norm": 0.748362123966217,
1617
+ "learning_rate": 1.0763802265595103e-05,
1618
+ "loss": 0.7449,
1619
+ "step": 230
1620
+ },
1621
+ {
1622
+ "epoch": 0.49,
1623
+ "grad_norm": 0.7535527348518372,
1624
+ "learning_rate": 1.0694483150725458e-05,
1625
+ "loss": 0.8146,
1626
+ "step": 231
1627
+ },
1628
+ {
1629
+ "epoch": 0.5,
1630
+ "grad_norm": 1.175562858581543,
1631
+ "learning_rate": 1.0625130486603879e-05,
1632
+ "loss": 0.7621,
1633
+ "step": 232
1634
+ },
1635
+ {
1636
+ "epoch": 0.5,
1637
+ "grad_norm": 1.3196650743484497,
1638
+ "learning_rate": 1.055574762353483e-05,
1639
+ "loss": 0.7666,
1640
+ "step": 233
1641
+ },
1642
+ {
1643
+ "epoch": 0.5,
1644
+ "grad_norm": 0.9028423428535461,
1645
+ "learning_rate": 1.0486337913281633e-05,
1646
+ "loss": 0.8021,
1647
+ "step": 234
1648
+ },
1649
+ {
1650
+ "epoch": 0.5,
1651
+ "grad_norm": 1.6453968286514282,
1652
+ "learning_rate": 1.041690470890455e-05,
1653
+ "loss": 0.7432,
1654
+ "step": 235
1655
+ },
1656
+ {
1657
+ "epoch": 0.51,
1658
+ "grad_norm": 0.6863554120063782,
1659
+ "learning_rate": 1.0347451364598805e-05,
1660
+ "loss": 0.7589,
1661
+ "step": 236
1662
+ },
1663
+ {
1664
+ "epoch": 0.51,
1665
+ "grad_norm": 0.7064222693443298,
1666
+ "learning_rate": 1.0277981235532541e-05,
1667
+ "loss": 0.7894,
1668
+ "step": 237
1669
+ },
1670
+ {
1671
+ "epoch": 0.51,
1672
+ "grad_norm": 0.941105306148529,
1673
+ "learning_rate": 1.0208497677684755e-05,
1674
+ "loss": 0.7692,
1675
+ "step": 238
1676
+ },
1677
+ {
1678
+ "epoch": 0.51,
1679
+ "grad_norm": 0.8583152294158936,
1680
+ "learning_rate": 1.0139004047683152e-05,
1681
+ "loss": 0.7511,
1682
+ "step": 239
1683
+ },
1684
+ {
1685
+ "epoch": 0.51,
1686
+ "grad_norm": 1.788122296333313,
1687
+ "learning_rate": 1.0069503702642011e-05,
1688
+ "loss": 0.7827,
1689
+ "step": 240
1690
+ },
1691
+ {
1692
+ "epoch": 0.52,
1693
+ "grad_norm": 0.7599253058433533,
1694
+ "learning_rate": 1e-05,
1695
+ "loss": 0.7404,
1696
+ "step": 241
1697
+ },
1698
+ {
1699
+ "epoch": 0.52,
1700
+ "grad_norm": 0.70832759141922,
1701
+ "learning_rate": 9.930496297357994e-06,
1702
+ "loss": 0.816,
1703
+ "step": 242
1704
+ },
1705
+ {
1706
+ "epoch": 0.52,
1707
+ "grad_norm": 0.618488609790802,
1708
+ "learning_rate": 9.860995952316851e-06,
1709
+ "loss": 0.7423,
1710
+ "step": 243
1711
+ },
1712
+ {
1713
+ "epoch": 0.52,
1714
+ "grad_norm": 0.7473293542861938,
1715
+ "learning_rate": 9.791502322315249e-06,
1716
+ "loss": 0.7795,
1717
+ "step": 244
1718
+ },
1719
+ {
1720
+ "epoch": 0.52,
1721
+ "grad_norm": 0.7179411053657532,
1722
+ "learning_rate": 9.72201876446746e-06,
1723
+ "loss": 0.7848,
1724
+ "step": 245
1725
+ },
1726
+ {
1727
+ "epoch": 0.53,
1728
+ "grad_norm": 0.7268110513687134,
1729
+ "learning_rate": 9.6525486354012e-06,
1730
+ "loss": 0.7148,
1731
+ "step": 246
1732
+ },
1733
+ {
1734
+ "epoch": 0.53,
1735
+ "grad_norm": 0.7188340425491333,
1736
+ "learning_rate": 9.583095291095454e-06,
1737
+ "loss": 0.8226,
1738
+ "step": 247
1739
+ },
1740
+ {
1741
+ "epoch": 0.53,
1742
+ "grad_norm": 1.236936092376709,
1743
+ "learning_rate": 9.513662086718372e-06,
1744
+ "loss": 0.7436,
1745
+ "step": 248
1746
+ },
1747
+ {
1748
+ "epoch": 0.53,
1749
+ "grad_norm": 1.1389752626419067,
1750
+ "learning_rate": 9.444252376465171e-06,
1751
+ "loss": 0.7829,
1752
+ "step": 249
1753
+ },
1754
+ {
1755
+ "epoch": 0.54,
1756
+ "grad_norm": 0.7218825221061707,
1757
+ "learning_rate": 9.374869513396123e-06,
1758
+ "loss": 0.7686,
1759
+ "step": 250
1760
+ },
1761
+ {
1762
+ "epoch": 0.54,
1763
+ "grad_norm": 0.9832814931869507,
1764
+ "learning_rate": 9.305516849274542e-06,
1765
+ "loss": 0.7705,
1766
+ "step": 251
1767
+ },
1768
+ {
1769
+ "epoch": 0.54,
1770
+ "grad_norm": 0.7646653652191162,
1771
+ "learning_rate": 9.2361977344049e-06,
1772
+ "loss": 0.7855,
1773
+ "step": 252
1774
+ },
1775
+ {
1776
+ "epoch": 0.54,
1777
+ "grad_norm": 1.1526681184768677,
1778
+ "learning_rate": 9.166915517470953e-06,
1779
+ "loss": 0.7537,
1780
+ "step": 253
1781
+ },
1782
+ {
1783
+ "epoch": 0.54,
1784
+ "grad_norm": 0.7619354128837585,
1785
+ "learning_rate": 9.09767354537399e-06,
1786
+ "loss": 0.807,
1787
+ "step": 254
1788
+ },
1789
+ {
1790
+ "epoch": 0.55,
1791
+ "grad_norm": 0.7558615207672119,
1792
+ "learning_rate": 9.028475163071142e-06,
1793
+ "loss": 0.7571,
1794
+ "step": 255
1795
+ },
1796
+ {
1797
+ "epoch": 0.55,
1798
+ "grad_norm": 0.8925402164459229,
1799
+ "learning_rate": 8.959323713413792e-06,
1800
+ "loss": 0.7655,
1801
+ "step": 256
1802
+ },
1803
+ {
1804
+ "epoch": 0.55,
1805
+ "grad_norm": 1.0979193449020386,
1806
+ "learning_rate": 8.890222536986085e-06,
1807
+ "loss": 0.7738,
1808
+ "step": 257
1809
+ },
1810
+ {
1811
+ "epoch": 0.55,
1812
+ "grad_norm": 0.7939193248748779,
1813
+ "learning_rate": 8.821174971943573e-06,
1814
+ "loss": 0.7993,
1815
+ "step": 258
1816
+ },
1817
+ {
1818
+ "epoch": 0.55,
1819
+ "grad_norm": 0.8744078278541565,
1820
+ "learning_rate": 8.752184353851917e-06,
1821
+ "loss": 0.7523,
1822
+ "step": 259
1823
+ },
1824
+ {
1825
+ "epoch": 0.56,
1826
+ "grad_norm": 0.7564054727554321,
1827
+ "learning_rate": 8.683254015525776e-06,
1828
+ "loss": 0.7687,
1829
+ "step": 260
1830
+ },
1831
+ {
1832
+ "epoch": 0.56,
1833
+ "grad_norm": 0.7039680480957031,
1834
+ "learning_rate": 8.614387286867814e-06,
1835
+ "loss": 0.7861,
1836
+ "step": 261
1837
+ },
1838
+ {
1839
+ "epoch": 0.56,
1840
+ "grad_norm": 0.7174641489982605,
1841
+ "learning_rate": 8.545587494707803e-06,
1842
+ "loss": 0.7807,
1843
+ "step": 262
1844
+ },
1845
+ {
1846
+ "epoch": 0.56,
1847
+ "grad_norm": 1.087963342666626,
1848
+ "learning_rate": 8.476857962641951e-06,
1849
+ "loss": 0.7467,
1850
+ "step": 263
1851
+ },
1852
+ {
1853
+ "epoch": 0.57,
1854
+ "grad_norm": 0.7684459090232849,
1855
+ "learning_rate": 8.408202010872312e-06,
1856
+ "loss": 0.7567,
1857
+ "step": 264
1858
+ },
1859
+ {
1860
+ "epoch": 0.57,
1861
+ "grad_norm": 0.7861683964729309,
1862
+ "learning_rate": 8.339622956046417e-06,
1863
+ "loss": 0.7847,
1864
+ "step": 265
1865
+ },
1866
+ {
1867
+ "epoch": 0.57,
1868
+ "grad_norm": 0.7736767530441284,
1869
+ "learning_rate": 8.271124111097026e-06,
1870
+ "loss": 0.7639,
1871
+ "step": 266
1872
+ },
1873
+ {
1874
+ "epoch": 0.57,
1875
+ "grad_norm": 0.8550255298614502,
1876
+ "learning_rate": 8.202708785082122e-06,
1877
+ "loss": 0.7774,
1878
+ "step": 267
1879
+ },
1880
+ {
1881
+ "epoch": 0.57,
1882
+ "grad_norm": 0.7113362550735474,
1883
+ "learning_rate": 8.134380283025014e-06,
1884
+ "loss": 0.785,
1885
+ "step": 268
1886
+ },
1887
+ {
1888
+ "epoch": 0.58,
1889
+ "grad_norm": 0.7256139516830444,
1890
+ "learning_rate": 8.066141905754724e-06,
1891
+ "loss": 0.7625,
1892
+ "step": 269
1893
+ },
1894
+ {
1895
+ "epoch": 0.58,
1896
+ "grad_norm": 0.7075888514518738,
1897
+ "learning_rate": 7.997996949746478e-06,
1898
+ "loss": 0.7464,
1899
+ "step": 270
1900
+ },
1901
+ {
1902
+ "epoch": 0.58,
1903
+ "grad_norm": 0.796419620513916,
1904
+ "learning_rate": 7.929948706962508e-06,
1905
+ "loss": 0.7859,
1906
+ "step": 271
1907
+ },
1908
+ {
1909
+ "epoch": 0.58,
1910
+ "grad_norm": 0.7683595418930054,
1911
+ "learning_rate": 7.862000464692992e-06,
1912
+ "loss": 0.76,
1913
+ "step": 272
1914
+ },
1915
+ {
1916
+ "epoch": 0.58,
1917
+ "grad_norm": 0.8139968514442444,
1918
+ "learning_rate": 7.79415550539726e-06,
1919
+ "loss": 0.7559,
1920
+ "step": 273
1921
+ },
1922
+ {
1923
+ "epoch": 0.59,
1924
+ "grad_norm": 0.6886945962905884,
1925
+ "learning_rate": 7.726417106545231e-06,
1926
+ "loss": 0.7708,
1927
+ "step": 274
1928
+ },
1929
+ {
1930
+ "epoch": 0.59,
1931
+ "grad_norm": 0.7851764559745789,
1932
+ "learning_rate": 7.658788540459063e-06,
1933
+ "loss": 0.7445,
1934
+ "step": 275
1935
+ },
1936
+ {
1937
+ "epoch": 0.59,
1938
+ "grad_norm": 0.8292492032051086,
1939
+ "learning_rate": 7.5912730741551044e-06,
1940
+ "loss": 0.7757,
1941
+ "step": 276
1942
+ },
1943
+ {
1944
+ "epoch": 0.59,
1945
+ "grad_norm": 0.9825494289398193,
1946
+ "learning_rate": 7.523873969186039e-06,
1947
+ "loss": 0.7556,
1948
+ "step": 277
1949
+ },
1950
+ {
1951
+ "epoch": 0.6,
1952
+ "grad_norm": 0.7132662534713745,
1953
+ "learning_rate": 7.456594481483355e-06,
1954
+ "loss": 0.7614,
1955
+ "step": 278
1956
+ },
1957
+ {
1958
+ "epoch": 0.6,
1959
+ "grad_norm": 0.9558830857276917,
1960
+ "learning_rate": 7.389437861200024e-06,
1961
+ "loss": 0.7734,
1962
+ "step": 279
1963
+ },
1964
+ {
1965
+ "epoch": 0.6,
1966
+ "grad_norm": 0.8095993399620056,
1967
+ "learning_rate": 7.322407352553529e-06,
1968
+ "loss": 0.7252,
1969
+ "step": 280
1970
+ },
1971
+ {
1972
+ "epoch": 0.6,
1973
+ "grad_norm": 0.663575291633606,
1974
+ "learning_rate": 7.2555061936691104e-06,
1975
+ "loss": 0.8247,
1976
+ "step": 281
1977
+ },
1978
+ {
1979
+ "epoch": 0.6,
1980
+ "grad_norm": 0.7267043590545654,
1981
+ "learning_rate": 7.188737616423357e-06,
1982
+ "loss": 0.731,
1983
+ "step": 282
1984
+ },
1985
+ {
1986
+ "epoch": 0.61,
1987
+ "grad_norm": 0.8647053241729736,
1988
+ "learning_rate": 7.122104846288065e-06,
1989
+ "loss": 0.7719,
1990
+ "step": 283
1991
+ },
1992
+ {
1993
+ "epoch": 0.61,
1994
+ "grad_norm": 0.8748030662536621,
1995
+ "learning_rate": 7.055611102174442e-06,
1996
+ "loss": 0.7706,
1997
+ "step": 284
1998
+ },
1999
+ {
2000
+ "epoch": 0.61,
2001
+ "grad_norm": 0.9459949135780334,
2002
+ "learning_rate": 6.9892595962775826e-06,
2003
+ "loss": 0.7097,
2004
+ "step": 285
2005
+ },
2006
+ {
2007
+ "epoch": 0.61,
2008
+ "grad_norm": 1.1018948554992676,
2009
+ "learning_rate": 6.923053533921312e-06,
2010
+ "loss": 0.8045,
2011
+ "step": 286
2012
+ },
2013
+ {
2014
+ "epoch": 0.61,
2015
+ "grad_norm": 0.9634941816329956,
2016
+ "learning_rate": 6.85699611340333e-06,
2017
+ "loss": 0.7283,
2018
+ "step": 287
2019
+ },
2020
+ {
2021
+ "epoch": 0.62,
2022
+ "grad_norm": 0.7713951468467712,
2023
+ "learning_rate": 6.791090525840722e-06,
2024
+ "loss": 0.7872,
2025
+ "step": 288
2026
+ },
2027
+ {
2028
+ "epoch": 0.62,
2029
+ "grad_norm": 0.95869380235672,
2030
+ "learning_rate": 6.725339955015777e-06,
2031
+ "loss": 0.757,
2032
+ "step": 289
2033
+ },
2034
+ {
2035
+ "epoch": 0.62,
2036
+ "grad_norm": 1.0996911525726318,
2037
+ "learning_rate": 6.659747577222215e-06,
2038
+ "loss": 0.7636,
2039
+ "step": 290
2040
+ },
2041
+ {
2042
+ "epoch": 0.62,
2043
+ "grad_norm": 0.7740480899810791,
2044
+ "learning_rate": 6.5943165611117244e-06,
2045
+ "loss": 0.7933,
2046
+ "step": 291
2047
+ },
2048
+ {
2049
+ "epoch": 0.63,
2050
+ "grad_norm": 0.6407970786094666,
2051
+ "learning_rate": 6.529050067540887e-06,
2052
+ "loss": 0.7556,
2053
+ "step": 292
2054
+ },
2055
+ {
2056
+ "epoch": 0.63,
2057
+ "grad_norm": 0.7106875777244568,
2058
+ "learning_rate": 6.4639512494185104e-06,
2059
+ "loss": 0.7393,
2060
+ "step": 293
2061
+ },
2062
+ {
2063
+ "epoch": 0.63,
2064
+ "grad_norm": 0.6763285398483276,
2065
+ "learning_rate": 6.39902325155328e-06,
2066
+ "loss": 0.786,
2067
+ "step": 294
2068
+ },
2069
+ {
2070
+ "epoch": 0.63,
2071
+ "grad_norm": 0.8059327006340027,
2072
+ "learning_rate": 6.334269210501876e-06,
2073
+ "loss": 0.7669,
2074
+ "step": 295
2075
+ },
2076
+ {
2077
+ "epoch": 0.63,
2078
+ "grad_norm": 0.6124692559242249,
2079
+ "learning_rate": 6.269692254417408e-06,
2080
+ "loss": 0.7802,
2081
+ "step": 296
2082
+ },
2083
+ {
2084
+ "epoch": 0.64,
2085
+ "grad_norm": 1.0250524282455444,
2086
+ "learning_rate": 6.205295502898348e-06,
2087
+ "loss": 0.7889,
2088
+ "step": 297
2089
+ },
2090
+ {
2091
+ "epoch": 0.64,
2092
+ "grad_norm": 0.83560711145401,
2093
+ "learning_rate": 6.141082066837791e-06,
2094
+ "loss": 0.7176,
2095
+ "step": 298
2096
+ },
2097
+ {
2098
+ "epoch": 0.64,
2099
+ "grad_norm": 1.0154820680618286,
2100
+ "learning_rate": 6.077055048273193e-06,
2101
+ "loss": 0.7941,
2102
+ "step": 299
2103
+ },
2104
+ {
2105
+ "epoch": 0.64,
2106
+ "grad_norm": 0.7823308706283569,
2107
+ "learning_rate": 6.013217540236503e-06,
2108
+ "loss": 0.7533,
2109
+ "step": 300
2110
+ },
2111
+ {
2112
+ "epoch": 0.64,
2113
+ "grad_norm": 0.7707633376121521,
2114
+ "learning_rate": 5.9495726266047605e-06,
2115
+ "loss": 0.7922,
2116
+ "step": 301
2117
+ },
2118
+ {
2119
+ "epoch": 0.65,
2120
+ "grad_norm": 0.7845512628555298,
2121
+ "learning_rate": 5.886123381951103e-06,
2122
+ "loss": 0.7215,
2123
+ "step": 302
2124
+ },
2125
+ {
2126
+ "epoch": 0.65,
2127
+ "grad_norm": 0.9354090690612793,
2128
+ "learning_rate": 5.822872871396255e-06,
2129
+ "loss": 0.767,
2130
+ "step": 303
2131
+ },
2132
+ {
2133
+ "epoch": 0.65,
2134
+ "grad_norm": 0.9846552610397339,
2135
+ "learning_rate": 5.759824150460436e-06,
2136
+ "loss": 0.7866,
2137
+ "step": 304
2138
+ },
2139
+ {
2140
+ "epoch": 0.65,
2141
+ "grad_norm": 0.82850182056427,
2142
+ "learning_rate": 5.696980264915777e-06,
2143
+ "loss": 0.7449,
2144
+ "step": 305
2145
+ },
2146
+ {
2147
+ "epoch": 0.66,
2148
+ "grad_norm": 0.7820140719413757,
2149
+ "learning_rate": 5.63434425063917e-06,
2150
+ "loss": 0.7662,
2151
+ "step": 306
2152
+ },
2153
+ {
2154
+ "epoch": 0.66,
2155
+ "grad_norm": 0.7161849737167358,
2156
+ "learning_rate": 5.571919133465605e-06,
2157
+ "loss": 0.7683,
2158
+ "step": 307
2159
+ },
2160
+ {
2161
+ "epoch": 0.66,
2162
+ "grad_norm": 0.7008098363876343,
2163
+ "learning_rate": 5.50970792904203e-06,
2164
+ "loss": 0.7755,
2165
+ "step": 308
2166
+ },
2167
+ {
2168
+ "epoch": 0.66,
2169
+ "grad_norm": 0.7139153480529785,
2170
+ "learning_rate": 5.447713642681612e-06,
2171
+ "loss": 0.7443,
2172
+ "step": 309
2173
+ },
2174
+ {
2175
+ "epoch": 0.66,
2176
+ "grad_norm": 1.005012035369873,
2177
+ "learning_rate": 5.3859392692186256e-06,
2178
+ "loss": 0.7852,
2179
+ "step": 310
2180
+ },
2181
+ {
2182
+ "epoch": 0.67,
2183
+ "grad_norm": 0.7501611113548279,
2184
+ "learning_rate": 5.324387792863719e-06,
2185
+ "loss": 0.7567,
2186
+ "step": 311
2187
+ },
2188
+ {
2189
+ "epoch": 0.67,
2190
+ "grad_norm": 0.7607911825180054,
2191
+ "learning_rate": 5.263062187059785e-06,
2192
+ "loss": 0.7597,
2193
+ "step": 312
2194
+ },
2195
+ {
2196
+ "epoch": 0.67,
2197
+ "grad_norm": 0.992053210735321,
2198
+ "learning_rate": 5.201965414338308e-06,
2199
+ "loss": 0.7655,
2200
+ "step": 313
2201
+ },
2202
+ {
2203
+ "epoch": 0.67,
2204
+ "grad_norm": 0.6949525475502014,
2205
+ "learning_rate": 5.14110042617625e-06,
2206
+ "loss": 0.7374,
2207
+ "step": 314
2208
+ },
2209
+ {
2210
+ "epoch": 0.67,
2211
+ "grad_norm": 1.5913808345794678,
2212
+ "learning_rate": 5.080470162853473e-06,
2213
+ "loss": 0.7907,
2214
+ "step": 315
2215
+ },
2216
+ {
2217
+ "epoch": 0.68,
2218
+ "grad_norm": 0.8287177085876465,
2219
+ "learning_rate": 5.020077553310694e-06,
2220
+ "loss": 0.7055,
2221
+ "step": 316
2222
+ },
2223
+ {
2224
+ "epoch": 0.68,
2225
+ "grad_norm": 0.8051833510398865,
2226
+ "learning_rate": 4.959925515008003e-06,
2227
+ "loss": 0.7803,
2228
+ "step": 317
2229
+ },
2230
+ {
2231
+ "epoch": 0.68,
2232
+ "grad_norm": 0.7546500563621521,
2233
+ "learning_rate": 4.9000169537839126e-06,
2234
+ "loss": 0.7545,
2235
+ "step": 318
2236
+ },
2237
+ {
2238
+ "epoch": 0.68,
2239
+ "grad_norm": 0.7356598973274231,
2240
+ "learning_rate": 4.840354763714991e-06,
2241
+ "loss": 0.744,
2242
+ "step": 319
2243
+ },
2244
+ {
2245
+ "epoch": 0.69,
2246
+ "grad_norm": 0.6519478559494019,
2247
+ "learning_rate": 4.780941826976054e-06,
2248
+ "loss": 0.7621,
2249
+ "step": 320
2250
+ },
2251
+ {
2252
+ "epoch": 0.69,
2253
+ "grad_norm": 0.6492840051651001,
2254
+ "learning_rate": 4.721781013700928e-06,
2255
+ "loss": 0.7444,
2256
+ "step": 321
2257
+ },
2258
+ {
2259
+ "epoch": 0.69,
2260
+ "grad_norm": 0.9139990210533142,
2261
+ "learning_rate": 4.662875181843799e-06,
2262
+ "loss": 0.7904,
2263
+ "step": 322
2264
+ },
2265
+ {
2266
+ "epoch": 0.69,
2267
+ "grad_norm": 0.7614346742630005,
2268
+ "learning_rate": 4.604227177041156e-06,
2269
+ "loss": 0.7186,
2270
+ "step": 323
2271
+ },
2272
+ {
2273
+ "epoch": 0.69,
2274
+ "grad_norm": 0.7474337816238403,
2275
+ "learning_rate": 4.545839832474318e-06,
2276
+ "loss": 0.7475,
2277
+ "step": 324
2278
+ },
2279
+ {
2280
+ "epoch": 0.7,
2281
+ "grad_norm": 0.8080284595489502,
2282
+ "learning_rate": 4.487715968732568e-06,
2283
+ "loss": 0.7641,
2284
+ "step": 325
2285
+ },
2286
+ {
2287
+ "epoch": 0.7,
2288
+ "grad_norm": 0.8405642509460449,
2289
+ "learning_rate": 4.429858393676898e-06,
2290
+ "loss": 0.7749,
2291
+ "step": 326
2292
+ },
2293
+ {
2294
+ "epoch": 0.7,
2295
+ "grad_norm": 0.719484806060791,
2296
+ "learning_rate": 4.3722699023043634e-06,
2297
+ "loss": 0.7181,
2298
+ "step": 327
2299
+ },
2300
+ {
2301
+ "epoch": 0.7,
2302
+ "grad_norm": 0.6552863717079163,
2303
+ "learning_rate": 4.314953276613066e-06,
2304
+ "loss": 0.8089,
2305
+ "step": 328
2306
+ },
2307
+ {
2308
+ "epoch": 0.7,
2309
+ "grad_norm": 0.7018444538116455,
2310
+ "learning_rate": 4.257911285467754e-06,
2311
+ "loss": 0.734,
2312
+ "step": 329
2313
+ },
2314
+ {
2315
+ "epoch": 0.71,
2316
+ "grad_norm": 0.7911511063575745,
2317
+ "learning_rate": 4.201146684466065e-06,
2318
+ "loss": 0.7752,
2319
+ "step": 330
2320
+ },
2321
+ {
2322
+ "epoch": 0.71,
2323
+ "grad_norm": 0.8911862373352051,
2324
+ "learning_rate": 4.144662215805426e-06,
2325
+ "loss": 0.7733,
2326
+ "step": 331
2327
+ },
2328
+ {
2329
+ "epoch": 0.71,
2330
+ "grad_norm": 0.8746122121810913,
2331
+ "learning_rate": 4.088460608150537e-06,
2332
+ "loss": 0.7336,
2333
+ "step": 332
2334
+ },
2335
+ {
2336
+ "epoch": 0.71,
2337
+ "grad_norm": 0.6681094169616699,
2338
+ "learning_rate": 4.0325445765016145e-06,
2339
+ "loss": 0.7892,
2340
+ "step": 333
2341
+ },
2342
+ {
2343
+ "epoch": 0.72,
2344
+ "grad_norm": 0.8015274405479431,
2345
+ "learning_rate": 3.9769168220631745e-06,
2346
+ "loss": 0.774,
2347
+ "step": 334
2348
+ },
2349
+ {
2350
+ "epoch": 0.72,
2351
+ "grad_norm": 0.8299674391746521,
2352
+ "learning_rate": 3.921580032113602e-06,
2353
+ "loss": 0.7814,
2354
+ "step": 335
2355
+ },
2356
+ {
2357
+ "epoch": 0.72,
2358
+ "grad_norm": 0.7029886245727539,
2359
+ "learning_rate": 3.866536879875269e-06,
2360
+ "loss": 0.7556,
2361
+ "step": 336
2362
+ },
2363
+ {
2364
+ "epoch": 0.72,
2365
+ "grad_norm": 0.84246826171875,
2366
+ "learning_rate": 3.81179002438546e-06,
2367
+ "loss": 0.7678,
2368
+ "step": 337
2369
+ },
2370
+ {
2371
+ "epoch": 0.72,
2372
+ "grad_norm": 0.8028191328048706,
2373
+ "learning_rate": 3.7573421103678707e-06,
2374
+ "loss": 0.7679,
2375
+ "step": 338
2376
+ },
2377
+ {
2378
+ "epoch": 0.73,
2379
+ "grad_norm": 0.7954402565956116,
2380
+ "learning_rate": 3.7031957681048604e-06,
2381
+ "loss": 0.7265,
2382
+ "step": 339
2383
+ },
2384
+ {
2385
+ "epoch": 0.73,
2386
+ "grad_norm": 0.7864363193511963,
2387
+ "learning_rate": 3.649353613310409e-06,
2388
+ "loss": 0.7926,
2389
+ "step": 340
2390
+ },
2391
+ {
2392
+ "epoch": 0.73,
2393
+ "grad_norm": 0.6914604902267456,
2394
+ "learning_rate": 3.5958182470037127e-06,
2395
+ "loss": 0.749,
2396
+ "step": 341
2397
+ },
2398
+ {
2399
+ "epoch": 0.73,
2400
+ "grad_norm": 0.7519959807395935,
2401
+ "learning_rate": 3.5425922553835866e-06,
2402
+ "loss": 0.7788,
2403
+ "step": 342
2404
+ },
2405
+ {
2406
+ "epoch": 0.73,
2407
+ "grad_norm": 1.0774928331375122,
2408
+ "learning_rate": 3.4896782097034755e-06,
2409
+ "loss": 0.7313,
2410
+ "step": 343
2411
+ },
2412
+ {
2413
+ "epoch": 0.74,
2414
+ "grad_norm": 0.7848466634750366,
2415
+ "learning_rate": 3.4370786661472922e-06,
2416
+ "loss": 0.7901,
2417
+ "step": 344
2418
+ },
2419
+ {
2420
+ "epoch": 0.74,
2421
+ "grad_norm": 0.7957246899604797,
2422
+ "learning_rate": 3.384796165705885e-06,
2423
+ "loss": 0.7606,
2424
+ "step": 345
2425
+ },
2426
+ {
2427
+ "epoch": 0.74,
2428
+ "grad_norm": 0.6149446368217468,
2429
+ "learning_rate": 3.3328332340543314e-06,
2430
+ "loss": 0.7831,
2431
+ "step": 346
2432
+ },
2433
+ {
2434
+ "epoch": 0.74,
2435
+ "grad_norm": 1.1103349924087524,
2436
+ "learning_rate": 3.281192381429894e-06,
2437
+ "loss": 0.7119,
2438
+ "step": 347
2439
+ },
2440
+ {
2441
+ "epoch": 0.75,
2442
+ "grad_norm": 0.7909545302391052,
2443
+ "learning_rate": 3.2298761025107707e-06,
2444
+ "loss": 0.7467,
2445
+ "step": 348
2446
+ },
2447
+ {
2448
+ "epoch": 0.75,
2449
+ "grad_norm": 0.7904770970344543,
2450
+ "learning_rate": 3.178886876295578e-06,
2451
+ "loss": 0.7978,
2452
+ "step": 349
2453
+ },
2454
+ {
2455
+ "epoch": 0.75,
2456
+ "grad_norm": 0.7983621954917908,
2457
+ "learning_rate": 3.128227165983595e-06,
2458
+ "loss": 0.7281,
2459
+ "step": 350
2460
+ },
2461
+ {
2462
+ "epoch": 0.75,
2463
+ "grad_norm": 0.6307840943336487,
2464
+ "learning_rate": 3.0778994188557722e-06,
2465
+ "loss": 0.7959,
2466
+ "step": 351
2467
+ },
2468
+ {
2469
+ "epoch": 0.75,
2470
+ "grad_norm": 0.8515567779541016,
2471
+ "learning_rate": 3.027906066156503e-06,
2472
+ "loss": 0.74,
2473
+ "step": 352
2474
+ },
2475
+ {
2476
+ "epoch": 0.76,
2477
+ "grad_norm": 0.6278306245803833,
2478
+ "learning_rate": 2.978249522976181e-06,
2479
+ "loss": 0.748,
2480
+ "step": 353
2481
+ },
2482
+ {
2483
+ "epoch": 0.76,
2484
+ "grad_norm": 0.8014828562736511,
2485
+ "learning_rate": 2.9289321881345257e-06,
2486
+ "loss": 0.74,
2487
+ "step": 354
2488
+ },
2489
+ {
2490
+ "epoch": 0.76,
2491
+ "grad_norm": 1.187994122505188,
2492
+ "learning_rate": 2.879956444064703e-06,
2493
+ "loss": 0.7533,
2494
+ "step": 355
2495
+ },
2496
+ {
2497
+ "epoch": 0.76,
2498
+ "grad_norm": 0.6586583256721497,
2499
+ "learning_rate": 2.8313246566982342e-06,
2500
+ "loss": 0.7291,
2501
+ "step": 356
2502
+ },
2503
+ {
2504
+ "epoch": 0.76,
2505
+ "grad_norm": 0.7355678677558899,
2506
+ "learning_rate": 2.783039175350699e-06,
2507
+ "loss": 0.7521,
2508
+ "step": 357
2509
+ },
2510
+ {
2511
+ "epoch": 0.77,
2512
+ "grad_norm": 0.636843740940094,
2513
+ "learning_rate": 2.735102332608247e-06,
2514
+ "loss": 0.7392,
2515
+ "step": 358
2516
+ },
2517
+ {
2518
+ "epoch": 0.77,
2519
+ "grad_norm": 0.7756280303001404,
2520
+ "learning_rate": 2.6875164442149147e-06,
2521
+ "loss": 0.7927,
2522
+ "step": 359
2523
+ },
2524
+ {
2525
+ "epoch": 0.77,
2526
+ "grad_norm": 0.8138651251792908,
2527
+ "learning_rate": 2.640283808960754e-06,
2528
+ "loss": 0.778,
2529
+ "step": 360
2530
+ },
2531
+ {
2532
+ "epoch": 0.77,
2533
+ "grad_norm": 0.7197765111923218,
2534
+ "learning_rate": 2.5934067085707835e-06,
2535
+ "loss": 0.7382,
2536
+ "step": 361
2537
+ },
2538
+ {
2539
+ "epoch": 0.78,
2540
+ "grad_norm": 0.804681658744812,
2541
+ "learning_rate": 2.54688740759476e-06,
2542
+ "loss": 0.7456,
2543
+ "step": 362
2544
+ },
2545
+ {
2546
+ "epoch": 0.78,
2547
+ "grad_norm": 0.633647620677948,
2548
+ "learning_rate": 2.500728153297788e-06,
2549
+ "loss": 0.7492,
2550
+ "step": 363
2551
+ },
2552
+ {
2553
+ "epoch": 0.78,
2554
+ "grad_norm": 0.8081827759742737,
2555
+ "learning_rate": 2.454931175551746e-06,
2556
+ "loss": 0.7657,
2557
+ "step": 364
2558
+ },
2559
+ {
2560
+ "epoch": 0.78,
2561
+ "grad_norm": 0.7836536169052124,
2562
+ "learning_rate": 2.409498686727587e-06,
2563
+ "loss": 0.7666,
2564
+ "step": 365
2565
+ },
2566
+ {
2567
+ "epoch": 0.78,
2568
+ "grad_norm": 2.165264844894409,
2569
+ "learning_rate": 2.364432881588431e-06,
2570
+ "loss": 0.7266,
2571
+ "step": 366
2572
+ },
2573
+ {
2574
+ "epoch": 0.79,
2575
+ "grad_norm": 1.187296986579895,
2576
+ "learning_rate": 2.3197359371835802e-06,
2577
+ "loss": 0.8071,
2578
+ "step": 367
2579
+ },
2580
+ {
2581
+ "epoch": 0.79,
2582
+ "grad_norm": 0.6810380220413208,
2583
+ "learning_rate": 2.2754100127433033e-06,
2584
+ "loss": 0.7322,
2585
+ "step": 368
2586
+ },
2587
+ {
2588
+ "epoch": 0.79,
2589
+ "grad_norm": 0.6353223323822021,
2590
+ "learning_rate": 2.2314572495745746e-06,
2591
+ "loss": 0.7805,
2592
+ "step": 369
2593
+ },
2594
+ {
2595
+ "epoch": 0.79,
2596
+ "grad_norm": 0.74691241979599,
2597
+ "learning_rate": 2.187879770957585e-06,
2598
+ "loss": 0.7186,
2599
+ "step": 370
2600
+ },
2601
+ {
2602
+ "epoch": 0.79,
2603
+ "grad_norm": 0.7790375351905823,
2604
+ "learning_rate": 2.144679682043217e-06,
2605
+ "loss": 0.7743,
2606
+ "step": 371
2607
+ },
2608
+ {
2609
+ "epoch": 0.8,
2610
+ "grad_norm": 0.9003098011016846,
2611
+ "learning_rate": 2.1018590697513007e-06,
2612
+ "loss": 0.7577,
2613
+ "step": 372
2614
+ },
2615
+ {
2616
+ "epoch": 0.8,
2617
+ "grad_norm": 0.7259723544120789,
2618
+ "learning_rate": 2.0594200026698363e-06,
2619
+ "loss": 0.7921,
2620
+ "step": 373
2621
+ },
2622
+ {
2623
+ "epoch": 0.8,
2624
+ "grad_norm": 0.7609323859214783,
2625
+ "learning_rate": 2.017364530955055e-06,
2626
+ "loss": 0.7276,
2627
+ "step": 374
2628
+ },
2629
+ {
2630
+ "epoch": 0.8,
2631
+ "grad_norm": 1.556248426437378,
2632
+ "learning_rate": 1.9756946862323534e-06,
2633
+ "loss": 0.782,
2634
+ "step": 375
2635
+ },
2636
+ {
2637
+ "epoch": 0.81,
2638
+ "grad_norm": 0.7257483005523682,
2639
+ "learning_rate": 1.934412481498198e-06,
2640
+ "loss": 0.7655,
2641
+ "step": 376
2642
+ },
2643
+ {
2644
+ "epoch": 0.81,
2645
+ "grad_norm": 0.7583906054496765,
2646
+ "learning_rate": 1.8935199110228274e-06,
2647
+ "loss": 0.7412,
2648
+ "step": 377
2649
+ },
2650
+ {
2651
+ "epoch": 0.81,
2652
+ "grad_norm": 0.8019598722457886,
2653
+ "learning_rate": 1.8530189502539608e-06,
2654
+ "loss": 0.7554,
2655
+ "step": 378
2656
+ },
2657
+ {
2658
+ "epoch": 0.81,
2659
+ "grad_norm": 3.036848783493042,
2660
+ "learning_rate": 1.8129115557213262e-06,
2661
+ "loss": 0.749,
2662
+ "step": 379
2663
+ },
2664
+ {
2665
+ "epoch": 0.81,
2666
+ "grad_norm": 0.9667990803718567,
2667
+ "learning_rate": 1.77319966494218e-06,
2668
+ "loss": 0.8014,
2669
+ "step": 380
2670
+ },
2671
+ {
2672
+ "epoch": 0.82,
2673
+ "grad_norm": 0.6818022131919861,
2674
+ "learning_rate": 1.7338851963276827e-06,
2675
+ "loss": 0.7119,
2676
+ "step": 381
2677
+ },
2678
+ {
2679
+ "epoch": 0.82,
2680
+ "grad_norm": 0.7843496799468994,
2681
+ "learning_rate": 1.6949700490902344e-06,
2682
+ "loss": 0.7811,
2683
+ "step": 382
2684
+ },
2685
+ {
2686
+ "epoch": 0.82,
2687
+ "grad_norm": 0.8651720881462097,
2688
+ "learning_rate": 1.6564561031517278e-06,
2689
+ "loss": 0.7667,
2690
+ "step": 383
2691
+ },
2692
+ {
2693
+ "epoch": 0.82,
2694
+ "grad_norm": 1.0504810810089111,
2695
+ "learning_rate": 1.6183452190527317e-06,
2696
+ "loss": 0.7769,
2697
+ "step": 384
2698
+ },
2699
+ {
2700
+ "epoch": 0.82,
2701
+ "grad_norm": 0.7682644724845886,
2702
+ "learning_rate": 1.5806392378626079e-06,
2703
+ "loss": 0.7382,
2704
+ "step": 385
2705
+ },
2706
+ {
2707
+ "epoch": 0.83,
2708
+ "grad_norm": 0.8382360935211182,
2709
+ "learning_rate": 1.543339981090578e-06,
2710
+ "loss": 0.7557,
2711
+ "step": 386
2712
+ },
2713
+ {
2714
+ "epoch": 0.83,
2715
+ "grad_norm": 0.6545696258544922,
2716
+ "learning_rate": 1.5064492505977234e-06,
2717
+ "loss": 0.7465,
2718
+ "step": 387
2719
+ },
2720
+ {
2721
+ "epoch": 0.83,
2722
+ "grad_norm": 0.8746246099472046,
2723
+ "learning_rate": 1.4699688285099489e-06,
2724
+ "loss": 0.7518,
2725
+ "step": 388
2726
+ },
2727
+ {
2728
+ "epoch": 0.83,
2729
+ "grad_norm": 0.7189457416534424,
2730
+ "learning_rate": 1.433900477131882e-06,
2731
+ "loss": 0.7783,
2732
+ "step": 389
2733
+ },
2734
+ {
2735
+ "epoch": 0.84,
2736
+ "grad_norm": 0.6742172241210938,
2737
+ "learning_rate": 1.3982459388617453e-06,
2738
+ "loss": 0.7703,
2739
+ "step": 390
2740
+ },
2741
+ {
2742
+ "epoch": 0.84,
2743
+ "grad_norm": 0.6717891693115234,
2744
+ "learning_rate": 1.363006936107183e-06,
2745
+ "loss": 0.7566,
2746
+ "step": 391
2747
+ },
2748
+ {
2749
+ "epoch": 0.84,
2750
+ "grad_norm": 0.6821526885032654,
2751
+ "learning_rate": 1.3281851712020522e-06,
2752
+ "loss": 0.712,
2753
+ "step": 392
2754
+ },
2755
+ {
2756
+ "epoch": 0.84,
2757
+ "grad_norm": 0.6083774566650391,
2758
+ "learning_rate": 1.29378232632419e-06,
2759
+ "loss": 0.7865,
2760
+ "step": 393
2761
+ },
2762
+ {
2763
+ "epoch": 0.84,
2764
+ "grad_norm": 1.004950761795044,
2765
+ "learning_rate": 1.259800063414146e-06,
2766
+ "loss": 0.7437,
2767
+ "step": 394
2768
+ },
2769
+ {
2770
+ "epoch": 0.85,
2771
+ "grad_norm": 0.7969427704811096,
2772
+ "learning_rate": 1.2262400240949023e-06,
2773
+ "loss": 0.6931,
2774
+ "step": 395
2775
+ },
2776
+ {
2777
+ "epoch": 0.85,
2778
+ "grad_norm": 0.6904407739639282,
2779
+ "learning_rate": 1.1931038295925646e-06,
2780
+ "loss": 0.7848,
2781
+ "step": 396
2782
+ },
2783
+ {
2784
+ "epoch": 0.85,
2785
+ "grad_norm": 0.7331759333610535,
2786
+ "learning_rate": 1.1603930806580443e-06,
2787
+ "loss": 0.7295,
2788
+ "step": 397
2789
+ },
2790
+ {
2791
+ "epoch": 0.85,
2792
+ "grad_norm": 0.6611264944076538,
2793
+ "learning_rate": 1.128109357489734e-06,
2794
+ "loss": 0.8059,
2795
+ "step": 398
2796
+ },
2797
+ {
2798
+ "epoch": 0.85,
2799
+ "grad_norm": 0.8612853288650513,
2800
+ "learning_rate": 1.0962542196571636e-06,
2801
+ "loss": 0.7421,
2802
+ "step": 399
2803
+ },
2804
+ {
2805
+ "epoch": 0.86,
2806
+ "grad_norm": 0.6944461464881897,
2807
+ "learning_rate": 1.064829206025665e-06,
2808
+ "loss": 0.7537,
2809
+ "step": 400
2810
+ },
2811
+ {
2812
+ "epoch": 0.86,
2813
+ "grad_norm": 0.7044425010681152,
2814
+ "learning_rate": 1.0338358346820355e-06,
2815
+ "loss": 0.7097,
2816
+ "step": 401
2817
+ },
2818
+ {
2819
+ "epoch": 0.86,
2820
+ "grad_norm": 0.6490280032157898,
2821
+ "learning_rate": 1.003275602861188e-06,
2822
+ "loss": 0.7768,
2823
+ "step": 402
2824
+ },
2825
+ {
2826
+ "epoch": 0.86,
2827
+ "grad_norm": 0.6698139309883118,
2828
+ "learning_rate": 9.731499868738448e-07,
2829
+ "loss": 0.7691,
2830
+ "step": 403
2831
+ },
2832
+ {
2833
+ "epoch": 0.87,
2834
+ "grad_norm": 0.7602378129959106,
2835
+ "learning_rate": 9.434604420351912e-07,
2836
+ "loss": 0.7538,
2837
+ "step": 404
2838
+ },
2839
+ {
2840
+ "epoch": 0.87,
2841
+ "grad_norm": 0.7384971976280212,
2842
+ "learning_rate": 9.142084025945986e-07,
2843
+ "loss": 0.7224,
2844
+ "step": 405
2845
+ },
2846
+ {
2847
+ "epoch": 0.87,
2848
+ "grad_norm": 0.6048542261123657,
2849
+ "learning_rate": 8.853952816663214e-07,
2850
+ "loss": 0.8024,
2851
+ "step": 406
2852
+ },
2853
+ {
2854
+ "epoch": 0.87,
2855
+ "grad_norm": 0.6573376655578613,
2856
+ "learning_rate": 8.570224711612385e-07,
2857
+ "loss": 0.7336,
2858
+ "step": 407
2859
+ },
2860
+ {
2861
+ "epoch": 0.87,
2862
+ "grad_norm": 0.6137224435806274,
2863
+ "learning_rate": 8.290913417196178e-07,
2864
+ "loss": 0.79,
2865
+ "step": 408
2866
+ },
2867
+ {
2868
+ "epoch": 0.88,
2869
+ "grad_norm": 0.5987165570259094,
2870
+ "learning_rate": 8.016032426448816e-07,
2871
+ "loss": 0.723,
2872
+ "step": 409
2873
+ },
2874
+ {
2875
+ "epoch": 0.88,
2876
+ "grad_norm": 0.7513317465782166,
2877
+ "learning_rate": 7.745595018384577e-07,
2878
+ "loss": 0.7547,
2879
+ "step": 410
2880
+ },
2881
+ {
2882
+ "epoch": 0.88,
2883
+ "grad_norm": 0.6956148147583008,
2884
+ "learning_rate": 7.479614257355972e-07,
2885
+ "loss": 0.7716,
2886
+ "step": 411
2887
+ },
2888
+ {
2889
+ "epoch": 0.88,
2890
+ "grad_norm": 0.7517545819282532,
2891
+ "learning_rate": 7.218102992422882e-07,
2892
+ "loss": 0.7415,
2893
+ "step": 412
2894
+ },
2895
+ {
2896
+ "epoch": 0.88,
2897
+ "grad_norm": 0.7091813087463379,
2898
+ "learning_rate": 6.961073856731648e-07,
2899
+ "loss": 0.7552,
2900
+ "step": 413
2901
+ },
2902
+ {
2903
+ "epoch": 0.89,
2904
+ "grad_norm": 0.6687448620796204,
2905
+ "learning_rate": 6.708539266905e-07,
2906
+ "loss": 0.7959,
2907
+ "step": 414
2908
+ },
2909
+ {
2910
+ "epoch": 0.89,
2911
+ "grad_norm": 0.7798054218292236,
2912
+ "learning_rate": 6.460511422441984e-07,
2913
+ "loss": 0.6974,
2914
+ "step": 415
2915
+ },
2916
+ {
2917
+ "epoch": 0.89,
2918
+ "grad_norm": 0.7497337460517883,
2919
+ "learning_rate": 6.21700230512885e-07,
2920
+ "loss": 0.7888,
2921
+ "step": 416
2922
+ },
2923
+ {
2924
+ "epoch": 0.89,
2925
+ "grad_norm": 0.8241896629333496,
2926
+ "learning_rate": 5.978023678460099e-07,
2927
+ "loss": 0.7506,
2928
+ "step": 417
2929
+ },
2930
+ {
2931
+ "epoch": 0.9,
2932
+ "grad_norm": 0.614700198173523,
2933
+ "learning_rate": 5.743587087070235e-07,
2934
+ "loss": 0.7527,
2935
+ "step": 418
2936
+ },
2937
+ {
2938
+ "epoch": 0.9,
2939
+ "grad_norm": 0.747754693031311,
2940
+ "learning_rate": 5.513703856176112e-07,
2941
+ "loss": 0.7463,
2942
+ "step": 419
2943
+ },
2944
+ {
2945
+ "epoch": 0.9,
2946
+ "grad_norm": 0.7050479650497437,
2947
+ "learning_rate": 5.288385091029724e-07,
2948
+ "loss": 0.7713,
2949
+ "step": 420
2950
+ },
2951
+ {
2952
+ "epoch": 0.9,
2953
+ "grad_norm": 1.1944342851638794,
2954
+ "learning_rate": 5.067641676381918e-07,
2955
+ "loss": 0.7593,
2956
+ "step": 421
2957
+ },
2958
+ {
2959
+ "epoch": 0.9,
2960
+ "grad_norm": 0.6154138445854187,
2961
+ "learning_rate": 4.851484275956331e-07,
2962
+ "loss": 0.7298,
2963
+ "step": 422
2964
+ },
2965
+ {
2966
+ "epoch": 0.91,
2967
+ "grad_norm": 2.1620090007781982,
2968
+ "learning_rate": 4.6399233319344703e-07,
2969
+ "loss": 0.7723,
2970
+ "step": 423
2971
+ },
2972
+ {
2973
+ "epoch": 0.91,
2974
+ "grad_norm": 0.6451908349990845,
2975
+ "learning_rate": 4.432969064451109e-07,
2976
+ "loss": 0.7605,
2977
+ "step": 424
2978
+ },
2979
+ {
2980
+ "epoch": 0.91,
2981
+ "grad_norm": 1.1842758655548096,
2982
+ "learning_rate": 4.230631471100655e-07,
2983
+ "loss": 0.7478,
2984
+ "step": 425
2985
+ },
2986
+ {
2987
+ "epoch": 0.91,
2988
+ "grad_norm": 0.7377268075942993,
2989
+ "learning_rate": 4.0329203264541594e-07,
2990
+ "loss": 0.7502,
2991
+ "step": 426
2992
+ },
2993
+ {
2994
+ "epoch": 0.91,
2995
+ "grad_norm": 0.7105292677879333,
2996
+ "learning_rate": 3.8398451815870984e-07,
2997
+ "loss": 0.7391,
2998
+ "step": 427
2999
+ },
3000
+ {
3001
+ "epoch": 0.92,
3002
+ "grad_norm": 0.7181400656700134,
3003
+ "learning_rate": 3.6514153636180384e-07,
3004
+ "loss": 0.7818,
3005
+ "step": 428
3006
+ },
3007
+ {
3008
+ "epoch": 0.92,
3009
+ "grad_norm": 0.6796531081199646,
3010
+ "learning_rate": 3.467639975257997e-07,
3011
+ "loss": 0.7778,
3012
+ "step": 429
3013
+ },
3014
+ {
3015
+ "epoch": 0.92,
3016
+ "grad_norm": 0.6979401111602783,
3017
+ "learning_rate": 3.2885278943707524e-07,
3018
+ "loss": 0.7531,
3019
+ "step": 430
3020
+ },
3021
+ {
3022
+ "epoch": 0.92,
3023
+ "grad_norm": 0.9989197850227356,
3024
+ "learning_rate": 3.114087773543939e-07,
3025
+ "loss": 0.7049,
3026
+ "step": 431
3027
+ },
3028
+ {
3029
+ "epoch": 0.93,
3030
+ "grad_norm": 0.7881868481636047,
3031
+ "learning_rate": 2.9443280396710847e-07,
3032
+ "loss": 0.7798,
3033
+ "step": 432
3034
+ },
3035
+ {
3036
+ "epoch": 0.93,
3037
+ "grad_norm": 0.7045859694480896,
3038
+ "learning_rate": 2.7792568935444796e-07,
3039
+ "loss": 0.7405,
3040
+ "step": 433
3041
+ },
3042
+ {
3043
+ "epoch": 0.93,
3044
+ "grad_norm": 0.7407889366149902,
3045
+ "learning_rate": 2.618882309459081e-07,
3046
+ "loss": 0.6954,
3047
+ "step": 434
3048
+ },
3049
+ {
3050
+ "epoch": 0.93,
3051
+ "grad_norm": 0.6761502623558044,
3052
+ "learning_rate": 2.4632120348272e-07,
3053
+ "loss": 0.7933,
3054
+ "step": 435
3055
+ },
3056
+ {
3057
+ "epoch": 0.93,
3058
+ "grad_norm": 0.9298683404922485,
3059
+ "learning_rate": 2.312253589804314e-07,
3060
+ "loss": 0.7192,
3061
+ "step": 436
3062
+ },
3063
+ {
3064
+ "epoch": 0.94,
3065
+ "grad_norm": 0.6314147710800171,
3066
+ "learning_rate": 2.166014266925731e-07,
3067
+ "loss": 0.8084,
3068
+ "step": 437
3069
+ },
3070
+ {
3071
+ "epoch": 0.94,
3072
+ "grad_norm": 0.7935160994529724,
3073
+ "learning_rate": 2.0245011307543416e-07,
3074
+ "loss": 0.748,
3075
+ "step": 438
3076
+ },
3077
+ {
3078
+ "epoch": 0.94,
3079
+ "grad_norm": 0.9580221176147461,
3080
+ "learning_rate": 1.88772101753929e-07,
3081
+ "loss": 0.7057,
3082
+ "step": 439
3083
+ },
3084
+ {
3085
+ "epoch": 0.94,
3086
+ "grad_norm": 0.766875147819519,
3087
+ "learning_rate": 1.7556805348858063e-07,
3088
+ "loss": 0.7403,
3089
+ "step": 440
3090
+ },
3091
+ {
3092
+ "epoch": 0.94,
3093
+ "grad_norm": 0.66485196352005,
3094
+ "learning_rate": 1.6283860614358936e-07,
3095
+ "loss": 0.8058,
3096
+ "step": 441
3097
+ },
3098
+ {
3099
+ "epoch": 0.95,
3100
+ "grad_norm": 0.7386608719825745,
3101
+ "learning_rate": 1.5058437465602982e-07,
3102
+ "loss": 0.698,
3103
+ "step": 442
3104
+ },
3105
+ {
3106
+ "epoch": 0.95,
3107
+ "grad_norm": 0.8943004012107849,
3108
+ "learning_rate": 1.388059510061379e-07,
3109
+ "loss": 0.7899,
3110
+ "step": 443
3111
+ },
3112
+ {
3113
+ "epoch": 0.95,
3114
+ "grad_norm": 0.6558826565742493,
3115
+ "learning_rate": 1.2750390418871605e-07,
3116
+ "loss": 0.7423,
3117
+ "step": 444
3118
+ },
3119
+ {
3120
+ "epoch": 0.95,
3121
+ "grad_norm": 0.6673349738121033,
3122
+ "learning_rate": 1.1667878018564171e-07,
3123
+ "loss": 0.8005,
3124
+ "step": 445
3125
+ },
3126
+ {
3127
+ "epoch": 0.96,
3128
+ "grad_norm": 1.509501576423645,
3129
+ "learning_rate": 1.063311019395008e-07,
3130
+ "loss": 0.7367,
3131
+ "step": 446
3132
+ },
3133
+ {
3134
+ "epoch": 0.96,
3135
+ "grad_norm": 0.948124349117279,
3136
+ "learning_rate": 9.64613693283123e-08,
3137
+ "loss": 0.7318,
3138
+ "step": 447
3139
+ },
3140
+ {
3141
+ "epoch": 0.96,
3142
+ "grad_norm": 0.6254904866218567,
3143
+ "learning_rate": 8.707005914139422e-08,
3144
+ "loss": 0.7596,
3145
+ "step": 448
3146
+ },
3147
+ {
3148
+ "epoch": 0.96,
3149
+ "grad_norm": 0.6569192409515381,
3150
+ "learning_rate": 7.815762505632096e-08,
3151
+ "loss": 0.7495,
3152
+ "step": 449
3153
+ },
3154
+ {
3155
+ "epoch": 0.96,
3156
+ "grad_norm": 0.7474935054779053,
3157
+ "learning_rate": 6.972449761700862e-08,
3158
+ "loss": 0.7723,
3159
+ "step": 450
3160
+ },
3161
+ {
3162
+ "epoch": 0.97,
3163
+ "grad_norm": 0.8758479356765747,
3164
+ "learning_rate": 6.177108421292266e-08,
3165
+ "loss": 0.7392,
3166
+ "step": 451
3167
+ },
3168
+ {
3169
+ "epoch": 0.97,
3170
+ "grad_norm": 1.0904728174209595,
3171
+ "learning_rate": 5.429776905938489e-08,
3172
+ "loss": 0.7561,
3173
+ "step": 452
3174
+ },
3175
+ {
3176
+ "epoch": 0.97,
3177
+ "grad_norm": 0.5738272666931152,
3178
+ "learning_rate": 4.7304913179025967e-08,
3179
+ "loss": 0.7998,
3180
+ "step": 453
3181
+ },
3182
+ {
3183
+ "epoch": 0.97,
3184
+ "grad_norm": 0.7406892776489258,
3185
+ "learning_rate": 4.0792854384338334e-08,
3186
+ "loss": 0.7018,
3187
+ "step": 454
3188
+ },
3189
+ {
3190
+ "epoch": 0.97,
3191
+ "grad_norm": 0.7328673601150513,
3192
+ "learning_rate": 3.4761907261356976e-08,
3193
+ "loss": 0.7957,
3194
+ "step": 455
3195
+ },
3196
+ {
3197
+ "epoch": 0.98,
3198
+ "grad_norm": 0.6121895909309387,
3199
+ "learning_rate": 2.9212363154463853e-08,
3200
+ "loss": 0.7514,
3201
+ "step": 456
3202
+ },
3203
+ {
3204
+ "epoch": 0.98,
3205
+ "grad_norm": 0.8818193674087524,
3206
+ "learning_rate": 2.4144490152313572e-08,
3207
+ "loss": 0.7642,
3208
+ "step": 457
3209
+ },
3210
+ {
3211
+ "epoch": 0.98,
3212
+ "grad_norm": 1.2568868398666382,
3213
+ "learning_rate": 1.9558533074882646e-08,
3214
+ "loss": 0.7433,
3215
+ "step": 458
3216
+ },
3217
+ {
3218
+ "epoch": 0.98,
3219
+ "grad_norm": 0.8221530914306641,
3220
+ "learning_rate": 1.545471346164007e-08,
3221
+ "loss": 0.7665,
3222
+ "step": 459
3223
+ },
3224
+ {
3225
+ "epoch": 0.99,
3226
+ "grad_norm": 0.7928904294967651,
3227
+ "learning_rate": 1.1833229560848092e-08,
3228
+ "loss": 0.7617,
3229
+ "step": 460
3230
+ },
3231
+ {
3232
+ "epoch": 0.99,
3233
+ "grad_norm": 0.7835130095481873,
3234
+ "learning_rate": 8.694256319987659e-09,
3235
+ "loss": 0.7412,
3236
+ "step": 461
3237
+ },
3238
+ {
3239
+ "epoch": 0.99,
3240
+ "grad_norm": 0.7215912938117981,
3241
+ "learning_rate": 6.037945377297405e-09,
3242
+ "loss": 0.7756,
3243
+ "step": 462
3244
+ },
3245
+ {
3246
+ "epoch": 0.99,
3247
+ "grad_norm": 0.8312851786613464,
3248
+ "learning_rate": 3.8644250544594975e-09,
3249
+ "loss": 0.7332,
3250
+ "step": 463
3251
+ },
3252
+ {
3253
+ "epoch": 0.99,
3254
+ "grad_norm": 0.6416303515434265,
3255
+ "learning_rate": 2.173800350394606e-09,
3256
+ "loss": 0.797,
3257
+ "step": 464
3258
+ },
3259
+ {
3260
+ "epoch": 1.0,
3261
+ "grad_norm": 1.003443956375122,
3262
+ "learning_rate": 9.661529361892907e-10,
3263
+ "loss": 0.7039,
3264
+ "step": 465
3265
+ },
3266
+ {
3267
+ "epoch": 1.0,
3268
+ "grad_norm": 0.8252460956573486,
3269
+ "learning_rate": 2.415411511536014e-10,
3270
+ "loss": 0.7942,
3271
+ "step": 466
3272
+ },
3273
+ {
3274
+ "epoch": 1.0,
3275
+ "grad_norm": 0.7430176734924316,
3276
+ "learning_rate": 0.0,
3277
+ "loss": 0.7726,
3278
+ "step": 467
3279
+ },
3280
+ {
3281
+ "epoch": 1.0,
3282
+ "step": 467,
3283
+ "total_flos": 0.0,
3284
+ "train_loss": 0.435926725573407,
3285
+ "train_runtime": 10640.2467,
3286
+ "train_samples_per_second": 102.107,
3287
+ "train_steps_per_second": 0.044
3288
+ }
3289
+ ],
3290
+ "logging_steps": 1.0,
3291
+ "max_steps": 467,
3292
+ "num_input_tokens_seen": 0,
3293
+ "num_train_epochs": 1,
3294
+ "save_steps": 100,
3295
+ "stateful_callbacks": {
3296
+ "TrainerControl": {
3297
+ "args": {
3298
+ "should_epoch_stop": false,
3299
+ "should_evaluate": false,
3300
+ "should_log": false,
3301
+ "should_save": false,
3302
+ "should_training_stop": false
3303
+ },
3304
+ "attributes": {}
3305
+ }
3306
+ },
3307
+ "total_flos": 0.0,
3308
+ "train_batch_size": 2,
3309
+ "trial_name": null,
3310
+ "trial_params": null
3311
+ }
utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+ import os
18
+ import os.path as osp
19
+
20
+ from huggingface_hub import repo_exists, snapshot_download
21
+ from huggingface_hub.utils import HFValidationError, validate_repo_id
22
+ from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
23
+
24
+ from .configuration_vila import VILAConfig
25
+ from .constants import MEDIA_TOKENS
26
+ from .tokenizer_utils import infer_stop_tokens
27
+
28
+
29
+ def load_tokenizer_then_handle_media_tokens_and_chat_template(
30
+ model_name_or_path, config: VILAConfig, model_max_length=None
31
+ ):
32
+ # TODO(ligeng): a lot of copy-paste code, refactor to make a single function
33
+ tokenizer = AutoTokenizer.from_pretrained(
34
+ osp.join(model_name_or_path, "llm"), padding_side="right", use_fast=True, legacy=False
35
+ )
36
+ if model_max_length is not None:
37
+ tokenizer.model_max_length = model_max_length
38
+
39
+ # Load chat template if specified.
40
+ if getattr(config, "chat_template", None) is not None:
41
+ print(f"Using chat template: {config.chat_template}")
42
+ fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
43
+ if not os.path.exists(fpath):
44
+ fpath = os.path.join(model_name_or_path, f"{config.chat_template}.jinja")
45
+ with open(fpath) as fd:
46
+ chat_template = fd.read()
47
+ tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
48
+
49
+ # Set stop tokens for the tokenizer
50
+ tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
51
+ tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
52
+
53
+ # Add media tokens to the tokenizer
54
+ tokenizer.media_tokens = MEDIA_TOKENS
55
+ tokenizer.media_token_ids = {}
56
+ for name, token in MEDIA_TOKENS.items():
57
+ tokenizer.add_tokens([token], special_tokens=True)
58
+ tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
59
+
60
+ return tokenizer
61
+
62
+
63
+ def get_model_config(config):
64
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
65
+
66
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
67
+ root_path = config._name_or_path
68
+ else:
69
+ root_path = config.resume_path
70
+
71
+ # download from huggingface
72
+ if root_path is not None and not osp.exists(root_path):
73
+ try:
74
+ valid_hf_repo = repo_exists(root_path)
75
+ except HFValidationError as e:
76
+ valid_hf_repo = False
77
+ if valid_hf_repo:
78
+ root_path = snapshot_download(root_path)
79
+
80
+ return_list = []
81
+ for key in default_keys:
82
+ cfg = getattr(config, key, None)
83
+ if isinstance(cfg, dict):
84
+ try:
85
+ return_list.append(os.path.join(root_path, key[:-4]))
86
+ except:
87
+ raise ValueError(f"Cannot find resume path in config for {key}!")
88
+ elif isinstance(cfg, PretrainedConfig):
89
+ return_list.append(os.path.join(root_path, key[:-4]))
90
+ elif isinstance(cfg, str):
91
+ return_list.append(cfg)
92
+
93
+ return return_list
94
+
95
+
96
+ def get_model_config_fp8(config):
97
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
98
+
99
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
100
+ root_path = config._name_or_path
101
+ else:
102
+ root_path = config.resume_path
103
+
104
+ # download from huggingface
105
+ if root_path is not None and not osp.exists(root_path):
106
+ try:
107
+ valid_hf_repo = repo_exists(root_path)
108
+ except HFValidationError as e:
109
+ valid_hf_repo = False
110
+ if valid_hf_repo:
111
+ root_path = snapshot_download(root_path)
112
+
113
+ return_list = []
114
+ for key in default_keys:
115
+ cfg = getattr(config, key, None)
116
+ if isinstance(cfg, dict):
117
+ try:
118
+ return_list.append(os.path.join(root_path, key[:-4]))
119
+ except:
120
+ raise ValueError(f"Cannot find resume path in config for {key}!")
121
+ elif isinstance(cfg, PretrainedConfig):
122
+ return_list.append(os.path.join(root_path, key[:-4]))
123
+ elif isinstance(cfg, str):
124
+ return_list.append(cfg)
125
+
126
+ # fp8_llm
127
+ key = "fp8_llm_cfg"
128
+ directory_path = os.path.join(root_path, key[:-4])
129
+ assert os.path.isdir(directory_path) and os.listdir(
130
+ directory_path
131
+ ), "You need to first convert the model weights to FP8 explicitly."
132
+ return_list.append(directory_path)
133
+
134
+ return return_list
135
+
136
+
137
+ def get_model_config_fp8(config):
138
+ default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
139
+
140
+ if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
141
+ root_path = config._name_or_path
142
+ else:
143
+ root_path = config.resume_path
144
+
145
+ # download from huggingface
146
+ if root_path is not None and not osp.exists(root_path):
147
+ try:
148
+ valid_hf_repo = repo_exists(root_path)
149
+ except HFValidationError as e:
150
+ valid_hf_repo = False
151
+ if valid_hf_repo:
152
+ root_path = snapshot_download(root_path)
153
+
154
+ return_list = []
155
+ for key in default_keys:
156
+ cfg = getattr(config, key, None)
157
+ if isinstance(cfg, dict):
158
+ try:
159
+ return_list.append(os.path.join(root_path, key[:-4]))
160
+ except:
161
+ raise ValueError(f"Cannot find resume path in config for {key}!")
162
+ elif isinstance(cfg, PretrainedConfig):
163
+ return_list.append(os.path.join(root_path, key[:-4]))
164
+ elif isinstance(cfg, str):
165
+ return_list.append(cfg)
166
+
167
+ # fp8_llm
168
+ key = "fp8_llm_cfg"
169
+ directory_path = os.path.join(root_path, key[:-4])
170
+ assert os.path.isdir(directory_path) and os.listdir(
171
+ directory_path
172
+ ), "You need to first convert the model weights to FP8 explicitly."
173
+ return_list.append(directory_path)
174
+
175
+ return return_list
176
+
177
+
178
+ def is_mm_model(model_path):
179
+ """
180
+ Check if the model at the given path is a visual language model.
181
+
182
+ Args:
183
+ model_path (str): The path to the model.
184
+
185
+ Returns:
186
+ bool: True if the model is an MM model, False otherwise.
187
+ """
188
+ config = AutoConfig.from_pretrained(model_path)
189
+ architectures = config.architectures
190
+ for architecture in architectures:
191
+ if "llava" in architecture.lower():
192
+ return True
193
+ return False
194
+
195
+
196
+ def auto_upgrade(config):
197
+ cfg = AutoConfig.from_pretrained(config)
198
+ if "llava" in config and "llava" not in cfg.model_type:
199
+ assert cfg.model_type == "llama"
200
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
201
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
202
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
203
+ if confirm.lower() in ["y", "yes"]:
204
+ print("Upgrading checkpoint...")
205
+ assert len(cfg.architectures) == 1
206
+ setattr(cfg.__class__, "model_type", "llava")
207
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
208
+ cfg.save_pretrained(config)
209
+ print("Checkpoint upgraded.")
210
+ else:
211
+ print("Checkpoint upgrade aborted.")
212
+ exit(1)
vision_tower/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "runs/train/NVILA-Lite_14b_siglip_aws_env2_obelics_ja/sft_14b_GPT4_v6/model/vision_tower",
3
+ "architectures": [
4
+ "SiglipVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "hidden_act": "gelu_pytorch_tanh",
8
+ "hidden_size": 1152,
9
+ "image_size": 448,
10
+ "intermediate_size": 4304,
11
+ "layer_norm_eps": 1e-06,
12
+ "model_type": "siglip_vision_model",
13
+ "num_attention_heads": 16,
14
+ "num_channels": 3,
15
+ "num_hidden_layers": 27,
16
+ "num_image_tokens": 256,
17
+ "patch_size": 14,
18
+ "projection_dim": 2048,
19
+ "projector_hidden_act": "gelu_fast",
20
+ "torch_dtype": "bfloat16",
21
+ "transformers_version": "4.45.0",
22
+ "vision_use_head": false
23
+ }
vision_tower/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b252dab753e022135ac0110affc9dfa0cab40680abc935dcaa3f09b449ff1323
3
+ size 826707904
vision_tower/preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "SiglipImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "processor_class": "SiglipProcessor",
18
+ "resample": 3,
19
+ "rescale_factor": 0.00392156862745098,
20
+ "size": {
21
+ "height": 448,
22
+ "width": 448
23
+ }
24
+ }