Skip to content

Commit

Permalink
Rename TableVectorizer parameters (#947)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes authored Jun 21, 2024
1 parent adfe24d commit a47c0a5
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 92 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ Major changes
used twice (go through 2 different transformers).
:pr:`902` by :user:`Jérôme Dockès <jeromedockes>`.

* Some parameters of :class:`TableVectorizer` have been renamed:
`high_cardinality_transformer` → `high_cardinality`,
`low_cardinality_transformer` → `low_cardinality`,
`datetime_transformer` → `datetime`, `numeric_transformer` → `numeric`.
:pr:`947` by :user:`Jérôme Dockès <jeromedockes>`.

* The :class:`GapEncoder` and :class:`MinHashEncoder` are now a single-column
transformers: their ``fit``, ``fit_transform`` and ``transform`` methods
accept a single column (a pandas or polars Series). Dataframes and numpy
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/bench_gap_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def benchmark(max_iter_e_step, dataset_name):
(
"encoding",
TableVectorizer(
high_cardinality_transformer=ModifiedGapEncoder(
high_cardinality=ModifiedGapEncoder(
min_iter=5,
max_iter=5,
max_iter_e_step=max_iter_e_step,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/bench_tablevectorizer_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def benchmark(
):
tv = TableVectorizer(
cardinality_threshold=tv_cardinality_threshold,
high_cardinality_transformer=MinHashEncoder(n_components=minhash_n_components),
high_cardinality=MinHashEncoder(n_components=minhash_n_components),
)

dataset = dataset_map[dataset_name]
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/run_on_openml_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@

classification_pipeline = Pipeline(
[
("vectorizer", TableVectorizer(high_cardinality_transformer=MinHashEncoder())),
("vectorizer", TableVectorizer(high_cardinality=MinHashEncoder())),
("classifier", HistGradientBoostingClassifier()),
]
)

regression_pipeline = Pipeline(
[
("vectorizer", TableVectorizer(high_cardinality_transformer=MinHashEncoder())),
("vectorizer", TableVectorizer(high_cardinality=MinHashEncoder())),
("regressor", HistGradientBoostingRegressor()),
]
)
Expand Down
3 changes: 1 addition & 2 deletions examples/01_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,7 @@
from skrub import MinHashEncoder, ToCategorical

vectorizer = TableVectorizer(
low_cardinality_transformer=ToCategorical(),
high_cardinality_transformer=MinHashEncoder(),
low_cardinality=ToCategorical(), high_cardinality=MinHashEncoder()
)
pipeline = make_pipeline(
vectorizer, HistGradientBoostingRegressor(categorical_features="from_dtype")
Expand Down
8 changes: 2 additions & 6 deletions examples/03_datetime_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@
#
# Here, for example, we want it to extract the day of the week.

table_vec = TableVectorizer(
datetime_transformer=DatetimeEncoder(add_weekday=True),
).fit(X)
table_vec = TableVectorizer(datetime=DatetimeEncoder(add_weekday=True)).fit(X)
pprint(table_vec.get_feature_names_out())

###############################################################################
Expand Down Expand Up @@ -257,9 +255,7 @@
###############################################################################
from sklearn.inspection import permutation_importance

table_vec = TableVectorizer(
datetime_transformer=DatetimeEncoder(add_weekday=True),
)
table_vec = TableVectorizer(datetime=DatetimeEncoder(add_weekday=True))

# In this case, we don't use a pipeline, because we want to compute the
# importance of the features created by the DatetimeEncoder
Expand Down
4 changes: 1 addition & 3 deletions examples/08_join_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@
# columns, and doesn't interact with numerical columns.
from skrub import DatetimeEncoder, TableVectorizer

table_vectorizer = TableVectorizer(
datetime_transformer=DatetimeEncoder(add_weekday=True)
)
table_vectorizer = TableVectorizer(datetime=DatetimeEncoder(add_weekday=True))
X_date_encoded = table_vectorizer.fit_transform(X)
X_date_encoded.head()

Expand Down
6 changes: 2 additions & 4 deletions examples/FIXME/07_grid_searching_with_the_tablevectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@

from skrub import MinHashEncoder

tv = TableVectorizer(
high_cardinality_transformer=MinHashEncoder(),
)
tv = TableVectorizer(high_cardinality=MinHashEncoder())
tv.fit(X)

pprint(tv.transformers_)
Expand Down Expand Up @@ -117,7 +115,7 @@

pipeline = make_pipeline(
TableVectorizer(
high_cardinality_transformer=GapEncoder(),
high_cardinality=GapEncoder(),
specific_transformers=[
("mh_dep_name", MinHashEncoder(), ["department_name"]),
],
Expand Down
2 changes: 1 addition & 1 deletion skrub/_interpolation_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from skrub._minhash_encoder import MinHashEncoder
from skrub._table_vectorizer import TableVectorizer

DEFAULT_VECTORIZER = TableVectorizer(high_cardinality_transformer=MinHashEncoder())
DEFAULT_VECTORIZER = TableVectorizer(high_cardinality=MinHashEncoder())
DEFAULT_REGRESSOR = HistGradientBoostingRegressor()
DEFAULT_CLASSIFIER = HistGradientBoostingClassifier()

Expand Down
60 changes: 29 additions & 31 deletions skrub/_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class TableVectorizer(TransformerMixin, BaseEstimator):
String and categorical columns with a count of unique values smaller
than a given threshold (40 by default). Category encoding schemes such
as one-hot encoding, ordinal encoding etc. are typically appropriate
for low-cardinality columns.
for columns with few unique values.
- high_cardinality:
String and categorical columns with many unique values, such as
free-form text. Such columns have so many distinct values that it is
Expand All @@ -140,8 +140,7 @@ class TableVectorizer(TransformerMixin, BaseEstimator):
multivariate transformations are therefore not supported.
The transformer for each kind of column can be configured with the
corresponding ``*_transformer`` parameter: ``numeric_transformer``,
``datetime_transformer``, ...
corresponding parameter.
A transformer can be a scikit-learn Transformer (an object providing the
``fit``, ``fit_transform`` and ``transform`` methods), a clone of which
Expand All @@ -156,30 +155,33 @@ class TableVectorizer(TransformerMixin, BaseEstimator):
.. note::
The ``specific_transformers`` parameter is likely to be removed in a
future version of ``skrub``, when better utilities for building complex
The ``specific_transformers`` parameter will be removed in a future
version of ``skrub``, when better utilities for building complex
pipelines are introduced.
Parameters
----------
cardinality_threshold : int, default=40
String and categorical features with a number of unique values strictly
smaller than this threshold are considered ``low_cardinality``, the
rest are considered ``high_cardinality``.
smaller than this threshold are handled by the transformer ``low_cardinality``, the
rest are handled by the transformer ``high_cardinality``.
low_cardinality_transformer : transformer, "passthrough" or "drop", optional
The transformer for ``low_cardinality`` columns. The default is a
low_cardinality : transformer, "passthrough" or "drop", optional
The transformer for string or categorical columns with strictly fewer
than ``cardinality_threshold`` unique values. The default is a
``OneHotEncoder``.
high_cardinality_transformer : transformer, "passthrough" or "drop", optional
The transformer for ``high_cardinality`` columns. The default is a
``GapEncoder`` with 30 components (30 output columns for each input).
high_cardinality : transformer, "passthrough" or "drop", optional
The transformer for string or categorical columns with at least
``cardinality_threshold`` unique values. The default is a ``GapEncoder``
with 30 components (30 output columns for each input).
numeric_transformer : transformer, "passthrough" or "drop", optional
The transformer for ``numeric`` columns. The default is passthrough.
numeric : transformer, "passthrough" or "drop", optional
The transformer for numeric columns (floats, ints, booleans). The
default is passthrough.
datetime_transformer : transformer, "passthrough" or "drop", optional
The transformer for ``datetime`` columns. The default is
datetime : transformer, "passthrough" or "drop", optional
The transformer for date and datetime columns. The default is
``DatetimeEncoder``, which extracts features such as year, month, etc.
specific_transformers : list of (transformer, list of column names) pairs, optional
Expand Down Expand Up @@ -407,26 +409,22 @@ def __init__(
self,
*,
cardinality_threshold=40,
low_cardinality_transformer=LOW_CARDINALITY_TRANSFORMER,
high_cardinality_transformer=HIGH_CARDINALITY_TRANSFORMER,
numeric_transformer=NUMERIC_TRANSFORMER,
datetime_transformer=DATETIME_TRANSFORMER,
low_cardinality=LOW_CARDINALITY_TRANSFORMER,
high_cardinality=HIGH_CARDINALITY_TRANSFORMER,
numeric=NUMERIC_TRANSFORMER,
datetime=DATETIME_TRANSFORMER,
specific_transformers=(),
n_jobs=None,
):
self.cardinality_threshold = cardinality_threshold
self.low_cardinality_transformer = _utils.clone_if_default(
low_cardinality_transformer, LOW_CARDINALITY_TRANSFORMER
self.low_cardinality = _utils.clone_if_default(
low_cardinality, LOW_CARDINALITY_TRANSFORMER
)
self.high_cardinality_transformer = _utils.clone_if_default(
high_cardinality_transformer, HIGH_CARDINALITY_TRANSFORMER
)
self.numeric_transformer = _utils.clone_if_default(
numeric_transformer, NUMERIC_TRANSFORMER
)
self.datetime_transformer = _utils.clone_if_default(
datetime_transformer, DATETIME_TRANSFORMER
self.high_cardinality = _utils.clone_if_default(
high_cardinality, HIGH_CARDINALITY_TRANSFORMER
)
self.numeric = _utils.clone_if_default(numeric, NUMERIC_TRANSFORMER)
self.datetime = _utils.clone_if_default(datetime, DATETIME_TRANSFORMER)
self.specific_transformers = specific_transformers
self.n_jobs = n_jobs

Expand Down Expand Up @@ -562,7 +560,7 @@ def add_step(steps, transformer, cols, allow_reject=False):
]:
self._named_encoders[name] = add_step(
self._encoders,
getattr(self, f"{name}_transformer"),
getattr(self, name),
cols & selector - _created_by(*self._encoders),
)

Expand Down
24 changes: 12 additions & 12 deletions skrub/_tabular_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def tabular_learner(estimator, *, n_jobs=None):
>>> tabular_learner('regressor') # doctest: +SKIP
Pipeline(steps=[('tablevectorizer',
TableVectorizer(high_cardinality_transformer=MinHashEncoder(),
low_cardinality_transformer=ToCategorical())),
TableVectorizer(high_cardinality=MinHashEncoder(),
low_cardinality=ToCategorical())),
('histgradientboostingregressor',
HistGradientBoostingRegressor(categorical_features='from_dtype'))])
Expand All @@ -118,8 +118,8 @@ def tabular_learner(estimator, *, n_jobs=None):
>>> tabular_learner('classifier') # doctest: +SKIP
Pipeline(steps=[('tablevectorizer',
TableVectorizer(high_cardinality_transformer=MinHashEncoder(),
low_cardinality_transformer=ToCategorical())),
TableVectorizer(high_cardinality=MinHashEncoder(),
low_cardinality=ToCategorical())),
('histgradientboostingclassifier',
HistGradientBoostingClassifier(categorical_features='from_dtype'))])
Expand Down Expand Up @@ -192,21 +192,21 @@ def tabular_learner(estimator, *, n_jobs=None):
>>> tabular_learner('classifier') # doctest: +SKIP
Pipeline(steps=[('tablevectorizer',
TableVectorizer(high_cardinality_transformer=MinHashEncoder(),
low_cardinality_transformer=ToCategorical())),
TableVectorizer(high_cardinality=MinHashEncoder(),
low_cardinality=ToCategorical())),
('histgradientboostingclassifier',
HistGradientBoostingClassifier(categorical_features='from_dtype'))])
- A :obj:`MinHashEncoder` is used as the
``high_cardinality_transformer``. This encoder provides good
``high_cardinality``. This encoder provides good
performance when the supervised estimator is based on a decision tree
or ensemble of trees, as is the case for the
:obj:`~sklearn.ensemble.HistGradientBoostingClassifier`. Unlike the
default :obj:`GapEncoder`, the :obj:`MinHashEncoder` does not produce
interpretable features. However, it is much faster and uses less
memory.
- The ``low_cardinality_transformer`` does not one-hot encode features.
- The ``low_cardinality`` does not one-hot encode features.
The :obj:`~sklearn.ensemble.HistGradientBoostingClassifier` has built-in
support for categorical data which is more efficient than one-hot
encoding. Therefore the selected encoder, :obj:`ToCategorical`, simply
Expand Down Expand Up @@ -257,13 +257,13 @@ def tabular_learner(estimator, *, n_jobs=None):
and getattr(estimator, "categorical_features", None) == "from_dtype"
):
vectorizer.set_params(
low_cardinality_transformer=ToCategorical(),
high_cardinality_transformer=MinHashEncoder(),
low_cardinality=ToCategorical(),
high_cardinality=MinHashEncoder(),
)
elif isinstance(estimator, _TREE_ENSEMBLE_CLASSES):
vectorizer.set_params(
low_cardinality_transformer=OrdinalEncoder(),
high_cardinality_transformer=MinHashEncoder(),
low_cardinality=OrdinalEncoder(),
high_cardinality=MinHashEncoder(),
)
steps = [vectorizer]
if not hasattr(estimator, "_get_tags") or not estimator._get_tags().get(
Expand Down
2 changes: 1 addition & 1 deletion skrub/tests/test_interpolation_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_join_on_date(df_module):
aux_key="date",
regressor=KNeighborsRegressor(1),
)
.set_params(vectorizer__datetime_transformer__resolution=None)
.set_params(vectorizer__datetime__resolution=None)
.fit_transform(sales)
)
assert_array_equal(ns.to_list(ns.col(transformed, "temp")), [-10, 10])
Expand Down
Loading

0 comments on commit a47c0a5

Please sign in to comment.