From 8683139f58036165cf6a16c929675876f1eb8db5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 14 Oct 2024 14:04:05 -0400 Subject: [PATCH] refactor: cleanup NNODE --- Project.toml | 6 +- src/BPINN_ode.jl | 6 +- src/NeuralPDE.jl | 3 +- src/discretize.jl | 10 +- src/ode_solve.jl | 335 +++++++++++++++----------------------- src/pinn_types.jl | 5 - src/symbolic_utilities.jl | 7 +- test/NNODE_tests.jl | 138 ++++++---------- test/NNODE_tstops_test.jl | 79 ++------- 9 files changed, 211 insertions(+), 378 deletions(-) diff --git a/Project.toml b/Project.toml index 2bffeadcf2..2644a88506 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas "] version = "5.16.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -31,17 +32,18 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ADTypes = "1.9.0" Adapt = "4" AdvancedHMC = "0.6.1" Aqua = "0.8" @@ -78,6 +80,7 @@ OrdinaryDiffEq = "6.87" Pkg = "1.10" QuasiMonteCarlo = "0.3.2" Random = "1" +RecursiveArrayTools = "3.27.0" Reexport = "1.2" RuntimeGeneratedFunctions = "0.5.12" SafeTestsets = "0.1" @@ -86,7 +89,6 @@ Statistics = "1.10" SymbolicUtils = "3.7.2" Symbolics = "6.14" Test = "1.10" -UnPack = "1" WeightInitializers = "1.0.3" Zygote = "0.6.71" julia = "1.10" diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index f3a16ad61c..f33cb7c935 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -188,11 +188,7 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem, saveat = 1 / 50.0, maxiters = nothing, numensemble = floor(Int, alg.draw_samples / 3)) - @unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy, - draw_samples, dataset, init_params, - nchains, physdt, Adaptorkwargs, Integratorkwargs, - MCMCkwargs, numensemble, estim_collocate, autodiff, progress, - verbose = alg + (; chain, l2std, phystd, param, priorsNNw, Kernel, strategy, draw_samples, dataset, init_params, nchains, physdt, Adaptorkwargs, Integratorkwargs, MCMCkwargs, numensemble, estim_collocate, autodiff, progress, verbose) = alg # ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters param = param === nothing ? [] : param diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 21a38c6156..f8cc92f245 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -27,8 +27,8 @@ import DomainSets using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, rightendpoint, ProductDomain using SciMLBase: @add_kwonly, parameterless_type -using UnPack: @unpack +using ADTypes: AutoForwardDiff, AutoZygote using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives using ComponentArrays: ComponentArrays, ComponentArray, getdata, getaxes using ConcreteStructs: @concrete @@ -37,6 +37,7 @@ using Lux: Lux, Chain, Dense, SkipConnection, StatefulLuxLayer using Lux: FromFluxAdaptor, recursive_eltype using LuxCore: AbstractLuxLayer, AbstractLuxWrapperLayer, AbstractLuxContainerLayer using Optimisers: Optimisers, Adam +using RecursiveArrayTools: DiffEqArray using QuasiMonteCarlo: QuasiMonteCarlo, LatinHypercubeSample using WeightInitializers: glorot_uniform, zeros32 diff --git a/src/discretize.jl b/src/discretize.jl index 72eec30ccf..476c00451c 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -34,10 +34,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; dict_transformation_vars = nothing, transformation_vars = nothing, integrating_depvars = pinnrep.depvars) - @unpack indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input, - phi, derivative, integral, - multioutput, init_params, strategy, eq_params, - param_estim, default_p = pinnrep + (; indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input, phi, derivative, integral, multioutput, init_params, strategy, eq_params, param_estim, default_p) = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) @@ -150,7 +147,7 @@ Returns the body of loss function, which is the executable Julia function, for t equation or boundary condition. """ function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars) - @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep + (; eq_params, param_estim, default_p, phi, derivative, integral) = pinnrep bc_indvars = bc_indvars === nothing ? pinnrep.indvars : bc_indvars @@ -312,8 +309,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str end function get_numeric_integral(pinnrep::PINNRepresentation) - @unpack strategy, indvars, depvars, multioutput, derivative, - depvars, indvars, dict_indvars, dict_depvars = pinnrep + (; strategy, indvars, depvars, multioutput, derivative, depvars, indvars, dict_indvars, dict_depvars) = pinnrep integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, indvars = indvars, depvars = depvars, dict_indvars = dict_indvars, dict_depvars = dict_depvars) -> begin function integration_(cord, lb, ub, θ) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index d0f006f83c..351e0fac7e 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -1,10 +1,12 @@ abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end """ - NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...) + NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, + additional_loss = nothing, kwargs...) -Algorithm for solving ordinary differential equations using a neural network. This is a specialization -of the physics-informed neural network which is used as a solver for a standard `ODEProblem`. +Algorithm for solving ordinary differential equations using a neural network. This is a +specialization of the physics-informed neural network which is used as a solver for a +standard `ODEProblem`. !!! warn @@ -14,24 +16,31 @@ of the physics-informed neural network which is used as a solver for a standard ## Positional Arguments -* `chain`: A neural network architecture, defined as a `Lux.AbstractLuxLayer` or `Flux.Chain`. - `Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`. +* `chain`: A neural network architecture, defined as a `Lux.AbstractLuxLayer` or + `Flux.Chain`. `Flux.Chain` will be converted to `Lux` using + `adapt(FromFluxAdaptor(false, false), chain)`. * `opt`: The optimizer to train the neural network. * `init_params`: The initial parameter of the neural network. By default, this is `nothing` - which thus uses the random initialization provided by the neural network library. + which thus uses the random initialization provided by the neural network + library. ## Keyword Arguments -* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions, - θ are the weights of the neural network(s). + +* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network + trial solutions, θ are the weights of the neural network(s). * `autodiff`: The switch between automatic and numerical differentiation for the PDE operators. The reverse mode of the loss function is always automatic differentiation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time). -* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values - `t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data. - `false` means which means the application of the neural network is done at individual time points one at a time. - This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand. -* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network. +* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural + network is applied at a row vector of values `t` simultaneously, i.e. it's the + batch size for the neural network evaluations. This requires a neural network + compatible with batched data. `false` means which means the application of the + neural network is done at individual time points one at a time. This is not + applicable to `QuadratureTraining` where `batch` is passed in the `strategy` + which is the number of points it can parallelly compute the integrand. +* `param_estim`: Boolean to indicate whether parameters of the differential equations are + learnt along with parameters of the neural network. * `strategy`: The training strategy used to choose the points for the evaluations. Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no `dt` is given, and `GridTraining` is used with `dt` if given. @@ -61,97 +70,66 @@ sol = solve(prob, NNODE(chain, opt), verbose = true, abstol = 1e-10, maxiters = ## Solution Notes -Note that the solution is evaluated at fixed time points according to standard output handlers -such as `saveat` and `dt`. However, the neural network is a fully continuous solution so `sol(t)` -is an accurate interpolation (up to the neural network training result). In addition, the -`OptimizationSolution` is returned as `sol.k` for further analysis. +Note that the solution is evaluated at fixed time points according to standard output +handlers such as `saveat` and `dt`. However, the neural network is a fully continuous +solution so `sol(t)` is an accurate interpolation (up to the neural network training +result). In addition, the `OptimizationSolution` is returned as `sol.k` for further +analysis. ## References -Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving -ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000. +Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks +for solving ordinary and partial differential equations." IEEE Transactions on Neural +Networks 9, no. 5 (1998): 987-1000. """ -struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, - S <: Union{Nothing, AbstractTrainingStrategy} -} <: - NeuralPDEAlgorithm - chain::C - opt::O - init_params::P +@concrete struct NNODE + chain <: AbstractLuxLayer + opt + init_params autodiff::Bool - batch::B - strategy::S - param_estim::PE - additional_loss::AL - kwargs::K + batch + strategy <: Union{Nothing, AbstractTrainingStrategy} + param_estim + additional_loss <: Union{Nothing, Function} + kwargs end -function NNODE(chain, opt, init_params = nothing; - strategy = nothing, - autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...) - !(chain isa Lux.AbstractLuxLayer) && - (chain = adapt(FromFluxAdaptor(false, false), chain)) - NNODE(chain, opt, init_params, autodiff, batch, + +function NNODE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, + batch = true, param_estim = false, additional_loss = nothing, kwargs...) + chain isa Lux.AbstractLuxLayer || (chain = FromFluxAdaptor()(chain)) + return NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs) end """ ODEPhi(chain::Lux.AbstractLuxLayer, t, u0, st) -Internal struct, used for representing the ODE solution as a neural network in a form that respects boundary conditions, i.e. -`phi(t) = u0 + t*NN(t)`. +Internal struct, used for representing the ODE solution as a neural network in a form that +respects boundary conditions, i.e. `phi(t) = u0 + t*NN(t)`. """ -mutable struct ODEPhi{C, T, U, S} - chain::C - t0::T - u0::U - st::S - function ODEPhi(chain::Lux.AbstractLuxLayer, t::Number, u0, st) - new{typeof(chain), typeof(t), typeof(u0), typeof(st)}(chain, t, u0, st) - end +@concrete struct ODEPhi + u0 + t0 + smodel <: StatefulLuxLayer end -function generate_phi_θ(chain::Lux.AbstractLuxLayer, t, u0, init_params) - θ, st = Lux.setup(Random.default_rng(), chain) - isnothing(init_params) && (init_params = θ) - ODEPhi(chain, t, u0, st), init_params +function ODEPhi(model::AbstractLuxLayer, t0::Number, u0, st) + return ODEPhi(u0, t0, StatefulLuxLayer{true}(model, nothing, st)) end -function (f::ODEPhi{C, T, U})(t::Number, - θ) where {C <: Lux.AbstractLuxLayer, T, U <: Number} - eltypeθ, typeθ = eltype(θ.depvar), parameterless_type(ComponentArrays.getdata(θ.depvar)) - t_ = convert.(eltypeθ, adapt(typeθ, [t])) - y, st = f.chain(t_, θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 + (t - f.t0) * first(y) +function generate_phi_θ(chain::AbstractLuxLayer, t, u0, init_params) + θ, st = Lux.setup(Random.default_rng(), chain) + init_params === nothing && (init_params = θ) + return ODEPhi(chain, t, u0, st), init_params end -function (f::ODEPhi{C, T, U})(t::AbstractVector, - θ) where {C <: Lux.AbstractLuxLayer, T, U <: Number} - # Batch via data as row vectors - eltypeθ, typeθ = eltype(θ.depvar), parameterless_type(ComponentArrays.getdata(θ.depvar)) - t_ = convert.(eltypeθ, adapt(typeθ, t')) - y, st = f.chain(t_, θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t' .- f.t0) .* y -end +(f::ODEPhi{<:Number})(t::Number, θ) = f.u0 + (t - f.t0) * first(f.smodel([t], θ.depvar)) -function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractLuxLayer, T, U} - eltypeθ, typeθ = eltype(θ.depvar), parameterless_type(ComponentArrays.getdata(θ.depvar)) - t_ = convert.(eltypeθ, adapt(typeθ, [t])) - y, st = f.chain(t_, θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t .- f.t0) .* y -end +(f::ODEPhi{<:Number})(t::AbstractVector, θ) = f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar) -function (f::ODEPhi{C, T, U})(t::AbstractVector, - θ) where {C <: Lux.AbstractLuxLayer, T, U} - # Batch via data as row vectors - eltypeθ, typeθ = eltype(θ.depvar), parameterless_type(ComponentArrays.getdata(θ.depvar)) - t_ = convert.(eltypeθ, adapt(typeθ, t')) - y, st = f.chain(t_, θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t' .- f.t0) .* y -end +(f::ODEPhi)(t::Number, θ) = f.u0 .+ (t .- f.t0) .* f.smodel([t], θ.depvar) + +(f::ODEPhi)(t::AbstractVector, θ) = f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar) """ ode_dfdx(phi, t, θ, autodiff) @@ -160,30 +138,16 @@ Computes u' using either forward-mode automatic differentiation or numerical dif """ function ode_dfdx end -function ode_dfdx(phi::ODEPhi{C, T, U}, t::Number, θ, - autodiff::Bool) where {C, T, U <: Number} - if autodiff - ForwardDiff.derivative(t -> phi(t, θ), t) - else - (phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t))) - end -end - -function ode_dfdx(phi::ODEPhi{C, T, U}, t::Number, θ, - autodiff::Bool) where {C, T, U <: AbstractVector} - if autodiff - ForwardDiff.jacobian(t -> phi(t, θ), t) - else - (phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t))) - end +function ode_dfdx(phi::ODEPhi{<:Number}, t::Number, θ, autodiff::Bool) + autodiff && return ForwardDiff.derivative(Base.Fix2(phi, θ), t) + ϵ = sqrt(eps(typeof(t))) + return (phi(t + ϵ, θ) - phi(t, θ)) / ϵ end -function ode_dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool) - if autodiff - ForwardDiff.jacobian(t -> phi(t, θ), t) - else - (phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t))) - end +function ode_dfdx(phi::ODEPhi, t, θ, autodiff::Bool) + autodiff && return ForwardDiff.jacobian(Base.Fix2(phi, θ), t) + ϵ = sqrt(eps(eltype(t))) + return (phi(t .+ ϵ, θ) .- phi(t, θ)) ./ ϵ end """ @@ -193,35 +157,34 @@ Simple L2 inner loss at a time `t` with parameters `θ` of the neural network. """ function inner_loss end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, - p, param_estim::Bool) where {C, T, U <: Number} +function inner_loss(phi::ODEPhi{<:Number}, f, autodiff::Bool, t::Number, θ, + p, param_estim::Bool) p_ = param_estim ? θ.p : p - sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), p_, t)) + return sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t)) end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, param_estim::Bool) where {C, T, U <: Number} +function inner_loss(phi::ODEPhi{<:Number}, f, autodiff::Bool, t::AbstractVector, θ, + p, param_estim::Bool) p_ = param_estim ? θ.p : p out = phi(t, θ) fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)]) dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) - sum(abs2, dxdtguess .- fs) / length(t) + return sum(abs2, fs .- dxdtguess) / length(t) end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, - p, param_estim::Bool) where {C, T, U} +function inner_loss(phi::ODEPhi, f, autodiff::Bool, t::Number, θ, p, param_estim::Bool) p_ = param_estim ? θ.p : p - sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t)) + return sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t)) end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, param_estim::Bool) where {C, T, U} +function inner_loss( + phi::ODEPhi, f, autodiff::Bool, t::AbstractVector, θ, p, param_estim::Bool) p_ = param_estim ? θ.p : p out = Array(phi(t, θ)) arrt = Array(t) fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)]) dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) - sum(abs2, dxdtguess .- fs) / length(t) + return sum(abs2, fs .- dxdtguess) / length(t) end """ @@ -234,16 +197,17 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) function integrand(ts, θ) - [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts] + return [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts] end function loss(θ, _) intf = BatchIntegralFunction(integrand, max_batch = strategy.batch) intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ) - sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, - reltol = strategy.reltol, maxiters = strategy.maxiters) - sol.u + sol = solve(intprob, strategy.quadrature_alg; strategy.abstol, + strategy.reltol, strategy.maxiters) + return sol.u end + return loss end @@ -251,93 +215,71 @@ function generate_loss( strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch, param_estim::Bool) ts = tspan[1]:(strategy.dx):tspan[2] autodiff && throw(ArgumentError("autodiff not supported for GridTraining.")) - function loss(θ, _) - if batch - inner_loss(phi, f, autodiff, ts, θ, p, param_estim) - else - sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts]) - end - end - return loss + batch && return (θ, _) -> inner_loss(phi, f, autodiff, ts, θ, p, param_estim) + return (θ, _) -> sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts]) end function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, batch, param_estim::Bool) autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining.")) - function loss(θ, _) - ts = adapt(parameterless_type(θ), - [(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)]) + return (θ, _) -> begin + T = promote_type(eltype(tspan[1]), eltype(tspan[2])) + ts = (tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1] if batch inner_loss(phi, f, autodiff, ts, θ, p, param_estim) else sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts]) end end - return loss end function generate_loss( strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, batch, param_estim::Bool) autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining.")) - minT = tspan[1] - maxT = tspan[2] - + minT, maxT = tspan weights = strategy.weights ./ sum(strategy.weights) - N = length(weights) - points = strategy.points - difference = (maxT - minT) / N - data = Float64[] + ts = eltype(difference)[] for (index, item) in enumerate(weights) - temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+ + temp_data = rand(1, trunc(Int, strategy.points * item)) .* difference .+ minT .+ ((index - 1) * difference) - data = append!(data, temp_data) + append!(ts, temp_data) end - ts = data - function loss(θ, _) - if batch - inner_loss(phi, f, autodiff, ts, θ, p, param_estim) - else - sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts]) - end - end - return loss + batch && return (θ, _) -> inner_loss(phi, f, autodiff, ts, θ, p, param_estim) + return (θ, _) -> sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts]) end function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch, param_estim::Bool) - function loss(θ, _) - if batch - inner_loss(phi, f, autodiff, tstops, θ, p, param_estim) - else - sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in tstops]) - end - end - return loss + batch && return (θ, _) -> inner_loss(phi, f, autodiff, tstops, θ, p, param_estim) + return (θ, _) -> sum([inner_loss(phi, f, autodiff, t, θ, p, param_estim) + for t in tstops]) end -function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan) - error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.") +function generate_loss(::QuasiRandomTraining, phi, f, autodiff::Bool, tspan) + error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional \ + spaces only. Use StochasticTraining instead.") end -struct NNODEInterpolation{T <: ODEPhi, T2} - phi::T - θ::T2 +@concrete struct NNODEInterpolation + phi <: ODEPhi + θ end -(f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ) + +(f::NNODEInterpolation)(t, ::Nothing, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ) (f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ)[idxs] -function (f::NNODEInterpolation)(t::Vector, idxs::Nothing, ::Type{Val{0}}, p, continuity) +function (f::NNODEInterpolation)(t::Vector, ::Nothing, ::Type{Val{0}}, p, continuity) out = f.phi(t, f.θ) - SciMLBase.RecursiveArrayTools.DiffEqArray([out[:, i] for i in axes(out, 2)], t) + return DiffEqArray([out[:, i] for i in axes(out, 2)], t) end function (f::NNODEInterpolation)(t::Vector, idxs, ::Type{Val{0}}, p, continuity) out = f.phi(t, f.θ) - SciMLBase.RecursiveArrayTools.DiffEqArray([out[idxs, i] for i in axes(out, 2)], t) + return DiffEqArray([out[idxs, i] for i in axes(out, 2)], t) end SciMLBase.interp_summary(::NNODEInterpolation) = "Trained neural network interpolation" @@ -356,34 +298,19 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, saveat = nothing, maxiters = nothing, tstops = nothing) - u0 = prob.u0 - tspan = prob.tspan - f = prob.f - p = prob.p + (; u0, tspan, f, p) = prob t0 = tspan[1] - param_estim = alg.param_estim + (; param_estim, chain, opt, autodiff, init_params, batch, additional_loss) = alg - #hidden layer - chain = alg.chain - opt = alg.opt - autodiff = alg.autodiff - - #train points generation - init_params = alg.init_params - - !(chain isa Lux.AbstractLuxLayer) && - error("Only Lux.AbstractLuxLayer neural networks are supported") phi, init_params = generate_phi_θ(chain, t0, u0, init_params) - (recursive_eltype(init_params) <: Complex && - alg.strategy isa QuadratureTraining) && + + (recursive_eltype(init_params) <: Complex && alg.strategy isa QuadratureTraining) && error("QuadratureTraining cannot be used with complex parameters. Use other strategies.") init_params = if alg.param_estim - ComponentArrays.ComponentArray(; - depvar = ComponentArrays.ComponentArray(init_params), p = prob.p) + ComponentArray(; depvar = init_params, p) else - ComponentArrays.ComponentArray(; - depvar = ComponentArrays.ComponentArray(init_params)) + ComponentArray(; depvar = init_params) end isinplace(prob) && @@ -404,27 +331,25 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, GridTraining(dt) else QuadratureTraining(; quadrature_alg = QuadGKJL(), - reltol = convert(eltype(u0), reltol), - abstol = convert(eltype(u0), abstol), maxiters = maxiters, - batch = 0) + reltol = convert(eltype(u0), reltol), abstol = convert(eltype(u0), abstol), + maxiters, batch = 0) end else alg.strategy end - batch = alg.batch inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim) - additional_loss = alg.additional_loss - (param_estim && isnothing(additional_loss)) && + + (param_estim && additional_loss === nothing) && throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true).")) # Creates OptimizationFunction Object from total_loss function total_loss(θ, _) L2_loss = inner_f(θ, phi) - if !(additional_loss isa Nothing) + if additional_loss !== nothing L2_loss = L2_loss + additional_loss(phi, θ) end - if !(tstops isa Nothing) + if tstops !== nothing num_tstops_points = length(tstops) tstops_loss_func = evaluate_tstops_loss( phi, f, autodiff, tstops, p, batch, param_estim) @@ -445,11 +370,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, end # Choice of Optimization Algo for Training Strategies - opt_algo = if strategy isa QuadratureTraining - Optimization.AutoForwardDiff() - else - Optimization.AutoZygote() - end + opt_algo = ifelse(strategy isa QuadratureTraining, AutoForwardDiff(), AutoZygote()) # Creates OptimizationFunction Object from total_loss optf = OptimizationFunction(total_loss, opt_algo) @@ -487,8 +408,10 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, retcode = ReturnCode.Success, original = res, resid = res.objective) + SciMLBase.has_analytic(prob.f) && - SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) - sol -end #solve + SciMLBase.calculate_solution_errors!( + sol; timeseries_errors = true, dense_errors = false) + + return sol +end diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 5ca5581ad1..b15d05aef4 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -512,14 +512,9 @@ get_u() = (cord, θ, phi) -> phi(cord, θ) # the method to calculate the derivative function numeric_derivative(phi, u, x, εs, order, θ) - _type = parameterless_type(ComponentArrays.getdata(θ)) - ε = εs[order] _epsilon = inv(first(ε[ε .!= zero(ε)])) - ε = adapt(_type, ε) - x = adapt(_type, x) - # any(x->x!=εs[1],εs) # εs is the epsilon for each order, if they are all the same then we use a fancy formula # if order 1, this is trivially true diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index c78ddeff83..974f348e78 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -115,11 +115,8 @@ where - θ - weights in neural network. """ function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = false, - dict_transformation_vars = nothing, - transformation_vars = nothing) - @unpack indvars, depvars, dict_indvars, dict_depvars, - dict_depvar_input, multioutput, strategy, phi, - derivative, integral, flat_init_params, init_params = pinnrep + dict_transformation_vars = nothing, transformation_vars = nothing) + (; indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input, multioutput, strategy, phi, derivative, integral, flat_init_params, init_params) = pinnrep eltypeθ = eltype(flat_init_params) _args = ex.args diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index b7b4a697a8..17fa61fb1a 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -1,30 +1,23 @@ -using Test -using Random, NeuralPDE -using OrdinaryDiffEq, Statistics -import Lux, OptimizationOptimisers, OptimizationOptimJL -using WeightInitializers -using Flux -using LineSearches +using Test, Random, NeuralPDE, OrdinaryDiffEq, Statistics, Lux, OptimizationOptimisers, + OptimizationOptimJL, WeightInitializers, LineSearches +import Flux rng = Random.default_rng() Random.seed!(100) @testset "Scalar" begin - # Run a solve on scalars - println("Scalar") linear = (u, p, t) -> cos(2pi * t) tspan = (0.0f0, 1.0f0) u0 = 0.0f0 prob = ODEProblem(linear, u0, tspan) - luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + luxchain = Chain(Dense(1, 5, σ), Dense(5, 1)) opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95)) sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, verbose = false, abstol = 1.0f-10, maxiters = 200) @test_throws ArgumentError solve(prob, NNODE(luxchain, opt; autodiff = true), - dt = 1 / 20.0f0, - verbose = false, abstol = 1.0f-10, maxiters = 200) + dt = 1 / 20.0f0, verbose = false, abstol = 1.0f-10, maxiters = 200) sol = solve(prob, NNODE(luxchain, opt), verbose = false, abstol = 1.0f-6, maxiters = 200) @@ -38,13 +31,11 @@ Random.seed!(100) end @testset "Vector" begin - # Run a solve on vectors - println("Vector") linear = (u, p, t) -> [cos(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0] prob = ODEProblem(linear, u0, tspan) - luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + luxchain = Chain(Dense(1, 5, σ), Dense(5, 1)) opt = OptimizationOptimJL.BFGS() sol = solve(prob, NNODE(luxchain, opt), dt = 1 / 20.0f0, abstol = 1e-10, @@ -63,27 +54,24 @@ end end @testset "Example 1" begin - println("Example 1") linear = (u, p, t) -> @. t^3 + 2 * t + (t^2) * ((1 + 3 * (t^2)) / (1 + t + (t^3))) - u * (t + ((1 + 3 * (t^2)) / (1 + t + t^3))) linear_analytic = (u0, p, t) -> [exp(-(t^2) / 2) / (1 + t + t^3) + t^2] prob = ODEProblem( ODEFunction(linear, analytic = linear_analytic), [1.0f0], (0.0f0, 1.0f0)) - luxchain = Lux.Chain(Lux.Dense(1, 128, Lux.σ), Lux.Dense(128, 1)) + luxchain = Chain(Dense(1, 128, σ), Dense(128, 1)) opt = OptimizationOptimisers.Adam(0.01) sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 - sol = solve(prob, - NNODE(luxchain, opt; batch = false, - strategy = StochasticTraining(100)), + sol = solve( + prob, NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 - sol = solve(prob, - NNODE(luxchain, opt; batch = true, - strategy = StochasticTraining(100)), + sol = solve( + prob, NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), verbose = false, maxiters = 400) @test sol.errors[:l2] < 0.5 @@ -91,19 +79,17 @@ end maxiters = 400, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 - sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, - maxiters = 400, - dt = 1 / 5.0f0) + sol = solve(prob, NNODE(luxchain, opt; batch = true), + verbose = false, maxiters = 400, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 end @testset "Example 2" begin - println("Example 2") linear = (u, p, t) -> -u / 5 + exp(-t / 5) .* cos(t) linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t)) prob = ODEProblem( ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0)) - luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + luxchain = Chain(Dense(1, 5, σ), Dense(5, 1)) opt = OptimizationOptimisers.Adam(0.1) sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400, @@ -111,51 +97,42 @@ end @test sol.errors[:l2] < 0.5 sol = solve(prob, - NNODE(luxchain, opt; batch = false, - strategy = StochasticTraining(100)), - verbose = false, maxiters = 400, - abstol = 1.0f-8) + NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), + verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 sol = solve(prob, - NNODE(luxchain, opt; batch = true, - strategy = StochasticTraining(100)), - verbose = false, maxiters = 400, - abstol = 1.0f-8) + NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), + verbose = false, maxiters = 400, abstol = 1.0f-8) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false, - maxiters = 400, - abstol = 1.0f-8, dt = 1 / 5.0f0) + maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, - maxiters = 400, - abstol = 1.0f-8, dt = 1 / 5.0f0) + maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 end @testset "Example 3" begin - println("Example 3") linear = (u, p, t) -> [cos(2pi * t), sin(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0, -1.0f0 / 2pi] linear_analytic = (u0, p, t) -> [sin(2pi * t) / 2pi, -cos(2pi * t) / 2pi] odefunction = ODEFunction(linear, analytic = linear_analytic) prob = ODEProblem(odefunction, u0, tspan) - luxchain = Lux.Chain(Lux.Dense(1, 10, Lux.σ), Lux.Dense(10, 2)) + luxchain = Chain(Dense(1, 10, σ), Dense(10, 2)) opt = OptimizationOptimisers.Adam(0.1) alg = NNODE(luxchain, opt; autodiff = false) - sol = solve(prob, - alg, verbose = false, dt = 1 / 40.0f0, - maxiters = 2000, abstol = 1.0f-7) + sol = solve( + prob, alg, verbose = false, dt = 1 / 40.0f0, maxiters = 2000, abstol = 1.0f-7) @test sol.errors[:l2] < 0.5 end @testset "Training Strategies" begin @testset "WeightedIntervalTraining" begin - println("WeightedIntervalTraining") function f(u, p, t) [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] end @@ -164,21 +141,21 @@ end prob_oop = ODEProblem{false}(f, u0, (0.0, 3.0), p) true_sol = solve(prob_oop, Tsit5(), saveat = 0.01) - N = 32 - chain = Lux.Chain( - Lux.Dense(1, N, tanh), - Lux.Dense(N, N, tanh), - Lux.Dense(N, N, tanh), - Lux.Dense(N, N, tanh), - Lux.Dense(N, length(u0)) + N = 64 + chain = Chain( + Dense(1, N, gelu), + Dense(N, N, gelu), + Dense(N, N, gelu), + Dense(N, N, gelu), + Dense(N, length(u0)) ) - opt = OptimizationOptimisers.Adam(0.1) + opt = OptimizationOptimisers.Adam(0.001) weights = [0.7, 0.2, 0.1] points = 200 alg = NNODE(chain, opt, autodiff = false, - strategy = NeuralPDE.WeightedIntervalTraining(weights, points)) + strategy = WeightedIntervalTraining(weights, points)) sol = solve(prob_oop, alg; verbose = false, maxiters = 5000, saveat = 0.01) - @test_broken abs(mean(sol) - mean(true_sol)) < 0.2 + @test abs(mean(sol) - mean(true_sol)) < 0.2 end linear = (u, p, t) -> cos(2pi * t) @@ -191,46 +168,40 @@ end u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x) @testset "GridTraining" begin - println("GridTraining") - luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + luxchain = Chain(Dense(1, 5, σ), Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) end - alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01), - additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = false, abstol = 1e-8, maxiters = 500) + alg1 = NNODE(luxchain, opt; strategy = GridTraining(0.01), additional_loss) + sol1 = solve(prob, alg1; verbose = false, abstol = 1e-8, maxiters = 500) @test sol1.errors[:l2] < 0.5 end @testset "QuadratureTraining" begin - println("QuadratureTraining") - luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + luxchain = Chain(Dense(1, 5, σ), Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) end - alg1 = NNODE(luxchain, opt, additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) + alg1 = NNODE(luxchain, opt; additional_loss) + sol1 = solve(prob, alg1; verbose = false, abstol = 1e-10, maxiters = 200) @test sol1.errors[:l2] < 0.5 end @testset "StochasticTraining" begin - println("StochasticTraining") - luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + luxchain = Chain(Dense(1, 5, σ), Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) end - alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000), - additional_loss = additional_loss) - sol1 = solve(prob, alg1, verbose = false, abstol = 1e-8, maxiters = 500) + alg1 = NNODE(luxchain, opt; strategy = StochasticTraining(1000), additional_loss) + sol1 = solve(prob, alg1; verbose = false, abstol = 1e-8, maxiters = 500) @test sol1.errors[:l2] < 0.5 end end @testset "Parameter Estimation" begin - println("Parameter Estimation") function lorenz(u, p, t) return [p[1] * (u[2] - u[1]), u[1] * (p[2] - u[3]) - u[2], @@ -246,15 +217,15 @@ end return sum(abs2, phi(t_, θ) .- u_) / 100 end n = 8 - luxchain = Lux.Chain( - Lux.Dense(1, n, Lux.σ), - Lux.Dense(n, n, Lux.σ), - Lux.Dense(n, n, Lux.σ), - Lux.Dense(n, 3) + luxchain = Chain( + Dense(1, n, σ), + Dense(n, n, σ), + Dense(n, n, σ), + Dense(n, 3) ) opt = OptimizationOptimJL.BFGS(linesearch = BackTracking()) alg = NNODE(luxchain, opt, strategy = GridTraining(0.01), - param_estim = true, additional_loss = additional_loss) + param_estim = true, additional_loss) sol = solve(prob, alg, verbose = false, abstol = 1e-8, maxiters = 1000, saveat = t_) @test sol.k.u.p≈true_p atol=1e-2 @test reduce(hcat, sol.u)≈u_ atol=1e-2 @@ -279,11 +250,11 @@ end problem = ODEProblem(bloch_equations, u0, time_span, parameters) - chain = Lux.Chain( - Lux.Dense(1, 16, tanh; - init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...)), - Lux.Dense( - 16, 4; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...)) + chain = Chain( + Dense(1, 16, tanh; + init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)), + Dense( + 16, 4; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)) ) ps, st = Lux.setup(rng, chain) @@ -304,7 +275,6 @@ end end @testset "Translating from Flux" begin - println("Translating from Flux") linear = (u, p, t) -> cos(2pi * t) linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t) tspan = (0.0, 1.0) @@ -315,7 +285,7 @@ end u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x) fluxchain = Flux.Chain(Flux.Dense(1, 5, Flux.σ), Flux.Dense(5, 1)) alg1 = NNODE(fluxchain, opt) - @test alg1.chain isa Lux.AbstractLuxLayer + @test alg1.chain isa AbstractLuxLayer sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @test sol1.errors[:l2] < 0.5 end diff --git a/test/NNODE_tstops_test.jl b/test/NNODE_tstops_test.jl index 74a5a7252f..d7de703b25 100644 --- a/test/NNODE_tstops_test.jl +++ b/test/NNODE_tstops_test.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, Lux, OptimizationOptimisers, Test, Statistics, NeuralPDE +using OrdinaryDiffEq, Lux, OptimizationOptimisers, Optimisers, Test, Statistics, NeuralPDE function fu(u, p, t) [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] @@ -13,78 +13,31 @@ points3 = [rand() + 2 for i in 1:40] addedPoints = vcat(points1, points2, points3) saveat = 0.01 -maxiters = 30000 prob_oop = ODEProblem{false}(fu, u0, tspan, p) -true_sol = solve(prob_oop, Tsit5(), saveat = saveat) -func = Lux.σ -N = 12 -chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func), - Lux.Dense(N, N, func), Lux.Dense(N, length(u0))) +true_sol = solve(prob_oop, Tsit5(); saveat) +N = 16 +chain = Chain( + Dense(1, N, σ), Dense(N, N, σ), Dense(N, N, σ), Dense(N, N, σ), Dense(N, length(u0))) -opt = OptimizationOptimisers.Adam(0.01) +opt = Adam(0.01) threshold = 0.2 -#bad choices for weights, samples and dx so that the algorithm will fail without the added points -weights = [0.3, 0.3, 0.4] -points = 3 -dx = 1.0 +@testset "$(nameof(typeof(strategy)))" for strategy in [ + GridTraining(1.0), + WeightedIntervalTraining([0.3, 0.3, 0.4], 3), + StochasticTraining(3) +] + alg = NNODE(chain, opt; autodiff = false, strategy) -@testset "GridTraining" begin - println("GridTraining") @testset "Without added points" begin - println("Without added points") - # (difference between solutions should be high) - alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) - sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) + sol = solve(prob_oop, alg; verbose = true, maxiters = 1000, saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end - @testset "With added points" begin - println("With added points") - # (difference between solutions should be low) - alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) - sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, - saveat = saveat, tstops = addedPoints) - @test_broken abs(mean(sol) - mean(true_sol)) < threshold - end -end -@testset "WeightedIntervalTraining" begin - println("WeightedIntervalTraining") - @testset "Without added points" begin - println("Without added points") - # (difference between solutions should be high) - alg = NNODE(chain, opt, autodiff = false, - strategy = WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) - @test abs(mean(sol) - mean(true_sol)) > threshold - end - @testset "With added points" begin - println("With added points") - # (difference between solutions should be low) - alg = NNODE(chain, opt, autodiff = false, - strategy = WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, - saveat = saveat, tstops = addedPoints) - @test_broken abs(mean(sol) - mean(true_sol)) < threshold - end -end - -@testset "StochasticTraining" begin - println("StochasticTraining") - @testset "Without added points" begin - println("Without added points") - # (difference between solutions should be high) - alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) - sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) - @test abs(mean(sol) - mean(true_sol)) > threshold - end @testset "With added points" begin - println("With added points") - # (difference between solutions should be low) - alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) - sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, - saveat = saveat, tstops = addedPoints) - @test_broken abs(mean(sol) - mean(true_sol)) < threshold + sol = solve( + prob_oop, alg; verbose = true, maxiters = 10000, saveat, tstops = addedPoints) + @test abs(mean(sol) - mean(true_sol)) < threshold end end