Skip to content

Commit

Permalink
add function for converting MSData to MuData
Browse files Browse the repository at this point in the history
  • Loading branch information
tanliwei-coder committed Jul 26, 2024
1 parent 24aa57f commit 0861152
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 79 deletions.
16 changes: 13 additions & 3 deletions stereo/core/ms_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ def relationship(self, value: str):
@property
def relationship_info(self):
return self._relationship_info

@relationship_info.setter
def relationship_info(self, value: dict):
self._relationship_info = value

def reset_position(self, mode='integrate'):
if mode == 'integrate' and self.merged_data:
Expand Down Expand Up @@ -863,10 +867,12 @@ def to_integrate(
if scope_key in self._scopes_data:
self._scopes_data[scope_key].tl.reset_key_record('cluster', res_key)
self._scopes_data[scope_key].tl.result.set_result_key_method(res_key)
self._scopes_data[scope_key].cells[res_key] = self._scopes_data[scope_key].cells[res_key].astype('category')

if self._merged_data is not None:
self._merged_data.tl.reset_key_record('cluster', res_key)
self._merged_data.tl.result.set_result_key_method(res_key)
self._merged_data.cells[res_key] = self._merged_data.cells[res_key].astype('category')
elif type == 'var':
raise NotImplementedError
else:
Expand Down Expand Up @@ -1033,9 +1039,13 @@ def __str__(self):
def __repr__(self):
return self.__str__()

def write(self, filename):
from stereo.io.writer import write_h5ms
write_h5ms(self, filename)
def write(self, filename, to_mudata=False):
if not to_mudata:
from stereo.io.writer import write_h5ms
write_h5ms(self, filename)
else:
from stereo.io.writer import write_h5mu
return write_h5mu(self, filename)


TL = type('TL', (MSDataPipeLine,), {'ATTR_NAME': 'tl', "BASE_CLASS": StPipeline})
Expand Down
8 changes: 8 additions & 0 deletions stereo/core/ms_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ def _use_integrate_method(self, item, *args, **kwargs):
scope_key = self.ms_data.generate_scope_key(ms_data_view._names)
self.ms_data.scopes_data[scope_key] = ms_data_view.merged_data

def set_result_key_method(key):
self.result_keys.setdefault(scope_key, [])
if key in self.result_keys[scope_key]:
self.result_keys[scope_key].remove(key)
self.result_keys[scope_key].append(key)

ms_data_view.merged_data.tl.result.set_result_key_method = set_result_key_method

# def callback_func(key, value):
# # key_name = "scope_[" + ",".join(
# # [str(self.ms_data._names.index(name)) for name in ms_data_view._names]) + "]"
Expand Down
40 changes: 29 additions & 11 deletions stereo/core/st_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, data: Union[StereoExpData, AnnBasedStereoExpData]):
self.data: Union[StereoExpData, AnnBasedStereoExpData] = data
self.result = Result(data)
self._raw: Union[StereoExpData, AnnBasedStereoExpData] = None
self.key_record = {'hvg': [], 'pca': [], 'neighbors': [], 'umap': [], 'cluster': [], 'marker_genes': []}
self._key_record = {'hvg': [], 'pca': [], 'neighbors': [], 'umap': [], 'cluster': [], 'marker_genes': []}
# self.reset_key_record = self._reset_key_record

def __getattr__(self, item):
Expand All @@ -81,6 +81,10 @@ def __getattr__(self, item):
f'{item} not existed, please check the function name you called!'
)

@property
def key_record(self):
return self._key_record

@property
def raw(self) -> Union[StereoExpData, AnnBasedStereoExpData]:
"""
Expand All @@ -107,7 +111,14 @@ def reset_raw_data(self):
:return:
"""
self.data = self.raw
# self.data = self.raw
self.data.exp_matrix = copy.deepcopy(self.raw.exp_matrix)
self.data.cells = copy.deepcopy(self.raw.cells)
self.data.genes = copy.deepcopy(self.raw.genes)
self.data.position = copy.deepcopy(self.raw.position)
self.data.position_z = copy.deepcopy(self.raw.position_z)
from stereo.preprocess.qc import cal_qc
cal_qc(self.data)

