From 1238386e454dffc517f863cb628318a800452163 Mon Sep 17 00:00:00 2001 From: benlebrun Date: Thu, 20 Jun 2024 10:52:11 -0400 Subject: [PATCH] Weight corrections for token and character proposals (#5) --- genparse/proposal/character.py | 48 +++--- genparse/proposal/token.py | 72 ++++++--- genparse/steer.py | 6 +- tests/test_character_proposal.py | 156 +++++++++--------- tests/test_token_proposal.py | 204 ++++++++++++++++++++++++ tests/test_utils/proposal_testing.py | 226 +++++++++------------------ 6 files changed, 431 insertions(+), 281 deletions(-) diff --git a/genparse/proposal/character.py b/genparse/proposal/character.py index dad2f9d0..35bd8a0a 100644 --- a/genparse/proposal/character.py +++ b/genparse/proposal/character.py @@ -5,6 +5,8 @@ from genparse.proposal.trie_numba import TokenCharacterTrie from genparse.semiring import Float +from inspect import iscoroutinefunction + class CharacterProposal(TokenCharacterTrie): """ @@ -58,6 +60,7 @@ def sample( ): context = '' W = 1 + P = 1 t = 0 while True: t += 1 @@ -66,13 +69,15 @@ def sample( p_llm = self.llm.p_next(prompt + context) with self.timer['cfg+trie'](t=len(context)): self._update_trie(p_llm) - token, weight_update = self._guided_sample_trie( + token, proposal_p, weight_update = self._guided_sample_trie( context, verbosity=verbosity, draw=draw, **kwargs ) else: token = self.guide.eos weight_update = 1 + proposal_p = 1 W *= weight_update + P *= proposal_p if self.guide.eos == token: break if verbosity > 0: @@ -80,14 +85,13 @@ def sample( context += token if verbosity > 0: print() - return (context, W) + return (context, P, W) async def sample_next_token( self, prompt, context, verbosity=0, - compare_time=False, correct_weights=True, draw=sample_dict, **kwargs, @@ -106,21 +110,24 @@ async def sample_next_token( token : Proposed LLM token. weight_update : Incremental SMC weight update. """ - with self.timer['llm'](t=len(context)): + + if iscoroutinefunction(self.llm.p_next): p_llm = await self.llm.p_next(prompt + context) - with self.timer['cfg+trie'](t=len(context)): - self._update_trie(p_llm) - if correct_weights: - (token, weight_update) = self._guided_sample_trie( - context, verbosity=verbosity, draw=draw, **kwargs - ) - else: - (token, weight_update) = self._guided_sample_trie_uncorrected( - context, verbosity=verbosity, draw=draw, **kwargs - ) - if compare_time: - self.timer.compare() - return (token, weight_update) + else: + p_llm = self.llm.p_next(prompt + context) + + self._update_trie(p_llm) + + if correct_weights: + (token, proposal_p, weight_update) = self._guided_sample_trie( + context, draw=draw, verbosity=verbosity, **kwargs + ) + else: + (token, proposal_p, weight_update) = self._guided_sample_trie_uncorrected( + context, draw=draw, verbosity=verbosity, **kwargs + ) + + return (token, proposal_p, weight_update) def __deepcopy__(self, memo): cpy = type(self).__new__(type(self)) @@ -164,6 +171,7 @@ def _guided_sample_trie(self, context, draw, verbosity=0): inclusion_prob = 1 # path prefix probability cfg_prob = 1 + proposal_p = 1 # probability of trace weights = Float.chart() @@ -215,6 +223,7 @@ def _guided_sample_trie(self, context, draw, verbosity=0): a = draw(q) inclusion_prob *= q[a] cfg_prob *= p2[a] + proposal_p *= q[a] curr = children_curr[a] @@ -230,13 +239,14 @@ def _guided_sample_trie(self, context, draw, verbosity=0): print(colors.light.green % 'token probs:', normalized_weights) token = draw(normalized_weights) + proposal_p *= normalized_weights[token] weight_update = weights.sum() if verbosity > 1: print(colors.orange % 'sampled token=', repr(token)) print(colors.orange % 'weight update=', weight_update) - return (token, weight_update) + return (token, proposal_p, weight_update) def _guided_sample_trie_uncorrected(self, context, draw, verbosity=0): """ @@ -318,4 +328,4 @@ def _guided_sample_trie_uncorrected(self, context, draw, verbosity=0): weight_update = (llm_prob * guide_prob) / proposal_prob - return (token, weight_update) + return (token, proposal_prob, weight_update) diff --git a/genparse/proposal/token.py b/genparse/proposal/token.py index d5b50f82..a862f2b1 100644 --- a/genparse/proposal/token.py +++ b/genparse/proposal/token.py @@ -6,6 +6,8 @@ from genparse.proposal.trie import TokenCharacterTrie from genparse.semiring import Float +from inspect import iscoroutinefunction + # TODO: It's tempting to require proposal distributions to implement the `LM` # interface, but it might be difficult to correctly implement `__call__` and # `p_next` as proposal distributions may only be distributions over sample paths @@ -40,32 +42,56 @@ def __init__(self, *, llm, guide, K=None): super().__init__(words, old_eos=llm.eos, new_eos=guide.eos) def _p_next(self, context, K=None): - with self.timer['llm'](t=len(context)): - p_llm = self.llm.p_next(self._prompt + context) + p_llm = self.llm.p_next(self._prompt + context) + return Float.chart(take(K, self.traverse_trie(context, p_llm))).normalize() - with self.timer['cfg+trie'](t=len(context)): - return Float.chart(take(K, self.traverse_trie(context, p_llm))).normalize() + async def sample_next_token(self, prompt, context, draw=sample_dict, **kwargs): + """ + Proposes a token and incremental weight update. + + The following procedure, justified using RAVI, gives the way we sample a token and compute the incremental SMC weight update. + + 1. Sample a subset S of size K of the token vocabulary by + a. enumerating the top K - 1 tokens + b. sampling a wilcard token from the remainder of the vocabulary proportional to p_llm(x) + 2. Compute *unnormalized target* p(x) of each x \in S according to p_llm(x)p_cfg(x). + 3. Compute (local) weight w(x) of each token as p(x)/Pr(x \in S) where Pr(x \in S) is the *inclusion probability*. + * Pr(x \in S) = 1 if x in top K - 1 + * Pr(x \in S) \propto p_llm(x) for the wilcard token + 4. Renormalize the weights of the tokens in S and sample one of them. + 5. Set the incremental SMC weight update w'(x) = \sum_{x \in S} w(x) + """ - async def sample_next_token( - self, prompt, context, verbosity=0, compare_time=False, draw=sample_dict, **kwargs - ): - with self.timer['llm'](t=len(context)): + if iscoroutinefunction(self.llm.p_next): p_llm = await self.llm.p_next(prompt + context) - - with self.timer['cfg+trie'](t=len(context)): - Q = Float.chart( - take(self.K - 1, self.traverse_trie(context, p_llm)) - ).normalize() - token = draw(Q) - - llm_prob = p_llm[self.old_eos if token == self.new_eos else token] - guide_prob = self._p_guide[token] - - if compare_time: - self.timer.compare() - - # temp fix because hfppl step now requires only two return values - return (token, llm_prob * guide_prob / Q[token]) + else: + p_llm = self.llm.p_next(prompt + context) + + # enumerate top K - 1 tokens + Ws = Float.chart(take(self.K - 1, self.traverse_trie(context, p_llm))) + + # sample wildcard token from p_llm + P_wc = Float.chart({x: p for x, p in p_llm.items() if x not in Ws}).normalize() + wildcard = draw(P_wc) + proposal_p = P_wc[wildcard] + + # compute wild card weight + p_cfg_wc = 1 + for i, c in enumerate(wildcard): + p_cfg_wc *= self.guide.p_next(context + wildcard[:i])[c] + Ws[wildcard] = ( + p_llm[self.old_eos if wildcard == self.new_eos else wildcard] + * p_cfg_wc + / P_wc[wildcard] + ) + + # sample token from weights and compute update + Ws_norm = Ws.normalize() + token = draw(Ws_norm) + proposal_p *= Ws_norm[token] + weight_update = Ws.sum() + + return (token, proposal_p, weight_update) def _update_internal(self): # overrides base method. Takes max rather than sum of internal nodes diff --git a/genparse/steer.py b/genparse/steer.py index fbab62eb..7d8597bf 100644 --- a/genparse/steer.py +++ b/genparse/steer.py @@ -225,10 +225,8 @@ def __init__(self, llm, guide, proposal, prompt, max_tokens, verbosity=0): self.verbosity = verbosity async def step(self): - (token, weight_update) = await self.proposal.sample_next_token( - prompt=self.prompt, - context=''.join(self.context), - compare_time=(self.verbosity > 1), + (token, _, weight_update) = await self.proposal.sample_next_token( + prompt=self.prompt, context=''.join(self.context) ) self.context.append(token) self.weight += np.log(weight_update) diff --git a/tests/test_character_proposal.py b/tests/test_character_proposal.py index cd975cb6..ec78b73f 100644 --- a/tests/test_character_proposal.py +++ b/tests/test_character_proposal.py @@ -2,7 +2,6 @@ import numpy as np from arsenal import colors, timeit -from arsenal.maths import assert_equal from genparse.cfglm import CFGLM, BoolMaskCFGLM, locally_normalize from genparse.lm import GreedilyTokenizedLLM @@ -17,13 +16,9 @@ def test_timothy(): pcfg = CFGLM( locally_normalize( - LarkStuff( - r""" - - start: /[ ]*Tim(othy)?[ ](Fabbri[ ])?Vieira\./ - - """ - ).char_cfg(0.99), + LarkStuff(r""" start: /[ ]*Tim(othy)?[ ](Fabbri[ ])?Vieira\./""").char_cfg( + 0.99 + ), tol=1e-100, ) ) @@ -38,7 +33,7 @@ def test_timothy(): for _ in range(10): print('----------------------------------') with timeit('sample'): - ys, w = proposal.sample(prompt, verbosity=1, max_tokens=50) + ys, _, w = proposal.sample(prompt, verbosity=1, max_tokens=50) W[ys] += w @@ -104,7 +99,7 @@ def todo_chomsky(): for _ in range(10): print('----------------------------------') with timeit('sample'): - ys, w = proposal.sample(prompt, verbosity=1) + ys, _, w = proposal.sample(prompt, verbosity=1) W[ys] += w @@ -153,7 +148,7 @@ def todo_fruit(): for _ in range(10): print('----------------------------------') with timeit('sample'): - ys, w = proposal.sample(prompt, verbosity=1) + ys, _, w = proposal.sample(prompt, verbosity=1) W[ys] += w @@ -163,9 +158,11 @@ def todo_fruit(): from test_utils.proposal_testing import ( - make_character_proposal, enumerate_traces, enumerate_target, + make_character_proposal, + assert_proper_weighting, + assert_unbiased_Z, ) @@ -184,7 +181,8 @@ def test_normalizing_constant_unbiased(): ' W', ' O', ' S', - ' s' ' WHE', + ' s', + ' WHE', ' ORD', ' SEL', ' ORD', @@ -194,7 +192,8 @@ def test_normalizing_constant_unbiased(): ' SELE', ' ORDE', ' stat', - ' stad' ' SELECT', + ' stad', + ' SELECT', ' WHERE', ' ORDER', ' state', @@ -219,40 +218,22 @@ def test_normalizing_constant_unbiased(): WS: /[ ]/ """ - proposal = make_character_proposal(V=V, grammar=grammar, uniform=True) + proposal = make_character_proposal(V=V, guide_spec=grammar, uniform=True) - # E_{(x,S) ~ q(x,S)}[w(x,S)] = \sum_{x,S} q(x,S) * w(x,S) - Z_hat = lambda traces: sum([z.weight * z.score for z in traces]) - - context = ' ' prompt = '' - traces = enumerate_traces(proposal, prompt, context) - target = enumerate_target(proposal, prompt, context) - - want = target.sum() - have = Z_hat(traces) + context = ' ' - assert_equal(want, have, tol=1e-8) + assert_unbiased_Z(prompt, context, proposal, tol=1e-8) - context = ' SELECT' prompt = '' - traces = enumerate_traces(proposal, prompt, context) - target = enumerate_target(proposal, prompt, context) - - want = target.sum() - have = Z_hat(traces) + context = ' SELECT' - assert_equal(want, have, tol=1e-8) + assert_unbiased_Z(prompt, context, proposal, tol=1e-8) - context = ' SELECT * FROM data' prompt = '' - traces = enumerate_traces(proposal, prompt, context) - target = enumerate_target(proposal, prompt, context) - - want = target.sum() - have = Z_hat(traces) + context = ' SELECT * FROM data' - assert_equal(want, have, tol=1e-8) + assert_unbiased_Z(prompt, context, proposal, tol=1e-8) def test_proper_weighting(): @@ -274,6 +255,10 @@ def test_proper_weighting(): np.random.seed(0) random.seed(0) + ################# + # Boolean guide # + ################# + V = {' ', ' a', ' b', '▪'} grammar = r""" @@ -283,21 +268,12 @@ def test_proper_weighting(): WS: /[ ]/ """ - proposal = make_character_proposal(V=V, uniform=True, grammar=grammar) + proposal = make_character_proposal(V=V, uniform=True, guide_spec=grammar) - contxt = '' prompt = '' - traces = enumerate_traces(proposal, prompt, contxt) - target = enumerate_target(proposal, prompt, contxt) - - pi_hat = lambda traces, x: sum( - [tr.weight * tr.score for tr in traces if tr.token == x] - ) + context = '' - for x in proposal.llm.V: - have = pi_hat(traces, x) - want = target[x] - assert_equal(have, want, tol=1e-8) + assert_proper_weighting(prompt, context, proposal, tol=1e-8) V = { '▪', @@ -326,44 +302,62 @@ def test_proper_weighting(): } grammar = r""" - start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS - EOS: "▪" - select_expr: STAR | select_list - bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")" - bool_expr: var "=" value | var ">" value | var "<" value - from_expr: "data" - orderby_expr: var_list WS "ASC" | var_list WS "DESC" - select_list: select_var ("," WS select_var)* - var_list: var ("," WS var)* - select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")" - var: "state" | "stadium" - value: NUMBER | "'red'" - STAR: "*" - NUMBER: /\d+/ - WS: /[ ]/ + start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS + EOS: "▪" + select_expr: STAR | select_list + bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")" + bool_expr: var "=" value | var ">" value | var "<" value + from_expr: "data" + orderby_expr: var_list WS "ASC" | var_list WS "DESC" + select_list: select_var ("," WS select_var)* + var_list: var ("," WS var)* + select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")" + var: "state" | "stadium" + value: NUMBER | "'red'" + STAR: "*" + NUMBER: /\d+/ + WS: /[ ]/ """ - proposal = make_character_proposal(V=V, grammar=grammar, uniform=True) + proposal = make_character_proposal(V=V, guide_spec=grammar, uniform=True) + + prompt = '' + context = ' SELECT' + + assert_proper_weighting(prompt, context, proposal, tol=1e-8) + + prompt = '' + context = ' SELECT * FROM data' + + assert_proper_weighting(prompt, context, proposal, tol=1e-8) + + ####################### + # Probabilistic guide # + ####################### + + pcfg = CFGLM.from_string( + """ + + 1: S -> a + 1: S -> a a + 2: S -> a a a + + """ + ) + + V = {'a', 'aa', 'aaa', '▪'} + + proposal = make_character_proposal(V=V, guide_spec=pcfg, uniform=True) - contxt = ' SELECT' prompt = '' - traces = enumerate_traces(proposal, prompt, contxt) - target = enumerate_target(proposal, prompt, contxt) + context = '' - for x in proposal.llm.V: - have = pi_hat(traces, x) - want = target[x] - assert_equal(have, want, tol=1e-8) + assert_proper_weighting(prompt, context, proposal, tol=1e-8) - contxt = ' SELECT * FROM data' prompt = '' - traces = enumerate_traces(proposal, prompt, contxt) - target = enumerate_target(proposal, prompt, contxt) + context = 'a' - for x in proposal.llm.V: - have = pi_hat(traces, x) - want = target[x] - assert_equal(have, want, tol=1e-8) + assert_proper_weighting(prompt, context, proposal, tol=1e-8) if __name__ == '__main__': diff --git a/tests/test_token_proposal.py b/tests/test_token_proposal.py index a0a32610..35edbda2 100644 --- a/tests/test_token_proposal.py +++ b/tests/test_token_proposal.py @@ -1,10 +1,15 @@ from arsenal import timeit +from arsenal.maths import assert_equal + +import random +import numpy as np from genparse.cfglm import add_EOS, locally_normalize from genparse.experimental.earley import EarleyLM from genparse.lm import make_mock_llm from genparse.proposal import TokenProposal from genparse.util import LarkStuff +from genparse import CFGLM # TODO: test equivalence of `traverse_trie` and `traverse_naive`. # def traverse_naive(self, context): @@ -100,6 +105,205 @@ def test_basic_aligned_model_iql_small(): print(proposal.sample()) +from test_utils.proposal_testing import ( + enumerate_traces, + enumerate_target, + make_token_proposal, + assert_proper_weighting, + assert_unbiased_Z, +) + + +def test_normalizing_constant_unbiased(): + """ + The expected importance weight should provide an unbiased estimate of the normalizing constant. + That is, we expect E_{(x,S) ~ q(x,S)}[w(x,S)] = Σ_x p(x). + """ + np.random.seed(0) + random.seed(0) + + V = { + '▪', + ' ', + ' ', + ' W', + ' O', + ' S', + ' s', + ' WHE', + ' ORD', + ' SEL', + ' ORD', + ' sta', + ' WHER', + ' ORDE', + ' SELE', + ' ORDE', + ' stat', + ' stad', + ' SELECT', + ' WHERE', + ' ORDER', + ' state', + ' stadium', + } + + grammar = r""" + start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS + EOS: "▪" + select_expr: STAR | select_list + bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")" + bool_expr: var "=" value | var ">" value | var "<" value + from_expr: "data" + orderby_expr: var_list WS "ASC" | var_list WS "DESC" + select_list: select_var ("," WS select_var)* + var_list: var ("," WS var)* + select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")" + var: "state" | "stadium" + value: NUMBER | "'red'" + STAR: "*" + NUMBER: /\d+/ + WS: /[ ]/ + """ + + proposal = make_token_proposal(V=V, guide_spec=grammar, K=10, uniform=True) + + prompt = '' + context = ' ' + + assert_unbiased_Z(prompt, context, proposal, tol=1e-8) + + prompt = '' + context = ' SELECT' + + assert_unbiased_Z(prompt, context, proposal, tol=1e-8) + + prompt = '' + context = ' SELECT * FROM data' + + assert_unbiased_Z(prompt, context, proposal, tol=1e-8) + + +def test_proper_weighting(): + """ + A particle (x,w) is *properly weighted* for unnormalized density p' if, for any function f, + + E_{(x,w) ~ \\tilde{q}}[f(x)w] = Σ_x p'(x) f(x) + + where Z normalizes p'. In our case, we have that + + E_{(x,w) ~ \\tilde{q}}[f(x)w] = E_{(x,S) ~ q}[f(x)w(x,S)] + + Thus, we expect + + E_{(x,S) ~ q}[f(x)w(x,S)] = Σ_x p'(x) f(x) + + for the local product of experts distributions. We test this for f(x) = δ(x', x) for all x' ∈ V. + """ + np.random.seed(0) + random.seed(0) + + V = {' ', ' a', ' b', '▪'} + + grammar = r""" + start: WS x EOS + EOS: "▪" + x: "a" | "b" | "ab" + WS: /[ ]/ + """ + + proposal = make_token_proposal(V=V, guide_spec=grammar, K=2, uniform=True) + + prompt = '' + context = '' + + assert_proper_weighting(prompt, context, proposal, tol=1e-8) + + V = { + '▪', + ' ', + ' ', + ' W', + ' O', + ' S', + ' s', + ' WHE', + ' ORD', + ' SEL', + ' ORD', + ' sta', + ' WHER', + ' ORDE', + ' SELE', + ' ORDE', + ' stat', + ' stad', + ' SELECT', + ' WHERE', + ' ORDER', + ' state', + ' stadium', + } + + grammar = r""" + start: WS? "SELECT" WS select_expr WS "FROM" WS from_expr [WS "WHERE" WS bool_condition] [WS "GROUP BY" WS var_list] [WS "ORDER BY" WS orderby_expr] WS EOS + EOS: "▪" + select_expr: STAR | select_list + bool_condition: bool_expr | "(" bool_condition WS "AND" WS bool_condition ")" | "(" bool_condition WS "OR" WS bool_condition ")" + bool_expr: var "=" value | var ">" value | var "<" value + from_expr: "data" + orderby_expr: var_list WS "ASC" | var_list WS "DESC" + select_list: select_var ("," WS select_var)* + var_list: var ("," WS var)* + select_var: var | "AVG(" var ")" | "MEDIAN(" var ")" | "COUNT(" var ")" + var: "state" | "stadium" + value: NUMBER | "'red'" + STAR: "*" + NUMBER: /\d+/ + WS: /[ ]/ + """ + + proposal = make_token_proposal(V=V, guide_spec=grammar, K=10, uniform=False) + + prompt = '' + context = ' SELECT' + + assert_proper_weighting(prompt, context, proposal, tol=1e-8) + + prompt = '' + context = ' SELECT * FROM data' + + assert_proper_weighting(prompt, context, proposal, tol=1e-8) + + ####################### + # Probabilistic guide # + ####################### + + pcfg = CFGLM.from_string( + """ + + 1: S -> a + 1: S -> a a + 2: S -> a a a + + """ + ) + + V = {'a', 'aa', 'aaa', '▪'} + + proposal = make_token_proposal(V=V, guide_spec=pcfg, K=2, uniform=True) + + prompt = '' + context = '' + + assert_proper_weighting(prompt, context, proposal, tol=1e-8) + + prompt = '' + context = 'a' + + assert_proper_weighting(prompt, context, proposal, tol=1e-8) + + if __name__ == '__main__': from arsenal import testing_framework diff --git a/tests/test_utils/proposal_testing.py b/tests/test_utils/proposal_testing.py index 445646ee..b65a4039 100644 --- a/tests/test_utils/proposal_testing.py +++ b/tests/test_utils/proposal_testing.py @@ -1,174 +1,92 @@ -########################### -# Trace enumeration utils +from genparse.lm import MockLLM, LM +from genparse.proposal import TokenProposal, CharacterProposal +from genparse.cfglm import EarleyBoolMaskCFGLM +from genparse.util import LarkStuff +from arsenal.maths import random_dist, assert_equal -from genparse import Float -import copy + +def _make_guide(guide_spec): + if isinstance(guide_spec, str): + return EarleyBoolMaskCFGLM(LarkStuff(guide_spec).char_cfg(0.99, ignore='[ ]?')) + elif isinstance(guide_spec, LM): + return guide_spec + else: + raise ValueError('Unknown guide specification') + + +def _make_mock_llm(V, uniform): + return MockLLM(V=V, eos='▪', _p=None if uniform else random_dist(len(V))) + + +def make_character_proposal(V, guide_spec, uniform=False): + llm = _make_mock_llm(V, uniform) + guide = _make_guide(guide_spec) + + return CharacterProposal(llm=llm, guide=guide) -class Trace: - def __init__(self): - self.score = 1 - self.choices = [] - self.path = None - self.token = None - self.weight = None +def make_token_proposal(V, guide_spec, K, uniform=False): + llm = _make_mock_llm(V, uniform) + guide = _make_guide(guide_spec) - def record(self, name, outcome, dist): - self.score *= dist[outcome] - self.choices.append( - {'name': name, 'outcome': outcome, 'p': dist[outcome], 'dist': dist} - ) + return TokenProposal(llm=llm, guide=guide, K=K) - def __repr__(self): - return f'{self.path}→`{self.token}` : {self.weight}' + +from genparse.inference import TraceSWOR +from genparse import Float +import asyncio def enumerate_traces(proposal, prompt, context): - p_llm = proposal.llm.p_next(prompt + context) - proposal._update_trie(p_llm) - - curr = proposal.root - children = proposal.children - mass = proposal.mass - - def _enum_traces(chars, trace, children_curr, mass_curr, cfg_p, inc_p, weights): - p1 = Float.chart((a, mass[c] / mass_curr) for a, c in children_curr.items()) - p2 = proposal.guide.p_next(context + ''.join(chars)).trim() - - if None in p1: - weights[''.join(chars)] = (mass[children_curr[None]] * cfg_p) / inc_p - - _q = (p1 * p2).trim() - - traces = [] - if not _q: - normalized_weights = weights.normalize() - for token in normalized_weights.keys(): - new_trace = copy.deepcopy(trace) - new_trace.record('exit', token, normalized_weights) - - new_trace.token = token - new_trace.weight = weights.sum() - new_trace.path = '→'.join(chars) - - traces.append(new_trace) - else: - q = _q.normalize() - for a, q_ in q.items(): - curr = children_curr[a] - new_chars = chars.copy() - new_chars.append(a) - - new_trace = copy.deepcopy(trace) - new_trace.record(f'char {len(new_chars)}', a, q) - - traces.extend( - _enum_traces( - chars=new_chars, - trace=new_trace, - children_curr=children[curr], - mass_curr=mass[curr], - inc_p=inc_p * q_, - cfg_p=cfg_p * p2[a], - weights=weights.copy(), - ) - ) - - return traces - - return _enum_traces( - chars=[], - trace=Trace(), - children_curr=children[curr], - mass_curr=mass[curr], - cfg_p=1, - inc_p=1, - weights=Float.chart(), - ) - - -def enumerate_traces_uncorrected(proposal, prompt, context): - p_llm = proposal.llm.p_next(prompt + context) - proposal._update_trie(p_llm) - - curr = proposal.root - children = proposal.children - mass = proposal.mass - - def _enum_traces(chars, trace, children_curr, mass_curr, cfg_p, exits): - p1 = Float.chart((a, mass[c] / mass_curr) for a, c in children_curr.items()) - p2 = proposal.guide.p_next(context + ''.join(chars)).trim() - - if None in p1: - exits[''.join(chars)] = mass[children_curr[None]] - - _q = (p1 * p2).trim() - - traces = [] - if not _q: - # no more paths to explore - exits_norm = exits.normalize() - for token, exit_p in exits_norm.items(): - new_trace = copy.deepcopy(trace) - new_trace.record('exit', token, exits_norm) - - new_trace.token = token - new_trace.weight = ( - cfg_p * mass[proposal.word2leaf[token]] / new_trace.score - ) - new_trace.path = '→'.join(chars) - - traces.append(new_trace) - else: - q = _q.normalize() - for a, q_ in q.items(): - curr = children_curr[a] - new_chars = chars.copy() - new_chars.append(a) - - new_trace = copy.deepcopy(trace) - new_trace.record(f'char {len(new_chars)}', a, q) - - traces.extend( - _enum_traces( - chars=new_chars, - trace=new_trace, - children_curr=children[curr], - cfg_p=cfg_p * p2[a], - mass_curr=mass[curr], - exits=exits.copy(), - ) - ) - return traces - - return _enum_traces( - chars=[], - trace=Trace(), - children_curr=children[curr], - mass_curr=mass[curr], - cfg_p=1, - exits=Float.chart(), - ) + """ + This function uses program tracing and sampling without replacement to compute + + E_{(x,w) ~ q'}[ δ(x, x') * w ] = E_{(x,S) ~ q}[ δ(x, x') * w(x,S) ] + = Σ_{x,S} δ(x, x') * q(x,S) * w(x,S) + + for each x' in V. + + Its use is to check whether our proposal satisfies properties like proper weighting through exact enumeration. + """ + tracer = TraceSWOR() + P = Float.chart() + # sample without replacement until all traces have been exhausted + while tracer.root.mass > 0: + with tracer: + s, q, w = asyncio.run( + proposal.sample_next_token(draw=tracer, prompt=prompt, context=context) + ) + P[s] += w * q + return (P, tracer) def enumerate_target(proposal, prompt, context): + """ + This function exactly computes the unnormalized local POE target over next tokens given a prompt and context. + """ p_next = Float.chart() for token in proposal.llm.V: cfg_prob = 1 - for i in range(0, len(token)): - cfg_prob *= proposal.guide.p_next(context + token[:i])[token[i]] + for i, c in enumerate(token): + cfg_prob *= proposal.guide.p_next(context + token[:i])[c] p_next[token] = cfg_prob * proposal.llm.p_next(prompt + context)[token] return p_next -def make_character_proposal(V, grammar, uniform=False): - from genparse.lm import MockLLM - from genparse.proposal import CharacterProposal - from genparse.cfglm import EarleyBoolMaskCFGLM - from genparse.util import LarkStuff - from arsenal.maths import random_dist +def assert_proper_weighting(prompt, context, proposal, tol=1e-8): + pi_q, _ = enumerate_traces(proposal, prompt, context) + pi_true = enumerate_target(proposal, prompt, context) - llm = MockLLM(V=V, eos='▪', _p=None if uniform else random_dist(len(V))) - guide = EarleyBoolMaskCFGLM(LarkStuff(grammar).char_cfg(0.99, ignore='[ ]?')) + for x in proposal.llm.V: + have = pi_q[x] + want = pi_true[x] + assert_equal(have, want, tol=tol) - return CharacterProposal(llm=llm, guide=guide) + +def assert_unbiased_Z(prompt, context, proposal, tol=1e-8): + pi_q, _ = enumerate_traces(proposal, prompt, context) + pi_true = enumerate_target(proposal, prompt, context) + + have = pi_q.sum() + want = pi_true.sum() + assert_equal(have, want, tol=tol)