Skip to content

Commit

Permalink
Add row count check validator
Browse files Browse the repository at this point in the history
Signed-off-by: popcorny <[email protected]>
  • Loading branch information
popcornylu committed Dec 4, 2024
1 parent 88399d8 commit 4e37a0d
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 35 deletions.
2 changes: 2 additions & 0 deletions recce/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def server(host, port, state_file=None, **kwargs):
from .server import app, AppState
from rich.console import Console

RecceConfig(config_file=kwargs.get('config'))

handle_debug_flag(**kwargs)
is_review = kwargs.get('review', False)
is_cloud = kwargs.get('cloud', False)
Expand Down
8 changes: 7 additions & 1 deletion recce/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ def _verify_preset_checks(self):
raise ValueError(f'Check type is required for check "{check}"')
if check_type == 'linage_diff' or check_type == 'schema_diff':
validator = CheckValidator()
elif check_type == 'row_count_diff':
from recce.tasks.rowcount import RowCountDiffCheckValidator
validator = RowCountDiffCheckValidator()
else:
validator = CheckValidator()
validator.validate(check)
except Exception as e:
raise RecceConfigException(f'Load preset check failed from "{self.config_file}"\n{check}', cause=e)
import json
raise RecceConfigException(
f"Load preset checks failed from '{self.config_file}'\n{json.dumps(check, indent=2)}",
cause=e)

def generate_template(self):
data = yaml.CommentedMap(
Expand Down
4 changes: 2 additions & 2 deletions recce/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class Check(BaseModel):
name: str
description: Optional[str] = None
type: RunType
params: Optional[dict] = None
view_options: Optional[dict] = None
params: dict = {}
view_options: dict = {}
check_id: UUID4 = Field(default_factory=uuid.uuid4)
is_checked: bool = False
is_preset: bool = False
Expand Down
9 changes: 5 additions & 4 deletions recce/tasks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def validate(self, check: dict):
except Exception as e:
raise ValueError(f'Invalid check format. {str(e)}')

self.validate_params(check)
self.validate_check(check)

def validate_params(self, check: Check):
def validate_check(self, check: Check):
"""
Validate the check. This is supposed to be overridden by subclass. Throw ValueError if the check is invalid.
"""
pass
# if check.params is None:
# raise ValueError(f'"params" cannot be empty')
2 changes: 1 addition & 1 deletion recce/tasks/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
class LineageDiffCheckValidator(CheckValidator):

@override
def validate_params(self, check: Check):
def validate_check(self, check: Check):
if check.params is None and check.view_options is None:
raise ValueError('"params" or "view_options" must be provided')
53 changes: 33 additions & 20 deletions recce/tasks/rowcount.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
from typing import TypedDict, Optional, Union, List
from typing import Optional, Union, List, Literal

from pydantic import BaseModel
from typing_extensions import override

from recce.core import default_context
from recce.models import Check
from recce.tasks import Task
from recce.tasks.core import TaskResultDiffer
from recce.tasks.core import TaskResultDiffer, CheckValidator
from recce.tasks.query import QueryMixin


class RowCountDiffParams(TypedDict, total=False):
node_names: Optional[list[str]]
node_ids: Optional[list[str]]
select: Optional[str]
exclude: Optional[str]
packages: Optional[list[str]]
view_mode: Optional[str]
class RowCountDiffParams(BaseModel):
node_names: Optional[list[str]] = None
node_ids: Optional[list[str]] = None
select: Optional[str] = None
exclude: Optional[str] = None
packages: Optional[list[str]] = None
view_mode: Optional[Literal['all', 'changed_models']] = None


class RowCountDiffTask(Task, QueryMixin):
def __init__(self, params: RowCountDiffParams):
def __init__(self, params: dict):
super().__init__()
self.params = params if params is not None else {}
self.params = RowCountDiffParams(**params) if params is not None else RowCountDiffParams()
self.connection = None

def _query_row_count(self, dbt_adapter, model_name, base=False):
Expand Down Expand Up @@ -47,22 +51,22 @@ def execute_dbt(self):
dbt_adapter = default_context().adapter

query_candidates = []
if self.params.get('node_ids', []) or self.params.get('node_names', []):
for node_id in self.params.get('node_ids', []):
if self.params.node_ids or self.params.node_names:
for node_id in self.params.node_ids or []:
name = dbt_adapter.get_node_name_by_id(node_id)
if name:
query_candidates.append(name)
for node in self.params.get('node_names', []):
for node in self.params.node_names or []:
query_candidates.append(node)
else:
def countable(unique_id):
return unique_id.startswith('model') or unique_id.startswith('snapshot') or unique_id.startswith('seed')

node_ids = dbt_adapter.select_nodes(
select=self.params.get('select', None),
exclude=self.params.get('exclude', None),
packages=self.params.get('packages', None),
view_mode=self.params.get('view_mode', None)
select=self.params.select,
exclude=self.params.exclude,
packages=self.params.packages,
view_mode=self.params.view_mode,
)
node_ids = list(filter(countable, node_ids))
for node_id in node_ids:
Expand Down Expand Up @@ -95,9 +99,9 @@ def execute_sqlmesh(self):

query_candidates = []

for node_id in self.params.get('node_ids', []):
for node_id in self.node_ids or []:
query_candidates.append(node_id)
for node_name in self.params.get('node_names', []):
for node_name in self.params.node_names or []:
query_candidates.append(node_name)

from recce.adapter.sqlmesh_adapter import SqlmeshAdapter
Expand Down Expand Up @@ -175,3 +179,12 @@ def _get_related_node_ids(self) -> Union[List[str], None]:
def _get_changed_nodes(self) -> Union[List[str], None]:
if self.changes:
return self.changes.affected_root_keys.items


class RowCountDiffCheckValidator(CheckValidator):
@override
def validate_check(self, check: Check):
try:
RowCountDiffParams(**check.params)
except Exception as e:
raise ValueError(f'Invalid params: str{e}')
3 changes: 2 additions & 1 deletion tests/adapter/dbt_adapter/dbt_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ def create_model(
curr_csv=None,
depends_on=[],
disabled=False,
unique_id=None,
package_name="recce_test",
):
# unique_id = f"model.{package_name}.{model_name}"
unique_id = model_name
unique_id = unique_id if unique_id else model_name

def _add_model_to_manifest(base, raw_code):
if base:
Expand Down
11 changes: 6 additions & 5 deletions tests/tasks/test_preset_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ def test_default_validator():
"params": {},
})

