Spaces:
Running
on
Zero
Running
on
Zero
File size: 55,898 Bytes
12a0dd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 |
# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import torch
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.functional import scaled_dot_product_attention
from transformers.utils import ModelOutput
from flash_attn import flash_attn_varlen_func
from modeling.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2MLP,
Qwen2PreTrainedModel,
Qwen2RMSNorm,
Qwen2RotaryEmbedding,
apply_rotary_pos_emb,
)
from modeling.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config
torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.accumulated_cache_size_limit = 4096
# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
flex_attention = torch.compile(flex_attention)
class Qwen2Config(_Qwen2Config):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
is_causal=True,
_attn_implementation="flash_attention_2",
qk_norm=True,
layer_module="Qwen2DecoderLayer",
freeze_und=False,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
tie_word_embeddings=tie_word_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
use_sliding_window=use_sliding_window,
sliding_window=sliding_window,
max_window_layers=max_window_layers,
attention_dropout=attention_dropout,
is_causal=is_causal,
_attn_implementation=_attn_implementation,
**kwargs,
)
self.qk_norm = qk_norm
self.layer_module = layer_module
self.freeze_und = freeze_und
class NaiveCache:
def __init__(self, num_layers):
self.key_cache = {k: None for k in range(num_layers)}
self.value_cache = {k: None for k in range(num_layers)}
@property
def num_layers(self):
return len(self.key_cache)
@property
def seq_lens(self):
if self.key_cache[0] is not None:
return self.key_cache[0].shape[0]
else:
return 0
@dataclass
class BaseNavitOutputWithPast(ModelOutput):
packed_query_sequence: torch.FloatTensor = None
past_key_values: Optional[NaiveCache] = None
def pad_sequence(tensor, pad_size):
H, L, D = tensor.shape
pad_tensor = tensor.new_zeros((H, pad_size, D))
return torch.cat([tensor, pad_tensor], dim=1)
class PackedAttention(Qwen2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
if self.config.qk_norm:
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask: List[torch.Tensor],
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
):
packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim)
packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
packed_cos, packed_sin = packed_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
)
if isinstance(attention_mask, List):
packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim)
packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1)
unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1)
unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
upacked_attn_output = []
for query_states, key_states, value_states, attention_mask_per_sample in zip(
unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
attn_output = scaled_dot_product_attention(
query_states.to(torch.bfloat16).unsqueeze(0),
key_states.to(torch.bfloat16).unsqueeze(0),
value_states.to(torch.bfloat16).unsqueeze(0),
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
)
upacked_attn_output.append(attn_output.squeeze(0))
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
else:
pad_size = sum(sample_lens) - packed_query_states.shape[0]
packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size)
packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size)
packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
packed_attn_output = flex_attention(
packed_query_states.unsqueeze(0),
packed_key_states.unsqueeze(0),
packed_value_states.unsqueeze(0),
enable_gqa=True,
block_mask=attention_mask,
)
end_index = packed_attn_output.shape[2] - pad_size
packed_attn_output = packed_attn_output[0, :, :end_index, :]
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size)
packed_attn_output = self.o_proj(packed_attn_output)
return packed_attn_output
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
):
packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
packed_cos, packed_sin = packed_query_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
)
packed_query_states = packed_query_states.to(torch.bfloat16)
packed_key_states = packed_key_states.to(torch.bfloat16)
packed_value_states = packed_value_states.to(torch.bfloat16)
if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
past_key_states = past_key_values.key_cache[self.layer_idx]
past_value_states = past_key_values.value_cache[self.layer_idx]
seqlens = sum(query_lens) + sum(key_values_lens)
merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
merged_key_states[packed_query_indexes] = packed_key_states
merged_key_states[packed_key_value_indexes] = past_key_states
merged_value_states[packed_query_indexes] = packed_value_states
merged_value_states[packed_key_value_indexes] = past_value_states
key_values_lens = key_values_lens + query_lens
else:
merged_key_states = packed_key_states
merged_value_states = packed_value_states
key_values_lens = query_lens
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
k=merged_key_states,
v=merged_value_states,
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
max_seqlen_q=max(query_lens).item(),
max_seqlen_k=max(key_values_lens).item(),
causal=is_causal,
)
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
packed_attn_output = self.o_proj(packed_attn_output)
if update_past_key_values:
past_key_values.key_cache[self.layer_idx] = merged_key_states
past_key_values.value_cache[self.layer_idx] = merged_value_states
return packed_attn_output, past_key_values
class PackedAttentionMoT(Qwen2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
if self.config.qk_norm:
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.q_norm_moe_gen = nn.Identity()
self.k_norm_moe_gen = nn.Identity()
self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
):
packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim))
packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
packed_sequence_und = packed_sequence[packed_und_token_indexes]
packed_sequence_gen = packed_sequence[packed_gen_token_indexes]
packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen)
packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen)
packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen)
packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
if self.config.freeze_und:
packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach()
packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)
packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes])
if self.config.freeze_und:
packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach()
packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes])
packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes])
if self.config.freeze_und:
packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach()
packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes])
packed_cos, packed_sin = packed_position_embeddings
packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1
)
if isinstance(attention_mask, List):
packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim)
packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1)
unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1)
unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
upacked_attn_output = []
for query_states, key_states, value_states, attention_mask_per_sample in zip(
unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
attn_output = scaled_dot_product_attention(
query_states.to(torch.bfloat16).unsqueeze(0),
key_states.to(torch.bfloat16).unsqueeze(0),
value_states.to(torch.bfloat16).unsqueeze(0),
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
)
upacked_attn_output.append(attn_output.squeeze(0))
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
else:
pad_size = sum(sample_lens) - packed_query_states.shape[0]
packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size)
packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size)
packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
packed_attn_output = flex_attention(
packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim
packed_key_states_.unsqueeze(0),
packed_value_states.unsqueeze(0),
enable_gqa=True,
block_mask=attention_mask,
)
end_index = packed_attn_output.shape[2] - pad_size
packed_attn_output = packed_attn_output[0, :, :end_index, :]
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim)
packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes])
packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes])
return packed_attn_output_
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
):
if mode == 'und':
packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
elif mode == 'gen':
packed_query_sequence = packed_query_sequence.to(torch.bfloat16)
packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim))
packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence)
packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence)
packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence)
packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence)
packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence)
packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence)
packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
packed_query_states = packed_query_states.to(torch.float32)
packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes])
packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes])
packed_key_states = packed_key_states.to(torch.float32)
packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes])
packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes])
packed_cos, packed_sin = packed_query_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
)
packed_query_states = packed_query_states.to(torch.bfloat16)
packed_key_states = packed_key_states.to(torch.bfloat16)
packed_value_states = packed_value_states.to(torch.bfloat16)
if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
past_key_states = past_key_values.key_cache[self.layer_idx]
past_value_states = past_key_values.value_cache[self.layer_idx]
seqlens = sum(query_lens) + sum(key_values_lens)
merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
merged_key_states[packed_query_indexes] = packed_key_states
merged_key_states[packed_key_value_indexes] = past_key_states
merged_value_states[packed_query_indexes] = packed_value_states
merged_value_states[packed_key_value_indexes] = past_value_states
key_values_lens = key_values_lens + query_lens
else:
merged_key_states = packed_key_states
merged_value_states = packed_value_states
key_values_lens = query_lens
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
k=merged_key_states,
v=merged_value_states,
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
max_seqlen_q=max(query_lens).item(),
max_seqlen_k=max(key_values_lens).item(),
causal=is_causal,
)
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
if mode == 'und':
packed_attn_output = self.o_proj(packed_attn_output)
elif mode == 'gen':
packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes])
packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes])
if update_past_key_values:
past_key_values.key_cache[self.layer_idx] = merged_key_states
past_key_values.value_cache[self.layer_idx] = merged_value_states
return packed_attn_output, past_key_values
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PackedAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
residual = packed_sequence
packed_sequence = self.input_layernorm(packed_sequence)
# Self Attention
packed_sequence = self.self_attn(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
)
packed_sequence = residual + packed_sequence
# Fully Connected
residual = packed_sequence
packed_sequence = self.post_attention_layernorm(packed_sequence)
packed_sequence = self.mlp(packed_sequence)
packed_sequence = residual + packed_sequence
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
packed_query_sequence = self.input_layernorm(packed_query_sequence)
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
packed_query_sequence = self.mlp(packed_query_sequence)
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
class Qwen2MoTDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx: Optional[int] = None,
attn_module: Optional[Qwen2Attention] = PackedAttentionMoT,
):
super().__init__()
self.hidden_size = config.hidden_size
self.freeze_und = config.freeze_und
self.self_attn = attn_module(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.mlp_moe_gen = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
) -> torch.Tensor:
residual = packed_sequence
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes])
packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
# Self Attention
packed_sequence_ = self.self_attn(
packed_sequence=packed_sequence_,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
if self.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
packed_sequence = residual + packed_sequence_
# Fully Connected
residual = packed_sequence
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_[packed_und_token_indexes] = self.mlp(
self.post_attention_layernorm(packed_sequence[packed_und_token_indexes])
)
if self.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen(
self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
)
packed_sequence = residual + packed_sequence_
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.input_layernorm(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes])
packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
packed_query_sequence = packed_query_sequence_
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
mode=mode,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
packed_query_sequence = self.mlp(packed_query_sequence)
elif mode == "gen":
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16)
packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16)
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence)
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence)
packed_query_sequence = packed_query_sequence_
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
class Qwen2MoEDecoderLayer(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PackedAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.mlp_moe_gen = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
) -> torch.Tensor:
residual = packed_sequence
packed_sequence = self.input_layernorm(packed_sequence)
# Self Attention
packed_sequence = self.self_attn(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
)
packed_sequence = residual + packed_sequence
# Fully Connected
residual = packed_sequence
packed_sequence = self.post_attention_layernorm(packed_sequence)
packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes])
packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes])
packed_sequence_new[packed_und_token_indexes] = packed_sequence_und
packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen
packed_sequence = residual + packed_sequence_new
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
packed_query_sequence = self.input_layernorm(packed_query_sequence)
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
if mode == "und":
packed_query_sequence = self.mlp(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes])
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes])
packed_query_sequence = packed_query_sequence_
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
Decoder_layer_dict = {
"Qwen2DecoderLayer": Qwen2DecoderLayer,
"Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer,
"Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT),
}
class Qwen2Model(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.use_moe = 'Mo' in config.layer_module
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
layer_module = Decoder_layer_dict[config.layer_module]
self.layers = nn.ModuleList(
[layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.use_moe:
self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_ids: torch.Tensor,
packed_und_token_indexes: Optional[torch.LongTensor] = None,
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if self.config.freeze_und:
packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach()
# create position embeddings to be shared across the decoder layers
cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0))
cos = cos.squeeze(0)
sin = sin.squeeze(0)
packed_position_embeddings = (cos, sin)
extra_inputs = {}
if self.use_moe:
assert packed_und_token_indexes is not None
if packed_gen_token_indexes is None:
packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0])
extra_inputs.update(
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
for decoder_layer in self.layers:
packed_sequence = decoder_layer(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
**extra_inputs
)
if self.use_moe:
packed_sequence_ = torch.zeros_like(packed_sequence)
packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes])
if self.config.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes])
return packed_sequence_
else:
return self.norm(packed_sequence)
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
# create position embeddings to be shared across the decoder layers
cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0))
cos = cos.squeeze(0)
sin = sin.squeeze(0)
packed_query_position_embeddings = (cos, sin)
extra_inputs = {}
if self.use_moe:
extra_inputs.update(mode=mode)
if mode == 'gen':
assert packed_vae_token_indexes is not None
assert packed_text_indexes is not None
extra_inputs.update(
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
for decoder_layer in self.layers:
packed_query_sequence, past_key_values = decoder_layer(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
**extra_inputs,
)
if self.use_moe:
if mode == "und":
packed_query_sequence = self.norm(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes])
packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
packed_query_sequence = packed_query_sequence_
else:
packed_query_sequence = self.norm(packed_query_sequence)
return BaseNavitOutputWithPast(
packed_query_sequence=packed_query_sequence,
past_key_values=past_key_values,
)
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def init_moe(self):
for name, param in self.named_parameters():
if "moe_gen" in name:
original_name = name.replace("_moe_gen", "")
param.data.copy_(self.state_dict()[original_name].data)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_ids: torch.Tensor,
packed_und_token_indexes: Optional[torch.LongTensor] = None,
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
outputs = self.model(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
packed_position_ids=packed_position_ids,
attention_mask=attention_mask,
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
return outputs
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
outputs = self.model(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_ids=packed_query_position_ids,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
mode=mode,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
return outputs
|