From bdd60a77025cc767382833385ff8cbed0760667a Mon Sep 17 00:00:00 2001 From: Ananya Kumar Date: Mon, 7 Nov 2022 16:46:40 -0800 Subject: [PATCH] Optionally return platt scaling classifier. --- calibration/utils.py | 4 +++- setup.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/calibration/utils.py b/calibration/utils.py index 5ae7f10..f462f25 100644 --- a/calibration/utils.py +++ b/calibration/utils.py @@ -453,7 +453,7 @@ def bootstrap_std(data: List[T], estimator=None, num_samples=100) -> Tuple[float # Re-Calibration utilities. -def get_platt_scaler(model_probs, labels): +def get_platt_scaler(model_probs, labels, get_clf=False): clf = LogisticRegression(C=1e10, solver='lbfgs') eps = 1e-12 model_probs = model_probs.astype(dtype=np.float64) @@ -468,6 +468,8 @@ def calibrator(probs): x = x * clf.coef_[0] + clf.intercept_ output = 1 / (1 + np.exp(-x)) return output + if get_clf: + return calibrator, clf return calibrator diff --git a/setup.py b/setup.py index 789eff4..c70cf9d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="uncertainty-calibration", - version="0.1.2", + version="0.1.3", author="Ananya Kumar", author_email="skywalker94@gmail.com", description="Utilities to calibrate model uncertainties and measure calibration.",