Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update #24

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
sudo apt-get install
chmod +x ./run
uv venv
uv sync --group python-dev
uv sync --extra python-dev
uv pip install -U "dbt-core==$DBT_CORE_VERSION" "dbt-${DBT_TARGET}==$DBT_CORE_VERSION"
env:
UV_NO_SYNC: true
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Official support for Clickhouse!
- Rename `format=` and `format_options=` to `output=` and `output_options=` to make the API consistent with **dbt_pca**.
- Allow for setting method and output options globally with `vars:`

### `0.2.6`

Expand Down
81 changes: 59 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ Reasons to use **dbt_linreg**:
- 📱 **Simple interface:** Just define a `table=` (which works with `ref()`, `source()`, and CTEs), a y-variable with `endog=`, your x-variables in a list with `exog=...`, and you're all set! Note that the API is loosely based on Statsmodels's naming conventions.
- 🤖 **Support for ridge regression:** Just pass in `alpha=scalar` or `alpha=[scalar1, scalar2, ...]` to regularize your regressions. (Note: regressors are not automatically standardized.)
- 🤸‍ **Flexibility:** Tons of formatting options available to return coefficients the way you want.
- 💪 **Durable and tested:** The API provides feedback on parsing errors, and everything in this code base has been tested (check the continuous integration).
- 🤗 **User friendly:** The API provides comprehensive feedback on input errors.
- 💪 **Durable and tested:** Everything in this code base is tested against equivalent regressions performed in Statsmodels with high precision assertions (between 10e-6 to 10e-7, depending on the database engine).

# Installation

Expand Down Expand Up @@ -169,14 +170,19 @@ group by

**dbt_linreg** should work with most SQL databases, but so far, testing has been done for the following database tools:

- Snowflake
- DuckDB
- Clickhouse
- Postgres\*
| Database | Supported | Precision asserted in CI\* | Supported since version |
|----------------|-----------|----------------------------|-------------------------|
| **Snowflake** | ✅ | _n/a_ | 0.1.0 |
| **DuckDB** | ✅ | 10e-7 | 0.1.0 |
| **Postgres**† | ✅ | 10e-7 | 0.2.3 |
| **Redshift** | ✅ | _n/a_ | 0.2.4 |
| **Clickhouse** | ✅ | 10e-6 | 0.3.0 |

If **dbt_linreg** does not work in your database tool, please let me know in a bug report.

> _* Minimal support. Postgres is syntactically supported, but is not performant under certain circumstances._
> _\* Precision is for test cases using the **collinear_matrix** for unregularized regressions, in comparison to the output of the same regression in the Python package Statsmodels using `sm.OLS().fit(method="pinv")`. For example, coefficients for unregularized regressions performed in DuckDB are asserted to be within 10e-7 of Statsmodels._

> _† Minimal support for Postgres. Postgres is syntactically supported, but is not performant under certain circumstances._

# API

Expand Down Expand Up @@ -226,24 +232,38 @@ This has been deprecated to make **dbt_linreg**'s API more consistent with **dbt

### Options for `output='long'`

