-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Implement vectorized NUTS #228
Conversation
Codecov Report
@@ Coverage Diff @@
## master #228 +/- ##
==========================================
+ Coverage 89.13% 89.15% +0.01%
==========================================
Files 16 16
Lines 672 802 +130
==========================================
+ Hits 599 715 +116
- Misses 73 87 +14
Continue to review full report at Codecov.
|
Alternatively we could keep both the iterative and recursive NUTS implementations and only use the iterative one for vectorized NUTS (dispatching on matrix of positions). The advantage of this is we can always check consistency of iterative and recursive NUTS by passing either a vector or positions or a n by 1 matrix of positions. |
j == 0 && return state′.tree, state′.sampler, state′.termination | ||
|
||
# TODO: allocate state_cache in `transition` and reuse for all subtrees | ||
state_cache = Vector{typeof(state′)}(undef, state_cache_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check my understanding. The reason we have the -1
in L727 and L815 is because we are using the state_cache
in a binary format, thus state_cache[end]
is the least significant digit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reverse iteration is because we combine a subtree with the most recent subtree of the same size. And that produces a subtree that must be combined with the next most recently merged subtree, which will always be the previous entry in the cache. subtree_cache_combine_range
isn't very intuitive, but it gives us this range, as checked against a simplified recursion. I'll add a suitable test.
I agree we should keep both for a while. |
This reverts commit 2c59d70.
For chains that have already terminated, continued leapfrogs are still parallelized, but the resulting subtrees are always discarded.
Okay, NUTS is completely vectorized, and for a single chain, it produces the same output for recursive and iterative NUTS. I've tested with 100 parallel chains with mvnormal and mvcauchy distributions, and the chains all seem to target the right distributions. What's left now is lots of unit and integration tests of individual functions and trying to eliminate redundant function calls. Along the way I'll try to add better tests for recursive NUTS. Since this refactor ended up producing more atomic functions, this should be easier. I did need to make one functional change to recursive NUTS for consistency with iterative. Namely, if a proposed subtree has terminated, it still draws a random number for the metropolis-hastings step, but it still rejects the tree. This is necessary because vectorizing iterative NUTS requires drawing the random number for all chains even if some have already terminated. One property of vectorized NUTS is that if the target distribution is heavy-tailed, then there's a high chance one of the chains ends up in the tails, and since all chains are synchronized, that means that all chains must wait for that one to terminate before continuing, which destroys the performance. (this property is known for TFP's vectorized NUTS, see tensorflow/probability#728 (comment). Hence, in practice a user might switch to multithreading or parallel processing when they find that the distribution is really slow to sample. |
Sounds fine. I will take a closer look at the PR (the new commits) this week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work on vectorised operations. One thing I notice is that it current doesn't support passing a vector of RNGs. Making such interface work can also help testing: we can pass a vector of same RNG and we should expect all the chains being identical.
end | ||
|
||
@inline colwise_dot(x::AbstractVector, y::AbstractVector) = dot(x, y) | ||
# TODO: this function needs a custom GPU kernel, e.g. using KernelAbstractions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this comment mean?
ℓw = logaddexp(s1.ℓw, s2.ℓw) | ||
zcand = rand(rng) < exp(s1.ℓw - ℓw) ? s1.zcand : s2.zcand | ||
return MultinomialTS(zcand, ℓw) | ||
end | ||
function combine(rng::AbstractRNG, s1::MultinomialTS, s2::MultinomialTS) | ||
ℓw = logaddexp.(s1.ℓw, s2.ℓw) | ||
is_accept = rand.(rng) .< exp.(s1.ℓw .- ℓw) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have the broadcast rand
for rng
here even we have rng::AbstractRNG
?
tree_left = accept!(tree′, tree_left, is_accept) | ||
return tree_left, tree_right | ||
end | ||
|
||
""" | ||
Recursivly build a tree for a given depth `j`. | ||
""" | ||
function build_tree( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realise after doing this refactoring to the existing NUTS, we cannot check if the new implementation gives the same results as the old one. Any idea on doing this?
Thanks @xukai92 for the review. I got busy and have had to put this on hold. I also think the PR might need to be broken into a few more modest PRs. One reason I paused this was because in informal benchmarks, batched iterative NUTS for N chains was universally slower than running N chains in serial. I saw this happening for two reasons, which I'll document here before I forget:
I'd like to revisit this in more detail in the future. For example, perhaps a different termination criterion that scales better could be used (e.g. something between terminate all chains when any chain has terminated and terminate only when all chains have terminated), but we must be careful to ensure that such a criterion targets the same distribution and does so efficiently. I'm sure @junpenglao has had similar thoughts RE TFP's batched iterative NUTS implementation. |
Drive by comments on:
I am not familiar with Julia code, but probably you could avoid copying by doing some indexing and update of tensor. |
It's true, there may be other refactors that can avoid the copy. This was just an obvious one. Thanks for the links. I'm not familiar with TensorFlow code but will process when I work on this again.
Perhaps we're talking about different things? The NUTS algorithm proceeds by tree-doubling. Each time the tree is doubled, a Bernoulli random variable determines whether we double on the left or the right. Let's call our original tree A and the new tree B. Tree B is constructed in the same way as A (i.e. via tree doubling with all of the same u-turn checks before accepting new subtrees), except that it is only extended in one direction (else it could overlap A). Otherwise tree A and B could overlap. Informally, as long as tree B does not encounter a termination condition internally or when merged with A, then given the merged tree A+B, it is not possible to determine whether tree A or B was the original tree. Due to the recursive nature, of the tree doubling, we cannot determine which point on the trajectory was the starting point, so we have reversibility. The extension of B in one direction to balance the tree is what I am referring to. |
Maybe it is more implementation detail in your PR then - In isolation, B is always in one direction, as you select the starting point (left or right of A) and direction of the momentum (-1 if left, 1 if right). After B is complete, you added that to A ( |
Ah, no, that is not what I am saying. |
No worries. Maybe we can start with a PR that only updating the interfaces/abstractions for the recursive NUTS without modifying any functionality? We can make sure that everything stays the same with the new interfaces. Then pushing another PR to add the vectorised version would be more smooth. |
This is getting a bit out of date from the current master branch. Please feel free to open a new PR if needed. |
As discussed in #140, the recursion in the
build_tree
function used for NUTS algorithm is problematic for vectorization. This PR introduces an iterativebuild_tree
function that is equivalent to the recursive one.To facilitate comparison while working on this PR, I have added a type parameter to the signature of
NUTS
to indicate whether to use recursion or iteration. This is used for benchmarking and regression tests. My intention if this is accepted is to revert that commit and delete the recursivebuild_tree
, making iterative NUTS the default.In preliminary benchmarks, the iterative and recursive versions are similar in runtime, allocations, and memory usage. However, for some target distributions, one may be slightly faster than the other. As this gets further along, I'll add more rigorous benchmarks. However, all regression tests currently pass, and reviews are welcome.
Edit: Since we are now keeping both recursive and iterative NUTS, this PR is now implementing vectorized iterative NUTS