Skip to content


refactor: cleanup NNDAE
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2024
1 parent 7e0e580 commit c98c2b9
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 116 deletions.
135 changes: 47 additions & 88 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
@@ -1,82 +1,72 @@
init_params = nothing;
autodiff = false,
NNDAE(chain, opt, init_params = nothing; autodiff = false, kwargs...)
Algorithm for solving differential algebraic equationsusing a neural network. This is a specialization
of the physics-informed neural network which is used as a solver for a standard `DAEProblem`.
Algorithm for solving differential algebraic equationsusing a neural network. This is a
specialization of the physics-informed neural network which is used as a solver for a
standard `DAEProblem`.
!!! warn
!!! warning
Note that NNDAE only supports DAEs which are written in the out-of-place form, i.e.
`du = f(du,u,p,t)`, and not `f(out,du,u,p,t)`. If not declared out-of-place, then the NNDAE
will exit with an error.
`du = f(du,u,p,t)`, and not `f(out,du,u,p,t)`. If not declared out-of-place, then the
NNDAE will exit with an error.
## Positional Arguments
* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a `Lux.AbstractLuxLayer`.
* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a
* `opt`: The optimizer to train the neural network.
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
which thus uses the random initialization provided by the neural network library.
## Keyword Arguments
* `autodiff`: The switch between automatic(not supported yet) and numerical differentiation for
the PDE operators. The reverse mode of the loss function is always
* `autodiff`: The switch between automatic (not supported yet) and numerical differentiation
for the PDE operators. The reverse mode of the loss function is always
automatic differentiation (via Zygote), this is only for the derivative
in the loss function (the derivative with respect to time).
* `strategy`: The training strategy used to choose the points for the evaluations.
By default, `GridTraining` is used with `dt` if given.
struct NNDAE{C, O, P, K, S <: Union{Nothing, AbstractTrainingStrategy}
} <: SciMLBase.AbstractDAEAlgorithm
@concrete struct NNDAE <: SciMLBase.AbstractDAEAlgorithm
chain <: AbstractLuxLayer
strategy <: Union{Nothing, AbstractTrainingStrategy}

function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false,
!(chain isa Lux.AbstractLuxLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
chain isa Lux.AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
return NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)

function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool,
if autodiff
autodiff && throw(ArgumentError("autodiff not supported for DAE problem."))
dphi = (phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t)))
batch_size = size(t)[1]
[dv ? dphi[[i], :] : zeros(1, batch_size)
for (i, dv) in enumerate(differential_vars)])
autodiff && throw(ArgumentError("autodiff not supported for DAE problem."))
ϵ = sqrt(eps(eltype(t)))
= (phi(t .+ ϵ, θ) .- phi(t, θ)) ./ ϵ
return reduce(vcat,
[dv ? dϕ[i:i, :] : zeros(eltype(dϕ), 1, size(dϕ, 2))
for (i, dv) in enumerate(differential_vars)])

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, differential_vars::AbstractVector) where {C, T, U}
out = Array(phi(t, θ))
dphi = Array(dfdx(phi, t, θ, autodiff, differential_vars))
arrt = Array(t)
loss = reduce(hcat, [f(dphi[:, i], out[:, i], p, arrt[i]) for i in 1:size(out, 2)])
sum(abs2, loss) / length(t)
function inner_loss(phi::ODEPhi, f, autodiff::Bool, t::AbstractVector,
θ, p, differential_vars::AbstractVector)
out = phi(t, θ)
dphi = dfdx(phi, t, θ, autodiff, differential_vars)
return mapreduce(+, enumerate(t)) do (i, tᵢ)
sum(abs2, f(dphi[:, i], out[:, i], p, tᵢ))
end / length(t)

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
ts = tspan[1]:(strategy.dx):tspan[2]
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
function loss(θ, _)
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
return loss
ts = tspan[1]:(strategy.dx):tspan[2]
return (θ, _) -> sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))

function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
Expand All @@ -92,31 +82,12 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
saveat = nothing,
maxiters = nothing,
tstops = nothing)
u0 = prob.u0
du0 = prob.du0
tspan = prob.tspan
f = prob.f
p = prob.p
(; u0, tspan, f, p, differential_vars) = prob
t0 = tspan[1]
(; chain, opt, autodiff, init_params) = alg

#hidden layer
chain = alg.chain
opt = alg.opt
autodiff = alg.autodiff

#train points generation
init_params = alg.init_params

# A logical array which declares which variables are the differential (non-algebraic) vars
differential_vars = prob.differential_vars

if chain isa Lux.AbstractLuxLayer || chain isa Flux.Chain
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
init_params = ComponentArray(;
depvar = ComponentArray(init_params))
error("Only Lux.AbstractLuxLayer and Flux.Chain neural networks are supported")
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
init_params = ComponentArray(; depvar = init_params)

if isinplace(prob)
throw(error("The NNODE solver only supports out-of-place DAE definitions, i.e. du=f(u,p,t)."))
Expand All @@ -133,29 +104,20 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,

strategy = if alg.strategy === nothing
if dt !== nothing
error("dt is not defined")
dt === nothing && error("`dt` is not defined")

inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars)

