diff --git a/backend/app/pkgs/tools/llm_basic.py b/backend/app/pkgs/tools/llm_basic.py index 41042342..965bf990 100644 --- a/backend/app/pkgs/tools/llm_basic.py +++ b/backend/app/pkgs/tools/llm_basic.py @@ -1,6 +1,7 @@ import threading import time import openai +import litellm from app.pkgs.tools.llm_interface import LLMInterface from config import LLM_MODEL from config import GPT_KEYS @@ -37,12 +38,21 @@ def chatCompletion(self, context): openai.api_key = get_next_api_key() print("chartGPT - get api key:"+openai.api_key, flush=True) - response = openai.ChatCompletion.create( - model= LLM_MODEL, - messages=context, - max_tokens=12000, - temperature=0, - ) + if LLM_MODEL in litellm.models: + # see litellm supported models here: https://litellm.readthedocs.io/en/latest/supported/ + response = litellm.completion( + model= LLM_MODEL, + messages=context, + max_tokens=12000, + temperature=0, + ) + else: + response = openai.ChatCompletion.create( + model= LLM_MODEL, + messages=context, + max_tokens=12000, + temperature=0, + ) response_text = response["choices"][0]["message"]["content"] print("chartGPT - response_text:"+response_text, flush=True) diff --git a/requirements.txt b/requirements.txt index 388ce27a..c1001cf5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ flask-sqlalchemy flask-cors==3.0.10 pyyaml python-gitlab -openai==0.27.8 \ No newline at end of file +openai==0.27.8 +litellm==0.1.226 \ No newline at end of file