Skip to content

Commit

Permalink
feat(validation): adding credible set variant validation (opentargets…
Browse files Browse the repository at this point in the history
…#757)

* feat(validation): adding logic to validate credible sets against variant index

* fix: tidying docstrings
  • Loading branch information
DSuveges authored Sep 12, 2024
1 parent a49ae9a commit 6469bf5
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 1 deletion.
66 changes: 65 additions & 1 deletion src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import pyspark.sql.functions as f
from pyspark.sql.types import FloatType, StringType
from pyspark.sql.types import ArrayType, FloatType, StringType

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import (
Expand All @@ -18,6 +18,7 @@
from gentropy.common.utils import get_logsum, parse_region
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.study_locus_overlap import StudyLocusOverlap
from gentropy.dataset.variant_index import VariantIndex
from gentropy.method.clump import LDclumping

if TYPE_CHECKING:
Expand Down Expand Up @@ -47,6 +48,7 @@ class StudyLocusQualityCheck(Enum):
FAILED_STUDY (str): Flagging study loci if the study has failed QC
MISSING_STUDY (str): Flagging study loci if the study is not found in the study index as a reference
DUPLICATED_STUDYLOCUS_ID (str): Study-locus identifier is not unique.
INVALID_VARIANT_IDENTIFIER (str): Flagging study loci where identifier of any tagging variant was not found in the variant index
"""

SUBSIGNIFICANT_FLAG = "Subsignificant p-value"
Expand All @@ -65,6 +67,9 @@ class StudyLocusQualityCheck(Enum):
FAILED_STUDY = "Study has failed quality controls"
MISSING_STUDY = "Study not found in the study index"
DUPLICATED_STUDYLOCUS_ID = "Non-unique study locus identifier"
INVALID_VARIANT_IDENTIFIER = (
"Some variant identifiers of this locus were not found in variant index"
)


class CredibleInterval(Enum):
Expand Down Expand Up @@ -141,6 +146,65 @@ def validate_study(self: StudyLocus, study_index: StudyIndex) -> StudyLocus:
_schema=self.get_schema(),
)

def validate_variant_identifiers(
self: StudyLocus, variant_index: VariantIndex
) -> StudyLocus:
"""Flagging study loci, where tagging variant identifiers are not found in variant index.
Args:
variant_index (VariantIndex): Variant index to resolve variant identifiers.
Returns:
StudyLocus: Updated study locus with quality control flags.
"""
# QC column might not be present in the variant index schema, so we have to be ready to handle it:
qc_select_expression = (
f.col("qualityControls")
if "qualityControls" in self.df.columns
else f.lit(None).cast(ArrayType(StringType()))
)

# Find out which study loci have variants not in the variant index:
flag = (
self.df
# Exploding locus:
.select("studyLocusId", f.explode("locus").alias("locus"))
.select("studyLocusId", "locus.variantId")
# Join with variant index variants:
.join(
variant_index.df.select(
"variantId", f.lit(True).alias("inVariantIndex")
),
on="variantId",
how="left",
)
# Flagging variants not in the variant index:
.withColumn("inVariantIndex", f.col("inVariantIndex").isNotNull())
# Flagging study loci with ANY variants not in the variant index:
.groupBy("studyLocusId")
.agg(f.collect_set("inVariantIndex").alias("inVariantIndex"))
.select(
"studyLocusId",
f.array_contains("inVariantIndex", False).alias("toFlag"),
)
)

return StudyLocus(
_df=(
self.df.join(flag, on="studyLocusId", how="left")
.withColumn(
"qualityControls",
self.update_quality_flag(
qc_select_expression,
f.col("toFlag"),
StudyLocusQualityCheck.INVALID_VARIANT_IDENTIFIER,
),
)
.drop("toFlag")
),
_schema=self.get_schema(),
)

def validate_lead_pvalue(self: StudyLocus, pvalue_cutoff: float) -> StudyLocus:
"""Flag associations below significant threshold.
Expand Down
89 changes: 89 additions & 0 deletions tests/gentropy/dataset/test_study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from gentropy.dataset.study_locus_overlap import StudyLocusOverlap
from gentropy.dataset.summary_statistics import SummaryStatistics
from gentropy.dataset.variant_index import VariantIndex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -562,6 +563,94 @@ def test_annotate_locus_statistics_boundaries(
)


class TestStudyLocusVariantValidation:
"""Collection of tests for StudyLocus variant validation."""

VARIANT_DATA = [
("v1", "c1", 1, "r", "a"),
("v2", "c1", 2, "r", "a"),
("v3", "c1", 3, "r", "a"),
("v4", "c1", 4, "r", "a"),
]
VARIANT_HEADERS = [
"variantId",
"chromosome",
"position",
"referenceAllele",
"alternateAllele",
]

STUDYLOCUS_DATA = [
# First studylocus passes qc:
(1, "v1", "s1", "v1"),
(1, "v1", "s1", "v2"),
(1, "v1", "s1", "v3"),
# Second studylocus passes qc:
(2, "v1", "s1", "v1"),
(2, "v1", "s1", "v5"),
]
STUDYLOCUS_HEADER = ["studyLocusId", "variantId", "studyId", "tagVariantId"]

@pytest.fixture(autouse=True)
def _setup(self: TestStudyLocusVariantValidation, spark: SparkSession) -> None:
"""Setup study locus for testing."""
self.variant_index = VariantIndex(
_df=spark.createDataFrame(
self.VARIANT_DATA, self.VARIANT_HEADERS
).withColumn("position", f.col("position").cast(t.IntegerType())),
_schema=VariantIndex.get_schema(),
)

self.credible_set = StudyLocus(
_df=(
spark.createDataFrame(self.STUDYLOCUS_DATA, self.STUDYLOCUS_HEADER)
.withColumn("studyLocusId", f.col("studyLocusId").cast(t.LongType()))
.withColumn("qualityControls", f.array())
.groupBy("studyLocusId", "variantId", "studyId")
.agg(
f.collect_set(
f.struct(f.col("tagVariantId").alias("variantId"))
).alias("locus")
)
),
_schema=StudyLocus.get_schema(),
)

def test_validation_return_type(self: TestStudyLocusVariantValidation) -> None:
"""Testing if the validation returns the right type."""
assert isinstance(
self.credible_set.validate_variant_identifiers(self.variant_index),
StudyLocus,
)

def test_validation_no_data_loss(self: TestStudyLocusVariantValidation) -> None:
"""Testing if the validation returns same number of rows."""
assert (
self.credible_set.validate_variant_identifiers(
self.variant_index
).df.count()
== self.credible_set.df.count()
)

def test_validation_correctness(self: TestStudyLocusVariantValidation) -> None:
"""Testing if the validation flags the right number of variants."""
# Execute validation:
validated = self.credible_set.validate_variant_identifiers(
self.variant_index
).df

# Make sure there's only one study locus with a failed variants:
assert validated.filter(f.size("qualityControls") > 0).count() == 1

# Check that the right one is flagged:
assert (
validated.filter(
(f.size("qualityControls") > 0) & (f.col("studyLocusId") == 2)
).count()
== 1
)


class TestStudyLocusValidation:
"""Collection of tests for StudyLocus validation."""

Expand Down

0 comments on commit 6469bf5

Please sign in to comment.