- **round** (default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (default = `'const'`): String name that refers to constant term.
- **variable_column_name** (default = `'variable_name'`): Column name storing strings of variable names.
- **coefficient_column_name** (default = `'coefficient'`): Column name storing model coefficients.
- **strip_quotes** (default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals.
- **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (`string`; default = `'const'`): String name that refers to constant term.
- **variable_column_name** (`string`; default = `'variable_name'`): Column name storing strings of variable names.
- **coefficient_column_name** (`string`; default = `'coefficient'`): Column name storing model coefficients.
- **strip_quotes** (`bool`; default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals.

These options are available for `output='long'` only when `method='chol'`:

- **calculate_standard_error** (default = `True if not alpha else False`): If true, provide the standard error in the output.
- **standard_error_column_name** (default = `'standard_error'`): Column name storing the standard error for the parameter.
- **t_statistic_column_name** (default = `'t_statistic'`): Column name storing the t-statistic for the parameter.
- **calculate_standard_error** (`bool`; default = `True if not alpha else False`): If true, provide the standard error in the output.
- **standard_error_column_name** (`string`; default = `'standard_error'`): Column name storing the standard error for the parameter.
- **t_statistic_column_name** (`string`; default = `'t_statistic'`): Column name storing the t-statistic for the parameter.

### Options for `output='wide'`

- **round** (default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (default = `'const'`): String name that refers to constant term.
- **variable_column_prefix** (default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **variable_column_suffix** (default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (`string`; default = `'const'`): String name that refers to constant term.
- **variable_column_prefix** (`string`; default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **variable_column_suffix** (`string`; default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)

## Setting output options globally

Output options can be set globally via `vars`, e.g. in your `dbt_project.yml`:

```yaml
# dbt_project.yml
vars:
dbt_linreg:
output_options:
round: 5
```

Output options passed via `ols()` always take precedence over globally set output options.

# Methods and method options

Expand All @@ -262,8 +282,9 @@ This method calculates regression coefficients using the Moore-Penrose pseudo-in

Specify these in a dict using the `method_options=` kwarg:

- **safe** (default = `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value will be passed into a SQRT(), and most SQL engines will raise an error when this happens.
- **subquery_optimization** (default: `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened.
- **safe** (`bool`; default: `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value may be passed into a SQRT() or a divide by zero may occur, and most SQL engines will raise an error when this happens.
- **subquery_optimization** (`bool`; default = `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened.
- **intra_select_aliasing** (`bool`; default = `[depends on db]`): If True, within a single select statement, column aliases are used to refer to other columns created during that select. This can significantly reduce the text of a SQL query, but not all SQL engines support this. By default, for all databases officially supported by **dbt_linreg**, the best option is already selected. For unsupported databases, the default is `False` for broad compatibility, so if you are running **dbt_linreg** in an officially unsupported database engine which supports this feature, you may want to modify this option globally in your `vars` to be `true`.

## `fwl` method

Expand Down Expand Up @@ -299,11 +320,27 @@ So when should you use `fwl`? The main use case is in OLTP systems (e.g. Postgre

- Regression coefficients in Postgres are always `numeric` types.

### Possible future features
## Setting method options globally

Method options can be set globally via `vars`, e.g. in your `dbt_project.yml`. Each `method` gets its own config, e.g. `dbt_linreg: chol: ...`. Here is an example:

```yaml
# dbt_project.yml
vars:
dbt_linreg:
method_options:
chol:
intra_select_aliasing: true
```

Method options passed via `ols()` always take precedence over globally set method options.

# Possible future features

Some things that could happen in the future:

- Weighted least squares (WLS)
- Efficient multivariate regression (i.e. multiple endogenous vectors sharing a single design matrix)
- P-values
- Heteroskedasticity robust standard errors
- Recursive CTE implementations + long formatted inputs
Expand Down Expand Up @@ -332,7 +369,7 @@ There is no closed-form solution to L1 regularization, which makes it very very

### Is the `group_by=[...]` argument like categorical variables / one-hot encodings?

No. You should think of the group by more as a [seemingly unrelated regressions](https://en.wikipedia.org/wiki/Seemingly_unrelated_regressions) implementation than as a categorical variable implementation. It's running multiple regressions and each individual partition is its own `y` vector and `X` matrix. This is _not_ a replacement for dummy variables.
No. The `group_by` runs a linear regressions within each group, and each individual partition is its own `y` vector and `X` matrix. This is _not_ a replacement for dummy variables.

### Why aren't categorical variables / one-hot encodings supported?

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/tests/test_long_format_options.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ with
base as (

select strip_quotes, vname, co
from {{ ref("long_output_options") }}
from {{ ref("long_format_options") }}

),

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/tests/test_wide_format_options.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ with base as (
"fooxa_bar",
fooxb_bar
from
{{ ref("wide_output_options") }}
{{ ref("wide_format_options") }}

)

Expand Down
46 changes: 27 additions & 19 deletions macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,33 @@
in the query you wrote.
If that's not available, the previous calc will be in the dict. #}

{% macro _cell_or_alias(i, j, d, prefix=none) %}
{% macro _cell_or_alias(i, j, d, prefix=none, isa=none) %}
{% if isa is not none %}
{% if isa %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% else %}
{{ return(d[(i, j)]) }}
{% endif %}
{% endif %}
{{ return(
adapter.dispatch('_cell_or_alias', 'dbt_linreg')
(i, j, d, prefix)
(i, j, d, prefix, isa)
) }}
{% endmacro %}

{% macro default___cell_or_alias(i, j, d, prefix=none) %}
{% macro default___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return(d[(i, j)]) }}
{% endmacro %}

{% macro snowflake___cell_or_alias(i, j, d, prefix=none) %}
{% macro snowflake___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}

{% macro duckdb___cell_or_alias(i, j, d, prefix=none) %}
{% macro duckdb___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}

{% macro clickhouse___cell_or_alias(i, j, d, prefix=none) %}
{% macro clickhouse___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}

Expand All @@ -46,7 +53,7 @@
{{ return('sqrt('~x~')') }}
{% endmacro %}

{% macro _cholesky_decomposition(li, subquery_optimization=True, safe=True) %}
{% macro _cholesky_decomposition(li, subquery_optimization=true, safe=true, isa=none) %}
{% set d = {} %}
{% for i in li %}
{% for j in range(li[0], i + 1) %}
Expand All @@ -57,18 +64,18 @@
{% set ns.s = 'x'~j~'x'~i %}
{% for k in range(li[0], j) %}
{% if subquery_optimization and i != j %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d)~'*i'~j~'j'~k %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*i'~j~'j'~k %}
{% else %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d) %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d, isa=isa) %}
{% endif %}
{% endfor %}
{% if i == j %}
{% do d.update({(i, j): dbt_linreg._safe_sqrt(x=ns.s, safe=safe)}) %}
{% else %}
{% if adapter.type() == "postgres" %}
{% do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d) ~ ', 0)'}) %}
{% if safe %}
{% do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa) ~ ', 0)'}) %}
{% else %}
{% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d)}) %}
{% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa)}) %}
{% endif %}
{% endif %}
{% endif %}
Expand All @@ -77,7 +84,7 @@
{{ return(d) }}
{% endmacro %}

