Skip to content

Commit

Permalink
Add LotusUsageLimitException (#129)
Browse files Browse the repository at this point in the history
We can set `usage_limit` to be sure that we do not exceed a token limit.
Raises `LotusUsageLimitException` if exceeded.
  • Loading branch information
sidjha1 authored Feb 21, 2025
1 parent eda795b commit 877dff1
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 8 deletions.
40 changes: 37 additions & 3 deletions docs/llm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ Example models include but not limited to: OpenAI, Ollama, vLLM

Example
---------
To run a model, you can use the LM class. We use the liteLLMm library to interface with the model. This allows
ypu to use any model provider that is supported by liteLLM
To run a model, you can use the LM class. We use the LiteLLM library to interface with the model. This allows
you to use any model provider that is supported by LiteLLM.

Creating a LM object for gpt-4o

Expand All @@ -33,4 +33,38 @@ Creating a LM object to use Meta-Llama-3-8B-Instruct on vLLM
lm = LM(model='hosted_vllm/meta-llama/Meta-Llama-3-8B-Instruct',
api_base='http://localhost:8000/v1',
max_ctx_len=8000,
max_tokens=1000)
max_tokens=1000)
Usage Limits
-----------
The LM class supports setting usage limits to control costs and token consumption. You can set limits on:

- Prompt tokens
- Completion tokens
- Total tokens
- Total cost

When any limit is exceeded, a ``LotusUsageLimitException`` will be raised.

Example setting usage limits:

.. code-block:: python
from lotus.models import LM
from lotus.types import UsageLimit, LotusUsageLimitException
# Set limits
usage_limit = UsageLimit(
prompt_tokens_limit=4000,
completion_tokens_limit=1000,
total_tokens_limit=3000,
total_cost_limit=1.00
)
lm = LM(model="gpt-4o", usage_limit=usage_limit)
try:
course_df = pd.read_csv("course_df.csv")
course_df = course_df.sem_filter("What {Course Name} requires a lot of math")
except LotusUsageLimitException as e:
print(f"Usage limit exceeded: {e}")
# Handle the exception as needed
44 changes: 39 additions & 5 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@

import lotus
from lotus.cache import CacheFactory
from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade
from lotus.types import (
LMOutput,
LMStats,
LogprobsForCascade,
LogprobsForFilterCascade,
LotusUsageLimitException,
UsageLimit,
)

logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
logging.getLogger("httpx").setLevel(logging.CRITICAL)
Expand All @@ -29,8 +36,22 @@ def __init__(
max_batch_size: int = 64,
tokenizer: Tokenizer | None = None,
cache=None,
usage_limit: UsageLimit = UsageLimit(),
**kwargs: dict[str, Any],
):
"""Language Model class for interacting with various LLM providers.
Args:
model (str): Name of the model to use. Defaults to "gpt-4o-mini".
temperature (float): Sampling temperature. Defaults to 0.0.
max_ctx_len (int): Maximum context length in tokens. Defaults to 128000.
max_tokens (int): Maximum number of tokens to generate. Defaults to 512.
max_batch_size (int): Maximum batch size for concurrent requests. Defaults to 64.
tokenizer (Tokenizer | None): Custom tokenizer instance. Defaults to None.
cache: Cache instance to use. Defaults to None.
usage_limit (UsageLimit): Usage limits for the model. Defaults to UsageLimit().
**kwargs: Additional keyword arguments passed to the underlying LLM API.
"""
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens
Expand All @@ -39,6 +60,7 @@ def __init__(
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()
self.usage_limit = usage_limit

self.cache = cache or CacheFactory.create_default_cache()

Expand Down Expand Up @@ -72,9 +94,11 @@ def __call__(
uncached_responses = self._process_uncached_messages(
uncached_data, all_kwargs, show_progress_bar, progress_bar_desc
)
if lotus.settings.enable_cache:
# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):

# Add new responses to cache and update stats
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._update_stats(resp)
if lotus.settings.enable_cache:
self._cache_response(resp, hash)

# Merge all responses in original order and extract outputs
Expand Down Expand Up @@ -115,7 +139,6 @@ def _cache_response(self, response, hash):
"""Caches a response and updates stats if successful."""
if isinstance(response, OpenAIError):
raise response
self._update_stats(response)
self.cache.insert(hash, response)

def _hash_messages(self, messages: list[dict[str, str]], kwargs: dict[str, Any]) -> str:
Expand All @@ -138,6 +161,17 @@ def _update_stats(self, response: ModelResponse):
self.stats.total_usage.completion_tokens += response.usage.completion_tokens
self.stats.total_usage.total_tokens += response.usage.total_tokens

# Check if any usage limits are exceeded
if (
self.stats.total_usage.prompt_tokens > self.usage_limit.prompt_tokens_limit
or self.stats.total_usage.completion_tokens > self.usage_limit.completion_tokens_limit
or self.stats.total_usage.total_tokens > self.usage_limit.total_tokens_limit
or self.stats.total_usage.total_cost > self.usage_limit.total_cost_limit
):
raise LotusUsageLimitException(
f"Usage limit exceeded. Current usage: {self.stats.total_usage}, Limit: {self.usage_limit}"
)

try:
self.stats.total_usage.total_cost += completion_cost(completion_response=response)
except litellm.exceptions.NotFoundError as e:
Expand Down
26 changes: 26 additions & 0 deletions lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,29 @@ class SerializationFormat(Enum):
JSON = "json"
XML = "xml"
DEFAULT = "default"


################################################################################
# Utility
################################################################################
@dataclass
class UsageLimit:
prompt_tokens_limit: float = float("inf")
completion_tokens_limit: float = float("inf")
total_tokens_limit: float = float("inf")
total_cost_limit: float = float("inf")


################################################################################
# Exception related
################################################################################
class LotusException(Exception):
"""Base class for all Lotus exceptions."""

pass


class LotusUsageLimitException(LotusException):
"""Exception raised when the usage limit is exceeded."""

pass
33 changes: 33 additions & 0 deletions tests/test_lm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,41 @@
import pytest

from lotus.models import LM
from lotus.types import LotusUsageLimitException, UsageLimit
from tests.base_test import BaseTest


class TestLM(BaseTest):
def test_lm_initialization(self):
lm = LM(model="gpt-4o-mini")
assert isinstance(lm, LM)

def test_lm_token_usage_limit(self):
# Test prompt token limit
usage_limit = UsageLimit(prompt_tokens_limit=100)
lm = LM(model="gpt-4o-mini", usage_limit=usage_limit)
short_prompt = "What is the capital of France? Respond in one word."
messages = [[{"role": "user", "content": short_prompt}]]
lm(messages)

long_prompt = "What is the capital of France? Respond in one word." * 50
messages = [[{"role": "user", "content": long_prompt}]]
with pytest.raises(LotusUsageLimitException):
lm(messages)

# Test completion token limit
usage_limit = UsageLimit(completion_tokens_limit=10)
lm = LM(model="gpt-4o-mini", usage_limit=usage_limit)
long_response_prompt = "Write a 100 word essay about the history of France"
messages = [[{"role": "user", "content": long_response_prompt}]]
with pytest.raises(LotusUsageLimitException):
lm(messages)

# Test total token limit
usage_limit = UsageLimit(total_tokens_limit=50)
lm = LM(model="gpt-4o-mini", usage_limit=usage_limit)
messages = [[{"role": "user", "content": short_prompt}]]
lm(messages) # First call should work
with pytest.raises(LotusUsageLimitException):
for _ in range(5): # Multiple calls to exceed total limit
lm(messages)

0 comments on commit 877dff1

Please sign in to comment.