Skip to content

Commit

Permalink
Merge pull request #317 from CliMA/glw/update
Browse files Browse the repository at this point in the history
Update Oceananigans
  • Loading branch information
glwagner authored May 30, 2023
2 parents 47f84bd + 4afe922 commit a9dbf70
Show file tree
Hide file tree
Showing 14 changed files with 270 additions and 273 deletions.
425 changes: 217 additions & 208 deletions Manifest.toml

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"

[compat]
BlockDiagonals = "0.1.36"
CUDA = "3"
CUDA = "4"
DataDeps = "0.7"
Distributions = "0.25"
DocStringExtensions = "0.8, 0.9"
Expand All @@ -37,7 +37,7 @@ GaussianProcesses = "0.12"
JLD2 = "0.4"
LineSearches = "7"
MPI = "0.20"
Oceananigans = "0.79"
Oceananigans = "0.81"
OffsetArrays = "1"
OrderedCollections = "1"
ProgressBars = "1"
Expand Down
2 changes: 1 addition & 1 deletion examples/intro_to_observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ fig = Figure()
ax_b = Axis(fig[1, 1], xlabel = "Buoyancy [m s⁻²]", ylabel = "z [m]")
ax_u = Axis(fig[1, 2], xlabel = "Velocities [m s⁻¹]", ylabel = "z [m]")

z = znodes(Center, observations.grid)
z = znodes(observations.grid, Center())

colorcycle = [:black, :red, :blue, :orange, :pink]

Expand Down
2 changes: 1 addition & 1 deletion examples/lesbrary_catke_calibration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function make_figure_axes()
end

function plot_fields!(axs, b, u, v, e, label, color)
z = znodes(Center, b.grid)
z = znodes(b.grid, Center())
## Note unit conversions below, eg m s⁻² -> cm s⁻²:
lines!(axs[1], 1e2 * interior(b, 1, 1, :), z; color, label)
lines!(axs[2], 1e2 * interior(u, 1, 1, :), z; color, label)
Expand Down
2 changes: 1 addition & 1 deletion examples/multi_case_lesbrary_ri_based_calibration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function make_figure_axes(n=1)
end

