Skip to content

Commit

Permalink
refactor: use explicit imports
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 14, 2024
1 parent f64cd21 commit 7e0e580
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 94 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Expand All @@ -57,6 +58,7 @@ DiffEqNoiseProcess = "5.20"
Distributions = "0.25.107"
DocStringExtensions = "0.9.3"
DomainSets = "0.7"
ExplicitImports = "1.10.1"
Flux = "0.14.22"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
Expand Down Expand Up @@ -86,6 +88,7 @@ RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.1"
SciMLBase = "2.56"
Statistics = "1.10"
SymbolicIndexingInterface = "0.3.31"
SymbolicUtils = "3.7.2"
Symbolics = "6.14"
Test = "1.10"
Expand All @@ -96,6 +99,7 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Expand All @@ -108,4 +112,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CUDA", "Flux", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "Test"]
test = ["Aqua", "CUDA", "ExplicitImports", "Flux", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "Test"]
4 changes: 2 additions & 2 deletions docs/src/tutorials/neural_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ chain2 = Lux.Chain(Dense(2, inner_, af),
Dense(inner_, inner_, af),
Dense(inner_, 1))
initp, st = Lux.setup(Random.default_rng(), chain2)
init_params2 = Float64.(ComponentArrays.ComponentArray(initp))
init_params2 = Float64.(ComponentArray(initp))
# the rule by which the training will take place is described here in loss function
function loss(cord, θ)
Expand Down Expand Up @@ -226,7 +226,7 @@ chain2 = Lux.Chain(Dense(2, inner_, af),
Dense(inner_, 1))
initp, st = Lux.setup(Random.default_rng(), chain2)
init_params2 = Float64.(ComponentArrays.ComponentArray(initp))
init_params2 = Float64.(ComponentArray(initp))
@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
Expand Down
6 changes: 3 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,14 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
ninv = length(param)
t = collect(eltype(saveat), prob.tspan[1]:saveat:prob.tspan[2])

