diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 042d06a53..fb6ec0f2f 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -404,11 +404,11 @@ def cudnn_flash_attention( sliding_window_size = self.sliding_window_size if self.attention_type == AttentionType.LOCAL_SLIDING: sliding_window_size = [self.sliding_window_size, 0] - mask_type = "causal" # SWA only works with causal masking + mask_type = "causal" # SWA only works with causal masking attn_mask = None else: # generate attn_mask - mask_type = "padding_causal" # only padding_causal mask type can take a created mask + mask_type = "padding_causal" # only padding_causal mask type can take a created mask attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) dpa_layer = DotProductAttention(