diff --git a/dgea/dgea_scvi.py b/dgea/dgea_scvi.py index 594dfbb..e56449a 100644 --- a/dgea/dgea_scvi.py +++ b/dgea/dgea_scvi.py @@ -20,10 +20,13 @@ def dgea_ui(): def dgea_server(input, output, session, _adata: reactive.Value[ad.AnnData], _model: reactive.Value[str]): _counts = reactive.value(None) _uniques = reactive.value([]) + _sub_uniques = reactive.value([]) _contrast = reactive.value(None) _reference = reactive.value(None) _alternative = reactive.value(None) + _sub_category = reactive.value(None) + _chosen_values = reactive.value([]) _log10_p = reactive.value(0.05) _lfc = reactive.value(1) @@ -32,9 +35,9 @@ def dgea_server(input, output, session, _adata: reactive.Value[ad.AnnData], _mod _filtered_genes = reactive.value(None) _filtered_counts = reactive.value(None) - run_dgea_server("run_dgea", _adata, _result, _counts, _reference, _alternative, _model, _uniques, _contrast) + run_dgea_server("run_dgea", _adata, _result, _counts, _reference, _alternative, _model, _uniques, _contrast, _sub_category, _chosen_values, _sub_uniques) filter_dgea_server("filter_dgea", _adata, _counts, _uniques, _result, _filtered_result, _filtered_genes, - _filtered_counts, _reference, _alternative, _contrast, _model, _log10_p, _lfc) + _filtered_counts, _reference, _alternative, _contrast, _sub_category, _sub_uniques, _chosen_values, _model, _log10_p, _lfc) plot_dgea_server("plot_dgea", _filtered_counts, _contrast, _reference, _alternative, _result, _log10_p, _lfc) diff --git a/dgea/filter_dgea_scvi.py b/dgea/filter_dgea_scvi.py index bc74050..8436cb8 100644 --- a/dgea/filter_dgea_scvi.py +++ b/dgea/filter_dgea_scvi.py @@ -6,6 +6,7 @@ @module.ui def filter_dgea_ui(): return ui.div( + ui.output_ui("select_values"), ui.output_ui("select_reference"), ui.output_ui("select_alternative"), ui.input_slider("log10_pscore", "Ropability in Reference (significance threshold)", min=0, max=20, step=0.01, value=3), @@ -17,10 +18,17 @@ def filter_dgea_ui(): def filter_dgea_server(input, output, session, _adata, _counts, _uniques, _result, _filtered_result, _filtered_genes, _filtered_counts, - _reference, _alternative, _contrast, _model, + _reference, _alternative, _contrast, _sub_category, _sub_uniques, _chosen_values, _model, _log10_p, _lfc ): + @output + @render.ui + def select_values(): + sub_uniques = _sub_uniques.get() + + return ui.input_select("value", "Choose values:", choices=sub_uniques, selectize=True, multiple=True, selected=sub_uniques) + @output @render.ui def select_reference(): @@ -46,27 +54,12 @@ def select_alternative(): @reactive.effect def update_filters(): + _chosen_values.set(input["value"].get()) _reference.set(input["reference"].get()) _alternative.set(input["alternative"].get()) _log10_p.set(input["log10_pscore"].get()) _lfc.set(input["lfc"].get()) - @reactive.effect - def update_result(): - adata = _adata.get() - reference = _reference.get() - alternative = _alternative.get() - contrast = _contrast.get() - model = _model.get() - - if None in (reference, alternative, contrast): - return - - res_df = scanvi_dgea(adata, contrast, reference, alternative, model) - res_counts = get_normalized_counts(adata) - _result.set(res_df) - _counts.set(res_counts) - @reactive.effect def filter_result(): result = _result.get() diff --git a/dgea/plot_dgea_scvi.py b/dgea/plot_dgea_scvi.py index 2fc0b15..dfce88d 100644 --- a/dgea/plot_dgea_scvi.py +++ b/dgea/plot_dgea_scvi.py @@ -38,9 +38,6 @@ def plot_dgea_server(input, output, session, @render.plot def plot_heatmap(): counts_df = _filtered_counts.get() - contrast = _contrast.get() - reference = _reference.get() - alternative = _alternative.get() if counts_df is None: return None diff --git a/dgea/run_dgea_scvi.py b/dgea/run_dgea_scvi.py index 82e9afc..a590426 100644 --- a/dgea/run_dgea_scvi.py +++ b/dgea/run_dgea_scvi.py @@ -8,7 +8,8 @@ @module.ui def run_dgea_ui(): return ui.div(ui.output_ui("contrast_selector"), - ui.input_task_button("run", "Run analysis")) + ui.output_ui("subset_selector"), + ui.input_task_button("run", "Run analysis")) @module.server def run_dgea_server(input, output, session, @@ -19,7 +20,10 @@ def run_dgea_server(input, output, session, _alternative: reactive.Value[str], _model: reactive.Value[str], _uniques: reactive.Value[list], - _contrast: reactive.Value[str]): + _contrast: reactive.Value[str], + _sub_category: reactive.Value[str], + _chosen_values: reactive.Value[list], + _sub_uniques: reactive.Value[list]): _category_columns = reactive.value([]) _numeric_columns = reactive.value([]) @@ -38,10 +42,22 @@ def contrast_selector(): return ui.input_select("contrast", "Contrast", choices=columns, selected=columns[0]) + @output + @render.ui + def subset_selector(): + columns = _category_columns.get() + + return ui.input_select("sub_category", "Category to subset", choices=columns, selected=columns[0]) + @reactive.effect def update_contrast(): contrast = input["contrast"].get() _contrast.set(contrast) + + @reactive.effect + def update_subset(): + sub_category = input["sub_category"].get() + _sub_category.set(sub_category) @reactive.effect @reactive.event(input["run"]) @@ -51,7 +67,10 @@ def handle_run(): referece = _reference.get() alternative = _alternative.get() model = _model.get() - run_scanvi(adata, contrast, referece, alternative, model) + sub_category = _sub_category.get() + chosen_values = _chosen_values.get() + adata_subset = adata[adata.obs[sub_category].isin(chosen_values)].copy() + run_scanvi(adata_subset, contrast, referece, alternative, model) @reactive.effect def update_scanvi_result(): @@ -75,3 +94,12 @@ def update_uniques(): return uniques = adata.obs[contrast].unique().tolist() _uniques.set(uniques) + + @reactive.effect + def update_sub_uniques(): + adata = _adata.get() + sub_category = _sub_category.get() + if adata is None or sub_category is None: + return + sub_uniques = adata.obs[sub_category].unique().tolist() + _sub_uniques.set(sub_uniques)