Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mask= argument in LocBody #566

Merged
merged 13 commits into from
Jan 22, 2025
Merged
57 changes: 50 additions & 7 deletions great_tables/_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@

columns: SelectExpr = None
rows: RowSelectExpr = None
mask: PlExpr | None = None


@dataclass
Expand Down Expand Up @@ -823,6 +824,41 @@
)


def resolve_mask_i(
data: GTData | list[str],
expr: PlExpr,
excl_stub: bool = True,
excl_group: bool = True,
) -> list[tuple[int, int, str]]:
"""Return data for creating `CellPos`, based on expr"""
if not isinstance(expr, PlExpr):
raise ValueError("Only Polars expressions can be passed to the `mask` argument.")

Check warning on line 835 in great_tables/_locations.py

View check run for this annotation

Codecov / codecov/patch

great_tables/_locations.py#L835

Added line #L835 was not covered by tests

stub_var = data._boxhead.vars_from_type(ColInfoTypeEnum.stub)
group_var = data._boxhead.vars_from_type(ColInfoTypeEnum.row_group)
cols_excl = [*(stub_var if excl_stub else []), *(group_var if excl_group else [])]

frame: PlDataFrame = data._tbl_data
df = frame.select(expr)
Copy link
Collaborator

@machow machow Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the rules for the mask were as follows?:

  • the mask expressions should--when run by frame.select(mask_expr)--return a DataFrame that:
    • columns: has the same or fewer columns (names must match original frame, but could be in a different order).
    • rows: has the same number of rows
  • after running the expression, we validate this right away
  • the mask is assumed to be in the original row order
  • in this case we can capture the column name and row number of each cell (e.g. with enumerate(mask_result.iter_rows())).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for summarizing the rules.

It seems that for the last part, we only need a single loop to gather all the information required by CellPos.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this version?

For prototyping, use assert for validation, which can later be replaced with raise ValueError().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've replaced assert with raise ValueError() and added more test cases.


newer_row_attr, older_row_attr = "with_row_index", "with_row_count"
row_number_colname = "__row_number__"
row_attr = newer_row_attr if hasattr(df, newer_row_attr) else older_row_attr

# Add row numbers after `df.select()`, as the `__row_number__` column type is `UInt32`.
df_with_row_number = getattr(df, row_attr)(name=row_number_colname)

select_columns = [col for col in df.columns if col not in cols_excl]

cellpos_data: list[tuple[int, int, str]] = [] # column, row, colname for `CellPos`
for row_dict in df_with_row_number.iter_rows(named=True):
for col_idx, colname in enumerate(select_columns):
if row_dict[colname]: # select only when `row_dict[colname]` is True
row_idx = row_dict[row_number_colname]
cellpos_data.append((col_idx, row_idx, colname))
return cellpos_data


# Resolve generic ======================================================================


Expand Down Expand Up @@ -869,14 +905,21 @@
@resolve.register
def _(loc: LocBody, data: GTData) -> list[CellPos]:
cols = resolve_cols_i(data=data, expr=loc.columns)
rows = resolve_rows_i(data=data, expr=loc.rows)

# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
]
if (loc.columns is not None or loc.rows is not None) and loc.mask is not None:
raise ValueError(

Check warning on line 909 in great_tables/_locations.py

View check run for this annotation

Codecov / codecov/patch

great_tables/_locations.py#L909

Added line #L909 was not covered by tests
"Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."
)

if loc.mask is None:
rows = resolve_rows_i(data=data, expr=loc.rows)
jrycw marked this conversation as resolved.
Show resolved Hide resolved
# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
]
else:
cellpos_data = resolve_mask_i(data=data, expr=loc.mask)
cell_pos = [CellPos(*cellpos) for cellpos in cellpos_data]
return cell_pos


Expand Down
25 changes: 25 additions & 0 deletions tests/test_tab_create_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from great_tables._locations import LocBody
from great_tables._styles import CellStyleFill
from great_tables._tab_create_modify import tab_style
from polars import selectors as cs


@pytest.fixture
Expand Down Expand Up @@ -33,6 +34,30 @@ def test_tab_style_multiple_columns(gt: GT):
assert new_gt._styles[0].styles[0] is style


def test_tab_style_loc_body_mask():
gt = GT(pl.DataFrame({"x": [1, 2], "y": [4, 5]}))
style = CellStyleFill(color="blue")
new_gt = tab_style(gt, style, LocBody(mask=cs.numeric().gt(1.5)))

assert len(gt._styles) == 0
assert len(new_gt._styles) == 3

xy_0y, xy_1x, xy_1y = new_gt._styles

assert xy_0y.styles[0] is style
assert xy_1x.styles[0] is style
assert xy_1y.styles[0] is style

assert xy_0y.rownum == 0
assert xy_0y.colname == "y"

assert xy_1x.rownum == 1
assert xy_1x.colname == "x"

assert xy_1y.rownum == 1
assert xy_1y.colname == "y"


def test_tab_style_google_font(gt: GT):
new_gt = tab_style(
gt,
Expand Down
Loading