if chain isa Lux.AbstractLuxLayer
θinit, st = Lux.setup(Random.default_rng(), chain)
if chain isa AbstractLuxLayer
θinit, st = LuxCore.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
θinit = collect(ComponentArray(θinit))
else
throw(error("Only Lux.AbstractLuxLayer neural networks are supported"))
end
Expand Down
67 changes: 38 additions & 29 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,54 @@ $(DocStringExtensions.README)
"""
module NeuralPDE

using DocStringExtensions
using Reexport, Statistics
@reexport using SciMLBase
@reexport using ModelingToolkit

using Zygote, ForwardDiff, Random, Distributions
using Adapt, DiffEqNoiseProcess
using Optimization
using OptimizationOptimisers, OptimizationOptimJL
using Integrals, Cubature
using RuntimeGeneratedFunctions
using Statistics
using ArrayInterface

using Symbolics: wrap, unwrap, arguments, operation
using SymbolicUtils
using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains
using MonteCarloMeasurements: Particles
using ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives, Interval,
infimum, supremum
import DomainSets
using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, rightendpoint,
ProductDomain
using SciMLBase: @add_kwonly, parameterless_type

using ADTypes: AutoForwardDiff, AutoZygote
using ADTypes: ADTypes, AutoForwardDiff, AutoZygote
using Adapt: Adapt, adapt
using AdvancedHMC: AdvancedHMC, DiagEuclideanMetric, HMC, HMCDA, Hamiltonian,
JitteredLeapfrog, Leapfrog, MassMatrixAdaptor, NUTS, StanHMCAdaptor,
StepSizeAdaptor, TemperedLeapfrog, find_good_stepsize
using ArrayInterface: ArrayInterface, parameterless_type
using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives
using Cubature: Cubature
using ComponentArrays: ComponentArrays, ComponentArray, getdata, getaxes
using ConcreteStructs: @concrete
using Functors: fmap
using Distributions: Distributions, Distribution, MvNormal, Normal, dim, logpdf
using DiffEqNoiseProcess: DiffEqNoiseProcess
using DocStringExtensions: DocStringExtensions, FIELDS
using DomainSets: DomainSets, AbstractInterval, leftendpoint, rightendpoint, ProductDomain
using ForwardDiff: ForwardDiff
using Functors: Functors, fmap
using Integrals: Integrals, CubatureJLh, QuadGKJL
using LinearAlgebra: Diagonal
using LogDensityProblems: LogDensityProblems
using Lux: Lux, Chain, Dense, SkipConnection, StatefulLuxLayer
using Lux: FromFluxAdaptor, recursive_eltype
using LuxCore: AbstractLuxLayer, AbstractLuxWrapperLayer, AbstractLuxContainerLayer
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
using MCMCChains: MCMCChains, Chains, sample
using ModelingToolkit: ModelingToolkit, Num, PDESystem, toexpr, expand_derivatives, infimum,
supremum
using MonteCarloMeasurements: Particles
using Optimisers: Optimisers, Adam
using Optimization: Optimization
using OptimizationOptimisers: OptimizationOptimisers
using OptimizationOptimJL: OptimizationOptimJL
using Random: Random, AbstractRNG
using RecursiveArrayTools: DiffEqArray
using Reexport: @reexport
using RuntimeGeneratedFunctions: RuntimeGeneratedFunctions, @RuntimeGeneratedFunction
using SciMLBase: SciMLBase, BatchIntegralFunction, IntegralProblem,
OptimizationFunction, OptimizationProblem, ReturnCode, discretize,
isinplace, solve, symbolic_discretize
using Statistics: Statistics, mean
using Symbolics: Symbolics, unwrap, arguments, operation, build_expr
using SymbolicUtils: SymbolicUtils
using SymbolicIndexingInterface: SymbolicIndexingInterface
using QuasiMonteCarlo: QuasiMonteCarlo, LatinHypercubeSample
using WeightInitializers: glorot_uniform, zeros32
using Zygote: Zygote

import LuxCore: initialparameters, initialstates, parameterlength

import LuxCore: initialparameters, initialstates, parameterlength, statelength
@reexport using SciMLBase, ModelingToolkit

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down
12 changes: 6 additions & 6 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ function setparameters(Tar::PDELogTargetDensity, θ)
Luxparams = [vector_to_parameters(ps_new[((i += length(ps[x])) - length(ps[x]) + 1):i],
ps[x]) for x in names]

a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams))
a = ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams))

if Tar.extraparams > 0
b = θ[(end - Tar.extraparams + 1):end]
return ComponentArrays.ComponentArray(;
return ComponentArray(;
depvar = a,
p = b)
else
return ComponentArrays.ComponentArray(;
return ComponentArray(;
depvar = a)
end
end
Expand Down Expand Up @@ -126,8 +126,8 @@ function L2LossData(Tar::PDELogTargetDensity, θ)
vector_to_parameters(θ[1:(end - Tar.extraparams)],
init_params)[Tar.names[i]])[1,
:],
LinearAlgebra.Diagonal(abs2.(ones(size(dataset[i])[1]) .*
L2stds[i]))),
Diagonal(abs2.(ones(size(dataset[i])[1]) .*
L2stds[i]))),
dataset[i][:, 1])
end
return sumt
Expand Down Expand Up @@ -350,7 +350,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# add init_params for NN params
priors = [
MvNormal(priorsNNw[1] * ones(nparameters),
LinearAlgebra.Diagonal(abs2.(priorsNNw[2] .* ones(nparameters))))
Diagonal(abs2.(priorsNNw[2] .* ones(nparameters))))
]

# append Ode params to all paramvector - initial_θ
Expand Down
39 changes: 17 additions & 22 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
init_params,
estim_collocate)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractLuxLayer, st, strategy,
function LogTargetDensity(dim, prob, chain::AbstractLuxLayer, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple, estim_collocate)
Expand Down Expand Up @@ -73,7 +73,7 @@ the sampled parameters are of exotic type `Dual` due to ForwardDiff's autodiff t
"""
function vector_to_parameters(ps_new::AbstractVector,
ps::Union{NamedTuple, ComponentArrays.ComponentVector})
@assert length(ps_new) == Lux.parameterlength(ps)
@assert length(ps_new) == LuxCore.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
Expand Down Expand Up @@ -139,7 +139,7 @@ function L2loss2(Tar::LogTargetDensity, θ)
# can add phystd[i] for u[i]
physlogprob += logpdf(
MvNormal(deri_physsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
Diagonal(map(abs2,
(Tar.l2std[i] * 4.0) .*
ones(length(nnsol[i, :]))))),
nnsol[i, :])
Expand All @@ -166,8 +166,8 @@ function L2LossData(Tar::LogTargetDensity, θ)
# for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
L2logprob += logpdf(
MvNormal(nn[i, :],
LinearAlgebra.Diagonal(abs2.(Tar.l2std[i] .*
ones(length(Tar.dataset[i]))))),
Diagonal(abs2.(Tar.l2std[i] .*
ones(length(Tar.dataset[i]))))),
Tar.dataset[i])
end
return L2logprob
Expand Down Expand Up @@ -308,8 +308,8 @@ function innerdiff(Tar::LogTargetDensity, f, autodiff::Bool, t::AbstractVector,
# N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector of dependant variables)
return [logpdf(
MvNormal(vals[i, :],
LinearAlgebra.Diagonal(abs2.(Tar.phystd[i] .*
ones(length(vals[i, :]))))),
Diagonal(abs2.(Tar.phystd[i] .*
ones(length(vals[i, :]))))),
zeros(length(vals[i, :]))) for i in 1:length(Tar.prob.u0)]
end

