diff --git a/benchmarks/conv.dx b/benchmarks/conv.dx index 51d8ffd6a..33e528805 100644 --- a/benchmarks/conv.dx +++ b/benchmarks/conv.dx @@ -8,11 +8,13 @@ This computation is interesting because it occurs in the inner loop of computing the Neural Tangent Kernel of a convolutional layer. -def unsafe_from_integer {n} [Ix n] (i:Int) : n = - unsafe_from_ordinal _ $ unsafe_i_to_n i +def unsafe_from_integer(i:Int) -> n given (n|Ix) = + unsafe_from_ordinal $ unsafe_i_to_n i -def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float) - (size: Nat) : (Fin d1)=>(Fin d2)=>Float = +def conv_1d( + kernel: (Fin d1)=>(Fin d2)=>Float, + size: Nat) + -> (Fin d1)=>(Fin d2)=>Float given (d1, d2) = half_kernel_size = (f_to_i $ (n_to_f size) / 2.0) for i j. sum for k: (Fin size). i' = n_to_i $ ordinal i @@ -22,14 +24,18 @@ def conv_1d {d1 d2} (kernel: (Fin d1)=>(Fin d2)=>Float) j'' = j' + k' - half_kernel_size if i'' < 0 || i'' >= (n_to_i d1) || j'' < 0 || j'' >= (n_to_i d2) then 0 - else kernel.(unsafe_from_integer i'').(unsafe_from_integer j'') - -def conv {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float) - (size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float = - for n' c'. conv_1d kernel.n'.c' (unsafe_i_to_n size) - -def conv_spec {n c h w} (kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float) - (size: Int) : (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float = + else kernel[unsafe_from_integer i'', unsafe_from_integer j''] + +def conv( + kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float, + size: Int) + -> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) = + for n' c'. conv_1d(kernel[n', c'], unsafe_i_to_n(size)) + +def conv_spec( + kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float, + size: Int) + -> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) = if size == 3 then conv kernel 3 else conv kernel size diff --git a/lib/prelude.dx b/lib/prelude.dx index 88225fae6..fc13b7a95 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -449,6 +449,15 @@ instance Ix((a, b, c)) given (a|Ix, b|Ix, c|Ix) (i, (j, k)) = unsafe_from_ordinal o (i, j, k) +instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix) + def size'() = size a * size b * size c * size d + def ordinal(tup) = + (i, j, k, m) = tup + ordinal((i,(j,(k,m)))) + def unsafe_from_ordinal(o) = + (i, (j, (k, m))) = unsafe_from_ordinal o + (i, j, k, m) + '## Vector spaces interface VSpace(a|Add|Sub)