From a3bab12976e520d433ecdd7a5e9340c3fea5ebc8 Mon Sep 17 00:00:00 2001 From: dwreeves Date: Sun, 29 Oct 2023 22:03:34 -0400 Subject: [PATCH] fixes --- .github/workflows/tests.yml | 7 ++- .gitignore | 1 + README.md | 15 ++++-- integration_tests/dbt_project.yml | 2 +- .../collinear_matrix_regression_chol.sql | 3 +- ...ear_matrix_regression_chol_unoptimized.sql | 3 +- ...collinear_matrix_ridge_regression_chol.sql | 3 +- ...trix_ridge_regression_chol_unoptimized.sql | 3 +- ...oups_matrix_regression_chol_optimized.sql} | 6 ++- ...ups_matrix_regression_chol_unoptimized.sql | 3 +- .../models/simple_10var_regression_long.sql | 3 +- .../models/simple_4var_regression_wide.sql | 3 +- .../models/simple_5var_regression_long.sql | 3 +- .../models/simple_5var_regression_wide.sql | 3 +- .../models/simple_8var_regression_wide.sql | 4 +- integration_tests/profiles/profiles.yml | 8 ++-- integration_tests/selectors.yml | 11 +++++ ...oups_matrix_regression_chol_optimized.sql} | 2 +- .../tests/test_long_format_options.sql | 12 ++--- .../ols_impl_chol/_ols_impl_chol.sql | 3 ++ macros/linear_regression/utils.sql | 29 +++++++++--- run | 47 ++++++++++++++----- 22 files changed, 124 insertions(+), 50 deletions(-) rename integration_tests/models/{groups_matrix_regression_chol.sql => groups_matrix_regression_chol_optimized.sql} (64%) create mode 100644 integration_tests/selectors.yml rename integration_tests/tests/{test_groups_matrix_regression_chol.sql => test_groups_matrix_regression_chol_optimized.sql} (96%) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f657785..2e7972e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,8 +19,11 @@ jobs: postgres: image: postgres env: + POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres - POSTGRES_DB: postgres + POSTGRES_DB: dbt_linreg + ports: + - 5432:5432 options: >- --health-cmd pg_isready --health-interval 10s @@ -55,5 +58,7 @@ jobs: run: ./run test "${DBT_TARGET}" env: DBT_TARGET: ${{ matrix.db_target }} + POSTGRES_HOST: localhost POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres + POSTGRES_DB: dbt_linreg diff --git a/.gitignore b/.gitignore index c076aed..0ce0028 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ dbt.duckdb +dbt.duckdb.wal .user.yml docs/site/ integration_tests/seeds/*.csv diff --git a/README.md b/README.md index b8a4230..cbab499 100644 --- a/README.md +++ b/README.md @@ -168,11 +168,13 @@ group by **dbt_linreg** should work with most SQL databases, but so far, testing has been done for the following database tools: - Snowflake -- Postgres - DuckDB +- Postgres\* If `dbt_linreg` does not work in your database tool, please let me know in a bug report and I can make sure it is supported. +_* Syntactically supported, but not performant._ + # API The only function available in the public API is the `dbt_linreg.ols()` macro. @@ -256,7 +258,7 @@ This method calculates regression coefficients using the Moore-Penrose pseudo-in Specify these in a dict using the `method_options=` kwarg: - **safe** (default = `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value will be passed into a SQRT(), and most SQL engines will raise an error when this happens. -- **subquery_optimization** (default = `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened. Note that turning this off can significantly degrade performance. +- **subquery_optimization** (default: `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened. ## `fwl` method @@ -270,10 +272,12 @@ Ridge regression is implemented using the augmentation technique described in Ex There are a few reasons why this method is discouraged over the `chol` method: -- 🐌 It tends to be much slower, and struggles to efficiently calculate large number of columns. +- 🐌 It tends to be much slower in OLAP systems, and struggles to efficiently calculate large number of columns. - 📊 It does not calculate standard errors. - 😕 For ridge regression, coefficients are not accurate; they tend to be off by a magnitude of ~0.01%. +So when should you use `fwl`? The main use case is in OLTP systems (e.g. Postgres) for unregularized coefficient estimation. Long story short, the `chol` method relies on subquery optimization to be more performant than `fwl`; however, OLTP systems do not benefit at all from subquery optimization. This means that `fwl` is slightly more performant in this context. + # Notes - ⚠️ **If your coefficients are null, it does not mean dbt_linreg is broken, it most likely means your feature columns are perfectly multicollinear.** If you are 100% sure that is not the issue, please file a bug report with a minimally reproducible example. @@ -283,6 +287,11 @@ There are a few reasons why this method is discouraged over the `chol` method: - An array input (e.g. `alpha=[0.01, 0.02, 0.03, 0.04, 0.05]`) will apply an alpha of `0.01` to the first column, `0.02` to the second column, etc. - `alpha` is equivalent to what TEoSL refers to as "lambda," times the sample size N. That is to say: `α ≡ λ * N`. +- Regularization as currently implemented for the `chol` method tends to be very slow in OLTP systems (e.g. Postgres), but is very performant in OLAP systems (e.g. Snowflake, DuckDB, BigQuery, Redshift). As dbt is more commonly used in OLAP contexts, the code base is optimized for the OLAP use case. + - That said, it may be possible to make regularization in OLTP more performant (e.g. with augmentation of the design matrix), so PRs are welcome. + +- Regression coefficients in Postgres are always `numeric` types. + ### Possible future features Some things I am thinking about working on down the line: diff --git a/integration_tests/dbt_project.yml b/integration_tests/dbt_project.yml index c85ee09..88472f4 100644 --- a/integration_tests/dbt_project.yml +++ b/integration_tests/dbt_project.yml @@ -1,5 +1,5 @@ name: "dbt_linreg_tests" -version: "0.2.1" +version: "0.2.3" require-dbt-version: [">=1.0.0", "<2.0.0"] diff --git a/integration_tests/models/collinear_matrix_regression_chol.sql b/integration_tests/models/collinear_matrix_regression_chol.sql index 470875d..0874ed7 100644 --- a/integration_tests/models/collinear_matrix_regression_chol.sql +++ b/integration_tests/models/collinear_matrix_regression_chol.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql b/integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql index 4bae3ab..ae8a0da 100644 --- a/integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql +++ b/integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/collinear_matrix_ridge_regression_chol.sql b/integration_tests/models/collinear_matrix_ridge_regression_chol.sql index ba7d4e3..ccb622e 100644 --- a/integration_tests/models/collinear_matrix_ridge_regression_chol.sql +++ b/integration_tests/models/collinear_matrix_ridge_regression_chol.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql b/integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql index d7465d5..e0e5160 100644 --- a/integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql +++ b/integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/groups_matrix_regression_chol.sql b/integration_tests/models/groups_matrix_regression_chol_optimized.sql similarity index 64% rename from integration_tests/models/groups_matrix_regression_chol.sql rename to integration_tests/models/groups_matrix_regression_chol_optimized.sql index 83cb626..0b6a836 100644 --- a/integration_tests/models/groups_matrix_regression_chol.sql +++ b/integration_tests/models/groups_matrix_regression_chol_optimized.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ @@ -10,7 +11,8 @@ select * from {{ exog=['x1', 'x2', 'x3'], group_by=['gb_var'], format='long', - method='chol' + method='chol', + method_options={'subquery_optimization': True} ) }} as linreg order by gb_var, variable_name diff --git a/integration_tests/models/groups_matrix_regression_chol_unoptimized.sql b/integration_tests/models/groups_matrix_regression_chol_unoptimized.sql index 7c3a3f0..0a6e718 100644 --- a/integration_tests/models/groups_matrix_regression_chol_unoptimized.sql +++ b/integration_tests/models/groups_matrix_regression_chol_unoptimized.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/simple_10var_regression_long.sql b/integration_tests/models/simple_10var_regression_long.sql index 6c881a1..e1d16c2 100644 --- a/integration_tests/models/simple_10var_regression_long.sql +++ b/integration_tests/models/simple_10var_regression_long.sql @@ -1,7 +1,8 @@ {{ config( materialized="view", - enabled=False + enabled=False, + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/simple_4var_regression_wide.sql b/integration_tests/models/simple_4var_regression_wide.sql index fd4f7f7..8c6cedb 100644 --- a/integration_tests/models/simple_4var_regression_wide.sql +++ b/integration_tests/models/simple_4var_regression_wide.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/simple_5var_regression_long.sql b/integration_tests/models/simple_5var_regression_long.sql index 37a6215..466fbfc 100644 --- a/integration_tests/models/simple_5var_regression_long.sql +++ b/integration_tests/models/simple_5var_regression_long.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/simple_5var_regression_wide.sql b/integration_tests/models/simple_5var_regression_wide.sql index d5e59ff..9c21289 100644 --- a/integration_tests/models/simple_5var_regression_wide.sql +++ b/integration_tests/models/simple_5var_regression_wide.sql @@ -1,6 +1,7 @@ {{ config( - materialized="table" + materialized="table", + tags=["skip-postgres"] ) }} select * from {{ diff --git a/integration_tests/models/simple_8var_regression_wide.sql b/integration_tests/models/simple_8var_regression_wide.sql index 3ede86e..3b0f9b4 100644 --- a/integration_tests/models/simple_8var_regression_wide.sql +++ b/integration_tests/models/simple_8var_regression_wide.sql @@ -1,8 +1,8 @@ {{ config( materialized="view", - tags=["perftest"], - enabled=False + tags=["perftest", "skip-postgres"], + enabled=False, ) }} select * from {{ diff --git a/integration_tests/profiles/profiles.yml b/integration_tests/profiles/profiles.yml index 9557e16..a2b454b 100644 --- a/integration_tests/profiles/profiles.yml +++ b/integration_tests/profiles/profiles.yml @@ -5,12 +5,10 @@ dbt_linreg_profile: type: duckdb path: dbt.duckdb dbt-postgres: - # This is configured for Github Actions. - # For local configuration, set env vars. type: postgres - user: '{{ env_var("POSTGRES_USER", "postgres") }}' - password: '{{ env_var("POSTGRES_PASSWORD", "postgres") }}' + user: '{{ env_var("POSTGRES_USER") }}' + password: '{{ env_var("POSTGRES_PASSWORD") }}' host: '{{ env_var("POSTGRES_HOST", "localhost") }}' port: '{{ env_var("POSTGRES_PORT", "5432") | as_number }}' - dbname: '{{ env_var("POSTGRES_DB", "postgres") }}' + dbname: '{{ env_var("POSTGRES_DB", "dbt_linreg") }}' schema: '{{ env_var("POSTGRES_SCHEMA", "public") }}' diff --git a/integration_tests/selectors.yml b/integration_tests/selectors.yml new file mode 100644 index 0000000..37e3554 --- /dev/null +++ b/integration_tests/selectors.yml @@ -0,0 +1,11 @@ +selectors: + - name: dbt-duckdb-selector + definition: 'fqn:*' + - name: dbt-postgres-selector + # Postgres runs into memory / performance issues for some of these queries. + # Resolving this and making Postgres more performant is a TODO. + definition: + union: + - 'fqn:*' + - exclude: + - '@tag:skip-postgres' diff --git a/integration_tests/tests/test_groups_matrix_regression_chol.sql b/integration_tests/tests/test_groups_matrix_regression_chol_optimized.sql similarity index 96% rename from integration_tests/tests/test_groups_matrix_regression_chol.sql rename to integration_tests/tests/test_groups_matrix_regression_chol_optimized.sql index 3bd35cd..22075ae 100644 --- a/integration_tests/tests/test_groups_matrix_regression_chol.sql +++ b/integration_tests/tests/test_groups_matrix_regression_chol_optimized.sql @@ -21,7 +21,7 @@ expected as ( ) select base.variable_name -from {{ ref('groups_matrix_regression_chol') }} as base +from {{ ref('groups_matrix_regression_chol_optimized') }} as base full outer join expected on base.gb_var = expected.gb_var diff --git a/integration_tests/tests/test_long_format_options.sql b/integration_tests/tests/test_long_format_options.sql index 5b35bbb..1d8f12b 100644 --- a/integration_tests/tests/test_long_format_options.sql +++ b/integration_tests/tests/test_long_format_options.sql @@ -10,8 +10,8 @@ base as ( find_unstripped_quotes as ( select - max(vname = '"xa"') as should_be_true, - max(vname = 'xa') as should_be_false + cast(max(cast(vname = '"xa"' as integer)) as boolean) as should_be_true, + cast(max(cast(vname = 'xa' as integer)) as boolean) as should_be_false from base where not strip_quotes @@ -20,8 +20,8 @@ find_unstripped_quotes as ( dodge_unstripped_quotes as ( select - max(vname = 'xa') as should_be_true, - max(vname = '"xa"') as should_be_false + cast(max(cast(vname = 'xa' as integer)) as boolean) as should_be_true, + cast(max(cast(vname = '"xa"' as integer)) as boolean) as should_be_false from base where strip_quotes @@ -30,8 +30,8 @@ dodge_unstripped_quotes as ( coef_col_name as ( select - max(vname = 'constant_term') as should_be_true, - max(vname = 'const') as should_be_false + cast(max(cast(vname = 'constant_term' as integer)) as boolean) as should_be_true, + cast(max(cast(vname = 'const' as integer)) as boolean) as should_be_false from base ) diff --git a/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql b/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql index 26c8674..20b5073 100644 --- a/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql +++ b/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql @@ -1,6 +1,7 @@ {# In some warehouses, you can reference newly created column aliases in the query you wrote. If that's not available, the previous calc will be in the dict. #} + {% macro _cell_or_alias(i, j, d, prefix=none) %} {{ return( adapter.dispatch('_cell_or_alias', 'dbt_linreg') @@ -60,6 +61,7 @@ {% if i == j %} {% do d.update({(i, j): dbt_linreg._safe_sqrt(x=ns.s, safe=safe)}) %} {% else %} + {# do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d) ~ ', 0)'}) #} {% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d)}) %} {% endif %} {% endif %} @@ -81,6 +83,7 @@ {% endfor %} {% set ns.numerator = ns.numerator~')' %} {% endif %} + {# do d.update({(i, j): '('~ns.numerator~'/nullif(i'~j~'j'~j~', 0))'}) #} {% do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %} {% endfor %} {{ return(d) }} diff --git a/macros/linear_regression/utils.sql b/macros/linear_regression/utils.sql index e3f66eb..d5cbaab 100644 --- a/macros/linear_regression/utils.sql +++ b/macros/linear_regression/utils.sql @@ -43,7 +43,7 @@ {%- if add_constant %} select {{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }} - '{{ format_options.get('constant_name', 'const') }}' as {{ format_options.get('variable_column_name', 'variable_name') }}, + {{ dbt.string_literal(format_options.get('constant_name', 'const')) }} as {{ format_options.get('variable_column_name', 'variable_name') }}, {{ dbt_linreg._maybe_round('x0_coef', format_options.get('round')) }} as {{ format_options.get('coefficient_column_name', 'coefficient') }}{% if calculate_standard_error %}, {{ dbt_linreg._maybe_round('x0_stderr', format_options.get('round')) }} as {{ format_options.get('standard_error_column_name', 'standard_error') }}, {{ dbt_linreg._maybe_round('x0_coef/x0_stderr', format_options.get('round')) }} as {{ format_options.get('t_statistic_column_name', 't_statistic') }} @@ -59,7 +59,7 @@ union all {%- for i in exog_aliased %} select {{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }} - '{{ dbt_linreg._strip_quotes(exog[loop.index0], format_options) }}' as {{ format_options.get('variable_column_name', 'variable_name') }}, + {{ dbt.string_literal(dbt_linreg._strip_quotes(exog[loop.index0], format_options)) }} as {{ format_options.get('variable_column_name', 'variable_name') }}, {{ dbt_linreg._maybe_round(i~'_coef', format_options.get('round')) }} as {{ format_options.get('coefficient_column_name', 'coefficient') }}{% if calculate_standard_error %}, {{ dbt_linreg._maybe_round(i~'_stderr', format_options.get('round')) }} as {{ format_options.get('standard_error_column_name', 'standard_error') }}, {{ dbt_linreg._maybe_round(i~'_coef/'~i~'_stderr', format_options.get('round')) }} as {{ format_options.get('t_statistic_column_name', 't_statistic') }} @@ -154,12 +154,27 @@ gb{{ loop.index }} as {{ gb }}, {# Round the final coefficient if the user specifies the `round` format option. Otherwise, keep as is. #} + {% macro _maybe_round(x, round_) %} -{% if round_ is not none %} - {{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }} -{% else %} - {{ return(x) }} -{% endif %} + {{ return( + adapter.dispatch('_maybe_round', 'dbt_linreg')(x, round_) + ) }} +{% endmacro %} + +{% macro default___maybe_round(x, round_) %} + {% if round_ is not none %} + {{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }} + {% else %} + {{ return(x) }} + {% endif %} +{% endmacro %} + +{% macro postgres___maybe_round(x, round_) %} + {% if round_ is not none %} + {{ return('round((' ~ x ~ ')::numeric, ' ~ round_ ~ ')') }} + {% else %} + {{ return('(' ~ x ~ ')::numeric') }} + {% endif %} {% endmacro %} {# Alias and write group by columns in a standard way. #} diff --git a/run b/run index a78ab51..3562fc6 100755 --- a/run +++ b/run @@ -2,27 +2,48 @@ set -eo pipefail +if [ -f .env ]; then + # shellcheck disable=SC2002,SC2046 + export $(cat .env | xargs) +fi + function setup { poetry install poetry run pre-commit install } -function testloc { - # rm -f integration_tests/dbt.duckdb - export DBT_PROFILES_DIR=./integration_tests/profiles - poetry run dbt deps --project-dir ./integration_tests - # poetry run dbt compile --project-dir ./integration_tests --select tag:perftest - poetry run dbt run --project-dir ./integration_tests --select tag:perftest -} - function test { local target="${1-"dbt-duckdb"}" - rm -f dbt.duckdb + + if [ -z "${GITHUB_ACTIONS}" ] && [ "${target}" = "dbt-postgres" ]; + then + createdb "${POSTGRES_DB-"dbt_linreg"}" || true + fi + + if [ -z "${GITHUB_ACTIONS}" ] && [ "${target}" = "dbt-duckdb" ]; + then + rm -f dbt.duckdb + fi + poetry run python scripts.py gen-test-cases --skip-if-exists - poetry run dbt deps --project-dir ./integration_tests --profiles-dir ./integration_tests/profiles --target "${target}" - poetry run dbt seed --project-dir ./integration_tests --profiles-dir ./integration_tests/profiles --target "${target}" - poetry run dbt run --project-dir ./integration_tests --profiles-dir ./integration_tests/profiles --target "${target}" - poetry run dbt test --project-dir ./integration_tests --profiles-dir ./integration_tests/profiles --target "${target}" + poetry run dbt deps \ + --project-dir ./integration_tests \ + --profiles-dir ./integration_tests/profiles \ + --target "${target}" + poetry run dbt seed \ + --project-dir ./integration_tests \ + --profiles-dir ./integration_tests/profiles \ + --target "${target}" + poetry run dbt run \ + --project-dir ./integration_tests \ + --profiles-dir ./integration_tests/profiles \ + --target "${target}" \ + --selector "${target}-selector" + poetry run dbt test \ + --project-dir ./integration_tests \ + --profiles-dir ./integration_tests/profiles \ + --target "${target}" \ + --selector "${target}-selector" } function lint {