Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dwreeves committed Oct 30, 2023
1 parent 2cda884 commit a3bab12
Show file tree
Hide file tree
Showing 22 changed files with 124 additions and 50 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ jobs:
postgres:
image: postgres
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
POSTGRES_DB: dbt_linreg
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
Expand Down Expand Up @@ -55,5 +58,7 @@ jobs:
run: ./run test "${DBT_TARGET}"
env:
DBT_TARGET: ${{ matrix.db_target }}
POSTGRES_HOST: localhost
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: dbt_linreg
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
dbt.duckdb
dbt.duckdb.wal
.user.yml
docs/site/
integration_tests/seeds/*.csv
Expand Down
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,13 @@ group by
**dbt_linreg** should work with most SQL databases, but so far, testing has been done for the following database tools:

- Snowflake
- Postgres
- DuckDB
- Postgres\*

If `dbt_linreg` does not work in your database tool, please let me know in a bug report and I can make sure it is supported.

_* Syntactically supported, but not performant._

# API

The only function available in the public API is the `dbt_linreg.ols()` macro.
Expand Down Expand Up @@ -256,7 +258,7 @@ This method calculates regression coefficients using the Moore-Penrose pseudo-in
Specify these in a dict using the `method_options=` kwarg:

- **safe** (default = `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value will be passed into a SQRT(), and most SQL engines will raise an error when this happens.
- **subquery_optimization** (default = `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened. Note that turning this off can significantly degrade performance.
- **subquery_optimization** (default: `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened.

## `fwl` method

Expand All @@ -270,10 +272,12 @@ Ridge regression is implemented using the augmentation technique described in Ex

There are a few reasons why this method is discouraged over the `chol` method:

- 🐌 It tends to be much slower, and struggles to efficiently calculate large number of columns.
- 🐌 It tends to be much slower in OLAP systems, and struggles to efficiently calculate large number of columns.
- 📊 It does not calculate standard errors.
- 😕 For ridge regression, coefficients are not accurate; they tend to be off by a magnitude of ~0.01%.

So when should you use `fwl`? The main use case is in OLTP systems (e.g. Postgres) for unregularized coefficient estimation. Long story short, the `chol` method relies on subquery optimization to be more performant than `fwl`; however, OLTP systems do not benefit at all from subquery optimization. This means that `fwl` is slightly more performant in this context.

# Notes

- ⚠️ **If your coefficients are null, it does not mean dbt_linreg is broken, it most likely means your feature columns are perfectly multicollinear.** If you are 100% sure that is not the issue, please file a bug report with a minimally reproducible example.
Expand All @@ -283,6 +287,11 @@ There are a few reasons why this method is discouraged over the `chol` method:
- An array input (e.g. `alpha=[0.01, 0.02, 0.03, 0.04, 0.05]`) will apply an alpha of `0.01` to the first column, `0.02` to the second column, etc.
- `alpha` is equivalent to what TEoSL refers to as "lambda," times the sample size N. That is to say: `α ≡ λ * N`.

- Regularization as currently implemented for the `chol` method tends to be very slow in OLTP systems (e.g. Postgres), but is very performant in OLAP systems (e.g. Snowflake, DuckDB, BigQuery, Redshift). As dbt is more commonly used in OLAP contexts, the code base is optimized for the OLAP use case.
- That said, it may be possible to make regularization in OLTP more performant (e.g. with augmentation of the design matrix), so PRs are welcome.

- Regression coefficients in Postgres are always `numeric` types.

### Possible future features

Some things I am thinking about working on down the line:
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/dbt_project.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: "dbt_linreg_tests"
version: "0.2.1"
version: "0.2.3"

require-dbt-version: [">=1.0.0", "<2.0.0"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand All @@ -10,7 +11,8 @@ select * from {{
exog=['x1', 'x2', 'x3'],
group_by=['gb_var'],
format='long',
method='chol'
method='chol',
method_options={'subquery_optimization': True}
)
}} as linreg
order by gb_var, variable_name
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
3 changes: 2 additions & 1 deletion integration_tests/models/simple_10var_regression_long.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{{
config(
materialized="view",
enabled=False
enabled=False,
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
3 changes: 2 additions & 1 deletion integration_tests/models/simple_4var_regression_wide.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
3 changes: 2 additions & 1 deletion integration_tests/models/simple_5var_regression_long.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
3 changes: 2 additions & 1 deletion integration_tests/models/simple_5var_regression_wide.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{{
config(
materialized="table"
materialized="table",
tags=["skip-postgres"]
)
}}
select * from {{
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/models/simple_8var_regression_wide.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{{
config(
materialized="view",
tags=["perftest"],
enabled=False
tags=["perftest", "skip-postgres"],
enabled=False,
)
}}
select * from {{
Expand Down
8 changes: 3 additions & 5 deletions integration_tests/profiles/profiles.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ dbt_linreg_profile:
type: duckdb
path: dbt.duckdb
dbt-postgres:
# This is configured for Github Actions.
# For local configuration, set env vars.
type: postgres
user: '{{ env_var("POSTGRES_USER", "postgres") }}'
password: '{{ env_var("POSTGRES_PASSWORD", "postgres") }}'
user: '{{ env_var("POSTGRES_USER") }}'
password: '{{ env_var("POSTGRES_PASSWORD") }}'
host: '{{ env_var("POSTGRES_HOST", "localhost") }}'
port: '{{ env_var("POSTGRES_PORT", "5432") | as_number }}'
dbname: '{{ env_var("POSTGRES_DB", "postgres") }}'
dbname: '{{ env_var("POSTGRES_DB", "dbt_linreg") }}'
schema: '{{ env_var("POSTGRES_SCHEMA", "public") }}'
11 changes: 11 additions & 0 deletions integration_tests/selectors.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
selectors:
- name: dbt-duckdb-selector
definition: 'fqn:*'
- name: dbt-postgres-selector
# Postgres runs into memory / performance issues for some of these queries.
# Resolving this and making Postgres more performant is a TODO.
definition:
union:
- 'fqn:*'
- exclude:
- '@tag:skip-postgres'
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ expected as (
)

select base.variable_name
from {{ ref('groups_matrix_regression_chol') }} as base
from {{ ref('groups_matrix_regression_chol_optimized') }} as base
full outer join expected
on
base.gb_var = expected.gb_var
Expand Down
12 changes: 6 additions & 6 deletions integration_tests/tests/test_long_format_options.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ base as (
find_unstripped_quotes as (

select
max(vname = '"xa"') as should_be_true,
max(vname = 'xa') as should_be_false
cast(max(cast(vname = '"xa"' as integer)) as boolean) as should_be_true,
cast(max(cast(vname = 'xa' as integer)) as boolean) as should_be_false
from base
where not strip_quotes

Expand All @@ -20,8 +20,8 @@ find_unstripped_quotes as (
dodge_unstripped_quotes as (

select
max(vname = 'xa') as should_be_true,
max(vname = '"xa"') as should_be_false
cast(max(cast(vname = 'xa' as integer)) as boolean) as should_be_true,
cast(max(cast(vname = '"xa"' as integer)) as boolean) as should_be_false
from base
where strip_quotes

Expand All @@ -30,8 +30,8 @@ dodge_unstripped_quotes as (
coef_col_name as (

select
max(vname = 'constant_term') as should_be_true,
max(vname = 'const') as should_be_false
cast(max(cast(vname = 'constant_term' as integer)) as boolean) as should_be_true,
cast(max(cast(vname = 'const' as integer)) as boolean) as should_be_false
from base

)
Expand Down
3 changes: 3 additions & 0 deletions macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{# In some warehouses, you can reference newly created column aliases
in the query you wrote.
If that's not available, the previous calc will be in the dict. #}
{% macro _cell_or_alias(i, j, d, prefix=none) %}
{{ return(
adapter.dispatch('_cell_or_alias', 'dbt_linreg')
Expand Down Expand Up @@ -60,6 +61,7 @@
{% if i == j %}
{% do d.update({(i, j): dbt_linreg._safe_sqrt(x=ns.s, safe=safe)}) %}
{% else %}
{# do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d) ~ ', 0)'}) #}
{% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d)}) %}
{% endif %}
{% endif %}
Expand All @@ -81,6 +83,7 @@
{% endfor %}
{% set ns.numerator = ns.numerator~')' %}
{% endif %}
{# do d.update({(i, j): '('~ns.numerator~'/nullif(i'~j~'j'~j~', 0))'}) #}
{% do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %}
{% endfor %}
{{ return(d) }}
Expand Down
29 changes: 22 additions & 7 deletions macros/linear_regression/utils.sql
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
{%- if add_constant %}
select
{{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }}
'{{ format_options.get('constant_name', 'const') }}' as {{ format_options.get('variable_column_name', 'variable_name') }},
{{ dbt.string_literal(format_options.get('constant_name', 'const')) }} as {{ format_options.get('variable_column_name', 'variable_name') }},
{{ dbt_linreg._maybe_round('x0_coef', format_options.get('round')) }} as {{ format_options.get('coefficient_column_name', 'coefficient') }}{% if calculate_standard_error %},
{{ dbt_linreg._maybe_round('x0_stderr', format_options.get('round')) }} as {{ format_options.get('standard_error_column_name', 'standard_error') }},
{{ dbt_linreg._maybe_round('x0_coef/x0_stderr', format_options.get('round')) }} as {{ format_options.get('t_statistic_column_name', 't_statistic') }}
Expand All @@ -59,7 +59,7 @@ union all
{%- for i in exog_aliased %}
select
{{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }}
'{{ dbt_linreg._strip_quotes(exog[loop.index0], format_options) }}' as {{ format_options.get('variable_column_name', 'variable_name') }},
{{ dbt.string_literal(dbt_linreg._strip_quotes(exog[loop.index0], format_options)) }} as {{ format_options.get('variable_column_name', 'variable_name') }},
{{ dbt_linreg._maybe_round(i~'_coef', format_options.get('round')) }} as {{ format_options.get('coefficient_column_name', 'coefficient') }}{% if calculate_standard_error %},
{{ dbt_linreg._maybe_round(i~'_stderr', format_options.get('round')) }} as {{ format_options.get('standard_error_column_name', 'standard_error') }},
{{ dbt_linreg._maybe_round(i~'_coef/'~i~'_stderr', format_options.get('round')) }} as {{ format_options.get('t_statistic_column_name', 't_statistic') }}
Expand Down Expand Up @@ -154,12 +154,27 @@ gb{{ loop.index }} as {{ gb }},

{# Round the final coefficient if the user specifies the `round` format
option. Otherwise, keep as is. #}

{% macro _maybe_round(x, round_) %}
{% if round_ is not none %}
{{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }}
{% else %}
{{ return(x) }}
{% endif %}
{{ return(
adapter.dispatch('_maybe_round', 'dbt_linreg')(x, round_)
) }}
{% endmacro %}

{% macro default___maybe_round(x, round_) %}
{% if round_ is not none %}
{{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }}
{% else %}
{{ return(x) }}
{% endif %}
{% endmacro %}

{% macro postgres___maybe_round(x, round_) %}
{% if round_ is not none %}
{{ return('round((' ~ x ~ ')::numeric, ' ~ round_ ~ ')') }}
{% else %}
{{ return('(' ~ x ~ ')::numeric') }}
{% endif %}
{% endmacro %}

{# Alias and write group by columns in a standard way. #}
Expand Down
Loading

0 comments on commit a3bab12

Please sign in to comment.