Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support faster processing using pandas or polars functions in IterableDataset.map() #7370

Merged
merged 6 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ The [`~Dataset.with_format`] function also changes the format of a column, excep

<Tip>

🤗 Datasets also provides support for other common data formats such as NumPy, Pandas, and JAX. Check out the [Using Datasets with TensorFlow](https://huggingface.co/docs/datasets/master/en/use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset.
🤗 Datasets also provides support for other common data formats such as NumPy, TensorFlow, JAX, Arrow, Pandas and Polars. Check out the [Using Datasets with TensorFlow](https://huggingface.co/docs/datasets/master/en/use_with_tensorflow#using-totfdataset) guide for more details on how to efficiently create a TensorFlow dataset.

</Tip>

Expand Down
6 changes: 3 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,7 +2457,7 @@ def formatted_as(

Args:
type (`str`, *optional*):
Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__`` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -2491,7 +2491,7 @@ def set_format(

Args:
type (`str`, *optional*):
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -2644,7 +2644,7 @@ def with_format(

Args:
type (`str`, *optional*):
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down
9 changes: 4 additions & 5 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def formatted_as(

Args:
type (`str`, *optional*):
Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -563,7 +563,7 @@ def set_format(

Args:
type (`str`, *optional*):
Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -670,7 +670,7 @@ def with_format(

Args:
type (`str`, *optional*):
Output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'pandas', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means `__getitem__` returns python objects (default).
columns (`List[str]`, *optional*):
Columns to format in the output.
Expand Down Expand Up @@ -1821,12 +1821,11 @@ def with_format(
) -> "IterableDatasetDict":
"""
Return a dataset with the specified format.
The 'pandas' format is currently not implemented.

Args:

type (`str`, *optional*):
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'arrow', 'jax']`.
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'jax', 'arrow', 'pandas', 'polars']`.
`None` means it returns python objects (default).

Example:
Expand Down
1 change: 1 addition & 0 deletions src/datasets/formatting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Formatter,
PandasFormatter,
PythonFormatter,
TableFormatter,
TensorFormatter,
format_table,
query_table,
Expand Down
15 changes: 13 additions & 2 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,15 @@ def recursive_tensorize(self, data_struct: dict):
raise NotImplementedError


class ArrowFormatter(Formatter[pa.Table, pa.Array, pa.Table]):
class TableFormatter(Formatter[RowFormat, ColumnFormat, BatchFormat]):
table_type: str
column_type: str


class ArrowFormatter(TableFormatter[pa.Table, pa.Array, pa.Table]):
table_type = "arrow table"
column_type = "arrow array"

def format_row(self, pa_table: pa.Table) -> pa.Table:
return self.simple_arrow_extractor().extract_row(pa_table)

Expand Down Expand Up @@ -465,7 +473,10 @@ def format_batch(self, pa_table: pa.Table) -> Mapping:
return batch


class PandasFormatter(Formatter[pd.DataFrame, pd.Series, pd.DataFrame]):
class PandasFormatter(TableFormatter[pd.DataFrame, pd.Series, pd.DataFrame]):
table_type = "pandas dataframe"
column_type = "pandas series"

def format_row(self, pa_table: pa.Table) -> pd.DataFrame:
row = self.pandas_arrow_extractor().extract_row(pa_table)
row = self.pandas_features_decoder.decode_row(row)
Expand Down
8 changes: 5 additions & 3 deletions src/datasets/formatting/polars_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import sys
from collections.abc import Mapping
from functools import partial
from typing import TYPE_CHECKING, Optional

Expand All @@ -23,7 +22,7 @@
from ..features import Features
from ..features.features import decode_nested_example
from ..utils.py_utils import no_op_if_value_is_null
from .formatting import BaseArrowExtractor, TensorFormatter
from .formatting import BaseArrowExtractor, TableFormatter


if TYPE_CHECKING:
Expand Down Expand Up @@ -98,7 +97,10 @@ def decode_batch(self, batch: "pl.DataFrame") -> "pl.DataFrame":
return self.decode_row(batch)


class PolarsFormatter(TensorFormatter[Mapping, "pl.DataFrame", Mapping]):
class PolarsFormatter(TableFormatter["pl.DataFrame", "pl.Series", "pl.DataFrame"]):
table_type = "polars dataframe"
column_type = "polars series"

def __init__(self, features=None, **np_array_kwargs):
super().__init__(features=features)
self.np_array_kwargs = np_array_kwargs
Expand Down
Loading