Skip to content

Commit

Permalink
Merge branch 'main' into vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
benlebrun authored Jun 20, 2024
2 parents bca561c + 92c2b60 commit dc48360
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 278 deletions.
50 changes: 30 additions & 20 deletions genparse/proposal/character.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from genparse.proposal.trie_numba import TokenCharacterTrie
from genparse.semiring import Float

from inspect import iscoroutinefunction


class CharacterProposal(TokenCharacterTrie):
"""
Expand Down Expand Up @@ -58,6 +60,7 @@ def sample(
):
context = ''
W = 1
P = 1
t = 0
while True:
t += 1
Expand All @@ -66,28 +69,29 @@ def sample(
p_llm = self.llm.p_next(prompt + context)
with self.timer['cfg+trie'](t=len(context)):
self._update_trie(p_llm)
token, weight_update = self._guided_sample_trie(
token, proposal_p, weight_update = self._guided_sample_trie(
context, verbosity=verbosity, draw=draw, **kwargs
)
else:
token = self.guide.eos
weight_update = 1
proposal_p = 1
W *= weight_update
P *= proposal_p
if self.guide.eos == token:
break
if verbosity > 0:
print(colors.cyan % token, end=colors.magenta % '|')
context += token
if verbosity > 0:
print()
return (context, W)
return (context, P, W)

async def sample_next_token(
self,
prompt,
context,
verbosity=0,
compare_time=False,
correct_weights=True,
draw=sample_dict,
execute_model_req=None,
Expand All @@ -100,28 +104,31 @@ async def sample_next_token(
prompt : The LLM prompt.
context : The previous generated tokens.
verbosity : > 1 prints sampling process.
compare_time : true compares time spent in LLM to cfg+trie.
correct_weights : whether to correct the importance weights with RAVI.
false leads to probabilistically incorrect inference.
Returns:
token : Proposed LLM token.
weight_update : Incremental SMC weight update.
"""
with self.timer['llm'](t=len(context)):

if iscoroutinefunction(self.llm.p_next):
p_llm = await self.llm.p_next(prompt + context, execute_model_req=execute_model_req)
with self.timer['cfg+trie'](t=len(context)):
self._update_trie(p_llm)
if correct_weights:
(token, weight_update) = self._guided_sample_trie(
context, verbosity=verbosity, draw=draw, **kwargs
)
else:
(token, weight_update) = self._guided_sample_trie_uncorrected(
context, verbosity=verbosity, draw=draw, **kwargs
)
if compare_time:
self.timer.compare()
return (token, weight_update)
else:
p_llm = self.llm.p_next(prompt + context, execute_model_req=execute_model_req)

self._update_trie(p_llm)

if correct_weights:
(token, proposal_p, weight_update) = self._guided_sample_trie(
context, draw=draw, verbosity=verbosity, **kwargs
)
else:
(token, proposal_p, weight_update) = self._guided_sample_trie_uncorrected(
context, draw=draw, verbosity=verbosity, **kwargs
)

return (token, proposal_p, weight_update)


def __deepcopy__(self, memo):
cpy = type(self).__new__(type(self))
Expand Down Expand Up @@ -165,6 +172,7 @@ def _guided_sample_trie(self, context, draw, verbosity=0):

inclusion_prob = 1 # path prefix probability
cfg_prob = 1
proposal_p = 1 # probability of trace

weights = Float.chart()

Expand Down Expand Up @@ -216,6 +224,7 @@ def _guided_sample_trie(self, context, draw, verbosity=0):
a = draw(q)
inclusion_prob *= q[a]
cfg_prob *= p2[a]
proposal_p *= q[a]

curr = children_curr[a]

Expand All @@ -231,13 +240,14 @@ def _guided_sample_trie(self, context, draw, verbosity=0):
print(colors.light.green % 'token probs:', normalized_weights)

token = draw(normalized_weights)
proposal_p *= normalized_weights[token]
weight_update = weights.sum()

if verbosity > 1:
print(colors.orange % 'sampled token=', repr(token))
print(colors.orange % 'weight update=', weight_update)

return (token, weight_update)
return (token, proposal_p, weight_update)

