From e45c48c318a9b7aad0d136cbaa9411f42fbbc009 Mon Sep 17 00:00:00 2001 From: Matthew Wallace Date: Tue, 16 Jan 2024 14:44:31 -0700 Subject: [PATCH] Support Black & MyPy pre-commit hooks (#167) * Add black and mypy as pre-commit hooks * Run black formatter on all files * Add MyPy configuration & make tweaks and ignore errors to make MyPy pass * Add .git-blame-ignore-revs to ignore `black` changes in git blame * Update changelog --- .git-blame-ignore-revs | 2 ++ .github/workflows/main.yml | 1 + .pre-commit-config.yaml | 45 +++++++++++++++++++++++-- CHANGELOG.md | 3 +- Makefile | 19 +++++++++-- dbt/adapters/mariadb/__init__.py | 2 +- dbt/adapters/mariadb/connections.py | 11 +++---- dbt/adapters/mariadb/impl.py | 51 ++++++++++++----------------- dbt/adapters/mysql/__init__.py | 4 ++- dbt/adapters/mysql/connections.py | 11 +++---- dbt/adapters/mysql/impl.py | 48 +++++++++++---------------- dbt/adapters/mysql5/__init__.py | 4 ++- dbt/adapters/mysql5/connections.py | 11 +++---- dbt/adapters/mysql5/impl.py | 49 +++++++++++---------------- dev-requirements.txt | 2 ++ mypy.ini | 2 ++ tests/unit/utils.py | 4 +-- 17 files changed, 148 insertions(+), 121 deletions(-) create mode 100644 .git-blame-ignore-revs create mode 100644 mypy.ini diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..4323bf5 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Ran `black` on all files +99a125c82b846fd25aed432ed67a9ad982bbe0ad diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2d3db96..66093db 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -58,6 +58,7 @@ jobs: python -m pip install -r dev-requirements.txt python -m pip --version pre-commit --version + mypy --version dbt --version - name: Run pre-commit hooks diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e392cb..3d80b95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,12 @@ # For more on configuring pre-commit hooks (see https://pre-commit.com/) +# Force all unspecified python hooks to run python 3.8 default_language_version: python: python3 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v4.4.0 hooks: - id: check-yaml args: [--unsafe] @@ -13,10 +14,50 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - id: check-case-conflict +- repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + additional_dependencies: ['click~=8.1'] + args: + - "--line-length=99" + - "--target-version=py38" + - id: black + alias: black-check + stages: [manual] + additional_dependencies: ['click~=8.1'] + args: + - "--line-length=99" + - "--target-version=py38" + - "--check" + - "--diff" - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 - id: flake8 alias: flake8-check stages: [manual] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.1.1 + hooks: + - id: mypy + # N.B.: Mypy is... a bit fragile. + # + # By using `language: system` we run this hook in the local + # environment instead of a pre-commit isolated one. This is needed + # to ensure mypy correctly parses the project. + + # It may cause trouble in that it adds environmental variables out + # of our control to the mix. Unfortunately, there's nothing we can + # do about per pre-commit's author. + # See https://github.com/pre-commit/pre-commit/issues/730 for details. + args: [--show-error-codes, --ignore-missing-imports, --explicit-package-bases] + files: ^dbt/adapters/.* + language: system + - id: mypy + alias: mypy-check + stages: [manual] + args: [--show-error-codes, --pretty, --ignore-missing-imports, --explicit-package-bases] + files: ^dbt/adapters + language: system diff --git a/CHANGELOG.md b/CHANGELOG.md index 85d01e2..0186095 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Features - Migrate CircleCI to GitHub Actions ([#120](https://github.com/dbeatty10/dbt-mysql/issues/120)) +- Support Black & MyPy pre-commit hooks ([#138](https://github.com/dbeatty10/dbt-mysql/issues/138)) ### Fixes - Fix incremental composite keys ([#144](https://github.com/dbeatty10/dbt-mysql/issues/144)) @@ -9,7 +10,7 @@ ### Contributors - [@moszutij](https://github.com/moszutij) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146), [#144](https://github.com/dbeatty10/dbt-mysql/issues/144)) - [@wesen](https://github.com/wesen) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146)) -- [@mwallace582](https://github.com/mwallace582) ([#162](https://github.com/dbeatty10/dbt-mysql/pull/162)) +- [@mwallace582](https://github.com/mwallace582) ([#162](https://github.com/dbeatty10/dbt-mysql/pull/162), [#138](https://github.com/dbeatty10/dbt-mysql/issues/138)) ## dbt-mysql 1.1.0 (Feb 5, 2023) diff --git a/Makefile b/Makefile index d6aa5b5..e59c742 100644 --- a/Makefile +++ b/Makefile @@ -11,15 +11,26 @@ dev-uninstall: ## Uninstalls all packages while maintaining the virtual environm pip freeze | grep -v "^-e" | cut -d "@" -f1 | xargs pip uninstall -y pip uninstall -y dbt-mysql +.PHONY: mypy +mypy: ## Runs mypy against staged changes for static type checking. + @\ + pre-commit run --hook-stage manual mypy-check | grep -v "INFO" + .PHONY: flake8 flake8: ## Runs flake8 against staged changes to enforce style guide. @\ pre-commit run --hook-stage manual flake8-check | grep -v "INFO" +.PHONY: black +black: ## Runs black against staged changes to enforce style guide. + @\ + pre-commit run --hook-stage manual black-check -v | grep -v "INFO" + .PHONY: lint -lint: ## Runs flake8 code checks against staged changes. +lint: ## Runs flake8 and mypy code checks against staged changes. @\ - pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; + pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run mypy-check --hook-stage manual | grep -v "INFO" .PHONY: linecheck linecheck: ## Checks for all Python lines 100 characters or more @@ -35,7 +46,9 @@ unit: ## Runs unit tests with py38. test: ## Runs unit tests with py38 and code checks against staged changes. @\ tox -p -e py38; \ - pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; + pre-commit run black-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run mypy-check --hook-stage manual | grep -v "INFO" .PHONY: integration integration: ## Runs mysql integration tests with py38. diff --git a/dbt/adapters/mariadb/__init__.py b/dbt/adapters/mariadb/__init__.py index 820e9d1..f6f09ad 100644 --- a/dbt/adapters/mariadb/__init__.py +++ b/dbt/adapters/mariadb/__init__.py @@ -9,7 +9,7 @@ Plugin = AdapterPlugin( - adapter=MariaDBAdapter, + adapter=MariaDBAdapter, # type: ignore[arg-type] credentials=MariaDBCredentials, include_path=mariadb.PACKAGE_PATH, ) diff --git a/dbt/adapters/mariadb/connections.py b/dbt/adapters/mariadb/connections.py index d85ec29..879b260 100644 --- a/dbt/adapters/mariadb/connections.py +++ b/dbt/adapters/mariadb/connections.py @@ -16,10 +16,10 @@ @dataclass(init=False) class MariaDBCredentials(Credentials): - server: str + server: str = "" port: Optional[int] = None - database: Optional[str] = None - schema: str + database: str = "" + schema: str = "" username: Optional[str] = None password: Optional[str] = None charset: Optional[str] = None @@ -95,7 +95,6 @@ def open(cls, connection): connection.handle = mysql.connector.connect(**kwargs) connection.state = "open" except mysql.connector.Error: - try: logger.debug( "Failed connection without supplying the `database`. " @@ -108,10 +107,8 @@ def open(cls, connection): connection.handle = mysql.connector.connect(**kwargs) connection.state = "open" except mysql.connector.Error as e: - logger.debug( - "Got an error when attempting to open a MariaDB " - "connection: '{}'".format(e) + "Got an error when attempting to open a MariaDB " "connection: '{}'".format(e) ) connection.handle = None diff --git a/dbt/adapters/mariadb/impl.py b/dbt/adapters/mariadb/impl.py index 6a16bcb..cdb6273 100644 --- a/dbt/adapters/mariadb/impl.py +++ b/dbt/adapters/mariadb/impl.py @@ -1,6 +1,6 @@ from concurrent.futures import Future from dataclasses import asdict -from typing import Optional, List, Dict, Any, Iterable +from typing import Optional, List, Dict, Any, Iterable, Tuple import agate import dbt @@ -12,6 +12,7 @@ from dbt.adapters.mariadb import MariaDBRelation from dbt.adapters.mariadb import MariaDBColumn from dbt.adapters.base import BaseRelation +from dbt.contracts.graph.manifest import Manifest from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger from dbt.utils import executor @@ -38,8 +39,8 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: def quote(self, identifier): return "`{}`".format(identifier) - def list_relations_without_caching( - self, schema_relation: MariaDBRelation + def list_relations_without_caching( # type: ignore[override] + self, schema_relation: MariaDBRelation # type: ignore[override] ) -> List[MariaDBRelation]: kwargs = {"schema_relation": schema_relation} try: @@ -62,20 +63,16 @@ def list_relations_without_caching( f"got {len(row)} values, expected 4" ) _, name, _schema, relation_type = row - relation = self.Relation.create( - schema=_schema, identifier=name, type=relation_type - ) + relation = self.Relation.create(schema=_schema, identifier=name, type=relation_type) relations.append(relation) return relations - def get_columns_in_relation(self, relation: Relation) -> List[MariaDBColumn]: + def get_columns_in_relation(self, relation: MariaDBRelation) -> List[MariaDBColumn]: rows: List[agate.Row] = super().get_columns_in_relation(relation) return self.parse_show_columns(relation, rows) - def _get_columns_for_catalog( - self, relation: MariaDBRelation - ) -> Iterable[Dict[str, Any]]: + def _get_columns_for_catalog(self, relation: MariaDBRelation) -> Iterable[Dict[str, Any]]: columns = self.get_columns_in_relation(relation) for column in columns: @@ -87,7 +84,7 @@ def _get_columns_for_catalog( yield as_dict def get_relation( - self, database: str, schema: str, identifier: str + self, database: Optional[str], schema: str, identifier: str ) -> Optional[BaseRelation]: if not self.Relation.include_policy.database: database = None @@ -95,7 +92,7 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_show_columns( - self, relation: Relation, raw_rows: List[agate.Row] + self, relation: MariaDBRelation, raw_rows: List[agate.Row] ) -> List[MariaDBColumn]: return [ MariaDBColumn( @@ -112,12 +109,12 @@ def parse_show_columns( for idx, column in enumerate(raw_rows) ] - def get_catalog(self, manifest): + def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]: schema_map = self._get_catalog_schemas(manifest) + if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f"Expected only one database in get_catalog, found " - f"{list(schema_map)}" + f"Expected only one database in get_catalog, found " f"{list(schema_map)}" ) with executor(self.config) as tpe: @@ -145,8 +142,7 @@ def _get_one_catalog( ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f"Expected only one schema in mariadb _get_one_catalog, found " - f"{schemas}" + f"Expected only one schema in mariadb_get_one_catalog, found " f"{schemas}" ) database = information_schema.database @@ -155,13 +151,11 @@ def _get_one_catalog( columns: List[Dict[str, Any]] = [] for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) - columns.extend(self._get_columns_for_catalog(relation)) + columns.extend(self._get_columns_for_catalog(relation)) # type: ignore[arg-type] return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): - results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database} - ) + results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) exists = True if schema in [row[0] for row in results] else False return exists @@ -179,9 +173,7 @@ def update_column_sql( clause += f" where {where_clause}" return clause - def timestamp_add_sql( - self, add_to: str, number: int = 1, interval: str = "hour" - ) -> str: + def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: # for backwards compatibility, we're compelled to set some sort of # default. A lot of searching has lead me to believe that the # '+ interval' syntax used in postgres/redshift is relatively common @@ -205,9 +197,10 @@ def string_add_sql( def get_rows_different_sql( self, - relation_a: MariaDBRelation, - relation_b: MariaDBRelation, + relation_a: MariaDBRelation, # type: ignore[override] + relation_b: MariaDBRelation, # type: ignore[override] column_names: Optional[List[str]] = None, + except_operator: str = "", # Required to match BaseRelation.get_rows_different_sql() ) -> str: # This method only really exists for test reasons names: List[str] @@ -221,12 +214,10 @@ def get_rows_different_sql( alias_b = "B" columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names]) columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names]) - join_condition = " AND ".join( - [f"{alias_a}.{name} = {alias_b}.{name}" for name in names] - ) + join_condition = " AND ".join([f"{alias_a}.{name} = {alias_b}.{name}" for name in names]) first_column = names[0] - # There is no EXCEPT or MINUS operator, so we need to simulate it + # MariaDB doesn't have an EXCEPT or MINUS operator, so we need to simulate it COLUMNS_EQUAL_SQL = """ SELECT row_count_diff.difference as row_count_difference, diff --git a/dbt/adapters/mysql/__init__.py b/dbt/adapters/mysql/__init__.py index 654b023..1357d08 100644 --- a/dbt/adapters/mysql/__init__.py +++ b/dbt/adapters/mysql/__init__.py @@ -9,5 +9,7 @@ Plugin = AdapterPlugin( - adapter=MySQLAdapter, credentials=MySQLCredentials, include_path=mysql.PACKAGE_PATH + adapter=MySQLAdapter, # type: ignore[arg-type] + credentials=MySQLCredentials, + include_path=mysql.PACKAGE_PATH, ) diff --git a/dbt/adapters/mysql/connections.py b/dbt/adapters/mysql/connections.py index 6a4e285..087aa4a 100644 --- a/dbt/adapters/mysql/connections.py +++ b/dbt/adapters/mysql/connections.py @@ -16,10 +16,10 @@ @dataclass(init=False) class MySQLCredentials(Credentials): - server: str + server: str = "" port: Optional[int] = None - database: Optional[str] = None - schema: str + database: str = "" + schema: str = "" username: Optional[str] = None password: Optional[str] = None charset: Optional[str] = None @@ -91,7 +91,6 @@ def open(cls, connection): connection.handle = mysql.connector.connect(**kwargs) connection.state = "open" except mysql.connector.Error: - try: logger.debug( "Failed connection without supplying the `database`. " @@ -104,10 +103,8 @@ def open(cls, connection): connection.handle = mysql.connector.connect(**kwargs) connection.state = "open" except mysql.connector.Error as e: - logger.debug( - "Got an error when attempting to open a mysql " - "connection: '{}'".format(e) + "Got an error when attempting to open a mysql " "connection: '{}'".format(e) ) connection.handle = None diff --git a/dbt/adapters/mysql/impl.py b/dbt/adapters/mysql/impl.py index f1e11d1..41afacf 100644 --- a/dbt/adapters/mysql/impl.py +++ b/dbt/adapters/mysql/impl.py @@ -1,6 +1,6 @@ from concurrent.futures import Future from dataclasses import asdict -from typing import Optional, List, Dict, Any, Iterable +from typing import Optional, List, Dict, Any, Iterable, Tuple import agate import dbt @@ -12,6 +12,7 @@ from dbt.adapters.mysql import MySQLRelation from dbt.adapters.mysql import MySQLColumn from dbt.adapters.base import BaseRelation +from dbt.contracts.graph.manifest import Manifest from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger from dbt.utils import executor @@ -38,8 +39,8 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: def quote(self, identifier): return "`{}`".format(identifier) - def list_relations_without_caching( - self, schema_relation: MySQLRelation + def list_relations_without_caching( # type: ignore[override] + self, schema_relation: MySQLRelation # type: ignore[override] ) -> List[MySQLRelation]: kwargs = {"schema_relation": schema_relation} try: @@ -62,20 +63,16 @@ def list_relations_without_caching( f"got {len(row)} values, expected 4" ) _, name, _schema, relation_type = row - relation = self.Relation.create( - schema=_schema, identifier=name, type=relation_type - ) + relation = self.Relation.create(schema=_schema, identifier=name, type=relation_type) relations.append(relation) return relations - def get_columns_in_relation(self, relation: Relation) -> List[MySQLColumn]: + def get_columns_in_relation(self, relation: MySQLRelation) -> List[MySQLColumn]: rows: List[agate.Row] = super().get_columns_in_relation(relation) return self.parse_show_columns(relation, rows) - def _get_columns_for_catalog( - self, relation: MySQLRelation - ) -> Iterable[Dict[str, Any]]: + def _get_columns_for_catalog(self, relation: MySQLRelation) -> Iterable[Dict[str, Any]]: columns = self.get_columns_in_relation(relation) for column in columns: @@ -87,7 +84,7 @@ def _get_columns_for_catalog( yield as_dict def get_relation( - self, database: str, schema: str, identifier: str + self, database: Optional[str], schema: str, identifier: str ) -> Optional[BaseRelation]: if not self.Relation.include_policy.database: database = None @@ -95,7 +92,7 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_show_columns( - self, relation: Relation, raw_rows: List[agate.Row] + self, relation: MySQLRelation, raw_rows: List[agate.Row] ) -> List[MySQLColumn]: return [ MySQLColumn( @@ -112,13 +109,12 @@ def parse_show_columns( for idx, column in enumerate(raw_rows) ] - def get_catalog(self, manifest): + def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]: schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f"Expected only one database in get_catalog, found " - f"{list(schema_map)}" + f"Expected only one database in get_catalog, found " f"{list(schema_map)}" ) with executor(self.config) as tpe: @@ -146,8 +142,7 @@ def _get_one_catalog( ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f"Expected only one schema in mysql _get_one_catalog, found " - f"{schemas}" + f"Expected only one schema in mysql_get_one_catalog, found " f"{schemas}" ) database = information_schema.database @@ -156,13 +151,11 @@ def _get_one_catalog( columns: List[Dict[str, Any]] = [] for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) - columns.extend(self._get_columns_for_catalog(relation)) + columns.extend(self._get_columns_for_catalog(relation)) # type: ignore[arg-type] return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): - results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database} - ) + results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) exists = True if schema in [row[0] for row in results] else False return exists @@ -180,9 +173,7 @@ def update_column_sql( clause += f" where {where_clause}" return clause - def timestamp_add_sql( - self, add_to: str, number: int = 1, interval: str = "hour" - ) -> str: + def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: # for backwards compatibility, we're compelled to set some sort of # default. A lot of searching has lead me to believe that the # '+ interval' syntax used in postgres/redshift is relatively common @@ -206,9 +197,10 @@ def string_add_sql( def get_rows_different_sql( self, - relation_a: MySQLRelation, - relation_b: MySQLRelation, + relation_a: MySQLRelation, # type: ignore[override] + relation_b: MySQLRelation, # type: ignore[override] column_names: Optional[List[str]] = None, + except_operator: str = "", # Required to match BaseRelation.get_rows_different_sql() ) -> str: # This method only really exists for test reasons names: List[str] @@ -222,9 +214,7 @@ def get_rows_different_sql( alias_b = "B" columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names]) columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names]) - join_condition = " AND ".join( - [f"{alias_a}.{name} = {alias_b}.{name}" for name in names] - ) + join_condition = " AND ".join([f"{alias_a}.{name} = {alias_b}.{name}" for name in names]) first_column = names[0] # MySQL doesn't have an EXCEPT or MINUS operator, so we need to simulate it diff --git a/dbt/adapters/mysql5/__init__.py b/dbt/adapters/mysql5/__init__.py index 8f23e58..00d41e6 100644 --- a/dbt/adapters/mysql5/__init__.py +++ b/dbt/adapters/mysql5/__init__.py @@ -9,5 +9,7 @@ Plugin = AdapterPlugin( - adapter=MySQLAdapter, credentials=MySQLCredentials, include_path=mysql5.PACKAGE_PATH + adapter=MySQLAdapter, # type: ignore[arg-type] + credentials=MySQLCredentials, + include_path=mysql5.PACKAGE_PATH, ) diff --git a/dbt/adapters/mysql5/connections.py b/dbt/adapters/mysql5/connections.py index c8c1d20..e6c6ed8 100644 --- a/dbt/adapters/mysql5/connections.py +++ b/dbt/adapters/mysql5/connections.py @@ -16,10 +16,10 @@ @dataclass(init=False) class MySQLCredentials(Credentials): - server: str + server: str = "" port: Optional[int] = None - database: Optional[str] = None - schema: str + database: str = "" + schema: str = "" username: Optional[str] = None password: Optional[str] = None charset: Optional[str] = None @@ -95,7 +95,6 @@ def open(cls, connection): connection.handle = mysql.connector.connect(**kwargs) connection.state = "open" except mysql.connector.Error: - try: logger.debug( "Failed connection without supplying the `database`. " @@ -108,10 +107,8 @@ def open(cls, connection): connection.handle = mysql.connector.connect(**kwargs) connection.state = "open" except mysql.connector.Error as e: - logger.debug( - "Got an error when attempting to open a mysql " - "connection: '{}'".format(e) + "Got an error when attempting to open a mysql " "connection: '{}'".format(e) ) connection.handle = None diff --git a/dbt/adapters/mysql5/impl.py b/dbt/adapters/mysql5/impl.py index 2582c83..65f84b4 100644 --- a/dbt/adapters/mysql5/impl.py +++ b/dbt/adapters/mysql5/impl.py @@ -1,6 +1,6 @@ from concurrent.futures import Future from dataclasses import asdict -from typing import Optional, List, Dict, Any, Iterable +from typing import Optional, List, Dict, Any, Iterable, Tuple import agate import dbt @@ -12,6 +12,7 @@ from dbt.adapters.mysql5 import MySQLRelation from dbt.adapters.mysql5 import MySQLColumn from dbt.adapters.base import BaseRelation +from dbt.contracts.graph.manifest import Manifest from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger from dbt.utils import executor @@ -38,8 +39,8 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: def quote(self, identifier): return "`{}`".format(identifier) - def list_relations_without_caching( - self, schema_relation: MySQLRelation + def list_relations_without_caching( # type: ignore[override] + self, schema_relation: MySQLRelation # type: ignore[override] ) -> List[MySQLRelation]: kwargs = {"schema_relation": schema_relation} try: @@ -62,20 +63,16 @@ def list_relations_without_caching( f"got {len(row)} values, expected 4" ) _, name, _schema, relation_type = row - relation = self.Relation.create( - schema=_schema, identifier=name, type=relation_type - ) + relation = self.Relation.create(schema=_schema, identifier=name, type=relation_type) relations.append(relation) return relations - def get_columns_in_relation(self, relation: Relation) -> List[MySQLColumn]: + def get_columns_in_relation(self, relation: MySQLRelation) -> List[MySQLColumn]: rows: List[agate.Row] = super().get_columns_in_relation(relation) return self.parse_show_columns(relation, rows) - def _get_columns_for_catalog( - self, relation: MySQLRelation - ) -> Iterable[Dict[str, Any]]: + def _get_columns_for_catalog(self, relation: MySQLRelation) -> Iterable[Dict[str, Any]]: columns = self.get_columns_in_relation(relation) for column in columns: @@ -87,7 +84,7 @@ def _get_columns_for_catalog( yield as_dict def get_relation( - self, database: str, schema: str, identifier: str + self, database: Optional[str], schema: str, identifier: str ) -> Optional[BaseRelation]: if not self.Relation.include_policy.database: database = None @@ -95,7 +92,7 @@ def get_relation( return super().get_relation(database, schema, identifier) def parse_show_columns( - self, relation: Relation, raw_rows: List[agate.Row] + self, relation: MySQLRelation, raw_rows: List[agate.Row] ) -> List[MySQLColumn]: return [ MySQLColumn( @@ -112,12 +109,12 @@ def parse_show_columns( for idx, column in enumerate(raw_rows) ] - def get_catalog(self, manifest): + def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]: schema_map = self._get_catalog_schemas(manifest) + if len(schema_map) > 1: dbt.exceptions.raise_compiler_error( - f"Expected only one database in get_catalog, found " - f"{list(schema_map)}" + f"Expected only one database in get_catalog, found " f"{list(schema_map)}" ) with executor(self.config) as tpe: @@ -145,8 +142,7 @@ def _get_one_catalog( ) -> agate.Table: if len(schemas) != 1: dbt.exceptions.raise_compiler_error( - f"Expected only one schema in mysql5 _get_one_catalog, found " - f"{schemas}" + f"Expected only one schema in mysql5 _get_one_catalog, found " f"{schemas}" ) database = information_schema.database @@ -155,13 +151,11 @@ def _get_one_catalog( columns: List[Dict[str, Any]] = [] for relation in self.list_relations(database, schema): logger.debug("Getting table schema for relation {}", relation) - columns.extend(self._get_columns_for_catalog(relation)) + columns.extend(self._get_columns_for_catalog(relation)) # type: ignore[arg-type] return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) def check_schema_exists(self, database, schema): - results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database} - ) + results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) exists = True if schema in [row[0] for row in results] else False return exists @@ -179,9 +173,7 @@ def update_column_sql( clause += f" where {where_clause}" return clause - def timestamp_add_sql( - self, add_to: str, number: int = 1, interval: str = "hour" - ) -> str: + def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: # for backwards compatibility, we're compelled to set some sort of # default. A lot of searching has lead me to believe that the # '+ interval' syntax used in postgres/redshift is relatively common @@ -205,9 +197,10 @@ def string_add_sql( def get_rows_different_sql( self, - relation_a: MySQLRelation, - relation_b: MySQLRelation, + relation_a: MySQLRelation, # type: ignore[override] + relation_b: MySQLRelation, # type: ignore[override] column_names: Optional[List[str]] = None, + except_operator: str = "", # Required to match BaseRelation.get_rows_different_sql() ) -> str: # This method only really exists for test reasons names: List[str] @@ -221,9 +214,7 @@ def get_rows_different_sql( alias_b = "B" columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names]) columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names]) - join_condition = " AND ".join( - [f"{alias_a}.{name} = {alias_b}.{name}" for name in names] - ) + join_condition = " AND ".join([f"{alias_a}.{name} = {alias_b}.{name}" for name in names]) first_column = names[0] # MySQL doesn't have an EXCEPT or MINUS operator, so we need to simulate it diff --git a/dev-requirements.txt b/dev-requirements.txt index b52651f..7acf8ba 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -5,12 +5,14 @@ git+https://github.com/dbt-labs/dbt-core.git@1.2.latest#egg=dbt-tests-adapter&su # if version 1.x or greater -> pin to major version # if version 0.x -> pin to minor +black~=22.12 bumpversion~=0.6.0 ddtrace~=2.3 flake8~=6.1 flaky~=3.7 freezegun~=1.3 ipdb~=0.13.13 +mypy==1.7.1 # patch updates have historically introduced breaking changes pre-commit~=3.5 pre-commit-hooks~=4.5 pytest~=7.4 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..b6e6035 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +namespace_packages = True diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 07371d3..bac51d6 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -71,9 +71,7 @@ def project_from_dict(project, profile, packages=None, selectors=None, cli_vars= return partial.render(renderer) -def config_from_parts_or_dicts( - project, profile, packages=None, selectors=None, cli_vars="{}" -): +def config_from_parts_or_dicts(project, profile, packages=None, selectors=None, cli_vars="{}"): from dbt.config import Project, Profile, RuntimeConfig from copy import deepcopy