-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MERGE latest changes on master with calibration
- Loading branch information
Showing
24 changed files
with
853 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,33 @@ | ||
import os | ||
|
||
import utils_io | ||
import utils_data | ||
import utils_snapshots | ||
|
||
import model_slda | ||
calc_nef_map_pi_DK = model_slda.calc_nef_map_pi_DK | ||
|
||
# TODO discard this line | ||
# calc_nef_map_pi_DK = model_slda.calc_nef_map_pi_DK | ||
|
||
PC_REPO_DIR = os.path.sep.join( | ||
os.path.abspath(__file__).split(os.path.sep)[:-2]) | ||
|
||
## Create version attrib | ||
__version__ = None | ||
version_txt_path = os.path.join(PC_REPO_DIR, 'version.txt') | ||
if os.path.exists(version_txt_path): | ||
with open(version_txt_path, 'r') as f: | ||
__version__ = f.readline().strip() | ||
|
||
## Create requirements attrib | ||
__requirements__ = None | ||
reqs_txt_path = os.path.join(PC_REPO_DIR, 'requirements.txt') | ||
if os.path.exists(reqs_txt_path): | ||
with open(reqs_txt_path, 'r') as f: | ||
__requirements__ = [] | ||
for line in f.readlines(): | ||
__requirements__.append(line.strip()) | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import numpy as np | ||
from scipy.special import expit | ||
import matplotlib.gridspec as gridspec | ||
import matplotlib.pyplot as plt | ||
|
||
def plot_binary_clf_calibration_curve_and_histograms( | ||
info_per_bin=None, | ||
fig_kws=dict( | ||
figsize=(1.4*3, 1.4*4), | ||
tight_layout=True), | ||
): | ||
fig_h = plt.figure(**fig_kws) | ||
ax_grid = gridspec.GridSpec( | ||
nrows=4, ncols=1, | ||
height_ratios=[1, 1, 4, 0.1], | ||
) | ||
ax_cal = fig_h.add_subplot(ax_grid[2,0]) | ||
ax_TP = fig_h.add_subplot(ax_grid[0,0]) | ||
ax_TN = fig_h.add_subplot(ax_grid[1,0]) | ||
|
||
# Plot calibration curve | ||
# First, lay down idealized line from 0-1 | ||
unit_grid = np.linspace(0, 1, 10) | ||
ax_cal.plot( | ||
unit_grid, unit_grid, 'k--', alpha=0.5) | ||
# Then, plot actual-vs-expected fractions on top | ||
ax_cal.plot( | ||
info_per_bin['xcenter_per_bin'], | ||
info_per_bin['fracTP_per_bin'], | ||
'ks-') | ||
ax_cal.set_ylabel('frac. true positive') | ||
ax_cal.set_xlabel('predicted proba.') | ||
|
||
# Plot TP histogram | ||
ax_TP.bar( | ||
info_per_bin['xcenter_per_bin'], | ||
info_per_bin['countTP_per_bin'], | ||
width=0.9*info_per_bin['xwidth_per_bin'], | ||
color='b') | ||
|
||
# Plot TN histogram | ||
ax_TN.bar( | ||
info_per_bin['xcenter_per_bin'], | ||
info_per_bin['countTN_per_bin'], | ||
width=0.9*info_per_bin['xwidth_per_bin'], | ||
color='r') | ||
for ax in [ax_cal, ax_TP, ax_TN]: | ||
ax.set_xlim([0, 1]) | ||
ax_cal.set_ylim([0, 1]) | ||
|
||
def calc_binary_clf_calibration_per_bin( | ||
y_true, y_prob, | ||
bins=10): | ||
""" | ||
""" | ||
if y_prob.min() < 0 or y_prob.max() > 1: | ||
raise ValueError("y_prob has values outside [0, 1]") | ||
|
||
bins = np.asarray(bins) | ||
if bins.ndim == 1 and bins.size > 1: | ||
bin_edges = bins | ||
else: | ||
bin_edges = np.linspace(0, 1, int(bins) + 1) | ||
if bin_edges[-1] == 1.0: | ||
bin_edges[-1] += 1e-8 | ||
assert bin_edges.ndim == 1 | ||
assert bin_edges.size > 2 | ||
nbins = bin_edges.size - 1 | ||
# Assign each predicted probability into one bin | ||
# from 0, 1, ... nbins | ||
binids = np.digitize(y_prob, bin_edges) - 1 | ||
assert binids.max() <= nbins | ||
assert binids.min() >= 0 | ||
|
||
count_per_bin = np.bincount(binids, minlength=nbins) | ||
countTP_per_bin = np.bincount(binids, minlength=nbins, weights=y_true == 1) | ||
countTN_per_bin = np.bincount(binids, minlength=nbins, weights=y_true == 0) | ||
|
||
# This divide will (and should) yield nan | ||
# if any bin has no content | ||
fracTP_per_bin = countTP_per_bin / np.asarray(count_per_bin, dtype=np.float64) | ||
|
||
info_per_bin = dict( | ||
count_per_bin=count_per_bin, | ||
countTP_per_bin=countTP_per_bin, | ||
countTN_per_bin=countTN_per_bin, | ||
fracTP_per_bin=fracTP_per_bin, | ||
xcenter_per_bin=0.5 * (bin_edges[:-1] + bin_edges[1:]), | ||
xwidth_per_bin=(bin_edges[1:] - bin_edges[:-1]), | ||
bin_edges=bin_edges, | ||
) | ||
return info_per_bin | ||
|
||
|
||
if __name__ == '__main__': | ||
prng = np.random.RandomState(0) | ||
thr_true = prng.rand(100000) | ||
u_true = 0.65 * prng.randn(100000) | ||
y_true = np.asarray(expit(u_true) >= thr_true, dtype=np.float32) | ||
y_prob = expit(u_true) | ||
|
||
bins = 20 | ||
|
||
info_per_bin = calc_binary_clf_calibration_per_bin( | ||
y_true=y_true, | ||
y_prob=y_prob, | ||
bins=bins) | ||
bin_edges = info_per_bin['bin_edges'] | ||
for bb in range(bin_edges.size - 1): | ||
print "bin [%.2f, %.2f] count %5d fracTP %.3f" % ( | ||
bin_edges[bb], | ||
bin_edges[bb+1], | ||
info_per_bin['count_per_bin'][bb], | ||
info_per_bin['fracTP_per_bin'][bb], | ||
) | ||
|
||
plot_binary_clf_calibration_curve_and_histograms( | ||
info_per_bin=info_per_bin) | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 6 additions & 0 deletions
6
pc_toolbox/model_slda/est_local_params__vb_qpiDir_qzCat/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from calc_elbo_for_many_docs__vb_qpiDir_qzCat import ( | ||
calc_elbo_for_many_docs) | ||
|
||
from calc_N_d_K__vb_qpiDir_qzCat import ( | ||
calc_N_d_K__vb_coord_ascent__many_tries, | ||
calc_N_d_K__vb_coord_ascent) |
Oops, something went wrong.