Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dwreeves committed Jan 7, 2025
1 parent d36580f commit 72e0a43
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
sudo apt-get install
chmod +x ./run
uv venv
uv sync --group python-dev
uv sync --extra python-dev
uv pip install -U "dbt-core==$DBT_CORE_VERSION" "dbt-${DBT_TARGET}==$DBT_CORE_VERSION"
env:
UV_NO_SYNC: true
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Official support for Clickhouse!
- Rename `format=` and `format_options=` to `output=` and `output_options=` to make the API consistent with **dbt_pca**.
- Allow for setting method and output options globally with `vars:`

### `0.2.6`

Expand Down
63 changes: 47 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,24 +226,38 @@ This has been deprecated to make **dbt_linreg**'s API more consistent with **dbt

### Options for `output='long'`

- **round** (default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (default = `'const'`): String name that refers to constant term.
- **variable_column_name** (default = `'variable_name'`): Column name storing strings of variable names.
- **coefficient_column_name** (default = `'coefficient'`): Column name storing model coefficients.
- **strip_quotes** (default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals.
- **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (`string`; default = `'const'`): String name that refers to constant term.
- **variable_column_name** (`string`; default = `'variable_name'`): Column name storing strings of variable names.
- **coefficient_column_name** (`string`; default = `'coefficient'`): Column name storing model coefficients.
- **strip_quotes** (`bool`; default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals.

These options are available for `output='long'` only when `method='chol'`:

- **calculate_standard_error** (default = `True if not alpha else False`): If true, provide the standard error in the output.
- **standard_error_column_name** (default = `'standard_error'`): Column name storing the standard error for the parameter.
- **t_statistic_column_name** (default = `'t_statistic'`): Column name storing the t-statistic for the parameter.
- **calculate_standard_error** (`bool`; default = `True if not alpha else False`): If true, provide the standard error in the output.
- **standard_error_column_name** (`string`; default = `'standard_error'`): Column name storing the standard error for the parameter.
- **t_statistic_column_name** (`string`; default = `'t_statistic'`): Column name storing the t-statistic for the parameter.

### Options for `output='wide'`

- **round** (default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (default = `'const'`): String name that refers to constant term.
- **variable_column_prefix** (default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **variable_column_suffix** (default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (`string`; default = `'const'`): String name that refers to constant term.
- **variable_column_prefix** (`string`; default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **variable_column_suffix** (`string`; default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)

## Setting output options globally

Output options can be set globally via `vars`, e.g. in your `dbt_project.yml`:

```yaml
# dbt_project.yml
vars:
dbt_linreg:
output_options:
round: 5
```

Output options passed via `ols()` always take precedence over globally set output options.

# Methods and method options

Expand All @@ -262,8 +276,9 @@ This method calculates regression coefficients using the Moore-Penrose pseudo-in

Specify these in a dict using the `method_options=` kwarg:

- **safe** (default = `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value will be passed into a SQRT(), and most SQL engines will raise an error when this happens.
- **subquery_optimization** (default: `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened.
- **safe** (`bool`; default: `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value will be passed into a SQRT(), and most SQL engines will raise an error when this happens.
- **subquery_optimization** (`bool`; default = `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened.
- **intra_select_aliasing** (`bool`; default = `[depends on db]`): If True, within a single select statement, column aliases are used to refer to other columns created during that select. This can significantly reduce the text of a SQL query, but not all SQL engines support this. By default, for supported databases,

## `fwl` method

Expand Down Expand Up @@ -299,11 +314,27 @@ So when should you use `fwl`? The main use case is in OLTP systems (e.g. Postgre

- Regression coefficients in Postgres are always `numeric` types.

### Possible future features
## Setting method options globally

Method options can be set globally via `vars`, e.g. in your `dbt_project.yml`. Each `method` gets its own config, e.g. `dbt_linreg: chol: ...`. Here is an example:

```yaml
# dbt_project.yml
vars:
dbt_linreg:
method_options:
chol:
intra_select_aliasing: true
```

Method options passed via `ols()` always take precedence over globally set method options.

# Possible future features

Some things that could happen in the future:

- Weighted least squares (WLS)
- Efficient multivariate regression (i.e. multiple endogenous vectors sharing a single design matrix)
- P-values
- Heteroskedasticity robust standard errors
- Recursive CTE implementations + long formatted inputs
Expand Down Expand Up @@ -332,7 +363,7 @@ There is no closed-form solution to L1 regularization, which makes it very very

### Is the `group_by=[...]` argument like categorical variables / one-hot encodings?

No. You should think of the group by more as a [seemingly unrelated regressions](https://en.wikipedia.org/wiki/Seemingly_unrelated_regressions) implementation than as a categorical variable implementation. It's running multiple regressions and each individual partition is its own `y` vector and `X` matrix. This is _not_ a replacement for dummy variables.
No. The `group_by` runs a linear regressions within each group, and each individual partition is its own `y` vector and `X` matrix. This is _not_ a replacement for dummy variables.

### Why aren't categorical variables / one-hot encodings supported?

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/tests/test_long_format_options.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ with
base as (

select strip_quotes, vname, co
from {{ ref("long_output_options") }}
from {{ ref("long_format_options") }}

),

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/tests/test_wide_format_options.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ with base as (
"fooxa_bar",
fooxb_bar
from
{{ ref("wide_output_options") }}
{{ ref("wide_format_options") }}

)

Expand Down
46 changes: 27 additions & 19 deletions macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,33 @@
in the query you wrote.
If that's not available, the previous calc will be in the dict. #}
{% macro _cell_or_alias(i, j, d, prefix=none) %}
{% macro _cell_or_alias(i, j, d, prefix=none, isa=none) %}
{% if isa is not none %}
{% if isa %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% else %}
{{ return(d[(i, j)]) }}
{% endif %}
{% endif %}
{{ return(
adapter.dispatch('_cell_or_alias', 'dbt_linreg')
(i, j, d, prefix)
(i, j, d, prefix, isa)
) }}
{% endmacro %}
{% macro default___cell_or_alias(i, j, d, prefix=none) %}
{% macro default___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return(d[(i, j)]) }}
{% endmacro %}
{% macro snowflake___cell_or_alias(i, j, d, prefix=none) %}
{% macro snowflake___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}
{% macro duckdb___cell_or_alias(i, j, d, prefix=none) %}
{% macro duckdb___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}
{% macro clickhouse___cell_or_alias(i, j, d, prefix=none) %}
{% macro clickhouse___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}
Expand All @@ -46,7 +53,7 @@
{{ return('sqrt('~x~')') }}
{% endmacro %}
{% macro _cholesky_decomposition(li, subquery_optimization=True, safe=True) %}
{% macro _cholesky_decomposition(li, subquery_optimization=true, safe=true, isa=none) %}
{% set d = {} %}
{% for i in li %}
{% for j in range(li[0], i + 1) %}
Expand All @@ -57,18 +64,18 @@
{% set ns.s = 'x'~j~'x'~i %}
{% for k in range(li[0], j) %}
{% if subquery_optimization and i != j %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d)~'*i'~j~'j'~k %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*i'~j~'j'~k %}
{% else %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d) %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d, isa=isa) %}
{% endif %}
{% endfor %}
{% if i == j %}
{% do d.update({(i, j): dbt_linreg._safe_sqrt(x=ns.s, safe=safe)}) %}
{% else %}
{% if adapter.type() == "postgres" %}
{% do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d) ~ ', 0)'}) %}
{% if safe %}
{% do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa) ~ ', 0)'}) %}
{% else %}
{% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d)}) %}
{% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa)}) %}
{% endif %}
{% endif %}
{% endif %}
Expand All @@ -77,7 +84,7 @@
{{ return(d) }}
{% endmacro %}
{% macro _forward_substitution(li, safe=true) %}
{% macro _forward_substitution(li, safe=true, isa=none) %}
{% set d = {} %}
{% for i, j in modules.itertools.combinations_with_replacement(li, 2) %}
{% set ns = namespace() %}
Expand All @@ -86,7 +93,7 @@
{% else %}
{% set ns.numerator = '(' %}
{% for k in range(i, j) %}
{% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, prefix="inv_") %}
{% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, prefix="inv_", isa=isa) %}
{% endfor %}
{% set ns.numerator = ns.numerator~')' %}
{% endif %}
Expand Down Expand Up @@ -121,9 +128,10 @@
alpha=alpha
)) }}
{%- endif %}
{%- set subquery_optimization = method_options.get('subquery_optimization', True) %}
{%- set safe_mode = method_options.get('safe', True) %}
{%- set calculate_standard_error = output_options.get('calculate_standard_error', (not alpha)) and output == 'long' %}
{%- set subquery_optimization = dbt_linreg._get_method_option('chol', 'subquery_optimization', method_options, true) %}
{%- set safe_mode = dbt_linreg._get_method_option('chol', 'safe', method_options, true) %}
{% set isa = dbt_linreg._get_method_option('chol', 'intra_select_aliasing', method_options) %}
{%- set calculate_standard_error = dbt_linreg._get_output_option('calculate_standard_error', output_options, (not alpha) and output == 'long') %}
{%- if alpha and calculate_standard_error %}
{% do log(
'Warning: Standard errors are NOT designed to take into account ridge regression regularization.'
Expand Down Expand Up @@ -175,7 +183,7 @@ _dbt_linreg_xtx as (
),
_dbt_linreg_chol as (
{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode) %}
{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode, isa=isa) %}
{%- if subquery_optimization %}
{%- for i in (xcols | reverse) %}
select
Expand Down Expand Up @@ -206,7 +214,7 @@ _dbt_linreg_chol as (
),
_dbt_linreg_inverse_chol as (
{#- The optimal way to calculate is to do each diagonal at a time. #}
{%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode) %}
{%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode, isa=isa) %}
{%- if subquery_optimization %}
{%- for gap in (range(0, upto) | reverse) %}
select *,
Expand Down
52 changes: 30 additions & 22 deletions macros/linear_regression/utils/utils.sql
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@
{# Every OLS method ends with a "_dbt_linreg_final_coefs" CTE with a common
interface. This interface can then be transformed in a standard way using the
final_select() macro, which formats the output for the user. #}
{% macro final_select(exog=None,
exog_aliased=None,
group_by=None,
add_constant=True,
output=None,
output_options=None,
calculate_standard_error=False) -%}
{% macro final_select(exog=none,
exog_aliased=none,
group_by=none,
add_constant=true,
output=none,
output_options=none,
calculate_standard_error=false) -%}
{%- if output == 'long' %}
{%- if add_constant %}
select
{{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }}
{{ dbt.string_literal(output_options.get('constant_name', 'const')) }} as {{ output_options.get('variable_column_name', 'variable_name') }},
{{ dbt_linreg._maybe_round('x0_coef', output_options.get('round')) }} as {{ output_options.get('coefficient_column_name', 'coefficient') }}{% if calculate_standard_error %},
{{ dbt_linreg._maybe_round('x0_stderr', output_options.get('round')) }} as {{ output_options.get('standard_error_column_name', 'standard_error') }},
{{ dbt_linreg._maybe_round('x0_coef/x0_stderr', output_options.get('round')) }} as {{ output_options.get('t_statistic_column_name', 't_statistic') }}
{{ dbt.string_literal(dbt_linreg._get_output_option('constant_name', output_options, 'const')) }} as {{ dbt_linreg._get_output_option('variable_column_name', output_options, 'variable_name') }},
{{ dbt_linreg._maybe_round('x0_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('coefficient_column_name', output_options, 'coefficient') }}{% if calculate_standard_error %},
{{ dbt_linreg._maybe_round('x0_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('standard_error_column_name', output_options, 'standard_error') }},
{{ dbt_linreg._maybe_round('x0_coef/x0_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('t_statistic_column_name', output_options, 't_statistic') }}
{%- endif %}
from _dbt_linreg_final_coefs as b
{%- if calculate_standard_error %}
Expand All @@ -59,10 +59,10 @@ union all
{%- for i in exog_aliased %}
select
{{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }}
{{ dbt.string_literal(dbt_linreg._strip_quotes(exog[loop.index0], output_options)) }} as {{ output_options.get('variable_column_name', 'variable_name') }},
{{ dbt_linreg._maybe_round(i~'_coef', output_options.get('round')) }} as {{ output_options.get('coefficient_column_name', 'coefficient') }}{% if calculate_standard_error %},
{{ dbt_linreg._maybe_round(i~'_stderr', output_options.get('round')) }} as {{ output_options.get('standard_error_column_name', 'standard_error') }},
{{ dbt_linreg._maybe_round(i~'_coef/'~i~'_stderr', output_options.get('round')) }} as {{ output_options.get('t_statistic_column_name', 't_statistic') }}
{{ dbt.string_literal(dbt_linreg._strip_quotes(exog[loop.index0], output_options)) }} as {{ dbt_linreg._get_output_option('variable_column_name', output_options, 'variable_name') }},
{{ dbt_linreg._maybe_round(i~'_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('coefficient_column_name', output_options, 'coefficient') }}{% if calculate_standard_error %},
{{ dbt_linreg._maybe_round(i~'_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('standard_error_column_name', output_options, 'standard_error') }},
{{ dbt_linreg._maybe_round(i~'_coef/'~i~'_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('t_statistic_column_name', output_options, 't_statistic') }}
{%- endif %}
from _dbt_linreg_final_coefs as b
{%- if calculate_standard_error %}
Expand All @@ -76,13 +76,13 @@ union all
select
{%- if add_constant -%}
{{ dbt_linreg._unalias_gb_cols(group_by) | indent(2) }}
{{ dbt_linreg._maybe_round('x0_coef', output_options.get('round')) }} as {{ dbt_linreg._format_wide_variable_column(output_options.get('constant_name', 'const'), output_options) }}
{{ dbt_linreg._maybe_round('x0_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._format_wide_variable_column(dbt_linreg._get_output_option('constant_name', output_options, 'const'), output_options) }}
{%- if exog_aliased -%}
,
{%- endif -%}
{%- endif -%}
{%- for i in exog_aliased %}
{{ dbt_linreg._maybe_round(i~'_coef', output_options.get('round')) }} as {{ dbt_linreg._format_wide_variable_column(exog[loop.index0], output_options) }}
{{ dbt_linreg._maybe_round(i~'_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._format_wide_variable_column(exog[loop.index0], output_options) }}
{%- if not loop.last -%}
,
{%- endif %}
Expand All @@ -102,7 +102,7 @@ select * from _dbt_linreg_final_coefs
In this situation, we want to strip the double quotes when presenting
outputs in a long format. #}
{% macro _strip_quotes(x, output_options) -%}
{% if output_options.get('strip_quotes') | default(True) %}
{% if dbt_linreg._get_output_option('strip_quotes', output_options) | default(True) %}
{% if x[0] == '"' and x[-1] == '"' and (x | length) > 1 %}
{{ return(x[1:-1]) }}
{% endif %}
Expand All @@ -117,11 +117,11 @@ select * from _dbt_linreg_final_coefs
{% else %}
{% set _add_quotes = False %}
{% endif %}
{% if output_options.get('variable_column_prefix') %}
{% set x = output_options.get('variable_column_prefix') ~ x %}
{% if dbt_linreg._get_output_option('variable_column_prefix', output_options) %}
{% set x = dbt_linreg._get_output_option('variable_column_prefix', output_options) ~ x %}
{% endif %}
{% if output_options.get('variable_column_suffix') %}
{% set x = x ~ output_options.get('variable_column_suffix') %}
{% if dbt_linreg._get_output_option('variable_column_suffix', output_options) %}
{% set x = x ~ dbt_linreg._get_output_option('variable_column_suffix', output_options) %}
{% endif %}
{% if _add_quotes %}
{% set x = '"' ~ x ~ '"' %}
Expand Down Expand Up @@ -227,3 +227,11 @@ on
{%- endfor %}
{%- endif %}
{%- endmacro %}

{% macro _get_output_option(field, output_options, default=none) %}
{{ return(output_options.get(field, var("dbt_linreg", {}).get("output_options", {}).get(field, default))) }}
{% endmacro %}

{% macro _get_method_option(method, field, method_options, default=none) %}
{{ return(method_options.get(field, var("dbt_linreg", {}).get("method_options", {}).get("method", {}).get(field, default))) }}
{% endmacro %}
Loading

0 comments on commit 72e0a43

Please sign in to comment.