# Creates OptimizationFunction Object from total_loss
total_loss(θ, _) = inner_f(θ, phi)
optf = OptimizationFunction(total_loss, AutoZygote())

# Optimization Algo for Training Strategies
opt_algo = Optimization.AutoZygote()
# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

iteration = 0
callback = function (p, l)
iteration += 1
verbose && println("Current loss is: $l, Iteration: $iteration")
l < abstol
verbose && println("Current loss is: $l, Iteration: $(p.iter)")
return l < abstol

optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

Expand All @@ -178,14 +140,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
u = [phi(t, res.u) for t in ts]

sol = SciMLBase.build_solution(prob, alg, ts, u;
k = res, dense = true,
calculate_error = false,
retcode = ReturnCode.Success,
original = res,
sol = SciMLBase.build_solution(prob, alg, ts, u; k = res, dense = true,
calculate_error = false, retcode = ReturnCode.Success, original = res,
resid = res.objective)
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
return sol
13 changes: 5 additions & 8 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Algorithm for solving ordinary differential equations using a neural network. Th
specialization of the physics-informed neural network which is used as a solver for a
standard `ODEProblem`.
!!! warn
!!! warning
Note that NNODE only supports ODEs which are written in the out-of-place form, i.e.
`du = f(u,p,t)`, and not `f(du,u,p,t)`. If not declared out-of-place, then the NNODE
Expand Down Expand Up @@ -172,7 +172,7 @@ function inner_loss(phi::ODEPhi{<:Number}, f, autodiff::Bool, t::AbstractVector,
p_ = param_estim ? θ.p : p
out = phi(t, θ)
fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
dxdtguess = ode_dfdx(phi, t, θ, autodiff)
return sum(abs2, fs .- dxdtguess) / length(t)

Expand All @@ -184,10 +184,9 @@ end
function inner_loss(
phi::ODEPhi, f, autodiff::Bool, t::AbstractVector, θ, p, param_estim::Bool)
p_ = param_estim ? θ.p : p
out = Array(phi(t, θ))
arrt = Array(t)
fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
out = phi(t, θ)
fs = reduce(hcat, [f(out[:, i], p_, tᵢ) for (i, tᵢ) in enumerate(t)])
dxdtguess = ode_dfdx(phi, t, θ, autodiff)
return sum(abs2, fs .- dxdtguess) / length(t)

Expand Down Expand Up @@ -373,9 +372,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
return L2_loss

# Choice of Optimization Algo for Training Strategies
opt_algo = ifelse(strategy isa QuadratureTraining, AutoForwardDiff(), AutoZygote())
# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

callback = function (p, l)
Expand Down
27 changes: 10 additions & 17 deletions test/NNDAE_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using Test, Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
import Lux, OptimizationOptimisers, OptimizationOptimJL
using Test, Random, NeuralPDE, OrdinaryDiffEq, Statistics, Lux, Optimisers,
OptimizationOptimJL, Optimisers


Expand All @@ -22,15 +20,12 @@ Random.seed!(100)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

example = (du, u, p, t) -> [cos(2pi * t) - du[1], u[2] + cos(2pi * t) - du[2]]
differential_vars = [true, false]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, cos), Lux.Dense(15, 15, sin), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNDAE(chain, opt; autodiff = false)
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = [true, false])
chain = Chain(Dense(1, 15, cos), Dense(15, 15, sin), Dense(15, 2))
alg = NNDAE(chain, Optimisers.Adam(0.01); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)
sol = solve(
prob, alg, verbose = false, dt = 1 / 100.0f0, maxiters = 3000, abstol = 1.0f-10)
@test ground_sol(0:(1 / 100):1)sol atol=0.4

Expand All @@ -52,13 +47,11 @@ end
example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false)
chain = Chain(Dense(1, 15, σ), Dense(15, 2))
alg = NNDAE(chain, Optimisers.Adam(0.1); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)
alg, verbose = false, dt = 1 / 100.0f0, maxiters = 3000, abstol = 1.0f-10)

@test ground_sol(0:(1 / 100):(pi / 2))sol atol=0.4
4 changes: 2 additions & 2 deletions test/NNODE_tstops_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ threshold = 0.2
alg = NNODE(chain, opt; autodiff = false, strategy)

@testset "Without added points" begin
sol = solve(prob_oop, alg; verbose = true, maxiters = 1000, saveat)
sol = solve(prob_oop, alg; verbose = false, maxiters = 1000, saveat)
@test abs(mean(sol) - mean(true_sol)) > threshold

@testset "With added points" begin
sol = solve(
prob_oop, alg; verbose = true, maxiters = 10000, saveat, tstops = addedPoints)
prob_oop, alg; verbose = false, maxiters = 10000, saveat, tstops = addedPoints)
@test abs(mean(sol) - mean(true_sol)) < threshold
2 changes: 1 addition & 1 deletion test/dgm_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import ModelingToolkit: Interval, infimum, supremum
u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys],
(length(xs), length(ys)))

@test u_realu_predict atol=0.01 norm=Base.Fix2(norm, Inf)
@test u_realu_predict atol=0.1

@testset "Black-Scholes PDE: European Call Option" begin
Expand Down

0 comments on commit c98c2b9

Please sign in to comment.