Skip to content

Commit

Permalink
fixing p_next for vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
lyutyuh committed Jun 20, 2024
1 parent 7dd732f commit b1cddb3
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions genparse/proposal/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,20 @@ def __init__(self, *, llm, guide, K=None):

super().__init__(words, old_eos=llm.eos, new_eos=guide.eos)

def _p_next(self, context, K=None):
def _p_next(self, context, K=None, execute_model_req=None, **kwargs):
with self.timer['llm'](t=len(context)):
p_llm = self.llm.p_next(self._prompt + context)
p_llm = self.llm.p_next(self._prompt + context, execute_model_req=execute_model_req)

with self.timer['cfg+trie'](t=len(context)):
return Float.chart(take(K, self.traverse_trie(context, p_llm))).normalize()

async def sample_next_token(
self, prompt, context, verbosity=0, compare_time=False, draw=sample_dict, **kwargs
self, prompt, context, verbosity=0, compare_time=False, draw=sample_dict,
execute_model_req=None,
**kwargs
):
with self.timer['llm'](t=len(context)):
p_llm = await self.llm.p_next(prompt + context)
p_llm = await self.llm.p_next(prompt + context, execute_model_req=execute_model_req)

with self.timer['cfg+trie'](t=len(context)):
Q = Float.chart(
Expand Down
Binary file removed vllm_runtime.pdf
Binary file not shown.
Binary file removed vllm_runtime.pkl
Binary file not shown.

0 comments on commit b1cddb3

Please sign in to comment.