-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Direct colvolution #292
Comments
We have a time-domain implementation as part of |
Yeah I 100% agree, my main motivation for adding overlp-save to conv was to pave the way to combine |
I'd really love to have a |
So here is a potential fallback implementation: julia> function directconv(A::AbstractArray{<:Any,M}, B::AbstractArray{<:Any,N}) where {M,N}
axes_A = ntuple(i -> axes(A, i), Val(max(M,N)))
axes_B = ntuple(i -> axes(B, i), Val(max(M,N)))
krange(i) = CartesianIndices((:).(max.(first.(axes_B), Tuple(i) .- last.(axes_A)), min.(last.(axes_B), Tuple(i) .- first.(axes_A))))
return [sum(A[i-k]*B[k] for k in krange(i)) for i in CartesianIndices((:).(first.(axes_A).+first.(axes_B), last.(axes_A).+last.(axes_B)))]
end
directconv (generic function with 1 method)
julia> A = rand(3,4); B = rand(5,6,7);
julia> @btime conv($A,$B);
112.307 μs (178 allocations: 72.63 KiB)
julia> @btime directconv($A,$B);
20.673 μs (885 allocations: 52.00 KiB) From a brief glance, it's type-stable and results agree with |
I figured out a fast way to do direct convolution that uses SIMD. It doesn't support weird indexing at the moment, or arrays of different number of dimensions, but it's fast. julia> function _conv_kern_direct!(out, u, v)
fill!(out, 0)
u_region = CartesianIndices(u)
v_region = CartesianIndices(v)
one_index = oneunit(first(u_region))
for vindex in v_region
@simd for uindex in u_region
@inbounds out[uindex + vindex - one_index] += u[uindex] * v[vindex]
end
end
out
end
function _conv_kern_direct(
u::AbstractArray{T, N}, v::AbstractArray{S, N}, su, sv) where {T, S, N}
sout = su .+ sv .- 1
out = similar(u, promote_type(T, S), sout)
_conv_kern_direct!(out, u, v)
end
julia> sa = (2,3,4); sb = (5,6,7); a = rand(sa...); b = rand(sb...);
julia> @benchmark conv($b, $a)
BenchmarkTools.Trial:
memory estimate: 27.78 KiB
allocs estimate: 164
--------------
minimum time: 86.637 μs (0.00% GC)
median time: 99.368 μs (0.00% GC)
mean time: 143.282 μs (7.74% GC)
maximum time: 67.902 ms (95.95% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark directconv($b, $a)
BenchmarkTools.Trial:
memory estimate: 56.52 KiB
allocs estimate: 963
--------------
minimum time: 25.700 μs (0.00% GC)
median time: 27.732 μs (0.00% GC)
mean time: 39.263 μs (22.48% GC)
maximum time: 5.374 ms (99.37% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark _conv_kern_direct($b, $a, $sb, $sa)
BenchmarkTools.Trial:
memory estimate: 3.88 KiB
allocs estimate: 1
--------------
minimum time: 5.357 μs (0.00% GC)
median time: 5.909 μs (0.00% GC)
mean time: 8.284 μs (18.72% GC)
maximum time: 11.008 ms (99.89% GC)
--------------
samples: 10000
evals/sample: 6 It's also competitive with julia> sa = (1000, 1000, 3); sb = (3, 3, 3); a = rand(sa...); b = rand(sb...);
julia> @btime conv($a, $b);
135.281 ms (165 allocations: 38.32 MiB)
julia> @btime directconv($a, $b);
498.466 ms (10040044 allocations: 574.50 MiB)
julia> @btime _conv_kern_direct($a, $b, $sa, $sb);
146.988 ms (2 allocations: 38.30 MiB) This is also with the overlap-save version of julia> using DSP: _conv_kern_fft!, nextfastfft
julia> sout = sa .+ sb .- 1; out = zeros(sout); out = zeros(sout); nffts = nextfastfft(sout);
julia> @btime _conv_kern_fft!($out, $a, $b, $sa, $sb, $sout, $nffts);
348.628 ms (172 allocations: 232.88 MiB) Because it's using SIMD, you get a ~2x speedup for single precision float, and a ~4x speedup for Int16 etc. Also with the same test data: julia> using ImageFiltering
julia> @btime imfilter($a, ($b,));
171.826 ms (31 allocations: 68.96 MiB) |
It would be great if we had a direct convolution kernel, which would probably be faster for small convolutions.
The text was updated successfully, but these errors were encountered: