Skip to content

Commit

Permalink
lazy loading of images
Browse files Browse the repository at this point in the history
  • Loading branch information
harshitgupta412 committed Nov 13, 2024
1 parent 3d1cca8 commit da28054
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 154 deletions.
70 changes: 70 additions & 0 deletions .github/tests/multimodality_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os

import pandas as pd
import pytest

import lotus
from lotus.models import LM
from lotus.dtype_extensions import ImageArray

################################################################################
# Setup
################################################################################
# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")

# Environment flags to enable/disable tests
ENABLE_OPENAI_TESTS = os.getenv("ENABLE_OPENAI_TESTS", "false").lower() == "true"

MODEL_NAME_TO_ENABLED = {
"gpt-4o-mini": ENABLE_OPENAI_TESTS,
"gpt-4o": ENABLE_OPENAI_TESTS,
}
ENABLED_MODEL_NAMES = set([model_name for model_name, is_enabled in MODEL_NAME_TO_ENABLED.items() if is_enabled])


def get_enabled(*candidate_models: str) -> list[str]:
return [model for model in candidate_models if model in ENABLED_MODEL_NAMES]


@pytest.fixture(scope="session")
def setup_models():
models = {}

for model_path in ENABLED_MODEL_NAMES:
models[model_path] = LM(model=model_path)

return models


@pytest.fixture(autouse=True)
def print_usage_after_each_test(setup_models):
yield # this runs the test
models = setup_models
for model_name, model in models.items():
print(f"\nUsage stats for {model_name} after test:")
model.print_total_usage()
model.reset_stats()


################################################################################
# Standard tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_filter_operation(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

# Test filter operation on an easy dataframe
image_url = [
"https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0",
"https://thumbs.dreamstime.com/b/comida-r%C3%A1pida-nachos-con-el-sause-del-tomate-ejemplo-exhausto-de-la-acuarela-mano-aislado-en-blanco-150936354.jpg",
"https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1",
"https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg",
"https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg"
]
df = pd.DataFrame({"image": ImageArray(image_url)})
user_instruction = "{image} represents food"
filtered_df = df.sem_filter(user_instruction)

assert image_url[1] in filtered_df["image"].values
30 changes: 30 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,33 @@ jobs:
ENABLE_OPENAI_TESTS: true
ENABLE_LOCAL_TESTS: true
run: pytest .github/tests/rm_tests.py

multimodal_test:
name: Multimodality Tests
runs-on: ubuntu-latest
timeout-minutes: 5

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
pip install pytest
- name: Set OpenAI API Key
run: echo "OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> $GITHUB_ENV

- name: Run Multimodality tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ENABLE_OPENAI_TESTS: true
run: pytest .github/tests/multimodality_tests.py
225 changes: 77 additions & 148 deletions lotus/dtype_extensions/image.py
Original file line number Diff line number Diff line change
@@ -1,172 +1,101 @@
from pandas.api.extensions import ExtensionDtype, ExtensionArray
import io
import sys
from typing import List, Optional, Union, Sequence
import numpy as np
from pandas.api.extensions import ExtensionArray, ExtensionDtype
from PIL import Image
import io
import numpy as np

from lotus.utils import fetch_image

class ImageDtype(ExtensionDtype):
"""Custom dtype for Images."""
name = 'image'
type = Image.Image
na_value = None

@classmethod
def construct_array_type(cls):
return ImageArray


class ImageArray(ExtensionArray):
"""ExtensionArray for storing Images."""

def __init__(self, values):
"""Initialize array with validation."""
values = self._validate_values(values)
self._data = np.array(values, dtype=object)
def __init__(self, values: Sequence[Optional[Image.Image]]):
self._data = values
self._dtype = ImageDtype()

@staticmethod
def _validate_values(values):
"""Validate that all values are Images or None."""
if isinstance(values, (ImageArray, np.ndarray)):
values = values.tolist()

validated = []
for i, val in enumerate(values):
if val is None:
validated.append(None)
elif isinstance(val, Image.Image):
validated.append(val)
else:
raise TypeError(
f"Value at index {i} has type {type(val).__name__}. "
"Expected .Image.Image or None."
)
return validated


def __getitem__(self, item: Union[int, slice, Sequence[int]]) -> Union[Image.Image, 'ImageArray']:
if isinstance(item, (int, np.integer)):
return fetch_image(self._data[item])
if isinstance(item, slice):
return ImageArray(self._data[item])
return ImageArray([self._data[i] for i in item])

def __setitem__(self, key: Union[int, slice, Sequence[int]], value: Union[Image.Image, Sequence[Image.Image]]):
if isinstance(key, (int, np.integer)):
self._data[key] = value
else:
for i, k in enumerate(key):
self._data[k] = value[i]

def isna(self) -> np.ndarray:
return np.array([img is None for img in self._data], dtype=bool)

def take(self, indexer: Sequence[int], allow_fill: bool = False, fill_value: Optional[Image.Image] = None) -> 'ImageArray':
if allow_fill:
fill_value = fill_value if fill_value is not None else self.dtype.na_value
result = [self._data[idx] if idx >= 0 else fill_value for idx in indexer]
else:
result = [self._data[idx] for idx in indexer]
return ImageArray(result)

def copy(self) -> 'ImageArray':
return ImageArray([img.copy() if img and hasattr(img, "copy") else img for img in self._data])

@classmethod
def _from_sequence(cls, scalars, dtype=None, copy=False):
"""Create ImageArray from sequence of scalars."""
scalars = cls._validate_values(scalars)
if copy:
scalars = [img.copy() if img is not None else None for img in scalars]
return cls(scalars)

