From 49c9e857ac3ca2af55d271c0837162ef766ebe9c Mon Sep 17 00:00:00 2001 From: Matthew Wallace Date: Thu, 11 Jan 2024 16:42:01 -0700 Subject: [PATCH] Update dbt-core to 1.5 and implement support for model contracts (#163) Update dbt-core to 1.5.9 and implement support for constraints * Update changelog * Fix unit test error - The error was: `AttributeError: 'Namespace' object has no attribute 'MACRO_DEBUGGING'` * Allow Unix socket connection rather than just TCP (#165) --- CHANGELOG.md | 4 +- dbt/adapters/mariadb/column.py | 8 + dbt/adapters/mariadb/connections.py | 17 +- dbt/adapters/mariadb/impl.py | 20 +- dbt/adapters/mysql/column.py | 8 + dbt/adapters/mysql/connections.py | 17 +- dbt/adapters/mysql/impl.py | 20 +- dbt/adapters/mysql5/column.py | 8 + dbt/adapters/mysql5/connections.py | 17 +- dbt/adapters/mysql5/impl.py | 20 +- dbt/include/mariadb/macros/adapters.sql | 54 ++- dbt/include/mysql/macros/adapters.sql | 49 ++- dbt/include/mysql5/macros/adapters.sql | 47 ++- .../adapter/constraints/fixtures.py | 320 +++++++++++++++ .../adapter/constraints/test_constraints.py | 365 ++++++++++++++++++ tests/unit/test_adapter.py | 1 - tests/unit/utils.py | 9 +- 17 files changed, 963 insertions(+), 21 deletions(-) create mode 100644 tests/functional/adapter/constraints/fixtures.py create mode 100644 tests/functional/adapter/constraints/test_constraints.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b462295..ac88daf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ### Features - Migrate CircleCI to GitHub Actions ([#120](https://github.com/dbeatty10/dbt-mysql/issues/120)) - Support dbt v1.4 ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146)) +- Support dbt v1.5 ([#145](https://github.com/dbeatty10/dbt-mysql/issues/145)) +- Support connecting via UNIX sockets ([#164](https://github.com/dbeatty10/dbt-mysql/issues/164)) ### Fixes - Fix incremental composite keys ([#144](https://github.com/dbeatty10/dbt-mysql/issues/144)) @@ -11,7 +13,7 @@ - [@lpezet](https://github.com/lpezet) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146)) - [@moszutij](https://github.com/moszutij) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146), [#144](https://github.com/dbeatty10/dbt-mysql/issues/144)) - [@wesen](https://github.com/wesen) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146)) -- [@mwallace582](https://github.com/mwallace582) ([#162](https://github.com/dbeatty10/dbt-mysql/pull/162)) +- [@mwallace582](https://github.com/mwallace582) ([#162](https://github.com/dbeatty10/dbt-mysql/pull/162), [#163](https://github.com/dbeatty10/dbt-mysql/pull/163), [#164](https://github.com/dbeatty10/dbt-mysql/issues/164)) ## dbt-mysql 1.1.0 (Feb 5, 2023) diff --git a/dbt/adapters/mariadb/column.py b/dbt/adapters/mariadb/column.py index c7c47eb..adc941e 100644 --- a/dbt/adapters/mariadb/column.py +++ b/dbt/adapters/mariadb/column.py @@ -8,6 +8,14 @@ @dataclass class MariaDBColumn(Column): + TYPE_LABELS = { + "STRING": "TEXT", + "VAR_STRING": "TEXT", + "LONG": "INTEGER", + "LONGLONG": "INTEGER", + "INT": "INTEGER", + "TIMESTAMP": "DATETIME", + } table_database: Optional[str] = None table_schema: Optional[str] = None table_name: Optional[str] = None diff --git a/dbt/adapters/mariadb/connections.py b/dbt/adapters/mariadb/connections.py index cd50ea9..6b30f9a 100644 --- a/dbt/adapters/mariadb/connections.py +++ b/dbt/adapters/mariadb/connections.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import mysql.connector +import mysql.connector.constants import dbt.exceptions from dbt.adapters.sql import SQLConnectionManager @@ -16,7 +17,8 @@ @dataclass(init=False) class MariaDBCredentials(Credentials): - server: str + server: Optional[str] = None + unix_socket: Optional[str] = None port: Optional[int] = None database: Optional[str] = None schema: str @@ -61,6 +63,7 @@ def _connection_keys(self): """ return ( "server", + "unix_socket", "port", "database", "schema", @@ -80,7 +83,6 @@ def open(cls, connection): credentials = cls.get_credentials(connection.credentials) kwargs = {} - kwargs["host"] = credentials.server kwargs["user"] = credentials.username kwargs["passwd"] = credentials.password kwargs["buffered"] = True @@ -88,6 +90,11 @@ def open(cls, connection): if credentials.ssl_disabled: kwargs["ssl_disabled"] = credentials.ssl_disabled + if credentials.server: + kwargs["host"] = credentials.server + elif credentials.unix_socket: + kwargs["unix_socket"] = credentials.unix_socket + if credentials.port: kwargs["port"] = credentials.port @@ -172,3 +179,9 @@ def get_response(cls, cursor) -> AdapterResponse: rows_affected=num_rows, code=code ) + + @classmethod + def data_type_code_to_name(cls, type_code: int) -> str: + field_type_values = mysql.connector.constants.FieldType.desc.values() + mapping = {code: name for (code, name) in field_type_values} + return mapping[type_code] diff --git a/dbt/adapters/mariadb/impl.py b/dbt/adapters/mariadb/impl.py index 2557f36..7289e7a 100644 --- a/dbt/adapters/mariadb/impl.py +++ b/dbt/adapters/mariadb/impl.py @@ -12,6 +12,8 @@ from dbt.adapters.mariadb import MariaDBRelation from dbt.adapters.mariadb import MariaDBColumn from dbt.adapters.base import BaseRelation +from dbt.contracts.graph.nodes import ConstraintType +from dbt.adapters.base.impl import ConstraintSupport from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger from dbt.utils import executor @@ -27,6 +29,19 @@ class MariaDBAdapter(SQLAdapter): Column = MariaDBColumn ConnectionManager = MariaDBConnectionManager + CONSTRAINT_SUPPORT = { + ConstraintType.check: ConstraintSupport.ENFORCED, + ConstraintType.not_null: ConstraintSupport.ENFORCED, + ConstraintType.unique: ConstraintSupport.ENFORCED, + ConstraintType.primary_key: ConstraintSupport.ENFORCED, + # While Foreign Keys are indeed supported, they're not supported in + # CREATE TABLE AS SELECT statements, which is what DBT uses. + # + # It is possible to use a `post-hook` to add a foreign key after the + # table is created. + ConstraintType.foreign_key: ConstraintSupport.NOT_SUPPORTED, + } + @classmethod def date_function(cls): return "current_date()" @@ -36,7 +51,8 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" - def quote(self, identifier): + @classmethod + def quote(cls, identifier: str) -> str: return "`{}`".format(identifier) def list_relations_without_caching( @@ -157,7 +173,7 @@ def _get_one_catalog( columns: List[Dict[str, Any]] = [] for relation in self.list_relations(database, schema): - logger.debug("Getting table schema for relation {}", relation) + logger.debug("Getting table schema for relation {}", str(relation)) columns.extend(self._get_columns_for_catalog(relation)) return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) diff --git a/dbt/adapters/mysql/column.py b/dbt/adapters/mysql/column.py index 9ce3786..931593c 100644 --- a/dbt/adapters/mysql/column.py +++ b/dbt/adapters/mysql/column.py @@ -8,6 +8,14 @@ @dataclass class MySQLColumn(Column): + TYPE_LABELS = { + "STRING": "TEXT", + "VAR_STRING": "TEXT", + "LONG": "INTEGER", + "LONGLONG": "INTEGER", + "INT": "INTEGER", + "TIMESTAMP": "DATETIME", + } table_database: Optional[str] = None table_schema: Optional[str] = None table_name: Optional[str] = None diff --git a/dbt/adapters/mysql/connections.py b/dbt/adapters/mysql/connections.py index 42880f6..d8932dd 100644 --- a/dbt/adapters/mysql/connections.py +++ b/dbt/adapters/mysql/connections.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import mysql.connector +import mysql.connector.constants import dbt.exceptions from dbt.adapters.sql import SQLConnectionManager @@ -16,7 +17,8 @@ @dataclass(init=False) class MySQLCredentials(Credentials): - server: str + server: Optional[str] = None + unix_socket: Optional[str] = None port: Optional[int] = None database: Optional[str] = None schema: str @@ -60,6 +62,7 @@ def _connection_keys(self): """ return ( "server", + "unix_socket", "port", "database", "schema", @@ -79,11 +82,15 @@ def open(cls, connection): credentials = cls.get_credentials(connection.credentials) kwargs = {} - kwargs["host"] = credentials.server kwargs["user"] = credentials.username kwargs["passwd"] = credentials.password kwargs["buffered"] = True + if credentials.server: + kwargs["host"] = credentials.server + elif credentials.unix_socket: + kwargs["unix_socket"] = credentials.unix_socket + if credentials.port: kwargs["port"] = credentials.port @@ -168,3 +175,9 @@ def get_response(cls, cursor) -> AdapterResponse: rows_affected=num_rows, code=code ) + + @classmethod + def data_type_code_to_name(cls, type_code: int) -> str: + field_type_values = mysql.connector.constants.FieldType.desc.values() + mapping = {code: name for (code, name) in field_type_values} + return mapping[type_code] diff --git a/dbt/adapters/mysql/impl.py b/dbt/adapters/mysql/impl.py index 7e449ef..df76878 100644 --- a/dbt/adapters/mysql/impl.py +++ b/dbt/adapters/mysql/impl.py @@ -12,6 +12,8 @@ from dbt.adapters.mysql import MySQLRelation from dbt.adapters.mysql import MySQLColumn from dbt.adapters.base import BaseRelation +from dbt.contracts.graph.nodes import ConstraintType +from dbt.adapters.base.impl import ConstraintSupport from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger from dbt.utils import executor @@ -27,6 +29,19 @@ class MySQLAdapter(SQLAdapter): Column = MySQLColumn ConnectionManager = MySQLConnectionManager + CONSTRAINT_SUPPORT = { + ConstraintType.check: ConstraintSupport.ENFORCED, + ConstraintType.not_null: ConstraintSupport.ENFORCED, + ConstraintType.unique: ConstraintSupport.ENFORCED, + ConstraintType.primary_key: ConstraintSupport.ENFORCED, + # While Foreign Keys are indeed supported, they're not supported in + # CREATE TABLE AS SELECT statements, which is what DBT uses. + # + # It is possible to use a `post-hook` to add a foreign key after the + # table is created. + ConstraintType.foreign_key: ConstraintSupport.NOT_SUPPORTED, + } + @classmethod def date_function(cls): return "current_date()" @@ -36,7 +51,8 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" - def quote(self, identifier): + @classmethod + def quote(cls, identifier: str) -> str: return "`{}`".format(identifier) def list_relations_without_caching( @@ -157,7 +173,7 @@ def _get_one_catalog( columns: List[Dict[str, Any]] = [] for relation in self.list_relations(database, schema): - logger.debug("Getting table schema for relation {}", relation) + logger.debug("Getting table schema for relation {}", str(relation)) columns.extend(self._get_columns_for_catalog(relation)) return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) diff --git a/dbt/adapters/mysql5/column.py b/dbt/adapters/mysql5/column.py index 9ce3786..931593c 100644 --- a/dbt/adapters/mysql5/column.py +++ b/dbt/adapters/mysql5/column.py @@ -8,6 +8,14 @@ @dataclass class MySQLColumn(Column): + TYPE_LABELS = { + "STRING": "TEXT", + "VAR_STRING": "TEXT", + "LONG": "INTEGER", + "LONGLONG": "INTEGER", + "INT": "INTEGER", + "TIMESTAMP": "DATETIME", + } table_database: Optional[str] = None table_schema: Optional[str] = None table_name: Optional[str] = None diff --git a/dbt/adapters/mysql5/connections.py b/dbt/adapters/mysql5/connections.py index 6199ff5..f1481a2 100644 --- a/dbt/adapters/mysql5/connections.py +++ b/dbt/adapters/mysql5/connections.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import mysql.connector +import mysql.connector.constants import dbt.exceptions from dbt.adapters.sql import SQLConnectionManager @@ -16,7 +17,8 @@ @dataclass(init=False) class MySQLCredentials(Credentials): - server: str + server: Optional[str] = None + unix_socket: Optional[str] = None port: Optional[int] = None database: Optional[str] = None schema: str @@ -61,6 +63,7 @@ def _connection_keys(self): """ return ( "server", + "unix_socket", "port", "database", "schema", @@ -80,7 +83,6 @@ def open(cls, connection): credentials = cls.get_credentials(connection.credentials) kwargs = {} - kwargs["host"] = credentials.server kwargs["user"] = credentials.username kwargs["passwd"] = credentials.password kwargs["buffered"] = True @@ -88,6 +90,11 @@ def open(cls, connection): if credentials.ssl_disabled: kwargs["ssl_disabled"] = credentials.ssl_disabled + if credentials.server: + kwargs["host"] = credentials.server + elif credentials.unix_socket: + kwargs["unix_socket"] = credentials.unix_socket + if credentials.port: kwargs["port"] = credentials.port @@ -172,3 +179,9 @@ def get_response(cls, cursor) -> AdapterResponse: rows_affected=num_rows, code=code ) + + @classmethod + def data_type_code_to_name(cls, type_code: int) -> str: + field_type_values = mysql.connector.constants.FieldType.desc.values() + mapping = {code: name for (code, name) in field_type_values} + return mapping[type_code] diff --git a/dbt/adapters/mysql5/impl.py b/dbt/adapters/mysql5/impl.py index e0d61a3..c5f7aa4 100644 --- a/dbt/adapters/mysql5/impl.py +++ b/dbt/adapters/mysql5/impl.py @@ -12,6 +12,8 @@ from dbt.adapters.mysql5 import MySQLRelation from dbt.adapters.mysql5 import MySQLColumn from dbt.adapters.base import BaseRelation +from dbt.contracts.graph.nodes import ConstraintType +from dbt.adapters.base.impl import ConstraintSupport from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER from dbt.events import AdapterLogger from dbt.utils import executor @@ -27,6 +29,19 @@ class MySQLAdapter(SQLAdapter): Column = MySQLColumn ConnectionManager = MySQLConnectionManager + CONSTRAINT_SUPPORT = { + ConstraintType.check: ConstraintSupport.NOT_SUPPORTED, + ConstraintType.not_null: ConstraintSupport.ENFORCED, + ConstraintType.unique: ConstraintSupport.ENFORCED, + ConstraintType.primary_key: ConstraintSupport.ENFORCED, + # While Foreign Keys are indeed supported, they're not supported in + # CREATE TABLE AS SELECT statements, which is what DBT uses. + # + # It is possible to use a `post-hook` to add a foreign key after the + # table is created. + ConstraintType.foreign_key: ConstraintSupport.NOT_SUPPORTED, + } + @classmethod def date_function(cls): return "current_date()" @@ -36,7 +51,8 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" - def quote(self, identifier): + @classmethod + def quote(cls, identifier: str) -> str: return "`{}`".format(identifier) def list_relations_without_caching( @@ -156,7 +172,7 @@ def _get_one_catalog( columns: List[Dict[str, Any]] = [] for relation in self.list_relations(database, schema): - logger.debug("Getting table schema for relation {}", relation) + logger.debug("Getting table schema for relation {}", str(relation)) columns.extend(self._get_columns_for_catalog(relation)) return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) diff --git a/dbt/include/mariadb/macros/adapters.sql b/dbt/include/mariadb/macros/adapters.sql index 51f7476..90fa9bd 100644 --- a/dbt/include/mariadb/macros/adapters.sql +++ b/dbt/include/mariadb/macros/adapters.sql @@ -39,14 +39,29 @@ create {% if temporary: -%}temporary{%- endif %} table {{ relation.include(database=False) }} - {{ sql }} + {% set contract_config = config.get('contract') %} + {% if contract_config.enforced %} + {{ get_assert_columns_equivalent(sql) }} + {{ get_table_columns_and_constraints() }} + {%- set sql = get_select_subquery(sql) %} + ( + {{ sql }} + ) + {% else %} + {{ sql }} + {% endif %} {% endmacro %} {% macro mariadb__create_view_as(relation, sql) -%} {%- set sql_header = config.get('sql_header', none) -%} {{ sql_header if sql_header is not none }} - create view {{ relation }} as + create view {{ relation }} + {% set contract_config = config.get('contract') %} + {% if contract_config.enforced %} + {{ get_assert_columns_equivalent(sql) }} + {%- endif %} + as {{ sql }} {%- endmacro %} @@ -101,3 +116,38 @@ {% macro mariadb__generate_database_name(custom_database_name=none, node=none) -%} {% do return(None) %} {%- endmacro %} + +{% macro mariadb__get_phony_data_for_type(data_type) %} + {# + The types that MariaDB supports in CAST statements are NOT the same as the + types that are supported in table definitions. This is a bit of a hack to + work around the known mismatches. + #} + {%- if data_type.lower() == 'integer' -%} + 0 + {%- elif data_type.lower() == 'text' -%} + '' + {%- elif data_type.lower() == 'integer unsigned' -%} + cast(null as unsigned) + {%- elif data_type.lower() == 'integer signed' -%} + cast(null as signed) + {%- else -%} + cast(null as {{ data_type }}) + {%- endif -%} +{% endmacro %} + +{% macro mariadb__get_empty_schema_sql(columns) %} + {%- set col_err = [] -%} + select + {% for i in columns %} + {%- set col = columns[i] -%} + {%- if col['data_type'] is not defined -%} + {{ col_err.append(col['name']) }} + {%- endif -%} + {% set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] %} + {{ mariadb__get_phony_data_for_type(col['data_type']) }} as {{ col_name }}{{ ", " if not loop.last }} + {%- endfor -%} + {%- if (col_err | length) > 0 -%} + {{ exceptions.column_type_missing(column_names=col_err) }} + {%- endif -%} +{% endmacro %} diff --git a/dbt/include/mysql/macros/adapters.sql b/dbt/include/mysql/macros/adapters.sql index eada89e..3b8b4d5 100644 --- a/dbt/include/mysql/macros/adapters.sql +++ b/dbt/include/mysql/macros/adapters.sql @@ -39,9 +39,17 @@ create {% if temporary: -%}temporary{%- endif %} table {{ relation.include(database=False) }} - as ( - {{ sql }} - ) + {% set contract_config = config.get('contract') %} + {% if contract_config.enforced %} + {{ get_assert_columns_equivalent(sql) }} + {{ get_table_columns_and_constraints() }} + {%- set sql = get_select_subquery(sql) %} + {% else %} + as + {% endif %} + ( + {{ sql }} + ) {% endmacro %} {% macro mysql__current_timestamp() -%} @@ -95,3 +103,38 @@ {% macro mysql__generate_database_name(custom_database_name=none, node=none) -%} {% do return(None) %} {%- endmacro %} + +{% macro mysql__get_phony_data_for_type(data_type) %} + {# + The types that MySQL supports in CAST statements are NOT the same as the + types that are supported in table definitions. This is a bit of a hack to + work around the known mismatches. + #} + {%- if data_type.lower() == 'integer' -%} + 0 + {%- elif data_type.lower() == 'text' -%} + '' + {%- elif data_type.lower() == 'integer unsigned' -%} + cast(null as unsigned) + {%- elif data_type.lower() == 'integer signed' -%} + cast(null as signed) + {%- else -%} + cast(null as {{ data_type }}) + {%- endif -%} +{% endmacro %} + +{% macro mysql__get_empty_schema_sql(columns) %} + {%- set col_err = [] -%} + select + {% for i in columns %} + {%- set col = columns[i] -%} + {%- if col['data_type'] is not defined -%} + {{ col_err.append(col['name']) }} + {%- endif -%} + {% set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] %} + {{ mysql__get_phony_data_for_type(col['data_type']) }} as {{ col_name }}{{ ", " if not loop.last }} + {%- endfor -%} + {%- if (col_err | length) > 0 -%} + {{ exceptions.column_type_missing(column_names=col_err) }} + {%- endif -%} +{% endmacro %} diff --git a/dbt/include/mysql5/macros/adapters.sql b/dbt/include/mysql5/macros/adapters.sql index 624d282..5c0c9bc 100644 --- a/dbt/include/mysql5/macros/adapters.sql +++ b/dbt/include/mysql5/macros/adapters.sql @@ -39,7 +39,17 @@ create {% if temporary: -%}temporary{%- endif %} table {{ relation.include(database=False) }} - {{ sql }} + {% set contract_config = config.get('contract') %} + {% if contract_config.enforced %} + {{ get_assert_columns_equivalent(sql) }} + {{ get_table_columns_and_constraints() }} + {%- set sql = get_select_subquery(sql) %} + ( + {{ sql }} + ) + {% else %} + {{ sql }} + {% endif %} {% endmacro %} {% macro mysql5__current_timestamp() -%} @@ -93,3 +103,38 @@ {% macro mysql5__generate_database_name(custom_database_name=none, node=none) -%} {% do return(None) %} {%- endmacro %} + +{% macro mysql5__get_phony_data_for_type(data_type) %} + {# + The types that MySQL supports in CAST statements are NOT the same as the + types that are supported in table definitions. This is a bit of a hack to + work around the known mismatches. + #} + {%- if data_type.lower() == 'integer' -%} + 0 + {%- elif data_type.lower() == 'text' -%} + '' + {%- elif data_type.lower() == 'integer unsigned' -%} + cast(null as unsigned) + {%- elif data_type.lower() == 'integer signed' -%} + cast(null as signed) + {%- else -%} + cast(null as {{ data_type }}) + {%- endif -%} +{% endmacro %} + +{% macro mysql5__get_empty_schema_sql(columns) %} + {%- set col_err = [] -%} + select + {% for i in columns %} + {%- set col = columns[i] -%} + {%- if col['data_type'] is not defined -%} + {{ col_err.append(col['name']) }} + {%- endif -%} + {% set col_name = adapter.quote(col['name']) if col.get('quote') else col['name'] %} + {{ mysql5__get_phony_data_for_type(col['data_type']) }} as {{ col_name }}{{ ", " if not loop.last }} + {%- endfor -%} + {%- if (col_err | length) > 0 -%} + {{ exceptions.column_type_missing(column_names=col_err) }} + {%- endif -%} +{% endmacro %} diff --git a/tests/functional/adapter/constraints/fixtures.py b/tests/functional/adapter/constraints/fixtures.py new file mode 100644 index 0000000..2b860f4 --- /dev/null +++ b/tests/functional/adapter/constraints/fixtures.py @@ -0,0 +1,320 @@ +# model breaking constraints +my_model_with_nulls_sql = """ +{{ + config( + materialized = "table" + ) +}} + +select + -- null value for 'id' + CAST(null AS UNSIGNED) as id, + -- change the color as well (to test rollback) + 'red' as color, + '2019-01-01' as date_day +""" + + +my_model_view_with_nulls_sql = """ +{{ + config( + materialized = "view" + ) +}} + +select + -- null value for 'id' + CAST(null AS UNSIGNED) as id, + -- change the color as well (to test rollback) + 'red' as color, + '2019-01-01' as date_day +""" + +my_model_incremental_with_nulls_sql = """ +{{ + config( + materialized = "incremental", + on_schema_change='append_new_columns' ) +}} + +select + -- null value for 'id' + CAST(null AS UNSIGNED) as id, + -- change the color as well (to test rollback) + 'red' as color, + '2019-01-01' as date_day +""" + +# model columns data types different to schema definitions +my_model_contract_sql_header_sql = """ +{{ + config( + materialized = "table" + ) +}} + +select 'Kolkata' as column_name +""" + +my_model_incremental_contract_sql_header_sql = """ +{{ + config( + materialized = "incremental", + on_schema_change="append_new_columns" + ) +}} + +select 'Kolkata' as column_name +""" + +constrained_model_schema_yml = """ +version: 2 +models: + - name: my_model + config: + contract: + enforced: true + constraints: + - type: check + expression: (id > 0) + - type: check + expression: id >= 1 + - type: primary_key + columns: [ id ] + - type: unique + columns: [ color(10), date_day(20) ] + name: strange_uniqueness_requirement + - type: foreign_key + columns: [ id ] + expression: {schema}.foreign_key_model (id) + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text + - name: foreign_key_model + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + constraints: + - type: unique + - type: primary_key +""" + +model_quoted_column_schema_yml = """ +version: 2 +models: + - name: my_model + config: + contract: + enforced: true + materialized: table + constraints: + - type: check + # this one is the on the user + expression: (`from` = 'blue') + columns: [ '`from`' ] + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + tests: + - unique + - name: from # reserved word + quote: true + data_type: text + constraints: + - type: not_null + - name: date_day + data_type: text +""" + +# MariaDB does not support multiple column-level CHECK constraints +mariadb_model_schema_yml = """ +version: 2 +models: + - name: my_model + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + - type: primary_key + - type: check + expression: id > 0 AND id >= 1 + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text + - name: my_model_error + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + - type: primary_key + - type: check + expression: (id > 0) + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text + - name: my_model_wrong_order + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + - type: primary_key + - type: check + expression: (id > 0) + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text + - name: my_model_wrong_name + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + - type: primary_key + - type: check + expression: (id > 0) + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text +""" + +# MariaDB does not support multiple column-level CHECK constraints +# Additionally, MariaDB requires CHECK constraints to come last +mariadb_model_fk_constraint_schema_yml = """ +version: 2 +models: + - name: my_model + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + - type: primary_key + - type: foreign_key + expression: {schema}.foreign_key_model (id) + - type: unique + - type: check + expression: id > 0 AND id >= 1 + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text + - name: my_model_error + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + - type: primary_key + - type: check + expression: (id > 0) + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text + - name: my_model_wrong_order + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + - type: primary_key + - type: check + expression: (id > 0) + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text + - name: my_model_wrong_name + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + description: hello + constraints: + - type: not_null + - type: primary_key + - type: check + expression: (id > 0) + tests: + - unique + - name: color + data_type: text + - name: date_day + data_type: text + - name: foreign_key_model + config: + contract: + enforced: true + columns: + - name: id + data_type: integer + constraints: + - type: unique + - type: primary_key +""" diff --git a/tests/functional/adapter/constraints/test_constraints.py b/tests/functional/adapter/constraints/test_constraints.py new file mode 100644 index 0000000..78b6710 --- /dev/null +++ b/tests/functional/adapter/constraints/test_constraints.py @@ -0,0 +1,365 @@ +import pytest + +from dbt.tests.adapter.constraints.test_constraints import ( + BaseTableConstraintsColumnsEqual, + BaseViewConstraintsColumnsEqual, + BaseTableContractSqlHeader, + BaseIncrementalContractSqlHeader, + BaseIncrementalConstraintsColumnsEqual, + BaseConstraintsRuntimeDdlEnforcement, + BaseConstraintsRollback, + BaseIncrementalConstraintsRuntimeDdlEnforcement, + BaseIncrementalConstraintsRollback, + BaseModelConstraintsRuntimeEnforcement, + BaseConstraintQuotedColumn, +) + +from dbt.tests.adapter.constraints.fixtures import ( + my_incremental_model_sql, + model_contract_header_schema_yml, + model_schema_yml, + my_model_wrong_order_depends_on_fk_sql, + foreign_key_model_sql, + my_model_with_quoted_column_name_sql, + my_model_incremental_wrong_order_depends_on_fk_sql, + model_fk_constraint_schema_yml, +) + +from tests.functional.adapter.constraints.fixtures import ( + my_model_with_nulls_sql, + my_model_incremental_with_nulls_sql, + my_model_contract_sql_header_sql, + my_model_incremental_contract_sql_header_sql, + mariadb_model_schema_yml, + mariadb_model_fk_constraint_schema_yml, + constrained_model_schema_yml, + model_quoted_column_schema_yml, +) + + +class MySQLColumnEqualSetup: + @pytest.fixture + def int_type(self): + return "INTEGER" + + @pytest.fixture + def schema_int_type(self): + return "INTEGER" + + @pytest.fixture + def data_types(self, int_type, schema_int_type, string_type): + # sql_column_value, schema_data_type, error_data_type + return [ + ["1", schema_int_type, int_type], + ["'str'", string_type, string_type], + ["cast('2019-01-01' as date)", "date", "DATE"], + ["cast('2013-11-03 00:00:00' as datetime)", "datetime", "DATETIME"], + ] + + +class TestMySQLTableConstraintsColumnsEqual( + MySQLColumnEqualSetup, BaseTableConstraintsColumnsEqual +): + pass + + +class TestMySQLViewConstraintsColumnsEqual( + MySQLColumnEqualSetup, BaseViewConstraintsColumnsEqual +): + pass + + +class TestMySQLIncrementalConstraintsColumnsEqual( + MySQLColumnEqualSetup, BaseIncrementalConstraintsColumnsEqual +): + pass + + +class TestMySQLTableContractsSqlHeader(BaseTableContractSqlHeader): + @pytest.fixture(scope="class") + def models(self): + return { + "my_model_contract_sql_header.sql": my_model_contract_sql_header_sql, + "constraints_schema.yml": model_contract_header_schema_yml, + } + + +class TestMySQLIncrementalContractsSqlHeader(BaseIncrementalContractSqlHeader): + @pytest.fixture(scope="class") + def models(self): + return { + "my_model_contract_sql_header.sql": my_model_incremental_contract_sql_header_sql, + "constraints_schema.yml": model_contract_header_schema_yml, + } + + +# MySQL 5 does not support CHECK constraints +_expected_mysql5_ddl_enforcement_sql = """ + create table ( + id integer not null primary key unique, + color text, + date_day text + ) + ( + select id, color, date_day + from ( + -- depends_on: + select + 'blue' as color, + 1 as id, + '2019-01-01' as date_day + ) as model_subq + ) +""" + +# MariaDB does not support multiple column-level CHECK constraints +# Additionally, MariaDB requires CHECK constraints to come last +_expected_mariadb_ddl_enforcement_sql = """ + create table ( + id integer not null primary key unique check (id > 0 AND id >= 1), + color text, + date_day text + ) + ( + select id, color, date_day + from ( + -- depends_on: + select + 'blue' as color, + 1 as id, + '2019-01-01' as date_day + ) as model_subq + ) +""" + +_expected_mysql_ddl_enforcement_sql = """ + create table ( + id integer not null primary key check ((id > 0)) check (id >= 1) unique, + color text, + date_day text + ) + ( + select id, color, date_day + from ( + -- depends_on: + select + 'blue' as color, + 1 as id, + '2019-01-01' as date_day + ) as model_subq + ) +""" + + +class TestMySQLTableConstraintsDdlEnforcement(BaseConstraintsRuntimeDdlEnforcement): + @pytest.fixture(scope="class") + def models(self, dbt_profile_target): + if dbt_profile_target["type"] == "mariadb": + return { + "my_model.sql": my_model_incremental_wrong_order_depends_on_fk_sql, + "foreign_key_model.sql": foreign_key_model_sql, + "constraints_schema.yml": mariadb_model_fk_constraint_schema_yml, + } + else: + return { + "my_model.sql": my_model_incremental_wrong_order_depends_on_fk_sql, + "foreign_key_model.sql": foreign_key_model_sql, + "constraints_schema.yml": model_fk_constraint_schema_yml, + } + + @pytest.fixture(scope="class") + def expected_sql(self, dbt_profile_target): + if dbt_profile_target["type"] == "mysql5": + return _expected_mysql5_ddl_enforcement_sql + elif dbt_profile_target["type"] == "mariadb": + return _expected_mariadb_ddl_enforcement_sql + else: + return _expected_mysql_ddl_enforcement_sql + + +class TestMySQLIncrementalConstraintsDdlEnforcement( + BaseIncrementalConstraintsRuntimeDdlEnforcement +): + @pytest.fixture(scope="class") + def models(self, dbt_profile_target): + if dbt_profile_target["type"] == "mariadb": + return { + "my_model.sql": my_model_incremental_wrong_order_depends_on_fk_sql, + "foreign_key_model.sql": foreign_key_model_sql, + "constraints_schema.yml": mariadb_model_fk_constraint_schema_yml, + } + else: + return { + "my_model.sql": my_model_incremental_wrong_order_depends_on_fk_sql, + "foreign_key_model.sql": foreign_key_model_sql, + "constraints_schema.yml": model_fk_constraint_schema_yml, + } + + @pytest.fixture(scope="class") + def expected_sql(self, dbt_profile_target): + if dbt_profile_target["type"] == "mysql5": + return _expected_mysql5_ddl_enforcement_sql + elif dbt_profile_target["type"] == "mariadb": + return _expected_mariadb_ddl_enforcement_sql + else: + return _expected_mysql_ddl_enforcement_sql + + +class TestMySQLTableConstraintsRollback(BaseConstraintsRollback): + @pytest.fixture(scope="class") + def models(self, dbt_profile_target): + if dbt_profile_target["type"] == "mariadb": + return { + "my_model.sql": my_incremental_model_sql, + "constraints_schema.yml": mariadb_model_schema_yml, + } + else: + return { + "my_model.sql": my_incremental_model_sql, + "constraints_schema.yml": model_schema_yml, + } + + @pytest.fixture(scope="class") + def expected_error_messages(self): + return ["Column 'id' cannot be null"] + + @pytest.fixture(scope="class") + def null_model_sql(self): + return my_model_with_nulls_sql + + +class TestMySQLIncrementalConstraintsRollback(BaseIncrementalConstraintsRollback): + @pytest.fixture(scope="class") + def models(self, dbt_profile_target): + if dbt_profile_target["type"] == "mariadb": + return { + "my_model.sql": my_incremental_model_sql, + "constraints_schema.yml": mariadb_model_schema_yml, + } + else: + return { + "my_model.sql": my_incremental_model_sql, + "constraints_schema.yml": model_schema_yml, + } + + @pytest.fixture(scope="class") + def expected_error_messages(self): + return ["Column 'id' cannot be null"] + + @pytest.fixture(scope="class") + def null_model_sql(self): + return my_model_incremental_with_nulls_sql + + +# MySQL 5 does not support CHECK constraints +_expected_mysql5_runtime_enforcement_sql = """ + create table ( + id integer not null, + color text, + date_day text, + primary key (id), + constraint strange_uniqueness_requirement unique (color(10), date_day(20)) + ) + ( + select id, color, date_day + from ( + -- depends_on: + select + 'blue' as color, + 1 as id, + '2019-01-01' as date_day + ) as model_subq + ) +""" + +_expected_mysql_runtime_enforcement_sql = """ + create table ( + id integer not null, + color text, + date_day text, + check ((id > 0)), + check (id >= 1), + primary key (id), + constraint strange_uniqueness_requirement unique (color(10), date_day(20)) + ) + ( + select id, color, date_day + from ( + -- depends_on: + select + 'blue' as color, + 1 as id, + '2019-01-01' as date_day + ) as model_subq + ) +""" + + +class TestMySQLModelConstraintsRuntimeEnforcement(BaseModelConstraintsRuntimeEnforcement): + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_wrong_order_depends_on_fk_sql, + "foreign_key_model.sql": foreign_key_model_sql, + "constraints_schema.yml": constrained_model_schema_yml, + } + + @pytest.fixture(scope="class") + def expected_sql(self, dbt_profile_target): + if dbt_profile_target["type"] == "mysql5": + return _expected_mysql5_runtime_enforcement_sql + else: + return _expected_mysql_runtime_enforcement_sql + + +# MySQL 5 does not support CHECK constraints +_expected_mysql5_quoted_column_sql = """ + create table ( + id integer not null, + `from` text not null, + date_day text + ) + ( + select id, `from`, date_day + from ( + select + 'blue' as `from`, + 1 as id, + '2019-01-01' as date_day + ) as model_subq + ) +""" + +_expected_mysql_quoted_column_sql = """ + create table ( + id integer not null, + `from` text not null, + date_day text, + check ((`from` = 'blue')) + ) + ( + select id, `from`, date_day + from ( + select + 'blue' as `from`, + 1 as id, + '2019-01-01' as date_day + ) as model_subq + ) +""" + + +class TestMySQLConstraintQuotedColumn(BaseConstraintQuotedColumn): + @pytest.fixture(scope="class") + def models(self): + return { + "my_model.sql": my_model_with_quoted_column_name_sql, + "constraints_schema.yml": model_quoted_column_schema_yml, + } + + @pytest.fixture(scope="class") + def expected_sql(self, dbt_profile_target): + if dbt_profile_target["type"] == "mysql5": + return _expected_mysql5_quoted_column_sql + else: + return _expected_mysql_quoted_column_sql diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 8c499d5..06af385 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -8,7 +8,6 @@ class TestMySQLAdapter(unittest.TestCase): def setUp(self): - pass flags.STRICT_MODE = True profile_cfg = { diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 07371d3..74d8483 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -44,6 +44,13 @@ def profile_from_dict(profile, profile_name, cli_vars="{}"): cli_vars = parse_cli_vars(cli_vars) renderer = ProfileRenderer(cli_vars) + + # in order to call dbt's internal profile rendering, we need to set the + # flags global. This is a bit of a hack, but it's the best way to do it. + from dbt.flags import set_from_args + from argparse import Namespace + + set_from_args(Namespace(), None) return Profile.from_raw_profile_info( profile, profile_name, @@ -72,7 +79,7 @@ def project_from_dict(project, profile, packages=None, selectors=None, cli_vars= def config_from_parts_or_dicts( - project, profile, packages=None, selectors=None, cli_vars="{}" + project, profile, packages=None, selectors=None, cli_vars={} ): from dbt.config import Project, Profile, RuntimeConfig from copy import deepcopy