Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement vectorized NUTS #228

Closed
wants to merge 45 commits into from
Closed

[WIP] Implement vectorized NUTS #228

wants to merge 45 commits into from

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Oct 30, 2020

As discussed in #140, the recursion in the build_tree function used for NUTS algorithm is problematic for vectorization. This PR introduces an iterative build_tree function that is equivalent to the recursive one.

To facilitate comparison while working on this PR, I have added a type parameter to the signature of NUTS to indicate whether to use recursion or iteration. This is used for benchmarking and regression tests. My intention if this is accepted is to revert that commit and delete the recursive build_tree, making iterative NUTS the default.

In preliminary benchmarks, the iterative and recursive versions are similar in runtime, allocations, and memory usage. However, for some target distributions, one may be slightly faster than the other. As this gets further along, I'll add more rigorous benchmarks. However, all regression tests currently pass, and reviews are welcome.

Edit: Since we are now keeping both recursive and iterative NUTS, this PR is now implementing vectorized iterative NUTS

@sethaxen sethaxen requested review from yebai and xukai92 October 30, 2020 08:20
@codecov
Copy link

codecov bot commented Oct 30, 2020

Codecov Report

Merging #228 into master will increase coverage by 0.01%.
The diff coverage is 87.11%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #228      +/-   ##
==========================================
+ Coverage   89.13%   89.15%   +0.01%     
==========================================
  Files          16       16              
  Lines         672      802     +130     
