Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ronanstokes-db committed Apr 10, 2024
1 parent 88dcd39 commit 7d593ce
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
4 changes: 2 additions & 2 deletions dbldatagen/constraints/positive_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def _generate_filter_expression(self):
""" Generate a filter expression that may be used for filtering"""
expressions = [F.col(colname) for colname in self._columns]
if self._strict:
filters = [col.isNotNull() & col > 0 for col in expressions]
filters = [col.isNotNull() & (col > 0) for col in expressions]
else:
filters = [col.isNotNull() & col >= 0 for col in expressions]
filters = [col.isNotNull() & (col >= 0) for col in expressions]

return self.combineConstraintExpressions(filters)
26 changes: 20 additions & 6 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,43 @@ def test_scalar_relation(self, generationSpec1, column, operation, literalValue,
rowCount = testDataDF.count()
assert rowCount == expectedRows

def testNegativeValues(self, generationSpec1):
@pytest.mark.parametrize("columns, strictFlag, expectedRows",
[
("positive_and_negative", True, 99),
("positive_and_negative", False, 100),
("positive_and_negative", "skip", 100),
])
def testNegativeValues(self, generationSpec1, columns, strictFlag, expectedRows):
testDataSpec = (generationSpec1
.withConstraints([SqlExpr("id < 100"),
SqlExpr("id > 0")])
.withConstraint(NegativeValues("positive_and_negative"))
.withConstraint(NegativeValues(columns, strict=strictFlag) if strictFlag != "skip"
else NegativeValues(columns))
)

testDataDF = testDataSpec.build()

rowCount = testDataDF.count()
assert rowCount == 99

def testPositiveValues(self, generationSpec1):
@pytest.mark.parametrize("columns, strictFlag, expectedRows",
[
("positive_and_negative", True, 99),
("positive_and_negative", False, 100),
("positive_and_negative", "skip", 100),
])
def testPositiveValues(self, generationSpec1, columns, strictFlag, expectedRows):
testDataSpec = (generationSpec1
.withConstraints([SqlExpr("id < 100"),
.withConstraints([SqlExpr("id < 200"),
SqlExpr("id > 0")])
.withConstraint(PositiveValues("positive_and_negative"))
.withConstraint(PositiveValues(columns, strict=strictFlag) if strictFlag != "skip"
else PositiveValues(columns))
)

testDataDF = testDataSpec.build()

rowCount = testDataDF.count()
assert rowCount == 99
assert rowCount == expectedRows

def test_scalar_relation_bad(self, generationSpec1):
with pytest.raises(ValueError):
Expand Down

0 comments on commit 7d593ce

Please sign in to comment.