Skip to content

Commit

Permalink
Merge pull request #880 from google:lizhiyu/switch_ep_axis
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673469152
  • Loading branch information
maxtext authors committed Sep 11, 2024
2 parents d2c7a2e + b51797f commit c7c3f4e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,27 +492,27 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel):
loss = self.load_balance_loss(top_k_indices, softmax_probs)
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
with jax.named_scope("dispatch"):
dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> BECM", inputs, dispatch_mask)
dispatch = nn.with_logical_constraint(dispatch, ("activation_batch_no_exp", "activation_exp", None, "activation_embed"))
dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask)
dispatch = nn.with_logical_constraint(dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed"))
with jax.named_scope("wi_0"):
w0_kernel_axes = ("exp", None, None)
w0_kernel = nn.with_logical_constraint(w0_kernel, w0_kernel_axes)
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)("BECM,EMH -> BECH", dispatch, w0_kernel)
layer_w0 = nn.with_logical_constraint(layer_w0, ("activation_batch_no_exp", "activation_exp", None, "activation_mlp"))
layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w0_kernel)
layer_w0 = nn.with_logical_constraint(layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp"))
with jax.named_scope("wi_1"):
w1_kernel_axes = ("exp", None, None)
w1_kernel = nn.with_logical_constraint(w1_kernel, w1_kernel_axes)
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)("BECM,EMH -> BECH", dispatch, w1_kernel)
layer_w1 = nn.with_logical_constraint(layer_w1, ("activation_batch_no_exp", "activation_exp", None, "activation_mlp"))
layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w1_kernel)
layer_w1 = nn.with_logical_constraint(layer_w1, ("activation_exp", "activation_batch_no_exp",None, "activation_mlp"))
layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1)
with jax.named_scope("wo"):
wo_kernel_axes = ("exp", None, None)
wo_kernel = nn.with_logical_constraint(wo_kernel, wo_kernel_axes)
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)("BECH,EHM -> BECM", layer_multiply, wo_kernel)
intermediate_layer = nn.with_logical_constraint(intermediate_layer, ("activation_batch_no_exp", "activation_exp", None, "activation_embed"))
intermediate_layer = self.get_einsum(rhs_mesh_axes=wo_kernel_axes)("EBCH,EHM -> EBCM", layer_multiply, wo_kernel)
intermediate_layer = nn.with_logical_constraint(intermediate_layer, ("activation_exp", "activation_batch_no_exp", None, "activation_embed"))
with jax.named_scope("combine"):
output = self.get_einsum(rhs_mesh_axes=mask_axes)("BECM,BSEC -> BSM", intermediate_layer, combine_mask)
output = self.get_einsum(rhs_mesh_axes=mask_axes)("EBCM,BSEC -> BSM", intermediate_layer, combine_mask)
return output, loss
else:
weights = self.reshape_and_update_weights(top_k_weights, top_k_indices)
Expand Down

0 comments on commit c7c3f4e

Please sign in to comment.