diff --git a/CHANGELOG.md b/CHANGELOG.md index a4185113..6c1deb7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,10 @@ and this project adheres to [Semantic Versioning][]. ## [0.2.3] - 2024-09-25 +### Major + +- Added attributes at the SpatialData object level (`.attrs`) + ### Minor - Added `clip: bool = False` parameter to `polygon_query()` #670 diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index 1ce04409..cb38d46d 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -47,7 +47,8 @@ def _(sdata: SpatialData) -> SpatialData: elements_dict = {} for _, element_name, element in sdata.gen_elements(): elements_dict[element_name] = deepcopy(element) - return SpatialData.from_elements_dict(elements_dict) + deepcopied_attrs = _deepcopy(sdata.attrs) + return SpatialData.from_elements_dict(elements_dict, attrs=deepcopied_attrs) @deepcopy.register(DataArray) diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 4f2e1a6a..6b91e22d 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from copy import copy # Should probably go up at the top from itertools import chain from typing import Any @@ -9,6 +9,7 @@ import numpy as np from anndata import AnnData +from anndata._core.merge import StrategiesLiteral, resolve_merge_strategy from spatialdata._core._utils import _find_common_table_keys from spatialdata._core.spatialdata import SpatialData @@ -80,6 +81,7 @@ def concatenate( concatenate_tables: bool = False, obs_names_make_unique: bool = True, modify_tables_inplace: bool = False, + attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None, **kwargs: Any, ) -> SpatialData: """ @@ -108,6 +110,8 @@ def concatenate( modify_tables_inplace Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables will be copied before modification. Copying is enabled by default but can be disabled for performance reasons. + attrs_merge + How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html) kwargs See :func:`anndata.concat` for more details. @@ -188,12 +192,16 @@ def concatenate( else: merged_tables[k] = v + attrs_merge = resolve_merge_strategy(attrs_merge) + attrs = attrs_merge([sdata.attrs for sdata in sdatas]) + sdata = SpatialData( images=merged_images, labels=merged_labels, points=merged_points, shapes=merged_shapes, tables=merged_tables, + attrs=attrs, ) if obs_names_make_unique: for table in sdata.tables.values(): diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index e78fccd3..00b65500 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -134,7 +134,7 @@ def transform_to_data_extent( set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True) for k, v in sdata.tables.items(): sdata_to_return_elements[k] = v.copy() - return SpatialData.from_elements_dict(sdata_to_return_elements) + return SpatialData.from_elements_dict(sdata_to_return_elements, attrs=sdata.attrs) def _parse_element( diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index 2270cc6e..b7bdb929 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -304,7 +304,7 @@ def rasterize( new_labels[new_name] = rasterized else: raise RuntimeError(f"Unsupported model {model} detected as return type of rasterize().") - return SpatialData(images=new_images, labels=new_labels, tables=data.tables) + return SpatialData(images=new_images, labels=new_labels, tables=data.tables, attrs=data.attrs) parsed_data = _parse_element(element=data, sdata=sdata, element_var_name="data", sdata_var_name="sdata") model = get_model(parsed_data) diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index 85c97537..769e19d2 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -299,7 +299,7 @@ def _( new_elements[element_type][k] = transform( v, transformation, to_coordinate_system=to_coordinate_system, maintain_positioning=maintain_positioning ) - return SpatialData(**new_elements) + return SpatialData(**new_elements, attrs=data.attrs) @transform.register(DataArray) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 83d82659..ecba815e 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -529,7 +529,7 @@ def _( tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata) - return SpatialData(**new_elements, tables=tables) + return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs) @bounding_box_query.register(DataArray) @@ -885,7 +885,7 @@ def _( tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata) - return SpatialData(**new_elements, tables=tables) + return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs) @polygon_query.register(DataArray) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 2e6ad4bc..db9b91ab 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -3,7 +3,7 @@ import hashlib import os import warnings -from collections.abc import Generator +from collections.abc import Generator, Mapping from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Any, Literal @@ -122,6 +122,7 @@ def __init__( points: dict[str, DaskDataFrame] | None = None, shapes: dict[str, GeoDataFrame] | None = None, tables: dict[str, AnnData] | Tables | None = None, + attrs: Mapping[Any, Any] | None = None, ) -> None: self._path: Path | None = None @@ -131,6 +132,7 @@ def __init__( self._points: Points = Points(shared_keys=self._shared_keys) self._shapes: Shapes = Shapes(shared_keys=self._shared_keys) self._tables: Tables = Tables(shared_keys=self._shared_keys) + self.attrs = attrs if attrs else {} # type: ignore[assignment] # Workaround to allow for backward compatibility if isinstance(tables, AnnData): @@ -216,7 +218,9 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None: ) @staticmethod - def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData: + def from_elements_dict( + elements_dict: dict[str, SpatialElement | AnnData], attrs: Mapping[Any, Any] | None = None + ) -> SpatialData: """ Create a SpatialData object from a dict of elements. @@ -225,38 +229,20 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp elements_dict Dict of elements. The keys are the names of the elements and the values are the elements. A table can be present in the dict, but only at most one; its name is not used and can be anything. + attrs + Additional attributes to store in the SpatialData object. Returns ------- The SpatialData object. """ - d: dict[str, dict[str, SpatialElement] | AnnData | None] = { - "images": {}, - "labels": {}, - "points": {}, - "shapes": {}, - "tables": {}, - } - for k, e in elements_dict.items(): - schema = get_model(e) - if schema in (Image2DModel, Image3DModel): - assert isinstance(d["images"], dict) - d["images"][k] = e - elif schema in (Labels2DModel, Labels3DModel): - assert isinstance(d["labels"], dict) - d["labels"][k] = e - elif schema == PointsModel: - assert isinstance(d["points"], dict) - d["points"][k] = e - elif schema == ShapesModel: - assert isinstance(d["shapes"], dict) - d["shapes"][k] = e - elif schema == TableModel: - assert isinstance(d["tables"], dict) - d["tables"][k] = e - else: - raise ValueError(f"Unknown schema {schema}") - return SpatialData(**d) # type: ignore[arg-type] + warnings.warn( + 'This method is deprecated and will be removed in a future release. Use "SpatialData.init_from_elements(' + ')" instead. For the momment, such methods will be automatically called.', + DeprecationWarning, + stacklevel=2, + ) + return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs) @staticmethod def get_annotated_regions(table: AnnData) -> str | list[str]: @@ -712,7 +698,7 @@ def filter_by_coordinate_system( set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system ) - return SpatialData(**elements, tables=tables) + return SpatialData(**elements, tables=tables, attrs=self.attrs) # TODO: move to relational query with refactor def _filter_tables( @@ -954,7 +940,7 @@ def transform_to_coordinate_system( if element_type not in elements: elements[element_type] = {} elements[element_type][element_name] = transformed - return SpatialData(**elements, tables=sdata.tables) + return SpatialData(**elements, tables=sdata.tables, attrs=self.attrs) def elements_are_self_contained(self) -> dict[str, bool]: """ @@ -1179,7 +1165,8 @@ def write( self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) store = parse_url(file_path, mode="w").store - _ = zarr.group(store=store, overwrite=overwrite) + zarr_group = zarr.group(store=store, overwrite=overwrite) + self.write_attrs(zarr_group=zarr_group) store.close() for element_type, element_name, element in self.gen_elements(): @@ -1583,7 +1570,36 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s element_type, element_name = element_path.split("/") return element_type, element_name - def write_metadata(self, element_name: str | None = None, consolidate_metadata: bool | None = None) -> None: + def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr.Group | None = None) -> None: + from spatialdata._io.format import _parse_formats + + parsed = _parse_formats(formats=format) + + store = None + + if zarr_group is None: + assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." + store = parse_url(self.path, mode="r+").store + zarr_group = zarr.group(store=store, overwrite=False) + + version = parsed["SpatialData"].spatialdata_format_version + version_specific_attrs = parsed["SpatialData"].attrs_to_dict() + attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs + + try: + zarr_group.attrs.put(attrs_to_write) + except TypeError as e: + raise TypeError("Invalid attribute in SpatialData.attrs") from e + + if store is not None: + store.close() + + def write_metadata( + self, + element_name: str | None = None, + consolidate_metadata: bool | None = None, + write_attrs: bool = True, + ) -> None: """ Write the metadata of a single element, or of all elements, to the Zarr store, without rewriting the data. @@ -1618,6 +1634,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata: # TODO: write .uns['spatialdata_attrs'] metadata for AnnData. # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. + if write_attrs: + self.write_attrs() + if consolidate_metadata is None and self.has_consolidated_metadata(): consolidate_metadata = True if consolidate_metadata: @@ -2103,9 +2122,11 @@ def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | A return found[0] @classmethod - @_deprecation_alias(table="tables", version="0.1.0") def init_from_elements( - cls, elements: dict[str, SpatialElement], tables: AnnData | dict[str, AnnData] | None = None + cls, + elements: dict[str, SpatialElement], + tables: AnnData | dict[str, AnnData] | None = None, + attrs: Mapping[Any, Any] | None = None, ) -> SpatialData: """ Create a SpatialData object from a dict of named elements and an optional table. @@ -2116,6 +2137,8 @@ def init_from_elements( A dict of named elements. tables An optional table or dictionary of tables + attrs + Additional attributes to store in the SpatialData object. Returns ------- @@ -2130,11 +2153,33 @@ def init_from_elements( element_type = "labels" elif model == PointsModel: element_type = "points" + elif model == TableModel: + element_type = "tables" else: assert model == ShapesModel element_type = "shapes" elements_dict.setdefault(element_type, {})[name] = element - return cls(**elements_dict, tables=tables) + # when the "tables" argument is removed, we can remove all this if block + if tables is not None: + warnings.warn( + 'The "tables" argument is deprecated and will be removed in a future version. Please ' + "specifies the tables in the `elements` argument. Until the removal occurs, the `elements` " + "variable will be automatically populated with the tables if the `tables` argument is not None.", + DeprecationWarning, + stacklevel=2, + ) + if "tables" in elements_dict: + raise ValueError( + "The tables key is already present in the elements dictionary. Please do not specify " + "the `tables` argument." + ) + elements_dict["tables"] = {} + if isinstance(tables, AnnData): + elements_dict["tables"]["table"] = tables + else: + for name, table in tables.items(): + elements_dict["tables"][name] = table + return cls(**elements_dict, attrs=attrs) def subset( self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False @@ -2173,7 +2218,7 @@ def subset( include_orphan_tables, elements_dict=elements_dict, ) - return SpatialData(**elements_dict, tables=tables) + return SpatialData(**elements_dict, tables=tables, attrs=self.attrs) def __getitem__(self, item: str) -> SpatialElement: """ @@ -2255,6 +2300,45 @@ def __delitem__(self, key: str) -> None: element_type, _, _ = self._find_element(key) getattr(self, element_type).__delitem__(key) + @property + def attrs(self) -> dict[Any, Any]: + """ + Dictionary of global attributes on this SpatialData object. + + Notes + ----- + Operations on SpatialData objects such as `subset()`, `query()`, ..., will pass the `.attrs` by + reference. If you want to modify the `.attrs` without affecting the original object, you should + either use `copy.deepcopy(sdata.attrs)` or eventually copy the SpatialData object using + `spatialdata.deepcopy()`. + """ + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Any, Any]) -> None: + """ + Set the global attributes on this SpatialData object. + + Parameters + ---------- + value + The new attributes to set. + + Notes + ----- + If a dict is passed, the attrs will be passed by reference, else if a mapping is passed, + the mapping will be casted to a dict (shallow copy), i.e. if the mapping contains a dict inside, + that dict will be passed by reference. + """ + if isinstance(value, dict): + # even if we call dict(value), we still get a shallow copy. For example, dict({'a': {'b': 1}}) will return + # a new dict, {'b': 1} is passed by reference. For this reason, we just pass .attrs by reference, which is + # more performant. The user can always use copy.deepcopy(sdata.attrs), or spatialdata.deepcopy(sdata), to + # get the attrs deepcopied. + self._attrs = value + else: + self._attrs = dict(value) + class QueryManager: """Perform queries on SpatialData objects.""" diff --git a/src/spatialdata/_io/format.py b/src/spatialdata/_io/format.py index b992c287..abd1700e 100644 --- a/src/spatialdata/_io/format.py +++ b/src/spatialdata/_io/format.py @@ -48,6 +48,20 @@ class SpatialDataFormat(CurrentFormat): pass +class SpatialDataContainerFormatV01(SpatialDataFormat): + @property + def spatialdata_format_version(self) -> str: + return "0.1" + + def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, Any]: + return {} + + def attrs_to_dict(self) -> dict[str, str | dict[str, Any]]: + from spatialdata import __version__ + + return {"spatialdata_software_version": __version__} + + class RasterFormatV01(SpatialDataFormat): """Formatter for raster data.""" @@ -201,6 +215,7 @@ def validate_table( CurrentShapesFormat = ShapesFormatV02 CurrentPointsFormat = PointsFormatV01 CurrentTablesFormat = TablesFormatV01 +CurrentSpatialDataContainerFormats = SpatialDataContainerFormatV01 ShapesFormats = { "0.1": ShapesFormatV01(), @@ -215,6 +230,9 @@ def validate_table( RasterFormats = { "0.1": RasterFormatV01(), } +SpatialDataContainerFormats = { + "0.1": SpatialDataContainerFormatV01(), +} def _parse_formats(formats: SpatialDataFormat | list[SpatialDataFormat] | None) -> dict[str, SpatialDataFormat]: @@ -223,6 +241,7 @@ def _parse_formats(formats: SpatialDataFormat | list[SpatialDataFormat] | None) "shapes": CurrentShapesFormat(), "points": CurrentPointsFormat(), "tables": CurrentTablesFormat(), + "SpatialData": CurrentSpatialDataContainerFormats(), } if formats is None: return parsed @@ -236,6 +255,7 @@ def _parse_formats(formats: SpatialDataFormat | list[SpatialDataFormat] | None) "shapes": False, "points": False, "tables": False, + "SpatialData": False, } def _check_modified(element_type: str) -> None: @@ -256,6 +276,9 @@ def _check_modified(element_type: str) -> None: elif any(isinstance(fmt, type(v)) for v in RasterFormats.values()): _check_modified("raster") parsed["raster"] = fmt + elif any(isinstance(fmt, type(v)) for v in SpatialDataContainerFormats.values()): + _check_modified("SpatialData") + parsed["SpatialData"] = fmt else: raise ValueError(f"Unsupported format {fmt}") return parsed diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 8dc2788a..0be7c8f4 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -138,12 +138,22 @@ def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = Non logger.debug(f"Found {count} elements in {group}") + # read attrs metadata + attrs = f.attrs.asdict() + if "spatialdata_attrs" in attrs: + # when refactoring the read_zarr function into reading componenets separately (and according to the version), + # we can move the code below (.pop()) into attrs_from_dict() + attrs.pop("spatialdata_attrs") + else: + attrs = None + sdata = SpatialData( images=images, labels=labels, points=points, shapes=shapes, tables=tables, + attrs=attrs, ) sdata.path = Path(store) return sdata diff --git a/tests/core/operations/test_rasterize.py b/tests/core/operations/test_rasterize.py index 9ce5618b..ea8a9d63 100644 --- a/tests/core/operations/test_rasterize.py +++ b/tests/core/operations/test_rasterize.py @@ -205,7 +205,7 @@ def test_rasterize_shapes(): ) adata.obs["cat_values"] = adata.obs["cat_values"].astype("category") adata = TableModel.parse(adata, region=element_name, region_key="region", instance_key="instance_id") - sdata = SpatialData.init_from_elements({element_name: gdf[["geometry"]]}, table=adata) + sdata = SpatialData.init_from_elements({element_name: gdf[["geometry"]], "table": adata}) def _rasterize(element: GeoDataFrame, **kwargs) -> SpatialImage: return _rasterize_test_alternative_calls(element=element, sdata=sdata, element_name=element_name, **kwargs) @@ -320,7 +320,7 @@ def test_rasterize_points(): ) adata.obs["gene"] = adata.obs["gene"].astype("category") adata = TableModel.parse(adata, region=element_name, region_key="region", instance_key="instance_id") - sdata = SpatialData.init_from_elements({element_name: ddf[["x", "y"]]}, table=adata) + sdata = SpatialData.init_from_elements({element_name: ddf[["x", "y"]], "table": adata}) def _rasterize(element: DaskDataFrame, **kwargs) -> SpatialImage: return _rasterize_test_alternative_calls(element=element, sdata=sdata, element_name=element_name, **kwargs) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index a0d7ea2c..bf63fc4b 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -389,9 +389,15 @@ def test_no_shared_transformations() -> None: def test_init_from_elements(full_sdata: SpatialData) -> None: + # this first code block needs to be removed when the tables argument is removed from init_from_elements() all_elements = {name: el for _, name, el in full_sdata._gen_elements()} - sdata = SpatialData.init_from_elements(all_elements, table=full_sdata.table) - for element_type in ["images", "labels", "points", "shapes"]: + sdata = SpatialData.init_from_elements(all_elements, tables=full_sdata["table"]) + for element_type in ["images", "labels", "points", "shapes", "tables"]: + assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) + + all_elements = {name: el for _, name, el in full_sdata._gen_elements(include_table=True)} + sdata = SpatialData.init_from_elements(all_elements) + for element_type in ["images", "labels", "points", "shapes", "tables"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) diff --git a/tests/core/test_deepcopy.py b/tests/core/test_deepcopy.py index 8e1c427a..b21cc925 100644 --- a/tests/core/test_deepcopy.py +++ b/tests/core/test_deepcopy.py @@ -1,5 +1,6 @@ from pandas.testing import assert_frame_equal +from spatialdata import SpatialData from spatialdata._core._deepcopy import deepcopy as _deepcopy from spatialdata.testing import assert_spatial_data_objects_are_identical @@ -45,3 +46,18 @@ def test_deepcopy(full_sdata): assert_spatial_data_objects_are_identical(full_sdata, copied) assert_spatial_data_objects_are_identical(full_sdata, copied_again) + + +def test_deepcopy_attrs(points: SpatialData) -> None: + some_attrs = {"a": {"b": 0}} + points.attrs = some_attrs + + # before deepcopy + sub_points = points.subset(["points_0"]) + assert sub_points.attrs is some_attrs + assert sub_points.attrs["a"] is some_attrs["a"] + + # after deepcopy + sub_points_deepcopy = _deepcopy(sub_points) + assert sub_points_deepcopy.attrs is not some_attrs + assert sub_points_deepcopy.attrs["a"] is not some_attrs["a"] diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index cc54fe04..abfe4eaa 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -581,6 +581,30 @@ def test_incremental_io_valid_name(points: SpatialData) -> None: _check_valid_name(points.delete_element_from_disk) +def test_incremental_io_attrs(points: SpatialData) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + my_attrs = {"a": "b", "c": 1} + points.attrs = my_attrs + points.write(f) + + # test that the attributes are written to disk + sdata = SpatialData.read(f) + assert sdata.attrs == my_attrs + + # test incremental io attrs (write_attrs()) + sdata.attrs["c"] = 2 + sdata.write_attrs() + sdata2 = SpatialData.read(f) + assert sdata2.attrs["c"] == 2 + + # test incremental io attrs (write_metadata()) + sdata.attrs["c"] = 3 + sdata.write_metadata() + sdata2 = SpatialData.read(f) + assert sdata2.attrs["c"] == 3 + + cached_sdata_blobs = blobs()