|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import math |
|
|
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, device='cpu'): |
|
freqs_part = torch.arange(0, dim, 2)[: (dim // 2)].float() / dim |
|
freqs = 1.0 / (theta ** freqs_part).to(device) |
|
t = torch.arange(end, device=device) |
|
freqs = torch.outer(t, freqs).float() |
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
return freqs_cis |
|
|
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
|
ndim = x.ndim |
|
assert 0 <= 1 < ndim |
|
if not freqs_cis.shape == (x.shape[1], x.shape[-1]): |
|
raise ValueError(f"Freqs shape {freqs_cis.shape} does not match x shape {x.shape} at dims 1 and -1") |
|
shape = [1] * (ndim - 2) + list(freqs_cis.shape) |
|
return freqs_cis.view(*shape) |
|
|
|
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): |
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) |
|
xq_out_complex = xq_ * freqs_cis |
|
xk_out_complex = xk_ * freqs_cis |
|
xq_out = torch.view_as_real(xq_out_complex).flatten(start_dim=2) |
|
xk_out = torch.view_as_real(xk_out_complex).flatten(start_dim=2) |
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|