From 93a0b16c7986d5a483e5221aa47f1314b75b0151 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 5 Feb 2023 19:26:45 +0000 Subject: [PATCH] Missing impl of `with_logabsdet_jacobian` for `PDBijector` (#245) * added missing implementation for with_logabsdet_jacobian and tests for PDBijector * added more informative error message in the case where with_logabsdet_jacobian has not been implemented and transform and logabsdetjac fail * bump patch version * Apply suggestions from code review Co-authored-by: David Widmann * Apply suggestions from code review Co-authored-by: David Widmann * reverted a change * fixed default impls of transform and logabsdetjac * fixed logabsdetjac_pdbijector_chol * qualified reference to logtwo --------- Co-authored-by: David Widmann --- Project.toml | 2 +- src/Bijectors.jl | 2 +- src/bijectors/pd.jl | 19 +++++++++++++------ src/interface.jl | 18 ++++++++++++++++-- test/bijectors/pd.jl | 13 +++++++++++++ test/bijectors/utils.jl | 20 ++++++++++++++++---- test/runtests.jl | 1 + 7 files changed, 61 insertions(+), 14 deletions(-) create mode 100644 test/bijectors/pd.jl diff --git a/Project.toml b/Project.toml index 93db8e41..8a447f5d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.12.0" +version = "0.12.1" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 9121ea70..afabd3d7 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -37,7 +37,7 @@ using LinearAlgebra: AbstractTriangular using InverseFunctions: InverseFunctions -import ChangesOfVariables: with_logabsdet_jacobian +import ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian import InverseFunctions: inverse import ChainRulesCore diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index 5b57f55b..bed6ee9a 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -26,12 +26,19 @@ function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) if !issuccess(Xcf) Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I) end - return logabsdetjac(b, Xcf) + return logabsdetjac_pdbijector_chol(Xcf) end -function logabsdetjac(b::PDBijector, Xcf::Cholesky) - U = Xcf.U - T = eltype(U) - d = size(U, 1) - return - sum((d .- (1:d) .+ 2) .* log.(diag(U))) - d * log(T(2)) +function logabsdetjac_pdbijector_chol(Xcf::Cholesky) + # NOTE: Use `UpperTriangular` here because we only need `diag(U)` + # and `UL` is by default already constructed in `Cholesky`. + UL = Xcf.UL + d = size(UL, 1) + z = sum(((d + 1):(-1):2) .* log.(diag(UL))) + return - (z + d * oftype(z, IrrationalConstants.logtwo)) +end + +# TODO: Implement explicitly. +function with_logabsdet_jacobian(b::PDBijector, X) + return transform(b, X), logabsdetjac(b, X) end diff --git a/src/interface.jl b/src/interface.jl index ea10d2f0..2da09663 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -56,7 +56,14 @@ Broadcast.broadcastable(b::Transform) = Ref(b) Transform `x` using `b`, treating `x` as a single input. """ transform(f::F, x) where {F<:Function} = f(x) -transform(t::Transform, x) = first(with_logabsdet_jacobian(t, x)) +function transform(t::Transform, x) + res = with_logabsdet_jacobian(t, x) + if res isa ChangesOfVariables.NoLogAbsDetJacobian + error("`transform` not implemented for $(typeof(b)); implement `transform` and/or `with_logabsdet_jacobian`.") + end + + return first(res) +end """ transform!(b, x[, y]) @@ -73,7 +80,14 @@ transform!(b, x, y) = copyto!(y, transform(b, x)) Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. """ -logabsdetjac(b, x) = last(with_logabsdet_jacobian(b, x)) +function logabsdetjac(b, x) + res = with_logabsdet_jacobian(b, x) + if res isa ChangesOfVariables.NoLogAbsDetJacobian + error("`logabsdetjac` not implemented for $(typeof(b)); implement `logabsdetjac` and/or `with_logabsdet_jacobian`.") + end + + return last(res) +end """ logabsdetjac!(b, x[, logjac]) diff --git a/test/bijectors/pd.jl b/test/bijectors/pd.jl new file mode 100644 index 00000000..f429c5c2 --- /dev/null +++ b/test/bijectors/pd.jl @@ -0,0 +1,13 @@ +using Bijectors, DistributionsAD, LinearAlgebra, Test +using Bijectors: PDBijector + +@testset "PDBijector" begin + d = 5 + b = PDBijector() + dist = Wishart(d, Matrix{Float64}(I, d, d)) + x = rand(dist) + # NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian` + # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. + # Hence, we disable those tests. + test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) +end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index cf1283dd..dc1d3a55 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -12,6 +12,8 @@ function test_bijector( logjac=nothing, test_not_identity=isnothing(y) && isnothing(logjac), test_types=false, + changes_of_variables_test=true, + inverse_functions_test=true, compare=isapprox, kwargs... ) @@ -29,12 +31,22 @@ function test_bijector( end # ChangesOfVariables.jl - ChangesOfVariables.test_with_logabsdet_jacobian(b, x, getjacobian; compare=compare, kwargs...) - ChangesOfVariables.test_with_logabsdet_jacobian(ib, isnothing(y) ? y_test : y, getjacobian; compare=compare, kwargs...) + # For non-bijective transformations, these tests always fail since determinant of + # the Jacobian is zero. Hence we allow the caller to disable them if necessary. + if changes_of_variables_test + ChangesOfVariables.test_with_logabsdet_jacobian(b, x, getjacobian; compare=compare, kwargs...) + ChangesOfVariables.test_with_logabsdet_jacobian( + ib, isnothing(y) ? y_test : y, getjacobian; + compare=compare, + kwargs... + ) + end # InverseFunctions.jl - InverseFunctions.test_inverse(b, x; compare, kwargs...) - InverseFunctions.test_inverse(ib, isnothing(y) ? y_test : y; compare=compare, kwargs...) + if inverse_functions_test + InverseFunctions.test_inverse(b, x; compare, kwargs...) + InverseFunctions.test_inverse(ib, isnothing(y) ? y_test : y; compare=compare, kwargs...) + end # Always want the following to hold @test compare(ires[1], x; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index 3185bf4c..0a00ea5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ if GROUP == "All" || GROUP == "Interface" include("bijectors/leaky_relu.jl") include("bijectors/coupling.jl") include("bijectors/ordered.jl") + include("bijectors/pd.jl") end if GROUP == "All" || GROUP == "AD"