From d48bf53f4af0496f5b7be5087cf6ba22d6897234 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 14 Jan 2025 19:10:09 +0100 Subject: [PATCH 1/5] add pandas and polars formatting in iterabledataset --- src/datasets/formatting/__init__.py | 1 + src/datasets/formatting/formatting.py | 15 +++- src/datasets/formatting/polars_formatter.py | 8 +- src/datasets/iterable_dataset.py | 94 +++++++++++++++------ 4 files changed, 86 insertions(+), 32 deletions(-) diff --git a/src/datasets/formatting/__init__.py b/src/datasets/formatting/__init__.py index 8aa21d37bd2..9771618c7b9 100644 --- a/src/datasets/formatting/__init__.py +++ b/src/datasets/formatting/__init__.py @@ -22,6 +22,7 @@ Formatter, PandasFormatter, PythonFormatter, + TableFormatter, TensorFormatter, format_table, query_table, diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index ddd77353519..c0b31cc16c5 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -429,7 +429,15 @@ def recursive_tensorize(self, data_struct: dict): raise NotImplementedError -class ArrowFormatter(Formatter[pa.Table, pa.Array, pa.Table]): +class TableFormatter(Formatter[RowFormat, ColumnFormat, BatchFormat]): + table_type: str + column_type: str + + +class ArrowFormatter(TableFormatter[pa.Table, pa.Array, pa.Table]): + table_type = "arrow table" + column_type = "arrow array" + def format_row(self, pa_table: pa.Table) -> pa.Table: return self.simple_arrow_extractor().extract_row(pa_table) @@ -465,7 +473,10 @@ def format_batch(self, pa_table: pa.Table) -> Mapping: return batch -class PandasFormatter(Formatter[pd.DataFrame, pd.Series, pd.DataFrame]): +class PandasFormatter(TableFormatter[pd.DataFrame, pd.Series, pd.DataFrame]): + table_type = "pandas dataframe" + column_type = "pandas series" + def format_row(self, pa_table: pa.Table) -> pd.DataFrame: row = self.pandas_arrow_extractor().extract_row(pa_table) row = self.pandas_features_decoder.decode_row(row) diff --git a/src/datasets/formatting/polars_formatter.py b/src/datasets/formatting/polars_formatter.py index 543bde52dd0..7ea2f783aec 100644 --- a/src/datasets/formatting/polars_formatter.py +++ b/src/datasets/formatting/polars_formatter.py @@ -13,7 +13,6 @@ # limitations under the License. import sys -from collections.abc import Mapping from functools import partial from typing import TYPE_CHECKING, Optional @@ -23,7 +22,7 @@ from ..features import Features from ..features.features import decode_nested_example from ..utils.py_utils import no_op_if_value_is_null -from .formatting import BaseArrowExtractor, TensorFormatter +from .formatting import BaseArrowExtractor, TableFormatter if TYPE_CHECKING: @@ -98,7 +97,10 @@ def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame": return self.decode_row(batch) -class PolarsFormatter(TensorFormatter[Mapping, "pl.DataFrame", Mapping]): +class PolarsFormatter(TableFormatter["pl.DataFrame", "pl.Series", "pl.DataFrame"]): + table_type = "polars dataframe" + column_type = "polars series" + def __init__(self, features=None, **np_array_kwargs): super().__init__(features=features) self.np_array_kwargs = np_array_kwargs diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index f2d47ff64b9..c21ffc02d7c 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -10,13 +10,21 @@ import fsspec.asyn import numpy as np +import pandas as pd import pyarrow as pa from . import config from .arrow_dataset import Dataset, DatasetInfoMixin from .features import Features from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects -from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter +from .formatting import ( + ArrowFormatter, + PythonFormatter, + TableFormatter, + TensorFormatter, + get_format_type_from_alias, + get_formatter, +) from .info import DatasetInfo from .splits import NamedSplit, Split from .table import cast_table_to_features, read_schema_from_file, table_cast @@ -966,6 +974,19 @@ def shard_data_sources( ) +def _table_output_to_arrow(output) -> pa.Table: + if isinstance(output, pa.Table): + return output + if isinstance(output, (pd.DataFrame, pd.Series)): + return pa.Table.from_pandas(output) + if config.POLARS_AVAILABLE and "polars" in sys.modules: + import polars as pl + + if isinstance(output, (pl.DataFrame, pl.Series)): + return output.to_arrow() + return output + + class MappedExamplesIterable(_BaseExamplesIterable): def __init__( self, @@ -994,22 +1015,22 @@ def __init__( self.formatting = formatting # required for iter_arrow self._features = features # sanity checks - if formatting and formatting.format_type == "arrow": + if formatting and formatting.is_table: # batch_size should match for iter_arrow if not isinstance(ex_iterable, RebatchedArrowExamplesIterable): raise ValueError( - "The Arrow-formatted MappedExamplesIterable has underlying iterable" + f"The {formatting.format_type.capitalize()}-formatted MappedExamplesIterable has underlying iterable" f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable." ) elif ex_iterable.batch_size != (batch_size if batched else 1): raise ValueError( - f"The Arrow-formatted MappedExamplesIterable has batch_size={batch_size if batched else 1} which is" + f"The {formatting.format_type.capitalize()}-formatted MappedExamplesIterable has batch_size={batch_size if batched else 1} which is" f"different from {ex_iterable.batch_size=} from its underlying iterable." ) @property def iter_arrow(self): - if self.formatting and self.formatting.format_type == "arrow": + if self.formatting and self.formatting.is_table: return self._iter_arrow @property @@ -1030,7 +1051,7 @@ def _init_state_dict(self) -> dict: return self._state_dict def __iter__(self): - if self.formatting and self.formatting.format_type == "arrow": + if self.formatting and self.formatting.is_table: formatter = PythonFormatter() for key, pa_table in self._iter_arrow(max_chunksize=1): yield key, formatter.format_row(pa_table) @@ -1156,6 +1177,7 @@ def _iter(self): yield key, transformed_example def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key, pa.Table]]: + formatter: TableFormatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter() if self.ex_iterable.iter_arrow: iterator = self.ex_iterable.iter_arrow() else: @@ -1182,18 +1204,23 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key ): return # first build the batch - function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] + function_args = ( + [formatter.format_batch(pa_table)] + if self.input_columns is None + else [pa_table[col] for col in self.input_columns] + ) if self.with_indices: if self.batched: function_args.append([current_idx + i for i in range(len(pa_table))]) else: function_args.append(current_idx) # then apply the transform - output_table = self.function(*function_args, **self.fn_kwargs) + output = self.function(*function_args, **self.fn_kwargs) + output_table = _table_output_to_arrow(output) if not isinstance(output_table, pa.Table): raise TypeError( - f"Provided `function` which is applied to pyarrow tables returns a variable of type " - f"{type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset." + f"Provided `function` which is applied to {formatter.table_type} returns a variable of type " + f"{type(output_table)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset." ) # we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts # then remove the unwanted columns @@ -1280,16 +1307,16 @@ def __init__( self.fn_kwargs = fn_kwargs or {} self.formatting = formatting # required for iter_arrow # sanity checks - if formatting and formatting.format_type == "arrow": + if formatting and formatting.is_table: # batch_size should match for iter_arrow if not isinstance(ex_iterable, RebatchedArrowExamplesIterable): raise ValueError( - "The Arrow-formatted FilteredExamplesIterable has underlying iterable" + f"The {formatting.format_type.capitalize()}-formatted FilteredExamplesIterable has underlying iterable" f"that is a {type(ex_iterable).__name__} instead of a RebatchedArrowExamplesIterable." ) elif ex_iterable.batch_size != (batch_size if batched else 1): raise ValueError( - f"The Arrow-formatted FilteredExamplesIterable has batch_size={batch_size if batched else 1} which is" + f"The {formatting.format_type.capitalize()}-formatted FilteredExamplesIterable has batch_size={batch_size if batched else 1} which is" f"different from {ex_iterable.batch_size=} from its underlying iterable." ) @@ -1392,6 +1419,7 @@ def _iter(self): yield key, example def _iter_arrow(self, max_chunksize: Optional[int] = None): + formatter = get_formatter(self.formatting) if self.formatting else ArrowFormatter() if self.ex_iterable.iter_arrow: iterator = self.ex_iterable.iter_arrow() else: @@ -1415,14 +1443,24 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None): ): return - function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] + function_args = ( + [formatter.format_batch(pa_table)] + if self.input_columns is None + else [pa_table[col] for col in self.input_columns] + ) if self.with_indices: if self.batched: function_args.append([current_idx + i for i in range(len(pa_table))]) else: function_args.append(current_idx) # then apply the transform - mask = self.function(*function_args, **self.fn_kwargs) + output = self.function(*function_args, **self.fn_kwargs) + mask = _table_output_to_arrow(output) + if not isinstance(mask, (pa.Array, pa.BooleanScalar)): + raise TypeError( + f"Provided `function` which is applied to {formatter.table_type} returns a variable of type " + f"{type(output_table)}. Make sure provided `function` returns a {formatter.column_type} to update the dataset." + ) # return output if self.batched: output_table = pa_table.filter(mask) @@ -1734,11 +1772,13 @@ def _apply_feature_types_on_batch( class FormattingConfig: format_type: Optional[str] - def __post_init__(self): - if self.format_type == "pandas": - raise NotImplementedError( - "The 'pandas' formatting is not implemented for iterable datasets. You can use 'numpy' or 'arrow' instead." - ) + @property + def is_table(self) -> bool: + return isinstance(get_formatter(self.format_type), TableFormatter) + + @property + def is_tensor(self) -> bool: + return isinstance(get_formatter(self.format_type), TensorFormatter) class FormattedExamplesIterable(_BaseExamplesIterable): @@ -1757,7 +1797,7 @@ def __init__( @property def iter_arrow(self): - if self.ex_iterable.iter_arrow and (not self.formatting or self.formatting.format_type == "arrow"): + if self.ex_iterable.iter_arrow and (not self.formatting or self.formatting.is_table): return self._iter_arrow @property @@ -1773,7 +1813,7 @@ def _init_state_dict(self) -> dict: return self._state_dict def __iter__(self): - if not self.formatting or self.formatting.format_type == "arrow": + if not self.formatting or self.formatting.is_table: formatter = PythonFormatter() else: formatter = get_formatter( @@ -2093,7 +2133,7 @@ def _iter_pytorch(self): else: format_dict = None - if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"): + if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): if ex_iterable.iter_arrow: iterator = ex_iterable.iter_arrow() else: @@ -2133,7 +2173,7 @@ def _prepare_ex_iterable_for_iteration( self, batch_size: int = 1, drop_last_batch: bool = False ) -> _BaseExamplesIterable: ex_iterable = self._ex_iterable - if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"): + if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): ex_iterable = RebatchedArrowExamplesIterable( ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch ) @@ -2189,7 +2229,7 @@ def __iter__(self): else: format_dict = None - if self._formatting and (ex_iterable.iter_arrow or self._formatting.format_type == "arrow"): + if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): if ex_iterable.iter_arrow: iterator = ex_iterable.iter_arrow() else: @@ -2225,7 +2265,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): format_dict = None ex_iterable = self._prepare_ex_iterable_for_iteration(batch_size=batch_size, drop_last_batch=drop_last_batch) - if self._formatting and (ex_iterable.iter_arrow or self._formatting == "arrow"): + if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): if ex_iterable.iter_arrow: iterator = ex_iterable.iter_arrow() else: @@ -2516,7 +2556,7 @@ def map( else self._info.features ) - if self._formatting and self._formatting.format_type == "arrow": + if self._formatting and self._formatting.is_table: # apply formatting before iter_arrow to keep map examples iterable happy ex_iterable = FormattedExamplesIterable( ex_iterable, From 7386a84dacbb0dea8cd0e7c9a716e03cc1c7e328 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 14 Jan 2025 19:24:27 +0100 Subject: [PATCH 2/5] fix tests --- src/datasets/iterable_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index c21ffc02d7c..ebff7397440 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1220,7 +1220,7 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key if not isinstance(output_table, pa.Table): raise TypeError( f"Provided `function` which is applied to {formatter.table_type} returns a variable of type " - f"{type(output_table)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset." + f"{type(output)}. Make sure provided `function` returns a {formatter.table_type} to update the dataset." ) # we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts # then remove the unwanted columns @@ -1419,7 +1419,7 @@ def _iter(self): yield key, example def _iter_arrow(self, max_chunksize: Optional[int] = None): - formatter = get_formatter(self.formatting) if self.formatting else ArrowFormatter() + formatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter() if self.ex_iterable.iter_arrow: iterator = self.ex_iterable.iter_arrow() else: @@ -1456,10 +1456,10 @@ def _iter_arrow(self, max_chunksize: Optional[int] = None): # then apply the transform output = self.function(*function_args, **self.fn_kwargs) mask = _table_output_to_arrow(output) - if not isinstance(mask, (pa.Array, pa.BooleanScalar)): + if not isinstance(mask, (bool, pa.Array, pa.BooleanScalar)): raise TypeError( f"Provided `function` which is applied to {formatter.table_type} returns a variable of type " - f"{type(output_table)}. Make sure provided `function` returns a {formatter.column_type} to update the dataset." + f"{type(output)}. Make sure provided `function` returns a {formatter.column_type} to update the dataset." ) # return output if self.batched: From 4435fa4453636e5e6f12f978a0c34dc35bdded00 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 14 Jan 2025 19:30:10 +0100 Subject: [PATCH 3/5] docs --- docs/source/process.mdx | 2 +- src/datasets/arrow_dataset.py | 6 +++--- src/datasets/dataset_dict.py | 9 ++++----- src/datasets/iterable_dataset.py | 3 +-- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/docs/source/process.mdx b/docs/source/process.mdx index 198b7509456..456a57f2d44 100644 --- a/docs/source/process.mdx +++ b/docs/source/process.mdx @@ -647,7 +647,7 @@ The [`~Dataset.with_format`] function also changes the format of a column, excep -🤗 Datasets also provides support for other common data formats such as NumPy, Pandas, and JAX. Check out the [Using Datasets with TensorFlow](https://huggingface.co/docs/datasets/master/en/use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset. +🤗 Datasets also provides support for other common data formats such as NumPy, TensorFlow, JAX, Arrow, Pandas and Polars. Check out the [Using Datasets with TensorFlow](https://huggingface.co/docs/datasets/master/en/use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset. diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index f272b3ec24b..f005b2374d1 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -2457,7 +2457,7 @@ def formatted_as( Args: type (`str`, *optional*): - Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`. + Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__`` returns python objects (default). columns (`List[str]`, *optional*): Columns to format in the output. @@ -2491,7 +2491,7 @@ def set_format( Args: type (`str`, *optional*): - Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`. + Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__` returns python objects (default). columns (`List[str]`, *optional*): Columns to format in the output. @@ -2644,7 +2644,7 @@ def with_format( Args: type (`str`, *optional*): - Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`. + Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__` returns python objects (default). columns (`List[str]`, *optional*): Columns to format in the output. diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index acb824b2bfc..5c9deddb73e 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -527,7 +527,7 @@ def formatted_as( Args: type (`str`, *optional*): - Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`. + Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__` returns python objects (default). columns (`List[str]`, *optional*): Columns to format in the output. @@ -563,7 +563,7 @@ def set_format( Args: type (`str`, *optional*): - Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`. + Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__` returns python objects (default). columns (`List[str]`, *optional*): Columns to format in the output. @@ -670,7 +670,7 @@ def with_format( Args: type (`str`, *optional*): - Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`. + Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means `__getitem__` returns python objects (default). columns (`List[str]`, *optional*): Columns to format in the output. @@ -1821,12 +1821,11 @@ def with_format( ) -> "IterableDatasetDict": """ Return a dataset with the specified format. - The 'pandas' format is currently not implemented. Args: type (`str`, *optional*): - Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'arrow', 'jax']`. + Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means it returns python objects (default). Example: diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index ebff7397440..317cc0b1723 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2407,12 +2407,11 @@ def with_format( ) -> "IterableDataset": """ Return a dataset with the specified format. - The 'pandas' format is currently not implemented. Args: type (`str`, *optional*): - Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'arrow', 'jax']`. + Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`. `None` means it returns python objects (default). Example: From 83090e107e2c2f79c58979bdb0013a254afe5b64 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 30 Jan 2025 12:01:09 +0100 Subject: [PATCH 4/5] fix ci --- tests/test_iterable_dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index f3f7ee3106f..a88bf64590f 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -51,6 +51,7 @@ assert_arrow_memory_doesnt_increase, is_rng_equal, require_dill_gt_0_3_2, + require_jax, require_not_windows, require_numpy1_on_windows, require_pyspark, @@ -1681,7 +1682,12 @@ def test_iterable_dataset_features_cast_to_python(): assert list(dataset) == [{"timestamp": pd.Timestamp(2020, 1, 1).to_pydatetime(), "array": [1] * 5, "id": 0}] -@pytest.mark.parametrize("format_type", [None, "torch", "python", "tf", "tensorflow", "np", "numpy", "jax"]) +@require_torch +@require_tf +@require_jax +@pytest.mark.parametrize( + "format_type", [None, "torch", "python", "tf", "tensorflow", "np", "numpy", "jax", "arrow", "pd", "pandas"] +) def test_iterable_dataset_with_format(dataset: IterableDataset, format_type): formatted_dataset = dataset.with_format(format_type) assert formatted_dataset._formatting.format_type == get_format_type_from_alias(format_type) From 1f6a22c75a3893a80e8dad7b14eba29f3eabf66b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 30 Jan 2025 14:12:18 +0100 Subject: [PATCH 5/5] add tests --- tests/test_iterable_dataset.py | 40 ++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index a88bf64590f..bd79863f9c3 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -54,6 +54,7 @@ require_jax, require_not_windows, require_numpy1_on_windows, + require_polars, require_pyspark, require_tf, require_torch, @@ -2224,6 +2225,45 @@ def g(): mock_python_arrow_extractor.assert_not_called() +def test_format_arrow(dataset: IterableDataset): + ds = dataset.with_format("arrow") + assert isinstance(next(iter(ds)), pa.Table) + assert isinstance(next(iter(ds.iter(batch_size=4))), pa.Table) + assert len(next(iter(ds))) == 1 + assert len(next(iter(ds.iter(batch_size=4)))) == 4 + ds = ds.map(lambda t: t.append_column("new_col", pa.array([0] * len(t)))) + ds = ds.map(lambda t: t.append_column("new_col_batched", pa.array([1] * len(t))), batched=True) + ds = ds.with_format(None) + assert next(iter(ds)) == {**next(iter(dataset)), "new_col": 0, "new_col_batched": 1} + + +def test_format_pandas(dataset: IterableDataset): + ds = dataset.with_format("pandas") + assert isinstance(next(iter(ds)), pd.DataFrame) + assert isinstance(next(iter(ds.iter(batch_size=4))), pd.DataFrame) + assert len(next(iter(ds))) == 1 + assert len(next(iter(ds.iter(batch_size=4)))) == 4 + ds = ds.map(lambda df: df.assign(new_col=[0] * len(df))) + ds = ds.map(lambda df: df.assign(new_col_batched=[1] * len(df)), batched=True) + ds = ds.with_format(None) + assert next(iter(ds)) == {**next(iter(dataset)), "new_col": 0, "new_col_batched": 1} + + +@require_polars +def test_format_polars(dataset: IterableDataset): + import polars as pl + + ds = dataset.with_format("polars") + assert isinstance(next(iter(ds)), pl.DataFrame) + assert isinstance(next(iter(ds.iter(batch_size=4))), pl.DataFrame) + assert len(next(iter(ds))) == 1 + assert len(next(iter(ds.iter(batch_size=4)))) == 4 + ds = ds.map(lambda df: df.with_columns(pl.Series([0] * len(df)).alias("new_col"))) + ds = ds.map(lambda df: df.with_columns(pl.Series([1] * len(df)).alias("new_col_batched")), batched=True) + ds = ds.with_format(None) + assert next(iter(ds)) == {**next(iter(dataset)), "new_col": 0, "new_col_batched": 1} + + @pytest.mark.parametrize("num_shards1, num_shards2, num_workers", [(2, 1, 1), (2, 2, 2), (1, 3, 1), (4, 3, 3)]) def test_interleave_dataset_with_sharding(num_shards1, num_shards2, num_workers): from torch.utils.data import DataLoader