Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mask= argument in LocBody #566

Merged
merged 13 commits into from
Jan 22, 2025
Merged
78 changes: 70 additions & 8 deletions great_tables/_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ class LocBody(Loc):
rows
The rows to target. Can either be a single row name or a series of row names provided in a
list.
mask
The cells to target. If the underlying wrapped DataFrame is a Polars DataFrame,
you can pass a Polars expression for cell-based selection. This argument must be used
exclusively and cannot be combined with the `columns=` or `rows=` arguments.

:::{.callout-warning}
`mask=` is still experimental.
:::

Returns
-------
Expand Down Expand Up @@ -539,6 +547,7 @@ class LocBody(Loc):

columns: SelectExpr = None
rows: RowSelectExpr = None
mask: PlExpr | None = None


@dataclass
Expand Down Expand Up @@ -823,6 +832,52 @@ def resolve_rows_i(
)


def resolve_mask(
data: GTData | list[str],
expr: PlExpr,
excl_stub: bool = True,
excl_group: bool = True,
) -> list[tuple[int, int, str]]:
"""Return data for creating `CellPos`, based on expr"""
if not isinstance(expr, PlExpr):
raise ValueError("Only Polars expressions can be passed to the `mask` argument.")

frame: PlDataFrame = data._tbl_data
frame_cols = frame.columns

stub_var = data._boxhead.vars_from_type(ColInfoTypeEnum.stub)
group_var = data._boxhead.vars_from_type(ColInfoTypeEnum.row_group)
cols_excl = [*(stub_var if excl_stub else []), *(group_var if excl_group else [])]

# `df.select()` raises `ColumnNotFoundError` if columns are missing from the original DataFrame.
masked = frame.select(expr).drop(cols_excl, strict=False)

# Validate that `masked.columns` exist in the `frame_cols`
missing = set(masked.columns) - set(frame_cols)
if missing:
raise ValueError(
"The `mask` expression produces extra columns, with names not in the original DataFrame."
f"\n\nExtra columns: {missing}"
)

# Validate that row lengths are equal
if masked.height != frame.height:
raise ValueError(
"The DataFrame length after applying `mask` differs from the original."
"\n\n* Original length: {frame.height}"
"\n* Mask length: {masked.height}"
)

cellpos_data: list[tuple[int, int, str]] = [] # column, row, colname for `CellPos`
col_idx_map = {colname: frame_cols.index(colname) for colname in frame_cols}
for row_idx, row_dict in enumerate(masked.iter_rows(named=True)):
for colname, value in row_dict.items():
if value: # select only when `value` is True
col_idx = col_idx_map[colname]
cellpos_data.append((col_idx, row_idx, colname))
return cellpos_data


# Resolve generic ======================================================================


Expand Down Expand Up @@ -868,15 +923,22 @@ def _(loc: LocStub, data: GTData) -> set[int]:

@resolve.register
def _(loc: LocBody, data: GTData) -> list[CellPos]:
cols = resolve_cols_i(data=data, expr=loc.columns)
rows = resolve_rows_i(data=data, expr=loc.rows)

# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
]
if (loc.columns is not None or loc.rows is not None) and loc.mask is not None:
raise ValueError(
"Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."
)

if loc.mask is None:
rows = resolve_rows_i(data=data, expr=loc.rows)
jrycw marked this conversation as resolved.
Show resolved Hide resolved
cols = resolve_cols_i(data=data, expr=loc.columns)
# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
]
else:
cellpos_data = resolve_mask(data=data, expr=loc.mask)
cell_pos = [CellPos(*cellpos) for cellpos in cellpos_data]
return cell_pos


Expand Down
76 changes: 76 additions & 0 deletions tests/test_tab_create_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from great_tables._locations import LocBody
from great_tables._styles import CellStyleFill
from great_tables._tab_create_modify import tab_style
from polars import selectors as cs


@pytest.fixture
def gt():
return GT(pd.DataFrame({"x": [1, 2], "y": [4, 5]}))


@pytest.fixture
def gt2():
return GT(pl.DataFrame({"x": [1, 2], "y": [4, 5]}))


def test_tab_style(gt: GT):
style = CellStyleFill(color="blue")
new_gt = tab_style(gt, style, LocBody(["x"], [0]))
Expand Down Expand Up @@ -71,3 +77,73 @@ def test_tab_style_font_from_column():

assert rendered_html.find('<td style="font-family: Helvetica;" class="gt_row gt_right">1</td>')
assert rendered_html.find('<td style="font-family: Courier;" class="gt_row gt_right">2</td>')


def test_tab_style_loc_body_mask(gt2: GT):
style = CellStyleFill(color="blue")
new_gt = tab_style(gt2, style, LocBody(mask=cs.numeric().gt(1.5)))

assert len(gt2._styles) == 0
assert len(new_gt._styles) == 3

xy_0y, xy_1x, xy_1y = new_gt._styles

assert xy_0y.styles[0] is style
assert xy_1x.styles[0] is style
assert xy_1y.styles[0] is style

assert xy_0y.rownum == 0
assert xy_0y.colname == "y"

assert xy_1x.rownum == 1
assert xy_1x.colname == "x"

assert xy_1y.rownum == 1
assert xy_1y.colname == "y"


def test_tab_style_loc_body_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = cs.numeric().gt(1.5)
err_msg = "Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(columns=["x"], mask=mask))
assert err_msg in exc_info.value.args[0]

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(rows=[0], mask=mask))

assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_not_polars_expression_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = "fake expression"
err_msg = "Only Polars expressions can be passed to the `mask` argument."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_columns_not_inside_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = pl.len()
err_msg = (
"The `mask` expression produces extra columns, with names not in the original DataFrame."
)

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]


def test_tab_style_loc_body_mask_rows_not_equal_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = pl.len().alias("x")
err_msg = "The DataFrame length after applying `mask` differs from the original."

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
assert err_msg in exc_info.value.args[0]
Loading