From d6236611d6947ed16acabb19309bd34722ebfa67 Mon Sep 17 00:00:00 2001 From: Joe Greener Date: Mon, 30 Oct 2023 23:46:35 +0000 Subject: [PATCH] Move A \ B rule test --- test/internal_rules.jl | 22 ++++++++++++++++++++++ test/runtests.jl | 22 ---------------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 6698dbf81b5..6427715c0de 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -29,4 +29,26 @@ using Test @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 end +@testset "Linear Solve" begin + A = Float64[2 3; 5 7] + dA = zero(A) + b = Float64[11, 13] + db = zero(b) + + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) + + tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) + + dy = Float64[17, 19] + copyto!(shadow, dy) + + pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape) + + z = transpose(A) \ dy + + y = A \ b + @test dA ≈ (-z * transpose(y)) + @test db ≈ z +end + end # InternalRules diff --git a/test/runtests.jl b/test/runtests.jl index d4965665387..b31411d4ac2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2490,25 +2490,3 @@ end @test autodiff(Forward, f9, Duplicated(2.0, 1.0))[1] == 1.2 end end - -@testset "Linear Solve" begin - A = Float64[2 3; 5 7] - dA = zero(A) - b = Float64[11, 13] - db = zero(b) - - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) - - tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) - - dy = Float64[17, 19] - copyto!(shadow, dy) - - pullback(Const(\), Duplicated(A, dA), Duplicated(b, db), tape) - - z = transpose(A) \ dy - - y = A \ b - @test dA ≈ (-z * transpose(y)) - @test db ≈ z -end