def _guided_sample_trie_uncorrected(self, context, draw, verbosity=0):
"""
Expand Down Expand Up @@ -319,4 +329,4 @@ def _guided_sample_trie_uncorrected(self, context, draw, verbosity=0):

weight_update = (llm_prob * guide_prob) / proposal_prob

return (token, weight_update)
return (token, proposal_prob, weight_update)
72 changes: 52 additions & 20 deletions genparse/proposal/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from genparse.proposal.trie import TokenCharacterTrie
from genparse.semiring import Float

from inspect import iscoroutinefunction

# TODO: It's tempting to require proposal distributions to implement the `LM`
# interface, but it might be difficult to correctly implement `__call__` and
# `p_next` as proposal distributions may only be distributions over sample paths
Expand Down Expand Up @@ -39,35 +41,65 @@ def __init__(self, *, llm, guide, K=None):

super().__init__(words, old_eos=llm.eos, new_eos=guide.eos)


def _p_next(self, context, K=None, execute_model_req=None, **kwargs):
with self.timer['llm'](t=len(context)):
p_llm = self.llm.p_next(self._prompt + context, execute_model_req=execute_model_req)
p_llm = self.llm.p_next(self._prompt + context, execute_model_req=execute_model_req)
return Float.chart(take(K, self.traverse_trie(context, p_llm))).normalize()


with self.timer['cfg+trie'](t=len(context)):
return Float.chart(take(K, self.traverse_trie(context, p_llm))).normalize()
async def sample_next_token(self, prompt, context, draw=sample_dict, **kwargs):
"""
Proposes a token and incremental weight update.
The following procedure, justified using RAVI, gives the way we sample a token and compute the incremental SMC weight update.
1. Sample a subset S of size K of the token vocabulary by
a. enumerating the top K - 1 tokens
b. sampling a wilcard token from the remainder of the vocabulary proportional to p_llm(x)
2. Compute *unnormalized target* p(x) of each x \in S according to p_llm(x)p_cfg(x).
3. Compute (local) weight w(x) of each token as p(x)/Pr(x \in S) where Pr(x \in S) is the *inclusion probability*.
* Pr(x \in S) = 1 if x in top K - 1
* Pr(x \in S) \propto p_llm(x) for the wilcard token
4. Renormalize the weights of the tokens in S and sample one of them.
5. Set the incremental SMC weight update w'(x) = \sum_{x \in S} w(x)
"""

async def sample_next_token(
self, prompt, context, verbosity=0, compare_time=False, draw=sample_dict,
self, prompt, context, verbosity=0, draw=sample_dict,
execute_model_req=None,
**kwargs
):
with self.timer['llm'](t=len(context)):
if iscoroutinefunction(self.llm.p_next):
p_llm = await self.llm.p_next(prompt + context, execute_model_req=execute_model_req)
else:
p_llm = self.llm.p_next(prompt + context, execute_model_req=execute_model_req)

# enumerate top K - 1 tokens
Ws = Float.chart(take(self.K - 1, self.traverse_trie(context, p_llm)))

# sample wildcard token from p_llm
P_wc = Float.chart({x: p for x, p in p_llm.items() if x not in Ws}).normalize()
wildcard = draw(P_wc)
proposal_p = P_wc[wildcard]

# compute wild card weight
p_cfg_wc = 1
for i, c in enumerate(wildcard):
p_cfg_wc *= self.guide.p_next(context + wildcard[:i])[c]
Ws[wildcard] = (
p_llm[self.old_eos if wildcard == self.new_eos else wildcard]
* p_cfg_wc
/ P_wc[wildcard]
)

# sample token from weights and compute update
Ws_norm = Ws.normalize()
token = draw(Ws_norm)
proposal_p *= Ws_norm[token]
weight_update = Ws.sum()

return (token, proposal_p, weight_update)

with self.timer['cfg+trie'](t=len(context)):
Q = Float.chart(
take(self.K - 1, self.traverse_trie(context, p_llm))
).normalize()
token = draw(Q)

llm_prob = p_llm[self.old_eos if token == self.new_eos else token]
guide_prob = self._p_guide[token]

if compare_time:
self.timer.compare()

# temp fix because hfppl step now requires only two return values
return (token, llm_prob * guide_prob / Q[token])

def _update_internal(self):
# overrides base method. Takes max rather than sum of internal nodes
Expand Down
6 changes: 2 additions & 4 deletions genparse/steer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,8 @@ def __init__(self, llm, guide, proposal, prompt, max_tokens, verbosity=0):
self.verbosity = verbosity

async def step(self):
(token, weight_update) = await self.proposal.sample_next_token(
prompt=self.prompt,
context=''.join(self.context),
compare_time=(self.verbosity > 1),
(token, _, weight_update) = await self.proposal.sample_next_token(
prompt=self.prompt, context=''.join(self.context)
)
self.context.append(token)
self.weight += np.log(weight_update)
Expand Down
Loading

0 comments on commit dc48360

Please sign in to comment.