Skip to content

Commit

Permalink
Merge pull request #57 from sanjaynagi/implement-minimum-counts-filter
Browse files Browse the repository at this point in the history
allow file input to gsea, filter_low_counts + sample_query
  • Loading branch information
sanjaynagi authored Jan 9, 2024
2 parents 56fb44b + 1efbe9a commit e05bef6
Show file tree
Hide file tree
Showing 4 changed files with 2,886 additions and 1,958 deletions.
133 changes: 102 additions & 31 deletions anoexpress/anoexpress.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,39 @@ def metadata(analysis, microarray=False):
return metadata


def data(data_type, analysis, microarray=False, gene_id=None, sort_by=None, annotations=False, pvalue_filter=None, fraction_na_allowed=None):
def resolve_gene_id(gene_id, analysis):

if isinstance(gene_id, str):
if gene_id.startswith(('2L', '2R', '3L', '3R', 'X', '2RL', '3RL')):
import malariagen_data
if analysis == 'fun':
assert "Unfortunately the genome feature file in malariagen_data does not contain AFUN identifiers, so we cannot subset by genomic span for An. funestus."
else:
ag3 = malariagen_data.Ag3()
gff = ag3.genome_features(region=gene_id).query("type == 'gene'")
gene_id = gff.ID.to_list()
elif gene_id.endswith(('.tsv', '.txt')):
gene_id = pd.read_csv(gene_id, sep="\t", header=None).iloc[:, 0].to_list()
elif gene_id.endswith('.csv'):
gene_id = pd.read_csv(gene_id, header=None).iloc[:, 0].to_list()
elif gene_id.endswith('.xlsx'):
gene_id = pd.read_excel(gene_id, header=None).iloc[:, 0].to_list()

return gene_id

def filter_low_counts(data_df, data_type, analysis, gene_id, count_threshold=5, func=np.nanmedian):
if data_type != 'log2counts':
count_data = data(data_type='log2counts', analysis=analysis, gene_id=gene_id)
mask = 2**count_data.apply(func=func, axis=1) > count_threshold
else:
mask = 2**data_df.apply(func=func, axis=1) > count_threshold

print(f"Removing {(mask == 0).sum()} genes with median counts below the threshold ({count_threshold})")
mask = mask[mask].index.to_list()

return data_df.query("GeneID in @mask")

