diff --git a/optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 78b310c15f..6eaa02c1d8 100644 --- a/optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -43,6 +43,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache +from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -68,6 +69,8 @@ _CONFIG_FOR_DOC = "DeepseekV3Config" +# default expert number per slice for dynamic MoE +SLICE_MAX_EXPERT = 80 try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -493,46 +496,85 @@ def __init__(self, config): intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=intermediate_size) + self.expert_slice = math.ceil(config.n_routed_experts / SLICE_MAX_EXPERT) + self.expert_chunk = self.config.n_routed_experts // self.expert_slice def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - # Fix style error by commenting out unused flat_topk_idx variable in original code - # flat_topk_idx = topk_idx.view(-1) - if not self.training: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) - return y + # we cast back to the input dtype + topk_weight = topk_weight.to(hidden_states.dtype) + batch = orig_shape[0] + sequence_length = orig_shape[1] + hidden_dim = orig_shape[2] + if self.training: + padded_weights = torch.zeros( + (batch * sequence_length, self.config.n_routed_experts), + dtype=topk_weight.dtype, + device=topk_weight.device, + ) + padded_weights.scatter_(-1, topk_idx, topk_weight) + padded_weights = padded_weights.reshape(-1, sequence_length, self.config.n_routed_experts) + padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) - @torch.no_grad() - def moe_infer(self, x, topk_ids, topk_weight): - """ - Rewrite DeepseekV3MoE.moe_infer: https://huggingface.co/deepseek-ai/DeepSeek-R1/resolve/main/modeling_deepseek.py - """ - out = torch.zeros_like(x) + final_hidden_states = torch.zeros( + (batch, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + for i, expert in enumerate(self.experts): + current_hidden_state = expert(hidden_states) + current_padded_weight = padded_weights[i] + final_hidden_states = ( + final_hidden_states + + current_hidden_state.reshape(-1, sequence_length, hidden_dim) * current_padded_weight + ) + final_hidden_states = final_hidden_states.type(hidden_states.dtype) + final_hidden_states = final_hidden_states.view(*orig_shape) + #final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, aux_loss) + else: + final_hidden_states = torch.zeros( + (batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + for idx in range(self.expert_slice): + experts_range = range(self.expert_chunk) + gate_proj_list = [ + self.experts[idx * self.expert_chunk + i].gate_proj.weight.squeeze() for i in experts_range + ] + down_proj_list = [ + self.experts[idx * self.expert_chunk + i].down_proj.weight.squeeze() for i in experts_range + ] + up_proj_list = [ + self.experts[idx * self.expert_chunk + i].up_proj.weight.squeeze() for i in experts_range + ] - seq_len, hidden_dim = x.shape - num_experts = len(self.experts) + hidden_states_slice = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=topk_idx, + router_weights=topk_weight, + w1=gate_proj_list, + w2=up_proj_list, + w3=down_proj_list, + permuted_weights=True, + activation="silu", + experts_min=(self.expert_chunk * idx), + experts_max=(self.expert_chunk * (idx + 1) - 1), + ) + final_hidden_states = final_hidden_states + hidden_states_slice + htcore.mark_step() - padded_weights = torch.zeros((seq_len, num_experts), dtype=topk_weight.dtype, device=x.device) - padded_weights.scatter_(-1, topk_ids, topk_weight) - padded_weights = padded_weights.reshape(seq_len, num_experts) - padded_weights = padded_weights.permute(1, 0).unsqueeze(-1) + if is_deepspeed_available(): + from deepspeed import comm as dist - # Loop over all available experts in the model and perform the computation on each expert - for i in range(self.experts_per_rank): - expert_idx = i + self.ep_rank * self.experts_per_rank - expert = self.experts[expert_idx] - padded_weight = padded_weights[expert_idx] - x_static = expert(x) * padded_weight - out += x_static + if dist.is_initialized(): + dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM) - if self.ep_size > 1: - out = _all_reduce(out) + final_hidden_states = final_hidden_states.type(hidden_states.dtype) + final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim) + + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + self.shared_experts(identity) - return out + return final_hidden_states class Matmul(torch.nn.Module):