Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mask= argument in LocBody #566

Merged
merged 13 commits into from
Jan 22, 2025
Merged
69 changes: 62 additions & 7 deletions great_tables/_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@
rows
The rows to target. Can either be a single row name or a series of row names provided in a
list.
mask
The cells to target. If the underlying wrapped DataFrame is a Polars DataFrame,
you can pass a Polars expression for cell-based selection. This argument must be used
exclusively and cannot be combined with the `columns=` or `rows=` arguments.

:::{.callout-warning}
`mask=` is still experimental.
:::

Returns
-------
Expand Down Expand Up @@ -539,6 +547,7 @@

columns: SelectExpr = None
rows: RowSelectExpr = None
mask: PlExpr | None = None


@dataclass
Expand Down Expand Up @@ -823,6 +832,45 @@
)


def resolve_mask(
data: GTData | list[str],
expr: PlExpr,
excl_stub: bool = True,
excl_group: bool = True,
) -> list[tuple[int, int, str]]:
"""Return data for creating `CellPos`, based on expr"""
if not isinstance(expr, PlExpr):
raise ValueError("Only Polars expressions can be passed to the `mask` argument.")

frame: PlDataFrame = data._tbl_data
frame_cols = frame.columns

stub_var = data._boxhead.vars_from_type(ColInfoTypeEnum.stub)
group_var = data._boxhead.vars_from_type(ColInfoTypeEnum.row_group)
cols_excl = [*(stub_var if excl_stub else []), *(group_var if excl_group else [])]

# `df.select()` raises `ColumnNotFoundError` if columns are missing from the original DataFrame.
masked = frame.select(expr)
jrycw marked this conversation as resolved.
Show resolved Hide resolved
for col_excl in cols_excl:
masked = masked.drop(col_excl, strict=False)

Check warning on line 855 in great_tables/_locations.py

View check run for this annotation

Codecov / codecov/patch

great_tables/_locations.py#L855

Added line #L855 was not covered by tests

# Validate that `masked.columns` exist in the `frame_cols`
if not (set(masked.columns).issubset(set(frame_cols))):
raise ValueError("The `mask` may reference columns not in the original DataFrame.")
jrycw marked this conversation as resolved.
Show resolved Hide resolved

# Validate that row lengths are equal
if masked.height != frame.height:
raise ValueError("The DataFrame length after applying `mask` differs from the original.")
jrycw marked this conversation as resolved.
Show resolved Hide resolved

cellpos_data: list[tuple[int, int, str]] = [] # column, row, colname for `CellPos`
for row_idx, row_dict in enumerate(masked.iter_rows(named=True)):
for colname, value in row_dict.items():
if value: # select only when `value` is True
col_idx = frame_cols.index(colname)
jrycw marked this conversation as resolved.
Show resolved Hide resolved
cellpos_data.append((col_idx, row_idx, colname))
return cellpos_data


# Resolve generic ======================================================================


Expand Down Expand Up @@ -869,14 +917,21 @@
@resolve.register
def _(loc: LocBody, data: GTData) -> list[CellPos]:
cols = resolve_cols_i(data=data, expr=loc.columns)
rows = resolve_rows_i(data=data, expr=loc.rows)

# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
]
if (loc.columns is not None or loc.rows is not None) and loc.mask is not None:
raise ValueError(
"Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."
)

if loc.mask is None:
rows = resolve_rows_i(data=data, expr=loc.rows)
jrycw marked this conversation as resolved.
Show resolved Hide resolved
# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
]
else:
cellpos_data = resolve_mask(data=data, expr=loc.mask)
cell_pos = [CellPos(*cellpos) for cellpos in cellpos_data]
return cell_pos


Expand Down
74 changes: 74 additions & 0 deletions tests/test_tab_create_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from great_tables._locations import LocBody
from great_tables._styles import CellStyleFill
from great_tables._tab_create_modify import tab_style
from polars import selectors as cs


@pytest.fixture
def gt():
return GT(pd.DataFrame({"x": [1, 2], "y": [4, 5]}))


@pytest.fixture
def gt2():
return GT(pl.DataFrame({"x": [1, 2], "y": [4, 5]}))


def test_tab_style(gt: GT):
style = CellStyleFill(color="blue")
new_gt = tab_style(gt, style, LocBody(["x"], [0]))
Expand Down Expand Up @@ -71,3 +77,71 @@ def test_tab_style_font_from_column():

assert rendered_html.find('<td style="font-family: Helvetica;" class="gt_row gt_right">1</td>')
assert rendered_html.find('<td style="font-family: Courier;" class="gt_row gt_right">2</td>')


def test_tab_style_loc_body_mask(gt2: GT):
style = CellStyleFill(color="blue")
new_gt = tab_style(gt2, style, LocBody(mask=cs.numeric().gt(1.5)))

assert len(gt2._styles) == 0
assert len(new_gt._styles) == 3

xy_0y, xy_1x, xy_1y = new_gt._styles

assert xy_0y.styles[0] is style
assert xy_1x.styles[0] is style
assert xy_1y.styles[0] is style

assert xy_0y.rownum == 0
assert xy_0y.colname == "y"

assert xy_1x.rownum == 1
assert xy_1x.colname == "x"

assert xy_1y.rownum == 1
assert xy_1y.colname == "y"


def test_tab_style_loc_body_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = cs.numeric().gt(1.5)
err_msg = "Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(columns=["x"], mask=mask))
assert err_msg in exc_info.value.args[0]

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(rows=[0], mask=mask))

assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_not_polars_expression_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = "fake expression"
err_msg = "Only Polars expressions can be passed to the `mask` argument."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_columns_not_inside_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = pl.len()
err_msg = "The `mask` may reference columns not in the original DataFrame."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_rows_not_equal_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = pl.len().alias("x")
err_msg = "The DataFrame length after applying `mask` differs from the original."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]
Loading