Expand Down Expand Up @@ -338,21 +338,19 @@ function priorweights(Tar::LogTargetDensity, θ)
end
end

function generate_Tar(chain::Lux.AbstractLuxLayer, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
return init_params, chain, st
function generate_Tar(chain::AbstractLuxLayer, init_params)
return init_params, chain, LuxCore.initialstates(Random.default_rng(), chain)
end

function generate_Tar(chain::Lux.AbstractLuxLayer, init_params::Nothing)
θ, st = Lux.setup(Random.default_rng(), chain)
function generate_Tar(chain::AbstractLuxLayer, ::Nothing)
θ, st = LuxCore.setup(Random.default_rng(), chain)
return θ, chain, st
end

"""
NN OUTPUT AT t,θ ~ phi(t,θ).
"""
function (f::LogTargetDensity{C, S})(t::AbstractVector,
θ) where {C <: Lux.AbstractLuxLayer, S}
function (f::LogTargetDensity{C, S})(t::AbstractVector, θ) where {C <: AbstractLuxLayer, S}
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
θ = vector_to_parameters(θ, f.init_params)
t_ = convert.(eltypeθ, adapt(typeθ, t'))
Expand All @@ -361,8 +359,7 @@ function (f::LogTargetDensity{C, S})(t::AbstractVector,
f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* y
end

function (f::LogTargetDensity{C, S})(t::Number,
θ) where {C <: Lux.AbstractLuxLayer, S}
function (f::LogTargetDensity{C, S})(t::Number, θ) where {C <: AbstractLuxLayer, S}
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
θ = vector_to_parameters(θ, f.init_params)
t_ = convert.(eltypeθ, adapt(typeθ, [t]))
Expand Down Expand Up @@ -506,8 +503,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false,
estim_collocate = false)
!(chain isa Lux.AbstractLuxLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
!(chain isa AbstractLuxLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
# NN parameter prior mean and variance(PriorsNN must be a tuple)
if isinplace(prob)
throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))
Expand All @@ -526,8 +522,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
throw(error("Dataset Required for Parameter Estimation."))
end

if chain isa Lux.AbstractLuxLayer
# Lux-Named Tuple
if chain isa AbstractLuxLayer
initial_nnθ, recon, st = generate_Tar(chain, init_params)
else
error("Only Lux.AbstractLuxLayer Neural networks are supported")
Expand All @@ -542,14 +537,14 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
# eltype(physdt) cause needs Float64 for find_good_stepsize
# Lux chain(using component array later as vector_to_parameter need namedtuple)
initial_θ = collect(eltype(physdt),
vcat(ComponentArrays.ComponentArray(initial_nnθ)))
vcat(ComponentArray(initial_nnθ)))

# adding ode parameter estimation
nparameters = length(initial_θ)
ninv = length(param)
priors = [
MvNormal(priorsNNw[1] * ones(nparameters),
LinearAlgebra.Diagonal(abs2.(priorsNNw[2] .* ones(nparameters))))
Diagonal(abs2.(priorsNNw[2] .* ones(nparameters))))
]

# append Ode params to all paramvector
Expand Down
4 changes: 2 additions & 2 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,

if chain isa Lux.AbstractLuxLayer || chain isa Flux.Chain
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
init_params = ComponentArrays.ComponentArray(;
depvar = ComponentArrays.ComponentArray(init_params))
init_params = ComponentArray(;
depvar = ComponentArray(init_params))
else
error("Only Lux.AbstractLuxLayer and Flux.Chain neural networks are supported")
end
Expand Down
16 changes: 8 additions & 8 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,39 +393,39 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::Ab
# This is done because Float64 is almost always better for these applications
if chain isa AbstractArray
x = map(chain) do x
_x = ComponentArrays.ComponentArray(Lux.initialparameters(
_x = ComponentArray(LuxCore.initialparameters(
Random.default_rng(),
x))
Float64.(_x) # No ComponentArray GPU support
end
names = ntuple(i -> depvars[i], length(chain))
init_params = ComponentArrays.ComponentArray(NamedTuple{names}(i
init_params = ComponentArray(NamedTuple{names}(i
for i in x))
else
init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters(
init_params = Float64.(ComponentArray(LuxCore.initialparameters(
Random.default_rng(),
chain)))
end
else
init_params = init_params
end

flat_init_params = if init_params isa ComponentArrays.ComponentArray
flat_init_params = if init_params isa ComponentArray
init_params
elseif multioutput
@assert length(init_params) == length(depvars)
names = ntuple(i -> depvars[i], length(init_params))
x = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in init_params))
x = ComponentArray(NamedTuple{names}(i for i in init_params))
else
ComponentArrays.ComponentArray(init_params)
ComponentArray(init_params)
end

flat_init_params = if !param_estim && multioutput
ComponentArrays.ComponentArray(; depvar = flat_init_params)
ComponentArray(; depvar = flat_init_params)
elseif !param_estim && !multioutput
flat_init_params
else
ComponentArrays.ComponentArray(; depvar = flat_init_params, p = default_p)
ComponentArray(; depvar = flat_init_params, p = default_p)
end

if length(flat_init_params) == 0 && !Base.isconcretetype(eltype(flat_init_params))
Expand Down
8 changes: 6 additions & 2 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,13 @@ function ODEPhi(model::AbstractLuxLayer, t0::Number, u0, st)
return ODEPhi(u0, t0, StatefulLuxLayer{true}(model, nothing, st))
end

function generate_phi_θ(chain::AbstractLuxLayer, t, u0, ::Nothing)
θ, st = LuxCore.setup(Random.default_rng(), chain)
return ODEPhi(chain, t, u0, st), θ
end

function generate_phi_θ(chain::AbstractLuxLayer, t, u0, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
init_params === nothing && (init_params = θ)
st = LuxCore.initialstates(Random.default_rng(), chain)
return ODEPhi(chain, t, u0, st), init_params
end

Expand Down
Loading

0 comments on commit 7e0e580

Please sign in to comment.