Skip to content

Commit

Permalink
Added dynamic MoE changes
Browse files Browse the repository at this point in the history
  • Loading branch information
srajabos authored and skavulya committed Feb 11, 2025
1 parent ae5fdb8 commit b2b1715
Showing 1 changed file with 71 additions and 29 deletions.
100 changes: 71 additions & 29 deletions optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.integrations.deepspeed import is_deepspeed_available
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -68,6 +69,8 @@

_CONFIG_FOR_DOC = "DeepseekV3Config"

# default expert number per slice for dynamic MoE
SLICE_MAX_EXPERT = 80
try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE

Expand Down Expand Up @@ -493,46 +496,85 @@ def __init__(self, config):
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size)

self.expert_slice = math.ceil(config.n_routed_experts / SLICE_MAX_EXPERT)
self.expert_chunk = self.config.n_routed_experts // self.expert_slice
def forward(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# Fix style error by commenting out unused flat_topk_idx variable in original code
# flat_topk_idx = topk_idx.view(-1)
if not self.training:
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
batch = orig_shape[0]
sequence_length = orig_shape[1]
hidden_dim = orig_shape[2]
if self.training:
padded_weights = torch.zeros(
(batch * sequence_length, self.config.n_routed_experts),
dtype=topk_weight.dtype,
device=topk_weight.device,
)
padded_weights.scatter_(-1, topk_idx, topk_weight)
padded_weights = padded_weights.reshape(-1, sequence_length, self.config.n_routed_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

@torch.no_grad()
def moe_infer(self, x, topk_ids, topk_weight):
"""
Rewrite DeepseekV3MoE.moe_infer: https://huggingface.co/deepseek-ai/DeepSeek-R1/resolve/main/modeling_deepseek.py
"""
out = torch.zeros_like(x)
final_hidden_states = torch.zeros(
(batch, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
for i, expert in enumerate(self.experts):
current_hidden_state = expert(hidden_states)
current_padded_weight = padded_weights[i]
final_hidden_states = (
final_hidden_states
+ current_hidden_state.reshape(-1, sequence_length, hidden_dim) * current_padded_weight
)
final_hidden_states = final_hidden_states.type(hidden_states.dtype)
final_hidden_states = final_hidden_states.view(*orig_shape)
#final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, aux_loss)
else:
final_hidden_states = torch.zeros(
(batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
for idx in range(self.expert_slice):
experts_range = range(self.expert_chunk)
gate_proj_list = [
self.experts[idx * self.expert_chunk + i].gate_proj.weight.squeeze() for i in experts_range
]
down_proj_list = [
self.experts[idx * self.expert_chunk + i].down_proj.weight.squeeze() for i in experts_range
]
up_proj_list = [
self.experts[idx * self.expert_chunk + i].up_proj.weight.squeeze() for i in experts_range
]

seq_len, hidden_dim = x.shape
num_experts = len(self.experts)
hidden_states_slice = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=topk_idx,
router_weights=topk_weight,
w1=gate_proj_list,
w2=up_proj_list,
w3=down_proj_list,
permuted_weights=True,
activation="silu",
experts_min=(self.expert_chunk * idx),
experts_max=(self.expert_chunk * (idx + 1) - 1),
)
final_hidden_states = final_hidden_states + hidden_states_slice
htcore.mark_step()

padded_weights = torch.zeros((seq_len, num_experts), dtype=topk_weight.dtype, device=x.device)
padded_weights.scatter_(-1, topk_ids, topk_weight)
padded_weights = padded_weights.reshape(seq_len, num_experts)
padded_weights = padded_weights.permute(1, 0).unsqueeze(-1)
if is_deepspeed_available():
from deepspeed import comm as dist

# Loop over all available experts in the model and perform the computation on each expert
for i in range(self.experts_per_rank):
expert_idx = i + self.ep_rank * self.experts_per_rank
expert = self.experts[expert_idx]
padded_weight = padded_weights[expert_idx]
x_static = expert(x) * padded_weight
out += x_static
if dist.is_initialized():
dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM)

if self.ep_size > 1:
out = _all_reduce(out)
final_hidden_states = final_hidden_states.type(hidden_states.dtype)
final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim)

if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + self.shared_experts(identity)

return out
return final_hidden_states


class Matmul(torch.nn.Module):
Expand Down

0 comments on commit b2b1715

Please sign in to comment.