diff --git a/textgrad/engine/vllm.py b/textgrad/engine/vllm.py index 69f241a..c898a63 100644 --- a/textgrad/engine/vllm.py +++ b/textgrad/engine/vllm.py @@ -18,6 +18,7 @@ def __init__( self, model_string="meta-llama/Meta-Llama-3-8B-Instruct", system_prompt=DEFAULT_SYSTEM_PROMPT, + **llm_config, ): root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_vllm_{model_string}.db") @@ -25,7 +26,7 @@ def __init__( self.model_string = model_string self.system_prompt = system_prompt - self.client = LLM(self.model_string) + self.client = LLM(self.model_string, **llm_config) self.tokenizer = self.client.get_tokenizer() def generate(