From e4a00b1d410a78288d4c3a4fd19202a78536458b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 13 Apr 2023 00:40:35 +0530 Subject: [PATCH] reduce allocations in dims_howmany (#269) * reduce allocations in dims_howmany * Update src/fft.jl Co-authored-by: Steven G. Johnson * Dont collect size tuple * filter for Int/Tuple regions * use tuple instead of vector region at more places * remove unused methods * test region collections * bump version to v1.7.0 --------- Co-authored-by: Steven G. Johnson --- Project.toml | 2 +- src/fft.jl | 87 ++++++++++++++++++++++++++++++++++-------------- test/runtests.jl | 3 ++ 3 files changed, 66 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 07d6f52..43527d1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FFTW" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.6.1" +version = "1.7.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/fft.jl b/src/fft.jl index 2567133..daa5866 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -566,20 +566,51 @@ unsafe_execute!(plan::r2rFFTWPlan{T}, # re-use the table of trigonometric constants from the first plan. # Compute dims and howmany for FFTW guru planner -function dims_howmany(X::StridedArray, Y::StridedArray, - sz::Vector{Int}, region) - reg = Int[region...]::Vector{Int} - if length(unique(reg)) < length(reg) +_anyrepeated(::Union{Number, AbstractUnitRange}) = false +function _anyrepeated(region) + any(region) do x + count(==(x), region) > 1 + end +end + +# Utility methods to reduce allocations in dims_howmany +@inline _setindex(oreg, v, n) = (oreg[n] = v; oreg) +@inline _setindex(oreg::Tuple, v, n) = Base.setindex(oreg, v, n) +@inline _filtercoll(region::Union{Int, Tuple}, len) = ntuple(zero, len) +@inline _filtercoll(region, len) = Vector{Int}(undef, len) +# Optimized filter(∉(region), 1:ndims(X)) +function _filter_notin_region(region, ::Val{ndimsX}) where {ndimsX} + oreg = _filtercoll(region, ndimsX - length(region)) + n = 1 + for dim in 1:ndimsX + dim in region && continue + oreg = _setindex(oreg, dim, n) + n += 1 + end + oreg +end +function dims_howmany(X::StridedArray, Y::StridedArray, sz, region) + if _anyrepeated(region) throw(ArgumentError("each dimension can be transformed at most once")) end - ist = [strides(X)...] - ost = [strides(Y)...] - dims = Matrix(transpose([sz[reg] ist[reg] ost[reg]])) - oreg = [1:ndims(X);] - oreg[reg] .= 0 - oreg = filter(d -> d > 0, oreg) - howmany = Matrix(transpose([sz[oreg] ist[oreg] ost[oreg]])) - return (dims, howmany) + ist = strides(X) + ost = strides(Y) + dims = Matrix{Int}(undef, 3, length(region)) + for (ind, i) in enumerate(region) + dims[1, ind] = sz[i] + dims[2, ind] = ist[i] + dims[3, ind] = ost[i] + end + + oreg = _filter_notin_region(region, Val(ndims(X))) + howmany = Matrix{Int}(undef, 3, length(oreg)) + for (ind, i) in enumerate(oreg) + howmany[1, ind] = sz[i] + howmany[2, ind] = ist[i] + howmany[3, ind] = ost[i] + end + + return dims, howmany end function fix_kinds(region, kinds) @@ -604,6 +635,10 @@ function fix_kinds(region, kinds) return k end +_circshiftmin1(v) = circshift(collect(Int, v), -1) +_circshiftmin1(t::Tuple) = (t[2:end]..., t[1]) +_circshiftmin1(x::Integer) = x + # low-level FFTWPlan creation (for internal use in FFTW module) for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3), (:Float32,:(Complex{Float32}),"fftwf",:libfftw3f)) @@ -613,7 +648,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3), direction = K unsafe_set_timelimit($Tr, timelimit) R = isa(region, Tuple) ? region : copy(region) - dims, howmany = dims_howmany(X, Y, [size(X)...], R) + dims, howmany = dims_howmany(X, Y, size(X), R) plan = ccall(($(string(fftw,"_plan_guru64_dft")),$lib[]), PlanPtr, (Int32, Ptr{Int}, Int32, Ptr{Int}, @@ -631,9 +666,9 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3), Y::StridedArray{$Tc,N}, region, flags::Integer, timelimit::Real) where {inplace,N} R = isa(region, Tuple) ? region : copy(region) - region = circshift(Int[region...],-1) # FFTW halves last dim + regionshft = _circshiftmin1(region) # FFTW halves last dim unsafe_set_timelimit($Tr, timelimit) - dims, howmany = dims_howmany(X, Y, [size(X)...], region) + dims, howmany = dims_howmany(X, Y, size(X), regionshft) plan = ccall(($(string(fftw,"_plan_guru64_dft_r2c")),$lib[]), PlanPtr, (Int32, Ptr{Int}, Int32, Ptr{Int}, @@ -651,9 +686,9 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3), Y::StridedArray{$Tr,N}, region, flags::Integer, timelimit::Real) where {inplace,N} R = isa(region, Tuple) ? region : copy(region) - region = circshift(Int[region...],-1) # FFTW halves last dim + regionshft = _circshiftmin1(region) # FFTW halves last dim unsafe_set_timelimit($Tr, timelimit) - dims, howmany = dims_howmany(X, Y, [size(Y)...], region) + dims, howmany = dims_howmany(X, Y, size(Y), regionshft) plan = ccall(($(string(fftw,"_plan_guru64_dft_c2r")),$lib[]), PlanPtr, (Int32, Ptr{Int}, Int32, Ptr{Int}, @@ -675,7 +710,7 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3), R = isa(region, Tuple) ? region : copy(region) knd = fix_kinds(region, kinds) unsafe_set_timelimit($Tr, timelimit) - dims, howmany = dims_howmany(X, Y, [size(X)...], region) + dims, howmany = dims_howmany(X, Y, size(X), region) plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]), PlanPtr, (Int32, Ptr{Int}, Int32, Ptr{Int}, @@ -698,9 +733,11 @@ for (Tr,Tc,fftw,lib) in ((:Float64,:(Complex{Float64}),"fftw",:libfftw3), R = isa(region, Tuple) ? region : copy(region) knd = fix_kinds(region, kinds) unsafe_set_timelimit($Tr, timelimit) - dims, howmany = dims_howmany(X, Y, [size(X)...], region) - dims[2:3, 1:size(dims,2)] *= 2 - howmany[2:3, 1:size(howmany,2)] *= 2 + dims, howmany = dims_howmany(X, Y, size(X), region) + @views begin + dims[2:3, :] .*= 2 + howmany[2:3, :] .*= 2 + end howmany = [howmany [2,1,1]] # append loop over real/imag parts plan = ccall(($(string(fftw,"_plan_guru64_r2r")),$lib[]), PlanPtr, @@ -759,9 +796,9 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD)) cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit) end $plan_f(X::StridedArray{<:fftwComplex}; kws...) = - $plan_f(X, 1:ndims(X); kws...) + $plan_f(X, ntuple(identity, ndims(X)); kws...) $plan_f!(X::StridedArray{<:fftwComplex}; kws...) = - $plan_f!(X, 1:ndims(X); kws...) + $plan_f!(X, ntuple(identity, ndims(X)); kws...) function plan_inv(p::cFFTWPlan{T,$direction,inplace,N}; num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace} @@ -845,8 +882,8 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) end end - plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...) - plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...) + plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,ntuple(identity, ndims(X));kws...) + plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,ntuple(identity, ndims(X));kws...) function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N}, num_threads::Union{Nothing, Integer} = nothing) where N diff --git a/test/runtests.jl b/test/runtests.jl index 257e344..301194d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -181,6 +181,9 @@ true_fftd3_m3d[:,:,2] .= -15 end @testset "rfft/rfftn" begin + # Test regions as int/collection + @test rfft(m4,1) == rfft(m4,1:1) == rfft(m4,(1,)) == rfft(m4, [1]) + rfft_m4 = rfft(m4,1) rfftd2_m4 = rfft(m4,2) rfftn_m4 = rfft(m4)