diff --git a/src/PDE_BPINN.jl b/src/PDE_BPINN.jl index 0b81033fa1..e180449bbf 100644 --- a/src/PDE_BPINN.jl +++ b/src/PDE_BPINN.jl @@ -62,17 +62,17 @@ mutable struct PDELogTargetDensity{ end end -function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ) +function LogDensityProblems.logdensity(ltd::PDELogTargetDensity, θ) # for parameter estimation neccesarry to use multioutput case - return Tar.full_loglikelihood(setparameters(Tar, θ), - Tar.allstd) + priorlogpdf(Tar, θ) + L2LossData(Tar, θ) - # + L2loss2(Tar, θ) + return ltd.full_loglikelihood(setparameters(ltd, θ), + ltd.allstd) + priorlogpdf(ltd, θ) + L2LossData(ltd, θ) + # + L2loss2(ltd, θ) end -function setparameters(Tar::PDELogTargetDensity, θ) - names = Tar.names - ps_new = θ[1:(end - Tar.extraparams)] - ps = Tar.init_params +function setparameters(ltd::PDELogTargetDensity, θ) + names = ltd.names + ps_new = θ[1:(end - ltd.extraparams)] + ps = ltd.init_params # multioutput case for Lux chains, for each depvar ps would contain Lux ComponentVectors # which we use for mapping current ahmc sampled vector of parameters onto NNs @@ -80,10 +80,10 @@ function setparameters(Tar::PDELogTargetDensity, θ) Luxparams = [vector_to_parameters(ps_new[((i += length(ps[x])) - length(ps[x]) + 1):i], ps[x]) for x in names] - a = ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams)) + a = ComponentArray(NamedTuple{ltd.names}(i for i in Luxparams)) - if Tar.extraparams > 0 - b = θ[(end - Tar.extraparams + 1):end] + if ltd.extraparams > 0 + b = θ[(end - ltd.extraparams + 1):end] return ComponentArray(; depvar = a, p = b) @@ -93,22 +93,22 @@ function setparameters(Tar::PDELogTargetDensity, θ) end end -LogDensityProblems.dimension(Tar::PDELogTargetDensity) = Tar.dim +LogDensityProblems.dimension(ltd::PDELogTargetDensity) = ltd.dim function LogDensityProblems.capabilities(::PDELogTargetDensity) LogDensityProblems.LogDensityOrder{1}() end # L2 losses loglikelihood(needed mainly for ODE parameter estimation) -function L2LossData(Tar::PDELogTargetDensity, θ) - Φ = Tar.Φ - init_params = Tar.init_params - dataset = Tar.dataset +function L2LossData(ltd::PDELogTargetDensity, θ) + Φ = ltd.Φ + init_params = ltd.init_params + dataset = ltd.dataset sumt = 0 - L2stds = Tar.allstd[3] + L2stds = ltd.allstd[3] # each dep var has a diff dataset depending on its indep var and their domains # these datasets are matrices of first col-dep var and remaining cols-all indep var - # Tar.init_params is needed to construct a vector of parameters into a ComponentVector + # ltd.init_params is needed to construct a vector of parameters into a ComponentVector # dataset of form Vector[matrix_x, matrix_y, matrix_z] # matrix_i is of form [i,indvar1,indvar2,..] (needed in case if heterogenous domains) @@ -118,13 +118,13 @@ function L2LossData(Tar::PDELogTargetDensity, θ) # dataset[i][:, 2:end] -> indepvar cols of a particular depvar's dataset # dataset[i][:, 1] -> depvar col of depvar's dataset - if Tar.extraparams > 0 + if ltd.extraparams > 0 for i in eachindex(Φ) sumt += logpdf( MvNormal( Φ[i](dataset[i][:, 2:end]', - vector_to_parameters(θ[1:(end - Tar.extraparams)], - init_params)[Tar.names[i]])[1, + vector_to_parameters(θ[1:(end - ltd.extraparams)], + init_params)[ltd.names[i]])[1, :], Diagonal(abs2.(ones(size(dataset[i])[1]) .* L2stds[i]))), @@ -136,23 +136,23 @@ function L2LossData(Tar::PDELogTargetDensity, θ) end # priors for NN parameters + ODE constants -function priorlogpdf(Tar::PDELogTargetDensity, θ) - allparams = Tar.priors +function priorlogpdf(ltd::PDELogTargetDensity, θ) + allparams = ltd.priors # Vector of ode parameters priors invpriors = allparams[2:end] # nn weights nnwparams = allparams[1] - if Tar.extraparams > 0 + if ltd.extraparams > 0 invlogpdf = sum( logpdf(invpriors[length(θ) - i + 1], θ[i]) - for i in (length(θ) - Tar.extraparams + 1):length(θ); + for i in (length(θ) - ltd.extraparams + 1):length(θ); init = 0.0) return (invlogpdf + - logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)])) + logpdf(nnwparams, θ[1:(length(θ) - ltd.extraparams)])) end return logpdf(nnwparams, θ) end diff --git a/src/rode_solve.jl b/src/rode_solve.jl index e4ef6ad930..bd388ea6d5 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -99,7 +99,6 @@ function SciMLBase.__solve( total_loss(θ, _) = inner_f(θ, phi) optf = OptimizationFunction(total_loss, AutoZygote()) - plen = maxiters === nothing ? 6 : ndigits(maxiters) callback = function (p, l) if verbose diff --git a/test/runtests.jl b/test/runtests.jl index b7d9039673..4d0034cdd3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,12 +51,9 @@ end @time @safetestset "AdaptiveLoss" include("adaptive_loss_tests.jl") end - #= - # Fails because it uses sciml_train if GROUP == "All" || GROUP == "NNRODE" - @time @safetestset "NNRODE" begin include("NNRODE_tests.jl") end + @time @safetestset "NNRODE" include("NNRODE_tests.jl") end - =# if GROUP == "All" || GROUP == "Forward" @time @safetestset "Forward" include("forward_tests.jl")