Skip to content

Commit

Permalink
refactor: cleanup NNDAE
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2024
1 parent 7e0e580 commit c98c2b9
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 116 deletions.
135 changes: 47 additions & 88 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
@@ -1,82 +1,72 @@
"""
NNDAE(chain,
OptimizationOptimisers.Adam(0.1),
init_params = nothing;
autodiff = false,
kwargs...)
NNDAE(chain, opt, init_params = nothing; autodiff = false, kwargs...)
Algorithm for solving differential algebraic equationsusing a neural network. This is a specialization
of the physics-informed neural network which is used as a solver for a standard `DAEProblem`.
Algorithm for solving differential algebraic equationsusing a neural network. This is a
specialization of the physics-informed neural network which is used as a solver for a
standard `DAEProblem`.
!!! warn
!!! warning
Note that NNDAE only supports DAEs which are written in the out-of-place form, i.e.
`du = f(du,u,p,t)`, and not `f(out,du,u,p,t)`. If not declared out-of-place, then the NNDAE
will exit with an error.
`du = f(du,u,p,t)`, and not `f(out,du,u,p,t)`. If not declared out-of-place, then the
NNDAE will exit with an error.
## Positional Arguments
* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a `Lux.AbstractLuxLayer`.
* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a
`Lux.AbstractLuxLayer`.
* `opt`: The optimizer to train the neural network.
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
which thus uses the random initialization provided by the neural network library.
## Keyword Arguments
* `autodiff`: The switch between automatic(not supported yet) and numerical differentiation for
the PDE operators. The reverse mode of the loss function is always
* `autodiff`: The switch between automatic (not supported yet) and numerical differentiation
for the PDE operators. The reverse mode of the loss function is always
automatic differentiation (via Zygote), this is only for the derivative
in the loss function (the derivative with respect to time).
* `strategy`: The training strategy used to choose the points for the evaluations.
By default, `GridTraining` is used with `dt` if given.
"""
struct NNDAE{C, O, P, K, S <: Union{Nothing, AbstractTrainingStrategy}
} <: SciMLBase.AbstractDAEAlgorithm
chain::C
opt::O
init_params::P
@concrete struct NNDAE <: SciMLBase.AbstractDAEAlgorithm
chain <: AbstractLuxLayer
opt
init_params
autodiff::Bool
strategy::S
kwargs::K
strategy <: Union{Nothing, AbstractTrainingStrategy}
kwargs
end

function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false,
kwargs...)
!(chain isa Lux.AbstractLuxLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
chain isa Lux.AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
return NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
end

function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool,
differential_vars::AbstractVector)
if autodiff
autodiff && throw(ArgumentError("autodiff not supported for DAE problem."))
else
dphi = (phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t)))
batch_size = size(t)[1]
reduce(vcat,
[dv ? dphi[[i], :] : zeros(1, batch_size)
for (i, dv) in enumerate(differential_vars)])
end
autodiff && throw(ArgumentError("autodiff not supported for DAE problem."))
ϵ = sqrt(eps(eltype(t)))
= (phi(t .+ ϵ, θ) .- phi(t, θ)) ./ ϵ
return reduce(vcat,
[dv ? dϕ[i:i, :] : zeros(eltype(dϕ), 1, size(dϕ, 2))
for (i, dv) in enumerate(differential_vars)])
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, differential_vars::AbstractVector) where {C, T, U}
out = Array(phi(t, θ))
dphi = Array(dfdx(phi, t, θ, autodiff, differential_vars))
arrt = Array(t)
loss = reduce(hcat, [f(dphi[:, i], out[:, i], p, arrt[i]) for i in 1:size(out, 2)])
sum(abs2, loss) / length(t)
function inner_loss(phi::ODEPhi, f, autodiff::Bool, t::AbstractVector,
θ, p, differential_vars::AbstractVector)
out = phi(t, θ)
dphi = dfdx(phi, t, θ, autodiff, differential_vars)
return mapreduce(+, enumerate(t)) do (i, tᵢ)
sum(abs2, f(dphi[:, i], out[:, i], p, tᵢ))
end / length(t)
end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
ts = tspan[1]:(strategy.dx):tspan[2]
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
function loss(θ, _)
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
ts = tspan[1]:(strategy.dx):tspan[2]
return (θ, _) -> sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end

function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
Expand All @@ -92,31 +82,12 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
saveat = nothing,
maxiters = nothing,
tstops = nothing)
u0 = prob.u0
du0 = prob.du0
tspan = prob.tspan
f = prob.f
p = prob.p
(; u0, tspan, f, p, differential_vars) = prob
t0 = tspan[1]
(; chain, opt, autodiff, init_params) = alg

#hidden layer
chain = alg.chain
opt = alg.opt
autodiff = alg.autodiff

#train points generation
init_params = alg.init_params

# A logical array which declares which variables are the differential (non-algebraic) vars
differential_vars = prob.differential_vars

if chain isa Lux.AbstractLuxLayer || chain isa Flux.Chain
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
init_params = ComponentArray(;
depvar = ComponentArray(init_params))
else
error("Only Lux.AbstractLuxLayer and Flux.Chain neural networks are supported")
end
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
init_params = ComponentArray(; depvar = init_params)

if isinplace(prob)
throw(error("The NNODE solver only supports out-of-place DAE definitions, i.e. du=f(u,p,t)."))
Expand All @@ -133,29 +104,20 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
end

strategy = if alg.strategy === nothing
if dt !== nothing
GridTraining(dt)
else
error("dt is not defined")
end
dt === nothing && error("`dt` is not defined")
GridTraining(dt)
end

inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars)

# Creates OptimizationFunction Object from total_loss
total_loss(θ, _) = inner_f(θ, phi)
optf = OptimizationFunction(total_loss, AutoZygote())

# Optimization Algo for Training Strategies
opt_algo = Optimization.AutoZygote()
# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

iteration = 0
callback = function (p, l)
iteration += 1
verbose && println("Current loss is: $l, Iteration: $iteration")
l < abstol
verbose && println("Current loss is: $l, Iteration: $(p.iter)")
return l < abstol
end

optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

Expand All @@ -178,14 +140,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
u = [phi(t, res.u) for t in ts]
end

sol = SciMLBase.build_solution(prob, alg, ts, u;
k = res, dense = true,
calculate_error = false,
retcode = ReturnCode.Success,
original = res,
sol = SciMLBase.build_solution(prob, alg, ts, u; k = res, dense = true,
calculate_error = false, retcode = ReturnCode.Success, original = res,
resid = res.objective)
SciMLBase.has_analytic(prob.f) &&
SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
return sol
end
13 changes: 5 additions & 8 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Algorithm for solving ordinary differential equations using a neural network. Th
specialization of the physics-informed neural network which is used as a solver for a
standard `ODEProblem`.
!!! warn
!!! warning
Note that NNODE only supports ODEs which are written in the out-of-place form, i.e.
`du = f(u,p,t)`, and not `f(du,u,p,t)`. If not declared out-of-place, then the NNODE
Expand Down Expand Up @@ -172,7 +172,7 @@ function inner_loss(phi::ODEPhi{<:Number}, f, autodiff::Bool, t::AbstractVector,
p_ = param_estim ? θ.p : p
out = phi(t, θ)
fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
dxdtguess = ode_dfdx(phi, t, θ, autodiff)
return sum(abs2, fs .- dxdtguess) / length(t)
end

Expand All @@ -184,10 +184,9 @@ end
function inner_loss(
phi::ODEPhi, f, autodiff::Bool, t::AbstractVector, θ, p, param_estim::Bool)
p_ = param_estim ? θ.p : p
out = Array(phi(t, θ))
arrt = Array(t)
fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
out = phi(t, θ)
fs = reduce(hcat, [f(out[:, i], p_, tᵢ) for (i, tᵢ) in enumerate(t)])
dxdtguess = ode_dfdx(phi, t, θ, autodiff)
return sum(abs2, fs .- dxdtguess) / length(t)
end

Expand Down Expand Up @@ -373,9 +372,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
return L2_loss
end

# Choice of Optimization Algo for Training Strategies
opt_algo = ifelse(strategy isa QuadratureTraining, AutoForwardDiff(), AutoZygote())
# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

callback = function (p, l)
Expand Down
27 changes: 10 additions & 17 deletions test/NNDAE_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using Test, Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
import Lux, OptimizationOptimisers, OptimizationOptimJL
using Test, Random, NeuralPDE, OrdinaryDiffEq, Statistics, Lux, Optimisers,
OptimizationOptimJL, Optimisers

Random.seed!(100)

Expand All @@ -22,15 +20,12 @@ Random.seed!(100)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

example = (du, u, p, t) -> [cos(2pi * t) - du[1], u[2] + cos(2pi * t) - du[2]]
differential_vars = [true, false]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, cos), Lux.Dense(15, 15, sin), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNDAE(chain, opt; autodiff = false)
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = [true, false])
chain = Chain(Dense(1, 15, cos), Dense(15, 15, sin), Dense(15, 2))
alg = NNDAE(chain, Optimisers.Adam(0.01); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)
sol = solve(
prob, alg, verbose = false, dt = 1 / 100.0f0, maxiters = 3000, abstol = 1.0f-10)
@test ground_sol(0:(1 / 100):1)sol atol=0.4
end

Expand All @@ -52,13 +47,11 @@ end
example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false)
chain = Chain(Dense(1, 15, σ), Dense(15, 2))
alg = NNDAE(chain, Optimisers.Adam(0.1); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)
alg, verbose = false, dt = 1 / 100.0f0, maxiters = 3000, abstol = 1.0f-10)

@test ground_sol(0:(1 / 100):(pi / 2))sol atol=0.4
end
4 changes: 2 additions & 2 deletions test/NNODE_tstops_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ threshold = 0.2
alg = NNODE(chain, opt; autodiff = false, strategy)

@testset "Without added points" begin
sol = solve(prob_oop, alg; verbose = true, maxiters = 1000, saveat)
sol = solve(prob_oop, alg; verbose = false, maxiters = 1000, saveat)
@test abs(mean(sol) - mean(true_sol)) > threshold
end

@testset "With added points" begin
sol = solve(
prob_oop, alg; verbose = true, maxiters = 10000, saveat, tstops = addedPoints)
prob_oop, alg; verbose = false, maxiters = 10000, saveat, tstops = addedPoints)
@test abs(mean(sol) - mean(true_sol)) < threshold
end
end
2 changes: 1 addition & 1 deletion test/dgm_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import ModelingToolkit: Interval, infimum, supremum
u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys],
(length(xs), length(ys)))

@test u_realu_predict atol=0.01 norm=Base.Fix2(norm, Inf)
@test u_realu_predict atol=0.1
end

@testset "Black-Scholes PDE: European Call Option" begin
Expand Down

0 comments on commit c98c2b9

Please sign in to comment.