-
Notifications
You must be signed in to change notification settings - Fork 109
/
Copy pathconv.dx
59 lines (47 loc) · 1.55 KB
/
conv.dx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
'# Diagonal convolution
'This computes a diagonally-indexed summation:
```
result.i.j = input.(i-1).(j-1) + input.i.j + input.(i+1).(j+1)
```
This computation is interesting because it occurs in the inner
loop of computing the Neural Tangent Kernel of a convolutional
layer.
def unsafe_from_integer(i:Int) -> n given (n|Ix) =
unsafe_from_ordinal $ unsafe_i_to_n i
def conv_1d(
kernel: (Fin d1)=>(Fin d2)=>Float,
size: Nat)
-> (Fin d1)=>(Fin d2)=>Float given (d1, d2) =
half_kernel_size = (f_to_i $ (n_to_f size) / 2.0)
for i j. sum for k: (Fin size).
i' = n_to_i $ ordinal i
j' = n_to_i $ ordinal j
k' = n_to_i $ ordinal k
i'' = i' + k' - half_kernel_size
j'' = j' + k' - half_kernel_size
if i'' < 0 || i'' >= (n_to_i d1) || j'' < 0 || j'' >= (n_to_i d2)
then 0
else kernel[unsafe_from_integer i'', unsafe_from_integer j'']
def conv(
kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float,
size: Int)
-> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) =
for n' c'. conv_1d(kernel[n', c'], unsafe_i_to_n(size))
def conv_spec(
kernel: (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float,
size: Int)
-> (Fin n)=>(Fin c)=>(Fin h)=>(Fin w)=>Float given (n, c, h, w) =
if size == 3
then conv kernel 3
else conv kernel size
'We benchmark it on a roughly representative input.
width = 3
side = 32
n = 100
x1 = for i:(Fin n) m:(Fin width) j:(Fin side) k:(Fin side).
randn (ixkey (new_key 0) (i, m, j, k))
:t x1
filter_size = +3
%bench "Diagonal convolution"
res = conv x1 filter_size
:t res