function plot_fields!(axs, b, u, v, label, color, grid=first(batch).grid)
z = znodes(Center, grid)
z = znodes(grid, Center())
## Note unit conversions below, eg m s⁻² -> cm s⁻²:
lines!(axs[1], 1e2 * b, z; color, label)
lines!(axs[2], 1e2 * u, z; color, label)
Expand Down
2 changes: 1 addition & 1 deletion examples/perfect_catke_calibration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ ax_b = Axis(fig[1, 1], xlabel = "Buoyancy\n[10⁻⁴ m s⁻²]", ylabel = "z [m]
ax_u = Axis(fig[1, 2], xlabel = "Velocities\n[cm s⁻¹]")
ax_e = Axis(fig[1, 3], xlabel = "Turbulent kinetic energy\n[cm² s⁻²]")

z = znodes(Center, observations.grid)
z = znodes(observations.grid, Center())

colorcycle = [:black, :red, :blue, :orange, :pink]

Expand Down
2 changes: 1 addition & 1 deletion examples/single_case_lesbrary_ri_based_calibration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function make_figure_axes()
end

function plot_fields!(axs, b, u, v, label, color)
z = znodes(Center, b.grid)
z = znodes(b.grid, Center())
## Note unit conversions below, eg m s⁻² -> cm s⁻²:
lines!(axs[1], 1e2 * interior(b, 1, 1, :), z; color, label)
lines!(axs[2], 1e2 * interior(u, 1, 1, :), z; color, label)
Expand Down
15 changes: 9 additions & 6 deletions src/EnsembleKalmanInversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ Keyword Arguments
- `tikhonov`: Whether to incorporate prior information in the EKI objective via Tikhonov regularization.
See Chada et al. "Tikhonov Regularization Within Ensemble Kalman Inversion." SIAM J. Numer. Anal. 2020.
"""
function EnsembleKalmanInversion(inverse_problem;
noise_covariance = 1,
Expand Down Expand Up @@ -523,10 +521,6 @@ function step_parameters(eki::EnsembleKalmanInversion, pseudo_stepping;
successes = findall(.!particle_failure)
some_failures = length(failures) > 0

some_failures && @warn string(length(failures), " particles failed. ",
"Performing ensemble update with statistics from ",
length(successes), " successful particles.")

successful_Gⁿ = Gⁿ[:, successes]
successful_Xⁿ = Xⁿ[:, successes]

Expand All @@ -546,6 +540,15 @@ function step_parameters(eki::EnsembleKalmanInversion, pseudo_stepping;
sampled_Xⁿ⁺¹ = rand(new_X_distribution, length(failures))
Xⁿ⁺¹[:, failures] .= sampled_Xⁿ⁺¹
end

msg = @sprintf("Particles stepped adaptively. Iteration: %d, pseudotime: %.3e, pseudostep: %.3e",
eki.iteration, eki.pseudotime, Δt)

if some_failures
msg *= string(" (", length(failures), " failed, ", length(successes), " successful particles)")
end

@info msg

return Xⁿ⁺¹, Δt
end
Expand Down
21 changes: 11 additions & 10 deletions src/InverseProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ closure_with_parameters(grid, closure, parameter_ensemble) = closure_with_parame
Initialize `ip.simulation` with `parameter_ensemble` and run it forward. Output is stored
in `ip.time_series_collector`. `forward_run` can also be called with one parameter set.
"""
function forward_run!(ip::InverseProblem, maybe_parameter_ensemble)
function forward_run!(ip::InverseProblem, maybe_parameter_ensemble=nothing)
# Ensure there are enough parameters for ensemble members in the simulation
parameter_ensemble = expand_parameter_ensemble(ip, maybe_parameter_ensemble)
_forward_run!(ip, parameter_ensemble, ip.simulation, ip.time_series_collector)
Expand Down Expand Up @@ -354,7 +354,7 @@ Base.length(batch::BatchedInverseProblem) = length(batch.batch)
Nensemble(batched_ip::BatchedInverseProblem) = Nensemble(first(batched_ip.batch))

function collect_forward_maps_asynchronously!(outputs, batched_ip, parameters; kw...)
asyncmap(1:length(batched_ip), ntasks=10) do n
for n = 1:length(batched_ip)
ip = batched_ip[n]
forward_map_output = forward_map(ip, parameters; kw...)
outputs[n] = batched_ip.weights[n] * forward_map_output
Expand All @@ -363,6 +363,13 @@ function collect_forward_maps_asynchronously!(outputs, batched_ip, parameters; k
return outputs
end

function forward_map(batched_ip::BatchedInverseProblem, parameters; kw...)
outputs = Dict()
collect_forward_maps_asynchronously!(outputs, batched_ip, parameters; kw...)
vectorized_outputs = [outputs[n] for n = 1:length(batched_ip)]
return vcat(vectorized_outputs...)
end

function forward_run_asynchronously!(batched_ip::BatchedInverseProblem, parameters; kw...)
asyncmap(1:length(batched_ip), ntasks=10) do n
ip = batched_ip[n]
Expand All @@ -371,16 +378,9 @@ function forward_run_asynchronously!(batched_ip::BatchedInverseProblem, paramete
return nothing
end

forward_run!(batched_ip::BatchedInverseProblem, parameters; kw...) =
forward_run!(batched_ip::BatchedInverseProblem, parameters=nothing; kw...) =
forward_run_asynchronously!(batched_ip, parameters; kw...)

function forward_map(batched_ip::BatchedInverseProblem, parameters; kw...)
outputs = Dict()
collect_forward_maps_asynchronously!(outputs, batched_ip, parameters; kw...)
vectorized_outputs = [outputs[n] for n = 1:length(batched_ip)]
return vcat(vectorized_outputs...)
end

function observation_map(batched_ip::BatchedInverseProblem)
maps = []

Expand Down Expand Up @@ -438,6 +438,7 @@ expand_parameter_ensemble(ip, θ::NamedTuple) = [θ]

# Convert matrix to vector of vectors
expand_parameter_ensemble(ip, θ::Matrix) = expand_parameter_ensemble(ip, [θ[:, k] for k = 1:size(θ, 2)])
expand_parameter_ensemble(ip, ::Nothing) = nothing

"""
observation_map(ip::InverseProblem)
Expand Down
10 changes: 5 additions & 5 deletions src/Observations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Oceananigans
using Oceananigans.Fields
using Oceananigans.Architectures

using Oceananigans: fields
using Oceananigans: prognostic_fields
using Oceananigans.Grids: AbstractGrid
using Oceananigans.Grids: cpu_face_constructor_x, cpu_face_constructor_y, cpu_face_constructor_z
using Oceananigans.Grids: pop_flat_elements, topology, halo_size, on_architecture
Expand Down Expand Up @@ -306,8 +306,8 @@ function column_ensemble_interior(batch::BatchedSyntheticObservations,
end

function set!(model, obs::SyntheticObservations, time_index=1)
for field_name in keys(fields(model))
model_field = fields(model)[field_name]
for field_name in keys(prognostic_fields(model))
model_field = prognostic_fields(model)[field_name]

if field_name keys(obs.field_time_serieses)
obs_field = obs.field_time_serieses[field_name][time_index]
Expand All @@ -323,8 +323,8 @@ function set!(model, obs::SyntheticObservations, time_index=1)
end

function set!(model, observations::BatchedSyntheticObservations, time_index=1)
for field_name in keys(fields(model))
model_field = fields(model)[field_name]
for field_name in keys(prognostic_fields(model))
model_field = prognostic_fields(model)[field_name]
model_field_size = size(model_field)
Nensemble = model.grid.Nx

Expand Down
3 changes: 3 additions & 0 deletions src/Parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ Closure(ClosureSubModel(12, 2), 3)
```
"""
closure_with_parameters(closure, parameters) = construct_object(dict_properties(closure), parameters)
closure_with_parameters(closure, ::Nothing) = nothing

closure_with_parameters(closure::AbstractTurbulenceClosure{ExplicitTimeDiscretization}, parameters) =
construct_object(dict_properties(closure), parameters, type_parameters=nothing)
Expand Down Expand Up @@ -601,6 +602,8 @@ end
new_closure_ensemble(closures::Tuple, parameter_ensemble, arch) =
Tuple(new_closure_ensemble(closure, parameter_ensemble, arch) for closure in closures)

# Don't change closure if parameters=nothing
new_closure_ensemble(closure::Union{Tuple, AbstractArray}, ::Nothing, arch) = closure
new_closure_ensemble(closure, parameter_ensemble, arch) = closure

end # module
27 changes: 3 additions & 24 deletions src/PseudoSteppingSchemes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ Implement an EKI update with a fixed time step given by `pseudo_scheme.step_size
function eki_update(pseudo_scheme::ConstantPseudoTimeStep, Xₙ, Gₙ, eki)
Δtₙ = pseudo_scheme.step_size
Xₙ₊₁ = iglesias_2013_update(Xₙ, Gₙ, eki; Δtₙ)
@info "Particles stepped with time step $Δtₙ"
return Xₙ₊₁, Δtₙ
end

Expand All @@ -200,12 +199,7 @@ function eki_update(pseudo_scheme::Kovachki2018, Xₙ, Gₙ, eki)
initial_step_size = pseudo_scheme.initial_step_size
Xₙ₊₁, Δtₙ = kovachki_2018_update(Xₙ, Gₙ, eki; Δt₀=initial_step_size)

intro = "Particles stepped adaptively with the Kovachki2018 pseudo-stepping scheme."
info1 = @sprintf(" ├─ iteration: %d", eki.iteration)
info2 = @sprintf(" ├─ pseudo time: %.3e", eki.pseudotime)
info3 = @sprintf(" └─ pseudo step: %.3e", Δtₙ)
@info string(intro, '\n', info1, '\n', info2, '\n', info3)


return Xₙ₊₁, Δtₙ
end

Expand Down Expand Up @@ -272,8 +266,6 @@ function eki_update(pseudo_scheme::Kovachki2018InitialConvergenceRatio, Xₙ, G

pseudo_scheme.initial_step_size = Δt₀

@info "Particles stepped adaptively with time step $Δtₙ and convergence ratio $r (target $target)."

return Xₙ₊₁, Δtₙ

else
Expand All @@ -294,8 +286,6 @@ function eki_update(pseudo_scheme::Chada2021, Xₙ, Gₙ, eki)
Δtₙ = ((n+1) ^ pseudo_scheme.β) * initial_step_size
Xₙ₊₁ = iglesias_2013_update(Xₙ, Gₙ, eki; Δtₙ)

@info "Particles stepped adaptively with time step $Δtₙ"

return Xₙ₊₁, Δtₙ
end

Expand Down Expand Up @@ -341,8 +331,6 @@ function eki_update(pseudo_scheme::ThresholdedConvergenceRatio, Xₙ, Gₙ, eki;

Xₙ₊₁ = iglesias_2013_update(Xₙ, Gₙ, eki; Δtₙ)

report && @info "Particles stepped adaptively with time step $Δtₙ"

return Xₙ₊₁, Δtₙ
end

Expand Down Expand Up @@ -393,9 +381,8 @@ function trained_gp_predict_function(X, y; standardize_X=true, zscore_limit=noth
if n_pruned > 0
percent_pruned = round((100n_pruned / length(y)); sigdigits=3)


@info "Pruned $n_pruned GP training points ($percent_pruned%) corresponding to outputs
outside $zscore_limit standard deviations from the mean."
outside $zscore_limit standard deviations from the mean."
end
end

Expand Down Expand Up @@ -497,13 +484,7 @@ function eki_update(pseudo_scheme::ConstantConvergence, Xₙ, Gₙ, eki)
iter += 1
end

# A nice message
intro_str = "Pseudo time step found for ConstantConvergence pseudo-stepping."
convergence_str = @sprintf(" ├─ convergence ratio: %.6f (target: %.2f)", r, conv_rate)
iteration_str = @sprintf(" ├─ iteration: %d", eki.iteration)
time_str = @sprintf(" ├─ pseudo time: %.3e", eki.pseudotime)
time_step_str = @sprintf(" └─ pseudo step: %.3e", Δtₙ)
@info string(intro_str, '\n', convergence_str, '\n', iteration_str, '\n', time_str, '\n', time_step_str)
@info @sprintf("ConstantConvergence pseudo stepping: convergence ratio: %.6f (target: %.2f)", r, conv_rate)

return Xₙ₊₁, Δtₙ
end
Expand All @@ -529,8 +510,6 @@ function eki_update(pseudo_scheme::Iglesias2021, Xₙ, Gₙ, eki)
Δtₙ = minimum([qₙ, 1-tₙ])
Xₙ₊₁ = iglesias_2013_update(Xₙ, Gₙ, eki; Δtₙ)

@info "Pseudo time step $Δtₙ found for Iglesias2021 pseudo-stepping."

return Xₙ₊₁, Δtₙ
end

Expand Down
25 changes: 13 additions & 12 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,17 @@ function resample!(resampler::Resampler, X, G, eki)

priors = eki.inverse_problem.free_parameters.priors
θ = transform_to_constrained(priors, X)
failed_parameters_message = string(" ", param_str.(keys(priors))..., '\n',
(failed_particle_str(θ, k) for k in failures)...)

@warn("""
The forward map for $Nfailures $particles ($(100failed_fraction)%) failed.
The failed particles are:
$failed_parameters_message
""")

if resampler.verbose
failed_parameters_message = string(" ", param_str.(keys(priors))..., '\n',
(failed_particle_str(θ, k) for k in failures)...)

@info("""
The forward map for $Nfailures $particles ($(100failed_fraction)%) failed.
The failed particles are:
$failed_parameters_message
""")
end
end

if failed_fraction > resampler.acceptable_failure_fraction
Expand All @@ -87,9 +90,6 @@ function resample!(resampler::Resampler, X, G, eki)
# We are resampling!

if resampler.only_failed_particles
#Nsample = Nfailures
#replace_columns = failures

Nneed = ceil(Int, (1 - resampler.resample_failure_fraction) * Nens)
Nsuccesses = Nens - Nfailures
Nsample = Nneed - Nsuccesses
Expand All @@ -100,9 +100,10 @@ function resample!(resampler::Resampler, X, G, eki)
replace_columns = Colon()
end

@info "Searching for $Nsample successful particles..."
found_X, found_G = find_successful_particles(eki, X, G, Nsample)

@info "Replacing columns $replace_columns (failed fraction: $failed_fraction)..."
@info "Replacing columns $replace_columns (failed fraction: $failed_fraction)."
view(X, :, replace_columns) .= found_X
view(G, :, replace_columns) .= found_G

Expand Down
3 changes: 2 additions & 1 deletion test/test_forward_map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Oceananigans.Units
using Oceananigans.Grids: halo_size
using Oceananigans.Models.HydrostaticFreeSurfaceModels: ColumnEnsembleSize
using Oceananigans.TurbulenceClosures: ConvectiveAdjustmentVerticalDiffusivity
using Oceananigans: prognostic_fields

using ParameterEstimocean.Observations: FieldTimeSeriesCollector, initialize_forward_run!, observation_times
using ParameterEstimocean.InverseProblems: transpose_model_output, forward_run!, drop_y_dimension
Expand Down Expand Up @@ -93,7 +94,7 @@ end
@info " Testing initialize_forward_run!..."
random_initial_condition(x, y, z) = rand()

for field in fields(test_simulation.model)
for field in prognostic_fields(test_simulation.model)
set!(field, random_initial_condition)
end

Expand Down

0 comments on commit a9dbf70

Please sign in to comment.