Skip to content

Commit

Permalink
Merge pull request #20 from dwreeves/0.2.5-bugfixes-galore
Browse files Browse the repository at this point in the history
0.2.5 - bug fixes
  • Loading branch information
dwreeves authored Aug 24, 2024
2 parents a8fde13 + 5a55b21 commit c7fb8fc
Show file tree
Hide file tree
Showing 12 changed files with 1,387 additions and 971 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.284
rev: v0.6.2
hooks:
- id: ruff

Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

### `0.2.5`

- Fix bug where `exog` and `group_by` did not handle `str` inputs e.g. `exog="x"`.
- Fix bug where `group_by` for `method='fwl'` with exactly 1 exog variable did not work. (Explanation: `method='fwl'` dispatches to a different macro for the special case of 1 exog variable, and `group_by` was not implemented correctly here.)
- Fix bug where `safe` mode did not work for `method='chol'`.
- Improved docs by hiding everything except `ols()`, improved description of `ols()` macro, and added missing arg.

### `0.2.4`

- Fix minor incompatibility with Redshift; contributed by [@steelcd](https://github.com/steelcd).
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def ols(
format_options: Optional[dict[str, Any]] = None,
group_by: Optional[Union[str, list[str]]] = None,
alpha: Optional[Union[float, list[float]]] = None,
method: Literal['chol', 'fwl'] = 'chol'
method: Literal['chol', 'fwl'] = 'chol',
method_options: Optional[dict[str, Any]] = None
):
...
```
Expand Down
2 changes: 1 addition & 1 deletion dbt_project.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: "dbt_linreg"
version: "0.2.4"
version: "0.2.5"

# 1.2 is required because of modules.itertools.
require-dbt-version: [">=1.2.0", "<2.0.0"]
Expand Down
4 changes: 4 additions & 0 deletions macros/linear_regression/ols.sql
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
{% else %}
{% set exog = [exog] %}
{% endif %}
{% elif exog is string %}
{% set exog = [exog] %}
{% endif %}

{% if group_by is not iterable %}
Expand All @@ -68,6 +70,8 @@
{% else %}
{% set group_by = [group_by] %}
{% endif %}
{% elif group_by is string %}
{% set group_by = [group_by] %}
{% endif %}

{% if alpha is not iterable and alpha is not none %}
Expand Down
10 changes: 5 additions & 5 deletions macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
{{ return(d) }}
{% endmacro %}
{% macro _forward_substitution(li) %}
{% macro _forward_substitution(li, safe=true) %}
{% set d = {} %}
{% for i, j in modules.itertools.combinations_with_replacement(li, 2) %}
{% set ns = namespace() %}
Expand All @@ -86,7 +86,7 @@
{% endfor %}
{% set ns.numerator = ns.numerator~')' %}
{% endif %}
{% if adapter.type() == "postgres" %}
{% if safe %}
{% do d.update({(i, j): '('~ns.numerator~'/nullif(i'~j~'j'~j~', 0))'}) %}
{% else %}
{% do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %}
Expand Down Expand Up @@ -118,7 +118,7 @@
)) }}
{%- endif %}
{%- set subquery_optimization = method_options.get('subquery_optimization', True) %}
{%- set safe_sqrt = method_options.get('safe', True) %}
{%- set safe_mode = method_options.get('safe', True) %}
{%- set calculate_standard_error = format_options.get('calculate_standard_error', (not alpha)) and format == 'long' %}
{%- if alpha and calculate_standard_error %}
{% do log(
Expand Down Expand Up @@ -171,7 +171,7 @@ _dbt_linreg_xtx as (
),
_dbt_linreg_chol as (
{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_sqrt) %}
{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode) %}
{%- if subquery_optimization %}
{%- for i in (xcols | reverse) %}
select
Expand Down Expand Up @@ -202,7 +202,7 @@ _dbt_linreg_chol as (
),
_dbt_linreg_inverse_chol as (
{#- The optimal way to calculate is to do each diagonal at a time. #}
{%- set d = dbt_linreg._forward_substitution(li=xcols) %}
{%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode) %}
{%- if subquery_optimization %}
{%- for gap in (range(0, upto) | reverse) %}
select *,
Expand Down
2 changes: 1 addition & 1 deletion macros/linear_regression/ols_impl_special/_ols_1var.sql
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ _dbt_linreg_cmeans as (
{%- endif %}
_dbt_linreg_base as (
select
{{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
{{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }}
{%- if alpha and add_constant %}
b.{{ endog }} - _dbt_linreg_cmeans.y as y,
b.{{ exog[0] }} - _dbt_linreg_cmeans.x1 as x1,
Expand Down
File renamed without changes.
114 changes: 111 additions & 3 deletions macros/schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,31 @@ macros:
```
{% endraw %}
The macro renders a subquery; in some database engines, such as Postgres, it is required to alias all subqueries.
You may also select from a CTE; in this case, just pass a string referencing the CTE:
{% raw %}
```sql
{{
config(
materialized="table"
)
}}
with my_data as (
select * from {{ ref('simple_matrix') }}
)
select * from {{
dbt_linreg.ols(
table='my_data',
endog='y',
exog=['xa', 'xb', 'xc'],
format='long',
format_options={'round': 5}
)
}}
```
{% endraw %}
The macro renders a subquery, inclusive of parentheses.
Please see the README / full documentation for more information: [https://dwreeves.github.io/dbt_linreg/](https://dwreeves.github.io/dbt_linreg/)
arguments:
Expand All @@ -53,6 +77,9 @@ macros:
- name: format_options
type: dict
description: See **Formats and format options** section in the README for more.
- name: group_by
type: string or list of numbers
description: If specified, the regression will be grouped by these variables, and individual regressions will run on each group.
- name: alpha
type: number or list of numbers
description: If not null, the regression will be run as a ridge regression with a penalty of `alpha`. See **Notes** section in the README for more information.
Expand All @@ -62,18 +89,36 @@ macros:
- name: method_options
type: dict
description: Options specific to the estimation method. See **Methods and method options** in the README for more.
# Everything down here is just for intermediary calculations.
# Better to hide this stuff to reduce confusion when reading docs.
# Everything down here is just for intermediary calculations or helper functions.
# There is no point to showing these in the docs.
# The truly curious can just look at the source code.
#
# Please generate the below with the following command:
# >>> python scripts.py gen-hide-macros-yaml
- name: _alias_exog
docs:
show: false
- name: _alias_gb_cols
docs:
show: false
- name: _cell_or_alias
docs:
show: false
- name: _cholesky_decomposition
docs:
show: false
- name: _filter_and_center_if_alpha
docs:
show: false
- name: _filter_if_alpha
docs:
show: false
- name: _format_wide_variable_column
docs:
show: false
- name: _forward_substitution
docs:
show: false
- name: _gb_cols
docs:
show: false
Expand All @@ -83,21 +128,84 @@ macros:
- name: _maybe_round
docs:
show: false
- name: _ols_0var
docs:
show: false
- name: _ols_1var
docs:
show: false
- name: _ols_chol
docs:
show: false
- name: _ols_fwl
docs:
show: false
- name: _orth_x_intercept
docs:
show: false
- name: _orth_x_slope
docs:
show: false
- name: _regress_or_alias
docs:
show: false
- name: _safe_sqrt
docs:
show: false
- name: _strip_quotes
docs:
show: false
- name: _traverse_intercepts
docs:
show: false
- name: _traverse_slopes
docs:
show: false
- name: _unalias_gb_cols
docs:
show: false
- name: bigquery___safe_sqrt
docs:
show: false
- name: default___cell_or_alias
docs:
show: false
- name: default___maybe_round
docs:
show: false
- name: default___regress_or_alias
docs:
show: false
- name: default___safe_sqrt
docs:
show: false
- name: default__regress
docs:
show: false
- name: duckdb___cell_or_alias
docs:
show: false
- name: duckdb___regress_or_alias
docs:
show: false
- name: final_select
docs:
show: false
- name: postgres___maybe_round
docs:
show: false
- name: redshift___maybe_round
docs:
show: false
- name: regress
docs:
show: false
- name: snowflake___cell_or_alias
docs:
show: false
- name: snowflake___regress_or_alias
docs:
show: false
- name: snowflake__regress
docs:
show: false
Loading

0 comments on commit c7fb8fc

Please sign in to comment.