Skip to content

Commit

Permalink
Support dbt 1.8
Browse files Browse the repository at this point in the history
Signed-off-by: popcorny <[email protected]>
  • Loading branch information
popcornylu committed May 7, 2024
1 parent 5d7699b commit d52f048
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 89 deletions.
143 changes: 64 additions & 79 deletions recce/adapter/dbt_adapter.py → recce/adapter/dbt_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import os
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Callable, Dict, List, Optional, Tuple, Iterator, Any

import agate
import dbt.adapters.factory

# Reference: https://github.com/AltimateAI/vscode-dbt-power-user/blob/master/dbt_core_integration.py
from dbt.mp_context import get_mp_context

get_adapter_orig = dbt.adapters.factory.get_adapter

Expand All @@ -27,7 +27,15 @@ def get_adapter(config):
from dbt.adapters.factory import get_adapter_class_by_name
from dbt.adapters.sql import SQLAdapter
from dbt.config.runtime import RuntimeConfig
from dbt.adapters.contracts.connection import Connection
from .dbt_version import DbtVersion

dbt_version = DbtVersion()

if dbt_version < 'v1.8':
from dbt.contracts.connection import Connection
else:
from dbt.adapters.contracts.connection import Connection

from dbt.contracts.graph.manifest import Manifest, WritableManifest, MacroManifest
from dbt.contracts.graph.nodes import ManifestNode
from dbt.contracts.results import CatalogArtifact
Expand Down Expand Up @@ -66,53 +74,23 @@ def on_created(self, event):
self.callback(event)


class DbtVersionTool:

def __init__(self):
from dbt import version as dbt_version
self.dbt_version = self.parse(dbt_version.__version__)

@staticmethod
def parse(version: str):
from packaging import version as v
return v.parse(version)

def as_version(self, other):
from packaging.version import Version
if isinstance(other, Version):
return other
if isinstance(other, str):
return self.parse(other)
return self.parse(str(other))

def __ge__(self, other):
return self.dbt_version >= self.as_version(other)

def __gt__(self, other):
return self.dbt_version > self.as_version(other)

def __lt__(self, other):
return self.dbt_version < self.as_version(other)

def __le__(self, other):
return self.dbt_version <= self.as_version(other)

def __eq__(self, other):
return self.dbt_version.release[:2] == self.as_version(other).release[:2]

def __str__(self):
return ".".join([str(x) for x in list(self.dbt_version.release)])


dbt_version = DbtVersionTool()
def merge_tables(tables: List[agate.Table]) -> agate.Table:
if dbt_version < 'v1.8':
from dbt.clients.agate_helper import merge_tables
return merge_tables(tables)
else:
from dbt_common.clients.agate_helper import _merged_column_types
return _merged_column_types(tables)


def as_manifest(m: WritableManifest) -> Manifest:
# data = m.__dict__
# all_fields = set([x.name for x in fields(Manifest)])
# new_data = {k: v for k, v in data.items() if k in all_fields}
# return Manifest(**new_data)
return Manifest.from_writable_manifest(m)
if dbt_version < 'v1.8':
data = m.__dict__
all_fields = set([x.name for x in fields(Manifest)])
new_data = {k: v for k, v in data.items() if k in all_fields}
return Manifest(**new_data)
else:
return Manifest.from_writable_manifest(m)


def load_manifest(path: str = None, data: dict = None):
Expand Down Expand Up @@ -199,16 +177,23 @@ def load(cls, artifacts: ArtifactsRoot = None, **kwargs):

from dbt.exceptions import DbtProjectError
try:
from dbt_common.context import set_invocation_context, get_invocation_context
# adapter
if dbt_version < 'v1.8':
runtime_config = RuntimeConfig.from_args(args)
adapter_name = runtime_config.credentials.type
adapter_cls = get_adapter_class_by_name(adapter_name)
adapter: SQLAdapter = adapter_cls(runtime_config)
else:
from dbt_common.context import set_invocation_context, get_invocation_context
from dbt.mp_context import get_mp_context

