-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3d1cca8
commit da28054
Showing
5 changed files
with
189 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
|
||
import pandas as pd | ||
import pytest | ||
|
||
import lotus | ||
from lotus.models import LM | ||
from lotus.dtype_extensions import ImageArray | ||
|
||
################################################################################ | ||
# Setup | ||
################################################################################ | ||
# Set logger level to DEBUG | ||
lotus.logger.setLevel("DEBUG") | ||
|
||
# Environment flags to enable/disable tests | ||
ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true" | ||
|
||
MODEL_NAME_TO_ENABLED = { | ||
"gpt-4o-mini": ENABLE_OPENAI_TESTS, | ||
"gpt-4o": ENABLE_OPENAI_TESTS, | ||
} | ||
ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled]) | ||
|
||
|
||
def get_enabled(*candidate_models: str) -> list[str]: | ||
return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def setup_models(): | ||
models = {} | ||
|
||
for model_path in ENABLED_MODEL_NAMES: | ||
models[model_path] = LM(model=model_path) | ||
|
||
return models | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def print_usage_after_each_test(setup_models): | ||
yield # this runs the test | ||
models = setup_models | ||
for model_name, model in models.items(): | ||
print(f"\nUsage stats for {model_name} after test:") | ||
model.print_total_usage() | ||
model.reset_stats() | ||
|
||
|
||
################################################################################ | ||
# Standard tests | ||
################################################################################ | ||
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) | ||
def test_filter_operation(setup_models, model): | ||
lm = setup_models[model] | ||
lotus.settings.configure(lm=lm) | ||
|
||
# Test filter operation on an easy dataframe | ||
image_url = [ | ||
"https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", | ||
"https://thumbs.dreamstime.com/b/comida-r%C3%A1pida-nachos-con-el-sause-del-tomate-ejemplo-exhausto-de-la-acuarela-mano-aislado-en-blanco-150936354.jpg", | ||
"https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1", | ||
"https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", | ||
"https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg" | ||
] | ||
df = pd.DataFrame({"image": ImageArray(image_url)}) | ||
user_instruction = "{image} represents food" | ||
filtered_df = df.sem_filter(user_instruction) | ||
|
||
assert image_url[1] in filtered_df["image"].values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,172 +1,101 @@ | ||
from pandas.api.extensions import ExtensionDtype, ExtensionArray | ||
import io | ||
import sys | ||
from typing import List, Optional, Union, Sequence | ||
import numpy as np | ||
from pandas.api.extensions import ExtensionArray, ExtensionDtype | ||
from PIL import Image | ||
import io | ||
import numpy as np | ||
|
||
from lotus.utils import fetch_image | ||
|
||
class ImageDtype(ExtensionDtype): | ||
"""Custom dtype for Images.""" | ||
name = 'image' | ||
type = Image.Image | ||
na_value = None | ||
|
||
@classmethod | ||
def construct_array_type(cls): | ||
return ImageArray | ||
|
||
|
||
class ImageArray(ExtensionArray): | ||
"""ExtensionArray for storing Images.""" | ||
|
||
def __init__(self, values): | ||
"""Initialize array with validation.""" | ||
values = self._validate_values(values) | ||
self._data = np.array(values, dtype=object) | ||
def __init__(self, values: Sequence[Optional[Image.Image]]): | ||
self._data = values | ||
self._dtype = ImageDtype() | ||
|
||
@staticmethod | ||
def _validate_values(values): | ||
"""Validate that all values are Images or None.""" | ||
if isinstance(values, (ImageArray, np.ndarray)): | ||
values = values.tolist() | ||
|
||
validated = [] | ||
for i, val in enumerate(values): | ||
if val is None: | ||
validated.append(None) | ||
elif isinstance(val, Image.Image): | ||
validated.append(val) | ||
else: | ||
raise TypeError( | ||
f"Value at index {i} has type {type(val).__name__}. " | ||
"Expected .Image.Image or None." | ||
) | ||
return validated | ||
|
||
|
||
def __getitem__(self, item: Union[int, slice, Sequence[int]]) -> Union[Image.Image, 'ImageArray']: | ||
if isinstance(item, (int, np.integer)): | ||
return fetch_image(self._data[item]) | ||
if isinstance(item, slice): | ||
return ImageArray(self._data[item]) | ||
return ImageArray([self._data[i] for i in item]) | ||
|
||
def __setitem__(self, key: Union[int, slice, Sequence[int]], value: Union[Image.Image, Sequence[Image.Image]]): | ||
if isinstance(key, (int, np.integer)): | ||
self._data[key] = value | ||
else: | ||
for i, k in enumerate(key): | ||
self._data[k] = value[i] | ||
|
||
def isna(self) -> np.ndarray: | ||
return np.array([img is None for img in self._data], dtype=bool) | ||
|
||
def take(self, indexer: Sequence[int], allow_fill: bool = False, fill_value: Optional[Image.Image] = None) -> 'ImageArray': | ||
if allow_fill: | ||
fill_value = fill_value if fill_value is not None else self.dtype.na_value | ||
result = [self._data[idx] if idx >= 0 else fill_value for idx in indexer] | ||
else: | ||
result = [self._data[idx] for idx in indexer] | ||
return ImageArray(result) | ||
|
||
def copy(self) -> 'ImageArray': | ||
return ImageArray([img.copy() if img and hasattr(img, "copy") else img for img in self._data]) | ||
|
||
@classmethod | ||
def _from_sequence(cls, scalars, dtype=None, copy=False): | ||
"""Create ImageArray from sequence of scalars.""" | ||
scalars = cls._validate_values(scalars) | ||
if copy: | ||
scalars = [img.copy() if img is not None else None for img in scalars] | ||
return cls(scalars) | ||
|
||
return cls([img.copy() if img and copy and hasattr(img, 'copy') else img for img in scalars]) | ||
|
||
@classmethod | ||
def _from_factorized(cls, values, original): | ||
"""Create ImageArray from factorized values.""" | ||
return original | ||
|
||
def __getitem__(self, item): | ||
"""Get item(s) from array.""" | ||
result = self._data[item] | ||
if isinstance(item, (int, np.integer)): | ||
return result | ||
return ImageArray(result) | ||
|
||
def __len__(self): | ||
"""Length of array.""" | ||
|
||
@classmethod | ||
def _concat_same_type(cls, to_concat: Sequence['ImageArray']) -> 'ImageArray': | ||
combined = [img for array in to_concat for img in array._data] | ||
return cls(combined) | ||
|
||
def __len__(self) -> int: | ||
return len(self._data) | ||
def __eq__(self, other): | ||
"""Equality comparison.""" | ||
|
||
def __eq__(self, other) -> np.ndarray: | ||
# check if other is iterable | ||
if isinstance(other, ImageArray): | ||
return np.array([ | ||
_compare_images(img1, img2) | ||
for img1, img2 in zip(self._data, other._data) | ||
]) | ||
elif isinstance(other, (Image.Image, type(None))): | ||
return np.array([ | ||
_compare_images(img, other) | ||
for img in self._data | ||
]) | ||
return NotImplemented | ||
|
||
def __setitem__(self, key, value): | ||
"""Set item(s) in array with validation.""" | ||
if isinstance(key, (int, np.integer)): | ||
if not (isinstance(value, Image.Image) or value is None): | ||
raise TypeError( | ||
f"Cannot set value of type {type(value).__name__}. " | ||
"Expected Image.Image or None." | ||
) | ||
self._data[key] = value | ||
else: | ||
value = self._validate_values(value) | ||
self._data[key] = value | ||
|
||
return np.array([_compare_images(img1, img2) for img1, img2 in zip(self._data, other._data)], dtype=bool) | ||
if hasattr(other, '__iter__') and not isinstance(other, str): | ||
return np.array([_compare_images(img1, img2) for img1, img2 in zip(self._data, other)], dtype=bool) | ||
return np.array([_compare_images(img, other) for img in self._data], dtype=bool) | ||
|
||
@property | ||
def dtype(self): | ||
"""Return the dtype object.""" | ||
def dtype(self) -> ImageDtype: | ||
return self._dtype | ||
|
||
@property | ||
def nbytes(self): | ||
"""Return number of bytes in memory.""" | ||
return sum( | ||
len(img_to_bytes(img)) if img is not None else 0 | ||
for img in self._data | ||
) | ||
|
||
def isna(self): | ||
"""Return boolean array indicating missing values.""" | ||
return np.array([img is None for img in self._data]) | ||
|
||
def take(self, indexer, allow_fill=False, fill_value=None): | ||
"""Take elements from array.""" | ||
if allow_fill: | ||
if fill_value is not None and not (isinstance(fill_value, Image.Image) or fill_value is None): | ||
raise TypeError( | ||
f"Fill value must be Image.Image or None, not {type(fill_value).__name__}" | ||
) | ||
if fill_value is None: | ||
fill_value = self.dtype.na_value | ||
|
||
result = np.array([ | ||
self._data[idx] if idx >= 0 else fill_value | ||
for idx in indexer | ||
]) | ||
else: | ||
result = self._data.take(indexer) | ||
|
||
return ImageArray(result) | ||
|
||
def copy(self): | ||
"""Return deep copy of array.""" | ||
return ImageArray([ | ||
img.copy() if img is not None else None | ||
for img in self._data | ||
]) | ||
|
||
@classmethod | ||
def _concat_same_type(cls, to_concat): | ||
"""Concatenate multiple arrays.""" | ||
return cls(np.concatenate([array._data for array in to_concat])) | ||
|
||
def interpolate(self, method='linear', axis=0, limit=None, inplace=False, | ||
limit_direction=None, limit_area=None, downcast=None, **kwargs): | ||
"""Interpolate missing values.""" | ||
return self.copy() if not inplace else self | ||
|
||
def _compare_images(img1, img2): | ||
"""Compare two Images for equality.""" | ||
if img1 is None and img2 is None: | ||
return True | ||
def nbytes(self) -> int: | ||
return sum(sys.getsizeof(img) for img in self._data if img) | ||
|
||
def __repr__(self) -> str: | ||
return f"ImageArray([{', '.join(['<Image>' if img is not None else 'None' for img in self._data[:5]])}, ...])" | ||
|
||
def _formatter(self, boxed: bool = False): | ||
return lambda x: '<Image>' if x is not None else 'None' | ||
|
||
|
||
def _compare_images(img1: Optional[Image.Image], img2: Optional[Image.Image]) -> bool: | ||
if img1 is None or img2 is None: | ||
return False | ||
if img1.size != img2.size: | ||
return False | ||
if img1.mode != img2.mode: | ||
return False | ||
return np.array_equal(np.array(img1), np.array(img2)) | ||
|
||
def img_to_bytes(img): | ||
"""Convert Image to bytes.""" | ||
if img is None: | ||
return None | ||
buf = io.BytesIO() | ||
img.save(buf, format=img.format or 'PNG') | ||
return buf.getvalue() | ||
|
||
def bytes_to_img(b): | ||
"""Convert bytes to Image.""" | ||
if b is None: | ||
return None | ||
return Image.open(io.BytesIO(b)) | ||
return img1 is img2 | ||
if isinstance(img1, Image.Image) or isinstance(img2, Image.Image): | ||
img1 = fetch_image(img1) | ||
img2 = fetch_image(img2) | ||
return (img1.size == img2.size and img1.mode == img2.mode and img1.tobytes() == img2.tobytes()) | ||
else: | ||
return img1 == img2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.