From 6469bf5dc4311675a90ae60e3ed1a8f0b4da9349 Mon Sep 17 00:00:00 2001 From: Daniel Suveges Date: Thu, 12 Sep 2024 14:17:21 +0100 Subject: [PATCH] feat(validation): adding credible set variant validation (#757) * feat(validation): adding logic to validate credible sets against variant index * fix: tidying docstrings --- src/gentropy/dataset/study_locus.py | 66 +++++++++++++++- tests/gentropy/dataset/test_study_locus.py | 89 ++++++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index edf9dc8be..b59d57650 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -8,7 +8,7 @@ import numpy as np import pyspark.sql.functions as f -from pyspark.sql.types import FloatType, StringType +from pyspark.sql.types import ArrayType, FloatType, StringType from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import ( @@ -18,6 +18,7 @@ from gentropy.common.utils import get_logsum, parse_region from gentropy.dataset.dataset import Dataset from gentropy.dataset.study_locus_overlap import StudyLocusOverlap +from gentropy.dataset.variant_index import VariantIndex from gentropy.method.clump import LDclumping if TYPE_CHECKING: @@ -47,6 +48,7 @@ class StudyLocusQualityCheck(Enum): FAILED_STUDY (str): Flagging study loci if the study has failed QC MISSING_STUDY (str): Flagging study loci if the study is not found in the study index as a reference DUPLICATED_STUDYLOCUS_ID (str): Study-locus identifier is not unique. + INVALID_VARIANT_IDENTIFIER (str): Flagging study loci where identifier of any tagging variant was not found in the variant index """ SUBSIGNIFICANT_FLAG = "Subsignificant p-value" @@ -65,6 +67,9 @@ class StudyLocusQualityCheck(Enum): FAILED_STUDY = "Study has failed quality controls" MISSING_STUDY = "Study not found in the study index" DUPLICATED_STUDYLOCUS_ID = "Non-unique study locus identifier" + INVALID_VARIANT_IDENTIFIER = ( + "Some variant identifiers of this locus were not found in variant index" + ) class CredibleInterval(Enum): @@ -141,6 +146,65 @@ def validate_study(self: StudyLocus, study_index: StudyIndex) -> StudyLocus: _schema=self.get_schema(), ) + def validate_variant_identifiers( + self: StudyLocus, variant_index: VariantIndex + ) -> StudyLocus: + """Flagging study loci, where tagging variant identifiers are not found in variant index. + + Args: + variant_index (VariantIndex): Variant index to resolve variant identifiers. + + Returns: + StudyLocus: Updated study locus with quality control flags. + """ + # QC column might not be present in the variant index schema, so we have to be ready to handle it: + qc_select_expression = ( + f.col("qualityControls") + if "qualityControls" in self.df.columns + else f.lit(None).cast(ArrayType(StringType())) + ) + + # Find out which study loci have variants not in the variant index: + flag = ( + self.df + # Exploding locus: + .select("studyLocusId", f.explode("locus").alias("locus")) + .select("studyLocusId", "locus.variantId") + # Join with variant index variants: + .join( + variant_index.df.select( + "variantId", f.lit(True).alias("inVariantIndex") + ), + on="variantId", + how="left", + ) + # Flagging variants not in the variant index: + .withColumn("inVariantIndex", f.col("inVariantIndex").isNotNull()) + # Flagging study loci with ANY variants not in the variant index: + .groupBy("studyLocusId") + .agg(f.collect_set("inVariantIndex").alias("inVariantIndex")) + .select( + "studyLocusId", + f.array_contains("inVariantIndex", False).alias("toFlag"), + ) + ) + + return StudyLocus( + _df=( + self.df.join(flag, on="studyLocusId", how="left") + .withColumn( + "qualityControls", + self.update_quality_flag( + qc_select_expression, + f.col("toFlag"), + StudyLocusQualityCheck.INVALID_VARIANT_IDENTIFIER, + ), + ) + .drop("toFlag") + ), + _schema=self.get_schema(), + ) + def validate_lead_pvalue(self: StudyLocus, pvalue_cutoff: float) -> StudyLocus: """Flag associations below significant threshold. diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index 9b40796db..c7538b28b 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -27,6 +27,7 @@ ) from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.dataset.summary_statistics import SummaryStatistics +from gentropy.dataset.variant_index import VariantIndex @pytest.mark.parametrize( @@ -562,6 +563,94 @@ def test_annotate_locus_statistics_boundaries( ) +class TestStudyLocusVariantValidation: + """Collection of tests for StudyLocus variant validation.""" + + VARIANT_DATA = [ + ("v1", "c1", 1, "r", "a"), + ("v2", "c1", 2, "r", "a"), + ("v3", "c1", 3, "r", "a"), + ("v4", "c1", 4, "r", "a"), + ] + VARIANT_HEADERS = [ + "variantId", + "chromosome", + "position", + "referenceAllele", + "alternateAllele", + ] + + STUDYLOCUS_DATA = [ + # First studylocus passes qc: + (1, "v1", "s1", "v1"), + (1, "v1", "s1", "v2"), + (1, "v1", "s1", "v3"), + # Second studylocus passes qc: + (2, "v1", "s1", "v1"), + (2, "v1", "s1", "v5"), + ] + STUDYLOCUS_HEADER = ["studyLocusId", "variantId", "studyId", "tagVariantId"] + + @pytest.fixture(autouse=True) + def _setup(self: TestStudyLocusVariantValidation, spark: SparkSession) -> None: + """Setup study locus for testing.""" + self.variant_index = VariantIndex( + _df=spark.createDataFrame( + self.VARIANT_DATA, self.VARIANT_HEADERS + ).withColumn("position", f.col("position").cast(t.IntegerType())), + _schema=VariantIndex.get_schema(), + ) + + self.credible_set = StudyLocus( + _df=( + spark.createDataFrame(self.STUDYLOCUS_DATA, self.STUDYLOCUS_HEADER) + .withColumn("studyLocusId", f.col("studyLocusId").cast(t.LongType())) + .withColumn("qualityControls", f.array()) + .groupBy("studyLocusId", "variantId", "studyId") + .agg( + f.collect_set( + f.struct(f.col("tagVariantId").alias("variantId")) + ).alias("locus") + ) + ), + _schema=StudyLocus.get_schema(), + ) + + def test_validation_return_type(self: TestStudyLocusVariantValidation) -> None: + """Testing if the validation returns the right type.""" + assert isinstance( + self.credible_set.validate_variant_identifiers(self.variant_index), + StudyLocus, + ) + + def test_validation_no_data_loss(self: TestStudyLocusVariantValidation) -> None: + """Testing if the validation returns same number of rows.""" + assert ( + self.credible_set.validate_variant_identifiers( + self.variant_index + ).df.count() + == self.credible_set.df.count() + ) + + def test_validation_correctness(self: TestStudyLocusVariantValidation) -> None: + """Testing if the validation flags the right number of variants.""" + # Execute validation: + validated = self.credible_set.validate_variant_identifiers( + self.variant_index + ).df + + # Make sure there's only one study locus with a failed variants: + assert validated.filter(f.size("qualityControls") > 0).count() == 1 + + # Check that the right one is flagged: + assert ( + validated.filter( + (f.size("qualityControls") > 0) & (f.col("studyLocusId") == 2) + ).count() + == 1 + ) + + class TestStudyLocusValidation: """Collection of tests for StudyLocus validation."""