diff --git a/Project.toml b/Project.toml index 2bbf1f5b..c17fa11b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.3" +version = "0.15.4" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/ext/BijectorsZygoteExt.jl b/ext/BijectorsZygoteExt.jl index 79195b88..0db59007 100644 --- a/ext/BijectorsZygoteExt.jl +++ b/ext/BijectorsZygoteExt.jl @@ -22,7 +22,6 @@ using Bijectors: find_alpha, pd_logpdf_with_trans, istraining, - mapvcat, eachcolmaphcat, sumeachcol, pd_link, @@ -36,10 +35,6 @@ using Bijectors.Distributions: LocationScale @adjoint istraining() = true, _ -> nothing -@adjoint function mapvcat(f, args...) - g(f, args...) = map(f, args...) - return pullback(g, f, args...) -end @adjoint function eachcolmaphcat(f, x1, x2) function g(f, x1, x2) init = reshape(f(view(x1, :, 1), x2[1]), :, 1) diff --git a/test/ad/stacked.jl b/test/ad/stacked.jl index 63855410..89f5f4be 100644 --- a/test/ad/stacked.jl +++ b/test/ad/stacked.jl @@ -24,4 +24,15 @@ test_ad(y) do y sum(transform(binv, y)) end + + bvec = Stacked([b1, b2], [1:4, 5:5]) + bvec_inv = inverse(bvec) + + test_ad(y) do x + sum(transform(bvec, binv(x))) + end + + test_ad(y) do y + sum(transform(bvec_inv, y)) + end end