Skip to content

Commit

Permalink
Issue314 (#324)
Browse files Browse the repository at this point in the history
* Increase default timeout

* Remove cell limit

* Increase payload limit for tcp server

* Remove unnecessary db query work

* Add more caching

* Increase verbosity

* Increase buffer size for tcp transfer

* Revert buffer, waiting indefinitely

* Make docker context slimmer

* Include some cells data logic

* Table formatting

* Correct typo

* Add section to doc

* Include byte order clarification and outline of test

* Force int representation

* Add some debug

* Cache some intermediate results for performance on cells data

* Fix unhashable keys

* Switch endian

* Deal with implications of hashable type change

* More list to tuple

* Restore timeouts

* Fix table format

* Investigate test failure

* Update test

* Reduce precision of squidpy test

* Fix rounding algorithm

* Including max and min location data

* Add feature names endpoint

* Remove debug

* Update doc link

* Update link target

* Describe preview method
  • Loading branch information
jimmymathews authored Jun 6, 2024
1 parent 5b5f37c commit 52ed062
Show file tree
Hide file tree
Showing 31 changed files with 709 additions and 68 deletions.
4 changes: 3 additions & 1 deletion build/build_scripts/.dockerignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
.git
docs/image_assets/**
.venv
venv
venv
**/.mypy_cache
**/.mypy
32 changes: 32 additions & 0 deletions docs/cells.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

# Data structure to represent summarized cell-level data for one slide

For memory- and time-efficient manipulation, a simple binary data structure is used to represent the cell data for one slide.

The memory map is as follows:

| Byte range start | Byte range end | Number of bytes | Section | Description | Data type |
|------------------|----------------|-----------------|---------|----------------------------------------------------------|---------------------------------------|
| 1 | 4 | 4 | Header | The number of cells represented by this serialization. | 32-bit integer, big-endian byte order |
| 5 | 8 | 4 | | Minimum first pixel coordinate occurring in this file. | 32-bit integer, big-endian byte order |
| 9 | 12 | 4 | | Maximum first pixel coordinate occurring in this file. | 32-bit integer, big-endian byte order |
| 13 | 16 | 4 | | Minimum second pixel coordinate occurring in this file. | 32-bit integer, big-endian byte order |
| 17 | 20 | 4 | | Maximum second pixel coordinate occurring in this file. | 32-bit integer, big-endian byte order |
| 21 | 24 | 4 | Cell 1 | Cell 1 index integer. | 32-bit integer, big-endian byte order |
| 25 | 28 | 4 | | Cell 1 location's first pixel coordinate integer. | 32-bit integer, big-endian byte order |
| 29 | 32 | 4 | | Cell 1 location's second pixel coordinate integer. | 32-bit integer, big-endian byte order |
| 33 | 40 | 8 | | Cell 1 phenotype membership bit-mask, up to 64 channels. | 64-bit mask |
| 41 | 44 | 4 | Cell 2 | Cell 2 location's first pixel coordinate integer. | 32-bit integer, big-endian byte order |
| ... | ... | ... | ... | ... |

The ellipsis represents repetition of the per-cell section once for each cell. This is 4 + 4 + 4 + 8 = 20 bytes per cell. The "header" preceding the per-cell sections is 20 bytes.

A representation of an example of the cell sections can be found [here](https://github.com/nadeemlab/SPT/blob/main/test/apiserver/module_tests/celldata.dump).

There is a convenient way to preview the contents at the command line using `xxd`:

```sh
tail -c +21 payload.bin | xxd -b -c 20
```

(The initial `tail` command strips out the header.)
88 changes: 74 additions & 14 deletions spatialprofilingtoolbox/apiserver/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import secure

from spatialprofilingtoolbox.db.simple_method_cache import simple_function_cache
from spatialprofilingtoolbox.db.study_tokens import StudyCollectionNaming
from spatialprofilingtoolbox.ondemand.service_client import OnDemandRequester
from spatialprofilingtoolbox.db.exchange_data_formats.study import StudyHandle
Expand All @@ -29,8 +30,9 @@
PhenotypeCount,
UnivariateMetricsComputationResult,
CellData,
AvailableGNN,
AvailableGNN
)
from spatialprofilingtoolbox.db.exchange_data_formats.cells import BitMaskFeatureNames
from spatialprofilingtoolbox.db.exchange_data_formats.metrics import UMAPChannel
from spatialprofilingtoolbox.db.querying import query
from spatialprofilingtoolbox.apiserver.app.validation import (
Expand Down Expand Up @@ -71,7 +73,7 @@
},
)

CELL_DATA_CELL_LIMIT = 100001
CELL_DATA_CELL_LIMIT = 5 * int(pow(10, 6))


def custom_openapi():
Expand Down Expand Up @@ -196,6 +198,7 @@ async def get_anonymous_phenotype_counts_fast(
"""
return _get_anonymous_phenotype_counts_fast(positive_marker, negative_marker, study)


def _get_anonymous_phenotype_counts_fast(
positive_marker: ValidChannelListPositives,
negative_marker: ValidChannelListNegatives,
Expand All @@ -204,6 +207,7 @@ def _get_anonymous_phenotype_counts_fast(
number_cells = cast(int, query().get_number_cells(study))
return get_phenotype_counts(positive_marker, negative_marker, study, number_cells)


@app.get("/request-spatial-metrics-computation/")
async def request_spatial_metrics_computation(
study: ValidStudy,
Expand All @@ -218,8 +222,8 @@ async def request_spatial_metrics_computation(
]
markers: list[list[str]] = []
for criterion in criteria:
markers.append(criterion.positive_markers)
markers.append(criterion.negative_markers)
markers.append(list(criterion.positive_markers))
markers.append(list(criterion.negative_markers))
return get_squidpy_metrics(study, markers, feature_class, radius=radius)


Expand Down Expand Up @@ -317,6 +321,25 @@ def _get_importance_composition(
)


@simple_function_cache()
def get_phenotype_counts_cached(
positives: tuple[str, ...],
negatives: tuple[str, ...],
study: str,
number_cells: int,
selected: tuple[int, ...],
) -> PhenotypeCounts:
with OnDemandRequester(service='counts') as requester:
counts = requester.get_counts_by_specimen(
positives,
negatives,
study,
number_cells,
set(selected) if selected is not None else None,
)
return counts


def get_phenotype_counts(
positive_marker: ValidChannelListPositives,
negative_marker: ValidChannelListNegatives,
Expand All @@ -327,15 +350,13 @@ def get_phenotype_counts(
"""For each specimen, return the fraction of selected/all cells expressing the phenotype."""
positive_markers = [m for m in positive_marker if m != '']
negative_markers = [m for m in negative_marker if m != '']
with OnDemandRequester(service='counts') as requester:
counts = requester.get_counts_by_specimen(
positive_markers,
negative_markers,
study,
number_cells,
cells_selected,
)
return counts
return get_phenotype_counts_cached(
tuple(positive_markers),
tuple(negative_markers),
study,
number_cells,
tuple(sorted(list(cells_selected))) if cells_selected is not None else None,
)


def get_proximity_metrics(
Expand Down Expand Up @@ -389,6 +410,44 @@ def match(c: PhenotypeCount) -> bool:
return payload


@app.get("/cell-data-binary/")
async def get_cell_data_binary(
study: ValidStudy,
sample: Annotated[str, Query(max_length=512)],
):
"""
Get streaming cell-level location and phenotype data in a custom binary format.
The format is documented [here](https://github.com/nadeemlab/SPT/blob/main/docs/cells.md).
"""
if not sample in query().get_sample_names(study):
raise HTTPException(status_code=404, detail=f'Sample "{sample}" does not exist.')
number_cells = cast(int, query().get_number_cells(study))
def match(c: PhenotypeCount) -> bool:
return c.specimen == sample
count = tuple(filter(
match,
get_phenotype_counts([], [], study, number_cells,
).counts))[0].count
if count is None or count > CELL_DATA_CELL_LIMIT:
message = f'Sample "{sample}" has too many cells: {count}.'
raise HTTPException(status_code=404, detail=message)
data = query().get_cells_data(study, sample)
input_buffer = BytesIO(data)
input_buffer.seek(0)
def streaming_iteration():
yield from input_buffer
return StreamingResponse(streaming_iteration(), media_type="application/octet-stream")


@app.get("/cell-data-binary-feature-names/")
async def get_cell_data_binary_feature_names(study: ValidStudy) -> BitMaskFeatureNames:
"""
Get the features corresponding to the channels in the binary/bitmask representation of a cell's
channel positivity/negativity assignments.
"""
return query().get_ordered_feature_names(study)


@app.get("/visualization-plots/")
async def get_plots(
study: ValidStudy,
Expand All @@ -402,7 +461,8 @@ async def get_plot_high_resolution(
study: ValidStudy,
channel: ValidChannel,
):
"""One full-resolution UMAP plot (for the given channel in the given study), provided as a
"""
One full-resolution UMAP plot (for the given channel in the given study), provided as a
streaming PNG.
"""
umap = query().get_umap(study, channel)
Expand Down
1 change: 1 addition & 0 deletions spatialprofilingtoolbox/db/accessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from spatialprofilingtoolbox.db.accessors.phenotypes import PhenotypesAccess
from spatialprofilingtoolbox.db.accessors.study import StudyAccess
from spatialprofilingtoolbox.db.accessors.umap import UMAPAccess
from spatialprofilingtoolbox.db.accessors.cells import CellsAccess
160 changes: 160 additions & 0 deletions spatialprofilingtoolbox/db/accessors/cells.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Convenience accessor of all cell data for a given sample."""
from pickle import loads as pickle_loads
from json import loads as json_loads
from typing import Any
from typing import Iterable
from itertools import islice
from itertools import product

from psycopg2.extensions import cursor as Psycopg2Cursor

from spatialprofilingtoolbox.db.exchange_data_formats.cells import CellsData
from spatialprofilingtoolbox.db.exchange_data_formats.cells import BitMaskFeatureNames
from spatialprofilingtoolbox.db.exchange_data_formats.metrics import Channel
from spatialprofilingtoolbox.db.database_connection import SimpleReadOnlyProvider
from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger

logger = colorized_logger(__name__)


class CellsAccess(SimpleReadOnlyProvider):
"""Retrieve cell-level data for a sample."""

def get_cells_data(self, sample: str) -> CellsData:
return CellsAccess._zip_location_and_phenotype_data(
self._get_location_data(sample),
self._get_phenotype_data(sample),
)

def get_ordered_feature_names(self) -> BitMaskFeatureNames:
expressions_index = json_loads(bytearray(self.fetch_one_or_else(
'''
SELECT blob_contents
FROM ondemand_studies_index osi
WHERE blob_type='expressions_index' ;
''',
(),
self.cursor,
'No feature metadata for the given study.',
)).decode('utf-8'))[''][0]
lookup1: dict[str, int] = expressions_index['target index lookup']
lookup2: dict[str, str] = expressions_index['target by symbol']
target_from_index = {value: key for key, value in lookup1.items()}
symbol_from_target = {value: key for key, value in lookup2.items()}
indices = sorted(list(target_from_index.keys()))
names = tuple(map(
lambda i: symbol_from_target[target_from_index[i]],
indices,
))
return BitMaskFeatureNames(
names=tuple(Channel(symbol=n) for n in names)
)

def _get_location_data(self, sample: str) -> dict[int, tuple[float, float]]:
by_sample = pickle_loads(
self.fetch_one_or_else(
'''
SELECT blob_contents
FROM ondemand_studies_index
WHERE specimen=%s AND blob_type='centroids' ;
''',
(sample,),
self.cursor,
f'Requested centroids data for "{sample}" not found in database.'
)
)
return by_sample[sample]

def _get_phenotype_data(self, sample: str) -> dict[int, bytes]:
index_and_expressions = bytearray(self.fetch_one_or_else(
'''
SELECT blob_contents
FROM ondemand_studies_index
WHERE specimen=%s AND blob_type='feature_matrix' ;
''',
(sample,),
self.cursor,
f'Requested phenotype data for "{sample}" not found in database.',
))
byte_count = len(index_and_expressions)
if byte_count % 16 != 0:
message = f'Expected 16 bytes per cell in binary representation of phenotype data, got {byte_count}.'
logger.error(message)
raise ValueError(message)
bytes_iterator = index_and_expressions.__iter__()
return dict(
(int.from_bytes(batch[0:8], 'little'), bytes(batch[8:16]))
for batch in self._batched(bytes_iterator, 16)
)

@staticmethod
def _batched(iterable: Iterable, batch_size: int):
iterator = iter(iterable)
while batch := tuple(islice(iterator, batch_size)):
yield batch

@classmethod
def _zip_location_and_phenotype_data(
cls,
location_data: dict[int, tuple[float, float]],
phenotype_data: dict[int, bytes],
) -> CellsData:
identifiers = sorted(list(location_data.keys()))
_identifiers = sorted(list(phenotype_data.keys()))
if _identifiers != identifiers:
message = f'Mismatch of cell sets for location and phenotype data.'
raise ValueError(message)
cls._check_consecutive(identifiers)
combined = tuple(
(i, location_data[i], phenotype_data[i])
for i in identifiers
)
serial = b''.join(map(cls._format_cell_bytes, combined))
if len(serial) % 20 != 0:
message = f'Expected exactly 20 bytes per cell to be created. Got total {len(serial)}.'
logger.error(message)
raise ValueError(message)
cell_count = int(len(serial) / 20)

extrema = {
(operation[1], index): operation[0](map(lambda pair: pair[index-1], location_data.values()))
for operation, index in product(((min, 'min'), (max, 'max')), (1, 2))
}
header = b''.join(map(
lambda i: int(i).to_bytes(4),
(cell_count, extrema[('min',1)], extrema[('max',1)], extrema[('min',2)], extrema[('max',2)])
))
return b''.join((header, serial))

@classmethod
def _check_consecutive(cls, identifiers: list[int]):
offset = identifiers[0]
for i1, i2 in zip(identifiers, range(len(identifiers))):
if i1 != i2 + offset:
message = f'Identifiers {identifiers[0]}..{identifiers[-1]} not consecutive: {i1} should be {i2 + offset}.'
logger.warning(message)
break

@classmethod
def _format_cell_bytes(cls, args: tuple[int, tuple[float, float], bytes]) -> bytes:
identifier, location, phenotype = args
return b''.join((
identifier.to_bytes(4),
int(location[0]).to_bytes(4),
int(location[1]).to_bytes(4),
phenotype,
))

@staticmethod
def fetch_one_or_else(
query: str,
args: tuple,
cursor: Psycopg2Cursor,
error_message: str,
) -> Any:
cursor.execute(query, args)
fetched = cursor.fetchone()
if fetched is None:
logger.error(error_message)
raise ValueError(error_message)
return fetched[0]
6 changes: 4 additions & 2 deletions spatialprofilingtoolbox/db/accessors/phenotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_phenotype_criteria(self, study: str, phenotype_symbol: str) -> Phenotype
negatives = sorted([
marker for marker, polarity in rows if polarity == 'negative'
])
return PhenotypeCriteria(positive_markers=positives, negative_markers=negatives)
return PhenotypeCriteria(positive_markers=tuple(positives), negative_markers=tuple(negatives))

def get_phenotype_criteria_by_identifier(
self,
Expand All @@ -84,7 +84,9 @@ def get_phenotype_criteria_by_identifier(
rows = self.cursor.fetchall()
positives = sorted([str(row[0]) for row in rows if row[1] == 'positive'])
negatives = sorted([str(row[0]) for row in rows if row[1] == 'negative'])
return PhenotypeCriteria(positive_markers=positives, negative_markers=negatives)
return PhenotypeCriteria(
positive_markers=tuple(positives), negative_markers=tuple(negatives),
)

def get_channel_names(self, study: str) -> tuple[str, ...]:
components = StudyAccess(self.cursor).get_study_components(study)
Expand Down
Loading

0 comments on commit 52ed062

Please sign in to comment.