return cls([img.copy() if img and copy and hasattr(img, 'copy') else img for img in scalars])

@classmethod
def _from_factorized(cls, values, original):
"""Create ImageArray from factorized values."""
return original

def __getitem__(self, item):
"""Get item(s) from array."""
result = self._data[item]
if isinstance(item, (int, np.integer)):
return result
return ImageArray(result)

def __len__(self):
"""Length of array."""

@classmethod
def _concat_same_type(cls, to_concat: Sequence['ImageArray']) -> 'ImageArray':
combined = [img for array in to_concat for img in array._data]
return cls(combined)

def __len__(self) -> int:
return len(self._data)
def __eq__(self, other):
"""Equality comparison."""

def __eq__(self, other) -> np.ndarray:
# check if other is iterable
if isinstance(other, ImageArray):
return np.array([
_compare_images(img1, img2)
for img1, img2 in zip(self._data, other._data)
])
elif isinstance(other, (Image.Image, type(None))):
return np.array([
_compare_images(img, other)
for img in self._data
])
return NotImplemented

def __setitem__(self, key, value):
"""Set item(s) in array with validation."""
if isinstance(key, (int, np.integer)):
if not (isinstance(value, Image.Image) or value is None):
raise TypeError(
f"Cannot set value of type {type(value).__name__}. "
"Expected Image.Image or None."
)
self._data[key] = value
else:
value = self._validate_values(value)
self._data[key] = value

return np.array([_compare_images(img1, img2) for img1, img2 in zip(self._data, other._data)], dtype=bool)
if hasattr(other, '__iter__') and not isinstance(other, str):
return np.array([_compare_images(img1, img2) for img1, img2 in zip(self._data, other)], dtype=bool)
return np.array([_compare_images(img, other) for img in self._data], dtype=bool)

@property
def dtype(self):
"""Return the dtype object."""
def dtype(self) -> ImageDtype:
return self._dtype

@property
def nbytes(self):
"""Return number of bytes in memory."""
return sum(
len(img_to_bytes(img)) if img is not None else 0
for img in self._data
)

def isna(self):
"""Return boolean array indicating missing values."""
return np.array([img is None for img in self._data])

def take(self, indexer, allow_fill=False, fill_value=None):
"""Take elements from array."""
if allow_fill:
if fill_value is not None and not (isinstance(fill_value, Image.Image) or fill_value is None):
raise TypeError(
f"Fill value must be Image.Image or None, not {type(fill_value).__name__}"
)
if fill_value is None:
fill_value = self.dtype.na_value

result = np.array([
self._data[idx] if idx >= 0 else fill_value
for idx in indexer
])
else:
result = self._data.take(indexer)

return ImageArray(result)

def copy(self):
"""Return deep copy of array."""
return ImageArray([
img.copy() if img is not None else None
for img in self._data
])

@classmethod
def _concat_same_type(cls, to_concat):
"""Concatenate multiple arrays."""
return cls(np.concatenate([array._data for array in to_concat]))

def interpolate(self, method='linear', axis=0, limit=None, inplace=False,
limit_direction=None, limit_area=None, downcast=None, **kwargs):
"""Interpolate missing values."""
return self.copy() if not inplace else self

def _compare_images(img1, img2):
"""Compare two Images for equality."""
if img1 is None and img2 is None:
return True
def nbytes(self) -> int:
return sum(sys.getsizeof(img) for img in self._data if img)

def __repr__(self) -> str:
return f"ImageArray([{', '.join(['<Image>' if img is not None else 'None' for img in self._data[:5]])}, ...])"

def _formatter(self, boxed: bool = False):
return lambda x: '<Image>' if x is not None else 'None'


def _compare_images(img1: Optional[Image.Image], img2: Optional[Image.Image]) -> bool:
if img1 is None or img2 is None:
return False
if img1.size != img2.size:
return False
if img1.mode != img2.mode:
return False
return np.array_equal(np.array(img1), np.array(img2))

def img_to_bytes(img):
"""Convert Image to bytes."""
if img is None:
return None
buf = io.BytesIO()
img.save(buf, format=img.format or 'PNG')
return buf.getvalue()

def bytes_to_img(b):
"""Convert bytes to Image."""
if b is None:
return None
return Image.open(io.BytesIO(b))
return img1 is img2
if isinstance(img1, Image.Image) or isinstance(img2, Image.Image):
img1 = fetch_image(img1)
img2 = fetch_image(img2)
return (img1.size == img2.size and img1.mode == img2.mode and img1.tobytes() == img2.tobytes())
else:
return img1 == img2
8 changes: 5 additions & 3 deletions lotus/templates/task_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def filter_user_message_formatter(
text = multimodal_data
image_inputs = []

if isinstance(multimodal_data, list):
if isinstance(multimodal_data, dict):
_image_inputs = [
[{
"type": "text",
Expand All @@ -19,13 +19,13 @@ def filter_user_message_formatter(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
"url": base64_image
},
}]
for key, base64_image in multimodal_data["image"].items()
]
image_inputs = [m for image_input in _image_inputs for m in image_input]
text = multimodal_data["text"]
text = multimodal_data["text"] or ""

return {
"role": "user",
Expand Down Expand Up @@ -261,6 +261,8 @@ def format_row(x: pd.Series, cols: list[str]) -> str:

# take cols that are in df
cols = [col for col in cols if col in df.columns]
if len(cols) == 0:
return [""]*len(df)
formatted_rows: list[str] = df.apply(lambda x: format_row(x, cols), axis=1).tolist()
return formatted_rows

Expand Down
Loading

0 comments on commit da28054

Please sign in to comment.