Skip to content

Commit

Permalink
Add ollama support and provider examples
Browse files Browse the repository at this point in the history
  • Loading branch information
sidjha1 committed Sep 29, 2024
1 parent 9f0db7d commit d06f68c
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 14 deletions.
20 changes: 20 additions & 0 deletions examples/provider_examples/oai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel

lm = OpenAIModel()

lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df = df.sem_filter(user_instruction)
print(df)
25 changes: 25 additions & 0 deletions examples/provider_examples/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel

lm = OpenAIModel(
api_base="http://localhost:11434/v1",
model="llama3.2",
hf_name="meta-llama/Llama-3.2-3B-Instruct",
provider="ollama",
)

lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df = df.sem_filter(user_instruction)
print(df)
24 changes: 24 additions & 0 deletions examples/provider_examples/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pandas as pd

import lotus
from lotus.models import OpenAIModel

lm = OpenAIModel(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
api_base="http://localhost:8000/v1",
provider="vllm",
)

lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df = df.sem_filter(user_instruction)
print(df)
19 changes: 5 additions & 14 deletions lotus/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@
import lotus
from lotus.models.lm import LM

# Mapping from Databricks model names to their Hugging Face model names for tokenizers
DBRX_NAME_TO_MODEL = {
"databricks-dbrx-instruct": "databricks/dbrx-instruct",
"databricks-llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf",
"databricks-mixtral-8x7b-instruct": "mistralai/Mixtral-8x7B-Instruct-v0.1",
}

ERRORS = (openai.RateLimitError, openai.APIError)


Expand Down Expand Up @@ -46,16 +39,17 @@ class OpenAIModel(LM):
def __init__(
self,
model: str = "gpt-4o-mini",
hf_name: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
provider: str = "openai",
max_batch_size=64,
max_ctx_len=4096,
max_batch_size: int = 64,
max_ctx_len: int = 4096,
**kwargs: Dict[str, Any],
):
super().__init__()
self.provider = provider
self.use_chat = provider in ["openai", "dbrx"]
self.use_chat = provider in ["openai", "dbrx", "ollama"]
self.max_batch_size = max_batch_size
self.max_ctx_len = max_ctx_len

Expand All @@ -70,14 +64,11 @@ def __init__(

self.client = OpenAI(api_key=api_key, base_url=api_base)

self.kwargs["model"] = model
# TODO: Refactor this
if self.provider == "openai":
self.tokenizer = tiktoken.encoding_for_model(model)
elif model in DBRX_NAME_TO_MODEL:
self.tokenizer = AutoTokenizer.from_pretrained(DBRX_NAME_TO_MODEL[model])
else:
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.tokenizer = AutoTokenizer.from_pretrained(hf_name)

def handle_chat_request(self, messages: List, **kwargs: Dict[str, Any]) -> Union[List, Tuple[List, List]]:
"""Handle single chat request to OpenAI server.
Expand Down

0 comments on commit d06f68c

Please sign in to comment.