Skip to content

Commit

Permalink
feat: decouple feature generation from L2G training step (opentargets…
Browse files Browse the repository at this point in the history
…#823)

* fix: join between gold standard and credible set based on studyId and variantId

* fix: minor bugs to generate the feature matrix

* fix(colocalisation): safeguard existing rightStudyType when applying `append_study_metadata`

* feat: add feature generation step (working interactively)

* feat(l2g): remove feature generation from `LocusToGeneStep`

* fix: correct feature names

* feat: filter gwas credible sets in `L2GPrediction.from_credible_set`

* chore: update docs

* chore: pass credible set to  `L2GGoldStandard.build_feature_matrix` in test

* chore: uncomment code
  • Loading branch information
ireneisdoomed authored Oct 9, 2024
1 parent 60f6bfa commit b7dce8f
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 118 deletions.
4 changes: 3 additions & 1 deletion docs/python_api/steps/l2g.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
---
title: locus_to_gene
title: Locus to Gene (L2G)
---

::: gentropy.l2g.LocusToGeneFeatureMatrixStep

::: gentropy.l2g.LocusToGeneStep
60 changes: 56 additions & 4 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,6 @@ class LocusToGeneConfig(StepConfig):
predictions_path: str = MISSING
credible_set_path: str = MISSING
variant_index_path: str = MISSING
colocalisation_path: str = MISSING
study_index_path: str = MISSING
gene_index_path: str = MISSING
model_path: str | None = None
feature_matrix_path: str | None = None
gold_standard_curation_path: str | None = None
Expand Down Expand Up @@ -282,10 +279,60 @@ class LocusToGeneConfig(StepConfig):
wandb_run_name: str | None = None
hf_hub_repo_id: str | None = "opentargets/locus_to_gene"
download_from_hub: bool = True
write_feature_matrix: bool = True
_target_: str = "gentropy.l2g.LocusToGeneStep"


@dataclass
class LocusToGeneFeatureMatrixConfig(StepConfig):
"""Locus to gene feature matrix step configuration."""

session: Any = field(
default_factory=lambda: {
"extended_spark_conf": {
"spark.driver.memory": "48g",
"spark.executor.memory": "48g",
"spark.sql.shuffle.partitions": "800",
}
}
)
credible_set_path: str = MISSING
variant_index_path: str | None = None
colocalisation_path: str | None = None
study_index_path: str | None = None
gene_index_path: str | None = None
feature_matrix_path: str = MISSING
features_list: list[str] = field(
default_factory=lambda: [
# max CLPP for each (study, locus, gene) aggregating over a specific qtl type
"eQtlColocClppMaximum",
"pQtlColocClppMaximum",
"sQtlColocClppMaximum",
"tuQtlColocClppMaximum",
# max H4 for each (study, locus, gene) aggregating over a specific qtl type
"eQtlColocH4Maximum",
"pQtlColocH4Maximum",
"sQtlColocH4Maximum",
"tuQtlColocH4Maximum",
# distance to gene footprint
"distanceSentinelFootprint",
"distanceSentinelFootprintNeighbourhood",
"distanceFootprintMean",
"distanceFootprintMeanNeighbourhood",
# distance to gene tss
"distanceTssMean",
"distanceTssMeanNeighbourhood",
"distanceSentinelTss",
"distanceSentinelTssNeighbourhood",
# vep
"vepMaximum",
"vepMaximumNeighbourhood",
"vepMean",
"vepMeanNeighbourhood",
]
)
_target_: str = "gentropy.l2g.LocusToGeneFeatureMatrixStep"


@dataclass
class PICSConfig(StepConfig):
"""PICS step configuration."""
Expand Down Expand Up @@ -597,6 +644,11 @@ def register_config() -> None:
cs.store(group="step", name="ld_based_clumping", node=LDBasedClumpingConfig)
cs.store(group="step", name="ld_index", node=LDIndexConfig)
cs.store(group="step", name="locus_to_gene", node=LocusToGeneConfig)
cs.store(
group="step",
name="locus_to_gene_feature_matrix",
node=LocusToGeneFeatureMatrixConfig,
)
cs.store(group="step", name="finngen_studies", node=FinngenStudiesConfig)

cs.store(
Expand Down
14 changes: 11 additions & 3 deletions src/gentropy/dataset/colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def extract_maximum_coloc_probability_per_region_and_gene(
"""
from gentropy.colocalisation import ColocalisationStep

valid_qtls = list(EqtlCatalogueStudyIndex.method_to_study_type_mapping.values())
valid_qtls = list(
set(EqtlCatalogueStudyIndex.method_to_study_type_mapping.values())
)
if filter_by_qtl and filter_by_qtl not in valid_qtls:
raise ValueError(f"There are no studies with QTL type {filter_by_qtl}")

Expand Down Expand Up @@ -91,7 +93,7 @@ def extract_maximum_coloc_probability_per_region_and_gene(
self.append_study_metadata(
study_locus,
study_index,
metadata_cols=["geneId"],
metadata_cols=["geneId", "studyType"],
colocalisation_side="right",
)
# it also filters based on method and qtl type
Expand Down Expand Up @@ -147,6 +149,12 @@ def append_study_metadata(
)
.distinct()
)
coloc_df = (
# drop `rightStudyType` in case it is requested
self.df.drop("rightStudyType")
if "studyType" in metadata_cols and colocalisation_side == "right"
else self.df
)
return (
# Append that to the respective side of the colocalisation dataset
study_loci_w_metadata.selectExpr(
Expand All @@ -155,5 +163,5 @@ def append_study_metadata(
f"{col} as {colocalisation_side}{col[0].upper() + col[1:]}"
for col in metadata_cols
],
).join(self.df, f"{colocalisation_side}StudyLocusId", "right")
).join(coloc_df, f"{colocalisation_side}StudyLocusId", "right")
)
4 changes: 2 additions & 2 deletions src/gentropy/dataset/l2g_features/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class DistanceSentinelFootprintFeature(L2GFeature):

fill_na_value = 500_000
feature_dependency_type = VariantIndex
feature_name = "distanceSentinelFootprintMinimum"
feature_name = "distanceSentinelFootprint"

@classmethod
def compute(
Expand Down Expand Up @@ -388,7 +388,7 @@ class DistanceSentinelFootprintNeighbourhoodFeature(L2GFeature):

fill_na_value = 500_000
feature_dependency_type = VariantIndex
feature_name = "DistanceSentinelFootprintNeighbourhoodFeature"
feature_name = "distanceSentinelFootprintNeighbourhood"

@classmethod
def compute(
Expand Down
15 changes: 12 additions & 3 deletions src/gentropy/dataset/l2g_gold_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import get_record_with_maximum_value
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.study_locus import StudyLocus

if TYPE_CHECKING:
from pyspark.sql import DataFrame
Expand Down Expand Up @@ -107,11 +108,13 @@ def process_gene_interactions(
def build_feature_matrix(
self: L2GGoldStandard,
full_feature_matrix: L2GFeatureMatrix,
credible_set: StudyLocus,
) -> L2GFeatureMatrix:
"""Return a feature matrix for study loci in the gold standard.
Args:
full_feature_matrix (L2GFeatureMatrix): Feature matrix for all study loci to join on
credible_set (StudyLocus): Full credible sets to annotate the feature matrix with variant and study IDs and perform the join
Returns:
L2GFeatureMatrix: Feature matrix for study loci in the gold standard
Expand All @@ -120,10 +123,16 @@ def build_feature_matrix(

return L2GFeatureMatrix(
_df=full_feature_matrix._df.join(
f.broadcast(self.df.drop("variantId", "studyId", "sources")),
on=["studyLocusId", "geneId"],
credible_set.df.select("studyLocusId", "variantId", "studyId"),
"studyLocusId",
"left",
)
.join(
f.broadcast(self.df.drop("studyLocusId", "sources")),
on=["studyId", "variantId", "geneId"],
how="inner",
),
)
.distinct(),
with_gold_standard=True,
)

Expand Down
17 changes: 10 additions & 7 deletions src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Type

import pyspark.sql.functions as f

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.session import Session
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix
from gentropy.dataset.study_locus import StudyLocus
from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader
from gentropy.method.l2g.model import LocusToGeneModel

if TYPE_CHECKING:
Expand Down Expand Up @@ -40,8 +41,8 @@ def from_credible_set(
cls: Type[L2GPrediction],
session: Session,
credible_set: StudyLocus,
feature_matrix: L2GFeatureMatrix,
features_list: list[str],
features_input_loader: L2GFeatureInputLoader,
model_path: str | None,
hf_token: str | None = None,
download_from_hub: bool = True,
Expand All @@ -51,8 +52,8 @@ def from_credible_set(
Args:
session (Session): Session object that contains the Spark session
credible_set (StudyLocus): Dataset containing credible sets from GWAS only
feature_matrix (L2GFeatureMatrix): Dataset containing all credible sets and their annotations
features_list (list[str]): List of features to use for the model
features_input_loader (L2GFeatureInputLoader): Loader with all feature dependencies
model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name).
hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private.
download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True.
Expand All @@ -70,10 +71,12 @@ def from_credible_set(

# Prepare data
fm = (
L2GFeatureMatrix.from_features_list(
study_loci_to_annotate=credible_set,
features_list=features_list,
features_input_loader=features_input_loader,
L2GFeatureMatrix(
_df=(
credible_set.df.filter(f.col("studyType") == "gwas")
.select("studyLocusId")
.join(feature_matrix._df, "studyLocusId")
)
)
.fill_na()
.select_features(features_list)
Expand Down
Loading

0 comments on commit b7dce8f

Please sign in to comment.