Skip to content

Commit

Permalink
Revert tiled matmul implementation because it triggers a segfault whe…
Browse files Browse the repository at this point in the history
…n running benchmarks.
  • Loading branch information
axch committed Apr 11, 2023
1 parent 1bac834 commit bbec9ad
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -2570,8 +2570,7 @@ def tile(
body (FullTileIx(n, tile_size, tile_ix'))
body (CodaIx(n, coda_offset, coda_size))

-- matmul. Better symbol to use? `@`?
def (**)(
def tiled_matmul(
x: l=>m=>Float,
y: m=>n=>Float
) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
Expand All @@ -2591,6 +2590,14 @@ def (**)(
m_ix = inject m_offset
result!l_ix!n_ix += x[l_ix,m_ix] * y[m_ix,n_ix]

-- matmul. Better symbol to use? `@`?
def (**)(
x: l=>m=>Float,
y: m=>n=>Float
) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
-- TODO(https://github.com/google-research/dex-lang/issues/1212) Replace with tiled_matmul.
naive_matmul(x, y)

def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
for i. vdot(mat[i], v)
def(.**)(v: n=>Float, mat: n=>m=>Float) -> (m=>Float) given (n|Ix, m|Ix) =
Expand Down
2 changes: 1 addition & 1 deletion tests/linalg-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import linalg
-- Check that the optimized matmul gives the same answers as the naive one
amat = for i:(Fin 100) j:(Fin 100). n_to_f $ ordinal (i, j)

:p amat ** amat ~~ naive_matmul amat amat
:p tiled_matmul(amat, amat) ~~ naive_matmul amat amat
> True

-- Check that the inverse of the inverse is identity.
Expand Down

0 comments on commit bbec9ad

Please sign in to comment.