Skip to content

Commit

Permalink
Interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Oct 29, 2024
1 parent 35d8346 commit c5a456a
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 48 deletions.
59 changes: 50 additions & 9 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,26 +216,67 @@ function generate_loss(
end
end

"""
PINOODEInterpolation(phi, θ)
Interpolation of the solution of the ODE using a trained neural network.
## Arguments
* `phi`: The neural network
* `θ`: The parameters of the neural network.
```
"""
@concrete struct PINOODEInterpolation{T <: PINOPhi, T2}
phi::T
θ::T2
end

"""
Override interpolation method for PINOODEInterpolation
## Arguments
* `x`: Input data on which the solution is to be interpolated.
## Example
```jldoctest
interp = PINOODEInterpolation(phi, θ)
x = rand(2, 50, 10)
interp(x)
```
"""
(f::PINOODEInterpolation)(x) = f.phi(x, f.θ)

"""
Override interpolation method for PINOODEInterpolation
## Arguments
# * `p`: The parameters points on which the solution is to be interpolated.
# * `t`: The time points on which the solution is to be interpolated.
## Example
```jldoctest
interp = PINOODEInterpolation(phi, θ)
p,t = rand(1, 50, 10), rand(1, 50, 10)
interp(p, t)
```
"""
function (f::PINOODEInterpolation)(p, t)
if f.phi.model isa DeepONet
f.phi((p, t), f.θ)
elseif f.phi.model isa Chain
f.phi(reduce(vcat, (p, t)), f.θ)
else
error("Only DeepONet and Chain neural networks are supported with PINO ODE")
end
end

SciMLBase.interp_summary(::PINOODEInterpolation) = "Trained neural network interpolation"
SciMLBase.allowscomplex(::PINOODE) = true

#TODO
function (sol::SciMLBase.AbstractODESolution)(t::AbstractArray)
# p,t = sol.t
# sol.interp(reduce(vcat, (p, t)))
sol.interp(t)
end
function (sol::SciMLBase.AbstractODESolution)(t::Tuple)
# p,t = sol.t
# sol.interp((p, t))
sol.interp(t)
p, _ = sol.t
sol.interp(p, t)
end

function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
Expand Down
72 changes: 33 additions & 39 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end
@testitem "Example Chain du = cos(p * t)" tags=[:pinoode] setup=[PINOODETestSetup] begin
using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
tspan = (0.0, 1.0)
u0 = 1.0
prob = ODEProblem(equation, u0, tspan)
chain = Chain(
Expand All @@ -44,19 +44,20 @@ end
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
p, t = get_trainset(chain, bounds, 50, tspan, 0.025)
ground_solution = ground_analytic.(u0, p, t)
predict_sol = sol(reduce(vcat, (p, t)))
predict_sol = sol.interp(p, t)
predict_sol = sol.interp(reduce(vcat, (p, t)))
@test ground_solutionpredict_sol rtol=0.05
p, t = get_trainset(chain, bounds, 100, tspan, 0.01)
ground_solution = ground_analytic.(u0, p, t)
predict_sol = sol(reduce(vcat, (p, t)))
predict_sol = sol.interp(p, t)
@test ground_solutionpredict_sol rtol=0.05
end

#Test DeepONet
@testitem "Example DeepONet du = cos(p * t)" tags=[:pinoode] setup=[PINOODETestSetup] begin
using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
tspan = (0.0, 1.0)
u0 = 1.0
prob = ODEProblem(equation, u0, tspan)
deeponet = NeuralOperators.DeepONet(
Expand Down Expand Up @@ -84,29 +85,30 @@ end
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
p, t = get_trainset(deeponet, bounds, 50, tspan, 0.025)
ground_solution = ground_analytic.(u0, p, vec(t))
predict_sol = sol((p, t))
predict_sol = sol(t)
predict_sol = sol.interp(p, t)
@test ground_solutionpredict_sol rtol=0.05
p, t = get_trainset(deeponet, bounds, 100, tspan, 0.01)
ground_solution = ground_analytic.(u0, p, vec(t))
predict_sol = sol((p, t))
predict_sol = sol.interp(p, t)
@test ground_solutionpredict_sol rtol=0.05
end

@testitem "Example du = cos(p * t) + u" tags=[:pinoode] setup=[PINOODETestSetup] begin
using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random
eq_(u, p, t) = cos(p * t) + u
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
tspan = (0.0, 1.0)
u0 = 1.0
prob = ODEProblem(eq_, u0, tspan)
deeponet = NeuralOperators.DeepONet(
Chain(
Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast)))
bounds = [(0.1f0, 2.0f0)]
bounds = [(0.1, 2.0)]
number_of_parameters = 40
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(0.1f0)
strategy = GridTraining(0.1)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 4000)
Expand All @@ -116,53 +118,45 @@ end
(p^2 + 1)
p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
ground_solution = ground_analytic_.(u0, p, vec(t))
predict_sol = sol((p, t))
predict_sol = sol.interp(p, t)
@test ground_solutionpredict_sol rtol=0.05
end

@testitem "Example with data du = p*t^2" tags=[:pinoode] setup=[PINOODETestSetup] begin
using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random
equation = (u, p, t) -> p * t^2
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
tspan = (0.0, 1.0)
u0 = 0.0
prob = ODEProblem(equation, u0, tspan)
deeponet = NeuralOperators.DeepONet(
Chain(
Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast)))
bounds = [(0.0f0, 10.0f0)]
bounds = [(0.0, 10.0)]
number_of_parameters = 60
dt = (tspan[2] - tspan[1]) / 40
strategy = StochasticTraining(60)
opt = OptimizationOptimisers.Adam(0.03)
opt = OptimizationOptimisers.Adam(0.01)

