Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.2.5 - bug fixes #20

Merged
merged 2 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.284
rev: v0.6.2
hooks:
- id: ruff

Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

### `0.2.5`

- Fix bug where `exog` and `group_by` did not handle `str` inputs e.g. `exog="x"`.
- Fix bug where `group_by` for `method='fwl'` with exactly 1 exog variable did not work. (Explanation: `method='fwl'` dispatches to a different macro for the special case of 1 exog variable, and `group_by` was not implemented correctly here.)
- Fix bug where `safe` mode did not work for `method='chol'`.
- Improved docs by hiding everything except `ols()`, improved description of `ols()` macro, and added missing arg.

### `0.2.4`

- Fix minor incompatibility with Redshift; contributed by [@steelcd](https://github.com/steelcd).
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def ols(
format_options: Optional[dict[str, Any]] = None,
group_by: Optional[Union[str, list[str]]] = None,
alpha: Optional[Union[float, list[float]]] = None,
method: Literal['chol', 'fwl'] = 'chol'
method: Literal['chol', 'fwl'] = 'chol',
method_options: Optional[dict[str, Any]] = None
):
...
```
Expand Down
2 changes: 1 addition & 1 deletion dbt_project.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: "dbt_linreg"
version: "0.2.4"
version: "0.2.5"

# 1.2 is required because of modules.itertools.
require-dbt-version: [">=1.2.0", "<2.0.0"]
Expand Down
4 changes: 4 additions & 0 deletions macros/linear_regression/ols.sql
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
{% else %}
{% set exog = [exog] %}
{% endif %}
{% elif exog is string %}
{% set exog = [exog] %}
{% endif %}

{% if group_by is not iterable %}
Expand All @@ -68,6 +70,8 @@
{% else %}
{% set group_by = [group_by] %}
{% endif %}
{% elif group_by is string %}
{% set group_by = [group_by] %}
{% endif %}

{% if alpha is not iterable and alpha is not none %}
Expand Down
10 changes: 5 additions & 5 deletions macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
{{ return(d) }}
{% endmacro %}

{% macro _forward_substitution(li) %}
{% macro _forward_substitution(li, safe=true) %}
{% set d = {} %}
{% for i, j in modules.itertools.combinations_with_replacement(li, 2) %}
{% set ns = namespace() %}
Expand All @@ -86,7 +86,7 @@
{% endfor %}
{% set ns.numerator = ns.numerator~')' %}
{% endif %}
{% if adapter.type() == "postgres" %}
{% if safe %}
{% do d.update({(i, j): '('~ns.numerator~'/nullif(i'~j~'j'~j~', 0))'}) %}
{% else %}
{% do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %}
Expand Down Expand Up @@ -118,7 +118,7 @@
)) }}
{%- endif %}
{%- set subquery_optimization = method_options.get('subquery_optimization', True) %}
{%- set safe_sqrt = method_options.get('safe', True) %}
{%- set safe_mode = method_options.get('safe', True) %}
{%- set calculate_standard_error = format_options.get('calculate_standard_error', (not alpha)) and format == 'long' %}
{%- if alpha and calculate_standard_error %}
{% do log(
Expand Down Expand Up @@ -171,7 +171,7 @@ _dbt_linreg_xtx as (
),
_dbt_linreg_chol as (

{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_sqrt) %}
{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode) %}
{%- if subquery_optimization %}
{%- for i in (xcols | reverse) %}
select
Expand Down Expand Up @@ -202,7 +202,7 @@ _dbt_linreg_chol as (
),
_dbt_linreg_inverse_chol as (
{#- The optimal way to calculate is to do each diagonal at a time. #}
{%- set d = dbt_linreg._forward_substitution(li=xcols) %}
{%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode) %}
{%- if subquery_optimization %}
{%- for gap in (range(0, upto) | reverse) %}
select *,
Expand Down
2 changes: 1 addition & 1 deletion macros/linear_regression/ols_impl_special/_ols_1var.sql
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ _dbt_linreg_cmeans as (
{%- endif %}
_dbt_linreg_base as (
select
{{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
{{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }}
{%- if alpha and add_constant %}
b.{{ endog }} - _dbt_linreg_cmeans.y as y,
b.{{ exog[0] }} - _dbt_linreg_cmeans.x1 as x1,
Expand Down
114 changes: 111 additions & 3 deletions macros/schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,31 @@ macros:
```
{% endraw %}

The macro renders a subquery; in some database engines, such as Postgres, it is required to alias all subqueries.
You may also select from a CTE; in this case, just pass a string referencing the CTE:

{% raw %}
```sql
{{
config(
materialized="table"
)
}}
with my_data as (
select * from {{ ref('simple_matrix') }}
)
select * from {{
dbt_linreg.ols(
table='my_data',
endog='y',
exog=['xa', 'xb', 'xc'],
format='long',
format_options={'round': 5}
)
}}
```
{% endraw %}

The macro renders a subquery, inclusive of parentheses.

Please see the README / full documentation for more information: [https://dwreeves.github.io/dbt_linreg/](https://dwreeves.github.io/dbt_linreg/)
arguments:
Expand All @@ -53,6 +77,9 @@ macros:
- name: format_options
type: dict
description: See **Formats and format options** section in the README for more.
- name: group_by
type: string or list of numbers
description: If specified, the regression will be grouped by these variables, and individual regressions will run on each group.
- name: alpha
type: number or list of numbers
description: If not null, the regression will be run as a ridge regression with a penalty of `alpha`. See **Notes** section in the README for more information.
Expand All @@ -62,18 +89,36 @@ macros:
- name: method_options
type: dict
description: Options specific to the estimation method. See **Methods and method options** in the README for more.
# Everything down here is just for intermediary calculations.
# Better to hide this stuff to reduce confusion when reading docs.
# Everything down here is just for intermediary calculations or helper functions.
# There is no point to showing these in the docs.
# The truly curious can just look at the source code.
#
# Please generate the below with the following command:
# >>> python scripts.py gen-hide-macros-yaml
- name: _alias_exog
docs:
show: false
- name: _alias_gb_cols
docs:
show: false
- name: _cell_or_alias
docs:
show: false
- name: _cholesky_decomposition
docs:
show: false
- name: _filter_and_center_if_alpha
docs:
show: false
- name: _filter_if_alpha
docs:
show: false
- name: _format_wide_variable_column
docs:
show: false
- name: _forward_substitution
docs:
show: false
- name: _gb_cols
docs:
show: false
Expand All @@ -83,21 +128,84 @@ macros:
- name: _maybe_round
docs:
show: false
- name: _ols_0var
docs:
show: false
- name: _ols_1var
docs:
show: false
- name: _ols_chol
docs:
show: false
- name: _ols_fwl
docs:
show: false
- name: _orth_x_intercept
docs:
show: false
- name: _orth_x_slope
docs:
show: false
- name: _regress_or_alias
docs:
show: false
- name: _safe_sqrt
docs:
show: false
- name: _strip_quotes
docs:
show: false
- name: _traverse_intercepts
docs:
show: false
- name: _traverse_slopes
docs:
show: false
- name: _unalias_gb_cols
docs:
show: false
- name: bigquery___safe_sqrt
docs:
show: false
- name: default___cell_or_alias
docs:
show: false
- name: default___maybe_round
docs:
show: false
- name: default___regress_or_alias
docs:
show: false
- name: default___safe_sqrt
docs:
show: false
- name: default__regress
docs:
show: false
- name: duckdb___cell_or_alias
docs:
show: false
- name: duckdb___regress_or_alias
docs:
show: false
- name: final_select
docs:
show: false
- name: postgres___maybe_round
docs:
show: false
- name: redshift___maybe_round
docs:
show: false
- name: regress
docs:
show: false
- name: snowflake___cell_or_alias
docs:
show: false
- name: snowflake___regress_or_alias
docs:
show: false
- name: snowflake__regress
docs:
show: false
Loading
Loading