diff --git a/genparse/proposal/token.py b/genparse/proposal/token.py index d5b50f82..c9e0d811 100644 --- a/genparse/proposal/token.py +++ b/genparse/proposal/token.py @@ -39,18 +39,20 @@ def __init__(self, *, llm, guide, K=None): super().__init__(words, old_eos=llm.eos, new_eos=guide.eos) - def _p_next(self, context, K=None): + def _p_next(self, context, K=None, execute_model_req=None, **kwargs): with self.timer['llm'](t=len(context)): - p_llm = self.llm.p_next(self._prompt + context) + p_llm = self.llm.p_next(self._prompt + context, execute_model_req=execute_model_req) with self.timer['cfg+trie'](t=len(context)): return Float.chart(take(K, self.traverse_trie(context, p_llm))).normalize() async def sample_next_token( - self, prompt, context, verbosity=0, compare_time=False, draw=sample_dict, **kwargs + self, prompt, context, verbosity=0, compare_time=False, draw=sample_dict, + execute_model_req=None, + **kwargs ): with self.timer['llm'](t=len(context)): - p_llm = await self.llm.p_next(prompt + context) + p_llm = await self.llm.p_next(prompt + context, execute_model_req=execute_model_req) with self.timer['cfg+trie'](t=len(context)): Q = Float.chart( diff --git a/vllm_runtime.pdf b/vllm_runtime.pdf deleted file mode 100644 index 29242fc8..00000000 Binary files a/vllm_runtime.pdf and /dev/null differ diff --git a/vllm_runtime.pkl b/vllm_runtime.pkl deleted file mode 100644 index a8b6f384..00000000 Binary files a/vllm_runtime.pkl and /dev/null differ