Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem Interface #35

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/DiffEqUncertainty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ module DiffEqUncertainty

using DiffEqBase, Statistics, Distributions, Reexport
@reexport using Quadrature

abstract type AbstractUncertaintyProblem end

include("uncertainty_utils.jl")
include("uncertainty_problems.jl")
include("probints.jl")
include("koopman.jl")

export ProbIntsUncertainty,AdaptiveProbIntsUncertainty
export expectation, centralmoment, Koopman, MonteCarlo
export AbstractUncertaintyProblem, ExpectationProblem
export solve, expectation, centralmoment, Koopman, MonteCarlo

end
81 changes: 78 additions & 3 deletions src/koopman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,88 @@ struct MonteCarlo <: AbstractExpectationAlgorithm end
@inline tuplejoin(x, y) = (x..., y...)
@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...)

_rand(x::T) where T <: Sampleable = rand(x)
_rand(x) = x

function __make_map(prob::ODEProblem, args...; kwargs...)
(u,p) -> solve(remake(prob,u0=u,p=p), args...; kwargs...)
end

"""
solve(prob, expalg::Koopman, args; kwargs...)

Solves the ExpectationProblem using via the Koopman expectation

Both args and kwargs are passed to DifferentialEquation solver

Special kwargs:
- maxiters: max quadrature iterations (default 1000000)
- batch: solve quadrature using n-batch processing (default 0, off)
- quadalg: quadrature algorithm to use (default HCubatureJL())
- ireltol, iabstol: integration relative and absolute tolerances (default 1e-2, 1e-2)
"""
function DiffEqBase.solve(prob::ExpectationProblem, ::Koopman, args...;
maxiters=1000000,
batch=0,
quadalg=HCubatureJL(),
ireltol=1e-2, iabstol=1e-2,
kwargs...)

jargs = tuplejoin(args)
jkwargs = tuplejoin(prob.kwargs, kwargs)

# build the integrand for ∫(Ug)(x) * f(x) dx
integrand = function(xq, pq)
xqsize = size(xq)
neval = length(xqsize) > 1 ? xqsize[2] : 0
# solve ODE and evaluate observable
if neval == 0
# scalar solution
_X = prob.comp_func(prob.to_phys(xq,pq)...)
_prob = remake(prob.ode_prob, u0=_X[1], p=_X[2])
Ug = prob.g(solve(_prob, jargs...; jkwargs...))
f0 = prob.f0_func(_X...)
I = Ug .* f0
else
# ensemble solution
_X = map(xi->prob.comp_func(prob.to_phys(xi,pq)...), eachcol(xq))
prob_func(prob, i, _) = remake(prob, u0=_X[i][1], p=_X[i][2])
output_func(sol,i) = prob.g(sol)*prob.f0_func(_X[i][1], _X[i][2]), false
_prob = EnsembleProblem(prob.ode_prob, prob_func=prob_func, output_func=output_func)
I = hcat(solve(_prob, jargs...; trajectories=neval, jkwargs...)[:]...)
end

return I
end

# solve the integral using quadrature methods
intprob = QuadratureProblem(integrand, prob.quad_lb, prob.quad_ub, prob.Tscalar.(prob.p_quad), batch=batch, nout=prob.nout)
sol = solve(intprob, quadalg, reltol=ireltol, abstol=iabstol, maxiters=maxiters)
end

"""
solve(prob, expalg::MonteCarlo, args; kwargs...)

Solves the ExpectationProblem using via the Monte Carlo integration

Both args and kwargs are passed to DifferentialEquation solver

"""
function DiffEqBase.solve(prob::ExpectationProblem, ::MonteCarlo, args...; trajectories,kwargs...)
jargs = tuplejoin(prob.args, args)
jkwargs = tuplejoin(prob.kwargs, kwargs)

prob_func = function (prob, i, repeat)
_u0, _p = prob.comp_func(prob.samp_func()...)
remake(prob, u0=_u0, p=_p)
end

output_func(sol, i) = (prob.g(sol), false)

monte_prob = EnsembleProblem(prob.ode_prob;
output_func=output_func,
prob_func=prob_func)
sol = solve(monte_prob, jargs...;trajectories=trajectories,jkwargs...)
mean(sol.u)
end

