Skip to content

Commit

Permalink
start litellm support, not yet working
Browse files Browse the repository at this point in the history
  • Loading branch information
njbrake committed Feb 6, 2025
1 parent 9c2fe5b commit 69bdbca
Show file tree
Hide file tree
Showing 18 changed files with 1,072 additions and 229 deletions.
19 changes: 4 additions & 15 deletions lumigator/backend/backend/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from backend.db import session_manager
from backend.repositories.datasets import DatasetRepository
from backend.repositories.jobs import JobRepository, JobResultRepository
from backend.services.completions import MistralCompletionService, OpenAICompletionService
from backend.services.completions import LiteLLMCompletionService as CompletionService
from backend.services.datasets import DatasetService
from backend.services.experiments import ExperimentService
from backend.services.jobs import JobService
Expand Down Expand Up @@ -96,19 +96,8 @@ def get_workflow_service(
WorkflowServiceDep = Annotated[WorkflowService, Depends(get_workflow_service)]


def get_mistral_completion_service() -> MistralCompletionService:
return MistralCompletionService()
def get_completion_service() -> CompletionService:
return CompletionService()


MistralCompletionServiceDep = Annotated[
MistralCompletionService, Depends(get_mistral_completion_service)
]


def get_openai_completion_service() -> OpenAICompletionService:
return OpenAICompletionService()


OpenAICompletionServiceDep = Annotated[
OpenAICompletionService, Depends(get_openai_completion_service)
]
CompletionServiceDep = Annotated[CompletionService, Depends(get_workflow_service)]
16 changes: 3 additions & 13 deletions lumigator/backend/backend/api/routes/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import APIRouter, status
from lumigator_schemas.completions import CompletionRequest

from backend.api.deps import MistralCompletionServiceDep, OpenAICompletionServiceDep
from backend.api.deps import CompletionServiceDep
from backend.services.exceptions.base_exceptions import ServiceError
from backend.services.exceptions.completion_exceptions import CompletionUpstreamError

Expand All @@ -19,16 +19,6 @@ def completion_exception_mappings() -> dict[type[ServiceError], HTTPStatus]:
}


@router.get("/")
def list_vendors():
return [VENDOR_MISTRAL, VENDOR_OPENAI]


@router.post(f"/{VENDOR_MISTRAL}")
def get_mistral_completion(request: CompletionRequest, service: MistralCompletionServiceDep):
return service.get_completions_response(request)


@router.post(f"/{VENDOR_OPENAI}")
def get_openai_completion(request: CompletionRequest, service: OpenAICompletionServiceDep):
@router.post("/")
def get_completion(request: CompletionRequest, service: CompletionServiceDep) -> dict:
return service.get_completions_response(request)
41 changes: 37 additions & 4 deletions lumigator/backend/backend/config_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Evaluation templates
from lumigator_schemas.jobs import JobType

from backend.settings import settings

seq2seq_eval_template = """{{
"name": "{job_name}/{job_id}",
"model": {{ "path": "{model_uri}" }},
Expand Down Expand Up @@ -127,7 +129,6 @@
"dataset": {{ "path": "{dataset_path}" }}
}}"""


oai_infer_template = """{{
"name": "{job_name}/{job_id}",
"dataset": {{ "path": "{dataset_path}" }},
Expand All @@ -151,12 +152,44 @@
}}"""


# LiteLLM can support a lot of other options, but we're keeping it simple for now and will keep just
# a few opinionated options https://docs.litellm.ai/docs/providers
SUPPORTED_CONFIGS = {
"open-mistral-7b": {
"litellm_params": {
"model": "mistral/open-mistral-7b",
},
"max_tokens": 256,
"temperature": 1,
"top_p": 1,
"prompt": settings.DEFAULT_SUMMARIZER_PROMPT,
},
"gpt-4o-mini": {
"litellm_params": {
"model": "text-completion-openai/gpt-4o-mini",
},
"max_tokens": 256,
"temperature": 1,
"top_p": 1,
"prompt": settings.DEFAULT_SUMMARIZER_PROMPT,
},
"gpt-4o": {
"litellm_params": {
"model": "text-completion-openai/gpt-4o",
},
"max_tokens": 256,
"temperature": 1,
"top_p": 1,
"prompt": settings.DEFAULT_SUMMARIZER_PROMPT,
},
}

