diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 639800b44..79ad54097 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -70,7 +70,7 @@ class Encoder1DBlock(nn.Module): def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( + y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, @@ -89,7 +89,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: x = x + y else: y = x - y = nn.SelfAttention( + y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index cb6287c5e..85a8d1bb7 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -396,10 +396,9 @@ def __call__(self, inputs, paddings, train): mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) inputs = LayerNorm(dim=config.encoder_dim)(inputs) - attention_fn = functools.partial( dot_product_attention, temperature=config.attention_temperature) - result = nn.SelfAttention( + result = nn.MultiHeadDotProductAttention( num_heads=config.num_attention_heads, qkv_features=config.encoder_dim, decode=False, @@ -410,7 +409,8 @@ def __call__(self, inputs, paddings, train): broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=config.attention_dropout_rate, - deterministic=not train)(inputs, attention_mask) + deterministic=not train)( + inputs_q=inputs, mask=attention_mask) if config.attention_residual_dropout_rate is None: attention_residual_dropout_rate = 0.1 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 05faf1135..f546ef785 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -227,11 +227,12 @@ def ctc_loss(self, labels: spec.Tensor, label_paddings: spec.Tensor, blank_id: int = 0) -> spec.Tensor: - return optax.ctc_loss(logits, - logit_paddings, - labels, - label_paddings, - blank_id) + return optax.ctc_loss( + logits=logits, + logit_paddings=logit_paddings, + labels=labels, + label_paddings=label_paddings, + blank_id=blank_id) # Adapted from lingvo's greedy decoding logic here: # https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138.