Skip to content

Commit

Permalink
fix: validation name mapping (opentargets#753)
Browse files Browse the repository at this point in the history
* fix: use mapping instead of enum values in valid_rows

* fix: typos

* fix: swap valid and invalid paths

---------

Co-authored-by: Szymon Szyszkowski <[email protected]>
  • Loading branch information
project-defiant and Szymon Szyszkowski authored Sep 10, 2024
1 parent 151b4ec commit 0b216f6
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 34 deletions.
27 changes: 19 additions & 8 deletions src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,15 @@ def get_QC_column_name(cls: type[Self]) -> str | None:
return None

@classmethod
def get_QC_categories(cls: type[Self]) -> list[str]:
"""Method to get the QC categories for this dataset. Returns empty list unless overriden by child classes.
def get_QC_mappings(cls: type[Self]) -> dict[str, str]:
"""Method to get the mapping between QC flag and corresponding QC category value.
Returns empty dict unless overriden by child classes.
Returns:
list[str]: Column name
dict[str, str]: Mapping between flag name and QC column category value.
"""
return []
return {}

@classmethod
def from_parquet(
Expand Down Expand Up @@ -193,22 +195,31 @@ def validate_schema(self: Dataset) -> None:
def valid_rows(self: Self, invalid_flags: list[str], invalid: bool = False) -> Self:
"""Filters `Dataset` according to a list of quality control flags. Only `Dataset` classes with a QC column can be validated.
This method checks do following steps:
- Check if the Dataset contains a QC column.
- Check if the invalid_flags exist in the QC mappings flags.
- Filter the Dataset according to the invalid_flags and invalid parameters.
Args:
invalid_flags (list[str]): List of quality control flags to be excluded.
invalid (bool): If True returns the invalid rows, instead of the valids. Defaults to False.
invalid (bool): If True returns the invalid rows, instead of the valid. Defaults to False.
Returns:
Self: filtered dataset.
Raises:
ValueError: If the Dataset does not contain a QC column.
ValueError: If the invalid_flags elements do not exist in QC mappings flags.
"""
# If the invalid flags are not valid quality checks (enum) for this Dataset we raise an error:
invalid_reasons = []
for flag in invalid_flags:
if flag not in self.get_QC_categories():
if flag not in self.get_QC_mappings():
raise ValueError(
f"{flag} is not a valid QC flag for {type(self).__name__} ({self.get_QC_categories()})."
f"{flag} is not a valid QC flag for {type(self).__name__} ({self.get_QC_mappings()})."
)
reason = self.get_QC_mappings()[flag]
invalid_reasons.append(reason)

qc_column_name = self.get_QC_column_name()
# If Dataset (class) does not contain QC column we raise an error:
Expand All @@ -222,7 +233,7 @@ def valid_rows(self: Self, invalid_flags: list[str], invalid: bool = False) -> S
qc = f.when(f.col(column).isNull(), f.array()).otherwise(f.col(column))

filterCondition = ~f.arrays_overlap(
f.array([f.lit(i) for i in invalid_flags]), qc
f.array([f.lit(i) for i in invalid_reasons]), qc
)
# Returning the filtered dataset:
if invalid:
Expand Down
8 changes: 4 additions & 4 deletions src/gentropy/dataset/study_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ def get_QC_column_name(cls: type[StudyIndex]) -> str:
return "qualityControls"

@classmethod
def get_QC_categories(cls: type[StudyIndex]) -> list[str]:
"""Return the quality control categories.
def get_QC_mappings(cls: type[StudyIndex]) -> dict[str, str]:
"""Quality control flag to QC column category mappings.
Returns:
list[str]: The quality control categories.
dict[str, str]: Mapping between flag name and QC column category value.
"""
return [member.value for member in StudyQualityCheck]
return {member.name: member.value for member in StudyQualityCheck}

@classmethod
def aggregate_and_map_ancestries(
Expand Down
8 changes: 4 additions & 4 deletions src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,13 @@ def get_QC_column_name(cls: type[StudyLocus]) -> str:
return "qualityControls"

@classmethod
def get_QC_categories(cls: type[StudyLocus]) -> list[str]:
"""Quality control categories.
def get_QC_mappings(cls: type[StudyLocus]) -> dict[str, str]:
"""Quality control flag to QC column category mappings.
Returns:
list[str]: List of quality control categories.
dict[str, str]: Mapping between flag name and QC column category value.
"""
return [member.value for member in StudyLocusQualityCheck]
return {member.name: member.value for member in StudyLocusQualityCheck}

def filter_by_study_type(
self: StudyLocus, study_type: str, study_index: StudyIndex
Expand Down
10 changes: 5 additions & 5 deletions src/gentropy/study_locus_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(
.validate_unique_study_locus_id() # Flagging duplicated study locus ids
).persist() # we will need this for 2 types of outputs

study_locus_with_qc.valid_rows(invalid_qc_reasons).df.write.parquet(
invalid_study_locus_path
)

study_locus_with_qc.valid_rows(
invalid_qc_reasons, invalid=True
).df.write.parquet(valid_study_locus_path)
).df.write.parquet(invalid_study_locus_path)

study_locus_with_qc.valid_rows(invalid_qc_reasons).df.write.parquet(
valid_study_locus_path
)
22 changes: 9 additions & 13 deletions tests/gentropy/dataset/test_dataset_exclusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class TestDataExclusion:
the right rows are excluded.
"""

CORRECT_FILTER = ["The identifier of this study is not unique."]
INCORRECT_FILTER = ["Some mock flag."]
ALL_FILTERS = [member.value for member in StudyQualityCheck]
CORRECT_FLAG = ["DUPLICATED_STUDY"]
INCORRECT_FLAG = ["UNKNOWN_CATEGORY"]
ALL_FLAGS = [member.name for member in StudyQualityCheck]

DATASET = [
# Good study no flag:
Expand Down Expand Up @@ -52,8 +52,8 @@ def _setup(self: TestDataExclusion, spark: SparkSession) -> None:
@pytest.mark.parametrize(
"filter_, expected",
[
(CORRECT_FILTER, ["S1", "S2"]),
(ALL_FILTERS, ["S1"]),
(CORRECT_FLAG, ["S1", "S2"]),
(ALL_FLAGS, ["S1"]),
],
)
def test_valid_rows(
Expand All @@ -72,8 +72,8 @@ def test_valid_rows(
@pytest.mark.parametrize(
"filter_, expected",
[
(CORRECT_FILTER, ["S3"]),
(ALL_FILTERS, ["S2", "S3"]),
(CORRECT_FLAG, ["S3"]),
(ALL_FLAGS, ["S2", "S3"]),
],
)
def test_invalid_rows(
Expand All @@ -90,11 +90,7 @@ def test_invalid_rows(
def test_failing_quality_flag(self: TestDataExclusion) -> None:
"""Test invalid quality flag."""
with pytest.raises(ValueError):
self.study_index.valid_rows(
self.INCORRECT_FILTER, invalid=True
).df.collect()
self.study_index.valid_rows(self.INCORRECT_FLAG, invalid=True).df.collect()

with pytest.raises(ValueError):
self.study_index.valid_rows(
self.INCORRECT_FILTER, invalid=False
).df.collect()
self.study_index.valid_rows(self.INCORRECT_FLAG, invalid=False).df.collect()

0 comments on commit 0b216f6

Please sign in to comment.