set_invocation_context({})
get_invocation_context()._env = dict(os.environ)
runtime_config = RuntimeConfig.from_args(args)
set_invocation_context({})
get_invocation_context()._env = dict(os.environ)
runtime_config = RuntimeConfig.from_args(args)
adapter_name = runtime_config.credentials.type
adapter_cls = get_adapter_class_by_name(adapter_name)
adapter: SQLAdapter = adapter_cls(runtime_config, get_mp_context())

# adapter
adapter_name = runtime_config.credentials.type
adapter_cls = get_adapter_class_by_name(adapter_name)
adapter: SQLAdapter = adapter_cls(runtime_config, get_mp_context())
adapter.connections.set_connection_name()
runtime_config.adapter = adapter

Expand Down Expand Up @@ -247,21 +232,20 @@ def print_lineage_info(self):
print(f" Catalog: {self.curr_catalog.metadata.generated_at if self.curr_catalog else 'N/A'}")

def get_columns(self, model: str, base=False) -> List[Column]:
from dbt.context.providers import generate_runtime_macro_context

relation = self.create_relation(model, base)

macro_manifest = MacroManifest(self.manifest.macros)
self.adapter.set_macro_resolver(macro_manifest)
self.adapter.set_macro_context_generator(generate_runtime_macro_context)
return self.adapter.execute_macro(
'get_columns_in_relation',
kwargs={"relation": relation})

# return self.adapter.execute_macro(
# 'get_columns_in_relation',
# kwargs={"relation": relation},
# manifest=self.manifest)
if dbt_version < 'v1.8':
return self.adapter.execute_macro(
'get_columns_in_relation',
kwargs={"relation": relation},
manifest=self.manifest)
else:
from dbt.context.providers import generate_runtime_macro_context
macro_manifest = MacroManifest(self.manifest.macros)
self.adapter.set_macro_resolver(macro_manifest)
self.adapter.set_macro_context_generator(generate_runtime_macro_context)
return self.adapter.execute_macro(
'get_columns_in_relation',
kwargs={"relation": relation})

