From 5ee7c4b930c93d571385aba22b3000932364ea21 Mon Sep 17 00:00:00 2001 From: Denis Kokorin Date: Wed, 3 Jul 2024 19:32:08 +0300 Subject: [PATCH] Add ResourceLoader.resolve_adapter_name --- dbt_pumpkin/loader.py | 30 +++++++++++++++++-- ...get_column_types.sql => lookup_tables.sql} | 6 ++-- dbt_pumpkin/macros/resolve_adapter_name.sql | 16 ++++++++++ tests/test_loader.py | 4 +++ 4 files changed, 50 insertions(+), 6 deletions(-) rename dbt_pumpkin/macros/{get_column_types.sql => lookup_tables.sql} (82%) create mode 100644 dbt_pumpkin/macros/resolve_adapter_name.sql diff --git a/dbt_pumpkin/loader.py b/dbt_pumpkin/loader.py index 2a45db6..1861a26 100644 --- a/dbt_pumpkin/loader.py +++ b/dbt_pumpkin/loader.py @@ -36,6 +36,7 @@ def __init__(self, project_params: ProjectParams, resource_params: ResourceParam self._resource_ids: dict[ResourceType, set[ResourceID]] = None self._resources: list[Resource] = None self._tables: list[Table] = None + self._adapter_name: str = None self._yaml = YAML(typ="safe") def _do_load_manifest(self) -> Manifest: @@ -226,7 +227,7 @@ def _create_pumpkin_project(self, project_vars: dict[str, any]) -> Path: return pumpkin_dir def _run_operation( - self, operation_name: str, project_vars: dict[str, any], result_callback: Callable[[dict], None] + self, operation_name: str, project_vars: dict[str, any] | None, result_callback: Callable[[any], None] ): pumpkin_dir = self._create_pumpkin_project(project_vars) @@ -257,7 +258,7 @@ def _do_lookup_tables(self) -> list[Table]: raw_resources = self._raw_resources project_vars = { - "get_column_types_args": { + "lookup_tables_args": { str(resource.unique_id): [resource.database, resource.schema, resource.identifier] for resource in raw_resources }, @@ -273,7 +274,7 @@ def on_result(result: dict): tables.append(table) logger.info("Looked up %s / %s: %s", len(tables), len(raw_resources), table.resource_id) - self._run_operation("get_column_types", project_vars, on_result) + self._run_operation("lookup_tables", project_vars, on_result) logger.info("Found %s tables", len(tables)) @@ -283,3 +284,26 @@ def lookup_tables(self): if self._tables is None: self._tables = self._do_lookup_tables() return self._tables + + def _do_resolve_adapter_name(self) -> str: + logger.info("Resolving adapter name") + + adapter_names: list[str] = [] + + def on_result(result: str): + adapter_names.append(result) + logger.debug("Resolved adapter name: %s", result) + + self._run_operation("resolve_adapter_name", project_vars=None, result_callback=on_result) + + if len(adapter_names) != 1: + msg = f"Expected exactly 1 adapter name, got: {adapter_names}" + raise PumpkinError(msg) + + return adapter_names[0] + + def resolve_adapter_name(self) -> str: + if self._adapter_name is None: + self._adapter_name = self._do_resolve_adapter_name() + + return self._adapter_name diff --git a/dbt_pumpkin/macros/get_column_types.sql b/dbt_pumpkin/macros/lookup_tables.sql similarity index 82% rename from dbt_pumpkin/macros/get_column_types.sql rename to dbt_pumpkin/macros/lookup_tables.sql index 822a35a..f58cdb3 100644 --- a/dbt_pumpkin/macros/get_column_types.sql +++ b/dbt_pumpkin/macros/lookup_tables.sql @@ -1,5 +1,5 @@ -{% macro get_column_types() %} - {% for resource_id, database_schema_identifier in var('get_column_types_args').items() %} +{% macro lookup_tables() %} + {% for resource_id, database_schema_identifier in var('lookup_tables_args').items() %} {% set database, schema, identifier = database_schema_identifier %} {% set relation = adapter.get_relation(database, schema, identifier) %} @@ -22,6 +22,6 @@ {% do result.update({'columns': columns}) %} {% endif %} - {{ log(tojson( {'get_column_types': result} )) }} + {{ log(tojson( {'lookup_tables': result} )) }} {% endfor %} {% endmacro %} diff --git a/dbt_pumpkin/macros/resolve_adapter_name.sql b/dbt_pumpkin/macros/resolve_adapter_name.sql new file mode 100644 index 0000000..11fedd2 --- /dev/null +++ b/dbt_pumpkin/macros/resolve_adapter_name.sql @@ -0,0 +1,16 @@ +{% macro resolve_adapter_name() %} + {% set adapter_name = adapter.dispatch('resolve_adapter_name')() %} + {{ log(tojson( {'resolve_adapter_name': adapter_name} )) }} +{% endmacro %} + +{% macro default__resolve_adapter_name() -%} + {{ return('default') }} +{%- endmacro %} + +{% macro snowflake__resolve_adapter_name() %} + {{ return('snowflake') }} +{% endmacro %} + +{% macro duckdb__resolve_adapter_name() %} + {{ return('duckdb') }} +{% endmacro %} diff --git a/tests/test_loader.py b/tests/test_loader.py index 87694c4..60da7cc 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -411,3 +411,7 @@ def test_selected_resource_tables(loader_all): ], ), } + + +def test_resolve_adapter_name(loader_all): + assert loader_all.resolve_adapter_name() == "duckdb"