diff --git a/Project.toml b/Project.toml index 55c3459d..d13289bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.17" +version = "0.13.18" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 79daccbb..df7814fb 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -135,6 +135,11 @@ invlink(d::Distribution, y) = inverse(bijector(d))(y) # To still allow `logpdf_with_trans` to work with "batches" in a similar way # as `logpdf` can. + +# Default catch-all so we can work with distributions by default and batch-support can be +# added when needed. +_logabsdetjac_dist(d::Distribution, x) = logabsdetjac(bijector(d), x) + _logabsdetjac_dist(d::UnivariateDistribution, x::Real) = logabsdetjac(bijector(d), x) function _logabsdetjac_dist(d::UnivariateDistribution, x::AbstractArray) return logabsdetjac.((bijector(d),), x) diff --git a/src/bijectors/product_bijector.jl b/src/bijectors/product_bijector.jl index 94cd1fb1..c24ef394 100644 --- a/src/bijectors/product_bijector.jl +++ b/src/bijectors/product_bijector.jl @@ -8,11 +8,10 @@ inverse(b::ProductBijector) = ProductBijector(map(inverse, b.bs)) function _product_bijector_check_dim(::Val{N}, ::Val{M}) where {N,M} if N > M - throw( - DimensionMismatch( - "Number of bijectors needs to be smaller than or equal to the number of dimensions", - ), - ) + msg = """ + Number of bijectors needs to be smaller than or equal to the number of dimensions + """ + throw(DimensionMismatch(msg)) end end @@ -23,7 +22,18 @@ function _product_bijector_slices( # If N < M, then the bijectors expect an input vector of dimension `M - N`. # To achieve this, we need to slice along the last `N` dimensions. - return eachslice(x; dims=ntuple(i -> i + (M - N), N)) + slice_indices = ntuple(i -> i + (M - N), N) + if VERSION >= v"1.9" + return eachslice(x; dims=slice_indices) + else + # Earlier Julia versions can't eachslice over multiple dimensions, so reshape the + # slice dimensions into a single one. + other_dims = tuple((size(x, i) for i in 1:(M - N))...) + slice_dims = tuple((size(x, i) for i in (1 + M - N):M)...) + x_reshaped = reshape(x, other_dims..., prod(slice_dims)) + slices = eachslice(x_reshaped; dims=M - N + 1) + return reshape(collect(slices), slice_dims) + end end # Specialization for case where we're just applying elementwise. diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index e21f644c..bcdb9523 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -29,14 +29,35 @@ end if @isdefined Tapir rng = Xoshiro(123456) - Tapir.TestUtils.test_rrule!!( - rng, Bijectors.find_alpha, x, y, z; is_primitive=true, perf_flag=:none + Tapir.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + z; + is_primitive=true, + perf_flag=:none, + interp=Tapir.TapirInterpreter(), ) - Tapir.TestUtils.test_rrule!!( - rng, Bijectors.find_alpha, x, y, 3; is_primitive=true, perf_flag=:none + Tapir.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + 3; + is_primitive=true, + perf_flag=:none, + interp=Tapir.TapirInterpreter(), ) - Tapir.TestUtils.test_rrule!!( - rng, Bijectors.find_alpha, x, y, UInt32(3); is_primitive=true, perf_flag=:none + Tapir.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + UInt32(3); + is_primitive=true, + perf_flag=:none, + interp=Tapir.TapirInterpreter(), ) end diff --git a/test/interface.jl b/test/interface.jl index c3221307..44a73878 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -223,6 +223,15 @@ end end end +@testset "ProductDistribution" begin + d = product_distribution(fill(Dirichlet(ones(4)), 2, 3)) + x = rand(d) + b = bijector(d) + + @test logpdf_with_trans(d, x, false) == logpdf(d, x) + @test logpdf_with_trans(d, x, true) == logpdf(d, x) - logabsdetjac(b, x) +end + @testset "DistributionsAD" begin @testset "$dist" for dist in [ filldist(Normal(), 2),