diff --git a/ext/OffsetArraysExt.jl b/ext/OffsetArraysExt.jl index c7d472436..66b1dee67 100644 --- a/ext/OffsetArraysExt.jl +++ b/ext/OffsetArraysExt.jl @@ -2,6 +2,6 @@ module OffsetArraysExt import DSP import OffsetArrays -DSP.conv_with_offset(::OffsetArrays.IdOffsetRange) = true +DSP.conv_axis_with_offset(::OffsetArrays.IdOffsetRange) = true end diff --git a/src/dspbase.jl b/src/dspbase.jl index 51858bbe8..741fae4a5 100644 --- a/src/dspbase.jl +++ b/src/dspbase.jl @@ -660,8 +660,16 @@ function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::Abstra end # whether the given axis are to be considered to carry an offset for `conv!` and `conv` -conv_with_offset(::Base.OneTo) = false -conv_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))")) +conv_axis_with_offset(::Base.OneTo) = false +conv_axis_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))")) + +function conv_axes_with_offset(as::Tuple...) + with_offset = ((map(a -> map(conv_axis_with_offset, a), as)...)...,) + if !allequal(with_offset) + throw(ArgumentError("cannot mix offset and non-offset axes")) + end + return !isempty(with_offset) && first(with_offset) +end const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64} @@ -677,7 +685,7 @@ offsets. If none of them has offset axes, `size(out,d) ≥ size(u,d) + size(v,d) - 1` must hold. If both input and output have offset axes, `firstindex(out,d) ≤ firstindex(u,d) + firstindex(v,d)` and `lastindex(out,d) ≥ lastindex(u,d) + lastindex(v,d)` must hold (for d = 1,...,N). -A mix of offset and non-offset axes between input and output is not permitted. +A mix of offset and non-offset axes is not permitted. The `algorithm` keyword allows choosing the algorithm to use: * `:direct`: Evaluates the convolution sum in time domain. @@ -704,12 +712,8 @@ function conv!( v::AbstractArray{<:Number, N}; algorithm=:auto ) where {T<:Number, N} + offset = conv_axes_with_offset(axes(out), axes(u), axes(v)) ? 0 : 1 output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av - input_has_offset = conv_with_offset(au) || conv_with_offset(av) - if input_has_offset !== conv_with_offset(ao) - throw(ArgumentError("output must have offset axes if and only if the input has")) - end - offset = input_has_offset ? 0 : 1 return (first(au)+first(av) : last(au)+last(av)) .- offset end) @@ -752,9 +756,13 @@ function conv!( end end -conv_output_axis(au, av) = - conv_with_offset(au) || conv_with_offset(av) ? - (first(au)+first(av):last(au)+last(av)) : Base.OneTo(last(au) + last(av) - 1) +function conv_output_axes(au::Tuple, av::Tuple) + if conv_axes_with_offset(au, av) + return map((au, av) -> first(au)+first(av):last(au)+last(av), au, av) + else + return map((au, av) -> Base.OneTo(last(au) + last(av) - 1), au, av) + end +end """ conv(u, v; algorithm) @@ -768,7 +776,7 @@ function conv( u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs... ) where {Tu<:Number, Tv<:Number, N} T = promote_type(Tu, Tv) - out_axes = map(conv_output_axis, axes(u), axes(v)) + out_axes = conv_output_axes(axes(u), axes(v)) out = similar(u, T, out_axes) return conv!(out, u, v; kwargs...) end @@ -792,7 +800,7 @@ Uses 2-D FFT algorithm. """ function conv(u::AbstractVector{T}, v::Transpose{T,<:AbstractVector}, A::AbstractMatrix{T}) where T # Arbitrary indexing offsets not implemented - if any(conv_with_offset, (axes(u)..., axes(v)..., axes(A)...)) + if any(conv_axis_with_offset, (axes(u)..., axes(v)..., axes(A)...)) throw(ArgumentError("offset axes not supported")) end m = length(u)+size(A,1)-1 diff --git a/test/dsp.jl b/test/dsp.jl index cf6163e6c..610b44bc4 100644 --- a/test/dsp.jl +++ b/test/dsp.jl @@ -72,13 +72,19 @@ end offset_arr = OffsetVector{Int}(undef, -1:2) offset_arr[:] = a - @test conv(offset_arr, 1:3) == OffsetVector(expectation, 0:5) + @test_throws ArgumentError conv(offset_arr, 1:3) + @test conv(offset_arr, OffsetArray(1:3)) == OffsetVector(expectation, 0:5) offset_arr_f = OffsetVector{Float64}(undef, -1:2) offset_arr_f[:] = fa - @test conv(offset_arr_f, 1:3) ≈ OffsetVector(fexp, 0:5) + @test_throws ArgumentError conv(offset_arr_f, 1:3) + @test conv(offset_arr_f, OffsetArray(1:3)) ≈ OffsetVector(fexp, 0:5) @test_throws ArgumentError conv!(zeros(6), offset_arr, 1:3) # output needs to be OA, too @test_throws ArgumentError conv!(OffsetVector{Int}(undef, 1:6), 1:4, 1:3) # output mustn't be OA + @test conv(fa, fill(true)) == conv(fill(true), fa) == fa + @test_broken conv(offset_arr_f, fill(true)) == conv(fill(true), offset_arr_f) == offset_arr_f + @test conv(fill(true), fill(true)) == fill(true) + for M in [10, 200], N in [10, 200], T in [Float64, ComplexF64] u = rand(T, M) v = rand(T, N) @@ -156,7 +162,8 @@ end offset_arr = OffsetMatrix{Int}(undef, -1:1, -1:1) offset_arr[:] = a - @test conv(offset_arr, b) == OffsetArray(expectation, 0:3, 0:3) + @test_throws ArgumentError conv(offset_arr, b) + @test conv(offset_arr, OffsetArray(b)) == OffsetArray(expectation, 0:3, 0:3) for (M1, M2) in [(10, 20), (190, 200)], (N1, N2) in [(20, 10), (210, 200)], T in [Float64, ComplexF64] u = rand(T, M1, M2)