diff --git a/docs/tutorials/api_usage.ipynb b/docs/tutorials/api_usage.ipynb index 5584c341..bcfb6885 100644 --- a/docs/tutorials/api_usage.ipynb +++ b/docs/tutorials/api_usage.ipynb @@ -17,7 +17,9 @@ "source": [ "## Save the `SpatialData` object\n", "\n", - "For this tutorial, we use a generated dataset. The command below will generate and save it on disk (you can change the path `tuto.zarr` to save it somewhere else).\n", + "For this tutorial, we use a generated dataset. You can expect a total runtime of a few minutes.\n", + "\n", + "The command below will generate and save it on disk (you can change the path `tuto.zarr` to save it somewhere else).\n", "\n", "See the commented lines below to load your own data, or see the [`sopa.io` API](../../api/io)." ] diff --git a/docs/tutorials/cli_usage.md b/docs/tutorials/cli_usage.md index d71011ad..a1625be6 100644 --- a/docs/tutorials/cli_usage.md +++ b/docs/tutorials/cli_usage.md @@ -2,7 +2,9 @@ Here, we provide a minimal example of command line usage. For more details and t ## Save the `SpatialData` object -For this tutorial, we use a generated dataset. The command below will generate and save it on disk (you can change the path `tuto.zarr` to save it somewhere else). If you want to load your own data: choose the right panel below, or see the [`sopa read` CLI documentation](`../../cli/#sopa-read`). +For this tutorial, we use a generated dataset. You can expect a total runtime of a few minutes. + +The command below will generate and save it on disk (you can change the path `tuto.zarr` to save it somewhere else). If you want to load your own data: choose the right panel below, or see the [`sopa read` CLI documentation](`../../cli/#sopa-read`). === "Tutorial" ```sh diff --git a/sopa/embedding/patches.py b/sopa/embedding/patches.py index ae15d78d..7865b7fe 100644 --- a/sopa/embedding/patches.py +++ b/sopa/embedding/patches.py @@ -1,7 +1,14 @@ +from __future__ import annotations + import logging -from typing import Callable -import dask as da +try: + import torch +except ImportError: + raise ImportError( + "For patch embedding, you need `torch` (and perhaps `torchvision`). Consider installing the sopa WSI extra: `pip install 'sopa[wsi]'`" + ) + import numpy as np import tqdm from multiscale_spatial_image import MultiscaleSpatialImage @@ -10,6 +17,7 @@ from spatialdata.models import Image2DModel from spatialdata.transformations import Scale +import sopa.embedding.models as models from sopa._constants import SopaKeys from sopa._sdata import get_intrinsic_cs, get_key from sopa.segmentation import Patches2D @@ -30,13 +38,23 @@ def _get_best_level_for_downsample( def _get_extraction_parameters( - tiff_metadata: dict, magnification: int, patch_width: int + tiff_metadata: dict, + patch_width: int, + level: int | None, + magnification: int | None, ) -> tuple[int, int, int, bool]: """ Given the metadata for the slide, a target magnification and a patch width, it returns the best scale to get it from (level), a resize factor (resize_factor) and the corresponding patch size at scale0 (patch_width) """ + if level is None and magnification is None: + log.warn("Both level and magnification arguments are None. Using level=0 by default.") + level = 0 + + if magnification is None: + return level, 1, patch_width, True + if tiff_metadata["properties"].get("tiffslide.objective-power"): objective_power = int(tiff_metadata["properties"].get("tiffslide.objective-power")) downsample = objective_power / magnification @@ -56,61 +74,102 @@ def _get_extraction_parameters( return level, resize_factor, patch_width, True -def _numpy_patch( - image: MultiscaleSpatialImage, - box: tuple[int, int, int, int], - level: int, - resize_factor: float, - coordinate_system: str, -) -> np.ndarray: - """Extract a numpy patch from the MultiscaleSpatialImage given a bounding box""" - import cv2 +class Embedder: + def __init__( + self, + image: MultiscaleSpatialImage | SpatialImage, + model_name: str, + patch_width: int, + level: int | None = 0, + magnification: int | None = None, + device: str = "cpu", + ): + self.image = image + self.model_name = model_name + self.device = device + + self.cs = get_intrinsic_cs(None, image) + + tiff_metadata = image.attrs.get("metadata", {}) + self.level, self.resize_factor, self.patch_width, success = _get_extraction_parameters( + tiff_metadata, patch_width, level, magnification + ) + if not success: + log.error("Error retrieving the image mpp, skipping tile embedding.") + return False + + assert hasattr( + models, model_name + ), f"'{model_name}' is not a valid model name under `sopa.embedding.models`. Valid names are: {', '.join(models.__all__)}" + + self.model: torch.nn.Module = getattr(models, model_name)() + self.model.eval().to(device) + + def _resize(self, patch: np.ndarray): + import cv2 + + patch = patch.transpose(1, 2, 0) + dim = ( + int(patch.shape[0] * self.resize_factor), + int(patch.shape[1] * self.resize_factor), + ) + patch = cv2.resize(patch, dim) + return patch.transpose(2, 0, 1) - multiscale_patch = bounding_box_query( - image, ("y", "x"), box[:2][::-1], box[2:][::-1], coordinate_system - ) - patch = np.array( - next(iter(multiscale_patch[f"scale{level}"].values())).transpose("y", "x", "c") - ) + def _torch_patch( + self, + box: tuple[int, int, int, int], + ) -> np.ndarray: + """Extract a numpy patch from the MultiscaleSpatialImage given a bounding box""" + image_patch = bounding_box_query( + self.image, ("y", "x"), box[:2][::-1], box[2:][::-1], self.cs + ) - if resize_factor != 1: - dim = (int(patch.shape[0] * resize_factor), int(patch.shape[1] * resize_factor)) - patch = cv2.resize(patch, dim) + if isinstance(self.image, MultiscaleSpatialImage): + image_patch = next(iter(image_patch[f"scale{self.level}"].values())) - return patch.transpose(2, 0, 1) + patch = image_patch.compute().data + if self.resize_factor != 1: + patch = self._resize(patch) -def embed_batch(model_name: str, device: str) -> tuple[Callable, int]: - import torch + return torch.tensor(patch / 255.0, dtype=torch.float32) + + def _torch_batch(self, bboxes: np.ndarray): + batch = [self._torch_patch(box) for box in bboxes] + + max_y = max(img.shape[1] for img in batch) + max_x = max(img.shape[2] for img in batch) + + def _pad(patch: torch.Tensor, max_y: int, max_x: int) -> torch.Tensor: + pad_x, pad_y = max_x - patch.shape[2], max_y - patch.shape[1] + return torch.nn.functional.pad(patch, (0, pad_x, 0, pad_y), value=0) - import sopa.embedding.models as models + return torch.stack([_pad(patch, max_y, max_x) for patch in batch]) - assert hasattr( - models, model_name - ), f"'{model_name}' is not a valid model name under `sopa.embedding.models`. Valid names are: {', '.join(models.__all__)}" + @property + def output_dim(self): + return self.model.output_dim - model: torch.nn.Module = getattr(models, model_name)() - model.eval().to(device) + @torch.no_grad() + def embed_bboxes(self, bboxes: np.ndarray) -> torch.Tensor: + patches = self._torch_batch(bboxes) # shape (B,3,Y,X) - def _(patch: np.ndarray): - torch_patch = torch.tensor(patch / 255.0, dtype=torch.float32) - if len(torch_patch.shape) == 3: - torch_patch = torch_patch.unsqueeze(0) - with torch.no_grad(): - embedding = model(torch_patch.to(device)).squeeze() - return embedding.cpu() + if len(patches.shape) == 3: + patches = patches.unsqueeze(0) - return _, model.output_dim + embedding = self.model(patches.to(self.device)).squeeze() + return embedding.cpu() # shape (B * output_dim) def embed_wsi_patches( sdata: SpatialData, model_name: str, - magnification: int, patch_width: int, + level: int | None = 0, + magnification: int | None = None, image_key: str | None = None, batch_size: int = 32, - num_workers: int = 1, device: str = "cpu", ) -> SpatialImage | bool: """Create an image made of patch embeddings of a WSI image. @@ -121,11 +180,11 @@ def embed_wsi_patches( Args: sdata: A `SpatialData` object model_name: Name of the computer vision model to be used. One of `Resnet50Features`, `HistoSSLFeatures`, or `DINOv2Features`. - magnification: The target magnification. patch_width: Width of the patches for which the embeddings will be computed. + level: Image level on which the embedding is performed. Either `level` or `magnification` should be provided. + magnification: The target magnification on which the embedding is performed. If `magnification` is provided, the `level` argument will be automatically computed. image_key: Optional image key of the WSI image, unecessary if there is only one image. batch_size: Mini-batch size used during inference. - num_workers: Number of workers used to extract patches. device: Device used for the computer vision model. Returns: @@ -134,47 +193,28 @@ def embed_wsi_patches( image_key = get_key(sdata, "images", image_key) image = sdata.images[image_key] - assert isinstance( - image, MultiscaleSpatialImage - ), "Only `MultiscaleSpatialImage` images are supported" + embedder = Embedder(image, model_name, patch_width, level, magnification, device) - tiff_metadata = image.attrs["metadata"] - coordinate_system = get_intrinsic_cs(sdata, image) + patches = Patches2D(sdata, image_key, embedder.patch_width, 0) + embedding_image = np.zeros((embedder.output_dim, *patches.shape), dtype=np.float32) - embedder, output_dim = embed_batch(model_name=model_name, device=device) - - level, resize_factor, patch_width, success = _get_extraction_parameters( - tiff_metadata, magnification, patch_width - ) - if not success: - log.error(f"Error retrieving the mpp for {image_key}, skipping tile embedding.") - return False - - patches = Patches2D(sdata, image_key, patch_width, 0) - embedding_image = np.zeros((output_dim, *patches.shape), dtype=np.float32) - - log.info(f"Computing {len(patches)} embeddings at level {level}") + log.info(f"Computing {len(patches)} embeddings at level {embedder.level}") for i in tqdm.tqdm(range(0, len(patches), batch_size)): - patch_boxes = patches.bboxes[i : i + batch_size] - - get_batches = [ - da.delayed(_numpy_patch)(image, box, level, resize_factor, coordinate_system) - for box in patch_boxes - ] - batch = da.compute(*get_batches, num_workers=num_workers) - embedding = embedder(np.stack(batch)) + embedding = embedder.embed_bboxes(patches.bboxes[i : i + batch_size]) - loc_x, loc_y = patches.ilocs[i : i + len(batch)].T + loc_x, loc_y = patches.ilocs[i : i + len(embedding)].T embedding_image[:, loc_y, loc_x] = embedding.T embedding_image = SpatialImage(embedding_image, dims=("c", "y", "x")) embedding_image = Image2DModel.parse( embedding_image, - transformations={coordinate_system: Scale([patch_width, patch_width], axes=("x", "y"))}, + transformations={ + embedder.cs: Scale([embedder.patch_width, embedder.patch_width], axes=("x", "y")) + }, ) - embedding_image.coords["y"] = patch_width * embedding_image.coords["y"] - embedding_image.coords["x"] = patch_width * embedding_image.coords["x"] + embedding_image.coords["y"] = embedder.patch_width * embedding_image.coords["y"] + embedding_image.coords["x"] = embedder.patch_width * embedding_image.coords["x"] embedding_key = f"sopa_{model_name}" sdata.add_image(embedding_key, embedding_image)