Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/decouple adapters from core #1026

Merged
merged 15 commits into from
Jan 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240102-152030.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Update base adapter references as part of decoupling migration
time: 2024-01-02T15:20:30.038221-08:00
custom:
Author: colin-rogers-dbt
Issue: "1067"
35 changes: 17 additions & 18 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,9 @@
from contextlib import contextmanager
from dataclasses import dataclass, field

from dbt.events.contextvars import get_node_info
from dbt.common.invocation import get_invocation_id

from dbt.common.events.contextvars import get_node_info
from mashumaro.helper import pass_through

from functools import lru_cache
@@ -25,23 +27,21 @@
)

from dbt.adapters.bigquery import gcloud
from dbt.clients import agate_helper
from dbt.config.profile import INVALID_PROFILE_MESSAGE
from dbt.tracking import active_user
from dbt.contracts.connection import ConnectionState, AdapterResponse
from dbt.exceptions import (
FailedToConnectError,
from dbt.common.clients import agate_helper
from dbt.adapters.contracts.connection import ConnectionState, AdapterResponse
from dbt.common.exceptions import (
DbtRuntimeError,
DbtDatabaseError,
DbtProfileError,
DbtConfigError,
)
from dbt.common.exceptions import DbtDatabaseError
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.base import BaseConnectionManager, Credentials
from dbt.events import AdapterLogger
from dbt.events.functions import fire_event
from dbt.events.types import SQLQuery
from dbt.version import __version__ as dbt_version
from dbt.adapters.events.logging import AdapterLogger
from dbt.adapters.events.types import SQLQuery
from dbt.common.events.functions import fire_event
from dbt.adapters.bigquery import __version__ as dbt_version

from dbt.dataclass_schema import ExtensibleDbtClassMixin, StrEnum
from dbt.common.dataclass_schema import ExtensibleDbtClassMixin, StrEnum

logger = AdapterLogger("BigQuery")

@@ -85,7 +85,7 @@ def get_bigquery_defaults(scopes=None) -> Tuple[Any, Optional[str]]:
credentials, _ = google.auth.default(scopes=scopes)
return credentials, _
except google.auth.exceptions.DefaultCredentialsError as e:
raise DbtProfileError(INVALID_PROFILE_MESSAGE.format(error_string=e))
raise DbtConfigError(f"Failed to authenticate with supplied credentials\nerror:\n{e}")


class Priority(StrEnum):
@@ -382,7 +382,7 @@ def get_bigquery_client(cls, profile_credentials):
execution_project = profile_credentials.execution_project
location = getattr(profile_credentials, "location", None)

info = client_info.ClientInfo(user_agent=f"dbt-{dbt_version}")
info = client_info.ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}")
return google.cloud.bigquery.Client(
execution_project,
creds,
@@ -470,8 +470,7 @@ def raw_execute(

labels = self.get_labels_from_query_comment()

if active_user:
labels["dbt_invocation_id"] = active_user.invocation_id
labels["dbt_invocation_id"] = get_invocation_id()

job_params = {
"use_legacy_sql": use_legacy_sql,
2 changes: 1 addition & 1 deletion dbt/adapters/bigquery/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List
from google.cloud.bigquery import Dataset, AccessEntry

from dbt.events import AdapterLogger
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("BigQuery")

8 changes: 4 additions & 4 deletions dbt/adapters/bigquery/gcloud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dbt.events import AdapterLogger
import dbt.exceptions
from dbt.clients.system import run_cmd
from dbt.adapters.events.logging import AdapterLogger
import dbt.common.exceptions
from dbt.common.clients.system import run_cmd

NOT_INSTALLED_MSG = """
dbt requires the gcloud SDK to be installed to authenticate with BigQuery.
@@ -25,4 +25,4 @@ def setup_default_credentials():
if gcloud_installed():
run_cmd(".", ["gcloud", "auth", "application-default", "login"])
else:
raise dbt.exceptions.DbtRuntimeError(NOT_INSTALLED_MSG)
raise dbt.common.exceptions.DbtRuntimeError(NOT_INSTALLED_MSG)
63 changes: 34 additions & 29 deletions dbt/adapters/bigquery/impl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from dataclasses import dataclass
import json
import threading
from multiprocessing.context import SpawnContext

import time
from typing import Any, Dict, List, Optional, Type, Set, Union
from typing import Any, Dict, List, Optional, Type, Set, Union, FrozenSet, Tuple, Iterable

import agate
from dbt import ui # type: ignore
from dbt.adapters.contracts.relation import RelationConfig

import dbt.common.exceptions.base
from dbt.adapters.base import ( # type: ignore
AdapterConfig,
BaseAdapter,
@@ -17,17 +21,15 @@
available,
)
from dbt.adapters.cache import _make_ref_key_dict # type: ignore
import dbt.clients.agate_helper
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint # type: ignore
from dbt.dataclass_schema import dbtClassMixin
import dbt.deprecations
from dbt.events import AdapterLogger
from dbt.events.functions import fire_event
from dbt.events.types import SchemaCreation, SchemaDrop
import dbt.exceptions
from dbt.utils import filter_null_values
import dbt.common.clients.agate_helper
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.common.contracts.constraints import ColumnLevelConstraint, ConstraintType, ModelLevelConstraint # type: ignore
from dbt.common.dataclass_schema import dbtClassMixin
from dbt.adapters.events.logging import AdapterLogger
from dbt.common.events.functions import fire_event
from dbt.adapters.events.types import SchemaCreation, SchemaDrop
import dbt.common.exceptions
from dbt.common.utils import filter_null_values
import google.api_core
import google.auth
import google.oauth2
@@ -116,8 +118,8 @@ class BigQueryAdapter(BaseAdapter):
ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
}

def __init__(self, config) -> None:
super().__init__(config)
def __init__(self, config, mp_context: SpawnContext) -> None:
super().__init__(config, mp_context)
self.connections: BigQueryConnectionManager = self.connections

###
@@ -145,7 +147,9 @@ def drop_relation(self, relation: BigQueryRelation) -> None:
conn.handle.delete_table(table_ref, not_found_ok=True)

def truncate_relation(self, relation: BigQueryRelation) -> None:
raise dbt.exceptions.NotImplementedError("`truncate` is not implemented for this adapter!")
raise dbt.common.exceptions.base.NotImplementedError(
"`truncate` is not implemented for this adapter!"
)

def rename_relation(
self, from_relation: BigQueryRelation, to_relation: BigQueryRelation
@@ -160,7 +164,7 @@ def rename_relation(
or from_relation.type == RelationType.View
or to_relation.type == RelationType.View
):
raise dbt.exceptions.DbtRuntimeError(
raise dbt.common.exceptions.DbtRuntimeError(
"Renaming of views is not currently supported in BigQuery"
)

@@ -386,7 +390,7 @@ def copy_table(self, source, destination, materialization):
elif materialization == "table":
write_disposition = WRITE_TRUNCATE
else:
raise dbt.exceptions.CompilationError(
raise dbt.common.exceptions.CompilationError(
'Copy table materialization must be "copy" or "table", but '
f"config.get('copy_materialization', 'table') was "
f"{materialization}"
@@ -433,11 +437,11 @@ def poll_until_job_completes(cls, job, timeout):
job.reload()

if job.state != "DONE":
raise dbt.exceptions.DbtRuntimeError("BigQuery Timeout Exceeded")
raise dbt.common.exceptions.DbtRuntimeError("BigQuery Timeout Exceeded")

elif job.error_result:
message = "\n".join(error["message"].strip() for error in job.errors)
raise dbt.exceptions.DbtRuntimeError(message)
raise dbt.common.exceptions.DbtRuntimeError(message)

def _bq_table_to_relation(self, bq_table) -> Union[BigQueryRelation, None]:
if bq_table is None:
@@ -454,15 +458,14 @@ def _bq_table_to_relation(self, bq_table) -> Union[BigQueryRelation, None]:
@classmethod
def warning_on_hooks(cls, hook_type):
msg = "{} is not supported in bigquery and will be ignored"
warn_msg = dbt.ui.color(msg, ui.COLOR_FG_YELLOW)
logger.info(warn_msg)
logger.info(msg)

@available
def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
if self.nice_connection_name() in ["on-run-start", "on-run-end"]:
self.warning_on_hooks(self.nice_connection_name())
else:
raise dbt.exceptions.NotImplementedError(
raise dbt.common.exceptions.base.NotImplementedError(
"`add_query` is not implemented for this adapter!"
)

@@ -679,14 +682,16 @@ def upload_file(
self.poll_until_job_completes(job, timeout)

@classmethod
def _catalog_filter_table(cls, table: agate.Table, manifest: Manifest) -> agate.Table:
def _catalog_filter_table(
cls, table: agate.Table, used_schemas: FrozenSet[Tuple[str, str]]
) -> agate.Table:
table = table.rename(
column_names={col.name: col.name.replace("__", ":") for col in table.columns}
)
return super()._catalog_filter_table(table, manifest)
return super()._catalog_filter_table(table, used_schemas)

def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
candidates = super()._get_catalog_schemas(manifest)
def _get_catalog_schemas(self, relation_config: Iterable[RelationConfig]) -> SchemaSearchMap:
candidates = super()._get_catalog_schemas(relation_config)
db_schemas: Dict[str, Set[str]] = {}
result = SchemaSearchMap()

@@ -772,7 +777,7 @@ def describe_relation(
bq_table = self.get_bq_table(relation)
parser = BigQueryMaterializedViewConfig
else:
raise dbt.exceptions.DbtRuntimeError(
raise dbt.common.exceptions.DbtRuntimeError(
f"The method `BigQueryAdapter.describe_relation` is not implemented "
f"for the relation type: {relation.type}"
)
@@ -838,7 +843,7 @@ def string_add_sql(
elif location == "prepend":
return f"concat('{value}', {add_to})"
else:
raise dbt.exceptions.DbtRuntimeError(
raise dbt.common.exceptions.DbtRuntimeError(
f'Got an unexpected location value of "{location}"'
)

2 changes: 1 addition & 1 deletion dbt/adapters/bigquery/python_submissions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from typing import Dict, Union

from dbt.events import AdapterLogger
from dbt.adapters.events.logging import AdapterLogger

from dbt.adapters.base import PythonJobHelper
from google.api_core.future.polling import POLLING_PREDICATE
18 changes: 8 additions & 10 deletions dbt/adapters/bigquery/relation.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
from typing import FrozenSet, Optional, TypeVar

from itertools import chain, islice
from dbt.context.providers import RuntimeConfigObject
from dbt.adapters.base.relation import BaseRelation, ComponentName, InformationSchema
from dbt.adapters.relation_configs import RelationConfigChangeAction
from dbt.adapters.bigquery.relation_configs import (
@@ -12,10 +11,9 @@
BigQueryOptionsConfigChange,
BigQueryPartitionConfigChange,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import RelationType
from dbt.exceptions import CompilationError
from dbt.utils import filter_null_values
from dbt.adapters.contracts.relation import RelationType, RelationConfig
from dbt.common.exceptions import CompilationError
from dbt.common.utils.dict import filter_null_values


Self = TypeVar("Self", bound="BigQueryRelation")
@@ -63,19 +61,19 @@ def dataset(self):
return self.schema

@classmethod
def materialized_view_from_model_node(
cls, model_node: ModelNode
def materialized_view_from_relation_config(
cls, relation_config: RelationConfig
) -> BigQueryMaterializedViewConfig:
return BigQueryMaterializedViewConfig.from_model_node(model_node) # type: ignore
return BigQueryMaterializedViewConfig.from_relation_config(relation_config) # type: ignore

@classmethod
def materialized_view_config_changeset(
cls,
existing_materialized_view: BigQueryMaterializedViewConfig,
runtime_config: RuntimeConfigObject,
relation_config: RelationConfig,
) -> Optional[BigQueryMaterializedViewConfigChangeset]:
config_change_collection = BigQueryMaterializedViewConfigChangeset()
new_materialized_view = cls.materialized_view_from_model_node(runtime_config.model)
new_materialized_view = cls.materialized_view_from_relation_config(relation_config)

if new_materialized_view.options != existing_materialized_view.options:
config_change_collection.options = BigQueryOptionsConfigChange(
18 changes: 9 additions & 9 deletions dbt/adapters/bigquery/relation_configs/_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Dict

import agate
from dbt.adapters.base.relation import Policy
from dbt.adapters.relation_configs import RelationConfigBase
from google.cloud.bigquery import Table as BigQueryTable
from typing_extensions import Self

from dbt.adapters.bigquery.relation_configs._policies import (
BigQueryIncludePolicy,
BigQueryQuotePolicy,
)
from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import ComponentName
from dbt.adapters.contracts.relation import ComponentName, RelationConfig


@dataclass(frozen=True, eq=True, unsafe_hash=True)
@@ -25,25 +25,25 @@ def quote_policy(cls) -> Policy:
return BigQueryQuotePolicy()

@classmethod
def from_model_node(cls, model_node: ModelNode) -> "BigQueryBaseRelationConfig":
relation_config = cls.parse_model_node(model_node)
relation = cls.from_dict(relation_config)
def from_relation_config(cls, relation_config: RelationConfig) -> Self:
relation_config_dict = cls.parse_relation_config(relation_config)
relation = cls.from_dict(relation_config_dict)
return relation # type: ignore

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> dict:
def parse_relation_config(cls, relation_config: RelationConfig) -> Dict:
raise NotImplementedError(
"`parse_model_node()` needs to be implemented on this RelationConfigBase instance"
)

@classmethod
def from_bq_table(cls, table: BigQueryTable) -> "BigQueryBaseRelationConfig":
def from_bq_table(cls, table: BigQueryTable) -> Self:
relation_config = cls.parse_bq_table(table)
relation = cls.from_dict(relation_config)
return relation # type: ignore

@classmethod
def parse_bq_table(cls, table: BigQueryTable) -> dict:
def parse_bq_table(cls, table: BigQueryTable) -> Dict:
raise NotImplementedError("`parse_bq_table()` is not implemented for this relation type")

@classmethod
12 changes: 6 additions & 6 deletions dbt/adapters/bigquery/relation_configs/_cluster.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,9 @@
from typing import Any, Dict, FrozenSet, Optional

from dbt.adapters.relation_configs import RelationConfigChange
from dbt.contracts.graph.nodes import ModelNode
from dbt.adapters.contracts.relation import RelationConfig
from google.cloud.bigquery import Table as BigQueryTable
from typing_extensions import Self

from dbt.adapters.bigquery.relation_configs._base import BigQueryBaseRelationConfig

@@ -22,16 +23,15 @@ class BigQueryClusterConfig(BigQueryBaseRelationConfig):
fields: FrozenSet[str]

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "BigQueryClusterConfig":
def from_dict(cls, config_dict: Dict[str, Any]) -> Self:
kwargs_dict = {"fields": config_dict.get("fields")}
cluster: "BigQueryClusterConfig" = super().from_dict(kwargs_dict) # type: ignore
return cluster
return super().from_dict(kwargs_dict) # type: ignore

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> Dict[str, Any]:
def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]:
config_dict = {}

if cluster_by := model_node.config.extra.get("cluster_by"):
if cluster_by := relation_config.config.extra.get("cluster_by"): # type: ignore
# users may input a single field as a string
if isinstance(cluster_by, str):
cluster_by = [cluster_by]
26 changes: 15 additions & 11 deletions dbt/adapters/bigquery/relation_configs/_materialized_view.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional

from dbt.contracts.graph.nodes import ModelNode
from dbt.contracts.relation import ComponentName
from dbt.adapters.contracts.relation import (
RelationConfig,
ComponentName,
)
from google.cloud.bigquery import Table as BigQueryTable

from dbt.adapters.bigquery.relation_configs._base import BigQueryBaseRelationConfig
@@ -63,21 +65,23 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> "BigQueryMaterializedViewConf
return materialized_view

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> Dict[str, Any]:
def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]:
config_dict = {
"table_id": model_node.identifier,
"dataset_id": model_node.schema,
"project_id": model_node.database,
"table_id": relation_config.identifier,
"dataset_id": relation_config.schema,
"project_id": relation_config.database,
# despite this being a foreign object, there will always be options because of defaults
"options": BigQueryOptionsConfig.parse_model_node(model_node),
"options": BigQueryOptionsConfig.parse_relation_config(relation_config),
}

# optional
if "partition_by" in model_node.config:
config_dict.update({"partition": PartitionConfig.parse_model_node(model_node)})
if "partition_by" in relation_config.config:
config_dict.update({"partition": PartitionConfig.parse_model_node(relation_config)})

if "cluster_by" in model_node.config:
config_dict.update({"cluster": BigQueryClusterConfig.parse_model_node(model_node)})
if "cluster_by" in relation_config.config:
config_dict.update(
{"cluster": BigQueryClusterConfig.parse_relation_config(relation_config)}
)

return config_dict

17 changes: 10 additions & 7 deletions dbt/adapters/bigquery/relation_configs/_options.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,9 @@
from typing import Any, Dict, Optional

from dbt.adapters.relation_configs import RelationConfigChange
from dbt.contracts.graph.nodes import ModelNode
from dbt.adapters.contracts.relation import RelationConfig
from google.cloud.bigquery import Table as BigQueryTable
from typing_extensions import Self

from dbt.adapters.bigquery.relation_configs._base import BigQueryBaseRelationConfig
from dbt.adapters.bigquery.utility import bool_setting, float_setting, sql_escape
@@ -78,7 +79,7 @@ def formatted_option(name: str) -> Optional[Any]:
return options

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "BigQueryOptionsConfig":
def from_dict(cls, config_dict: Dict[str, Any]) -> Self:
setting_formatters = {
"enable_refresh": bool_setting,
"refresh_interval_minutes": float_setting,
@@ -102,13 +103,13 @@ def formatted_setting(name: str) -> Any:
if kwargs_dict["enable_refresh"] is False:
kwargs_dict.update({"refresh_interval_minutes": None, "max_staleness": None})

options: "BigQueryOptionsConfig" = super().from_dict(kwargs_dict) # type: ignore
options: Self = super().from_dict(kwargs_dict) # type: ignore
return options

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> Dict[str, Any]:
def parse_relation_config(cls, relation_config: RelationConfig) -> Dict[str, Any]:
config_dict = {
option: model_node.config.extra.get(option)
option: relation_config.config.extra.get(option) # type: ignore
for option in [
"enable_refresh",
"refresh_interval_minutes",
@@ -121,11 +122,13 @@ def parse_model_node(cls, model_node: ModelNode) -> Dict[str, Any]:
}

# update dbt-specific versions of these settings
if hours_to_expiration := model_node.config.extra.get("hours_to_expiration"):
if hours_to_expiration := relation_config.config.extra.get( # type: ignore
"hours_to_expiration"
): # type: ignore
config_dict.update(
{"expiration_timestamp": datetime.now() + timedelta(hours=hours_to_expiration)}
)
if not model_node.config.persist_docs:
if not relation_config.config.persist_docs: # type: ignore
del config_dict["description"]

return config_dict
16 changes: 9 additions & 7 deletions dbt/adapters/bigquery/relation_configs/_partition.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import dbt.common.exceptions
from dbt.adapters.relation_configs import RelationConfigChange
from dbt.contracts.graph.nodes import ModelNode
from dbt.dataclass_schema import dbtClassMixin, ValidationError
import dbt.exceptions
from dbt.adapters.contracts.relation import RelationConfig
from dbt.common.dataclass_schema import dbtClassMixin, ValidationError
from google.cloud.bigquery.table import Table as BigQueryTable


@@ -92,24 +92,26 @@ def parse(cls, raw_partition_by) -> Optional["PartitionConfig"]:
}
)
except ValidationError as exc:
raise dbt.exceptions.DbtValidationError("Could not parse partition config") from exc
raise dbt.common.exceptions.base.DbtValidationError(
"Could not parse partition config"
) from exc
except TypeError:
raise dbt.exceptions.CompilationError(
raise dbt.common.exceptions.CompilationError(
f"Invalid partition_by config:\n"
f" Got: {raw_partition_by}\n"
f' Expected a dictionary with "field" and "data_type" keys'
)

@classmethod
def parse_model_node(cls, model_node: ModelNode) -> Dict[str, Any]:
def parse_model_node(cls, relation_config: RelationConfig) -> Dict[str, Any]:
"""
Parse model node into a raw config for `PartitionConfig.parse`
- Note:
This doesn't currently collect `time_ingestion_partitioning` and `copy_partitions`
because this was built for materialized views, which do not support those settings.
"""
config_dict = model_node.config.extra.get("partition_by")
config_dict = relation_config.config.extra.get("partition_by") # type: ignore
if "time_ingestion_partitioning" in config_dict:
del config_dict["time_ingestion_partitioning"]
if "copy_partitions" in config_dict:
4 changes: 2 additions & 2 deletions dbt/adapters/bigquery/utility.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Optional

import dbt.exceptions
import dbt.common.exceptions


def bool_setting(value: Optional[Any] = None) -> Optional[bool]:
@@ -41,5 +41,5 @@ def float_setting(value: Optional[Any] = None) -> Optional[float]:

def sql_escape(string):
if not isinstance(string, str):
raise dbt.exceptions.CompilationError(f"cannot escape a non-string: {string}")
raise dbt.common.exceptions.CompilationError(f"cannot escape a non-string: {string}")
return json.dumps(string)[1:-1]
Original file line number Diff line number Diff line change
@@ -20,6 +20,6 @@

{% macro bigquery__get_materialized_view_configuration_changes(existing_relation, new_config) %}
{% set _existing_materialized_view = adapter.describe_relation(existing_relation) %}
{% set _configuration_changes = existing_relation.materialized_view_config_changeset(_existing_materialized_view, new_config) %}
{% set _configuration_changes = existing_relation.materialized_view_config_changeset(_existing_materialized_view, new_config.model) %}
{% do return(_configuration_changes) %}
{% endmacro %}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{% macro bigquery__get_create_materialized_view_as_sql(relation, sql) %}

{%- set materialized_view = adapter.Relation.materialized_view_from_model_node(config.model) -%}
{%- set materialized_view = adapter.Relation.materialized_view_from_relation_config(config.model) -%}

create materialized view if not exists {{ relation }}
{% if materialized_view.partition %}{{ partition_by(materialized_view.partition) }}{% endif %}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{% macro bigquery__get_replace_materialized_view_as_sql(relation, sql) %}

{%- set materialized_view = adapter.Relation.materialized_view_from_model_node(config.model) -%}
{%- set materialized_view = adapter.Relation.materialized_view_from_relation_config(config.model) -%}

create or replace materialized view if not exists {{ relation }}
{% if materialized_view.partition %}{{ partition_by(materialized_view.partition) }}{% endif %}
1 change: 0 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
# TODO: how to automate switching from develop to version branches?
git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core
git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter

# if version 1.x or greater -> pin to major version
# if version 0.x -> pin to minor
black~=23.12
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from dbt.adapters.base.relation import BaseRelation
from dbt.contracts.relation import RelationType
from dbt.adapters.contracts.relation import RelationType
from dbt.tests.util import get_connection, run_dbt

from dbt.adapters.bigquery.relation_configs import BigQueryMaterializedViewConfig
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import pytest

from dbt.adapters.base.relation import BaseRelation
from dbt.contracts.relation import RelationType
from dbt.adapters.contracts.relation import RelationType
from dbt.tests.adapter.materialized_view.files import MY_TABLE, MY_VIEW
from dbt.tests.util import (
get_connection,
84 changes: 55 additions & 29 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from multiprocessing import get_context
from unittest import mock

import agate
import decimal
import string
@@ -7,21 +10,28 @@
import unittest
from unittest.mock import patch, MagicMock, create_autospec

import dbt.dataclass_schema
import dbt.common.dataclass_schema
import dbt.common.exceptions.base
from dbt.adapters.bigquery.relation_configs import PartitionConfig
from dbt.adapters.bigquery import BigQueryAdapter, BigQueryRelation
from dbt.adapters.bigquery import Plugin as BigQueryPlugin
from google.cloud.bigquery.table import Table
from dbt.adapters.bigquery.connections import _sanitize_label, _VALIDATE_LABEL_LENGTH_LIMIT
from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.clients import agate_helper
import dbt.exceptions
from dbt.common.clients import agate_helper
import dbt.common.exceptions
from dbt.context.manifest import generate_query_header_context
from dbt.contracts.files import FileHash
from dbt.contracts.graph.manifest import ManifestStateCheck
from dbt.logger import GLOBAL_LOGGER as logger # noqa
from dbt.context.providers import RuntimeConfigObject
from dbt.context.providers import RuntimeConfigObject, generate_runtime_macro_context

from google.cloud.bigquery import AccessEntry

from .utils import config_from_parts_or_dicts, inject_adapter, TestAdapterConversions
from .utils import (
config_from_parts_or_dicts,
inject_adapter,
TestAdapterConversions,
load_internal_manifest_macros,
)


def _bq_conn():
@@ -146,6 +156,21 @@ def setUp(self):
}
self.qh_patch = None

@mock.patch("dbt.parser.manifest.ManifestLoader.build_manifest_state_check")
def _mock_state_check(self):
all_projects = self.all_projects
return ManifestStateCheck(
vars_hash=FileHash.from_contents("vars"),
project_hashes={name: FileHash.from_contents(name) for name in all_projects},
profile_hash=FileHash.from_contents("profile"),
)

self.load_state_check = mock.patch(
"dbt.parser.manifest.ManifestLoader.build_manifest_state_check"
)
self.mock_state_check = self.load_state_check.start()
self.mock_state_check.side_effect = _mock_state_check

def tearDown(self):
if self.qh_patch:
self.qh_patch.stop()
@@ -155,20 +180,22 @@ def get_adapter(self, target) -> BigQueryAdapter:
project = self.project_cfg.copy()
profile = self.raw_profile.copy()
profile["target"] = target

config = config_from_parts_or_dicts(
project=project,
profile=profile,
)
adapter = BigQueryAdapter(config)

adapter.connections.query_header = MacroQueryStringSetter(config, MagicMock(macros={}))
adapter = BigQueryAdapter(config, get_context("spawn"))
adapter.set_macro_resolver(load_internal_manifest_macros(config))
adapter.set_macro_context_generator(generate_runtime_macro_context)
adapter.connections.set_query_header(
generate_query_header_context(config, adapter.get_macro_resolver())
)

self.qh_patch = patch.object(adapter.connections.query_header, "add")
self.mock_query_header_add = self.qh_patch.start()
self.mock_query_header_add.side_effect = lambda q: "/* dbt */\n{}".format(q)

inject_adapter(adapter, BigQueryPlugin)
inject_adapter(adapter)
return adapter


@@ -187,7 +214,7 @@ def test_acquire_connection_oauth_no_project_validations(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.exceptions.DbtValidationError as e:
except dbt.common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
@@ -204,7 +231,7 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection):
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.exceptions.DbtValidationError as e:
except dbt.common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
@@ -228,7 +255,7 @@ def test_acquire_connection_dataproc_serverless(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.exceptions.ValidationException as e:
except dbt.common.exceptions.ValidationException as e:
self.fail("got ValidationException: {}".format(str(e)))

except BaseException:
@@ -245,7 +272,7 @@ def test_acquire_connection_service_account_validations(self, mock_open_connecti
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.exceptions.DbtValidationError as e:
except dbt.common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
@@ -262,7 +289,7 @@ def test_acquire_connection_oauth_token_validations(self, mock_open_connection):
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.exceptions.DbtValidationError as e:
except dbt.common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
@@ -279,7 +306,7 @@ def test_acquire_connection_oauth_credentials_validations(self, mock_open_connec
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.exceptions.DbtValidationError as e:
except dbt.common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
@@ -298,7 +325,7 @@ def test_acquire_connection_impersonated_service_account_validations(
connection = adapter.acquire_connection("dummy")
self.assertEqual(connection.type, "bigquery")

except dbt.exceptions.DbtValidationError as e:
except dbt.common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

except BaseException:
@@ -316,7 +343,7 @@ def test_acquire_connection_priority(self, mock_open_connection):
self.assertEqual(connection.type, "bigquery")
self.assertEqual(connection.credentials.priority, "batch")

except dbt.exceptions.DbtValidationError as e:
except dbt.common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

mock_open_connection.assert_not_called()
@@ -331,7 +358,7 @@ def test_acquire_connection_maximum_bytes_billed(self, mock_open_connection):
self.assertEqual(connection.type, "bigquery")
self.assertEqual(connection.credentials.maximum_bytes_billed, 0)

except dbt.exceptions.DbtValidationError as e:
except dbt.common.exceptions.base.DbtValidationError as e:
self.fail("got DbtValidationError: {}".format(str(e)))

mock_open_connection.assert_not_called()
@@ -379,7 +406,7 @@ def test_location_user_agent(self, mock_bq, mock_auth_default):


class HasUserAgent:
PAT = re.compile(r"dbt-\d+\.\d+\.\d+((a|b|rc)\d+)?")
PAT = re.compile(r"dbt-bigquery-\d+\.\d+\.\d+((a|b|rc)\d+)?")

def __eq__(self, other):
compare = getattr(other, "user_agent", "")
@@ -482,7 +509,7 @@ def test_invalid_relation(self):
},
"quote_policy": {"identifier": False, "schema": True},
}
with self.assertRaises(dbt.dataclass_schema.ValidationError):
with self.assertRaises(dbt.common.dataclass_schema.ValidationError):
BigQueryRelation.validate(kwargs)


@@ -554,10 +581,10 @@ def test_copy_table_materialization_incremental(self):
def test_parse_partition_by(self):
adapter = self.get_adapter("oauth")

with self.assertRaises(dbt.exceptions.DbtValidationError):
with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
adapter.parse_partition_by("date(ts)")

with self.assertRaises(dbt.exceptions.DbtValidationError):
with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
adapter.parse_partition_by("ts")

self.assertEqual(
@@ -709,7 +736,7 @@ def test_parse_partition_by(self):
)

# Invalid, should raise an error
with self.assertRaises(dbt.exceptions.DbtValidationError):
with self.assertRaises(dbt.common.exceptions.base.DbtValidationError):
adapter.parse_partition_by({})

# passthrough
@@ -778,8 +805,7 @@ def test_view_kms_key_name(self):

class TestBigQueryFilterCatalog(unittest.TestCase):
def test__catalog_filter_table(self):
manifest = MagicMock()
manifest.get_used_schemas.return_value = [["a", "B"], ["a", "1234"]]
used_schemas = [["a", "B"], ["a", "1234"]]
column_names = ["table_name", "table_database", "table_schema", "something"]
rows = [
["foo", "a", "b", "1234"], # include
@@ -789,7 +815,7 @@ def test__catalog_filter_table(self):
]
table = agate.Table(rows, column_names, agate_helper.DEFAULT_TYPE_TESTER)

result = BigQueryAdapter._catalog_filter_table(table, manifest)
result = BigQueryAdapter._catalog_filter_table(table, used_schemas)
assert len(result) == 3
for row in result.rows:
assert isinstance(row["table_schema"], str)
8 changes: 4 additions & 4 deletions tests/unit/test_bigquery_connection_manager.py
Original file line number Diff line number Diff line change
@@ -6,20 +6,20 @@
from requests.exceptions import ConnectionError
from unittest.mock import patch, MagicMock, Mock, ANY

import dbt.dataclass_schema
import dbt.common.dataclass_schema

from dbt.adapters.bigquery import BigQueryCredentials
from dbt.adapters.bigquery import BigQueryRelation
from dbt.adapters.bigquery.connections import BigQueryConnectionManager
import dbt.exceptions
import dbt.common.exceptions
from dbt.logger import GLOBAL_LOGGER as logger # noqa


class TestBigQueryConnectionManager(unittest.TestCase):
def setUp(self):
credentials = Mock(BigQueryCredentials)
profile = Mock(query_comment=None, credentials=credentials)
self.connections = BigQueryConnectionManager(profile=profile)
self.connections = BigQueryConnectionManager(profile=profile, mp_context=Mock())

self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client)
self.mock_connection = MagicMock()
@@ -127,7 +127,7 @@ def test_query_and_results_timeout(self, mock_bq):
self.mock_client.query = Mock(
return_value=Mock(result=lambda *args, **kwargs: time.sleep(4))
)
with pytest.raises(dbt.exceptions.DbtRuntimeError) as exc:
with pytest.raises(dbt.common.exceptions.DbtRuntimeError) as exc:
self.connections._query_and_results(
self.mock_client,
"sql",
32 changes: 15 additions & 17 deletions tests/unit/utils.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,8 @@

import agate
import pytest
from dbt.dataclass_schema import ValidationError

from dbt.common.dataclass_schema import ValidationError
from dbt.config.project import PartialProject


@@ -123,19 +124,17 @@ def inject_plugin(plugin):


def inject_plugin_for(config):
# from dbt.adapters.postgres import Plugin, PostgresAdapter
from dbt.adapters.factory import FACTORY

FACTORY.load_plugin(config.credentials.type)
adapter = FACTORY.get_adapter(config)
return adapter


def inject_adapter(value, plugin):
def inject_adapter(value):
"""Inject the given adapter into the adapter factory, so your hand-crafted
artisanal adapter will be available from get_adapter() as if dbt loaded it.
"""
inject_plugin(plugin)
from dbt.adapters.factory import FACTORY

key = value.type()
@@ -229,7 +228,7 @@ def assert_fails_validation(dct, cls):


def generate_name_macros(package):
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.contracts.graph.nodes import Macro
from dbt.node_types import NodeType

name_sql = {}
@@ -243,13 +242,12 @@ def generate_name_macros(package):
name_sql[name] = sql

for name, sql in name_sql.items():
pm = ParsedMacro(
pm = Macro(
name=name,
resource_type=NodeType.Macro,
unique_id=f"macro.{package}.{name}",
package_name=package,
original_file_path=normalize("macros/macro.sql"),
root_path="./dbt_packages/root",
path=normalize("macros/macro.sql"),
macro_sql=sql,
)
@@ -258,7 +256,7 @@ def generate_name_macros(package):

class TestAdapterConversions(TestCase):
def _get_tester_for(self, column_type):
from dbt.clients import agate_helper
from dbt.common.clients import agate_helper

if column_type is agate.TimeDelta: # dbt never makes this!
return agate.TimeDelta()
@@ -280,7 +278,7 @@ def _make_table_of(self, rows, column_types):


def MockMacro(package, name="my_macro", **kwargs):
from dbt.contracts.graph.parsed import ParsedMacro
from dbt.contracts.graph.nodes import Macro
from dbt.node_types import NodeType

mock_kwargs = dict(
@@ -292,7 +290,7 @@ def MockMacro(package, name="my_macro", **kwargs):

mock_kwargs.update(kwargs)

macro = mock.MagicMock(spec=ParsedMacro, **mock_kwargs)
macro = mock.MagicMock(spec=Macro, **mock_kwargs)
macro.name = name
return macro

@@ -311,10 +309,10 @@ def MockGenerateMacro(package, component="some_component", **kwargs):

def MockSource(package, source_name, name, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedSourceDefinition
from dbt.contracts.graph.nodes import SourceDefinition

src = mock.MagicMock(
__class__=ParsedSourceDefinition,
__class__=SourceDefinition,
resource_type=NodeType.Source,
source_name=source_name,
package_name=package,
@@ -328,14 +326,14 @@ def MockSource(package, source_name, name, **kwargs):

def MockNode(package, name, resource_type=None, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedModelNode, ParsedSeedNode
from dbt.contracts.graph.nodes import ModelNode, SeedNode

if resource_type is None:
resource_type = NodeType.Model
if resource_type == NodeType.Model:
cls = ParsedModelNode
cls = ModelNode
elif resource_type == NodeType.Seed:
cls = ParsedSeedNode
cls = SeedNode
else:
raise ValueError(f"I do not know how to handle {resource_type}")
node = mock.MagicMock(
@@ -352,10 +350,10 @@ def MockNode(package, name, resource_type=None, **kwargs):

def MockDocumentation(package, name, **kwargs):
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedDocumentation
from dbt.contracts.graph.nodes import Documentation

doc = mock.MagicMock(
__class__=ParsedDocumentation,
__class__=Documentation,
resource_type=NodeType.Documentation,
package_name=package,
search_name=name,
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ passenv =
DD_ENV
DD_SERVICE
commands =
bigquery: {envpython} -m pytest {posargs} -vv tests/functional -k "not TestPython" --profile service_account
bigquery: {envpython} -m pytest -n auto {posargs} -vv tests/functional -k "not TestPython" --profile service_account
deps =
-rdev-requirements.txt
-e.