diff --git a/Project.toml b/Project.toml index 6b8ab0e6..a944a413 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ClimaTimeSteppers" uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" authors = ["Climate Modeling Alliance"] -version = "0.7.40" +version = "0.8.0" [deps] ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" diff --git a/docs/src/api/ode_solvers.md b/docs/src/api/ode_solvers.md index c19659ff..05d18c27 100644 --- a/docs/src/api/ode_solvers.md +++ b/docs/src/api/ode_solvers.md @@ -7,6 +7,7 @@ CurrentModule = ClimaTimeSteppers ## Interface ```@docs +ClimaODEFunction AbstractAlgorithmConstraint Unconstrained SSP diff --git a/ext/ClimaTimeSteppersBenchmarkToolsExt.jl b/ext/ClimaTimeSteppersBenchmarkToolsExt.jl index 9a08994d..0acebceb 100644 --- a/ext/ClimaTimeSteppersBenchmarkToolsExt.jl +++ b/ext/ClimaTimeSteppersBenchmarkToolsExt.jl @@ -35,8 +35,8 @@ n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict( "T_exp_T_lim!" => 4, "lim!" => 4, "dss!" => 4, - "post_explicit!" => 3, - "post_implicit!" => 4, + "cache!" => 3, + "cache_imp!" => 4, "step!" => 1, ) function n_calls_per_step(alg::CTS.RosenbrockAlgorithm) @@ -47,8 +47,8 @@ function n_calls_per_step(alg::CTS.RosenbrockAlgorithm) "T_exp_T_lim!" => CTS.n_stages(alg.tableau), "lim!" => 0, "dss!" => CTS.n_stages(alg.tableau), - "post_explicit!" => 0, - "post_implicit!" => CTS.n_stages(alg.tableau), + "cache!" => 0, + "cache_imp!" => CTS.n_stages(alg.tableau), "step!" => 1, ) end @@ -59,8 +59,7 @@ function maybe_push!(trials₀, name, f!, args, kwargs, only) end end -const allowed_names = - ["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "post_explicit!", "post_implicit!", "step!"] +const allowed_names = ["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "cache!", "cache_imp!", "step!"] """ benchmark_step( @@ -89,8 +88,8 @@ Benchmark a DistributedODEIntegrator given: - "T_exp_T_lim!" - "lim!" - "dss!" - - "post_explicit!" - - "post_implicit!" + - "cache!" + - "cache_imp!" - "step!" """ function CTS.benchmark_step( @@ -123,8 +122,8 @@ function CTS.benchmark_step( maybe_push!(trials₀, "T_exp_T_lim!", remaining_fun(integrator), remaining_args(integrator), kwargs, only) maybe_push!(trials₀, "lim!", f.lim!, (Xlim, p, t, u), kwargs, only) maybe_push!(trials₀, "dss!", f.dss!, (u, p, t), kwargs, only) - maybe_push!(trials₀, "post_explicit!", f.post_explicit!, (u, p, t), kwargs, only) - maybe_push!(trials₀, "post_implicit!", f.post_implicit!, (u, p, t), kwargs, only) + maybe_push!(trials₀, "cache!", f.cache!, (u, p, t), kwargs, only) + maybe_push!(trials₀, "cache_imp!", f.cache_imp!, (u, p, t), kwargs, only) maybe_push!(trials₀, "step!", SciMLBase.step!, (integrator, ), kwargs, only) #! format: on diff --git a/src/functions.jl b/src/functions.jl index 04c92f4d..3a32947e 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -4,26 +4,51 @@ export ClimaODEFunction, ForwardEulerODEFunction abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end -struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction +""" + ClimaODEFunction(; T_imp!, [dss!], [cache!], [cache_imp!]) + ClimaODEFunction(; T_exp!, T_lim!, [T_imp!], [lim!], [dss!], [cache!], [cache_imp!]) + ClimaODEFunction(; T_exp_lim!, [T_imp!], [lim!], [dss!], [cache!], [cache_imp!]) + +Container for all functions used to advance through a timestep: + - `T_imp!(T_imp, u, p, t)`: sets the implicit tendency + - `T_exp!(T_exp, u, p, t)`: sets the component of the explicit tendency that + is not passed through the limiter + - `T_lim!(T_lim, u, p, t)`: sets the component of the explicit tendency that + is passed through the limiter + - `T_exp_lim!(T_exp, T_lim, u, p, t)`: fused alternative to the separate + functions `T_exp!` and `T_lim!` + - `lim!(u, p, t, u_ref)`: applies the limiter to every state `u` that has + been incremented from `u_ref` by the explicit tendency component `T_lim!` + - `dss!(u, p, t)`: applies direct stiffness summation to every state `u`, + except for intermediate states generated within the implicit solver + - `cache!(u, p, t)`: updates the cache `p` to reflect the state `u` before + the first timestep and on every subsequent timestepping stage + - `cache_imp!(u, p, t)`: updates the components of the cache `p` that are + required to evaluate `T_imp!` and its Jacobian within the implicit solver +By default, `lim!`, `dss!`, and `cache!` all do nothing, and `cache_imp!` is +identical to `cache!`. Any of the tendency functions can be set to `nothing` in +order to avoid corresponding allocations in the integrator. +""" +struct ClimaODEFunction{TEL, TL, TE, TI, L, D, C, CI} <: AbstractClimaODEFunction T_exp_T_lim!::TEL T_lim!::TL T_exp!::TE T_imp!::TI lim!::L dss!::D - post_explicit!::PE - post_implicit!::PI + cache!::C + cache_imp!::CI function ClimaODEFunction(; - T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ... - T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ... - T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ... - T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ... + T_exp_T_lim! = nothing, + T_lim! = nothing, + T_exp! = nothing, + T_imp! = nothing, lim! = (u, p, t, u_ref) -> nothing, dss! = (u, p, t) -> nothing, - post_explicit! = (u, p, t) -> nothing, - post_implicit! = (u, p, t) -> nothing, + cache! = (u, p, t) -> nothing, + cache_imp! = cache!, ) - args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!) + args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, cache!, cache_imp!) if !isnothing(T_exp_T_lim!) @assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`" diff --git a/src/integrators.jl b/src/integrators.jl index 09345304..f36cfd56 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -147,8 +147,8 @@ function DiffEqBase.__init( tdir, ) if prob.f isa ClimaODEFunction - (; post_explicit!) = prob.f - isnothing(post_explicit!) || post_explicit!(u0, p, t0) + (; cache!) = prob.f + isnothing(cache!) || cache!(u0, p, t0) end DiffEqBase.initialize!(callback, u0, t0, integrator) return integrator diff --git a/src/nl_solvers/newtons_method.jl b/src/nl_solvers/newtons_method.jl index bd48cc47..d9033f4d 100644 --- a/src/nl_solvers/newtons_method.jl +++ b/src/nl_solvers/newtons_method.jl @@ -130,7 +130,7 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov method without directly using the Jacobian `j(x[n])`, and instead only using `x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by -calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`. +calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, prepare_for_f!)`. The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where `x_prototype` is `similar` to `x` (and also to `Δx` and `f`). @@ -151,13 +151,13 @@ end allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = zero(x_prototype), f2 = zero(x_prototype)) -function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!) +function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, prepare_for_f!) (; default_step, step_adjustment) = alg (; x2, f2) = cache FT = eltype(x) ε = FT(step_adjustment) * default_step(Δx, x) @. x2 = x + ε * Δx - isnothing(post_implicit!) || post_implicit!(x2) + isnothing(prepare_for_f!) || prepare_for_f!(x2) f!(f2, x2) @. jΔx = (f2 - f) / ε end @@ -343,7 +343,7 @@ end Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the value of the forcing term on iteration `n`. This is done by calling -`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`, +`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, prepare_for_f!, j = nothing)`, where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place. The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`, @@ -428,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype) ) end -NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing) +NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, prepare_for_f!, j = nothing) (; jacobian_free_jvp, forcing_term, solve_kwargs) = alg (; disable_preconditioner, debugger) = alg type = solver_type(alg) (; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache jΔx!(jΔx, Δx) = isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) : - jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!) + jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, prepare_for_f!) opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!) M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j print_debug!(debugger, debugger_cache, opj, M) @@ -567,25 +567,9 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing) ) end -solve_newton!( - alg::NewtonsMethod, - cache::Nothing, - x, - f!, - j! = nothing, - post_implicit! = nothing, - post_implicit_last! = nothing, -) = nothing - -NVTX.@annotate function solve_newton!( - alg::NewtonsMethod, - cache, - x, - f!, - j! = nothing, - post_implicit! = nothing, - post_implicit_last! = nothing, -) +solve_newton!(alg::NewtonsMethod, cache::Nothing, x, f!, j! = nothing, prepare_for_f! = nothing) = nothing + +NVTX.@annotate function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, prepare_for_f! = nothing) (; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg (; krylov_method_cache, convergence_checker_cache) = cache (; Δx, f, j) = cache @@ -605,7 +589,7 @@ NVTX.@annotate function solve_newton!( ldiv!(Δx, j, f) end else - solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j) + solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, prepare_for_f!, j) end is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))" @@ -613,14 +597,11 @@ NVTX.@annotate function solve_newton!( # Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed. # Check for convergence if necessary. if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n) - isnothing(post_implicit_last!) || post_implicit_last!(x) break - elseif n == max_iters - isnothing(post_implicit_last!) || post_implicit_last!(x) + elseif n < max_iters + isnothing(prepare_for_f!) || prepare_for_f!(x) else - isnothing(post_implicit!) || post_implicit!(x) - end - if is_verbose(verbose) && n == max_iters + is_verbose(verbose) @warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))" end end diff --git a/src/solvers/hard_coded_ars343.jl b/src/solvers/hard_coded_ars343.jl index d6ee5b3b..bb440129 100644 --- a/src/solvers/hard_coded_ars343.jl +++ b/src/solvers/hard_coded_ars343.jl @@ -4,7 +4,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) (; u, p, t, dt, sol, alg) = integrator (; f) = sol.prob (; T_imp!, lim!, dss!) = f - (; post_explicit!, post_implicit!) = f + (; cache!, cache_imp!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau (; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache @@ -35,7 +35,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) @. temp = U # used in closures let i = i t_imp = t + dt * c_imp[i] - post_implicit!(U, p, t_imp) + cache_imp!(U, p, t_imp) implicit_equation_residual! = (residual, Ui) -> begin T_imp!(residual, Ui, p, t_imp) @. residual = temp + dt * a_imp[i, i] * residual - Ui @@ -43,19 +43,19 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) implicit_equation_jacobian! = (jacobian, Ui) -> begin T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) end - call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp) + call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp) solve_newton!( newtons_method, newtons_method_cache, U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, + call_cache_imp!, nothing, ) @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) dss!(U, p, t_imp) - post_explicit!(U, p, t_imp) + cache!(U, p, t_imp) end T_lim!(T_lim[i], U, p, t_exp) T_exp!(T_exp[i], U, p, t_exp) @@ -69,7 +69,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) @. temp = U # used in closures let i = i t_imp = t + dt * c_imp[i] - post_implicit!(U, p, t_imp) + cache_imp!(U, p, t_imp) implicit_equation_residual! = (residual, Ui) -> begin T_imp!(residual, Ui, p, t_imp) @. residual = temp + dt * a_imp[i, i] * residual - Ui @@ -77,19 +77,19 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) implicit_equation_jacobian! = (jacobian, Ui) -> begin T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) end - call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp) + call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp) solve_newton!( newtons_method, newtons_method_cache, U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, + call_cache_imp!, nothing, ) @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) dss!(U, p, t_imp) - post_explicit!(U, p, t_imp) + cache!(U, p, t_imp) end T_lim!(T_lim[i], U, p, t_exp) T_exp!(T_exp[i], U, p, t_exp) @@ -108,7 +108,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) @. temp = U # used in closures let i = i t_imp = t + dt * c_imp[i] - post_implicit!(U, p, t_imp) + cache_imp!(U, p, t_imp) implicit_equation_residual! = (residual, Ui) -> begin T_imp!(residual, Ui, p, t_imp) @. residual = temp + dt * a_imp[i, i] * residual - Ui @@ -116,19 +116,19 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) implicit_equation_jacobian! = (jacobian, Ui) -> begin T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) end - call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp) + call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp) solve_newton!( newtons_method, newtons_method_cache, U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, + call_cache_imp!, nothing, ) @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) dss!(U, p, t_imp) - post_explicit!(U, p, t_imp) + cache!(U, p, t_imp) end T_lim!(T_lim[i], U, p, t_exp) T_exp!(T_exp[i], U, p, t_exp) @@ -145,6 +145,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343) dt * b_imp[3] * T_imp[3] + dt * b_imp[4] * T_imp[4] dss!(u, p, t_final) - post_explicit!(u, p, t_final) + cache!(u, p, t_final) return u end diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 995bb147..510ce2d6 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -49,7 +49,7 @@ end function step_u!(integrator, cache::IMEXARKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob - (; post_explicit!, post_implicit!) = f + (; cache!, cache_imp!) = f (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau @@ -83,7 +83,7 @@ function step_u!(integrator, cache::IMEXARKCache) isnothing(T_imp!) || fused_increment!(u, dt, b_imp, T_imp, Val(s)) dss!(u, p, t_final) - post_explicit!(u, p, t_final) + cache!(u, p, t_final) return u end @@ -98,7 +98,7 @@ end @inline function update_stage!(integrator, cache::IMEXARKCache, i::Int) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob - (; post_explicit!, post_implicit!) = f + (; cache!, cache_imp!) = f (; T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau @@ -122,16 +122,15 @@ end i ≠ 1 && dss!(U, p, t_exp) if isnothing(T_imp!) || iszero(a_imp[i, i]) - i ≠ 1 && post_explicit!(U, p, t_exp) + i ≠ 1 && cache!(U, p, t_exp) if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]) # If its coefficient is 0, T_imp[i] is being treated explicitly. isnothing(T_imp!) || T_imp!(T_imp[i], U, p, t_imp) end else # Implicit solve @assert !isnothing(newtons_method) - i ≠ 1 && post_implicit!(U, p, t_imp) + i ≠ 1 && cache_imp!(U, p, t_imp) @. temp = U - # TODO: can/should we remove these closures? implicit_equation_residual! = (residual, Ui) -> begin T_imp!(residual, Ui, p, t_imp) @. residual = temp + dt * a_imp[i, i] * residual - Ui @@ -139,15 +138,14 @@ end implicit_equation_jacobian! = (jacobian, Ui) -> begin T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) end - call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp) + implicit_equation_cache! = Ui -> cache_imp!(Ui, p, t_imp) solve_newton!( newtons_method, newtons_method_cache, U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, - nothing, + implicit_equation_cache!, ) if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]) # If T_imp[i] is being treated implicitly, ensure that it @@ -155,7 +153,7 @@ end @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) end dss!(U, p, t_imp) - post_explicit!(U, p, t_imp) + cache!(U, p, t_imp) end if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]) diff --git a/src/solvers/imex_ssprk.jl b/src/solvers/imex_ssprk.jl index 750ebd63..9af75d84 100644 --- a/src/solvers/imex_ssprk.jl +++ b/src/solvers/imex_ssprk.jl @@ -55,7 +55,7 @@ end function step_u!(integrator, cache::IMEXSSPRKCache) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob - (; post_explicit!, post_implicit!) = f + (; cache!, cache_imp!) = f (; T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!) = f (; tableau, newtons_method) = alg (; a_imp, b_imp, c_exp, c_imp) = tableau @@ -105,16 +105,15 @@ function step_u!(integrator, cache::IMEXSSPRKCache) i ≠ 1 && dss!(U, p, t_exp) if isnothing(T_imp!) || iszero(a_imp[i, i]) - i ≠ 1 && post_explicit!(U, p, t_exp) + i ≠ 1 && cache!(U, p, t_exp) if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]) # If its coefficient is 0, T_imp[i] is being treated explicitly. isnothing(T_imp!) || T_imp!(T_imp[i], U, p, t_imp) end else # Implicit solve @assert !isnothing(newtons_method) - i ≠ 1 && post_implicit!(U, p, t_imp) + i ≠ 1 && cache_imp!(U, p, t_imp) @. temp = U - # TODO: can/should we remove these closures? implicit_equation_residual! = (residual, Ui) -> begin T_imp!(residual, Ui, p, t_imp) @. residual = temp + dt * a_imp[i, i] * residual - Ui @@ -122,15 +121,14 @@ function step_u!(integrator, cache::IMEXSSPRKCache) implicit_equation_jacobian! = (jacobian, Ui) -> begin T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) end - call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp) + implicit_equation_cache! = Ui -> cache_imp!(Ui, p, t_imp) solve_newton!( newtons_method, newtons_method_cache, U, implicit_equation_residual!, implicit_equation_jacobian!, - call_post_implicit!, - nothing, + implicit_equation_cache!, ) if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]) # If T_imp[i] is being treated implicitly, ensure that it @@ -138,7 +136,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache) @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) end dss!(U, p, t_imp) - post_explicit!(U, p, t_imp) + cache!(U, p, t_imp) end if !iszero(β[i]) @@ -173,7 +171,7 @@ function step_u!(integrator, cache::IMEXSSPRKCache) end dss!(u, p, t_final) - post_explicit!(u, p, t_final) + cache!(u, p, t_final) return u end diff --git a/src/solvers/rosenbrock.jl b/src/solvers/rosenbrock.jl index e3b02407..37b9384d 100644 --- a/src/solvers/rosenbrock.jl +++ b/src/solvers/rosenbrock.jl @@ -123,7 +123,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages} T_exp_lim! = int.sol.prob.f.T_exp_T_lim! tgrad! = isnothing(T_imp!) ? nothing : T_imp!.tgrad - (; post_explicit!, dss!) = int.sol.prob.f + (; cache!, dss!) = int.sol.prob.f # TODO: This is only valid when Γ[i, i] is constant, otherwise we have to # move this in the for loop @@ -150,16 +150,13 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages} U .+= A[i, j] .* k[j] end - # NOTE: post_explicit! is a misnomer; should be post_stage! - if !isnothing(post_explicit!) - # We apply DSS and update p on every stage but the first, and at the - # end of each timestep. Since the first stage is unchanged from the - # end of the previous timestep, this order of operations ensures - # that the state is always continuous and that p is consistent with - # the state, including between timesteps. - (i != 1) && dss!(U, p, t + αi * dt) - (i != 1) && post_explicit!(U, p, t + αi * dt) - end + # We apply DSS and update the cache on every stage but the first, + # and at the end of each timestep. Since the first stage is + # unchanged from the end of the previous timestep, this order of + # operations ensures that the state is always continuous and + # consistent with the cache, including between timesteps. + (i != 1) && dss!(U, p, t + αi * dt) + (i != 1) && cache!(U, p, t + αi * dt) if !isnothing(T_imp!) T_imp!(fU_imp, U, p, t + αi * dt) @@ -203,7 +200,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages} end dss!(u, p, t + dt) - post_explicit!(u, p, t + dt) + cache!(u, p, t + dt) return nothing end diff --git a/test/problems.jl b/test/problems.jl index ebe9a9f1..9be6fbfd 100644 --- a/test/problems.jl +++ b/test/problems.jl @@ -493,8 +493,8 @@ function climacore_2Dheat_test_cts(::Type{FT}) where {FT} # we add implicit pieces here for inference analysis T_lim! = (Yₜ, u, _, t) -> nothing - post_implicit! = (u, _, t) -> nothing - post_explicit! = (u, _, t) -> nothing + cache_imp! = (u, _, t) -> nothing + cache! = (u, _, t) -> nothing jacobian = ClimaCore.MatrixFields.FieldMatrix((@name(u), @name(u)) => FT(-1) * LinearAlgebra.I) @@ -505,7 +505,7 @@ function climacore_2Dheat_test_cts(::Type{FT}) where {FT} tgrad = (∂Y∂t, Y, p, t) -> (∂Y∂t .= 0), ) - tendency_func = ClimaODEFunction(; T_exp!, T_imp!, dss!, post_implicit!, post_explicit!) + tendency_func = ClimaODEFunction(; T_exp!, T_imp!, dss!, cache_imp!, cache!) split_tendency_func = tendency_func make_prob(func) = ODEProblem(func, init_state, (FT(0), t_end), nothing) IntegratorTestCase(