Skip to content

Commit

Permalink
fix: missing NNRODE tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2024
1 parent 22c316b commit 93d270e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 31 deletions.
52 changes: 26 additions & 26 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,28 @@ mutable struct PDELogTargetDensity{
end
end

function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ)
function LogDensityProblems.logdensity(ltd::PDELogTargetDensity, θ)
# for parameter estimation neccesarry to use multioutput case
return Tar.full_loglikelihood(setparameters(Tar, θ),
Tar.allstd) + priorlogpdf(Tar, θ) + L2LossData(Tar, θ)
# + L2loss2(Tar, θ)
return ltd.full_loglikelihood(setparameters(ltd, θ),
ltd.allstd) + priorlogpdf(ltd, θ) + L2LossData(ltd, θ)
# + L2loss2(ltd, θ)
end

function setparameters(Tar::PDELogTargetDensity, θ)
names = Tar.names
ps_new = θ[1:(end - Tar.extraparams)]
ps = Tar.init_params
function setparameters(ltd::PDELogTargetDensity, θ)
names = ltd.names
ps_new = θ[1:(end - ltd.extraparams)]
ps = ltd.init_params

# multioutput case for Lux chains, for each depvar ps would contain Lux ComponentVectors
# which we use for mapping current ahmc sampled vector of parameters onto NNs
i = 0
Luxparams = [vector_to_parameters(ps_new[((i += length(ps[x])) - length(ps[x]) + 1):i],
ps[x]) for x in names]

a = ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams))
a = ComponentArray(NamedTuple{ltd.names}(i for i in Luxparams))

if Tar.extraparams > 0
b = θ[(end - Tar.extraparams + 1):end]
if ltd.extraparams > 0
b = θ[(end - ltd.extraparams + 1):end]
return ComponentArray(;
depvar = a,
p = b)
Expand All @@ -93,22 +93,22 @@ function setparameters(Tar::PDELogTargetDensity, θ)
end
end

LogDensityProblems.dimension(Tar::PDELogTargetDensity) = Tar.dim
LogDensityProblems.dimension(ltd::PDELogTargetDensity) = ltd.dim

function LogDensityProblems.capabilities(::PDELogTargetDensity)
LogDensityProblems.LogDensityOrder{1}()
end

# L2 losses loglikelihood(needed mainly for ODE parameter estimation)
function L2LossData(Tar::PDELogTargetDensity, θ)
Φ = Tar.Φ
init_params = Tar.init_params
dataset = Tar.dataset
function L2LossData(ltd::PDELogTargetDensity, θ)
Φ = ltd.Φ
init_params = ltd.init_params
dataset = ltd.dataset
sumt = 0
L2stds = Tar.allstd[3]
L2stds = ltd.allstd[3]
# each dep var has a diff dataset depending on its indep var and their domains
# these datasets are matrices of first col-dep var and remaining cols-all indep var
# Tar.init_params is needed to construct a vector of parameters into a ComponentVector
# ltd.init_params is needed to construct a vector of parameters into a ComponentVector

# dataset of form Vector[matrix_x, matrix_y, matrix_z]
# matrix_i is of form [i,indvar1,indvar2,..] (needed in case if heterogenous domains)
Expand All @@ -118,13 +118,13 @@ function L2LossData(Tar::PDELogTargetDensity, θ)
# dataset[i][:, 2:end] -> indepvar cols of a particular depvar's dataset
# dataset[i][:, 1] -> depvar col of depvar's dataset

if Tar.extraparams > 0
if ltd.extraparams > 0
for i in eachindex(Φ)
sumt += logpdf(
MvNormal(
Φ[i](dataset[i][:, 2:end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
init_params)[Tar.names[i]])[1,
vector_to_parameters(θ[1:(end - ltd.extraparams)],
init_params)[ltd.names[i]])[1,
:],
Diagonal(abs2.(ones(size(dataset[i])[1]) .*
L2stds[i]))),
Expand All @@ -136,23 +136,23 @@ function L2LossData(Tar::PDELogTargetDensity, θ)
end

# priors for NN parameters + ODE constants
function priorlogpdf(Tar::PDELogTargetDensity, θ)
allparams = Tar.priors
function priorlogpdf(ltd::PDELogTargetDensity, θ)
allparams = ltd.priors
# Vector of ode parameters priors
invpriors = allparams[2:end]

# nn weights
nnwparams = allparams[1]

if Tar.extraparams > 0
if ltd.extraparams > 0
invlogpdf = sum(
logpdf(invpriors[length(θ) - i + 1], θ[i])
for i in (length(θ) - Tar.extraparams + 1):length(θ);
for i in (length(θ) - ltd.extraparams + 1):length(θ);
init = 0.0)

return (invlogpdf
+
logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)]))
logpdf(nnwparams, θ[1:(length(θ) - ltd.extraparams)]))
end
return logpdf(nnwparams, θ)
end
Expand Down
1 change: 0 additions & 1 deletion src/rode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ function SciMLBase.__solve(
total_loss(θ, _) = inner_f(θ, phi)
optf = OptimizationFunction(total_loss, AutoZygote())


plen = maxiters === nothing ? 6 : ndigits(maxiters)
callback = function (p, l)
if verbose
Expand Down
5 changes: 1 addition & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@ end
@time @safetestset "AdaptiveLoss" include("adaptive_loss_tests.jl")
end

#=
# Fails because it uses sciml_train
if GROUP == "All" || GROUP == "NNRODE"
@time @safetestset "NNRODE" begin include("NNRODE_tests.jl") end
@time @safetestset "NNRODE" include("NNRODE_tests.jl")
end
=#

if GROUP == "All" || GROUP == "Forward"
@time @safetestset "Forward" include("forward_tests.jl")
Expand Down

0 comments on commit 93d270e

Please sign in to comment.