Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added eltype for multivariate distributions #882

Merged
merged 11 commits into from
Jul 11, 2019
1 change: 1 addition & 0 deletions docs/src/extends.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ Following methods need to be implemented for each multivariate distribution type

- [`length(d::MultivariateDistribution)`](@ref)
- [`sampler(d::Distribution)`](@ref)
- [`eltype(d::Distribution)`](@ref)
- [`Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractArray)`](@ref)
- [`Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray)`](@ref)

Expand Down
1 change: 1 addition & 0 deletions docs/src/multivariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The methods listed as below are implemented for each multivariate distribution,
```@docs
length(::MultivariateDistribution)
size(::MultivariateDistribution)
eltype(d::MultivariateDistribution)
mean(::MultivariateDistribution)
var(::MultivariateDistribution)
cov(::MultivariateDistribution)
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ end

length(d::DirichletCanon) = length(d.alpha)

eltype(::Dirichlet{T}) where {T} = T
#### Conversions
convert(::Type{Dirichlet{Float64}}, cf::DirichletCanon) = Dirichlet(cf.alpha)
convert(::Type{Dirichlet{T}}, alpha::Vector{S}) where {T<:Real, S<:Real} =
Expand Down
3 changes: 1 addition & 2 deletions src/multivariate/dirichletmultinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ ncategories(d::DirichletMultinomial) = length(d.α)
length(d::DirichletMultinomial) = ncategories(d)
ntrials(d::DirichletMultinomial) = d.n
params(d::DirichletMultinomial) = (d.n, d.α)
@inline partype(d::DirichletMultinomial{T}) where {T<:Real} = T

@inline partype(d::DirichletMultinomial{T}) where {T} = T

# Statistics
mean(d::DirichletMultinomial) = d.α .* (d.n / d.α0)
Expand Down
2 changes: 2 additions & 0 deletions src/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ MvLogNormal(Σ::AbstractMatrix) = MvLogNormal(MvNormal(Σ))
MvLogNormal(σ::AbstractVector) = MvLogNormal(MvNormal(σ))
MvLogNormal(d::Int,s::Real) = MvLogNormal(MvNormal(d,s))


eltype(::MvLogNormal{T}) where {T} = T
### Conversion
function convert(::Type{MvLogNormal{T}}, d::MvLogNormal) where T<:Real
MvLogNormal(convert(MvNormal{T}, d.normal))
Expand Down
4 changes: 3 additions & 1 deletion src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ MvNormal(Σ::Matrix{<:Real}) = MvNormal(PDMat(Σ))
MvNormal(σ::Vector{<:Real}) = MvNormal(PDiagMat(abs2.(σ)))
MvNormal(d::Int, σ::Real) = MvNormal(ScalMat(d, abs2(σ)))


eltype(::MvNormal{T}) where {T} = T
### Conversion
function convert(::Type{MvNormal{T}}, d::MvNormal) where T<:Real
MvNormal(convert(AbstractArray{T}, d.μ), convert(AbstractArray{T}, d.Σ))
Expand Down Expand Up @@ -270,7 +272,7 @@ _rand!(rng::AbstractRNG, d::MvNormal, x::VecOrMat) =
# Workaround: randn! only works for Array, but not generally for AbstractArray
function _rand!(rng::AbstractRNG, d::MvNormal, x::AbstractVector)
for i in eachindex(x)
@inbounds x[i] = randn(rng)
@inbounds x[i] = randn(rng,eltype(d))
end
add!(unwhiten!(d.Σ, x), d.μ)
end
Expand Down
3 changes: 2 additions & 1 deletion src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::AbstractP
if typeof(μ) == typeof(h)
return MvNormalCanon{T,typeof(J),typeof(μ)}(μ, h, J)
else
return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J)
return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J)
end
end

Expand Down Expand Up @@ -156,6 +156,7 @@ length(d::MvNormalCanon) = length(d.μ)
mean(d::MvNormalCanon) = convert(Vector{eltype(d.μ)}, d.μ)
params(d::MvNormalCanon) = (d.μ, d.h, d.J)
@inline partype(d::MvNormalCanon{T}) where {T<:Real} = T
eltype(::MvNormalCanon{T}) where {T} = T

var(d::MvNormalCanon) = diag(inv(d.J))
cov(d::MvNormalCanon) = Matrix(inv(d.J))
Expand Down
38 changes: 21 additions & 17 deletions src/multivariate/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ end

GenericMvTDist(df::Real, μ::Vector{S}, Σ::Cov) where {Cov<:AbstractPDMat, S<:Real} = GenericMvTDist(df, μ, Σ, allzeros(μ))

GenericMvTDist(df::T, Σ::Cov) where {Cov<:AbstractPDMat, T<:Real} = GenericMvTDist(df, zeros(dim(Σ)), Σ, true)
function GenericMvTDist(df::T, Σ::Cov) where {Cov<:AbstractPDMat, T<:Real}
R = Base.promote_eltype(T, Σ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure Base.promote_eltype is what's needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, we want a promote_type between T and eltype of Σ and T<:Real

GenericMvTDist(df, zeros(R,dim(Σ)), Σ, true)
end

### Conversion
function convert(::Type{GenericMvTDist{T}}, d::GenericMvTDist) where T<:Real
Expand All @@ -50,19 +53,19 @@ const IsoTDist = GenericMvTDist{Float64, ScalMat{Float64}}
const DiagTDist = GenericMvTDist{Float64, PDiagMat{Float64,Vector{Float64}}}
const MvTDist = GenericMvTDist{Float64, PDMat{Float64,Matrix{Float64}}}

MvTDist(df::Real, μ::Vector{Float64}, C::PDMat) = GenericMvTDist(df, μ, C)
MvTDist(df::Real, μ::Vector{<:Real}, C::PDMat) = GenericMvTDist(df, μ, C)
MvTDist(df::Real, C::PDMat) = GenericMvTDist(df, C)
MvTDist(df::Real, μ::Vector{Float64}, Σ::Matrix{Float64}) = GenericMvTDist(df, μ, PDMat(Σ))
MvTDist(df::Float64, Σ::Matrix{Float64}) = GenericMvTDist(df, PDMat(Σ))
MvTDist(df::Real, μ::Vector{<:Real}, Σ::Matrix{<:Real}) = GenericMvTDist(df, μ, PDMat(Σ))
MvTDist(df::Real, Σ::Matrix{<:Real}) = GenericMvTDist(df, PDMat(Σ))

DiagTDist(df::Float64, μ::Vector{Float64}, C::PDiagMat) = GenericMvTDist(df, μ, C)
DiagTDist(df::Float64, C::PDiagMat) = GenericMvTDist(df, C)
DiagTDist(df::Float64, μ::Vector{Float64}, σ::Vector{Float64}) = GenericMvTDist(df, μ, PDiagMat(abs2.(σ)))
DiagTDist(df::Real, μ::Vector{<:Real}, C::PDiagMat) = GenericMvTDist(df, μ, C)
DiagTDist(df::Real, C::PDiagMat) = GenericMvTDist(df, C)
DiagTDist(df::Real, μ::Vector{<:Real}, σ::Vector{<:Real}) = GenericMvTDist(df, μ, PDiagMat(abs2.(σ)))

IsoTDist(df::Float64, μ::Vector{Float64}, C::ScalMat) = GenericMvTDist(df, μ, C)
IsoTDist(df::Float64, C::ScalMat) = GenericMvTDist(df, C)
IsoTDist(df::Float64, μ::Vector{Float64}, σ::Real) = GenericMvTDist(df, μ, ScalMat(length(μ), abs2(Float64(σ))))
IsoTDist(df::Float64, d::Int, σ::Real) = GenericMvTDist(df, ScalMat(d, abs2(Float64(σ))))
IsoTDist(df::Real, μ::Vector{<:Real}, C::ScalMat) = GenericMvTDist(df, μ, C)
IsoTDist(df::Real, C::ScalMat) = GenericMvTDist(df, C)
IsoTDist(df::Real, μ::Vector{<:Real}, σ::Real) = GenericMvTDist(df, μ, ScalMat(length(μ), abs2(σ)))
IsoTDist(df::Real, d::Int, σ::Real) = GenericMvTDist(df, ScalMat(d, abs2(σ)))

## convenient function to construct distributions of proper type based on arguments

Expand All @@ -75,11 +78,11 @@ mvtdist(df::Real, μ::Vector, σ::Vector) = GenericMvTDist(df, μ, PDiagMat(abs2
mvtdist(df::Real, μ::Vector, Σ::Matrix) = GenericMvTDist(df, μ, PDMat(Σ))
mvtdist(df::Real, Σ::Matrix) = GenericMvTDist(df, PDMat(Σ))

mvtdist(df::Float64, μ::Vector{Float64}, σ::Real) = IsoTDist(df, μ, Float64(σ))
mvtdist(df::Float64, d::Int, σ::Float64) = IsoTDist(d, σ)
mvtdist(df::Float64, μ::Vector{Float64}, σ::Vector{Float64}) = DiagTDist(df, μ, σ)
mvtdist(df::Float64, μ::Vector{Float64}, Σ::Matrix{Float64}) = MvTDist(df, μ, Σ)
mvtdist(df::Float64, Σ::Matrix{Float64}) = MvTDist(df, Σ)
# mvtdist(df::Real, μ::Vector{<:Real}, σ::Real) = IsoTDist(df, μ, σ)
# mvtdist(df::Real, d::Int, σ::Real) = IsoTDist(d, σ)
mvtdist(df::Real, μ::Vector{<:Real}, σ::Vector{<:Real}) = DiagTDist(df, μ, σ)
mvtdist(df::Real, μ::Vector{<:Real}, Σ::Matrix{<:Real}) = MvTDist(df, μ, Σ)
mvtdist(df::Real, Σ::Matrix{<:Real}) = MvTDist(df, Σ)

# Basic statistics

Expand All @@ -97,7 +100,8 @@ invcov(d::GenericMvTDist) = d.df>2 ? ((d.df-2)/d.df)*Matrix(inv(d.Σ)) : NaN*one
logdet_cov(d::GenericMvTDist) = d.df>2 ? logdet((d.df/(d.df-2))*d.Σ) : NaN

params(d::GenericMvTDist) = (d.df, d.μ, d.Σ)
@inline partype(d::GenericMvTDist{T}) where {T<:Real} = T
@inline partype(d::GenericMvTDist{T}) where {T} = T
eltype(::GenericMvTDist{T}) where {T} = T

# For entropy calculations see "Multivariate t Distributions and their Applications", S. Kotz & S. Nadarajah
function entropy(d::GenericMvTDist)
Expand Down
7 changes: 7 additions & 0 deletions src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ Return the sample size of distribution `d`, *i.e* `(length(d),)`.
"""
size(d::MultivariateDistribution)

"""
eltype(d::MultivariateDistribution)
theogf marked this conversation as resolved.
Show resolved Hide resolved
Return the sample type of distribution `d`
"""
eltype(d::MultivariateDistribution)


## sampling

"""
Expand Down
33 changes: 16 additions & 17 deletions src/samplers/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,23 @@

# suitable for shape >= 1.0

struct GammaGDSampler <: Sampleable{Univariate,Continuous}
a::Float64
s2::Float64
s::Float64
i2s::Float64
d::Float64
q0::Float64
b::Float64
σ::Float64
c::Float64
scale::Float64
struct GammaGDSampler{T<:Real} <: Sampleable{Univariate,Continuous}
a::T
s2::T
s::T
i2s::T
d::T
q0::T
b::T
σ::T
c::T
scale::T
end

function GammaGDSampler(g::Gamma)
function GammaGDSampler(g::Gamma{T}) where {T}
a = shape(g)

# Step 1
s2 = a-0.5
s2 = a - 0.5
s = sqrt(s2)
i2s = 0.5/s
d = 5.656854249492381 - 12.0s # 4*sqrt(2) - 12s
Expand Down Expand Up @@ -55,7 +54,7 @@ function GammaGDSampler(g::Gamma)
c = 0.1515/s
end

GammaGDSampler(a,s2,s,i2s,d,q0,b,σ,c,scale(g))
GammaGDSampler(a,s2,s,i2s,d,q0,T(b),T(σ),T(c),scale(g))
end

function calc_q(s::GammaGDSampler, t)
Expand Down Expand Up @@ -195,9 +194,9 @@ end

# Inverse Power sampler
# uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1
struct GammaIPSampler{S<:Sampleable{Univariate,Continuous}} <: Sampleable{Univariate,Continuous}
struct GammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous}
s::S #sampler for Gamma(1+shape,scale)
nia::Float64 #-1/scale
nia::T #-1/scale
end

function GammaIPSampler(d::Gamma,::Type{S}) where S<:Sampleable
Expand Down
16 changes: 16 additions & 0 deletions test/types.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Test type relations

using Distributions
using ForwardDiff: Dual

@assert UnivariateDistribution <: Distribution
@assert MultivariateDistribution <: Distribution
Expand All @@ -21,3 +22,18 @@ using Distributions
@assert DiscreteMatrixDistribution <: MatrixDistribution
@assert ContinuousMatrixDistribution <: ContinuousDistribution
@assert ContinuousMatrixDistribution <: MatrixDistribution

@testset "Test Sample Type" begin
for T in (Float64,Float32,Dual{Nothing,Float64,0})
@testset "Type $T" begin
for d in (MvNormal,MvLogNormal,MvNormalCanon,Dirichlet)
dist = d(map(T,ones(2)))
@test eltype(dist) == T
@test eltype(rand(dist)) == eltype(dist)
end
dist = Distributions.mvtdist(map(T,1.0),map(T,[1.0 0.0; 0.0 1.0]))
@test eltype(dist) == T
@test eltype(rand(dist)) == eltype(dist)
end
end
end