Skip to content

Commit

Permalink
Merge pull request #660 from JuliaDiff/ox/subtract
Browse files Browse the repository at this point in the history
Implement Tangent subtraction
  • Loading branch information
oxinabox authored Feb 7, 2024
2 parents d2b4b94 + c7e00c7 commit 1e3d426
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/tangent_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a

Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)
Base.:-(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P} = a + (-b)

# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
# In general one doesn't have to represent multiplications of 2 tangents
Expand Down
18 changes: 18 additions & 0 deletions test/tangent_types/structural_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,24 @@ end
@test -1.0 * t == -t
end

@testset "subtraction" begin
a = Tangent{Foo}(; x=2.0, y=-2.0)
b = Tangent{Foo}(; x=1.0, y=2.0)
@test (a - b) == Tangent{Foo}(; x=1.0, y=-4.0)

a = Tangent{Foo}(; x=2.0, y=-2.0)
b = Tangent{Foo}(; x=1.0)
@test (a - b) == Tangent{Foo}(; x=1.0, y=-2.0)

a = Tangent{Tuple{Float64,Float64}}(2.0, 3.0)
b = Tangent{Tuple{Float64,Float64}}(1.0, 1.0)
@test (a - b) == Tangent{Tuple{Float64,Float64}}(1.0, 2.0)

a = MutableTangent{MFoo}(; x=1.5, y=1.5)
b = MutableTangent{MFoo}(; x=0.5, y=0.5)
@test (a - b) == MutableTangent{MFoo}(; x=1.0, y=1.0)
end

@testset "scaling" begin
@test (
2 * Tangent{Foo}(; y=1.5, x=2.5) ==
Expand Down

0 comments on commit 1e3d426

Please sign in to comment.