Skip to content

Commit

Permalink
Add custom adjoint for Wishart
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jun 6, 2020
1 parent db300fc commit 8fbeeda
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,34 @@ function unwhiten!(C::Cholesky, x::StridedVecOrMat)
lmul!(transpose(cf), x)
end

## Custom adjoint since Zygote can't differentiate through `@warn`

ZygoteRules.@adjoint function Wishart(df::T, S::AbstractPDMat{T}, warn::Bool = true) where T
return ZygoteRules.pullback(_Wishart, df, S)
end

function _Wishart(df::T, S::AbstractPDMat{T}, warn::Bool = true) where T
df > 0 || throw(ArgumentError("df must be positive. got $(df)."))
p = dim(S)
rnk = p
singular = df <= p - 1
if singular
isinteger(df) || throw(ArgumentError("singular df must be an integer. got $(df)."))
rnk = convert(Integer, df)
warn && _warn("got df <= dim - 1; returning a singular Wishart")
end
logc0 = Distributions.wishart_logc0(df, S, rnk)
R = Base.promote_eltype(T, logc0)
prom_S = convert(AbstractArray{T}, S)
Wishart{R, typeof(prom_S), typeof(rnk)}(R(df), prom_S, R(logc0), rnk, singular)
end

_warn(msg) = @warn(msg)

ZygoteRules.@adjoint function _warn(msg)
return _warn(msg), _ -> nothing
end

## InverseWishart

struct TuringInverseWishart{T<:Real, ST<:AbstractMatrix} <: ContinuousMatrixDistribution
Expand Down

0 comments on commit 8fbeeda

Please sign in to comment.