Skip to content

Commit

Permalink
feat: optimisation of qc step (opentargets#813)
Browse files Browse the repository at this point in the history
* feat: optimisation of qc step

* fix: adding Z2 filter

* fix: v1
  • Loading branch information
addramir authored Oct 3, 2024
1 parent 8876fc1 commit fca55be
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 30 deletions.
6 changes: 4 additions & 2 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class FinemapperConfig(StepConfig):


@dataclass
class GWASQCStep(StepConfig):
class SummaryStatisticsQCStepConfig(StepConfig):
"""GWAS QC step configuration."""

gwas_path: str = MISSING
Expand Down Expand Up @@ -614,7 +614,9 @@ def register_config() -> None:
group="step", name="window_based_clumping", node=WindowBasedClumpingStepConfig
)
cs.store(group="step", name="susie_finemapping", node=FinemapperConfig)
cs.store(group="step", name="summary_statistics_qc", node=GWASQCStep)
cs.store(
group="step", name="summary_statistics_qc", node=SummaryStatisticsQCStepConfig
)
cs.store(
group="step", name="locus_breaker_clumping", node=LocusBreakerClumpingConfig
)
Expand Down
31 changes: 6 additions & 25 deletions src/gentropy/method/sumstat_quality_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,11 @@ def _calculate_logpval(z2: float) -> float:
@staticmethod
def sumstat_qc_pz_check(
gwas_for_qc: SummaryStatistics,
limit: int = 10_000_000,
) -> DataFrame:
"""The PZ check for QC of GWAS summary statstics. It runs linear regression between reported p-values and p-values infered from z-scores.
Args:
gwas_for_qc (SummaryStatistics): The instance of the SummaryStatistics class.
limit (int): The limit for the number of variants to be used for the estimation.
Returns:
DataFrame: PySpark DataFrame with the results of the linear regression for each study.
Expand All @@ -86,17 +84,10 @@ def sumstat_qc_pz_check(
SummaryStatisticsQC._calculate_logpval, t.DoubleType()
)

window = Window.partitionBy("studyId").orderBy("studyId")

gwas_df = (
gwas_df.withColumn("row_num", row_number().over(window))
.filter(f.col("row_num") <= limit)
.drop("row_num")
)

qc_c = (
gwas_df.withColumn("zscore", f.col("beta") / f.col("standardError"))
.withColumn("new_logpval", calculate_logpval_udf(f.col("zscore") ** 2))
gwas_df.withColumn("Z2", (f.col("beta") / f.col("standardError")) ** 2)
.filter(f.col("Z2") <= 100)
.withColumn("new_logpval", calculate_logpval_udf(f.col("Z2")))
.withColumn("log_mantissa", log10("pValueMantissa"))
.withColumn(
"diffpval",
Expand Down Expand Up @@ -194,24 +185,16 @@ def sumstat_n_eff_check(
@staticmethod
def gc_lambda_check(
gwas_for_qc: SummaryStatistics,
limit: int = 10_000_000,
) -> DataFrame:
"""The genomic control lambda check for QC of GWAS summary statstics.
Args:
gwas_for_qc (SummaryStatistics): The instance of the SummaryStatistics class.
limit (int): The limit for the number of variants to be used for the estimation.
Returns:
DataFrame: PySpark DataFrame with the genomic control lambda for each study.
"""
gwas_df = gwas_for_qc._df
window = Window.partitionBy("studyId").orderBy("studyId")
gwas_df = (
gwas_df.withColumn("row_num", row_number().over(window))
.filter(f.col("row_num") <= limit)
.drop("row_num")
)

qc_c = (
gwas_df.select("studyId", "beta", "standardError")
Expand Down Expand Up @@ -254,22 +237,20 @@ def number_of_snps(
@staticmethod
def get_quality_control_metrics(
gwas: SummaryStatistics,
limit: int = 100_000_000,
pval_threshold: float = 5e-8,
pval_threshold: float = 1e-8,
) -> DataFrame:
"""The function calculates the quality control metrics for the summary statistics.
Args:
gwas (SummaryStatistics): The instance of the SummaryStatistics class.
limit (int): The limit for the number of variants to be used for the estimation.
pval_threshold (float): The threshold for the p-value.
Returns:
DataFrame: PySpark DataFrame with the quality control metrics for the summary statistics.
"""
qc1 = SummaryStatisticsQC.sumstat_qc_beta_check(gwas_for_qc=gwas)
qc2 = SummaryStatisticsQC.sumstat_qc_pz_check(gwas_for_qc=gwas, limit=limit)
qc4 = SummaryStatisticsQC.gc_lambda_check(gwas_for_qc=gwas, limit=limit)
qc2 = SummaryStatisticsQC.sumstat_qc_pz_check(gwas_for_qc=gwas)
qc4 = SummaryStatisticsQC.gc_lambda_check(gwas_for_qc=gwas)
qc5 = SummaryStatisticsQC.number_of_snps(
gwas_for_qc=gwas, pval_threshold=pval_threshold
)
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/sumstat_qc_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(

(
SummaryStatisticsQC.get_quality_control_metrics(
gwas=gwas, limit=100_000_000, pval_threshold=pval_threshold
gwas=gwas, pval_threshold=pval_threshold
)
.write.mode(session.write_mode)
.parquet(output_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/gentropy/method/test_qc_of_sumstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_qc_functions(
) -> None:
"""Test all sumstat qc functions."""
gwas = sample_summary_statistics.sanity_filter()
QC = SummaryStatisticsQC.get_quality_control_metrics(gwas=gwas, limit=100000)
QC = SummaryStatisticsQC.get_quality_control_metrics(gwas=gwas, pval_threshold=5e-8)
QC = QC.toPandas()

assert QC["n_variants"].iloc[0] == 1663
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_several_studyid(
)
gwas._df = gwas_df

QC = SummaryStatisticsQC.get_quality_control_metrics(gwas=gwas, limit=100000)
QC = SummaryStatisticsQC.get_quality_control_metrics(gwas=gwas)
QC = QC.toPandas()
assert QC.shape == (2, 7)

Expand Down

0 comments on commit fca55be

Please sign in to comment.