From 25b65ff14d1f73b3a6e633eba4628b22ffbe163d Mon Sep 17 00:00:00 2001 From: Robert Kirchner Date: Wed, 11 Sep 2024 10:49:43 -0500 Subject: [PATCH] fix parse completion request body - rk --- src/main/serve/routers/open_ai_router.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/main/serve/routers/open_ai_router.py b/src/main/serve/routers/open_ai_router.py index e1ebbec..03598a0 100644 --- a/src/main/serve/routers/open_ai_router.py +++ b/src/main/serve/routers/open_ai_router.py @@ -24,9 +24,8 @@ def chat_completions_endpoint(): body = request.get_json(force=True) prompt = _construct_chat_prompt(body) - - max_tokens = int(body['max_tokens']) if body['max_tokens'] is not None else 100 - completion = llm.completion(prompt, max_tokens, parse_temp(float(body['temperature'])), stops=body['stop'], repetition_penalty=body['frequency_penalty']) + max_tokens = int(body['max_tokens']) if 'max_tokens' in body else 100 + completion = llm.completion(prompt, max_tokens, parse_temp(float(body['temperature']) if 'temperature' in body else 0), stops=body['stop'] if 'stop' in body else None, repetition_penalty=body['frequency_penalty'] if 'frequency_penalty' in body else None) prompt_tokens = len(encoding.encode(prompt)) completion_tokens = len(encoding.encode(completion)) chat_response = { @@ -42,7 +41,7 @@ def chat_completions_endpoint(): "content": f"{completion}", }, "logprobs": None, - "finish_reason": "length" if body['stop'] is None or not completion.endswith(body['stop']) else "stop" + "finish_reason": _get_finish_reason(body, completion) }], "usage": { "prompt_tokens": prompt_tokens, @@ -55,8 +54,8 @@ def chat_completions_endpoint(): @app.route("/v1/completions", methods=['POST']) def completions_endpoint(): body = request.get_json(force=True) - max_tokens = int(body['max_tokens']) if body['max_tokens'] is not None else 100 - completion = llm.completion(body['prompt'], max_tokens, parse_temp(float(body['temperature'])), stops=body['stop'], repetition_penalty=body['frequency_penalty']) + max_tokens = int(body['max_tokens']) if 'max_tokens' in body else 100 + completion = llm.completion(body['prompt'], max_tokens, parse_temp(float(body['temperature']) if 'temperature' in body else 0), stops=body['stop'] if 'stop' in body else None, repetition_penalty=body['frequency_penalty'] if 'frequency_penalty' in body else None) prompt_tokens = len(encoding.encode(body['prompt'])) completion_tokens = len(encoding.encode(completion)) @@ -71,7 +70,7 @@ def completions_endpoint(): "text": completion, "index": 0, "logprobs": None, - "finish_reason": "length" if body['stop'] is None or not completion.endswith(body['stop']) else "stop" + "finish_reason": _get_finish_reason(body, completion) } ], "usage": { @@ -82,3 +81,10 @@ def completions_endpoint(): } return jsonify(completion_response) + +def _get_finish_reason(body: dict, completion: str) -> str: + if 'stop' in body: + for stop in body['stop']: + if completion.endswith(stop): + return "stop" + return "length" \ No newline at end of file