diff --git a/.github/tests/multimodality_tests.py b/.github/tests/multimodality_tests.py new file mode 100644 index 00000000..390718ac --- /dev/null +++ b/.github/tests/multimodality_tests.py @@ -0,0 +1,70 @@ +import os + +import pandas as pd +import pytest + +import lotus +from lotus.models import LM +from lotus.dtype_extensions import ImageArray + +################################################################################ +# Setup +################################################################################ +# Set logger level to DEBUG +lotus.logger.setLevel("DEBUG") + +# Environment flags to enable/disable tests +ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" + +MODEL_NAME_TO_ENABLED = { + "gpt-4o-mini": ENABLE_OPENAI_TESTS, + "gpt-4o": ENABLE_OPENAI_TESTS, +} +ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) + + +def get_enabled(*candidate_models: str) -> list[str]: + return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] + + +@pytest.fixture(scope="session") +def setup_models(): + models = {} + + for model_path in ENABLED_MODEL_NAMES: + models[model_path] = LM(model=model_path) + + return models + + +@pytest.fixture(autouse=True) +def print_usage_after_each_test(setup_models): + yield # this runs the test + models = setup_models + for model_name, model in models.items(): + print(f"\nUsage stats for {model_name} after test:") + model.print_total_usage() + model.reset_stats() + + +################################################################################ +# Standard tests +################################################################################ +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_filter_operation(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + # Test filter operation on an easy dataframe + image_url = [ + "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", + "https://thumbs.dreamstime.com/b/comida-r%C3%A1pida-nachos-con-el-sause-del-tomate-ejemplo-exhausto-de-la-acuarela-mano-aislado-en-blanco-150936354.jpg", + "https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1", + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + "https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg" + ] + df = pd.DataFrame({"image": ImageArray(image_url)}) + user_instruction = "{image} represents food" + filtered_df = df.sem_filter(user_instruction) + + assert image_url[1] in filtered_df["image"].values \ No newline at end of file diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 07a9f3ea..29d6cd9c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -158,3 +158,33 @@ jobs: ENABLE_OPENAI_TESTS: true ENABLE_LOCAL_TESTS: true run: pytest .github/tests/rm_tests.py + + multimodal_test: + name: Multimodality Tests + runs-on: ubuntu-latest + timeout-minutes: 5 + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -e . + pip install pytest + + - name: Set OpenAI API Key + run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV + + - name: Run Multimodality tests + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ENABLE_OPENAI_TESTS: true + run: pytest .github/tests/multimodality_tests.py \ No newline at end of file diff --git a/lotus/dtype_extensions/image.py b/lotus/dtype_extensions/image.py index 46b3210b..410930d0 100644 --- a/lotus/dtype_extensions/image.py +++ b/lotus/dtype_extensions/image.py @@ -1,172 +1,101 @@ -from pandas.api.extensions import ExtensionDtype, ExtensionArray +import io +import sys +from typing import List, Optional, Union, Sequence import numpy as np +from pandas.api.extensions import ExtensionArray, ExtensionDtype from PIL import Image -import io +import numpy as np + +from lotus.utils import fetch_image class ImageDtype(ExtensionDtype): - """Custom dtype for Images.""" name = 'image' type = Image.Image na_value = None - + @classmethod def construct_array_type(cls): return ImageArray + class ImageArray(ExtensionArray): - """ExtensionArray for storing Images.""" - - def __init__(self, values): - """Initialize array with validation.""" - values = self._validate_values(values) - self._data = np.array(values, dtype=object) + def __init__(self, values: Sequence[Optional[Image.Image]]): + self._data = values self._dtype = ImageDtype() - - @staticmethod - def _validate_values(values): - """Validate that all values are Images or None.""" - if isinstance(values, (ImageArray, np.ndarray)): - values = values.tolist() - - validated = [] - for i, val in enumerate(values): - if val is None: - validated.append(None) - elif isinstance(val, Image.Image): - validated.append(val) - else: - raise TypeError( - f"Value at index {i} has type {type(val).__name__}. " - "Expected .Image.Image or None." - ) - return validated - + + def __getitem__(self, item: Union[int, slice, Sequence[int]]) -> Union[Image.Image, 'ImageArray']: + if isinstance(item, (int, np.integer)): + return fetch_image(self._data[item]) + if isinstance(item, slice): + return ImageArray(self._data[item]) + return ImageArray([self._data[i] for i in item]) + + def __setitem__(self, key: Union[int, slice, Sequence[int]], value: Union[Image.Image, Sequence[Image.Image]]): + if isinstance(key, (int, np.integer)): + self._data[key] = value + else: + for i, k in enumerate(key): + self._data[k] = value[i] + + def isna(self) -> np.ndarray: + return np.array([img is None for img in self._data], dtype=bool) + + def take(self, indexer: Sequence[int], allow_fill: bool = False, fill_value: Optional[Image.Image] = None) -> 'ImageArray': + if allow_fill: + fill_value = fill_value if fill_value is not None else self.dtype.na_value + result = [self._data[idx] if idx >= 0 else fill_value for idx in indexer] + else: + result = [self._data[idx] for idx in indexer] + return ImageArray(result) + + def copy(self) -> 'ImageArray': + return ImageArray([img.copy() if img and hasattr(img, "copy") else img for img in self._data]) + @classmethod def _from_sequence(cls, scalars, dtype=None, copy=False): - """Create ImageArray from sequence of scalars.""" - scalars = cls._validate_values(scalars) - if copy: - scalars = [img.copy() if img is not None else None for img in scalars] - return cls(scalars) - + return cls([img.copy() if img and copy and hasattr(img, 'copy') else img for img in scalars]) + @classmethod def _from_factorized(cls, values, original): - """Create ImageArray from factorized values.""" return original - - def __getitem__(self, item): - """Get item(s) from array.""" - result = self._data[item] - if isinstance(item, (int, np.integer)): - return result - return ImageArray(result) - - def __len__(self): - """Length of array.""" + + @classmethod + def _concat_same_type(cls, to_concat: Sequence['ImageArray']) -> 'ImageArray': + combined = [img for array in to_concat for img in array._data] + return cls(combined) + + def __len__(self) -> int: return len(self._data) - - def __eq__(self, other): - """Equality comparison.""" + + def __eq__(self, other) -> np.ndarray: + # check if other is iterable if isinstance(other, ImageArray): - return np.array([ - _compare_images(img1, img2) - for img1, img2 in zip(self._data, other._data) - ]) - elif isinstance(other, (Image.Image, type(None))): - return np.array([ - _compare_images(img, other) - for img in self._data - ]) - return NotImplemented - - def __setitem__(self, key, value): - """Set item(s) in array with validation.""" - if isinstance(key, (int, np.integer)): - if not (isinstance(value, Image.Image) or value is None): - raise TypeError( - f"Cannot set value of type {type(value).__name__}. " - "Expected Image.Image or None." - ) - self._data[key] = value - else: - value = self._validate_values(value) - self._data[key] = value - + return np.array([_compare_images(img1, img2) for img1, img2 in zip(self._data, other._data)], dtype=bool) + if hasattr(other, '__iter__') and not isinstance(other, str): + return np.array([_compare_images(img1, img2) for img1, img2 in zip(self._data, other)], dtype=bool) + return np.array([_compare_images(img, other) for img in self._data], dtype=bool) + @property - def dtype(self): - """Return the dtype object.""" + def dtype(self) -> ImageDtype: return self._dtype - + @property - def nbytes(self): - """Return number of bytes in memory.""" - return sum( - len(img_to_bytes(img)) if img is not None else 0 - for img in self._data - ) - - def isna(self): - """Return boolean array indicating missing values.""" - return np.array([img is None for img in self._data]) - - def take(self, indexer, allow_fill=False, fill_value=None): - """Take elements from array.""" - if allow_fill: - if fill_value is not None and not (isinstance(fill_value, Image.Image) or fill_value is None): - raise TypeError( - f"Fill value must be Image.Image or None, not {type(fill_value).__name__}" - ) - if fill_value is None: - fill_value = self.dtype.na_value - - result = np.array([ - self._data[idx] if idx >= 0 else fill_value - for idx in indexer - ]) - else: - result = self._data.take(indexer) - - return ImageArray(result) - - def copy(self): - """Return deep copy of array.""" - return ImageArray([ - img.copy() if img is not None else None - for img in self._data - ]) - - @classmethod - def _concat_same_type(cls, to_concat): - """Concatenate multiple arrays.""" - return cls(np.concatenate([array._data for array in to_concat])) - - def interpolate(self, method='linear', axis=0, limit=None, inplace=False, - limit_direction=None, limit_area=None, downcast=None, **kwargs): - """Interpolate missing values.""" - return self.copy() if not inplace else self - -def _compare_images(img1, img2): - """Compare two Images for equality.""" - if img1 is None and img2 is None: - return True + def nbytes(self) -> int: + return sum(sys.getsizeof(img) for img in self._data if img) + + def __repr__(self) -> str: + return f"ImageArray([{', '.join(['' if img is not None else 'None' for img in self._data[:5]])}, ...])" + + def _formatter(self, boxed: bool = False): + return lambda x: '' if x is not None else 'None' + + +def _compare_images(img1: Optional[Image.Image], img2: Optional[Image.Image]) -> bool: if img1 is None or img2 is None: - return False - if img1.size != img2.size: - return False - if img1.mode != img2.mode: - return False - return np.array_equal(np.array(img1), np.array(img2)) - -def img_to_bytes(img): - """Convert Image to bytes.""" - if img is None: - return None - buf = io.BytesIO() - img.save(buf, format=img.format or 'PNG') - return buf.getvalue() - -def bytes_to_img(b): - """Convert bytes to Image.""" - if b is None: - return None - return Image.open(io.BytesIO(b)) \ No newline at end of file + return img1 is img2 + if isinstance(img1, Image.Image) or isinstance(img2, Image.Image): + img1 = fetch_image(img1) + img2 = fetch_image(img2) + return (img1.size == img2.size and img1.mode == img2.mode and img1.tobytes() == img2.tobytes()) + else: + return img1 == img2 diff --git a/lotus/templates/task_instructions.py b/lotus/templates/task_instructions.py index f8afa95f..c889c799 100644 --- a/lotus/templates/task_instructions.py +++ b/lotus/templates/task_instructions.py @@ -10,7 +10,7 @@ def filter_user_message_formatter( text = multimodal_data image_inputs = [] - if isinstance(multimodal_data, list): + if isinstance(multimodal_data, dict): _image_inputs = [ [{ "type": "text", @@ -19,13 +19,13 @@ def filter_user_message_formatter( { "type": "image_url", "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" + "url": base64_image }, }] for key, base64_image in multimodal_data["image"].items() ] image_inputs = [m for image_input in _image_inputs for m in image_input] - text = multimodal_data["text"] + text = multimodal_data["text"] or "" return { "role": "user", @@ -261,6 +261,8 @@ def format_row(x: pd.Series, cols: list[str]) -> str: # take cols that are in df cols = [col for col in cols if col in df.columns] + if len(cols) == 0: + return [""]*len(df) formatted_rows: list[str] = df.apply(lambda x: format_row(x, cols), axis=1).tolist() return formatted_rows diff --git a/lotus/utils.py b/lotus/utils.py index 6d956016..7e9ce345 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -57,7 +57,7 @@ def ret( return ret -def fetch_image(image: str | Image.Image, size_factor: int = 28, image_type: str = "Image") -> Image.Image: +def fetch_image(image: str | Image.Image | None, size_factor: int = 28, image_type: str = "Image") -> Image.Image | None: """ Fetches an image from the internet or loads it from a file. @@ -69,13 +69,17 @@ def fetch_image(image: str | Image.Image, size_factor: int = 28, image_type: str Returns: Image.Image: The image. """ + + if image is None: + return None + assert image_type in ["Image", "base64"], f"image_type must be Image or base64, got {image_type}" - + image = qwen_vl_utils.fetch_image({"image": image}, size_factor) if image_type == "base64": buffered = BytesIO() image.save(buffered, format="PNG") - return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode() + return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") return image