Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Typing Issues #22

Merged
merged 4 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,33 @@ def test_join(setup_models):
assert joined_pairs == expected_pairs


def test_join_cascade(setup_models):
gpt_4o_mini, gpt_4o = setup_models
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini)

data1 = {"School": ["UC Berkeley", "Stanford"]}
data2 = {"School Type": ["Public School", "Private School"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
join_instruction = "{School} is a {School Type}"
expected_pairs = set([("UC Berkeley", "Public School"), ("Stanford", "Private School")])

# All joins resolved by the helper model
joined_df, stats = df1.sem_join(df2, join_instruction, cascade_threshold=0, return_stats=True)
joined_pairs = set(zip(joined_df["School"], joined_df["School Type"]))
assert joined_pairs == expected_pairs
assert stats["filters_resolved_by_large_model"] == 0, stats
assert stats["filters_resolved_by_helper_model"] == 4, stats

# All joins resolved by the large model
joined_df, stats = df1.sem_join(df2, join_instruction, cascade_threshold=1.01, return_stats=True)
joined_pairs = set(zip(joined_df["School"], joined_df["School Type"]))
assert joined_pairs == expected_pairs
assert stats["filters_resolved_by_large_model"] == 4, stats
assert stats["filters_resolved_by_helper_model"] == 0, stats


def test_map_fewshot(setup_models):
gpt_4o_mini, _ = setup_models
lotus.settings.configure(lm=gpt_4o_mini)
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: |
Expand All @@ -43,7 +43,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version: 2
build:
os: ubuntu-20.04
tools:
python: "3.9"
python: "3.10"

sphinx:
configuration: docs/conf.py
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

To set up for development, create a conda environment, install lotus, and install additional dev dependencies.
```
conda create -n lotus python=3.9 -y
conda create -n lotus python=3.10 -y
conda activate lotus
git clone [email protected]:stanford-futuredata/lotus.git
pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ LOTUS offers a number of semantic operators in a Pandas-like API, some of which

# Installation
```
conda create -n lotus python=3.9 -y
conda create -n lotus python=3.10 -y
conda activate lotus
pip install lotus-ai
```
Expand Down
4 changes: 2 additions & 2 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Requirements
------------

* OS: MacOS, Linux
* Python: 3.9
* Python: 3.10

Install with pip
----------------
Expand All @@ -16,6 +16,6 @@ You can install Lotus using pip:

.. code-block:: console

$ conda create -n lotus python=3.9 -y
$ conda create -n lotus python=3.10 -y
$ conda activate lotus
$ pip install lotus-ai
18 changes: 9 additions & 9 deletions lotus/models/colbertv2_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

from lotus.models.rm import RM

Expand All @@ -8,9 +8,9 @@ class ColBERTv2Model(RM):
"""ColBERTv2 Model"""

def __init__(self, **kwargs):
self.docs: Optional[List[str]] = None
self.kwargs: Dict[str, Any] = {"doc_maxlen": 300, "nbits": 2, **kwargs}
self.index_dir: Optional[str] = None
self.docs: list[str] | None = None
self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2, **kwargs}
self.index_dir: str | None = None

from colbert import Indexer, Searcher
from colbert.infra import ColBERTConfig, Run, RunConfig
Expand All @@ -21,7 +21,7 @@ def __init__(self, **kwargs):
self.Run = Run
self.RunConfig = RunConfig

def index(self, docs: List[str], index_dir: str, **kwargs: Dict[str, Any]) -> None:
def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None:
kwargs = {**self.kwargs, **kwargs}
checkpoint = "colbert-ir/colbertv2.0"

Expand All @@ -41,15 +41,15 @@ def load_index(self, index_dir: str) -> None:
with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "rb") as fp:
self.docs = pickle.load(fp)

def get_vectors_from_index(self, index_dir: str, ids: List[int]) -> List:
def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list:
raise NotImplementedError("This method is not implemented for ColBERTv2Model")

def __call__(
self,
queries: Union[str, List[str], List[List[float]]],
queries: str | list[str] | list[list[float]],
k: int,
**kwargs: Dict[str, Any],
) -> Tuple[List[float], List[int]]:
**kwargs: dict[str, Any],
) -> tuple[list[float], list[int]]:
if isinstance(queries, str):
queries = [queries]

Expand Down
6 changes: 2 additions & 4 deletions lotus/models/cross_encoder_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List, Optional

import torch
from sentence_transformers import CrossEncoder

Expand All @@ -17,15 +15,15 @@ class CrossEncoderModel(Reranker):
def __init__(
self,
model: str = "mixedbread-ai/mxbai-rerank-large-v1",
device: Optional[str] = None,
device: str | None = None,
**kwargs,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.model = CrossEncoder(model, device=device, **kwargs)

def __call__(self, query: str, docs: List[str], k: int) -> List[int]:
def __call__(self, query: str, docs: list[str], k: int) -> list[int]:
results = self.model.rank(query, docs, top_k=k)
results = [result["corpus_id"] for result in results]
return results
37 changes: 19 additions & 18 deletions lotus/models/e5_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

import numpy as np
import torch
Expand All @@ -14,18 +14,18 @@
class E5Model(RM):
"""E5 retriever model"""

def __init__(self, model: str = "intfloat/e5-base-v2", device: Optional[str] = None, **kwargs):
def __init__(self, model: str = "intfloat/e5-base-v2", device: str | None = None, **kwargs: dict[str, Any]) -> None:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModel.from_pretrained(model).to(self.device)
self.faiss_index = None
self.index_dir = None
self.docs = None
self.kwargs = {"normalize": True, "index_type": "Flat", **kwargs}
self.batch_size = 100
self.vecs = None
self.index_dir: str | None = None
self.docs: list[str] | None = None
self.kwargs: dict[str, Any] = {"normalize": True, "index_type": "Flat", **kwargs}
self.batch_size: int = 100
self.vecs: np.ndarray[Any, np.dtype[np.float32]] | None = None

import faiss

Expand All @@ -45,7 +45,7 @@ def average_pool(self, last_hidden_states: torch.Tensor, attention_mask: torch.T
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def embed(self, docs: List[str], **kwargs: Dict[str, Any]) -> np.ndarray:
def embed(self, docs: list[str], **kwargs: dict[str, Any]) -> np.ndarray[Any, np.dtype[np.float32]]:
"""Run the embedding model.

Args:
Expand All @@ -55,10 +55,11 @@ def embed(self, docs: List[str], **kwargs: Dict[str, Any]) -> np.ndarray:
Embeddings of the documents.
"""

kwargs = {**self.kwargs, **kwargs}
kwargs = {**self.kwargs, **dict(kwargs)}

batch_size = kwargs.get("batch_size", self.batch_size)

assert isinstance(batch_size, int), "batch_size must be an integer"

# Calculating the embedding dimension
total_docs = len(docs)
first_batch = self.tokenizer(docs[:1], return_tensors="pt", padding=True, truncation=True).to(self.device)
Expand All @@ -79,7 +80,7 @@ def embed(self, docs: List[str], **kwargs: Dict[str, Any]) -> np.ndarray:

return embeddings.numpy(force=True)

def index(self, docs: List[str], index_dir: str, **kwargs: Dict[str, Any]) -> None:
def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None:
# Make index directory
os.makedirs(index_dir, exist_ok=True)

Expand Down Expand Up @@ -110,17 +111,17 @@ def load_index(self, index_dir: str) -> None:
self.vecs = pickle.load(fp)

@classmethod
def get_vectors_from_index(self, index_dir: str, ids: List[int]) -> List:
def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> list[np.ndarray[Any, np.dtype[np.float32]]]:
with open(f"{index_dir}/vecs", "rb") as fp:
vecs = pickle.load(fp)
vecs: np.ndarray[Any, np.dtype[np.float32]] = pickle.load(fp)

return vecs[ids]

def load_vecs(self, index_dir: str, ids: List[int]) -> List:
def load_vecs(self, index_dir: str, ids: list[int]) -> list:
"""loads vectors to the rm and returns them
Args:
index_dir (str): Directory of the index.
ids (List[int]): The ids of the vectors to retrieve
ids (list[int]): The ids of the vectors to retrieve

Returns:
The vectors matching the specified ids.
Expand All @@ -134,10 +135,10 @@ def load_vecs(self, index_dir: str, ids: List[int]) -> List:

def __call__(
self,
queries: Union[str, List[str], List[List[float]]],
queries: str | list[str] | list[list[float]],
k: int,
**kwargs: Dict[str, Any],
) -> Tuple[List[float], List[int]]:
**kwargs: dict[str, Any],
) -> tuple[list[float], list[int]]:
if isinstance(queries, str):
queries = [queries]

Expand Down
34 changes: 23 additions & 11 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple, Union
from typing import Any


class LM(ABC):
Expand All @@ -9,42 +9,54 @@ def _init__(self):
pass

@abstractmethod
def count_tokens(self, prompt: Union[str, list]) -> int:
def count_tokens(self, prompt: str | list) -> int:
"""
Counts the number of tokens in the given prompt.

Args:
prompt (Union[str, list]): The prompt to count tokens for. This can be a string or a list of messages.
prompt (str | list): The prompt to count tokens for. This can be a string or a list of messages.

Returns:
int: The number of tokens in the prompt.
"""
pass

def format_logprobs_for_cascade(self, logprobs: List) -> Tuple[List[List[str]], List[List[float]]]:
def format_logprobs_for_cascade(self, logprobs: list) -> tuple[list[list[str]], list[list[float]]]:
"""
Formats the logprobs for the cascade.

Args:
logprobs (List): The logprobs to format.
logprobs (list): The logprobs to format.

Returns:
Tuple[List[List[str]], List[List[float]]]: A tuple containing the tokens and their corresponding confidences.
tuple[list[list[str]], list[list[float]]]: A tuple containing the tokens and their corresponding confidences.
"""
pass

@abstractmethod
def __call__(
self, messages_batch: Union[List, List[List]], **kwargs: Dict[str, Any]
) -> Union[List, Tuple[List, List]]:
self, messages_batch: list | list[list], **kwargs: dict[str, Any]
) -> list[str] | tuple[list[str], list[dict[str, Any]]]:
"""Invoke the LLM.

Args:
messages_batch (Union[List, List[List]]): Either one prompt or a list of prompts in message format.
kwargs (Dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters.
messages_batch (list | list[list]): Either one prompt or a list of prompts in message format.
kwargs (dict[str, Any]): Additional keyword arguments. They can be used to specify inference parameters.

Returns:
Union[List, Tuple[List, List]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments,
list[str] | tuple[list[str], list[dict[str, Any]]]: A list of outputs for each prompt in the batch. If logprobs is specified in the keyword arguments,
then a list of logprobs is also returned.
"""
pass

@property
@abstractmethod
def max_ctx_len(self) -> int:
"""The maximum context length of the LLM."""
pass

@property
@abstractmethod
def max_tokens(self) -> int:
"""The maximum number of tokens that can be generated by the LLM."""
pass
Loading
Loading