Skip to content

Commit

Permalink
Update benchmarks to new syntax, except jvp_matmul segfaults now?
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Mar 25, 2023
1 parent 0d845be commit 9b5eeb0
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
10 changes: 5 additions & 5 deletions benchmarks/jvp_matmul.dx
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
n = if dex_test_mode() then 10 else 500

m1 = rand_mat n n randn (new_key 0)
m2 = rand_mat n n randn (new_key 1)
m1 = rand_mat(n, n, randn, new_key 0)
m2 = rand_mat(n, n, randn, new_key 1)

def mmp' {l m n} (m1:l=>m=>Float) (m2:m=>n=>Float) : l=>n=>Float =
jvp ((**) m1) m2 m2
def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
jvp (\m. m1 ** m) m2 m2

%bench "jvp_matmul"
res = mmp' m1 m2
res = mmp'(m1, m2)
>
> jvp_matmul
> Compile time: 82.255 ms
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/matmul_small.dx
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
n = 10
width = 1000

m1 = for i:(Fin width). rand_mat n n randn (new_key 0)
m2 = for i:(Fin width). rand_mat n n randn (new_key 1)
m1 = for i:(Fin width). rand_mat(n, n, randn, new_key 0)
m2 = for i:(Fin width). rand_mat(n, n, randn, new_key 1)

%bench "matmul_small"
res = for i. (m1.i ** m2.i)
res = for i. (m1[i] ** m2[i])
>
> matmul_small
> Compile time: 33.241 ms
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/matvec_small.dx
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
n = 10
width = 10000

ms = for i:(Fin width). rand_mat n n randn (new_key 0)
vs = for i:(Fin width). rand_vec n randn (new_key 1)
ms = for i:(Fin width). rand_mat(n, n, randn, new_key 0)
vs = for i:(Fin width). rand_vec(n, randn, new_key 1)

%bench "matvec_small"
res = for i. ms.i **. vs.i
res = for i. ms[i] **. vs[i]
>
> matvec_small
> Compile time: 29.506 ms
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/poly.dx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ n = if dex_test_mode() then 1000 else 100000
a = for i:(Fin n). n_to_f $ ordinal i

%bench "poly"
res = for i. evalpoly [0.0, 1.0, 2.0, 3.0, 4.0] a.i
res = for i. evalpoly [0.0, 1.0, 2.0, 3.0, 4.0] a[i]
>
> poly
> Compile time: 44.950 ms
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/vjp_matmul.dx
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ n = if dex_test_mode() then 10 else 500
m1 = rand_mat n n randn (new_key 0)
m2 = rand_mat n n randn (new_key 1)

def mmp' {l m n} (m1:l=>m=>Float) (m2:m=>n=>Float) : l=>n=>Float =
snd (vjp ((**) (transpose m1)) (for _ _. 0.0)) m2
def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
snd(vjp(\m. transpose(m1) ** m, for _ _. 0.0))(m2)

%bench "vjp_matmul"
res = mmp' m1 m2
res = mmp'(m1, m2)
>
> vjp_matmul
> Compile time: 130.231 ms
Expand Down

0 comments on commit 9b5eeb0

Please sign in to comment.