zhiyuan8 commited on
Commit
a6e35c0
·
verified ·
1 Parent(s): 3359596

Upload wkv.py

Browse files
Files changed (1) hide show
  1. wkv.py +163 -122
wkv.py CHANGED
@@ -1,16 +1,16 @@
 
1
  import torch
2
  from einops import rearrange
3
 
4
- from .hybrid_cache import TimeMixState, BlockState
5
  import math
6
  import torch.nn as nn
7
  from torch.nn import functional as F
8
  from .configuration_rwkv_hybrid import RwkvHybridConfig
9
- from typing import TYPE_CHECKING, Optional
10
- from transformers.cache_utils import Cache
11
 
12
  try:
13
- import triton
14
  from rwkvfla.ops.rwkv7 import (
15
  fused_recurrent_rwkv7,
16
  chunk_rwkv7,
@@ -33,17 +33,26 @@ except ImportError:
33
  fused_recurrent_rwkv6 = native_recurrent_rwkv6
34
  fused_addcmul_rwkv7 = torch_addcmul_rwkv7
35
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class Rwkv_Tmix_x070(nn.Module):
38
- def __init__(self, args: RwkvHybridConfig, layer_id, update_v_first, get_v_first):
39
  super().__init__()
40
  self.args = args
41
  self.layer_id = layer_id
42
  self.hidden_size = args.hidden_size
43
 
44
- self.update_v_first = update_v_first
45
- self.get_v_first = get_v_first
46
-
47
  self.head_size = args.head_size
48
  self.n_head = args.num_wkv_heads
49
  assert args.hidden_size % self.n_head == 0
@@ -55,7 +64,7 @@ class Rwkv_Tmix_x070(nn.Module):
55
  self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
56
  self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
57
  self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
58
-
59
  D_DECAY_LORA = 64
60
  D_AAA_LORA = 64
61
  D_MV_LORA = 32
@@ -122,7 +131,6 @@ class Rwkv_Tmix_x070(nn.Module):
122
  )
123
  nn.init.constant_(
124
  self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
125
-
126
 
127
  def ortho_init(x, scale):
128
  shape = x.shape
@@ -181,7 +189,7 @@ class Rwkv_Tmix_x070(nn.Module):
181
  D_GATE_LORA, self.args.hidden_size), 0.1)
182
  )
183
  nn.init.constant_(
184
- self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
185
 
186
  nn.init.constant_(self.k_k, 0.85)
187
  nn.init.constant_(self.k_a, 1.0)
@@ -196,27 +204,27 @@ class Rwkv_Tmix_x070(nn.Module):
196
  nn.init.ones_(self.ln_x.weight)
197
  nn.init.zeros_(self.ln_x.bias)
198
 
199
- def apply_wkv7_state(self, r, k, v, w, a, b, s,
200
- output_final_state,
201
- cu_seqlens,
202
- head_first
203
- ):
204
-
205
- if r.device.type == "cpu":
206
- r, w, k, v, a, b = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
207
  o, state = native_recurrent_rwkv7(
208
  r=r, k=k, v=v, w=w,
209
  a=a, b=b,
210
  scale=1.0,
211
- initial_state=s.transpose(-1, -2),
212
  output_final_state=True,
213
  head_first=True,
214
  )
215
- state = state.transpose(-1, -2)
216
  x = rearrange(o, "b h l d -> b l (h d)")
217
  else:
218
- r, w, k, v, a, b = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
219
- wkv7_func = chunk_rwkv7 if self.training else fused_recurrent_rwkv7
 
220
  o, state = wkv7_func(
221
  r=r, k=k, v=v, w=w,
222
  a=a, b=b,
@@ -224,32 +232,27 @@ class Rwkv_Tmix_x070(nn.Module):
224
  initial_state=s,
225
  output_final_state=output_final_state,
226
  cu_seqlens=cu_seqlens,
227
- head_first=head_first,
228
  )
229
  x = rearrange(o, "b l h d -> b l (h d)")
230
  return x, state
231
 
 
232
  def forward(
233
  self,
234
  hidden_states,
235
- last_state: TimeMixState,
236
- sequence_mask: Optional[torch.Tensor] = None,
237
  use_cache: Optional[bool] = False,
238
  cu_seqlens: Optional[torch.Tensor] = None,
 
 
239
  **kwargs
240
  ):
241
- if sequence_mask is not None:
242
- hidden_states = hidden_states.mul(
243
- sequence_mask[:, -hidden_states.shape[-2]:, None])
244
-
245
  shift_state = last_state.shift_state
246
  B, T, C = hidden_states.size()
247
 
248
- if shift_state is not None:
249
- xx = torch.concat((shift_state.unsqueeze(
250
- 1), hidden_states[:, :-1]), dim=1) - hidden_states
251
- else:
252
- xx = self.time_shift(hidden_states) - hidden_states
253
 
254
  lx = hidden_states[:, -1]
255
 
@@ -257,7 +260,8 @@ class Rwkv_Tmix_x070(nn.Module):
257
  xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
258
  hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
259
  else:
260
- xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a)
 
