Skip to content

Commit

Permalink
fixed bug in MockLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
benlebrun committed Jun 18, 2024
1 parent 4b0f5d8 commit 5536871
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions genparse/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def __repr__(self):

from functools import lru_cache


@lru_cache(None)
def make_mock_llm(**kwargs):
from genparse.util import hf_tokenizer
Expand All @@ -325,17 +324,16 @@ class MockLLM(LM):
Uniform distribution over next token; used for testing.
"""

def __init__(self, V, eos):
def __init__(self, V, eos, _p=None):
n = len(V)
self._p = Float.chart({w: 1 / n for w in V})
self._logp = Float.chart({w: -np.log(n) for w in V})
super().__init__(
eos=eos,
V=V,
)
self._p = np.array([1 / n for _ in range(len(V))]) if _p is None else _p
self._logp = np.log(self._p)
self._decode = list(V)
self._encode = {x: i for i, x in enumerate(self._decode)}
super().__init__(eos=eos, V=V)

def p_next(self, _):
return self._p
return LazyProb(self._p, self._encode, self._decode)

def __call__(self, x):
assert x[-1] == self.eos
Expand Down

0 comments on commit 5536871

Please sign in to comment.