Skip to content

Commit

Permalink
Rename post_explicit! to cache! and post_implicit! to cache_imp!
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Jan 24, 2025
1 parent 96b7067 commit be13e01
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 103 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ClimaTimeSteppers"
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
authors = ["Climate Modeling Alliance"]
version = "0.7.40"
version = "0.8.0"

[deps]
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/ode_solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CurrentModule = ClimaTimeSteppers
## Interface

```@docs
ClimaODEFunction
AbstractAlgorithmConstraint
Unconstrained
SSP
Expand Down
19 changes: 9 additions & 10 deletions ext/ClimaTimeSteppersBenchmarkToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
"T_exp_T_lim!" => 4,
"lim!" => 4,
"dss!" => 4,
"post_explicit!" => 3,
"post_implicit!" => 4,
"cache!" => 3,
"cache_imp!" => 4,
"step!" => 1,
)
function n_calls_per_step(alg::CTS.RosenbrockAlgorithm)
Expand All @@ -47,8 +47,8 @@ function n_calls_per_step(alg::CTS.RosenbrockAlgorithm)
"T_exp_T_lim!" => CTS.n_stages(alg.tableau),
"lim!" => 0,
"dss!" => CTS.n_stages(alg.tableau),
"post_explicit!" => 0,
"post_implicit!" => CTS.n_stages(alg.tableau),
"cache!" => 0,
"cache_imp!" => CTS.n_stages(alg.tableau),
"step!" => 1,
)
end
Expand All @@ -59,8 +59,7 @@ function maybe_push!(trials₀, name, f!, args, kwargs, only)
end
end

const allowed_names =
["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "post_explicit!", "post_implicit!", "step!"]
const allowed_names = ["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "cache!", "cache_imp!", "step!"]

"""
benchmark_step(
Expand Down Expand Up @@ -89,8 +88,8 @@ Benchmark a DistributedODEIntegrator given:
- "T_exp_T_lim!"
- "lim!"
- "dss!"
- "post_explicit!"
- "post_implicit!"
- "cache!"
- "cache_imp!"
- "step!"
"""
function CTS.benchmark_step(
Expand Down Expand Up @@ -123,8 +122,8 @@ function CTS.benchmark_step(
maybe_push!(trials₀, "T_exp_T_lim!", remaining_fun(integrator), remaining_args(integrator), kwargs, only)
maybe_push!(trials₀, "lim!", f.lim!, (Xlim, p, t, u), kwargs, only)
maybe_push!(trials₀, "dss!", f.dss!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "post_explicit!", f.post_explicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "post_implicit!", f.post_implicit!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "cache!", f.cache!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "cache_imp!", f.cache_imp!, (u, p, t), kwargs, only)
maybe_push!(trials₀, "step!", SciMLBase.step!, (integrator, ), kwargs, only)
#! format: on

Expand Down
45 changes: 35 additions & 10 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,51 @@ export ClimaODEFunction, ForwardEulerODEFunction

abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end

struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
"""
ClimaODEFunction(; T_imp!, [dss!], [cache!], [cache_imp!])
ClimaODEFunction(; T_exp!, T_lim!, [T_imp!], [lim!], [dss!], [cache!], [cache_imp!])
ClimaODEFunction(; T_exp_lim!, [T_imp!], [lim!], [dss!], [cache!], [cache_imp!])
Container for all functions used to advance through a timestep:
- `T_imp!(T_imp, u, p, t)`: sets the implicit tendency
- `T_exp!(T_exp, u, p, t)`: sets the component of the explicit tendency that
is not passed through the limiter
- `T_lim!(T_lim, u, p, t)`: sets the component of the explicit tendency that
is passed through the limiter
- `T_exp_lim!(T_exp, T_lim, u, p, t)`: fused alternative to the separate
functions `T_exp!` and `T_lim!`
- `lim!(u, p, t, u_ref)`: applies the limiter to every state `u` that has
been incremented from `u_ref` by the explicit tendency component `T_lim!`
- `dss!(u, p, t)`: applies direct stiffness summation to every state `u`,
except for intermediate states generated within the implicit solver
- `cache!(u, p, t)`: updates the cache `p` to reflect the state `u` before
the first timestep and on every subsequent timestepping stage
- `cache_imp!(u, p, t)`: updates the components of the cache `p` that are
required to evaluate `T_imp!` and its Jacobian within the implicit solver
By default, `lim!`, `dss!`, and `cache!` all do nothing, and `cache_imp!` is
identical to `cache!`. Any of the tendency functions can be set to `nothing` in
order to avoid corresponding allocations in the integrator.
"""
struct ClimaODEFunction{TEL, TL, TE, TI, L, D, C, CI} <: AbstractClimaODEFunction
T_exp_T_lim!::TEL
T_lim!::TL
T_exp!::TE
T_imp!::TI
lim!::L
dss!::D
post_explicit!::PE
post_implicit!::PI
cache!::C
cache_imp!::CI
function ClimaODEFunction(;
T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ...
T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_exp_T_lim! = nothing,
T_lim! = nothing,
T_exp! = nothing,
T_imp! = nothing,
lim! = (u, p, t, u_ref) -> nothing,
dss! = (u, p, t) -> nothing,
post_explicit! = (u, p, t) -> nothing,
post_implicit! = (u, p, t) -> nothing,
cache! = (u, p, t) -> nothing,
cache_imp! = cache!,
)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, cache!, cache_imp!)

if !isnothing(T_exp_T_lim!)
@assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`"
Expand Down
4 changes: 2 additions & 2 deletions src/integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ function DiffEqBase.__init(
tdir,
)
if prob.f isa ClimaODEFunction
(; post_explicit!) = prob.f
isnothing(post_explicit!) || post_explicit!(u0, p, t0)
(; cache!) = prob.f
isnothing(cache!) || cache!(u0, p, t0)
end
DiffEqBase.initialize!(callback, u0, t0, integrator)
return integrator
Expand Down
45 changes: 13 additions & 32 deletions src/nl_solvers/newtons_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end
Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov
method without directly using the Jacobian `j(x[n])`, and instead only using
`x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`.
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, prepare_for_f!)`.
The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can
be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where
`x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
Expand All @@ -151,13 +151,13 @@ end

allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = zero(x_prototype), f2 = zero(x_prototype))

function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)
function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, prepare_for_f!)
(; default_step, step_adjustment) = alg
(; x2, f2) = cache
FT = eltype(x)
ε = FT(step_adjustment) * default_step(Δx, x)
@. x2 = x + ε * Δx
isnothing(post_implicit!) || post_implicit!(x2)
isnothing(prepare_for_f!) || prepare_for_f!(x2)
f!(f2, x2)
@. jΔx = (f2 - f) / ε
end
Expand Down Expand Up @@ -343,7 +343,7 @@ end
Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such
that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the
value of the forcing term on iteration `n`. This is done by calling
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`,
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, prepare_for_f!, j = nothing)`,
where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an
approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place.
The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
Expand Down Expand Up @@ -428,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype)
)
end

NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)
NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, prepare_for_f!, j = nothing)
(; jacobian_free_jvp, forcing_term, solve_kwargs) = alg
(; disable_preconditioner, debugger) = alg
type = solver_type(alg)
(; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache
jΔx!(jΔx, Δx) =
isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) :
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!)
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, prepare_for_f!)
opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!)
M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j
print_debug!(debugger, debugger_cache, opj, M)
Expand Down Expand Up @@ -567,25 +567,9 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing)
)
end

solve_newton!(
alg::NewtonsMethod,
cache::Nothing,
x,
f!,
j! = nothing,
post_implicit! = nothing,
post_implicit_last! = nothing,
) = nothing

