Skip to content

Commit

Permalink
Weight correction for CharacterProposal (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
benlebrun committed Jun 18, 2024
1 parent 5536871 commit cc0a898
Show file tree
Hide file tree
Showing 6 changed files with 587 additions and 118 deletions.
241 changes: 145 additions & 96 deletions genparse/proposal/character.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, *, llm, guide):

def sample(self, prompt, max_tokens=float('inf'), verbosity=0, **kwargs):
context = ''
P = 1
W = 1
t = 0
while True:
t += 1
Expand All @@ -64,13 +64,13 @@ def sample(self, prompt, max_tokens=float('inf'), verbosity=0, **kwargs):
p_llm = self.llm.p_next(prompt + context)
with self.timer['cfg+trie'](t=len(context)):
self._update_trie(p_llm)
token, p_token, _, _ = self._guided_sample_trie(
token, weight_update = self._guided_sample_trie(
self.root, context, verbosity=verbosity, **kwargs
)
else:
token = self.guide.eos
p_token = 1
P *= p_token
weight_update = 1
W *= weight_update
if self.guide.eos == token:
break
if verbosity > 0:
Expand All @@ -79,21 +79,46 @@ def sample(self, prompt, max_tokens=float('inf'), verbosity=0, **kwargs):
if verbosity > 0:
print()
self.timer.compare()
return (context, P)
return (context, W)

async def sample_next_token(
self, prompt, context, verbosity=0, compare_time=False, **kwargs
self,
prompt,
context,
verbosity=0,
compare_time=False,
correct_weights=True,
**kwargs,
):
"""
Proposes a token and incremental weight update.
Args:
prompt : The LLM prompt.
context : The previous generated tokens.
verbosity : > 1 prints sampling process.
compare_time : true compares time spent in LLM to cfg+trie.
correct_weights : whether to correct the importance weights with RAVI.
false leads to probabilistically incorrect inference.
Returns:
token : Proposed LLM token.
weight_update : Incremental SMC weight update.
"""
with self.timer['llm'](t=len(context)):
p_llm = await self.llm.p_next(prompt + context)
with self.timer['cfg+trie'](t=len(context)):
self._update_trie(p_llm)
(path, llm_prob, guide_prob, proposal_prob) = self._guided_sample_trie(
self.root, context, verbosity=verbosity, **kwargs
)
if correct_weights:
(token, weight_update) = self._guided_sample_trie(
self.root, context, verbosity=verbosity, **kwargs
)
else:
(token, weight_update) = self._guided_sample_trie_uncorrected(
self.root, context, verbosity=verbosity, **kwargs
)
if compare_time:
self.timer.compare()
return (path, llm_prob, guide_prob, proposal_prob)
return (token, weight_update)

def __deepcopy__(self, memo):
cpy = type(self).__new__(type(self))
Expand All @@ -117,6 +142,111 @@ def __deepcopy__(self, memo):
return cpy

def _guided_sample_trie(self, root, context, draw=sample_dict, verbosity=0):
"""
This function samples a token from the trie and computes the incremental weight update.
The following procedure, justified using RAVI, gives the way we sample a token and compute the incremental SMC weight update.
1. Sample a subset $S$ of the token vocabulary by sampling a path through the trie.
2. Compute *unnormalized target* $p(x)$ of each $x \in S$ according to $p_\text{LLM}(x)p_\text{CFG}(x)$.
* $p_\text{LLM}(x)$ is given from the mass at the leaf of the trie;
* $p_\text{CFG}(x)$ is given as the product of the next character distributions up to that point in the path
3. Compute (local) weight $w(x)$ of each token as $\frac{p(x)}{\Pr(x \in S)}$ where $\Pr(x \in S)$ is the *inclusion probability*.
* $\Pr(x \in S)$ in the character proposal is given as the probability of the path prefix up to $x$.
4. Renormalize the weights of the tokens in $S$ and sample one of them.
5. Set the incremental SMC weight update $w^\prime(x) = \sum_{x \in S} w(x)$
"""
curr = root
path = []

inclusion_prob = 1 # path prefix probability
cfg_prob = 1

weights = Float.chart()

children = self.children
mass = self.mass

if verbosity > 1:
print(colors.line(80))
while True:
children_curr = children[curr]
mass_curr = mass[curr]

p1 = Float.chart((a, mass[c] / mass_curr) for a, c in children_curr.items())

p2 = self.guide.p_next(context + ''.join(path)).trim()

if None in p1:
token = ''.join(path)

weights[token] = (mass[children_curr[None]] * cfg_prob) / inclusion_prob

if verbosity > 1:
print(
colors.blue % 'ADDED TOKEN TO S',
repr(token),
'weight=',
weights[token],
'token prob=',
mass[children_curr[None]] * cfg_prob,
'inclusion prob=',
inclusion_prob,
)

_q = (p1 * p2).trim()

if verbosity > 1:
print(colors.yellow % 'calling context=', repr(''.join(context)))
print(colors.yellow % 'partial token=', repr(''.join(path)))
if not _q:
print('llm (top 10) =', p1.top(10))
print('guide (top 10) =', p2.top(10))
print('_q (top 10) =', _q.top(10))

if not _q:
break

q = _q.normalize()

a = draw(q)
inclusion_prob *= q[a]
cfg_prob *= p2[a]

curr = children_curr[a]

if verbosity > 1:
print(colors.orange % 'action', repr(a), 'context', repr(''.join(path)))

path.append(a)

normalized_weights = weights.normalize()

if verbosity > 1:
print(colors.light.green % 'token weights:', weights)
print(colors.light.green % 'token probs:', normalized_weights)

token = draw(normalized_weights)
weight_update = weights.sum()

if verbosity > 1:
print(colors.orange % 'sampled token=', repr(token))
print(colors.orange % 'weight update=', weight_update)

return (token, weight_update)

def _guided_sample_trie_uncorrected(
self, root, context, draw=sample_dict, verbosity=0
):
"""
This function samples a token from the trie and computes the incremental weight update.
WARNING: This function is probabilistically incorrect; it produces biased estimates.
The returned weight update is given as p_llm(x) * p_cfg(x) / q(x,S) where x is the proposed token
and S is the path through the trie from which x is sampled.
"""
curr = root
path = []
guide_prob = 1
Expand Down Expand Up @@ -177,96 +307,15 @@ def _guided_sample_trie(self, root, context, draw=sample_dict, verbosity=0):
if verbosity > 1:
print(colors.light.green % 'p exits:', exits)

path = draw(exits)
token = draw(exits)

if verbosity > 1:
print(colors.orange % 'picked exit', repr(path))

proposal_prob *= exits[path]

llm_prob = mass[self.word2leaf[path]]

return (path, llm_prob, guide_prob, proposal_prob)

def _enumerate_paths(self, context):
# Used for debugging
# MAKE SURE TO CALL proposal._update_trie(p_llm) BEFORE RUNNING

curr = self.root
children = self.children
mass = self.mass
paths = []
exits = Float.chart()

def _enum_paths(chars, trace, children_curr, mass_curr, proposal_prob, exits):
p1 = Float.chart((a, mass[c] / mass_curr) for a, c in children_curr.items())
p2 = self.guide.p_next(context + ''.join(chars)).trim()
proposal_prob *= exits[token]

if None in p1:
exits[''.join(chars)] = mass[children_curr[None]]
llm_prob = mass[self.word2leaf[token]]

_q = (p1 * p2).trim()
weight_update = (llm_prob * guide_prob) / proposal_prob

if not _q:
# no more paths to explore
exits = exits.normalize()
these_paths = []
for token, exit_p in exits.items():
new_trace = trace.copy()
new_trace.append(
{
'name': 'exit',
'outcome': token,
'prob': exit_p,
'dist': exits,
}
)
these_paths.append(
{
'token': token,
'proposal_prob': proposal_prob * exit_p,
'trace': new_trace,
}
)
return these_paths
else:
# keep exploring paths
q = _q.normalize()
for a, q_prob in q.items():
curr = children_curr[a]
new_chars = chars.copy()
new_chars.append(a)

new_exits = exits.copy()

new_trace = trace.copy()
new_trace.append(
{
'name': f'char {len(new_chars)}',
'outcome': a,
'prob': q_prob,
'dist': q,
}
)
paths.extend(
_enum_paths(
chars=new_chars,
trace=new_trace,
children_curr=children[curr],
mass_curr=mass[curr],
proposal_prob=proposal_prob * q_prob,
exits=new_exits,
)
)
return []

_enum_paths(
chars=[],
children_curr=children[curr],
mass_curr=mass[curr],
proposal_prob=1,
exits=exits,
trace=[],
)

return paths
return (token, weight_update)
7 changes: 5 additions & 2 deletions genparse/proposal/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ async def sample_next_token(
p_llm = await self.llm.p_next(prompt + context)

with self.timer['cfg+trie'](t=len(context)):
Q = Float.chart(take(self.K, self.traverse_trie(context, p_llm))).normalize()
Q = Float.chart(
take(self.K - 1, self.traverse_trie(context, p_llm))
).normalize()
rest = bottom_K(p_llm, len(self.llm.V) - self.K - 1)
token = sample_dict(Q)

llm_prob = p_llm[self.old_eos if token == self.new_eos else token]
Expand All @@ -62,7 +65,7 @@ async def sample_next_token(
if compare_time:
self.timer.compare()

return (token, llm_prob, guide_prob, Q[token])
return (token, llm_prob + guide_prob, Q[token])

def _update_internal(self):
# overrides base method. Takes max rather than sum of internal nodes
Expand Down
9 changes: 2 additions & 7 deletions genparse/steer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,13 @@ def __init__(self, llm, guide, proposal, prompt, max_tokens, verbosity=0):
self.verbosity = verbosity

async def step(self):
(
token,
llm_prob,
guide_prob,
proposal_prob,
) = await self.proposal.sample_next_token(
(token, weight_update) = await self.proposal.sample_next_token(
prompt=self.prompt,
context=''.join(self.context),
compare_time=(self.verbosity > 1),
)
self.context.append(token)
self.weight += np.log(llm_prob) + np.log(guide_prob) - np.log(proposal_prob)
self.weight += np.log(weight_update)
self.max_tokens -= 1

if self.verbosity > 1:
Expand Down
Loading

0 comments on commit cc0a898

Please sign in to comment.