From 51c0e470f8a74e62a626d65874a415c59bf653b4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Dec 2024 16:56:23 +0530 Subject: [PATCH] chore: bump minimum Reactant version (#1125) * chore: bump minimum Reactant version * fix: manually `set_abi` for reactant --- Project.toml | 2 +- docs/Project.toml | 2 +- ext/LuxReactantExt/training.jl | 6 ++++-- test/reactant/layer_tests.jl | 6 +++--- test/runtests.jl | 3 +-- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index d6a5b4ccd8..2fc315c8a3 100644 --- a/Project.toml +++ b/Project.toml @@ -98,7 +98,7 @@ NNlib = "0.9.24" Optimisers = "0.4.1" Preferences = "1.4.3" Random = "1.10" -Reactant = "0.2.6" +Reactant = "0.2.8" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/docs/Project.toml b/docs/Project.toml index 2294d0bfde..3eb44b24ef 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -56,7 +56,7 @@ Optimisers = "0.4.1" Pkg = "1.10" Printf = "1.10" Random = "1.10" -Reactant = "0.2.6" +Reactant = "0.2.8" StableRNGs = "1" StaticArrays = "1" WeightInitializers = "1" diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 182ca9c86d..605093e1ea 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -74,7 +74,8 @@ function compute_gradients_internal_and_step(objective_function::F, model, data, st, opt_state) where {F} dps = Enzyme.make_zero(ps) _, (loss, stₙ, stats) = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), + Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), + Const(objective_function), Active, Const(model), Duplicated(ps, dps), Const(st), Const(data)) opt_state, ps = Optimisers.update(opt_state, ps, dps) return dps, ps, loss, stats, stₙ, opt_state @@ -84,7 +85,8 @@ function compute_gradients_internal_and_step!(objective_function::F, model, data st, opt_state) where {F} dps = Enzyme.make_zero(ps) _, (loss, stₙ, stats) = Enzyme.autodiff( - Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), + Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), + Const(objective_function), Active, Const(model), Duplicated(ps, dps), Const(st), Const(data)) # XXX: Inplace updates not actually inplace opt_state, ps = Optimisers.update!(opt_state, ps, dps) diff --git a/test/reactant/layer_tests.jl b/test/reactant/layer_tests.jl index bb84fe59d9..8130691cb7 100644 --- a/test/reactant/layer_tests.jl +++ b/test/reactant/layer_tests.jl @@ -52,13 +52,13 @@ end y_ra, _ = @jit model(x_ra, ps_ra, st_ra) y, _ = model(x, ps, st) - @test y_ra≈y atol=1e-3 rtol=1e-3 + @test y_ra≈y atol=1e-2 rtol=1e-2 @testset "gradient" begin ∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st) ∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra) - @test ∂x_ra≈∂x atol=1e-3 rtol=1e-3 - @test check_approx(∂ps_ra, ∂ps; atol=1e-3, rtol=1e-3) + @test ∂x_ra≈∂x atol=1e-2 rtol=1e-2 + @test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2) end end end diff --git a/test/runtests.jl b/test/runtests.jl index 6837b9ae00..0f96e8b49f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -135,8 +135,7 @@ const RETESTITEMS_NWORKER_THREADS = parse( ReTestItems.runtests(Lux; tags=(tag == "all" ? nothing : [Symbol(tag)]), testitem_timeout=2400, - nworkers, nworker_threads=RETESTITEMS_NWORKER_THREADS, - retries=tag == "reactant" ? 2 : 0 + nworkers, nworker_threads=RETESTITEMS_NWORKER_THREADS ) end end