#generate data
ground_analytic = (u0, p, t) -> u0 + p * t^3 / 3

function get_data()
sol = ground_analytic.(u0, p, vec(t))
tuple_ = (p, t)
sol, tuple_
end
u = rand(1, 50)
v = rand(1, 40, 1)
θ, st = Lux.setup(Random.default_rng(), deeponet)
c = deeponet((u, v), θ, st)[1]
sol = ground_analytic.(u0, p, vec(t))
p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
data, tuple_ = get_data()
function additional_loss_(phi, θ)
u = phi(tuple_, θ)
u = phi((p, t), θ)
norm = prod(size(u))
sum(abs2, u .- data) / norm
sum(abs2, u .- sol) / norm
end

alg = PINOODE(
deeponet, opt, bounds, number_of_parameters; strategy = strategy,
additional_loss = additional_loss_)
sol = solve(prob, alg, verbose = false, maxiters = 2000)
sol = solve(prob, alg, verbose = true, maxiters = 3000)

p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
ground_solution = ground_analytic.(u0, p, vec(t))
predict_sol = sol((p, t))
predict_sol = sol.interp(p, t)
@test ground_solutionpredict_sol rtol=0.05
end

Expand Down Expand Up @@ -200,12 +194,12 @@ end
end
(p, t) = get_trainset(chain, bounds, 20, tspan, 0.1)
ground_solution_ = ground_solution_f(p, t)
predict = sol(reduce(vcat, (p, t)))[1, :, :]
predict = sol.interp(p, t)[1, :, :]
@test ground_solution_predict rtol=0.05

p, t = get_trainset(chain, bounds, 50, tspan, 0.025)
ground_solution_ = ground_solution_f(p, t)
predict_sol = sol(reduce(vcat, (p, t)))[1, :, :]
predict_sol = sol.interp(p, t)[1, :, :]
@test ground_solution_predict_sol rtol=0.05
end

Expand Down Expand Up @@ -243,21 +237,21 @@ end

(p, t) = get_trainset(deeponet, bounds, 50, tspan, 0.025)
ground_solution_ = ground_solution_f(p, t)
predict = sol((p, t))
predict = sol.interp(p, t)
@test ground_solution_predict rtol=0.05

p, t = get_trainset(deeponet, bounds, 100, tspan, 0.01)
ground_solution_ = ground_solution_f(p, t)
predict = sol((p, t))
predict = sol.interp(p, t)
@test ground_solution_predict rtol=0.05
end

#vector output
@testitem "Example du = [cos(p * t), sin(p * t)]" tags=[:pinoode] setup=[PINOODETestSetup] begin
using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random
equation = (u, p, t) -> [cos(p * t), sin(p * t)]
tspan = (0.0f0, 1.0f0)
u0 = [1.0f0, 0.0f0]
tspan = (0.0, 1.0)
u0 = [1.0, 0.0]
prob = ODEProblem(equation, u0, tspan)
input_branch_size = 1
chain = Chain(
Expand All @@ -269,7 +263,7 @@ end
strategy = StochasticTraining(300)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 6000)
sol = solve(prob, alg, verbose = false, maxiters = 6000)

ground_solution = (u0, p, t) -> [1 + sin(p * t) / p, 1 / p - cos(p * t) / p]
function ground_solution_f(p, t)
Expand All @@ -288,14 +282,14 @@ end
end
p, t = get_trainset(chain, bounds, 50, tspan, 0.025)
ground_solution_ = ground_solution_f(p, t)
predict = sol(reduce(vcat, (p, t)))
predict = sol.interp(p, t)
@test ground_solution_[1, :, :]predict[1, :, :] rtol=0.05
@test ground_solution_[2, :, :]predict[2, :, :] rtol=0.05
@test ground_solution_predict rtol=0.05

p, t = get_trainset(chain, bounds, 300, tspan, 0.01)
ground_solution_ = ground_solution_f(p, t)
predict = sol(reduce(vcat, (p, t)))
predict = sol.interp(p, t)
@test ground_solution_[1, :, :]predict[1, :, :] rtol=0.05
@test ground_solution_[2, :, :]predict[2, :, :] rtol=0.05
@test ground_solution_predict rtol=0.3
Expand Down

0 comments on commit c5a456a

Please sign in to comment.