Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
timvieira committed Jun 18, 2024
2 parents c922f37 + 5536871 commit d0b71e6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
3 changes: 2 additions & 1 deletion bench/run_spider_llama2_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def main():
prompt_formatter = PromptFormatter(spider_train_data, spider_schemas)

model = 'meta-llama/Llama-2-7b-chat-hf'
model.replace('7b', args.model_size)
model = model.replace('7b', args.model_size)
access_token = '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'
logger.info(f"using model {model}")

tokenizer = AutoTokenizer.from_pretrained(model, token=access_token)
pipe = pipeline(
Expand Down
16 changes: 7 additions & 9 deletions genparse/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def __repr__(self):

from functools import lru_cache


@lru_cache(None)
def make_mock_llm(**kwargs):
from genparse.util import hf_tokenizer
Expand All @@ -325,17 +324,16 @@ class MockLLM(LM):
Uniform distribution over next token; used for testing.
"""

def __init__(self, V, eos):
def __init__(self, V, eos, _p=None):
n = len(V)
self._p = Float.chart({w: 1 / n for w in V})
self._logp = Float.chart({w: -np.log(n) for w in V})
super().__init__(
eos=eos,
V=V,
)
self._p = np.array([1 / n for _ in range(len(V))]) if _p is None else _p
self._logp = np.log(self._p)
self._decode = list(V)
self._encode = {x: i for i, x in enumerate(self._decode)}
super().__init__(eos=eos, V=V)

def p_next(self, _):
return self._p
return LazyProb(self._p, self._encode, self._decode)

def __call__(self, x):
assert x[-1] == self.eos
Expand Down

0 comments on commit d0b71e6

Please sign in to comment.