Skip to content

Commit

Permalink
Apply code review suggestions for the mask= implementation in `LocB…
Browse files Browse the repository at this point in the history
…ody`
  • Loading branch information
jrycw committed Jan 22, 2025
1 parent 4cbce52 commit be7c141
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
23 changes: 15 additions & 8 deletions great_tables/_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,23 +850,30 @@ def resolve_mask(
cols_excl = [*(stub_var if excl_stub else []), *(group_var if excl_group else [])]

# `df.select()` raises `ColumnNotFoundError` if columns are missing from the original DataFrame.
masked = frame.select(expr)
for col_excl in cols_excl:
masked = masked.drop(col_excl, strict=False)
masked = frame.select(expr).drop(cols_excl, strict=False)

# Validate that `masked.columns` exist in the `frame_cols`
if not (set(masked.columns).issubset(set(frame_cols))):
raise ValueError("The `mask` may reference columns not in the original DataFrame.")
missing = set(masked.columns) - set(frame_cols)
if missing:
raise ValueError(
"The `mask` expression produces extra columns, with names not in the original DataFrame."
f"\n\nExtra columns: {missing}"
)

# Validate that row lengths are equal
if masked.height != frame.height:
raise ValueError("The DataFrame length after applying `mask` differs from the original.")
raise ValueError(
"The DataFrame length after applying `mask` differs from the original."
"\n\n* Original length: {frame.height}"
"\n* Mask length: {masked.height}"
)

cellpos_data: list[tuple[int, int, str]] = [] # column, row, colname for `CellPos`
col_idx_map = {colname: frame_cols.index(colname) for colname in frame_cols}
for row_idx, row_dict in enumerate(masked.iter_rows(named=True)):
for colname, value in row_dict.items():
if value: # select only when `value` is True
col_idx = frame_cols.index(colname)
col_idx = col_idx_map[colname]
cellpos_data.append((col_idx, row_idx, colname))
return cellpos_data

Expand Down Expand Up @@ -916,14 +923,14 @@ def _(loc: LocStub, data: GTData) -> set[int]:

@resolve.register
def _(loc: LocBody, data: GTData) -> list[CellPos]:
cols = resolve_cols_i(data=data, expr=loc.columns)
if (loc.columns is not None or loc.rows is not None) and loc.mask is not None:
raise ValueError(
"Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."
)

if loc.mask is None:
rows = resolve_rows_i(data=data, expr=loc.rows)
cols = resolve_cols_i(data=data, expr=loc.columns)
# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
# thing multiple times
cell_pos = [
Expand Down
4 changes: 3 additions & 1 deletion tests/test_tab_create_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def test_tab_style_loc_body_mask_not_polars_expression_raises(gt2: GT):
def test_tab_style_loc_body_mask_columns_not_inside_raises(gt2: GT):
style = CellStyleFill(color="blue")
mask = pl.len()
err_msg = "The `mask` may reference columns not in the original DataFrame."
err_msg = (
"The `mask` expression produces extra columns, with names not in the original DataFrame."
)

with pytest.raises(ValueError) as exc_info:
tab_style(gt2, style, LocBody(mask=mask))
Expand Down

0 comments on commit be7c141

Please sign in to comment.