Skip to content

Commit

Permalink
Update convolution benchmark to new syntax.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Apr 11, 2023
1 parent bbec9ad commit c0fa128
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
30 changes: 18 additions & 12 deletions benchmarks/conv.dx
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ This computation is interesting because it occurs in the inner
loop of computing the Neural Tangent Kernel of a convolutional
layer.

def unsafe_from_integer {n} [Ix n] (i:Int) : n =
unsafe_from_ordinal _ $ unsafe_i_to_n i
def unsafe_from_integer(i:Int) -> n given (n|Ix) =
unsafe_from_ordinal $ unsafe_i_to_n i

def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float)
(size: Nat) : (Fin d1)=>(Fin d2)=>Float =
def conv_1d(
kernel: (Fin d1)=>(Fin d2)=>Float,
size: Nat)
-> (Fin d1)=>(Fin d2)=>Float given (d1, d2) =
half_kernel_size = (f_to_i $ (n_to_f size) / 2.0)
for i j. sum for k: (Fin size).
i' = n_to_i $ ordinal i
Expand All @@ -22,14 +24,18 @@ def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float)
j'' = j' + k' - half_kernel_size
if i'' < 0 || i'' >= (n_to_i d1) || j'' < 0 || j'' >= (n_to_i d2)
then 0
else kernel.(unsafe_from_integer i'').(unsafe_from_integer j'')

def conv {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float)
(size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float =
for n' c'. conv_1d kernel.n'.c' (unsafe_i_to_n size)

def conv_spec {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float)
(size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float =
else kernel[unsafe_from_integer i'', unsafe_from_integer j'']

def conv(
kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float,
size: Int)
-> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) =
for n' c'. conv_1d(kernel[n', c'], unsafe_i_to_n(size))

def conv_spec(
kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float,
size: Int)
-> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) =
if size == 3
then conv kernel 3
else conv kernel size
Expand Down
9 changes: 9 additions & 0 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,15 @@ instance Ix((a, b, c)) given (a|Ix, b|Ix, c|Ix)
(i, (j, k)) = unsafe_from_ordinal o
(i, j, k)

instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix)
def size'() = size a * size b * size c * size d
def ordinal(tup) =
(i, j, k, m) = tup
ordinal((i,(j,(k,m))))
def unsafe_from_ordinal(o) =
(i, (j, (k, m))) = unsafe_from_ordinal o
(i, j, k, m)

'## Vector spaces

interface VSpace(a|Add|Sub)
Expand Down

0 comments on commit c0fa128

Please sign in to comment.