Skip to content

Commit

Permalink
StereoExpData supports multiple spatial infomation.
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Sep 13, 2024
1 parent 9af84c8 commit 09fc33b
Show file tree
Hide file tree
Showing 14 changed files with 296 additions and 206 deletions.
38 changes: 13 additions & 25 deletions stereo/algorithm/paste/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,11 @@ def stack_slices_pairwise(

# new_slices = []
for i in range(len(slices)):
if isinstance(slices[i], AnnBasedStereoExpData):
if slices[i].position_z is not None:
slices[i].adata.obsm['spatial_paste_pairwise'] = np.concatenate((new_coor[i], slices[i].position_z), axis=1)
else:
slices[i].adata.obsm['spatial_paste_pairwise'] = new_coor[i]
slices[i].spatial_key = 'spatial_paste_pairwise'
if slices[i].position_z is not None:
slices[i].cells_matrix['spatial_paste_pairwise'] = np.concatenate((new_coor[i], slices[i].position_z), axis=1)
else:
slices[i].raw_position = slices[i].position
slices[i].position = new_coor[i]
slices[i].cells_matrix['spatial_paste_pairwise'] = new_coor[i]
slices[i].spatial_key = 'spatial_paste_pairwise'

if not output_params:
return slices
Expand Down Expand Up @@ -321,25 +317,17 @@ def stack_slices_center(
new_coor.append(y)

for i in range(len(slices)):
if isinstance(slices[i], AnnBasedStereoExpData):
if slices[i].position_z is not None:
slices[i].adata.obsm['spatial_paste_center'] = np.concatenate((new_coor[i], slices[i].position_z), axis=1)
else:
slices[i].adata.obsm['spatial_paste_center'] = new_coor[i]
slices[i].spatial_key = 'spatial_paste_center'
if slices[i].position_z is not None:
slices[i].cells_matrix['spatial_paste_center'] = np.concatenate((new_coor[i], slices[i].position_z), axis=1)
else:
slices[i].raw_position = slices[i].position
slices[i].position = new_coor[i]

if isinstance(center_slice, AnnBasedStereoExpData):
if center_slice.position_z is not None:
center_slice.adata.obsm['spatial_paste_center'] = np.concatenate((center_slice.position, center_slice.position_z), axis=1)
else:
center_slice.adata.obsm['spatial_paste_center'] = c
center_slice.spatial_key = 'spatial_paste_center'
slices[i].cells_matrix['spatial_paste_center'] = new_coor[i]
slices[i].spatial_key = 'spatial_paste_center'
if center_slice.position_z is not None:
center_slice.cells_matrix['spatial_paste_center'] = np.concatenate((center_slice.position, center_slice.position_z), axis=1)
else:
center_slice.raw_position = center_slice.position
center_slice.position = c
center_slice.cells_matrix['spatial_paste_center'] = c
center_slice.spatial_key = 'spatial_paste_center'

if not output_params:
return center_slice, slices
else:
Expand Down
19 changes: 8 additions & 11 deletions stereo/algorithm/paste/paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
Optional
)

try:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
except:
pass

from stereo.algorithm.ms_algorithm_base import MSDataAlgorithmBase
from stereo.log_manager import logger
from .helper import stack_slices_center
from .helper import stack_slices_pairwise
from .methods import center_align
from .methods import pairwise_align

from .helper import stack_slices_center, stack_slices_pairwise
from .methods import center_align, pairwise_align

class Paste(MSDataAlgorithmBase):
def main(
Expand All @@ -33,12 +36,6 @@ def main(
"""
if method not in ('pairwise', 'center'):
raise ValueError(f'Error method({method}), it must be one of pairwise and center')

try:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
except:
pass

logger.info(f'Using method {method}')
if method == 'pairwise':
Expand Down
6 changes: 6 additions & 0 deletions stereo/algorithm/st_gears/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import Optional, Union, List

try:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
except:
pass

from stereo.algorithm.ms_algorithm_base import MSDataAlgorithmBase
from stereo.core.stereo_exp_data import AnnBasedStereoExpData
from stereo.io.reader import stereo_to_anndata
Expand Down
4 changes: 2 additions & 2 deletions stereo/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
@time:2021/03/17
"""
# flake8: noqa
from .st_pipeline import StPipeline
from .stereo_exp_data import StereoExpData
from .st_pipeline import StPipeline, AnnBasedStPipeline
from .stereo_exp_data import StereoExpData, AnnBasedStereoExpData
9 changes: 6 additions & 3 deletions stereo/core/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,10 @@ def sub_set(self, index):
:param index: a numpy array of index info.
:return: the subset of Cell object.
"""

if self.cell_border is not None:
self.cell_border = self.cell_border[index]
if isinstance(index, pd.Series):
index = index.to_numpy()
if self.cell_border is not None:
self.cell_border = self.cell_border[index]
self._obs = self._obs.iloc[index].copy()
for col in self._obs.columns:
if self._obs[col].dtype.name == 'category':
Expand Down Expand Up @@ -303,6 +302,10 @@ def __contains__(self, item):
def _obs(self):
return self.__based_ann_data.obs

@property
def obs(self):
return self.__based_ann_data.obs

@property
def matrix(self):
return self.__based_ann_data.obsm
Expand Down
30 changes: 26 additions & 4 deletions stereo/core/ms_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import numpy as np
import pandas as pd

from . import StPipeline, StereoExpData
from . import(
StPipeline, AnnBasedStPipeline,
StereoExpData, AnnBasedStereoExpData
)
from .ms_pipeline import MSDataPipeLine
from ..plots.plot_collection import PlotCollection

Expand Down Expand Up @@ -270,16 +273,35 @@ class _MSDataStruct(object):
_data_dict: Dict[int, str] = field(default_factory=dict)
__idx_generator: int = _default_idx()

def __check_data_list(self, data_list):
if not isinstance(data_list, list):
raise TypeError('data_list must be a list object')
first_data = data_list[0]
for data in data_list[1:]:
if type(data) != type(first_data):
raise TypeError('each data in data_list must be the same type, available types: StereoExpData and AnnBasedStereoExpData')
return data_list

def __post_init__(self) -> object:
while len(self._data_list) > len(self._names):
self._names.append(self.__get_auto_key())
if not self._name_dict or not self._data_dict:
self.reset_name(default_key=False)
self.__check_data_list(self._data_list)
return self

def __iter__(self):
return iter(self._data_list)

@property
def data_list(self):
return self._data_list

@data_list.setter
def data_list(self, data_list: List[StereoExpData]):
self.__check_data_list(data_list)
assert len(data_list) == len(self._names), 'length of data_list must be equal to length of names'
self._data_list = list(data_list)

@property
def merged_data(self):
Expand Down Expand Up @@ -729,7 +751,7 @@ def integrate(self, scope=None, remove_existed=False, **kwargs):
batch_tags = None
else:
batch_tags = [self._names.index(name) for name in self[scope].names]
merged_data = merge(*data_list, var_type=self._var_type, batch_tags=batch_tags)
merged_data = merge(*data_list, var_type=self._var_type, batch_tags=batch_tags, **kwargs)
else:
merged_data = deepcopy(data_list[0])
batch = self._names.index(self[scope].names[0])
Expand Down Expand Up @@ -1066,5 +1088,5 @@ def write(self, filename, to_mudata=False):
return write_h5mu(self, filename)


TL = type('TL', (MSDataPipeLine,), {'ATTR_NAME': 'tl', "BASE_CLASS": StPipeline})
PLT = type('PLT', (MSDataPipeLine,), {'ATTR_NAME': 'plt', "BASE_CLASS": PlotCollection})
TL = type('TL', (MSDataPipeLine,), {'ATTR_NAME': 'tl', "BASE_CLASS": None})
PLT = type('PLT', (MSDataPipeLine,), {'ATTR_NAME': 'plt', "BASE_CLASS": None})
34 changes: 13 additions & 21 deletions stereo/core/ms_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __getitem__(self, item):

class MSDataPipeLine(object):
ATTR_NAME = 'tl'
BASE_CLASS = StPipeline
BASE_CLASS = None

def __init__(self, _ms_data):
super().__init__()
Expand All @@ -34,6 +34,7 @@ def __init__(self, _ms_data):
self._key_record = dict()
self.__mode = "integrate"
self.__scope = slice(None)
self.__class__.BASE_CLASS = getattr(self.ms_data[0], self.__class__.ATTR_NAME).__class__

@property
def result(self):
Expand Down Expand Up @@ -91,12 +92,6 @@ def _use_integrate_method(self, item, *args, **kwargs):
scope = kwargs.get("scope", slice(None))
del kwargs["scope"]

# if item in {"cal_qc", "filter_cells", "filter_genes", "sctransform", "log1p", "normalize_total",
# "scale", "raw_checkpoint", "batches_integrate"}:
# if scope != slice(None):
# raise Exception(f'{item} use integrate should use all sample')
# ms_data_view = self.ms_data
# elif scope == slice(None):
if len(self.ms_data[scope]) == len(self.ms_data):
ms_data_view = self.ms_data
if ms_data_view.merged_data is None:
Expand All @@ -107,15 +102,8 @@ def _use_integrate_method(self, item, *args, **kwargs):
scope_key = self.ms_data.generate_scope_key(ms_data_view._names)
self.ms_data.scopes_data[scope_key] = ms_data_view.merged_data

# def set_result_key_method(key):
# self.result_keys.setdefault(scope_key, [])
# if key in self.result_keys[scope_key]:
# self.result_keys[scope_key].remove(key)
# self.result_keys[scope_key].append(key)

# ms_data_view.merged_data.tl.result.set_result_key_method = set_result_key_method

new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None)
# new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None)
new_attr = getattr(self.__class__.BASE_CLASS, item, None)
if new_attr is None:
if self.__class__.ATTR_NAME == "tl":
from stereo.algorithm.algorithm_base import AlgorithmBase
Expand All @@ -134,7 +122,8 @@ def _use_integrate_method(self, item, *args, **kwargs):

logger.info(f'data_obj(idx=0) in ms_data start to run {item}')
return new_attr(
ms_data_view.merged_data.__getattribute__(self.__class__.ATTR_NAME),
# ms_data_view.merged_data.__getattribute__(self.__class__.ATTR_NAME),
getattr(ms_data_view.merged_data, self.__class__.ATTR_NAME),
*args,
**kwargs
)
Expand All @@ -146,23 +135,26 @@ def _run_isolated_method(self, item, *args, **kwargs):
if "scope" in kwargs:
del kwargs["scope"]

new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None)
# new_attr = self.__class__.BASE_CLASS.__dict__.get(item, None)
new_attr = getattr(self.__class__.BASE_CLASS, item, None)
if self.__class__.ATTR_NAME == 'tl':
n_jobs = min(len(ms_data_view.data_list), cpu_count())
else:
n_jobs = 1
if new_attr:
def log_delayed_task(idx, *arg, **kwargs):
def log_delayed_task(idx, obj, *arg, **kwargs):
logger.info(f'data_obj(idx={idx}) in ms_data start to run {item}')
if self.__class__.ATTR_NAME == 'plt':
out_path = kwargs.get('out_path', None)
if out_path is not None:
path_name, ext = os.path.splitext(out_path)
kwargs['out_path'] = f'{path_name}_{idx}{ext}'
new_attr(*arg, **kwargs)
tl_or_plt = getattr(obj, self.__class__.ATTR_NAME)
new_attr(tl_or_plt, *arg, **kwargs)

Parallel(n_jobs=n_jobs, backend='threading', verbose=100)(
delayed(log_delayed_task)(idx, obj.__getattribute__(self.__class__.ATTR_NAME), *args, **kwargs)
# delayed(log_delayed_task)(idx, obj.__getattribute__(self.__class__.ATTR_NAME), *args, **kwargs)
delayed(log_delayed_task)(idx, obj, *args, **kwargs)
for idx, obj in enumerate(ms_data_view.data_list)
)
else:
Expand Down
Loading

0 comments on commit 09fc33b

Please sign in to comment.