Skip to content

Commit

Permalink
Merge pull request #376 from DataRecce/feature/drc-538-enhancement-su…
Browse files Browse the repository at this point in the history
…pport-to-select-removed-node-by-graph-operator

[Enhancement] Support to select removed node by graph operator
  • Loading branch information
popcornylu authored Jul 4, 2024
2 parents a5c8484 + 9eca81b commit f63b5b2
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 14 deletions.
28 changes: 26 additions & 2 deletions recce/adapter/dbt_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,13 +631,37 @@ def select_nodes(self, select: Optional[str] = None, exclude: Optional[str] = No
spec = parse_difference(select_list, exclude_list, "eager")
else:
spec = parse_difference(select_list, exclude_list)

manifest = self.manifest.deepcopy()

for (key, node) in self.previous_state.manifest.nodes.items():
if key not in manifest.nodes:
manifest.nodes[key] = node

for (key, node) in self.previous_state.manifest.sources.items():
if key not in manifest.sources:
manifest.sources[key] = node

for (key, node) in self.previous_state.manifest.exposures.items():
if key not in manifest.exposures:
manifest.exposures[key] = node

for (key, node) in self.previous_state.manifest.metrics.items():
if key not in manifest.metrics:
manifest.metrics[key] = node

if hasattr(self.previous_state.manifest, 'semantic_models'):
for (key, node) in self.previous_state.manifest.semantic_models.items():
if key not in manifest.semantic_models:
manifest.semantic_models[key] = node

compiler = Compiler(self.runtime_config)
# disable to print compile states
tmp_func = dbt.compilation.print_compile_stats
dbt.compilation.print_compile_stats = lambda x: None
graph = compiler.compile(self.manifest, write=False)
graph = compiler.compile(manifest, write=False)
dbt.compilation.print_compile_stats = tmp_func
selector = NodeSelector(graph, self.manifest, previous_state=self.previous_state)
selector = NodeSelector(graph, manifest, previous_state=self.previous_state)

return selector.get_selected(spec)

Expand Down
24 changes: 12 additions & 12 deletions tests/adapter/dbt_adapter/dbt_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self):
self.adapter.execute(f"CREATE schema IF NOT EXISTS {self.curr_schema}")
self.adapter.set_artifacts(self.base_manifest, self.curr_manifest)

def create_model(self, model_name, base_csv, curr_csv, depends_on=[]):
def create_model(self, model_name, base_csv=None, curr_csv=None, depends_on=[]):
package_name = "recce_test"
# unique_id = f"model.{package_name}.{model_name}"
unique_id = model_name
Expand Down Expand Up @@ -84,19 +84,19 @@ def _add_model_to_manifest(base, raw_code):
manifest.add_node_nofile(node)
return node

base_csv = textwrap.dedent(base_csv)
curr_csv = textwrap.dedent(curr_csv)

_add_model_to_manifest(True, base_csv)
_add_model_to_manifest(False, curr_csv)

import pandas as pd
df_base = pd.read_csv(StringIO(base_csv))
df_curr = pd.read_csv(StringIO(curr_csv))
dbt_adapter = self.adapter
with dbt_adapter.connection_named('create model'):
dbt_adapter.execute(f"CREATE TABLE {self.base_schema}.{model_name} AS SELECT * FROM df_base")
dbt_adapter.execute(f"CREATE TABLE {self.curr_schema}.{model_name} AS SELECT * FROM df_curr")
import pandas as pd
if base_csv:
base_csv = textwrap.dedent(base_csv)
_add_model_to_manifest(True, base_csv)
df_base = pd.read_csv(StringIO(base_csv))
dbt_adapter.execute(f"CREATE TABLE {self.base_schema}.{model_name} AS SELECT * FROM df_base")
if curr_csv:
curr_csv = textwrap.dedent(curr_csv)
_add_model_to_manifest(False, curr_csv)
df_curr = pd.read_csv(StringIO(curr_csv))
dbt_adapter.execute(f"CREATE TABLE {self.curr_schema}.{model_name} AS SELECT * FROM df_curr")
self.adapter.set_artifacts(self.base_manifest, self.curr_manifest)

def remove_model(self, model_name):
Expand Down
29 changes: 29 additions & 0 deletions tests/adapter/dbt_adapter/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,32 @@ def test_select(dbt_test_helper):

node_ids = adapter.select_nodes("state:modified,resource_type:snapshot")
assert len(node_ids) == 1


def test_select_removed_by_graph(dbt_test_helper):
csv_data_curr = """
customer_id,name,age
1,Alice,30
2,Bob,25
3,Charlie,35
"""

csv_data_base = """
customer_id,name,age
1,Alice,35
2,Bob,25
3,Charlie,35
"""

dbt_test_helper.create_model("customers_1", csv_data_base, csv_data_curr)
dbt_test_helper.create_model("customers_2", base_csv=csv_data_base, depends_on=["customers_1"])
dbt_test_helper.create_model("customers_3", curr_csv=csv_data_curr, depends_on=["customers_1"])
adapter: DbtAdapter = dbt_test_helper.context.adapter

# Test graph operation
node_ids = adapter.select_nodes("customers_1+")
assert len(node_ids) == 3
node_ids = adapter.select_nodes("+customers_2")
assert len(node_ids) == 2
node_ids = adapter.select_nodes("+customers_3")
assert len(node_ids) == 2

0 comments on commit f63b5b2

Please sign in to comment.