Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use hatch fmt, fix formatting issues #5

Merged
merged 2 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ env:
jobs:
test:
strategy:
fail-fast: false
matrix:
os:
- ubuntu-latest
Expand All @@ -32,10 +33,9 @@ jobs:
3.12
cache: pip
- run: pip install hatch pre-commit
# TODO enable later
# - run: hatch fmt --check
- run: git fetch origin main
- run: pre-commit run --from-ref origin/main --to-ref HEAD
- run: hatch fmt --check
- run: |
hatch run dbt clean
hatch run dbt seed
Expand Down
2 changes: 2 additions & 0 deletions dbt_pumpkin/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class PumpkinError(Exception):
pass
42 changes: 26 additions & 16 deletions dbt_pumpkin/pumpkin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import json
import os
import shutil
import tempfile
from functools import cached_property
from pathlib import Path
from typing import Dict, List, Set, Union
from typing import Union

from ruamel.yaml import YAML

Expand All @@ -20,16 +22,20 @@
dbtRunnerResult,
default_project_dir,
)
from dbt_pumpkin.exception import PumpkinError

yaml = YAML(typ="safe")

Resource = Union[SourceDefinition, ModelNode, SnapshotNode, SeedNode]


class Pumpkin:

def __init__(
self, project_dir: str = None, profiles_dir: str = None, selects: List[str] = None, excludes: List[str] = None
self,
project_dir: str | None = None,
profiles_dir: str | None = None,
selects: list[str] | None = None,
excludes: list[str] | None = None,
) -> None:
self.project_dir = project_dir
self.profiles_dir = profiles_dir
Expand All @@ -52,7 +58,7 @@ def manifest(self) -> Manifest:
return res.result

@cached_property
def selected_resource_ids(self) -> Dict[str, Set[str]]:
def selected_resource_ids(self) -> dict[str, set[str]]:
"""
Returns a dictionary mapping resource type to a set of resource identifiers
"""
Expand All @@ -72,7 +78,7 @@ def selected_resource_ids(self) -> Dict[str, Set[str]]:
if not res.success:
raise res.exception

result: Dict[str, Set[str]] = {}
result: dict[str, set[str]] = {}
for raw_resource in res.result:
resource = json.loads(raw_resource)
resource_type = resource["resource_type"]
Expand All @@ -82,28 +88,30 @@ def selected_resource_ids(self) -> Dict[str, Set[str]]:
return result

@cached_property
def selected_resources(self) -> List[Resource]:
results: List[Resource] = []
def selected_resources(self) -> list[Resource]:
results: list[Resource] = []

for resource_type, resource_ids in self.selected_resource_ids.items():
resource_by_id = self.manifest.sources if resource_type == "source" else self.manifest.nodes
results += [resource_by_id[id] for id in resource_ids]
results += [resource_by_id[res_id] for res_id in resource_ids]

return results

@cached_property
def selected_resource_actual_schemas(self) -> Dict[str, List[ColumnInfo]]:
def selected_resource_actual_schemas(self) -> dict[str, list[ColumnInfo]]:
src_macros_path = Path(__file__).parent / "macros"

if not src_macros_path.exists() or not src_macros_path.is_dir():
raise Exception(f"Macros directory is not found or doesn't exist: {src_macros_path}")
msg = f"Macros directory is not found or doesn't exist: {src_macros_path}"
raise PumpkinError(msg)

project_dir = Path(self.project_dir or os.environ.get("DBT_PROJECT_DIR", None) or default_project_dir())

project_yml_path = project_dir / "dbt_project.yml"

if not project_yml_path.exists() or not project_yml_path.is_file():
raise Exception(f"dbt_project.ym is not found or doesn't exist: {project_yml_path}")
msg = f"dbt_project.yml is not found or doesn't exist: {project_yml_path}"
raise PumpkinError(msg)

operation_args = {
resource.unique_id: [resource.database, resource.schema, resource.identifier]
Expand All @@ -115,14 +123,14 @@ def selected_resource_actual_schemas(self) -> Dict[str, List[ColumnInfo]]:
"name": "dbt_pumpkin",
"version": "0.1.0",
"profile": project_yml["profile"],
# TODO copy vars?
# TODO: copy vars?
"vars": {
# workaround for too long CMD on Windows
"get_column_types_args": operation_args
},
}

jinja_log_messages: List[str] = []
jinja_log_messages: list[str] = []

def event_callback(event: EventMsg):
if event.info.name == "JinjaLogInfo":
Expand All @@ -144,11 +152,13 @@ def event_callback(event: EventMsg):
if not res.success:
raise res.exception

assert jinja_log_messages
if not jinja_log_messages:
msg = "No schema retrieved from database"
raise PumpkinError(msg)

column_types_response = json.loads(jinja_log_messages[0])

return {
id: [ColumnInfo(name=c["name"], data_type=c["data_type"]) for c in columns]
for id, columns in column_types_response.items()
res_id: [ColumnInfo(name=c["name"], data_type=c["data_type"]) for c in columns]
for res_id, columns in column_types_response.items()
}
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ dbt = ["1.5", "1.6", "1.7", "1.8"]
python = ["3.12"]
dbt = ["1.7", "1.8"]

[tool.ruff]
# cache directory for GH actions
extend-exclude = [".cache"]

[tool.yamlfix]
explicit_start = false
whitelines = 1
Expand Down
7 changes: 5 additions & 2 deletions tests/test_hatch.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import os
import sys

import dbt
import dbt.version


def test_expected_python_version():
sys_version = str(sys.version_info.major) + "." + str(sys.version_info.minor)
expected_version = os.environ.get('EXPECTED_PYTHON_VERSION')
expected_version = os.environ.get("EXPECTED_PYTHON_VERSION")
assert sys_version == expected_version


def test_expected_dbt_version():
sys_version = dbt.version.get_installed_version().major + "." + dbt.version.get_installed_version().minor
expected_version = os.environ.get('EXPECTED_DBT_VERSION')
expected_version = os.environ.get("EXPECTED_DBT_VERSION")
assert sys_version == expected_version
4 changes: 2 additions & 2 deletions tests/test_pumpkin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List
from __future__ import annotations

import pytest

from dbt_pumpkin.dbt_compat import ColumnInfo
from dbt_pumpkin.pumpkin import Pumpkin


def pumpkin(selects: List[str] = None, excludes: List[str] = None) -> Pumpkin:
def pumpkin(selects: list[str] | None = None, excludes: list[str] | None = None) -> Pumpkin:
return Pumpkin("tests/my_pumpkin", "tests/my_pumpkin", selects, excludes)


Expand Down