diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 18ae7b3ad..f79957311 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -146,6 +146,7 @@ Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent) +Base.:-(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P} = a + (-b) # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 tangents diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index c177b05f4..d93487703 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -358,6 +358,24 @@ end @test -1.0 * t == -t end + @testset "subtraction" begin + a = Tangent{Foo}(; x=2.0, y=-2.0) + b = Tangent{Foo}(; x=1.0, y=2.0) + @test (a - b) == Tangent{Foo}(; x=1.0, y=-4.0) + + a = Tangent{Foo}(; x=2.0, y=-2.0) + b = Tangent{Foo}(; x=1.0) + @test (a - b) == Tangent{Foo}(; x=1.0, y=-2.0) + + a = Tangent{Tuple{Float64,Float64}}(2.0, 3.0) + b = Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + @test (a - b) == Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + a = MutableTangent{MFoo}(; x=1.5, y=1.5) + b = MutableTangent{MFoo}(; x=0.5, y=0.5) + @test (a - b) == MutableTangent{MFoo}(; x=1.0, y=1.0) + end + @testset "scaling" begin @test ( 2 * Tangent{Foo}(; y=1.5, x=2.5) ==