diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index de9f986f6..71cc22666 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -307,11 +307,13 @@ def run_server(self, port, model_name=None, model_path=None, env_variable=None, docker_process.wait() class Together(HFModel): - def __init__(self, model, **kwargs): + def __init__(self, model, api_base="https://api.together.xyz/v1", api_key=None, **kwargs): super().__init__(model=model, is_client=True) self.session = requests.Session() - self.api_base = os.getenv("TOGETHER_API_BASE") - self.token = os.getenv("TOGETHER_API_KEY") + self.api_base = os.getenv("TOGETHER_API_BASE") or api_base + assert not self.api_base.endswith("/"), "Together base URL shouldn't end with /" + self.token = os.getenv("TOGETHER_API_KEY") or api_key + self.model = model self.use_inst_template = False @@ -338,8 +340,6 @@ def __init__(self, model, **kwargs): on_backoff=backoff_hdlr, ) def _generate(self, prompt, use_chat_api=False, **kwargs): - url = f"{self.api_base}" - kwargs = {**self.kwargs, **kwargs} stop = kwargs.get("stop") @@ -367,6 +367,7 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): "stop": stop, } else: + url = f"{self.api_base}/completions" body = { "model": self.model, "prompt": prompt, @@ -384,9 +385,9 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): with self.session.post(url, headers=headers, json=body) as resp: resp_json = resp.json() if use_chat_api: - completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")] + completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")] else: - completions = [resp_json['output'].get('choices', [])[0].get('text', "")] + completions = [resp_json.get('choices', [])[0].get('text', "")] response = {"prompt": prompt, "choices": [{"text": c} for c in completions]} return response except Exception as e: