Skip to content

Commit

Permalink
Merge pull request #1057 from AI-Hypercomputer:raymondzou-splash-dq
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699208949
  • Loading branch information
maxtext authors committed Nov 22, 2024
2 parents 1411510 + 7051d5a commit 968fc84
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,8 @@ def wrap_flash_attention(query, key, value, decoder_segment_ids):
block_q_dkv=min(global_block_q_dkv, query.shape[2]),
block_kv_dkv=min(global_block_kv_dkv, key.shape[2]),
block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]),
block_q_dq=min(global_block_q_dq, query.shape[2]),
block_kv_dq=min(global_block_kv_dq, query.shape[2]),
block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]),
block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]),
use_fused_bwd_kernel=global_use_fused_bwd_kernel,
q_layout=splash_attention_kernel.QKVLayout[global_q_layout],
k_layout=splash_attention_kernel.QKVLayout[global_k_layout],
Expand Down

0 comments on commit 968fc84

Please sign in to comment.