zhouzaida commited on
Commit
14b9a5a
·
1 Parent(s): 4dea07c

support training

Browse files
Files changed (1) hide show
  1. modeling_kimi_vl.py +68 -3
modeling_kimi_vl.py CHANGED
@@ -970,6 +970,10 @@ class MoEGate(nn.Module):
970
  ) # [n, e]
971
  _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
972
  topk_weight = scores.gather(1, topk_idx)
 
 
 
 
973
  else:
974
  raise NotImplementedError(
975
  f"insupportable TopK function for MoE gating: {self.topk_method}"
@@ -983,7 +987,57 @@ class MoEGate(nn.Module):
983
  topk_weight * self.routed_scaling_factor
984
  ) # must multiply the scaling factor
985
 
986
- return topk_idx, topk_weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987
 
988
 
989
  class DeepseekV3MoE(nn.Module):
@@ -1036,9 +1090,20 @@ class DeepseekV3MoE(nn.Module):
1036
  def forward(self, hidden_states):
1037
  identity = hidden_states
1038
  orig_shape = hidden_states.shape
1039
- topk_idx, topk_weight = self.gate(hidden_states)
1040
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
1041
- if not self.training:
 
 
 
 
 
 
 
 
 
 
 
1042
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
1043
  if self.config.n_shared_experts is not None:
1044
  y = y + self.shared_experts(identity)
 
970
  ) # [n, e]
971
  _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
972
  topk_weight = scores.gather(1, topk_idx)
973
+ elif self.topk_method == "greedy":
974
+ topk_weight, topk_idx = torch.topk(
975
+ scores, k=self.top_k, dim=-1, sorted=False
976
+ )
977
  else:
978
  raise NotImplementedError(
979
  f"insupportable TopK function for MoE gating: {self.topk_method}"
 
987
  topk_weight * self.routed_scaling_factor
988
  ) # must multiply the scaling factor
989
 
990
+ if self.training and self.alpha > 0.0:
991
+ scores_for_aux = scores
992
+ aux_topk = self.top_k
993
+ # always compute aux loss based on the naive greedy topk method
994
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
995
+ if self.seq_aux:
996
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
997
+ ce = torch.zeros(
998
+ bsz, self.n_routed_experts, device=hidden_states.device
999
+ )
1000
+ ce.scatter_add_(
1001
+ 1,
1002
+ topk_idx_for_aux_loss,
1003
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
1004
+ ).div_(seq_len * aux_topk / self.n_routed_experts)
1005
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
1006
+ dim=1
1007
+ ).mean() * self.alpha
1008
+ else:
1009
+ mask_ce = F.one_hot(
1010
+ topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
1011
+ )
1012
+ ce = mask_ce.float().mean(0)
1013
+ Pi = scores_for_aux.mean(0)
1014
+ fi = ce * self.n_routed_experts
1015
+ aux_loss = (Pi * fi).sum() * self.alpha
1016
+ else:
1017
+ aux_loss = None
1018
+
1019
+ return topk_idx, topk_weight, aux_loss
1020
+
1021
+
1022
+ class AddAuxiliaryLoss(torch.autograd.Function):
1023
+ """
1024
+ The trick function of adding auxiliary (aux) loss,
1025
+ which includes the gradient of the aux loss during backpropagation.
1026
+ """
1027
+
1028
+ @staticmethod
1029
+ def forward(ctx, x, loss):
1030
+ assert loss.numel() == 1
1031
+ ctx.dtype = loss.dtype
1032
+ ctx.required_aux_loss = loss.requires_grad
1033
+ return x
1034
+
1035
+ @staticmethod
1036
+ def backward(ctx, grad_output):
1037
+ grad_loss = None
1038
+ if ctx.required_aux_loss:
1039
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
1040
+ return grad_output, grad_loss
1041
 
1042
 
1043
  class DeepseekV3MoE(nn.Module):
 
1090
  def forward(self, hidden_states):
1091
  identity = hidden_states
1092
  orig_shape = hidden_states.shape
1093
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
1094
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
1095
+ if self.training:
1096
+ flat_topk_idx = topk_idx.view(-1)
1097
+ hidden_states = hidden_states.repeat_interleave(
1098
+ self.num_experts_per_tok, dim=0
1099
+ )
1100
+ y = torch.empty_like(hidden_states)
1101
+ for i, expert in enumerate(self.experts):
1102
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
1103
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
1104
+ y = y.to(hidden_states.dtype).view(*orig_shape)
1105
+ y = AddAuxiliaryLoss.apply(y, aux_loss)
1106
+ else:
1107
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
1108
  if self.config.n_shared_experts is not None:
1109
  y = y + self.shared_experts(identity)