Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LotusUsageLimitException #129

Merged
merged 5 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions docs/llm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ Example models include but not limited to: OpenAI, Ollama, vLLM

Example
---------
To run a model, you can use the LM class. We use the liteLLMm library to interface with the model. This allows
ypu to use any model provider that is supported by liteLLM
To run a model, you can use the LM class. We use the LiteLLM library to interface with the model. This allows
you to use any model provider that is supported by LiteLLM.

Creating a LM object for gpt-4o

Expand All @@ -33,4 +33,38 @@ Creating a LM object to use Meta-Llama-3-8B-Instruct on vLLM
lm = LM(model='hosted_vllm/meta-llama/Meta-Llama-3-8B-Instruct',
api_base='http://localhost:8000/v1',
max_ctx_len=8000,
max_tokens=1000)
max_tokens=1000)

Usage Limits
-----------
The LM class supports setting usage limits to control costs and token consumption. You can set limits on:

- Prompt tokens
- Completion tokens
- Total tokens
- Total cost

When any limit is exceeded, a ``LotusUsageLimitException`` will be raised.

Example setting usage limits:

.. code-block:: python

from lotus.models import LM
from lotus.types import UsageLimit, LotusUsageLimitException

# Set limits
usage_limit = UsageLimit(
prompt_tokens_limit=4000,
completion_tokens_limit=1000,
total_tokens_limit=3000,
total_cost_limit=1.00
)
lm = LM(model="gpt-4o", usage_limit=usage_limit)

try:
course_df = pd.read_csv("course_df.csv")
course_df = course_df.sem_filter("What {Course Name} requires a lot of math")
except LotusUsageLimitException as e:
print(f"Usage limit exceeded: {e}")
# Handle the exception as needed
44 changes: 39 additions & 5 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@

import lotus
from lotus.cache import CacheFactory
from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade
from lotus.types import (
LMOutput,
LMStats,
LogprobsForCascade,
LogprobsForFilterCascade,
LotusUsageLimitException,
UsageLimit,
)

logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
logging.getLogger("httpx").setLevel(logging.CRITICAL)
Expand All @@ -29,8 +36,22 @@ def __init__(
max_batch_size: int = 64,
tokenizer: Tokenizer | None = None,
cache=None,
usage_limit: UsageLimit = UsageLimit(),
**kwargs: dict[str, Any],
):
"""Language Model class for interacting with various LLM providers.

Args:
model (str): Name of the model to use. Defaults to "gpt-4o-mini".
temperature (float): Sampling temperature. Defaults to 0.0.
max_ctx_len (int): Maximum context length in tokens. Defaults to 128000.
max_tokens (int): Maximum number of tokens to generate. Defaults to 512.
max_batch_size (int): Maximum batch size for concurrent requests. Defaults to 64.
tokenizer (Tokenizer | None): Custom tokenizer instance. Defaults to None.
cache: Cache instance to use. Defaults to None.
usage_limit (UsageLimit): Usage limits for the model. Defaults to UsageLimit().
**kwargs: Additional keyword arguments passed to the underlying LLM API.
"""
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens
Expand All @@ -39,6 +60,7 @@ def __init__(
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()
self.usage_limit = usage_limit

self.cache = cache or CacheFactory.create_default_cache()

Expand Down Expand Up @@ -72,9 +94,11 @@ def __call__(
uncached_responses = self._process_uncached_messages(
uncached_data, all_kwargs, show_progress_bar, progress_bar_desc
)
if lotus.settings.enable_cache:
# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):

# Add new responses to cache and update stats
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._update_stats(resp)
if lotus.settings.enable_cache:
Comment on lines -75 to +101
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug fix from #74

self._cache_response(resp, hash)

# Merge all responses in original order and extract outputs
Expand Down Expand Up @@ -115,7 +139,6 @@ def _cache_response(self, response, hash):
"""Caches a response and updates stats if successful."""
if isinstance(response, OpenAIError):
raise response
self._update_stats(response)
self.cache.insert(hash, response)

def _hash_messages(self, messages: list[dict[str, str]], kwargs: dict[str, Any]) -> str:
Expand All @@ -138,6 +161,17 @@ def _update_stats(self, response: ModelResponse):
self.stats.total_usage.completion_tokens += response.usage.completion_tokens
self.stats.total_usage.total_tokens += response.usage.total_tokens

# Check if any usage limits are exceeded
if (
self.stats.total_usage.prompt_tokens > self.usage_limit.prompt_tokens_limit
or self.stats.total_usage.completion_tokens > self.usage_limit.completion_tokens_limit
or self.stats.total_usage.total_tokens > self.usage_limit.total_tokens_limit
or self.stats.total_usage.total_cost > self.usage_limit.total_cost_limit
):
raise LotusUsageLimitException(
f"Usage limit exceeded. Current usage: {self.stats.total_usage}, Limit: {self.usage_limit}"
)

try:
self.stats.total_usage.total_cost += completion_cost(completion_response=response)
except litellm.exceptions.NotFoundError as e:
Expand Down
26 changes: 26 additions & 0 deletions lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,29 @@ class SerializationFormat(Enum):
JSON = "json"
XML = "xml"
DEFAULT = "default"


################################################################################
# Utility
################################################################################
@dataclass
class UsageLimit:
prompt_tokens_limit: float = float("inf")
completion_tokens_limit: float = float("inf")
total_tokens_limit: float = float("inf")
total_cost_limit: float = float("inf")


################################################################################
# Exception related
################################################################################
class LotusException(Exception):
"""Base class for all Lotus exceptions."""

pass


class LotusUsageLimitException(LotusException):
"""Exception raised when the usage limit is exceeded."""

pass
33 changes: 33 additions & 0 deletions tests/test_lm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,41 @@
import pytest

from lotus.models import LM
from lotus.types import LotusUsageLimitException, UsageLimit
from tests.base_test import BaseTest


class TestLM(BaseTest):
def test_lm_initialization(self):
lm = LM(model="gpt-4o-mini")
assert isinstance(lm, LM)

def test_lm_token_usage_limit(self):
# Test prompt token limit
usage_limit = UsageLimit(prompt_tokens_limit=100)
lm = LM(model="gpt-4o-mini", usage_limit=usage_limit)
short_prompt = "What is the capital of France? Respond in one word."
messages = [[{"role": "user", "content": short_prompt}]]
lm(messages)

long_prompt = "What is the capital of France? Respond in one word." * 50
messages = [[{"role": "user", "content": long_prompt}]]
with pytest.raises(LotusUsageLimitException):
lm(messages)

# Test completion token limit
usage_limit = UsageLimit(completion_tokens_limit=10)
lm = LM(model="gpt-4o-mini", usage_limit=usage_limit)
long_response_prompt = "Write a 100 word essay about the history of France"
messages = [[{"role": "user", "content": long_response_prompt}]]
with pytest.raises(LotusUsageLimitException):
lm(messages)

# Test total token limit
usage_limit = UsageLimit(total_tokens_limit=50)
lm = LM(model="gpt-4o-mini", usage_limit=usage_limit)
messages = [[{"role": "user", "content": short_prompt}]]
lm(messages) # First call should work
with pytest.raises(LotusUsageLimitException):
for _ in range(5): # Multiple calls to exceed total limit
lm(messages)