diff --git a/Project.toml b/Project.toml index 31d77496a..82f4d6616 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.57.0" +version = "1.58.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 7fbf46062..078bb602a 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -351,7 +351,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R function backslash_pullback(ȳ) Ȳ = unthunk(ȳ) - + Ȳf = Ȳ @static if VERSION >= v"1.9" # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358 @@ -360,7 +360,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R end end Yf = Y - @static if VERSION >= v"1.9" + @static if VERSION >= v"1.9" # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358 if !isa(Y, AbstractArray) Yf = [Y] @@ -371,7 +371,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R B̄ = A' \ Ȳf Ā = -B̄ * Y' t = (B - A * Y) * B̄' - @static if VERSION >= v"1.9" + @static if VERSION >= v"1.9" # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358 if !isa(t, AbstractArray) t = [t] diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 9576abd98..28cc11d19 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -94,6 +94,7 @@ end @scalar_rule fma(x, y, z) (y, x, true) @scalar_rule muladd(x, y, z) (y, x, true) +@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true) @scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent()) @scalar_rule( mod(x, y), diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 36452da1e..9a5278747 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -153,6 +153,16 @@ test_rrule(muladd, 10randn(), randn(), randn()) end + @testset "muladd ZeroTangent" begin + test_frule(muladd, 2.0, 3.0, ZeroTangent()) + test_frule(muladd, 2.0, ZeroTangent(), 4.0) + test_frule(muladd, ZeroTangent(), 3.0, 4.0) + + test_rrule(muladd, 2.0, 3.0, ZeroTangent()) + test_rrule(muladd, 2.0, ZeroTangent(), 4.0) + test_rrule(muladd, ZeroTangent(), 3.0, 4.0) + end + @testset "fma" begin test_frule(fma, 10randn(), randn(), randn()) test_rrule(fma, 10randn(), randn(), randn())