Skip to content

Commit

Permalink
fix error when report input contains lists skrub-data#1066
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes committed Sep 25, 2024
1 parent 91d1f05 commit a33b9ea
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 14 deletions.
6 changes: 3 additions & 3 deletions skrub/_dataframe/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,10 @@ def _to_string_pandas(col):

@to_string.specialize("polars", argument_type="Column")
def _to_string_polars(col):
if col.dtype != pl.Object:
# polars raises an error when trying to cast those types to string directly
# so we have to use map_elements
if col.dtype not in (pl.Object, pl.List, pl.Array):
return _cast_polars(col, pl.String)
# Objects are mere passengers in polars dataframes and we can't do
# anything with them; cast raises an exception.
# polars emits a performance warning when using map_elements
with warnings.catch_warnings():
warnings.filterwarnings(
Expand Down
4 changes: 2 additions & 2 deletions skrub/_reporting/_data/templates/column-summary.html
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ <h3 class="margin-r-t">
<summary>Most frequent values</summary>
<div class="shrink">
<div class="copybutton-grid">
{% for value in column.value_counts %}
{% for (value, count) in column.value_counts %}
{% set val_id = "{}-freq-value-{}".format(col_id, loop.index0) %}
<div class="box" data-test="frequent-value-{{ loop.index0 }}">
<pre id="{{ val_id }}"
Expand All @@ -99,7 +99,7 @@ <h3 class="margin-r-t">
{% set val_id = "{}-freq-value-list".format(col_id) %}
<div class="box">
<pre id="{{ val_id }}"
data-copy-text="{{ column.value_counts | list }}">{{ column.value_counts.keys() | list }}</pre>
data-copy-text="{{ column.most_frequent_values }}">{{ column.most_frequent_values }}</pre>
{{ buttons.copybutton(val_id) }}
</div>
</div>
Expand Down
9 changes: 8 additions & 1 deletion skrub/_reporting/_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
}


def _is_null(value):
isna = pd.isna(value)
if isinstance(isna, bool):
return isna
return False


def _get_jinja_env():
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(
Expand All @@ -47,7 +54,7 @@ def _get_jinja_env():
"svg_to_img_src",
]:
env.filters[function_name] = getattr(_utils, function_name)
env.filters["is_null"] = pd.isna
env.filters["is_null"] = _is_null
env.globals["random_string"] = random_string
return env

Expand Down
9 changes: 4 additions & 5 deletions skrub/_reporting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ def value_counts(value_counts, n_unique, n_rows, color=COLOR_0):
Parameters
----------
value_counts : dict
The keys are values, and values are counts. Must be sorted from most to
least frequent.
value_counts : list
Pairs of (value, count). Must be sorted from most to least frequent.
n_unique : int
Cardinality of the plotted column, used to determine if all unique
Expand All @@ -129,8 +128,8 @@ def value_counts(value_counts, n_unique, n_rows, color=COLOR_0):
str
The plot as a XML string.
"""
values = [_utils.ellide_string_short(s) for s in value_counts.keys()][::-1]
counts = list(value_counts.values())[::-1]
values = [_utils.ellide_string_short(v) for v, _ in value_counts][::-1]
counts = [c for _, c in value_counts][::-1]
if n_unique > len(value_counts):
title = f"{len(value_counts)} most frequent"
else:
Expand Down
3 changes: 2 additions & 1 deletion skrub/_reporting/_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ def _add_value_counts(summary, column, *, dataframe_summary, with_plots):
summary["unique_proportion"] = n_unique / max(1, dataframe_summary["n_rows"])
summary["high_cardinality"] = n_unique >= _HIGH_CARDINALITY_THRESHOLD
summary["value_counts"] = value_counts
summary["most_frequent_values"] = [v for v, _ in value_counts]

if n_unique == 1:
summary["value_is_constant"] = True
summary["constant_value"] = next(iter(value_counts.keys()))
summary["constant_value"] = value_counts[0][0]
else:
summary["value_is_constant"] = False
if with_plots:
Expand Down
2 changes: 1 addition & 1 deletion skrub/_reporting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def top_k_value_counts(column, k):
n_unique = sbd.shape(counts)[0]
counts = sbd.sort(counts, by="count", descending=True)
counts = sbd.slice(counts, k)
return n_unique, dict(zip(*to_dict(counts).values()))
return n_unique, list(zip(*to_dict(counts).values()))


def quantiles(column):
Expand Down
3 changes: 2 additions & 1 deletion skrub/_reporting/tests/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def test_summarize(monkeypatch, df_module, air_quality, order_by, with_plots):
"plot_names": [],
"position": 0,
"unique_proportion": 0.118,
"value_counts": {"London": 8, "Paris": 9},
"value_counts": [("Paris", 9), ("London", 8)],
"most_frequent_values": ["Paris", "London"],
"value_is_constant": False,
}

Expand Down
7 changes: 7 additions & 0 deletions skrub/_reporting/tests/test_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,10 @@ def test_empty_dataframe(df_module):
def test_open(pd_module, browser_mock):
TableReport(pd_module.example_dataframe, title="the title").open()
assert b"the title" in browser_mock.content


def test_non_hashable_values(df_module):
# non-regression test for #1066
df = df_module.make_dataframe(dict(a=[[1, 2, 3], None, [4]]))
html = TableReport(df).html()
assert "[1, 2, 3]" in html

0 comments on commit a33b9ea

Please sign in to comment.