templates = {
JobType.INFERENCE: {
"default": default_infer_template,
"oai://gpt-4o-mini": oai_infer_template,
"oai://gpt-4o": oai_eval_template,
"mistral://open-mistral-7b": oai_infer_template,
SUPPORTED_CONFIGS["gpt-4o-mini"]["litellm_params"]["model"]: oai_infer_template,
SUPPORTED_CONFIGS["gpt-4o"]["litellm_params"]["model"]: oai_infer_template,
SUPPORTED_CONFIGS["open-mistral-7b"]["litellm_params"]["model"]: oai_infer_template,
"llamafile://mistralai/Mistral-7B-Instruct-v0.2": oai_infer_template,
},
JobType.EVALUATION: {
Expand Down
79 changes: 16 additions & 63 deletions lumigator/backend/backend/services/completions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from abc import ABC, abstractmethod
from http import HTTPStatus

from litellm import APIError, OpenAIError, completion
from lumigator_schemas.completions import CompletionRequest, CompletionResponse
from mistralai.client import MistralClient
from mistralai.exceptions import MistralAPIException, MistralException
from mistralai.models.chat_completion import ChatMessage
from openai import APIError, OpenAI, OpenAIError

from backend.config_templates import SUPPORTED_CONFIGS
from backend.services.exceptions.completion_exceptions import CompletionUpstreamError
from backend.settings import settings


class CompletionService(ABC):
Expand All @@ -17,73 +13,30 @@ def get_completions_response(self, request: CompletionRequest) -> CompletionResp
pass


class MistralCompletionService(CompletionService):
def __init__(self):
if settings.MISTRAL_API_KEY is None:
raise Exception("MISTRAL_API_KEY is not set")
self.client = MistralClient(api_key=settings.MISTRAL_API_KEY)
self.model = "open-mistral-7b"
self.max_tokens = 256
self.temperature = 1
self.top_p = 1
self.prompt = settings.DEFAULT_SUMMARIZER_PROMPT

def get_completions_response(self, request: CompletionRequest) -> CompletionResponse:
"""Gets a completion response from the API.
:param request: the request (text) to be completed
:raises CompletionUpstreamError: if there is an exception interacting with Mistral
"""
service_name = "Mistral"
try:
response = self.client.chat(
model=self.model,
messages=[
ChatMessage(role="system", content=f"{self.prompt}"),
ChatMessage(role="user", content=f"{request.text}"),
],
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
)
response = response.choices[0].message.content
return CompletionResponse(text=response)
except MistralAPIException as e:
raise CompletionUpstreamError(service_name, HTTPStatus(e.http_status).phrase, e) from e
except MistralException as e:
raise CompletionUpstreamError(
service_name, "unexpected error getting completions response", e
) from e


class OpenAICompletionService(CompletionService):
def __init__(self):
if settings.OAI_API_KEY is None:
raise Exception("OPENAI_API_KEY is not set")
self.client = OpenAI(api_key=settings.OAI_API_KEY)
self.model = "gpt-4o-mini"
self.max_tokens = 256
self.temperature = 1
self.top_p = 1
self.prompt = settings.DEFAULT_SUMMARIZER_PROMPT

class LiteLLMCompletionService(CompletionService):
def get_completions_response(self, request: CompletionRequest) -> CompletionResponse:
"""Gets a completion response from the API.
:param request: the request (text) to be completed
:raises CompletionUpstreamError: if there is an exception interacting with OpenAI
"""
service_name = "OpenAI"
model = request.model_name
model_config = SUPPORTED_CONFIGS.get(model)
service_name = "LiteLLM"
if not model_config:
raise CompletionUpstreamError(
service_name, f"model {model} is not supported by Lumigator", None
)
try:
response = self.client.chat.completions.create(
model=self.model,
response = completion(
model=model_config["litellm_params"]["model"],
messages=[
{"role": "system", "content": self.prompt},
{"role": "system", "content": model_config["prompt"]},
{"role": "user", "content": request.text},
],
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
temperature=model_config["temperature"],
max_tokens=model_config["max_tokens"],
top_p=model_config["top_p"],
)
response = response.choices[0].message.content
return CompletionResponse(text=response)
Expand Down
15 changes: 2 additions & 13 deletions lumigator/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,17 +306,6 @@ def _get_config_template(self, job_type: str, model_name: str) -> str:

return config_template

def _set_model_type(self, request) -> str:
"""Sets model URL based on protocol address"""
if request.model.startswith("oai://"):
model_url = settings.OAI_API_URL
elif request.model.startswith("mistral://"):
model_url = settings.MISTRAL_API_URL
else:
model_url = request.model_url

return model_url

def _validate_config(self, job_type: str, config_template: str, config_params: dict):
if job_type == JobType.INFERENCE:
InferenceJobConfig.model_validate_json(config_template.format(**config_params))
Expand Down Expand Up @@ -350,7 +339,7 @@ def _get_job_params(self, job_type: JobType, record, request: BaseModel) -> dict
# this section differs between inference and eval
if job_type == JobType.EVALUATION:
job_params = job_params | {
"model_url": self._set_model_type(request),
"model_url": request.model_url,
"skip_inference": request.skip_inference,
"system_prompt": request.system_prompt,
}
Expand All @@ -359,7 +348,7 @@ def _get_job_params(self, job_type: JobType, record, request: BaseModel) -> dict
"accelerator": request.accelerator,
"frequency_penalty": request.frequency_penalty,
"max_tokens": request.max_tokens,
"model_url": self._set_model_type(request),
"model_url": request.model_url,
"output_field": request.output_field,
"revision": request.revision,
"system_prompt": request.system_prompt,
Expand Down
2 changes: 0 additions & 2 deletions lumigator/backend/backend/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def TRACKING_BACKEND_URI(self) -> str: # noqa: N802
raise ValueError(f"Unsupported tracking backend: {self.TRACKING_BACKEND}")

# Served models
OAI_API_URL: str = "https://api.openai.com/v1"
MISTRAL_API_URL: str = "https://api.mistral.ai/v1"
DEFAULT_SUMMARIZER_PROMPT: str = "You are a helpful assistant, expert in text summarization. For every prompt you receive, provide a summary of its contents in at most two sentences." # noqa: E501

# Eval job details
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ def test_set_explicit_inference_job_params(job_record, job_service):
"http://localhost:8000/v1/chat/completions",
"http://localhost:8000/v1/chat/completions",
),
# openai model (from API)
("oai://gpt-4-turbo", None, settings.OAI_API_URL),
# mistral model (from API)
("mistral://open-mistral-7b", None, settings.MISTRAL_API_URL),
],
)
def test_set_model(job_service, model, input_model_url, returned_model_url):
Expand All @@ -75,5 +71,5 @@ def test_set_model(job_service, model, input_model_url, returned_model_url):
model_url=input_model_url,
dataset="d34dd34d-d34d-d34d-d34d-d34dd34dd34d",
)
model_url = job_service._set_model_type(request)
model_url = request.model_url
assert model_url == returned_model_url
3 changes: 2 additions & 1 deletion lumigator/backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ dependencies = [
"python-dotenv>=1.0.1",
"alembic>=1.13.3",
"lumigator-schemas",
"mlflow>=2.20.0"
"mlflow>=2.20.0",
"litellm>=1.43.1",
]

[tool.uv]
Expand Down
Loading

0 comments on commit 69bdbca

Please sign in to comment.