diff --git a/recce/adapter/dbt_adapter/__init__.py b/recce/adapter/dbt_adapter/__init__.py index e31ce3c4..d6d74123 100644 --- a/recce/adapter/dbt_adapter/__init__.py +++ b/recce/adapter/dbt_adapter/__init__.py @@ -631,13 +631,37 @@ def select_nodes(self, select: Optional[str] = None, exclude: Optional[str] = No spec = parse_difference(select_list, exclude_list, "eager") else: spec = parse_difference(select_list, exclude_list) + + manifest = self.manifest.deepcopy() + + for (key, node) in self.previous_state.manifest.nodes.items(): + if key not in manifest.nodes: + manifest.nodes[key] = node + + for (key, node) in self.previous_state.manifest.sources.items(): + if key not in manifest.sources: + manifest.sources[key] = node + + for (key, node) in self.previous_state.manifest.exposures.items(): + if key not in manifest.exposures: + manifest.exposures[key] = node + + for (key, node) in self.previous_state.manifest.metrics.items(): + if key not in manifest.metrics: + manifest.metrics[key] = node + + if hasattr(self.previous_state.manifest, 'semantic_models'): + for (key, node) in self.previous_state.manifest.semantic_models.items(): + if key not in manifest.semantic_models: + manifest.semantic_models[key] = node + compiler = Compiler(self.runtime_config) # disable to print compile states tmp_func = dbt.compilation.print_compile_stats dbt.compilation.print_compile_stats = lambda x: None - graph = compiler.compile(self.manifest, write=False) + graph = compiler.compile(manifest, write=False) dbt.compilation.print_compile_stats = tmp_func - selector = NodeSelector(graph, self.manifest, previous_state=self.previous_state) + selector = NodeSelector(graph, manifest, previous_state=self.previous_state) return selector.get_selected(spec) diff --git a/tests/adapter/dbt_adapter/dbt_test_helper.py b/tests/adapter/dbt_adapter/dbt_test_helper.py index 6f04f376..39cf8d2b 100644 --- a/tests/adapter/dbt_adapter/dbt_test_helper.py +++ b/tests/adapter/dbt_adapter/dbt_test_helper.py @@ -41,7 +41,7 @@ def __init__(self): self.adapter.execute(f"CREATE schema IF NOT EXISTS {self.curr_schema}") self.adapter.set_artifacts(self.base_manifest, self.curr_manifest) - def create_model(self, model_name, base_csv, curr_csv, depends_on=[]): + def create_model(self, model_name, base_csv=None, curr_csv=None, depends_on=[]): package_name = "recce_test" # unique_id = f"model.{package_name}.{model_name}" unique_id = model_name @@ -84,19 +84,19 @@ def _add_model_to_manifest(base, raw_code): manifest.add_node_nofile(node) return node - base_csv = textwrap.dedent(base_csv) - curr_csv = textwrap.dedent(curr_csv) - - _add_model_to_manifest(True, base_csv) - _add_model_to_manifest(False, curr_csv) - - import pandas as pd - df_base = pd.read_csv(StringIO(base_csv)) - df_curr = pd.read_csv(StringIO(curr_csv)) dbt_adapter = self.adapter with dbt_adapter.connection_named('create model'): - dbt_adapter.execute(f"CREATE TABLE {self.base_schema}.{model_name} AS SELECT * FROM df_base") - dbt_adapter.execute(f"CREATE TABLE {self.curr_schema}.{model_name} AS SELECT * FROM df_curr") + import pandas as pd + if base_csv: + base_csv = textwrap.dedent(base_csv) + _add_model_to_manifest(True, base_csv) + df_base = pd.read_csv(StringIO(base_csv)) + dbt_adapter.execute(f"CREATE TABLE {self.base_schema}.{model_name} AS SELECT * FROM df_base") + if curr_csv: + curr_csv = textwrap.dedent(curr_csv) + _add_model_to_manifest(False, curr_csv) + df_curr = pd.read_csv(StringIO(curr_csv)) + dbt_adapter.execute(f"CREATE TABLE {self.curr_schema}.{model_name} AS SELECT * FROM df_curr") self.adapter.set_artifacts(self.base_manifest, self.curr_manifest) def remove_model(self, model_name): diff --git a/tests/adapter/dbt_adapter/test_selector.py b/tests/adapter/dbt_adapter/test_selector.py index 08463fce..5703d636 100644 --- a/tests/adapter/dbt_adapter/test_selector.py +++ b/tests/adapter/dbt_adapter/test_selector.py @@ -70,3 +70,32 @@ def test_select(dbt_test_helper): node_ids = adapter.select_nodes("state:modified,resource_type:snapshot") assert len(node_ids) == 1 + + +def test_select_removed_by_graph(dbt_test_helper): + csv_data_curr = """ + customer_id,name,age + 1,Alice,30 + 2,Bob,25 + 3,Charlie,35 + """ + + csv_data_base = """ + customer_id,name,age + 1,Alice,35 + 2,Bob,25 + 3,Charlie,35 + """ + + dbt_test_helper.create_model("customers_1", csv_data_base, csv_data_curr) + dbt_test_helper.create_model("customers_2", base_csv=csv_data_base, depends_on=["customers_1"]) + dbt_test_helper.create_model("customers_3", curr_csv=csv_data_curr, depends_on=["customers_1"]) + adapter: DbtAdapter = dbt_test_helper.context.adapter + + # Test graph operation + node_ids = adapter.select_nodes("customers_1+") + assert len(node_ids) == 3 + node_ids = adapter.select_nodes("+customers_2") + assert len(node_ids) == 2 + node_ids = adapter.select_nodes("+customers_3") + assert len(node_ids) == 2