Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add possibility to simulate images in batches #95

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions examples/notebooks/cosem_multi_sample.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/microsim/schema/detectors/_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ def simulate(
# dark current
avg_dark_e = self.dark_current * exposure_s + self.clock_induced_charge
if not isinstance(avg_dark_e, float):
new_shape = avg_dark_e.shape + (1,) * (detected_photons.ndim - 1)
new_shape = (1,) + avg_dark_e.shape + (1,) * (detected_photons.ndim - 2)
avg_dark_e = np.asarray(avg_dark_e).reshape(new_shape) # type: ignore [assignment]
thermal_electrons = xp.poisson_rvs(avg_dark_e, shape=detected_photons.shape)
total_electrons = detected_photons + thermal_electrons

# cap total electrons to full-well-capacity
total_electrons = xp.minimum(total_electrons, self.full_well)

if binning > 1:
if binning > 1: # TODO: this function might not work with batch dim
total_electrons = self.apply_pre_quantization_binning(
total_electrons, binning
)
Expand All @@ -141,7 +141,7 @@ def simulate(
gray_values = self.quantize_electrons(total_electrons, xp)

# sCMOS binning
if binning > 1:
if binning > 1: # TODO: this function might not work with batch dim
gray_values = self.apply_post_quantization_binning(gray_values, binning)

# ADC saturation
Expand Down
1 change: 1 addition & 0 deletions src/microsim/schema/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Axis(str, Enum):
T = "t" # Time
F = "f" # Fluorophore
W = "w" # Wavelength
S = "s" # Sample

def __repr__(self) -> str:
return f"<Axis.{self.name}>"
Expand Down
38 changes: 17 additions & 21 deletions src/microsim/schema/modality/_simple_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def psf(

def render(
self,
truth: xrDataArray, # (F, Z, Y, X)
truth: xrDataArray, # (S, F, Z, Y, X)
em_rates: xrDataArray, # (C, F, W)
objective_lens: ObjectiveLens,
settings: Settings,
xp: NumpyAPI,
) -> xrDataArray:
"""Render a 3D image of the truth for F fluorophores, in C channels."""
"""Render a batch of 3D images of truth for F fluorophores, in C channels."""
# for every channel in the emission rates...
channels = []
for ch in em_rates.coords[Axis.C].values:
Expand Down Expand Up @@ -81,12 +81,13 @@ def render(
# stack the fluorophores together to create the channel
channels.append(xp.stack(fluors, axis=0))

return DataArray(
channels = DataArray(
channels,
dims=[Axis.C, Axis.F, Axis.Z, Axis.Y, Axis.X],
dims=[Axis.C, Axis.F, Axis.S, Axis.Z, Axis.Y, Axis.X],
coords={
Axis.C: em_rates.coords[Axis.C],
Axis.F: truth.coords[Axis.F],
Axis.S: truth.coords[Axis.S],
Axis.Z: truth.coords[Axis.Z],
Axis.Y: truth.coords[Axis.Y],
Axis.X: truth.coords[Axis.X],
Expand All @@ -95,8 +96,10 @@ def render(
"space": truth.attrs["space"],
"objective": objective_lens,
"units": "photons",
"long_name": "Optical Image",
},
)
return channels.transpose(Axis.S, ...) # put batch dim first

def _summed_weighted_psf(
self,
Expand Down Expand Up @@ -152,6 +155,8 @@ def _summed_weighted_psf(
xp=xp,
)
summed_psf += psf * weight

summed_psf = summed_psf[None, ...] # add batch dim
return summed_psf # type: ignore [no-any-return]


Expand Down Expand Up @@ -200,7 +205,7 @@ class Identity(_PSFModality):

def render(
self,
truth: xrDataArray, # (F, Z, Y, X)
truth: xrDataArray, # (S, F, Z, Y, X)
em_rates: xrDataArray, # (C, F, W)
*args: Any,
**kwargs: Any,
Expand All @@ -211,23 +216,14 @@ def render(
already convolved with the PSF. Therefore, we simply compute the emission flux
for each fluorophore and each channel.
"""
em_image = em_rates.sum(Axis.W) * truth
return DataArray(
em_image,
dims=[Axis.C, Axis.F, Axis.Z, Axis.Y, Axis.X],
coords={
Axis.C: em_rates.coords[Axis.C],
Axis.F: truth.coords[Axis.F],
Axis.Z: truth.coords[Axis.Z],
Axis.Y: truth.coords[Axis.Y],
Axis.X: truth.coords[Axis.X],
},
attrs={
"space": truth.attrs["space"],
"objective": "",
"units": "photons",
},
em_image = (em_rates.sum(Axis.W) * truth).transpose(Axis.S, ...)
em_image.attrs.update(
units="photons",
objective="",
space=truth.attrs["space"],
long_name="Optical Image",
)
return em_image


def bin_spectrum(
Expand Down
2 changes: 2 additions & 0 deletions src/microsim/schema/sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def cache_path(self) -> tuple[str, ...] | None:

def render(self, space: xrDataArray, xp: NumpyAPI | None = None) -> xrDataArray:
"""Render the fluorophore distribution into the given space."""
# This would need to change, as with additional batch dim there'd be a mismatch
# between distribution and space dims
dist = self.distribution.render(space, xp)
if isinstance(self.concentration, float | int):
return dist * self.concentration
Expand Down
133 changes: 88 additions & 45 deletions src/microsim/schema/simulation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Annotated
from typing import TYPE_CHECKING, Annotated, Any

import numpy as np
import pandas as pd
import xarray as xr
from pydantic import AfterValidator, Field, model_validator
from annotated_types import MinLen
from pydantic import AfterValidator, Field, field_validator, model_validator

from microsim._data_array import ArrayProtocol, from_cache, to_cache
from microsim.util import microsim_cache
Expand Down Expand Up @@ -53,7 +54,7 @@ class Simulation(SimBaseModel):

truth_space: Space
output_space: Space | None = None
sample: Sample
samples: Annotated[list[Sample], MinLen(1)]
modality: Modality = Field(default_factory=Widefield)
objective_lens: ObjectiveLens = Field(default_factory=ObjectiveLens)
channels: list[OpticalConfig] = Field(default_factory=lambda: [FITC])
Expand Down Expand Up @@ -93,6 +94,28 @@ def _resolve_spaces(self) -> "Self":
self.output_space.reference = self.truth_space
return self

@model_validator(mode="after")
def _check_fluorophores_equal_in_samples(self) -> "Self":
fp_names = [{lbl.fluorophore.name for lbl in s.labels} for s in self.samples]
if len({frozenset(s) for s in fp_names}) != 1:
raise ValueError(
"All samples in the batch must use the same set of fluorophores."
)
Comment on lines +99 to +103
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about this. I can see that this would be necessary in the case where we assume that the sample "axis" is just an extension of a non-jagged array. But, maybe sample should be special in that case, and maybe samples should always be iterated over in a for-loop rather than processed in a vectorized fashion? (might help with memory bloat too)

that can be decided later. It's definitely easier to start with this restriction and relax it later.

return self

@field_validator("samples")
def _samples_to_list(value: Any) -> list[Any]:
return [value] if not isinstance(value, list | tuple) else value

@property
def sample(self) -> Sample:
warnings.warn(
"The `sample` attribute is deprecated. Use `samples` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.samples[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we don't have to worry about the case where there are zero samples right? sample was previously a required field, so there should always be a sample?

If that's the case, we might want to explicitly require that in the samples field definition. That can be done by using annotated-types

    samples: Annotated[list[Sample], MinLen(1)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with your point. However, I don't fully get the relation with the lines you pointed your comment to


@property
def _xp(self) -> "NumpyAPI":
return self.settings.backend_module()
Expand All @@ -101,7 +124,7 @@ def ground_truth(self) -> xr.DataArray:
"""Return the ground truth data.

Returns position and quantity of fluorophores in the sample. The return array
has dimensions (F, Z, Y, X). The units are fluorophores counts.
has dimensions (S, F, Z, Y, X). The units are fluorophores counts.

Examples
--------
Expand All @@ -112,36 +135,51 @@ def ground_truth(self) -> xr.DataArray:
"""
if not hasattr(self, "_ground_truth"):
xp = self._xp
# make empty space into which we'll add the ground truth
# TODO: this is wasteful... label.render should probably
# accept the space object directly
truth = self.truth_space.create(array_creator=xp.zeros)

# render each ground truth
label_data = []
for label in self.sample.labels:
cache_path = self._truth_cache_path(
label, self.truth_space, self.settings.random_seed
)
if self.settings.cache.read and cache_path and cache_path.exists():
data = from_cache(cache_path, xp=xp).astype(
self.settings.float_dtype
)
logging.info(
f"Loaded ground truth for {label} from cache: {cache_path}"

# render each ground sample
# TODO: iterating over sample is a bottleneck for batched approach... Ideas?
truths: list[xr.DataArray] = []
for sample in self.samples:
# render each label in the sample

# make empty space into which we'll add the ground truth
# TODO: this is wasteful... label.render should probably
# accept the space object directly
truth = self.truth_space.create(array_creator=xp.zeros)

label_data = []
for label in sample.labels:
# TODO: differentiate caching for each label
cache_path = self._truth_cache_path(
label, self.truth_space, self.settings.random_seed
)
else:
data = label.render(truth, xp=xp)
if self.settings.cache.write and cache_path:
to_cache(data, cache_path, dtype=np.uint16)

label_data.append(data)

# concat along the F axis
fluors = [lbl.fluorophore for lbl in self.sample.labels]
truth = xr.concat(label_data, dim=pd.Index(fluors, name=Axis.F))
truth.attrs.update(units="fluorophores", long_name="Ground Truth")
self._ground_truth = truth
if self.settings.cache.read and cache_path and cache_path.exists():
data = from_cache(cache_path, xp=xp).astype(
self.settings.float_dtype
)
logging.info(
f"Loaded ground truth for {label} from cache: {cache_path}"
)
else:
data = label.render(truth, xp=xp)
if self.settings.cache.write and cache_path:
to_cache(data, cache_path, dtype=np.uint16)

label_data.append(data)

# concat along the F axis
fluors = [lbl.fluorophore for lbl in sample.labels]
truth = xr.concat(label_data, dim=pd.Index(fluors, name=Axis.F))
truths.append(truth)

# concat along B axis
self._ground_truth = xr.concat(
truths, dim=pd.Index(range(len(truths)), name=Axis.S)
) # TODO: is there a better way to give coords to this axis?
self._ground_truth.attrs.update(
units="fluorophores", long_name="Ground Truth"
)

return self._ground_truth

def filtered_emission_rates(self) -> xr.DataArray:
Expand Down Expand Up @@ -183,7 +221,7 @@ def emission_flux(self) -> xr.DataArray:

This multiplies the per-fluorophore emission rates by the ground truth data to
get the total emission flux for each voxel in the ground truth. The return
array has dimensions (C, F, Z, Y, X). The units are photons/s.
array has dimensions (S, C, F, Z, Y, X). The units are photons/s.

Note, this integrates over all wavelengths. For finer control over the emission
spectrum, you may wish to directly combine `filtered_emission_rates` with the
Expand All @@ -202,10 +240,12 @@ def emission_flux(self) -> xr.DataArray:
return truth

# total photons/s emitted by each fluorophore in each channel
total_flux = self.filtered_emission_rates().sum(Axis.W) * truth
total_flux = (self.filtered_emission_rates().sum(Axis.W) * truth).transpose(
Axis.S, ...
)
total_flux.attrs.update(units="photon/sec", long_name="Emission Flux")

# (C, F, Z, Y, X)
# (S, C, F, Z, Y, X)
return total_flux

def optical_image_per_fluor(self) -> xr.DataArray:
Expand All @@ -214,11 +254,11 @@ def optical_image_per_fluor(self) -> xr.DataArray:
This is the emission from each fluorophore in each channel, after filtering by
the optical configuration and convolution with the PSF.

The return array has dimensions (C, F, Z, Y, X). The units are photons/s.
The return array has dimensions (S, C, F, Z, Y, X). The units are photons/s.
"""
# (C, F, Z, Y, X)
# (S, C, F, Z, Y, X)
return self.modality.render(
self.ground_truth(), # (F, Z, Y, X)
self.ground_truth(), # (S, F, Z, Y, X)
self.filtered_emission_rates(), # (C, F, W)
objective_lens=self.objective_lens,
settings=self.settings,
Expand All @@ -230,10 +270,10 @@ def optical_image(self) -> xr.DataArray:

This is the same as `optical_image_per_fluor`, but sums the contributions of all
fluorophores in each channel (which a detector would not know). The return
array has dimensions (C, Z, Y, X). The units are photons/s.
array has dimensions (S, C, Z, Y, X). The units are photons/s.
"""
oipf = self.optical_image_per_fluor()
return oipf.sum(Axis.F) # (C, Z, Y, X)
return oipf.sum(Axis.F) # (S, C, Z, Y, X)

def digital_image(
self,
Expand All @@ -245,13 +285,16 @@ def digital_image(
"""Return the digital image as captured by the detector.

This down-scales the optical image to the output space, and simulates the
detector response. The return array has dimensions (C, Z, Y, X). The units
are gray values, based on the bit-depth of the detector. If there is no
detector response. The return array has dimensions (S, C, [F], Z, Y, X). The
units are gray values, based on the bit-depth of the detector. If there is no
detector or `with_detector_noise` is False, the units are simply photons.

NOTE: the input `optical_image` can only contain a fluorophore dimension (e.g.,
it can be the output of `optical_image_per_fluor()`).
"""
if optical_image is None:
optical_image = self.optical_image()
image = optical_image # (C, Z, Y, X)
image = optical_image # (S, C, [F], Z, Y, X)

# downscale to output space
# TODO: consider how we would integrate detector pixel size
Expand Down Expand Up @@ -280,7 +323,7 @@ def digital_image(
image = image * (ch_exposures / 1000)
image.attrs.update(units="photons")

# (C, Z, Y, X)
# (S, C, Z, Y, X)
return image

def run(self) -> xr.DataArray:
Expand Down
Loading