Skip to content

Commit

Permalink
Use CTS src from main
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 22, 2025
1 parent 2d18ef2 commit 449f5ee
Show file tree
Hide file tree
Showing 17 changed files with 1,174 additions and 529 deletions.
51 changes: 46 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,28 +1,69 @@
name = "ClimaTimeSteppers"
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
authors = ["Climate Modeling Alliance"]
version = "0.7.0"
version = "0.7.39"

[deps]
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"

[extensions]
ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"]

[compat]
Aqua = "0.8"
BenchmarkTools = "1"
ClimaComms = "0.4, 0.5, 0.6"
ClimaCore = "0.10, 0.11, 0.12, 0.13, 0.14"
CUDA = "3, 4, 5"
Colors = "0.12, 0.13"
DataStructures = "0.18"
DiffEqBase = "6"
DiffEqCallbacks = "2"
Distributions = "0.25"
KernelAbstractions = "0.7, 0.8, 0.9"
Krylov = "0.8, 0.9"
LinearAlgebra = "1"
LinearOperators = "2"
SciMLBase = "1"
MPI = "0.20"
NVTX = "0.3"
ODEConvergenceTester = "0.2"
OrderedCollections = "1"
PrettyTables = "2"
Random = "1"
SafeTestsets = "0.1"
SciMLBase = "1, 2"
StaticArrays = "1"
julia = "1.8"
StatsBase = "0.33, 0.34"
Test = "1"
julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
ODEConvergenceTester = "42a5c2e1-f365-4540-8ca5-3684de3ecd95"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua","ClimaCore","Distributions","Krylov", "MPI","ODEConvergenceTester","PrettyTables","Random","SafeTestsets","Test"]
30 changes: 12 additions & 18 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,20 @@ version = "0.14.19"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"

[[deps.ClimaTimeSteppers]]
deps = ["ClimaComms", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "SciMLBase", "StaticArrays"]
deps = ["ClimaComms", "Colors", "DataStructures", "DiffEqBase", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "NVTX", "SciMLBase", "StaticArrays"]
path = ".."
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
version = "0.7.0"
version = "0.7.39"

[deps.ClimaTimeSteppers.extensions]
ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"]

