Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandro Hou committed Nov 23, 2024
2 parents fddcb83 + 816c808 commit 3b0cd1a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 96 deletions.
6 changes: 3 additions & 3 deletions dbt-athena/test.env.example → dbt-athena/.env.example
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
DBT_TEST_ATHENA_S3_STAGING_DIR=
DBT_TEST_ATHENA_S3_TMP_TABLE_DIR=
DBT_TEST_ATHENA_REGION_NAME=
DBT_TEST_ATHENA_THREADS=
DBT_TEST_ATHENA_POLL_INTERVAL=
DBT_TEST_ATHENA_DATABASE=
DBT_TEST_ATHENA_SCHEMA=
DBT_TEST_ATHENA_WORK_GROUP=
DBT_TEST_ATHENA_THREADS=
DBT_TEST_ATHENA_POLL_INTERVAL=
DBT_TEST_ATHENA_NUM_RETRIES=
DBT_TEST_ATHENA_AWS_PROFILE_NAME=
DBT_TEST_ATHENA_SPARK_WORK_GROUP=
1 change: 0 additions & 1 deletion dbt-athena/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ check-sdist = [
]

[tool.pytest]
env_files = ["test.env"]
testpaths = [
"tests/unit",
"tests/functional",
Expand Down
89 changes: 0 additions & 89 deletions dbt-athena/tests/conftest.py

This file was deleted.

26 changes: 26 additions & 0 deletions dbt-athena/tests/functional/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

import pytest

# Import the functional fixtures as a plugin
# Note: fixtures with session scope need to be local
pytest_plugins = ["dbt.tests.fixtures.project"]


# The profile dictionary, used to write out profiles.yml
@pytest.fixture(scope="class")
def dbt_profile_target():
return {
"type": "athena",
"s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"),
"s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"),
"region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"),
"database": os.getenv("DBT_TEST_ATHENA_DATABASE"),
"schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"),
"work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"),
"threads": int(os.getenv("DBT_TEST_ATHENA_THREADS", "1")),
"poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")),
"num_retries": int(os.getenv("DBT_TEST_ATHENA_NUM_RETRIES", "2")),
"aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None,
"spark_work_group": os.getenv("DBT_TEST_ATHENA_SPARK_WORK_GROUP"),
}
60 changes: 57 additions & 3 deletions dbt-athena/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from io import StringIO
import os
from unittest.mock import MagicMock, patch

import boto3
import pytest

from .constants import AWS_REGION
from .utils import MockAWSService
from dbt_common.events import get_event_manager
from dbt_common.events.base_types import EventLevel
from dbt_common.events.logger import LineFormat, LoggerConfig, NoFilter

from dbt.adapters.athena import connections
from dbt.adapters.athena.connections import AthenaCredentials

from tests.unit.utils import MockAWSService
from tests.unit import constants


@pytest.fixture(scope="class")
def athena_client():
with patch.object(boto3.session.Session, "client", return_value=MagicMock()) as mock_athena_client:
return mock_athena_client


@pytest.fixture(scope="function")
Expand All @@ -13,9 +29,47 @@ def aws_credentials():
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
os.environ["AWS_DEFAULT_REGION"] = AWS_REGION
os.environ["AWS_DEFAULT_REGION"] = constants.AWS_REGION


@patch.object(connections, "AthenaCredentials")
@pytest.fixture(scope="class")
def athena_credentials():
return AthenaCredentials(
database=constants.DATA_CATALOG_NAME,
schema=constants.DATABASE_NAME,
s3_staging_dir=constants.S3_STAGING_DIR,
region_name=constants.AWS_REGION,
work_group=constants.ATHENA_WORKGROUP,
spark_work_group=constants.SPARK_WORKGROUP,
)


@pytest.fixture()
def mock_aws_service(aws_credentials) -> MockAWSService:
return MockAWSService()


@pytest.fixture(scope="function")
def dbt_error_caplog() -> StringIO:
return _setup_custom_caplog("dbt_error", EventLevel.ERROR)


@pytest.fixture(scope="function")
def dbt_debug_caplog() -> StringIO:
return _setup_custom_caplog("dbt_debug", EventLevel.DEBUG)


def _setup_custom_caplog(name: str, level: EventLevel):
string_buf = StringIO()
capture_config = LoggerConfig(
name=name,
level=level,
use_colors=False,
line_format=LineFormat.PlainText,
filter=NoFilter,
output_stream=string_buf,
)
event_manager = get_event_manager()
event_manager.add_logger(capture_config)
return string_buf

0 comments on commit 3b0cd1a

Please sign in to comment.