Skip to content

Commit

Permalink
Add Ollama Support (#10)
Browse files Browse the repository at this point in the history
Adds support for Ollama and adds examples of using various LM providers.
This adds a `hf_name` parameter to the `OpenAIModel` class which is
helpful when the model name and tokenizer name differ. So we don't need
to do this weird hardcoded mapping thing for Databricks models anymore.

Tested Ollama by running the operator examples with Ollama running Llama
3B locally on MacBook. Verified that OAI examples still work.
  • Loading branch information
sidjha1 authored Sep 29, 2024
1 parent 32c37b5 commit 2dad842
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 15 deletions.
20 changes: 20 additions & 0 deletions examples/provider_examples/oai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel

lm = OpenAIModel()

lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df = df.sem_filter(user_instruction)
print(df)
25 changes: 25 additions & 0 deletions examples/provider_examples/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel

lm = OpenAIModel(
api_base="http://localhost:11434/v1",
model="llama3.2",
hf_name="meta-llama/Llama-3.2-3B-Instruct",
provider="ollama",
)

lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df = df.sem_filter(user_instruction)
print(df)
24 changes: 24 additions & 0 deletions examples/provider_examples/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel

lm = OpenAIModel(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
api_base="http://localhost:8000/v1",
provider="vllm",
)

lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df = df.sem_filter(user_instruction)
print(df)
24 changes: 9 additions & 15 deletions lotus/models/openai_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import threading
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -11,13 +12,6 @@
import lotus
from lotus.models.lm import LM

# Mapping from Databricks model names to their Hugging Face model names for tokenizers
DBRX_NAME_TO_MODEL = {
"databricks-dbrx-instruct": "databricks/dbrx-instruct",
"databricks-llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf",
"databricks-mixtral-8x7b-instruct": "mistralai/Mixtral-8x7B-Instruct-v0.1",
}

ERRORS = (openai.RateLimitError, openai.APIError)


Expand Down Expand Up @@ -46,18 +40,20 @@ class OpenAIModel(LM):
def __init__(
self,
model: str = "gpt-4o-mini",
hf_name: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
provider: str = "openai",
max_batch_size=64,
max_ctx_len=4096,
max_batch_size: int = 64,
max_ctx_len: int = 4096,
**kwargs: Dict[str, Any],
):
super().__init__()
self.provider = provider
self.use_chat = provider in ["openai", "dbrx"]
self.use_chat = provider in ["openai", "dbrx", "ollama"]
self.max_batch_size = max_batch_size
self.max_ctx_len = max_ctx_len
self.hf_name = hf_name if hf_name is not None else model

self.kwargs = {
"model": model,
Expand All @@ -68,16 +64,14 @@ def __init__(
**kwargs,
}

self.client = OpenAI(api_key=api_key, base_url=api_base)
api_key = api_key or os.environ.get("OPENAI_API_KEY", "None")
self.client = OpenAI(api_key=api_key if api_key else "None", base_url=api_base)

self.kwargs["model"] = model
# TODO: Refactor this
if self.provider == "openai":
self.tokenizer = tiktoken.encoding_for_model(model)
elif model in DBRX_NAME_TO_MODEL:
self.tokenizer = AutoTokenizer.from_pretrained(DBRX_NAME_TO_MODEL[model])
else:
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name)

def handle_chat_request(self, messages: List, **kwargs: Dict[str, Any]) -> Union[List, Tuple[List, List]]:
"""Handle single chat request to OpenAI server.
Expand Down

0 comments on commit 2dad842

Please sign in to comment.