From 9b0f19db306f9dc91e50887ca464c6b629f5f512 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 4 Sep 2024 18:48:03 -0700 Subject: [PATCH 1/5] remove None --- src/spatialdata/_core/query/relational_query.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 76de711c..b16f9bea 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -145,8 +145,8 @@ def _( # TODO: replace function use throughout repo by `join_sdata_spatialelement_table` def _filter_table_by_elements( - table: AnnData | None, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False -) -> AnnData | None: + table: AnnData, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False +) -> AnnData: """ Filter an AnnData table to keep only the rows that are in the elements. @@ -168,8 +168,6 @@ def _filter_table_by_elements( assert any( len(elements) > 0 for elements in elements_dict.values() ), "elements_dict must contain at least one dict which contains at least one element" - if table is None: - return None to_keep = np.zeros(len(table), dtype=bool) region_key = table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] instance_key = table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] From e935dc0f0f8544e3008d51e8e2e22dd5bf9b74d8 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 4 Sep 2024 18:55:50 -0700 Subject: [PATCH 2/5] refactor --- .../_core/query/relational_query.py | 81 ++++++++++++------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index b16f9bea..03cf24b3 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -163,40 +163,59 @@ def _filter_table_by_elements( ------- The filtered table (eventually with reordered rows), or None if the input table was None. """ - assert set(elements_dict.keys()).issubset({"images", "labels", "shapes", "points"}) - assert len(elements_dict) > 0, "elements_dict must not be empty" - assert any( - len(elements) > 0 for elements in elements_dict.values() - ), "elements_dict must contain at least one dict which contains at least one element" + + def _validate_elements_dict(elements_dict: dict[str, dict[str, Any]]) -> None: + assert set(elements_dict.keys()).issubset({"images", "labels", "shapes", "points"}) + assert len(elements_dict) > 0, "elements_dict must not be empty" + assert any( + len(elements) > 0 for elements in elements_dict.values() + ), "elements_dict must contain at least one dict which contains at least one element" + + def _get_table_keys(table: AnnData) -> tuple[str, str]: + return ( + table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY], + table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY], + ) + + def _get_element_instances(element: SpatialElement) -> ArrayLike | None: + if get_model(element) in [Labels2DModel, Labels3DModel]: + if isinstance(element, DataArray): + instances = da.unique(element.data).compute() + else: + assert isinstance(element, DataTree) + v = element["scale0"].values() + assert len(v) == 1 + xdata = next(iter(v)) + instances = da.unique(xdata.data).compute() + return np.sort(instances) + if get_model(element) == ShapesModel: + return element.index.to_numpy() + if get_model(element) == PointsModel: + return element.compute().index.to_numpy() + return None + + def _get_matching_indices( + table: AnnData, region_key: str, instance_key: str, name: str, instances: ArrayLike + ) -> ArrayLike: + return ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy() + + def _filter_table(table: AnnData, to_keep: ArrayLike) -> AnnData: + table.obs = pd.DataFrame(table.obs) + return table[to_keep, :] + + _validate_elements_dict(elements_dict) to_keep = np.zeros(len(table), dtype=bool) - region_key = table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] - instance_key = table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] - instances = None - for _, elements in elements_dict.items(): + region_key, instance_key = _get_table_keys(table) + + for elements in elements_dict.values(): for name, element in elements.items(): - if get_model(element) == Labels2DModel or get_model(element) == Labels3DModel: - if isinstance(element, DataArray): - # get unique labels value (including 0 if present) - instances = da.unique(element.data).compute() - else: - assert isinstance(element, DataTree) - v = element["scale0"].values() - assert len(v) == 1 - xdata = next(iter(v)) - # can be slow - instances = da.unique(xdata.data).compute() - instances = np.sort(instances) - elif get_model(element) == ShapesModel: - instances = element.index.to_numpy() - elif get_model(element) == PointsModel: - instances = element.compute().index.to_numpy() - else: - continue - indices = ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy() - to_keep = to_keep | indices + instances = _get_element_instances(element) + if instances is not None: + indices = _get_matching_indices(table, region_key, instance_key, name, instances) + to_keep |= indices + original_table = table - table.obs = pd.DataFrame(table.obs) - table = table[to_keep, :] + table = _filter_table(table, to_keep) if match_rows: assert instances is not None assert isinstance(instances, np.ndarray) From a8c0dd3d8b7396f98fdd42b8e06f4cca33d149a0 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 4 Sep 2024 19:12:56 -0700 Subject: [PATCH 3/5] refactor --- src/spatialdata/_core/query/relational_query.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 03cf24b3..1b66aab5 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -145,7 +145,7 @@ def _( # TODO: replace function use throughout repo by `join_sdata_spatialelement_table` def _filter_table_by_elements( - table: AnnData, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False + table: AnnData | list[AnnData], elements_dict: dict[str, dict[str, Any]], match_rows: bool = False ) -> AnnData: """ Filter an AnnData table to keep only the rows that are in the elements. @@ -216,6 +216,7 @@ def _filter_table(table: AnnData, to_keep: ArrayLike) -> AnnData: original_table = table table = _filter_table(table, to_keep) + if match_rows: assert instances is not None assert isinstance(instances, np.ndarray) From 28fd89139c2bc08aa98e1f105d70de41416bac51 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 9 Sep 2024 17:39:16 -0700 Subject: [PATCH 4/5] fix element assertion --- .../_core/query/relational_query.py | 32 ++++--------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 1b66aab5..f43f4ee5 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -110,8 +110,10 @@ def get_element_instances( def _( element: DataArray | DataTree, return_background: bool = False, -) -> pd.Index: +) -> pd.Index | None: model = get_model(element) + if model in [Image2DModel, Image3DModel]: + return None assert model in [Labels2DModel, Labels3DModel], "Expected a `Labels` element. Found an `Image` instead." if isinstance(element, DataArray): # get unique labels value (including 0 if present) @@ -171,29 +173,6 @@ def _validate_elements_dict(elements_dict: dict[str, dict[str, Any]]) -> None: len(elements) > 0 for elements in elements_dict.values() ), "elements_dict must contain at least one dict which contains at least one element" - def _get_table_keys(table: AnnData) -> tuple[str, str]: - return ( - table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY], - table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY], - ) - - def _get_element_instances(element: SpatialElement) -> ArrayLike | None: - if get_model(element) in [Labels2DModel, Labels3DModel]: - if isinstance(element, DataArray): - instances = da.unique(element.data).compute() - else: - assert isinstance(element, DataTree) - v = element["scale0"].values() - assert len(v) == 1 - xdata = next(iter(v)) - instances = da.unique(xdata.data).compute() - return np.sort(instances) - if get_model(element) == ShapesModel: - return element.index.to_numpy() - if get_model(element) == PointsModel: - return element.compute().index.to_numpy() - return None - def _get_matching_indices( table: AnnData, region_key: str, instance_key: str, name: str, instances: ArrayLike ) -> ArrayLike: @@ -205,11 +184,12 @@ def _filter_table(table: AnnData, to_keep: ArrayLike) -> AnnData: _validate_elements_dict(elements_dict) to_keep = np.zeros(len(table), dtype=bool) - region_key, instance_key = _get_table_keys(table) + _, region_key, instance_key = get_table_keys(table) for elements in elements_dict.values(): for name, element in elements.items(): - instances = _get_element_instances(element) + model = get_model(element) + instances = get_element_instances(element) if instances is not None: indices = _get_matching_indices(table, region_key, instance_key, name, instances) to_keep |= indices From 1d0519e6d70ba55f39c6d222fa0b55dd0526d5c4 Mon Sep 17 00:00:00 2001 From: giovp Date: Sun, 15 Sep 2024 19:27:59 -0700 Subject: [PATCH 5/5] fix bug of filtering --- .../_core/query/relational_query.py | 25 ------------------- src/spatialdata/_core/spatialdata.py | 11 ++++++-- .../operations/test_spatialdata_operations.py | 12 ++++++--- 3 files changed, 18 insertions(+), 30 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index f43f4ee5..9ed18071 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -58,31 +58,6 @@ def get_element_annotators(sdata: SpatialData, element_name: str) -> set[str]: return table_names -def _filter_table_by_element_names(table: AnnData | None, element_names: str | list[str]) -> AnnData | None: - """ - Filter an AnnData table to keep only the rows that are in the coordinate system. - - Parameters - ---------- - table - The table to filter; if None, returns None - element_names - The element_names to keep in the tables obs.region column - - Returns - ------- - The filtered table, or None if the input table was None - """ - if table is None or not table.uns.get(TableModel.ATTRS_KEY): - return None - table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] - region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] - table.obs = pd.DataFrame(table.obs) - table = table[table.obs[region_key].isin(element_names)].copy() - table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist() - return table - - @singledispatch def get_element_instances( element: SpatialElement, diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 1c3affa2..6fa01b48 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -735,10 +735,17 @@ def _filter_tables( continue # each mode here requires paths or elements, using assert here to avoid mypy errors. if by == "cs": - from spatialdata._core.query.relational_query import _filter_table_by_element_names + from spatialdata._core.query.relational_query import _filter_table_by_elements assert element_names is not None - table = _filter_table_by_element_names(table, element_names) + elements_dict = {} + for element_type in ["images", "labels", "shapes", "points"]: + elements = getattr(self, element_type) + if elements: # Check if the dictionary is not empty + elements_dict[element_type] = { + name: elements[name] for name in element_names if name in elements + } + table = _filter_table_by_elements(table, elements_dict=elements_dict) if len(table) != 0: tables[table_name] = table elif by == "elements": diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 8a59147f..19a055a8 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -135,9 +135,15 @@ def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None: from spatialdata.models import TableModel - rng = np.random.default_rng(seed=0) - full_sdata["table"].obs["annotated_shapes"] = rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0]) - adata = full_sdata["table"] + adata = full_sdata["table"].copy() + + circles_instances = full_sdata["circles"].index.values + poly_instances = full_sdata["poly"].index.values + + adata = adata[: len(circles_instances) + len(poly_instances), :].copy() + adata.obs["annotated_shapes"] = ["circles"] * len(circles_instances) + ["poly"] * len(poly_instances) + adata.obs["instance_id"] = np.concatenate([circles_instances, poly_instances]) + del adata.uns[TableModel.ATTRS_KEY] del full_sdata.tables["table"] full_sdata.table = TableModel.parse(