function expectation(g::Function, prob::ODEProblem, u0, p, expalg::Koopman, args...;
u0_CoV=(u,p)->u, p_CoV=(u,p)->p,
maxiters=1000000,
Expand Down
85 changes: 85 additions & 0 deletions src/uncertainty_problems.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

struct ExpectationProblem{T, Tg, Tq, Tp, Tf, Ts, Tc, Tb, Tr, To, Tk} <: AbstractUncertaintyProblem
Tscalar::Type{T}
nout::Int64
g::Tg
to_quad::Tq
to_phys::Tp
f0_func::Tf
samp_func::Ts
comp_func::Tc
quad_lb::Tb
quad_ub::Tb
p_quad::Tr
ode_prob::To
kwargs::Tk
end

DEFAULT_COMP_FUNC(x,p) = (x,p)


# Builds problem from (arrays) of u0 and p distribution(s)
function ExpectationProblem(g::Function, u0_dist, p_dist, prob::ODEProblem, nout=1;
comp_func=DEFAULT_COMP_FUNC, lower_bounds=nothing, upper_bounds=nothing, kwargs...)

T = promote_type([eltype.(u0_dist)..., eltype.(p_dist)...]...)

(nu, usizes, umask) = _dist_mask.(u0_dist) |> _dist_mask_reduce
(np, psizes, pmask) = _dist_mask.(p_dist) |> _dist_mask_reduce
dist_mask = Bool.(vcat(umask, pmask))

# map physical x-state, p-params to quadrature state and params
to_quad = function(x,p)
esv = [x;p]
return (view(esv,dist_mask), view(esv, .!(dist_mask)))
end

# map quadrature x-state, p-params to physical state and params
to_phys = function(x,p)
x_it, p_it = 0, 0
esv = map(1:length(dist_mask)) do idx
dist_mask[idx] ? T(x[x_it+=1]) : T(p[p_it+=1])
end
return (view(esv,1:nu), view(esv,(nu+1):(nu+np)))
end

# evaluate the f0 (joint) distribution
f0_func = function(u,p)
# we need to use this to play nicely with Zygote...
_u = let u_it=1
map(1:length(u0_dist)) do idx
ii = usizes[idx] > 1 ? (u_it:(u_it+usizes[idx]-1)) : u_it
u_it += usizes[idx]
# u[ii]
view(u, ii)
end
end
_p = let p_it=1
map(1:length(p_dist)) do idx
ii = psizes[idx] > 1 ? (p_it:(p_it+psizes[idx]-1)) : p_it
p_it += psizes[idx]
# p[ii]
view(p, ii)
end
end
return prod(_pdf(a,b) for (a,b) in zip(u0_dist,_u)) * prod(_pdf(a,b) for (a,b) in zip(p_dist,_p))
end

# sample from (joint) distribution
samp_func() = comp_func(vcat(_rand.(u0_dist)...), vcat(_rand.(p_dist)...))

# compute the bounds
lb = isnothing(lower_bounds) ? to_quad(comp_func(vcat(_minimum.(u0_dist)...), vcat(_minimum.(p_dist)...))...)[1] : lower_bounds
ub = isnothing(upper_bounds) ? to_quad(comp_func(vcat(_maximum.(u0_dist)...), vcat(_maximum.(p_dist)...))...)[1] : upper_bounds

# compute "static" quadrature parameters
p_quad = to_quad(comp_func(vcat(mean.(u0_dist)...), vcat(mean.(p_dist)...))...)[2]

return ExpectationProblem(T,nout,g,to_quad,to_phys,f0_func,samp_func,comp_func,lb,ub,p_quad,prob,kwargs)
end

# Builds problem from (array) of u0 distribution(s)
function ExpectationProblem(g::Function, u0_dist, prob::ODEProblem, nout=1; kwargs...)
T = promote_type(eltype.(u0_dist)...)
return ExpectationProblem(g,u0_dist,Vector{T}(),prob,nout=nout; kwargs...)
end
25 changes: 25 additions & 0 deletions src/uncertainty_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

# wraps Distribtuions.jl
# note: MultivariateDistributions do not support minimum/maximum. Setting to -Inf/Inf with
# the knowledge that it will cause quadrature integration to fail if Koopman is used. To use
# MultivariateDistributions the upper/lower bounds should be set with the kwargs.
_minimum(f::T) where T <: MultivariateDistribution = -Inf .* ones(eltype(f), size(f)...)
_minimum(f) = minimum(f)
_maximum(f::T) where T <: MultivariateDistribution = Inf .* ones(eltype(f), size(f)...)
_maximum(f) = maximum(f)
_rand(f::T) where T <: Distribution = rand(f)
_rand(x) = x
_pdf(f::T, x) where T <: MultivariateDistribution = pdf(f,x) # needed to not iterate for MV
_pdf(f::T, x) where T <: Distribution = pdf.(f,x)
_pdf(f, x) = one(eltype(x))

_dist_mask(::Nothing) = (0, Vector{Bool}())
_dist_mask(x) = (length(x), repeat([isa(x, Distribution)], length(x),))
_dist_mask_reduce(x::T) where T <: AbstractArray = (sum(first.(x)), vcat(first.(x)), vcat(last.(x)...))
_dist_mask_reduce(x::T) where T <: Tuple{Int64, Vector{Bool}} = (x[1], [x[1]], x[2])

# creates a tuple of idices, or ranges, from array partition lengths
function accumulated_range(partition_lengths)
c = [0, cumsum(partition_lengths)...]
return Tuple(c[i]+1==c[i+1] ? c[i+1] : (c[i]+1):c[i+1] for i in 1:length(c)-1)
end
106 changes: 105 additions & 1 deletion test/koopman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,20 @@ prob = ODEProblem(eom!,u0,tspan,p)
A = [0.0 1.0; -p[1] -p[2]]
u0s_dist = [Uniform(1,10), Uniform(2,6)]

@testset "Koopman solve, nout = 1" begin
@info "Koopman solve, nout = 1"
g(sol) = sol[1,end]
exp_prob = ExpectationProblem(g, u0s_dist, p, prob)
analytical = (exp(A*tspan[end])*mean.(u0s_dist))[1]

for alg ∈ quadalgs
@info "$alg"
@test solve(exp_prob, Koopman(), Tsit5(); quadalg=alg)[1] ≈ analytical rtol=1e-2
end
end

@testset "Koopman Expectation, nout = 1" begin
@info "Koopman Expectation, nout = 1"
g(sol) = sol[1,end]
analytical = (exp(A*tspan[end])*mean.(u0s_dist))[1]

Expand All @@ -31,7 +44,20 @@ u0s_dist = [Uniform(1,10), Uniform(2,6)]
end
end

@testset "Koopman solve, nout > 1" begin
@info "Koopman solve, nout > 1"
g(sol) = sol[:,end]
exp_prob = ExpectationProblem(g, u0s_dist, p, prob,2)
analytical = (exp(A*tspan[end])*mean.(u0s_dist))[1]

for alg ∈ quadalgs
@info "$alg"
@test solve(exp_prob, Koopman(), Tsit5(); quadalg=alg)[1] ≈ analytical rtol=1e-2
end
end

@testset "Koopman Expectation, nout > 1" begin
@info "Koopman Expectation, nout > 1"
g(sol) = sol[:,end]
analytical = (exp(A*tspan[end])*mean.(u0s_dist))

Expand All @@ -41,7 +67,25 @@ end
end
end

@testset "Koopman solve, nout = 1, batch" begin
@info "Koopman solve, nout = 1, batch"
g(sol) = sol[1,end]
exp_prob = ExpectationProblem(g, u0s_dist, p, prob)
analytical = (exp(A*tspan[end])*mean.(u0s_dist))[1]

for bmode ∈ batchmode
for alg ∈ quadalgs
if alg isa HCubatureJL
continue
end
@info "nout = 1, batch mode = $bmode, $alg"
@test solve(exp_prob, Koopman(), Tsit5(), bmode; quadalg=alg, batch=15)[1] ≈ analytical rtol=1e-2
end
end
end

@testset "Koopman Expectation, nout = 1, batch" begin
@info "Koopman Expectation, nout = 1, batch"
g(sol) = sol[1,end]
analytical = (exp(A*tspan[end])*mean.(u0s_dist))[1]

Expand All @@ -56,7 +100,25 @@ end
end
end

@testset "Koopman solve, nout > 1, batch" begin
@info "Koopman solve, nout > 1, batch"
g(sol) = sol[:,end]
exp_prob = ExpectationProblem(g, u0s_dist, p, prob, 2)
analytical = (exp(A*tspan[end])*mean.(u0s_dist))

for bmode ∈ batchmode
for alg ∈ quadalgs
if alg isa HCubatureJL
continue
end
@info "nout = 2, batch mode = $bmode, $alg"
@test solve(exp_prob, Koopman(), Tsit5(), bmode; quadalg=alg, batch=15) ≈ analytical rtol=1e-2
end
end
end

@testset "Koopman Expectation, nout > 1, batch" begin
@info "Koopman Expectation, nout > 1, batch"
g(sol) = sol[:,end]
analytical = (exp(A*tspan[end])*mean.(u0s_dist))

Expand All @@ -75,7 +137,26 @@ end

############## Koopman AD ###############

@testset "Koopman solve AD" begin
@info "Koopman solve AD"
g(sol) = sol[1,end]
loss = function(p, alg;lb=nothing, ub=nothing)
exp_prob = ExpectationProblem(g, u0s_dist, p, prob;lower_bounds=lb, upper_bounds=ub)
solve(exp_prob, Koopman(), Tsit5(); quadalg=alg)[1]
end
dp1 = FiniteDiff.finite_difference_gradient(p->loss(p, HCubatureJL()),p)
for alg ∈ quadalgs
@info "$alg, ForwardDiff"
alg isa HCubatureJL ?
(@test ForwardDiff.gradient(p->loss(p,alg;lb=[1.,2.],ub=[10.,6.]),p) ≈ dp1 rtol=1e-2) :
(@test_broken ForwardDiff.gradient(p->loss(p,alg),p) ≈ dp1 rtol=1e-2)
@info "$alg, Zygote"
@test Zygote.gradient(p->loss(p,alg),p)[1] ≈ dp1 rtol=1e-2
end
end

@testset "Koopman Expectation AD" begin
@info "Koopman Expectation AD"
g(sol) = sol[1,end]
loss(p, alg) = expectation(g, prob, u0s_dist, p, Koopman(), Tsit5(); quadalg = alg)[1]
dp1 = FiniteDiff.finite_difference_gradient(p->loss(p, HCubatureJL()),p)
Expand All @@ -87,7 +168,29 @@ end
end
end

@testset "Koopman solve AD, batch" begin
@info "Koopman solve AD, batch"
g(sol) = sol[1,end]
loss = function(p, alg, bmode;lb=nothing, ub=nothing)
exp_prob = ExpectationProblem(g, u0s_dist, p, prob;lower_bounds=lb, upper_bounds=ub)
solve(exp_prob, Koopman(), Tsit5(); quadalg=alg)[1]
end
dp1 = FiniteDiff.finite_difference_gradient(p->loss(p, CubatureJLh(), EnsembleThreads()),p)
for bmode ∈ batchmode
for alg ∈ quadalgs
if alg isa HCubatureJL #no batch support
continue
end
@info "$bmode, $alg, ForwardDiff"
@test_broken ForwardDiff.gradient(p->loss(p,alg,bmode),p) ≈ dp1 rtol=1e-2
@info "$bmode, $alg, Zygote"
@test Zygote.gradient(p->loss(p,alg,bmode),p)[1] ≈ dp1 rtol=1e-2
end
end
end

@testset "Koopman Expectation AD, batch" begin
@info "Koopman Expectation AD, batch"
g(sol) = sol[1,end]
loss(p, alg, bmode) = expectation(g, prob, u0s_dist, p, Koopman(), Tsit5(), bmode; quadalg = alg, batch = 10)[1]
dp1 = FiniteDiff.finite_difference_gradient(p->loss(p, CubatureJLh(), EnsembleThreads()),p)
Expand Down Expand Up @@ -121,6 +224,7 @@ prob = ODEProblem(eom!,u0,tspan,p)
u0s_dist = [Uniform(1,10)]

@testset "Koopman Central Moment" begin
@info "Koopman Central Moment"
g(sol) = sol[1,end]
analytical = [0.0, exp(2*p[1]*tspan[end])*var(u0s_dist[1]), 0.0]

Expand All @@ -139,6 +243,7 @@ u0s_dist = [Uniform(1,10)]
end

@testset "Koopman Central Moment, batch" begin
@info "Koopman Central Moment, batch"
g(sol) = sol[1,end]
analytical = [0.0, exp(2*p[1]*tspan[end])*var(u0s_dist[1]), 0.0]

Expand All @@ -158,4 +263,3 @@ end
end
end
end

Loading