Skip to content

Commit

Permalink
Merge pull request #56 from gustaveroussy/dev
Browse files Browse the repository at this point in the history
overlay segmentation #55
  • Loading branch information
quentinblampey authored Apr 22, 2024
2 parents abfa6f2 + 2546693 commit 580c8e6
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 17 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
## [1.0.11] - 2024-xx-xx

### Added
- Can overlay a custom segmentation (merge boundaries)
- Xenium Explorer selection(s) can be added as shapes in a SpatialData object

### Changed
- Rename `Aggregator.update_table` to `Aggregator.compute_table`

## [1.0.10] - 2024-04-08

### Added
Expand Down
4 changes: 4 additions & 0 deletions docs/api/io.explorer.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@
::: sopa.io.explorer.save_column_csv
options:
show_root_heading: true

::: sopa.io.explorer.add_explorer_selection
options:
show_root_heading: true
4 changes: 4 additions & 0 deletions docs/api/segmentation/aggregate.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
::: sopa.segmentation.aggregate.overlay_segmentation
options:
show_root_heading: true

::: sopa.segmentation.aggregate.average_channels
options:
show_root_heading: true
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/api_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@
"source": [
"aggregator = sopa.segmentation.Aggregator(sdata, image_key=image_key, shapes_key=shapes_key)\n",
"\n",
"aggregator.update_table(gene_column=gene_column, average_intensities=True)"
"aggregator.compute_table(gene_column=gene_column, average_intensities=True)"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions sopa/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ class SopaKeys:
BAYSOR_BOUNDARIES = "baysor_boundaries"
PATCHES = "sopa_patches"
TABLE = "table"
OLD_TABLE = "old_table"
CELL_OVERLAY_KEY = "is_overlay"

BOUNDS = "bboxes"
PATCHES_ILOCS = "ilocs"
Expand Down
2 changes: 1 addition & 1 deletion sopa/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def aggregate(
sdata = read_zarr_standardized(sdata_path, warn=True)

aggregator = Aggregator(sdata, image_key=image_key, shapes_key=method_name)
aggregator.update_table(
aggregator.compute_table(
gene_column, average_intensities, expand_radius_ratio, min_transcripts, min_intensity_ratio
)

Expand Down
2 changes: 1 addition & 1 deletion sopa/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .explorer import write, align
from .explorer import write, align, add_explorer_selection
from .standardize import write_standardized
from .reader.cosmx import cosmx
from .reader.merscope import merscope
Expand Down
2 changes: 1 addition & 1 deletion sopa/io/explorer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .table import write_cell_categories, write_gene_counts, save_column_csv
from .shapes import write_polygons
from .converter import write, write_metadata
from .utils import str_cell_id, int_cell_id
from .utils import str_cell_id, int_cell_id, add_explorer_selection
53 changes: 53 additions & 0 deletions sopa/io/explorer/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
from __future__ import annotations

from pathlib import Path

import geopandas as gpd
import pandas as pd
from shapely import Polygon
from spatialdata import SpatialData
from spatialdata.models import ShapesModel
from spatialdata.transformations import get_transformation

from ..._sdata import get_element


def explorer_file_path(path: str, filename: str, is_dir: bool):
path: Path = Path(path)
Expand Down Expand Up @@ -28,3 +39,45 @@ def str_cell_id(cell_id: int) -> str:
cell_id, coef = divmod(cell_id, 16)
coefs.append(coef)
return "".join([chr(97 + coef) for coef in coefs][::-1]) + "-1"


def _selection_to_polygon(df, pixel_size):
return Polygon(df[["X", "Y"]].values / pixel_size)


def xenium_explorer_selection(
path: str | Path, pixel_size: float = 0.2125, return_list: bool = False
) -> Polygon:
df = pd.read_csv(path, skiprows=2)

if "Selection" not in df:
polygon = _selection_to_polygon(df, pixel_size)
return [polygon] if return_list else polygon

return [_selection_to_polygon(sub_df, pixel_size) for _, sub_df in df.groupby("Selection")]


def add_explorer_selection(
sdata: SpatialData,
path: str,
shapes_key: str,
image_key: str | None = None,
pixel_size: float = 0.2125,
):
"""After saving a selection on the Xenium Explorer, it will add all polygons inside `sdata.shapes[shapes_key]`
Args:
sdata: A `SpatialData` object
path: The path to the `coordinates.csv` selection file
shapes_key: The name to provide to the shapes
image_key: The original image name
pixel_size: Number of microns in a pixel. It must be the same value as the one used in `sopa.io.write`
"""
polys = xenium_explorer_selection(path, pixel_size=pixel_size, return_list=True)
image = get_element(sdata, "images", image_key)

transformations = get_transformation(image, get_all=True).copy()

sdata.shapes[shapes_key] = ShapesModel.parse(
gpd.GeoDataFrame(geometry=polys), transformations=transformations
)
2 changes: 1 addition & 1 deletion sopa/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import shapes, aggregate, methods, patching, stainings
from .patching import Patches2D, BaysorPatches
from .aggregate import Aggregator
from .aggregate import Aggregator, overlay_segmentation
from .stainings import StainingSegmentation
132 changes: 120 additions & 12 deletions sopa/segmentation/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from functools import partial

import anndata
import dask.array as da
import dask.dataframe as dd
import geopandas as gpd
Expand All @@ -26,15 +27,53 @@
get_item,
get_spatial_image,
save_shapes,
save_table,
to_intrinsic,
)
from .._sdata import save_table as _save_table
from .._sdata import to_intrinsic
from ..io.explorer.utils import str_cell_id
from . import shapes

log = logging.getLogger(__name__)


def overlay_segmentation(
sdata: SpatialData,
shapes_key: str,
gene_column: str | None = None,
area_ratio_threshold: float = 0.25,
image_key: str | None = None,
save_table: bool = False,
):
"""Overlay a segmentation on top of an existing segmentation
Args:
sdata: A `SpatialData` object
shapes_key: The key of the new shapes to be added
gene_column: Key of the points dataframe containing the genes names
area_ratio_threshold: Threshold between 0 and 1. For each original cell overlapping with a new cell, we compute the overlap-area/cell-area, if above the threshold the cell is removed.
image_key: Optional key of the original image
save_table: Whether to save the new table on-disk or not
"""
average_intensities = False

if "table" in sdata.tables and SopaKeys.UNS_KEY in sdata.tables["table"].uns:
sopa_attrs = sdata.tables["table"].uns[SopaKeys.UNS_KEY]

if sopa_attrs[SopaKeys.UNS_HAS_TRANSCRIPTS]:
assert gene_column is not None, "Need 'gene_column' argument to count transcripts"
else:
gene_column = gene_column
average_intensities = sopa_attrs[SopaKeys.UNS_HAS_INTENSITIES]

aggr = Aggregator(sdata, image_key=image_key, shapes_key=shapes_key)
aggr.overlay_segmentation(
gene_column=gene_column,
average_intensities=average_intensities,
area_ratio_threshold=area_ratio_threshold,
save_table=save_table,
)


class Aggregator:
"""Perform transcript count and channel averaging over a `SpatialData` object"""

Expand Down Expand Up @@ -64,14 +103,61 @@ def __init__(
self.geo_df = self.sdata[shapes_key]

self.table = None
if SopaKeys.TABLE in sdata.tables:
table = sdata.tables[SopaKeys.TABLE]
if len(self.geo_df) != table.n_obs:
log.warn("Not using existing table (aggregating on a different number of cells)")
else:
if SopaKeys.TABLE in self.sdata.tables:
table = self.sdata.tables[SopaKeys.TABLE]
if len(self.geo_df) == table.n_obs:
log.info("Using existing table for aggregation")
self.table = table

def standardize_table(self):
def overlay_segmentation(
self,
gene_column: str | None = None,
average_intensities: bool = True,
area_ratio_threshold: float = 0.25,
save_table: bool = True,
):
old_table: AnnData = self.sdata.tables[SopaKeys.TABLE]
self.sdata.tables[SopaKeys.OLD_TABLE] = old_table
del self.sdata.tables[SopaKeys.TABLE]

old_shapes_key = old_table.uns["spatialdata_attrs"]["region"]
instance_key = old_table.uns["spatialdata_attrs"]["instance_key"]

if isinstance(old_shapes_key, list):
assert (
len(old_shapes_key) == 1
), "Can't overlap segmentation on multi-region SpatialData object"
old_shapes_key = old_shapes_key[0]

old_geo_df = self.sdata[old_shapes_key]
geo_df = to_intrinsic(self.sdata, self.geo_df, old_geo_df)

gdf_join = gpd.sjoin(old_geo_df, geo_df)
gdf_join["geometry_right"] = gdf_join["index_right"].map(lambda i: geo_df.geometry.iloc[i])
gdf_join["overlap_ratio"] = gdf_join.apply(_overlap_area_ratio, axis=1)
gdf_join: gpd.GeoDataFrame = gdf_join[gdf_join.overlap_ratio >= area_ratio_threshold]

table_crop = old_table[~np.isin(old_table.obs[instance_key], gdf_join.index)].copy()
table_crop.obs[SopaKeys.CELL_OVERLAY_KEY] = False

self.compute_table(
gene_column=gene_column, average_intensities=average_intensities, save_table=False
)
self.table.obs[SopaKeys.CELL_OVERLAY_KEY] = True

self.table = anndata.concat(
[table_crop, self.table], uns_merge="first", join="outer", fill_value=0
)
_fillna(self.table.obs)

self.shapes_key = f"{old_shapes_key}+{self.shapes_key}"
geo_df_cropped = old_geo_df.loc[~old_geo_df.index.isin(gdf_join.index)]
self.geo_df = pd.concat([geo_df_cropped, geo_df], join="outer", axis=0)
self.geo_df.attrs = old_geo_df.attrs

self.standardized_table(save_table=save_table)

def standardized_table(self, save_table: bool = True):
self.table.obs_names = list(map(str_cell_id, range(self.table.n_obs)))

self.geo_df.index = list(self.table.obs_names)
Expand Down Expand Up @@ -102,7 +188,9 @@ def standardize_table(self):
)

self.sdata.tables[SopaKeys.TABLE] = self.table
save_table(self.sdata, SopaKeys.TABLE)

if save_table:
_save_table(self.sdata, SopaKeys.TABLE)

def filter_cells(self, where_filter: np.ndarray):
log.info(f"Filtering {where_filter.sum()} cells")
Expand All @@ -114,13 +202,18 @@ def filter_cells(self, where_filter: np.ndarray):
if self.table is not None:
self.table = self.table[~where_filter]

def update_table(
def update_table(self, *args, **kwargs):
log.warn("'update_table' is deprecated, use 'compute_table' instead")
self.compute_table(*args, **kwargs)

def compute_table(
self,
gene_column: str | None = None,
average_intensities: bool = True,
expand_radius_ratio: float = 0,
min_transcripts: int = 0,
min_intensity_ratio: float = 0,
save_table: bool = True,
):
"""Perform aggregation and update the spatialdata table
Expand All @@ -130,6 +223,7 @@ def update_table(
expand_radius_ratio: Cells polygons will be expanded by `expand_radius_ratio * mean_radius` for channels averaging **only**. This help better aggregate boundary stainings
min_transcripts: Minimum amount of transcript to keep a cell
min_intensity_ratio: Cells whose mean channel intensity is less than `min_intensity_ratio * quantile_90` will be filtered
save_table: Whether the table should be saved on disk or not
"""
does_count = (
self.table is not None and isinstance(self.table.X, csr_matrix)
Expand Down Expand Up @@ -183,7 +277,21 @@ def update_table(
SopaKeys.UNS_HAS_INTENSITIES: average_intensities,
}

self.standardize_table()
self.standardized_table(save_table=save_table)


def _overlap_area_ratio(row) -> float:
poly: Polygon = row["geometry"]
poly_right: Polygon = row["geometry_right"]
return poly.intersection(poly_right).area / poly.area


def _fillna(df: pd.DataFrame):
for key in df:
if df[key].dtype == "category":
df[key] = df[key].cat.add_categories("NA").fillna("NA")
else:
df[key] = df[key].fillna(0)


def average_channels(
Expand Down Expand Up @@ -315,7 +423,7 @@ def _count_transcripts_aligned(

X = coo_matrix((len(geo_df), len(gene_names)), dtype=int)
adata = AnnData(X=X, var=pd.DataFrame(index=gene_names))
adata.obs_names = geo_df.index
adata.obs_names = geo_df.index.astype(str)

geo_df = geo_df.reset_index()

Expand Down

0 comments on commit 580c8e6

Please sign in to comment.