[deps.ClimaTimeSteppers.weakdeps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[[deps.CloseOpenIntervals]]
deps = ["Static", "StaticArrayInterface"]
Expand Down Expand Up @@ -415,16 +425,6 @@ version = "6.130.0"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[[deps.DiffEqCallbacks]]
deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "Functors", "LinearAlgebra", "Markdown", "NLsolve", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArraysCore"]
git-tree-sha1 = "d0b94b3694d55e7eedeee918e7daee9e3b873399"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
version = "2.35.0"

[deps.DiffEqCallbacks.weakdeps]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"

[[deps.DiffResults]]
deps = ["StaticArraysCore"]
git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621"
Expand Down Expand Up @@ -651,12 +651,6 @@ git-tree-sha1 = "b104d487b34566608f8b4e1c39fb0b10aa279ff8"
uuid = "77dc65aa-8811-40c2-897b-53d922fa7daf"
version = "0.1.3"

[[deps.Functors]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.4.12"

[[deps.Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
Expand Down
39 changes: 38 additions & 1 deletion docs/src/dev/report_gen.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ClimaTimeSteppers
import ClimaTimeSteppers as CTS
using Test
using InteractiveUtils: subtypes

Expand All @@ -11,7 +12,43 @@ all_subtypes(::Type{T}) where {T} = isabstracttype(T) ? vcat(all_subtypes.(subty

@testset "IMEX Algorithm Convergence" begin
title = "IMEX Algorithms"
algorithm_names = map(T -> T(), all_subtypes(ClimaTimeSteppers.IMEXAlgorithmName))
# algorithm_names = map(T -> T(), all_subtypes(ClimaTimeSteppers.IMEXAlgorithmName))
algorithm_names = [
# CTS.SSP22Heuns(),
# CTS.SSP33ShuOsher(),
# CTS.RK4(),
# CTS.ARK2GKC(),
CTS.ARS111(),
CTS.ARS121(),
CTS.ARS122(),
CTS.ARS222(),
CTS.ARS232(),
CTS.ARS233(),
CTS.ARS343(),
CTS.ARS443(),
CTS.SSP222(),
CTS.SSP322(),
CTS.SSP332(),
CTS.SSP333(),
CTS.SSP433(),
CTS.DBM453(),
CTS.HOMMEM1(),
CTS.IMKG232a(),
CTS.IMKG232b(),
CTS.IMKG242a(),
CTS.IMKG242b(),
CTS.IMKG243a(),
CTS.IMKG252a(),
CTS.IMKG252b(),
CTS.IMKG253a(),
CTS.IMKG253b(),
CTS.IMKG254a(),
CTS.IMKG254b(),
CTS.IMKG254c(),
CTS.IMKG342a(),
CTS.IMKG343a(),
# CTS.SSPKnoth()
]
test_imex_algorithms(title, algorithm_names, ark_analytic_nonlin_test_cts(Float64), 200)
test_imex_algorithms(title, algorithm_names, ark_analytic_sys_test_cts(Float64), 400)
test_imex_algorithms(title, algorithm_names, ark_analytic_test_cts(Float64), 40000; super_convergence = (ARS121(),))
Expand Down
13 changes: 7 additions & 6 deletions src/Callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ A suite of callback functions to be used with the ClimaTimeSteppers.jl ODE solve
module Callbacks

import ClimaComms, DiffEqBase
import SciMLBase

"""
ClimaTimeSteppers.Callbacks.initialize!(f!::F, integrator)
Expand Down Expand Up @@ -72,9 +73,9 @@ function EveryXWallTimeSeconds(f!, Δwt, comm_ctx::ClimaComms.AbstractCommsConte
end

if isdefined(DiffEqBase, :finalize!)
DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
else
DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize)
end
end

Expand Down Expand Up @@ -115,9 +116,9 @@ function EveryXSimulationTime(f!, Δt; atinit = false)
end
end
if isdefined(DiffEqBase, :finalize!)
DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
else
DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize)
end
end

Expand Down Expand Up @@ -159,9 +160,9 @@ function EveryXSimulationSteps(f!, Δsteps; atinit = false)
end

if isdefined(DiffEqBase, :finalize!)
DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize, finalize = _finalize)
else
DiffEqBase.DiscreteCallback(condition, f!; initialize = _initialize)
SciMLBase.DiscreteCallback(condition, f!; initialize = _initialize)
end
end

Expand Down
35 changes: 22 additions & 13 deletions src/ClimaTimeSteppers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ JuliaDiffEq terminology:
SplitODEProlem(fL, fR)
* `ODEProblem` from OrdinaryDiffEq.jl
* `ODEProblem` from SciMLBase.jl
- use `jac` option to `ODEFunction` for linear + full IMEX (https://docs.sciml.ai/latest/features/performance_overloads/#ode_explicit_jac-1)
* `SplitODEProblem` for linear + remainder IMEX
* `MultirateODEProblem` for true multirate
Expand All @@ -45,22 +45,24 @@ module ClimaTimeSteppers


using KernelAbstractions
using KernelAbstractions.Extras: @unroll
using LinearAlgebra
using LinearOperators
using StaticArrays
import ClimaComms
# using Colors
using NVTX

export AbstractAlgorithmName, AbstractAlgorithmConstraint, Unconstrained, SSPConstrained
export AbstractAlgorithmName, AbstractAlgorithmConstraint, Unconstrained, SSP

array_device(::Union{Array, SArray, MArray}) = CPU()
realview(x::Union{Array, SArray, MArray}) = x
realview(x::Array) = x
array_device(x) = CUDADevice() # assume CUDA

import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov
import DiffEqBase, SciMLBase, LinearAlgebra, Krylov

include(joinpath("utilities", "sparse_coeffs.jl"))
include(joinpath("utilities", "fused_increment.jl"))
include("sparse_containers.jl")
include("functions.jl")
include("operators.jl")

abstract type DistributedODEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end

Expand All @@ -69,7 +71,7 @@ abstract type AbstractAlgorithmName end
"""
AbstractAlgorithmConstraint
A mechanism for restricting which operations can be performed by an algorithm
A mechanism for constraining which operations can be performed by an algorithm
for solving ODEs.
For example, an unconstrained algorithm might compute a Runge-Kutta stage by
Expand All @@ -87,15 +89,17 @@ Indicates that an algorithm may perform any supported operations.
"""
struct Unconstrained <: AbstractAlgorithmConstraint end

default_constraint(::AbstractAlgorithmName) = Unconstrained()

"""
SSPConstrained
SSP
Indicates that an algorithm must be "strong stability preserving", which makes
it easier to guarantee that the algorithm will preserve monotonicity properties
satisfied by the initial state. For example, this ensures that the algorithm
will be able to use limiters in a mathematically consistent way.
"""
struct SSPConstrained <: AbstractAlgorithmConstraint end
struct SSP <: AbstractAlgorithmConstraint end

SciMLBase.allowscomplex(alg::DistributedODEAlgorithm) = true
include("integrators.jl")
Expand All @@ -110,17 +114,22 @@ n_stages_ntuple(::Type{<:NTuple{Nstages}}) where {Nstages} = Nstages
n_stages_ntuple(::Type{<:SVector{Nstages}}) where {Nstages} = Nstages

# Include concrete implementations
const SPCO = SparseCoeffs

include("solvers/imex_tableaus.jl")
include("solvers/explicit_tableaus.jl")
include("solvers/imex_ark.jl")
include("solvers/imex_ssp.jl")
include("solvers/imex_ssprk.jl")
include("solvers/multirate.jl")
include("solvers/lsrk.jl")
include("solvers/ssprk.jl")
include("solvers/ark.jl")
include("solvers/mis.jl")
include("solvers/wickerskamarock.jl")
include("solvers/rosenbrock.jl")

include("Callbacks.jl")


benchmark_step(integrator, device) =
@warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension"

end
49 changes: 39 additions & 10 deletions src/functions.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,47 @@
import DiffEqBase
export AbstractClimaODEFunction
export ClimaODEFunction, ForwardEulerODEFunction

Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, S} <: DiffEqBase.AbstractODEFunction{true}
T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ...
T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ...
T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ...
lim!::L = (u, p, t, u_ref) -> nothing
dss!::D = (u, p, t) -> nothing
stage_callback!::S = (u, p, t) -> nothing
abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end

struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
T_exp_T_lim!::TEL
T_lim!::TL
T_exp!::TE
T_imp!::TI
lim!::L
dss!::D
post_explicit!::PE
post_implicit!::PI
function ClimaODEFunction(;
T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ...
T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ...
T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ...
lim! = (u, p, t, u_ref) -> nothing,
dss! = (u, p, t) -> nothing,
post_explicit! = (u, p, t) -> nothing,
post_implicit! = (u, p, t) -> nothing,
)
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!)

if !isnothing(T_exp_T_lim!)
@assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`"
@assert isnothing(T_lim!) "`T_exp_T_lim!` was passed, `T_lim!` must be `nothing`"
end
if !isnothing(T_exp!) && !isnothing(T_lim!)
@warn "Both `T_exp!` and `T_lim!` are not `nothing`, please use `T_exp_T_lim!` instead."
end
return new{typeof.(args)...}(args...)
end
end

# Don't wrap a ClimaODEFunction in an ODEFunction (makes ODEProblem work).
DiffEqBase.ODEFunction{iip}(f::ClimaODEFunction) where {iip} = f
DiffEqBase.ODEFunction(f::ClimaODEFunction) = f
has_T_exp(f::ClimaODEFunction) = !isnothing(f.T_exp!) || !isnothing(f.T_exp_T_lim!)
has_T_lim(f::ClimaODEFunction) = !isnothing(f.lim!) && (!isnothing(f.T_lim!) || !isnothing(f.T_exp_T_lim!))

# Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work).
DiffEqBase.ODEFunction{iip}(f::AbstractClimaODEFunction) where {iip} = f
DiffEqBase.ODEFunction(f::AbstractClimaODEFunction) = f

"""
ForwardEulerODEFunction(f; jac_prototype, Wfact, tgrad)
Expand Down
Loading

0 comments on commit 449f5ee

Please sign in to comment.