diff --git a/app.py b/app.py index 29c3b5e..10068df 100644 --- a/app.py +++ b/app.py @@ -38,7 +38,7 @@ def server(input, output, session): _model = reactive.value(model_path) composition_server("composition", _dataframe) export_server("export") - dgea_server("dgea", _adata) + dgea_server("dgea", _adata, _model) tree_server("tree", _tree) app = App(app_ui, server) \ No newline at end of file diff --git a/dgea/dgea_scvi.py b/dgea/dgea_scvi.py index 4e4073b..594dfbb 100644 --- a/dgea/dgea_scvi.py +++ b/dgea/dgea_scvi.py @@ -17,7 +17,7 @@ def dgea_ui(): ) @module.server -def dgea_server(input, output, session, _adata: reactive.Value[ad.AnnData]): +def dgea_server(input, output, session, _adata: reactive.Value[ad.AnnData], _model: reactive.Value[str]): _counts = reactive.value(None) _uniques = reactive.value([]) @@ -32,9 +32,9 @@ def dgea_server(input, output, session, _adata: reactive.Value[ad.AnnData]): _filtered_genes = reactive.value(None) _filtered_counts = reactive.value(None) - run_dgea_server("run_dgea", _adata, _result, _counts, _reference, _alternative, _uniques, _contrast) + run_dgea_server("run_dgea", _adata, _result, _counts, _reference, _alternative, _model, _uniques, _contrast) filter_dgea_server("filter_dgea", _adata, _counts, _uniques, _result, _filtered_result, _filtered_genes, - _filtered_counts, _reference, _alternative, _contrast, _log10_p, _lfc) + _filtered_counts, _reference, _alternative, _contrast, _model, _log10_p, _lfc) plot_dgea_server("plot_dgea", _filtered_counts, _contrast, _reference, _alternative, _result, _log10_p, _lfc) diff --git a/dgea/filter_dgea_scvi.py b/dgea/filter_dgea_scvi.py index fec3dcd..bc74050 100644 --- a/dgea/filter_dgea_scvi.py +++ b/dgea/filter_dgea_scvi.py @@ -17,7 +17,7 @@ def filter_dgea_ui(): def filter_dgea_server(input, output, session, _adata, _counts, _uniques, _result, _filtered_result, _filtered_genes, _filtered_counts, - _reference, _alternative, _contrast, + _reference, _alternative, _contrast, _model, _log10_p, _lfc ): @@ -57,11 +57,12 @@ def update_result(): reference = _reference.get() alternative = _alternative.get() contrast = _contrast.get() + model = _model.get() if None in (reference, alternative, contrast): return - res_df = scanvi_dgea(adata, contrast, reference, alternative) + res_df = scanvi_dgea(adata, contrast, reference, alternative, model) res_counts = get_normalized_counts(adata) _result.set(res_df) _counts.set(res_counts)