diff --git a/tests/test_fit.py b/tests/test_fit.py index fc59117..17aa8ad 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -14,7 +14,7 @@ explain_timeseries, ) from wise_pizza.segment_data import SegmentData -from wise_pizza.solver import solve_lasso, solve_lp +from wise_pizza.solve.solver import solve_lasso, solve_lp from wise_pizza.time import create_time_basis from wise_pizza.plotting_time import plot_time @@ -33,7 +33,7 @@ # Too long, delete some values for quick starts, e.g. by deleting the parameters in nan_percent, size_one_percent deltas_test_values = [ ("totals", "split_fits", "force_dim", "extra_dim"), # how - ("lp", "lasso"), # solver + ("lp", "lasso", "tree"), # solver (True,), # plot_is_static (explain_changes_in_average, explain_changes_in_totals), # function (0.0, 90.0), # nan_percent @@ -44,7 +44,7 @@ # possible values for explain_levels levels_test_values = [ - ("lp", "lasso"), # solver + ("lp", "lasso", "tree"), # solver (0.0, 90.0), # nan_percent (0.0, 90.0), # size_one_percent ] @@ -136,9 +136,9 @@ def test_categorical(): print("yay!") -@pytest.mark.parametrize("nan_percent", [0.0, 1.0]) -def test_synthetic_template(nan_percent: float): - all_data = synthetic_data(init_len=1000) +@pytest.mark.parametrize("nan_percent, clustering", [[0.0, False], [1.0, False]]) +def test_synthetic_template(nan_percent: float, clustering: bool): + all_data = synthetic_data(init_len=10000, dim_values=5) data = all_data.data data.loc[(data["dim0"] == 0) & (data["dim1"] == 1), "totals"] += 100 @@ -155,6 +155,7 @@ def test_synthetic_template(nan_percent: float): min_segments=5, verbose=1, solver="lp", + cluster_values=clustering, ) print("***") for s in sf.segments: @@ -167,6 +168,38 @@ def test_synthetic_template(nan_percent: float): print("yay!") +@pytest.mark.parametrize("nan_percent", [0.0, 1.0]) +def test_synthetic_template_tree(nan_percent: float): + all_data = synthetic_data(init_len=1000) + data = all_data.data + + data.loc[(data["dim0"] == 0) & (data["dim1"] == 1), "totals"] += 200 + data.loc[(data["dim1"] == 0) & (data["dim2"] == 1), "totals"] += 300 + + if nan_percent > 0: + data = values_to_nan(data, nan_percent) + sf = explain_levels( + data, + dims=all_data.dimensions, + total_name=all_data.segment_total, + size_name=all_data.segment_size, + max_depth=2, + min_segments=5, + verbose=1, + solver="tree", + ) + print("***") + for s in sf.segments: + print(s) + + # TODO: insert approppriate asserts + # assert abs(sf.segments[0]["coef"] - 300) < 2 + # assert abs(sf.segments[1]["coef"] - 100) < 2 + + # sf.plot() + print("yay!") + + @pytest.mark.parametrize("nan_percent", [0.0, 1.0]) def test_synthetic_ts_template(nan_percent: float): all_data = synthetic_ts_data(init_len=10000) diff --git a/wise_pizza/cluster.py b/wise_pizza/cluster.py index f4e185f..6090af7 100644 --- a/wise_pizza/cluster.py +++ b/wise_pizza/cluster.py @@ -1,3 +1,6 @@ +from typing import List, Dict, Tuple +from collections import defaultdict + import numpy as np import pandas as pd from sklearn.preprocessing import PowerTransformer @@ -18,17 +21,27 @@ def guided_kmeans(X: np.ndarray, power_transform: bool = True) -> np.ndarray: X = X.values if power_transform: - if len(X[X > 0] > 1): - X[X > 0] = PowerTransformer(standardize=False).fit_transform(X[X > 0].reshape(-1, 1)).reshape(-1) - if len(X[X < 0] > 1): - X[X < 0] = -PowerTransformer(standardize=False).fit_transform(-X[X < 0].reshape(-1, 1)).reshape(-1) + if len(X[X > 0]) > 1: + X[X > 0] = ( + PowerTransformer(standardize=False) + .fit_transform(X[X > 0].reshape(-1, 1)) + .reshape(-1) + ) + if len(X[X < 0]) > 1: + X[X < 0] = ( + -PowerTransformer(standardize=False) + .fit_transform(-X[X < 0].reshape(-1, 1)) + .reshape(-1) + ) best_score = -1 best_labels = None best_n = -1 # If we allow 2 clusters, it almost always just splits positive vs negative - boring! for n_clusters in range(3, int(len(X) / 2) + 1): - cluster_labels = KMeans(n_clusters=n_clusters, init="k-means++", n_init=10).fit_predict(X) + cluster_labels = KMeans( + n_clusters=n_clusters, init="k-means++", n_init=10 + ).fit_predict(X) score = silhouette_score(X, cluster_labels) # print(n_clusters, score) if score > best_score: @@ -45,3 +58,55 @@ def to_matrix(labels: np.ndarray) -> np.ndarray: for i in labels.unique(): out[labels == i, i] = 1.0 return out + + +def make_clusters(dim_df: pd.DataFrame, dims: List[str]): + cluster_names = {} + for dim in dims: + if len(dim_df[dim].unique()) >= 6: # otherwise what's the point in clustering? + grouped_df = ( + dim_df[[dim, "totals", "weights"]].groupby(dim, as_index=False).sum() + ) + grouped_df["avg"] = grouped_df["totals"] / grouped_df["weights"] + grouped_df["cluster"], _ = guided_kmeans(grouped_df["avg"]) + pre_clusters = ( + grouped_df[["cluster", dim]] + .groupby("cluster") + .agg({dim: lambda x: "@@".join(x)}) + .values + ) + # filter out clusters with only one element + these_clusters = [c for c in pre_clusters.reshape(-1) if "@@" in c] + # create short cluster names + for i, c in enumerate(these_clusters): + cluster_names[f"{dim}_cluster_{i + 1}"] = c + return cluster_names + + +def nice_cluster_names(x: List[Dict[str, List[str]]]) -> Tuple[List[Dict], Dict]: + # first pass just populate cluster names + cluster_strings = defaultdict(set) + for xx in x: + for dim, v in xx.items(): + if len(v) > 1: + cluster_strings[dim].add("@@".join(v)) + + cluster_names = {} + reverse_cluster_names = {} + for dim, clusters in cluster_strings.items(): + reverse_cluster_names[dim] = {} + for i, c in enumerate(clusters): + cluster_names[f"{dim}_cluster_{i + 1}"] = c + reverse_cluster_names[dim][c] = f"{dim}_cluster_{i + 1}" + + col_defs = [] + for xx in x: + this_def = {} + for dim, v in xx.items(): + if len(v) > 1: + this_def[dim] = reverse_cluster_names[dim]["@@".join(v)] + else: + this_def[dim] = v[0] + col_defs.append(this_def) + + return col_defs, cluster_names diff --git a/wise_pizza/explain.py b/wise_pizza/explain.py index 0c3c3d5..58cb74f 100644 --- a/wise_pizza/explain.py +++ b/wise_pizza/explain.py @@ -361,6 +361,7 @@ def explain_timeseries( max_depth: int = 2, solver: str = "omp", verbose: bool = False, + constrain_signs: bool = False, cluster_values: bool = False, time_basis: Optional[pd.DataFrame] = None, fit_log_space: bool = False, @@ -388,7 +389,10 @@ def explain_timeseries( fit_sizes = True if fit_log_space: - tf = LogTransform(offset=1, weight_pow_sc=log_space_weight_sc) + tf = LogTransform( + offset=1, + weight_pow_sc=log_space_weight_sc, + ) else: tf = IdentityTransform() @@ -415,6 +419,7 @@ def explain_timeseries( max_depth=max_depth, solver=solver, verbose=verbose, + constrain_signs=constrain_signs, cluster_values=cluster_values, time_basis=time_basis, ) @@ -441,6 +446,7 @@ def explain_timeseries( max_depth=max_depth, solver=solver, verbose=verbose, + constrain_signs=constrain_signs, cluster_values=cluster_values, time_basis=time_basis, ) @@ -477,6 +483,7 @@ def explain_timeseries( max_depth=max_depth, solver=solver, verbose=verbose, + constrain_signs=constrain_signs, cluster_values=cluster_values, time_basis=time_basis, ) diff --git a/wise_pizza/slicer.py b/wise_pizza/slicer.py index 89c576c..10ac069 100644 --- a/wise_pizza/slicer.py +++ b/wise_pizza/slicer.py @@ -8,12 +8,14 @@ import pandas as pd from scipy.sparse import csc_matrix, diags -from wise_pizza.find_alpha import clean_up_min_max, find_alpha +from wise_pizza.solve.find_alpha import clean_up_min_max, find_alpha from wise_pizza.make_matrix import sparse_dummy_matrix -from wise_pizza.cluster import guided_kmeans +from wise_pizza.cluster import make_clusters from wise_pizza.preselect import HeuristicSelector from wise_pizza.time import extend_dataframe from wise_pizza.slicer_facades import SliceFinderPredictFacade +from wise_pizza.solve.tree import tree_solver +from wise_pizza.solve.solver import solve_lasso def _summary(obj) -> str: @@ -116,7 +118,7 @@ def fit( @param max_segments: Maximum number of segments to find, defaults to min_segments @param min_depth: Minimum number of dimension to constrain in segment definition @param max_depth: Maximum number of dimension to constrain in segment definition - @param solver: If this equals to "lp" uses the LP solver, else uses the (recommended) Lasso solver + @param solver: Valid values are "lasso" (default), "tree" (for non-overlapping segments), "omp", or "lp" @param verbose: If set to a truish value, lots of debug info is printed to console @param force_dim: To add dim @param force_add_up: To force add up @@ -125,6 +127,8 @@ def fit( group of segments from the same dimension with similar naive averages """ + + assert solver.lower() in ["lasso", "tree", "omp", "lp"] min_segments, max_segments = clean_up_min_max(min_segments, max_segments) if verbose is not None: self.verbose = verbose @@ -139,12 +143,16 @@ def fit( assert min(weights) >= 0 assert np.sum(np.abs(totals[weights == 0])) == 0 + # Cast all dimension values to strings + dim_df = dim_df.astype(str) + dims = list(dim_df.columns) # sort the dataframe by dimension values, # making sure the other vectors stay aligned dim_df = dim_df.reset_index(drop=True) dim_df["totals"] = totals dim_df["weights"] = weights + if time_col is not None: dim_df["__time"] = time_col dim_df = pd.merge(dim_df, time_basis, left_on="__time", right_index=True) @@ -176,70 +184,73 @@ def fit( # of dimension values with similar outcomes clusters = defaultdict(list) self.cluster_names = {} - if cluster_values: - for dim in dims: - if ( - len(dim_df[dim].unique()) >= 6 - ): # otherwise what's the point in clustering? - grouped_df = ( - dim_df[[dim, "totals", "weights"]] - .groupby(dim, as_index=False) - .sum() - ) - grouped_df["avg"] = grouped_df["totals"] / grouped_df["weights"] - grouped_df["cluster"], _ = guided_kmeans(grouped_df["avg"]) - pre_clusters = ( - grouped_df[["cluster", dim]] - .groupby("cluster") - .agg({dim: lambda x: "@@".join(x)}) - .values - ) - # filter out clusters with only one element - these_clusters = [c for c in pre_clusters.reshape(-1) if "@@" in c] - # create short cluster names - for i, c in enumerate(these_clusters): - self.cluster_names[f"{dim}_cluster_{i+1}"] = c + + if solver == "tree": + if cluster_values: + warnings.warn( + "Ignoring cluster_values argument as tree solver makes its own clusters" + ) + self.X, self.col_defs, self.cluster_names = tree_solver( + dim_df=dim_df, + dims=dims, + time_basis=self.time_basis, + num_leaves=max_segments, + ) + self.nonzeros = np.array(range(self.X.shape[1])) + Xw = csc_matrix(diags(self.weights) @ self.X) + self.reg = solve_lasso( + Xw.toarray(), + self.totals, + alpha=1e-5, + verbose=self.verbose, + fit_intercept=False, + ) + print("") + else: + if cluster_values: + self.cluster_names = make_clusters(dim_df, dims) + for dim in dims: clusters[dim] = [ c for c in self.cluster_names.keys() if c.startswith(dim) ] - dim_df = dim_df[dims] # if time_col is None else dims + ["__time"]] - self.dim_df = dim_df - - # lazy calculation of the dummy matrix (calculation can be very slow) - if ( - list(dim_df.columns) != self.dims - or max_depth != self.max_depth - or self.X is not None - and len(dim_df) != self.X.shape[1] - ): - self.X, self.col_defs = self._init_mat( - dim_df, - min_depth, - max_depth, - force_dim=force_dim, - clusters=clusters, - time_basis=self.time_basis, + dim_df = dim_df[dims] # if time_col is None else dims + ["__time"]] + self.dim_df = dim_df + # lazy calculation of the dummy matrix (calculation can be very slow) + if ( + list(dim_df.columns) != self.dims + or max_depth != self.max_depth + or self.X is not None + and len(dim_df) != self.X.shape[1] + ): + self.X, self.col_defs = self._init_mat( + dim_df, + min_depth, + max_depth, + force_dim=force_dim, + clusters=clusters, + time_basis=self.time_basis, + ) + assert len(self.col_defs) == self.X.shape[1] + self.min_depth = min_depth + self.max_depth = max_depth + self.dims = list(dim_df.columns) + + Xw = csc_matrix(diags(self.weights) @ self.X) + + if self.verbose: + print("Starting solve!") + self.reg, self.nonzeros = find_alpha( + Xw, + self.totals, + max_nonzeros=max_segments, + solver=solver, + min_nonzeros=min_segments, + verbose=self.verbose, + adding_up_regularizer=force_add_up, + constrain_signs=constrain_signs, ) - assert len(self.col_defs) == self.X.shape[1] - self.min_depth = min_depth - self.max_depth = max_depth - self.dims = list(dim_df.columns) - Xw = csc_matrix(diags(self.weights) @ self.X) - - if self.verbose: - print("Starting solve!") - self.reg, self.nonzeros = find_alpha( - Xw, - self.totals, - max_nonzeros=max_segments, - solver=solver, - min_nonzeros=min_segments, - verbose=self.verbose, - adding_up_regularizer=force_add_up, - constrain_signs=constrain_signs, - ) if self.verbose: print("Solver done!!") diff --git a/wise_pizza/solve/__init__.py b/wise_pizza/solve/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/wise_pizza/find_alpha.py b/wise_pizza/solve/find_alpha.py similarity index 98% rename from wise_pizza/find_alpha.py rename to wise_pizza/solve/find_alpha.py index aae0fbb..ec454fe 100644 --- a/wise_pizza/find_alpha.py +++ b/wise_pizza/solve/find_alpha.py @@ -1,9 +1,9 @@ import numpy as np import pandas as pd from scipy.sparse import vstack, csr_array, issparse -from scipy.linalg import svd, expm +from scipy.linalg import svd -from wise_pizza.solver import solve_lasso, solve_lp, solve_omp +from wise_pizza.solve.solver import solve_lasso, solve_lp, solve_omp def find_alpha( @@ -122,7 +122,7 @@ def find_alpha( mat = X y = y_ - if solver=="omp": + if solver == "omp": reg, nonzeros = solve_omp(mat.toarray(), y, min_nonzeros) return reg, nonzeros @@ -143,9 +143,6 @@ def print_errors(a: np.ndarray): if verbose: print_errors(np.zeros(X.shape[1])) - - - while len(nonzeros) < min_nonzeros: alpha /= 2 reg = solve( diff --git a/wise_pizza/solve/fitter.py b/wise_pizza/solve/fitter.py new file mode 100644 index 0000000..4f44254 --- /dev/null +++ b/wise_pizza/solve/fitter.py @@ -0,0 +1,40 @@ +from typing import List +from abc import ABC, abstractmethod + +import numpy as np + + +class Fitter(ABC): + @abstractmethod + def fit(self, X, y, sample_weight=None): + pass + + @abstractmethod + def predict(self, X): + pass + + def fit_predict(self, X, y, sample_weight=None): + self.fit(X, y, sample_weight) + return self.predict(X) + + def error(self, X, y, sample_weight=None): + err = y - self.predict(X) + if sample_weight is not None: + err *= sample_weight + return np.nansum(err**2) + + +class AverageFitter(Fitter): + def __init__(self): + self.avg = None + + def fit(self, X, y, sample_weight=None): + y = np.array(y) + sample_weight = np.array(sample_weight) + if sample_weight is None: + self.avg = np.nanmean(y) + else: + self.avg = np.nansum(y * sample_weight) / np.nansum(sample_weight) + + def predict(self, X): + return np.full(X.shape[0], self.avg) diff --git a/wise_pizza/solver.py b/wise_pizza/solve/solver.py similarity index 100% rename from wise_pizza/solver.py rename to wise_pizza/solve/solver.py diff --git a/wise_pizza/solve/tree.py b/wise_pizza/solve/tree.py new file mode 100644 index 0000000..eda97c5 --- /dev/null +++ b/wise_pizza/solve/tree.py @@ -0,0 +1,180 @@ +import copy +from typing import Optional, List, Dict, Tuple + +import numpy as np +import pandas as pd +from scipy.sparse import csc_matrix + +from .weighted_quantiles import weighted_quantiles +from .fitter import AverageFitter, Fitter +from wise_pizza.cluster import nice_cluster_names + + +def tree_solver( + dim_df: pd.DataFrame, + dims: List[str], + time_basis: Optional[pd.DataFrame] = None, + max_depth: int = 3, + num_leaves: Optional[int] = None, +): + if time_basis is None: + fitter = AverageFitter() + else: + raise NotImplementedError("Time fitter not yet implemented") + # fitter = TimeFitter(dims, list(time_basis.columns)) + + df = dim_df.copy().reset_index(drop=True) + df["__avg"] = df["totals"] / df["weights"] + df["__avg"] = df["__avg"].fillna(df["__avg"].mean()) + + root = ModelNode(df=df, fitter=fitter, dims=dims) + + build_tree(root=root, num_leaves=num_leaves, max_depth=max_depth) + + leaves = get_leaves(root) + + col_defs, cluster_names = nice_cluster_names([leaf.dim_split for leaf in leaves]) + + for l, leaf in enumerate(leaves): + leaf.df["Segment_id"] = l + + re_df = pd.concat([leaf.df for leaf in leaves]).sort_values(dims) + X = pd.get_dummies(re_df["Segment_id"]).values + + return csc_matrix(X), col_defs, cluster_names + + +def error(x: np.ndarray, y: np.ndarray) -> float: + return np.sum((x - y) ** 2) + + +def target_encode(df: pd.DataFrame, dim: str) -> dict: + df = df[[dim, "totals", "weights"]] + agg = df.groupby(dim, as_index=False).sum() + agg["__avg"] = agg["totals"] / agg["weights"] + agg["__avg"] = agg["__avg"].fillna(agg["__avg"].mean()) + enc_map = {k: v for k, v in zip(agg[dim], agg["__avg"])} + + if np.isnan(np.array(list(enc_map.values()))).any(): + raise ValueError("NaNs in encoded values") + return enc_map + + +class ModelNode: + def __init__( + self, + df: pd.DataFrame, + fitter: Fitter, + dims: List[str], + dim_split: Optional[Dict[str, List]] = None, + depth: int = 0, + ): + self.df = df + self.fitter = fitter + self.dims = dims + self._best_submodels = None + self._error_improvement = float("-inf") + self.children = None + self.dim_split = dim_split or {} + self.depth = depth + self.model = None + + @property + def error(self): + if self.model is None: + self.model = copy.deepcopy(self.fitter) + self.model.fit( + X=self.df[self.dims], + y=self.df["totals"], + sample_weight=self.df["weights"], + ) + return self.model.error( + self.df[self.dims], self.df["__avg"], self.df["weights"] + ) + + @property + def error_improvement(self): + if self._best_submodels is None: + best_error = float("inf") + for dim in self.dims: + if len(self.df[dim].unique()) == 1: + continue + enc_map = target_encode(self.df, dim) + self.df[dim + "_encoded"] = self.df[dim].apply(lambda x: enc_map[x]) + if np.any(np.isnan(self.df[dim + "_encoded"])): # pragma: no cover + raise ValueError("NaNs in encoded values") + # Get split candidates for brute force search + deciles = np.array([q / 10.0 for q in range(1, 10)]) + + splits = weighted_quantiles( + self.df[dim + "_encoded"], deciles, self.df["weights"] + ) + + for split in np.unique(splits): + left = self.df[self.df[dim + "_encoded"] < split] + right = self.df[self.df[dim + "_encoded"] >= split] + if len(left) == 0 or len(right) == 0: + continue + dim_values1 = [k for k, v in enc_map.items() if v < split] + dim_values2 = [k for k, v in enc_map.items() if v >= split] + left_candidate = ModelNode( + df=left, + fitter=self.fitter, + dims=self.dims, + dim_split={**self.dim_split, **{dim: dim_values1}}, + depth=self.depth + 1, + ) + right_candidate = ModelNode( + df=right, + fitter=self.fitter, + dims=self.dims, + dim_split={**self.dim_split, **{dim: dim_values2}}, + depth=self.depth + 1, + ) + + err = left_candidate.error + right_candidate.error + if err < best_error: + best_error = err + self._error_improvement = self.error - best_error + self._best_submodels = (left_candidate, right_candidate) + + return self._error_improvement + + +def mod_improvement(improvement: float, depth: int, max_depth: int) -> float: + if depth < max_depth: + return improvement + else: + return float("-inf") + + +def get_best_subtree_result( + node: ModelNode, max_depth: Optional[int] = 1000 +) -> ModelNode: + if node.children is None or node.depth >= max_depth: + return node + else: + node1 = get_best_subtree_result(node.children[0]) + node2 = get_best_subtree_result(node.children[1]) + improvement1 = mod_improvement(node1.error_improvement, node1.depth, max_depth) + improvement2 = mod_improvement(node2.error_improvement, node2.depth, max_depth) + if improvement1 > improvement2: + return node1 + else: + return node2 + + +def build_tree(root: ModelNode, num_leaves: int, max_depth: Optional[int] = 1000): + for _ in range(num_leaves - 1): + best_node = get_best_subtree_result(root, max_depth) + if best_node.error_improvement > 0: + best_node.children = best_node._best_submodels + else: + break + + +def get_leaves(node: ModelNode) -> List[ModelNode]: + if node.children is None: + return [node] + else: + return get_leaves(node.children[0]) + get_leaves(node.children[1]) diff --git a/wise_pizza/solve/weighted_quantiles.py b/wise_pizza/solve/weighted_quantiles.py new file mode 100644 index 0000000..5fef17a --- /dev/null +++ b/wise_pizza/solve/weighted_quantiles.py @@ -0,0 +1,21 @@ +import numpy as np + + +def weighted_quantiles(values, quantiles, sample_weight): + """Compute the weighted quantile of a 1D numpy array.""" + values_ = np.array(values) + sample_weight_ = np.array(sample_weight) + nice = ~np.isnan(values) & ~np.isnan(sample_weight) + if np.any(~nice): + raise ValueError("Data contains NaNs") + sorter = np.argsort(values_) + sorted_values = values_[sorter] + sorted_weights = sample_weight_[sorter] + w_quantiles = np.cumsum(sorted_weights) - 0.5 * sorted_weights + w_quantiles /= np.sum(sorted_weights) + + try: + return np.interp(quantiles, w_quantiles, sorted_values) + except Exception as e: + print(e) + raise e diff --git a/wise_pizza/transform.py b/wise_pizza/transform.py index 3193b3c..e61837b 100644 --- a/wise_pizza/transform.py +++ b/wise_pizza/transform.py @@ -39,7 +39,7 @@ def inverse_transform_totals_weights( w = self.inverse_transform_weight(t_w, t_mean) return mean * w, w - def test_transforms(self, total, weights, eps=1e-6): + def test_transforms(self, total, weights, eps=1e-4): mean = total / weights t_mean = self.transform_mean(mean) assert almost_equals(mean, self.inverse_transform_mean(t_mean), eps) @@ -71,19 +71,28 @@ def inverse_transform_weight(self, w: np.ndarray, x: np.ndarray) -> np.ndarray: class LogTransform(TransformWithWeights): def __init__( - self, offset: float, weight_pow_sc: float = 0.1, max_inverse: float = 1e6 + self, offset: float, weight_pow_sc: float = 0.1, cap_inverse: bool = True ): self.offset = offset self.weight_pow_sc = weight_pow_sc - self.max_inverse = max_inverse + self.cap_inverse = cap_inverse + if cap_inverse: + self.max_inverse = 0.0 + else: + self.max_inverse = None def transform_mean(self, x: np.ndarray) -> np.ndarray: + if self.cap_inverse: + self.max_inverse = np.maximum(self.max_inverse, 2 * x.max()) return np.log(self.offset + x) def inverse_transform_mean(self, x: np.ndarray) -> np.ndarray: - return np.maximum( - 0.0, np.exp(np.minimum(x, np.log(self.max_inverse))) - self.offset - ) + if self.cap_inverse: + return np.maximum( + 0.0, np.exp(np.minimum(x, np.log(self.max_inverse))) - self.offset + ) + else: + return np.maximum(0.0, np.exp(x) - self.offset) def transform_weight(self, w: np.ndarray, mean: np.ndarray) -> np.ndarray: # pure math would give weight_pow_sc = 1, but then