{% macro _forward_substitution(li, safe=true) %}
{% macro _forward_substitution(li, safe=true, isa=none) %}
{% set d = {} %}
{% for i, j in modules.itertools.combinations_with_replacement(li, 2) %}
{% set ns = namespace() %}
Expand All @@ -86,7 +93,7 @@
{% else %}
{% set ns.numerator = '(' %}
{% for k in range(i, j) %}
{% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, prefix="inv_") %}
{% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, prefix="inv_", isa=isa) %}
{% endfor %}
{% set ns.numerator = ns.numerator~')' %}
{% endif %}
Expand Down Expand Up @@ -121,9 +128,10 @@
alpha=alpha
)) }}
{%- endif %}
{%- set subquery_optimization = method_options.get('subquery_optimization', True) %}
{%- set safe_mode = method_options.get('safe', True) %}
{%- set calculate_standard_error = output_options.get('calculate_standard_error', (not alpha)) and output == 'long' %}
{%- set subquery_optimization = dbt_linreg._get_method_option('chol', 'subquery_optimization', method_options, true) %}
{%- set safe_mode = dbt_linreg._get_method_option('chol', 'safe', method_options, true) %}
{% set isa = dbt_linreg._get_method_option('chol', 'intra_select_aliasing', method_options) %}
{%- set calculate_standard_error = dbt_linreg._get_output_option('calculate_standard_error', output_options, (not alpha) and output == 'long') %}
{%- if alpha and calculate_standard_error %}
{% do log(
'Warning: Standard errors are NOT designed to take into account ridge regression regularization.'
Expand Down Expand Up @@ -175,7 +183,7 @@ _dbt_linreg_xtx as (
),
_dbt_linreg_chol as (

{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode) %}
{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode, isa=isa) %}
{%- if subquery_optimization %}
{%- for i in (xcols | reverse) %}
select
Expand Down Expand Up @@ -206,7 +214,7 @@ _dbt_linreg_chol as (
),
_dbt_linreg_inverse_chol as (
{#- The optimal way to calculate is to do each diagonal at a time. #}
{%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode) %}
{%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode, isa=isa) %}
{%- if subquery_optimization %}
{%- for gap in (range(0, upto) | reverse) %}
select *,
Expand Down
Loading
Loading