Upload modeling_rope_utils.py
Browse files- modeling_rope_utils.py +558 -0
modeling_rope_utils.py
ADDED
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from typing import Optional, Tuple
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import is_torch_available, logging
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
if is_torch_available():
|
24 |
+
import torch
|
25 |
+
|
26 |
+
|
27 |
+
def _compute_default_rope_parameters(
|
28 |
+
config: Optional[PretrainedConfig] = None,
|
29 |
+
device: Optional["torch.device"] = None,
|
30 |
+
seq_len: Optional[int] = None,
|
31 |
+
**rope_kwargs,
|
32 |
+
) -> Tuple["torch.Tensor", float]:
|
33 |
+
"""
|
34 |
+
Computes the inverse frequencies according to the original RoPE implementation
|
35 |
+
Args:
|
36 |
+
config ([`~transformers.PretrainedConfig`]):
|
37 |
+
The model configuration.
|
38 |
+
device (`torch.device`):
|
39 |
+
The device to use for initialization of the inverse frequencies.
|
40 |
+
seq_len (`int`, *optional*):
|
41 |
+
The current sequence length. Unused for this type of RoPE.
|
42 |
+
rope_kwargs (`Dict`, *optional*):
|
43 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
44 |
+
Returns:
|
45 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
46 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
47 |
+
"""
|
48 |
+
if config is not None and len(rope_kwargs) > 0:
|
49 |
+
raise ValueError(
|
50 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
51 |
+
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
52 |
+
)
|
53 |
+
if len(rope_kwargs) > 0:
|
54 |
+
base = rope_kwargs["base"]
|
55 |
+
dim = rope_kwargs["dim"]
|
56 |
+
elif config is not None:
|
57 |
+
base = config.rope_theta
|
58 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
59 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
60 |
+
dim = int(head_dim * partial_rotary_factor)
|
61 |
+
|
62 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
63 |
+
|
64 |
+
# Compute the inverse frequencies
|
65 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
66 |
+
return inv_freq, attention_factor
|
67 |
+
|
68 |
+
|
69 |
+
def _compute_linear_scaling_rope_parameters(
|
70 |
+
config: Optional[PretrainedConfig] = None,
|
71 |
+
device: Optional["torch.device"] = None,
|
72 |
+
seq_len: Optional[int] = None,
|
73 |
+
**rope_kwargs,
|
74 |
+
) -> Tuple["torch.Tensor", float]:
|
75 |
+
"""
|
76 |
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
77 |
+
Args:
|
78 |
+
config ([`~transformers.PretrainedConfig`]):
|
79 |
+
The model configuration.
|
80 |
+
device (`torch.device`):
|
81 |
+
The device to use for initialization of the inverse frequencies.
|
82 |
+
seq_len (`int`, *optional*):
|
83 |
+
The current sequence length. Unused for this type of RoPE.
|
84 |
+
rope_kwargs (`Dict`, *optional*):
|
85 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
86 |
+
Returns:
|
87 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
88 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
89 |
+
"""
|
90 |
+
if config is not None and len(rope_kwargs) > 0:
|
91 |
+
raise ValueError(
|
92 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
93 |
+
f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
94 |
+
)
|
95 |
+
if len(rope_kwargs) > 0:
|
96 |
+
factor = rope_kwargs["factor"]
|
97 |
+
elif config is not None:
|
98 |
+
factor = config.rope_scaling["factor"]
|
99 |
+
|
100 |
+
# Gets the default RoPE parameters
|
101 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
102 |
+
|
103 |
+
# Then applies linear scaling to the frequencies.
|
104 |
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
105 |
+
# applying scaling to the inverse frequencies is equivalent.
|
106 |
+
inv_freq /= factor
|
107 |
+
return inv_freq, attention_factor
|
108 |
+
|
109 |
+
|
110 |
+
def _compute_dynamic_ntk_parameters(
|
111 |
+
config: Optional[PretrainedConfig] = None,
|
112 |
+
device: Optional["torch.device"] = None,
|
113 |
+
seq_len: Optional[int] = None,
|
114 |
+
**rope_kwargs,
|
115 |
+
) -> Tuple["torch.Tensor", float]:
|
116 |
+
"""
|
117 |
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
118 |
+
Args:
|
119 |
+
config ([`~transformers.PretrainedConfig`]):
|
120 |
+
The model configuration.
|
121 |
+
device (`torch.device`):
|
122 |
+
The device to use for initialization of the inverse frequencies.
|
123 |
+
seq_len (`int`, *optional*):
|
124 |
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
125 |
+
rope_kwargs (`Dict`, *optional*):
|
126 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
127 |
+
Returns:
|
128 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
129 |
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
130 |
+
"""
|
131 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
132 |
+
if config is not None and len(rope_kwargs) > 0:
|
133 |
+
raise ValueError(
|
134 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
135 |
+
f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
136 |
+
)
|
137 |
+
if len(rope_kwargs) > 0:
|
138 |
+
base = rope_kwargs["base"]
|
139 |
+
dim = rope_kwargs["dim"]
|
140 |
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
141 |
+
factor = rope_kwargs["factor"]
|
142 |
+
elif config is not None:
|
143 |
+
base = config.rope_theta
|
144 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
145 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
146 |
+
dim = int(head_dim * partial_rotary_factor)
|
147 |
+
max_position_embeddings = config.max_position_embeddings
|
148 |
+
factor = config.rope_scaling["factor"]
|
149 |
+
|
150 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
151 |
+
|
152 |
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
153 |
+
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
154 |
+
|
155 |
+
# Compute the inverse frequencies
|
156 |
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
157 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
|
158 |
+
return inv_freq, attention_factor
|
159 |
+
|
160 |
+
|
161 |
+
def _compute_yarn_parameters(
|
162 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
163 |
+
) -> Tuple["torch.Tensor", float]:
|
164 |
+
"""
|
165 |
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
166 |
+
[original paper](https://arxiv.org/abs/2309.00071)
|
167 |
+
Args:
|
168 |
+
config ([`~transformers.PretrainedConfig`]):
|
169 |
+
The model configuration.
|
170 |
+
device (`torch.device`):
|
171 |
+
The device to use for initialization of the inverse frequencies.
|
172 |
+
seq_len (`int`, *optional*):
|
173 |
+
The current sequence length. Unused for this type of RoPE.
|
174 |
+
rope_kwargs (`Dict`, *optional*):
|
175 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
176 |
+
Returns:
|
177 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
178 |
+
post-processing scaling factor applied to the computed cos/sin.
|
179 |
+
"""
|
180 |
+
# No need to keep BC with yarn, unreleased when this new pattern was created.
|
181 |
+
if len(rope_kwargs) > 0:
|
182 |
+
raise ValueError(
|
183 |
+
f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
|
184 |
+
)
|
185 |
+
|
186 |
+
base = config.rope_theta
|
187 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
188 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
189 |
+
dim = int(head_dim * partial_rotary_factor)
|
190 |
+
max_position_embeddings = config.max_position_embeddings
|
191 |
+
factor = config.rope_scaling["factor"]
|
192 |
+
|
193 |
+
# Sets the attention factor as suggested in the paper
|
194 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
195 |
+
if attention_factor is None:
|
196 |
+
attention_factor = 0.1 * math.log(factor) + 1.0
|
197 |
+
|
198 |
+
# Optional config options
|
199 |
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
200 |
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
201 |
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
202 |
+
|
203 |
+
# Compute the inverse frequencies
|
204 |
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
205 |
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
206 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
207 |
+
|
208 |
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
209 |
+
"""Find dimension range bounds based on rotations"""
|
210 |
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
211 |
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
212 |
+
return max(low, 0), min(high, dim - 1)
|
213 |
+
|
214 |
+
def linear_ramp_factor(min, max, dim):
|
215 |
+
if min == max:
|
216 |
+
max += 0.001 # Prevent singularity
|
217 |
+
|
218 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
219 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
220 |
+
return ramp_func
|
221 |
+
|
222 |
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
223 |
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
224 |
+
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
225 |
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
226 |
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
227 |
+
|
228 |
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
229 |
+
|
230 |
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
231 |
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
232 |
+
inv_freq = (
|
233 |
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
234 |
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
235 |
+
)
|
236 |
+
|
237 |
+
return inv_freq, attention_factor
|
238 |
+
|
239 |
+
|
240 |
+
def _compute_longrope_parameters(
|
241 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
242 |
+
) -> Tuple["torch.Tensor", float]:
|
243 |
+
"""
|
244 |
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
245 |
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
246 |
+
Args:
|
247 |
+
config ([`~transformers.PretrainedConfig`]):
|
248 |
+
The model configuration.
|
249 |
+
device (`torch.device`):
|
250 |
+
The device to use for initialization of the inverse frequencies.
|
251 |
+
seq_len (`int`, *optional*):
|
252 |
+
The current sequence length. Unused for this type of RoPE.
|
253 |
+
rope_kwargs (`Dict`, *optional*):
|
254 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
255 |
+
Returns:
|
256 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
257 |
+
post-processing scaling factor applied to the computed cos/sin.
|
258 |
+
"""
|
259 |
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
260 |
+
# No need to keep BC with longrope, unreleased when this new pattern was created.
|
261 |
+
if len(rope_kwargs) > 0:
|
262 |
+
raise ValueError(
|
263 |
+
"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
|
264 |
+
f"{rope_kwargs}"
|
265 |
+
)
|
266 |
+
|
267 |
+
base = config.rope_theta
|
268 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
269 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
270 |
+
dim = int(head_dim * partial_rotary_factor)
|
271 |
+
long_factor = config.rope_scaling["long_factor"]
|
272 |
+
short_factor = config.rope_scaling["short_factor"]
|
273 |
+
factor = config.rope_scaling.get("factor")
|
274 |
+
attention_factor = config.rope_scaling.get("attention_factor")
|
275 |
+
|
276 |
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
277 |
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
278 |
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
279 |
+
if hasattr(config, "original_max_position_embeddings"):
|
280 |
+
max_position_embeddings = config.original_max_position_embeddings
|
281 |
+
expanded_max_position_embeddings = config.max_position_embeddings
|
282 |
+
factor = expanded_max_position_embeddings / max_position_embeddings
|
283 |
+
else:
|
284 |
+
max_position_embeddings = config.max_position_embeddings
|
285 |
+
expanded_max_position_embeddings = max_position_embeddings * factor
|
286 |
+
|
287 |
+
# Sets the attention factor as suggested in the paper
|
288 |
+
if attention_factor is None:
|
289 |
+
if factor <= 1.0:
|
290 |
+
attention_factor = 1.0
|
291 |
+
else:
|
292 |
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
293 |
+
|
294 |
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
295 |
+
if expanded_max_position_embeddings > max_position_embeddings:
|
296 |
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
297 |
+
else:
|
298 |
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
299 |
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
|
300 |
+
inv_freq = 1.0 / (ext_factors * base ** inv_freq_shape)
|
301 |
+
|
302 |
+
return inv_freq, attention_factor
|
303 |
+
|
304 |
+
|
305 |
+
def _compute_llama3_parameters(
|
306 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
|
307 |
+
) -> Tuple["torch.Tensor", float]:
|
308 |
+
"""
|
309 |
+
Computes the inverse frequencies for llama 3.1.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
config ([`~transformers.PretrainedConfig`]):
|
313 |
+
The model configuration.
|
314 |
+
device (`torch.device`):
|
315 |
+
The device to use for initialization of the inverse frequencies.
|
316 |
+
seq_len (`int`, *optional*):
|
317 |
+
The current sequence length. Unused for this type of RoPE.
|
318 |
+
rope_kwargs (`Dict`, *optional*):
|
319 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
320 |
+
Returns:
|
321 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
322 |
+
post-processing scaling factor applied to the computed cos/sin.
|
323 |
+
"""
|
324 |
+
# Gets the default RoPE parameters
|
325 |
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
|
326 |
+
|
327 |
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
328 |
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
329 |
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
330 |
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
331 |
+
|
332 |
+
low_freq_wavelen = old_context_len / low_freq_factor
|
333 |
+
high_freq_wavelen = old_context_len / high_freq_factor
|
334 |
+
|
335 |
+
wavelen = 2 * math.pi / inv_freq
|
336 |
+
# wavelen < high_freq_wavelen: do nothing
|
337 |
+
# wavelen > low_freq_wavelen: divide by factor
|
338 |
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
339 |
+
# otherwise: interpolate between the two, using a smooth factor
|
340 |
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
341 |
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
342 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
343 |
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
344 |
+
|
345 |
+
return inv_freq_llama, attention_factor
|
346 |
+
|
347 |
+
|
348 |
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
349 |
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
350 |
+
# parameterizations, as long as the callable has the same signature.
|
351 |
+
ROPE_INIT_FUNCTIONS = {
|
352 |
+
"default": _compute_default_rope_parameters,
|
353 |
+
"linear": _compute_linear_scaling_rope_parameters,
|
354 |
+
"dynamic": _compute_dynamic_ntk_parameters,
|
355 |
+
"yarn": _compute_yarn_parameters,
|
356 |
+
"longrope": _compute_longrope_parameters,
|
357 |
+
"llama3": _compute_llama3_parameters,
|
358 |
+
}
|
359 |
+
|
360 |
+
|
361 |
+
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
362 |
+
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
363 |
+
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
364 |
+
if "type" in received_keys:
|
365 |
+
received_keys -= {"type"}
|
366 |
+
required_keys.add("rope_type")
|
367 |
+
|
368 |
+
missing_keys = required_keys - received_keys
|
369 |
+
if missing_keys:
|
370 |
+
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
|
371 |
+
|
372 |
+
if optional_keys is not None:
|
373 |
+
unused_keys = received_keys - required_keys - optional_keys
|
374 |
+
else:
|
375 |
+
unused_keys = received_keys - required_keys
|
376 |
+
if unused_keys:
|
377 |
+
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
378 |
+
|
379 |
+
|
380 |
+
def _validate_default_rope_parameters(config: PretrainedConfig):
|
381 |
+
rope_scaling = config.rope_scaling
|
382 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
383 |
+
required_keys = {"rope_type"}
|
384 |
+
received_keys = set(rope_scaling.keys())
|
385 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
386 |
+
|
387 |
+
|
388 |
+
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
389 |
+
rope_scaling = config.rope_scaling
|
390 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
391 |
+
required_keys = {"rope_type", "factor"}
|
392 |
+
received_keys = set(rope_scaling.keys())
|
393 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
394 |
+
|
395 |
+
factor = rope_scaling["factor"]
|
396 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
397 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
398 |
+
|
399 |
+
|
400 |
+
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
401 |
+
rope_scaling = config.rope_scaling
|
402 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
403 |
+
required_keys = {"rope_type", "factor"}
|
404 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
405 |
+
optional_keys = {"original_max_position_embeddings"}
|
406 |
+
received_keys = set(rope_scaling.keys())
|
407 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
408 |
+
|
409 |
+
factor = rope_scaling["factor"]
|
410 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
411 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
412 |
+
|
413 |
+
|
414 |
+
def _validate_yarn_parameters(config: PretrainedConfig):
|
415 |
+
rope_scaling = config.rope_scaling
|
416 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
417 |
+
required_keys = {"rope_type", "factor"}
|
418 |
+
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
419 |
+
received_keys = set(rope_scaling.keys())
|
420 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
421 |
+
|
422 |
+
factor = rope_scaling["factor"]
|
423 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
424 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
425 |
+
|
426 |
+
attention_factor = rope_scaling.get("attention_factor")
|
427 |
+
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
|
428 |
+
logger.warning(
|
429 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
430 |
+
)
|
431 |
+
beta_fast = rope_scaling.get("beta_fast")
|
432 |
+
if beta_fast is not None and not isinstance(beta_fast, float):
|
433 |
+
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
|
434 |
+
beta_slow = rope_scaling.get("beta_slow")
|
435 |
+
if beta_slow is not None and not isinstance(beta_slow, float):
|
436 |
+
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
|
437 |
+
|
438 |
+
if (beta_fast or 32) < (beta_slow or 1):
|
439 |
+
logger.warning(
|
440 |
+
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
|
441 |
+
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
|
442 |
+
)
|
443 |
+
|
444 |
+
|
445 |
+
def _validate_longrope_parameters(config: PretrainedConfig):
|
446 |
+
rope_scaling = config.rope_scaling
|
447 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
448 |
+
required_keys = {"rope_type", "short_factor", "long_factor"}
|
449 |
+
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
450 |
+
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
451 |
+
received_keys = set(rope_scaling.keys())
|
452 |
+
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
453 |
+
|
454 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
455 |
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
456 |
+
dim = int(head_dim * partial_rotary_factor)
|
457 |
+
|
458 |
+
short_factor = rope_scaling.get("short_factor")
|
459 |
+
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
460 |
+
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
|
461 |
+
if not len(short_factor) == dim // 2:
|
462 |
+
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
|
463 |
+
|
464 |
+
long_factor = rope_scaling.get("long_factor")
|
465 |
+
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
|
466 |
+
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
|
467 |
+
if not len(long_factor) == dim // 2:
|
468 |
+
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
|
469 |
+
|
470 |
+
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
|
471 |
+
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
|
472 |
+
# unique to longrope (= undesirable)
|
473 |
+
if hasattr(config, "original_max_position_embeddings"):
|
474 |
+
logger.warning_once(
|
475 |
+
"This model has set a `original_max_position_embeddings` field, to be used together with "
|
476 |
+
"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
|
477 |
+
"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
|
478 |
+
"as it is compatible with most model architectures."
|
479 |
+
)
|
480 |
+
else:
|
481 |
+
factor = rope_scaling.get("factor")
|
482 |
+
if factor is None:
|
483 |
+
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
|
484 |
+
elif not isinstance(factor, float) or factor < 1.0:
|
485 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
486 |
+
|
487 |
+
attention_factor = rope_scaling.get("attention_factor")
|
488 |
+
if attention_factor is not None:
|
489 |
+
if not isinstance(attention_factor, float) or attention_factor < 0.0:
|
490 |
+
logger.warning(
|
491 |
+
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
492 |
+
)
|
493 |
+
|
494 |
+
|
495 |
+
def _validate_llama3_parameters(config: PretrainedConfig):
|
496 |
+
rope_scaling = config.rope_scaling
|
497 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
498 |
+
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
499 |
+
received_keys = set(rope_scaling.keys())
|
500 |
+
_check_received_keys(rope_type, received_keys, required_keys)
|
501 |
+
|
502 |
+
factor = rope_scaling["factor"]
|
503 |
+
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
504 |
+
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
505 |
+
|
506 |
+
low_freq_factor = rope_scaling["low_freq_factor"]
|
507 |
+
high_freq_factor = rope_scaling["high_freq_factor"]
|
508 |
+
if low_freq_factor is None or not isinstance(low_freq_factor, float):
|
509 |
+
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
510 |
+
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
511 |
+
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
512 |
+
if high_freq_factor <= low_freq_factor:
|
513 |
+
logger.warning(
|
514 |
+
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
515 |
+
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
516 |
+
)
|
517 |
+
|
518 |
+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
519 |
+
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
|
520 |
+
logger.warning(
|
521 |
+
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
|
522 |
+
f"{original_max_position_embeddings}"
|
523 |
+
)
|
524 |
+
if original_max_position_embeddings >= config.max_position_embeddings:
|
525 |
+
logger.warning(
|
526 |
+
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
|
527 |
+
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
|
528 |
+
)
|
529 |
+
|
530 |
+
|
531 |
+
# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
|
532 |
+
ROPE_VALIDATION_FUNCTIONS = {
|
533 |
+
"default": _validate_default_rope_parameters,
|
534 |
+
"linear": _validate_linear_scaling_rope_parameters,
|
535 |
+
"dynamic": _validate_dynamic_scaling_rope_parameters,
|
536 |
+
"yarn": _validate_yarn_parameters,
|
537 |
+
"longrope": _validate_longrope_parameters,
|
538 |
+
"llama3": _validate_llama3_parameters,
|
539 |
+
}
|
540 |
+
|
541 |
+
|
542 |
+
def rope_config_validation(config: PretrainedConfig):
|
543 |
+
"""
|
544 |
+
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
545 |
+
"""
|
546 |
+
rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
|
547 |
+
if rope_scaling is None:
|
548 |
+
return
|
549 |
+
|
550 |
+
# BC: "rope_type" was originally "type"
|
551 |
+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
552 |
+
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
553 |
+
if validation_fn is not None:
|
554 |
+
validation_fn(config)
|
555 |
+
else:
|
556 |
+
logger.warning(
|
557 |
+
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
558 |
+
)
|