From e875cd13ff2ffd4b6fadafb21190f27c495d6d12 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 30 Nov 2021 21:59:28 +0100 Subject: [PATCH 1/7] Simplify convolution of `MvNormal` --- src/convolution.jl | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/src/convolution.jl b/src/convolution.jl index a5c81737cf..538a564197 100644 --- a/src/convolution.jl +++ b/src/convolution.jl @@ -61,27 +61,9 @@ function convolve(d1::Gamma, d2::Gamma) end # continuous multivariate -# The first two methods exist for performance reasons to avoid unnecessarily converting -# PDMats to a Matrix -function convolve( - d1::Union{IsoNormal, ZeroMeanIsoNormal, DiagNormal, ZeroMeanDiagNormal}, - d2::Union{IsoNormal, ZeroMeanIsoNormal, DiagNormal, ZeroMeanDiagNormal}, - ) - _check_convolution_shape(d1, d2) - return MvNormal(d1.μ .+ d2.μ, d1.Σ + d2.Σ) -end - -function convolve( - d1::Union{FullNormal, ZeroMeanFullNormal}, - d2::Union{FullNormal, ZeroMeanFullNormal}, - ) - _check_convolution_shape(d1, d2) - return MvNormal(d1.μ .+ d2.μ, d1.Σ.mat + d2.Σ.mat) -end - function convolve(d1::MvNormal, d2::MvNormal) _check_convolution_shape(d1, d2) - return MvNormal(d1.μ .+ d2.μ, Matrix(d1.Σ) + Matrix(d2.Σ)) + return MvNormal(d1.μ + d2.μ, d1.Σ + d2.Σ) end From ddfe7a48c369761a07e26f32eb03823404079986 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 30 Nov 2021 22:00:38 +0100 Subject: [PATCH 2/7] Add Unicode operator as alias for `convolve` --- src/Distributions.jl | 1 + src/convolution.jl | 49 ++++++++++++++++++++++++++------------------ 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/Distributions.jl b/src/Distributions.jl index 3cf0a2812a..375afc4f21 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -190,6 +190,7 @@ export componentwise_logpdf, # component-wise logpdf for mixture models concentration, # the concentration parameter convolve, # convolve distributions of the same type + ⊕, # unicode alias of convolve dim, # sample dimension of multivariate distribution dof, # get the degree of freedom entropy, # entropy of distribution in nats diff --git a/src/convolution.jl b/src/convolution.jl index 538a564197..8814b51eee 100644 --- a/src/convolution.jl +++ b/src/convolution.jl @@ -1,25 +1,34 @@ """ - convolve(d1::T, d2::T) where T<:Distribution -> Distribution - -Convolve two distributions of the same type to yield the distribution corresponding to the -sum of independent random variables drawn from the underlying distributions. - -The function is only defined in the cases where the convolution has a closed form as -defined here https://en.wikipedia.org/wiki/List_of_convolutions_of_probability_distributions - -* `Bernoulli` -* `Binomial` -* `NegativeBinomial` -* `Geometric` -* `Poisson` -* `Normal` -* `Cauchy` -* `Chisq` -* `Exponential` -* `Gamma` -* `MultivariateNormal` + convolve(d1::Distribution, d2::Distribution) + d1 ⊕ d2 + +Convolve two distributions and return the distribution corresponding to the sum of +independent random variables drawn from the underlying distributions. + +The Unicode operator `⊕` can be typed by `\\oplus`. + +Currently, the function is only defined in cases where the convolution has a closed form. +More precisely, the function is defined if the distributions of `d1` and `d2` are the same +and one of +* [`Bernoulli`](@ref) +* [`Binomial`](@ref) +* [`NegativeBinomial`](@ref) +* [`Geometric`](@ref) +* [`Poisson`](@ref) +* [`Normal`](@ref) +* [`Cauchy`](@ref) +* [`Chisq`](@ref) +* [`Exponential`](@ref) +* [`Gamma`](@ref) +* [`MvNormal`](@ref) + +External links: [List of convolutions of probability distributions on Wikipedia](https://en.wikipedia.org/wiki/List_of_convolutions_of_probability_distributions) """ -function convolve end +convolve(::Distribution, ::Distribution) + +# define Unicode alias and add docstring +⊕(d1::Distribution, d2::Distribution) = convolve(d1, d2) +@doc (@doc convolve) :⊕ # discrete univariate function convolve(d1::Bernoulli, d2::Bernoulli) From 084fd11672af19dd82a6db1923b2d3ad32ea1d00 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 30 Nov 2021 22:00:54 +0100 Subject: [PATCH 3/7] Update tests --- test/convolution.jl | 195 ++++++++++++++++++-------------------------- 1 file changed, 80 insertions(+), 115 deletions(-) diff --git a/test/convolution.jl b/test/convolution.jl index bcb5292d85..4289ea5a8c 100644 --- a/test/convolution.jl +++ b/test/convolution.jl @@ -8,131 +8,134 @@ using Test @testset "Bernoulli" begin d1 = Bernoulli(0.1) - d2 = convolve(d1, d1) - - @test isa(d2, Binomial) - @test d2.n == 2 - @test d2.p == 0.1 - - # cannot convolve a Binomial with a Bernoulli - @test_throws MethodError convolve(d1, d2) + for d2 in (@inferred(convolve(d1, d1)), @inferred(d1 ⊕ d1)) + @test d2 isa Binomial + @test d2.n == 2 + @test d2.p == 0.1 + + # cannot convolve a Binomial with a Bernoulli + @test_throws MethodError convolve(d1, d2) + @test_throws MethodError d1 ⊕ d2 + end # only works if p1 ≈ p2 d3 = Bernoulli(0.2) @test_throws ArgumentError convolve(d1, d3) - + @test_throws ArgumentError d1 ⊕ d3 end @testset "Binomial" begin d1 = Binomial(2, 0.1) d2 = Binomial(5, 0.1) - d3 = convolve(d1, d2) - - @test isa(d3, Binomial) - @test d3.n == 7 - @test d3.p == 0.1 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Binomial + @test d3.n == 7 + @test d3.p == 0.1 + end # only works if p1 ≈ p2 d4 = Binomial(2, 0.2) @test_throws ArgumentError convolve(d1, d4) - + @test_throws ArgumentError d1 ⊕ d4 end @testset "NegativeBinomial" begin d1 = NegativeBinomial(4, 0.1) d2 = NegativeBinomial(1, 0.1) - d3 = convolve(d1, d2) - - isa(d3, NegativeBinomial) - @test d3.r == 5 - @test d3.p == 0.1 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa NegativeBinomial + @test d3.r == 5 + @test d3.p == 0.1 + end d4 = NegativeBinomial(1, 0.2) @test_throws ArgumentError convolve(d1, d4) + @test_throws ArgumentError d1 ⊕ d4 end - @testset "Geometric" begin d1 = Geometric(0.2) - d2 = convolve(d1, d1) - - @test isa(d2, NegativeBinomial) - @test d2.p == 0.2 + for d2 in (@inferred(convolve(d1, d1)), @inferred(d1 ⊕ d1)) + @test d2 isa NegativeBinomial + @test d2.p == 0.2 - # cannot convolve a Geometric with a NegativeBinomial - @test_throws MethodError convolve(d1, d2) + # cannot convolve a Geometric with a NegativeBinomial + @test_throws MethodError convolve(d1, d2) + @test_throws MethodError d1 ⊕ d2 + end # only works if p1 ≈ p2 d3 = Geometric(0.5) @test_throws ArgumentError convolve(d1, d3) + @test_throws ArgumentError d1 ⊕ d3 end @testset "Poisson" begin d1 = Poisson(0.1) d2 = Poisson(0.4) - d3 = convolve(d1, d2) - - @test isa(d3, Poisson) - @test d3.λ == 0.5 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Poisson + @test d3.λ == 0.5 + end end - end @testset "continuous univariate" begin - @testset "Gaussian" begin d1 = Normal(0.1, 0.2) d2 = Normal(0.25, 1.7) - d3 = convolve(d1, d2) - - @test isa(d3, Normal) - @test d3.μ == 0.35 - @test d3.σ == hypot(0.2, 1.7) + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Normal + @test d3.μ == 0.35 + @test d3.σ == hypot(0.2, 1.7) + end end @testset "Cauchy" begin d1 = Cauchy(0.2, 0.7) d2 = Cauchy(1.9, 0.8) - d3 = convolve(d1, d2) - - @test isa(d3, Cauchy) - @test d3.μ == 2.1 - @test d3.σ == 1.5 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Cauchy + @test d3.μ == 2.1 + @test d3.σ == 1.5 + end end @testset "Chisq" begin d1 = Chisq(0.1) d2 = Chisq(0.3) - d3 = convolve(d1, d2) - - @test isa(d3, Chisq) - @test d3.ν == 0.4 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Chisq + @test d3.ν == 0.4 + end end @testset "Exponential" begin d1 = Exponential(0.7) - d2 = convolve(d1, d1) - - @test isa(d2, Gamma) - @test d2.α == 2 - @test d2.θ == 0.7 - - # cannot convolve an Exponential with a Gamma - @test_throws MethodError convolve(d1, d2) + for d2 in (@inferred(convolve(d1, d1)), @inferred(d1 ⊕ d1)) + @test d2 isa Gamma + @test d2.α == 2 + @test d2.θ == 0.7 + + # cannot convolve an Exponential with a Gamma + @test_throws MethodError convolve(d1, d2) + @test_throws MethodError d1 ⊕ d2 + end # only works if θ1 ≈ θ2 d3 = Exponential(0.2) @test_throws ArgumentError convolve(d1, d3) + @test_throws ArgumentError d1 ⊕ d3 end @testset "Gamma" begin d1 = Gamma(0.1, 1.7) d2 = Gamma(0.5, 1.7) - d3 = convolve(d1, d2) - - @test isa(d3, Gamma) - @test d3.α == 0.6 - @test d3.θ == 1.7 + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa Gamma + @test d3.α == 0.6 + @test d3.θ == 1.7 + end # only works if θ1 ≈ θ4 d4 = Gamma(1.2, 0.4) @@ -142,9 +145,7 @@ end end @testset "continuous multivariate" begin - @testset "iso-/diag-normal" begin - in1 = MvNormal([1.2, 0.3], 2 * I) in2 = MvNormal([-2.0, 6.9], 0.5 * I) @@ -155,74 +156,38 @@ end dn2 = MvNormal([-3.4, 1.2], Diagonal([3.2, 0.2])) zmdn1 = MvNormal(Diagonal([1.2, 0.3])) - zmdn2 = MvNormal(Diagonal([-0.8, 1.0])) - - dist_list = (in1, in2, zmin1, zmin2, dn1, dn2, zmdn1, zmdn2) - - for (d1, d2) in Iterators.product(dist_list, dist_list) - d3 = convolve(d1, d2) - @test d3 isa Union{IsoNormal,DiagNormal,ZeroMeanIsoNormal,ZeroMeanDiagNormal} - @test d3.μ == d1.μ .+ d2.μ - @test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats - end - - # erroring - in3 = MvNormal([1, 2, 3], 0.2 * I) - @test_throws ArgumentError convolve(in1, in3) - end - - - @testset "full-normal" begin + zmdn2 = MvNormal(Diagonal([0.8, 1.0])) m1 = Symmetric(rand(2,2)) - m1sq = m1^2 - fn1 = MvNormal(ones(2), m1sq.data) + fn1 = MvNormal(ones(2), m1^2) m2 = Symmetric(rand(2,2)) - m2sq = m2^2 - fn2 = MvNormal([2.1, 0.4], m2sq.data) + fn2 = MvNormal([2.1, 0.4], m2^2) m3 = Symmetric(rand(2,2)) - m3sq = m3^2 - zm1 = MvNormal(m3sq.data) + zm1 = MvNormal(m3^2) m4 = Symmetric(rand(2,2)) - m4sq = m4^2 - zm2 = MvNormal(m4sq.data) + zm2 = MvNormal(m4^2) - dist_list = (fn1, fn2, zm1, zm2) + dist_list = (in1, in2, zmin1, zmin2, dn1, dn2, zmdn1, zmdn2, fn1, fn2, zm1, zm2) for (d1, d2) in Iterators.product(dist_list, dist_list) - d3 = convolve(d1, d2) - @test d3 isa Union{FullNormal,ZeroMeanFullNormal} - @test d3.μ == d1.μ .+ d2.μ - @test d3.Σ.mat == d1.Σ.mat + d2.Σ.mat # isequal not defined for PDMats + for d3 in (@inferred(convolve(d1, d2)), @inferred(d1 ⊕ d2)) + @test d3 isa MvNormal + @test d3.μ == d1.μ .+ d2.μ + @test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats + end end # erroring + in3 = MvNormal([1, 2, 3], 0.2 * I) + @test_throws ArgumentError convolve(in1, in3) + @test_throws ArgumentError in1 ⊕ in3 + m5 = Symmetric(rand(3, 3)) - m5sq = m5^2 - fn3 = MvNormal(zeros(3), m5sq.data) + fn3 = MvNormal(zeros(3), m5^2) @test_throws ArgumentError convolve(fn1, fn3) - end - - @testset "mixed" begin - - in1 = MvNormal([1.2, 0.3], 2 * I) - zmin1 = MvNormal(Zeros(2), 1.9 * I) - dn1 = MvNormal([0.0, 4.7], Diagonal([0.1, 1.8])) - zmdn1 = MvNormal(Diagonal([1.2, 0.3])) - m1 = Symmetric(rand(2, 2)) - m1sq = m1^2 - full = MvNormal(ones(2), m1sq.data) - - dist_list = (in1, zmin1, dn1, zmdn1) - - for (d1, d2) in Iterators.product((full, ), dist_list) - d3 = convolve(d1, d2) - @test isa(d3, MvNormal) - @test d3.μ == d1.μ .+ d2.μ - @test Matrix(d3.Σ) == Matrix(d1.Σ + d2.Σ) # isequal not defined for PDMats - end + @test_throws ArgumentError fn1 ⊕ fn3 end end From e19399a7ce3dfa5b689efc6862cea16816dc6a71 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 30 Nov 2021 22:01:05 +0100 Subject: [PATCH 4/7] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 95fa6c9057..06510bb0b4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.34" +version = "0.25.35" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From a1f0dc81847be06766df748d37132f81e1d66a28 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 30 Nov 2021 22:13:18 +0100 Subject: [PATCH 5/7] Add documentation --- docs/make.jl | 1 + docs/src/convolution.md | 12 ++++++++++++ 2 files changed, 13 insertions(+) create mode 100644 docs/src/convolution.md diff --git a/docs/make.jl b/docs/make.jl index b451eecd2b..6b33de4ae3 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,6 +16,7 @@ makedocs( "reshape.md", "cholesky.md", "mixture.md", + "convolution.md", "fit.md", "extends.md", "density_interface.md", diff --git a/docs/src/convolution.md b/docs/src/convolution.md new file mode 100644 index 0000000000..07bf7a2812 --- /dev/null +++ b/docs/src/convolution.md @@ -0,0 +1,12 @@ +# Convolutions + +A [convolution of two probability distributions](https://en.wikipedia.org/wiki/List_of_convolutions_of_probability_distributions) +is the probability distribution of the sum of two independent random variables that are +distributed according to these distributions. + +The convolution of two distributions can be constructed with [`convolve`](@ref) or its +Unicode alias [`⊕`](@ref). + +```@docs +convolve +``` From 3649fc2cb32a847b60bb5eb09f51b78b0230074d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 1 Dec 2021 14:42:41 +0100 Subject: [PATCH 6/7] Update convolution.jl --- src/convolution.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/convolution.jl b/src/convolution.jl index 8814b51eee..c2ee7608c9 100644 --- a/src/convolution.jl +++ b/src/convolution.jl @@ -26,9 +26,8 @@ External links: [List of convolutions of probability distributions on Wikipedia] """ convolve(::Distribution, ::Distribution) -# define Unicode alias and add docstring -⊕(d1::Distribution, d2::Distribution) = convolve(d1, d2) -@doc (@doc convolve) :⊕ +# define Unicode alias +const ⊕ = convolve # discrete univariate function convolve(d1::Bernoulli, d2::Bernoulli) From c93190367a059a83cf160578df176e5953ac7d15 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 1 Dec 2021 15:20:09 +0100 Subject: [PATCH 7/7] Update convolution.jl --- src/convolution.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/convolution.jl b/src/convolution.jl index c2ee7608c9..fc52a1c370 100644 --- a/src/convolution.jl +++ b/src/convolution.jl @@ -26,9 +26,6 @@ External links: [List of convolutions of probability distributions on Wikipedia] """ convolve(::Distribution, ::Distribution) -# define Unicode alias -const ⊕ = convolve - # discrete univariate function convolve(d1::Bernoulli, d2::Bernoulli) _check_convolution_args(d1.p, d2.p) @@ -74,6 +71,8 @@ function convolve(d1::MvNormal, d2::MvNormal) return MvNormal(d1.μ + d2.μ, d1.Σ + d2.Σ) end +# define Unicode alias +const ⊕ = convolve function _check_convolution_args(p1, p2) p1 ≈ p2 || throw(ArgumentError(