From 2baff66a4cd1fb8178c543694e84488d5f193a56 Mon Sep 17 00:00:00 2001 From: lart Date: Fri, 24 Feb 2023 16:39:20 +0800 Subject: [PATCH] Add overall accuracy and kappa. --- examples/metric_recorder.py | 101 +++++------- examples/test_metrics.py | 300 ++++++++++++++++++++++++----------- py_sod_metrics/__init__.py | 6 +- py_sod_metrics/fmeasurev2.py | 103 ++++++++---- 4 files changed, 330 insertions(+), 180 deletions(-) diff --git a/examples/metric_recorder.py b/examples/metric_recorder.py index e568b3e..815d604 100644 --- a/examples/metric_recorder.py +++ b/examples/metric_recorder.py @@ -104,72 +104,53 @@ def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: return {"sequential": sequential_results, "numerical": numerical_results} +sample_gray = dict(with_adaptive=True, with_dynamic=True) +sample_bin = dict(with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True) +overall_bin = dict(with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=False) BINARY_CLASSIFICATION_METRIC_MAPPING = { # 灰度数据指标 - "fm": py_sod_metrics.FmeasureHandler(with_adaptive=True, with_dynamic=True, beta=0.3), - "f1": py_sod_metrics.FmeasureHandler(with_adaptive=True, with_dynamic=True, beta=0.1), - "pre": py_sod_metrics.PrecisionHandler(with_adaptive=True, with_dynamic=True), - "rec": py_sod_metrics.RecallHandler(with_adaptive=True, with_dynamic=True), - "iou": py_sod_metrics.IOUHandler(with_adaptive=True, with_dynamic=True), - "dice": py_sod_metrics.DICEHandler(with_adaptive=True, with_dynamic=True), - "spec": py_sod_metrics.SpecificityHandler(with_adaptive=True, with_dynamic=True), - "ber": py_sod_metrics.BERHandler(with_adaptive=True, with_dynamic=True), + "fm": py_sod_metrics.FmeasureHandler(**sample_gray, beta=0.3), + "f1": py_sod_metrics.FmeasureHandler(**sample_gray, beta=0.1), + "pre": py_sod_metrics.PrecisionHandler(**sample_gray), + "rec": py_sod_metrics.RecallHandler(**sample_gray), + "iou": py_sod_metrics.IOUHandler(**sample_gray), + "dice": py_sod_metrics.DICEHandler(**sample_gray), + "spec": py_sod_metrics.SpecificityHandler(**sample_gray), + "ber": py_sod_metrics.BERHandler(**sample_gray), + "oa": py_sod_metrics.OverallAccuracyHandler(**sample_gray), + "kappa": py_sod_metrics.KappaHandler(**sample_gray), # 二值化数据指标的特殊情况一:各个样本独立计算指标后取平均 - "sample_bifm": py_sod_metrics.FmeasureHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=0.3 - ), - "sample_bif1": py_sod_metrics.FmeasureHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=1 - ), - "sample_bipre": py_sod_metrics.PrecisionHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "sample_birec": py_sod_metrics.RecallHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "sample_biiou": py_sod_metrics.IOUHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "sample_bidice": py_sod_metrics.DICEHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "sample_bispec": py_sod_metrics.SpecificityHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "sample_biber": py_sod_metrics.BERHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), + "sample_bifm": py_sod_metrics.FmeasureHandler(**sample_bin, beta=0.3), + "sample_bif1": py_sod_metrics.FmeasureHandler(**sample_bin, beta=1), + "sample_bipre": py_sod_metrics.PrecisionHandler(**sample_bin), + "sample_birec": py_sod_metrics.RecallHandler(**sample_bin), + "sample_biiou": py_sod_metrics.IOUHandler(**sample_bin), + "sample_bidice": py_sod_metrics.DICEHandler(**sample_bin), + "sample_bispec": py_sod_metrics.SpecificityHandler(**sample_bin), + "sample_biber": py_sod_metrics.BERHandler(**sample_bin), + "sample_bioa": py_sod_metrics.OverallAccuracyHandler(**sample_bin), + "sample_bikappa": py_sod_metrics.KappaHandler(**sample_bin), # 二值化数据指标的特殊情况二:汇总所有样本的tp、fp、tn、fn后整体计算指标 - "overall_bifm": py_sod_metrics.FmeasureHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=0.3 - ), - "overall_bif1": py_sod_metrics.FmeasureHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=1 - ), - "overall_bipre": py_sod_metrics.PrecisionHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "overall_birec": py_sod_metrics.RecallHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "overall_biiou": py_sod_metrics.IOUHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "overall_bidice": py_sod_metrics.DICEHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "overall_bispec": py_sod_metrics.SpecificityHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), - "overall_biber": py_sod_metrics.BERHandler( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True - ), + "overall_bifm": py_sod_metrics.FmeasureHandler(**overall_bin, beta=0.3), + "overall_bif1": py_sod_metrics.FmeasureHandler(**overall_bin, beta=1), + "overall_bipre": py_sod_metrics.PrecisionHandler(**overall_bin), + "overall_birec": py_sod_metrics.RecallHandler(**overall_bin), + "overall_biiou": py_sod_metrics.IOUHandler(**overall_bin), + "overall_bidice": py_sod_metrics.DICEHandler(**overall_bin), + "overall_bispec": py_sod_metrics.SpecificityHandler(**overall_bin), + "overall_biber": py_sod_metrics.BERHandler(**overall_bin), + "overall_bioa": py_sod_metrics.OverallAccuracyHandler(**overall_bin), + "overall_bikappa": py_sod_metrics.KappaHandler(**overall_bin), } class MetricRecorderV2: suppoted_metrics = ["mae", "em", "sm", "wfm"] + sorted( - [k for k in BINARY_CLASSIFICATION_METRIC_MAPPING.keys() if not k.startswith(('sample_', 'overall_'))] + [ + k + for k in BINARY_CLASSIFICATION_METRIC_MAPPING.keys() + if not k.startswith(("sample_", "overall_")) + ] ) def __init__(self, metric_names=("sm", "wfm", "mae", "fmeasure", "em")): @@ -248,7 +229,11 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict: class BinaryMetricRecorder: suppoted_metrics = ["mae", "sm", "wfm"] + sorted( - [k for k in BINARY_CLASSIFICATION_METRIC_MAPPING.keys() if k.startswith(('sample_', 'overall_'))] + [ + k + for k in BINARY_CLASSIFICATION_METRIC_MAPPING.keys() + if k.startswith(("sample_", "overall_")) + ] ) def __init__(self, metric_names=("bif1", "biprecision", "birecall", "biiou")): diff --git a/examples/test_metrics.py b/examples/test_metrics.py index e698af8..4fab8c0 100644 --- a/examples/test_metrics.py +++ b/examples/test_metrics.py @@ -6,6 +6,8 @@ import os import sys import unittest +from pprint import pprint + import cv2 sys.path.append("..") @@ -17,39 +19,44 @@ EM = py_sod_metrics.Emeasure() MAE = py_sod_metrics.MAE() -sample_binary = dict(with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True) -overall_binary = dict( - with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=False -) +sample_gray = dict(with_adaptive=True, with_dynamic=True) +sample_bin = dict(with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True) +overall_bin = dict(with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=False) FMv2 = py_sod_metrics.FmeasureV2( metric_handlers={ # 灰度数据指标 - "fm": py_sod_metrics.FmeasureHandler(with_adaptive=True, with_dynamic=True, beta=0.3), - "f1": py_sod_metrics.FmeasureHandler(with_adaptive=True, with_dynamic=True, beta=0.1), - "pre": py_sod_metrics.PrecisionHandler(with_adaptive=True, with_dynamic=True), - "rec": py_sod_metrics.RecallHandler(with_adaptive=True, with_dynamic=True), - "iou": py_sod_metrics.IOUHandler(with_adaptive=True, with_dynamic=True), - "dice": py_sod_metrics.DICEHandler(with_adaptive=True, with_dynamic=True), - "spec": py_sod_metrics.SpecificityHandler(with_adaptive=True, with_dynamic=True), - "ber": py_sod_metrics.BERHandler(with_adaptive=True, with_dynamic=True), + "fm": py_sod_metrics.FmeasureHandler(**sample_gray, beta=0.3), + "f1": py_sod_metrics.FmeasureHandler(**sample_gray, beta=0.1), + "pre": py_sod_metrics.PrecisionHandler(**sample_gray), + "rec": py_sod_metrics.RecallHandler(**sample_gray), + "iou": py_sod_metrics.IOUHandler(**sample_gray), + "dice": py_sod_metrics.DICEHandler(**sample_gray), + "spec": py_sod_metrics.SpecificityHandler(**sample_gray), + "ber": py_sod_metrics.BERHandler(**sample_gray), + "oa": py_sod_metrics.OverallAccuracyHandler(**sample_gray), + "kappa": py_sod_metrics.KappaHandler(**sample_gray), # 二值化数据指标的特殊情况一:各个样本独立计算指标后取平均 - "sample_bifm": py_sod_metrics.FmeasureHandler(**sample_binary, beta=0.3), - "sample_bif1": py_sod_metrics.FmeasureHandler(**sample_binary, beta=1), - "sample_bipre": py_sod_metrics.PrecisionHandler(**sample_binary), - "sample_birec": py_sod_metrics.RecallHandler(**sample_binary), - "sample_biiou": py_sod_metrics.IOUHandler(**sample_binary), - "sample_bidice": py_sod_metrics.DICEHandler(**sample_binary), - "sample_bispec": py_sod_metrics.SpecificityHandler(**sample_binary), - "sample_biber": py_sod_metrics.BERHandler(**sample_binary), + "sample_bifm": py_sod_metrics.FmeasureHandler(**sample_bin, beta=0.3), + "sample_bif1": py_sod_metrics.FmeasureHandler(**sample_bin, beta=1), + "sample_bipre": py_sod_metrics.PrecisionHandler(**sample_bin), + "sample_birec": py_sod_metrics.RecallHandler(**sample_bin), + "sample_biiou": py_sod_metrics.IOUHandler(**sample_bin), + "sample_bidice": py_sod_metrics.DICEHandler(**sample_bin), + "sample_bispec": py_sod_metrics.SpecificityHandler(**sample_bin), + "sample_biber": py_sod_metrics.BERHandler(**sample_bin), + "sample_bioa": py_sod_metrics.OverallAccuracyHandler(**sample_bin), + "sample_bikappa": py_sod_metrics.KappaHandler(**sample_bin), # 二值化数据指标的特殊情况二:汇总所有样本的tp、fp、tn、fn后整体计算指标 - "overall_bifm": py_sod_metrics.FmeasureHandler(**overall_binary, beta=0.3), - "overall_bif1": py_sod_metrics.FmeasureHandler(**overall_binary, beta=1), - "overall_bipre": py_sod_metrics.PrecisionHandler(**overall_binary), - "overall_birec": py_sod_metrics.RecallHandler(**overall_binary), - "overall_biiou": py_sod_metrics.IOUHandler(**overall_binary), - "overall_bidice": py_sod_metrics.DICEHandler(**overall_binary), - "overall_bispec": py_sod_metrics.SpecificityHandler(**overall_binary), - "overall_biber": py_sod_metrics.BERHandler(**overall_binary), + "overall_bifm": py_sod_metrics.FmeasureHandler(**overall_bin, beta=0.3), + "overall_bif1": py_sod_metrics.FmeasureHandler(**overall_bin, beta=1), + "overall_bipre": py_sod_metrics.PrecisionHandler(**overall_bin), + "overall_birec": py_sod_metrics.RecallHandler(**overall_bin), + "overall_biiou": py_sod_metrics.IOUHandler(**overall_bin), + "overall_bidice": py_sod_metrics.DICEHandler(**overall_bin), + "overall_bispec": py_sod_metrics.SpecificityHandler(**overall_bin), + "overall_biber": py_sod_metrics.BERHandler(**overall_bin), + "overall_bioa": py_sod_metrics.OverallAccuracyHandler(**overall_bin), + "overall_bikappa": py_sod_metrics.KappaHandler(**overall_bin), } ) @@ -77,6 +84,80 @@ mae = MAE.get_results()["mae"] fmv2 = FMv2.get_results() +curr_results = { + "MAE": mae, + "Smeasure": sm, + "wFmeasure": wfm, + # E-measure for sod + "adpEm": em["adp"], + "meanEm": em["curve"].mean(), + "maxEm": em["curve"].max(), + # F-measure for sod + "adpFm": fm["adp"], + "meanFm": fm["curve"].mean(), + "maxFm": fm["curve"].max(), + # general F-measure + "adpfm": fmv2["fm"]["adaptive"], + "meanfm": fmv2["fm"]["dynamic"].mean(), + "maxfm": fmv2["fm"]["dynamic"].max(), + "sample_bifm": fmv2["sample_bifm"]["binary"], + "overall_bifm": fmv2["overall_bifm"]["binary"], + # precision + "adppre": fmv2["pre"]["adaptive"], + "meanpre": fmv2["pre"]["dynamic"].mean(), + "maxpre": fmv2["pre"]["dynamic"].max(), + "sample_bipre": fmv2["sample_bipre"]["binary"], + "overall_bipre": fmv2["overall_bipre"]["binary"], + # recall + "adprec": fmv2["rec"]["adaptive"], + "meanrec": fmv2["rec"]["dynamic"].mean(), + "maxrec": fmv2["rec"]["dynamic"].max(), + "sample_birec": fmv2["sample_birec"]["binary"], + "overall_birec": fmv2["overall_birec"]["binary"], + # dice + "adpdice": fmv2["dice"]["adaptive"], + "meandice": fmv2["dice"]["dynamic"].mean(), + "maxdice": fmv2["dice"]["dynamic"].max(), + "sample_bidice": fmv2["sample_bidice"]["binary"], + "overall_bidice": fmv2["overall_bidice"]["binary"], + # iou + "adpiou": fmv2["iou"]["adaptive"], + "meaniou": fmv2["iou"]["dynamic"].mean(), + "maxiou": fmv2["iou"]["dynamic"].max(), + "sample_biiou": fmv2["sample_biiou"]["binary"], + "overall_biiou": fmv2["overall_biiou"]["binary"], + # f1 score + "adpf1": fmv2["f1"]["adaptive"], + "meanf1": fmv2["f1"]["dynamic"].mean(), + "maxf1": fmv2["f1"]["dynamic"].max(), + "sample_bif1": fmv2["sample_bif1"]["binary"], + "overall_bif1": fmv2["overall_bif1"]["binary"], + # specificity + "adpspec": fmv2["spec"]["adaptive"], + "meanspec": fmv2["spec"]["dynamic"].mean(), + "maxspec": fmv2["spec"]["dynamic"].max(), + "sample_bispec": fmv2["sample_bispec"]["binary"], + "overall_bispec": fmv2["overall_bispec"]["binary"], + # ber + "adpber": fmv2["ber"]["adaptive"], + "meanber": fmv2["ber"]["dynamic"].mean(), + "maxber": fmv2["ber"]["dynamic"].max(), + "sample_biber": fmv2["sample_biber"]["binary"], + "overall_biber": fmv2["overall_biber"]["binary"], + # overall accuracy + "adpoa": fmv2["oa"]["adaptive"], + "meanoa": fmv2["oa"]["dynamic"].mean(), + "maxoa": fmv2["oa"]["dynamic"].max(), + "sample_bioa": fmv2["sample_bioa"]["binary"], + "overall_bioa": fmv2["overall_bioa"]["binary"], + # kappa + "adpkappa": fmv2["kappa"]["adaptive"], + "meankappa": fmv2["kappa"]["dynamic"].mean(), + "maxkappa": fmv2["kappa"]["dynamic"].max(), + "sample_bikappa": fmv2["sample_bikappa"]["binary"], + "overall_bikappa": fmv2["overall_bikappa"]["binary"], +} + default_results = { "v1_2_3": { "Smeasure": 0.9029763868504661, @@ -110,6 +191,8 @@ "adpf1": 0.5825795996723205, "adpfm": 0.5816750824038355, "adpiou": 0.5141023436626048, + "adpkappa": 0.6568702977598276, + "adpoa": 0.9391947016812359, "adppre": 0.583200007681871, "adprec": 0.5777548546727481, "adpspec": 0.9512882075256152, @@ -120,6 +203,8 @@ "maxf1": 0.6031100666167747, "maxfm": 0.5886784581120638, "maxiou": 0.5201569938888494, + "maxkappa": 0.6759493461328753, + "maxoa": 0.9654783867686053, "maxpre": 0.6396783912301717, "maxrec": 0.6666666666666666, "maxspec": 0.9965927890353435, @@ -130,6 +215,8 @@ "meanf1": 0.5821115124232528, "meanfm": 0.577051059518767, "meaniou": 0.49816648786971, + "meankappa": 0.6443053495487194, + "meanoa": 0.9596413706286032, "meanpre": 0.5857695537152126, "meanrec": 0.5599653001125341, "meanspec": 0.9742186408675534, @@ -138,6 +225,8 @@ "overall_bif1": 0.8510675335753017, "overall_bifm": 0.8525259082995088, "overall_biiou": 0.740746352327995, + "overall_bikappa": 0.7400114676102276, + "overall_bioa": 0.965778, "overall_bipre": 0.8537799277020065, "overall_birec": 0.8483723190115916, "overall_bispec": 0.9810724910256526, @@ -146,6 +235,8 @@ "sample_bif1": 0.5738376903441331, "sample_bifm": 0.5829998670906196, "sample_biiou": 0.5039622042094377, + "sample_bikappa": 0.6510635726572914, + "sample_bioa": 0.964811758770181, "sample_bipre": 0.5916996553523113, "sample_birec": 0.5592859147614985, "sample_bispec": 0.9799569090918337, @@ -153,84 +244,107 @@ }, } + class CheckMetricTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.results = default_results["v1_4_0"] + print("Current results:") + pprint(curr_results) + cls.default_results = default_results["v1_4_0"] def test_sm(self): - self.assertEqual(sm, self.results["Smeasure"]) + self.assertEqual(curr_results["Smeasure"], self.default_results["Smeasure"]) def test_wfm(self): - self.assertEqual(wfm, self.results["wFmeasure"]) + self.assertEqual(curr_results["wFmeasure"], self.default_results["wFmeasure"]) def test_mae(self): - self.assertEqual(mae, self.results["MAE"]) + self.assertEqual(curr_results["MAE"], self.default_results["MAE"]) def test_fm(self): - self.assertEqual(fm["adp"], self.results["adpFm"]) - self.assertEqual(fm["curve"].mean(), self.results["meanFm"]) - self.assertEqual(fm["curve"].max(), self.results["maxFm"]) + self.assertEqual(curr_results["adpFm"], self.default_results["adpFm"]) + self.assertEqual(curr_results["meanFm"], self.default_results["meanFm"]) + self.assertEqual(curr_results["maxFm"], self.default_results["maxFm"]) + + self.assertEqual(curr_results["adpfm"], self.default_results["adpfm"]) + self.assertEqual(curr_results["meanfm"], self.default_results["meanfm"]) + self.assertEqual(curr_results["maxfm"], self.default_results["maxfm"]) - def test_em(self): - self.assertEqual(em["adp"], self.results["adpEm"]) - self.assertEqual(em["curve"].mean(), self.results["meanEm"]) - self.assertEqual(em["curve"].max(), self.results["maxEm"]) - - def test_fmv2(self): - self.assertEqual(fmv2["fm"]["adaptive"], self.results["adpfm"]) - self.assertEqual(fmv2["fm"]["dynamic"].mean(), self.results["meanfm"]) - self.assertEqual(fmv2["fm"]["dynamic"].max(), self.results["maxfm"]) # 对齐v1版本 - self.assertEqual(fmv2["fm"]["adaptive"], self.results["adpFm"]) - self.assertEqual(fmv2["fm"]["dynamic"].mean(), self.results["meanFm"]) - self.assertEqual(fmv2["fm"]["dynamic"].max(), self.results["maxFm"]) - - self.assertEqual(fmv2["f1"]["adaptive"], self.results["adpf1"]) - self.assertEqual(fmv2["f1"]["dynamic"].mean(), self.results["meanf1"]) - self.assertEqual(fmv2["f1"]["dynamic"].max(), self.results["maxf1"]) - - self.assertEqual(fmv2["pre"]["adaptive"], self.results["adppre"]) - self.assertEqual(fmv2["pre"]["dynamic"].mean(), self.results["meanpre"]) - self.assertEqual(fmv2["pre"]["dynamic"].max(), self.results["maxpre"]) - - self.assertEqual(fmv2["rec"]["adaptive"], self.results["adprec"]) - self.assertEqual(fmv2["rec"]["dynamic"].mean(), self.results["meanrec"]) - self.assertEqual(fmv2["rec"]["dynamic"].max(), self.results["maxrec"]) - - self.assertEqual(fmv2["spec"]["adaptive"], self.results["adpspec"]) - self.assertEqual(fmv2["spec"]["dynamic"].mean(), self.results["meanspec"]) - self.assertEqual(fmv2["spec"]["dynamic"].max(), self.results["maxspec"]) - - self.assertEqual(fmv2["iou"]["adaptive"], self.results["adpiou"]) - self.assertEqual(fmv2["iou"]["dynamic"].mean(), self.results["meaniou"]) - self.assertEqual(fmv2["iou"]["dynamic"].max(), self.results["maxiou"]) - - self.assertEqual(fmv2["dice"]["adaptive"], self.results["adpdice"]) - self.assertEqual(fmv2["dice"]["dynamic"].mean(), self.results["meandice"]) - self.assertEqual(fmv2["dice"]["dynamic"].max(), self.results["maxdice"]) - - self.assertEqual(fmv2["ber"]["adaptive"], self.results["adpber"]) - self.assertEqual(fmv2["ber"]["dynamic"].mean(), self.results["meanber"]) - self.assertEqual(fmv2["ber"]["dynamic"].max(), self.results["maxber"]) - - self.assertEqual(fmv2["sample_bifm"]["binary"], self.results["sample_bifm"]) - self.assertEqual(fmv2["sample_bif1"]["binary"], self.results["sample_bif1"]) - self.assertEqual(fmv2["sample_bipre"]["binary"], self.results["sample_bipre"]) - self.assertEqual(fmv2["sample_birec"]["binary"], self.results["sample_birec"]) - self.assertEqual(fmv2["sample_biiou"]["binary"], self.results["sample_biiou"]) - self.assertEqual(fmv2["sample_bidice"]["binary"], self.results["sample_bidice"]) - self.assertEqual(fmv2["sample_bispec"]["binary"], self.results["sample_bispec"]) - self.assertEqual(fmv2["sample_biber"]["binary"], self.results["sample_biber"]) - - self.assertEqual(fmv2["overall_bifm"]["binary"], self.results["overall_bifm"]) - self.assertEqual(fmv2["overall_bif1"]["binary"], self.results["overall_bif1"]) - self.assertEqual(fmv2["overall_bipre"]["binary"], self.results["overall_bipre"]) - self.assertEqual(fmv2["overall_birec"]["binary"], self.results["overall_birec"]) - self.assertEqual(fmv2["overall_biiou"]["binary"], self.results["overall_biiou"]) - self.assertEqual(fmv2["overall_bidice"]["binary"], self.results["overall_bidice"]) - self.assertEqual(fmv2["overall_bispec"]["binary"], self.results["overall_bispec"]) - self.assertEqual(fmv2["overall_biber"]["binary"], self.results["overall_biber"]) + self.assertEqual(curr_results["adpFm"], self.default_results["adpfm"]) + self.assertEqual(curr_results["meanFm"], self.default_results["meanfm"]) + self.assertEqual(curr_results["maxFm"], self.default_results["maxfm"]) + + self.assertEqual(curr_results["sample_bifm"], self.default_results["sample_bifm"]) + self.assertEqual(curr_results["overall_bifm"], self.default_results["overall_bifm"]) + + def test_em(self): + self.assertEqual(curr_results["adpEm"], self.default_results["adpEm"]) + self.assertEqual(curr_results["meanEm"], self.default_results["meanEm"]) + self.assertEqual(curr_results["maxEm"], self.default_results["maxEm"]) + + def test_f1(self): + self.assertEqual(curr_results["adpf1"], self.default_results["adpf1"]) + self.assertEqual(curr_results["meanf1"], self.default_results["meanf1"]) + self.assertEqual(curr_results["maxf1"], self.default_results["maxf1"]) + self.assertEqual(curr_results["sample_bif1"], self.default_results["sample_bif1"]) + self.assertEqual(curr_results["overall_bif1"], self.default_results["overall_bif1"]) + + def test_pre(self): + self.assertEqual(curr_results["adppre"], self.default_results["adppre"]) + self.assertEqual(curr_results["meanpre"], self.default_results["meanpre"]) + self.assertEqual(curr_results["maxpre"], self.default_results["maxpre"]) + self.assertEqual(curr_results["sample_bipre"], self.default_results["sample_bipre"]) + self.assertEqual(curr_results["overall_bipre"], self.default_results["overall_bipre"]) + + def test_rec(self): + self.assertEqual(curr_results["adprec"], self.default_results["adprec"]) + self.assertEqual(curr_results["meanrec"], self.default_results["meanrec"]) + self.assertEqual(curr_results["maxrec"], self.default_results["maxrec"]) + self.assertEqual(curr_results["sample_birec"], self.default_results["sample_birec"]) + self.assertEqual(curr_results["overall_birec"], self.default_results["overall_birec"]) + + def test_iou(self): + self.assertEqual(curr_results["adpiou"], self.default_results["adpiou"]) + self.assertEqual(curr_results["meaniou"], self.default_results["meaniou"]) + self.assertEqual(curr_results["maxiou"], self.default_results["maxiou"]) + self.assertEqual(curr_results["sample_biiou"], self.default_results["sample_biiou"]) + self.assertEqual(curr_results["overall_biiou"], self.default_results["overall_biiou"]) + + def test_dice(self): + self.assertEqual(curr_results["adpdice"], self.default_results["adpdice"]) + self.assertEqual(curr_results["meandice"], self.default_results["meandice"]) + self.assertEqual(curr_results["maxdice"], self.default_results["maxdice"]) + self.assertEqual(curr_results["sample_bidice"], self.default_results["sample_bidice"]) + self.assertEqual(curr_results["overall_bidice"], self.default_results["overall_bidice"]) + + def test_spec(self): + self.assertEqual(curr_results["adpspec"], self.default_results["adpspec"]) + self.assertEqual(curr_results["meanspec"], self.default_results["meanspec"]) + self.assertEqual(curr_results["maxspec"], self.default_results["maxspec"]) + self.assertEqual(curr_results["sample_bispec"], self.default_results["sample_bispec"]) + self.assertEqual(curr_results["overall_bispec"], self.default_results["overall_bispec"]) + + def test_ber(self): + self.assertEqual(curr_results["adpber"], self.default_results["adpber"]) + self.assertEqual(curr_results["meanber"], self.default_results["meanber"]) + self.assertEqual(curr_results["maxber"], self.default_results["maxber"]) + self.assertEqual(curr_results["sample_biber"], self.default_results["sample_biber"]) + self.assertEqual(curr_results["overall_biber"], self.default_results["overall_biber"]) + + def test_oa(self): + self.assertEqual(curr_results["adpoa"], self.default_results["adpoa"]) + self.assertEqual(curr_results["meanoa"], self.default_results["meanoa"]) + self.assertEqual(curr_results["maxoa"], self.default_results["maxoa"]) + self.assertEqual(curr_results["sample_bioa"], self.default_results["sample_bioa"]) + self.assertEqual(curr_results["overall_bioa"], self.default_results["overall_bioa"]) + + def test_kappa(self): + self.assertEqual(curr_results["adpkappa"], self.default_results["adpkappa"]) + self.assertEqual(curr_results["meankappa"], self.default_results["meankappa"]) + self.assertEqual(curr_results["maxkappa"], self.default_results["maxkappa"]) + self.assertEqual(curr_results["sample_bikappa"], self.default_results["sample_bikappa"]) + self.assertEqual(curr_results["overall_bikappa"], self.default_results["overall_bikappa"]) if __name__ == "__main__": diff --git a/py_sod_metrics/__init__.py b/py_sod_metrics/__init__.py index 1e336d7..307409d 100755 --- a/py_sod_metrics/__init__.py +++ b/py_sod_metrics/__init__.py @@ -1,13 +1,15 @@ # -*- coding: utf-8 -*- from py_sod_metrics.fmeasurev2 import ( + BERHandler, DICEHandler, - FmeasureV2, FmeasureHandler, + FmeasureV2, IOUHandler, + KappaHandler, + OverallAccuracyHandler, PrecisionHandler, RecallHandler, SpecificityHandler, - BERHandler, ) from py_sod_metrics.sod_metrics import ( MAE, diff --git a/py_sod_metrics/fmeasurev2.py b/py_sod_metrics/fmeasurev2.py index c87a44f..37079f9 100644 --- a/py_sod_metrics/fmeasurev2.py +++ b/py_sod_metrics/fmeasurev2.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- +import abc + import numpy as np -from .utils import get_adaptive_threshold, prepare_data, TYPE -import abc +from .utils import TYPE, get_adaptive_threshold, prepare_data class _BaseHandler: @@ -37,6 +38,12 @@ def __init__( def __call__(self, *args, **kwds): pass + @staticmethod + def divide(numerator, denominator): + denominator = np.array(denominator, dtype=TYPE) + np.divide(numerator, denominator, out=denominator, where=denominator != 0) + return denominator + class IOUHandler(_BaseHandler): """Intersection over Union @@ -46,10 +53,7 @@ class IOUHandler(_BaseHandler): def __call__(self, tp, fp, tn, fn): # ious = np.where(Ps + FNs == 0, 0, TPs / (Ps + FNs)) - numerator = tp - denominator = np.array(tp + fp + fn, dtype=TYPE) - np.divide(numerator, denominator, out=denominator, where=denominator != 0) - return denominator + return self.divide(tp, tp + fp + fn) class SpecificityHandler(_BaseHandler): @@ -60,10 +64,7 @@ class SpecificityHandler(_BaseHandler): def __call__(self, tp, fp, tn, fn): # specificities = np.where(TNs + FPs == 0, 0, TNs / (TNs + FPs)) - numerator = tn - denominator = np.array(tn + fp, dtype=TYPE) - np.divide(numerator, denominator, out=denominator, where=denominator != 0) - return denominator + return self.divide(tn, tn + fp) class DICEHandler(_BaseHandler): @@ -74,10 +75,62 @@ class DICEHandler(_BaseHandler): def __call__(self, tp, fp, tn, fn): # dices = np.where(TPs + FPs == 0, 0, 2 * TPs / (T + Ps)) - numerator = 2 * tp - denominator = np.array(tp + fn + tp + fp, dtype=TYPE) - np.divide(numerator, denominator, out=denominator, where=denominator != 0) - return denominator + return self.divide(2 * tp, tp + fn + tp + fp) + + +class OverallAccuracyHandler(_BaseHandler): + """OverallAccuracy + + oa = overall_accuracy = (tp + tn) / (tp + fp + tn + fn) + """ + + def __call__(self, tp, fp, tn, fn): + # dices = np.where(TPs + FPs == 0, 0, 2 * TPs / (T + Ps)) + return self.divide(tp + tn, tp + fp + tn + fn) + + +class KappaHandler(_BaseHandler): + """KappaAccuracy + + kappa = kappa = (oa - p_) / (1 - p_) + p_ = [(tp + fp)(tp + fn) + (tn + fn)(tn + tp)] / (tp + fp + tn + fn)^2 + """ + + def __init__( + self, + with_dynamic: bool, + with_adaptive: bool, + *, + with_binary: bool = False, + sample_based: bool = True, + beta: float = 0.3, + ): + """ + Args: + with_dynamic (bool, optional): Record dynamic results for max/avg/curve versions. + with_adaptive (bool, optional): Record adaptive results for adp version. + with_binary (bool, optional): Record binary results for binary version. + sample_based (bool, optional): Whether to average the metric of each sample or calculate + the metric of the dataset. Defaults to True. + beta (bool, optional): β^2 in F-measure. Defaults to 0.3. + """ + super().__init__( + with_dynamic=with_dynamic, + with_adaptive=with_adaptive, + with_binary=with_binary, + sample_based=sample_based, + ) + + self.beta = beta + self.oa = OverallAccuracyHandler(False, False) + + def __call__(self, tp, fp, tn, fn): + oa = self.oa(tp, fp, tn, fn) + hpy_p = self.divide( + (tp + fp) * (tp + fn) + (tn + fn) * (tn + tp), + (tp + fp + tn + fn) ** 2, + ) + return self.divide(oa - hpy_p, 1 - hpy_p) class PrecisionHandler(_BaseHandler): @@ -88,10 +141,7 @@ class PrecisionHandler(_BaseHandler): def __call__(self, tp, fp, tn, fn): # precisions = np.where(Ps == 0, 0, TPs / Ps) - numerator = tp - denominator = np.array(tp + fp, dtype=TYPE) - np.divide(numerator, denominator, out=denominator, where=denominator != 0) - return denominator + return self.divide(tp, tp + fp) class RecallHandler(_BaseHandler): @@ -102,10 +152,7 @@ class RecallHandler(_BaseHandler): def __call__(self, tp, fp, tn, fn): # recalls = np.where(TPs == 0, 0, TPs / T) - numerator = tp - denominator = np.array(tp + fn, dtype=TYPE) - np.divide(numerator, denominator, out=denominator, where=denominator != 0) - return denominator + return self.divide(tp, tp + fn) class BERHandler(_BaseHandler): @@ -146,7 +193,12 @@ def __init__( the metric of the dataset. Defaults to True. beta (bool, optional): β^2 in F-measure. Defaults to 0.3. """ - super().__init__(with_dynamic=with_dynamic, with_adaptive=with_adaptive, with_binary=with_binary, sample_based=sample_based) + super().__init__( + with_dynamic=with_dynamic, + with_adaptive=with_adaptive, + with_binary=with_binary, + sample_based=sample_based, + ) self.beta = beta self.precision = PrecisionHandler(False, False) @@ -159,10 +211,7 @@ def __call__(self, tp, fp, tn, fn): p = self.precision(tp, fp, tn, fn) r = self.recall(tp, fp, tn, fn) - numerator = (self.beta + 1) * p * r - denominator = np.array(self.beta * p + r, dtype=TYPE) - np.divide(numerator, denominator, out=denominator, where=denominator != 0) - return denominator + return self.divide((self.beta + 1) * p * r, self.beta * p + r) class FmeasureV2: