Skip to content

Commit

Permalink
Merge pull request #63 from PALEOtoolkit/solver_recompile_fixes
Browse files Browse the repository at this point in the history
Fixes to avoid spurious recompilation when using steady-state solvers
  • Loading branch information
sjdaines authored Aug 6, 2023
2 parents 1495c04 + 06cee06 commit 0a12dfa
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 27 deletions.
10 changes: 10 additions & 0 deletions docs/src/PALEOmodelSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ steadystate_ptc
steadystate_ptc_splitdae
```

Function objects to project Newton steps into valid regions:

```@meta
CurrentModule = PALEOmodel.SolverFunctions
```
```@docs
ClampAll!
ClampAll
```

## Steady-state solvers (Sundials Kinsol based):
```@meta
CurrentModule = PALEOmodel.SteadyStateKinsol
Expand Down
4 changes: 2 additions & 2 deletions src/NonLinearNewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ function solve(
maxiters::Integer=100,
verbose::Integer=0,
jac_constant::Bool=false,
project_region = x->x,
project_region = identity,
) where {F, J}

u = copy(u0)
residual = func(u0)
residual = func(u)
Lnorm_2 = LinearAlgebra.norm(residual, 2)
Lnorm_inf = LinearAlgebra.norm(residual, Inf)
iters = 0
Expand Down
28 changes: 28 additions & 0 deletions src/SolverFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,34 @@ import SparseDiffTools

# import Infiltrator # Julia debugger

"""
ca! = ClampAll!(minvalue, maxvalue)
ca!(v)
Function object to clamp all values in Vector `v` to specified range using
`clamp!(v, minvalue, maxvalue)` (in-place, mutating version)
"""
struct ClampAll!
minvalue::Float64
maxvalue::Float64
end

(ca::ClampAll!)(v) = clamp!(v, ca.minvalue, ca.maxvalue)

"""
ca = ClampAll(minvalue, maxvalue)
ca(v) -> v
Function object to clamp all values in Vector `v` to specified range using
`clamp.(v, minvalue, maxvalue)` (out-of-place version)
"""
struct ClampAll
minvalue::Float64
maxvalue::Float64
end

(ca::ClampAll)(v) = clamp.(v, ca.minvalue, ca.maxvalue)

"""
ModelODE(
modeldata;
Expand Down
35 changes: 22 additions & 13 deletions src/SplitDAE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import ..SparseUtils
operatorID_inner=0,
transfer_inner_vars=["tmid", "volume", "ntotal", "Abox"], # additional Variables needed by 'inner' Reactions
inner_jac_ad=:ForwardDiff: # form of automatic differentiation to use for Jacobian for inner solver (options `:ForwardDiff`, `:ForwardDiffSparse`)
inner_start_initial=true, # true to use initial value of inner variables as start value, false to use last solution as initial value
inner_kwargs=(verbose=0, miniters=2, reltol=1e-12, jac_constant=true, project_region=x->x),
inner_start=:initial, # :initial to use initial value of inner variables as start value, :current to use last solution as initial value, :zero to use 0.0
inner_kwargs=(verbose=0, miniters=2, reltol=1e-12, jac_constant=true, project_region=identity),
generated_dispatch=true,
) -> (ms::ModelSplitDAE, initial_state_outer, jac_outer_prototype)
Expand Down Expand Up @@ -62,8 +62,8 @@ function create_split_dae(
operatorID_inner=0,
transfer_inner_vars=["tmid", "volume", "ntotal", "Abox"], # additional Variables needed by 'inner' Reactions
inner_jac_ad=:ForwardDiff,
inner_start_initial=true,
inner_kwargs=(verbose=0, miniters=2, reltol=1e-12, jac_constant=true, project_region=x->x),
inner_start=:initial,
@nospecialize(inner_kwargs=(verbose=0, miniters=2, reltol=1e-12, jac_constant=true, project_region=identity)),
generated_dispatch=true,
)

Expand Down Expand Up @@ -328,7 +328,7 @@ function create_split_dae(
[c for c in cellconstraintsidxfull], # narrow_type
[c for c in cellinitialstates], # narrow_type
[c for c in celldGdoutercols], # narrow type
inner_start_initial,
inner_start,
inner_kwargs,
similar(initial_state_outer),
similar(initial_state),
Expand All @@ -353,7 +353,7 @@ Provides functions for outer derivative and Jacobian:
(ms::ModelSplitDAE)(du_outer::AbstractVector, J_outer::SparseArrays.AbstractSparseMatrixCSC, u_outer::AbstractVector, p, t)
"""
struct ModelSplitDAE{T, SVA, DLA, DLR, JF, VA1, VA2, VA3, CD, CJ, IK, LU}
struct ModelSplitDAE{T, SVA, DLA, DLR, JF, VA1, VA2, VA3, CD, CJ, LU}
modeldata::PB.ModelData
solver_view_all::SVA
dispatchlists_all::DLA
Expand All @@ -370,8 +370,8 @@ struct ModelSplitDAE{T, SVA, DLA, DLR, JF, VA1, VA2, VA3, CD, CJ, IK, LU}
cellconstraintsidxfull::Vector{Vector{Int64}}
cellinitialstates::Vector{Vector{Float64}}
celldGdoutercols::Vector{Vector{Int64}}
inner_start_initial::Bool
inner_kwargs::IK
inner_start::Symbol
inner_kwargs # no specialization to avoid recompilation
outer_worksp::Vector{T}
full_worksp::Vector{T}
dG_dcellinner_lu::LU
Expand All @@ -390,6 +390,8 @@ end
#
function (ms::ModelSplitDAE)(du_outer::AbstractVector, u_outer::AbstractVector, p, t)