261
 
262
  r = self.receptance(xr)
263
  w = (
@@ -266,21 +270,23 @@ class Rwkv_Tmix_x070(nn.Module):
266
  k = self.key(xk)
267
  v = self.value(xv)
268
  if self.layer_id == 0:
269
- self.update_v_first(v)
270
  else:
271
- # Original implementation
272
- v = v + (self.get_v_first().to(v.device) - v) * torch.sigmoid(
273
  self.v0 + (xv @ self.v1) @ self.v2
274
- ) # add value residual
275
 
 
 
276
  a = torch.sigmoid(
277
  self.a0 + (xa @ self.a1) @ self.a2
278
  ) # a is "in-context learning rate"
279
  if self.args.wkv_has_gate:
280
- g = torch.sigmoid(xg @ self.g1) @ self.g2
281
  kk = k * self.k_k
282
- kk = F.normalize(kk.view(B, T, self.n_head, -1), dim=-1, p=2.0).view(B, T, C)
283
- k = k * (1 + (a - 1) * self.k_a)
 
284
 
285
  wkv_state = last_state.wkv_state
286
  hidden_states, wkv_state = self.apply_wkv7_state(
@@ -292,66 +298,68 @@ class Rwkv_Tmix_x070(nn.Module):
292
  (kk * a),
293
  s=wkv_state,
294
  output_final_state=use_cache,
295
- cu_seqlens=cu_seqlens,
296
- head_first=False
297
  )
298
  if self.args.wkv_has_group_norm:
299
  hidden_states = self.ln_x(
300
  hidden_states.view(B * T, C)).view(B, T, C)
301
- hidden_states = hidden_states + (
302
- (r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum(
303
- dim=-1, keepdim=True
304
- )
305
- * v.view(B, T, self.n_head, -1)
306
- ).view(B, T, C)
 
 
 
307
  hidden_states = self.output(
308
  hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
309
- return hidden_states, TimeMixState(lx, wkv_state)
310
 
311
 
312
  class Rwkv7Attention(nn.Module):
313
- def __init__(self, args: RwkvHybridConfig, layer_id, update_v_first, get_v_first):
314
  super().__init__()
315
  self.args = args
316
  self.layer_idx = layer_id
317
- self.time_mixer = Rwkv_Tmix_x070(
318
- args, layer_id, update_v_first, get_v_first)
319
 
320
  def forward(
321
  self,
322
  hidden_states: torch.Tensor,
323
- sequence_mask: Optional[torch.Tensor] = None,
324
- past_key_value: Optional[Cache] = None,
325
- use_cache: Optional[bool] = False,
326
  output_attentions: Optional[bool] = False,
 
 
 
 
 
327
  **kwargs
328
  ):
329
- if sequence_mask is not None:
330
- assert len(sequence_mask.shape) == 2, (
331
- "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
332
- "for padding purposes (0 indicating padding). "
333
- "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
334
- )
335
  batch_size, token_length, _ = hidden_states.shape
336
 
337
- if past_key_value is not None and len(past_key_value) > self.layer_idx:
338
  last_state = past_key_value[self.layer_idx][0]
339
  else:
340
  last_state = self.init_state(
341
  batch_size, hidden_states.device, hidden_states.dtype
342
  )
343
 
344
- attn_output, states = self.time_mixer(hidden_states=hidden_states,
345
- last_state=last_state.time_mix_state,
346
- sequence_mask=sequence_mask,
347
- use_cache=use_cache,
348
- **kwargs)
349
- last_state.time_mix_state = states
350
 
351
- if past_key_value is not None:
 
352
  past_key_value.update(token_length, last_state, self.layer_idx)
353
 
354
- return attn_output, None
355
 
356
  def init_state(self, batch_size, device, dtype) -> BlockState:
357
  wkv_states = torch.zeros(
@@ -364,10 +372,10 @@ class Rwkv7Attention(nn.Module):
364
  device=device,
365
  dtype=torch.float32,
366
  )
367
- token_shift = torch.zeros(
368
  (batch_size, self.args.hidden_size), device=device, dtype=dtype
369
  )
370
- return BlockState(TimeMixState(token_shift, wkv_states), None)
371
 
372
 
373
  class Rwkv_Tmix_x060(nn.Module):
@@ -380,8 +388,6 @@ class Rwkv_Tmix_x060(nn.Module):
380
  self.head_size = args.head_size
381
  self.n_head = args.num_wkv_heads
382
  assert args.hidden_size % self.n_head == 0
383
- H = self.n_head
384
- N = self.head_size
385
 
386
  with torch.no_grad():
387
  ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
@@ -445,7 +451,6 @@ class Rwkv_Tmix_x060(nn.Module):
445
 
446
  self.time_faaaa = nn.Parameter(
447
  tmp.reshape(self.n_head, self.head_size))
448
- # self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
449
 
450
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
451
  self.receptance = nn.Linear(
@@ -465,27 +470,36 @@ class Rwkv_Tmix_x060(nn.Module):
465
  def post_init(self):
466
  pass
467
 
468
- def forward(self, x, last_state: TimeMixState):
 
 
 
 
 
 
 
 
 
469
  shift_state = last_state.shift_state
470
- B, T, C = x.size()
471
  H = self.n_head
472
- if shift_state is not None:
473
- xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
474
- else:
475
- xx = self.time_shift(x) - x
476
- lx = x[:, -1]
477
 
478
- xxx = x + xx * self.time_maa_x
 
 
 
 
 
479
  xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
480
  T, 5, -1).transpose(0, 1)
481
  xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
482
  mw, mk, mv, mr, mg = xxx.unbind(dim=0)
483
 
484
- xw = x + xx * (self.time_maa_w + mw)
485
- xk = x + xx * (self.time_maa_k + mk)
486
- xv = x + xx * (self.time_maa_v + mv)
487
- xr = x + xx * (self.time_maa_r + mr)
488
- xg = x + xx * (self.time_maa_g + mg)
489
 
490
  r = self.receptance(xr)
491
  k = self.key(xk)
@@ -496,16 +510,18 @@ class Rwkv_Tmix_x060(nn.Module):
496
  w = self.time_decay + ww
497
 
498
  wkv_state = last_state.wkv_state
499
- x, wkv_state = self.apply_wkv6_state(
500
  B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
501
  )
502
  if self.args.wkv_has_group_norm:
503
- x = self.ln_x(x.view(B * T, C)).view(B, T, C)
504
- x = self.output(x * g)
505
- return x, TimeMixState(lx, wkv_state)
 
506
 
507
  def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
508
- r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v))
 
509
 
510
  if r.device.type == "cpu":
511
  wkv6_func = native_recurrent_rwkv6
@@ -535,31 +551,56 @@ class Rwkv6Attention(nn.Module):
535
  self.layer_idx = layer_id
536
  self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs)
