Skip to content

Commit

Permalink
ENH: return_all option for compute_regression
Browse files Browse the repository at this point in the history
  • Loading branch information
mmagnuski committed Dec 13, 2023
1 parent a9e3952 commit a7edea4
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions borsar/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# - [x] return residuals
# - [ ] consider returning a dictionary of additional info if required
# (for example coefficients and SE)
def compute_regression_t(data, preds, return_p=False, return_residuals=False):
def compute_regression_t(data, preds, return_p=False, return_residuals=False,
return_all=False):
'''Compute regression t values for whole multidimensional data space.
Parameters
Expand All @@ -19,6 +20,16 @@ def compute_regression_t(data, preds, return_p=False, return_residuals=False):
If ``True`` - also return p values. Defaults to ``False``.
return_residuals : bool
If ``True`` - also return regression residuals. Defaults to ``False``.
return_all : bool
If ``True`` - return all outputs as a dictionary. Defaults to
``False``. The outputs include:
* ``'coefs'`` - regression coefficients
* ``'SE'`` - standard errors
* ``'t'`` - t values
* ``'p'`` - p values (only if ``return_p`` is ``True``)
* ``'resid'`` - regression residuals
* ``'df'`` - degrees of freedom
* ``'predicted'`` - predicted values
Returns
-------
Expand Down Expand Up @@ -49,18 +60,30 @@ def compute_regression_t(data, preds, return_p=False, return_residuals=False):
SE = np.sqrt(MSE * np.diag(np.linalg.pinv(preds.T @ preds))[:, np.newaxis])
t_vals = (coefs / SE).reshape([n_preds, *original_shape[1:]])

out = (t_vals,)
if return_all:
out = {'coefs': coefs.reshape([n_preds, *original_shape[1:]]),
'SE': SE.reshape([n_preds, *original_shape[1:]]),
't': t_vals,
'df': df,
'predicted': prediction.reshape(original_shape)}
else:
out = (t_vals,)

if return_p:
from scipy.stats import t
p_vals = t.cdf(-np.abs(t_vals), df) * 2.
out += (p_vals,)

if return_residuals:
if return_all:
out['p'] = p_vals
else:
out += (p_vals,)

if return_residuals and not return_all:
residuals = residuals.reshape(original_shape)
out += (residuals,)

# make sure we return a tuple only if we have more than one output
if len(out) == 1:
if len(out) == 1 and not return_all:
return out[0]
else:
return out
Expand Down

0 comments on commit a7edea4

Please sign in to comment.