From 9b5eeb0ba7ab62649d5264e8f25d467970cbecbd Mon Sep 17 00:00:00 2001 From: Alexey Radul Date: Sat, 25 Mar 2023 17:17:31 -0400 Subject: [PATCH] Update benchmarks to new syntax, except jvp_matmul segfaults now? --- benchmarks/jvp_matmul.dx | 10 +++++----- benchmarks/matmul_small.dx | 6 +++--- benchmarks/matvec_small.dx | 6 +++--- benchmarks/poly.dx | 2 +- benchmarks/vjp_matmul.dx | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/benchmarks/jvp_matmul.dx b/benchmarks/jvp_matmul.dx index acdb007ea..ef12a194d 100644 --- a/benchmarks/jvp_matmul.dx +++ b/benchmarks/jvp_matmul.dx @@ -1,13 +1,13 @@ n = if dex_test_mode() then 10 else 500 -m1 = rand_mat n n randn (new_key 0) -m2 = rand_mat n n randn (new_key 1) +m1 = rand_mat(n, n, randn, new_key 0) +m2 = rand_mat(n, n, randn, new_key 1) -def mmp' {l m n} (m1:l=>m=>Float) (m2:m=>n=>Float) : l=>n=>Float = - jvp ((**) m1) m2 m2 +def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = + jvp (\m. m1 ** m) m2 m2 %bench "jvp_matmul" -res = mmp' m1 m2 +res = mmp'(m1, m2) > > jvp_matmul > Compile time: 82.255 ms diff --git a/benchmarks/matmul_small.dx b/benchmarks/matmul_small.dx index 1477ca86d..a9418be93 100644 --- a/benchmarks/matmul_small.dx +++ b/benchmarks/matmul_small.dx @@ -1,11 +1,11 @@ n = 10 width = 1000 -m1 = for i:(Fin width). rand_mat n n randn (new_key 0) -m2 = for i:(Fin width). rand_mat n n randn (new_key 1) +m1 = for i:(Fin width). rand_mat(n, n, randn, new_key 0) +m2 = for i:(Fin width). rand_mat(n, n, randn, new_key 1) %bench "matmul_small" -res = for i. (m1.i ** m2.i) +res = for i. (m1[i] ** m2[i]) > > matmul_small > Compile time: 33.241 ms diff --git a/benchmarks/matvec_small.dx b/benchmarks/matvec_small.dx index b7a6d76e0..be7ac4a6e 100644 --- a/benchmarks/matvec_small.dx +++ b/benchmarks/matvec_small.dx @@ -1,11 +1,11 @@ n = 10 width = 10000 -ms = for i:(Fin width). rand_mat n n randn (new_key 0) -vs = for i:(Fin width). rand_vec n randn (new_key 1) +ms = for i:(Fin width). rand_mat(n, n, randn, new_key 0) +vs = for i:(Fin width). rand_vec(n, randn, new_key 1) %bench "matvec_small" -res = for i. ms.i **. vs.i +res = for i. ms[i] **. vs[i] > > matvec_small > Compile time: 29.506 ms diff --git a/benchmarks/poly.dx b/benchmarks/poly.dx index 68bc573cb..2c2ba8cc3 100644 --- a/benchmarks/poly.dx +++ b/benchmarks/poly.dx @@ -3,7 +3,7 @@ n = if dex_test_mode() then 1000 else 100000 a = for i:(Fin n). n_to_f $ ordinal i %bench "poly" -res = for i. evalpoly [0.0, 1.0, 2.0, 3.0, 4.0] a.i +res = for i. evalpoly [0.0, 1.0, 2.0, 3.0, 4.0] a[i] > > poly > Compile time: 44.950 ms diff --git a/benchmarks/vjp_matmul.dx b/benchmarks/vjp_matmul.dx index f56d620f1..43d606c17 100644 --- a/benchmarks/vjp_matmul.dx +++ b/benchmarks/vjp_matmul.dx @@ -3,11 +3,11 @@ n = if dex_test_mode() then 10 else 500 m1 = rand_mat n n randn (new_key 0) m2 = rand_mat n n randn (new_key 1) -def mmp' {l m n} (m1:l=>m=>Float) (m2:m=>n=>Float) : l=>n=>Float = - snd (vjp ((**) (transpose m1)) (for _ _. 0.0)) m2 +def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = + snd(vjp(\m. transpose(m1) ** m, for _ _. 0.0))(m2) %bench "vjp_matmul" -res = mmp' m1 m2 +res = mmp'(m1, m2) > > vjp_matmul > Compile time: 130.231 ms