diff --git a/borsar/stats.py b/borsar/stats.py index 03e0d5a..e705dc7 100644 --- a/borsar/stats.py +++ b/borsar/stats.py @@ -5,7 +5,8 @@ # - [x] return residuals # - [ ] consider returning a dictionary of additional info if required # (for example coefficients and SE) -def compute_regression_t(data, preds, return_p=False, return_residuals=False): +def compute_regression_t(data, preds, return_p=False, return_residuals=False, + return_all=False): '''Compute regression t values for whole multidimensional data space. Parameters @@ -19,6 +20,16 @@ def compute_regression_t(data, preds, return_p=False, return_residuals=False): If ``True`` - also return p values. Defaults to ``False``. return_residuals : bool If ``True`` - also return regression residuals. Defaults to ``False``. + return_all : bool + If ``True`` - return all outputs as a dictionary. Defaults to + ``False``. The outputs include: + * ``'coefs'`` - regression coefficients + * ``'SE'`` - standard errors + * ``'t'`` - t values + * ``'p'`` - p values (only if ``return_p`` is ``True``) + * ``'resid'`` - regression residuals + * ``'df'`` - degrees of freedom + * ``'predicted'`` - predicted values Returns ------- @@ -49,18 +60,30 @@ def compute_regression_t(data, preds, return_p=False, return_residuals=False): SE = np.sqrt(MSE * np.diag(np.linalg.pinv(preds.T @ preds))[:, np.newaxis]) t_vals = (coefs / SE).reshape([n_preds, *original_shape[1:]]) - out = (t_vals,) + if return_all: + out = {'coefs': coefs.reshape([n_preds, *original_shape[1:]]), + 'SE': SE.reshape([n_preds, *original_shape[1:]]), + 't': t_vals, + 'df': df, + 'predicted': prediction.reshape(original_shape)} + else: + out = (t_vals,) + if return_p: from scipy.stats import t p_vals = t.cdf(-np.abs(t_vals), df) * 2. - out += (p_vals,) - if return_residuals: + if return_all: + out['p'] = p_vals + else: + out += (p_vals,) + + if return_residuals and not return_all: residuals = residuals.reshape(original_shape) out += (residuals,) # make sure we return a tuple only if we have more than one output - if len(out) == 1: + if len(out) == 1 and not return_all: return out[0] else: return out