diff --git a/app.py b/app.py index 687b429..0257765 100644 --- a/app.py +++ b/app.py @@ -5,7 +5,7 @@ from composition import composition_server, composition_ui from export import export_ui, export_server -from dgea.dgea_1 import dgea_server, dgea_ui +from dgea.dgea_scvi_define import dgea_server, dgea_ui from tree import tree_server, tree_ui with open("data/config.json") as f: @@ -13,6 +13,7 @@ adata = sc.read_h5ad("data/" + config["adata"]) tree = pickle.load(open("data/" + config["tree"], "rb")) if "tree" in config else None name = config["name"] + model_path = config["model_path"] categorical_columns = adata.obs.select_dtypes(include="category").columns.to_list() @@ -34,6 +35,7 @@ def server(input, output, session): _dataframe = reactive.value(adata.obs) _adata = reactive.value(adata) _tree = reactive.value(tree) + _model = reactive.value(model_path) composition_server("composition", _dataframe) export_server("export") dgea_server("dgea", _adata) diff --git a/data/config.json b/data/config.json index f5d0142..33b8237 100644 --- a/data/config.json +++ b/data/config.json @@ -1,5 +1,6 @@ { "adata": "merged.h5ad", "tree": "scarches.tree.pkl", - "name": "possible_atlas" + "name": "possible_atlas", + "model_path": "./data" } \ No newline at end of file diff --git a/dgea/dgea_1.py b/dgea/dgea_scanvi_define.py similarity index 100% rename from dgea/dgea_1.py rename to dgea/dgea_scanvi_define.py diff --git a/dgea/dgea_scvi.py b/dgea/dgea_scvi.py index db3020c..b40d73b 100644 --- a/dgea/dgea_scvi.py +++ b/dgea/dgea_scvi.py @@ -1,43 +1,30 @@ import numpy as np import pandas as pd import scanpy as sc -import random import anndata as ad -import os -import torch from scvi.model import SCANVI, SCVI -def scanvi_dgea(adata:ad.AnnData, groupby:str, reference:str, alternative:str): - - directory_model = "model_scvi_test/" - scanvi_subdirectory = os.path.join(directory_model, "scanvi_model") - - try: - loaded_model = SCVI.load(directory_model, adata=None) - except Exception as e: - loaded_model = None - - if isinstance(loaded_model, SCVI): - os.makedirs(scanvi_subdirectory, exist_ok=True) - scanvi_model = SCANVI.from_scvi_model(loaded_model, unlabeled_category = 'Unknown', labels_key="cell_type") - scanvi_model.save(scanvi_subdirectory, overwrite=True) - directory_model = scanvi_subdirectory - print('is scvi') +def scanvi_dgea(adata:ad.AnnData, groupby:str, reference:str, alternative:str, directory_model:str): + + if 'cell_type' in adata.obs.columns: + model_type = SCANVI + print('is scavi') else: + model_type = SCVI print('is scanvi') - SCANVI.prepare_query_anndata(adata = adata, reference_model=directory_model) + model_type.prepare_query_anndata(adata = adata, reference_model=directory_model) - scanvi_model = SCANVI.load_query_data(adata, directory_model) + model = model_type.load_query_data(adata, directory_model) groups = np.array(adata.obs[groupby].unique()) idx1 = adata.obs[groupby] == reference idx2 = adata.obs[groupby] == alternative - dge_change = scanvi_model.differential_expression(adata=adata, groupby=groupby, idx1=idx1, idx2=idx2, mode="change") + dge_change = model.differential_expression(adata=adata, groupby=groupby, idx1=idx1, idx2=idx2, mode="change") epsilon = 1e-10 dge_change['proba_not_de'] = np.maximum(dge_change["proba_not_de"], epsilon) @@ -56,6 +43,8 @@ def get_normalized_counts(adata): df_counts = pd.DataFrame(dense_matrix, index=adata.obs_names, columns=adata.var_names) return df_counts -#adata = sc.read_h5ad('model_scvi_test/adata.h5ad') -#dge_test = scanvi_dgea(adata, "cell_type", "Endothelial", "Epithelial") -#print(dge_test.head()) +if __name__ == '__main__': + print('Running DGEA test') + adata = sc.read_h5ad('/workspaces/SIMBA-Downstream_1/data/atlas.h5ad') + dge_test = scanvi_dgea(adata, "cell_type", "Endothelial", "Epithelial", './data') + print(dge_test.head()) diff --git a/dgea/run_dgea_scvi.py b/dgea/run_dgea_scvi.py index 6b0696e..14152d8 100644 --- a/dgea/run_dgea_scvi.py +++ b/dgea/run_dgea_scvi.py @@ -17,6 +17,7 @@ def run_dgea_server(input, output, session, _counts: reactive.Value[pd.DataFrame], _reference: reactive.Value[str], _alternative: reactive.Value[str], + _model: reactive.Value[str], _uniques: reactive.Value[list], _contrast: reactive.Value[str]): _category_columns = reactive.value([]) @@ -49,7 +50,8 @@ def handle_run(): contrast = _contrast.get() referece = _reference.get() alternative = _alternative.get() - run_scanvi(adata, contrast, referece, alternative) + model = _model.get() + run_scanvi(adata, contrast, referece, alternative, model) @reactive.effect def update_scanvi_result(): @@ -59,8 +61,8 @@ def update_scanvi_result(): @ui.bind_task_button(button_id="run") @reactive.extended_task - async def run_scanvi(adata, contrast, reference, alternative): - dge_change = scanvi_dgea(adata, contrast, reference, alternative) + async def run_scanvi(adata, contrast, reference, alternative, model): + dge_change = scanvi_dgea(adata, contrast, reference, alternative, model) counts = get_normalized_counts(adata) return dge_change, counts