diff --git a/Project.toml b/Project.toml index b5357221..90a88153 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.16" +version = "0.6.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -28,8 +28,8 @@ ChainRules = "0.7" ChainRulesCore = "0.9.9" Compat = "3.6" DiffRules = "0.1, 1.0" -Distributions = "0.23.3, 0.24" -FillArrays = "0.8, 0.9, 0.10" +Distributions = "0.24.12" +FillArrays = "0.8, 0.9, 0.10, 0.11" ForwardDiff = "0.10.6" NaNMath = "0.3" PDMats = "0.9, 0.10" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index c9306485..a2e24162 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -28,7 +28,6 @@ import StatsFuns: logsumexp, nbetalogpdf import Distributions: MvNormal, MvLogNormal, - poissonbinomial_pdf_fft, logpdf, quantile, PoissonBinomial, @@ -46,9 +45,6 @@ export TuringScalMvNormal, arraydist, filldist -# check if Distributions >= 0.24 by checking if a generic implementation of `pdf` is defined -const DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF = hasmethod(pdf, Tuple{UnivariateDistribution,Real}) - include("common.jl") include("arraydist.jl") include("filldist.jl") @@ -66,7 +62,7 @@ include("zygote.jl") using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here include("forwarddiff.jl") - # loads adjoint for `poissonbinomial_pdf` and `poissonbinomial_pdf_fft` + # loads adjoint for `poissonbinomial_pdf` include("zygote_forwarddiff.jl") end @@ -99,15 +95,6 @@ include("zygote.jl") return sum(copy(logpdf.(dist.v, x))) end - function Distributions.logpdf( - dist::LazyVectorOfUnivariate, - x::AbstractMatrix{<:Real}, - ) - size(x, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return vec(sum(copy(logpdf.(dists, x)), dims = 1)) - end - const LazyMatrixOfUnivariate{ S<:ValueSupport, T<:UnivariateDistribution{S}, diff --git a/src/arraydist.jl b/src/arraydist.jl index 6b3a76d9..53c07b7a 100644 --- a/src/arraydist.jl +++ b/src/arraydist.jl @@ -2,16 +2,7 @@ const VectorOfUnivariate = Distributions.Product -function arraydist(dists::AbstractVector{<:UnivariateDistribution}) - return Product(dists) -end - -function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) - size(x, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - # `eachcol` breaks Zygote, so we use `view` directly - return map(i -> sum(map(logpdf, dist.v, view(x, :, i))), axes(x, 2)) -end +arraydist(dists::AbstractVector{<:UnivariateDistribution}) = Product(dists) struct MatrixOfUnivariate{ S <: ValueSupport, @@ -29,12 +20,6 @@ function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Rea # Broadcasting here breaks Tracker for some reason return sum(map(logpdf, dist.dists, x)) end -function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(x -> logpdf(dist, x), x) -end -function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(x -> logpdf(dist, x), x) -end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) return rand.(Ref(rng), dist.dists) @@ -59,12 +44,6 @@ function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:R # `eachcol` breaks Zygote, so we use `view` directly return sum(i -> logpdf(dist.dists[i], view(x, :, i)), axes(x, 2)) end -function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(x -> logpdf(dist, x), x) -end -function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(x -> logpdf(dist, x), x) -end function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) init = reshape(rand(rng, dist.dists[1]), :, 1) diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 4e591be2..1bbff4ac 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -58,19 +58,3 @@ function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::Real, k::Int) where {T} Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, p, k) return FD(nbinomlogpdf(val_r, p, k), Δ_r) end - -## ForwardDiff broadcasting support ## -# If we use Distributions >= 0.24, then `DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF` is `true`. -# In Distributions 0.24 `logpdf` is defined for inputs of type `Real` which are then -# converted to the support of the distributions (such as integers) in their concrete implementations. -# Thus it is no needed to have a special function for dual numbers that performs the conversion -# (and actually this method leads to method ambiguity errors since even discrete distributions now -# define logpdf(::MyDistribution, ::Real), see, e.g., -# JuliaStats/Distributions.jl@ae2d6c5/src/univariate/discrete/binomial.jl#L119). -if !DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF - @eval begin - function Distributions.logpdf(d::DiscreteUnivariateDistribution, k::ForwardDiff.Dual) - return logpdf(d, convert(Integer, ForwardDiff.value(k))) - end - end -end diff --git a/src/matrixvariate.jl b/src/matrixvariate.jl index 42c588d7..716c9a08 100644 --- a/src/matrixvariate.jl +++ b/src/matrixvariate.jl @@ -214,24 +214,3 @@ function Distributions._rand!(rng::AbstractRNG, d::TuringInverseWishart, A::Abst X = Distributions._rand!(rng, TuringWishart(d.df, inv(cholesky(d.S))), A) A .= inv(cholesky!(X)) end - -# Only needed in Distributions < 0.24 -if !DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF - for T in (:MatrixBeta, :MatrixNormal, :Wishart, :InverseWishart, - :TuringWishart, :TuringInverseWishart, - :VectorOfMultivariate, :MatrixOfUnivariate) - @eval begin - Distributions.loglikelihood(d::$T, X::AbstractMatrix{<:Real}) = logpdf(d, X) - function Distributions.loglikelihood(d::$T, X::AbstractArray{<:Real,3}) - (size(X, 1), size(X, 2)) == size(d) || throw(DimensionMismatch("Inconsistent array dimensions.")) - return sum(i -> _logpdf(d, view(X, :, :, i)), axes(X, 3)) - end - function Distributions.loglikelihood( - d::$T, - X::AbstractArray{<:AbstractMatrix{<:Real}}, - ) - return sum(x -> logpdf(d, x), X) - end - end - end -end diff --git a/src/multivariate.jl b/src/multivariate.jl index c41734f1..d317271c 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -1,88 +1,3 @@ -## Dirichlet ## - -struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution - alpha::TV - alpha0::T - lmnB::T -end -Base.length(d::TuringDirichlet) = length(d.alpha) -function check(alpha) - all(ai -> ai > 0, alpha) || - throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) -end - -function Distributions._rand!(rng::Random.AbstractRNG, - d::TuringDirichlet, - x::AbstractVector{<:Real}) - s = 0.0 - n = length(x) - α = d.alpha - for i in 1:n - @inbounds s += (x[i] = rand(rng, Gamma(α[i]))) - end - Distributions.multiply!(x, inv(s)) # this returns x -end - -function TuringDirichlet(alpha::AbstractVector) - check(alpha) - alpha0 = sum(alpha) - lmnB = sum(loggamma, alpha) - loggamma(alpha0) - T = promote_type(typeof(alpha0), typeof(lmnB)) - TV = typeof(alpha) - TuringDirichlet{T, TV}(alpha, alpha0, lmnB) -end - -function TuringDirichlet(d::Integer, alpha::Real) - alpha0 = alpha * d - _alpha = fill(alpha, d) - lmnB = loggamma(alpha) * d - loggamma(alpha0) - T = promote_type(typeof(alpha0), typeof(lmnB)) - TV = typeof(_alpha) - TuringDirichlet{T, TV}(_alpha, alpha0, lmnB) -end -function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer} - TuringDirichlet(float.(alpha)) -end -TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha)) - -Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha) - -function Distributions._logpdf(d::TuringDirichlet, x::AbstractVector{<:Real}) - return simplex_logpdf(d.alpha, d.lmnB, x) -end -function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix{<:Real}) - size(x, 1) == length(d) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return simplex_logpdf(d.alpha, d.lmnB, x) -end - -ZygoteRules.@adjoint function Distributions.Dirichlet(alpha) - return ZygoteRules.pullback(TuringDirichlet, alpha) -end -ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha) - return ZygoteRules.pullback(TuringDirichlet, d, alpha) -end - -function simplex_logpdf(alpha, lmnB, x::AbstractVector) - sum((alpha .- 1) .* log.(x)) - lmnB -end -function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - @views init = vcat(sum((alpha .- 1) .* log.(x[:,1])) - lmnB) - mapreduce(vcat, drop(eachcol(x), 1); init = init) do c - sum((alpha .- 1) .* log.(c)) - lmnB - end -end - -ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector) - simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1) ./ x) -end - -ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - simplex_logpdf(alpha, lmnB, x), Δ -> begin - (log.(x) * Δ, -sum(Δ), ((alpha .- 1) ./ x) * Diagonal(Δ)) - end -end - ## MvNormal ## """ diff --git a/src/reversediff.jl b/src/reversediff.jl index 6ced110b..8e72dd89 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -33,7 +33,6 @@ import Distributions: logpdf, Gamma, MvNormal, MvLogNormal, - Dirichlet, Wishart, InverseWishart, PoissonBinomial, @@ -44,7 +43,6 @@ using ..DistributionsAD: TuringPoissonBinomial, TuringMvLogNormal, TuringWishart, TuringInverseWishart, - TuringDirichlet, TuringScalMvNormal, TuringDiagMvNormal, TuringDenseMvNormal @@ -240,39 +238,6 @@ end # zero mean,, constant variance MvLogNormal(d::Int, σ::TrackedReal) = TuringMvLogNormal(TuringMvNormal(d, σ)) -Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha) -Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) - -for func_header in [ - :(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractVector)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedVector)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractVector)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedVector)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedVector)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedVector)), - - :(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractMatrix)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractMatrix)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedMatrix)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractMatrix)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedMatrix)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedMatrix)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedMatrix)), -] - @eval $func_header = track(simplex_logpdf, alpha, lmnB, x) -end -@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector) - simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin - (Δ .* log.(value(x)), -Δ, Δ .* (value(alpha) .- 1)) - end -end -@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin - (log.(value(x)) * Δ, -sum(Δ), repeat(value(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ)) - end -end - Distributions.Wishart(df::TrackedReal, S::Matrix{<:Real}) = TuringWishart(df, S) Distributions.Wishart(df::TrackedReal, S::AbstractMatrix{<:Real}) = TuringWishart(df, S) Distributions.Wishart(df::Real, S::AbstractMatrix{<:TrackedReal}) = TuringWishart(df, S) diff --git a/src/tracker.jl b/src/tracker.jl index f08ec1ad..6278572f 100644 --- a/src/tracker.jl +++ b/src/tracker.jl @@ -261,28 +261,15 @@ end PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) = TuringPoissonBinomial(p; check_args = check_args) -# TODO: add adjoints without ForwardDiff -poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x) -@grad function poissonbinomial_pdf_fft(x::TrackedArray) +Distributions.poissonbinomial_pdf(x::TrackedArray) = track(Distributions.poissonbinomial_pdf, x) +@grad function Distributions.poissonbinomial_pdf(x::TrackedArray) x_data = data(x) T = eltype(x_data) - fft = poissonbinomial_pdf_fft(x_data) - return fft, Δ -> begin - ((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x_data)::Matrix{T})' * Δ,) - end -end - -if isdefined(Distributions, :poissonbinomial_pdf) - Distributions.poissonbinomial_pdf(x::TrackedArray) = track(Distributions.poissonbinomial_pdf, x) - @grad function Distributions.poissonbinomial_pdf(x::TrackedArray) - x_data = data(x) - T = eltype(x_data) - value = Distributions.poissonbinomial_pdf(x_data) - function poissonbinomial_pdf_pullback(Δ) - return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x_data)::Matrix{T})' * Δ,) - end - return value, poissonbinomial_pdf_pullback + value = Distributions.poissonbinomial_pdf(x_data) + function poissonbinomial_pdf_pullback(Δ) + return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x_data)::Matrix{T})' * Δ,) end + return value, poissonbinomial_pdf_pullback end ## Semicircle ## @@ -339,61 +326,13 @@ end Δ->(Δ * _nbinomlogpdf_grad_1(r, p, k), Tracker._zero(p), nothing) end -## Multinomial - -function Distributions.logpdf( - dist::Multinomial{<:Real,<:TrackedVector}, - X::AbstractMatrix{<:Real} -) - size(X, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - - return map(axes(X, 2)) do i - Distributions._logpdf(dist, view(X, :, i)) - end -end - -## Categorical ## - -function Distributions.DiscreteNonParametric{T,P,Ts,Ps}( - vs::Ts, - ps::Ps; - check_args=true, -) where {T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:TrackedArray{P, 1, <:SubArray{P, 1}}} - cps = ps[:] - return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args) -end - - ## Dirichlet ## -Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha) -Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) - -function Distributions._logpdf(d::Dirichlet, x::TrackedVector{<:Real}) - return Distributions._logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) -end -function Distributions.logpdf(d::Dirichlet, x::TrackedMatrix{<:Real}) - return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) -end -function Distributions.loglikelihood(d::Dirichlet, x::TrackedMatrix{<:Real}) - return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) -end - -# Fix ambiguities -function Distributions.logpdf(d::TuringDirichlet, x::TrackedMatrix{<:Real}) - size(x, 1) == length(d) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return simplex_logpdf(d.alpha, d.lmnB, x) -end - -## Product - -# TODO: Remove when modified upstream -function Distributions.loglikelihood(dist::Product, x::TrackedVector{<:Real}) - return Distributions.logpdf(dist, x) +# needed since `eltype(alpha) = TrackedReal` is not covered by the inner +# constructor in Distributions +function Distributions.Dirichlet(alpha::TrackedVector{T}; check_args=true) where {T<:Real} + return Dirichlet{T}(alpha; check_args=check_args) end - ## MvNormal for (f, T) in ( @@ -615,4 +554,3 @@ Distributions.InverseWishart(df::TrackedReal, S::AbstractMatrix{<:Real}) = Turin Distributions.InverseWishart(df::Real, S::TrackedMatrix) = TuringInverseWishart(df, S) Distributions.InverseWishart(df::TrackedReal, S::TrackedMatrix) = TuringInverseWishart(df, S) Distributions.InverseWishart(df::TrackedReal, S::AbstractPDMat{<:TrackedReal}) = TuringInverseWishart(df, S) - diff --git a/src/univariate.jl b/src/univariate.jl index d86b0577..d8685015 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -38,14 +38,11 @@ struct TuringPoissonBinomial{T<:Real, TV1<:AbstractVector{T}, TV2<:AbstractVecto pmf::TV2 end -# if available use the faster `poissonbinomial_pdf` -@eval begin - function TuringPoissonBinomial(p::AbstractArray{<:Real}; check_args = true) - pb = $(isdefined(Distributions, :poissonbinomial_pdf) ? Distributions.poissonbinomial_pdf : Distributions.poissonbinomial_pdf_fft)(p) - ϵ = eps(eltype(pb)) - check_args && @assert all(x -> x >= -ϵ, pb) && isapprox(sum(pb), 1; atol=ϵ) - return TuringPoissonBinomial(p, pb) - end +function TuringPoissonBinomial(p::AbstractArray{<:Real}; check_args = true) + pb = Distributions.poissonbinomial_pdf(p) + ϵ = eps(eltype(pb)) + check_args && @assert all(x -> x >= -ϵ, pb) && isapprox(sum(pb), 1; atol=ϵ) + return TuringPoissonBinomial(p, pb) end function logpdf(d::TuringPoissonBinomial{T}, k::Int) where T<:Real @@ -54,23 +51,3 @@ end quantile(d::TuringPoissonBinomial, x::Float64) = quantile(Categorical(d.pmf), x) - 1 Base.minimum(d::TuringPoissonBinomial) = 0 Base.maximum(d::TuringPoissonBinomial) = length(d.p) - - -## Categorical ## - -function Base.convert( - ::Type{Distributions.DiscreteNonParametric{T,P,Ts,Ps}}, - d::Distributions.DiscreteNonParametric{T,P,Ts,Ps}, -) where {T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}} - DiscreteNonParametric{T,P,Ts,Ps}(support(d), probs(d), check_args=false) -end - -# Fix SubArray support -function Distributions.DiscreteNonParametric{T,P,Ts,Ps}( - vs::Ts, - ps::Ps; - check_args=true, -) where {T<:Real, P<:Real, Ts<:AbstractVector{T}, Ps<:SubArray{P, 1}} - cps = ps[:] - return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args) -end diff --git a/src/zygote.jl b/src/zygote.jl index 7460c039..63512fd1 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -15,27 +15,7 @@ end ## PoissonBinomial ## # Zygote loads ForwardDiff, so this dummy adjoint should never be needed. -# The adjoint that is used for `poissonbinomial_pdf_fft` is defined in `src/zygote_forwarddiff.jl` -# ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{T}) where T<:Real -# error("This needs ForwardDiff. `using ForwardDiff` should fix this error.") -# end - -## Product - -# Tests with `Kolmogorov` seem to fail otherwise?! -ZygoteRules.@adjoint function Distributions._logpdf(d::Product, x::AbstractVector{<:Real}) - return ZygoteRules.pullback(d, x) do d, x - sum(map(logpdf, d.v, x)) - end -end -ZygoteRules.@adjoint function Distributions._logpdf( - d::FillVectorOfUnivariate, - x::AbstractVector{<:Real}, -) - return ZygoteRules.pullback(d, x) do d, x - _flat_logpdf(d.v.value, x) - end -end +# The adjoint that is used for `poissonbinomial_pdf` is defined in `src/zygote_forwarddiff.jl` ## Wishart ## @@ -72,36 +52,3 @@ ZygoteRules.@adjoint function Distributions.InverseWishart( ) return ZygoteRules.pullback(TuringInverseWishart, df, S) end - -## General definitions of `logpdf` for arrays - -ZygoteRules.@adjoint function Distributions.logpdf( - dist::MultivariateDistribution, - X::AbstractMatrix{<:Real}, -) - size(X, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return ZygoteRules.pullback(dist, X) do dist, X - return map(i -> Distributions._logpdf(dist, view(X, :, i)), axes(X, 2)) - end -end - -ZygoteRules.@adjoint function Distributions.logpdf( - dist::MatrixDistribution, - X::AbstractArray{<:Real,3}, -) - (size(X, 1), size(X, 2)) == size(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return ZygoteRules.pullback(dist, X) do dist, X - return map(i -> Distributions._logpdf(dist, view(X, :, :, i)), axes(X, 3)) - end -end - -ZygoteRules.@adjoint function Distributions.logpdf( - dist::MatrixDistribution, - X::AbstractArray{<:AbstractMatrix{<:Real}}, -) - return ZygoteRules.pullback(dist, X) do dist, X - return map(x -> logpdf(dist, x), X) - end -end diff --git a/src/zygote_forwarddiff.jl b/src/zygote_forwarddiff.jl index 7e157379..1fd57874 100644 --- a/src/zygote_forwarddiff.jl +++ b/src/zygote_forwarddiff.jl @@ -1,20 +1,9 @@ # Zygote loads ForwardDiff, so this adjoint will autmatically be loaded together # with `using Zygote`. - -# TODO: add adjoints without ForwardDiff -@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{T}) where T<:Real - fft = poissonbinomial_pdf_fft(x) - return fft, Δ -> begin - ((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x)::Matrix{T})' * Δ,) - end -end - -if isdefined(Distributions, :poissonbinomial_pdf) - @adjoint function Distributions.poissonbinomial_pdf(x::AbstractArray{T}) where T<:Real - value = Distributions.poissonbinomial_pdf(x) - function poissonbinomial_pdf_pullback(Δ) - return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x)::Matrix{T})' * Δ,) - end - return value, poissonbinomial_pdf_pullback +@adjoint function Distributions.poissonbinomial_pdf(x::AbstractArray{T}) where T<:Real + value = Distributions.poissonbinomial_pdf(x) + function poissonbinomial_pdf_pullback(Δ) + return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x)::Matrix{T})' * Δ,) end + return value, poissonbinomial_pdf_pullback end diff --git a/test/Project.toml b/test/Project.toml index 5271e9cf..47625f9e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,7 +18,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ChainRulesTestUtils = "0.5.3, 0.6" Combinatorics = "1.0.2" -Distributions = "0.24.3" +Distributions = "0.24.12" FiniteDifferences = "0.11.3, 0.12" ForwardDiff = "0.10.12" NNlib = "0.7.7"