diff --git a/docs/src/extends.md b/docs/src/extends.md index a997f0fc66..a2c6a13d01 100644 --- a/docs/src/extends.md +++ b/docs/src/extends.md @@ -139,6 +139,7 @@ Following methods need to be implemented for each multivariate distribution type - [`length(d::MultivariateDistribution)`](@ref) - [`sampler(d::Distribution)`](@ref) +- [`eltype(d::Distribution)`](@ref) - [`Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractArray)`](@ref) - [`Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray)`](@ref) diff --git a/docs/src/multivariate.md b/docs/src/multivariate.md index 9832e29053..5b994b6995 100644 --- a/docs/src/multivariate.md +++ b/docs/src/multivariate.md @@ -18,6 +18,7 @@ The methods listed as below are implemented for each multivariate distribution, ```@docs length(::MultivariateDistribution) size(::MultivariateDistribution) +eltype(d::MultivariateDistribution) mean(::MultivariateDistribution) var(::MultivariateDistribution) cov(::MultivariateDistribution) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index c5f055b341..9052d04e1f 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -57,6 +57,7 @@ end length(d::DirichletCanon) = length(d.alpha) +eltype(::Dirichlet{T}) where {T} = T #### Conversions convert(::Type{Dirichlet{Float64}}, cf::DirichletCanon) = Dirichlet(cf.alpha) convert(::Type{Dirichlet{T}}, alpha::Vector{S}) where {T<:Real, S<:Real} = diff --git a/src/multivariate/dirichletmultinomial.jl b/src/multivariate/dirichletmultinomial.jl index 1b0419538e..caaadd80af 100644 --- a/src/multivariate/dirichletmultinomial.jl +++ b/src/multivariate/dirichletmultinomial.jl @@ -22,8 +22,7 @@ ncategories(d::DirichletMultinomial) = length(d.α) length(d::DirichletMultinomial) = ncategories(d) ntrials(d::DirichletMultinomial) = d.n params(d::DirichletMultinomial) = (d.n, d.α) -@inline partype(d::DirichletMultinomial{T}) where {T<:Real} = T - +@inline partype(d::DirichletMultinomial{T}) where {T} = T # Statistics mean(d::DirichletMultinomial) = d.α .* (d.n / d.α0) diff --git a/src/multivariate/mvlognormal.jl b/src/multivariate/mvlognormal.jl index af1b61853e..bfb4b1a7d5 100644 --- a/src/multivariate/mvlognormal.jl +++ b/src/multivariate/mvlognormal.jl @@ -175,6 +175,8 @@ MvLogNormal(Σ::AbstractMatrix) = MvLogNormal(MvNormal(Σ)) MvLogNormal(σ::AbstractVector) = MvLogNormal(MvNormal(σ)) MvLogNormal(d::Int,s::Real) = MvLogNormal(MvNormal(d,s)) + +eltype(::MvLogNormal{T}) where {T} = T ### Conversion function convert(::Type{MvLogNormal{T}}, d::MvLogNormal) where T<:Real MvLogNormal(convert(MvNormal{T}, d.normal)) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 77d9397996..6effd498f8 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -219,6 +219,8 @@ MvNormal(Σ::Matrix{<:Real}) = MvNormal(PDMat(Σ)) MvNormal(σ::Vector{<:Real}) = MvNormal(PDiagMat(abs2.(σ))) MvNormal(d::Int, σ::Real) = MvNormal(ScalMat(d, abs2(σ))) + +eltype(::MvNormal{T}) where {T} = T ### Conversion function convert(::Type{MvNormal{T}}, d::MvNormal) where T<:Real MvNormal(convert(AbstractArray{T}, d.μ), convert(AbstractArray{T}, d.Σ)) @@ -270,7 +272,7 @@ _rand!(rng::AbstractRNG, d::MvNormal, x::VecOrMat) = # Workaround: randn! only works for Array, but not generally for AbstractArray function _rand!(rng::AbstractRNG, d::MvNormal, x::AbstractVector) for i in eachindex(x) - @inbounds x[i] = randn(rng) + @inbounds x[i] = randn(rng,eltype(d)) end add!(unwhiten!(d.Σ, x), d.μ) end diff --git a/src/multivariate/mvnormalcanon.jl b/src/multivariate/mvnormalcanon.jl index aa8d0e5295..23b3b00d93 100644 --- a/src/multivariate/mvnormalcanon.jl +++ b/src/multivariate/mvnormalcanon.jl @@ -85,7 +85,7 @@ function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::AbstractP if typeof(μ) == typeof(h) return MvNormalCanon{T,typeof(J),typeof(μ)}(μ, h, J) else - return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J) + return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J) end end @@ -156,6 +156,7 @@ length(d::MvNormalCanon) = length(d.μ) mean(d::MvNormalCanon) = convert(Vector{eltype(d.μ)}, d.μ) params(d::MvNormalCanon) = (d.μ, d.h, d.J) @inline partype(d::MvNormalCanon{T}) where {T<:Real} = T +eltype(::MvNormalCanon{T}) where {T} = T var(d::MvNormalCanon) = diag(inv(d.J)) cov(d::MvNormalCanon) = Matrix(inv(d.J)) diff --git a/src/multivariate/mvtdist.jl b/src/multivariate/mvtdist.jl index 67fb37d06e..93ff6e1472 100644 --- a/src/multivariate/mvtdist.jl +++ b/src/multivariate/mvtdist.jl @@ -32,7 +32,10 @@ end GenericMvTDist(df::Real, μ::Vector{S}, Σ::Cov) where {Cov<:AbstractPDMat, S<:Real} = GenericMvTDist(df, μ, Σ, allzeros(μ)) -GenericMvTDist(df::T, Σ::Cov) where {Cov<:AbstractPDMat, T<:Real} = GenericMvTDist(df, zeros(dim(Σ)), Σ, true) +function GenericMvTDist(df::T, Σ::Cov) where {Cov<:AbstractPDMat, T<:Real} + R = Base.promote_eltype(T, Σ) + GenericMvTDist(df, zeros(R,dim(Σ)), Σ, true) +end ### Conversion function convert(::Type{GenericMvTDist{T}}, d::GenericMvTDist) where T<:Real @@ -50,19 +53,19 @@ const IsoTDist = GenericMvTDist{Float64, ScalMat{Float64}} const DiagTDist = GenericMvTDist{Float64, PDiagMat{Float64,Vector{Float64}}} const MvTDist = GenericMvTDist{Float64, PDMat{Float64,Matrix{Float64}}} -MvTDist(df::Real, μ::Vector{Float64}, C::PDMat) = GenericMvTDist(df, μ, C) +MvTDist(df::Real, μ::Vector{<:Real}, C::PDMat) = GenericMvTDist(df, μ, C) MvTDist(df::Real, C::PDMat) = GenericMvTDist(df, C) -MvTDist(df::Real, μ::Vector{Float64}, Σ::Matrix{Float64}) = GenericMvTDist(df, μ, PDMat(Σ)) -MvTDist(df::Float64, Σ::Matrix{Float64}) = GenericMvTDist(df, PDMat(Σ)) +MvTDist(df::Real, μ::Vector{<:Real}, Σ::Matrix{<:Real}) = GenericMvTDist(df, μ, PDMat(Σ)) +MvTDist(df::Real, Σ::Matrix{<:Real}) = GenericMvTDist(df, PDMat(Σ)) -DiagTDist(df::Float64, μ::Vector{Float64}, C::PDiagMat) = GenericMvTDist(df, μ, C) -DiagTDist(df::Float64, C::PDiagMat) = GenericMvTDist(df, C) -DiagTDist(df::Float64, μ::Vector{Float64}, σ::Vector{Float64}) = GenericMvTDist(df, μ, PDiagMat(abs2.(σ))) +DiagTDist(df::Real, μ::Vector{<:Real}, C::PDiagMat) = GenericMvTDist(df, μ, C) +DiagTDist(df::Real, C::PDiagMat) = GenericMvTDist(df, C) +DiagTDist(df::Real, μ::Vector{<:Real}, σ::Vector{<:Real}) = GenericMvTDist(df, μ, PDiagMat(abs2.(σ))) -IsoTDist(df::Float64, μ::Vector{Float64}, C::ScalMat) = GenericMvTDist(df, μ, C) -IsoTDist(df::Float64, C::ScalMat) = GenericMvTDist(df, C) -IsoTDist(df::Float64, μ::Vector{Float64}, σ::Real) = GenericMvTDist(df, μ, ScalMat(length(μ), abs2(Float64(σ)))) -IsoTDist(df::Float64, d::Int, σ::Real) = GenericMvTDist(df, ScalMat(d, abs2(Float64(σ)))) +IsoTDist(df::Real, μ::Vector{<:Real}, C::ScalMat) = GenericMvTDist(df, μ, C) +IsoTDist(df::Real, C::ScalMat) = GenericMvTDist(df, C) +IsoTDist(df::Real, μ::Vector{<:Real}, σ::Real) = GenericMvTDist(df, μ, ScalMat(length(μ), abs2(σ))) +IsoTDist(df::Real, d::Int, σ::Real) = GenericMvTDist(df, ScalMat(d, abs2(σ))) ## convenient function to construct distributions of proper type based on arguments @@ -75,11 +78,11 @@ mvtdist(df::Real, μ::Vector, σ::Vector) = GenericMvTDist(df, μ, PDiagMat(abs2 mvtdist(df::Real, μ::Vector, Σ::Matrix) = GenericMvTDist(df, μ, PDMat(Σ)) mvtdist(df::Real, Σ::Matrix) = GenericMvTDist(df, PDMat(Σ)) -mvtdist(df::Float64, μ::Vector{Float64}, σ::Real) = IsoTDist(df, μ, Float64(σ)) -mvtdist(df::Float64, d::Int, σ::Float64) = IsoTDist(d, σ) -mvtdist(df::Float64, μ::Vector{Float64}, σ::Vector{Float64}) = DiagTDist(df, μ, σ) -mvtdist(df::Float64, μ::Vector{Float64}, Σ::Matrix{Float64}) = MvTDist(df, μ, Σ) -mvtdist(df::Float64, Σ::Matrix{Float64}) = MvTDist(df, Σ) +# mvtdist(df::Real, μ::Vector{<:Real}, σ::Real) = IsoTDist(df, μ, σ) +# mvtdist(df::Real, d::Int, σ::Real) = IsoTDist(d, σ) +mvtdist(df::Real, μ::Vector{<:Real}, σ::Vector{<:Real}) = DiagTDist(df, μ, σ) +mvtdist(df::Real, μ::Vector{<:Real}, Σ::Matrix{<:Real}) = MvTDist(df, μ, Σ) +mvtdist(df::Real, Σ::Matrix{<:Real}) = MvTDist(df, Σ) # Basic statistics @@ -97,7 +100,8 @@ invcov(d::GenericMvTDist) = d.df>2 ? ((d.df-2)/d.df)*Matrix(inv(d.Σ)) : NaN*one logdet_cov(d::GenericMvTDist) = d.df>2 ? logdet((d.df/(d.df-2))*d.Σ) : NaN params(d::GenericMvTDist) = (d.df, d.μ, d.Σ) -@inline partype(d::GenericMvTDist{T}) where {T<:Real} = T +@inline partype(d::GenericMvTDist{T}) where {T} = T +eltype(::GenericMvTDist{T}) where {T} = T # For entropy calculations see "Multivariate t Distributions and their Applications", S. Kotz & S. Nadarajah function entropy(d::GenericMvTDist) diff --git a/src/multivariates.jl b/src/multivariates.jl index a09ec22ae7..d2bd273440 100644 --- a/src/multivariates.jl +++ b/src/multivariates.jl @@ -14,6 +14,13 @@ Return the sample size of distribution `d`, *i.e* `(length(d),)`. """ size(d::MultivariateDistribution) +""" + eltype(d::MultivariateDistribution) +Return the sample type of distribution `d` +""" +eltype(d::MultivariateDistribution) + + ## sampling """ diff --git a/src/samplers/gamma.jl b/src/samplers/gamma.jl index c63db263b4..a80eb9a532 100644 --- a/src/samplers/gamma.jl +++ b/src/samplers/gamma.jl @@ -6,24 +6,23 @@ # suitable for shape >= 1.0 -struct GammaGDSampler <: Sampleable{Univariate,Continuous} - a::Float64 - s2::Float64 - s::Float64 - i2s::Float64 - d::Float64 - q0::Float64 - b::Float64 - σ::Float64 - c::Float64 - scale::Float64 +struct GammaGDSampler{T<:Real} <: Sampleable{Univariate,Continuous} + a::T + s2::T + s::T + i2s::T + d::T + q0::T + b::T + σ::T + c::T + scale::T end -function GammaGDSampler(g::Gamma) +function GammaGDSampler(g::Gamma{T}) where {T} a = shape(g) - # Step 1 - s2 = a-0.5 + s2 = a - 0.5 s = sqrt(s2) i2s = 0.5/s d = 5.656854249492381 - 12.0s # 4*sqrt(2) - 12s @@ -55,7 +54,7 @@ function GammaGDSampler(g::Gamma) c = 0.1515/s end - GammaGDSampler(a,s2,s,i2s,d,q0,b,σ,c,scale(g)) + GammaGDSampler(a,s2,s,i2s,d,q0,T(b),T(σ),T(c),scale(g)) end function calc_q(s::GammaGDSampler, t) @@ -195,9 +194,9 @@ end # Inverse Power sampler # uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1 -struct GammaIPSampler{S<:Sampleable{Univariate,Continuous}} <: Sampleable{Univariate,Continuous} +struct GammaIPSampler{S<:Sampleable{Univariate,Continuous},T<:Real} <: Sampleable{Univariate,Continuous} s::S #sampler for Gamma(1+shape,scale) - nia::Float64 #-1/scale + nia::T #-1/scale end function GammaIPSampler(d::Gamma,::Type{S}) where S<:Sampleable diff --git a/test/types.jl b/test/types.jl index 44eb4a0bd6..93452f8895 100644 --- a/test/types.jl +++ b/test/types.jl @@ -1,6 +1,7 @@ # Test type relations using Distributions +using ForwardDiff: Dual @assert UnivariateDistribution <: Distribution @assert MultivariateDistribution <: Distribution @@ -21,3 +22,18 @@ using Distributions @assert DiscreteMatrixDistribution <: MatrixDistribution @assert ContinuousMatrixDistribution <: ContinuousDistribution @assert ContinuousMatrixDistribution <: MatrixDistribution + +@testset "Test Sample Type" begin + for T in (Float64,Float32,Dual{Nothing,Float64,0}) + @testset "Type $T" begin + for d in (MvNormal,MvLogNormal,MvNormalCanon,Dirichlet) + dist = d(map(T,ones(2))) + @test eltype(dist) == T + @test eltype(rand(dist)) == eltype(dist) + end + dist = Distributions.mvtdist(map(T,1.0),map(T,[1.0 0.0; 0.0 1.0])) + @test eltype(dist) == T + @test eltype(rand(dist)) == eltype(dist) + end + end +end