From 7d593cec173049b1ca96f91e223d1086d07d71de Mon Sep 17 00:00:00 2001 From: ronanstokes-db Date: Wed, 10 Apr 2024 14:25:55 -0700 Subject: [PATCH] wip --- dbldatagen/constraints/positive_values.py | 4 ++-- tests/test_constraints.py | 26 +++++++++++++++++------ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/dbldatagen/constraints/positive_values.py b/dbldatagen/constraints/positive_values.py index b34ef9f7..cffc56ee 100644 --- a/dbldatagen/constraints/positive_values.py +++ b/dbldatagen/constraints/positive_values.py @@ -31,8 +31,8 @@ def _generate_filter_expression(self): """ Generate a filter expression that may be used for filtering""" expressions = [F.col(colname) for colname in self._columns] if self._strict: - filters = [col.isNotNull() & col > 0 for col in expressions] + filters = [col.isNotNull() & (col > 0) for col in expressions] else: - filters = [col.isNotNull() & col >= 0 for col in expressions] + filters = [col.isNotNull() & (col >= 0) for col in expressions] return self.combineConstraintExpressions(filters) \ No newline at end of file diff --git a/tests/test_constraints.py b/tests/test_constraints.py index 001b9109..6f82d41e 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -88,11 +88,18 @@ def test_scalar_relation(self, generationSpec1, column, operation, literalValue, rowCount = testDataDF.count() assert rowCount == expectedRows - def testNegativeValues(self, generationSpec1): + @pytest.mark.parametrize("columns, strictFlag, expectedRows", + [ + ("positive_and_negative", True, 99), + ("positive_and_negative", False, 100), + ("positive_and_negative", "skip", 100), + ]) + def testNegativeValues(self, generationSpec1, columns, strictFlag, expectedRows): testDataSpec = (generationSpec1 .withConstraints([SqlExpr("id < 100"), SqlExpr("id > 0")]) - .withConstraint(NegativeValues("positive_and_negative")) + .withConstraint(NegativeValues(columns, strict=strictFlag) if strictFlag != "skip" + else NegativeValues(columns)) ) testDataDF = testDataSpec.build() @@ -100,17 +107,24 @@ def testNegativeValues(self, generationSpec1): rowCount = testDataDF.count() assert rowCount == 99 - def testPositiveValues(self, generationSpec1): + @pytest.mark.parametrize("columns, strictFlag, expectedRows", + [ + ("positive_and_negative", True, 99), + ("positive_and_negative", False, 100), + ("positive_and_negative", "skip", 100), + ]) + def testPositiveValues(self, generationSpec1, columns, strictFlag, expectedRows): testDataSpec = (generationSpec1 - .withConstraints([SqlExpr("id < 100"), + .withConstraints([SqlExpr("id < 200"), SqlExpr("id > 0")]) - .withConstraint(PositiveValues("positive_and_negative")) + .withConstraint(PositiveValues(columns, strict=strictFlag) if strictFlag != "skip" + else PositiveValues(columns)) ) testDataDF = testDataSpec.build() rowCount = testDataDF.count() - assert rowCount == 99 + assert rowCount == expectedRows def test_scalar_relation_bad(self, generationSpec1): with pytest.raises(ValueError):