diff --git a/Project.toml b/Project.toml index f7be0257d..046813958 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.29" +version = "0.29.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -30,6 +30,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -38,6 +39,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLReverseDiffExt = ["ReverseDiff"] +DynamicPPLTapirExt = ["Tapir"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -63,6 +65,7 @@ OrderedCollections = "1" Random = "1.6" Requires = "1" ReverseDiff = "1" +Tapir = "0.2.44" Test = "1.6" ZygoteRules = "0.2" julia = "1.6" @@ -70,7 +73,8 @@ julia = "1.6" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/ext/DynamicPPLTapirExt.jl b/ext/DynamicPPLTapirExt.jl new file mode 100644 index 000000000..7f264d268 --- /dev/null +++ b/ext/DynamicPPLTapirExt.jl @@ -0,0 +1,19 @@ +module DynamicPPLTapirExt + +if isdefined(Base, :get_extension) + using DynamicPPL: DynamicPPL + using Tapir: Tapir +else + using ..DynamicPPL: DynamicPPL + using ..Tapir: Tapir +end + +using Tapir: DefaultCtx, CoDual, NoPullback, primal, zero_fcodual + +# This is purely an optimisation. +Tapir.@is_primitive DefaultCtx Tuple{typeof(DynamicPPL.istrans), Vararg} +function Tapir.rrule!!(f::CoDual{typeof(DynamicPPL.istrans)}, x::Vararg{CoDual, N}) where {N} + return zero_fcodual(DynamicPPL.istrans(map(primal, x)...)), NoPullback(f, x...) +end + +end # module diff --git a/test/Project.toml b/test/Project.toml index 13267ee1d..4fb9f02a5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,15 +21,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Accessors = "0.1" ADTypes = "0.2, 1" AbstractMCMC = "5" AbstractPPL = "0.8.2" +Accessors = "0.1" Bijectors = "0.13" Compat = "4.3.0" Distributions = "0.25" @@ -43,6 +44,7 @@ MCMCChains = "6.0.4" MacroTools = "0.5.5" ReverseDiff = "1" StableRNGs = "1" +Tapir = "0.2.44" Tracker = "0.2.23" Zygote = "0.6" julia = "1.6" diff --git a/test/ext/DynamicPPLTapirExt.jl b/test/ext/DynamicPPLTapirExt.jl new file mode 100644 index 000000000..60f7d33c0 --- /dev/null +++ b/test/ext/DynamicPPLTapirExt.jl @@ -0,0 +1,9 @@ +@testset "DynamicPPLTapirExt" begin + Tapir.TestUtils.test_rule( + Xoshiro(123), istrans, VarInfo(); + perf_flag=:none, + interface_only=true, + is_primitive=true, + interp=Tapir.TapirInterpreter(), + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index aa0883708..47617bd72 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ using ForwardDiff using LogDensityProblems, LogDensityProblemsAD using MacroTools using MCMCChains +using Tapir using Tracker using ReverseDiff using Zygote @@ -68,6 +69,7 @@ include("test_util.jl") @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") + include("ext/DynamicPPLTapirExt.jl") include("ad.jl") end