==========================================
+ Hits          599      715     +116     
- Misses         73       87      +14     
Impacted Files Coverage Δ
src/trajectory.jl 90.03% <85.20%> (-5.94%) ⬇️
src/hamiltonian.jl 95.45% <100.00%> (+1.16%) ⬆️
src/integrator.jl 95.91% <100.00%> (+2.04%) ⬆️
src/utilities.jl 97.05% <100.00%> (+2.05%) ⬆️
src/contrib/forwarddiff.jl 100.00% <0.00%> (+64.28%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d435a0b...569eeba. Read the comment docs.

@sethaxen
Copy link
Member Author

sethaxen commented Nov 1, 2020

Alternatively we could keep both the iterative and recursive NUTS implementations and only use the iterative one for vectorized NUTS (dispatching on matrix of positions). The advantage of this is we can always check consistency of iterative and recursive NUTS by passing either a vector or positions or a n by 1 matrix of positions.

test/trajectory.jl Show resolved Hide resolved
test/trajectory.jl Outdated Show resolved Hide resolved
src/trajectory.jl Show resolved Hide resolved
src/trajectory.jl Show resolved Hide resolved
src/trajectory.jl Outdated Show resolved Hide resolved
j == 0 && return state′.tree, state′.sampler, state′.termination

# TODO: allocate state_cache in `transition` and reuse for all subtrees
state_cache = Vector{typeof(state′)}(undef, state_cache_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check my understanding. The reason we have the -1 in L727 and L815 is because we are using the state_cache in a binary format, thus state_cache[end] is the least significant digit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reverse iteration is because we combine a subtree with the most recent subtree of the same size. And that produces a subtree that must be combined with the next most recently merged subtree, which will always be the previous entry in the cache. subtree_cache_combine_range isn't very intuitive, but it gives us this range, as checked against a simplified recursion. I'll add a suitable test.

@xukai92
Copy link
Member

xukai92 commented Nov 1, 2020

I agree we should keep both for a while.
As you said, doing consistency check on more targets can give us confidence on both implementations.

@sethaxen
Copy link
Member Author

sethaxen commented Nov 5, 2020

Okay, NUTS is completely vectorized, and for a single chain, it produces the same output for recursive and iterative NUTS. I've tested with 100 parallel chains with mvnormal and mvcauchy distributions, and the chains all seem to target the right distributions. What's left now is lots of unit and integration tests of individual functions and trying to eliminate redundant function calls. Along the way I'll try to add better tests for recursive NUTS. Since this refactor ended up producing more atomic functions, this should be easier.

I did need to make one functional change to recursive NUTS for consistency with iterative. Namely, if a proposed subtree has terminated, it still draws a random number for the metropolis-hastings step, but it still rejects the tree. This is necessary because vectorizing iterative NUTS requires drawing the random number for all chains even if some have already terminated.

One property of vectorized NUTS is that if the target distribution is heavy-tailed, then there's a high chance one of the chains ends up in the tails, and since all chains are synchronized, that means that all chains must wait for that one to terminate before continuing, which destroys the performance. (this property is known for TFP's vectorized NUTS, see tensorflow/probability#728 (comment). Hence, in practice a user might switch to multithreading or parallel processing when they find that the distribution is really slow to sample.

@xukai92
Copy link
Member

xukai92 commented Nov 8, 2020

I did need to make one functional change to recursive NUTS for consistency with iterative. Namely, if a proposed subtree has terminated, it still draws a random number for the metropolis-hastings step, but it still rejects the tree. This is necessary because vectorizing iterative NUTS requires drawing the random number for all chains even if some have already terminated.

Sounds fine.

I will take a closer look at the PR (the new commits) this week.

Copy link
Member

@xukai92 xukai92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work on vectorised operations. One thing I notice is that it current doesn't support passing a vector of RNGs. Making such interface work can also help testing: we can pass a vector of same RNG and we should expect all the chains being identical.

end

@inline colwise_dot(x::AbstractVector, y::AbstractVector) = dot(x, y)
# TODO: this function needs a custom GPU kernel, e.g. using KernelAbstractions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this comment mean?

ℓw = logaddexp(s1.ℓw, s2.ℓw)
zcand = rand(rng) < exp(s1.ℓw - ℓw) ? s1.zcand : s2.zcand
return MultinomialTS(zcand, ℓw)
end
function combine(rng::AbstractRNG, s1::MultinomialTS, s2::MultinomialTS)
ℓw = logaddexp.(s1.ℓw, s2.ℓw)
is_accept = rand.(rng) .< exp.(s1.ℓw .- ℓw)
Copy link
Member

@xukai92 xukai92 Nov 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have the broadcast rand for rng here even we have rng::AbstractRNG?

tree_left = accept!(tree′, tree_left, is_accept)
return tree_left, tree_right
end

"""
Recursivly build a tree for a given depth `j`.
"""
function build_tree(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise after doing this refactoring to the existing NUTS, we cannot check if the new implementation gives the same results as the old one. Any idea on doing this?

@sethaxen
Copy link
Member Author

sethaxen commented Jan 5, 2021

Thanks @xukai92 for the review. I got busy and have had to put this on hold. I also think the PR might need to be broken into a few more modest PRs.

One reason I paused this was because in informal benchmarks, batched iterative NUTS for N chains was universally slower than running N chains in serial. I saw this happening for two reasons, which I'll document here before I forget:

  1. the abstraction of storing left and right leaves is challenging for batching, because the control flow for picking the next leaf to use as a starting point requires a copy:
    function left_right_subtrees(tree, tree′, v)
    tree_left, tree_right = deepcopy(tree), deepcopy(tree′)
    is_accept = isone.(v)
    tree_right = accept!(tree, tree_right, is_accept)
    tree_left = accept!(tree′, tree_left, is_accept)
    return tree_left, tree_right
    end

    These copies dominated the runtime. We could avoid this control flow by instead storing the front and back leaves, where a subtree is always expanded in the forward direction. Then it's only when beginning a new tree doubling or checking whether to accept a doubling that these copies are needed. This would require a refactor to the subtrees.
  2. When batching NUTS, tree doublings occur until all chains have terminated. As the number of chains N increases, the probability of any one chain hitting some depth also increases. This is especially bad when the posterior has heavy tails. I thought there would be some trade-off, like for a small number of chains, batched iterative NUTS would be faster, but I did not see this in practice, perhaps because of issue (1). Certainly it seems to be the case that parallelizing chains on a GPU would be a bad idea.

I'd like to revisit this in more detail in the future. For example, perhaps a different termination criterion that scales better could be used (e.g. something between terminate all chains when any chain has terminated and terminate only when all chains have terminated), but we must be careful to ensure that such a criterion targets the same distribution and does so efficiently. I'm sure @junpenglao has had similar thoughts RE TFP's batched iterative NUTS implementation.

@junpenglao
Copy link

junpenglao commented Jan 5, 2021

Drive by comments on:

  1. the abstraction of storing left and right leaves is challenging for batching, because the control flow for picking the next leaf to use as a starting point requires a copy. These copies dominated the runtime. We could avoid this control flow by instead storing the front and back leaves, where a subtree is always expanded in the forward direction. Then it's only when beginning a new tree doubling or checking whether to accept a doubling that these copies are needed. This would require a refactor to the subtrees.

I am not familiar with Julia code, but probably you could avoid copying by doing some indexing and update of tensor.
Also, I am not sure if it is still valid to only extend subtree in one direction, from Betancount's HMC paper: "In order to ensure a valid correction, this transition from states to trajectories has to exhibit a certain reversibility. Formally we require that the probability of transitions to a trajectory is the same regardless of from which state in that trajectory we might begin" (page 49 of https://arxiv.org/pdf/1701.02434.pdf)

@sethaxen
Copy link
Member Author

sethaxen commented Jan 6, 2021

I am not familiar with Julia code, but probably you could avoid copying by doing some indexing and update of tensor.

It's true, there may be other refactors that can avoid the copy. This was just an obvious one. Thanks for the links. I'm not familiar with TensorFlow code but will process when I work on this again.

Also, I am not sure if it is still valid to only extend subtree in one direction, from Betancount's HMC paper: "In order to ensure a valid correction, this transition from states to trajectories has to exhibit a certain reversibility. Formally we require that the probability of transitions to a trajectory is the same regardless of from which state in that trajectory we might begin" (page 49 of https://arxiv.org/pdf/1701.02434.pdf)

Perhaps we're talking about different things? The NUTS algorithm proceeds by tree-doubling. Each time the tree is doubled, a Bernoulli random variable determines whether we double on the left or the right. Let's call our original tree A and the new tree B. Tree B is constructed in the same way as A (i.e. via tree doubling with all of the same u-turn checks before accepting new subtrees), except that it is only extended in one direction (else it could overlap A). Otherwise tree A and B could overlap. Informally, as long as tree B does not encounter a termination condition internally or when merged with A, then given the merged tree A+B, it is not possible to determine whether tree A or B was the original tree. Due to the recursive nature, of the tree doubling, we cannot determine which point on the trajectory was the starting point, so we have reversibility. The extension of B in one direction to balance the tree is what I am referring to.

@junpenglao
Copy link

The extension of B in one direction to balance the tree is what I am referring to.

Maybe it is more implementation detail in your PR then - In isolation, B is always in one direction, as you select the starting point (left or right of A) and direction of the momentum (-1 if left, 1 if right). After B is complete, you added that to A (reverse(B)+A or A+B, + is concatenation conceptually). So IIUC you are saying you will only do A+B, which I worry it might invalid the termination criteria.

@sethaxen
Copy link
Member Author

sethaxen commented Jan 6, 2021

After B is complete, you added that to A (reverse(B)+A or A+B, + is concatenation conceptually). So IIUC you are saying you will only do A+B, which I worry it might invalid the termination criteria.

Ah, no, that is not what I am saying. reverse(B)+A or A+B would be the operations, implicitly.

@xukai92
Copy link
Member

xukai92 commented Jan 6, 2021

Thanks @xukai92 for the review. I got busy and have had to put this on hold. I also think the PR might need to be broken into a few more modest PRs.

No worries. Maybe we can start with a PR that only updating the interfaces/abstractions for the recursive NUTS without modifying any functionality? We can make sure that everything stays the same with the new interfaces. Then pushing another PR to add the vectorised version would be more smooth.

@yebai yebai closed this Nov 1, 2024
@yebai yebai deleted the iterativenuts branch November 1, 2024 19:23
@yebai
Copy link
Member

yebai commented Nov 1, 2024

This is getting a bit out of date from the current master branch. Please feel free to open a new PR if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants