-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add possibility to simulate images in batches #95
base: main
Are you sure you want to change the base?
Changes from all commits
adc3ebd
7b51271
af726d4
c348a82
7887206
d27a44e
f003f39
a23e079
86a20e6
381a95c
e416deb
38954bb
6da8278
9b02463
a9ae073
f468461
2dd3b77
f1d9f71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
import logging | ||
import warnings | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Annotated | ||
from typing import TYPE_CHECKING, Annotated, Any | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
from pydantic import AfterValidator, Field, model_validator | ||
from annotated_types import MinLen | ||
from pydantic import AfterValidator, Field, field_validator, model_validator | ||
|
||
from microsim._data_array import ArrayProtocol, from_cache, to_cache | ||
from microsim.util import microsim_cache | ||
|
@@ -53,7 +54,7 @@ class Simulation(SimBaseModel): | |
|
||
truth_space: Space | ||
output_space: Space | None = None | ||
sample: Sample | ||
samples: Annotated[list[Sample], MinLen(1)] | ||
modality: Modality = Field(default_factory=Widefield) | ||
objective_lens: ObjectiveLens = Field(default_factory=ObjectiveLens) | ||
channels: list[OpticalConfig] = Field(default_factory=lambda: [FITC]) | ||
|
@@ -93,6 +94,28 @@ def _resolve_spaces(self) -> "Self": | |
self.output_space.reference = self.truth_space | ||
return self | ||
|
||
@model_validator(mode="after") | ||
def _check_fluorophores_equal_in_samples(self) -> "Self": | ||
fp_names = [{lbl.fluorophore.name for lbl in s.labels} for s in self.samples] | ||
if len({frozenset(s) for s in fp_names}) != 1: | ||
raise ValueError( | ||
"All samples in the batch must use the same set of fluorophores." | ||
) | ||
return self | ||
|
||
@field_validator("samples") | ||
def _samples_to_list(value: Any) -> list[Any]: | ||
return [value] if not isinstance(value, list | tuple) else value | ||
|
||
@property | ||
def sample(self) -> Sample: | ||
warnings.warn( | ||
"The `sample` attribute is deprecated. Use `samples` instead.", | ||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
return self.samples[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose we don't have to worry about the case where there are zero samples right? If that's the case, we might want to explicitly require that in the samples: Annotated[list[Sample], MinLen(1)] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I totally agree with your point. However, I don't fully get the relation with the lines you pointed your comment to |
||
|
||
@property | ||
def _xp(self) -> "NumpyAPI": | ||
return self.settings.backend_module() | ||
|
@@ -101,7 +124,7 @@ def ground_truth(self) -> xr.DataArray: | |
"""Return the ground truth data. | ||
|
||
Returns position and quantity of fluorophores in the sample. The return array | ||
has dimensions (F, Z, Y, X). The units are fluorophores counts. | ||
has dimensions (S, F, Z, Y, X). The units are fluorophores counts. | ||
|
||
Examples | ||
-------- | ||
|
@@ -112,36 +135,51 @@ def ground_truth(self) -> xr.DataArray: | |
""" | ||
if not hasattr(self, "_ground_truth"): | ||
xp = self._xp | ||
# make empty space into which we'll add the ground truth | ||
# TODO: this is wasteful... label.render should probably | ||
# accept the space object directly | ||
truth = self.truth_space.create(array_creator=xp.zeros) | ||
|
||
# render each ground truth | ||
label_data = [] | ||
for label in self.sample.labels: | ||
cache_path = self._truth_cache_path( | ||
label, self.truth_space, self.settings.random_seed | ||
) | ||
if self.settings.cache.read and cache_path and cache_path.exists(): | ||
data = from_cache(cache_path, xp=xp).astype( | ||
self.settings.float_dtype | ||
) | ||
logging.info( | ||
f"Loaded ground truth for {label} from cache: {cache_path}" | ||
|
||
# render each ground sample | ||
# TODO: iterating over sample is a bottleneck for batched approach... Ideas? | ||
truths: list[xr.DataArray] = [] | ||
for sample in self.samples: | ||
# render each label in the sample | ||
|
||
# make empty space into which we'll add the ground truth | ||
# TODO: this is wasteful... label.render should probably | ||
# accept the space object directly | ||
truth = self.truth_space.create(array_creator=xp.zeros) | ||
|
||
label_data = [] | ||
for label in sample.labels: | ||
# TODO: differentiate caching for each label | ||
cache_path = self._truth_cache_path( | ||
label, self.truth_space, self.settings.random_seed | ||
) | ||
else: | ||
data = label.render(truth, xp=xp) | ||
if self.settings.cache.write and cache_path: | ||
to_cache(data, cache_path, dtype=np.uint16) | ||
|
||
label_data.append(data) | ||
|
||
# concat along the F axis | ||
fluors = [lbl.fluorophore for lbl in self.sample.labels] | ||
truth = xr.concat(label_data, dim=pd.Index(fluors, name=Axis.F)) | ||
truth.attrs.update(units="fluorophores", long_name="Ground Truth") | ||
self._ground_truth = truth | ||
if self.settings.cache.read and cache_path and cache_path.exists(): | ||
data = from_cache(cache_path, xp=xp).astype( | ||
self.settings.float_dtype | ||
) | ||
logging.info( | ||
f"Loaded ground truth for {label} from cache: {cache_path}" | ||
) | ||
else: | ||
data = label.render(truth, xp=xp) | ||
if self.settings.cache.write and cache_path: | ||
to_cache(data, cache_path, dtype=np.uint16) | ||
|
||
label_data.append(data) | ||
|
||
# concat along the F axis | ||
fluors = [lbl.fluorophore for lbl in sample.labels] | ||
truth = xr.concat(label_data, dim=pd.Index(fluors, name=Axis.F)) | ||
truths.append(truth) | ||
|
||
# concat along B axis | ||
self._ground_truth = xr.concat( | ||
truths, dim=pd.Index(range(len(truths)), name=Axis.S) | ||
) # TODO: is there a better way to give coords to this axis? | ||
self._ground_truth.attrs.update( | ||
units="fluorophores", long_name="Ground Truth" | ||
) | ||
|
||
return self._ground_truth | ||
|
||
def filtered_emission_rates(self) -> xr.DataArray: | ||
|
@@ -183,7 +221,7 @@ def emission_flux(self) -> xr.DataArray: | |
|
||
This multiplies the per-fluorophore emission rates by the ground truth data to | ||
get the total emission flux for each voxel in the ground truth. The return | ||
array has dimensions (C, F, Z, Y, X). The units are photons/s. | ||
array has dimensions (S, C, F, Z, Y, X). The units are photons/s. | ||
|
||
Note, this integrates over all wavelengths. For finer control over the emission | ||
spectrum, you may wish to directly combine `filtered_emission_rates` with the | ||
|
@@ -202,10 +240,12 @@ def emission_flux(self) -> xr.DataArray: | |
return truth | ||
|
||
# total photons/s emitted by each fluorophore in each channel | ||
total_flux = self.filtered_emission_rates().sum(Axis.W) * truth | ||
total_flux = (self.filtered_emission_rates().sum(Axis.W) * truth).transpose( | ||
Axis.S, ... | ||
) | ||
total_flux.attrs.update(units="photon/sec", long_name="Emission Flux") | ||
|
||
# (C, F, Z, Y, X) | ||
# (S, C, F, Z, Y, X) | ||
return total_flux | ||
|
||
def optical_image_per_fluor(self) -> xr.DataArray: | ||
|
@@ -214,11 +254,11 @@ def optical_image_per_fluor(self) -> xr.DataArray: | |
This is the emission from each fluorophore in each channel, after filtering by | ||
the optical configuration and convolution with the PSF. | ||
|
||
The return array has dimensions (C, F, Z, Y, X). The units are photons/s. | ||
The return array has dimensions (S, C, F, Z, Y, X). The units are photons/s. | ||
""" | ||
# (C, F, Z, Y, X) | ||
# (S, C, F, Z, Y, X) | ||
return self.modality.render( | ||
self.ground_truth(), # (F, Z, Y, X) | ||
self.ground_truth(), # (S, F, Z, Y, X) | ||
self.filtered_emission_rates(), # (C, F, W) | ||
objective_lens=self.objective_lens, | ||
settings=self.settings, | ||
|
@@ -230,10 +270,10 @@ def optical_image(self) -> xr.DataArray: | |
|
||
This is the same as `optical_image_per_fluor`, but sums the contributions of all | ||
fluorophores in each channel (which a detector would not know). The return | ||
array has dimensions (C, Z, Y, X). The units are photons/s. | ||
array has dimensions (S, C, Z, Y, X). The units are photons/s. | ||
""" | ||
oipf = self.optical_image_per_fluor() | ||
return oipf.sum(Axis.F) # (C, Z, Y, X) | ||
return oipf.sum(Axis.F) # (S, C, Z, Y, X) | ||
|
||
def digital_image( | ||
self, | ||
|
@@ -245,13 +285,16 @@ def digital_image( | |
"""Return the digital image as captured by the detector. | ||
|
||
This down-scales the optical image to the output space, and simulates the | ||
detector response. The return array has dimensions (C, Z, Y, X). The units | ||
are gray values, based on the bit-depth of the detector. If there is no | ||
detector response. The return array has dimensions (S, C, [F], Z, Y, X). The | ||
units are gray values, based on the bit-depth of the detector. If there is no | ||
detector or `with_detector_noise` is False, the units are simply photons. | ||
|
||
NOTE: the input `optical_image` can only contain a fluorophore dimension (e.g., | ||
it can be the output of `optical_image_per_fluor()`). | ||
""" | ||
if optical_image is None: | ||
optical_image = self.optical_image() | ||
image = optical_image # (C, Z, Y, X) | ||
image = optical_image # (S, C, [F], Z, Y, X) | ||
|
||
# downscale to output space | ||
# TODO: consider how we would integrate detector pixel size | ||
|
@@ -280,7 +323,7 @@ def digital_image( | |
image = image * (ch_exposures / 1000) | ||
image.attrs.update(units="photons") | ||
|
||
# (C, Z, Y, X) | ||
# (S, C, Z, Y, X) | ||
return image | ||
|
||
def run(self) -> xr.DataArray: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious about this. I can see that this would be necessary in the case where we assume that the sample "axis" is just an extension of a non-jagged array. But, maybe sample should be special in that case, and maybe samples should always be iterated over in a for-loop rather than processed in a vectorized fashion? (might help with memory bloat too)
that can be decided later. It's definitely easier to start with this restriction and relax it later.