diff --git a/anoexpress/anoexpress.py b/anoexpress/anoexpress.py index fec5632..0aacd69 100644 --- a/anoexpress/anoexpress.py +++ b/anoexpress/anoexpress.py @@ -110,7 +110,7 @@ def filter_low_counts(data_df, data_type, analysis, gene_id, count_threshold=5, return data_df.query("GeneID in @mask") -def data(data_type, analysis, microarray=False, gene_id=None, sort_by=None, annotations=False, pvalue_filter=None, low_count_filter=None, fraction_na_allowed=None): +def data(data_type, analysis, microarray=False, gene_id=None, sample_query=None, sort_by=None, annotations=False, pvalue_filter=None, low_count_filter=None, fraction_na_allowed=None): """ Load the combined data for a given analysis and sample query @@ -126,6 +126,9 @@ def data(data_type, analysis, microarray=False, gene_id=None, sort_by=None, anno gene_id: str or list, optional A string (AGAP/AFUN identifier or genomic span in the format 2L:500-10000), or list of strings, or path to a file containing a list of gene ids in the first column. Input file can be .tsv, .txt, or .csv, or .xlsx. + sample_query: str, optional + A string containing a pandas query to subset the samples of interest from the comparison metadata file. For example, to plot only the + samples from Burkina Faso, use "country == 'Burkina Faso'". Defaults to None. sort_by: {"median", "mean", "agap", "position", None}, optional sort by median/mean of fold changes (descending), or by AGAP, or by position in the genome, or dont sort input gene ids. annotations: bool, optional @@ -163,6 +166,11 @@ def data(data_type, analysis, microarray=False, gene_id=None, sort_by=None, anno # subset to the species comparisons of interest df = df.loc[:, metadata_ids] + if sample_query: + # subset to the sample ids of interest + mask = df_metadata.eval(sample_query) + df = df.loc[:, mask] + # subset to the gene ids of interest including reading file if gene_id is not None: gene_id = resolve_gene_id(gene_id=gene_id, analysis=analysis) @@ -219,7 +227,21 @@ def _sort_genes(df, analysis, sort_by=None): return df.iloc[sort_idxs, :].copy() -def plot_gene_expression(gene_id, analysis="gamb_colu_arab_fun", microarray=False, title=None, plot_type='strip', sort_by='agap', pvalue_filter=None, width=1600, height=None, save_html=None): + +def query_fc_count_data(fc_data, count_data, comparison_metadata, sample_metadata, query): + mask = comparison_metadata.eval(query).to_list() + comparison_metadata = comparison_metadata[mask] + fc_data = fc_data.loc[:, mask] + + resistant_strains = comparison_metadata['resistant'].to_list() + sample_mask = sample_metadata.eval("condition in @resistant_strains").to_list() + sample_metadata = sample_metadata[sample_mask] + count_data = count_data.loc[:, sample_mask] + + return fc_data, count_data, comparison_metadata, sample_metadata + + +def plot_gene_expression(gene_id, analysis="gamb_colu_arab_fun", microarray=False, sample_query=None, title=None, plot_type='strip', sort_by='agap', pvalue_filter=None, width=1600, height=None, save_html=None): """Plot fold changes of provided AGAP gene IDs from RNA-Seq meta-analysis dataset @@ -234,6 +256,9 @@ def plot_gene_expression(gene_id, analysis="gamb_colu_arab_fun", microarray=Fals present, due to the process of finding orthologs. microarray: bool, optional whether to include the IR-Tex microarray data in the plot + sample_query: str, optional + A string containing a pandas query to subset the samples of interest from the comparison metadata file. For example, to plot only the + samples from Burkina Faso, use "country == 'Burkina Faso'". Defaults to None. title : str Plot title plot_type : {"strip", "boxplot"}, optional @@ -258,10 +283,14 @@ def plot_gene_expression(gene_id, analysis="gamb_colu_arab_fun", microarray=Fals df_samples = sample_metadata(analysis=analysis) # load fold change data, make long format and merge with metadata for hovertext - fc_data = data(data_type="fcs", analysis=analysis, microarray=microarray, gene_id=gene_id, sort_by=sort_by, annotations=True, pvalue_filter=pvalue_filter).reset_index() + fc_data = data(data_type="fcs", analysis=analysis, microarray=microarray, sample_query=sample_query, gene_id=gene_id, sort_by=sort_by, annotations=True, pvalue_filter=pvalue_filter).reset_index() # load count data, make long format and merge with metadata for hovertext - count_data = data(data_type="log2counts", analysis=analysis, microarray=microarray, gene_id=gene_id, sort_by=None) + count_data = data(data_type="log2counts", analysis=analysis, microarray=microarray, gene_id=gene_id, sample_query=sample_query, sort_by=None) count_data = count_data.loc[fc_data['GeneID']].reset_index() + + if sample_query: + fc_data, count_data, df_metadata, df_samples = query_fc_count_data(fc_data=fc_data, count_data=count_data, df_metadata=df_metadata, sample_metadata=df_samples, query=sample_query) + count_data = count_data.melt(id_vars='GeneID', var_name='sampleID', value_name='log2_counts') count_data = count_data.merge(df_samples, how='left').assign(counts = lambda x: np.round(2**x.log2_counts, 0)) diff --git a/tests/test_anoexpress.py b/tests/test_anoexpress.py index 2a6174b..14ea1a4 100644 --- a/tests/test_anoexpress.py +++ b/tests/test_anoexpress.py @@ -124,6 +124,27 @@ def test_data_low_count_filter(low_count_filter=10): assert isinstance(data_low_df, pd.DataFrame) assert data_df.shape[0] > data_low_df.shape[0] +def test_data_sample_query(query="country == 'Burkina Faso'"): + + data_df = xpress.data( + data_type="fcs", + analysis="gamb_colu", + microarray=False, + sort_by=None, + ) + + data_low_df = xpress.data( + data_type="fcs", + analysis="gamb_colu", + microarray=False, + sample_query=query, + sort_by=None) + + assert data_low_df is not None + assert not data_low_df.empty + assert isinstance(data_low_df, pd.DataFrame) + assert data_df.shape[1] > data_low_df.shape[1] + @pytest.mark.parametrize( "gene_ids", [gene, gene_ids]