537
 
538
- def forward(self, hidden_states, past_key_value, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  attn_output = hidden_states
540
- B, T, C = attn_output.size()
541
- if past_key_value is not None:
542
- if len(past_key_value) <= self.layer_idx:
543
- last_state = None
544
- else:
545
- last_state = past_key_value[self.layer_idx][0]
546
- if last_state is None:
547
- wkv_states = torch.zeros(
548
- (B, self.args.num_wkv_heads,
549
- self.args.head_size, self.args.head_size),
550
- device=attn_output.device,
551
- dtype=torch.float32,
552
- )
553
- token_shift = torch.zeros(
554
- (B, C), device=attn_output.device, dtype=attn_output.dtype
555
  )
556
- time_state = TimeMixState(token_shift, wkv_states)
557
- channel_state = None
558
- last_state = BlockState(time_state, channel_state)
559
- attn_output, states = self.time_mixer(
560
- attn_output, last_state.time_mix_state)
561
- last_state.time_mix_state = states
562
-
563
- if past_key_value is not None:
564
- past_key_value.update(T, last_state, self.layer_idx)
565
- return attn_output, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import torch
3
  from einops import rearrange
4
 
 
5
  import math
6
  import torch.nn as nn
7
  from torch.nn import functional as F
8
  from .configuration_rwkv_hybrid import RwkvHybridConfig
9
+ from typing import Optional
10
+ from .hybrid_cache import HybridCache, AttnState, BlockState
11
 
12
  try:
13
+ import triton # pylint: disable=F401
14
  from rwkvfla.ops.rwkv7 import (
15
  fused_recurrent_rwkv7,
16
  chunk_rwkv7,
 
33
  fused_recurrent_rwkv6 = native_recurrent_rwkv6
34
  fused_addcmul_rwkv7 = torch_addcmul_rwkv7
35
 
36
+ from rwkvfla.utils import check_pytorch_version
37
+
38
+ if check_pytorch_version("2.6"):
39
+ compile_decorator = torch.compile
40
+ torch._dynamo.config.cache_size_limit = 512
41
+ else:
42
+ def compile_decorator(func):
43
+ return func
44
+
45
+ wkv_mode = os.environ.get("WKV_MODE", "fused")
46
+ wkv_mode = wkv_mode.lower()
47
+ assert wkv_mode in ['fused', 'chunk', 'pytorch']
48
 
49
  class Rwkv_Tmix_x070(nn.Module):
50
+ def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
51
  super().__init__()
52
  self.args = args
53
  self.layer_id = layer_id
54
  self.hidden_size = args.hidden_size
55
 
 
 
 
56
  self.head_size = args.head_size
57
  self.n_head = args.num_wkv_heads
58
  assert args.hidden_size % self.n_head == 0
 
64
  self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
65
  self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
66
  self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
67
+
68
  D_DECAY_LORA = 64
69
  D_AAA_LORA = 64
70
  D_MV_LORA = 32
 
131
  )
132
  nn.init.constant_(
133
  self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
 
134
 
135
  def ortho_init(x, scale):
136
  shape = x.shape
 
189
  D_GATE_LORA, self.args.hidden_size), 0.1)
