From 43b5b33614a279d696874512004a24bfa80adab2 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 1 Nov 2022 13:50:14 -0500 Subject: [PATCH 1/7] Add support for explicit mode gradients and optimizers --- Project.toml | 4 +++- src/optimise/Optimise.jl | 9 +++++++++ src/optimise/gradients.jl | 23 +++++++++++++++++++++++ src/optimise/optimisers.jl | 3 --- src/optimise/train.jl | 36 +++++++++++++++++++++++++----------- 5 files changed, 60 insertions(+), 15 deletions(-) create mode 100644 src/optimise/gradients.jl diff --git a/Project.toml b/Project.toml index a494b960fc..aebae8dab1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,9 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.7" +version = "0.13.8-DEV" [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -26,6 +27,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +AbstractDifferentiation = "0.4.3" Adapt = "3.0" ArrayInterface = "3.1, 4, 5, 6" CUDA = "3" diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index e691ce0170..98d481f95e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -1,7 +1,15 @@ module Optimise +using Flux +using MacroTools: @forward +import Zygote +import Zygote: Params, gradient +using AbstractDifferentiation +import Optimisers +import Optimisers: update, update! using LinearAlgebra import ArrayInterface +using ProgressLogging: @progress, @withprogress, @logprogress export train!, update!, Descent, Adam, Momentum, Nesterov, RMSProp, @@ -10,6 +18,7 @@ export train!, update!, ClipValue, ClipNorm include("optimisers.jl") +include("gradients.jl") include("train.jl") end diff --git a/src/optimise/gradients.jl b/src/optimise/gradients.jl new file mode 100644 index 0000000000..9a3dba8f12 --- /dev/null +++ b/src/optimise/gradients.jl @@ -0,0 +1,23 @@ +struct ZygoteImplicitBackend{T} <: AD.AbstractReverseMode + core_backend::T +end +ZygoteImplicitBackend() = ZygoteImplicitBackend(AD.ZygoteBackend()) + +AD.@primitive pullback_function(ad::ZygoteImplicitBackend, f, x::Zygote.Params) = + AD.pullback_function(ad.core_backend, f, x) + +# this is a hack to get around +# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 +AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x) + +struct ZygoteExplicitBackend{T} <: AD.AbstractReverseMode + core_backend::T +end +ZygoteExplicitBackend() = ZygoteExplicitBackend(AD.ZygoteBackend()) + +AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) = + AD.pullback_function(ad.core_backend, f, xs...) + +# this is a hack to get around +# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 +AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index ce72a4b0ce..bdb50fd6d5 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -1,6 +1,3 @@ -using Flux -using MacroTools: @forward - abstract type AbstractOptimiser end const EPS = 1e-8 diff --git a/src/optimise/train.jl b/src/optimise/train.jl index a1c3e9a7aa..aa5b0ed0a9 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,7 +1,3 @@ -using ProgressLogging: @progress, @withprogress, @logprogress -import Zygote: Params, gradient, withgradient - - """ update!(opt, p, g) update!(opt, ps::Params, gs) @@ -12,18 +8,23 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`). As a result, the parameters are mutated and the optimizer's internal state may change. The gradient could be mutated as well. """ -function update!(opt::AbstractOptimiser, x, x̄) +function Optimisers.update!(opt::AbstractOptimiser, x, x̄) x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's # output are not mutable, see #1510 x .-= apply!(opt, x, x̄r) + + return opt, x end -function update!(opt::AbstractOptimiser, xs::Params, gs) +function Optimisers.update!(opt::AbstractOptimiser, xs::Params, gs) for x in xs isnothing(gs[x]) && continue update!(opt, x, gs[x]) end + + return opt, xs end +Optimisers.update(opt::AbstractOptimiser, xs::Params, gs) = update!(opt, xs, gs) # Callback niceties call(f, xs...) = f(xs...) @@ -82,6 +83,16 @@ end batchmemaybe(x) = tuple(x) batchmemaybe(x::Tuple) = x +_build_loss(::AD.AbstractBackend, loss, data) = function _loss(m) + return loss(m, data...) +end +_build_loss(::ZygoteImplicitBackend, loss, data) = function _loss() + return loss(data...) +end +_gradient_only(x::Zygote.Grads) = x +_gradient_only(x::NTuple{1}) = x[1] +_gradient_only(x) = error("Expected gradient w.r.t. single argument (or Zygote.Grads) but got $x") + """ train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb]) @@ -122,19 +133,18 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop. Multiple callbacks can be passed to `cb` as array. """ -function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) +function train!(loss, ad::AD.AbstractBackend, model, data, optstate; cb = () -> ()) cb = runall(cb) itrsz = Base.IteratorSize(typeof(data)) n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0 @withprogress for (i, d) in enumerate(data) try - l, gs = withgradient(ps) do - loss(batchmemaybe(d)...) - end + _loss = _build_loss(ad, loss, batchmemaybe(d)) + l, gs = AD.valud_and_gradient(ad, _loss, model) if !isfinite(l) throw(DomainError("Loss is $l on data item $i, stopping training")) end - update!(opt, ps, gs) + optstate, model = update(optstate, model, _gradient_only(gs)) cb() catch ex if ex isa StopException @@ -147,7 +157,11 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) end @logprogress iszero(n) ? nothing : i / n end + + return optstate, model end +train!(loss, model, data, optstate; kwargs...) = + train!(loss, ZygoteImplicitBackend(), model, data, optstate; kwargs...) """ @epochs N body From 1afce6a755813905a63ef7b809d626c19bc6bb8a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 14 Oct 2022 11:48:59 -0400 Subject: [PATCH 2/7] Add Tracker support --- Project.toml | 1 + src/optimise/Optimise.jl | 1 + src/optimise/gradients.jl | 3 +++ 3 files changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index aebae8dab1..abab9d6f9f 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 98d481f95e..7fc1c78a20 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -4,6 +4,7 @@ using Flux using MacroTools: @forward import Zygote import Zygote: Params, gradient +import Tracker using AbstractDifferentiation import Optimisers import Optimisers: update, update! diff --git a/src/optimise/gradients.jl b/src/optimise/gradients.jl index 9a3dba8f12..bbbacb2bb0 100644 --- a/src/optimise/gradients.jl +++ b/src/optimise/gradients.jl @@ -21,3 +21,6 @@ AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) = # this is a hack to get around # https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...) + +# this is to work around AD.TrackerBackend only supporting vectors of params +AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad From 341144d5f8520028a9d011c90a142b2b99feeb7e Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 14 Oct 2022 14:04:32 -0400 Subject: [PATCH 3/7] Remove Tracker support --- Project.toml | 1 - src/optimise/Optimise.jl | 1 - src/optimise/gradients.jl | 3 --- 3 files changed, 5 deletions(-) diff --git a/Project.toml b/Project.toml index abab9d6f9f..aebae8dab1 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,6 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 7fc1c78a20..98d481f95e 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -4,7 +4,6 @@ using Flux using MacroTools: @forward import Zygote import Zygote: Params, gradient -import Tracker using AbstractDifferentiation import Optimisers import Optimisers: update, update! diff --git a/src/optimise/gradients.jl b/src/optimise/gradients.jl index bbbacb2bb0..9a3dba8f12 100644 --- a/src/optimise/gradients.jl +++ b/src/optimise/gradients.jl @@ -21,6 +21,3 @@ AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) = # this is a hack to get around # https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...) - -# this is to work around AD.TrackerBackend only supporting vectors of params -AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad From d0cf342cd78ca12f9ff3e5e2e31c8dfbc7fccaf6 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 1 Nov 2022 13:50:48 -0500 Subject: [PATCH 4/7] Switch to update! only --- src/optimise/Optimise.jl | 2 +- src/optimise/train.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 98d481f95e..53d2792c8f 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -6,7 +6,7 @@ import Zygote import Zygote: Params, gradient using AbstractDifferentiation import Optimisers -import Optimisers: update, update! +import Optimisers: update! using LinearAlgebra import ArrayInterface using ProgressLogging: @progress, @withprogress, @logprogress diff --git a/src/optimise/train.jl b/src/optimise/train.jl index aa5b0ed0a9..84b43f7a2e 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -24,7 +24,6 @@ function Optimisers.update!(opt::AbstractOptimiser, xs::Params, gs) return opt, xs end -Optimisers.update(opt::AbstractOptimiser, xs::Params, gs) = update!(opt, xs, gs) # Callback niceties call(f, xs...) = f(xs...) @@ -144,7 +143,7 @@ function train!(loss, ad::AD.AbstractBackend, model, data, optstate; cb = () -> if !isfinite(l) throw(DomainError("Loss is $l on data item $i, stopping training")) end - optstate, model = update(optstate, model, _gradient_only(gs)) + optstate, model = update!(optstate, model, _gradient_only(gs)) cb() catch ex if ex isa StopException From fbde4776b4b150eca55cf3a322c97386fadec8dc Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 15 Oct 2022 17:12:25 -0400 Subject: [PATCH 5/7] Add AD.value_and_gradient too --- src/optimise/gradients.jl | 4 ++++ src/optimise/train.jl | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/optimise/gradients.jl b/src/optimise/gradients.jl index 9a3dba8f12..d92e55dc93 100644 --- a/src/optimise/gradients.jl +++ b/src/optimise/gradients.jl @@ -9,6 +9,8 @@ AD.@primitive pullback_function(ad::ZygoteImplicitBackend, f, x::Zygote.Params) # this is a hack to get around # https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x) +AD.value_and_gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = + Zygote.withgradient(f, x) struct ZygoteExplicitBackend{T} <: AD.AbstractReverseMode core_backend::T @@ -21,3 +23,5 @@ AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) = # this is a hack to get around # https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150 AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...) +AD.value_and_gradient(::ZygoteExplicitBackend, f, xs...) = + Zygote.withgradient(f, xs...) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 84b43f7a2e..38edcde9e3 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -94,7 +94,7 @@ _gradient_only(x) = error("Expected gradient w.r.t. single argument (or Zygote.G """ train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb]) - + Uses a `loss` function and training `data` to improve the model's parameters according to a particular optimisation rule `opt`. From 37c97590285552c1b8aab720a014cbed3cf7c10d Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 1 Nov 2022 13:45:14 -0500 Subject: [PATCH 6/7] Add tests for AD backends --- Project.toml | 3 ++- test/optimise.jl | 32 +++++++++++++++++++++++++++++++- test/runtests.jl | 2 ++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index aebae8dab1..213cdefed7 100644 --- a/Project.toml +++ b/Project.toml @@ -53,6 +53,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker"] diff --git a/test/optimise.jl b/test/optimise.jl index 41de5a4a10..b59f67fd13 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -1,5 +1,5 @@ using Flux.Optimise -using Flux.Optimise: runall +using Flux.Optimise: runall, ZygoteImplicitBackend, ZygoteExplicitBackend using Flux: Params, gradient import FillArrays, ComponentArrays using Test @@ -45,6 +45,36 @@ end end end +@testset "AD backends" begin + # this is hack to make Tracker work + AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad + AD.value_and_gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...) + + function _loss_and_model(::ZygoteImplicitBackend, loss, model) + return () -> loss(model), Flux.params(model) + end + _loss_and_model(ad, loss, model) = loss, model + + function _check_gradient(::ZygoteImplicitBackend, model, grad) + return grad[model[1].weight] == 2 .* Flux.ones32(5, 10) && + grad[model[2].weight] == 10 .* Flux.ones32(2, 5) + end + function _check_gradient(ad, model, grad) + return grad[1].layers[1].weight == 2 .* Flux.ones32(5, 10) && + grad[1].layers[2].weight == 10 .* Flux.ones32(2, 5) + end + + @testset for ad in [ZygoteImplicitBackend(), ZygoteExplicitBackend(), AD.TrackerBackend()] + model = Chain(Dense(Flux.ones32(5, 10), false), Dense(Flux.ones32(2, 5), false)) + x = Flux.ones32(10) + _loss, _model = _loss_and_model(ad, m -> sum(m(x)), model) + val, grad = AD.value_and_gradient(ad, _loss, _model) + @test val == sum(model(x)) + @test _check_gradient(ad, model, grad) + @test _check_gradient(ad, model, AD.gradient(ad, _loss, _model)) + end +end + @testset "Training Loop" begin i = 0 l = 1 diff --git a/test/runtests.jl b/test/runtests.jl index 9027b114fc..ed04582b32 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,9 @@ using Flux: params using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle +import Tracker using Zygote +using AbstractDifferentiation using CUDA Random.seed!(0) From 800b28fe407a6eacaffaf4298cfa46f8f72e9c4b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 1 Nov 2022 17:22:44 -0500 Subject: [PATCH 7/7] Fix typo after rebase --- src/optimise/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 38edcde9e3..2ea422e035 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -139,7 +139,7 @@ function train!(loss, ad::AD.AbstractBackend, model, data, optstate; cb = () -> @withprogress for (i, d) in enumerate(data) try _loss = _build_loss(ad, loss, batchmemaybe(d)) - l, gs = AD.valud_and_gradient(ad, _loss, model) + l, gs = AD.value_and_gradient(ad, _loss, model) if !isfinite(l) throw(DomainError("Loss is $l on data item $i, stopping training")) end