diff --git a/recce/adapter/dbt_adapter.py b/recce/adapter/dbt_adapter/__init__.py similarity index 87% rename from recce/adapter/dbt_adapter.py rename to recce/adapter/dbt_adapter/__init__.py index 8b331dec..a2fcfea1 100644 --- a/recce/adapter/dbt_adapter.py +++ b/recce/adapter/dbt_adapter/__init__.py @@ -3,13 +3,13 @@ import os import uuid from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import Callable, Dict, List, Optional, Tuple, Iterator, Any import agate import dbt.adapters.factory + # Reference: https://github.com/AltimateAI/vscode-dbt-power-user/blob/master/dbt_core_integration.py -from dbt.mp_context import get_mp_context get_adapter_orig = dbt.adapters.factory.get_adapter @@ -27,7 +27,15 @@ def get_adapter(config): from dbt.adapters.factory import get_adapter_class_by_name from dbt.adapters.sql import SQLAdapter from dbt.config.runtime import RuntimeConfig -from dbt.adapters.contracts.connection import Connection +from .dbt_version import DbtVersion + +dbt_version = DbtVersion() + +if dbt_version < 'v1.8': + from dbt.contracts.connection import Connection +else: + from dbt.adapters.contracts.connection import Connection + from dbt.contracts.graph.manifest import Manifest, WritableManifest, MacroManifest from dbt.contracts.graph.nodes import ManifestNode from dbt.contracts.results import CatalogArtifact @@ -66,53 +74,23 @@ def on_created(self, event): self.callback(event) -class DbtVersionTool: - - def __init__(self): - from dbt import version as dbt_version - self.dbt_version = self.parse(dbt_version.__version__) - - @staticmethod - def parse(version: str): - from packaging import version as v - return v.parse(version) - - def as_version(self, other): - from packaging.version import Version - if isinstance(other, Version): - return other - if isinstance(other, str): - return self.parse(other) - return self.parse(str(other)) - - def __ge__(self, other): - return self.dbt_version >= self.as_version(other) - - def __gt__(self, other): - return self.dbt_version > self.as_version(other) - - def __lt__(self, other): - return self.dbt_version < self.as_version(other) - - def __le__(self, other): - return self.dbt_version <= self.as_version(other) - - def __eq__(self, other): - return self.dbt_version.release[:2] == self.as_version(other).release[:2] - - def __str__(self): - return ".".join([str(x) for x in list(self.dbt_version.release)]) - - -dbt_version = DbtVersionTool() +def merge_tables(tables: List[agate.Table]) -> agate.Table: + if dbt_version < 'v1.8': + from dbt.clients.agate_helper import merge_tables + return merge_tables(tables) + else: + from dbt_common.clients.agate_helper import _merged_column_types + return _merged_column_types(tables) def as_manifest(m: WritableManifest) -> Manifest: - # data = m.__dict__ - # all_fields = set([x.name for x in fields(Manifest)]) - # new_data = {k: v for k, v in data.items() if k in all_fields} - # return Manifest(**new_data) - return Manifest.from_writable_manifest(m) + if dbt_version < 'v1.8': + data = m.__dict__ + all_fields = set([x.name for x in fields(Manifest)]) + new_data = {k: v for k, v in data.items() if k in all_fields} + return Manifest(**new_data) + else: + return Manifest.from_writable_manifest(m) def load_manifest(path: str = None, data: dict = None): @@ -199,16 +177,23 @@ def load(cls, artifacts: ArtifactsRoot = None, **kwargs): from dbt.exceptions import DbtProjectError try: - from dbt_common.context import set_invocation_context, get_invocation_context + # adapter + if dbt_version < 'v1.8': + runtime_config = RuntimeConfig.from_args(args) + adapter_name = runtime_config.credentials.type + adapter_cls = get_adapter_class_by_name(adapter_name) + adapter: SQLAdapter = adapter_cls(runtime_config) + else: + from dbt_common.context import set_invocation_context, get_invocation_context + from dbt.mp_context import get_mp_context - set_invocation_context({}) - get_invocation_context()._env = dict(os.environ) - runtime_config = RuntimeConfig.from_args(args) + set_invocation_context({}) + get_invocation_context()._env = dict(os.environ) + runtime_config = RuntimeConfig.from_args(args) + adapter_name = runtime_config.credentials.type + adapter_cls = get_adapter_class_by_name(adapter_name) + adapter: SQLAdapter = adapter_cls(runtime_config, get_mp_context()) - # adapter - adapter_name = runtime_config.credentials.type - adapter_cls = get_adapter_class_by_name(adapter_name) - adapter: SQLAdapter = adapter_cls(runtime_config, get_mp_context()) adapter.connections.set_connection_name() runtime_config.adapter = adapter @@ -247,21 +232,20 @@ def print_lineage_info(self): print(f" Catalog: {self.curr_catalog.metadata.generated_at if self.curr_catalog else 'N/A'}") def get_columns(self, model: str, base=False) -> List[Column]: - from dbt.context.providers import generate_runtime_macro_context - relation = self.create_relation(model, base) - - macro_manifest = MacroManifest(self.manifest.macros) - self.adapter.set_macro_resolver(macro_manifest) - self.adapter.set_macro_context_generator(generate_runtime_macro_context) - return self.adapter.execute_macro( - 'get_columns_in_relation', - kwargs={"relation": relation}) - - # return self.adapter.execute_macro( - # 'get_columns_in_relation', - # kwargs={"relation": relation}, - # manifest=self.manifest) + if dbt_version < 'v1.8': + return self.adapter.execute_macro( + 'get_columns_in_relation', + kwargs={"relation": relation}, + manifest=self.manifest) + else: + from dbt.context.providers import generate_runtime_macro_context + macro_manifest = MacroManifest(self.manifest.macros) + self.adapter.set_macro_resolver(macro_manifest) + self.adapter.set_macro_context_generator(generate_runtime_macro_context) + return self.adapter.execute_macro( + 'get_columns_in_relation', + kwargs={"relation": relation}) def get_model(self, model_id: str, base=False): manifest = self.curr_manifest if base is False else self.base_manifest @@ -381,16 +365,17 @@ def generate_sql(self, sql_template: str, base: bool = False, context: Dict = {} node = parser.parse_remote(sql_template, node_id) process_node(self.runtime_config, manifest, node) - # compiler = self.adapter.get_compiler() - # compiler.compile_node(node, manifest, context) - from dbt.context.providers import generate_runtime_model_context - from dbt.clients import jinja - jinja_ctx = generate_runtime_model_context(node, self.runtime_config, manifest) - jinja_ctx.update(context) - compiled_code = jinja.get_rendered(sql_template, jinja_ctx, node) - - # return node.compiled_code - return compiled_code + if dbt_version < dbt_version.parse('v1.8'): + compiler = self.adapter.get_compiler() + compiler.compile_node(node, manifest, context) + return node.compiled_code + else: + from dbt.context.providers import generate_runtime_model_context + from dbt.clients import jinja + jinja_ctx = generate_runtime_model_context(node, self.runtime_config, manifest) + jinja_ctx.update(context) + compiled_code = jinja.get_rendered(sql_template, jinja_ctx, node) + return compiled_code def execute(self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None) -> Tuple[ any, agate.Table]: diff --git a/recce/adapter/dbt_adapter/dbt_version.py b/recce/adapter/dbt_adapter/dbt_version.py new file mode 100644 index 00000000..9f74970d --- /dev/null +++ b/recce/adapter/dbt_adapter/dbt_version.py @@ -0,0 +1,36 @@ +class DbtVersion: + + def __init__(self): + from dbt import version as dbt_version + self.dbt_version = self.parse(dbt_version.__version__) + + @staticmethod + def parse(version: str): + from packaging import version as v + return v.parse(version) + + def as_version(self, other): + from packaging.version import Version + if isinstance(other, Version): + return other + if isinstance(other, str): + return self.parse(other) + return self.parse(str(other)) + + def __ge__(self, other): + return self.dbt_version >= self.as_version(other) + + def __gt__(self, other): + return self.dbt_version > self.as_version(other) + + def __lt__(self, other): + return self.dbt_version < self.as_version(other) + + def __le__(self, other): + return self.dbt_version <= self.as_version(other) + + def __eq__(self, other): + return self.dbt_version.release[:2] == self.as_version(other).release[:2] + + def __str__(self): + return ".".join([str(x) for x in list(self.dbt_version.release)]) \ No newline at end of file diff --git a/recce/apis/run_func.py b/recce/apis/run_func.py index 859c893b..2ab0940b 100644 --- a/recce/apis/run_func.py +++ b/recce/apis/run_func.py @@ -81,7 +81,11 @@ def fn(): asyncio.run_coroutine_threadsafe(update_run_result(run.run_id, result, None), loop) return result except BaseException as e: - from dbt_common.exceptions import DbtDatabaseError + from recce.adapter.dbt_adapter import dbt_version + if dbt_version < 'v1.8': + from dbt.exceptions import DbtDatabaseError + else: + from dbt_common.exceptions import DbtDatabaseError if isinstance(e, DbtDatabaseError): if str(e).find('100051') and run.type == RunType.PROFILE_DIFF: # Snowflake error '100051 (22012): Division by zero"' diff --git a/recce/tasks/dataframe.py b/recce/tasks/dataframe.py index bb691d79..19f5ddd8 100644 --- a/recce/tasks/dataframe.py +++ b/recce/tasks/dataframe.py @@ -32,15 +32,18 @@ class DataFrame(BaseModel): @staticmethod def from_agate(table: 'agate.Table', limit: t.Optional[int] = None, more: t.Optional[bool] = None): - # import dbt.clients.agate_helper + from recce.adapter.dbt_adapter import dbt_version + if dbt_version < 'v1.8': + import dbt.clients.agate_helper as agate_helper + else: + import dbt_common.clients.agate_helper as agate_helper - import dbt_common.clients.agate_helper import agate columns = [] for col_name, col_type in zip(table.column_names, table.column_types): - has_integer = hasattr(dbt_common.clients.agate_helper, 'Integer') + has_integer = hasattr(agate_helper, 'Integer') if isinstance(col_type, agate.Number): col_type = DataFrameColumnType.NUMBER diff --git a/recce/tasks/profile.py b/recce/tasks/profile.py index c543eab1..d8750240 100644 --- a/recce/tasks/profile.py +++ b/recce/tasks/profile.py @@ -1,12 +1,11 @@ from typing import TypedDict, List import agate -from dbt_common.clients.agate_helper import merge_tables from pydantic import BaseModel +from recce.adapter.dbt_adapter import DbtAdapter, merge_tables from .core import Task, TaskResultDiffer from .dataframe import DataFrame -from ..adapter.dbt_adapter import DbtAdapter from ..core import default_context from ..exceptions import RecceException @@ -100,7 +99,6 @@ def _profile_column(self, dbt_adapter: DbtAdapter, relation, column): base=False, # always false because we use the macro in current manifest context=dict(relation=relation, column_name=column_name, column_type=column_type) ) - print(sql) except Exception as e: raise RecceException(f"Failed to generate SQL for profiling column: {column_name}") from e diff --git a/recce/tasks/query.py b/recce/tasks/query.py index c688b384..3b99ad1f 100644 --- a/recce/tasks/query.py +++ b/recce/tasks/query.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Optional, Tuple, Union, List +from typing import TypedDict, Optional, Tuple, List import agate from pydantic import BaseModel @@ -76,7 +76,7 @@ def __init__(self, params: QueryParams): self.connection = None def execute_dbt(self): - from ..adapter.dbt_adapter import DbtAdapter + from recce.adapter.dbt_adapter import DbtAdapter dbt_adapter: DbtAdapter = default_context().adapter limit = QUERY_LIMIT @@ -189,7 +189,7 @@ def _query_diff_join(self, dbt_adapter, sql_template: str, primary_keys: List[st ) def execute_dbt(self): - from ..adapter.dbt_adapter import DbtAdapter + from recce.adapter.dbt_adapter import DbtAdapter dbt_adapter: DbtAdapter = default_context().adapter with dbt_adapter.connection_named("query"):