def raw_checkpoint(self):
"""
Expand Down Expand Up @@ -335,7 +346,7 @@ def filter_coordinates(self,
filter_raw
whether to filter raw data meanwhile.
inplace
whether to inplace the previous data or return a new data.
whether to replace the previous data or return a new data.
Returns
--------------------
Expand Down Expand Up @@ -373,7 +384,7 @@ def filter_by_clusters(
filter_raw
whether to filter raw data meanwhile.
inplace
whether to inplace the previous data or return a new data.
whether to replace the previous data or return a new data.
Returns
--------------------
Expand Down Expand Up @@ -408,7 +419,7 @@ def log1p(self,
Parameters
-----------------
inplace
whether to inplcae previous data or get a new express matrix after normalization of log1p.
whether to replace previous data or get a new express matrix after normalization of log1p.
res_key
the key to get targeted result from `self.result`.
Expand Down Expand Up @@ -437,7 +448,7 @@ def normalize_total(self,
the number of total counts per cell after normalization, if `None`, each cell has a
total count equal to the median of total counts for all cells before normalization.
inplace
whether to inplcae previous data or get a new express matrix after normalize_total.
whether to replace previous data or get a new express matrix after normalize_total.
res_key
the key to get targeted result from `self.result`.
Expand Down Expand Up @@ -468,7 +479,7 @@ def scale(self,
max_value
truncate to this value after scaling, if `None`, do not truncate.
inplace
whether to inplace the previous data or get a new express matrix after scaling.
whether to replace the previous data or get a new express matrix after scaling.
res_key
the key to get targeted result from `self.result`.
Expand All @@ -489,7 +500,7 @@ def quantile(self, inplace=True, res_key='quantile'):
Normalize the columns of X to each have the same distribution. Given an expression matrix of M genes by N
samples, quantile normalization ensures all samples have the same spread of data (by construction).
:param inplace: whether inplace the original data or get a new express matrix after quantile.
:param inplace: whether replace the original data or get a new express matrix after quantile.
:param res_key: the key for getting the result from the self.result.
:return:
"""
Expand All @@ -507,7 +518,7 @@ def disksmooth_zscore(self, r=20, inplace=True, res_key='disksmooth_zscore'):
for each position, given a radius, calculate the z-score within this circle as final normalized value.
:param r: radius for normalization.
:param inplace: whether inplace the original data or get a new express matrix after disksmooth_zscore.
:param inplace: whether replace the original data or get a new express matrix after disksmooth_zscore.
:param res_key: the key for getting the result from the self.result.
:return:
"""
Expand Down Expand Up @@ -647,7 +658,7 @@ def subset_by_hvg(self, hvg_res_key, use_raw=False, inplace=True):
get the subset by the result of highly variable genes.
:param hvg_res_key: the key of highly varialbe genes to getting the result.
:param inplace: whether inplace the data or get a new data after highly variable genes, which only save the
:param inplace: whether replace the data or get a new data after highly variable genes, which only save the
data info of highly variable genes.
:return: a StereoExpData object.
"""
Expand Down Expand Up @@ -1212,6 +1223,7 @@ def spatial_hotspot(self,
hs = spatial_hotspot(data, model=model, n_neighbors=n_neighbors, n_jobs=n_jobs, fdr_threshold=fdr_threshold,
min_gene_threshold=min_gene_threshold, outdir=outdir)
self.result[res_key] = hs
self.reset_key_record('spatial_hotspot', res_key)

@logit
def gaussian_smooth(self,
Expand All @@ -1230,7 +1242,7 @@ def gaussian_smooth(self,
Also too high value may cause overfitting, and low value may cause poor smoothing effect.
:param pca_res_key: the key of PCA to get targeted result from `self.result`.
:param n_jobs: the number of parallel jobs to run, if `-1`, all CPUs will be used.
:param inplace: whether to inplace the previous express matrix or get a new StereoExpData object with the new express matrix. # noqa
:param inplace: whether to replace the previous express matrix or get a new StereoExpData object with the new express matrix. # noqa
:return: An object of StereoExpData with the express matrix processed by Gaussian smooting.
"""
Expand Down Expand Up @@ -1630,3 +1642,9 @@ def __init__(self, based_ann_data: AnnData, data: AnnBasedStereoExpData):
def raw_checkpoint(self):
super().raw_checkpoint()
self.data._ann_data.raw = self.data._ann_data

@property
def key_record(self):
if 'key_record' not in self.data.adata.uns:
self.data.adata.uns['key_record'] = self._key_record
return self.data.adata.uns['key_record']
4 changes: 3 additions & 1 deletion stereo/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
read_gef_info,
read_seurat_h5ad,
read_h5ad,
read_h5ms
read_h5ms,
mudata_to_msdata
)
from .writer import (
write,
write_h5ad,
write_h5ms,
write_h5mu,
write_mid_gef,
update_gef
)
Loading

0 comments on commit 0861152

Please sign in to comment.