diff --git a/.github/tests/multimodality_tests.py b/.github/tests/multimodality_tests.py index 669b33bc..c6311d81 100644 --- a/.github/tests/multimodality_tests.py +++ b/.github/tests/multimodality_tests.py @@ -138,6 +138,23 @@ def test_topk_operation(setup_models, model): top_2_actual = set(sorted_df["image"].values) assert top_2_expected == top_2_actual + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_topk_with_groupby_operation(setup_models, model): + image_url = [ + "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", + "https://i1.wp.com/www.alloverthemap.net/wp-content/uploads/2014/02/2012-09-25-12.46.15.jpg?resize=400%2C284&ssl=1", + "https://i.pinimg.com/236x/a4/3a/65/a43a65683a0314f29b66402cebdcf46d.jpg", + "https://pravme.ru/wp-content/uploads/2018/01/sobor-Bogord-1.jpg", + ] + elements = ["doll", "bird"] + image_df = pd.DataFrame({"image": ImageArray(image_url)}) + element_df = pd.DataFrame({"element": elements}) + + df = image_df.join(element_df, how="cross") + df.sem_topk("the {image} is most likely an {element}", K=1, group_by=["element"]) + assert(len(set(df["element"])) == 2) + @pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) diff --git a/lotus/dtype_extensions/image.py b/lotus/dtype_extensions/image.py index dd2a522c..9fecab50 100644 --- a/lotus/dtype_extensions/image.py +++ b/lotus/dtype_extensions/image.py @@ -82,6 +82,23 @@ def copy(self) -> "ImageArray": new_array._cached_images = self._cached_images.copy() return new_array + def _concat_same_type(cls, to_concat: Sequence["ImageArray"]) -> "ImageArray": + """ + Concatenate multiple ImageArray instances into a single one. + + Args: + to_concat (Sequence[ImageArray]): A sequence of ImageArray instances to concatenate. + + Returns: + ImageArray: A new ImageArray containing all elements from the input arrays. + """ + # create list of all data + combined_data = np.concatenate([arr._data for arr in to_concat]) + return cls._from_sequence(combined_data) + + + + @classmethod def _from_sequence(cls, scalars, dtype=None, copy=False): if copy: