Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Wishart and InverseWishart #84

Merged
merged 12 commits into from
Jun 15, 2020
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Combinatorics = "0.7, 1.0"
Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.22, 0.23"
Distributions = "0.23.3"
FillArrays = "0.8"
ForwardDiff = "0.10.6"
MacroTools = "0.5"
Expand Down
83 changes: 47 additions & 36 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ using StatsFuns: logtwo, logmvgamma
struct TuringWishart{T<:Real, ST <: Cholesky} <: ContinuousMatrixDistribution
df::T # degree of freedom
chol::ST # the Cholesky of scale matrix
c0::T # the logarithm of normalizing constant in pdf
logc0::T # the logarithm of normalizing constant in pdf
end

#### Constructors

function TuringWishart(d::Wishart)
return TuringWishart(d.df, getchol(d.S), d.c0)
return TuringWishart(d.df, getchol(d.S), d.logc0)
end
getchol(p::PDMat) = p.chol
getchol(p::PDiagMat) = Diagonal(map(sqrt, p.diag))
Expand All @@ -40,15 +40,15 @@ function TuringWishart(df::T, S::AbstractMatrix) where {T <: Real}
return TuringWishart(df, C)
end
function TuringWishart(df::T, C::Cholesky) where {T <: Real}
c0 = _wishart_c0(df, C)
R = Base.promote_eltype(T, c0)
return TuringWishart(R(df), C, R(c0))
logc0 = _wishart_logc0(df, C)
R = Base.promote_eltype(T, logc0)
return TuringWishart(R(df), C, R(logc0))
end

function _wishart_c0(df::Real, C::Cholesky)
function _wishart_logc0(df::Real, C::Cholesky)
h_df = df / 2
p = size(C, 1)
h_df * (logdet(C) + p * float(typeof(df))(logtwo)) + logmvgamma(p, h_df)
-h_df * (logdet(C) + p * float(typeof(df))(logtwo)) - logmvgamma(p, h_df)
end

#### Properties
Expand Down Expand Up @@ -87,7 +87,7 @@ end
function Distributions.entropy(d::TuringWishart)
p = Distributions.dim(d)
df = d.df
d.c0 - 0.5 * (df - p - 1) * Distributions.meanlogdet(d) + 0.5 * df * p
return -d.logc0 - 0.5 * (df - p - 1) * Distributions.meanlogdet(d) + 0.5 * df * p
end

# Gupta/Nagar (1999) Theorem 3.3.15.i
Expand All @@ -113,7 +113,7 @@ function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real})
df = d.df
p = Distributions.dim(d)
Xcf = cholesky(X)
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) + d.logc0
end
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(d, x), X)
Expand All @@ -124,45 +124,56 @@ end

#### Sampling
function Distributions._rand!(rng::AbstractRNG, d::TuringWishart, A::AbstractMatrix)
_wishart_genA!(rng, Distributions.dim(d), d.df, A)
Distributions._wishart_genA!(rng, Distributions.dim(d), d.df, A)
unwhiten!(d.chol, A)
A .= A * A'
end

function _wishart_genA!(rng::AbstractRNG, p::Int, df::Real, A::AbstractMatrix)
# Generate the matrix A in the Bartlett decomposition
#
# A is a lower triangular matrix, with
#
# A(i, j) ~ sqrt of Chisq(df - i + 1) when i == j
# ~ Normal() when i > j
#
A .= zero(eltype(A))
for i = 1:p
@inbounds A[i,i] = rand(rng, Chi(df - i + 1.0))
end
for j in 1:p-1, i in j+1:p
@inbounds A[i,j] = randn(rng)
end
end

function unwhiten!(C::Cholesky, x::StridedVecOrMat)
cf = C.U
lmul!(transpose(cf), x)
end

## Custom adjoint since Zygote can't differentiate through `@warn`

ZygoteRules.@adjoint function Wishart(df::T, S::AbstractPDMat{T}, warn::Bool = true) where T
return ZygoteRules.pullback(_Wishart, df, S, warn)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

function _Wishart(df::T, S::AbstractPDMat{T}, warn::Bool = true) where T
df > 0 || throw(ArgumentError("df must be positive. got $(df)."))
p = dim(S)
rnk = p
singular = df <= p - 1
if singular
isinteger(df) || throw(ArgumentError("singular df must be an integer. got $(df)."))
rnk = convert(Integer, df)
warn && _warn("got df <= dim - 1; returning a singular Wishart")
end
logc0 = Distributions.wishart_logc0(df, S, rnk)
R = Base.promote_eltype(T, logc0)
prom_S = convert(AbstractArray{T}, S)
Wishart{R, typeof(prom_S), typeof(rnk)}(R(df), prom_S, R(logc0), rnk, singular)
end

_warn(msg) = @warn(msg)

ZygoteRules.@adjoint function _warn(msg)
return _warn(msg), _ -> nothing
end
devmotion marked this conversation as resolved.
Show resolved Hide resolved

## InverseWishart

struct TuringInverseWishart{T<:Real, ST<:AbstractMatrix} <: ContinuousMatrixDistribution
df::T # degree of freedom
S::ST # Scale matrix
c0::T # log of normalizing constant
logc0::T # log of normalizing constant
end

#### Constructors

function TuringInverseWishart(d::InverseWishart)
d = TuringInverseWishart(d.df, getmatrix(d.Ψ), d.c0)
d = TuringInverseWishart(d.df, getmatrix(d.Ψ), d.logc0)
end
getmatrix(p::PDMat) = p.mat
getmatrix(p::PDiagMat) = Diagonal(p.diag)
Expand All @@ -172,14 +183,14 @@ function TuringInverseWishart(df::T, Ψ::AbstractMatrix) where T<:Real
p = size(Ψ, 1)
df > p - 1 || error("df should be greater than dim - 1.")
C = cholesky(Ψ)
c0 = _invwishart_c0(df, C)
R = Base.promote_eltype(T, c0)
return TuringInverseWishart(R(df), Ψ, R(c0))
logc0 = _invwishart_logc0(df, C)
R = Base.promote_eltype(T, logc0)
return TuringInverseWishart(R(df), Ψ, R(logc0))
end
function _invwishart_c0(df::Real, C::Cholesky)
function _invwishart_logc0(df::Real, C::Cholesky)
h_df = df / 2
p = size(C, 1)
h_df * (p * float(typeof(df))(logtwo) - logdet(C)) + logmvgamma(p, h_df)
-h_df * (p * float(typeof(df))(logtwo) - logdet(C)) - logmvgamma(p, h_df)
end

#### Properties
Expand Down Expand Up @@ -217,7 +228,7 @@ end
function Distributions.var(d::TuringInverseWishart, i::Integer, j::Integer)
p, ν, Ψ = (Distributions.dim(d), d.df, d.S)
ν > p + 3 || throw(ArgumentError("var only defined for df > dim + 3"))
inv((ν - p)*(ν - p - 3)*(ν - p - 1)^2)*(ν - p + 1)*Ψ[i,j]^2 + (ν - p - 1)*Ψ[i,i]*Ψ[j,j]
inv((ν - p)*(ν - p - 3)*(ν - p - 1)^2)*((ν - p + 1)*Ψ[i,j]^2 + (ν - p - 1)*Ψ[i,i]*Ψ[j,j])
end

#### Evaluation
Expand All @@ -234,7 +245,7 @@ function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real}
Xcf = cholesky(X)
# we use the fact: tr(Ψ * inv(X)) = tr(inv(X) * Ψ) = tr(X \ Ψ)
Ψ = d.S
-0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0
-0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) + d.logc0
end
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(d, x), X)
Expand Down