Skip to content

Commit

Permalink
Add call_at_end and save_positions to callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
ph-kev committed Jan 3, 2025
1 parent 95e6109 commit c130e6f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
38 changes: 32 additions & 6 deletions src/Callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,13 @@ Trigger `f!(integrator)` every `Δt` simulation time.
If `atinit=true`, then `f!` will additionally be triggered at initialization. Otherwise
the first trigger will be after `Δt` simulation time.
If `call_at_end==true`, then `f!` will be triggered at the end of the time span. Otherwise
there is no call to `f!` at the end of the time span.
The tuple `save_positions` determines whether to save before or after `f!`.
"""
function EveryXSimulationTime(f!, Δt; atinit = false)
function EveryXSimulationTime(f!, Δt; atinit = false, call_at_end = false, save_positions = (true, true))
t_next = zero(Δt)

function _initialize(c, u, t, integrator)
Expand All @@ -111,14 +116,22 @@ function EveryXSimulationTime(f!, Δt; atinit = false)
t_next += Δt
end
return true
elseif (call_at_end && t == integrator.sol.prob.tspan[2])
return true
else
return false
end
end
if isdefined(DiffEqBase, :finalize!)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
SciMLBase.DiscreteCallback(
condition,
f!;
initialize = _initialize,
finalize = _finalize,
save_positions = save_positions,
)
else
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, save_positions = save_positions)
end
end

Expand All @@ -131,8 +144,13 @@ Trigger `f!(integrator)` every `Δsteps` simulation steps.
If `atinit==true`, then `f!` will additionally be triggered at initialization. Otherwise
the first trigger will be after `Δsteps`.
If `call_at_end==true`, then `f!` will be triggered at the end of the time span. Otherwise
there is no call to `f!` at the end of the time span.
The tuple `save_positions` determines whether to save before or after `f!`.
"""
function EveryXSimulationSteps(f!, Δsteps; atinit = false)
function EveryXSimulationSteps(f!, Δsteps; atinit = false, call_at_end = false, save_positions = (true, true))
steps = 0
steps_next = 0

Expand All @@ -154,15 +172,23 @@ function EveryXSimulationSteps(f!, Δsteps; atinit = false)
if steps >= steps_next
steps_next += Δsteps
return true
elseif (call_at_end && t == integrator.sol.prob.tspan[2])
return true
else
return false
end
end

if isdefined(DiffEqBase, :finalize!)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
SciMLBase.DiscreteCallback(
condition,
f!;
initialize = _initialize,
finalize = _finalize,
save_positions = save_positions,
)
else
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, save_positions = save_positions)
end
end

Expand Down
17 changes: 16 additions & 1 deletion test/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ mutable struct MyCallback
initialized::Bool
calls::Int
finalized::Bool
last_t::Real
end
MyCallback() = MyCallback(false, 0, false)
MyCallback() = MyCallback(false, 0, false, -1.0)

function Callbacks.initialize!(cb::MyCallback, integrator)
cb.initialized = true
Expand All @@ -25,13 +26,18 @@ function Callbacks.finalize!(cb::MyCallback, integrator)
end
function (cb::MyCallback)(integrator)
cb.calls += 1
cb.last_t = integrator.t
end

cb1 = MyCallback()
cb2 = MyCallback()
cb3 = MyCallback()
cb4 = MyCallback()
cb5 = MyCallback()
cb6 = MyCallback()
cb7 = MyCallback()
cb8 = MyCallback()
cb9 = MyCallback()

cbs = CallbackSet(
EveryXSimulationTime(cb1, 1 / 4),
Expand All @@ -40,6 +46,10 @@ cbs = CallbackSet(
EveryXSimulationSteps(cb4, 4, atinit = true),
EveryXSimulationSteps(_ -> sleep(1 / 32), 1),
EveryXWallTimeSeconds(cb5, 0.49, comm_ctx),
EveryXSimulationTime(cb6, 0.49, call_at_end = true),
EveryXSimulationSteps(cb7, 3, call_at_end = true),
EveryXSimulationTime(cb8, 0.3, call_at_end = false),
EveryXSimulationSteps(cb9, 3, call_at_end = false),
)

const_prob_inc = ODEProblem(
Expand All @@ -63,6 +73,11 @@ solve(const_prob_inc, LSRKEulerMethod(), dt = 1 / 32, callback = cbs)
@test cb4.calls == 9
@test cb5.calls >= 2

@test cb6.last_t == 1.0
@test cb7.last_t == 1.0
@test cb8.last_t == (1 / 32) * 29
@test cb9.last_t == (1 / 32) * 30

if isdefined(DiffEqBase, :finalize!)

@test cb1.finalized
Expand Down

0 comments on commit c130e6f

Please sign in to comment.