Skip to content

Commit

Permalink
Merge branch 'main' into require-alias-false
Browse files Browse the repository at this point in the history
  • Loading branch information
Dewwi authored Nov 26, 2024
2 parents 5967de5 + a425cc6 commit 8b01938
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 102 deletions.
6 changes: 3 additions & 3 deletions dbt-athena/test.env.example → dbt-athena/.env.example
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
DBT_TEST_ATHENA_S3_STAGING_DIR=
DBT_TEST_ATHENA_S3_TMP_TABLE_DIR=
DBT_TEST_ATHENA_REGION_NAME=
DBT_TEST_ATHENA_THREADS=
DBT_TEST_ATHENA_POLL_INTERVAL=
DBT_TEST_ATHENA_DATABASE=
DBT_TEST_ATHENA_SCHEMA=
DBT_TEST_ATHENA_WORK_GROUP=
DBT_TEST_ATHENA_THREADS=
DBT_TEST_ATHENA_POLL_INTERVAL=
DBT_TEST_ATHENA_NUM_RETRIES=
DBT_TEST_ATHENA_AWS_PROFILE_NAME=
DBT_TEST_ATHENA_SPARK_WORK_GROUP=
1 change: 0 additions & 1 deletion dbt-athena/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ check-sdist = [
]

[tool.pytest]
env_files = ["test.env"]
testpaths = [
"tests/unit",
"tests/functional",
Expand Down
9 changes: 4 additions & 5 deletions dbt-athena/src/dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def clean_up_table(self, relation: AthenaRelation) -> None:

@available
def generate_unique_temporary_table_suffix(self, suffix_initial: str = "__dbt_tmp") -> str:
return f"{suffix_initial}_{str(uuid4())}"
return f"{suffix_initial}_{str(uuid4()).replace('-', '_')}"

def quote(self, identifier: str) -> str:
return f"{self.quote_character}{identifier}{self.quote_character}"
Expand Down Expand Up @@ -1209,22 +1209,21 @@ def _generate_snapshot_migration_sql(self, relation: AthenaRelation, table_colum
- Copy the content of the staging table to the final table
- Delete the staging table
"""
col_csv = f",\n{' ' * 16}".join(table_columns)
col_csv = f", \n{' ' * 16}".join(table_columns)
staging_relation = relation.incorporate(
path={"identifier": relation.identifier + "__dbt_tmp_migration_staging"}
)
ctas = dedent(
f"""\
select
{col_csv},
{col_csv} ,
dbt_snapshot_at as dbt_updated_at,
dbt_valid_from,
if(dbt_valid_to > cast('9000-01-01' as timestamp), null, dbt_valid_to) as dbt_valid_to,
dbt_scd_id
from {relation}
where dbt_change_type != 'delete'
;
"""
;"""
)
staging_sql = self.execute_macro(
"create_table_as", kwargs=dict(temporary=True, relation=staging_relation, compiled_code=ctas)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
{% set tmp_table_suffix = '__dbt_tmp' %}
{% endif %}

{% if unique_tmp_table_suffix == True and table_type == 'iceberg' %}
{% set tmp_table_suffix = adapter.generate_unique_temporary_table_suffix() %}
{% endif %}

{% set old_tmp_relation = adapter.get_relation(identifier=target_relation.identifier ~ tmp_table_suffix,
schema=schema,
database=database) %}
Expand Down
89 changes: 0 additions & 89 deletions dbt-athena/tests/conftest.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"
expected_unique_table_name_re = (
r"unique_tmp_table_suffix__dbt_tmp_"
r"[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}"
r"[0-9a-fA-F]{8}_[0-9a-fA-F]{4}_[0-9a-fA-F]{4}_[0-9a-fA-F]{4}_[0-9a-fA-F]{12}"
)

first_model_run = run_dbt(
Expand Down
26 changes: 26 additions & 0 deletions dbt-athena/tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

import pytest

# Import the functional fixtures as a plugin
# Note: fixtures with session scope need to be local
pytest_plugins = ["dbt.tests.fixtures.project"]


# The profile dictionary, used to write out profiles.yml
@pytest.fixture(scope="class")
def dbt_profile_target():
return {
"type": "athena",
"s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"),
"s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"),
"region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"),
"database": os.getenv("DBT_TEST_ATHENA_DATABASE"),
"schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"),
"work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"),
"threads": int(os.getenv("DBT_TEST_ATHENA_THREADS", "1")),
"poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")),
"num_retries": int(os.getenv("DBT_TEST_ATHENA_NUM_RETRIES", "2")),
"aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None,
"spark_work_group": os.getenv("DBT_TEST_ATHENA_SPARK_WORK_GROUP"),
}
60 changes: 57 additions & 3 deletions dbt-athena/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from io import StringIO
import os
from unittest.mock import MagicMock, patch

import boto3
import pytest

from .constants import AWS_REGION
from .utils import MockAWSService
from dbt_common.events import get_event_manager
from dbt_common.events.base_types import EventLevel
from dbt_common.events.logger import LineFormat, LoggerConfig, NoFilter

from dbt.adapters.athena import connections
from dbt.adapters.athena.connections import AthenaCredentials

from tests.unit.utils import MockAWSService
from tests.unit import constants


@pytest.fixture(scope="class")
def athena_client():
with patch.object(boto3.session.Session, "client", return_value=MagicMock()) as mock_athena_client:
return mock_athena_client


@pytest.fixture(scope="function")
Expand All @@ -13,9 +29,47 @@ def aws_credentials():
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = AWS_REGION
os.environ["AWS_DEFAULT_REGION"] = constants.AWS_REGION


@patch.object(connections, "AthenaCredentials")
@pytest.fixture(scope="class")
def athena_credentials():
return AthenaCredentials(
database=constants.DATA_CATALOG_NAME,
schema=constants.DATABASE_NAME,
s3_staging_dir=constants.S3_STAGING_DIR,
region_name=constants.AWS_REGION,
work_group=constants.ATHENA_WORKGROUP,
spark_work_group=constants.SPARK_WORKGROUP,
)


@pytest.fixture()
def mock_aws_service(aws_credentials) -> MockAWSService:
return MockAWSService()


@pytest.fixture(scope="function")
def dbt_error_caplog() -> StringIO:
return _setup_custom_caplog("dbt_error", EventLevel.ERROR)


@pytest.fixture(scope="function")
def dbt_debug_caplog() -> StringIO:
return _setup_custom_caplog("dbt_debug", EventLevel.DEBUG)


def _setup_custom_caplog(name: str, level: EventLevel):
string_buf = StringIO()
capture_config = LoggerConfig(
name=name,
level=level,
use_colors=False,
line_format=LineFormat.PlainText,
filter=NoFilter,
output_stream=string_buf,
)
event_manager = get_event_manager()
event_manager.add_logger(capture_config)
return string_buf

0 comments on commit 8b01938

Please sign in to comment.