From ccf98fd021232ca3101c7429c50afe66edd1e33c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Aug 2024 22:42:59 +0100 Subject: [PATCH 1/7] Added default impl of `_logabdetjac_dist` so we can support non-batch by default --- src/Bijectors.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 79daccbb..e2c9ade0 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -135,6 +135,11 @@ invlink(d::Distribution, y) = inverse(bijector(d))(y) # To still allow `logpdf_with_trans` to work with "batches" in a similar way # as `logpdf` can. + +# Default catch-all so we can work with distributions by default and batch-support can be +# added when needed. +_logabsdetjac_dist(d, x) = logabsdetjac(bijector(d), x) + _logabsdetjac_dist(d::UnivariateDistribution, x::Real) = logabsdetjac(bijector(d), x) function _logabsdetjac_dist(d::UnivariateDistribution, x::AbstractArray) return logabsdetjac.((bijector(d),), x) From 2750b034100ca0c80b55d31291aedf389a95d6c5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Aug 2024 22:51:31 +0100 Subject: [PATCH 2/7] Added test for product of `Dirichlet` --- test/interface.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/interface.jl b/test/interface.jl index c3221307..44a73878 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -223,6 +223,15 @@ end end end +@testset "ProductDistribution" begin + d = product_distribution(fill(Dirichlet(ones(4)), 2, 3)) + x = rand(d) + b = bijector(d) + + @test logpdf_with_trans(d, x, false) == logpdf(d, x) + @test logpdf_with_trans(d, x, true) == logpdf(d, x) - logabsdetjac(b, x) +end + @testset "DistributionsAD" begin @testset "$dist" for dist in [ filldist(Normal(), 2), From c35b855f3199aa482722d618d76679d8803c2135 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Aug 2024 22:53:46 +0100 Subject: [PATCH 3/7] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 55c3459d..d13289bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.17" +version = "0.13.18" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From a03824900a6edc85c193ba457944f171a20fe0d7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 12 Aug 2024 01:11:55 +0100 Subject: [PATCH 4/7] Update src/Bijectors.jl Co-authored-by: David Widmann --- src/Bijectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index e2c9ade0..df7814fb 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -138,7 +138,7 @@ invlink(d::Distribution, y) = inverse(bijector(d))(y) # Default catch-all so we can work with distributions by default and batch-support can be # added when needed. -_logabsdetjac_dist(d, x) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::Distribution, x) = logabsdetjac(bijector(d), x) _logabsdetjac_dist(d::UnivariateDistribution, x::Real) = logabsdetjac(bijector(d), x) function _logabsdetjac_dist(d::UnivariateDistribution, x::AbstractArray) From 328a7ea366e1893182fa1b5d27fa9b7d3c9a8e39 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Aug 2024 23:26:53 +0200 Subject: [PATCH 5/7] Work around eachslice limitation on Julia M - throw( - DimensionMismatch( - "Number of bijectors needs to be smaller than or equal to the number of dimensions", - ), - ) + msg = """ + Number of bijectors needs to be smaller than or equal to the number of dimensions + """ + throw(DimensionMismatch(msg)) end end @@ -23,7 +22,18 @@ function _product_bijector_slices( # If N < M, then the bijectors expect an input vector of dimension `M - N`. # To achieve this, we need to slice along the last `N` dimensions. - return eachslice(x; dims=ntuple(i -> i + (M - N), N)) + slice_indices = ntuple(i -> i + (M - N), N) + if VERSION >= v"1.9" + return eachslice(x; dims=slice_indices) + else + # Earlier Julia versions can't eachslice over multiple dimensions, so reshape the + # slice dimensions into a single one. + other_dims = tuple((size(x, i) for i in 1:(M - N))...) + slice_dims = tuple((size(x, i) for i in (1 + M - N):M)...) + x_reshaped = reshape(x, other_dims..., prod(slice_dims)) + slices = eachslice(x_reshaped; dims=M - N + 1) + return reshape(collect(slices), slice_dims) + end end # Specialization for case where we're just applying elementwise. diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index e21f644c..10829697 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -29,14 +29,17 @@ end if @isdefined Tapir rng = Xoshiro(123456) - Tapir.TestUtils.test_rrule!!( - rng, Bijectors.find_alpha, x, y, z; is_primitive=true, perf_flag=:none + Tapir.TestUtils.test_rule( + rng, Bijectors.find_alpha, x, y, z; + is_primitive=true, perf_flag=:none, interp=Tapir.TapirInterpreter() ) - Tapir.TestUtils.test_rrule!!( - rng, Bijectors.find_alpha, x, y, 3; is_primitive=true, perf_flag=:none + Tapir.TestUtils.test_rule( + rng, Bijectors.find_alpha, x, y, 3; + is_primitive=true, perf_flag=:none, interp=Tapir.TapirInterpreter() ) - Tapir.TestUtils.test_rrule!!( - rng, Bijectors.find_alpha, x, y, UInt32(3); is_primitive=true, perf_flag=:none + Tapir.TestUtils.test_rule( + rng, Bijectors.find_alpha, x, y, UInt32(3); + is_primitive=true, perf_flag=:none, interp=Tapir.TapirInterpreter() ) end From 4fee7f1fc7ac4fb880e08257468304441815e0c6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Aug 2024 22:29:01 +0100 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/ad/chainrules.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 10829697..44792ee1 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -34,12 +34,24 @@ end is_primitive=true, perf_flag=:none, interp=Tapir.TapirInterpreter() ) Tapir.TestUtils.test_rule( - rng, Bijectors.find_alpha, x, y, 3; - is_primitive=true, perf_flag=:none, interp=Tapir.TapirInterpreter() + rng, + Bijectors.find_alpha, + x, + y, + 3; + is_primitive=true, + perf_flag=:none, + interp=Tapir.TapirInterpreter(), ) Tapir.TestUtils.test_rule( - rng, Bijectors.find_alpha, x, y, UInt32(3); - is_primitive=true, perf_flag=:none, interp=Tapir.TapirInterpreter() + rng, + Bijectors.find_alpha, + x, + y, + UInt32(3); + is_primitive=true, + perf_flag=:none, + interp=Tapir.TapirInterpreter(), ) end From c66f5096da6b3dcff3265479584212b6b28a3c3b Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 13 Aug 2024 23:28:46 +0100 Subject: [PATCH 7/7] Update test/ad/chainrules.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/ad/chainrules.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 44792ee1..bcdb9523 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -30,8 +30,14 @@ end if @isdefined Tapir rng = Xoshiro(123456) Tapir.TestUtils.test_rule( - rng, Bijectors.find_alpha, x, y, z; - is_primitive=true, perf_flag=:none, interp=Tapir.TapirInterpreter() + rng, + Bijectors.find_alpha, + x, + y, + z; + is_primitive=true, + perf_flag=:none, + interp=Tapir.TapirInterpreter(), ) Tapir.TestUtils.test_rule( rng,