Skip to content

Commit

Permalink
Adding attrs at the SpatialData object level (scverse#711)
Browse files Browse the repository at this point in the history
* read/write attrs to/from disk

* pass attrs to queries

* use same strategy as anndata to concat attrs

* import callable from collections

* add SpatialData.write_attrs method

* fixed tests

* fix initializers, deprecated from_elements_dict()

* fixes around attrs behavior; tests

* added root-level SpatialData versioning

* writing spatialdata software version in .zattrs root-level metadata

---------

Co-authored-by: Luca Marconato <[email protected]>
  • Loading branch information
quentinblampey and LucaMarconato authored Dec 1, 2024
1 parent 72dbffd commit 02396ea
Show file tree
Hide file tree
Showing 14 changed files with 224 additions and 48 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ and this project adheres to [Semantic Versioning][].

## [0.2.3] - 2024-09-25

### Major

- Added attributes at the SpatialData object level (`.attrs`)

### Minor

- Added `clip: bool = False` parameter to `polygon_query()` #670
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/_core/_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def _(sdata: SpatialData) -> SpatialData:
elements_dict = {}
for _, element_name, element in sdata.gen_elements():
elements_dict[element_name] = deepcopy(element)
return SpatialData.from_elements_dict(elements_dict)
deepcopied_attrs = _deepcopy(sdata.attrs)
return SpatialData.from_elements_dict(elements_dict, attrs=deepcopied_attrs)


@deepcopy.register(DataArray)
Expand Down
10 changes: 9 additions & 1 deletion src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from copy import copy # Should probably go up at the top
from itertools import chain
from typing import Any
from warnings import warn

import numpy as np
from anndata import AnnData
from anndata._core.merge import StrategiesLiteral, resolve_merge_strategy

from spatialdata._core._utils import _find_common_table_keys
from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -80,6 +81,7 @@ def concatenate(
concatenate_tables: bool = False,
obs_names_make_unique: bool = True,
modify_tables_inplace: bool = False,
attrs_merge: StrategiesLiteral | Callable[[list[dict[Any, Any]]], dict[Any, Any]] | None = None,
**kwargs: Any,
) -> SpatialData:
"""
Expand Down Expand Up @@ -108,6 +110,8 @@ def concatenate(
modify_tables_inplace
Whether to modify the tables in place. If `True`, the tables will be modified in place. If `False`, the tables
will be copied before modification. Copying is enabled by default but can be disabled for performance reasons.
attrs_merge
How the elements of `.attrs` are selected. Uses the same set of strategies as the `uns_merge` argument of [anndata.concat](https://anndata.readthedocs.io/en/latest/generated/anndata.concat.html)
kwargs
See :func:`anndata.concat` for more details.
Expand Down Expand Up @@ -188,12 +192,16 @@ def concatenate(
else:
merged_tables[k] = v

attrs_merge = resolve_merge_strategy(attrs_merge)
attrs = attrs_merge([sdata.attrs for sdata in sdatas])

sdata = SpatialData(
images=merged_images,
labels=merged_labels,
points=merged_points,
shapes=merged_shapes,
tables=merged_tables,
attrs=attrs,
)
if obs_names_make_unique:
for table in sdata.tables.values():
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def transform_to_data_extent(
set_transformation(el, transformation={coordinate_system: Identity()}, set_all=True)
for k, v in sdata.tables.items():
sdata_to_return_elements[k] = v.copy()
return SpatialData.from_elements_dict(sdata_to_return_elements)
return SpatialData.from_elements_dict(sdata_to_return_elements, attrs=sdata.attrs)


def _parse_element(
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def rasterize(
new_labels[new_name] = rasterized
else:
raise RuntimeError(f"Unsupported model {model} detected as return type of rasterize().")
return SpatialData(images=new_images, labels=new_labels, tables=data.tables)
return SpatialData(images=new_images, labels=new_labels, tables=data.tables, attrs=data.attrs)

parsed_data = _parse_element(element=data, sdata=sdata, element_var_name="data", sdata_var_name="sdata")
model = get_model(parsed_data)
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _(
new_elements[element_type][k] = transform(
v, transformation, to_coordinate_system=to_coordinate_system, maintain_positioning=maintain_positioning
)
return SpatialData(**new_elements)
return SpatialData(**new_elements, attrs=data.attrs)


@transform.register(DataArray)
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _(

tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata)

return SpatialData(**new_elements, tables=tables)
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)


@bounding_box_query.register(DataArray)
Expand Down Expand Up @@ -885,7 +885,7 @@ def _(

tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata)

return SpatialData(**new_elements, tables=tables)
return SpatialData(**new_elements, tables=tables, attrs=sdata.attrs)


@polygon_query.register(DataArray)
Expand Down
158 changes: 121 additions & 37 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import os
import warnings
from collections.abc import Generator
from collections.abc import Generator, Mapping
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -122,6 +122,7 @@ def __init__(
points: dict[str, DaskDataFrame] | None = None,
shapes: dict[str, GeoDataFrame] | None = None,
tables: dict[str, AnnData] | Tables | None = None,
attrs: Mapping[Any, Any] | None = None,
) -> None:
self._path: Path | None = None

Expand All @@ -131,6 +132,7 @@ def __init__(
self._points: Points = Points(shared_keys=self._shared_keys)
self._shapes: Shapes = Shapes(shared_keys=self._shared_keys)
self._tables: Tables = Tables(shared_keys=self._shared_keys)
self.attrs = attrs if attrs else {} # type: ignore[assignment]

# Workaround to allow for backward compatibility
if isinstance(tables, AnnData):
Expand Down Expand Up @@ -216,7 +218,9 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None:
)

@staticmethod
def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData:
def from_elements_dict(
elements_dict: dict[str, SpatialElement | AnnData], attrs: Mapping[Any, Any] | None = None
) -> SpatialData:
"""
Create a SpatialData object from a dict of elements.
Expand All @@ -225,38 +229,20 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp
elements_dict
Dict of elements. The keys are the names of the elements and the values are the elements.
A table can be present in the dict, but only at most one; its name is not used and can be anything.
attrs
Additional attributes to store in the SpatialData object.
Returns
-------
The SpatialData object.
"""
d: dict[str, dict[str, SpatialElement] | AnnData | None] = {
"images": {},
"labels": {},
"points": {},
"shapes": {},
"tables": {},
}
for k, e in elements_dict.items():
schema = get_model(e)
if schema in (Image2DModel, Image3DModel):
assert isinstance(d["images"], dict)
d["images"][k] = e
elif schema in (Labels2DModel, Labels3DModel):
assert isinstance(d["labels"], dict)
d["labels"][k] = e
elif schema == PointsModel:
assert isinstance(d["points"], dict)
d["points"][k] = e
elif schema == ShapesModel:
assert isinstance(d["shapes"], dict)
d["shapes"][k] = e
elif schema == TableModel:
assert isinstance(d["tables"], dict)
d["tables"][k] = e
else:
raise ValueError(f"Unknown schema {schema}")
return SpatialData(**d) # type: ignore[arg-type]
warnings.warn(
'This method is deprecated and will be removed in a future release. Use "SpatialData.init_from_elements('
')" instead. For the momment, such methods will be automatically called.',
DeprecationWarning,
stacklevel=2,
)
return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs)

@staticmethod
def get_annotated_regions(table: AnnData) -> str | list[str]:
Expand Down Expand Up @@ -712,7 +698,7 @@ def filter_by_coordinate_system(
set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system
)

return SpatialData(**elements, tables=tables)
return SpatialData(**elements, tables=tables, attrs=self.attrs)

# TODO: move to relational query with refactor
def _filter_tables(
Expand Down Expand Up @@ -954,7 +940,7 @@ def transform_to_coordinate_system(
if element_type not in elements:
elements[element_type] = {}
elements[element_type][element_name] = transformed
return SpatialData(**elements, tables=sdata.tables)
return SpatialData(**elements, tables=sdata.tables, attrs=self.attrs)

def elements_are_self_contained(self) -> dict[str, bool]:
"""
Expand Down Expand Up @@ -1179,7 +1165,8 @@ def write(
self._validate_can_safely_write_to_path(file_path, overwrite=overwrite)

store = parse_url(file_path, mode="w").store
_ = zarr.group(store=store, overwrite=overwrite)
zarr_group = zarr.group(store=store, overwrite=overwrite)
self.write_attrs(zarr_group=zarr_group)
store.close()

for element_type, element_name, element in self.gen_elements():
Expand Down Expand Up @@ -1583,7 +1570,36 @@ def _element_type_and_name_from_element_path(self, element_path: str) -> tuple[s
element_type, element_name = element_path.split("/")
return element_type, element_name

def write_metadata(self, element_name: str | None = None, consolidate_metadata: bool | None = None) -> None:
def write_attrs(self, format: SpatialDataFormat | None = None, zarr_group: zarr.Group | None = None) -> None:
from spatialdata._io.format import _parse_formats

parsed = _parse_formats(formats=format)

store = None

if zarr_group is None:
assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs."
store = parse_url(self.path, mode="r+").store
zarr_group = zarr.group(store=store, overwrite=False)

version = parsed["SpatialData"].spatialdata_format_version
version_specific_attrs = parsed["SpatialData"].attrs_to_dict()
attrs_to_write = {"spatialdata_attrs": {"version": version} | version_specific_attrs} | self.attrs

try:
zarr_group.attrs.put(attrs_to_write)
except TypeError as e:
raise TypeError("Invalid attribute in SpatialData.attrs") from e

if store is not None:
store.close()

def write_metadata(
self,
element_name: str | None = None,
consolidate_metadata: bool | None = None,
write_attrs: bool = True,
) -> None:
"""
Write the metadata of a single element, or of all elements, to the Zarr store, without rewriting the data.
Expand Down Expand Up @@ -1618,6 +1634,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata:
# TODO: write .uns['spatialdata_attrs'] metadata for AnnData.
# TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame.

if write_attrs:
self.write_attrs()

if consolidate_metadata is None and self.has_consolidated_metadata():
consolidate_metadata = True
if consolidate_metadata:
Expand Down Expand Up @@ -2103,9 +2122,11 @@ def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | A
return found[0]

@classmethod
@_deprecation_alias(table="tables", version="0.1.0")
def init_from_elements(
cls, elements: dict[str, SpatialElement], tables: AnnData | dict[str, AnnData] | None = None
cls,
elements: dict[str, SpatialElement],
tables: AnnData | dict[str, AnnData] | None = None,
attrs: Mapping[Any, Any] | None = None,
) -> SpatialData:
"""
Create a SpatialData object from a dict of named elements and an optional table.
Expand All @@ -2116,6 +2137,8 @@ def init_from_elements(
A dict of named elements.
tables
An optional table or dictionary of tables
attrs
Additional attributes to store in the SpatialData object.
Returns
-------
Expand All @@ -2130,11 +2153,33 @@ def init_from_elements(
element_type = "labels"
elif model == PointsModel:
element_type = "points"
elif model == TableModel:
element_type = "tables"
else:
assert model == ShapesModel
element_type = "shapes"
elements_dict.setdefault(element_type, {})[name] = element
return cls(**elements_dict, tables=tables)
# when the "tables" argument is removed, we can remove all this if block
if tables is not None:
warnings.warn(
'The "tables" argument is deprecated and will be removed in a future version. Please '
"specifies the tables in the `elements` argument. Until the removal occurs, the `elements` "
"variable will be automatically populated with the tables if the `tables` argument is not None.",
DeprecationWarning,
stacklevel=2,
)
if "tables" in elements_dict:
raise ValueError(
"The tables key is already present in the elements dictionary. Please do not specify "
"the `tables` argument."
)
elements_dict["tables"] = {}
if isinstance(tables, AnnData):
elements_dict["tables"]["table"] = tables
else:
for name, table in tables.items():
elements_dict["tables"][name] = table
return cls(**elements_dict, attrs=attrs)

def subset(
self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False
Expand Down Expand Up @@ -2173,7 +2218,7 @@ def subset(
include_orphan_tables,
elements_dict=elements_dict,
)
return SpatialData(**elements_dict, tables=tables)
return SpatialData(**elements_dict, tables=tables, attrs=self.attrs)

def __getitem__(self, item: str) -> SpatialElement:
"""
Expand Down Expand Up @@ -2255,6 +2300,45 @@ def __delitem__(self, key: str) -> None:
element_type, _, _ = self._find_element(key)
getattr(self, element_type).__delitem__(key)

@property
def attrs(self) -> dict[Any, Any]:
"""
Dictionary of global attributes on this SpatialData object.
Notes
-----
Operations on SpatialData objects such as `subset()`, `query()`, ..., will pass the `.attrs` by
reference. If you want to modify the `.attrs` without affecting the original object, you should
either use `copy.deepcopy(sdata.attrs)` or eventually copy the SpatialData object using
`spatialdata.deepcopy()`.
"""
return self._attrs

@attrs.setter
def attrs(self, value: Mapping[Any, Any]) -> None:
"""
Set the global attributes on this SpatialData object.
Parameters
----------
value
The new attributes to set.
Notes
-----
If a dict is passed, the attrs will be passed by reference, else if a mapping is passed,
the mapping will be casted to a dict (shallow copy), i.e. if the mapping contains a dict inside,
that dict will be passed by reference.
"""
if isinstance(value, dict):
# even if we call dict(value), we still get a shallow copy. For example, dict({'a': {'b': 1}}) will return
# a new dict, {'b': 1} is passed by reference. For this reason, we just pass .attrs by reference, which is
# more performant. The user can always use copy.deepcopy(sdata.attrs), or spatialdata.deepcopy(sdata), to
# get the attrs deepcopied.
self._attrs = value
else:
self._attrs = dict(value)


class QueryManager:
"""Perform queries on SpatialData objects."""
Expand Down
Loading

0 comments on commit 02396ea

Please sign in to comment.