Skip to content

Commit

Permalink
OOP tests
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Nov 30, 2024
1 parent eaaba53 commit 9935a2e
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions test/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,11 @@ _, easy_res = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(abstol = 1e-14,
reltol = 1e-14))
sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14))
_, easy_res22 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(autojacvec = false,
abstol = 1e-14,
reltol = 1e-14))
sensealg = QuadratureAdjoint(autojacvec = false, abstol = 1e-14, reltol = 1e-14))
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
Expand All @@ -239,17 +236,15 @@ _, easy_res3 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
@test easy_res32 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa
AbstractArray
sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa AbstractArray
_, easy_res4 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint())
@test easy_res42 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = false))[1] isa
AbstractArray
sensealg = BacksolveAdjoint(autojacvec = false))[1] isa AbstractArray
_, easy_res5 = adjoint_sensitivities(soloop,
Kvaerno5(nlsolve = NLAnderson(), smooth_est = false),
t = t, dgdu_discrete = dg, abstol = 1e-12,
Expand All @@ -263,8 +258,7 @@ _, easy_res6 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discre
_, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg, abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(checkpointing = true,
autojacvec = false),
sensealg = InterpolatingAdjoint(checkpointing = true, autojacvec = false),
checkpoints = soloop_nodense.t[1:5:end])

_, easy_res8 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discrete = dg,
Expand Down Expand Up @@ -304,6 +298,39 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc
reltol = 1e-14,
sensealg = GaussAdjoint(checkpointing = true),
checkpoints = soloop_nodense.t[1:5:end])

_, easy_res2_mc_quad = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(
abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res2_mc_interp = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res2_mc_back = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res6_mc_quad = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = QuadratureAdjoint(
abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP()))
_, easy_res6_mc_interp = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = InterpolatingAdjoint(checkpointing = true,
autojacvec = SciMLSensitivity.MooncakeVJP()),
checkpoints = soloop_nodense.t[1:5:end])
_, easy_res6_mc_back = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
dgdu_discrete = dg,
abstol = 1e-14,
reltol = 1e-14,
sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))

@test isapprox(res, easy_res, rtol = 1e-10)
@test isapprox(res, easy_res2, rtol = 1e-10)
@test isapprox(res, easy_res22, rtol = 1e-10)
Expand All @@ -324,6 +351,12 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc
@test isapprox(res, easy_res12, rtol = 1e-9)
@test isapprox(res, easy_res122, rtol = 1e-9)
@test isapprox(res, easy_res123, rtol = 1e-4)
@test isapprox(res, easy_res2_mc_quad, rtol=1e-9)
@test isapprox(res, easy_res2_mc_interp, rtol=1e-9)
@test isapprox(res, easy_res2_mc_back, rtol=1e-9)
@test isapprox(res, easy_res6_mc_quad, rtol=1e-4)
@test isapprox(res, easy_res6_mc_interp, rtol=1e-9)
@test isapprox(res, easy_res6_mc_back, rtol=1e-9)

println("Calculate adjoint sensitivities ")

Expand Down

0 comments on commit 9935a2e

Please sign in to comment.