Skip to content

Commit

Permalink
refactor: cleanup neural adapter code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2024
1 parent ff778e4 commit b76d8a0
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 146 deletions.
132 changes: 46 additions & 86 deletions src/neural_adapter.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
function generate_training_sets(domains, dx, eqs, eltypeθ)
if dx isa Array
dxs = dx
else
dxs = fill(dx, length(domains))
end
dxs = dx isa Array ? dx : fill(dx, length(domains))
spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]
train_set = adapt(eltypeθ,
hcat(vec(map(points -> collect(points), Iterators.product(spans...)))...))
end

function get_loss_function_(loss, init_params, pde_system, strategy::GridTraining)
eqs = pde_system.eqs
if !(eqs isa Array)
eqs = [eqs]
end
domains = pde_system.domain
depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars,
pde_system.depvars)
eltypeθ = eltype(init_params)
dx = strategy.dx
train_set = generate_training_sets(domains, dx, eqs, eltypeθ)
get_loss_function(loss, train_set, eltypeθ, strategy)
return hcat(vec(map(points -> collect(points), Iterators.product(spans...)))...) |>
EltypeAdaptor{eltypeθ}()
end

function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
Expand All @@ -29,75 +11,60 @@ function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strateg
args = get_argument(eqs, dict_indvars, dict_depvars)

bounds = first(map(args) do pd
span = map(p -> get(dict_span, p, p), pd)
map(s -> adapt(eltypeθ, s), span)
return get.((dict_span,), pd, pd) |> EltypeAdaptor{eltypeθ}()
end)
bounds = [getindex.(bounds, 1), getindex.(bounds, 2)]
return bounds
end

function get_loss_function_(loss, init_params, pde_system, strategy::StochasticTraining)
eqs = pde_system.eqs
if !(eqs isa Array)
eqs = [eqs]
end
domains = pde_system.domain

depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars,
pde_system.depvars)

eltypeθ = eltype(init_params)
bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
get_loss_function(loss, bound, eltypeθ, strategy)
end

function get_loss_function_(loss, init_params, pde_system, strategy::QuasiRandomTraining)
eqs = pde_system.eqs
if !(eqs isa Array)
eqs = [eqs]
end
domains = pde_system.domain

depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars,
pde_system.depvars)

eltypeθ = eltype(init_params)
bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
get_loss_function(loss, bound, eltypeθ, strategy)
return [getindex.(bounds, 1), getindex.(bounds, 2)]
end

function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars,
strategy::QuadratureTraining)
::QuadratureTraining)
dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains])
dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains])

args = get_argument(eqs, dict_indvars, dict_depvars)

lower_bounds = map(args) do pd
span = map(p -> get(dict_lower_bound, p, p), pd)
map(s -> adapt(eltypeθ, s), span)
return get.((dict_lower_bound,), pd, pd) |> EltypeAdaptor{eltypeθ}()
end
upper_bounds = map(args) do pd
span = map(p -> get(dict_upper_bound, p, p), pd)
map(s -> adapt(eltypeθ, s), span)
return get.((dict_upper_bound,), pd, pd) |> EltypeAdaptor{eltypeθ}()
end
bound = lower_bounds, upper_bounds
return lower_bounds, upper_bounds
end

function get_loss_function_(loss, init_params, pde_system, strategy::QuadratureTraining)
function get_loss_function_neural_adapter(
loss, init_params, pde_system, strategy::GridTraining)
eqs = pde_system.eqs
if !(eqs isa Array)
eqs = [eqs]
end
eqs isa Array || (eqs = [eqs])
eltypeθ = recursive_eltype(init_params)
train_set = generate_training_sets(pde_system.domain, strategy.dx, eqs, eltypeθ)
return get_loss_function(loss, train_set, eltypeθ, strategy)
end

function get_loss_function_neural_adapter(loss, init_params, pde_system,
strategy::Union{StochasticTraining, QuasiRandomTraining})
eqs = pde_system.eqs
eqs isa Array || (eqs = [eqs])
domains = pde_system.domain

_, _, dict_indvars, dict_depvars = get_vars(pde_system.indvars, pde_system.depvars)

eltypeθ = recursive_eltype(init_params)
bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
return get_loss_function(loss, bound, eltypeθ, strategy)
end

function get_loss_function_neural_adapter(
loss, init_params, pde_system, strategy::QuadratureTraining)
eqs = pde_system.eqs
eqs isa Array || (eqs = [eqs])
domains = pde_system.domain

depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars,
pde_system.depvars)
_, _, dict_indvars, dict_depvars = get_vars(pde_system.indvars, pde_system.depvars)

eltypeθ = eltype(init_params)
eltypeθ = recursive_eltype(init_params)
bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
lb, ub = bound
get_loss_function(loss, lb[1], ub[1], eltypeθ, strategy)
return get_loss_function(loss, bound[1][1], bound[2][1], eltypeθ, strategy)
end

