Skip to content

Commit

Permalink
Initial implementation for fast conditional simulation---but really just
Browse files Browse the repository at this point in the history
a jumping off point. Also added a TODO list to the README for others to
see and (hopefully) take an interest in.
  • Loading branch information
cgeoga committed Dec 27, 2024
1 parent ada3865 commit 04d9fd0
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Vecchia"
uuid = "8d73829f-f4b0-474a-9580-cecc8e084068"
authors = ["Chris Geoga <[email protected]>"]
version = "0.9.12"
version = "0.9.13"

[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,42 @@ really need them myself (at least, not beyond what I can do with this current
pattern) so I'm not feeling super motivated to think hard about the best design
choice.

# Wanted/planned changes (contributions welcome!)

- More docstrings!
- Move several deps to weakdeps. Really, this package should only have
non-stdlib deps of `StaticArrays` and `NearestNeighbors` (or some other kNN
package, see below). All the other deps have to do with estimation, but with
the extension system those could be moved out of the `[deps]` and the package
could be made a lot leaner. This is purely an issue of finding the
time---basically no engineering is required here.
- Changing from `NearestNeighbors.KDTree` to a dynamic object for kNN queries.
As of now, for configurations that pick conditioning points based on nearest
neighbors, an entirely new static tree is constructed in each iteration when a
new point gets added. This isn't actually as slow as you would think because
`NearestNeighbors` is so darn fast, but obviously this can and should be
approved. The ideal alternative is some data structure that takes
`SVector{D,Float64}`s and rebalances to keep a O(k log n) worst-case query. A
real dream would be for the query to also be non-allocating to open the door for
parallelization, but even a fast dynamic tree object would be a big improvement.
- Conditional simulations were recently added, but that implementation would
hugely benefit from somebody kicking the tires and playing with details and
smart defaults/guardrails.
- It would be interesting to at some point benchmark the potential improvement
from using memoization for kernel evaluations. In the rchol approach, there is
the `use_tiles={true,false}` kwarg, which effectively does manual book-keeping
to avoid ever evaluating the kernel for the same pair of points twice. But it
may be more elegant and just as fast to use memoization. This is probably
10-20 lines of code and an hour to benchmark and play with, so it would be a
great first way to tinker with Vecchia stuff.
- API refinement/seeking feedback. For the most part, it is me and students in
my orbit that use this package. But I'd love for it to be more widely adopted,
and so I'd love for the interface to be polished. For example: figuring out
how best to more properly support mean functions would be nice. Another
example: I've just put a bunch of print warnings in the code about permutation
footguns. But obviously it would be better to just somehow design the
interface that there is no chance of a user getting mixed up by that.

# Citation

If you use this software in your work, **particularly if you actually use
Expand Down
44 changes: 41 additions & 3 deletions src/predict_sim.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

function posterior_cov(vc::VecchiaConfig{H,D,F}, params,
pred_pts::Vector{SVector{D,Float64}};
ncondition=maximum(length, vc.condix)) where{H,D,F}
function dense_posterior(vc::VecchiaConfig{H,D,F}, params,
pred_pts::Vector{SVector{D,Float64}};
ncondition=maximum(length, vc.condix)) where{H,D,F}
# get the given data points and make a tree for fast conditioning set
# collection.
(n, m) = (sum(length, vc.pts), length(pred_pts))
Expand Down Expand Up @@ -43,3 +43,41 @@ function posterior_cov(vc::VecchiaConfig{H,D,F}, params,
(;cond_mean, cond_var)
end


function cond_sim(vc::Vecchia.VecchiaConfig{H,D,F}, params,
pred_pts::Vector{SVector{D,Float64}};
ncondition=maximum(length, vc.condix)) where{H,D,F}
if ncondition < 30
@warn "For small numbers of conditioning points (<= 30 from slight anecdata), this method can give poor results. See issue #10 on github for more details."
end
# get the given data points and make a tree for fast conditioning set
# collection.
(n, m) = (sum(length, vc.pts), length(pred_pts))
pts = reduce(vcat, vc.pts)
# create the new conditioning set elements for the joint configuration of
# given and prediction points.
jcondix = copy(vc.condix)
sizehint!(jcondix, length(vc.condix) + length(pred_pts)) # could also pre-allocate
# TODO (cg 2024/12/27 10:17): I really would like to switch to a dynamic tree
# object for kNN queries. I expect that that is the clear bottleneck here.
for k in eachindex(pred_pts)
tree = KDTree(vcat(pts, pred_pts[1:(k-1)]))
k_cond_ixs = NearestNeighbors.knn(tree, pred_pts[k], min(ncondition, n+(k-1)))[1]
sort!(k_cond_ixs)
push!(jcondix, k_cond_ixs)
end
# create the augmented/joint point list (for now, just singleton predictions):
jpts = vcat(vc.pts, [[x] for x in pred_pts])
# create the final joint config object. The data being passed in here isn't a
# compliant size, but we'll never touch it.
jcfg = Vecchia.VecchiaConfig(vc.kernel, [hcat(NaN) for _ in eachindex(jpts)],
jpts, jcondix)
Us = sparse(Vecchia.rchol(jcfg, params))
Usnn = Us[1:n, 1:n]
# conditional simulation using standard tricks with Cholesky factors:
data = reduce(vcat, vc.data)
rawwn = randn(length(pred_pts), size(data, 2))
jwn = vcat(Usnn'*data, rawwn)
sims=(Us'\jwn)[(n+1):end, :]
end

2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end
data = cholesky([kernel(x, y, (1.0, 0.1)) for x in pts, y in pts]).L*randn(length(pts))

cfg = Vecchia.nosortknnconfig(data, pts, 50, kernel)
(test_cond_mean, test_cond_var) = Vecchia.posterior_cov(cfg, [1.0, 0.1], ppts, ncondition=50)
(test_cond_mean, test_cond_var) = Vecchia.dense_posterior(cfg, [1.0, 0.1], ppts, ncondition=50)

S1 = [kernel(x, y, (1.0, 0.1)) for x in pts, y in pts]
S12 = [kernel(x, y, (1.0, 0.1)) for x in pts, y in ppts]
Expand Down

0 comments on commit 04d9fd0

Please sign in to comment.