diff --git a/src/maml/utils/_stats.py b/src/maml/utils/_stats.py index 2c631a40..8a1cd047 100644 --- a/src/maml/utils/_stats.py +++ b/src/maml/utils/_stats.py @@ -309,9 +309,9 @@ def power_mean(data: list[float], weights: list[float] | None = None, p: int = 1 assert abs(sum(weights) - 1) < 1e-3 if p == 0: - return np.prod([i**j for i, j in zip(data, weights, strict=False)]).item() + return np.prod([i**j for i, j in zip(data, weights)]).item() - s = np.sum([j * i**p for i, j in zip(data, weights, strict=False)]) + s = np.sum([j * i**p for i, j in zip(data, weights)]) return s ** (1.0 / p) @staticmethod