NVTX.@annotate function solve_newton!(
alg::NewtonsMethod,
cache,
x,
f!,
j! = nothing,
post_implicit! = nothing,
post_implicit_last! = nothing,
)
solve_newton!(alg::NewtonsMethod, cache::Nothing, x, f!, j! = nothing, prepare_for_f! = nothing) = nothing

NVTX.@annotate function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, prepare_for_f! = nothing)
(; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg
(; krylov_method_cache, convergence_checker_cache) = cache
(; Δx, f, j) = cache
Expand All @@ -605,22 +589,19 @@ NVTX.@annotate function solve_newton!(
ldiv!(Δx, j, f)
end
else
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j)
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, prepare_for_f!, j)
end
is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"

x .-= Δx
# Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed.
# Check for convergence if necessary.
if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n)
isnothing(post_implicit_last!) || post_implicit_last!(x)
break
elseif n == max_iters
isnothing(post_implicit_last!) || post_implicit_last!(x)
elseif n < max_iters
isnothing(prepare_for_f!) || prepare_for_f!(x)
else
isnothing(post_implicit!) || post_implicit!(x)
end
if is_verbose(verbose) && n == max_iters
is_verbose(verbose)
@warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
end
end
Expand Down
28 changes: 14 additions & 14 deletions src/solvers/hard_coded_ars343.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
(; u, p, t, dt, sol, alg) = integrator
(; f) = sol.prob
(; T_imp!, lim!, dss!) = f
(; post_explicit!, post_implicit!) = f
(; cache!, cache_imp!) = f
(; tableau, newtons_method) = alg
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
(; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache
Expand Down Expand Up @@ -35,27 +35,27 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
@. temp = U # used in closures
let i = i
t_imp = t + dt * c_imp[i]
post_implicit!(U, p, t_imp)
cache_imp!(U, p, t_imp)
implicit_equation_residual! = (residual, Ui) -> begin
T_imp!(residual, Ui, p, t_imp)
@. residual = temp + dt * a_imp[i, i] * residual - Ui
end
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp)
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_cache_imp!,
nothing,
)
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
dss!(U, p, t_imp)
post_explicit!(U, p, t_imp)
cache!(U, p, t_imp)
end
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)
Expand All @@ -69,27 +69,27 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
@. temp = U # used in closures
let i = i
t_imp = t + dt * c_imp[i]
post_implicit!(U, p, t_imp)
cache_imp!(U, p, t_imp)
implicit_equation_residual! = (residual, Ui) -> begin
T_imp!(residual, Ui, p, t_imp)
@. residual = temp + dt * a_imp[i, i] * residual - Ui
end
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp)
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_cache_imp!,
nothing,
)
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
dss!(U, p, t_imp)
post_explicit!(U, p, t_imp)
cache!(U, p, t_imp)
end
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)
Expand All @@ -108,27 +108,27 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
@. temp = U # used in closures
let i = i
t_imp = t + dt * c_imp[i]
post_implicit!(U, p, t_imp)
cache_imp!(U, p, t_imp)
implicit_equation_residual! = (residual, Ui) -> begin
T_imp!(residual, Ui, p, t_imp)
@. residual = temp + dt * a_imp[i, i] * residual - Ui
end
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
end
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp)
solve_newton!(
newtons_method,
newtons_method_cache,
U,
implicit_equation_residual!,
implicit_equation_jacobian!,
call_post_implicit!,
call_cache_imp!,
nothing,
)
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
dss!(U, p, t_imp)
post_explicit!(U, p, t_imp)
cache!(U, p, t_imp)
end
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)
Expand All @@ -145,6 +145,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
dt * b_imp[3] * T_imp[3] +
dt * b_imp[4] * T_imp[4]
dss!(u, p, t_final)
post_explicit!(u, p, t_final)
cache!(u, p, t_final)
return u
end
Loading

0 comments on commit be13e01

Please sign in to comment.