diff --git a/src/dae_solve.jl b/src/dae_solve.jl index 39e410d4bb..94057068b3 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -1,82 +1,72 @@ """ - NNDAE(chain, - OptimizationOptimisers.Adam(0.1), - init_params = nothing; - autodiff = false, - kwargs...) + NNDAE(chain, opt, init_params = nothing; autodiff = false, kwargs...) -Algorithm for solving differential algebraic equationsusing a neural network. This is a specialization -of the physics-informed neural network which is used as a solver for a standard `DAEProblem`. +Algorithm for solving differential algebraic equationsusing a neural network. This is a +specialization of the physics-informed neural network which is used as a solver for a +standard `DAEProblem`. -!!! warn +!!! warning Note that NNDAE only supports DAEs which are written in the out-of-place form, i.e. - `du = f(du,u,p,t)`, and not `f(out,du,u,p,t)`. If not declared out-of-place, then the NNDAE - will exit with an error. + `du = f(du,u,p,t)`, and not `f(out,du,u,p,t)`. If not declared out-of-place, then the + NNDAE will exit with an error. ## Positional Arguments -* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a `Lux.AbstractLuxLayer`. +* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a + `Lux.AbstractLuxLayer`. * `opt`: The optimizer to train the neural network. * `init_params`: The initial parameter of the neural network. By default, this is `nothing` which thus uses the random initialization provided by the neural network library. ## Keyword Arguments -* `autodiff`: The switch between automatic(not supported yet) and numerical differentiation for - the PDE operators. The reverse mode of the loss function is always +* `autodiff`: The switch between automatic (not supported yet) and numerical differentiation + for the PDE operators. The reverse mode of the loss function is always automatic differentiation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time). * `strategy`: The training strategy used to choose the points for the evaluations. By default, `GridTraining` is used with `dt` if given. """ -struct NNDAE{C, O, P, K, S <: Union{Nothing, AbstractTrainingStrategy} -} <: SciMLBase.AbstractDAEAlgorithm - chain::C - opt::O - init_params::P +@concrete struct NNDAE <: SciMLBase.AbstractDAEAlgorithm + chain <: AbstractLuxLayer + opt + init_params autodiff::Bool - strategy::S - kwargs::K + strategy <: Union{Nothing, AbstractTrainingStrategy} + kwargs end function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false, kwargs...) - !(chain isa Lux.AbstractLuxLayer) && - (chain = adapt(FromFluxAdaptor(false, false), chain)) - NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) + chain isa Lux.AbstractLuxLayer || (chain = FromFluxAdaptor()(chain)) + return NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) end function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool, differential_vars::AbstractVector) - if autodiff - autodiff && throw(ArgumentError("autodiff not supported for DAE problem.")) - else - dphi = (phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t))) - batch_size = size(t)[1] - reduce(vcat, - [dv ? dphi[[i], :] : zeros(1, batch_size) - for (i, dv) in enumerate(differential_vars)]) - end + autodiff && throw(ArgumentError("autodiff not supported for DAE problem.")) + ϵ = sqrt(eps(eltype(t))) + dϕ = (phi(t .+ ϵ, θ) .- phi(t, θ)) ./ ϵ + return reduce(vcat, + [dv ? dϕ[i:i, :] : zeros(eltype(dϕ), 1, size(dϕ, 2)) + for (i, dv) in enumerate(differential_vars)]) end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, differential_vars::AbstractVector) where {C, T, U} - out = Array(phi(t, θ)) - dphi = Array(dfdx(phi, t, θ, autodiff, differential_vars)) - arrt = Array(t) - loss = reduce(hcat, [f(dphi[:, i], out[:, i], p, arrt[i]) for i in 1:size(out, 2)]) - sum(abs2, loss) / length(t) +function inner_loss(phi::ODEPhi, f, autodiff::Bool, t::AbstractVector, + θ, p, differential_vars::AbstractVector) + out = phi(t, θ) + dphi = dfdx(phi, t, θ, autodiff, differential_vars) + return mapreduce(+, enumerate(t)) do (i, tᵢ) + sum(abs2, f(dphi[:, i], out[:, i], p, tᵢ)) + end / length(t) end function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, differential_vars::AbstractVector) - ts = tspan[1]:(strategy.dx):tspan[2] autodiff && throw(ArgumentError("autodiff not supported for GridTraining.")) - function loss(θ, _) - sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) - end - return loss + ts = tspan[1]:(strategy.dx):tspan[2] + return (θ, _) -> sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) end function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, @@ -92,31 +82,12 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, saveat = nothing, maxiters = nothing, tstops = nothing) - u0 = prob.u0 - du0 = prob.du0 - tspan = prob.tspan - f = prob.f - p = prob.p + (; u0, tspan, f, p, differential_vars) = prob t0 = tspan[1] + (; chain, opt, autodiff, init_params) = alg - #hidden layer - chain = alg.chain - opt = alg.opt - autodiff = alg.autodiff - - #train points generation - init_params = alg.init_params - - # A logical array which declares which variables are the differential (non-algebraic) vars - differential_vars = prob.differential_vars - - if chain isa Lux.AbstractLuxLayer || chain isa Flux.Chain - phi, init_params = generate_phi_θ(chain, t0, u0, init_params) - init_params = ComponentArray(; - depvar = ComponentArray(init_params)) - else - error("Only Lux.AbstractLuxLayer and Flux.Chain neural networks are supported") - end + phi, init_params = generate_phi_θ(chain, t0, u0, init_params) + init_params = ComponentArray(; depvar = init_params) if isinplace(prob) throw(error("The NNODE solver only supports out-of-place DAE definitions, i.e. du=f(u,p,t).")) @@ -133,29 +104,20 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, end strategy = if alg.strategy === nothing - if dt !== nothing - GridTraining(dt) - else - error("dt is not defined") - end + dt === nothing && error("`dt` is not defined") + GridTraining(dt) end inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars) - # Creates OptimizationFunction Object from total_loss total_loss(θ, _) = inner_f(θ, phi) + optf = OptimizationFunction(total_loss, AutoZygote()) - # Optimization Algo for Training Strategies - opt_algo = Optimization.AutoZygote() - # Creates OptimizationFunction Object from total_loss - optf = OptimizationFunction(total_loss, opt_algo) - - iteration = 0 callback = function (p, l) - iteration += 1 - verbose && println("Current loss is: $l, Iteration: $iteration") - l < abstol + verbose && println("Current loss is: $l, Iteration: $(p.iter)") + return l < abstol end + optprob = OptimizationProblem(optf, init_params) res = solve(optprob, opt; callback, maxiters, alg.kwargs...) @@ -178,14 +140,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, u = [phi(t, res.u) for t in ts] end - sol = SciMLBase.build_solution(prob, alg, ts, u; - k = res, dense = true, - calculate_error = false, - retcode = ReturnCode.Success, - original = res, + sol = SciMLBase.build_solution(prob, alg, ts, u; k = res, dense = true, + calculate_error = false, retcode = ReturnCode.Success, original = res, resid = res.objective) SciMLBase.has_analytic(prob.f) && SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true, dense_errors = false) - sol + return sol end diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 0febce8ea9..3bfb8c6741 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -8,7 +8,7 @@ Algorithm for solving ordinary differential equations using a neural network. Th specialization of the physics-informed neural network which is used as a solver for a standard `ODEProblem`. -!!! warn +!!! warning Note that NNODE only supports ODEs which are written in the out-of-place form, i.e. `du = f(u,p,t)`, and not `f(du,u,p,t)`. If not declared out-of-place, then the NNODE @@ -172,7 +172,7 @@ function inner_loss(phi::ODEPhi{<:Number}, f, autodiff::Bool, t::AbstractVector, p_ = param_estim ? θ.p : p out = phi(t, θ) fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)]) - dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) + dxdtguess = ode_dfdx(phi, t, θ, autodiff) return sum(abs2, fs .- dxdtguess) / length(t) end @@ -184,10 +184,9 @@ end function inner_loss( phi::ODEPhi, f, autodiff::Bool, t::AbstractVector, θ, p, param_estim::Bool) p_ = param_estim ? θ.p : p - out = Array(phi(t, θ)) - arrt = Array(t) - fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)]) - dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) + out = phi(t, θ) + fs = reduce(hcat, [f(out[:, i], p_, tᵢ) for (i, tᵢ) in enumerate(t)]) + dxdtguess = ode_dfdx(phi, t, θ, autodiff) return sum(abs2, fs .- dxdtguess) / length(t) end @@ -373,9 +372,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, return L2_loss end - # Choice of Optimization Algo for Training Strategies opt_algo = ifelse(strategy isa QuadratureTraining, AutoForwardDiff(), AutoZygote()) - # Creates OptimizationFunction Object from total_loss optf = OptimizationFunction(total_loss, opt_algo) callback = function (p, l) diff --git a/test/NNDAE_tests.jl b/test/NNDAE_tests.jl index bbcf12dd6d..cc36fd09e8 100644 --- a/test/NNDAE_tests.jl +++ b/test/NNDAE_tests.jl @@ -1,7 +1,5 @@ -using Test, Flux -using Random, NeuralPDE -using OrdinaryDiffEq, Statistics -import Lux, OptimizationOptimisers, OptimizationOptimJL +using Test, Random, NeuralPDE, OrdinaryDiffEq, Statistics, Lux, Optimisers, + OptimizationOptimJL, Optimisers Random.seed!(100) @@ -22,15 +20,12 @@ Random.seed!(100) ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) example = (du, u, p, t) -> [cos(2pi * t) - du[1], u[2] + cos(2pi * t) - du[2]] - differential_vars = [true, false] - prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) - chain = Lux.Chain(Lux.Dense(1, 15, cos), Lux.Dense(15, 15, sin), Lux.Dense(15, 2)) - opt = OptimizationOptimisers.Adam(0.1) - alg = NeuralPDE.NNDAE(chain, opt; autodiff = false) + prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = [true, false]) + chain = Chain(Dense(1, 15, cos), Dense(15, 15, sin), Dense(15, 2)) + alg = NNDAE(chain, Optimisers.Adam(0.01); autodiff = false) - sol = solve(prob, - alg, verbose = false, dt = 1 / 100.0f0, - maxiters = 3000, abstol = 1.0f-10) + sol = solve( + prob, alg, verbose = false, dt = 1 / 100.0f0, maxiters = 3000, abstol = 1.0f-10) @test ground_sol(0:(1 / 100):1)≈sol atol=0.4 end @@ -52,13 +47,11 @@ end example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]] differential_vars = [false, true] prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) - chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2)) - opt = OptimizationOptimisers.Adam(0.1) - alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false) + chain = Chain(Dense(1, 15, σ), Dense(15, 2)) + alg = NNDAE(chain, Optimisers.Adam(0.1); autodiff = false) sol = solve(prob, - alg, verbose = false, dt = 1 / 100.0f0, - maxiters = 3000, abstol = 1.0f-10) + alg, verbose = false, dt = 1 / 100.0f0, maxiters = 3000, abstol = 1.0f-10) @test ground_sol(0:(1 / 100):(pi / 2))≈sol atol=0.4 end diff --git a/test/NNODE_tstops_test.jl b/test/NNODE_tstops_test.jl index d7de703b25..82f0278a5d 100644 --- a/test/NNODE_tstops_test.jl +++ b/test/NNODE_tstops_test.jl @@ -31,13 +31,13 @@ threshold = 0.2 alg = NNODE(chain, opt; autodiff = false, strategy) @testset "Without added points" begin - sol = solve(prob_oop, alg; verbose = true, maxiters = 1000, saveat) + sol = solve(prob_oop, alg; verbose = false, maxiters = 1000, saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin sol = solve( - prob_oop, alg; verbose = true, maxiters = 10000, saveat, tstops = addedPoints) + prob_oop, alg; verbose = false, maxiters = 10000, saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end end diff --git a/test/dgm_test.jl b/test/dgm_test.jl index b4885058a0..7637ea921d 100644 --- a/test/dgm_test.jl +++ b/test/dgm_test.jl @@ -45,7 +45,7 @@ import ModelingToolkit: Interval, infimum, supremum u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys], (length(xs), length(ys))) - @test u_real≈u_predict atol=0.01 norm=Base.Fix2(norm, Inf) + @test u_real≈u_predict atol=0.1 end @testset "Black-Scholes PDE: European Call Option" begin