diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index 30b91541b..b829990ce 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -433,7 +433,8 @@ def _qc_subsignificant_associations( def qc_abnormal_pips( self: StudyLocus, sum_pips_lower_threshold: float = 0.99, - sum_pips_upper_threshold: float = 1.0001, # Set slightly above 1 to account for floating point errors + # Set slightly above 1 to account for floating point errors + sum_pips_upper_threshold: float = 1.0001, ) -> StudyLocus: """Filter study-locus by sum of posterior inclusion probabilities to ensure that the sum of PIPs is within a given range. @@ -691,6 +692,7 @@ def flag_trans_qtls( """Flagging transQTL credible sets based on genomic location of the measured gene. Process: + 0. Make sure that the `isTransQtl` column does not exist (remove if exists) 1. Enrich study-locus dataset with geneId based on study metadata. (only QTL studies are considered) 2. Enrich with transcription start site and chromosome of the studied gegne. 3. Flagging any tagging variant of QTL credible sets, if chromosome is different from the gene or distance is above the threshold. @@ -709,6 +711,12 @@ def flag_trans_qtls( if "geneId" not in study_index.df.columns: return self + # We have to remove the column `isTransQtl` to ensure the column is not duplicated + # The duplication can happen when one reads the StudyLocus from parquet with + # predefined schema that already contains the `isTransQtl` column. + if "isTransQtl" in self.df.columns: + self.df = self.df.drop("isTransQtl") + # Process study index: processed_studies = ( study_index.df diff --git a/src/gentropy/method/colocalisation.py b/src/gentropy/method/colocalisation.py index c05b26187..b45b91920 100644 --- a/src/gentropy/method/colocalisation.py +++ b/src/gentropy/method/colocalisation.py @@ -208,11 +208,15 @@ class Coloc(ColocalisationMethodInterface): Attributes: PSEUDOCOUNT (float): Pseudocount to avoid log(0). Defaults to 1e-10. + OVERLAP_SIZE_CUTOFF (int): Minimum number of overlapping variants bfore filtering. Defaults to 5. + POSTERIOR_CUTOFF (float): Minimum overlapping Posterior probability cutoff for small overlaps. Defaults to 0.5. """ METHOD_NAME: str = "COLOC" METHOD_METRIC: str = "h4" PSEUDOCOUNT: float = 1e-10 + OVERLAP_SIZE_CUTOFF: int = 5 + POSTERIOR_CUTOFF: float = 0.5 @staticmethod def _get_posteriors(all_bfs: NDArray[np.float64]) -> DenseVector: @@ -277,7 +281,15 @@ def colocalise( ) .select("*", "statistics.*") # Before summing log_BF columns nulls need to be filled with 0: - .fillna(0, subset=["left_logBF", "right_logBF"]) + .fillna( + 0, + subset=[ + "left_logBF", + "right_logBF", + "left_posteriorProbability", + "right_posteriorProbability", + ], + ) # Sum of log_BFs for each pair of signals .withColumn( "sum_log_bf", @@ -305,9 +317,18 @@ def colocalise( fml.array_to_vector(f.collect_list(f.col("right_logBF"))).alias( "right_logBF" ), + fml.array_to_vector( + f.collect_list(f.col("left_posteriorProbability")) + ).alias("left_posteriorProbability"), + fml.array_to_vector( + f.collect_list(f.col("right_posteriorProbability")) + ).alias("right_posteriorProbability"), fml.array_to_vector(f.collect_list(f.col("sum_log_bf"))).alias( "sum_log_bf" ), + f.collect_list(f.col("tagVariantSource")).alias( + "tagVariantSourceList" + ), ) .withColumn("logsum1", logsum(f.col("left_logBF"))) .withColumn("logsum2", logsum(f.col("right_logBF"))) @@ -327,10 +348,39 @@ def colocalise( # h3 .withColumn("sumlogsum", f.col("logsum1") + f.col("logsum2")) .withColumn("max", f.greatest("sumlogsum", "logsum12")) + .withColumn( + "anySnpBothSidesHigh", + f.aggregate( + f.transform( + f.arrays_zip( + fml.vector_to_array(f.col("left_posteriorProbability")), + fml.vector_to_array( + f.col("right_posteriorProbability") + ), + f.col("tagVariantSourceList"), + ), + # row["0"] = left PP, row["1"] = right PP, row["tagVariantSourceList"] + lambda row: f.when( + (row["tagVariantSourceList"] == "both") + & (row["0"] > Coloc.POSTERIOR_CUTOFF) + & (row["1"] > Coloc.POSTERIOR_CUTOFF), + 1.0, + ).otherwise(0.0), + ), + f.lit(0.0), + lambda acc, x: acc + x, + ) + > 0, # True if sum of these 1.0's > 0 + ) + .filter( + (f.col("numberColocalisingVariants") > Coloc.OVERLAP_SIZE_CUTOFF) + | (f.col("anySnpBothSidesHigh")) + ) .withColumn( "logdiff", f.when( - f.col("sumlogsum") == f.col("logsum12"), Coloc.PSEUDOCOUNT + (f.col("sumlogsum") == f.col("logsum12")), + Coloc.PSEUDOCOUNT, ).otherwise( f.col("max") + f.log( @@ -382,6 +432,10 @@ def colocalise( "lH2bf", "lH3bf", "lH4bf", + "left_posteriorProbability", + "right_posteriorProbability", + "tagVariantSourceList", + "anySnpBothSidesHigh", ) .withColumn("colocalisationMethod", f.lit(cls.METHOD_NAME)) .join( diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index f8e59d97e..ce0f951d0 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -2,6 +2,7 @@ from __future__ import annotations +from pathlib import Path from typing import Any import pyspark.sql.functions as f @@ -18,6 +19,8 @@ StructType, ) +from gentropy.common.schemas import SchemaValidationError +from gentropy.common.session import Session from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.ld_index import LDIndex @@ -1209,7 +1212,6 @@ class TestTransQtlFlagging: ] STUDY_LOCUS_COLUMNS = ["studyLocusId", "variantId", "studyId"] - STUDY_DATA = [ ("s1", "p1", "qtl", "g1"), ("s2", "p2", "gwas", None), @@ -1221,21 +1223,21 @@ class TestTransQtlFlagging: GENE_COLUMNS = ["id", "strand", "start", "end", "chromosome", "tss"] @pytest.fixture(autouse=True) - def _setup(self: TestTransQtlFlagging, spark: SparkSession) -> None: + def _setup(self: TestTransQtlFlagging, session: Session) -> None: """Setup study locus for testing.""" self.study_locus = StudyLocus( _df=( - spark.createDataFrame( + session.spark.createDataFrame( self.STUDY_LOCUS_DATA, self.STUDY_LOCUS_COLUMNS ).withColumn("locus", f.array(f.struct("variantId"))) ) ) self.study_index = StudyIndex( - _df=spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS) + _df=session.spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS) ) self.target_index = TargetIndex( _df=( - spark.createDataFrame(self.GENE_DATA, self.GENE_COLUMNS).select( + session.spark.createDataFrame(self.GENE_DATA, self.GENE_COLUMNS).select( f.struct( f.col("strand").cast(IntegerType()).alias("strand"), "start", @@ -1283,3 +1285,27 @@ def test_correctness_found_trans(self: TestTransQtlFlagging) -> None: assert ( self.qtl_flagged.df.filter(f.col("isTransQtl")).count() == 2 ), "Expected number of rows differ from observed." + + def test_add_flag_if_column_is_present( + self: TestTransQtlFlagging, tmp_path: Path, session: Session + ) -> None: + """Test adding flag if the `isTransQtl` column is already present. + + When reading the dataset, the reader will add the `isTransQtl` column to + the schema, which can cause column duplication captured only by Dataset schema validation. + + This test ensures that the column is dropped before the `flag_trans_qtls` is run. + """ + dataset_path = str(tmp_path / "study_locus") + self.study_locus.df.write.parquet(dataset_path) + schema_validated_study_locus = StudyLocus.from_parquet(session, dataset_path) + assert ( + "isTransQtl" in schema_validated_study_locus.df.columns + ), "`isTransQtl` column is missing after reading the dataset." + # Rerun the flag addition and check if any error is raised by the schema validation + try: + schema_validated_study_locus.flag_trans_qtls( + self.study_index, self.target_index, self.THRESHOLD + ) + except SchemaValidationError: + pytest.fail("Failed to validate the schema when adding isTransQtl flag") diff --git a/tests/gentropy/method/test_colocalisation_method.py b/tests/gentropy/method/test_colocalisation_method.py index 5b05d724b..78a66f732 100644 --- a/tests/gentropy/method/test_colocalisation_method.py +++ b/tests/gentropy/method/test_colocalisation_method.py @@ -43,6 +43,8 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: "right_logBF": 10.5, "left_beta": 0.1, "right_beta": 0.2, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, }, }, ], @@ -57,7 +59,7 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: }, ], ), - # associations with multiple overlapping SNPs + # Case with mismatched posterior probabilities: ( # observed overlap [ @@ -68,10 +70,12 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: "chromosome": "1", "tagVariantId": "snp1", "statistics": { - "left_logBF": 10.3, + "left_logBF": 1.2, "right_logBF": 10.5, - "left_beta": 0.1, + "left_beta": 0.001, "right_beta": 0.2, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.92, }, }, { @@ -82,23 +86,177 @@ def test_coloc(mock_study_locus_overlap: StudyLocusOverlap) -> None: "tagVariantId": "snp2", "statistics": { "left_logBF": 10.3, - "right_logBF": 10.5, + "right_logBF": 3.8, "left_beta": 0.3, - "right_beta": 0.5, + "right_beta": 0.005, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.01, + }, + }, + ], + # expected coloc + [], + ), + # Case of an overlap with significant PP overlap: + ( + # observed overlap + [ + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp1", + "statistics": { + "left_logBF": 10.2, + "right_logBF": 10.5, + "left_beta": 0.5, + "right_beta": 0.2, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp2", + "statistics": { + "left_logBF": 1.2, + "right_logBF": 3.8, + "left_beta": 0.003, + "right_beta": 0.005, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, }, }, ], # expected coloc [ { - "h0": 4.6230151407950416e-5, - "h1": 2.749086942648107e-4, - "h2": 3.357742374172504e-4, - "h3": 9.983447421747411e-4, - "h4": 0.9983447421747356, + "h0": 1.02277006860577e-4, + "h1": 2.7519169183135977e-4, + "h2": 3.718812819512325e-4, + "h3": 1.3533048074295033e-6, + "h4": 0.9992492967145488, }, ], ), + # Case where the overlap source is ["left", "both", "both"]: + ( + # observed overlap + [ + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp1", + "statistics": { + "left_logBF": 1.2, + "right_logBF": None, + "left_beta": 0.003, + "right_beta": None, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp2", + "statistics": { + "left_logBF": 1.2, + "right_logBF": 3.8, + "left_beta": 0.003, + "right_beta": 0.005, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp3", + "statistics": { + "left_logBF": 10.2, + "right_logBF": 10.5, + "left_beta": 0.5, + "right_beta": 0.2, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, + }, + }, + ], + # expected coloc + [ + { + "h0": 1.02277006860577e-4, + "h1": 2.752255943423052e-4, + "h2": 3.718914358059273e-4, + "h3": 1.5042926116520848e-6, + "h4": 0.9992491016906891, + }, + ], + ), + # Case where PPs are high on the left, but low on the right: + ( + # observed overlap + [ + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp1", + "statistics": { + "left_logBF": 1.2, + "right_logBF": None, + "left_beta": 0.003, + "right_beta": None, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp2", + "statistics": { + "left_logBF": 1.2, + "right_logBF": 3.8, + "left_beta": 0.003, + "right_beta": 0.005, + "left_posteriorProbability": 0.001, + "right_posteriorProbability": 0.01, + }, + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "rightStudyType": "eqtl", + "chromosome": "1", + "tagVariantId": "snp3", + "statistics": { + "left_logBF": 10.2, + "right_logBF": 10.5, + "left_beta": 0.5, + "right_beta": 0.2, + "left_posteriorProbability": 0.36, + "right_posteriorProbability": 0.92, + }, + }, + ], + # expected coloc + [], + ), ], ) def test_coloc_semantic( @@ -111,24 +269,45 @@ def test_coloc_semantic( _df=spark.createDataFrame(observed_data, schema=StudyLocusOverlap.get_schema()), _schema=StudyLocusOverlap.get_schema(), ) - observed_coloc_pdf = ( - Coloc.colocalise(observed_overlap) - .df.select("h0", "h1", "h2", "h3", "h4") - .toPandas() - ) - expected_coloc_pdf = ( - spark.createDataFrame(expected_data) - .select("h0", "h1", "h2", "h3", "h4") - .toPandas() - ) - assert_frame_equal( - observed_coloc_pdf, - expected_coloc_pdf, - check_exact=False, - check_dtype=True, + observed_coloc_df = Coloc.colocalise(observed_overlap).df + + # Define schema for the expected DataFrame + result_schema = StructType( + [ + StructField("h0", DoubleType(), True), + StructField("h1", DoubleType(), True), + StructField("h2", DoubleType(), True), + StructField("h3", DoubleType(), True), + StructField("h4", DoubleType(), True), + ] ) + if not expected_data: + expected_coloc_df = spark.createDataFrame([], schema=result_schema) + else: + expected_coloc_df = spark.createDataFrame(expected_data, schema=result_schema) + + if observed_coloc_df.rdd.isEmpty(): + observed_coloc_df = spark.createDataFrame([], schema=result_schema) + + observed_coloc_df = observed_coloc_df.select("h0", "h1", "h2", "h3", "h4") + + observed_coloc_pdf = observed_coloc_df.toPandas() + expected_coloc_pdf = expected_coloc_df.toPandas() + + if expected_coloc_pdf.empty: + assert ( + observed_coloc_pdf.empty + ), f"Expected an empty DataFrame, but got:\n{observed_coloc_pdf}" + else: + assert_frame_equal( + observed_coloc_pdf, + expected_coloc_pdf, + check_exact=False, + check_dtype=True, + ) + def test_coloc_no_logbf( spark: SparkSession, @@ -151,8 +330,8 @@ def test_coloc_no_logbf( "right_logBF": None, "left_beta": 0.1, "right_beta": 0.2, - "left_posteriorProbability": None, - "right_posteriorProbability": None, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, }, # irrelevant for COLOC } ], @@ -212,8 +391,8 @@ def test_coloc_no_betas(spark: SparkSession) -> None: "right_logBF": 10.3, "left_beta": None, "right_beta": None, - "left_posteriorProbability": None, - "right_posteriorProbability": None, + "left_posteriorProbability": 0.91, + "right_posteriorProbability": 0.92, }, # irrelevant for COLOC } ],