# Failed because no params
with pytest.raises(ValueError):
CheckValidator().validate({
"type": "row_count_diff",
})
# Failed "name" type
# with pytest.raises(ValueError):
CheckValidator().validate({
"name": 123,
"type": "row_count_diff",
})


def test_query_diff_validator():
Expand Down
68 changes: 67 additions & 1 deletion tests/tasks/test_row_count.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from recce.tasks import RowCountDiffTask


Expand All @@ -15,7 +17,7 @@ def test_row_count(dbt_test_helper):
2,Bob,25
"""

dbt_test_helper.create_model("customers", csv_data_base, csv_data_curr)
dbt_test_helper.create_model("customers", csv_data_base, csv_data_curr, unique_id='model.customers')
task = RowCountDiffTask(dict(node_names=['customers']))
run_result = task.execute()
assert run_result['customers']['base'] == 2
Expand All @@ -26,6 +28,11 @@ def test_row_count(dbt_test_helper):
assert run_result['customers_']['base'] is None
assert run_result['customers_']['curr'] is None

task = RowCountDiffTask(dict(node_ids=['model.customers']))
run_result = task.execute()
assert run_result['customers']['base'] == 2
assert run_result['customers']['curr'] == 3


def test_row_count_with_selector(dbt_test_helper):
csv_data_1 = """
Expand All @@ -51,3 +58,62 @@ def test_row_count_with_selector(dbt_test_helper):
run_result = task.execute()
assert len(run_result) == 2


def test_validator():
from recce.tasks.rowcount import RowCountDiffCheckValidator

validator = RowCountDiffCheckValidator()

def validate(params: dict):
validator.validate({
'name': 'test',
'type': 'row_count_diff',
'params': params,
})

# Select all modesl
validate({})

# Select by node name
validate({
'node_names': ['abc'],
})
with pytest.raises(ValueError):
validate({
'node_names': [123],
})
with pytest.raises(ValueError):
validate({
'node_names': 'abc',
})

# Select by node id
validate({
'node_ids': ['model.abc'],
})

# Select by selector
validate({
'select': 'customers',
'exclude': 'customers',
'packages': ['jaffle_shop'],
'view_mode': 'all',
})

# packages should be an array
with pytest.raises(ValueError):
validate({
'packages': 'jaffle_shop',
})

# view_mode should be 'all' or 'changed_models'
validate({
'view_mode': None,
})
validate({
'view_mode': 'all',
})
with pytest.raises(ValueError):
validate({
'view_mode': 'abc',
})

0 comments on commit 4e37a0d

Please sign in to comment.