diff --git a/CHANGES.rst b/CHANGES.rst index a747ed462..2e938e041 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -62,6 +62,10 @@ Minor changes the "sample" tab panel by specifying ``n_rows``. :pr:`1083` by :user:`Jérôme Dockès `. +* the `TableReport` used to raise an exception when the dataframe contained + unhashable types such as python lists. This has been fixed in :pr:`1087` by + :user:`Jérôme Dockès `. + Release 0.3.0 ============= diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index dee8347e9..66b8879a9 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -719,10 +719,10 @@ def _to_string_pandas(col): @to_string.specialize("polars", argument_type="Column") def _to_string_polars(col): - if col.dtype != pl.Object: + # polars raises an error when trying to cast those types to string directly + # so we have to use map_elements + if col.dtype not in (pl.Object, pl.List, pl.Array): return _cast_polars(col, pl.String) - # Objects are mere passengers in polars dataframes and we can't do - # anything with them; cast raises an exception. # polars emits a performance warning when using map_elements with warnings.catch_warnings(): warnings.filterwarnings( diff --git a/skrub/_reporting/_data/templates/column-summary.html b/skrub/_reporting/_data/templates/column-summary.html index fa2a18de3..4d74bc3cf 100644 --- a/skrub/_reporting/_data/templates/column-summary.html +++ b/skrub/_reporting/_data/templates/column-summary.html @@ -86,7 +86,7 @@

Most frequent values
- {% for value in column.value_counts %} + {% for (value, count) in column.value_counts %} {% set val_id = "{}-freq-value-{}".format(col_id, loop.index0) %}
                         {% set val_id = "{}-freq-value-list".format(col_id) %}
                         
{{ column.value_counts.keys() | list }}
+ data-copy-text="{{ column.most_frequent_values }}">{{ column.most_frequent_values }}
{{ buttons.copybutton(val_id) }}
diff --git a/skrub/_reporting/_html.py b/skrub/_reporting/_html.py index 1cdf3c766..f2e0f7721 100644 --- a/skrub/_reporting/_html.py +++ b/skrub/_reporting/_html.py @@ -33,6 +33,13 @@ } +def _is_null(value): + isna = pd.isna(value) + if isinstance(isna, bool): + return isna + return False + + def _get_jinja_env(): env = jinja2.Environment( loader=jinja2.FileSystemLoader( @@ -47,7 +54,7 @@ def _get_jinja_env(): "svg_to_img_src", ]: env.filters[function_name] = getattr(_utils, function_name) - env.filters["is_null"] = pd.isna + env.filters["is_null"] = _is_null env.globals["random_string"] = random_string return env diff --git a/skrub/_reporting/_plotting.py b/skrub/_reporting/_plotting.py index 28f49ede5..cd1664010 100644 --- a/skrub/_reporting/_plotting.py +++ b/skrub/_reporting/_plotting.py @@ -109,9 +109,8 @@ def value_counts(value_counts, n_unique, n_rows, color=COLOR_0): Parameters ---------- - value_counts : dict - The keys are values, and values are counts. Must be sorted from most to - least frequent. + value_counts : list + Pairs of (value, count). Must be sorted from most to least frequent. n_unique : int Cardinality of the plotted column, used to determine if all unique @@ -129,8 +128,8 @@ def value_counts(value_counts, n_unique, n_rows, color=COLOR_0): str The plot as a XML string. """ - values = [_utils.ellide_string_short(s) for s in value_counts.keys()][::-1] - counts = list(value_counts.values())[::-1] + values = [_utils.ellide_string_short(v) for v, _ in value_counts][::-1] + counts = [c for _, c in value_counts][::-1] if n_unique > len(value_counts): title = f"{len(value_counts)} most frequent" else: diff --git a/skrub/_reporting/_summarize.py b/skrub/_reporting/_summarize.py index b9ed736e8..96c8e083e 100644 --- a/skrub/_reporting/_summarize.py +++ b/skrub/_reporting/_summarize.py @@ -154,10 +154,11 @@ def _add_value_counts(summary, column, *, dataframe_summary, with_plots): summary["unique_proportion"] = n_unique / max(1, dataframe_summary["n_rows"]) summary["high_cardinality"] = n_unique >= _HIGH_CARDINALITY_THRESHOLD summary["value_counts"] = value_counts + summary["most_frequent_values"] = [v for v, _ in value_counts] if n_unique == 1: summary["value_is_constant"] = True - summary["constant_value"] = next(iter(value_counts.keys())) + summary["constant_value"] = value_counts[0][0] else: summary["value_is_constant"] = False if with_plots: diff --git a/skrub/_reporting/_utils.py b/skrub/_reporting/_utils.py index a01ab009b..0b057fb22 100644 --- a/skrub/_reporting/_utils.py +++ b/skrub/_reporting/_utils.py @@ -37,7 +37,7 @@ def top_k_value_counts(column, k): n_unique = sbd.shape(counts)[0] counts = sbd.sort(counts, by="count", descending=True) counts = sbd.slice(counts, k) - return n_unique, dict(zip(*to_dict(counts).values())) + return n_unique, list(zip(*to_dict(counts).values())) def quantiles(column): diff --git a/skrub/_reporting/tests/test_summarize.py b/skrub/_reporting/tests/test_summarize.py index 87ca21123..414a7e61e 100644 --- a/skrub/_reporting/tests/test_summarize.py +++ b/skrub/_reporting/tests/test_summarize.py @@ -74,7 +74,8 @@ def test_summarize(monkeypatch, df_module, air_quality, order_by, with_plots): "plot_names": [], "position": 0, "unique_proportion": 0.118, - "value_counts": {"London": 8, "Paris": 9}, + "value_counts": [("Paris", 9), ("London", 8)], + "most_frequent_values": ["Paris", "London"], "value_is_constant": False, } diff --git a/skrub/_reporting/tests/test_table_report.py b/skrub/_reporting/tests/test_table_report.py index 799061bf4..8a931c114 100644 --- a/skrub/_reporting/tests/test_table_report.py +++ b/skrub/_reporting/tests/test_table_report.py @@ -74,3 +74,10 @@ def test_empty_dataframe(df_module): def test_open(pd_module, browser_mock): TableReport(pd_module.example_dataframe, title="the title").open() assert b"the title" in browser_mock.content + + +def test_non_hashable_values(df_module): + # non-regression test for #1066 + df = df_module.make_dataframe(dict(a=[[1, 2, 3], None, [4]])) + html = TableReport(df).html() + assert "[1, 2, 3]" in html