Skip to content

Commit

Permalink
Merge pull request #33 from gustaveroussy/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
quentinblampey authored Mar 7, 2024
2 parents a3d438e + 7a80581 commit 54e5999
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 74 deletions.
4 changes: 3 additions & 1 deletion docs/tutorials/api_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
"source": [
"## Save the `SpatialData` object\n",
"\n",
"For this tutorial, we use a generated dataset. The command below will generate and save it on disk (you can change the path `tuto.zarr` to save it somewhere else).\n",
"For this tutorial, we use a generated dataset. You can expect a total runtime of a few minutes.\n",
"\n",
"The command below will generate and save it on disk (you can change the path `tuto.zarr` to save it somewhere else).\n",
"\n",
"See the commented lines below to load your own data, or see the [`sopa.io` API](../../api/io)."
]
Expand Down
4 changes: 3 additions & 1 deletion docs/tutorials/cli_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ Here, we provide a minimal example of command line usage. For more details and t

## Save the `SpatialData` object

For this tutorial, we use a generated dataset. The command below will generate and save it on disk (you can change the path `tuto.zarr` to save it somewhere else). If you want to load your own data: choose the right panel below, or see the [`sopa read` CLI documentation](`../../cli/#sopa-read`).
For this tutorial, we use a generated dataset. You can expect a total runtime of a few minutes.

The command below will generate and save it on disk (you can change the path `tuto.zarr` to save it somewhere else). If you want to load your own data: choose the right panel below, or see the [`sopa read` CLI documentation](`../../cli/#sopa-read`).

