Skip to content

Commit

Permalink
refactor DeepseekV3TopkRouter
Browse files Browse the repository at this point in the history
  • Loading branch information
bzantium committed Feb 1, 2025
1 parent 8e994dd commit 7405a95
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 30 deletions.
35 changes: 20 additions & 15 deletions src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,37 +142,42 @@ def __init__(self, config):
self.routed_scaling_factor = config.routed_scaling_factor
self.n_group = config.n_group
self.topk_group = config.topk_group
self.norm_topk_prob = config.norm_topk_prob

self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))

def forward(self, hidden_states):
batch_size, seq_length = hidden_states.shape[:-1]
hidden_states = hidden_states.view(-1, self.config.hidden_size)
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))

scores = router_logits.sigmoid()
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
if self.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights, router_logits

@torch.no_grad()
def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [n, n_group]
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
) # [n, e]
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_indices)
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
return topk_indices, topk_weights, router_logits
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
return topk_indices


class DeepseekV3MoE(nn.Module):
Expand Down
35 changes: 20 additions & 15 deletions src/transformers/models/deepseek_v3/modular_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,37 +68,42 @@ def __init__(self, config):
self.routed_scaling_factor = config.routed_scaling_factor
self.n_group = config.n_group
self.topk_group = config.topk_group
self.norm_topk_prob = config.norm_topk_prob

self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))

def forward(self, hidden_states):
batch_size, seq_length = hidden_states.shape[:-1]
hidden_states = hidden_states.view(-1, self.config.hidden_size)
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))

scores = router_logits.sigmoid()
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
if self.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights, router_logits

@torch.no_grad()
def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [n, n_group]
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
) # [n, e]
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_indices)
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
return topk_indices, topk_weights, router_logits
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
return topk_indices


class DeepseekV3MoE(nn.Module):
Expand Down

0 comments on commit 7405a95

Please sign in to comment.