Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
perrutquist committed Nov 10, 2023
2 parents ba89901 + 87f4996 commit ea25c11
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.57.0"
version = "1.58.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
6 changes: 3 additions & 3 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R

function backslash_pullback(ȳ)
= unthunk(ȳ)

Ȳf =
@static if VERSION >= v"1.9"
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
Expand All @@ -360,7 +360,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
end
end
Yf = Y
@static if VERSION >= v"1.9"
@static if VERSION >= v"1.9"
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(Y, AbstractArray)
Yf = [Y]
Expand All @@ -371,7 +371,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
= A' \ Ȳf
= -* Y'
t = (B - A * Y) *'
@static if VERSION >= v"1.9"
@static if VERSION >= v"1.9"
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
if !isa(t, AbstractArray)
t = [t]
Expand Down
1 change: 1 addition & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ end

@scalar_rule fma(x, y, z) (y, x, true)
@scalar_rule muladd(x, y, z) (y, x, true)
@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true)
@scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent())
@scalar_rule(
mod(x, y),
Expand Down
10 changes: 10 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@
test_rrule(muladd, 10randn(), randn(), randn())
end

@testset "muladd ZeroTangent" begin
test_frule(muladd, 2.0, 3.0, ZeroTangent())
test_frule(muladd, 2.0, ZeroTangent(), 4.0)
test_frule(muladd, ZeroTangent(), 3.0, 4.0)

test_rrule(muladd, 2.0, 3.0, ZeroTangent())
test_rrule(muladd, 2.0, ZeroTangent(), 4.0)
test_rrule(muladd, ZeroTangent(), 3.0, 4.0)
end

@testset "fma" begin
test_frule(fma, 10randn(), randn(), randn())
test_rrule(fma, 10randn(), randn(), randn())
Expand Down

0 comments on commit ea25c11

Please sign in to comment.