From 1be197f7a6abf59ed5eb58315f92e10974a232a0 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 4 Feb 2024 23:46:51 +0100 Subject: [PATCH] Use `Zygote.jacobian` etc. --- ext/AbstractDifferentiationZygoteExt.jl | 20 ++++++++++++++++++-- src/backends.jl | 7 ++++--- test/ruleconfig.jl | 24 ++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/ext/AbstractDifferentiationZygoteExt.jl b/ext/AbstractDifferentiationZygoteExt.jl index 808bf39..1610006 100644 --- a/ext/AbstractDifferentiationZygoteExt.jl +++ b/ext/AbstractDifferentiationZygoteExt.jl @@ -8,11 +8,27 @@ else using ..Zygote: Zygote end -AD.ZygoteBackend() = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig()) - # Context should not persist between different AD calls: fixes #69 function AD.ruleconfig(::AD.ReverseRuleConfigBackend{<:Zygote.ZygoteRuleConfig}) return Zygote.ZygoteRuleConfig() end +function AD.value_and_pullback_function(::AD.ZygoteBackend, f, args...) + return Zygote.pullback(f, args...) +end + +AD.gradient(::AD.ZygoteBackend, f, args...) = Zygote.gradient(f, args...) +function AD.value_and_gradient(::AD.ZygoteBackend, f, args...) + res = Zygote.withgradient(f, args...) + return res.val, res.grad +end + +AD.jacobian(::AD.ZygoteBackend, f, args...) = Zygote.jacobian(f, args...) +function AD.value_and_jacobian(::AD.ZygoteBackend, f, args...) + res = Zygote.withjacobian(f, args...) + return res.val, res.grad +end + +AD.hessian(::AD.ZygoteBackend, f, arg) = Zygote.hessian(f, arg) + end # module diff --git a/src/backends.jl b/src/backends.jl index 10ed9bb..397eff5 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -71,13 +71,14 @@ end ruleconfig(ba::ReverseRuleConfigBackend) = ba.ruleconfig """ - ZygoteBackend() + ZygoteBackend Create an AD backend that uses reverse mode with [Zygote.jl](https://github.com/FluxML/Zygote.jl). -It is a special case of [`ReverseRuleConfigBackend`](@ref). +Alternatively, you can perform AD with Zygote using a special [`ReverseRuleConfigBackend`](@ref), namely `ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())`. +Note, however, that the behaviour of this backend is not equivalent to `ZygoteBackend()` since the former uses a generic implementation of jacobian etc. for ChainRules-compatible AD backends whereas `ZygoteBackend` uses implementations in Zygote.jl. !!! note To be able to use this backend, you have to load Zygote. """ -function ZygoteBackend end +struct ZygoteBackend <: AbstractReverseMode end diff --git a/test/ruleconfig.jl b/test/ruleconfig.jl index 0a97b66..655adad 100644 --- a/test/ruleconfig.jl +++ b/test/ruleconfig.jl @@ -4,7 +4,10 @@ using Test using Zygote @testset "ReverseRuleConfigBackend(ZygoteRuleConfig())" begin - backends = [@inferred(AD.ZygoteBackend())] + backends = [ + @inferred(AD.ZygoteBackend()), + @inferred(AD.ReverseRuleConfig(Zygote.ZygoteRuleConfig())) + ] @testset for backend in backends @testset "Derivative" begin test_derivatives(backend) @@ -34,7 +37,7 @@ using Zygote # issue #69 @testset "Zygote context" begin - ad = AD.ZygoteBackend() + ad = AD.ReverseRuleConfig(Zygote.ZygoteRuleConfig()) # example in #69: context is not mutated @test ad.ruleconfig.context.cache === nothing @@ -53,6 +56,13 @@ using Zygote end @test AD.jacobian(ad, f, [1, 2, 3], 3) == ([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0]) + + # With `AD.ZygoteBackend`: + ad = AD.ZygoteBackend() + @test AD.derivative(ad, exp, 1.0) === (exp(1.0),) + @test AD.derivative(ad, exp, 1.0) === (exp(1.0),) + @test AD.jacobian(ad, f, [1, 2, 3], 3) == + ([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0]) end # issue #57 @@ -65,5 +75,15 @@ using Zygote @test_logs Zygote.gradient(myfunc, 1) # nothing is logged @test_logs AD.derivative(AD.ZygoteBackend(), myfunc, 1) # nothing is logged + @test_logs AD.derivative(AD.ReverseRuleConfig(Zygote.ZygoteRuleConfig()), myfunc, 1) # nothing is logged + end + + # issue #54 + @testset "allocations of jacobian" begin + f(x) = x .^ 2 + x = rand(100) + ad = AD.ZygoteBackend() + @test AD.jacobian(ad, f, x) == Zygote.jacobian(f, x) + @test @allocated(AD.jacobian(ad, f, x)) == @allocated(Zygote.jacobian(f, x)) end end