Skip to content

Commit

Permalink
Fix the node with disabled nodes
Browse files Browse the repository at this point in the history
Signed-off-by: popcorny <[email protected]>
  • Loading branch information
popcornylu committed Jul 15, 2024
1 parent 97f6068 commit 054e3fe
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 25 deletions.
32 changes: 10 additions & 22 deletions recce/adapter/dbt_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,28 +632,16 @@ def select_nodes(self, select: Optional[str] = None, exclude: Optional[str] = No
else:
spec = parse_difference(select_list, exclude_list)

manifest = self.manifest.deepcopy()

for (key, node) in self.previous_state.manifest.nodes.items():
if key not in manifest.nodes:
manifest.nodes[key] = node

for (key, node) in self.previous_state.manifest.sources.items():
if key not in manifest.sources:
manifest.sources[key] = node

for (key, node) in self.previous_state.manifest.exposures.items():
if key not in manifest.exposures:
manifest.exposures[key] = node

for (key, node) in self.previous_state.manifest.metrics.items():
if key not in manifest.metrics:
manifest.metrics[key] = node

if hasattr(self.previous_state.manifest, 'semantic_models'):
for (key, node) in self.previous_state.manifest.semantic_models.items():
if key not in manifest.semantic_models:
manifest.semantic_models[key] = node
manifest = Manifest()
manifest_prev = self.previous_state.manifest
manifest_curr = self.manifest

manifest.nodes = {**manifest_prev.nodes, **manifest_curr.nodes}
manifest.sources = {**manifest_prev.sources, **manifest_curr.sources}
manifest.exposures = {**manifest_prev.exposures, **manifest_curr.exposures}
manifest.metrics = {**manifest_prev.metrics, **manifest_curr.metrics}
if hasattr(manifest_prev, 'semantic_models'):
manifest.semantic_models = {**manifest_prev.semantic_models, **manifest_curr.semantic_models}

compiler = Compiler(self.runtime_config)
# disable to print compile states
Expand Down
9 changes: 6 additions & 3 deletions tests/adapter/dbt_adapter/dbt_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self):
self.adapter.execute(f"CREATE schema IF NOT EXISTS {self.curr_schema}")
self.adapter.set_artifacts(self.base_manifest, self.curr_manifest)

def create_model(self, model_name, base_csv=None, curr_csv=None, depends_on=[]):
def create_model(self, model_name, base_csv=None, curr_csv=None, depends_on=[], disabled=False):
package_name = "recce_test"
# unique_id = f"model.{package_name}.{model_name}"
unique_id = model_name
Expand Down Expand Up @@ -81,7 +81,11 @@ def _add_model_to_manifest(base, raw_code):
"nodes": depends_on
},
})
manifest.add_node_nofile(node)

if disabled:
manifest.add_disabled_nofile(node)
else:
manifest.add_node_nofile(node)
return node

dbt_adapter = self.adapter
Expand All @@ -97,7 +101,6 @@ def _add_model_to_manifest(base, raw_code):
_add_model_to_manifest(False, curr_csv)
df_curr = pd.read_csv(StringIO(curr_csv))
dbt_adapter.execute(f"CREATE TABLE {self.curr_schema}.{model_name} AS SELECT * FROM df_curr")
self.adapter.set_artifacts(self.base_manifest, self.curr_manifest)

def remove_model(self, model_name):
dbt_adapter = self.adapter
Expand Down
24 changes: 24 additions & 0 deletions tests/adapter/dbt_adapter/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,27 @@ def test_select_removed_by_graph(dbt_test_helper):
assert len(node_ids) == 2
node_ids = adapter.select_nodes("+customers_3")
assert len(node_ids) == 2


def test_select_with_disabled(dbt_test_helper):
csv_data_curr = """
customer_id,name,age
1,Alice,30
2,Bob,25
3,Charlie,35
"""

csv_data_base = """
customer_id,name,age
1,Alice,35
2,Bob,25
3,Charlie,35
"""

dbt_test_helper.create_model("customers_1", csv_data_base, csv_data_curr)
dbt_test_helper.create_model("customers_2", csv_data_base, csv_data_curr, disabled=True)
adapter: DbtAdapter = dbt_test_helper.context.adapter

# Test graph operation
node_ids = adapter.select_nodes("customers_1+")
assert len(node_ids) == 1

0 comments on commit 054e3fe

Please sign in to comment.