From e05886a9a64d2f862332e24023a60c6a8a762615 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sun, 9 Jul 2023 01:14:51 +0000 Subject: [PATCH 1/4] Don't use the array muladd rule for ZeroTangent --- src/rulesets/Base/arraymath.jl | 6 +++--- src/rulesets/Base/base.jl | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 7fbf46062..078bb602a 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -351,7 +351,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) - + Ȳf = Ȳ @static if VERSION >= v"1.9" # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358 @@ -360,7 +360,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R end end Yf = Y - @static if VERSION >= v"1.9" + @static if VERSION >= v"1.9" # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358 if !isa(Y, AbstractArray) Yf = [Y] @@ -371,7 +371,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R B̄ = A' \ Ȳf Ā = -B̄ * Y' t = (B - A * Y) * B̄' - @static if VERSION >= v"1.9" + @static if VERSION >= v"1.9" # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358 if !isa(t, AbstractArray) t = [t] diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 9576abd98..28cc11d19 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -94,6 +94,7 @@ end @scalar_rule fma(x, y, z) (y, x, true) @scalar_rule muladd(x, y, z) (y, x, true) +@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true) @scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent()) @scalar_rule( mod(x, y), From 47479f458ec798f824a4dc1806f196c39b2abd2c Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 10 Nov 2023 14:18:51 +0800 Subject: [PATCH 2/4] test muladd mixing numbers and zerotangents --- test/rulesets/Base/base.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 36452da1e..7d06b1421 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -153,6 +153,17 @@ test_rrule(muladd, 10randn(), randn(), randn()) end + @testset "muladd ZeroTangent" begin + test_frule(muladd, 2.0, 3.0, ZeroTangent()) + test_frule(muladd, 2.0, ZeroTangent(), 4.0) + test_frule(muladd, ZeroTangent(), 3.0, 4.0) + + test_rrule(muladd, 2.0, 3.0, ZeroTangent()) + test_rrule(muladd, 2.0, ZeroTangent(), 4.0) + test_rrule(muladd, ZeroTangent(), 3.0, 4.0) + end + + @testset "fma" begin test_frule(fma, 10randn(), randn(), randn()) test_rrule(fma, 10randn(), randn(), randn()) From 81c6e8ca011c680aaefdc1396e15bb10553e2116 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 10 Nov 2023 14:44:52 +0800 Subject: [PATCH 3/4] style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/rulesets/Base/base.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 7d06b1421..9a5278747 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -156,14 +156,13 @@ @testset "muladd ZeroTangent" begin test_frule(muladd, 2.0, 3.0, ZeroTangent()) test_frule(muladd, 2.0, ZeroTangent(), 4.0) - test_frule(muladd, ZeroTangent(), 3.0, 4.0) + test_frule(muladd, ZeroTangent(), 3.0, 4.0) test_rrule(muladd, 2.0, 3.0, ZeroTangent()) test_rrule(muladd, 2.0, ZeroTangent(), 4.0) - test_rrule(muladd, ZeroTangent(), 3.0, 4.0) + test_rrule(muladd, ZeroTangent(), 3.0, 4.0) end - @testset "fma" begin test_frule(fma, 10randn(), randn(), randn()) test_rrule(fma, 10randn(), randn(), randn()) From 87f49961e57b368e3f0afaa697f64a3f98f638c0 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 10 Nov 2023 16:04:58 +0800 Subject: [PATCH 4/4] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 31d77496a..82f4d6616 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.57.0" +version = "1.58.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"