From d9c1e0e1c5858143ecf2d6648dbc5e690b3cedbb Mon Sep 17 00:00:00 2001 From: Andreas Eisenbarth Date: Mon, 8 Jul 2024 21:01:46 +0200 Subject: [PATCH] Draft SpatialData.filter() --- CHANGELOG.md | 2 + src/spatialdata/_core/spatialdata.py | 77 +++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b5b5d3b3..fd792ee5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning][]. ## [0.x.x] - 2024-xx-xx +- Added `SpatialData.filter()` method for subsetting by `obs` and `var` @aeisenbarth + ## [0.2.1] - 2024-07-04 ### Minor diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 41de30e0..b4e0ddb1 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -3,11 +3,12 @@ import hashlib import os import warnings -from collections.abc import Generator +from collections.abc import Generator, Iterable from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Any, Literal +import numpy as np import pandas as pd import zarr from anndata import AnnData @@ -2063,6 +2064,80 @@ def subset( ) return SpatialData(**elements_dict, tables=tables) + def filter( + self, + elements: Iterable[str] | None = None, + regions: Iterable[str] | None = None, + tables: Iterable[str] | None = None, + obs_keys: Iterable[str] | None = None, + var_keys: Iterable[str] | None = None, + var_names: Iterable[str] | None = None, + layers: Iterable[str] | None = None, + ) -> SpatialData: + """ + Filter a SpatialData object to contain only specified elements or table entries. + + Parameters + ---------- + elements: Names of elements to include. Defaults to []. + regions: Regions to include in the table. Defaults to regions of all selected elements. + tables: Names of tables to include. Defaults to ["table"]. + obs_keys: Names of obs columns to include. Defaults to []. + var_keys: Names of var columns to include. Defaults to []. + var_names: Names of variables (X columns) to include. Defaults to []. + layers: Names of X layers to include. Defaults to []. + + Returns + ------- + A new SpatialData object + """ + elements = [] if elements is None else list(elements) + regions = elements if regions is None else regions + obs_keys = [] if obs_keys is None else obs_keys + var_keys = [] if var_keys is None else var_keys + var_names = [] if var_names is None else list(var_names) # iterable and sized + tables = ["table"] if tables is None else tables + layers = [] if layers is None else layers + + sdata_subset = self.subset(element_names=elements, filter_tables=True) if elements else SpatialData() + # We rely on `subset` returning an unbacked copy, so we don't modifying the original data. + assert not sdata_subset.is_backed() + # Further filtering on the tables + for table_name, table in list(sdata_subset.tables.items()): + if table_name not in tables: + del sdata_subset.tables[table_name] + continue + _, region_key, instance_key = get_table_keys(table) + obs_keys = list(obs_keys) + if instance_key not in obs_keys: + obs_keys.insert(0, instance_key) + if region_key not in obs_keys: + obs_keys.insert(0, region_key) + # Preserve order by checking "isin" instead of slicing. Also guarantees no duplicates. + table_subset = table[ + table.obs[region_key].isin(regions), + table.var_names.isin(var_names), + ] + layers_subset = ( + {key: layer for key, layer in table_subset.layers.items() if key in layers} + if table_subset.layers is not None and len(var_names) > 0 + else None + ) + table_subset = TableModel.parse( + AnnData( + X=table_subset.X if len(var_names) > 0 else None, + obs=table_subset.obs.loc[:, table_subset.obs.columns.isin(obs_keys)], + var=table_subset.var.loc[:, table_subset.var.columns.isin(var_keys)], + layers=layers_subset, + ), + region_key=region_key, + instance_key=instance_key, + region=np.unique(table_subset.obs[region_key]).tolist(), + ) + del sdata_subset.tables[table_name] + sdata_subset.tables[table_name] = table_subset + return sdata_subset + def __getitem__(self, item: str) -> SpatialElement: """ Return the element with the given name.