Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added eltype for multivariate distributions #882

Merged
merged 11 commits into from
Jul 11, 2019
1 change: 1 addition & 0 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ end

length(d::DirichletCanon) = length(d.alpha)

eltype(::Dirichlet{T}) where {T} = T
#### Conversions
convert(::Type{Dirichlet{Float64}}, cf::DirichletCanon) = Dirichlet(cf.alpha)
convert(::Type{Dirichlet{T}}, alpha::Vector{S}) where {T<:Real, S<:Real} =
Expand Down
3 changes: 1 addition & 2 deletions src/multivariate/dirichletmultinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ ncategories(d::DirichletMultinomial) = length(d.α)
length(d::DirichletMultinomial) = ncategories(d)
ntrials(d::DirichletMultinomial) = d.n
params(d::DirichletMultinomial) = (d.n, d.α)
@inline partype(d::DirichletMultinomial{T}) where {T<:Real} = T

@inline partype(d::DirichletMultinomial{T}) where {T} = T

# Statistics
mean(d::DirichletMultinomial) = d.α .* (d.n / d.α0)
Expand Down
2 changes: 2 additions & 0 deletions src/multivariate/mvlognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ MvLogNormal(Σ::AbstractMatrix) = MvLogNormal(MvNormal(Σ))
MvLogNormal(σ::AbstractVector) = MvLogNormal(MvNormal(σ))
MvLogNormal(d::Int,s::Real) = MvLogNormal(MvNormal(d,s))


eltype(::MvLogNormal{T}) where {T} = T
### Conversion
function convert(::Type{MvLogNormal{T}}, d::MvLogNormal) where T<:Real
MvLogNormal(convert(MvNormal{T}, d.normal))
Expand Down
4 changes: 3 additions & 1 deletion src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ MvNormal(Σ::Matrix{<:Real}) = MvNormal(PDMat(Σ))
MvNormal(σ::Vector{<:Real}) = MvNormal(PDiagMat(abs2.(σ)))
MvNormal(d::Int, σ::Real) = MvNormal(ScalMat(d, abs2(σ)))


eltype(::MvNormal{T}) where {T} = T
### Conversion
function convert(::Type{MvNormal{T}}, d::MvNormal) where T<:Real
MvNormal(convert(AbstractArray{T}, d.μ), convert(AbstractArray{T}, d.Σ))
Expand Down Expand Up @@ -270,7 +272,7 @@ _rand!(rng::AbstractRNG, d::MvNormal, x::VecOrMat) =
# Workaround: randn! only works for Array, but not generally for AbstractArray
function _rand!(rng::AbstractRNG, d::MvNormal, x::AbstractVector)
for i in eachindex(x)
@inbounds x[i] = randn(rng)
@inbounds x[i] = randn(rng,eltype(d))
end
add!(unwhiten!(d.Σ, x), d.μ)
end
Expand Down
3 changes: 2 additions & 1 deletion src/multivariate/mvnormalcanon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function MvNormalCanon(μ::AbstractVector{T}, h::AbstractVector{T}, J::AbstractP
if typeof(μ) == typeof(h)
return MvNormalCanon{T,typeof(J),typeof(μ)}(μ, h, J)
else
return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J)
return MvNormalCanon{T,typeof(J),Vector{T}}(collect(μ), collect(h), J)
end
end

Expand Down Expand Up @@ -130,6 +130,7 @@ distrname(d::ZeroMeanIsoNormalCanon) = "ZeroMeanIsoNormalCanon"
distrname(d::ZeroMeanDiagNormalCanon) = "ZeroMeanDiagormalCanon"
distrname(d::ZeroMeanFullNormalCanon) = "ZeroMeanFullNormalCanon"

eltype(::MvNormalCanon{T}) where {T} = T
### Conversion
function convert(::Type{MvNormalCanon{T}}, d::MvNormalCanon) where {T<:Real}
MvNormalCanon(convert(AbstractArray{T}, d.μ), convert(AbstractArray{T}, d.h), convert(AbstractArray{T}, d.J))
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/mvtdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ GenericMvTDist(df::Real, μ::Vector{S}, Σ::Cov) where {Cov<:AbstractPDMat, S<:R

GenericMvTDist(df::T, Σ::Cov) where {Cov<:AbstractPDMat, T<:Real} = GenericMvTDist(df, zeros(dim(Σ)), Σ, true)

eltype(::GenericMvTDist{T}) where {T} = T
### Conversion
function convert(::Type{GenericMvTDist{T}}, d::GenericMvTDist) where T<:Real
S = convert(AbstractArray{T}, d.Σ)
Expand Down