From c7ed69ad59bd012a829c318035f9997c7f510d2c Mon Sep 17 00:00:00 2001 From: Harshit Gupta <59705530+harshitgupta412@users.noreply.github.com> Date: Wed, 27 Nov 2024 19:25:44 -0800 Subject: [PATCH] Image Type and support in sem_operators (#33) Changes: - Added ImageArray type for storing images - loads the image lazily from the stored data. **Caution:** There are no explicit checks on what is stored; the only way to confirm whether the data is proper is by accessing it. - Updated sem_ops mentioned in the title to allow multimodal data. - Major updates in `lotus/templates/task_instructions.py` and retriever_models to support more general types. - Added `df2multimodal_info` to load multimodal data properly - Clip model can be used through sentence transformer. - updated the user prompts to send images properly (`user_message_formatter` and `context_formatter`) - index and __call__ can now take pd.Series or Images as input. - Tests in `.github/tests/multimodality_tests.py` - Example in `examples/op_examples/multimodal_ops` --------- Co-authored-by: Harshit Gupta Co-authored-by: liana313 --- .github/tests/multimodality_tests.py | 208 +++++++++++++++++ .github/workflows/tests.yml | 70 ++++++ examples/op_examples/multimodal_ops/filter.py | 23 ++ .../op_examples/multimodal_ops/images/0.png | Bin 0 -> 553 bytes .../op_examples/multimodal_ops/images/1.png | Bin 0 -> 353 bytes .../op_examples/multimodal_ops/images/4.png | Bin 0 -> 463 bytes .../op_examples/multimodal_ops/images/5.png | Bin 0 -> 549 bytes .../op_examples/multimodal_ops/images/9.png | Bin 0 -> 460 bytes examples/op_examples/multimodal_ops/join.py | 22 ++ examples/op_examples/multimodal_ops/map.py | 21 ++ examples/op_examples/multimodal_ops/topk.py | 21 ++ lotus/__init__.py | 3 +- lotus/dtype_extensions/__init__.py | 21 ++ lotus/dtype_extensions/image.py | 146 ++++++++++++ lotus/models/colbertv2_rm.py | 10 +- lotus/models/faiss_rm.py | 16 +- lotus/models/litellm_rm.py | 7 +- lotus/models/rm.py | 6 +- lotus/models/sentence_transformers_rm.py | 7 +- lotus/sem_ops/sem_extract.py | 10 +- lotus/sem_ops/sem_filter.py | 56 ++--- lotus/sem_ops/sem_index.py | 3 +- lotus/sem_ops/sem_join.py | 209 +++++++++--------- lotus/sem_ops/sem_map.py | 22 +- lotus/sem_ops/sem_sim_join.py | 27 +-- lotus/sem_ops/sem_topk.py | 38 ++-- lotus/templates/task_instructions.py | 203 +++++++++++------ lotus/utils.py | 53 +++++ 28 files changed, 936 insertions(+), 266 deletions(-) create mode 100644 .github/tests/multimodality_tests.py create mode 100644 examples/op_examples/multimodal_ops/filter.py create mode 100644 examples/op_examples/multimodal_ops/images/0.png create mode 100644 examples/op_examples/multimodal_ops/images/1.png create mode 100644 examples/op_examples/multimodal_ops/images/4.png create mode 100644 examples/op_examples/multimodal_ops/images/5.png create mode 100644 examples/op_examples/multimodal_ops/images/9.png create mode 100644 examples/op_examples/multimodal_ops/join.py create mode 100644 examples/op_examples/multimodal_ops/map.py create mode 100644 examples/op_examples/multimodal_ops/topk.py create mode 100644 lotus/dtype_extensions/__init__.py create mode 100644 lotus/dtype_extensions/image.py diff --git a/.github/tests/multimodality_tests.py b/.github/tests/multimodality_tests.py new file mode 100644 index 00000000..669b33bc --- /dev/null +++ b/.github/tests/multimodality_tests.py @@ -0,0 +1,208 @@ +import os + +import pandas as pd +import pytest + +import lotus +from lotus.dtype_extensions import ImageArray +from lotus.models import LM, SentenceTransformersRM + +################################################################################ +# Setup +################################################################################ +# Set logger level to DEBUG +lotus.logger.setLevel("DEBUG") + +# Environment flags to enable/disable tests +ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" +ENABLE_LOCAL_TESTS = os.getenv("ENABLE_LOCAL_TESTS", "false").lower() == "true" + +MODEL_NAME_TO_ENABLED = { + "gpt-4o-mini": ENABLE_OPENAI_TESTS, + "clip-ViT-B-32": ENABLE_LOCAL_TESTS, +} +ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + +MODEL_NAME_TO_CLS = { + "clip-ViT-B-32": SentenceTransformersRM, + "gpt-4o-mini": LM, +} + + +def get_enabled(*candidate_models: str) -> list[str]: + return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] + + +@pytest.fixture(scope="session") +def setup_models(): + models = {} + + for model_path in ENABLED_MODEL_NAMES: + models[model_path] = MODEL_NAME_TO_CLS[model_path](model=model_path) + + return models + + +@pytest.fixture(autouse=True) +def print_usage_after_each_test(setup_models): + yield # this runs the test + models = setup_models + for model_name, model in models.items(): + if not isinstance(model, LM): + continue + print(f"\nUsage stats for {model_name} after test:") + model.print_total_usage() + model.reset_stats() + + +################################################################################ +# Standard tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_filter_operation(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + image_url = [ + "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", + "https://thumbs.dreamstime.com/b/comida-r%C3%A1pida-nachos-con-el-sause-del-tomate-ejemplo-exhausto-de-la-acuarela-mano-aislado-en-blanco-150936354.jpg", + "https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1", + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + "https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg", + ] + df = pd.DataFrame({"image": ImageArray(image_url)}) + user_instruction = "{image} represents food" + filtered_df = df.sem_filter(user_instruction) + + expected_image_url = ImageArray( + [ + "https://thumbs.dreamstime.com/b/comida-r%C3%A1pida-nachos-con-el-sause-del-tomate-ejemplo-exhausto-de-la-acuarela-mano-aislado-en-blanco-150936354.jpg", + ] + ) + + assert expected_image_url == filtered_df["image"] + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_join_operation(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + image_url = [ + "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", + "https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1", + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + "https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg", + ] + elements = ["doll", "bird"] + image_df = pd.DataFrame({"image": ImageArray(image_url)}) + element_df = pd.DataFrame({"element": elements}) + user_instruction = "{image} contains {element}" + joined_df = image_df.sem_join(element_df, user_instruction) + + expected_result = [ + ("https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", "doll"), + ("https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", "bird"), + ] + + assert expected_result == list(zip(joined_df["image"], joined_df["element"])) + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_topk_operation(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + image_url = [ + "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", + "https://thumbs.dreamstime.com/b/comida-r%C3%A1pida-nachos-con-el-sause-del-tomate-ejemplo-exhausto-de-la-acuarela-mano-aislado-en-blanco-150936354.jpg", + "https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1", + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + "https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg", + ] + df = pd.DataFrame({"image": ImageArray(image_url)}) + user_instruction = "{image} represents living beings" + top_2_expected = set( + [ + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + "https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg", + ] + ) + + strategies = ["quick", "heap", "naive"] + for strategy in strategies: + sorted_df = df.sem_topk(user_instruction, K=2, strategy=strategy) + + top_2_actual = set(sorted_df["image"].values) + assert top_2_expected == top_2_actual + + +@pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) +def test_search_operation(setup_models, model): + rm = setup_models[model] + lotus.settings.configure(rm=rm) + + image_url = [ + "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", + "https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1", + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + "https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg", + ] + + expected_result = set(["https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg"]) + + df = pd.DataFrame({"image": ImageArray(image_url)}) + df = df.sem_index("image", "index_dir") + df = df.sem_search("image", "bird", K=1) + assert set(df["image"].values) == expected_result + + +@pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) +def test_sim_join_operation_image_index(setup_models, model): + rm = setup_models[model] + lotus.settings.configure(rm=rm) + + image_url = [ + "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", + "https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1", + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + "https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg", + ] + elements = ["doll", "bird"] + + image_df = pd.DataFrame({"image": ImageArray(image_url)}).sem_index("image", "index_dir") + element_df = pd.DataFrame({"element": elements}) + + joined_df = element_df.sem_sim_join(image_df, right_on="image", left_on="element", K=1) + + expected_result = [ + ("https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", "doll"), + ("https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", "bird"), + ] + assert expected_result == list(zip(joined_df["image"], joined_df["element"])) + + +@pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) +def test_sim_join_operation_text_index(setup_models, model): + rm = setup_models[model] + lotus.settings.configure(rm=rm) + + image_url = [ + "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + ] + elements = ["doll", "bird"] + + image_df = pd.DataFrame({"image": ImageArray(image_url)}) + element_df = pd.DataFrame({"element": elements}).sem_index("element", "index_dir") + + joined_df = image_df.sem_sim_join(element_df, left_on="image", right_on="element", K=1) + + expected_result = [ + ("https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", "doll"), + ("https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", "bird"), + ] + assert expected_result == list(zip(joined_df["image"], joined_df["element"])) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cace5f6f..796088bc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -39,6 +39,14 @@ jobs: - name: Checkout code uses: actions/checkout@v3 + - name: Cache pip dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Set up Python uses: actions/setup-python@v4 with: @@ -63,6 +71,14 @@ jobs: - name: Checkout code uses: actions/checkout@v3 + - name: Cache pip dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Set up Python uses: actions/setup-python@v4 with: @@ -93,6 +109,14 @@ jobs: - name: Checkout code uses: actions/checkout@v3 + - name: Cache pip dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Set up Python uses: actions/setup-python@v4 with: @@ -137,6 +161,14 @@ jobs: - name: Checkout code uses: actions/checkout@v3 + - name: Cache pip dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Set up Python uses: actions/setup-python@v4 with: @@ -158,3 +190,41 @@ jobs: ENABLE_OPENAI_TESTS: true ENABLE_LOCAL_TESTS: true run: pytest .github/tests/rm_tests.py + + multimodal_test: + name: Multimodality Tests + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Cache pip dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + pip install pytest + + - name: Set OpenAI API Key + run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV + + - name: Run Multimodality tests + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ENABLE_OPENAI_TESTS: true + run: pytest .github/tests/multimodality_tests.py \ No newline at end of file diff --git a/examples/op_examples/multimodal_ops/filter.py b/examples/op_examples/multimodal_ops/filter.py new file mode 100644 index 00000000..d2d2e597 --- /dev/null +++ b/examples/op_examples/multimodal_ops/filter.py @@ -0,0 +1,23 @@ +import os + +import pandas as pd + +import lotus +from lotus.dtype_extensions import ImageArray +from lotus.models import LM + +lotus.settings.configure(lm=LM(model="gpt-4o-mini")) + +# The images folder contain images representing digits taken from MNIST dataset +image_file_names = os.listdir("images") # get all file in the folder + +# file names are the same as the digit represented by image +labels = [os.path.splitext(image)[0] for image in image_file_names] +image_paths = [os.path.join("images", image) for image in image_file_names] + +df = pd.DataFrame({"image": ImageArray(image_paths), + "label": labels, + "image_path": image_paths}) + +df = df.sem_filter("{image} represents number 1") +print(df) diff --git a/examples/op_examples/multimodal_ops/images/0.png b/examples/op_examples/multimodal_ops/images/0.png new file mode 100644 index 0000000000000000000000000000000000000000..789ddac91a1a9cd091367b1914a1e7f3af10e4cd GIT binary patch literal 553 zcmV+^0@nSBP)fXnG7OYmUVrU&1QGI9U?XwjVHuQr>E0NQ52SC z!{M+j%et;xmW7D(`MlX|CK8FCpl(*H)oeDqEDggbl}gWXZWx9ckH@a!K5qc9ZTq_^ rKblY|R4f)RXti36L?SPVzj-n5{6$L5jYax400000NkvXXu0mjfi~s{F literal 0 HcmV?d00001 diff --git a/examples/op_examples/multimodal_ops/images/1.png b/examples/op_examples/multimodal_ops/images/1.png new file mode 100644 index 0000000000000000000000000000000000000000..e44e0c9cc7b83424faf62d8375a297655a5aad17 GIT binary patch literal 353 zcmV-n0iOPeP)#tAggiiPOo;45lsrLvL`7a84^R^kC~;9S5s6#XMnn`U zhG9=Eg@aSq&Y5z*-u-r9e+J+`pPb|INJLue?RLu}CP{L?-w~0BeBb|hVDSPfgjg<@ zZ_oBEVHi%+gos_&2_aUi)w`nSc}ginM8r6bIfeu{olZnVL{(Msh_-DL5hC7hw;%`@ zL@8z4wgyepltBc5!{J~$bUvRg%VH4s`~5u6h*;OP<2VeWl#-q)ih@A|fIQFfnayUy zAiAzQj-wIFvSbhe;Cj6p@p8GW*J~csbO;fnDB=-|qR?8KFA_p9h_34@rHJV9cw||| zus8r00000NkvXXu0mjf7%7|! literal 0 HcmV?d00001 diff --git a/examples/op_examples/multimodal_ops/images/4.png b/examples/op_examples/multimodal_ops/images/4.png new file mode 100644 index 0000000000000000000000000000000000000000..7d87808b56c627b668d4d9db7eb6875c8b12069c GIT binary patch literal 463 zcmV;=0WkiFP)MA5m69?P$)zQVeOmG=dD)DCK?C?Ow(*On{PTr#K#@4(r7dQU^bhvP*D^q zrS?&e#{&TUexHSAvsstRWmBY-3W5Lt>-Cz2o+o26nb<*VwHl>#I-TxzyT`ZO-Fh?{ zF)X*+oylYhg+e-=27pSXVi?AI)aUbEo}TEDT7He+zciQWl(n@3Z)DN1Ck7c zmFrKpE606@+xO*tKkpwhec$(+PCd{0b)KH{JOKRX??EgU6Te6#66tihTCHle+ES^+ z^E}J4$z<|2U-Fqquh&bZQisFgbUJYy_xXIc(bZ~|&1NBlhr?kw9Jbr-+bDL|Xti3K z&8AYRyc@ku$oQv;QX+i^cEv+iW(8L;?U9 zkH?)(=lQ1@jmBIqw_dM*kL>w;rfK>~om#Dq#bWdM{B{ZS`+cw1`;1PdQn}sk(P;Ex zr$7jw87+b!%H{HhE!FLINs^SwWDo1?c6+5#SuU5iy5sSfN~MGq^+VouA3_Kr6wZ^$ nq*kjTgff}TZnt~*f&VaXoA4WNq4hz900000NkvXXu0mjf6C(kC literal 0 HcmV?d00001 diff --git a/examples/op_examples/multimodal_ops/images/9.png b/examples/op_examples/multimodal_ops/images/9.png new file mode 100644 index 0000000000000000000000000000000000000000..405b2f66aa1f69755a8b66eabac052fd60577f2c GIT binary patch literal 460 zcmV;-0W35oBt1E%(iXM^AHgc8;yo(nzn5-#y%8H(>$F{ z0Dy?D>w2CC05lv9hr(Pq9M0$S*XtDk=JR>6Sd2s>i9}*DnGo7+Higkztwv2vr&C2y zs93F5skqm1g*y(f-5dixA{!5f3X|vf7TB%gTM9Z=e@qWKMjuVT;s2Gh#%jNRz zc0gQxulFJPH-$ps z`FtW`y list: + """ + Converts data to proper base data type. + - For original pandas data types, this is returns tolist(). + - For ImageDtype, this returns list of PIL.Image.Image. + """ + if isinstance(data, pd.Series): + if isinstance(data.dtype, ImageDtype): + return [data.array.get_image(i) for i in range(len(data))] + return data.tolist() + + return data + + +__all__ = ["ImageDtype", "ImageArray", "convert_to_base_data"] diff --git a/lotus/dtype_extensions/image.py b/lotus/dtype_extensions/image.py new file mode 100644 index 00000000..dd2a522c --- /dev/null +++ b/lotus/dtype_extensions/image.py @@ -0,0 +1,146 @@ +import sys +from typing import Sequence, Union + +import numpy as np +import pandas as pd +from pandas.api.extensions import ExtensionArray, ExtensionDtype +from PIL import Image + +from lotus.utils import fetch_image + + +class ImageDtype(ExtensionDtype): + name = "image" + type = Image.Image + na_value = None + + @classmethod + def construct_array_type(cls): + return ImageArray + + +class ImageArray(ExtensionArray): + def __init__(self, values): + self._data = np.asarray(values, dtype=object) + self._dtype = ImageDtype() + self.allowed_image_types = ["Image", "base64"] + self._cached_images: dict[tuple[int, str], str | Image.Image | None] = {} # Cache for loaded images + + def __getitem__(self, item: int | slice | Sequence[int]) -> np.ndarray: + result = self._data[item] + + if isinstance(item, (int, np.integer)): + # Return the raw value for display purposes + return result + + return ImageArray(result) + + def __setitem__(self, key, value) -> None: + """Set one or more values inplace, with cache invalidation.""" + if isinstance(key, np.ndarray): + if key.dtype == bool: + key = np.where(key)[0] + key = key.tolist() + if isinstance(key, (int, np.integer)): + key = [key] + if isinstance(key, slice): + key = range(*key.indices(len(self))) + if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): + for idx, val in zip(key, value): + self._data[idx] = val + self._invalidate_cache(idx) + else: + for idx in key: + self._data[idx] = value + self._invalidate_cache(idx) + + def _invalidate_cache(self, idx: int) -> None: + """Remove an item from the cache.""" + for image_type in self.allowed_image_types: + if (idx, image_type) in self._cached_images: + del self._cached_images[(idx, image_type)] + + def get_image(self, idx: int, image_type: str = "Image") -> Union[Image.Image, str, None]: + """Explicit method to fetch and return the actual image""" + if (idx, image_type) not in self._cached_images: + image_result = fetch_image(self._data[idx], image_type) + assert image_result is None or isinstance(image_result, (Image.Image, str)) + self._cached_images[(idx, image_type)] = image_result + return self._cached_images[(idx, image_type)] + + def isna(self) -> np.ndarray: + return pd.isna(self._data) + + def take(self, indices: Sequence[int], allow_fill: bool = False, fill_value=None) -> "ImageArray": + result = self._data.take(indices, axis=0) + if allow_fill and fill_value is not None: + result[indices == -1] = fill_value + return ImageArray(result) + + def copy(self) -> "ImageArray": + new_array = ImageArray(self._data.copy()) + new_array._cached_images = self._cached_images.copy() + return new_array + + @classmethod + def _from_sequence(cls, scalars, dtype=None, copy=False): + if copy: + scalars = np.array(scalars, dtype=object, copy=True) + return cls(scalars) + + def __len__(self) -> int: + return len(self._data) + + def __eq__(self, other) -> np.ndarray: # type: ignore + if isinstance(other, ImageArray): + return np.array([_compare_images(img1, img2) for img1, img2 in zip(self._data, other._data)], dtype=bool) + + if hasattr(other, "__iter__") and not isinstance(other, str): + if len(other) != len(self): + return np.repeat(False, len(self)) + return np.array([_compare_images(img1, img2) for img1, img2 in zip(self._data, other)], dtype=bool) + return np.array([_compare_images(img, other) for img in self._data], dtype=bool) + + @property + def dtype(self) -> ImageDtype: + return self._dtype + + @property + def nbytes(self) -> int: + return sum(sys.getsizeof(img) for img in self._data if img) + + def __repr__(self) -> str: + return f"ImageArray([{', '.join([f'' if img is not None else 'None' for img in self._data[:5]])}, ...])" + + def _formatter(self, boxed: bool = False): + return lambda x: f"" if x is not None else "None" + + def to_numpy(self, dtype=None, copy=False, na_value=None) -> np.ndarray: + """Convert the ImageArray to a numpy array.""" + pil_images = [] + for i, img_data in enumerate(self._data): + if isinstance(img_data, np.ndarray): + image = self.get_image(i) + pil_images.append(image) + else: + pil_images.append(img_data) + result = np.empty(len(self), dtype=object) + result[:] = pil_images + return result + + def __array__(self, dtype=None) -> np.ndarray: + """Numpy array interface.""" + return self.to_numpy(dtype=dtype) + + +def _compare_images(img1, img2) -> bool: + if img1 is None or img2 is None: + return img1 is img2 + + # Only fetch images when actually comparing + if isinstance(img1, Image.Image) or isinstance(img2, Image.Image): + img1 = fetch_image(img1) + img2 = fetch_image(img2) + return img1.size == img2.size and img1.mode == img2.mode and img1.tobytes() == img2.tobytes() + else: + return img1 == img2 diff --git a/lotus/models/colbertv2_rm.py b/lotus/models/colbertv2_rm.py index 018af594..2bd8ed99 100644 --- a/lotus/models/colbertv2_rm.py +++ b/lotus/models/colbertv2_rm.py @@ -2,7 +2,9 @@ from typing import Any import numpy as np +import pandas as pd from numpy.typing import NDArray +from PIL import Image from lotus.models.rm import RM from lotus.types import RMOutput @@ -20,14 +22,15 @@ def __init__(self) -> None: self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2} self.index_dir: str | None = None - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: + def index(self, docs: pd.Series, index_dir: str, **kwargs: dict[str, Any]) -> None: + _docs = docs.tolist() kwargs = {**self.kwargs, **kwargs} checkpoint = "colbert-ir/colbertv2.0" with Run().context(RunConfig(nranks=1, experiment="lotus")): config = ColBERTConfig(doc_maxlen=kwargs["doc_maxlen"], nbits=kwargs["nbits"], kmeans_niters=4) indexer = Indexer(checkpoint=checkpoint, config=config) - indexer.index(name=f"{index_dir}/index", collection=docs, overwrite=True) + indexer.index(name=f"{index_dir}/index", collection=_docs, overwrite=True) with open(f"experiments/lotus/indexes/{index_dir}/index/docs", "wb") as fp: pickle.dump(docs, fp) @@ -45,7 +48,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f def __call__( self, - queries: str | list[str] | NDArray[np.float64], + queries: str | Image.Image | list | NDArray[np.float64], K: int, **kwargs: dict[str, Any], ) -> RMOutput: @@ -56,6 +59,7 @@ def __call__( searcher = Searcher(index=f"{self.index_dir}/index", collection=self.docs) # make queries a dict with keys as query ids + assert isinstance(queries, list) queries_dict = {i: q for i, q in enumerate(queries)} all_results = searcher.search_all(queries_dict, k=K).todict() diff --git a/lotus/models/faiss_rm.py b/lotus/models/faiss_rm.py index 205129df..ace1fd92 100644 --- a/lotus/models/faiss_rm.py +++ b/lotus/models/faiss_rm.py @@ -5,7 +5,9 @@ import faiss import numpy as np +import pandas as pd from numpy.typing import NDArray +from PIL import Image from lotus.models.rm import RM from lotus.types import RMOutput @@ -20,7 +22,7 @@ def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODU self.faiss_index: faiss.Index | None = None self.vecs: NDArray[np.float64] | None = None - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: + def index(self, docs: pd.Series, index_dir: str, **kwargs: dict[str, Any]) -> None: vecs = self._embed(docs) self.faiss_index = faiss.index_factory(vecs.shape[1], self.factory_string, self.metric) self.faiss_index.add(vecs) @@ -42,12 +44,14 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f vecs: NDArray[np.float64] = pickle.load(fp) return vecs[ids] - def __call__(self, queries: str | list[str] | NDArray[np.float64], K: int, **kwargs: dict[str, Any]) -> RMOutput: - if isinstance(queries, str): + def __call__( + self, queries: pd.Series | str | Image.Image | list | NDArray[np.float64], K: int, **kwargs: dict[str, Any] + ) -> RMOutput: + if isinstance(queries, str) or isinstance(queries, Image.Image): queries = [queries] - if isinstance(queries[0], str): - embedded_queries = self._embed([str(q) for q in queries]) + if not isinstance(queries, np.ndarray): + embedded_queries = self._embed(queries) else: embedded_queries = np.asarray(queries, dtype=np.float32) @@ -58,5 +62,5 @@ def __call__(self, queries: str | list[str] | NDArray[np.float64], K: int, **kwa return RMOutput(distances=distances, indices=indices) @abstractmethod - def _embed(self, docs: list[str]) -> NDArray[np.float64]: + def _embed(self, docs: pd.Series | list) -> NDArray[np.float64]: pass diff --git a/lotus/models/litellm_rm.py b/lotus/models/litellm_rm.py index cadb4cf5..3c7a06e4 100644 --- a/lotus/models/litellm_rm.py +++ b/lotus/models/litellm_rm.py @@ -1,9 +1,11 @@ import faiss import numpy as np +import pandas as pd from litellm import embedding from litellm.types.utils import EmbeddingResponse from numpy.typing import NDArray +from lotus.dtype_extensions import convert_to_base_data from lotus.models.faiss_rm import FaissRM @@ -19,11 +21,12 @@ def __init__( self.model: str = model self.max_batch_size: int = max_batch_size - def _embed(self, docs: list[str]) -> NDArray[np.float64]: + def _embed(self, docs: pd.Series | list) -> NDArray[np.float64]: all_embeddings = [] for i in range(0, len(docs), self.max_batch_size): batch = docs[i : i + self.max_batch_size] - response: EmbeddingResponse = embedding(model=self.model, input=batch) + _batch = convert_to_base_data(batch) + response: EmbeddingResponse = embedding(model=self.model, input=_batch) embeddings = np.array([d["embedding"] for d in response.data]) all_embeddings.append(embeddings) return np.vstack(all_embeddings) diff --git a/lotus/models/rm.py b/lotus/models/rm.py index 330d7cd5..c58a0c38 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -2,7 +2,9 @@ from typing import Any import numpy as np +import pandas as pd from numpy.typing import NDArray +from PIL import Image from lotus.types import RMOutput @@ -14,7 +16,7 @@ def __init__(self) -> None: self.index_dir: str | None = None @abstractmethod - def index(self, docs: list[str], index_dir: str, **kwargs: dict[str, Any]) -> None: + def index(self, docs: pd.Series, index_dir: str, **kwargs: dict[str, Any]) -> None: """Create index and store it to a directory. Args: @@ -49,7 +51,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f @abstractmethod def __call__( self, - queries: str | list[str] | NDArray[np.float64], + queries: pd.Series | str | Image.Image | list | NDArray[np.float64], K: int, **kwargs: dict[str, Any], ) -> RMOutput: diff --git a/lotus/models/sentence_transformers_rm.py b/lotus/models/sentence_transformers_rm.py index bbcd36f9..1f86ce7e 100644 --- a/lotus/models/sentence_transformers_rm.py +++ b/lotus/models/sentence_transformers_rm.py @@ -1,9 +1,11 @@ import faiss import numpy as np +import pandas as pd import torch from numpy.typing import NDArray from sentence_transformers import SentenceTransformer +from lotus.dtype_extensions import convert_to_base_data from lotus.models.faiss_rm import FaissRM @@ -23,12 +25,13 @@ def __init__( self.normalize_embeddings: bool = normalize_embeddings self.transformer: SentenceTransformer = SentenceTransformer(model, device=device) - def _embed(self, docs: list[str]) -> NDArray[np.float64]: + def _embed(self, docs: pd.Series | list) -> NDArray[np.float64]: all_embeddings = [] for i in range(0, len(docs), self.max_batch_size): batch = docs[i : i + self.max_batch_size] + _batch = convert_to_base_data(batch) torch_embeddings = self.transformer.encode( - batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings + _batch, convert_to_tensor=True, normalize_embeddings=self.normalize_embeddings ) assert isinstance(torch_embeddings, torch.Tensor) cpu_embeddings = torch_embeddings.cpu().numpy() diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 35979fe7..515e56cb 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Any, Callable import pandas as pd @@ -11,7 +11,7 @@ def sem_extract( - docs: list[str], + docs: list[dict[str, Any]], model: LM, output_cols: dict[str, str | None], extract_quotes: bool = False, @@ -21,7 +21,7 @@ def sem_extract( Extracts attributes and values from a list of documents using a model. Args: - docs (list[str]): The list of documents to extract from. + docs (list[dict[str, Any]]): The list of documents to extract from. model (lotus.models.LM): The model to use. output_cols (dict[str, str | None]): A mapping from desired output column names to optional descriptions. extract_quotes (bool): Whether to extract quotes for the output columns. Defaults to False. @@ -88,10 +88,10 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"Column {column} not found in DataFrame") - docs = task_instructions.df2text(self._obj, input_cols) + multimodal_data = task_instructions.df2multimodal_info(self._obj, input_cols) out = sem_extract( - docs=docs, + docs=multimodal_data, model=lotus.settings.lm, output_cols=output_cols, extract_quotes=extract_quotes, diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index 1fced2d1..bad0d5e3 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -13,11 +13,11 @@ def sem_filter( - docs: list[str], + docs: list[dict[str, Any]], model: lotus.models.LM, user_instruction: str, default: bool = True, - examples_df_txt: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, strategy: str | None = None, @@ -27,11 +27,11 @@ def sem_filter( Filters a list of documents based on a given user instruction using a language model. Args: - docs (list[str]): The list of documents to filter. + docs (list[dict[str, Any]]): The list of documents to filter. Each document is a tuple of text and images. model (lotus.models.LM): The language model used for filtering. user_instruction (str): The user instruction for filtering. default (bool): The default value for filtering in case of parsing errors. Defaults to True. - examples_df_txt (list[str] | None): The text for examples. Defaults to None. + examples_multimodal_data (list[dict[str, Any]] | None): The text for examples. Defaults to None. examples_answers (list[bool] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. logprobs (bool): Whether to return log probabilities. Defaults to False. @@ -42,7 +42,7 @@ def sem_filter( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.filter_formatter( - doc, user_instruction, examples_df_txt, examples_answers, cot_reasoning, strategy + doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy ) lotus.logger.debug(f"input to model: {prompt}") inputs.append(prompt) @@ -60,7 +60,7 @@ def sem_filter( def learn_filter_cascade_thresholds( - sample_df_txt: list[str], + sample_multimodal_data: list[dict[str, Any]], lm: lotus.models.LM, formatted_usr_instr: str, default: bool, @@ -69,7 +69,7 @@ def learn_filter_cascade_thresholds( delta: float, helper_true_probs: list[float], sample_correction_factors: NDArray[np.float64], - examples_df_txt: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, strategy: str | None = None, @@ -80,11 +80,11 @@ def learn_filter_cascade_thresholds( try: large_outputs = sem_filter( - sample_df_txt, + sample_multimodal_data, lm, formatted_usr_instr, default=default, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, @@ -169,16 +169,16 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"Column {column} not found in DataFrame") - df_txt = task_instructions.df2text(self._obj, col_li) - lotus.logger.debug(df_txt) + multimodal_data = task_instructions.df2multimodal_info(self._obj, col_li) + lotus.logger.debug(multimodal_data) formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li) - examples_df_txt = None + examples_multimodal_data = None examples_answers = None cot_reasoning = None if examples is not None: assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" - examples_df_txt = task_instructions.df2text(examples, col_li) + examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() if strategy == "cot": @@ -188,12 +188,12 @@ def __call__( pos_cascade_threshold, neg_cascade_threshold = None, None if learn_cascade_threshold_sample_percentage is not None: # Get few-shot examples for small LM - helper_examples_df_txt = None + helper_examples_multimodal_data = None helper_examples_answers = None helper_cot_reasoning = None if helper_examples is not None: assert "Answer" in helper_examples.columns, "Answer must be a column in examples dataframe" - helper_examples_df_txt = task_instructions.df2text(helper_examples, col_li) + helper_examples_multimodal_data = task_instructions.df2multimodal_info(helper_examples, col_li) helper_examples_answers = helper_examples["Answer"].tolist() if helper_strategy == "cot": @@ -212,11 +212,11 @@ def __call__( # Run small LM and get logits helper_output = sem_filter( - df_txt, + multimodal_data, lotus.settings.helper_lm, formatted_usr_instr, default=default, - examples_df_txt=helper_examples_df_txt, + examples_multimodal_data=helper_examples_multimodal_data, examples_answers=helper_examples_answers, cot_reasoning=helper_cot_reasoning, logprobs=True, @@ -232,12 +232,12 @@ def __call__( helper_true_probs, learn_cascade_threshold_sample_percentage ) sample_df = self._obj.loc[sample_indices] - sample_df_txt = task_instructions.df2text(sample_df, col_li) + sample_multimodal_data = task_instructions.df2multimodal_info(sample_df, col_li) sample_helper_true_probs = [helper_true_probs[i] for i in sample_indices] sample_correction_factors = correction_factors[sample_indices] pos_cascade_threshold, neg_cascade_threshold = learn_filter_cascade_thresholds( - sample_df_txt=sample_df_txt, + sample_multimodal_data=sample_multimodal_data, lm=lotus.settings.lm, formatted_usr_instr=formatted_usr_instr, default=default, @@ -246,7 +246,7 @@ def __call__( delta=failure_probability / 2, helper_true_probs=sample_helper_true_probs, sample_correction_factors=sample_correction_factors, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, @@ -277,9 +277,9 @@ def __call__( lotus.logger.info(f"Num routed to smaller model: {len(high_conf_idxs)}") stats["num_routed_to_helper_model"] = len(high_conf_idxs) - outputs: list[bool] = [False] * len(df_txt) - raw_outputs: list[str] = [""] * len(df_txt) - explanations: list[str | None] = [None] * len(df_txt) + outputs: list[bool] = [False] * len(multimodal_data) + raw_outputs: list[str] = [""] * len(multimodal_data) + explanations: list[str | None] = [None] * len(multimodal_data) assert all(isinstance(x, str) for x in helper_output.explanations) or all( x is None for x in helper_output.explanations @@ -291,14 +291,14 @@ def __call__( # Send low confidence samples to large LM if any low_conf_idxs = sorted([i for i in range(len(helper_outputs)) if i not in high_conf_idxs]) - low_conf_df_txt = [df_txt[idx] for idx in low_conf_idxs] + low_conf_multimodal_data = [multimodal_data[idx] for idx in low_conf_idxs] if low_conf_idxs: large_output = sem_filter( - low_conf_df_txt, + low_conf_multimodal_data, lotus.settings.lm, formatted_usr_instr, default=default, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, @@ -314,11 +314,11 @@ def __call__( else: output = sem_filter( - df_txt, + multimodal_data, lotus.settings.lm, formatted_usr_instr, default=default, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, diff --git a/lotus/sem_ops/sem_index.py b/lotus/sem_ops/sem_index.py index a0c0955d..80e1cb45 100644 --- a/lotus/sem_ops/sem_index.py +++ b/lotus/sem_ops/sem_index.py @@ -30,10 +30,9 @@ def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame: Returns: pd.DataFrame: The DataFrame with the index directory saved. """ - documents = self._obj[col_name].tolist() rm = lotus.settings.rm if rm is None: raise AttributeError("Must set rm in lotus.settings") - rm.index(documents, index_dir) + rm.index(self._obj[col_name], index_dir) self._obj.attrs["index_dirs"][col_name] = index_dir return self._obj diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index e5760102..5248998f 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -19,7 +19,7 @@ def sem_join( col2_label: str, model: lotus.models.LM, user_instruction: str, - examples_df_txt: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, default: bool = True, @@ -37,7 +37,7 @@ def sem_join( col2_label (str): The label for the second column. model (lotus.models.LM): The model to use. user_instruction (str): The user instruction for join. - examples_df_txt (list[str] | None): The examples dataframe text. Defaults to None. + examples_multimodal_data (list[str] | None): The examples dataframe text. Defaults to None. examples_answers (list[bool] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. default (bool): The default value for the join in case of parsing errors. Defaults to True. @@ -51,15 +51,17 @@ def sem_join( join_results = [] + left_multimodal_data = task_instructions.df2multimodal_info(l1.to_frame(col1_label), [col1_label]) + right_multimodal_data = task_instructions.df2multimodal_info(l2.to_frame(col2_label), [col2_label]) # for i1 in enumerate(l1): - for id1, i1 in zip(ids1, l1): + for id1, i1 in zip(ids1, left_multimodal_data): # perform llm filter - modified_docs = l2.apply(lambda doc: f"{col1_label}: {i1}\n{col2_label}: {doc}") + modified_docs = task_instructions.merge_multimodal_info([i1], right_multimodal_data) output = sem_filter( modified_docs, model, user_instruction, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, @@ -104,7 +106,7 @@ def sem_join_cascade( precision_target: float, sampling_percentage: float = 0.1, failure_probability: float = 0.2, - examples_df_txt: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, map_instruction: str | None = None, map_examples: pd.DataFrame | None = None, @@ -114,7 +116,7 @@ def sem_join_cascade( ) -> SemanticJoinOutput: """ Joins two series using a cascade helper model and a large model. - + Args: l1 (pd.Series): The first series. l2 (pd.Series): The second series. @@ -127,20 +129,20 @@ def sem_join_cascade( precision_target (float): The target precision. sampling_percentage (float): The percentage of the data to sample. Defaults to 0.1. failure_probability (float): The failure probability. Defaults to 0.2. - examples_df_txt list[str] | None): The examples dataframe text. Defaults to None. + examples_multimodal_data (list[dict[str, Any]] | None): The examples multimodal data. Defaults to None. examples_answers (list[bool] | None): The answers for examples. Defaults to None. map_instruction (str | None): The map instruction. Defaults to None. map_examples (pd.DataFrame | None): The map examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. default (bool): The default value for the join in case of parsing errors. Defaults to True. strategy (str | None): The reasoning strategy. Defaults to None. - + Returns: SemanticJoinOutput: The join results, filter outputs, all raw outputs, all explanations, and stats. - + Note that filter_outputs, all_raw_outputs, and all_explanations are empty list because the helper model do not generate these outputs. - + SemanticJoinOutput.stats: join_resolved_by_helper_model: total number of join records resolved by the helper model join_helper_positive: number of high confidence positive results from the helper model @@ -168,26 +170,26 @@ def sem_join_cascade( user_instruction, sampling_percentage=sampling_percentage, failure_probability=failure_probability, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, map_instruction=map_instruction, map_examples=map_examples, cot_reasoning=cot_reasoning, default=default, strategy=strategy, - ) + ) num_helper = len(helper_high_conf) num_large = len(helper_low_conf) - + # Accept helper results with high confidence - join_results = [(row['_left_id'], row['_right_id'], None) for _, row in helper_high_conf.iterrows()] + join_results = [(row["_left_id"], row["_right_id"], None) for _, row in helper_high_conf.iterrows()] # Send low confidence rows to large LM for unique_l1 in helper_low_conf[col1_label].unique(): - unique_l1_id = helper_low_conf[helper_low_conf[col1_label] == unique_l1]['_left_id'].iloc[0] + unique_l1_id = helper_low_conf[helper_low_conf[col1_label] == unique_l1]["_left_id"].iloc[0] l2_for_l1 = helper_low_conf[helper_low_conf[col1_label] == unique_l1][col2_label] - l2_for_l1_index = helper_low_conf[helper_low_conf[col1_label] == unique_l1]['_right_id'] + l2_for_l1_index = helper_low_conf[helper_low_conf[col1_label] == unique_l1]["_right_id"] large_join_output = sem_join( pd.Series([unique_l1]), l2_for_l1, @@ -197,25 +199,27 @@ def sem_join_cascade( col2_label, lotus.settings.lm, user_instruction, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, strategy=strategy, ) - + join_results.extend(large_join_output.join_results) lotus.logger.debug(f"outputs: {filter_outputs}") lotus.logger.debug(f"explanations: {all_explanations}") # Log join cascade stats: - stats = {"join_resolved_by_helper_model": num_helper + num_helper_high_conf_neg, - "join_helper_positive": num_helper, - "join_helper_negative": num_helper_high_conf_neg, - "join_resolved_by_large_model": num_large, - "optimized_join_cost": join_optimization_cost, - "total_LM_calls": join_optimization_cost + num_large} + stats = { + "join_resolved_by_helper_model": num_helper + num_helper_high_conf_neg, + "join_helper_positive": num_helper, + "join_helper_negative": num_helper_high_conf_neg, + "join_resolved_by_large_model": num_large, + "optimized_join_cost": join_optimization_cost, + "total_LM_calls": join_optimization_cost + num_large, + } return SemanticJoinOutput( join_results=join_results, @@ -226,21 +230,16 @@ def sem_join_cascade( ) -def run_sem_sim_join( - l1: pd.Series, - l2: pd.Series, - col1_label: str, - col2_label: str -) -> pd.DataFrame: +def run_sem_sim_join(l1: pd.Series, l2: pd.Series, col1_label: str, col2_label: str) -> pd.DataFrame: """ Wrapper function to run sem_sim_join in sem_join then calibrate the scores for approximate join - + Args: l1 (pd.Series): The first series. l2 (pd.Series): The second series. col1_label (str): The label for the first column. col2_label (str): The label for the second column. - + Returns: pd.DataFrame: The similarity join results. """ @@ -257,15 +256,10 @@ def run_sem_sim_join( K = len(l2) * len(l1) # Run sem_sim_join as helper on the sampled data - out = l1_df.sem_sim_join( - l2_df, - left_on=col1_label, - right_on=col2_label, - K=K, - keep_index=True) + out = l1_df.sem_sim_join(l2_df, left_on=col1_label, right_on=col2_label, K=K, keep_index=True) # Correct helper scores - out['_scores'] = calibrate_sem_sim_join(out['_scores'].tolist()) + out["_scores"] = calibrate_sem_sim_join(out["_scores"].tolist()) return out @@ -274,18 +268,18 @@ def map_l1_to_l2( col1_label: str, col2_label: str, map_instruction: str | None = None, - map_examples: pd.DataFrame | None = None + map_examples: pd.DataFrame | None = None, ) -> tuple[pd.DataFrame, str]: """ Wrapper function to run sem_map in sem_join. - + Args: l1 (pd.Series): The first series. col1_label (str): The label for the first column. col2_label (str): The label for the second column. map_instruction (str): The map instruction. Defaults to None. map_examples (pd.DataFrame): The map examples. Defaults to None. - + Returns: tuple[pd.DataFrame, str]: The mapped DataFrame and the mapped column name. """ @@ -293,12 +287,12 @@ def map_l1_to_l2( real_left_on = col1_label.split(":left")[0] else: real_left_on = col1_label - + if ":right" in col2_label: real_right_on = col2_label.split(":right")[0] else: real_right_on = col2_label - + inst = "" if map_instruction: inst = map_instruction @@ -309,12 +303,9 @@ def map_l1_to_l2( # Transform l1 into DataFrame for sem_map l1_df = l1.to_frame(name=real_left_on) mapped_col1_name = f"_{col1_label}" - + # Map l1 to l2 - out = l1_df.sem_map( - inst, - suffix=mapped_col1_name, - examples=map_examples) + out = l1_df.sem_map(inst, suffix=mapped_col1_name, examples=map_examples) out = out.rename(columns={real_left_on: col1_label}) return out, mapped_col1_name @@ -330,7 +321,7 @@ def join_optimizer( user_instruction: str, sampling_percentage: float = 0.1, failure_probability: float = 0.2, - examples_df_txt: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, map_instruction: str | None = None, map_examples: pd.DataFrame | None = None, @@ -339,9 +330,9 @@ def join_optimizer( strategy: str | None = None, ) -> tuple[pd.DataFrame, pd.DataFrame, int, int]: """ - Find most cost-effective join plan between Search-Filter and Map-Search-Filter + Find most cost-effective join plan between Search-Filter and Map-Search-Filter while satisfying the recall and precision target. - + Args: recall_target (float): The target recall. precision_target (float): The target precision. @@ -352,20 +343,20 @@ def join_optimizer( user_instruction (str): The user instruction for join. sampling_percentage (float): The percentage of the data to sample. Defaults to 0.1. failure_probability (float): The failure probability. Defaults to 0.2. - examples_df_txt (list[str] | None): The examples dataframe text. Defaults to None. + examples_multimodal_data (list[dict[str, Any]] | None): The examples multimodal data. Defaults to None. examples_answers (list[bool] | None): The answers for examples. Defaults to None. map_instruction (str | None): The map instruction. Defaults to None. map_examples (pd.DataFrame | None): The map examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. default (bool): The default value for the join in case of parsing errors. Defaults to True. strategy (str | None): The reasoning strategy. Defaults to None. - + returns: tuple[pd.DataFrame, pd.DataFrame]: The high confidence and low confidence join results. int: The number of high confidence negative results. int: The number of LM calls from optimizing join plan. """ - + # Helper is currently default to similiarity join if lotus.settings.helper_lm is not None: lotus.logger.debug("Helper model is not supported yet. Default to similarity join.") @@ -373,64 +364,72 @@ def join_optimizer( # Learn search-filter thresholds sf_helper_join = run_sem_sim_join(l1, l2, col1_label, col2_label) sf_t_pos, sf_t_neg, sf_learn_cost = learn_join_cascade_threshold( - sf_helper_join, - recall_target, - precision_target, + sf_helper_join, + recall_target, + precision_target, col1_label, - col2_label, - user_instruction, + col2_label, + user_instruction, sampling_percentage, delta=failure_probability / 2, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy) - sf_high_conf = sf_helper_join[sf_helper_join['_scores'] >= sf_t_pos] - sf_high_conf_neg = len(sf_helper_join[sf_helper_join['_scores'] <= sf_t_neg]) - sf_low_conf = sf_helper_join[(sf_helper_join['_scores'] < sf_t_pos) & (sf_helper_join['_scores'] > sf_t_neg)] + strategy=strategy, + ) + sf_high_conf = sf_helper_join[sf_helper_join["_scores"] >= sf_t_pos] + sf_high_conf_neg = len(sf_helper_join[sf_helper_join["_scores"] <= sf_t_neg]) + sf_low_conf = sf_helper_join[(sf_helper_join["_scores"] < sf_t_pos) & (sf_helper_join["_scores"] > sf_t_neg)] sf_cost = len(sf_low_conf) # Learn map-search-filter thresholds - mapped_l1, mapped_col1_label = map_l1_to_l2(l1, col1_label, col2_label, map_instruction=map_instruction, map_examples=map_examples) + mapped_l1, mapped_col1_label = map_l1_to_l2( + l1, col1_label, col2_label, map_instruction=map_instruction, map_examples=map_examples + ) msf_helper_join = run_sem_sim_join(mapped_l1, l2, mapped_col1_label, col2_label) msf_t_pos, msf_t_neg, msf_learn_cost = learn_join_cascade_threshold( - msf_helper_join, - recall_target, - precision_target, + msf_helper_join, + recall_target, + precision_target, col1_label, col2_label, - user_instruction, + user_instruction, sampling_percentage, delta=failure_probability / 2, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, - strategy=strategy) - msf_high_conf = msf_helper_join[msf_helper_join['_scores'] >= msf_t_pos] - msf_high_conf_neg = len(msf_helper_join[msf_helper_join['_scores'] <= msf_t_neg]) - msf_low_conf = msf_helper_join[(msf_helper_join['_scores'] < msf_t_pos) & (msf_helper_join['_scores'] > msf_t_neg)] + strategy=strategy, + ) + msf_high_conf = msf_helper_join[msf_helper_join["_scores"] >= msf_t_pos] + msf_high_conf_neg = len(msf_helper_join[msf_helper_join["_scores"] <= msf_t_neg]) + msf_low_conf = msf_helper_join[(msf_helper_join["_scores"] < msf_t_pos) & (msf_helper_join["_scores"] > msf_t_neg)] msf_cost = len(msf_low_conf) - msf_learn_cost += len(l1) # cost from map l1 to l2 + msf_learn_cost += len(l1) # cost from map l1 to l2 # Select the cheaper join plan lotus.logger.info("Join Optimizer: plan cost analysis:") lotus.logger.info(f" Search-Filter: {sf_cost} LLM calls.") - lotus.logger.info(f" Search-Filter: accept {len(sf_high_conf)} helper positive results, {sf_high_conf_neg} helper negative results.") + lotus.logger.info( + f" Search-Filter: accept {len(sf_high_conf)} helper positive results, {sf_high_conf_neg} helper negative results." + ) lotus.logger.info(f" Map-Search-Filter: {msf_cost} LLM calls.") - lotus.logger.info(f" Map-Search-Filter: accept {len(msf_high_conf)} helper positive results, {msf_high_conf_neg} helper negative results.") + lotus.logger.info( + f" Map-Search-Filter: accept {len(msf_high_conf)} helper positive results, {msf_high_conf_neg} helper negative results." + ) learning_cost = sf_learn_cost + msf_learn_cost if sf_cost < msf_cost: lotus.logger.info("Proceeding with Search-Filter") - sf_high_conf = sf_high_conf.sort_values(by='_scores', ascending=False) - sf_low_conf = sf_low_conf.sort_values(by='_scores', ascending=False) + sf_high_conf = sf_high_conf.sort_values(by="_scores", ascending=False) + sf_low_conf = sf_low_conf.sort_values(by="_scores", ascending=False) return sf_high_conf, sf_low_conf, sf_high_conf_neg, learning_cost else: lotus.logger.info("Proceeding with Map-Search-Filter") - msf_high_conf = msf_high_conf.sort_values(by='_scores', ascending=False) - msf_low_conf = msf_low_conf.sort_values(by='_scores', ascending=False) + msf_high_conf = msf_high_conf.sort_values(by="_scores", ascending=False) + msf_low_conf = msf_low_conf.sort_values(by="_scores", ascending=False) return msf_high_conf, msf_low_conf, msf_high_conf_neg, learning_cost @@ -443,14 +442,14 @@ def learn_join_cascade_threshold( user_instruction: str, sampling_percentage: float = 0.1, delta: float = 0.2, - examples_df_txt: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[bool] | None = None, cot_reasoning: list[str] | None = None, default: bool = True, strategy: str | None = None, ) -> tuple[float, float, int]: """ - Extract a small sample of the data and find the optimal threshold pair that satisfies the recall and + Extract a small sample of the data and find the optimal threshold pair that satisfies the recall and precision target. Args: @@ -462,7 +461,7 @@ def learn_join_cascade_threshold( user_instruction (str): The user instruction for join. sampling_percentage (float): The percentage of the data to sample. Defaults to 0.1. delta (float): The failure probability. Defaults to 0.2. - examples_df_txt (list[str] | None): The examples dataframe text. Defaults to None. + examples_multimodal_data (list[dict[str, Any]] | None): The examples multimodal data. Defaults to None. examples_answers (list[bool] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. default (bool): The default value for the join in case of parsing errors. Defaults to True. @@ -471,25 +470,25 @@ def learn_join_cascade_threshold( tuple: The positive threshold, negative threshold, and the number of LM calls from learning thresholds. """ # Sample a small subset of the helper join result - helper_scores = helper_join['_scores'].tolist() - + helper_scores = helper_join["_scores"].tolist() + sample_indices, correction_factors = importance_sampling(helper_scores, sampling_percentage) lotus.logger.info(f"Sampled {len(sample_indices)} out of {len(helper_scores)} helper join results.") sample_df = helper_join.iloc[sample_indices] - sample_scores = sample_df['_scores'].tolist() + sample_scores = sample_df["_scores"].tolist() sample_correction_factors = correction_factors[sample_indices] col_li = [col1_label, col2_label] - sample_df_txt = task_instructions.df2text(sample_df, col_li) + sample_multimodal_data = task_instructions.df2multimodal_info(sample_df, col_li) try: output = sem_filter( - sample_df_txt, + sample_multimodal_data, lotus.settings.lm, user_instruction, default=default, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, @@ -501,7 +500,7 @@ def learn_join_cascade_threshold( sample_correction_factors=sample_correction_factors, recall_target=recall_target, precision_target=precision_target, - delta=delta + delta=delta, ) lotus.logger.info(f"Learned cascade thresholds: {(pos_threshold, neg_threshold)}") @@ -510,7 +509,7 @@ def learn_join_cascade_threshold( lotus.logger.error(f"Error while learning filter cascade thresholds: {e}") lotus.logger.error("Default to full join.") return 1.0, 0.0, len(sample_indices) - + return pos_threshold, neg_threshold, len(sample_indices) @@ -606,14 +605,12 @@ def __call__( assert left_on is not None, "Column not found in left dataframe" assert right_on is not None, "Column not found in right dataframe" - examples_df_txt = None + examples_multimodal_data = None examples_answers = None cot_reasoning = None if examples is not None: assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" - examples_df_txt = [] - for idx, row in examples.iterrows(): - examples_df_txt.append(f"{left_on}: {row[real_left_on]}\n{right_on}: {row[real_right_on]}") + examples_multimodal_data = task_instructions.df2multimodal_info(examples, [real_left_on, real_right_on]) examples_answers = examples["Answer"].tolist() if strategy == "cot": @@ -622,11 +619,15 @@ def __call__( num_full_join = len(self._obj) * len(other) - if (cascade_args is not None) and \ - (cascade_args.recall_target is not None or cascade_args.precision_target is not None) and \ - (num_full_join >= lotus.settings.min_join_cascade_size): + if ( + (cascade_args is not None) + and (cascade_args.recall_target is not None or cascade_args.precision_target is not None) + and (num_full_join >= lotus.settings.min_join_cascade_size) + ): cascade_args.recall_target = 1.0 if cascade_args.recall_target is None else cascade_args.recall_target - cascade_args.precision_target = 1.0 if cascade_args.precision_target is None else cascade_args.precision_target + cascade_args.precision_target = ( + 1.0 if cascade_args.precision_target is None else cascade_args.precision_target + ) output = sem_join_cascade( self._obj[real_left_on], other[real_right_on], @@ -639,7 +640,7 @@ def __call__( cascade_args.precision_target, sampling_percentage=cascade_args.sampling_percentage, failure_probability=cascade_args.failure_probability, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, map_instruction=cascade_args.map_instruction, map_examples=cascade_args.map_examples, @@ -657,7 +658,7 @@ def __call__( right_on, lotus.settings.lm, join_instruction, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, default=default, diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 9074c094..99d8a84d 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Any, Callable import pandas as pd @@ -10,11 +10,11 @@ def sem_map( - docs: list[str], + docs: list[dict[str, Any]], model: lotus.models.LM, user_instruction: str, postprocessor: Callable[[list[str], bool], SemanticMapPostprocessOutput] = map_postprocess, - examples_df_txt: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answers: list[str] | None = None, cot_reasoning: list[str] | None = None, strategy: str | None = None, @@ -23,11 +23,11 @@ def sem_map( Maps a list of documents to a list of outputs using a model. Args: - docs (list[str]): The list of documents to map. + docs (list[dict[str, Any]]): The list of documents to map. model (lotus.models.LM): The model to use. user_instruction (str): The user instruction for map. postprocessor (Callable): The postprocessor for the model outputs. Defaults to map_postprocess. - examples_df_txt (list[str] | None): The text for examples. Defaults to None. + examples_multimodal_data (list[dict[str, Any]] | None): The text for examples. Defaults to None. examples_answers (list[str] | None): The answers for examples. Defaults to None. cot_reasoning (list[str] | None): The reasoning for CoT. Defaults to None. @@ -38,7 +38,7 @@ def sem_map( inputs = [] for doc in docs: prompt = lotus.templates.task_instructions.map_formatter( - doc, user_instruction, examples_df_txt, examples_answers, cot_reasoning, strategy=strategy + doc, user_instruction, examples_multimodal_data, examples_answers, cot_reasoning, strategy=strategy ) lotus.logger.debug(f"input to model: {prompt}") lotus.logger.debug(f"inputs content to model: {[x.get('content') for x in prompt]}") @@ -101,15 +101,15 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"Column {column} not found in DataFrame") - df_txt = task_instructions.df2text(self._obj, col_li) + multimodal_data = task_instructions.df2multimodal_info(self._obj, col_li) formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li) - examples_df_txt = None + examples_multimodal_data = None examples_answers = None cot_reasoning = None if examples is not None: assert "Answer" in examples.columns, "Answer must be a column in examples dataframe" - examples_df_txt = task_instructions.df2text(examples, col_li) + examples_multimodal_data = task_instructions.df2multimodal_info(examples, col_li) examples_answers = examples["Answer"].tolist() if strategy == "cot": @@ -117,11 +117,11 @@ def __call__( cot_reasoning = examples["Reasoning"].tolist() output = sem_map( - df_txt, + multimodal_data, lotus.settings.lm, formatted_usr_instr, postprocessor=postprocessor, - examples_df_txt=examples_df_txt, + examples_multimodal_data=examples_multimodal_data, examples_answers=examples_answers, cot_reasoning=cot_reasoning, strategy=strategy, diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index debe2b7a..b1fd9860 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -64,9 +64,9 @@ def __call__( try: queries = rm.get_vectors_from_index(query_index_dir, self._obj.index) except NotImplementedError: - queries = self._obj[left_on].tolist() + queries = self._obj[left_on] else: - queries = self._obj[left_on].tolist() + queries = self._obj[left_on] # load index to search over try: @@ -95,19 +95,16 @@ def __call__( df1["_left_id"] = df1.index df2["_right_id"] = df2.index temp_df = pd.DataFrame(join_results, columns=["_left_id", "_right_id", "_scores" + score_suffix]) - joined_df = ( - df1.join( - temp_df.set_index("_left_id"), - how="right", - on="_left_id", - ) - .join( - df2.set_index("_right_id"), - how="left", - on="_right_id", - lsuffix=lsuffix, - rsuffix=rsuffix, - ) + joined_df = df1.join( + temp_df.set_index("_left_id"), + how="right", + on="_left_id", + ).join( + df2.set_index("_right_id"), + how="left", + on="_right_id", + lsuffix=lsuffix, + rsuffix=rsuffix, ) if not keep_index: joined_df.drop(columns=["_left_id", "_right_id"], inplace=True) diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index 1db8b514..0af1b475 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -11,7 +11,7 @@ def get_match_prompt_binary( - doc1: str, doc2: str, user_instruction: str, strategy: str | None = None + doc1: dict[str, Any], doc2: dict[str, Any], user_instruction: str, strategy: str | None = None ) -> list[dict[str, Any]]: if strategy == "zs-cot": sys_prompt = ( @@ -29,13 +29,13 @@ def get_match_prompt_binary( "NUMBER must be either 1 or 2, depending on which document is most relevant.\n" 'You must pick a number and cannot say things like "None" or "Neither"' ) - - prompt = f"Question: {user_instruction}\n\n" + prompt = [{"type": "text", "text": f"Question: {user_instruction}\n"}] for idx, doc in enumerate([doc1, doc2]): - prompt = f"{prompt}\nDocument {idx+1}:\n{doc}\n" + content_text, content_image_inputs = task_instructions.context_formatter(doc) + prompt += [{"type": "text", "text": f"\nDocument {idx+1}:\n{content_text}"}, *content_image_inputs] - messages = [{"role": "system", "content": sys_prompt}] - messages.append({"role": "user", "content": prompt}) + messages: list[dict[str, Any]] = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}] + lotus.logger.debug(f"Prompt: {messages}") return messages @@ -56,7 +56,7 @@ def parse_ans_binary(answer: str) -> bool: def compare_batch_binary( - pairs: list[tuple[str, str]], user_instruction: str, strategy: str | None = None + pairs: list[tuple[dict[str, Any], dict[str, Any]]], user_instruction: str, strategy: str | None = None ) -> tuple[list[bool], int]: match_prompts = [] tokens = 0 @@ -69,7 +69,7 @@ def compare_batch_binary( def compare_batch_binary_cascade( - pairs: list[tuple[str, str]], + pairs: list[tuple[dict[str, Any], dict[str, Any]]], user_instruction: str, cascade_threshold: float, strategy: str | None = None, @@ -118,7 +118,7 @@ def compare_batch_binary_cascade( def llm_naive_sort( - docs: list[str], + docs: list[dict[str, Any]], user_instruction: str, strategy: str | None = None, ) -> SemanticTopKOutput: @@ -157,7 +157,7 @@ def llm_naive_sort( def llm_quicksort( - docs: list[str], + docs: list[dict[str, Any]], user_instruction: str, K: int, embedding: bool = False, @@ -168,7 +168,7 @@ def llm_quicksort( Sorts the documents using quicksort. Args: - docs (list[str]): The list of documents to sort. + docs (list[dict[str, Any]]): The list of documents to sort. user_instruction (str): The user instruction for sorting. K (int): The number of documents to return. embedding (bool): Whether to use embedding optimization. @@ -257,7 +257,7 @@ class HeapDoc: total_tokens: int = 0 strategy: str | None = None - def __init__(self, doc: str, user_instruction: str, idx: int) -> None: + def __init__(self, doc: dict[str, Any], user_instruction: str, idx: int) -> None: self.doc = doc self.user_instruction = user_instruction self.idx = idx @@ -271,7 +271,7 @@ def __lt__(self, other: "HeapDoc") -> bool: def llm_heapsort( - docs: list[str], + docs: list[dict[str, Any]], user_instruction: str, K: int, strategy: str | None = None, @@ -280,7 +280,7 @@ def llm_heapsort( Sorts the documents using a heap. Args: - docs (list[str]): The list of documents to sort. + docs (list[dict[str, Any]]): The list of documents to sort. user_instruction (str): The user instruction for sorting. K (int): The number of documents to return. @@ -380,13 +380,13 @@ def __call__( col_name, user_instruction, len(self._obj) ) - df_txt = task_instructions.df2text(self._obj, col_li) - lotus.logger.debug(f"df_txt: {df_txt}") + multimodal_data = task_instructions.df2multimodal_info(self._obj, col_li) + lotus.logger.debug(f"multimodal_data: {multimodal_data}") formatted_usr_instr = lotus.nl_expression.nle2str(user_instruction, col_li) if method in ["quick", "quick-sem"]: output = llm_quicksort( - df_txt, + multimodal_data, formatted_usr_instr, K, embedding=method == "quick-sem", @@ -394,10 +394,10 @@ def __call__( cascade_threshold=cascade_threshold, ) elif method == "heap": - output = llm_heapsort(df_txt, formatted_usr_instr, K, strategy=strategy) + output = llm_heapsort(multimodal_data, formatted_usr_instr, K, strategy=strategy) elif method == "naive": output = llm_naive_sort( - df_txt, + multimodal_data, formatted_usr_instr, strategy=strategy, ) diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index 4fab675c..3804c32f 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -1,10 +1,61 @@ +from typing import Any + import pandas as pd +from lotus.dtype_extensions import ImageDtype + + +def context_formatter( + multimodal_data: dict[str, Any] | str, +) -> tuple[str, list[dict[str, str]]]: + if isinstance(multimodal_data, str): + text = multimodal_data + image_inputs: list[dict[str, str]] = [] + elif isinstance(multimodal_data, dict): + image_data: dict[str, str] = multimodal_data.get("image", {}) + _image_inputs: list[tuple[dict, dict]] = [ + ( + { + "type": "text", + "text": f"[{key.capitalize()}]: \n", + }, + { + "type": "image_url", + "image_url": {"url": base64_image}, + }, + ) + for key, base64_image in image_data.items() + ] + image_inputs = [m for image_input in _image_inputs for m in image_input] + text = multimodal_data["text"] or "" + else: + raise ValueError("multimodal_data must be a dictionary or a string") + return text, image_inputs + + +def user_message_formatter( + multimodal_data: dict[str, Any] | str, + user_instruction_with_tag: str | None = None, +) -> dict[str, Any]: + text, image_inputs = context_formatter(multimodal_data) + if not image_inputs or len(image_inputs) == 0: + return { + "role": "user", + "content": f"Context:\n{text}\n\n{user_instruction_with_tag}", + } + content = [{"type": "text", "text": f"Context:\n{text}"}] + image_inputs + if user_instruction_with_tag: + content.append({"type": "text", "text": f"\n\n{user_instruction_with_tag}"}) + return { + "role": "user", + "content": content, + } + def filter_formatter_cot( - df_text: str, + multimodal_data: dict[str, Any], user_instruction: str, - examples_df_text: list[str], + examples_multimodal_data: list[dict[str, Any]], examples_answer: list[bool], cot_reasoning: list[str], ) -> list[dict[str, str]]: @@ -17,16 +68,13 @@ def filter_formatter_cot( {"role": "system", "content": sys_instruction}, ] - for idx in range(len(examples_df_text)): - ex_df_txt = examples_df_text[idx] + for idx in range(len(examples_multimodal_data)): + ex_multimodal_data = examples_multimodal_data[idx] ex_ans = examples_answer[idx] cot = cot_reasoning[idx] messages.extend( [ - { - "role": "user", - "content": f"Context:\n{ex_df_txt}\n\nClaim: {user_instruction}", - }, + user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), { "role": "assistant", "content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", @@ -34,12 +82,12 @@ def filter_formatter_cot( ] ) - messages.append({"role": "user", "content": f"Context:\n{df_text}\n\nClaim: {user_instruction}"}) + messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) return messages def filter_formatter_zs_cot( - df_text: str, + multimodal_data: dict[str, Any], user_instruction: str, ) -> list[dict[str, str]]: sys_instruction = ( @@ -51,23 +99,25 @@ def filter_formatter_zs_cot( {"role": "system", "content": sys_instruction}, ] - messages.append({"role": "user", "content": f"Context:\n{df_text}\n\nClaim: {user_instruction}"}) + messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) return messages def filter_formatter( - df_text: str, + multimodal_data: dict[str, Any], user_instruction: str, - examples_df_text: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[bool] | None = None, cot_reasoning: list[str] | None = None, strategy: str | None = None, ) -> list[dict[str, str]]: if cot_reasoning: - assert examples_df_text is not None and examples_answer is not None - return filter_formatter_cot(df_text, user_instruction, examples_df_text, examples_answer, cot_reasoning) + assert examples_multimodal_data is not None and examples_answer is not None + return filter_formatter_cot( + multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning + ) elif strategy == "zs-cot": - return filter_formatter_zs_cot(df_text, user_instruction) + return filter_formatter_zs_cot(multimodal_data, user_instruction) sys_instruction = ( "The user will provide a claim and some relevant context.\n" @@ -78,27 +128,27 @@ def filter_formatter( {"role": "system", "content": sys_instruction}, ] - if examples_df_text: + if examples_multimodal_data: assert examples_answer is not None - for ex_df_txt, ex_ans in zip(examples_df_text, examples_answer): + assert isinstance(examples_multimodal_data, list) and isinstance(examples_answer, list) + for i in range(len(examples_multimodal_data)): + ex_multimodal_data = examples_multimodal_data[i] + ex_ans = examples_answer[i] messages.extend( [ - { - "role": "user", - "content": f"Context:\n{ex_df_txt}\n\nClaim: {user_instruction}", - }, + user_message_formatter(ex_multimodal_data, f"Claim: {user_instruction}"), {"role": "assistant", "content": str(ex_ans)}, ] ) - messages.append({"role": "user", "content": f"Context:\n{df_text}\n\nClaim: {user_instruction}"}) + messages.append(user_message_formatter(multimodal_data, f"Claim: {user_instruction}")) return messages def map_formatter_cot( - df_text: str, + multimodal_data: dict[str, Any], user_instruction: str, - examples_df_text: list[str], + examples_multimodal_data: list[dict[str, Any]], examples_answer: list[str], cot_reasoning: list[str], ) -> list[dict[str, str]]: @@ -111,16 +161,13 @@ def map_formatter_cot( {"role": "system", "content": sys_instruction}, ] - for idx in range(len(examples_df_text)): - ex_df_txt = examples_df_text[idx] + for idx in range(len(examples_multimodal_data)): + ex_df_txt = examples_multimodal_data[idx] ex_ans = examples_answer[idx] cot = cot_reasoning[idx] messages.extend( [ - { - "role": "user", - "content": f"Context:\n{ex_df_txt}\nInstruction: {user_instruction}", - }, + user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), { "role": "assistant", "content": f"Reasoning:\n{cot}\n\nAnswer: {ex_ans}", @@ -128,17 +175,12 @@ def map_formatter_cot( ] ) - messages.append( - { - "role": "user", - "content": f"Context:\n{df_text}\n\nInstruction: {user_instruction}", - } - ) + messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) return messages def map_formatter_zs_cot( - df_text: str, + multimodal_data: dict[str, Any], user_instruction: str, ) -> list[dict[str, str]]: sys_instruction = ( @@ -150,28 +192,25 @@ def map_formatter_zs_cot( {"role": "system", "content": sys_instruction}, ] - messages.append( - { - "role": "user", - "content": f"Context:\n{df_text}\nInstruction: {user_instruction}", - } - ) + messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) return messages def map_formatter( - df_text: str, + multimodal_data: dict[str, Any], user_instruction: str, - examples_df_text: list[str] | None = None, + examples_multimodal_data: list[dict[str, Any]] | None = None, examples_answer: list[str] | None = None, cot_reasoning: list[str] | None = None, strategy: str | None = None, ) -> list[dict[str, str]]: if cot_reasoning: - assert examples_df_text is not None and examples_answer is not None - return map_formatter_cot(df_text, user_instruction, examples_df_text, examples_answer, cot_reasoning) + assert examples_multimodal_data is not None and examples_answer is not None + return map_formatter_cot( + multimodal_data, user_instruction, examples_multimodal_data, examples_answer, cot_reasoning + ) elif strategy == "zs-cot": - return map_formatter_zs_cot(df_text, user_instruction) + return map_formatter_zs_cot(multimodal_data, user_instruction) sys_instruction = ( "The user will provide an instruction and some relevant context.\n" @@ -181,30 +220,22 @@ def map_formatter( {"role": "system", "content": sys_instruction}, ] - if examples_df_text: + if examples_multimodal_data: assert examples_answer is not None - for ex_df_txt, ex_ans in zip(examples_df_text, examples_answer): + for ex_df_txt, ex_ans in zip(examples_multimodal_data, examples_answer): messages.extend( [ - { - "role": "user", - "content": f"Context:\n{ex_df_txt}\n\nInstruction: {user_instruction}", - }, + user_message_formatter(ex_df_txt, f"Instruction: {user_instruction}"), {"role": "assistant", "content": str(ex_ans)}, ] ) - messages.append( - { - "role": "user", - "content": f"Context:\n{df_text}\n\nInstruction: {user_instruction}", - } - ) + messages.append(user_message_formatter(multimodal_data, f"Instruction: {user_instruction}")) return messages def extract_formatter( - df_text: str, output_cols: dict[str, str | None], extract_quotes: bool = True + multimodal_data: dict[str, Any], output_cols: dict[str, str | None], extract_quotes: bool = True ) -> list[dict[str, str]]: output_col_names = list(output_cols.keys()) # Set the description to be the key if no value is provided @@ -227,10 +258,7 @@ def extract_formatter( messages = [ {"role": "system", "content": sys_instruction}, - { - "role": "user", - "content": f"Context:\n{df_text}", - }, + user_message_formatter(multimodal_data), ] return messages @@ -244,9 +272,52 @@ def format_row(x: pd.Series, cols: list[str]) -> str: # take cols that are in df cols = [col for col in cols if col in df.columns] + if len(cols) == 0: + return [""] * len(df) formatted_rows: list[str] = df.apply(lambda x: format_row(x, cols), axis=1).tolist() return formatted_rows +def df2multimodal_info(df: pd.DataFrame, cols: list[str]) -> list[dict[str, Any]]: + """ + Formats the given DataFrame into a string containing info from cols. + Return a list of dictionaries, each containing text and image data. + """ + image_cols = [col for col in cols if isinstance(df[col].dtype, ImageDtype)] + text_cols = [col for col in cols if col not in image_cols] + text_rows = df2text(df, text_cols) + multimodal_data = [ + { + "text": text_rows[i], + "image": {col.capitalize(): df[col].array.get_image(i, "base64") for col in image_cols}, + } + for i in range(len(df)) + ] + return multimodal_data + + +def merge_multimodal_info(first: list[dict[str, Any]], second: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Merges two multimodal info lists into one. Each row of first is merged with each row of second. + + Args: + first: list of multimodal info dictionaries + second: list of multimodal info dictionaries + + Returns: + list of merged multimodal info dictionaries + """ + return [ + { + "text": f"{first[i]['text']}\n{second[j]['text']}" + if first[i]["text"] != "" and second[j]["text"] != "" + else first[i]["text"] + second[j]["text"], + "image": {**first[i]["image"], **second[j]["image"]}, + } + for i in range(len(first)) + for j in range(len(second)) + ] + + def li2text(li: list[str], name: str) -> str: return "".join([f"[{name}] {li[i]}\n" for i in range(len(li))]) diff --git a/lotus/utils.py b/lotus/utils.py index b53b8cce..1f86347c 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -1,6 +1,11 @@ +import base64 +from io import BytesIO from typing import Callable +import numpy as np import pandas as pd +import requests # type: ignore +from PIL import Image import lotus @@ -53,3 +58,51 @@ def ret( return list(map(int, indices.flatten().tolist())) return ret + + +def fetch_image(image: str | np.ndarray | Image.Image | None, image_type: str = "Image") -> Image.Image | str | None: + if image is None: + return None + + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif isinstance(image, np.ndarray): + image_obj = Image.fromarray(image.astype("uint8")) + elif image.startswith("http://") or image.startswith("https://"): + image_obj = Image.open(requests.get(image, stream=True).raw) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + image_obj = Image.open(BytesIO(data)) + elif image.startswith("s3://"): + from botocore.exceptions import NoCredentialsError, PartialCredentialsError + + try: + import boto3 + + s3 = boto3.client("s3") + bucket_name, key = image[5:].split("/", 1) # Split after "s3://" + response = s3.get_object(Bucket=bucket_name, Key=key) + image_data = response["Body"].read() + image_obj = Image.open(BytesIO(image_data)) + except (NoCredentialsError, PartialCredentialsError) as e: + raise ValueError("AWS credentials not found or incomplete.") from e + except Exception as e: + raise ValueError(f"Failed to fetch image from S3: {e}") from e + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError( + f"Unrecognized image input, support local path, http url, base64, S3, and PIL.Image, got {image}" + ) + image_obj = image_obj.convert("RGB") + if image_type == "base64": + buffered = BytesIO() + image_obj.save(buffered, format="PNG") + return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") + + return image_obj