From 0861152c750a1a15580bed601e23c0ec7f809302 Mon Sep 17 00:00:00 2001 From: tanliwei Date: Fri, 26 Jul 2024 15:43:40 +0800 Subject: [PATCH] add function for converting MSData to MuData --- stereo/core/ms_data.py | 16 +++- stereo/core/ms_pipeline.py | 8 ++ stereo/core/st_pipeline.py | 40 ++++++--- stereo/io/__init__.py | 4 +- stereo/io/reader.py | 177 ++++++++++++++++++++++++++++--------- stereo/io/writer.py | 132 ++++++++++++++++++++++----- 6 files changed, 298 insertions(+), 79 deletions(-) diff --git a/stereo/core/ms_data.py b/stereo/core/ms_data.py index 59d36dad..1882a714 100644 --- a/stereo/core/ms_data.py +++ b/stereo/core/ms_data.py @@ -348,6 +348,10 @@ def relationship(self, value: str): @property def relationship_info(self): return self._relationship_info + + @relationship_info.setter + def relationship_info(self, value: dict): + self._relationship_info = value def reset_position(self, mode='integrate'): if mode == 'integrate' and self.merged_data: @@ -863,10 +867,12 @@ def to_integrate( if scope_key in self._scopes_data: self._scopes_data[scope_key].tl.reset_key_record('cluster', res_key) self._scopes_data[scope_key].tl.result.set_result_key_method(res_key) + self._scopes_data[scope_key].cells[res_key] = self._scopes_data[scope_key].cells[res_key].astype('category') if self._merged_data is not None: self._merged_data.tl.reset_key_record('cluster', res_key) self._merged_data.tl.result.set_result_key_method(res_key) + self._merged_data.cells[res_key] = self._merged_data.cells[res_key].astype('category') elif type == 'var': raise NotImplementedError else: @@ -1033,9 +1039,13 @@ def __str__(self): def __repr__(self): return self.__str__() - def write(self, filename): - from stereo.io.writer import write_h5ms - write_h5ms(self, filename) + def write(self, filename, to_mudata=False): + if not to_mudata: + from stereo.io.writer import write_h5ms + write_h5ms(self, filename) + else: + from stereo.io.writer import write_h5mu + return write_h5mu(self, filename) TL = type('TL', (MSDataPipeLine,), {'ATTR_NAME': 'tl', "BASE_CLASS": StPipeline}) diff --git a/stereo/core/ms_pipeline.py b/stereo/core/ms_pipeline.py index 508f743d..92524c49 100644 --- a/stereo/core/ms_pipeline.py +++ b/stereo/core/ms_pipeline.py @@ -106,6 +106,14 @@ def _use_integrate_method(self, item, *args, **kwargs): scope_key = self.ms_data.generate_scope_key(ms_data_view._names) self.ms_data.scopes_data[scope_key] = ms_data_view.merged_data + def set_result_key_method(key): + self.result_keys.setdefault(scope_key, []) + if key in self.result_keys[scope_key]: + self.result_keys[scope_key].remove(key) + self.result_keys[scope_key].append(key) + + ms_data_view.merged_data.tl.result.set_result_key_method = set_result_key_method + # def callback_func(key, value): # # key_name = "scope_[" + ",".join( # # [str(self.ms_data._names.index(name)) for name in ms_data_view._names]) + "]" diff --git a/stereo/core/st_pipeline.py b/stereo/core/st_pipeline.py index ec4d363c..a3c9894a 100644 --- a/stereo/core/st_pipeline.py +++ b/stereo/core/st_pipeline.py @@ -58,7 +58,7 @@ def __init__(self, data: Union[StereoExpData, AnnBasedStereoExpData]): self.data: Union[StereoExpData, AnnBasedStereoExpData] = data self.result = Result(data) self._raw: Union[StereoExpData, AnnBasedStereoExpData] = None - self.key_record = {'hvg': [], 'pca': [], 'neighbors': [], 'umap': [], 'cluster': [], 'marker_genes': []} + self._key_record = {'hvg': [], 'pca': [], 'neighbors': [], 'umap': [], 'cluster': [], 'marker_genes': []} # self.reset_key_record = self._reset_key_record def __getattr__(self, item): @@ -81,6 +81,10 @@ def __getattr__(self, item): f'{item} not existed, please check the function name you called!' ) + @property + def key_record(self): + return self._key_record + @property def raw(self) -> Union[StereoExpData, AnnBasedStereoExpData]: """ @@ -107,7 +111,14 @@ def reset_raw_data(self): :return: """ - self.data = self.raw + # self.data = self.raw + self.data.exp_matrix = copy.deepcopy(self.raw.exp_matrix) + self.data.cells = copy.deepcopy(self.raw.cells) + self.data.genes = copy.deepcopy(self.raw.genes) + self.data.position = copy.deepcopy(self.raw.position) + self.data.position_z = copy.deepcopy(self.raw.position_z) + from stereo.preprocess.qc import cal_qc + cal_qc(self.data) def raw_checkpoint(self): """ @@ -335,7 +346,7 @@ def filter_coordinates(self, filter_raw whether to filter raw data meanwhile. inplace - whether to inplace the previous data or return a new data. + whether to replace the previous data or return a new data. Returns -------------------- @@ -373,7 +384,7 @@ def filter_by_clusters( filter_raw whether to filter raw data meanwhile. inplace - whether to inplace the previous data or return a new data. + whether to replace the previous data or return a new data. Returns -------------------- @@ -408,7 +419,7 @@ def log1p(self, Parameters ----------------- inplace - whether to inplcae previous data or get a new express matrix after normalization of log1p. + whether to replace previous data or get a new express matrix after normalization of log1p. res_key the key to get targeted result from `self.result`. @@ -437,7 +448,7 @@ def normalize_total(self, the number of total counts per cell after normalization, if `None`, each cell has a total count equal to the median of total counts for all cells before normalization. inplace - whether to inplcae previous data or get a new express matrix after normalize_total. + whether to replace previous data or get a new express matrix after normalize_total. res_key the key to get targeted result from `self.result`. @@ -468,7 +479,7 @@ def scale(self, max_value truncate to this value after scaling, if `None`, do not truncate. inplace - whether to inplace the previous data or get a new express matrix after scaling. + whether to replace the previous data or get a new express matrix after scaling. res_key the key to get targeted result from `self.result`. @@ -489,7 +500,7 @@ def quantile(self, inplace=True, res_key='quantile'): Normalize the columns of X to each have the same distribution. Given an expression matrix of M genes by N samples, quantile normalization ensures all samples have the same spread of data (by construction). - :param inplace: whether inplace the original data or get a new express matrix after quantile. + :param inplace: whether replace the original data or get a new express matrix after quantile. :param res_key: the key for getting the result from the self.result. :return: """ @@ -507,7 +518,7 @@ def disksmooth_zscore(self, r=20, inplace=True, res_key='disksmooth_zscore'): for each position, given a radius, calculate the z-score within this circle as final normalized value. :param r: radius for normalization. - :param inplace: whether inplace the original data or get a new express matrix after disksmooth_zscore. + :param inplace: whether replace the original data or get a new express matrix after disksmooth_zscore. :param res_key: the key for getting the result from the self.result. :return: """ @@ -647,7 +658,7 @@ def subset_by_hvg(self, hvg_res_key, use_raw=False, inplace=True): get the subset by the result of highly variable genes. :param hvg_res_key: the key of highly varialbe genes to getting the result. - :param inplace: whether inplace the data or get a new data after highly variable genes, which only save the + :param inplace: whether replace the data or get a new data after highly variable genes, which only save the data info of highly variable genes. :return: a StereoExpData object. """ @@ -1212,6 +1223,7 @@ def spatial_hotspot(self, hs = spatial_hotspot(data, model=model, n_neighbors=n_neighbors, n_jobs=n_jobs, fdr_threshold=fdr_threshold, min_gene_threshold=min_gene_threshold, outdir=outdir) self.result[res_key] = hs + self.reset_key_record('spatial_hotspot', res_key) @logit def gaussian_smooth(self, @@ -1230,7 +1242,7 @@ def gaussian_smooth(self, Also too high value may cause overfitting, and low value may cause poor smoothing effect. :param pca_res_key: the key of PCA to get targeted result from `self.result`. :param n_jobs: the number of parallel jobs to run, if `-1`, all CPUs will be used. - :param inplace: whether to inplace the previous express matrix or get a new StereoExpData object with the new express matrix. # noqa + :param inplace: whether to replace the previous express matrix or get a new StereoExpData object with the new express matrix. # noqa :return: An object of StereoExpData with the express matrix processed by Gaussian smooting. """ @@ -1630,3 +1642,9 @@ def __init__(self, based_ann_data: AnnData, data: AnnBasedStereoExpData): def raw_checkpoint(self): super().raw_checkpoint() self.data._ann_data.raw = self.data._ann_data + + @property + def key_record(self): + if 'key_record' not in self.data.adata.uns: + self.data.adata.uns['key_record'] = self._key_record + return self.data.adata.uns['key_record'] diff --git a/stereo/io/__init__.py b/stereo/io/__init__.py index 4c165416..62073af6 100644 --- a/stereo/io/__init__.py +++ b/stereo/io/__init__.py @@ -17,12 +17,14 @@ read_gef_info, read_seurat_h5ad, read_h5ad, - read_h5ms + read_h5ms, + mudata_to_msdata ) from .writer import ( write, write_h5ad, write_h5ms, + write_h5mu, write_mid_gef, update_gef ) diff --git a/stereo/io/reader.py b/stereo/io/reader.py index 3a2ed8d7..cec38f88 100644 --- a/stereo/io/reader.py +++ b/stereo/io/reader.py @@ -13,9 +13,8 @@ 2022/02/09 read raw data and result """ from copy import deepcopy -from typing import Optional -from typing import Union -from natsort import natsorted +from typing import Optional, Union, List +import re import h5py import numpy as np @@ -34,7 +33,6 @@ from stereo.core.result import _BaseResult from stereo.io import h5ad from stereo.io.utils import( - remove_genes_number, integrate_matrix_by_genes, transform_marker_genes_to_anndata, get_gem_comments @@ -436,14 +434,20 @@ def _read_stereo_h5_result(key_record: dict, data: StereoExpData, f: Union[h5py. data.tl.result[res_key][data_key] = h5ad.read_group(f[full_key]) def _read_anndata_from_group(f: h5py.Group) -> AnnBasedStereoExpData: - from anndata._io.specs.registry import read_elem + from distutils.version import StrictVersion + from anndata import __version__ as anndata_version + + if StrictVersion(anndata_version) < StrictVersion('0.8.0'): + from anndata._io.utils import read_attribute as read_elem + else: + from anndata._io.specs.registry import read_elem adata = AnnData( **{k: read_elem(f[k]) for k in f.keys()} ) data = AnnBasedStereoExpData(based_ann_data=adata) - if 'key_record' in adata.uns: - data.tl.key_record = {k: list(v) for k, v in adata.uns['key_record'].items()} - del adata.uns['key_record'] + # if 'key_record' in adata.uns: + # data.tl.key_record = {k: list(v) for k, v in adata.uns['key_record'].items()} + # del adata.uns['key_record'] data.merged = f.attrs.get('merged', False) data.spatial_key = f.attrs.get('spatial_key', 'spatial') return data @@ -475,21 +479,21 @@ def read_h5ms(file_path, use_raw=True, use_result=True): slice_keys = list(f[k].keys()) slice_keys.sort(key=lambda k: int(k.split('_')[1])) for one_slice_key in slice_keys: - data = _read_stereo_h5ad_from_group(f[k][one_slice_key], StereoExpData(), use_raw, use_result) - # encoding_type = f[k][one_slice_key].attrs.get('encoding-type', 'stereo_exp_data') - # if encoding_type == 'stereo_exp_data': - # data = _read_stereo_h5ad_from_group(f[k][one_slice_key], StereoExpData(), use_raw, use_result) - # else: - # data = _read_anndata_from_group(f[k][one_slice_key]) + # data = _read_stereo_h5ad_from_group(f[k][one_slice_key], StereoExpData(), use_raw, use_result) + encoding_type = f[k][one_slice_key].attrs.get('encoding-type', 'stereo_exp_data') + if encoding_type == 'anndata': + data = _read_anndata_from_group(f[k][one_slice_key]) + else: + data = _read_stereo_h5ad_from_group(f[k][one_slice_key], StereoExpData(), use_raw, use_result) data_list.append(data) elif k == 'sample_merged': for mk in f[k].keys(): - scope_data = _read_stereo_h5ad_from_group(f[k][mk], StereoExpData(), use_raw, use_result) - # encoding_type = f[k][mk].attrs.get('encoding-type', 'stereo_exp_data') - # if encoding_type == 'stereo_exp_data': - # scope_data = _read_stereo_h5ad_from_group(f[k][mk], StereoExpData(), use_raw, use_result) - # else: - # scope_data = _read_anndata_from_group(f[k][mk]) + # scope_data = _read_stereo_h5ad_from_group(f[k][mk], StereoExpData(), use_raw, use_result) + encoding_type = f[k][mk].attrs.get('encoding-type', 'stereo_exp_data') + if encoding_type == 'anndata': + scope_data = _read_anndata_from_group(f[k][mk]) + else: + scope_data = _read_stereo_h5ad_from_group(f[k][mk], StereoExpData(), use_raw, use_result) scopes_data[mk] = scope_data if f[k][mk].attrs is not None: merged_from_all = f[k][mk].attrs.get('merged_from_all', False) @@ -917,6 +921,8 @@ def stereo_to_anndata( adata.uns['resolution'] = data.attr['resolution'] if data.bin_type == 'cell_bins' and data.cells.cell_border is not None: adata.obsm['cell_border'] = data.cells.cell_border + if 'key_record' not in adata.uns: + adata.uns['key_record'] = deepcopy(data.tl.key_record) if data.sn is not None: if isinstance(data.sn, str): @@ -1003,6 +1009,12 @@ def stereo_to_anndata( for res_key in data.tl.key_record[key]: uns_key = _BaseResult.RENAME_DICT.get(res_key, res_key) adata.uns[uns_key] = transform_marker_genes_to_anndata(data.tl.result[res_key]) + elif key == 'spatial_hotspot': + for res_key in data.tl.key_record[key]: + if res_key in adata.uns: + del adata.uns[res_key] + if 'key_record' in adata.uns: + adata.uns['key_record']['spatial_hotspot'] = [] else: continue @@ -1192,15 +1204,13 @@ def read_gef( if len(gene_id[0]) == 0: gene_name_index = True - # gene_names = remove_genes_number(gene_names) if gene_name_index: if len(gene_id[0]) > 0: exp_matrix, gene_names = integrate_matrix_by_genes(gene_names, cell_num, - exp_matrix.data, exp_matrix.indices, exp_matrix.indptr) + exp_matrix.data, exp_matrix.indices, exp_matrix.indptr) data.genes = Gene(gene_name=gene_names) else: data.genes = Gene(gene_name=gene_id) - # data.genes['gene_name_underline'] = gene_names data.genes['real_gene_name'] = gene_names data.exp_matrix = exp_matrix if is_sparse else exp_matrix.toarray() @@ -1234,14 +1244,12 @@ def read_gef( data.position[:, 1] = cells['y'] if len(gene_id[0]) == 0: gene_name_index = True - # gene_names = remove_genes_number(gene_names) if gene_name_index: if len(gene_id[0]) > 0: exp_matrix, gene_names = integrate_matrix_by_genes(gene_names, cell_num, count, indices, indptr) data.genes = Gene(gene_name=gene_names) else: data.genes = Gene(gene_name=gene_id) - # data.genes['gene_name_underline'] = gene_names data.genes['real_gene_name'] = gene_names data.exp_matrix = exp_matrix if is_sparse else exp_matrix.toarray() @@ -1291,15 +1299,13 @@ def read_gef( exp_matrix = csr_matrix((count, (cell_ind, gene_ind)), shape=(cell_num, gene_num), dtype=np.uint32) if len(gene_id[0]) == 0: gene_name_index = True - # gene_names = remove_genes_number(gene_names) if gene_name_index: if len(gene_id[0]) > 0: exp_matrix, gene_names = integrate_matrix_by_genes(gene_names, cell_num, - exp_matrix.data, exp_matrix.indices, exp_matrix.indptr) + exp_matrix.data, exp_matrix.indices, exp_matrix.indptr) data.genes = Gene(gene_name=gene_names) else: data.genes = Gene(gene_name=gene_id) - # data.genes['gene_name_underline'] = gene_names data.genes['real_gene_name'] = gene_names data.exp_matrix = exp_matrix if is_sparse else exp_matrix.toarray() @@ -1327,15 +1333,13 @@ def read_gef( cell_ind, gene_ind, count = gef.get_sparse_matrix_indices2() exp_matrix = csr_matrix((count, (cell_ind, gene_ind)), shape=(cell_num, gene_num), dtype=np.uint32) - # gene_names = remove_genes_number(gene_names) if gene_name_index: if len(gene_id[0]) > 0: exp_matrix, gene_names = integrate_matrix_by_genes(gene_names, cell_num, - exp_matrix.data, exp_matrix.indices, exp_matrix.indptr) + exp_matrix.data, exp_matrix.indices, exp_matrix.indptr) data.genes = Gene(gene_name=gene_names) else: data.genes = Gene(gene_name=gene_id) - # data.genes['gene_name_underline'] = gene_names data.genes['real_gene_name'] = gene_names data.exp_matrix = exp_matrix if is_sparse else exp_matrix.toarray() @@ -1440,16 +1444,101 @@ def read_gef_info(file_path: str): return info_dict -# @ReadWriteUtils.check_file_exists -# def read_h5ad(file_path: str, flavor: str = 'scanpy'): -# ''' -# :param file_path: h5ad file path. -# :return: `StereoExpData`-like `AnnBasedStereoExpData` obj -# ''' -# if flavor == 'scanpy': -# from stereo.core.stereo_exp_data import AnnBasedStereoExpData -# return AnnBasedStereoExpData(file_path) -# elif flavor == 'seurat': -# raise NotImplementedError -# else: -# raise Exception +@ReadWriteUtils.check_file_exists +def mudata_to_msdata( + file_path: str = None, + sample_names: Optional[Union[np.ndarray, List[str], None]] = None, + scope_names: Optional[Union[np.ndarray, List[str], None]] = None, + entire_merged_data_name: Optional[str] = None +): + """ + Read a h5mu file and convert it to a MSData object. + + :param file_path: The path of the MuData file, defaults to None + :param sample_names: The names of single samples that are saved in the MuData object, defaults to None, + if None, the names starting with 'sample_' will be used. + :param scope_names: The names of merged samples that are saved in the MuData object, defaults to None, + if None, the names like 'scope_[0,1,2]' will be used. + :param entire_merged_data_name: The name of the merged sample which is merged from all samples, default to None, + if None, use the one like 'scope_[0,1,2]' whose square brackets contain index sequence of all samples. + + :return: The MSData object. + """ + try: + from mudata import read_h5mu + except ImportError: + raise ImportError("Please install mudata first: `pip install mudata`.") + from stereo.core.ms_data import MSData + + mudata = read_h5mu(file_path) + + mod_keys = list(mudata.mod.keys()) + if sample_names is None: + sample_names = [] + left_mod_keys = [] + for k in mod_keys: + match = re.match(r'^sample_\d+$', k) + if match: + sample_names.append(k) + else: + left_mod_keys.append(k) + sample_names.sort(key=lambda x: int(x.split('_')[1])) + mod_keys = left_mod_keys + + data_list = [AnnBasedStereoExpData(based_ann_data=mudata[n]) for n in sample_names if n in mudata.mod] + if len(data_list) == 0: + raise ValueError("No sample data found in the MuData object.") + if 'names' in mudata.uns: + names = list(mudata.uns['names']) + else: + names = sample_names + + var_type = mudata.uns.get('var_type', 'intersect') + relationship = mudata.uns.get('relationship', 'other') + relationship_info = mudata.uns.get('relationship_info', {}) + + ms_data = MSData( + _data_list=data_list, + _names=names, + _var_type=var_type, + _relationship=relationship, + _relationship_info=relationship_info + ) + + if entire_merged_data_name is None: + entire_merged_data_name = ms_data.generate_scope_key(ms_data.names) + entire_merged_data = None + + if scope_names is None: + scope_names = [] + left_mod_keys = [] + for k in mod_keys: + match = re.match(r'^scope_\[\d+(,\d+)*\]$', k) + if match: + scope_names.append(k) + else: + left_mod_keys.append(k) + mod_keys = left_mod_keys + + scopes_data = { + n: AnnBasedStereoExpData(based_ann_data=mudata[n]) for n in scope_names if n in mudata.mod + } + for k in scopes_data.keys(): + if k == entire_merged_data_name: + entire_merged_data = scopes_data[k] + if not re.match(r'^scope_\[\d+(,\d+)*\]$', k): + del scopes_data[k] + entire_merged_data_name = ms_data.generate_scope_key(ms_data.names) + scopes_data[entire_merged_data_name] = entire_merged_data + break + if len(scopes_data) > 0: + ms_data.scopes_data = scopes_data + ms_data.merged_data = entire_merged_data + + if 'result_keys' in mudata.uns: + for n, k in mudata.uns['result_keys'].items(): + if n not in ms_data.scopes_data: + continue + ms_data.tl.result_keys[n] = list(k) + + return ms_data diff --git a/stereo/io/writer.py b/stereo/io/writer.py index c2ca6ee6..a11361f0 100644 --- a/stereo/io/writer.py +++ b/stereo/io/writer.py @@ -14,6 +14,8 @@ import pickle from copy import deepcopy from os import environ +from typing import Optional, Literal +from tqdm import tqdm import h5py import numpy as np @@ -24,6 +26,7 @@ ) from stereo.core.stereo_exp_data import StereoExpData, AnnBasedStereoExpData +from stereo.core.ms_data import MSData from stereo.io import h5ad, stereo_to_anndata from stereo.log_manager import logger, LogManager @@ -258,7 +261,12 @@ def _write_one_h5ad_result(data, f, key_record): h5ad.write(item, f, f'{res_key}@{key}@co_occurrence', save_as_matrix=True) def _write_one_anndata(f: h5py.Group, data: AnnBasedStereoExpData): - from anndata._io.specs.registry import write_elem + from distutils.version import StrictVersion + from anndata import __version__ as anndata_version + if StrictVersion(anndata_version) < StrictVersion("0.8.0"): + from anndata._io.utils import write_attribute as write_elem + else: + from anndata._io.specs.registry import write_elem try: LogManager.stop_logging() adata = stereo_to_anndata(data, flavor='scanpy', split_batches=False) @@ -270,7 +278,6 @@ def _write_one_anndata(f: h5py.Group, data: AnnBasedStereoExpData): adata.strings_to_categoricals() if adata.raw is not None: adata.strings_to_categoricals(adata.raw.var) - adata.uns['key_record'] = data.tl.key_record f.attrs.setdefault("encoding-type", "anndata") f.attrs.setdefault("encoding-version", "0.1.0") @@ -290,7 +297,7 @@ def _write_one_anndata(f: h5py.Group, data: AnnBasedStereoExpData): write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) -def write_h5ms(ms_data, output: str): +def write_h5ms(ms_data, output: str, anndata_as_anndata: bool = True): """ Save an object of MSData into a h5 file whose suffix is 'h5ms'. @@ -301,11 +308,11 @@ def write_h5ms(ms_data, output: str): f.create_group('sample') for idx, data in enumerate(ms_data._data_list): f['sample'].create_group(f'sample_{idx}') - _write_one_h5ad(f['sample'][f'sample_{idx}'], data, use_raw=True, use_result=True) - # if isinstance(data, AnnBasedStereoExpData): - # _write_one_anndata(f['sample'][f'sample_{idx}'], data) - # else: - # _write_one_h5ad(f['sample'][f'sample_{idx}'], data) + # _write_one_h5ad(f['sample'][f'sample_{idx}'], data, use_raw=True, use_result=True) + if anndata_as_anndata and isinstance(data, AnnBasedStereoExpData): + _write_one_anndata(f['sample'][f'sample_{idx}'], data) + else: + _write_one_h5ad(f['sample'][f'sample_{idx}'], data, use_raw=True, use_result=True) # if ms_data._merged_data: # f.create_group('sample_merged') # _write_one_h5ad(f['sample_merged'], ms_data._merged_data) @@ -313,13 +320,14 @@ def write_h5ms(ms_data, output: str): f.create_group('sample_merged') for scope_key, merged_data in ms_data.scopes_data.items(): g = f['sample_merged'].create_group(scope_key) - if ms_data.merged_data and id(ms_data.merged_data) == id(merged_data): + # if ms_data.merged_data and id(ms_data.merged_data) == id(merged_data): + if merged_data is ms_data.merged_data: g.attrs['merged_from_all'] = True - _write_one_h5ad(g, merged_data, use_raw=True, use_result=True) - # if isinstance(merged_data, AnnBasedStereoExpData): - # _write_one_anndata(g, merged_data) - # else: - # _write_one_h5ad(g, merged_data) + # _write_one_h5ad(g, merged_data, use_raw=True, use_result=True) + if anndata_as_anndata and isinstance(merged_data, AnnBasedStereoExpData): + _write_one_anndata(g, merged_data) + else: + _write_one_h5ad(g, merged_data, use_raw=True, use_result=True) h5ad.write_list(f, 'names', ms_data.names) h5ad.write_dataframe(f, 'obs', ms_data.obs) h5ad.write_dataframe(f, 'var', ms_data.var) @@ -344,6 +352,8 @@ def write_mid_gef(data: StereoExpData, output: str): """ Write the StereoExpData object into a GEF (.h5) file. + The raw.exp_matrix will be used if it is not None, otherwise the data.exp_matrix will be used. + Parameters --------------------- data @@ -358,16 +368,28 @@ def write_mid_gef(data: StereoExpData, output: str): logger.info("The output standard gef file only contains one expression matrix with mid count." "Please make sure the expression matrix of StereoExpData object is mid count without normaliztion.") import numpy.lib.recfunctions as rfn - final_exp = [] # [(x_1,y_1,umi_1),(x_2,y_2,umi_2)] + final_exp_list = [] # [(x_1,y_1,umi_1),(x_2,y_2,umi_2)] final_gene = [] # [(A,offset,count)] - exp_np = data.exp_matrix.toarray() + # exp_np = data.exp_matrix.toarray() + + if data.raw is not None: + exp_np = data.raw.exp_matrix + if data.raw.shape != data.shape: + cells_isin = data.raw.cell_names.isin(data.cell_names) + genes_isin = data.raw.gene_names.isin(data.gene_names) + exp_np = exp_np[cells_isin, :][:, genes_isin] + else: + exp_np = data.exp_matrix - for i in range(exp_np.shape[1]): + for i in tqdm(range(exp_np.shape[1]), total=exp_np.shape[1]): gene_exp = exp_np[:, i] + if issparse(gene_exp): + gene_exp = gene_exp.toarray().flatten() c_idx = np.nonzero(gene_exp)[0] # idx for all cells - zipped = np.concatenate((data.position[c_idx], gene_exp[c_idx].reshape(c_idx.shape[0], 1)), axis=1) - for k in zipped: - final_exp.append(k) + final_exp_list.append(np.concatenate((data.position[c_idx], gene_exp[c_idx].reshape(c_idx.shape[0], 1)), axis=1)) + # zipped = np.concatenate((data.position[c_idx], gene_exp[c_idx].reshape(c_idx.shape[0], 1)), axis=1) + # for k in zipped: + # final_exp.append(k) # count g_len = len(final_gene) @@ -377,6 +399,7 @@ def write_mid_gef(data: StereoExpData, output: str): offset = last_offset + last_count count = c_idx.shape[0] final_gene.append((g_name, offset, count)) + final_exp = np.concatenate(final_exp_list, axis=0) final_exp_np = rfn.unstructured_to_structured( np.array(final_exp, dtype=int), np.dtype([('x', np.uint32), ('y', np.uint32), ('count', np.uint16)])) genetyp = np.dtype({'names': ['gene', 'offset', 'count'], 'formats': ['S32', np.uint32, np.uint32]}) @@ -477,3 +500,72 @@ def update_gef(data: StereoExpData, gef_file: str, cluster_res_key: str): h5f['cellBin']['cell']['cellTypeID'] = celltid del h5f['cellBin']['cellTypeList'] h5f['cellBin']['cellTypeList'] = groups_code + + +def write_h5mu(ms_data: MSData, output: str = None, compression: Optional[Literal["gzip", "lzf"]] = 'gzip'): + """ + Convert the MSData to a MuData and save it as a h5mu file. + + The single samples saved in MSData.data_list are named as 'sample_{i}'. + The scope data merged from some samples are named starting with 'scope_[{i0,i1,i2...}]'. + + :param ms_data: The object of MSData to be converted and saved. + :param output: The path of file into which MSData is saved, + if None, Only convert the MSData to a MuData object. + :param compression: The compression method used to save the h5mu file. + + :return: The MuData object. + """ + + try: + from mudata import MuData + except ImportError: + raise ImportError("Please install the mudata: pip install mudata.") + + adata_list = [] + adata_keys = [] + for i, data in enumerate(ms_data.data_list): + adata = stereo_to_anndata(data, flavor='scanpy', split_batches=False) + saved_name = f"sample_{i}" + # adata_dict[saved_name] = adata + adata_list.append(adata) + adata_keys.append(saved_name) + + merged_adata_list = [] + merged_adata_keys = [] + merged_adata_all = None + for scope_name, merged_data in ms_data.scopes_data.items(): + adata = stereo_to_anndata(merged_data, flavor='scanpy', split_batches=False) + # saved_name = f"merged_{scope_name}" + # adata_dict[scope_name] = adata + merged_adata_list.append(adata) + merged_adata_keys.append(scope_name) + if merged_data is ms_data.merged_data: + merged_adata_all = adata + + new_ms_data = MSData( + _data_list=[AnnBasedStereoExpData(based_ann_data=adata) for adata in adata_list], + _names=deepcopy(ms_data.names), + _merged_data=AnnBasedStereoExpData(based_ann_data=merged_adata_all), + _scopes_data={key: AnnBasedStereoExpData(based_ann_data=adata) for key, adata in zip(merged_adata_keys, merged_adata_list)}, + _var_type=ms_data.var_type, + _relationship=ms_data.relationship, + _relationship_info=deepcopy(ms_data.relationship_info) + ) + new_ms_data.tl.result_keys = deepcopy(ms_data.tl.result_keys) + # new_ms_data.tl._reset_result_keys() + + adata_dict = {key: adata for key, adata in zip(adata_keys, adata_list)} + adata_dict.update({key: adata for key, adata in zip(merged_adata_keys, merged_adata_list)}) + mudata = MuData(adata_dict) + + mudata.uns['names'] = new_ms_data.names + mudata.uns['var_type'] = new_ms_data.var_type + mudata.uns['relationship'] = new_ms_data.relationship + mudata.uns['relationship_info'] = new_ms_data.relationship_info + mudata.uns['result_keys'] = new_ms_data.tl.result_keys + + if output is not None: + mudata.write_h5mu(output, compression=compression) + + return mudata