Skip to content

Commit

Permalink
[Fix][KVCache] Fix incorrect tile size calculation (#17595)
Browse files Browse the repository at this point in the history
This PR fixes the tile size calculation in the TIR attention
kernels, where the computed tile sizes may not divide the total
loop extent.
  • Loading branch information
MasterJH5574 authored Jan 19, 2025
1 parent da2e89a commit 077e8eb
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def get_tile_size(x, y, t):
cnt = (x * y) // t
assert (x * y) % t == 0
tile_y = (int)(math.ceil(math.sqrt(cnt)))
while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt:
tile_y += 1
assert tile_y <= cnt
tile_x = cnt // tile_y
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def get_tile_size(x, y, t):
cnt = (x * y) // t
assert (x * y) % t == 0
tile_y = (int)(math.ceil(math.sqrt(cnt)))
while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt:
tile_y += 1
assert tile_y <= cnt
tile_x = cnt // tile_y
Expand Down Expand Up @@ -1867,7 +1867,7 @@ def get_tile_size(x, y, t):
cnt = (x * y) // t
assert (x * y) % t == 0
tile_y = (int)(math.ceil(math.sqrt(cnt)))
while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt:
tile_y += 1
assert tile_y <= cnt
tile_x = cnt // tile_y
Expand Down

0 comments on commit 077e8eb

Please sign in to comment.