diff --git a/docs/src/convolution.md b/docs/src/convolution.md index 352235675..07bf7a281 100644 --- a/docs/src/convolution.md +++ b/docs/src/convolution.md @@ -4,7 +4,8 @@ A [convolution of two probability distributions](https://en.wikipedia.org/wiki/L is the probability distribution of the sum of two independent random variables that are distributed according to these distributions. -The convolution of two distributions can be constructed with [`convolve`](@ref). +The convolution of two distributions can be constructed with [`convolve`](@ref) or its +Unicode alias [`⊕`](@ref). ```@docs convolve diff --git a/src/Distributions.jl b/src/Distributions.jl index 3cf0a2812..375afc4f2 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -190,6 +190,7 @@ export componentwise_logpdf, # component-wise logpdf for mixture models concentration, # the concentration parameter convolve, # convolve distributions of the same type + ⊕, # unicode alias of convolve dim, # sample dimension of multivariate distribution dof, # get the degree of freedom entropy, # entropy of distribution in nats diff --git a/src/convolution.jl b/src/convolution.jl index 6c18b70e4..fc52a1c37 100644 --- a/src/convolution.jl +++ b/src/convolution.jl @@ -1,9 +1,12 @@ """ convolve(d1::Distribution, d2::Distribution) + d1 ⊕ d2 Convolve two distributions and return the distribution corresponding to the sum of independent random variables drawn from the underlying distributions. +The Unicode operator `⊕` can be typed by `\\oplus`. + Currently, the function is only defined in cases where the convolution has a closed form. More precisely, the function is defined if the distributions of `d1` and `d2` are the same and one of @@ -68,6 +71,9 @@ function convolve(d1::MvNormal, d2::MvNormal) return MvNormal(d1.μ + d2.μ, d1.Σ + d2.Σ) end +# define Unicode alias +const ⊕ = convolve + function _check_convolution_args(p1, p2) p1 ≈ p2 || throw(ArgumentError( "$(p1) !≈ $(p2): distribution parameters must be approximately equal", diff --git a/test/convolution.jl b/test/convolution.jl index 826e4f82c..97800f5b0 100644 --- a/test/convolution.jl +++ b/test/convolution.jl @@ -7,64 +7,75 @@ using Test @testset "discrete univariate" begin @testset "Bernoulli" begin d1 = Bernoulli(0.1) - d2 = @inferred(convolve(d1, d1)) - @test d2 isa Binomial - @test d2.n == 2 - @test d2.p == 0.1 - - # cannot convolve a Binomial with a Bernoulli - @test_throws MethodError convolve(d1, d2) + for d2 in (@inferred(convolve(d1, d1)), @inferred(d1 ⊕ d1)) + @test d2 isa Binomial + @test d2.n == 2 + @test d2.p == 0.1 + + # cannot convolve a Binomial with a Bernoulli + @test_throws MethodError convolve(d1, d2) + @test_throws MethodError d1 ⊕ d2 + end # only works if p1 ≈ p2 d3 = Bernoulli(0.2) @test_throws ArgumentError convolve(d1, d3) + @test_throws ArgumentError d1 ⊕ d3 end @testset "Binomial" begin d1 = Binomial(2, 0.1) d2 = Binomial(5, 0.1) - d3 = @inferred(convolve(d1, d2)) - @test d3 isa Binomial - @test d3.n == 7 - @test d3.p == 0.1 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Binomial + @test d3.n == 7 + @test d3.p == 0.1 + end # only works if p1 ≈ p2 d4 = Binomial(2, 0.2) @test_throws ArgumentError convolve(d1, d4) + @test_throws ArgumentError d1 ⊕ d4 end @testset "NegativeBinomial" begin d1 = NegativeBinomial(4, 0.1) d2 = NegativeBinomial(1, 0.1) - d3 = @inferred(convolve(d1, d2)) - @test d3 isa NegativeBinomial - @test d3.r == 5 - @test d3.p == 0.1 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa NegativeBinomial + @test d3.r == 5 + @test d3.p == 0.1 + end d4 = NegativeBinomial(1, 0.2) @test_throws ArgumentError convolve(d1, d4) + @test_throws ArgumentError d1 ⊕ d4 end @testset "Geometric" begin d1 = Geometric(0.2) - d2 = convolve(d1, d1) - @test d2 isa NegativeBinomial - @test d2.p == 0.2 + for d2 in (@inferred(convolve(d1, d1)), @inferred(d1 ⊕ d1)) + @test d2 isa NegativeBinomial + @test d2.p == 0.2 - # cannot convolve a Geometric with a NegativeBinomial - @test_throws MethodError convolve(d1, d2) + # cannot convolve a Geometric with a NegativeBinomial + @test_throws MethodError convolve(d1, d2) + @test_throws MethodError d1 ⊕ d2 + end # only works if p1 ≈ p2 d3 = Geometric(0.5) @test_throws ArgumentError convolve(d1, d3) + @test_throws ArgumentError d1 ⊕ d3 end @testset "Poisson" begin d1 = Poisson(0.1) d2 = Poisson(0.4) - d3 = @inferred(convolve(d1, d2)) - @test d3 isa Poisson - @test d3.λ == 0.5 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Poisson + @test d3.λ == 0.5 + end end end @@ -72,51 +83,58 @@ end @testset "Gaussian" begin d1 = Normal(0.1, 0.2) d2 = Normal(0.25, 1.7) - d3 = @inferred(convolve(d1, d2)) - @test d3 isa Normal - @test d3.μ == 0.35 - @test d3.σ == hypot(0.2, 1.7) + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Normal + @test d3.μ == 0.35 + @test d3.σ == hypot(0.2, 1.7) + end end @testset "Cauchy" begin d1 = Cauchy(0.2, 0.7) d2 = Cauchy(1.9, 0.8) - d3 = @inferred(convolve(d1, d2)) - @test d3 isa Cauchy - @test d3.μ == 2.1 - @test d3.σ == 1.5 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Cauchy + @test d3.μ == 2.1 + @test d3.σ == 1.5 + end end @testset "Chisq" begin d1 = Chisq(0.1) d2 = Chisq(0.3) - d3 = @inferred(convolve(d1, d2)) - @test d3 isa Chisq - @test d3.ν == 0.4 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Chisq + @test d3.ν == 0.4 + end end @testset "Exponential" begin d1 = Exponential(0.7) - d2 = @inferred(convolve(d1, d1)) - @test d2 isa Gamma - @test d2.α == 2 - @test d2.θ == 0.7 - - # cannot convolve an Exponential with a Gamma - @test_throws MethodError convolve(d1, d2) + for d2 in (@inferred(convolve(d1, d1)), @inferred(d1 ⊕ d1)) + @test d2 isa Gamma + @test d2.α == 2 + @test d2.θ == 0.7 + + # cannot convolve an Exponential with a Gamma + @test_throws MethodError convolve(d1, d2) + @test_throws MethodError d1 ⊕ d2 + end # only works if θ1 ≈ θ2 d3 = Exponential(0.2) @test_throws ArgumentError convolve(d1, d3) + @test_throws ArgumentError d1 ⊕ d3 end @testset "Gamma" begin d1 = Gamma(0.1, 1.7) d2 = Gamma(0.5, 1.7) - d3 = @inferred(convolve(d1, d2)) - @test d3 isa Gamma - @test d3.α == 0.6 - @test d3.θ == 1.7 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Gamma + @test d3.α == 0.6 + @test d3.θ == 1.7 + end # only works if θ1 ≈ θ4 d4 = Gamma(1.2, 0.4) @@ -153,18 +171,21 @@ end dist_list = (in1, in2, zmin1, zmin2, dn1, dn2, zmdn1, zmdn2, fn1, fn2, zm1, zm2) for (d1, d2) in Iterators.product(dist_list, dist_list) - d3 = @inferred(convolve(d1, d2)) - @test d3 isa MvNormal - @test d3.μ == d1.μ .+ d2.μ - @test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa MvNormal + @test d3.μ == d1.μ .+ d2.μ + @test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats + end end # erroring in3 = MvNormal([1, 2, 3], 0.2 * I) @test_throws ArgumentError convolve(in1, in3) + @test_throws ArgumentError in1 ⊕ in3 m5 = Symmetric(rand(3, 3)) fn3 = MvNormal(zeros(3), m5^2) @test_throws ArgumentError convolve(fn1, fn3) + @test_throws ArgumentError fn1 ⊕ fn3 end end