marriola commited on
Commit
93c93e8
·
verified ·
1 Parent(s): 35ce154

Upload BD3LM

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. configuration_bd3lm.py +1 -1
  3. modeling_bd3lm.py +99 -54
config.json CHANGED
@@ -3,7 +3,7 @@
3
  "architectures": [
4
  "BD3LM"
5
  ],
6
- "attn_backend": "sdpa",
7
  "auto_map": {
8
  "AutoConfig": "configuration_bd3lm.BD3LMConfig",
9
  "AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
 
3
  "architectures": [
4
  "BD3LM"
5
  ],
6
+ "attn_backend": "flex",
7
  "auto_map": {
8
  "AutoConfig": "configuration_bd3lm.BD3LMConfig",
9
  "AutoModelForMaskedLM": "modeling_bd3lm.BD3LM"
configuration_bd3lm.py CHANGED
@@ -15,7 +15,7 @@ class BD3LMConfig(transformers.PretrainedConfig):
15
  vocab_size: int = 50258,
16
  model_length: int = 1024,
17
  cross_attn: bool = True,
18
- attn_backend: str = 'sdpa',
19
  hidden_dim: int = 768,
20
  cond_dim: int = 129,
21
  n_blocks: int = 12,
 
15
  vocab_size: int = 50258,
16
  model_length: int = 1024,
17
  cross_attn: bool = True,
18
+ attn_backend: str = 'flex',
19
  hidden_dim: int = 768,
20
  cond_dim: int = 129,
21
  n_blocks: int = 12,
modeling_bd3lm.py CHANGED
@@ -5,13 +5,17 @@ import math
5
  import typing
6
 
7
  import einops
8
- import flash_attn
9
- import flash_attn.layers.rotary
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  import transformers
14
  from transformers import modeling_outputs
 
 
 
 
 
15
 
16
  from .configuration_bd3lm import BD3LMConfig
17
 
@@ -21,21 +25,55 @@ torch._C._jit_set_profiling_executor(False)
21
  torch._C._jit_override_can_fuse_on_cpu(True)
22
  torch._C._jit_override_can_fuse_on_gpu(True)
23
 
24
- def block_causal_mask(num_rows, block_size, mode='full', offset=0):
25
- mask = block_size * torch.arange(
26
- 1, num_rows // block_size + 1).unsqueeze(1).tile(block_size).flatten().unsqueeze(1)
27
- if mode == 'full':
28
- mask = (mask >= mask.T + offset)
29
- elif mode == 'diag':
30
- mask = (mask + offset == mask.T)
31
- elif mode == 'triu_diag':
32
- mask = torch.zeros(num_rows, num_rows)
33
- rows = torch.arange(0, num_rows)
34
- group_indices = rows // (block_size)
35
- column_indices = group_indices * (block_size) + block_size + offset
36
- valid_rows = column_indices < num_rows
37
- mask[rows[valid_rows].unsqueeze(1), column_indices[valid_rows].unsqueeze(1)] = 1
38
- return mask.int()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def bias_dropout_add_scale(
41
  x: torch.Tensor,
@@ -132,12 +170,6 @@ def rotate_half(x):
132
  def apply_rotary_pos_emb_torchscript(qkv, cos, sin):
133
  return (qkv * cos) + (rotate_half(qkv) * sin)
134
 
135
- def apply_rotary_pos_emb(qkv, cos, sin):
136
- cos = cos[0,:,0,0,:cos.shape[-1]//2]
137
- sin = sin[0,:,0,0,:sin.shape[-1]//2]
138
- return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
139
-
140
-
141
  # function overload
142
  def modulate(x, shift, scale):
143
  return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@@ -268,8 +300,9 @@ def regular_attention_multi_headed(qkv):
268
 
269
  class DDiTBlock(nn.Module):
270
  def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
271
- dropout=0.1, attn_backend='flash_attn'):
272
  super().__init__()
 
273
  self.n = n
274
  self.block_size = block_size
275
  self.n_heads = n_heads
@@ -317,32 +350,33 @@ class DDiTBlock(nn.Module):
317
  h=self.n_heads)
318
  with torch.cuda.amp.autocast(enabled=False):
319
  cos, sin = rotary_cos_sin
320
- if self.attn_backend == 'flash_attn':
321
- qkv = apply_rotary_pos_emb(
322
- qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
323
- else:
324
- qkv = apply_rotary_pos_emb_torchscript(
325
- qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
326
  return qkv
327
 
328
- def cross_attn(self, x, qkv, cross_attn_mask=None):
329
  scale = qkv.shape[-1]
330
  qkv = qkv.transpose(1, 3)
331
- attn_dropout = self.attn_dropout if self.training else 0.0
332
- cross_attn_mask = cross_attn_mask.bool() if cross_attn_mask is not None else None
333
  x = F.scaled_dot_product_attention(
334
  query=qkv[:, :, 0],
335
  key=qkv[:, :, 1],
336
  value=qkv[:, :, 2],
337
- attn_mask=cross_attn_mask,
338
- dropout_p=attn_dropout,
339
  is_causal=False,
340
  scale=1 / math.sqrt(scale))
341
  x = x.transpose(1, 2)
342
  x = einops.rearrange(x, 'b s h d -> b s (h d)')
343
  return x
344
-
345
- def forward(self, x, rotary_cos_sin, c, cross_attn_mask=None,
 
 
 
 
 
 
 
346
  sample_mode=False, store_kv=False):
347
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
348
 
@@ -354,17 +388,21 @@ class DDiTBlock(nn.Module):
354
  x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
355
 
356
  # get qkvs
357
- if cross_attn_mask is not None and not sample_mode:
358
  qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
359
  qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
360
  qkv = torch.cat((qkv_x, qkv_x0), dim=1)
361
  else:
362
  qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
363
 
364
- if cross_attn_mask is None and self.attn_backend == 'flash_attn':
365
  x = regular_attention_multi_headed(qkv)
 
 
 
 
366
  else:
367
- x = self.cross_attn(x, qkv, cross_attn_mask=cross_attn_mask)
368
 
369
  x = bias_dropout_scale_fn(self.attn_out(x),
370
  None,
@@ -448,7 +486,7 @@ class DITBackbone(nn.Module):
448
  config.vocab_size,
449
  config.cond_dim)
450
  if self.cross_attn:
451
- self.gen_mask(config.model_length, self.block_size)
452
  self.precision = torch.float32
453
 
454
  def _get_bias_dropout_scale(self):
@@ -457,15 +495,18 @@ class DITBackbone(nn.Module):
457
  else:
458
  return bias_dropout_add_scale_fused_inference
459
 
460
- def gen_mask(self, seqlen, block_size):
461
- self_attn_mask = block_causal_mask(seqlen, block_size, mode='diag')
462
- x0_attn_mask = block_causal_mask(seqlen, block_size, mode='full')
463
- cross_attn_mask = x0_attn_mask.clone()
464
- cross_attn_mask.masked_fill_(self_attn_mask == 1, 0)
465
-
466
- cross_attn_mask = torch.cat((self_attn_mask, cross_attn_mask), dim=1)
467
- x0_attn_mask = torch.cat((torch.zeros_like(self_attn_mask), x0_attn_mask), dim=1)
468
- self.cross_attn_mask = torch.cat((cross_attn_mask, x0_attn_mask), dim=0)
 
 
 
469
 
470
  def forward(self, indices, sigma, sample_mode=False,
471
  store_kv=False, output_hidden_states=False):
@@ -478,13 +519,13 @@ class DITBackbone(nn.Module):
478
  c = F.silu(self.sigma_map(sigma))
479
  if self.cross_attn:
480
  rotary_cos_sin = self.rotary_emb(x[:, :self.n])
481
- cross_attn_mask = self.cross_attn_mask.to(x.device)
482
  # use block-causal mask only during sampling
483
  if sample_mode:
484
- cross_attn_mask = cross_attn_mask[
485
  self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
486
  else:
487
- cross_attn_mask = None
488
  rotary_cos_sin = self.rotary_emb(x)
489
 
490
  with torch.cuda.amp.autocast(dtype=self.precision):
@@ -492,7 +533,7 @@ class DITBackbone(nn.Module):
492
  x = self.blocks[i](x,
493
  rotary_cos_sin,
494
  c,
495
- cross_attn_mask=cross_attn_mask,
496
  sample_mode=sample_mode,
497
  store_kv=store_kv)
498
  if output_hidden_states:
@@ -512,6 +553,7 @@ class BD3LM(transformers.PreTrainedModel):
512
  self,
513
  config: BD3LMConfig):
514
  super().__init__(config)
 
515
  self.backbone = DITBackbone(config)
516
  if config.var_min:
517
  self.register_buffer(
@@ -537,6 +579,9 @@ class BD3LM(transformers.PreTrainedModel):
537
  torch.Tensor, typing.Tuple,
538
  modeling_outputs.MaskedLMOutput]:
539
  """HF-compatible forward method."""
 
 
 
540
  output_hidden_states = (
541
  output_hidden_states
542
  if output_hidden_states is not None
 
5
  import typing
6
 
7
  import einops
8
+ from functools import partial
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import transformers
13
  from transformers import modeling_outputs
14
+ try:
15
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
16
+ FLEX_ATTN_AVAILABLE = True
17
+ except:
18
+ FLEX_ATTN_AVAILABLE = False
19
 
20
  from .configuration_bd3lm import BD3LMConfig
21
 
 
25
  torch._C._jit_override_can_fuse_on_cpu(True)
26
  torch._C._jit_override_can_fuse_on_gpu(True)
27
 
28
+ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
29
+ """
30
+ Constructs the specialized block diffusion attention mask for training
31
+ composed of three masks:
32
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
33
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
34
+ - **Block Causal Mask (M_BC)**: Attention to update x0
35
+
36
+ Args:
37
+ b, h: Batch and head indices (ignored for mask logic).
38
+ q_idx, kv_idx: Query and Key indices.
39
+ seq_len: Total sequence length.
40
+ block_size: Defines the block structure.
41
+
42
+ Returns:
43
+ A boolean attention mask.
44
+ """
45
+
46
+ # Indicate whether token belongs to xt or x0
47
+ x0_flag_q = (q_idx >= n)
48
+ x0_flag_kv = (kv_idx >= n)
49
+
50
+ # Compute block indices
51
+ block_q = torch.where(x0_flag_q == 1,
52
+ (q_idx - n) // block_size,
53
+ q_idx // block_size)
54
+ block_kv = torch.where(x0_flag_kv == 1,
55
+ (kv_idx - n) // block_size,
56
+ kv_idx // block_size)
57
+
58
+ # **1. Block Diagonal Mask (M_BD) **
59
+ block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv)
60
+
61
+ # **2. Offset Block-Causal Mask (M_OBC) **
62
+ offset_block_causal = (
63
+ (block_q > block_kv)
64
+ & (x0_flag_kv == 1)
65
+ & (x0_flag_q == 0)
66
+ )
67
+
68
+ # **3. Block-Causal Mask (M_BC) **
69
+ block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1)
70
+
71
+ # **4. Combine Masks **
72
+ return block_diagonal | offset_block_causal | block_causal
73
+
74
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
75
+ def fused_flex_attention(q, k, v, mask=None):
76
+ return flex_attention(q, k, v, block_mask=mask)
77
 
78
  def bias_dropout_add_scale(
79
  x: torch.Tensor,
 
170
  def apply_rotary_pos_emb_torchscript(qkv, cos, sin):
171
  return (qkv * cos) + (rotate_half(qkv) * sin)
172
 
 
 
 
 
 
 
173
  # function overload
174
  def modulate(x, shift, scale):
175
  return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
 
300
 
301
  class DDiTBlock(nn.Module):
302
  def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
303
+ dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
304
  super().__init__()
305
+ self.max_seqlen = max_seqlen
306
  self.n = n
307
  self.block_size = block_size
308
  self.n_heads = n_heads
 
350
  h=self.n_heads)
351
  with torch.cuda.amp.autocast(enabled=False):
352
  cos, sin = rotary_cos_sin
353
+ qkv = apply_rotary_pos_emb_torchscript(
354
+ qkv, cos.to(qkv.dtype), sin.to(qkv.dtype))
 
 
 
 
355
  return qkv
356
 
357
+ def cross_attn(self, x, qkv, mask=None):
358
  scale = qkv.shape[-1]
359
  qkv = qkv.transpose(1, 3)
360
+ mask = mask.bool() if mask is not None else None
 
361
  x = F.scaled_dot_product_attention(
362
  query=qkv[:, :, 0],
363
  key=qkv[:, :, 1],
364
  value=qkv[:, :, 2],
365
+ attn_mask=mask,
 
366
  is_causal=False,
367
  scale=1 / math.sqrt(scale))
368
  x = x.transpose(1, 2)
369
  x = einops.rearrange(x, 'b s h d -> b s (h d)')
370
  return x
371
+
372
+ def cross_attn_flex(self, qkv, mask=None):
373
+ qkv = einops.rearrange(qkv, 'b s three h d -> b h three s d', h=self.n_heads)
374
+ x = fused_flex_attention(
375
+ qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], mask=mask)
376
+ x = einops.rearrange(x, 'b h s d -> b s (h d)')
377
+ return x
378
+
379
+ def forward(self, x, rotary_cos_sin, c, mask=None,
380
  sample_mode=False, store_kv=False):
381
  bias_dropout_scale_fn = self._get_bias_dropout_scale()
382
 
 
388
  x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
389
 
390
  # get qkvs
391
+ if mask is not None and not sample_mode:
392
  qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
393
  qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
394
  qkv = torch.cat((qkv_x, qkv_x0), dim=1)
395
  else:
396
  qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
397
 
398
+ if mask is None and self.attn_backend == 'flash_attn':
399
  x = regular_attention_multi_headed(qkv)
400
+ elif self.attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
401
+ x = self.cross_attn_flex(qkv, mask=mask)
402
+ elif self.attn_backend == 'sdpa':
403
+ x = self.cross_attn(x, qkv, mask=mask)
404
  else:
405
+ raise ValueError('Unknown attention backend')
406
 
407
  x = bias_dropout_scale_fn(self.attn_out(x),
408
  None,
 
486
  config.vocab_size,
487
  config.cond_dim)
488
  if self.cross_attn:
489
+ self.gen_mask(config.model_length, self.block_size, attn_backend=config.attn_backend)
490
  self.precision = torch.float32
491
 
492
  def _get_bias_dropout_scale(self):
 
495
  else:
496
  return bias_dropout_add_scale_fused_inference
497
 
498
+ def gen_mask(self, seqlen, block_size, attn_backend='sdpa'):
499
+ """Genererates attention mask"""
500
+ if attn_backend == 'flex' and FLEX_ATTN_AVAILABLE:
501
+ self.mask = create_block_mask(
502
+ partial(block_diff_mask, block_size=block_size, n=seqlen),
503
+ B=None, H=None, Q_LEN=seqlen*2, KV_LEN=seqlen*2)
504
+ elif attn_backend == 'sdpa':
505
+ self.mask = block_diff_mask(
506
+ b=None, h=None, q_idx=torch.arange(seqlen*2)[:, None],
507
+ kv_idx=torch.arange(seqlen*2)[None, :], block_size=block_size, n=seqlen)
508
+ else:
509
+ raise ValueError('Unknown attention backend')
510
 
511
  def forward(self, indices, sigma, sample_mode=False,
512
  store_kv=False, output_hidden_states=False):
 
519
  c = F.silu(self.sigma_map(sigma))
520
  if self.cross_attn:
521
  rotary_cos_sin = self.rotary_emb(x[:, :self.n])
522
+ mask = self.mask.to(x.device)
523
  # use block-causal mask only during sampling
524
  if sample_mode:
525
+ mask = mask[
526
  self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
527
  else:
528
+ mask = None
529
  rotary_cos_sin = self.rotary_emb(x)
530
 
531
  with torch.cuda.amp.autocast(dtype=self.precision):
 
533
  x = self.blocks[i](x,
534
  rotary_cos_sin,
535
  c,
536
+ mask=mask,
537
  sample_mode=sample_mode,
538
  store_kv=store_kv)
539
  if output_hidden_states:
 
553
  self,
554
  config: BD3LMConfig):
555
  super().__init__(config)
556
+ self.config = config
557
  self.backbone = DITBackbone(config)
558
  if config.var_min:
559
  self.register_buffer(
 
579
  torch.Tensor, typing.Tuple,
580
  modeling_outputs.MaskedLMOutput]:
581
  """HF-compatible forward method."""
582
+ if sample_mode:
583
+ assert self.config.attn_backend == 'sdpa', 'Sampling only supported with SDPA'
584
+
585
  output_hidden_states = (
586
  output_hidden_states
587
  if output_hidden_states is not None