From ecede3fb1c4077c3053cb0eb6fc7df755d75c0ad Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Sun, 29 Sep 2024 14:59:01 -0700 Subject: [PATCH] Minor fix --- lotus/models/openai_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lotus/models/openai_model.py b/lotus/models/openai_model.py index 0d992a9f..a828ca2b 100644 --- a/lotus/models/openai_model.py +++ b/lotus/models/openai_model.py @@ -52,6 +52,7 @@ def __init__( self.use_chat = provider in ["openai", "dbrx", "ollama"] self.max_batch_size = max_batch_size self.max_ctx_len = max_ctx_len + self.hf_name = hf_name if hf_name is not None else model self.kwargs = { "model": model, @@ -68,7 +69,7 @@ def __init__( if self.provider == "openai": self.tokenizer = tiktoken.encoding_for_model(model) else: - self.tokenizer = AutoTokenizer.from_pretrained(hf_name) + self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name) def handle_chat_request(self, messages: List, **kwargs: Dict[str, Any]) -> Union[List, Tuple[List, List]]: """Handle single chat request to OpenAI server.