From 449f5eeaec78ea8fb2e9a4ace4c5b8ce1e68ba74 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 22 Jan 2025 17:00:10 -0500 Subject: [PATCH] Use CTS src from main --- Project.toml | 51 +- docs/Manifest.toml | 30 +- docs/src/dev/report_gen.jl | 39 +- src/Callbacks.jl | 13 +- src/ClimaTimeSteppers.jl | 35 +- src/functions.jl | 49 +- src/integrators.jl | 68 ++- src/nl_solvers/newtons_method.jl | 93 ++-- src/solvers/imex_ark.jl | 215 ++++---- src/solvers/imex_tableaus.jl | 720 +++++++++++++++++-------- src/solvers/mis.jl | 50 +- src/solvers/multirate.jl | 20 +- src/solvers/rosenbrock.jl | 247 +++++++-- src/solvers/wickerskamarock.jl | 4 +- src/utilities/convergence_checker.jl | 6 +- src/utilities/convergence_condition.jl | 13 +- src/utilities/update_signal_handler.jl | 50 +- 17 files changed, 1174 insertions(+), 529 deletions(-) diff --git a/Project.toml b/Project.toml index 670c29c4..94db9a02 100644 --- a/Project.toml +++ b/Project.toml @@ -1,28 +1,69 @@ name = "ClimaTimeSteppers" uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" authors = ["Climate Modeling Alliance"] -version = "0.7.0" +version = "0.7.39" [deps] ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" -DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" +NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[weakdeps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" + +[extensions] +ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"] + [compat] +Aqua = "0.8" +BenchmarkTools = "1" ClimaComms = "0.4, 0.5, 0.6" +ClimaCore = "0.10, 0.11, 0.12, 0.13, 0.14" +CUDA = "3, 4, 5" +Colors = "0.12, 0.13" DataStructures = "0.18" DiffEqBase = "6" -DiffEqCallbacks = "2" +Distributions = "0.25" KernelAbstractions = "0.7, 0.8, 0.9" Krylov = "0.8, 0.9" +LinearAlgebra = "1" LinearOperators = "2" -SciMLBase = "1" +MPI = "0.20" +NVTX = "0.3" +ODEConvergenceTester = "0.2" +OrderedCollections = "1" +PrettyTables = "2" +Random = "1" +SafeTestsets = "0.1" +SciMLBase = "1, 2" StaticArrays = "1" -julia = "1.8" +StatsBase = "0.33, 0.34" +Test = "1" +julia = "1.9" + +[extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +ODEConvergenceTester = "42a5c2e1-f365-4540-8ca5-3684de3ecd95" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Aqua","ClimaCore","Distributions","Krylov", "MPI","ODEConvergenceTester","PrettyTables","Random","SafeTestsets","Test"] diff --git a/docs/Manifest.toml b/docs/Manifest.toml index e08e3586..8ea46fa6 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -230,10 +230,20 @@ version = "0.14.19" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" [[deps.ClimaTimeSteppers]] -deps = ["ClimaComms", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"] +deps = ["ClimaComms", "Colors", "DataStructures", "DiffEqBase", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "NVTX", "SciMLBase", "StaticArrays"] path = ".." uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" -version = "0.7.0" +version = "0.7.39" + + [deps.ClimaTimeSteppers.extensions] + ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"] + + [deps.ClimaTimeSteppers.weakdeps] + BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" + PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" + StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [[deps.CloseOpenIntervals]] deps = ["Static", "StaticArrayInterface"] @@ -415,16 +425,6 @@ version = "6.130.0" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -[[deps.DiffEqCallbacks]] -deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "Functors", "LinearAlgebra", "Markdown", "NLsolve", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"] -git-tree-sha1 = "d0b94b3694d55e7eedeee918e7daee9e3b873399" -uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def" -version = "2.35.0" - - [deps.DiffEqCallbacks.weakdeps] - OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" - Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" - [[deps.DiffResults]] deps = ["StaticArraysCore"] git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" @@ -651,12 +651,6 @@ git-tree-sha1 = "b104d487b34566608f8b4e1c39fb0b10aa279ff8" uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf" version = "0.1.3" -[[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.12" - [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" diff --git a/docs/src/dev/report_gen.jl b/docs/src/dev/report_gen.jl index ebdecc17..9eb816d0 100644 --- a/docs/src/dev/report_gen.jl +++ b/docs/src/dev/report_gen.jl @@ -1,4 +1,5 @@ using ClimaTimeSteppers +import ClimaTimeSteppers as CTS using Test using InteractiveUtils: subtypes @@ -11,7 +12,43 @@ all_subtypes(::Type{T}) where {T} = isabstracttype(T) ? vcat(all_subtypes.(subty @testset "IMEX Algorithm Convergence" begin title = "IMEX Algorithms" - algorithm_names = map(T -> T(), all_subtypes(ClimaTimeSteppers.IMEXAlgorithmName)) + # algorithm_names = map(T -> T(), all_subtypes(ClimaTimeSteppers.IMEXAlgorithmName)) + algorithm_names = [ + # CTS.SSP22Heuns(), + # CTS.SSP33ShuOsher(), + # CTS.RK4(), + # CTS.ARK2GKC(), + CTS.ARS111(), + CTS.ARS121(), + CTS.ARS122(), + CTS.ARS222(), + CTS.ARS232(), + CTS.ARS233(), + CTS.ARS343(), + CTS.ARS443(), + CTS.SSP222(), + CTS.SSP322(), + CTS.SSP332(), + CTS.SSP333(), + CTS.SSP433(), + CTS.DBM453(), + CTS.HOMMEM1(), + CTS.IMKG232a(), + CTS.IMKG232b(), + CTS.IMKG242a(), + CTS.IMKG242b(), + CTS.IMKG243a(), + CTS.IMKG252a(), + CTS.IMKG252b(), + CTS.IMKG253a(), + CTS.IMKG253b(), + CTS.IMKG254a(), + CTS.IMKG254b(), + CTS.IMKG254c(), + CTS.IMKG342a(), + CTS.IMKG343a(), + # CTS.SSPKnoth() + ] test_imex_algorithms(title, algorithm_names, ark_analytic_nonlin_test_cts(Float64), 200) test_imex_algorithms(title, algorithm_names, ark_analytic_sys_test_cts(Float64), 400) test_imex_algorithms(title, algorithm_names, ark_analytic_test_cts(Float64), 40000; super_convergence = (ARS121(),)) diff --git a/src/Callbacks.jl b/src/Callbacks.jl index 19aaae96..c4fe189d 100644 --- a/src/Callbacks.jl +++ b/src/Callbacks.jl @@ -6,6 +6,7 @@ A suite of callback functions to be used with the ClimaTimeSteppers.jl ODE solve module Callbacks import ClimaComms, DiffEqBase +import SciMLBase """ ClimaTimeSteppers.Callbacks.initialize!(f!::F, integrator) @@ -72,9 +73,9 @@ function EveryXWallTimeSeconds(f!, Δwt, comm_ctx::ClimaComms.AbstractCommsConte end if isdefined(DiffEqBase, :finalize!) - DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize) + SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize) else - DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize) + SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize) end end @@ -115,9 +116,9 @@ function EveryXSimulationTime(f!, Δt; atinit = false) end end if isdefined(DiffEqBase, :finalize!) - DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize) + SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize) else - DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize) + SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize) end end @@ -159,9 +160,9 @@ function EveryXSimulationSteps(f!, Δsteps; atinit = false) end if isdefined(DiffEqBase, :finalize!) - DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize) + SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize) else - DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize) + SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize) end end diff --git a/src/ClimaTimeSteppers.jl b/src/ClimaTimeSteppers.jl index 7f7789a4..6fab99d6 100644 --- a/src/ClimaTimeSteppers.jl +++ b/src/ClimaTimeSteppers.jl @@ -19,7 +19,7 @@ JuliaDiffEq terminology: SplitODEProlem(fL, fR) - * `ODEProblem` from OrdinaryDiffEq.jl + * `ODEProblem` from SciMLBase.jl - use `jac` option to `ODEFunction` for linear + full IMEX (https://docs.sciml.ai/latest/features/performance_overloads/#ode_explicit_jac-1) * `SplitODEProblem` for linear + remainder IMEX * `MultirateODEProblem` for true multirate @@ -45,22 +45,24 @@ module ClimaTimeSteppers using KernelAbstractions -using KernelAbstractions.Extras: @unroll using LinearAlgebra using LinearOperators using StaticArrays +import ClimaComms +# using Colors +using NVTX -export AbstractAlgorithmName, AbstractAlgorithmConstraint, Unconstrained, SSPConstrained +export AbstractAlgorithmName, AbstractAlgorithmConstraint, Unconstrained, SSP array_device(::Union{Array, SArray, MArray}) = CPU() -realview(x::Union{Array, SArray, MArray}) = x -realview(x::Array) = x +array_device(x) = CUDADevice() # assume CUDA -import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov +import DiffEqBase, SciMLBase, LinearAlgebra, Krylov +include(joinpath("utilities", "sparse_coeffs.jl")) +include(joinpath("utilities", "fused_increment.jl")) include("sparse_containers.jl") include("functions.jl") -include("operators.jl") abstract type DistributedODEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end @@ -69,7 +71,7 @@ abstract type AbstractAlgorithmName end """ AbstractAlgorithmConstraint -A mechanism for restricting which operations can be performed by an algorithm +A mechanism for constraining which operations can be performed by an algorithm for solving ODEs. For example, an unconstrained algorithm might compute a Runge-Kutta stage by @@ -87,15 +89,17 @@ Indicates that an algorithm may perform any supported operations. """ struct Unconstrained <: AbstractAlgorithmConstraint end +default_constraint(::AbstractAlgorithmName) = Unconstrained() + """ - SSPConstrained + SSP Indicates that an algorithm must be "strong stability preserving", which makes it easier to guarantee that the algorithm will preserve monotonicity properties satisfied by the initial state. For example, this ensures that the algorithm will be able to use limiters in a mathematically consistent way. """ -struct SSPConstrained <: AbstractAlgorithmConstraint end +struct SSP <: AbstractAlgorithmConstraint end SciMLBase.allowscomplex(alg::DistributedODEAlgorithm) = true include("integrators.jl") @@ -110,17 +114,22 @@ n_stages_ntuple(::Type{<:NTuple{Nstages}}) where {Nstages} = Nstages n_stages_ntuple(::Type{<:SVector{Nstages}}) where {Nstages} = Nstages # Include concrete implementations +const SPCO = SparseCoeffs + include("solvers/imex_tableaus.jl") +include("solvers/explicit_tableaus.jl") include("solvers/imex_ark.jl") -include("solvers/imex_ssp.jl") +include("solvers/imex_ssprk.jl") include("solvers/multirate.jl") include("solvers/lsrk.jl") -include("solvers/ssprk.jl") -include("solvers/ark.jl") include("solvers/mis.jl") include("solvers/wickerskamarock.jl") include("solvers/rosenbrock.jl") include("Callbacks.jl") + +benchmark_step(integrator, device) = + @warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension" + end diff --git a/src/functions.jl b/src/functions.jl index 3e99e66d..04c92f4d 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -1,18 +1,47 @@ import DiffEqBase +export AbstractClimaODEFunction export ClimaODEFunction, ForwardEulerODEFunction -Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, S} <: DiffEqBase.AbstractODEFunction{true} - T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ... - T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ... - T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ... - lim!::L = (u, p, t, u_ref) -> nothing - dss!::D = (u, p, t) -> nothing - stage_callback!::S = (u, p, t) -> nothing +abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end + +struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction + T_exp_T_lim!::TEL + T_lim!::TL + T_exp!::TE + T_imp!::TI + lim!::L + dss!::D + post_explicit!::PE + post_implicit!::PI + function ClimaODEFunction(; + T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ... + T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ... + T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ... + T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ... + lim! = (u, p, t, u_ref) -> nothing, + dss! = (u, p, t) -> nothing, + post_explicit! = (u, p, t) -> nothing, + post_implicit! = (u, p, t) -> nothing, + ) + args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!) + + if !isnothing(T_exp_T_lim!) + @assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`" + @assert isnothing(T_lim!) "`T_exp_T_lim!` was passed, `T_lim!` must be `nothing`" + end + if !isnothing(T_exp!) && !isnothing(T_lim!) + @warn "Both `T_exp!` and `T_lim!` are not `nothing`, please use `T_exp_T_lim!` instead." + end + return new{typeof.(args)...}(args...) + end end -# Don't wrap a ClimaODEFunction in an ODEFunction (makes ODEProblem work). -DiffEqBase.ODEFunction{iip}(f::ClimaODEFunction) where {iip} = f -DiffEqBase.ODEFunction(f::ClimaODEFunction) = f +has_T_exp(f::ClimaODEFunction) = !isnothing(f.T_exp!) || !isnothing(f.T_exp_T_lim!) +has_T_lim(f::ClimaODEFunction) = !isnothing(f.lim!) && (!isnothing(f.T_lim!) || !isnothing(f.T_exp_T_lim!)) + +# Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work). +DiffEqBase.ODEFunction{iip}(f::AbstractClimaODEFunction) where {iip} = f +DiffEqBase.ODEFunction(f::AbstractClimaODEFunction) = f """ ForwardEulerODEFunction(f; jac_prototype, Wfact, tgrad) diff --git a/src/integrators.jl b/src/integrators.jl index 1eb830eb..09345304 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -37,8 +37,34 @@ mutable struct DistributedODEIntegrator{ # DiffEqBase.initialize! and DiffEqBase.finalize! cache::cacheType sol::solType + tdir::tType # see https://docs.sciml.ai/DiffEqCallbacks/stable/output_saving/#DiffEqCallbacks.SavingCallback end +""" + SavedValues{tType<:Real, savevalType} + +A struct used to save values of the time in `t::Vector{tType}` and +additional values in `saveval::Vector{savevalType}`. + +From DiffEqCallbacks. +""" +struct SavedValues{tType, savevalType} + t::Vector{tType} + saveval::Vector{savevalType} +end + +""" + SavedValues(tType::DataType, savevalType::DataType) + +Return `SavedValues{tType, savevalType}` with empty storage vectors. + +From DiffEqCallbacks. +""" +function SavedValues(::Type{tType}, ::Type{savevalType}) where {tType, savevalType} + SavedValues{tType, savevalType}(Vector{tType}(), Vector{savevalType}()) +end + + # helper function for setting up min/max heaps for tstops and saveat function tstops_and_saveat_heaps(t0, tf, tstops, saveat) FT = typeof(tf) @@ -64,6 +90,8 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat) return tstops, saveat end +compute_tdir(ts) = ts[1] > ts[end] ? sign(ts[end] - ts[1]) : eltype(ts)(1) + # called by DiffEqBase.init and DiffEqBase.solve function DiffEqBase.__init( prob::DiffEqBase.AbstractODEProblem, @@ -75,9 +103,10 @@ function DiffEqBase.__init( save_everystep = false, callback = nothing, advance_to_tstop = false, - save_func = (u, t) -> copy(u), # custom kwarg - dtchangeable = true, # custom kwarg - stepstop = -1, # custom kwarg + save_func = (u, t) -> copy(u), # custom kwarg + dtchangeable = true, # custom kwarg + stepstop = -1, # custom kwarg + tdir = compute_tdir(prob.tspan), # kwargs..., ) (; u0, p) = prob @@ -92,8 +121,7 @@ function DiffEqBase.__init( tstops, saveat = tstops_and_saveat_heaps(t0, tf, tstops, saveat) sol = DiffEqBase.build_solution(prob, alg, typeof(t0)[], typeof(save_func(u0, t0))[]) - saving_callback = - NonInterpolatingSavingCallback(save_func, DiffEqCallbacks.SavedValues(sol.t, sol.u), save_everystep) + saving_callback = NonInterpolatingSavingCallback(save_func, SavedValues(sol.t, sol.u), save_everystep) callback = DiffEqBase.CallbackSet(callback, saving_callback) isempty(callback.continuous_callbacks) || error("Continuous callbacks are not supported") @@ -116,7 +144,12 @@ function DiffEqBase.__init( false, init_cache(prob, alg; dt, kwargs...), sol, + tdir, ) + if prob.f isa ClimaODEFunction + (; post_explicit!) = prob.f + isnothing(post_explicit!) || post_explicit!(u0, p, t0) + end DiffEqBase.initialize!(callback, u0, t0, integrator) return integrator end @@ -155,7 +188,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, alg::Distribute end # either called directly (after init), or by DiffEqBase.solve (via __solve) -function DiffEqBase.solve!(integrator::DistributedODEIntegrator) +NVTX.@annotate function DiffEqBase.solve!(integrator::DistributedODEIntegrator) while !isempty(integrator.tstops) && integrator.step != integrator.stepstop __step!(integrator) end @@ -192,6 +225,17 @@ is_past_t(integrator, t) = tdir(integrator) * (t - integrator.t) < zero(integrat reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) = integrator.t == tstop || (!stop_at_tstop && is_past_t(integrator, tstop)) + +@inline unrolled_foreach(::Tuple{}, integrator) = nothing +@inline unrolled_foreach(callback, integrator) = + callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing +@inline unrolled_foreach(discrete_callbacks::Tuple{Any}, integrator) = + unrolled_foreach(first(discrete_callbacks), integrator) +@inline function unrolled_foreach(discrete_callbacks::Tuple, integrator) + unrolled_foreach(first(discrete_callbacks), integrator) + unrolled_foreach(Base.tail(discrete_callbacks), integrator) +end + function __step!(integrator) (; _dt, dtchangeable, tstops) = integrator @@ -213,11 +257,7 @@ function __step!(integrator) # apply callbacks discrete_callbacks = integrator.callback.discrete_callbacks - for callback in discrete_callbacks - if callback.condition(integrator.u, integrator.t, integrator) - callback.affect!(integrator) - end - end + unrolled_foreach(discrete_callbacks, integrator) # remove tstops that were just reached while !isempty(tstops) && reached_tstop(integrator, first(tstops)) @@ -226,7 +266,9 @@ function __step!(integrator) end # solvers need to define this interface -step_u!(integrator) = step_u!(integrator, integrator.cache) +NVTX.@annotate function step_u!(integrator) + step_u!(integrator, integrator.cache) +end DiffEqBase.get_dt(integrator::DistributedODEIntegrator) = integrator._dt function set_dt!(integrator::DistributedODEIntegrator, dt) @@ -274,5 +316,5 @@ function NonInterpolatingSavingCallback(save_func, saved_values, save_everystep) end initialize(cb, u, t, integrator) = condition(u, t, integrator) && affect!(integrator) finalize(cb, u, t, integrator) = !save_everystep && !isempty(integrator.saveat) && affect!(integrator) - DiffEqBase.DiscreteCallback(condition, affect!; initialize, finalize) + SciMLBase.DiscreteCallback(condition, affect!; initialize, finalize) end diff --git a/src/nl_solvers/newtons_method.jl b/src/nl_solvers/newtons_method.jl index 7198d37c..bd48cc47 100644 --- a/src/nl_solvers/newtons_method.jl +++ b/src/nl_solvers/newtons_method.jl @@ -2,7 +2,6 @@ export NewtonsMethod, KrylovMethod export JacobianFreeJVP, ForwardDiffJVP, ForwardDiffStepSize export ForwardDiffStepSize1, ForwardDiffStepSize2, ForwardDiffStepSize3 export ForcingTerm, ConstantForcing, EisenstatWalkerForcing -export Verbosity # TODO: Implement AutoDiffJVP after ClimaAtmos's cache is moved from f! to x (so # that we only need to define Dual.(x), and not also make_dual(f!)). @@ -131,10 +130,10 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov method without directly using the Jacobian `j(x[n])`, and instead only using `x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by -calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f)`. The `jΔx` passed to -a Jacobian-free JVP is modified in-place. The `cache` can be obtained with -`allocate_cache(::JacobianFreeJVP, x_prototype)`, where `x_prototype` is -`similar` to `x` (and also to `Δx` and `f`). +calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`. +The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can +be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where +`x_prototype` is `similar` to `x` (and also to `Δx` and `f`). """ abstract type JacobianFreeJVP end @@ -150,14 +149,15 @@ Base.@kwdef struct ForwardDiffJVP{S <: ForwardDiffStepSize, T} <: JacobianFreeJV step_adjustment::T = 1 end -allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = similar(x_prototype), f2 = similar(x_prototype)) +allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = zero(x_prototype), f2 = zero(x_prototype)) -function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f) +function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!) (; default_step, step_adjustment) = alg (; x2, f2) = cache FT = eltype(x) ε = FT(step_adjustment) * default_step(Δx, x) @. x2 = x + ε * Δx + isnothing(post_implicit!) || post_implicit!(x2) f!(f2, x2) @. jΔx = (f2 - f) / ε end @@ -343,10 +343,10 @@ end Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the value of the forcing term on iteration `n`. This is done by calling -`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)`, where `f` is -`f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an approximation -of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place. The -`cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`, +`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`, +where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an +approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place. +The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`, where `x_prototype` is `similar` to `x` (and also to `Δx` and `f`). This is primarily a wrapper for a `Krylov.KrylovSolver` from `Krylov.jl`. In @@ -428,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype) ) end -function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing) +NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing) (; jacobian_free_jvp, forcing_term, solve_kwargs) = alg (; disable_preconditioner, debugger) = alg type = solver_type(alg) (; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache jΔx!(jΔx, Δx) = isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) : - jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f) + jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!) opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!) M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j print_debug!(debugger, debugger_cache, opj, M) @@ -558,33 +558,43 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing) (; update_j, krylov_method, convergence_checker) = alg @assert !(isnothing(j_prototype) && (isnothing(krylov_method) || isnothing(krylov_method.jacobian_free_jvp))) return (; - update_j_cache = allocate_cache(update_j, eltype(x_prototype)), krylov_method_cache = isnothing(krylov_method) ? nothing : allocate_cache(krylov_method, x_prototype), convergence_checker_cache = isnothing(convergence_checker) ? nothing : allocate_cache(convergence_checker, x_prototype), - Δx = similar(x_prototype), - f = similar(x_prototype), - j = isnothing(j_prototype) ? nothing : similar(j_prototype), + Δx = zero(x_prototype), + f = zero(x_prototype), + j = isnothing(j_prototype) ? nothing : zero(j_prototype), ) end -function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing) +solve_newton!( + alg::NewtonsMethod, + cache::Nothing, + x, + f!, + j! = nothing, + post_implicit! = nothing, + post_implicit_last! = nothing, +) = nothing + +NVTX.@annotate function solve_newton!( + alg::NewtonsMethod, + cache, + x, + f!, + j! = nothing, + post_implicit! = nothing, + post_implicit_last! = nothing, +) (; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg - (; update_j_cache, krylov_method_cache, convergence_checker_cache) = cache + (; krylov_method_cache, convergence_checker_cache) = cache (; Δx, f, j) = cache - if (!isnothing(j)) && needs_update!(update_j, update_j_cache, NewNewtonSolve()) + if (!isnothing(j)) && needs_update!(update_j, NewNewtonSolve()) j!(j, x) end - for n in 0:max_iters - # Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed. - n > 0 && (x .-= Δx) - if n == max_iters && isnothing(convergence_checker) - is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = N/A" - break - end - + for n in 1:max_iters # Compute Δx[n]. - if (!isnothing(j)) && needs_update!(update_j, update_j_cache, NewNewtonIteration()) + if (!isnothing(j)) && needs_update!(update_j, NewNewtonIteration()) j!(j, x) end f!(f, x) @@ -595,22 +605,23 @@ function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing) ldiv!(Δx, j, f) end else - solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, j) + solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j) end is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))" + x .-= Δx + # Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed. # Check for convergence if necessary. - if !isnothing(convergence_checker) - check_convergence!(convergence_checker, convergence_checker_cache, x, Δx, n) && break - n == max_iters && @warn "Newton's method did not converge within $n iterations" + if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n) + isnothing(post_implicit_last!) || post_implicit_last!(x) + break + elseif n == max_iters + isnothing(post_implicit_last!) || post_implicit_last!(x) + else + isnothing(post_implicit!) || post_implicit!(x) + end + if is_verbose(verbose) && n == max_iters + @warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))" end - end -end - -function update!(alg::NewtonsMethod, cache, signal::UpdateSignal, j!) - (; update_j) = alg - (; update_j_cache, j) = cache - if (!isnothing(j)) && needs_update!(update_j, update_j_cache, signal) - j!(j) end end diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index c831a737..116e64e8 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -1,3 +1,5 @@ +import NVTX + has_jac(T_imp!) = hasfield(typeof(T_imp!), :Wfact) && hasfield(typeof(T_imp!), :jac_prototype) && @@ -30,140 +32,147 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unco inds = ntuple(i -> i, s) inds_T_exp = filter(i -> !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]), inds) inds_T_imp = filter(i -> !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]), inds) - U = SparseContainer(map(i -> similar(u0), collect(1:length(inds))), inds) - T_lim = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_exp))), inds_T_exp) - T_exp = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_exp))), inds_T_exp) - T_imp = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_imp))), inds_T_imp) - temp = similar(u0) + U = zero(u0) + T_lim = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_exp))), inds_T_exp) + T_exp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_exp))), inds_T_exp) + T_imp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_imp))), inds_T_imp) + temp = zero(u0) γs = unique(filter(!iszero, diag(a_imp))) γ = length(γs) == 1 ? γs[1] : nothing # TODO: This could just be a constant. jac_prototype = has_jac(T_imp!) ? T_imp!.jac_prototype : nothing - newtons_method_cache = isnothing(T_imp!) ? nothing : allocate_cache(newtons_method, u0, jac_prototype) + newtons_method_cache = + isnothing(T_imp!) || isnothing(newtons_method) ? nothing : allocate_cache(newtons_method, u0, jac_prototype) return IMEXARKCache(U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) end +# generic fallback function step_u!(integrator, cache::IMEXARKCache) - (; u, p, t, dt, sol, alg) = integrator - (; f) = sol.prob - (; T_lim!, T_exp!, T_imp!, lim!, dss!, stage_callback!) = f - (; name, tableau, newtons_method) = alg + (; u, p, t, dt, alg) = integrator + (; f) = integrator.sol.prob + (; post_explicit!, post_implicit!) = f + (; T_lim!, T_exp!, T_imp!, lim!, dss!) = f + (; tableau, newtons_method) = alg (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau (; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache s = length(b_exp) - if !isnothing(T_imp!) - update!( - newtons_method, - newtons_method_cache, - NewTimeStep(t), - jacobian -> isnothing(γ) ? sdirk_error(name) : T_imp!.Wfact(jacobian, u, p, dt * γ, t), - ) + if !isnothing(T_imp!) && !isnothing(newtons_method) + (; update_j) = newtons_method + jacobian = newtons_method_cache.j + if (!isnothing(jacobian)) && needs_update!(update_j, NewTimeStep(t)) + if γ isa Nothing + sdirk_error(name) + else + T_imp!.Wfact(jacobian, u, p, dt * γ, t) + end + end end - for i in 1:s - t_exp = t + dt * c_exp[i] - t_imp = t + dt * c_imp[i] + update_stage!(integrator, cache, ntuple(i -> i, Val(s))) - @. U[i] = u + t_final = t + dt - if !isnothing(T_lim!) # Update based on limited tendencies from previous stages - for j in 1:(i - 1) - iszero(a_exp[i, j]) && continue - @. U[i] += dt * a_exp[i, j] * T_lim[j] - end - lim!(U[i], p, t_exp, u) - end + if has_T_lim(f) # Update based on limited tendencies from previous stages + assign_fused_increment!(temp, u, dt, b_exp, T_lim, Val(s)) + lim!(temp, p, t_final, u) + @. u = temp + end - if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages - for j in 1:(i - 1) - iszero(a_exp[i, j]) && continue - @. U[i] += dt * a_exp[i, j] * T_exp[j] - end - end + # Update based on tendencies from previous stages + has_T_exp(f) && fused_increment!(u, dt, b_exp, T_exp, Val(s)) + isnothing(T_imp!) || fused_increment!(u, dt, b_imp, T_imp, Val(s)) - if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages - for j in 1:(i - 1) - iszero(a_imp[i, j]) && continue - @. U[i] += dt * a_imp[i, j] * T_imp[j] - end - end + dss!(u, p, t_final) + post_explicit!(u, p, t_final) - dss!(U[i], p, t_exp) + return u +end - if !isnothing(T_imp!) && !iszero(a_imp[i, i]) # Implicit solve - @. temp = U[i] - # TODO: can/should we remove these closures? - implicit_equation_residual! = (residual, Ui) -> begin - T_imp!(residual, Ui, p, t_imp) - @. residual = temp + dt * a_imp[i, i] * residual - Ui - end - implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) - solve_newton!( - newtons_method, - newtons_method_cache, - U[i], - implicit_equation_residual!, - implicit_equation_jacobian!, - ) - end - # We do not need to DSS U[i] again because the implicit solve should - # give the same results for redundant columns (as long as the implicit - # tendency only acts in the vertical direction). - - if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]) - if !isnothing(T_imp!) - if iszero(a_imp[i, i]) - # If its coefficient is 0, T_imp[i] is effectively being - # treated explicitly. - T_imp!(T_imp[i], U[i], p, t_imp) - else - # If T_imp[i] is being treated implicitly, ensure that it - # exactly satisfies the implicit equation. - @. T_imp[i] = (U[i] - temp) / (dt * a_imp[i, i]) - end - end - end +@inline update_stage!(integrator, cache, ::Tuple{}) = nothing +@inline update_stage!(integrator, cache, is::Tuple{Int}) = update_stage!(integrator, cache, first(is)) +@inline function update_stage!(integrator, cache, is::Tuple) + update_stage!(integrator, cache, first(is)) + update_stage!(integrator, cache, Base.tail(is)) +end +@inline function update_stage!(integrator, cache::IMEXARKCache, i::Int) + (; u, p, t, dt, alg) = integrator + (; f) = integrator.sol.prob + (; post_explicit!, post_implicit!) = f + (; T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!) = f + (; tableau, newtons_method) = alg + (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau + (; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache + s = length(b_exp) - stage_callback!(U[i], p, t_exp) + t_exp = t + dt * c_exp[i] + t_imp = t + dt * c_imp[i] - if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]) - if !isnothing(T_lim!) - T_lim!(T_lim[i], U[i], p, t_exp) - end - if !isnothing(T_exp!) - T_exp!(T_exp[i], U[i], p, t_exp) - end - end + if has_T_lim(f) # Update based on limited tendencies from previous stages + assign_fused_increment!(U, u, dt, a_exp, T_lim, Val(i)) + i ≠ 1 && lim!(U, p, t_exp, u) + else + @. U = u end - t_final = t + dt - - if !isnothing(T_lim!) # Update based on limited tendencies from previous stages - @. temp = u - for j in 1:s - iszero(b_exp[j]) && continue - @. temp += dt * b_exp[j] * T_lim[j] + # Update based on tendencies from previous stages + has_T_exp(f) && fused_increment!(U, dt, a_exp, T_exp, Val(i)) + isnothing(T_imp!) || fused_increment!(U, dt, a_imp, T_imp, Val(i)) + + if isnothing(T_imp!) || iszero(a_imp[i, i]) + i ≠ 1 && dss!(U, p, t_imp) + i ≠ 1 && post_explicit!(U, p, t_imp) + else # Implicit solve + @assert !isnothing(newtons_method) + @. temp = U + # We do not need to apply DSS yet because the implicit solve does not + # involve any horizontal derivatives. + i ≠ 1 && post_explicit!(U, p, t_imp) + # TODO: can/should we remove these closures? + implicit_equation_residual! = (residual, Ui) -> begin + T_imp!(residual, Ui, p, t_imp) + @. residual = temp + dt * a_imp[i, i] * residual - Ui end - lim!(temp, p, t_final, u) - @. u = temp + implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp) + call_post_implicit! = Ui -> begin + post_implicit!(Ui, p, t_imp) + end + call_post_implicit_last! = Ui -> begin + dss!(Ui, p, t_imp) + post_implicit!(Ui, p, t_imp) + end + + solve_newton!( + newtons_method, + newtons_method_cache, + U, + implicit_equation_residual!, + implicit_equation_jacobian!, + call_post_implicit!, + call_post_implicit_last!, + ) end - if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages - for j in 1:s - iszero(b_exp[j]) && continue - @. u += dt * b_exp[j] * T_exp[j] + if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]) + if iszero(a_imp[i, i]) + # If its coefficient is 0, T_imp[i] is effectively being + # treated explicitly. + isnothing(T_imp!) || T_imp!(T_imp[i], U, p, t_imp) + else + # If T_imp[i] is being treated implicitly, ensure that it + # exactly satisfies the implicit equation. + isnothing(T_imp!) || @. T_imp[i] = (U - temp) / (dt * a_imp[i, i]) end end - if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages - for j in 1:s - iszero(b_imp[j]) && continue - @. u += dt * b_imp[j] * T_imp[j] + if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]) + if !isnothing(T_exp_T_lim!) + T_exp_T_lim!(T_exp[i], T_lim[i], U, p, t_exp) + else + isnothing(T_lim!) || T_lim!(T_lim[i], U, p, t_exp) + isnothing(T_exp!) || T_exp!(T_exp[i], U, p, t_exp) end end - dss!(u, p, t_final) - - return u + return nothing end diff --git a/src/solvers/imex_tableaus.jl b/src/solvers/imex_tableaus.jl index 974a3d14..5c0fc068 100644 --- a/src/solvers/imex_tableaus.jl +++ b/src/solvers/imex_tableaus.jl @@ -2,15 +2,10 @@ export IMEXTableau, IMEXAlgorithm export ARS111, ARS121, ARS122, ARS233, ARS232, ARS222, ARS343, ARS443 export IMKG232a, IMKG232b, IMKG242a, IMKG242b, IMKG243a, IMKG252a, IMKG252b export IMKG253a, IMKG253b, IMKG254a, IMKG254b, IMKG254c, IMKG342a, IMKG343a -export DBM453, HOMMEM1 export SSP222, SSP322, SSP332, SSP333, SSP433 +export DBM453, HOMMEM1, ARK2GKC, ARK437L2SA1, ARK548L2SA2 -using StaticArrays: @SArray, SMatrix, sacollect - -abstract type IMEXAlgorithmName <: AbstractAlgorithmName end -abstract type IMEXSSPRKAlgorithmName <: IMEXAlgorithmName end -default_constraint(::IMEXAlgorithmName) = Unconstrained() -default_constraint(::IMEXSSPRKAlgorithmName) = SSPConstrained() +abstract type IMEXARKAlgorithmName <: AbstractAlgorithmName end """ IMEXTableau(; a_exp, b_exp, c_exp, a_imp, b_imp, c_imp) @@ -24,14 +19,16 @@ default values for `c_exp` and `c_imp` assume that it is internally consistent. The explicit tableau must be strictly lower triangular, and the implicit tableau must be lower triangular (only DIRK algorithms are currently supported). """ -struct IMEXTableau{VS <: StaticArrays.StaticArray, MS <: StaticArrays.StaticArray} - a_exp::MS # matrix of size s×s - b_exp::VS # vector of length s - c_exp::VS # vector of length s - a_imp::MS # matrix of size s×s - b_imp::VS # vector of length s - c_imp::VS # vector of length s +struct IMEXTableau{AE <: SPCO, BE <: SPCO, CE <: SPCO, AI <: SPCO, BI <: SPCO, CI <: SPCO} + a_exp::AE # matrix of size s×s + b_exp::BE # vector of length s + c_exp::CE # vector of length s + a_imp::AI # matrix of size s×s + b_imp::BI # vector of length s + c_imp::CI # vector of length s end +IMEXTableau(args...) = IMEXTableau(map(x -> SparseCoeffs(x), args)...) + function IMEXTableau(; a_exp, b_exp = a_exp[end, :], @@ -52,22 +49,18 @@ end """ IMEXAlgorithm(tableau, newtons_method, [constraint]) IMEXAlgorithm(name, newtons_method, [constraint]) - [Name](newtons_method) Constructs an IMEX algorithm for solving ODEs, with an optional name and constraint. The first constructor accepts any `IMEXTableau` and an optional constraint, leaving the algorithm unnamed. The second constructor automatically -determines the tableau and the default constraint from the algorithm name, which -must be an `IMEXAlgorithmName`. - -The last constructor matches the notation of `OrdinaryDiffEq.jl`; it dispatches -to the second constructor by returning `IMEXAlgorithm(Name(), newtons_method)`. +determines the tableau and the default constraint from the algorithm name, +which must be an `IMEXARKAlgorithmName`. """ struct IMEXAlgorithm{ C <: AbstractAlgorithmConstraint, - N <: Union{Nothing, IMEXAlgorithmName}, + N <: Union{Nothing, AbstractAlgorithmName}, T <: IMEXTableau, - NM <: NewtonsMethod, + NM <: Union{Nothing, NewtonsMethod}, } <: DistributedODEAlgorithm constraint::C name::N @@ -76,39 +69,24 @@ struct IMEXAlgorithm{ end IMEXAlgorithm(tableau::IMEXTableau, newtons_method, constraint = Unconstrained()) = IMEXAlgorithm(constraint, nothing, tableau, newtons_method) -IMEXAlgorithm(name::IMEXAlgorithmName, newtons_method, constraint = default_constraint(name)) = +IMEXAlgorithm(name::IMEXARKAlgorithmName, newtons_method, constraint = default_constraint(name)) = IMEXAlgorithm(constraint, name, IMEXTableau(name), newtons_method) -# If all AbstractAlgorithmNames were singletons, we could make type-based -# functors, but some AbstractAlgorithmNames have parameters (e.g., SSP333) - -# (::Type{Name})(newtons_method::NewtonsMethod) where {Name <: IMEXAlgorithmName} = IMEXAlgorithm(Name(), newtons_method) - -#= Convenience constructor =# -(name::IMEXAlgorithmName)(newtons_method::NewtonsMethod) = IMEXAlgorithm(name, newtons_method) - ################################################################################ # ARS algorithms -# From Section 2 of "Implicit-Explicit Runge-Kutta Methods for Time-Dependent -# Partial Differential Equations" by Ascher et al. - -# Each algorithm is named ARSsσp, where s is the number of implicit stages, σ is -# the number of explicit stages, and p is the order of accuracy - -# This algorithm is equivalent to OrdinaryDiffEq.IMEXEuler. +# The naming convention is ARSsσp, where s is the number of implicit stages, +# σ is the number of explicit stages, and p is the order of accuracy. """ ARS111 -The Forward-Backward (1,1,1) implicit-explicit (IMEX) Runge-Kutta scheme of -[ARS1997](@cite), section 2.1. - -This is equivalent to the `OrdinaryDiffEq.IMEXEuler` algorithm. +An IMEX ARK algorithm from [ARS1997](@cite), section 2, with 1 implicit stage, +1 explicit stage and 1st order accuracy. Also called *IMEX Euler* or +*forward-backward Euler*; equivalent to `OrdinaryDiffEq.IMEXEuler`. """ -struct ARS111 <: IMEXAlgorithmName end - +struct ARS111 <: IMEXARKAlgorithmName end function IMEXTableau(::ARS111) IMEXTableau(; a_exp = @SArray([0 0; 1 0]), a_imp = @SArray([0 0; 0 1])) end @@ -116,28 +94,38 @@ end """ ARS121 -The Forward-Backward (1,2,1) implicit-explicit (IMEX) Runge-Kutta scheme of -[ARS1997](@cite), section 2.2. - -This is equivalent to the `OrdinaryDiffEq.IMEXEulerARK` algorithm. +An IMEX ARK algorithm from [ARS1997](@cite), section 2, with 1 implicit stage, 2 +explicit stages, and 1st order accuracy. Also called *IMEX Euler* or +*forward-backward Euler*; equivalent to `OrdinaryDiffEq.IMEXEulerARK`. """ -struct ARS121 <: IMEXAlgorithmName end - +struct ARS121 <: IMEXARKAlgorithmName end function IMEXTableau(::ARS121) IMEXTableau(; a_exp = @SArray([0 0; 1 0]), b_exp = @SArray([0, 1]), a_imp = @SArray([0 0; 0 1])) end -struct ARS122 <: IMEXAlgorithmName end +""" + ARS122 + +An IMEX ARK algorithm from [ARS1997](@cite), section 2, with 1 implicit stage, 2 +explicit stages, and 2nd order accuracy. Also called *IMEX midpoint*. +""" +struct ARS122 <: IMEXARKAlgorithmName end function IMEXTableau(::ARS122) IMEXTableau(; a_exp = @SArray([0 0; 1/2 0]), b_exp = @SArray([0, 1]), a_imp = @SArray([0 0; 0 1/2]), - b_imp = @SArray([0, 1]), + b_imp = @SArray([0, 1]) ) end -struct ARS233 <: IMEXAlgorithmName end +""" + ARS233 + +An IMEX ARK algorithm from [ARS1997](@cite), section 2, with 2 implicit stages, +3 explicit stages, and 3rd order accuracy. +""" +struct ARS233 <: IMEXARKAlgorithmName end function IMEXTableau(::ARS233) γ = 1 / 2 + √3 / 6 IMEXTableau(; @@ -152,17 +140,17 @@ function IMEXTableau(::ARS233) 0 γ 0 0 (1-2γ) γ ]), - b_imp = @SArray([0, 1 / 2, 1 / 2]), + b_imp = @SArray([0, 1 / 2, 1 / 2]) ) end """ ARS232 -The Forward-Backward (2,3,2) implicit-explicit (IMEX) Runge-Kutta scheme of -[ARS1997](@cite), section 2.5. +An IMEX ARK algorithm from [ARS1997](@cite), section 2, with 2 implicit stages, +3 explicit stages, and 2nd order accuracy. """ -struct ARS232 <: IMEXAlgorithmName end +struct ARS232 <: IMEXARKAlgorithmName end function IMEXTableau(::ARS232) γ = 1 - √2 / 2 δ = -2√2 / 3 @@ -177,11 +165,17 @@ function IMEXTableau(::ARS232) 0 0 0 0 γ 0 0 (1-γ) γ - ]), + ]) ) end -struct ARS222 <: IMEXAlgorithmName end +""" + ARS222 + +An IMEX ARK algorithm from [ARS1997](@cite), section 2, with 2 implicit stages, +2 explicit stages, and 2nd order accuracy. +""" +struct ARS222 <: IMEXARKAlgorithmName end function IMEXTableau(::ARS222) γ = 1 - √2 / 2 δ = 1 - 1 / 2γ @@ -199,10 +193,10 @@ end """ ARS343 -The L-stable, third-order (3,4,3) implicit-explicit (IMEX) Runge-Kutta scheme of -[ARS1997](@cite), section 2.7. +An IMEX ARK algorithm from [ARS1997](@cite), section 2, with 3 implicit stages, +4 explicit stages, and 3rd order accuracy. """ -struct ARS343 <: IMEXAlgorithmName end +struct ARS343 <: IMEXARKAlgorithmName end function IMEXTableau(::ARS343) γ = 0.4358665215084590 a42 = 0.5529291480359398 @@ -228,11 +222,17 @@ function IMEXTableau(::ARS343) 0 γ 0 0 0 (1 - γ)/2 γ 0 0 b1 b2 γ - ]), + ]) ) end -struct ARS443 <: IMEXAlgorithmName end +""" + ARS443 + +An IMEX ARK algorithm from [ARS1997](@cite), section 2, with 4 implicit stages, +4 explicit stages, and 3rd order accuracy. +""" +struct ARS443 <: IMEXARKAlgorithmName end function IMEXTableau(::ARS443) IMEXTableau(; a_exp = @SArray([ @@ -248,7 +248,7 @@ function IMEXTableau(::ARS443) 0 1/6 1/2 0 0 0 -1/2 1/2 1/2 0 0 3/2 -3/2 1/2 1/2 - ]), + ]) ) end @@ -256,18 +256,15 @@ end # IMKG algorithms -# From Tables 3 and 4 of "Efficient IMEX Runge-Kutta Methods for Nonhydrostatic -# Dynamics" by Steyer et al. - -# Each algorithm is named IMKGpfjl, where p is the order of accuracy, f is the +# The naming convention is IMKGpfjl, where p is the order of accuracy, f is the # number of explicit stages, j is the number of implicit stages, and l is an -# identifying letter +# identifying letter. # TODO: Tables 3 and 4 are riddled with typos, but most of these can be easily # identified and corrected by referencing the implementations in HOMME: -# https://github.com/E3SM-Project/E3SM/blob/master/components/homme/src/arkode/arkode_tables.F90 +# https://github.com/E3SM-Project/E3SM/blob/v2.0.0/components/homme/src/arkode/arkode_tables.F90 # Unfortunately, the implementations of IMKG353a and IMKG354a in HOMME also -# appear to be wrong, so they are not included here. Eventually, we should get +# appear to be wrong, so they are left unimplemented. Eventually, we should get # the official implementations from the paper's authors. # a_exp: a_imp: @@ -284,49 +281,96 @@ function IMKGTableau(α, α̂, δ̂, β = ntuple(_ -> 0, length(δ̂))) s = length(α̂) + 1 type = SMatrix{s, s} return IMEXTableau(; - a_exp = sacollect(type, imkg_exp(i, j, α, β) for i in 1:s, j in 1:s), - a_imp = sacollect(type, imkg_imp(i, j, α̂, β, δ̂) for i in 1:s, j in 1:s), + a_exp = StaticArrays.sacollect(type, imkg_exp(i, j, α, β) for i in 1:s, j in 1:s), + a_imp = StaticArrays.sacollect(type, imkg_imp(i, j, α̂, β, δ̂) for i in 1:s, j in 1:s), ) end -struct IMKG232a <: IMEXAlgorithmName end +""" + IMKG232a + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 2 implicit stages, +3 explicit stages, and 2nd order accuracy. +""" +struct IMKG232a <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG232a) IMKGTableau((1 / 2, 1 / 2, 1), (0, -1 / 2 + √2 / 2, 1), (1 - √2 / 2, 1 - √2 / 2)) end -struct IMKG232b <: IMEXAlgorithmName end +""" + IMKG232b + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 2 implicit stages, +3 explicit stages, and 2nd order accuracy. +""" +struct IMKG232b <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG232b) IMKGTableau((1 / 2, 1 / 2, 1), (0, -1 / 2 - √2 / 2, 1), (1 + √2 / 2, 1 + √2 / 2)) end -struct IMKG242a <: IMEXAlgorithmName end +""" + IMKG242a + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 2 implicit stages, +4 explicit stages, and 2nd order accuracy. +""" +struct IMKG242a <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG242a) IMKGTableau((1 / 4, 1 / 3, 1 / 2, 1), (0, 0, -1 / 2 + √2 / 2, 1), (0, 1 - √2 / 2, 1 - √2 / 2)) end -struct IMKG242b <: IMEXAlgorithmName end +""" + IMKG242b + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 2 implicit stages, +4 explicit stages, and 2nd order accuracy. +""" +struct IMKG242b <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG242b) IMKGTableau((1 / 4, 1 / 3, 1 / 2, 1), (0, 0, -1 / 2 - √2 / 2, 1), (0, 1 + √2 / 2, 1 + √2 / 2)) end -# The paper uses √3/6 for α̂[3], which also seems to work. -struct IMKG243a <: IMEXAlgorithmName end +""" + IMKG243a + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 3 implicit stages, +4 explicit stages, and 2nd order accuracy. +""" +struct IMKG243a <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG243a) IMKGTableau((1 / 4, 1 / 3, 1 / 2, 1), (0, 1 / 6, -√3 / 6, 1), (1 / 2 + √3 / 6, 1 / 2 + √3 / 6, 1 / 2 + √3 / 6)) end +# The paper uses √3/6 for α̂[3], which also seems to work. -struct IMKG252a <: IMEXAlgorithmName end +""" + IMKG252a + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 2 implicit stages, +5 explicit stages, and 2nd order accuracy. +""" +struct IMKG252a <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG252a) IMKGTableau((1 / 4, 1 / 6, 3 / 8, 1 / 2, 1), (0, 0, 0, -1 / 2 + √2 / 2, 1), (0, 0, 1 - √2 / 2, 1 - √2 / 2)) end -struct IMKG252b <: IMEXAlgorithmName end +""" + IMKG252b + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 2 implicit stages, +5 explicit stages, and 2nd order accuracy. +""" +struct IMKG252b <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG252b) IMKGTableau((1 / 4, 1 / 6, 3 / 8, 1 / 2, 1), (0, 0, 0, -1 / 2 - √2 / 2, 1), (0, 0, 1 + √2 / 2, 1 + √2 / 2)) end -# The paper uses 0.08931639747704086 for α̂[3], which also seems to work. -struct IMKG253a <: IMEXAlgorithmName end +""" + IMKG253a + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 3 implicit stages, +5 explicit stages, and 2nd order accuracy. +""" +struct IMKG253a <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG253a) IMKGTableau( (1 / 4, 1 / 6, 3 / 8, 1 / 2, 1), @@ -334,9 +378,15 @@ function IMEXTableau(::IMKG253a) (0, 1 / 2 - √3 / 6, 1 / 2 - √3 / 6, 1 / 2 - √3 / 6), ) end +# The paper uses 0.08931639747704086 for α̂[3], which also seems to work. -# The paper uses 1.2440169358562922 for α̂[3], which also seems to work. -struct IMKG253b <: IMEXAlgorithmName end +""" + IMKG253b + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 3 implicit stages, +5 explicit stages, and 2nd order accuracy. +""" +struct IMKG253b <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG253b) IMKGTableau( (1 / 4, 1 / 6, 3 / 8, 1 / 2, 1), @@ -344,32 +394,48 @@ function IMEXTableau(::IMKG253b) (0, 1 / 2 + √3 / 6, 1 / 2 + √3 / 6, 1 / 2 + √3 / 6), ) end +# The paper uses 1.2440169358562922 for α̂[3], which also seems to work. + +""" + IMKG254a -struct IMKG254a <: IMEXAlgorithmName end +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 4 implicit stages, +5 explicit stages, and 2nd order accuracy. +""" +struct IMKG254a <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG254a) IMKGTableau((1 / 4, 1 / 6, 3 / 8, 1 / 2, 1), (0, -3 / 10, 5 / 6, -3 / 2, 1), (-1 / 2, 1, 1, 2)) end -struct IMKG254b <: IMEXAlgorithmName end +""" + IMKG254b + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 4 implicit stages, +5 explicit stages, and 2nd order accuracy. +""" +struct IMKG254b <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG254b) IMKGTableau((1 / 4, 1 / 6, 3 / 8, 1 / 2, 1), (0, -1 / 20, 5 / 4, -1 / 2, 1), (-1 / 2, 1, 1, 1)) end -struct IMKG254c <: IMEXAlgorithmName end +""" + IMKG254c + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 3, with 4 implicit stages, +5 explicit stages, and 2nd order accuracy. +""" +struct IMKG254c <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG254c) IMKGTableau((1 / 4, 1 / 6, 3 / 8, 1 / 2, 1), (0, 1 / 20, 5 / 36, 1 / 3, 1), (1 / 6, 1 / 6, 1 / 6, 1 / 6)) end -# The paper and HOMME completely disagree on this algorithm. Since the version -# in the paper is not "342" (it appears to be "332"), the version from HOMME is -# used here. -# const IMKG342a = IMKGTableau( -# (0, 1/3, 1/3, 3/4), -# (0, -1/6 - √3/6, -1/6 - √3/6, 3/4), -# (0, 1/2 + √3/6, 1/2 + √3/6), -# (1/3, 1/3, 1/4), -# ) -struct IMKG342a <: IMEXAlgorithmName end +""" + IMKG342a + +An IMEX ARK algorithm from [SVTG2019](@cite), Table 4, with 2 implicit stages, +4 explicit stages, and 3rd order accuracy. +""" +struct IMKG342a <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG342a) IMKGTableau( (1 / 4, 2 / 3, 1 / 3, 3 / 4), @@ -378,121 +444,71 @@ function IMEXTableau(::IMKG342a) (0, 1 / 3, 1 / 4), ) end +# The paper and HOMME completely disagree on IMKG342a. Since the version in the +# paper is not "342" (it appears to be "332"), the version from HOMME is used +# here. The paper's version is +# IMKGTableau( +# (0, 1/3, 1/3, 3/4), +# (0, -1/6 - √3/6, -1/6 - √3/6, 3/4), +# (0, 1/2 + √3/6, 1/2 + √3/6), +# (1/3, 1/3, 1/4), +# ) + +""" + IMKG343a -struct IMKG343a <: IMEXAlgorithmName end +An IMEX ARK algorithm from [SVTG2019](@cite), Table 4, with 3 implicit stages, +4 explicit stages, and 3rd order accuracy. +""" +struct IMKG343a <: IMEXARKAlgorithmName end function IMEXTableau(::IMKG343a) IMKGTableau((1 / 4, 2 / 3, 1 / 3, 3 / 4), (0, -1 / 3, -2 / 3, 3 / 4), (-1 / 3, 1, 1), (0, 1 / 3, 1 / 4)) end -# The paper and HOMME completely disagree on this algorithm, but neither version -# is "353" (they appear to be "343" and "354", respectively). -# struct IMKG353a <: IMEXAlgorithmName end -# function IMEXTableau(::IMKG353a) -# IMKGTableau( -# (1/4, 2/3, 1/3, 3/4), -# (0, -359/600, -559/600, 3/4), -# (-1.1678009811335388, 253/200, 253/200), -# (0, 1/3, 1/4), -# ) -# end -# struct IMKG353a <: IMEXAlgorithmName end -# function IMEXTableau(::IMKG353a) -# IMKGTableau( -# (-0.017391304347826087, -23/25, 5/3, 1/3, 3/4), -# (0.3075640504095504, -1.2990164859879263, 751/600, -49/60, 3/4), -# (-0.2981612530370581, 83/200, 83/200, 23/20), -# (1, -1, 1/3, 1/4), -# ) -# end - -# The version of this algorithm in the paper is not "354" (it appears to be -# "253"), and this algorithm is missing from HOMME (or, more precisely, the -# tableau for IMKG353a is mistakenly used to define IMKG354a, and the tableau -# for IMKG354a is not specified). -# struct IMKG354a <: IMEXAlgorithmName end -# function IMEXTableau(::IMKG354a) -# IMKGTableau( -# (1/5, 1/5, 2/3, 1/3, 3/4), -# (0, 0, 11/30, -2/3, 3/4), -# (0, 2/4, 2/5, 1), -# (0, 0, 1/3, 1/4), -# ) -# end - -################################################################################ - -# DBM algorithm - -# From Appendix A of "Evaluation of Implicit-Explicit Additive Runge-Kutta -# Integrators for the HOMME-NH Dynamical Core" by Vogl et al. - -# The algorithm has 4 implicit stages, 5 overall stages, and 3rd order accuracy. +# The paper and HOMME completely disagree on IMKG353a, but neither version +# is "353" (they appear to be "343" and "354", respectively). The paper's +# version is +# IMKGTableau( +# (1/4, 2/3, 1/3, 3/4), +# (0, -359/600, -559/600, 3/4), +# (-1.1678009811335388, 253/200, 253/200), +# (0, 1/3, 1/4), +# ) +# HOMME's version is +# IMKGTableau( +# (-0.017391304347826087, -23/25, 5/3, 1/3, 3/4), +# (0.3075640504095504, -1.2990164859879263, 751/600, -49/60, 3/4), +# (-0.2981612530370581, 83/200, 83/200, 23/20), +# (1, -1, 1/3, 1/4), +# ) -struct DBM453 <: IMEXAlgorithmName end -function IMEXTableau(::DBM453) - γ = 0.32591194130117247 - IMEXTableau(; - a_exp = @SArray( - [ - 0 0 0 0 0 - 0.10306208811591838 0 0 0 0 - -0.94124866143519894 1.6626399742527356 0 0 0 - -1.3670975201437765 1.3815852911016873 1.2673234025619065 0 0 - -0.81287582068772448 0.81223739060505738 0.90644429603699305 0.094194134045674111 0 - ] - ), - b_exp = @SArray([0.87795339639076675, -0.72692641526151547, 0.7520413715737272, -0.22898029400415088, γ]), - a_imp = @SArray( - [ - 0 0 0 0 0 - -0.2228498531852541 γ 0 0 0 - -0.46801347074080545 0.86349284225716961 γ 0 0 - -0.46509906651927421 0.81063103116959553 0.61036726756832357 γ 0 - 0.87795339639076675 -0.72692641526151547 0.7520413715737272 -0.22898029400415088 γ - ] - ), - ) -end +# The version of IMKG354a in the paper is not "354" (it appears to be "253"), +# and IMKG354a is missing from HOMME (or, more precisely, the tableau for +# IMKG353a is mistakenly used to define IMKG354a, and the tableau for IMKG354a +# is not specified). The paper's version is +# IMKGTableau( +# (1/5, 1/5, 2/3, 1/3, 3/4), +# (0, 0, 11/30, -2/3, 3/4), +# (0, 2/4, 2/5, 1), +# (0, 0, 1/3, 1/4), +# ) ################################################################################ -# HOMMEM1 algorithm - -# From Section 4.1 of "A framework to evaluate IMEX schemes for atmospheric -# models" by Guba et al. +# IMEX SSPRK algorithms -# The algorithm has 5 implicit stages, 6 overall stages, and 2rd order accuracy. - -struct HOMMEM1 <: IMEXAlgorithmName end -function IMEXTableau(::HOMMEM1) - IMEXTableau(; - a_exp = @SArray([ - 0 0 0 0 0 0 - 1/5 0 0 0 0 0 - 0 1/5 0 0 0 0 - 0 0 1/3 0 0 0 - 0 0 0 1/2 0 0 - 0 0 0 0 1 0 - ]), - a_imp = @SArray([ - 0 0 0 0 0 0 - 0 1/5 0 0 0 0 - 0 0 1/5 0 0 0 - 0 0 0 1/3 0 0 - 0 0 0 0 1/2 0 - 5/18 5/18 0 0 0 8/18 - ]), - ) -end +# The naming convention is SSPsσp, where s is the number of implicit stages, +# σ is the number of explicit stages, and p is the order of accuracy. -################################################################################ +abstract type IMEXSSPRKAlgorithmName <: IMEXARKAlgorithmName end -# IMEX SSP algorithms +default_constraint(::IMEXSSPRKAlgorithmName) = SSP() """ SSP222 -https://link.springer.com/content/pdf/10.1007/BF02728986.pdf, Table II +An IMEX SSPRK algorithm from [PR2005](@cite), with 2 implicit stages, 2 explicit +stages, and 2nd order accuracy. Also called *SSP2(222)* in [GGHRUW2018](@cite). """ struct SSP222 <: IMEXSSPRKAlgorithmName end function IMEXTableau(::SSP222) @@ -507,14 +523,15 @@ function IMEXTableau(::SSP222) γ 0 (1-2γ) γ ]), - b_imp = @SArray([1 / 2, 1 / 2]), + b_imp = @SArray([1 / 2, 1 / 2]) ) end """ SSP322 -https://link.springer.com/content/pdf/10.1007/BF02728986.pdf, Table III +An IMEX SSPRK algorithm from [PR2005](@cite), with 3 implicit stages, 2 explicit +stages, and 2nd order accuracy. """ struct SSP322 <: IMEXSSPRKAlgorithmName end function IMEXTableau(::SSP322) @@ -530,14 +547,15 @@ function IMEXTableau(::SSP322) -1/2 1/2 0 0 1/2 1/2 ]), - b_imp = @SArray([0, 1 / 2, 1 / 2]), + b_imp = @SArray([0, 1 / 2, 1 / 2]) ) end """ SSP332 -https://link.springer.com/content/pdf/10.1007/BF02728986.pdf, Table V +An IMEX SSPRK algorithm from [PR2005](@cite), with 3 implicit stages, 3 explicit +stages, and 2nd order accuracy. Also called *SSP2(332)a* in [GGHRUW2018](@cite). """ struct SSP332 <: IMEXSSPRKAlgorithmName end function IMEXTableau(::SSP332) @@ -554,17 +572,17 @@ function IMEXTableau(::SSP332) (1-2γ) γ 0 (1 / 2-γ) 0 γ ]), - b_imp = @SArray([1 / 6, 1 / 6, 2 / 3]), + b_imp = @SArray([1 / 6, 1 / 6, 2 / 3]) ) end """ - SSP333([β]) + SSP333(; β = 1/2 + √3/6) -Family of SSP333 algorithms parametrized by the value β, from Section 3.2 of -https://arxiv.org/pdf/1702.04621.pdf. The default value of β, 1/2 + √3/6, -results in an SDIRK algorithm, which is also called SSP3(333)c in -https://gmd.copernicus.org/articles/11/1497/2018/gmd-11-1497-2018.pdf. +Family of IMEX SSPRK algorithms parametrized by the value β from +[CGGS2017](@cite), Section 3.2, with 3 implicit stages, 3 explicit stages, and +3rd order accuracy. The default value of β results in an SDIRK algorithm, which +is also called *SSP3(333)c* in [GGHRUW2018](@cite). """ Base.@kwdef struct SSP333{FT <: AbstractFloat} <: IMEXSSPRKAlgorithmName β::FT = 1 / 2 + √3 / 6 @@ -584,14 +602,15 @@ function IMEXTableau((; β)::SSP333) (4γ+2β) (1 - 4γ-2β) 0 (1 / 2 - β-γ) γ β ]), - b_imp = @SArray([1 / 6, 1 / 6, 2 / 3]), + b_imp = @SArray([1 / 6, 1 / 6, 2 / 3]) ) end """ SSP433 -https://link.springer.com/content/pdf/10.1007/BF02728986.pdf, Table VI +An IMEX SSPRK algorithm from [PR2005](@cite), with 4 implicit stages, 3 explicit +stages, and 3rd order accuracy. Also called *SSP3(433)* in [GGHRUW2018](@cite). """ struct SSP433 <: IMEXSSPRKAlgorithmName end function IMEXTableau(::SSP433) @@ -612,6 +631,285 @@ function IMEXTableau(::SSP433) 0 (1-α) α 0 β η (1 / 2 - α - β-η) α ]), - b_imp = @SArray([0, 1 / 6, 1 / 6, 2 / 3]), + b_imp = @SArray([0, 1 / 6, 1 / 6, 2 / 3]) + ) +end + +################################################################################ + +# Miscellaneous algorithms + +""" + DBM453 + +An IMEX ARK algorithm from [VSRUW2019](@cite), Appendix A, with 4 implicit +stages, 5 explicit stages, and 3rd order accuracy. +""" +struct DBM453 <: IMEXARKAlgorithmName end +function IMEXTableau(::DBM453) + γ = 0.32591194130117247 + IMEXTableau(; + a_exp = @SArray( + [ + 0 0 0 0 0 + 0.10306208811591838 0 0 0 0 + -0.94124866143519894 1.6626399742527356 0 0 0 + -1.3670975201437765 1.3815852911016873 1.2673234025619065 0 0 + -0.81287582068772448 0.81223739060505738 0.90644429603699305 0.094194134045674111 0 + ] + ), + b_exp = @SArray([0.87795339639076675, -0.72692641526151547, 0.7520413715737272, -0.22898029400415088, γ]), + a_imp = @SArray( + [ + 0 0 0 0 0 + -0.2228498531852541 γ 0 0 0 + -0.46801347074080545 0.86349284225716961 γ 0 0 + -0.46509906651927421 0.81063103116959553 0.61036726756832357 γ 0 + 0.87795339639076675 -0.72692641526151547 0.7520413715737272 -0.22898029400415088 γ + ] + ) + ) +end + +""" + HOMMEM1 + +An IMEX ARK algorithm from [GTBBS2020](@cite), section 4.1, with 5 implicit +stages, 6 explicit stages, and 2nd order accuracy. +""" +struct HOMMEM1 <: IMEXARKAlgorithmName end +function IMEXTableau(::HOMMEM1) + IMEXTableau(; + a_exp = @SArray([ + 0 0 0 0 0 0 + 1/5 0 0 0 0 0 + 0 1/5 0 0 0 0 + 0 0 1/3 0 0 0 + 0 0 0 1/2 0 0 + 0 0 0 0 1 0 + ]), + a_imp = @SArray([ + 0 0 0 0 0 0 + 0 1/5 0 0 0 0 + 0 0 1/5 0 0 0 + 0 0 0 1/3 0 0 + 0 0 0 0 1/2 0 + 5/18 5/18 0 0 0 8/18 + ]) + ) +end + +""" + ARK2GKC(; paper_version = false) + +An IMEX ARK algorithm from [GKC2013](@cite) with 2 implicit stages, 3 explicit +stages, and 2nd order accuracy. If `paper_version = true`, the algorithm uses +coefficients from the paper. Otherwise, it uses coefficients that make it more +stable but less accurate. +""" +Base.@kwdef struct ARK2GKC <: IMEXARKAlgorithmName + paper_version::Bool = false +end +function IMEXTableau((; paper_version)::ARK2GKC) + a32 = paper_version ? 1 / 2 + √2 / 3 : 1 / 2 + IMEXTableau(; + a_exp = @SArray([ + 0 0 0 + (2-√2) 0 0 + (1-a32) a32 0 + ]), + b_exp = @SArray([√2 / 4, √2 / 4, 1 - √2 / 2]), + a_imp = @SArray([ + 0 0 0 + (1-√2 / 2) (1-√2 / 2) 0 + √2/4 √2/4 (1-√2 / 2) + ]) + ) +end + +""" + ARK437L2SA1 + +An IMEX ARK algorithm from [KC2019](@cite), Table 8, with 6 implicit stages, 7 +explicit stages, and 4th order accuracy. Written as *ARK4(3)7L[2]SA₁* in the +paper. +""" +struct ARK437L2SA1 <: IMEXARKAlgorithmName end +function IMEXTableau(::ARK437L2SA1) + a_exp = zeros(Rational{Int64}, 7, 7) + a_imp = zeros(Rational{Int64}, 7, 7) + b = zeros(Rational{Int64}, 7) + c = zeros(Rational{Int64}, 7) + + γ = 1235 // 10000 + for i in 2:7 + a_imp[i, i] = γ + end + + a_imp[3, 2] = 624185399699 // 4186980696204 + a_imp[4, 2] = 1258591069120 // 10082082980243 + a_imp[4, 3] = -322722984531 // 8455138723562 + a_imp[5, 2] = -436103496990 // 5971407786587 + a_imp[5, 3] = -2689175662187 // 11046760208243 + a_imp[5, 4] = 4431412449334 // 12995360898505 + a_imp[6, 2] = -2207373168298 // 14430576638973 + a_imp[6, 3] = 242511121179 // 3358618340039 + a_imp[6, 4] = 3145666661981 // 7780404714551 + a_imp[6, 5] = 5882073923981 // 14490790706663 + a_imp[7, 2] = 0 + a_imp[7, 3] = 9164257142617 // 17756377923965 + a_imp[7, 4] = -10812980402763 // 74029279521829 + a_imp[7, 5] = 1335994250573 // 5691609445217 + a_imp[7, 6] = 2273837961795 // 8368240463276 + + a_exp[3, 1] = 247 // 4000 + a_exp[3, 2] = 2694949928731 // 7487940209513 + a_exp[4, 1] = 464650059369 // 8764239774964 + a_exp[4, 2] = 878889893998 // 2444806327765 + a_exp[4, 3] = -952945855348 // 12294611323341 + a_exp[5, 1] = 476636172619 // 8159180917465 + a_exp[5, 2] = -1271469283451 // 7793814740893 + a_exp[5, 3] = -859560642026 // 4356155882851 + a_exp[5, 4] = 1723805262919 // 4571918432560 + a_exp[6, 1] = 6338158500785 // 11769362343261 + a_exp[6, 2] = -4970555480458 // 10924838743837 + a_exp[6, 3] = 3326578051521 // 2647936831840 + a_exp[6, 4] = -880713585975 // 1841400956686 + a_exp[6, 5] = -1428733748635 // 8843423958496 + a_exp[7, 2] = 760814592956 // 3276306540349 + a_exp[7, 3] = -47223648122716 // 6934462133451 + a_exp[7, 4] = 71187472546993 // 9669769126921 + a_exp[7, 5] = -13330509492149 // 9695768672337 + a_exp[7, 6] = 11565764226357 // 8513123442827 + + b[2] = 0 + b[3] = 9164257142617 // 17756377923965 + b[4] = -10812980402763 // 74029279521829 + b[5] = 1335994250573 // 5691609445217 + b[6] = 2273837961795 // 8368240463276 + b[7] = 247 // 2000 + + c[2] = 247 // 1000 + c[3] = 4276536705230 // 10142255878289 + c[4] = 67 // 200 + c[5] = 3 // 40 + c[6] = 7 // 10 + + for i in 2:7 + a_imp[i, 1] = a_imp[i, 2] + end + b[1] = b[2] + a_exp[2, 1] = c[2] + a_exp[7, 1] = a_exp[7, 2] + c[1] = 0 + c[7] = 1 + + IMEXTableau(; + a_exp = SArray{Tuple{7, 7}}(a_exp), + b_exp = SArray{Tuple{7}}(b), + c_exp = SArray{Tuple{7}}(c), + a_imp = SArray{Tuple{7, 7}}(a_imp), + b_imp = SArray{Tuple{7}}(b), + c_imp = SArray{Tuple{7}}(c), + ) +end + +""" + ARK548L2SA2 + +An IMEX ARK algorithm from [KC2019](@cite), Table 8, with 7 implicit stages, 8 +explicit stages, and 5th order accuracy. Written as *ARK5(4)8L[2]SA₂* in the +paper. +""" +struct ARK548L2SA2 <: IMEXARKAlgorithmName end +function IMEXTableau(::ARK548L2SA2) + a_exp = zeros(Rational{Int64}, 8, 8) + a_imp = zeros(Rational{Int64}, 8, 8) + b = zeros(Rational{Int64}, 8) + c = zeros(Rational{Int64}, 8) + + γ = 2 // 9 + for i in 2:8 + a_imp[i, i] = γ + end + + a_imp[3, 2] = 2366667076620 // 8822750406821 + a_imp[4, 2] = -257962897183 // 4451812247028 + a_imp[4, 3] = 128530224461 // 14379561246022 + a_imp[5, 2] = -486229321650 // 11227943450093 + a_imp[5, 3] = -225633144460 // 6633558740617 + a_imp[5, 4] = 1741320951451 // 6824444397158 + a_imp[6, 2] = 621307788657 // 4714163060173 + a_imp[6, 3] = -125196015625 // 3866852212004 + a_imp[6, 4] = 940440206406 // 7593089888465 + a_imp[6, 5] = 961109811699 // 6734810228204 + a_imp[7, 2] = 2036305566805 // 6583108094622 + a_imp[7, 3] = -3039402635899 // 4450598839912 + a_imp[7, 4] = -1829510709469 // 31102090912115 + a_imp[7, 5] = -286320471013 // 6931253422520 + a_imp[7, 6] = 8651533662697 // 9642993110008 + + a_exp[3, 1] = 1 // 9 + a_exp[3, 2] = 1183333538310 // 1827251437969 + a_exp[4, 1] = 895379019517 // 9750411845327 + a_exp[4, 2] = 477606656805 // 13473228687314 + a_exp[4, 3] = -112564739183 // 9373365219272 + a_exp[5, 1] = -4458043123994 // 13015289567637 + a_exp[5, 2] = -2500665203865 // 9342069639922 + a_exp[5, 3] = 983347055801 // 8893519644487 + a_exp[5, 4] = 2185051477207 // 2551468980502 + a_exp[6, 1] = -167316361917 // 17121522574472 + a_exp[6, 2] = 1605541814917 // 7619724128744 + a_exp[6, 3] = 991021770328 // 13052792161721 + a_exp[6, 4] = 2342280609577 // 11279663441611 + a_exp[6, 5] = 3012424348531 // 12792462456678 + a_exp[7, 1] = 6680998715867 // 14310383562358 + a_exp[7, 2] = 5029118570809 // 3897454228471 + a_exp[7, 3] = 2415062538259 // 6382199904604 + a_exp[7, 4] = -3924368632305 // 6964820224454 + a_exp[7, 5] = -4331110370267 // 15021686902756 + a_exp[7, 6] = -3944303808049 // 11994238218192 + a_exp[8, 1] = 2193717860234 // 3570523412979 + a_exp[8, 2] = a_exp[8, 1] + a_exp[8, 3] = 5952760925747 // 18750164281544 + a_exp[8, 4] = -4412967128996 // 6196664114337 + a_exp[8, 5] = 4151782504231 // 36106512998704 + a_exp[8, 6] = 572599549169 // 6265429158920 + a_exp[8, 7] = -457874356192 // 11306498036315 + + b[2] = 0 + b[3] = 3517720773327 // 20256071687669 + b[4] = 4569610470461 // 17934693873752 + b[5] = 2819471173109 // 11655438449929 + b[6] = 3296210113763 // 10722700128969 + b[7] = -1142099968913 // 5710983926999 + + c[2] = 4 // 9 + c[3] = 6456083330201 // 8509243623797 + c[4] = 1632083962415 // 14158861528103 + c[5] = 6365430648612 // 17842476412687 + c[6] = 18 // 25 + c[7] = 191 // 200 + + for i in 2:8 + a_imp[i, 1] = a_imp[i, 2] + end + b[1] = b[2] + b[8] = γ + for i in 1:8 + a_imp[8, i] = b[i] + end + a_exp[2, 1] = c[2] + a_exp[8, 1] = a_exp[8, 2] + c[1] = 0 + c[8] = 1 + + IMEXTableau(; + a_exp = SArray{Tuple{8, 8}}(a_exp), + b_exp = SArray{Tuple{8}}(b), + c_exp = SArray{Tuple{8}}(c), + a_imp = SArray{Tuple{8, 8}}(a_imp), + b_imp = SArray{Tuple{8}}(b), + c_imp = SArray{Tuple{8}}(c), ) end diff --git a/src/solvers/mis.jl b/src/solvers/mis.jl index 77ce85c4..5f44e1da 100644 --- a/src/solvers/mis.jl +++ b/src/solvers/mis.jl @@ -63,12 +63,12 @@ function init_cache( tab = tableau(alg, eltype(prob.u0)) N = length(tab.d) - ΔU = ntuple(n -> similar(prob.u0), N) + ΔU = ntuple(n -> zero(prob.u0), N) # at time i, contains # ΔU[j] = U[j] - u j < i # ΔU[i] = U[i] # ΔU[N] = offset vec - F = ntuple(n -> similar(prob.u0), N) + F = ntuple(n -> zero(prob.u0), N) return MultirateInfinitesimalStepCache(ΔU, F, tab) end @@ -95,21 +95,37 @@ function update_inner!(innerinteg, outercache::MultirateInfinitesimalStepCache, innerinteg.u = i == N ? u : ΔU[i] groupsize = 256 - event = Event(array_device(u)) - event = mis_update!(array_device(u), groupsize)( - u, - ΔU, - F, - innerinteg.u, - f_offset.x, - outercache.tableau, # TODO: verify correctness - i, - N, - dt; - ndrange = length(u), - dependencies = (event,), - ) - wait(array_device(u), event) + if isdefined(KernelAbstractions, :Event) + event = Event(array_device(u)) + event = mis_update!(array_device(u), groupsize)( + u, + ΔU, + F, + innerinteg.u, + f_offset.x, + outercache.tableau, # TODO: verify correctness + i, + N, + dt; + ndrange = length(u), + dependencies = (event,), + ) + wait(array_device(u), event) + else + mis_update!(array_device(u), groupsize)( + u, + ΔU, + F, + innerinteg.u, + f_offset.x, + outercache.tableau, # TODO: verify correctness + i, + N, + dt; + ndrange = length(u), + ) + KernelAbstractions.synchronize(array_device(u)) + end # KW2014 (9) # evaluate f_fast(z(τ), p, t + c̃[i]*dt + (c[i]-c̃[i])/d[i] * τ) diff --git a/src/solvers/multirate.jl b/src/solvers/multirate.jl index eaef6256..a06d03ec 100644 --- a/src/solvers/multirate.jl +++ b/src/solvers/multirate.jl @@ -24,16 +24,32 @@ struct MultirateCache{OC, II} innerinteg::II end +""" + cts_remake(prob::DiffEqBase.AbstractODEProblem; f::DiffEqBase.AbstractODEFunction) + +Remake an ODE problem with a new function `f`. +""" +function cts_remake(prob::DiffEqBase.AbstractODEProblem; f::DiffEqBase.AbstractODEFunction) + return DiffEqBase.ODEProblem{DiffEqBase.isinplace(prob)}( + f, + prob.u0, + prob.tspan, + prob.p, + prob.problem_type; + prob.kwargs..., + ) +end + function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::Multirate; dt, fast_dt, kwargs...) @assert prob.f isa DiffEqBase.SplitFunction # subproblems - outerprob = DiffEqBase.remake(prob; f = prob.f.f2) + outerprob = cts_remake(prob; f = prob.f.f1) outercache = init_cache(outerprob, alg.slow) innerfun = init_inner(prob, outercache, dt) - innerprob = DiffEqBase.remake(prob; f = innerfun) + innerprob = cts_remake(prob; f = innerfun) innerinteg = DiffEqBase.init(innerprob, alg.fast; dt = fast_dt, kwargs...) return MultirateCache(outercache, innerinteg) end diff --git a/src/solvers/rosenbrock.jl b/src/solvers/rosenbrock.jl index ca871dbb..e99cc5cc 100644 --- a/src/solvers/rosenbrock.jl +++ b/src/solvers/rosenbrock.jl @@ -1,93 +1,236 @@ export SSPKnoth +using StaticArrays +import DiffEqBase +import LinearAlgebra: ldiv!, diagm +import LinearAlgebra -abstract type RosenbrockAlgorithm <: DistributedODEAlgorithm end +abstract type RosenbrockAlgorithmName <: AbstractAlgorithmName end -struct RosenbrockTableau{N, RT, N²} - A::SMatrix{N, N, RT, N²} - C::SMatrix{N, N, RT, N²} - Γ::SMatrix{N, N, RT, N²} - m::SMatrix{N, 1, RT, N} +""" + RosenbrockTableau{N, RT, N²} + +Contains everything that defines a Rosenbrock-type method. + +- N: number of stages. + +Refer to the documentation for the precise meaning of the symbols below. +""" +struct RosenbrockTableau{N, SM <: SMatrix{N, N}, SM1 <: SMatrix{N, 1}} + """A = α Γ⁻¹""" + A::SM + """Tableau used for the time-dependent part""" + α::SM + """Stepping matrix. C = 1/diag(Γ) - Γ⁻¹""" + C::SM + """Substage contribution matrix""" + Γ::SM + """m = b Γ⁻¹, used to compute the increments k""" + m::SM1 +end +n_stages(::RosenbrockTableau{N}) where {N} = N + +function RosenbrockTableau(α::SMatrix{N, N}, Γ::SMatrix{N, N}, b::SMatrix{1, N}) where {N} + A = α / Γ + invΓ = inv(Γ) + diag_invΓ = SMatrix{N, N}(diagm([invΓ[i, i] for i in 1:N])) + # C is diag(γ₁₁⁻¹, γ₂₂⁻¹, ...) - Γ⁻¹ + C = diag_invΓ .- inv(Γ) + m = b / Γ + m′ = convert(SMatrix{N, 1}, m) # Sometimes m is a SMatrix{1, N} matrix. + SM = typeof(A) + SM1 = typeof(m′) + return RosenbrockTableau{N, SM, SM1}(A, α, C, Γ, m′) end -struct RosenbrockCache{Nstages, RT, N², A} - tableau::RosenbrockTableau{Nstages, RT, N²} +""" + RosenbrockAlgorithm(tableau) + +Constructs a Rosenbrock algorithm for solving ODEs. +""" +struct RosenbrockAlgorithm{T <: RosenbrockTableau} <: ClimaTimeSteppers.DistributedODEAlgorithm + tableau::T +end + +""" + RosenbrockCache{N, A, WT} + +Contains everything that is needed to run a Rosenbrock-type method. + +- Nstages: number of stages, +- A: type of the evolved state (e.g., a ClimaCore.FieldVector), +- WT: type of the Jacobian (Wfact) +""" +struct RosenbrockCache{Nstages, A, WT} + """Preallocated space for the state""" U::A + + """Preallocated space for the tendency""" fU::A + + """Preallocated space for the implicit contribution to the tendency""" + fU_imp::A + + """Preallocated space for the explicit contribution to the tendency""" + fU_exp::A + + """Preallocated space for the limited contribution to the tendency""" + fU_lim::A + + """Contributions to the state for each stage""" k::NTuple{Nstages, A} - W::Any - linsolve!::Any + + """Preallocated space for the Wfact, dtγJ - 𝕀, or Wfact_t, 𝕀/dtγ - J, with J the Jacobian of the implicit tendency""" + W::WT + + """Preallocated space for the explicit time derivative of the tendency""" + ∂Y∂t::A end function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::RosenbrockAlgorithm; kwargs...) - - tab = tableau(alg, eltype(prob.u0)) - Nstages = length(tab.m) + Nstages = length(alg.tableau.m) U = zero(prob.u0) fU = zero(prob.u0) - k = ntuple(n -> similar(prob.u0), Nstages) - W = prob.f.jac_prototype - linsolve! = alg.linsolve(Val{:init}, W, prob.u0; kwargs...) - - return RosenbrockCache(tab, U, fU, k, W, linsolve!) + fU_imp = zero(prob.u0) + fU_exp = zero(prob.u0) + fU_lim = zero(prob.u0) + ∂Y∂t = zero(prob.u0) + k = ntuple(n -> zero(prob.u0), Nstages) + if !isnothing(prob.f.T_imp!) + W = prob.f.T_imp!.jac_prototype + else + W = nothing + end + return RosenbrockCache{Nstages, typeof(U), typeof(W)}(U, fU, fU_imp, fU_exp, fU_lim, k, W, ∂Y∂t) end +""" + step_u!(int, cache::RosenbrockCache{Nstages}) + +Take one step with the Rosenbrock-method with the given `cache`. -function step_u!(int, cache::RosenbrockCache{Nstages, RT}) where {Nstages, RT} - (; m, Γ, A, C) = cache.tableau +Some choices are being made here. Most of these are empirically motivated and should be +revisited on different problems. +- We do not update dtγ across stages +- We apply DSS to the sum of the explicit and implicit tendency at all the stages but the last +- We apply DSS to incremented state (ie, after the final stage is applied) +""" +function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages} + (; m, Γ, A, α, C) = int.alg.tableau (; u, p, t, dt) = int - (; W, U, fU, k, linsolve!) = cache - f! = int.sol.prob.f - Wfact_t! = int.sol.prob.f.Wfact_t - - # 1) compute jacobian factorization - γ = dt * Γ[1, 1] - Wfact_t!(W, u, p, γ, t) - for i in 1:Nstages + (; W, U, fU, fU_imp, fU_exp, fU_lim, k, ∂Y∂t) = cache + T_imp! = int.sol.prob.f.T_imp! + T_exp! = int.sol.prob.f.T_exp! + T_exp_lim! = int.sol.prob.f.T_exp_T_lim! + tgrad! = isnothing(T_imp!) ? nothing : T_imp!.tgrad + + (; post_explicit!, post_implicit!, dss!) = int.sol.prob.f + + # TODO: This is only valid when Γ[i, i] is constant, otherwise we have to + # move this in the for loop + @inbounds dtγ = dt * Γ[1, 1] + + if !isnothing(T_imp!) + Wfact! = int.sol.prob.f.T_imp!.Wfact + Wfact!(W, u, p, dtγ, t) + end + + if !isnothing(tgrad!) + tgrad!(∂Y∂t, u, p, t) + end + + @inbounds for i in 1:Nstages + # Reset tendency + fill!(fU, 0) + + αi = sum(α[i, 1:(i - 1)]; init = zero(eltype(α))) + γi = sum(Γ[i, 1:i]; init = zero(eltype(Γ))) + U .= u for j in 1:(i - 1) U .+= A[i, j] .* k[j] end - # TODO: there should be a time modification here (t + c * dt) - # if f does depend on time, would need to add tgrad term as well - f!(fU, U, p, t) + + # NOTE: post_implicit! is a misnomer + if !isnothing(post_implicit!) + # We apply DSS and update p on every stage but the first, and at the + # end of each timestep. Since the first stage is unchanged from the + # end of the previous timestep, this order of operations ensures + # that the state is always continuous and that p is consistent with + # the state, including between timesteps. + (i != 1) && dss!(U, p, t + αi * dt) + (i != 1) && post_implicit!(U, p, t + αi * dt) + end + + if !isnothing(T_imp!) + T_imp!(fU_imp, U, p, t + αi * dt) + fU .+= fU_imp + end + + if !isnothing(T_exp!) + T_exp!(fU_exp, U, p, t + αi * dt) + fU .+= fU_exp + end + + if !isnothing(T_exp_lim!) + T_exp_lim!(fU_exp, fU_lim, U, p, t + αi * dt) + fU .+= fU_exp + fU .+= fU_lim + end + + if !isnothing(tgrad!) + fU .+= γi .* dt .* ∂Y∂t + end + for j in 1:(i - 1) fU .+= (C[i, j] / dt) .* k[j] end - linsolve!(k[i], W, fU) + + fU .*= -dtγ + + if !isnothing(T_imp!) + if W isa Matrix + ldiv!(k[i], lu(W), fU) + else + ldiv!(k[i], W, fU) + end + else + k[i] .= .-fU + end end - for i in 1:Nstages + + @inbounds for i in 1:Nstages u .+= m[i] .* k[i] end -end -struct SSPKnoth{L} <: RosenbrockAlgorithm - linsolve::L + dss!(u, p, t + dt) + post_implicit!(u, p, t + dt) + return nothing end -SSPKnoth(; linsolve) = SSPKnoth(linsolve) +""" + SSPKnoth + +`SSPKnoth` is a second-order Rosenbrock method developed by Oswald Knoth. + +The coefficients are the same as in `CGDycore.jl`, except that for C we add the +diagonal elements too. Note, however, that the elements on the diagonal of C do +not really matter because C is only used in its lower triangular part. We add them +mostly to match literature on the subject +""" +struct SSPKnoth <: RosenbrockAlgorithmName end -function tableau(::SSPKnoth, RT) - # ROS.transformed=true; +function tableau(::SSPKnoth) N = 3 - N² = N * N - α = @SMatrix RT[ + α = @SMatrix [ 0 0 0 1 0 0 1/4 1/4 0 ] - # ROS.d=ROS.alpha*ones(ROS.nStage,1); - b = @SMatrix RT[1 / 6 1 / 6 2 / 3] - Γ = @SMatrix RT[ + b = @SMatrix [1 / 6 1 / 6 2 / 3] + Γ = @SMatrix [ 1 0 0 0 1 0 -3/4 -3/4 1 ] - A = α / Γ - C = -inv(Γ) - m = b / Γ - return RosenbrockTableau{N, RT, N²}(A, C, Γ, m) - # ROS.SSP.alpha=[1 0 0 - # 3/4 1/4 0 - # 1/3 0 2/3]; - + return RosenbrockTableau(α, Γ, b) end diff --git a/src/solvers/wickerskamarock.jl b/src/solvers/wickerskamarock.jl index b1b8568a..efc1616b 100644 --- a/src/solvers/wickerskamarock.jl +++ b/src/solvers/wickerskamarock.jl @@ -26,8 +26,8 @@ struct WickerSkamarockRungeKuttaCache{T <: WickerSkamarockRungeKuttaTableau, A} F::A end function init_cache(prob::DiffEqBase.ODEProblem, alg::WickerSkamarockRungeKutta; kwargs...) - U = similar(prob.u0) - F = similar(prob.u0) + U = zero(prob.u0) + F = zero(prob.u0) return WickerSkamarockRungeKuttaCache(tableau(alg, eltype(F)), U, F) end diff --git a/src/utilities/convergence_checker.jl b/src/utilities/convergence_checker.jl index cb49c286..15cee18e 100644 --- a/src/utilities/convergence_checker.jl +++ b/src/utilities/convergence_checker.jl @@ -12,7 +12,7 @@ using LinearAlgebra: norm Checks whether a sequence `val[0], val[1], val[2], ...` has converged to some limit `L`, given the errors `err[iter] = val[iter] .- L`. This is done by -calling `check_convergence!(::ConvergenceChecker, cache, val, err, iter)`, where +calling `is_converged!(::ConvergenceChecker, cache, val, err, iter)`, where `val = val[iter]` and `err = err[iter]`. If the value of `L` is not known, `err` can be an approximation of `err[iter]`. The `cache` for a `ConvergenceChecker` can be obtained with `allocate_cache(::ConvergenceChecker, val_prototype)`, @@ -68,7 +68,9 @@ function has_component_converged(alg, cache, val, err, iter) return all(component_bools) end -function check_convergence!(alg::ConvergenceChecker, cache, val, err, iter) +is_converged!(alg::Nothing, cache, val, err, iter) = false + +function is_converged!(alg::ConvergenceChecker, cache, val, err, iter) (; norm_condition, component_condition, condition_combiner, norm) = alg (; norm_cache, component_cache) = cache if isnothing(norm_condition) diff --git a/src/utilities/convergence_condition.jl b/src/utilities/convergence_condition.jl index 46ed9ff1..c9d6ed6a 100644 --- a/src/utilities/convergence_condition.jl +++ b/src/utilities/convergence_condition.jl @@ -92,13 +92,18 @@ updated_cache((; rate, order)::MinimumRateOfConvergence, cache, val, err, iter) Checks multiple `ConvergenceCondition`s, combining their results with either `all` or `any`. """ -struct MultipleConditions{CC <: Union{typeof(all), typeof(any)}, C <: Tuple{Vararg{<:ConvergenceCondition}}} <: - ConvergenceCondition +struct MultipleConditions{CC, C} <: ConvergenceCondition condition_combiner::CC conditions::C + function MultipleConditions( + condition_combiner::Union{typeof(all), typeof(any)}, + conditions::Vararg{ConvergenceCondition}, + ) + return new{typeof(condition_combiner), typeof(conditions)}(condition_combiner, conditions) + end end -MultipleConditions(condition_combiner::Union{typeof(all), typeof(any)} = all, conditions::ConvergenceCondition...) = - MultipleConditions(condition_combiner, conditions) + +MultipleConditions(conditions::ConvergenceCondition...) = MultipleConditions(all, conditions) cache_type((; conditions)::MultipleConditions, ::Type{FT}) where {FT} = Tuple{map(condition -> cache_type(condition, FT), conditions)...} has_converged(alg::MultipleConditions, caches, val, err, iter) = alg.condition_combiner( diff --git a/src/utilities/update_signal_handler.jl b/src/utilities/update_signal_handler.jl index e1ad4304..63d02388 100644 --- a/src/utilities/update_signal_handler.jl +++ b/src/utilities/update_signal_handler.jl @@ -4,24 +4,21 @@ export UpdateSignalHandler, UpdateEvery, UpdateEveryN, UpdateEveryDt """ UpdateSignal -A signal that gets passed to an `UpdateSignalHandler` whenever a certain -operation is performed. +A signal that gets passed to an `UpdateSignalHandler` +whenever a certain operation is performed. """ abstract type UpdateSignal end """ UpdateSignalHandler -A boolean indicating if updates a value upon receiving an appropriate -`UpdateSignal`. This is done by calling -`needs_update!(::UpdateSignalHandler, cache, ::UpdateSignal)`. - -The `cache` can be obtained with `allocate_cache(::UpdateSignalHandler, FT)`, -where `FT` is the floating-point type of the integrator. +A boolean indicating if updates a value upon receiving +an appropriate `UpdateSignal`. This is done by calling +`needs_update!(::UpdateSignalHandler, ::UpdateSignal)`. """ abstract type UpdateSignalHandler end -needs_update!(::UpdateSignalHandler, cache, ::UpdateSignal) = false +needs_update!(::UpdateSignalHandler, ::UpdateSignal) = false """ NewTimeStep(t) @@ -56,9 +53,7 @@ An `UpdateSignalHandler` that performs the update whenever it is `needs_update!` struct UpdateEvery{U <: UpdateSignal} <: UpdateSignalHandler end UpdateEvery(::Type{U}) where {U} = UpdateEvery{U}() -allocate_cache(::UpdateEvery, _) = nothing - -needs_update!(alg::UpdateEvery{U}, cache, ::U) where {U <: UpdateSignal} = true +needs_update!(alg::UpdateEvery{U}, ::U) where {U <: UpdateSignal} = true """ UpdateEveryN(n, update_signal_type, reset_signal_type = Nothing) @@ -69,16 +64,14 @@ specified, then the counter (which gets incremented from 0 to `n` and then gets reset to 0 when it is time to perform another update) is reset to 0 whenever the signal handler is `needs_update!` with an `UpdateSignal` of type `reset_signal_type`. """ -struct UpdateEveryN{U <: UpdateSignal, R <: Union{Nothing, UpdateSignal}} <: UpdateSignalHandler +struct UpdateEveryN{U <: UpdateSignal, C, R <: Union{Nothing, UpdateSignal}} <: UpdateSignalHandler n::Int + counter::C end -UpdateEveryN(n, ::Type{U}, ::Type{R} = Nothing) where {U, R} = UpdateEveryN{U, R}(n) +UpdateEveryN(n, ::Type{U}, ::Type{R} = Nothing) where {U, R} = UpdateEveryN{U, typeof(Ref(0)), R}(n, Ref(0)) -allocate_cache(::UpdateEveryN, _) = (; counter = Ref(0)) - -function needs_update!(alg::UpdateEveryN{U}, cache, ::U) where {U <: UpdateSignal} - (; n) = alg - (; counter) = cache +function needs_update!(alg::UpdateEveryN{U}, ::U) where {U <: UpdateSignal} + (; n, counter) = alg result = counter[] == 0 counter[] += 1 if counter[] == n @@ -86,14 +79,14 @@ function needs_update!(alg::UpdateEveryN{U}, cache, ::U) where {U <: UpdateSigna end return result end -function needs_update!(alg::UpdateEveryN{U, R}, cache, ::R) where {U, R <: UpdateSignal} - (; counter) = cache +function needs_update!(alg::UpdateEveryN{U, R}, ::R) where {U, R <: UpdateSignal} + (; counter) = alg counter[] = 0 return false end # Account for method ambiguitiy: -needs_update!(::UpdateEveryN{U, U}, cache, ::U) where {U <: UpdateSignal} = +needs_update!(::UpdateEveryN{U, U}, ::U) where {U <: UpdateSignal} = error("Reset and update signal types cannot be the same.") """ @@ -103,16 +96,15 @@ An `UpdateSignalHandler` that performs the update whenever it is `needs_update!` `UpdateSignal` of type `NewTimeStep` and the difference between the current time and the previous update time is no less than `dt`. """ -struct UpdateEveryDt{T} <: UpdateSignalHandler +struct UpdateEveryDt{T, BR, FTR} <: UpdateSignalHandler dt::T + is_first_t::BR + prev_update_t::FTR end +UpdateEveryDt(dt::Type{FT}) where {FT} = UpdateEveryDt(dt, Ref(true), Ref{FT}()) -# TODO: This assumes that typeof(t) == FT, which might not always be correct. -allocate_cache(alg::UpdateEveryDt, ::Type{FT}) where {FT} = (; is_first_t = Ref(true), prev_update_t = Ref{FT}()) - -function needs_update!(alg::UpdateEveryDt, cache, signal::NewTimeStep) - (; dt) = alg - (; is_first_t, prev_update_t) = cache +function needs_update!(alg::UpdateEveryDt, signal::NewTimeStep) + (; dt, is_first_t, prev_update_t) = alg (; t) = signal result = false if is_first_t[] || abs(t - prev_update_t[]) >= dt