Skip to content

Commit

Permalink
fix(dspy): together client response parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
atmiguel committed Jun 5, 2024
1 parent f18e604 commit 285dc90
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions dsp/modules/hf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,13 @@ def run_server(self, port, model_name=None, model_path=None, env_variable=None,
docker_process.wait()

class Together(HFModel):
def __init__(self, model, **kwargs):
def __init__(self, model, api_base="https://api.together.xyz/v1", api_key=None, **kwargs):
super().__init__(model=model, is_client=True)
self.session = requests.Session()
self.api_base = os.getenv("TOGETHER_API_BASE")
self.token = os.getenv("TOGETHER_API_KEY")
self.api_base = os.getenv("TOGETHER_API_BASE") or api_base
assert not self.api_base.endswith("/"), "Together base URL shouldn't end with /"
self.token = os.getenv("TOGETHER_API_KEY") or api_key

self.model = model

self.use_inst_template = False
Expand All @@ -338,8 +340,6 @@ def __init__(self, model, **kwargs):
on_backoff=backoff_hdlr,
)
def _generate(self, prompt, use_chat_api=False, **kwargs):
url = f"{self.api_base}"

kwargs = {**self.kwargs, **kwargs}

stop = kwargs.get("stop")
Expand Down Expand Up @@ -367,6 +367,7 @@ def _generate(self, prompt, use_chat_api=False, **kwargs):
"stop": stop,
}
else:
url = f"{self.api_base}/completions"
body = {
"model": self.model,
"prompt": prompt,
Expand All @@ -384,9 +385,9 @@ def _generate(self, prompt, use_chat_api=False, **kwargs):
with self.session.post(url, headers=headers, json=body) as resp:
resp_json = resp.json()
if use_chat_api:
completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")]
completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")]
else:
completions = [resp_json['output'].get('choices', [])[0].get('text', "")]
completions = [resp_json.get('choices', [])[0].get('text', "")]
response = {"prompt": prompt, "choices": [{"text": c} for c in completions]}
return response
except Exception as e:
Expand Down

0 comments on commit 285dc90

Please sign in to comment.