"""
Expand All @@ -115,24 +82,17 @@ Trains a neural network using the results from one already obtained prediction.
function neural_adapter end

function neural_adapter(loss, init_params, pde_system, strategy)
loss_function__ = get_loss_function_(loss, init_params, pde_system, strategy)

function loss_function_(θ, p)
loss_function__(θ)
end
f_ = OptimizationFunction(loss_function_, Optimization.AutoZygote())
prob = Optimization.OptimizationProblem(f_, init_params)
loss_function = get_loss_function_neural_adapter(
loss, init_params, pde_system, strategy)
return OptimizationProblem(
OptimizationFunction((θ, _) -> loss_function(θ), AutoZygote()), init_params)
end

function neural_adapter(losses::Array, init_params, pde_systems::Array, strategy)
loss_functions_ = map(zip(losses, pde_systems)) do (l, p)
get_loss_function_(l, init_params, p, strategy)
loss_functions = map(zip(losses, pde_systems)) do (l, p)
get_loss_function_neural_adapter(l, init_params, p, strategy)
end
loss_function__ = θ -> sum(map(l -> l(θ), loss_functions_))
function loss_function_(θ, p)
loss_function__(θ)
end

f_ = OptimizationFunction(loss_function_, Optimization.AutoZygote())
prob = Optimization.OptimizationProblem(f_, init_params)
return OptimizationProblem(
OptimizationFunction((θ, _) -> sum(l -> l(θ), loss_functions), AutoZygote()),
init_params)
end
2 changes: 1 addition & 1 deletion src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ standard `ODEProblem`.
* `chain`: A neural network architecture, defined as a `Lux.AbstractLuxLayer` or
`Flux.Chain`. `Flux.Chain` will be converted to `Lux` using
`adapt(FromFluxAdaptor(false, false), chain)`.
`adapt(FromFluxAdaptor(), chain)`.
* `opt`: The optimizer to train the neural network.
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
which thus uses the random initialization provided by the neural network
Expand Down
2 changes: 1 addition & 1 deletion test/dgm_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,5 @@ end

u_predict = [first(phi([t, x], res.u)) for t in ts, x in xs]

@test u_predictu_MOL rtol=0.025
@test u_predictu_MOL rtol=0.1
end
95 changes: 37 additions & 58 deletions test/neural_adapter_tests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
using Test, NeuralPDE
using Optimization
using Test, NeuralPDE, Optimization, Lux, OptimizationOptimisers, Statistics,
ComponentArrays, Random
import ModelingToolkit: Interval, infimum, supremum
import Lux, OptimizationOptimisers
using Statistics
using ComponentArrays

using Random
Random.seed!(100)

callback = function (p, l)
Expand All @@ -26,19 +22,15 @@ end
bcs = [u(0, y) ~ 0.0, u(1, y) ~ -sin(pi * 1) * sin(pi * y),
u(x, 0) ~ 0.0, u(x, 1) ~ -sin(pi * x) * sin(pi * 1)]
# Space and time domains
domains = [x Interval(0.0, 1.0),
y Interval(0.0, 1.0)]
quadrature_strategy = NeuralPDE.QuadratureTraining(reltol = 1e-3, abstol = 1e-6,
maxiters = 50, batch = 100)
domains = [x Interval(0.0, 1.0), y Interval(0.0, 1.0)]
quadrature_strategy = NeuralPDE.QuadratureTraining(
reltol = 1e-3, abstol = 1e-6, maxiters = 50, batch = 100)
inner = 8
af = Lux.tanh
chain1 = Lux.Chain(Lux.Dense(2, inner, af),
Lux.Dense(inner, inner, af),
Lux.Dense(inner, 1))
af = tanh
chain1 = Chain(Dense(2, inner, af), Dense(inner, inner, af), Dense(inner, 1))
init_params = Lux.setup(Random.default_rng(), chain1)[1] |> ComponentArray .|> Float64
discretization = NeuralPDE.PhysicsInformedNN(chain1,
quadrature_strategy;
init_params = init_params)
discretization = NeuralPDE.PhysicsInformedNN(
chain1, quadrature_strategy; init_params = init_params)

@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
prob = NeuralPDE.discretize(pde_system, discretization)
Expand All @@ -47,11 +39,9 @@ end
phi = discretization.phi

inner_ = 8
af = Lux.tanh
chain2 = Lux.Chain(Lux.Dense(2, inner_, af),
Lux.Dense(inner_, inner_, af),
Lux.Dense(inner_, inner_, af),
Lux.Dense(inner_, 1))
af = tanh
chain2 = Chain(Dense(2, inner_, af), Dense(inner_, inner_, af),
Dense(inner_, inner_, af), Dense(inner_, 1))
initp, st = Lux.setup(Random.default_rng(), chain2)
init_params2 = Float64.(ComponentArray(initp))

Expand Down Expand Up @@ -89,16 +79,16 @@ end
xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2)

u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys],
(length(xs), length(ys)))
u_predict = reshape(
[first(phi([x, y], res.u)) for x in xs for y in ys], (length(xs), length(ys)))

u_predicts = map(zip(phis, reses_)) do (phi_, res_)
reshape([first(phi_([x, y], res_.u)) for x in xs for y in ys],
(length(xs), length(ys)))
reshape(
[first(phi_([x, y], res_.u)) for x in xs for y in ys], (length(xs), length(ys)))
end

u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys],
(length(xs), length(ys)))
u_real = reshape(
[analytic_sol_func(x, y) for x in xs for y in ys], (length(xs), length(ys)))

@test u_predictu_real rtol=1e-1
@test u_predicts[1]u_real rtol=1e-1
Expand Down Expand Up @@ -127,37 +117,30 @@ end
count_decomp = 10

# Neural network
af = Lux.tanh
af = tanh
inner = 12
chains = [Lux.Chain(Lux.Dense(2, inner, af), Lux.Dense(inner, inner, af),
Lux.Dense(inner, 1)) for _ in 1:count_decomp]
chains = [Chain(Dense(2, inner, af), Dense(inner, inner, af), Dense(inner, 1))
for _ in 1:count_decomp]
init_params = map(
c -> Float64.(ComponentArray(Lux.setup(Random.default_rng(),
c)[1])),
chains)
c -> Float64.(ComponentArray(Lux.setup(Random.default_rng(), c)[1])), chains)

xs_ = infimum(x_domain):(1 / count_decomp):supremum(x_domain)
xs_domain = [(xs_[i], xs_[i + 1]) for i in 1:(length(xs_) - 1)]
domains_map = map(xs_domain) do (xs_dom)
x_domain_ = Interval(xs_dom...)
domains_ = [x x_domain_,
y y_domain]
domains_ = [x x_domain_, y y_domain]
end

analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2)
function create_bcs(x_domain_, phi_bound)
x_0, x_e = x_domain_.left, x_domain_.right
if x_0 == 0.0
bcs = [u(0, y) ~ 0.0,
u(x_e, y) ~ analytic_sol_func(x_e, y),
u(x, 0) ~ 0.0,
u(x, 1) ~ -sin(pi * x) * sin(pi * 1)]
bcs = [u(0, y) ~ 0.0, u(x_e, y) ~ analytic_sol_func(x_e, y),
u(x, 0) ~ 0.0, u(x, 1) ~ -sin(pi * x) * sin(pi * 1)]
return bcs
end
bcs = [u(x_0, y) ~ phi_bound(x_0, y),
u(x_e, y) ~ analytic_sol_func(x_e, y),
u(x, 0) ~ 0.0,
u(x, 1) ~ -sin(pi * x) * sin(pi * 1)]
bcs = [u(x_0, y) ~ phi_bound(x_0, y), u(x_e, y) ~ analytic_sol_func(x_e, y),
u(x, 0) ~ 0.0, u(x, 1) ~ -sin(pi * x) * sin(pi * 1)]
bcs
end

Expand Down Expand Up @@ -217,12 +200,9 @@ end
u_predict, diff_u = compose_result(dx)

inner_ = 18
af = Lux.tanh
chain2 = Lux.Chain(Lux.Dense(2, inner_, af),
Lux.Dense(inner_, inner_, af),
Lux.Dense(inner_, inner_, af),
Lux.Dense(inner_, inner_, af),
Lux.Dense(inner_, 1))
af = tanh
chain2 = Chain(Dense(2, inner_, af), Dense(inner_, inner_, af),
Dense(inner_, inner_, af), Dense(inner_, inner_, af), Dense(inner_, 1))

initp, st = Lux.setup(Random.default_rng(), chain2)
init_params2 = Float64.(ComponentArray(initp))
Expand All @@ -236,21 +216,20 @@ end
end
end

prob_ = NeuralPDE.neural_adapter(losses, init_params2, pde_system_map,
GridTraining([0.1 / count_decomp, 0.1]))
prob_ = NeuralPDE.neural_adapter(
losses, init_params2, pde_system_map, GridTraining([0.1 / count_decomp, 0.1]))
@time res_ = solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
@show res_.objective
prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map,
GridTraining(0.01))
prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map, GridTraining(0.01))
@time res_ = solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
@show res_.objective

phi_ = NeuralPDE.Phi(chain2)
xs, ys = [infimum(d.domain):dx:supremum(d.domain) for d in domains]
u_predict_ = reshape([first(phi_([x, y], res_.u)) for x in xs for y in ys],
(length(xs), length(ys)))
u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys],
(length(xs), length(ys)))
u_predict_ = reshape(
[first(phi_([x, y], res_.u)) for x in xs for y in ys], (length(xs), length(ys)))
u_real = reshape(
[analytic_sol_func(x, y) for x in xs for y in ys], (length(xs), length(ys)))
diff_u_ = u_predict_ .- u_real

@test u_predictu_real rtol=1e-1
Expand Down

0 comments on commit b76d8a0

Please sign in to comment.