From 55368714b4553c64cc2cc3f75e1b9f26c7d282d0 Mon Sep 17 00:00:00 2001 From: benlebrun Date: Tue, 18 Jun 2024 12:33:15 -0400 Subject: [PATCH] fixed bug in MockLLM --- genparse/lm.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/genparse/lm.py b/genparse/lm.py index a0bb84ad..3b9ea360 100644 --- a/genparse/lm.py +++ b/genparse/lm.py @@ -311,7 +311,6 @@ def __repr__(self): from functools import lru_cache - @lru_cache(None) def make_mock_llm(**kwargs): from genparse.util import hf_tokenizer @@ -325,17 +324,16 @@ class MockLLM(LM): Uniform distribution over next token; used for testing. """ - def __init__(self, V, eos): + def __init__(self, V, eos, _p=None): n = len(V) - self._p = Float.chart({w: 1 / n for w in V}) - self._logp = Float.chart({w: -np.log(n) for w in V}) - super().__init__( - eos=eos, - V=V, - ) + self._p = np.array([1 / n for _ in range(len(V))]) if _p is None else _p + self._logp = np.log(self._p) + self._decode = list(V) + self._encode = {x: i for i, x in enumerate(self._decode)} + super().__init__(eos=eos, V=V) def p_next(self, _): - return self._p + return LazyProb(self._p, self._encode, self._decode) def __call__(self, x): assert x[-1] == self.eos