verbose = ms.inner_kwargs.verbose

# set outer state Variables
# NB: inner state Variables are *not* set, so current values in modeldata arrays will be used
copyto!(ms.va_stateexplicit, u_outer)
Expand All @@ -406,16 +408,21 @@ function (ms::ModelSplitDAE)(du_outer::AbstractVector, u_outer::AbstractVector,
niter_inner_max = -1
for (ci, cd, cj, cs) in PB.IteratorUtils.zipstrict(ms.cellindex, ms.cellderivs, ms.jaccells, ms.cellinitialstates)

if ms.inner_start_initial
if ms.inner_start == :initial
# always start from initial state
copyto!(cd.worksp, cs)
else
elseif ms.inner_start == :current
# use current value (from previous iteration) as starting value
copyto!(cd.worksp, cd.state)
elseif ms.inner_start == :zero
# eg for linear problem
cd.worksp .= 0.0
else
error("ModelSplitDAE invalid inner_start = $inner_start")
end
initial_state = StaticArrays.SVector{ncomps(cd)}(cd.worksp)

ms.inner_kwargs.verbose > 1 && @info "cell index: $ci initial_state: $initial_state"
verbose > 1 && @info "cell index: $ci initial_state: $initial_state"
(_, Lnorm_2_cell, Lnorm_inf_cell, niter) = NonLinearNewton.solve(
cd,
cj,
Expand All @@ -425,7 +432,7 @@ function (ms::ModelSplitDAE)(du_outer::AbstractVector, u_outer::AbstractVector,
niter_inner_tot += niter
niter_inner_max = max(niter, niter_inner_max)
end
ms.inner_kwargs.verbose > 0 && @info " Inner iterations: max $niter_inner_max mean $(niter_inner_tot/length(ms.cellindex))"
verbose > 0 && @info " Inner iterations: max $niter_inner_max mean $(niter_inner_tot/length(ms.cellindex))"

# reevaluate full derivative with updated inner state variables
# use dispatchlists_recalc_deriv if available (an optimization, eg don't need to rerun radiative transfer)
Expand Down Expand Up @@ -558,7 +565,9 @@ end

function (mjfd::ModelJacForwardDiffCell)(x::StaticArrays.SVector)
# TODO ForwardDiff doesn't provide an API to get jacobian without setting Dual number 'tag'
return PALEOmodel.ForwardDiffWorkarounds.vector_mode_jacobian_notag(mjfd.modelderiv, x)
jac = PALEOmodel.ForwardDiffWorkarounds.vector_mode_jacobian_notag(mjfd.modelderiv, x)

return StaticArrays.lu(jac)
end

# calculate lu factorization of sparse Jacobian for a single cell
Expand Down
25 changes: 13 additions & 12 deletions src/SteadyState.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function steadystate(
run, initial_state, modeldata, tss;
outputwriter=run.output,
initial_time=-1.0,
solvekwargs::NamedTuple=NamedTuple{}(),
@nospecialize(solvekwargs::NamedTuple=NamedTuple{}()),
jac_ad=:NoJacobian,
use_norm::Bool=false,
BLAS_num_threads=1,
Expand Down Expand Up @@ -188,7 +188,7 @@ function steadystate_ptc(
deltat_fac=2.0,
tss_output=Float64[],
outputwriter=run.output,
solvekwargs::NamedTuple=NamedTuple{}(),
@nospecialize(solvekwargs::NamedTuple=NamedTuple{}()),
max_iter=1000,
jac_ad=:NoJacobian,
request_adchunksize=10,
Expand Down Expand Up @@ -294,15 +294,16 @@ As [`steadystate_ptc`](@ref), with an inner Newton solve for per-cell algebraic
- `transfer_inner_vars=["tmid", "volume", "ntotal", "Abox"]`: Variables not calculated by `operatorID_inner` that need to be copied for
inner solve (additional to those with `transfer_jacobian` set).
- `inner_jac_ad::Symbol=:ForwardDiff`: form of automatic differentiation to use for Jacobian for inner `NonlinearNewton.solve` solver (options `:ForwardDiff`, `:ForwardDiffSparse`)
- `inner_kwargs::NamedTuple=(verbose=0, miniters=2, reltol=1e-12, jac_constant=true, project_region=x->x)`: keywords for inner
- `inner_start::Symbol=:current`: start value for inner solve (options `:initial`, `:current`, `:zero`)
- `inner_kwargs::NamedTuple=(verbose=0, miniters=2, reltol=1e-12, jac_constant=true, project_region=identity)`: keywords for inner
`NonlinearNewton.solve` solver.
"""
function steadystate_ptc_splitdae(
run, initial_state, modeldata, tspan, deltat_initial::Float64;
deltat_fac=2.0,
tss_output=Float64[],
outputwriter=run.output,
solvekwargs::NamedTuple=NamedTuple{}(),
@nospecialize(solvekwargs::NamedTuple=NamedTuple{}()),
max_iter=1000,
request_adchunksize=10,
jac_cellranges=modeldata.cellranges_all,
Expand All @@ -312,8 +313,8 @@ function steadystate_ptc_splitdae(
operatorID_inner=3,
transfer_inner_vars=["tmid", "volume", "ntotal", "Abox"],
inner_jac_ad=:ForwardDiff,
inner_start_initial=false,
inner_kwargs::NamedTuple=(verbose=0, miniters=2, reltol=1e-12, jac_constant=true, project_region=x->x),
inner_start=:current,
@nospecialize(inner_kwargs::NamedTuple=(verbose=0, miniters=2, reltol=1e-12, jac_constant=true, project_region=identity)),
BLAS_num_threads=1,
generated_dispatch=true,
)
Expand Down Expand Up @@ -341,7 +342,7 @@ function steadystate_ptc_splitdae(
transfer_inner_vars,
tss_jac_sparsity=tspan[1],
inner_jac_ad,
inner_start_initial,
inner_start,
inner_kwargs,
generated_dispatch,
)
Expand Down Expand Up @@ -376,7 +377,7 @@ function solve_ptc(
tss_output=Float64[],
max_iter=1000,
outputwriter=run.output,
solvekwargs::NamedTuple=NamedTuple{}(),
@nospecialize(solvekwargs::NamedTuple=NamedTuple{}()),
enforce_noneg=false,
verbose=false,
BLAS_num_threads=1
Expand Down Expand Up @@ -557,8 +558,8 @@ from state `previous_u`:
struct FJacPTC
modelode #::M no specialization to minimise recompilation
jacode #::J
t::Ref{Float64}
delta_t::Ref{Float64}
t::Base.RefValue{Float64}
delta_t::Base.RefValue{Float64}
previous_u::Vector{Float64}
du_worksp::Vector{Float64}
end
Expand Down Expand Up @@ -706,8 +707,8 @@ from state `previous_u`:
"""
struct FJacSplitPTC
modeljacode # ::M # no specialization as this seems to cause compiler issues with Julia 1.7.3
t::Ref{Float64}
delta_t::Ref{Float64}
t::Base.RefValue{Float64}
delta_t::Base.RefValue{Float64}
previous_u::Vector{Float64}
du_worksp::Vector{Float64}
end
Expand Down

0 comments on commit 0a12dfa

Please sign in to comment.