Upload BD3LM
Browse files- config.json +1 -1
- configuration_bd3lm.py +1 -1
- modeling_bd3lm.py +99 -54
config.json
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
"architectures": [
|
4 |
"BD3LM"
|
5 |
],
|
6 |
-
"attn_backend": "
|
7 |
"auto_map": {
|
8 |
"AutoConfig": "configuration_bd3lm.BD3LMConfig",
|
9 |
"AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
|
|
|
3 |
"architectures": [
|
4 |
"BD3LM"
|
5 |
],
|
6 |
+
"attn_backend": "flex",
|
7 |
"auto_map": {
|
8 |
"AutoConfig": "configuration_bd3lm.BD3LMConfig",
|
9 |
"AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
|
configuration_bd3lm.py
CHANGED
@@ -15,7 +15,7 @@ class BD3LMConfig(transformers.PretrainedConfig):
|
|
15 |
vocab_size: int = 50258,
|
16 |
model_length: int = 1024,
|
17 |
cross_attn: bool = True,
|
18 |
-
attn_backend: str = '
|
19 |
hidden_dim: int = 768,
|
20 |
cond_dim: int = 129,
|
21 |
n_blocks: int = 12,
|
|
|
15 |
vocab_size: int = 50258,
|
16 |
model_length: int = 1024,
|
17 |
cross_attn: bool = True,
|
18 |
+
attn_backend: str = 'flex',
|
19 |
hidden_dim: int = 768,
|
20 |
cond_dim: int = 129,
|
21 |
n_blocks: int = 12,
|
modeling_bd3lm.py
CHANGED
@@ -5,13 +5,17 @@ import math
|
|
5 |
import typing
|
6 |
|
7 |
import einops
|
8 |
-
import
|
9 |
-
import flash_attn.layers.rotary
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
13 |
import transformers
|
14 |
from transformers import modeling_outputs
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
from .configuration_bd3lm import BD3LMConfig
|
17 |
|
@@ -21,21 +25,55 @@ torch._C._jit_set_profiling_executor(False)
|
|
21 |
torch._C._jit_override_can_fuse_on_cpu(True)
|
22 |
torch._C._jit_override_can_fuse_on_gpu(True)
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def bias_dropout_add_scale(
|
41 |
x: torch.Tensor,
|
@@ -132,12 +170,6 @@ def rotate_half(x):
|
|
132 |
def apply_rotary_pos_emb_torchscript(qkv, cos, sin):
|
133 |
return (qkv * cos) + (rotate_half(qkv) * sin)
|
134 |
|
135 |
-
def apply_rotary_pos_emb(qkv, cos, sin):
|
136 |
-
cos = cos[0,:,0,0,:cos.shape[-1]//2]
|
137 |
-
sin = sin[0,:,0,0,:sin.shape[-1]//2]
|
138 |
-
return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
|
139 |
-
|
140 |
-
|
141 |
# function overload
|
142 |
def modulate(x, shift, scale):
|
143 |
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
@@ -268,8 +300,9 @@ def regular_attention_multi_headed(qkv):
|
|
268 |
|
269 |
class DDiTBlock(nn.Module):
|
270 |
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
271 |
-
dropout=0.1, attn_backend='flash_attn'):
|
272 |
super().__init__()
|
|
|
273 |
self.n = n
|
274 |
self.block_size = block_size
|
275 |
self.n_heads = n_heads
|
@@ -317,32 +350,33 @@ class DDiTBlock(nn.Module):
|
|
317 |
h=self.n_heads)
|
318 |
with torch.cuda.amp.autocast(enabled=False):
|
319 |
cos, sin = rotary_cos_sin
|
320 |
-
|
321 |
-
qkv
|
322 |
-
qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
|
323 |
-
else:
|
324 |
-
qkv = apply_rotary_pos_emb_torchscript(
|
325 |
-
qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
|
326 |
return qkv
|
327 |
|
328 |
-
def cross_attn(self, x, qkv,
|
329 |
scale = qkv.shape[-1]
|
330 |
qkv = qkv.transpose(1, 3)
|
331 |
-
|
332 |
-
cross_attn_mask = cross_attn_mask.bool() if cross_attn_mask is not None else None
|
333 |
x = F.scaled_dot_product_attention(
|
334 |
query=qkv[:, :, 0],
|
335 |
key=qkv[:, :, 1],
|
336 |
value=qkv[:, :, 2],
|
337 |
-
attn_mask=
|
338 |
-
dropout_p=attn_dropout,
|
339 |
is_causal=False,
|
340 |
scale=1 / math.sqrt(scale))
|
341 |
x = x.transpose(1, 2)
|
342 |
x = einops.rearrange(x, 'b s h d -> b s (h d)')
|
343 |
return x
|
344 |
-
|
345 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
sample_mode=False, store_kv=False):
|
347 |
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
348 |
|
@@ -354,17 +388,21 @@ class DDiTBlock(nn.Module):
|
|
354 |
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
355 |
|
356 |
# get qkvs
|
357 |
-
if
|
358 |
qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
|
359 |
qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
|
360 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
361 |
else:
|
362 |
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
363 |
|
364 |
-
if
|
365 |
x = regular_attention_multi_headed(qkv)
|
|
|
|
|
|
|
|
|
366 |
else:
|
367 |
-
|
368 |
|
369 |
x = bias_dropout_scale_fn(self.attn_out(x),
|
370 |
None,
|
@@ -448,7 +486,7 @@ class DITBackbone(nn.Module):
|
|
448 |
config.vocab_size,
|
449 |
config.cond_dim)
|
450 |
if self.cross_attn:
|
451 |
-
self.gen_mask(config.model_length, self.block_size)
|
452 |
self.precision = torch.float32
|
453 |
|
454 |
def _get_bias_dropout_scale(self):
|
@@ -457,15 +495,18 @@ class DITBackbone(nn.Module):
|
|
457 |
else:
|
458 |
return bias_dropout_add_scale_fused_inference
|
459 |
|
460 |
-
def gen_mask(self, seqlen, block_size):
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
|
|
|
|
|
|
469 |
|
470 |
def forward(self, indices, sigma, sample_mode=False,
|
471 |
store_kv=False, output_hidden_states=False):
|
@@ -478,13 +519,13 @@ class DITBackbone(nn.Module):
|
|
478 |
c = F.silu(self.sigma_map(sigma))
|
479 |
if self.cross_attn:
|
480 |
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
481 |
-
|
482 |
# use block-causal mask only during sampling
|
483 |
if sample_mode:
|
484 |
-
|
485 |
self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
|
486 |
else:
|
487 |
-
|
488 |
rotary_cos_sin = self.rotary_emb(x)
|
489 |
|
490 |
with torch.cuda.amp.autocast(dtype=self.precision):
|
@@ -492,7 +533,7 @@ class DITBackbone(nn.Module):
|
|
492 |
x = self.blocks[i](x,
|
493 |
rotary_cos_sin,
|
494 |
c,
|
495 |
-
|
496 |
sample_mode=sample_mode,
|
497 |
store_kv=store_kv)
|
498 |
if output_hidden_states:
|
@@ -512,6 +553,7 @@ class BD3LM(transformers.PreTrainedModel):
|
|
512 |
self,
|
513 |
config: BD3LMConfig):
|
514 |
super().__init__(config)
|
|
|
515 |
self.backbone = DITBackbone(config)
|
516 |
if config.var_min:
|
517 |
self.register_buffer(
|
@@ -537,6 +579,9 @@ class BD3LM(transformers.PreTrainedModel):
|
|
537 |
torch.Tensor, typing.Tuple,
|
538 |
modeling_outputs.MaskedLMOutput]:
|
539 |
"""HF-compatible forward method."""
|
|
|
|
|
|
|
540 |
output_hidden_states = (
|
541 |
output_hidden_states
|
542 |
if output_hidden_states is not None
|
|
|
5 |
import typing
|
6 |
|
7 |
import einops
|
8 |
+
from functools import partial
|
|
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
import transformers
|
13 |
from transformers import modeling_outputs
|
14 |
+
try:
|
15 |
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
16 |
+
FLEX_ATTN_AVAILABLE = True
|
17 |
+
except:
|
18 |
+
FLEX_ATTN_AVAILABLE = False
|
19 |
|
20 |
from .configuration_bd3lm import BD3LMConfig
|
21 |
|
|
|
25 |
torch._C._jit_override_can_fuse_on_cpu(True)
|
26 |
torch._C._jit_override_can_fuse_on_gpu(True)
|
27 |
|
28 |
+
def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
|
29 |
+
"""
|
30 |
+
Constructs the specialized block diffusion attention mask for training
|
31 |
+
composed of three masks:
|
32 |
+
- **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
|
33 |
+
- **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
|
34 |
+
- **Block Causal Mask (M_BC)**: Attention to update x0
|
35 |
+
|
36 |
+
Args:
|
37 |
+
b, h: Batch and head indices (ignored for mask logic).
|
38 |
+
q_idx, kv_idx: Query and Key indices.
|
39 |
+
seq_len: Total sequence length.
|
40 |
+
block_size: Defines the block structure.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
A boolean attention mask.
|
44 |
+
"""
|
45 |
+
|
46 |
+
# Indicate whether token belongs to xt or x0
|
47 |
+
x0_flag_q = (q_idx >= n)
|
48 |
+
x0_flag_kv = (kv_idx >= n)
|
49 |
+
|
50 |
+
# Compute block indices
|
51 |
+
block_q = torch.where(x0_flag_q == 1,
|
52 |
+
(q_idx - n) // block_size,
|
53 |
+
q_idx // block_size)
|
54 |
+
block_kv = torch.where(x0_flag_kv == 1,
|
55 |
+
(kv_idx - n) // block_size,
|
56 |
+
kv_idx // block_size)
|
57 |
+
|
58 |
+
# **1. Block Diagonal Mask (M_BD) **
|
59 |
+
block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
|
60 |
+
|
61 |
+
# **2. Offset Block-Causal Mask (M_OBC) **
|
62 |
+
offset_block_causal = (
|
63 |
+
(block_q > block_kv)
|
64 |
+
& (x0_flag_kv == 1)
|
65 |
+
& (x0_flag_q == 0)
|
66 |
+
)
|
67 |
+
|
68 |
+
# **3. Block-Causal Mask (M_BC) **
|
69 |
+
block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
|
70 |
+
|
71 |
+
# **4. Combine Masks **
|
72 |
+
return block_diagonal | offset_block_causal | block_causal
|
73 |
+
|
74 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
75 |
+
def fused_flex_attention(q, k, v, mask=None):
|
76 |
+
return flex_attention(q, k, v, block_mask=mask)
|
77 |
|
78 |
def bias_dropout_add_scale(
|
79 |
x: torch.Tensor,
|
|
|
170 |
def apply_rotary_pos_emb_torchscript(qkv, cos, sin):
|
171 |
return (qkv * cos) + (rotate_half(qkv) * sin)
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
# function overload
|
174 |
def modulate(x, shift, scale):
|
175 |
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
|
300 |
|
301 |
class DDiTBlock(nn.Module):
|
302 |
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
303 |
+
dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
|
304 |
super().__init__()
|
305 |
+
self.max_seqlen = max_seqlen
|
306 |
self.n = n
|
307 |
self.block_size = block_size
|
308 |
self.n_heads = n_heads
|
|
|
350 |
h=self.n_heads)
|
351 |
with torch.cuda.amp.autocast(enabled=False):
|
352 |
cos, sin = rotary_cos_sin
|
353 |
+
qkv = apply_rotary_pos_emb_torchscript(
|
354 |
+
qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
|
|
|
|
|
|
|
|
|
355 |
return qkv
|
356 |
|
357 |
+
def cross_attn(self, x, qkv, mask=None):
|
358 |
scale = qkv.shape[-1]
|
359 |
qkv = qkv.transpose(1, 3)
|
360 |
+
mask = mask.bool() if mask is not None else None
|
|
|
361 |
x = F.scaled_dot_product_attention(
|
362 |
query=qkv[:, :, 0],
|
363 |
key=qkv[:, :, 1],
|
364 |
value=qkv[:, :, 2],
|
365 |
+
attn_mask=mask,
|
|
|
366 |
is_causal=False,
|
367 |
scale=1 / math.sqrt(scale))
|
368 |
x = x.transpose(1, 2)
|
369 |
x = einops.rearrange(x, 'b s h d -> b s (h d)')
|
370 |
return x
|
371 |
+
|
372 |
+
def cross_attn_flex(self, qkv, mask=None):
|
373 |
+
qkv = einops.rearrange(qkv, 'b s three h d -> b h three s d', h=self.n_heads)
|
374 |
+
x = fused_flex_attention(
|
375 |
+
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], mask=mask)
|
376 |
+
x = einops.rearrange(x, 'b h s d -> b s (h d)')
|
377 |
+
return x
|
378 |
+
|
379 |
+
def forward(self, x, rotary_cos_sin, c, mask=None,
|
380 |
sample_mode=False, store_kv=False):
|
381 |
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
382 |
|
|
|
388 |
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
389 |
|
390 |
# get qkvs
|
391 |
+
if mask is not None and not sample_mode:
|
392 |
qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
|
393 |
qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
|
394 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
395 |
else:
|
396 |
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
397 |
|
398 |
+
if mask is None and self.attn_backend == 'flash_attn':
|
399 |
x = regular_attention_multi_headed(qkv)
|
400 |
+
elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
|
401 |
+
x = self.cross_attn_flex(qkv, mask=mask)
|
402 |
+
elif self.attn_backend == 'sdpa':
|
403 |
+
x = self.cross_attn(x, qkv, mask=mask)
|
404 |
else:
|
405 |
+
raise ValueError('Unknown attention backend')
|
406 |
|
407 |
x = bias_dropout_scale_fn(self.attn_out(x),
|
408 |
None,
|
|
|
486 |
config.vocab_size,
|
487 |
config.cond_dim)
|
488 |
if self.cross_attn:
|
489 |
+
self.gen_mask(config.model_length, self.block_size, attn_backend=config.attn_backend)
|
490 |
self.precision = torch.float32
|
491 |
|
492 |
def _get_bias_dropout_scale(self):
|
|
|
495 |
else:
|
496 |
return bias_dropout_add_scale_fused_inference
|
497 |
|
498 |
+
def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
|
499 |
+
"""Genererates attention mask"""
|
500 |
+
if attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
|
501 |
+
self.mask = create_block_mask(
|
502 |
+
partial(block_diff_mask, block_size=block_size, n=seqlen),
|
503 |
+
B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
|
504 |
+
elif attn_backend == 'sdpa':
|
505 |
+
self.mask = block_diff_mask(
|
506 |
+
b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None],
|
507 |
+
kv_idx=torch.arange(seqlen*2)[None, :], block_size=block_size, n=seqlen)
|
508 |
+
else:
|
509 |
+
raise ValueError('Unknown attention backend')
|
510 |
|
511 |
def forward(self, indices, sigma, sample_mode=False,
|
512 |
store_kv=False, output_hidden_states=False):
|
|
|
519 |
c = F.silu(self.sigma_map(sigma))
|
520 |
if self.cross_attn:
|
521 |
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
522 |
+
mask = self.mask.to(x.device)
|
523 |
# use block-causal mask only during sampling
|
524 |
if sample_mode:
|
525 |
+
mask = mask[
|
526 |
self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
|
527 |
else:
|
528 |
+
mask = None
|
529 |
rotary_cos_sin = self.rotary_emb(x)
|
530 |
|
531 |
with torch.cuda.amp.autocast(dtype=self.precision):
|
|
|
533 |
x = self.blocks[i](x,
|
534 |
rotary_cos_sin,
|
535 |
c,
|
536 |
+
mask=mask,
|
537 |
sample_mode=sample_mode,
|
538 |
store_kv=store_kv)
|
539 |
if output_hidden_states:
|
|
|
553 |
self,
|
554 |
config: BD3LMConfig):
|
555 |
super().__init__(config)
|
556 |
+
self.config = config
|
557 |
self.backbone = DITBackbone(config)
|
558 |
if config.var_min:
|
559 |
self.register_buffer(
|
|
|
579 |
torch.Tensor, typing.Tuple,
|
580 |
modeling_outputs.MaskedLMOutput]:
|
581 |
"""HF-compatible forward method."""
|
582 |
+
if sample_mode:
|
583 |
+
assert self.config.attn_backend == 'sdpa', 'Sampling only supported with SDPA'
|
584 |
+
|
585 |
output_hidden_states = (
|
586 |
output_hidden_states
|
587 |
if output_hidden_states is not None
|