Skip to content

Commit

Permalink
fixed some naming errors and both model options
Browse files Browse the repository at this point in the history
  • Loading branch information
serareif committed Jun 14, 2024
1 parent ea5028d commit f37c73b
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 30 deletions.
4 changes: 3 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

from composition import composition_server, composition_ui
from export import export_ui, export_server
from dgea.dgea_1 import dgea_server, dgea_ui
from dgea.dgea_scvi_define import dgea_server, dgea_ui
from tree import tree_server, tree_ui

with open("data/config.json") as f:
config = json.load(f)
adata = sc.read_h5ad("data/" + config["adata"])
tree = pickle.load(open("data/" + config["tree"], "rb")) if "tree" in config else None
name = config["name"]
model_path = config["model_path"]

categorical_columns = adata.obs.select_dtypes(include="category").columns.to_list()

Expand All @@ -34,6 +35,7 @@ def server(input, output, session):
_dataframe = reactive.value(adata.obs)
_adata = reactive.value(adata)
_tree = reactive.value(tree)
_model = reactive.value(model_path)
composition_server("composition", _dataframe)
export_server("export")
dgea_server("dgea", _adata)
Expand Down
3 changes: 2 additions & 1 deletion data/config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"adata": "merged.h5ad",
"tree": "scarches.tree.pkl",
"name": "possible_atlas"
"name": "possible_atlas",
"model_path": "./data"
}
File renamed without changes.
39 changes: 14 additions & 25 deletions dgea/dgea_scvi.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,30 @@
import numpy as np
import pandas as pd
import scanpy as sc
import random
import anndata as ad
import os

import torch
from scvi.model import SCANVI, SCVI

def scanvi_dgea(adata:ad.AnnData, groupby:str, reference:str, alternative:str):

directory_model = "model_scvi_test/"
scanvi_subdirectory = os.path.join(directory_model, "scanvi_model")

try:
loaded_model = SCVI.load(directory_model, adata=None)
except Exception as e:
loaded_model = None

if isinstance(loaded_model, SCVI):
os.makedirs(scanvi_subdirectory, exist_ok=True)
scanvi_model = SCANVI.from_scvi_model(loaded_model, unlabeled_category = 'Unknown', labels_key="cell_type")
scanvi_model.save(scanvi_subdirectory, overwrite=True)
directory_model = scanvi_subdirectory
print('is scvi')
def scanvi_dgea(adata:ad.AnnData, groupby:str, reference:str, alternative:str, directory_model:str):

if 'cell_type' in adata.obs.columns:
model_type = SCANVI
print('is scavi')

else:
model_type = SCVI
print('is scanvi')

SCANVI.prepare_query_anndata(adata = adata, reference_model=directory_model)
model_type.prepare_query_anndata(adata = adata, reference_model=directory_model)

scanvi_model = SCANVI.load_query_data(adata, directory_model)
model = model_type.load_query_data(adata, directory_model)

groups = np.array(adata.obs[groupby].unique())

idx1 = adata.obs[groupby] == reference
idx2 = adata.obs[groupby] == alternative

dge_change = scanvi_model.differential_expression(adata=adata, groupby=groupby, idx1=idx1, idx2=idx2, mode="change")
dge_change = model.differential_expression(adata=adata, groupby=groupby, idx1=idx1, idx2=idx2, mode="change")

epsilon = 1e-10
dge_change['proba_not_de'] = np.maximum(dge_change["proba_not_de"], epsilon)
Expand All @@ -56,6 +43,8 @@ def get_normalized_counts(adata):
df_counts = pd.DataFrame(dense_matrix, index=adata.obs_names, columns=adata.var_names)
return df_counts

#adata = sc.read_h5ad('model_scvi_test/adata.h5ad')
#dge_test = scanvi_dgea(adata, "cell_type", "Endothelial", "Epithelial")
#print(dge_test.head())
if __name__ == '__main__':
print('Running DGEA test')
adata = sc.read_h5ad('/workspaces/SIMBA-Downstream_1/data/atlas.h5ad')
dge_test = scanvi_dgea(adata, "cell_type", "Endothelial", "Epithelial", './data')
print(dge_test.head())
8 changes: 5 additions & 3 deletions dgea/run_dgea_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def run_dgea_server(input, output, session,
_counts: reactive.Value[pd.DataFrame],
_reference: reactive.Value[str],
_alternative: reactive.Value[str],
_model: reactive.Value[str],
_uniques: reactive.Value[list],
_contrast: reactive.Value[str]):
_category_columns = reactive.value([])
Expand Down Expand Up @@ -49,7 +50,8 @@ def handle_run():
contrast = _contrast.get()
referece = _reference.get()
alternative = _alternative.get()
run_scanvi(adata, contrast, referece, alternative)
model = _model.get()
run_scanvi(adata, contrast, referece, alternative, model)

@reactive.effect
def update_scanvi_result():
Expand All @@ -59,8 +61,8 @@ def update_scanvi_result():

@ui.bind_task_button(button_id="run")
@reactive.extended_task
async def run_scanvi(adata, contrast, reference, alternative):
dge_change = scanvi_dgea(adata, contrast, reference, alternative)
async def run_scanvi(adata, contrast, reference, alternative, model):
dge_change = scanvi_dgea(adata, contrast, reference, alternative, model)

counts = get_normalized_counts(adata)
return dge_change, counts
Expand Down

0 comments on commit f37c73b

Please sign in to comment.