From 4e37a0d95b639948fb252a12169a6c0e7d87c5cb Mon Sep 17 00:00:00 2001 From: popcorny Date: Wed, 4 Dec 2024 16:16:11 +0800 Subject: [PATCH] Add row count check validator Signed-off-by: popcorny --- recce/cli.py | 2 + recce/config.py | 8 ++- recce/models/types.py | 4 +- recce/tasks/core.py | 9 +-- recce/tasks/lineage.py | 2 +- recce/tasks/rowcount.py | 53 +++++++++------ tests/adapter/dbt_adapter/dbt_test_helper.py | 3 +- tests/tasks/test_preset_checks.py | 11 ++-- tests/tasks/test_row_count.py | 68 +++++++++++++++++++- 9 files changed, 125 insertions(+), 35 deletions(-) diff --git a/recce/cli.py b/recce/cli.py index 5a0e4c12..d3712c10 100644 --- a/recce/cli.py +++ b/recce/cli.py @@ -225,6 +225,8 @@ def server(host, port, state_file=None, **kwargs): from .server import app, AppState from rich.console import Console + RecceConfig(config_file=kwargs.get('config')) + handle_debug_flag(**kwargs) is_review = kwargs.get('review', False) is_cloud = kwargs.get('cloud', False) diff --git a/recce/config.py b/recce/config.py index 41da5c51..8f061a2a 100644 --- a/recce/config.py +++ b/recce/config.py @@ -42,11 +42,17 @@ def _verify_preset_checks(self): raise ValueError(f'Check type is required for check "{check}"') if check_type == 'linage_diff' or check_type == 'schema_diff': validator = CheckValidator() + elif check_type == 'row_count_diff': + from recce.tasks.rowcount import RowCountDiffCheckValidator + validator = RowCountDiffCheckValidator() else: validator = CheckValidator() validator.validate(check) except Exception as e: - raise RecceConfigException(f'Load preset check failed from "{self.config_file}"\n{check}', cause=e) + import json + raise RecceConfigException( + f"Load preset checks failed from '{self.config_file}'\n{json.dumps(check, indent=2)}", + cause=e) def generate_template(self): data = yaml.CommentedMap( diff --git a/recce/models/types.py b/recce/models/types.py index f27bbb52..707319ea 100644 --- a/recce/models/types.py +++ b/recce/models/types.py @@ -55,8 +55,8 @@ class Check(BaseModel): name: str description: Optional[str] = None type: RunType - params: Optional[dict] = None - view_options: Optional[dict] = None + params: dict = {} + view_options: dict = {} check_id: UUID4 = Field(default_factory=uuid.uuid4) is_checked: bool = False is_preset: bool = False diff --git a/recce/tasks/core.py b/recce/tasks/core.py index d0ed1d0f..f8903abb 100644 --- a/recce/tasks/core.py +++ b/recce/tasks/core.py @@ -127,9 +127,10 @@ def validate(self, check: dict): except Exception as e: raise ValueError(f'Invalid check format. {str(e)}') - self.validate_params(check) + self.validate_check(check) - def validate_params(self, check: Check): + def validate_check(self, check: Check): + """ + Validate the check. This is supposed to be overridden by subclass. Throw ValueError if the check is invalid. + """ pass - # if check.params is None: - # raise ValueError(f'"params" cannot be empty') diff --git a/recce/tasks/lineage.py b/recce/tasks/lineage.py index 48d85029..2a26c7f4 100644 --- a/recce/tasks/lineage.py +++ b/recce/tasks/lineage.py @@ -7,6 +7,6 @@ class LineageDiffCheckValidator(CheckValidator): @override - def validate_params(self, check: Check): + def validate_check(self, check: Check): if check.params is None and check.view_options is None: raise ValueError('"params" or "view_options" must be provided') diff --git a/recce/tasks/rowcount.py b/recce/tasks/rowcount.py index cd74fade..c6e26b1e 100644 --- a/recce/tasks/rowcount.py +++ b/recce/tasks/rowcount.py @@ -1,24 +1,28 @@ -from typing import TypedDict, Optional, Union, List +from typing import Optional, Union, List, Literal + +from pydantic import BaseModel +from typing_extensions import override from recce.core import default_context +from recce.models import Check from recce.tasks import Task -from recce.tasks.core import TaskResultDiffer +from recce.tasks.core import TaskResultDiffer, CheckValidator from recce.tasks.query import QueryMixin -class RowCountDiffParams(TypedDict, total=False): - node_names: Optional[list[str]] - node_ids: Optional[list[str]] - select: Optional[str] - exclude: Optional[str] - packages: Optional[list[str]] - view_mode: Optional[str] +class RowCountDiffParams(BaseModel): + node_names: Optional[list[str]] = None + node_ids: Optional[list[str]] = None + select: Optional[str] = None + exclude: Optional[str] = None + packages: Optional[list[str]] = None + view_mode: Optional[Literal['all', 'changed_models']] = None class RowCountDiffTask(Task, QueryMixin): - def __init__(self, params: RowCountDiffParams): + def __init__(self, params: dict): super().__init__() - self.params = params if params is not None else {} + self.params = RowCountDiffParams(**params) if params is not None else RowCountDiffParams() self.connection = None def _query_row_count(self, dbt_adapter, model_name, base=False): @@ -47,22 +51,22 @@ def execute_dbt(self): dbt_adapter = default_context().adapter query_candidates = [] - if self.params.get('node_ids', []) or self.params.get('node_names', []): - for node_id in self.params.get('node_ids', []): + if self.params.node_ids or self.params.node_names: + for node_id in self.params.node_ids or []: name = dbt_adapter.get_node_name_by_id(node_id) if name: query_candidates.append(name) - for node in self.params.get('node_names', []): + for node in self.params.node_names or []: query_candidates.append(node) else: def countable(unique_id): return unique_id.startswith('model') or unique_id.startswith('snapshot') or unique_id.startswith('seed') node_ids = dbt_adapter.select_nodes( - select=self.params.get('select', None), - exclude=self.params.get('exclude', None), - packages=self.params.get('packages', None), - view_mode=self.params.get('view_mode', None) + select=self.params.select, + exclude=self.params.exclude, + packages=self.params.packages, + view_mode=self.params.view_mode, ) node_ids = list(filter(countable, node_ids)) for node_id in node_ids: @@ -95,9 +99,9 @@ def execute_sqlmesh(self): query_candidates = [] - for node_id in self.params.get('node_ids', []): + for node_id in self.node_ids or []: query_candidates.append(node_id) - for node_name in self.params.get('node_names', []): + for node_name in self.params.node_names or []: query_candidates.append(node_name) from recce.adapter.sqlmesh_adapter import SqlmeshAdapter @@ -175,3 +179,12 @@ def _get_related_node_ids(self) -> Union[List[str], None]: def _get_changed_nodes(self) -> Union[List[str], None]: if self.changes: return self.changes.affected_root_keys.items + + +class RowCountDiffCheckValidator(CheckValidator): + @override + def validate_check(self, check: Check): + try: + RowCountDiffParams(**check.params) + except Exception as e: + raise ValueError(f'Invalid params: str{e}') diff --git a/tests/adapter/dbt_adapter/dbt_test_helper.py b/tests/adapter/dbt_adapter/dbt_test_helper.py index 890ba637..cfbf0743 100644 --- a/tests/adapter/dbt_adapter/dbt_test_helper.py +++ b/tests/adapter/dbt_adapter/dbt_test_helper.py @@ -56,10 +56,11 @@ def create_model( curr_csv=None, depends_on=[], disabled=False, + unique_id=None, package_name="recce_test", ): # unique_id = f"model.{package_name}.{model_name}" - unique_id = model_name + unique_id = unique_id if unique_id else model_name def _add_model_to_manifest(base, raw_code): if base: diff --git a/tests/tasks/test_preset_checks.py b/tests/tasks/test_preset_checks.py index 2d982d5b..54027822 100644 --- a/tests/tasks/test_preset_checks.py +++ b/tests/tasks/test_preset_checks.py @@ -11,11 +11,12 @@ def test_default_validator(): "params": {}, }) - # Failed because no params - with pytest.raises(ValueError): - CheckValidator().validate({ - "type": "row_count_diff", - }) + # Failed "name" type + # with pytest.raises(ValueError): + CheckValidator().validate({ + "name": 123, + "type": "row_count_diff", + }) def test_query_diff_validator(): diff --git a/tests/tasks/test_row_count.py b/tests/tasks/test_row_count.py index bb04aab6..64d28b73 100644 --- a/tests/tasks/test_row_count.py +++ b/tests/tasks/test_row_count.py @@ -1,3 +1,5 @@ +import pytest + from recce.tasks import RowCountDiffTask @@ -15,7 +17,7 @@ def test_row_count(dbt_test_helper): 2,Bob,25 """ - dbt_test_helper.create_model("customers", csv_data_base, csv_data_curr) + dbt_test_helper.create_model("customers", csv_data_base, csv_data_curr, unique_id='model.customers') task = RowCountDiffTask(dict(node_names=['customers'])) run_result = task.execute() assert run_result['customers']['base'] == 2 @@ -26,6 +28,11 @@ def test_row_count(dbt_test_helper): assert run_result['customers_']['base'] is None assert run_result['customers_']['curr'] is None + task = RowCountDiffTask(dict(node_ids=['model.customers'])) + run_result = task.execute() + assert run_result['customers']['base'] == 2 + assert run_result['customers']['curr'] == 3 + def test_row_count_with_selector(dbt_test_helper): csv_data_1 = """ @@ -51,3 +58,62 @@ def test_row_count_with_selector(dbt_test_helper): run_result = task.execute() assert len(run_result) == 2 + +def test_validator(): + from recce.tasks.rowcount import RowCountDiffCheckValidator + + validator = RowCountDiffCheckValidator() + + def validate(params: dict): + validator.validate({ + 'name': 'test', + 'type': 'row_count_diff', + 'params': params, + }) + + # Select all modesl + validate({}) + + # Select by node name + validate({ + 'node_names': ['abc'], + }) + with pytest.raises(ValueError): + validate({ + 'node_names': [123], + }) + with pytest.raises(ValueError): + validate({ + 'node_names': 'abc', + }) + + # Select by node id + validate({ + 'node_ids': ['model.abc'], + }) + + # Select by selector + validate({ + 'select': 'customers', + 'exclude': 'customers', + 'packages': ['jaffle_shop'], + 'view_mode': 'all', + }) + + # packages should be an array + with pytest.raises(ValueError): + validate({ + 'packages': 'jaffle_shop', + }) + + # view_mode should be 'all' or 'changed_models' + validate({ + 'view_mode': None, + }) + validate({ + 'view_mode': 'all', + }) + with pytest.raises(ValueError): + validate({ + 'view_mode': 'abc', + })