190
  )
191
  nn.init.constant_(
192
+ self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
193
 
194
  nn.init.constant_(self.k_k, 0.85)
195
  nn.init.constant_(self.k_a, 1.0)
 
204
  nn.init.ones_(self.ln_x.weight)
205
  nn.init.zeros_(self.ln_x.bias)
206
 
207
+ def apply_wkv7_state(
208
+ self, r, k, v, w, a, b, s,
209
+ output_final_state,
210
+ cu_seqlens
211
+ ):
212
+ if wkv_mode == 'pytorch':
213
+ r, w, k, v, a, b = map(lambda x: rearrange(
214
+ x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
215
  o, state = native_recurrent_rwkv7(
216
  r=r, k=k, v=v, w=w,
217
  a=a, b=b,
218
  scale=1.0,
219
+ initial_state=s,
220
  output_final_state=True,
221
  head_first=True,
222
  )
 
223
  x = rearrange(o, "b h l d -> b l (h d)")
224
  else:
225
+ r, w, k, v, a, b = map(lambda x: rearrange(
226
+ x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
227
+ wkv7_func = chunk_rwkv7 if wkv_mode == 'chunk' else fused_recurrent_rwkv7
228
  o, state = wkv7_func(
229
  r=r, k=k, v=v, w=w,
230
  a=a, b=b,
 
232
  initial_state=s,
233
  output_final_state=output_final_state,
234
  cu_seqlens=cu_seqlens,
235
+ head_first=False,
236
  )
237
  x = rearrange(o, "b l h d -> b l (h d)")
238
  return x, state
239
 
240
+ @compile_decorator
241
  def forward(
242
  self,
243
  hidden_states,
244
+ last_state: AttnState,
 
245
  use_cache: Optional[bool] = False,
246
  cu_seqlens: Optional[torch.Tensor] = None,
247
+ v_first: Optional[torch.Tensor] = None,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
  **kwargs
250
  ):
 
 
 
 
251
  shift_state = last_state.shift_state
252
  B, T, C = hidden_states.size()
253
 
254
+ xx = torch.concat((shift_state.unsqueeze(
255
+ 1), hidden_states[:, :-1]), dim=1) - hidden_states
 
 
 
256
 
257
  lx = hidden_states[:, -1]
258
 
 
260
  xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
261
  hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
262
  else:
263
+ xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(
264
+ hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a)
265
 
266
  r = self.receptance(xr)
267
  w = (
 
270
  k = self.key(xk)
271
  v = self.value(xv)
272
  if self.layer_id == 0:
273
+ v_first = v
274
  else:
275
+ v = torch.lerp(v, v_first, torch.sigmoid(
 
276
  self.v0 + (xv @ self.v1) @ self.v2
277
+ )) # add value residual
278
 
279
+ if attention_mask is not None:
280
+ v = v.mul(attention_mask[:, -v.shape[-2]:, None])
281
  a = torch.sigmoid(
282
  self.a0 + (xa @ self.a1) @ self.a2
283
  ) # a is "in-context learning rate"
284
  if self.args.wkv_has_gate:
285
+ g = torch.sigmoid(xg @ self.g1) @ self.g2 + 1.0
286
  kk = k * self.k_k
287
+ kk = F.normalize(kk.view(B, T, self.n_head, -1),
288
+ p=2.0, dim=-1, eps=1e-4 if kk.dtype == torch.float16 else 1e-12).view(B, T, C)
289
+ k = torch.lerp(k, k * a, self.k_a)
290
 
291
  wkv_state = last_state.wkv_state
292
  hidden_states, wkv_state = self.apply_wkv7_state(
 
298
  (kk * a),
299
  s=wkv_state,
300
  output_final_state=use_cache,
301
+ cu_seqlens=cu_seqlens
 
302
  )
303
  if self.args.wkv_has_group_norm:
304
  hidden_states = self.ln_x(
305
  hidden_states.view(B * T, C)).view(B, T, C)
306
+
307
+ # original code:
308
+ # weighted_sum_rk = (r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum(
309
+ # dim=-1, keepdim=True
310
+ # )
311
+ weighted_sum_rk = torch.einsum('btij,btij,ij->btij', r.view(B, T, self.n_head, -1),
312
+ k.view(B, T, self.n_head, -1), self.r_k).sum(dim=-1, keepdim=True)
313
+ hidden_states = hidden_states + \
314
+ (weighted_sum_rk * v.view(B, T, self.n_head, -1)).view(B, T, C)
315
  hidden_states = self.output(
316
  hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
317
+ return hidden_states, AttnState(lx, wkv_state), v_first
318
 
319
 
320
  class Rwkv7Attention(nn.Module):
321
+ def __init__(self, args: RwkvHybridConfig, layer_id):
322
  super().__init__()
323
  self.args = args
324
  self.layer_idx = layer_id
325
+ self.time_mixer = Rwkv_Tmix_x070(args, layer_id)
 
326
 
327
  def forward(
328
  self,
329
  hidden_states: torch.Tensor,
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ position_ids: Optional[torch.Tensor] = None,
332
+ past_key_value: Optional[HybridCache] = None,
333
  output_attentions: Optional[bool] = False,
334
+ use_cache: Optional[bool] = False,
335
+ cache_position: Optional[torch.Tensor] = None,
336
+ position_embeddings: Optional[torch.Tensor] = None,
337
+ cu_seqlens: Optional[torch.Tensor] = None,
338
+ v_first: Optional[torch.Tensor] = None,
339
  **kwargs
340
  ):
341
+
 
 
 
 
 
342
  batch_size, token_length, _ = hidden_states.shape
343
 
344
+ if use_cache and len(past_key_value) > self.layer_idx:
345
  last_state = past_key_value[self.layer_idx][0]
346
  else:
347
  last_state = self.init_state(
348
  batch_size, hidden_states.device, hidden_states.dtype
349
  )
350
 
351
+ attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states,
352
+ last_state=last_state.attn_state,
353
+ use_cache=use_cache,
354
+ cu_seqlens=cu_seqlens,
355
+ v_first=v_first,
356
+ **kwargs)
357
 
358
+ if use_cache:
359
+ last_state.attn_state = states
360
  past_key_value.update(token_length, last_state, self.layer_idx)
361
 
362
+ return attn_output, None, v_first
363
 
364
  def init_state(self, batch_size, device, dtype) -> BlockState:
365
  wkv_states = torch.zeros(
 
372
  device=device,
373
  dtype=torch.float32,
374
  )
375
+ shift_states = torch.zeros(
376
  (batch_size, self.args.hidden_size), device=device, dtype=dtype
377
  )
378
+ return BlockState(AttnState(shift_states, wkv_states), None)
379
 
380
 
381
  class Rwkv_Tmix_x060(nn.Module):
 
388
  self.head_size = args.head_size
389
  self.n_head = args.num_wkv_heads
390
  assert args.hidden_size % self.n_head == 0
 
 
391
 
392
  with torch.no_grad():
393
  ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
 
451
 
452
  self.time_faaaa = nn.Parameter(
453
  tmp.reshape(self.n_head, self.head_size))
 
454
 
455
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
456
  self.receptance = nn.Linear(
 
470
  def post_init(self):
471
  pass
472
 
473
+ @compile_decorator
474
+ def forward(
475
+ self,
476
+ hidden_states,
477
+ last_state: AttnState,
478
+ use_cache: Optional[bool] = False,
479
+ cu_seqlens: Optional[torch.Tensor] = None,
480
+ v_first: Optional[torch.Tensor] = None,
481
+ **kwargs
482
+ ):
483
  shift_state = last_state.shift_state
484
+ B, T, C = hidden_states.size()
485
  H = self.n_head
 
 
 
 
 
486
 
487
+ xx = torch.concat((shift_state.unsqueeze(
488
+ 1), hidden_states[:, :-1]), dim=1) - hidden_states
489
+
490
+ lx = hidden_states[:, -1]
491
+
492
+ xxx = hidden_states + xx * self.time_maa_x
493
  xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
494
  T, 5, -1).transpose(0, 1)
495
  xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
496
  mw, mk, mv, mr, mg = xxx.unbind(dim=0)
497
 
498
+ xw = hidden_states + xx * (self.time_maa_w + mw)
499
+ xk = hidden_states + xx * (self.time_maa_k + mk)
500
+ xv = hidden_states + xx * (self.time_maa_v + mv)
501
+ xr = hidden_states + xx * (self.time_maa_r + mr)
502
+ xg = hidden_states + xx * (self.time_maa_g + mg)
503
 
504
  r = self.receptance(xr)
505
  k = self.key(xk)
 
510
  w = self.time_decay + ww
511
 
512
  wkv_state = last_state.wkv_state
513
+ hidden_states, wkv_state = self.apply_wkv6_state(
514
  B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
515
  )
516
  if self.args.wkv_has_group_norm:
517
+ hidden_states = self.ln_x(
518
+ hidden_states.view(B * T, C)).view(B, T, C)
519
+ hidden_states = self.output(hidden_states * g)
520
+ return hidden_states, AttnState(lx, wkv_state), None
521
 
522
  def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
523
+ r, w, k, v = map(lambda x: rearrange(
524
+ x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v))
525
 
526
  if r.device.type == "cpu":
527
  wkv6_func = native_recurrent_rwkv6
 
551
  self.layer_idx = layer_id
552
  self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs)
553
 
554
+ def forward(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ attention_mask: Optional[torch.Tensor] = None,
558
+ position_ids: Optional[torch.Tensor] = None,
559
+ past_key_value: Optional[HybridCache] = None,
560
+ output_attentions: Optional[bool] = False,
561
+ use_cache: Optional[bool] = False,
562
+ cache_position: Optional[torch.Tensor] = None,
563
+ position_embeddings: Optional[torch.Tensor] = None,
564
+ cu_seqlens: Optional[torch.Tensor] = None,
565
+ v_first: Optional[torch.Tensor] = None,
566
+ **kwargs
567
+ ):
568
  attn_output = hidden_states
569
+
570
+ batch_size, token_length, _ = hidden_states.shape
571
+
572
+ if use_cache and len(past_key_value) > self.layer_idx:
573
+ last_state = past_key_value[self.layer_idx][0]
574
+ else:
575
+ last_state = self.init_state(
576
+ batch_size, hidden_states.device, hidden_states.dtype
 
 
 
 
 
 
 
577
  )
578
+
579
+ attn_output, states, v_first = self.time_mixer(hidden_states=hidden_states,
580
+ last_state=last_state.attn_state,
581
+ use_cache=use_cache,
582
+ cu_seqlens=cu_seqlens,
583
+ v_first=v_first,
584
+ **kwargs)
585
+
586
+ if use_cache:
587
+ last_state.attn_state = states
588
+ past_key_value.update(token_length, last_state, self.layer_idx)
589
+
590
+ return attn_output, None, v_first
591
+
592
+ def init_state(self, batch_size, device, dtype) -> BlockState:
593
+ wkv_states = torch.zeros(
594
+ (
595
+ batch_size,
596
+ self.args.num_wkv_heads,
597
+ self.args.head_size,
598
+ self.args.head_size,
599
+ ),
600
+ device=device,
601
+ dtype=torch.float32,
602
+ )
603
+ shift_states = torch.zeros(
604
+ (batch_size, self.args.hidden_size), device=device, dtype=dtype
605
+ )
606
+ return BlockState(AttnState(shift_states, wkv_states), None)