=== "Tutorial"
```sh
Expand Down
184 changes: 112 additions & 72 deletions sopa/embedding/patches.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from __future__ import annotations

import logging
from typing import Callable

import dask as da
try:
import torch
except ImportError:
raise ImportError(
"For patch embedding, you need `torch` (and perhaps `torchvision`). Consider installing the sopa WSI extra: `pip install 'sopa[wsi]'`"
)

import numpy as np
import tqdm
from multiscale_spatial_image import MultiscaleSpatialImage
Expand All @@ -10,6 +17,7 @@
from spatialdata.models import Image2DModel
from spatialdata.transformations import Scale

import sopa.embedding.models as models
from sopa._constants import SopaKeys
from sopa._sdata import get_intrinsic_cs, get_key
from sopa.segmentation import Patches2D
Expand All @@ -30,13 +38,23 @@ def _get_best_level_for_downsample(


def _get_extraction_parameters(
tiff_metadata: dict, magnification: int, patch_width: int
tiff_metadata: dict,
patch_width: int,
level: int | None,
magnification: int | None,
) -> tuple[int, int, int, bool]:
"""
Given the metadata for the slide, a target magnification and a patch width,
it returns the best scale to get it from (level), a resize factor (resize_factor)
and the corresponding patch size at scale0 (patch_width)
"""
if level is None and magnification is None:
log.warn("Both level and magnification arguments are None. Using level=0 by default.")
level = 0

if magnification is None:
return level, 1, patch_width, True

if tiff_metadata["properties"].get("tiffslide.objective-power"):
objective_power = int(tiff_metadata["properties"].get("tiffslide.objective-power"))
downsample = objective_power / magnification
Expand All @@ -56,61 +74,102 @@ def _get_extraction_parameters(
return level, resize_factor, patch_width, True


def _numpy_patch(
image: MultiscaleSpatialImage,
box: tuple[int, int, int, int],
level: int,
resize_factor: float,
coordinate_system: str,
) -> np.ndarray:
"""Extract a numpy patch from the MultiscaleSpatialImage given a bounding box"""
import cv2
class Embedder:
def __init__(
self,
image: MultiscaleSpatialImage | SpatialImage,
model_name: str,
patch_width: int,
level: int | None = 0,
magnification: int | None = None,
device: str = "cpu",
):
self.image = image
self.model_name = model_name
self.device = device

self.cs = get_intrinsic_cs(None, image)

tiff_metadata = image.attrs.get("metadata", {})
self.level, self.resize_factor, self.patch_width, success = _get_extraction_parameters(
tiff_metadata, patch_width, level, magnification
)
if not success:
log.error("Error retrieving the image mpp, skipping tile embedding.")
return False

assert hasattr(
models, model_name
), f"'{model_name}' is not a valid model name under `sopa.embedding.models`. Valid names are: {', '.join(models.__all__)}"

self.model: torch.nn.Module = getattr(models, model_name)()
self.model.eval().to(device)

def _resize(self, patch: np.ndarray):
import cv2

patch = patch.transpose(1, 2, 0)
dim = (
int(patch.shape[0] * self.resize_factor),
int(patch.shape[1] * self.resize_factor),
)
patch = cv2.resize(patch, dim)
return patch.transpose(2, 0, 1)

multiscale_patch = bounding_box_query(
image, ("y", "x"), box[:2][::-1], box[2:][::-1], coordinate_system
)
patch = np.array(
next(iter(multiscale_patch[f"scale{level}"].values())).transpose("y", "x", "c")
)
def _torch_patch(
self,
box: tuple[int, int, int, int],
) -> np.ndarray:
"""Extract a numpy patch from the MultiscaleSpatialImage given a bounding box"""
image_patch = bounding_box_query(
self.image, ("y", "x"), box[:2][::-1], box[2:][::-1], self.cs
)

if resize_factor != 1:
dim = (int(patch.shape[0] * resize_factor), int(patch.shape[1] * resize_factor))
patch = cv2.resize(patch, dim)
if isinstance(self.image, MultiscaleSpatialImage):
image_patch = next(iter(image_patch[f"scale{self.level}"].values()))

return patch.transpose(2, 0, 1)
patch = image_patch.compute().data

if self.resize_factor != 1:
patch = self._resize(patch)

def embed_batch(model_name: str, device: str) -> tuple[Callable, int]:
import torch
return torch.tensor(patch / 255.0, dtype=torch.float32)

def _torch_batch(self, bboxes: np.ndarray):
batch = [self._torch_patch(box) for box in bboxes]

max_y = max(img.shape[1] for img in batch)
max_x = max(img.shape[2] for img in batch)

def _pad(patch: torch.Tensor, max_y: int, max_x: int) -> torch.Tensor:
pad_x, pad_y = max_x - patch.shape[2], max_y - patch.shape[1]
return torch.nn.functional.pad(patch, (0, pad_x, 0, pad_y), value=0)

import sopa.embedding.models as models
return torch.stack([_pad(patch, max_y, max_x) for patch in batch])

assert hasattr(
models, model_name
), f"'{model_name}' is not a valid model name under `sopa.embedding.models`. Valid names are: {', '.join(models.__all__)}"
@property
def output_dim(self):
return self.model.output_dim

model: torch.nn.Module = getattr(models, model_name)()
model.eval().to(device)
@torch.no_grad()
def embed_bboxes(self, bboxes: np.ndarray) -> torch.Tensor:
patches = self._torch_batch(bboxes) # shape (B,3,Y,X)

def _(patch: np.ndarray):
torch_patch = torch.tensor(patch / 255.0, dtype=torch.float32)
if len(torch_patch.shape) == 3:
torch_patch = torch_patch.unsqueeze(0)
with torch.no_grad():
embedding = model(torch_patch.to(device)).squeeze()
return embedding.cpu()
if len(patches.shape) == 3:
patches = patches.unsqueeze(0)

return _, model.output_dim
embedding = self.model(patches.to(self.device)).squeeze()
return embedding.cpu() # shape (B * output_dim)


def embed_wsi_patches(
sdata: SpatialData,
model_name: str,
magnification: int,
patch_width: int,
level: int | None = 0,
magnification: int | None = None,
image_key: str | None = None,
batch_size: int = 32,
num_workers: int = 1,
device: str = "cpu",
) -> SpatialImage | bool:
"""Create an image made of patch embeddings of a WSI image.
Expand All @@ -121,11 +180,11 @@ def embed_wsi_patches(
Args:
sdata: A `SpatialData` object
model_name: Name of the computer vision model to be used. One of `Resnet50Features`, `HistoSSLFeatures`, or `DINOv2Features`.
magnification: The target magnification.
patch_width: Width of the patches for which the embeddings will be computed.
level: Image level on which the embedding is performed. Either `level` or `magnification` should be provided.
magnification: The target magnification on which the embedding is performed. If `magnification` is provided, the `level` argument will be automatically computed.
image_key: Optional image key of the WSI image, unecessary if there is only one image.
batch_size: Mini-batch size used during inference.
num_workers: Number of workers used to extract patches.
device: Device used for the computer vision model.
Returns:
Expand All @@ -134,47 +193,28 @@ def embed_wsi_patches(
image_key = get_key(sdata, "images", image_key)
image = sdata.images[image_key]

assert isinstance(
image, MultiscaleSpatialImage
), "Only `MultiscaleSpatialImage` images are supported"
embedder = Embedder(image, model_name, patch_width, level, magnification, device)

tiff_metadata = image.attrs["metadata"]
coordinate_system = get_intrinsic_cs(sdata, image)
patches = Patches2D(sdata, image_key, embedder.patch_width, 0)
embedding_image = np.zeros((embedder.output_dim, *patches.shape), dtype=np.float32)

embedder, output_dim = embed_batch(model_name=model_name, device=device)

level, resize_factor, patch_width, success = _get_extraction_parameters(
tiff_metadata, magnification, patch_width
)
if not success:
log.error(f"Error retrieving the mpp for {image_key}, skipping tile embedding.")
return False

patches = Patches2D(sdata, image_key, patch_width, 0)
embedding_image = np.zeros((output_dim, *patches.shape), dtype=np.float32)

log.info(f"Computing {len(patches)} embeddings at level {level}")
log.info(f"Computing {len(patches)} embeddings at level {embedder.level}")

for i in tqdm.tqdm(range(0, len(patches), batch_size)):
patch_boxes = patches.bboxes[i : i + batch_size]

get_batches = [
da.delayed(_numpy_patch)(image, box, level, resize_factor, coordinate_system)
for box in patch_boxes
]
batch = da.compute(*get_batches, num_workers=num_workers)
embedding = embedder(np.stack(batch))
embedding = embedder.embed_bboxes(patches.bboxes[i : i + batch_size])

loc_x, loc_y = patches.ilocs[i : i + len(batch)].T
loc_x, loc_y = patches.ilocs[i : i + len(embedding)].T
embedding_image[:, loc_y, loc_x] = embedding.T

embedding_image = SpatialImage(embedding_image, dims=("c", "y", "x"))
embedding_image = Image2DModel.parse(
embedding_image,
transformations={coordinate_system: Scale([patch_width, patch_width], axes=("x", "y"))},
transformations={
embedder.cs: Scale([embedder.patch_width, embedder.patch_width], axes=("x", "y"))
},
)
embedding_image.coords["y"] = patch_width * embedding_image.coords["y"]
embedding_image.coords["x"] = patch_width * embedding_image.coords["x"]
embedding_image.coords["y"] = embedder.patch_width * embedding_image.coords["y"]
embedding_image.coords["x"] = embedder.patch_width * embedding_image.coords["x"]

embedding_key = f"sopa_{model_name}"
sdata.add_image(embedding_key, embedding_image)
Expand Down

0 comments on commit 54e5999

Please sign in to comment.