Skip to content

Commit

Permalink
started investigating issue #15
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 29, 2024
1 parent ae7c123 commit c8e6fed
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 12 deletions.
4 changes: 4 additions & 0 deletions genparse/cfglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __call__(self, context):
def clear_cache(self):
self.model.clear_cache()

@classmethod
def from_string(cls, x, semiring=Boolean, **kwargs):
return cls(CFG.from_string(x, semiring), **kwargs)


EarleyBoolMaskCFGLM = BoolMaskCFGLM

Expand Down
4 changes: 3 additions & 1 deletion genparse/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def sort(self, **kwargs):
return self.semiring.chart((k, self[k]) for k in sorted(self, **kwargs))

def sort_descending(self):
return self.semiring.chart((k, self[k]) for k in sorted(self, lambda k: -self[k]))
return self.semiring.chart(
(k, self[k]) for k in sorted(self, key=lambda k: -self[k])
)

def normalize(self):
Z = self.sum()
Expand Down
18 changes: 15 additions & 3 deletions genparse/proposal/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def sample_next_token(
p_llm=None,
**kwargs,
):
"""
r"""
Proposes a token and incremental weight update.
The following procedure, justified using RAVI, gives the way we sample a token and compute the incremental SMC weight update.
Expand Down Expand Up @@ -94,11 +94,16 @@ async def sample_next_token(
Ws = Float.chart(take(self.K - 1, self.traverse_trie(context, p_llm)))

# sample wildcard token from p_llm
P_wc = Float.chart({x: p for x, p in p_llm.items() if x not in Ws}).normalize()
P_wc = Float.chart({x: p for x, p in p_llm.items() if x not in Ws and p > 0})

# TODO: P_wc could be empty!
# print(f'{P_wc=}')

P_wc = P_wc.normalize()
wildcard = draw(P_wc)
proposal_p = P_wc[wildcard]

# compute wild card weight
# compute the wildcard's weight
p_cfg_wc = 1
with self.timer['cfg+trie'](t=len(context)):
for i, c in enumerate(wildcard):
Expand All @@ -109,8 +114,15 @@ async def sample_next_token(
/ P_wc[wildcard]
)

# TODO: Ws[wildcard] could be zero!
# print(f'{Ws[wildcard]=}')

# sample token from weights and compute update
Ws_norm = Ws.normalize()

# TODO: Ws_norm could be empty!
# print(f'{Ws=} {P_wc=} {Ws_norm=}')

token = draw(Ws_norm)
proposal_p *= Ws_norm[token]
weight_update = Ws.sum()
Expand Down
38 changes: 30 additions & 8 deletions tests/test_token_proposal.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
from arsenal import timeit

from genparse.util import set_seed
from genparse.cfglm import add_EOS, locally_normalize
from genparse.cfglm import add_EOS, locally_normalize, BoolMaskCFGLM
from genparse.experimental.earley import EarleyLM
from genparse.lm import make_mock_llm
from genparse.lm import make_mock_llm, MockLLM
from genparse.proposal import TokenProposal
from genparse.util import LarkStuff
from genparse import CFGLM
from genparse.proposal.util import (
mock_token_proposal,
assert_proper_weighting,
Expand Down Expand Up @@ -54,7 +54,6 @@ def test_basic_aligned_model_iql_small():
)

guide = EarleyLM(cfg)
# guide = CFGLM(cfg)

proposal = TokenProposal(guide=guide, llm=llm)

Expand Down Expand Up @@ -170,14 +169,14 @@ def test_normalizing_constant_unbiased():


def test_proper_weighting():
"""
r"""
A particle (x,w) is *properly weighted* for unnormalized density p' if, for any function f,
E_{(x,w) ~ \\tilde{q}}[f(x)w] = Σ_x p'(x) f(x)
E_{(x,w) ~ \tilde{q}}[f(x)w] = Σ_x p'(x) f(x)
where Z normalizes p'. In our case, we have that
E_{(x,w) ~ \\tilde{q}}[f(x)w] = E_{(x,S) ~ q}[f(x)w(x,S)]
E_{(x,w) ~ \tilde{q}}[f(x)w] = E_{(x,S) ~ q}[f(x)w(x,S)]
Thus, we expect
Expand Down Expand Up @@ -261,7 +260,7 @@ def test_proper_weighting():
# Probabilistic guide #
#######################

pcfg = CFGLM.from_string(
pcfg = EarleyLM.from_string(
"""
1: S -> a
Expand All @@ -286,6 +285,29 @@ def test_proper_weighting():
assert_proper_weighting(prompt, context, proposal, tol=1e-8)


# TODO: fix this error!
def todo_github_issue_15_wildcard_divide_by_zero():
guide = BoolMaskCFGLM.from_string(
"""
1: S -> a
1: S -> a a
1: S -> a a a
"""
)

V = ['a', 'aa', 'aaa', '▪']

llm = MockLLM(V=V, eos='▪', _p=np.array([0, 0, 1, 0]))

proposal = TokenProposal(llm=llm, guide=guide, K=1)

context = 'aa'

assert_proper_weighting('', context, proposal, tol=1e-8)


if __name__ == '__main__':
from arsenal import testing_framework

Expand Down

0 comments on commit c8e6fed

Please sign in to comment.