def data(data_type, analysis, microarray=False, gene_id=None, sample_query=None, sort_by=None, annotations=False, pvalue_filter=None, low_count_filter=None, fraction_na_allowed=None):
"""
Load the combined data for a given analysis and sample query
Expand All @@ -94,13 +126,18 @@ def data(data_type, analysis, microarray=False, gene_id=None, sort_by=None, anno
gene_id: str or list, optional
A string (AGAP/AFUN identifier or genomic span in the format 2L:500-10000), or list of strings, or path to a file containing a list of gene ids in the first column.
Input file can be .tsv, .txt, or .csv, or .xlsx.
sample_query: str, optional
A string containing a pandas query to subset the samples of interest from the comparison metadata file. For example, to plot only the
samples from Burkina Faso, use "country == 'Burkina Faso'". Defaults to None.
sort_by: {"median", "mean", "agap", "position", None}, optional
sort by median/mean of fold changes (descending), or by AGAP, or by position in the genome, or dont sort input gene ids.
annotations: bool, optional
whether to add gene name and description to the dataframe as index. Default is False.
pvalue_filter: float, optional
if provided, fold-change entries with an adjusted p-value below the threshold will be set to NaN. Default is None.
ignored if the data_type is not 'fcs'.
low_count_filter: int, optional
if provided, genes with a median count below the threshold will be removed from the dataframe. Default is None.
fraction_na_allowed: float, optional
fraction of missing values allowed in the data. Defaults to 0.5
Expand Down Expand Up @@ -129,23 +166,14 @@ def data(data_type, analysis, microarray=False, gene_id=None, sort_by=None, anno
# subset to the species comparisons of interest
df = df.loc[:, metadata_ids]

if sample_query:
# subset to the sample ids of interest
mask = df_metadata.eval(sample_query).to_list()
df = df.loc[:, mask]

# subset to the gene ids of interest including reading file
if gene_id is not None:
if isinstance(gene_id, str):
if gene_id.startswith(('2L', '2R', '3L', '3R', 'X', '2RL', '3RL')):
import malariagen_data
if analysis == 'fun':
assert "Unfortunately the genome feature file in malariagen_data does not contain AFUN identifiers, so we cannot subset by genomic span for An. funestus."
else:
ag3 = malariagen_data.Ag3()
gff = ag3.genome_features(region=gene_id).query("type == 'gene'")
gene_id = gff.ID.to_list()
elif gene_id.endswith(('.tsv', '.txt')):
gene_id = pd.read_csv(gene_id, sep="\t", header=None).iloc[:, 0].to_list()
elif gene_id.endswith('.csv'):
gene_id = pd.read_csv(gene_id, header=None).iloc[:, 0].to_list()
elif gene_id.endswith('.xlsx'):
gene_id = pd.read_excel(gene_id, header=None).iloc[:, 0].to_list()
gene_id = resolve_gene_id(gene_id=gene_id, analysis=analysis)
df = df.query("GeneID in @gene_id")

if annotations: # add gene name and description to the dataframe as index
Expand All @@ -158,8 +186,12 @@ def data(data_type, analysis, microarray=False, gene_id=None, sort_by=None, anno
# sort genes
df = _sort_genes(df=df, analysis=analysis, sort_by=sort_by)

# remove low count genes
if low_count_filter is not None:
df = filter_low_counts(data_df=df, data_type=data_type, analysis=analysis, gene_id=gene_id, count_threshold=low_count_filter, func=np.nanmedian)

# remove genes with lots of NA
if fraction_na_allowed:
# remove genes with lots of NA
df = filter_nas(df=df, fraction_na_allowed=fraction_na_allowed)

return df
Expand Down Expand Up @@ -195,7 +227,21 @@ def _sort_genes(df, analysis, sort_by=None):

return df.iloc[sort_idxs, :].copy()

def plot_gene_expression(gene_id, analysis="gamb_colu_arab_fun", microarray=False, title=None, plot_type='strip', sort_by='agap', pvalue_filter=None, width=1600, height=None, save_html=None):

def query_fc_count_data(fc_data, count_data, comparison_metadata, sample_metadata, query):
mask = comparison_metadata.eval(query).to_list()
comparison_metadata = comparison_metadata[mask]
fc_data = fc_data.loc[:, mask]

resistant_strains = comparison_metadata['resistant'].to_list()
sample_mask = sample_metadata.eval("condition in @resistant_strains").to_list()
sample_metadata = sample_metadata[sample_mask]
count_data = count_data.loc[:, sample_mask]

return fc_data, count_data, comparison_metadata, sample_metadata


def plot_gene_expression(gene_id, analysis="gamb_colu_arab_fun", microarray=False, sample_query=None, title=None, plot_type='strip', sort_by='agap', pvalue_filter=None, width=1600, height=None, save_html=None):
"""Plot fold changes of provided AGAP gene IDs from RNA-Seq
meta-analysis dataset
Expand All @@ -210,13 +256,18 @@ def plot_gene_expression(gene_id, analysis="gamb_colu_arab_fun", microarray=Fals
present, due to the process of finding orthologs.
microarray: bool, optional
whether to include the IR-Tex microarray data in the plot
sample_query: str, optional
A string containing a pandas query to subset the samples of interest from the comparison metadata file. For example, to plot only the
samples from Burkina Faso, use "country == 'Burkina Faso'". Defaults to None.
title : str
Plot title
plot_type : {"strip", "boxplot"}, optional
valid options are 'strip' or 'boxplot'
sort_by : {"median", "mean", "agap", None}, optional
sort by median/mean of fold changes (descending), or by AGAP, or dont sort input gene ids.
identifier
pvalue_filter: float, optional
if provided, fold-change entries with an adjusted p-value below the threshold will be removed from the plot. Default is None.
width : int
Width in pixels of the plotly figure
height: int, optional
Expand All @@ -232,10 +283,14 @@ def plot_gene_expression(gene_id, analysis="gamb_colu_arab_fun", microarray=Fals
df_samples = sample_metadata(analysis=analysis)

# load fold change data, make long format and merge with metadata for hovertext
fc_data = data(data_type="fcs", analysis=analysis, microarray=microarray, gene_id=gene_id, sort_by=sort_by, annotations=True, pvalue_filter=pvalue_filter).reset_index()
fc_data = data(data_type="fcs", analysis=analysis, microarray=microarray, sample_query=sample_query, gene_id=gene_id, sort_by=sort_by, annotations=True, pvalue_filter=pvalue_filter).reset_index()
# load count data, make long format and merge with metadata for hovertext
count_data = data(data_type="log2counts", analysis=analysis, microarray=microarray, gene_id=gene_id, sort_by=None)
count_data = data(data_type="log2counts", analysis=analysis, microarray=microarray, gene_id=gene_id, sample_query=sample_query, sort_by=None)
count_data = count_data.loc[fc_data['GeneID']].reset_index()

if sample_query:
fc_data, count_data, df_metadata, df_samples = query_fc_count_data(fc_data=fc_data, count_data=count_data, df_metadata=df_metadata, sample_metadata=df_samples, query=sample_query)

count_data = count_data.melt(id_vars='GeneID', var_name='sampleID', value_name='log2_counts')
count_data = count_data.merge(df_samples, how='left').assign(counts = lambda x: np.round(2**x.log2_counts, 0))

Expand Down Expand Up @@ -381,10 +436,11 @@ def filter_nas(df, fraction_na_allowed):
"""
n_cols = df.shape[1]
na_mask = df.apply(lambda x: x.isna().sum() / n_cols > fraction_na_allowed, axis=1)
print(f"Removing {na_mask.sum()} genes with higher proportion of NAs than the threshold ({fraction_na_allowed})")
return df.loc[~na_mask, :]


def load_candidates(analysis, name='median', func=np.nanmedian, query_annotation=None, query_fc=None, microarray=False, fraction_na_allowed=None):
def load_candidates(analysis, name='median', func=np.nanmedian, query_annotation=None, query_fc=None, microarray=False, low_count_filter=None, fraction_na_allowed=None):
"""
Load the candidate genes for a given analysis. Optionally, filter by annotation or fold change data.
Expand All @@ -403,6 +459,8 @@ def load_candidates(analysis, name='median', func=np.nanmedian, query_annotation
filter genes by fold change. Defaults to None
microarray: bool, optional
whether to include the IR-Tex microarray data in the requested data. Default is False.
low_count_filter: int, optional
if provided, genes with a median count below the threshold will be removed before gene ranking. Default is None.
fraction_nas_allowed: float, optional
fraction of missing values allowed in the data. Defaults to 0.5
Expand All @@ -411,7 +469,7 @@ def load_candidates(analysis, name='median', func=np.nanmedian, query_annotation
fc_ranked: pd.DataFrame
"""

fc_data = data(data_type='fcs', analysis=analysis, microarray=microarray, annotations=True, sort_by=None, fraction_na_allowed=fraction_na_allowed)
fc_data = data(data_type='fcs', analysis=analysis, microarray=microarray, annotations=True, sort_by=None, low_count_filter=low_count_filter, fraction_na_allowed=fraction_na_allowed)

if query_annotation is not None:
gene_annot_df = load_annotations()
Expand All @@ -429,7 +487,8 @@ def load_candidates(analysis, name='median', func=np.nanmedian, query_annotation

return(fc_ranked)

def load_genes_for_enrichment(analysis, func, gene_ids, percentile):

def load_genes_for_enrichment(analysis, func, gene_ids, percentile, microarray, low_count_filter=None):

assert func is not None or gene_ids is not None, "either a ranking function (func) or gene_ids must be provided"
assert func is None or gene_ids is None, "Only a ranking function (func) or gene_ids must be provided, not both"
Expand All @@ -440,15 +499,15 @@ def load_genes_for_enrichment(analysis, func, gene_ids, percentile):

if func:
# get top % percentile genes ranked by func
fc_ranked = load_candidates(analysis=analysis, name='enrich', func=func)
fc_ranked = load_candidates(analysis=analysis, name='enrich', func=func, microarray=microarray, low_count_filter=low_count_filter)
percentile_idx = fc_ranked.reset_index()['GeneID'].unique().shape[0] * percentile
top_geneIDs = fc_ranked.reset_index().loc[:, 'GeneID'][:int(percentile_idx)]
elif gene_ids:
top_geneIDs = gene_ids
top_geneIDs = resolve_gene_id(gene_id=gene_ids, analysis=analysis)

return top_geneIDs, fc_genes

def go_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05):
def go_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05, microarray=False, low_count_filter=None):
"""
Perform a hypergeometric test on GO terms of the the top % percentile genes ranked by user input function, or on
a user inputted gene_id list
Expand All @@ -464,13 +523,17 @@ def go_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05):
list of gene ids to perform hypergeometric test on. Defaults to None
percentile: float, optional
percentile of genes to use for the enriched set in hypergeometric test. Defaults to 0.05
microarray: bool, optional
whether to include the IR-Tex microarray data in the gene ranking. Default is False.
low_count_filter: int, optional
if provided, genes with a median count below the threshold will be removed before gene ranking. Default is None.
Returns
-------
go_hypergeo_results: pd.DataFrame
"""

top_geneIDs, fc_genes = load_genes_for_enrichment(analysis=analysis, func=func, gene_ids=gene_ids, percentile=percentile)
top_geneIDs, fc_genes = load_genes_for_enrichment(analysis=analysis, func=func, gene_ids=gene_ids, percentile=percentile, microarray=microarray, low_count_filter=low_count_filter)

# load gene annotation file
gaf_df = pd.read_csv("https://raw.githubusercontent.com/sanjaynagi/AnoExpress/main/resources/AgamP4.gaf", sep="\t")
Expand All @@ -490,7 +553,7 @@ def go_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05):
return(hyper_geo)


def pfam_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05):
def pfam_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05, microarray=False, low_count_filter=None):
"""
Perform a hypergeometric test on PFAM domains of the the top % percentile genes ranked by user input function,
or on a user inputted gene_id list
Expand All @@ -508,13 +571,17 @@ def pfam_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05):
list of gene ids to perform hypergeometric test on. Defaults to None
percentile: float, optional
percentile of genes to use for the enriched set in hypergeometric test. Defaults to 0.05
microarray: bool, optional
whether to include the IR-Tex microarray data in the gene ranking. Default is False.
low_count_filter: int, optional
if provided, genes with a median count below the threshold will be removed before gene ranking. Default is None.
Returns
-------
pfam_hypergeo_results: pd.DataFrame
"""

top_geneIDs, fc_genes = load_genes_for_enrichment(analysis=analysis, func=func, gene_ids=gene_ids, percentile=percentile)
top_geneIDs, fc_genes = load_genes_for_enrichment(analysis=analysis, func=func, gene_ids=gene_ids, percentile=percentile, microarray=microarray, low_count_filter=low_count_filter)

# load gene annotation file
pfam_df = pd.read_csv("https://github.com/sanjaynagi/AnoExpress/blob/main/resources/Anogam_long.pep_Pfamscan.seqs.gz?raw=true", sep="\s+", header=None, compression='gzip').iloc[:, [0,4]]
Expand All @@ -534,7 +601,7 @@ def pfam_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05):

return(hyper_geo)

def kegg_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05):
def kegg_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05, microarray=False, low_count_filter=None):
"""
Perform a hypergeometric test on GO terms of the the top % percentile genes ranked by user input function.
Expand All @@ -551,13 +618,17 @@ def kegg_hypergeometric(analysis, func=None, gene_ids=None, percentile=0.05):
list of gene ids to perform hypergeometric test on. Defaults to None
percentile: float, optional
percentile of genes to use for the enriched set in hypergeometric test. Defaults to 0.05
microarray: bool, optional
whether to include the IR-Tex microarray data in the gene ranking. Default is False.
low_count_filter: int, optional
if provided, genes with a median count below the threshold will be removed before gene ranking. Default is None.
Returns
-------
go_hypergeo_results: pd.DataFrame
"""

top_geneIDs, fc_genes = load_genes_for_enrichment(analysis=analysis, func=func, gene_ids=gene_ids, percentile=percentile)
top_geneIDs, fc_genes = load_genes_for_enrichment(analysis=analysis, func=func, gene_ids=gene_ids, percentile=percentile, microarray=microarray, low_count_filter=low_count_filter)

# load gene annotation file
kegg_df = pd.read_csv("https://raw.githubusercontent.com/sanjaynagi/AnoExpress/main/resources/AgamP4.kegg", sep="\t")
Expand Down
Loading

0 comments on commit e05bef6

Please sign in to comment.