def get_model(self, model_id: str, base=False):
manifest = self.curr_manifest if base is False else self.base_manifest
Expand Down Expand Up @@ -381,16 +365,17 @@ def generate_sql(self, sql_template: str, base: bool = False, context: Dict = {}
node = parser.parse_remote(sql_template, node_id)
process_node(self.runtime_config, manifest, node)

# compiler = self.adapter.get_compiler()
# compiler.compile_node(node, manifest, context)
from dbt.context.providers import generate_runtime_model_context
from dbt.clients import jinja
jinja_ctx = generate_runtime_model_context(node, self.runtime_config, manifest)
jinja_ctx.update(context)
compiled_code = jinja.get_rendered(sql_template, jinja_ctx, node)

# return node.compiled_code
return compiled_code
if dbt_version < dbt_version.parse('v1.8'):
compiler = self.adapter.get_compiler()
compiler.compile_node(node, manifest, context)
return node.compiled_code
else:
from dbt.context.providers import generate_runtime_model_context
from dbt.clients import jinja
jinja_ctx = generate_runtime_model_context(node, self.runtime_config, manifest)
jinja_ctx.update(context)
compiled_code = jinja.get_rendered(sql_template, jinja_ctx, node)
return compiled_code

def execute(self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None) -> Tuple[
any, agate.Table]:
Expand Down
36 changes: 36 additions & 0 deletions recce/adapter/dbt_adapter/dbt_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
class DbtVersion:

def __init__(self):
from dbt import version as dbt_version
self.dbt_version = self.parse(dbt_version.__version__)

@staticmethod
def parse(version: str):
from packaging import version as v
return v.parse(version)

def as_version(self, other):
from packaging.version import Version
if isinstance(other, Version):
return other
if isinstance(other, str):
return self.parse(other)
return self.parse(str(other))

def __ge__(self, other):
return self.dbt_version >= self.as_version(other)

def __gt__(self, other):
return self.dbt_version > self.as_version(other)

def __lt__(self, other):
return self.dbt_version < self.as_version(other)

def __le__(self, other):
return self.dbt_version <= self.as_version(other)

def __eq__(self, other):
return self.dbt_version.release[:2] == self.as_version(other).release[:2]

def __str__(self):
return ".".join([str(x) for x in list(self.dbt_version.release)])
6 changes: 5 additions & 1 deletion recce/apis/run_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ def fn():
asyncio.run_coroutine_threadsafe(update_run_result(run.run_id, result, None), loop)
return result
except BaseException as e:
from dbt_common.exceptions import DbtDatabaseError
from recce.adapter.dbt_adapter import dbt_version
if dbt_version < 'v1.8':
from dbt.exceptions import DbtDatabaseError
else:
from dbt_common.exceptions import DbtDatabaseError
if isinstance(e, DbtDatabaseError):
if str(e).find('100051') and run.type == RunType.PROFILE_DIFF:
# Snowflake error '100051 (22012): Division by zero"'
Expand Down
9 changes: 6 additions & 3 deletions recce/tasks/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@ class DataFrame(BaseModel):

@staticmethod
def from_agate(table: 'agate.Table', limit: t.Optional[int] = None, more: t.Optional[bool] = None):
# import dbt.clients.agate_helper
from recce.adapter.dbt_adapter import dbt_version
if dbt_version < 'v1.8':
import dbt.clients.agate_helper as agate_helper
else:
import dbt_common.clients.agate_helper as agate_helper

import dbt_common.clients.agate_helper
import agate
columns = []

for col_name, col_type in zip(table.column_names, table.column_types):

has_integer = hasattr(dbt_common.clients.agate_helper, 'Integer')
has_integer = hasattr(agate_helper, 'Integer')

if isinstance(col_type, agate.Number):
col_type = DataFrameColumnType.NUMBER
Expand Down
4 changes: 1 addition & 3 deletions recce/tasks/profile.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import TypedDict, List

import agate
from dbt_common.clients.agate_helper import merge_tables
from pydantic import BaseModel

from recce.adapter.dbt_adapter import DbtAdapter, merge_tables
from .core import Task, TaskResultDiffer
from .dataframe import DataFrame
from ..adapter.dbt_adapter import DbtAdapter
from ..core import default_context
from ..exceptions import RecceException

Expand Down Expand Up @@ -100,7 +99,6 @@ def _profile_column(self, dbt_adapter: DbtAdapter, relation, column):
base=False, # always false because we use the macro in current manifest
context=dict(relation=relation, column_name=column_name, column_type=column_type)
)
print(sql)
except Exception as e:
raise RecceException(f"Failed to generate SQL for profiling column: {column_name}") from e

Expand Down
6 changes: 3 additions & 3 deletions recce/tasks/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypedDict, Optional, Tuple, Union, List
from typing import TypedDict, Optional, Tuple, List

import agate
from pydantic import BaseModel
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(self, params: QueryParams):
self.connection = None

def execute_dbt(self):
from ..adapter.dbt_adapter import DbtAdapter
from recce.adapter.dbt_adapter import DbtAdapter
dbt_adapter: DbtAdapter = default_context().adapter

limit = QUERY_LIMIT
Expand Down Expand Up @@ -189,7 +189,7 @@ def _query_diff_join(self, dbt_adapter, sql_template: str, primary_keys: List[st
)

def execute_dbt(self):
from ..adapter.dbt_adapter import DbtAdapter
from recce.adapter.dbt_adapter import DbtAdapter
dbt_adapter: DbtAdapter = default_context().adapter

with dbt_adapter.connection_named("query"):
Expand Down

0 comments on commit d52f048

Please sign in to comment.