Skip to content

Commit

Permalink
✨ feat:
Browse files Browse the repository at this point in the history
1. Update the metrics for binary image.
2. Use unittest library to check the resutls.
3. Update texamples.
  • Loading branch information
lartpang committed Jan 11, 2023
1 parent f3756f2 commit b228260
Show file tree
Hide file tree
Showing 3 changed files with 378 additions and 271 deletions.
193 changes: 141 additions & 52 deletions examples/metric_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ def _to_list_or_scalar(item):
}


class CalTotalMetricV1:
class MetricRecorderV1:
def __init__(self):
"""
用于统计各种指标的类
https://github.com/lartpang/Py-SOD-VOS-EvalToolkit/blob/81ce89da6813fdd3e22e3f20e3a09fe1e4a1a87c/utils/recorders/metric_recorder.py
主要应用于旧版本实现中的五个指标,即mae/fm/sm/em/wfm。推荐使用V2版本。
"""
self.mae = INDIVADUAL_METRIC_MAPPING["mae"]()
self.fm = INDIVADUAL_METRIC_MAPPING["fm"]()
Expand Down Expand Up @@ -103,46 +105,76 @@ def get_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:


BINARY_CLASSIFICATION_METRIC_MAPPING = {
"fmeasure": {
"handler": py_sod_metrics.FmeasureHandler,
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True, beta=0.3),
},
"precision": {
"handler": py_sod_metrics.PrecisionHandler,
"kwargs": dict(with_dynamic=True, with_adaptive=False, with_binary=False),
},
"recall": {
"handler": py_sod_metrics.RecallHandler,
"kwargs": dict(with_dynamic=True, with_adaptive=False, with_binary=False),
},
"iou": {
"handler": py_sod_metrics.IOUHandler,
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True),
},
"dice": {
"handler": py_sod_metrics.DICEHandler,
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True),
},
"specificity": {
"handler": py_sod_metrics.SpecificityHandler,
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True),
},
"ber": {
"handler": py_sod_metrics.BERHandler,
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=True),
}
# 灰度数据指标
"fm": py_sod_metrics.FmeasureHandler(with_adaptive=True, with_dynamic=True, beta=0.3),
"f1": py_sod_metrics.FmeasureHandler(with_adaptive=True, with_dynamic=True, beta=0.1),
"pre": py_sod_metrics.PrecisionHandler(with_adaptive=True, with_dynamic=True),
"rec": py_sod_metrics.RecallHandler(with_adaptive=True, with_dynamic=True),
"iou": py_sod_metrics.IOUHandler(with_adaptive=True, with_dynamic=True),
"dice": py_sod_metrics.DICEHandler(with_adaptive=True, with_dynamic=True),
"spec": py_sod_metrics.SpecificityHandler(with_adaptive=True, with_dynamic=True),
"ber": py_sod_metrics.BERHandler(with_adaptive=True, with_dynamic=True),
# 二值化数据指标的特殊情况一:各个样本独立计算指标后取平均
"sample_bifm": py_sod_metrics.FmeasureHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=0.3
),
"sample_bif1": py_sod_metrics.FmeasureHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=1
),
"sample_bipre": py_sod_metrics.PrecisionHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"sample_birec": py_sod_metrics.RecallHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"sample_biiou": py_sod_metrics.IOUHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"sample_bidice": py_sod_metrics.DICEHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"sample_bispec": py_sod_metrics.SpecificityHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"sample_biber": py_sod_metrics.BERHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
# 二值化数据指标的特殊情况二:汇总所有样本的tp、fp、tn、fn后整体计算指标
"overall_bifm": py_sod_metrics.FmeasureHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=0.3
),
"overall_bif1": py_sod_metrics.FmeasureHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True, beta=1
),
"overall_bipre": py_sod_metrics.PrecisionHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"overall_birec": py_sod_metrics.RecallHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"overall_biiou": py_sod_metrics.IOUHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"overall_bidice": py_sod_metrics.DICEHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"overall_bispec": py_sod_metrics.SpecificityHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
"overall_biber": py_sod_metrics.BERHandler(
with_adaptive=False, with_dynamic=False, with_binary=True, sample_based=True
),
}


class CalTotalMetricV2:
# 'fm' is replaced by 'fmeasure' in BINARY_CLASSIFICATION_METRIC_MAPPING
class MetricRecorderV2:
suppoted_metrics = ["mae", "em", "sm", "wfm"] + sorted(
BINARY_CLASSIFICATION_METRIC_MAPPING.keys()
[k for k in BINARY_CLASSIFICATION_METRIC_MAPPING.keys() if not k.startswith(('sample_', 'overall_'))]
)

def __init__(self, metric_names=None):
def __init__(self, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
"""
用于统计各种指标的类
用于统计各种指标的类,支持更多的指标,更好的兼容性。
"""
if not metric_names:
metric_names = self.suppoted_metrics
Expand All @@ -161,24 +193,18 @@ def __init__(self, metric_names=None):
has_existed = True
metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING[metric_name]
self.metric_objs["fmeasurev2"].add_handler(
# instantiate inside the class instead of outside the class
metric_handler["handler"](**metric_handler["kwargs"])
handler_name=metric_name,
metric_handler=metric_handler["handler"](**metric_handler["kwargs"]),
)

def update(self, pre: np.ndarray, gt: np.ndarray):
def step(self, pre: np.ndarray, gt: np.ndarray):
assert pre.shape == gt.shape, (pre.shape, gt.shape)
assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype)

for m_obj in self.metric_objs.values():
m_obj.step(pre, gt)

def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
"""
返回指标计算结果:
- 曲线数据(sequential)
- 数值指标(numerical)
"""
def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
sequential_results = {}
numerical_results = {}
for m_name, m_obj in self.metric_objs.items():
Expand All @@ -187,15 +213,12 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
for _name, results in info.items():
dynamic_results = results.get("dynamic")
adaptive_results = results.get("adaptive")
binary_results = results.get('binary')
if dynamic_results is not None:
sequential_results[_name] = np.flip(dynamic_results)
numerical_results[f"max{_name}"] = dynamic_results.max()
numerical_results[f"avg{_name}"] = dynamic_results.mean()
if adaptive_results is not None:
numerical_results[f"adp{_name}"] = adaptive_results
if binary_results is not None:
numerical_results[f"bi{_name}"] = binary_results
else:
results = info[m_name]
if m_name in ("wfm", "sm", "mae"):
Expand All @@ -204,9 +227,9 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
sequential_results[m_name] = np.flip(results["curve"])
numerical_results.update(
{
"maxe": results["curve"].max(),
"avge": results["curve"].mean(),
"adpe": results["adp"],
"maxem": results["curve"].max(),
"avgem": results["curve"].mean(),
"adpem": results["adp"],
}
)
else:
Expand All @@ -219,15 +242,81 @@ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
numerical_results = ndarray_to_basetype(numerical_results)
return {"sequential": sequential_results, "numerical": numerical_results}

def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
return self.get_all_results(num_bits=num_bits, return_ndarray=return_ndarray)["numerical"]


class BinaryMetricRecorder:
suppoted_metrics = ["mae", "sm", "wfm"] + sorted(
[k for k in BINARY_CLASSIFICATION_METRIC_MAPPING.keys() if k.startswith(('sample_', 'overall_'))]
)

def __init__(self, metric_names=("bif1", "biprecision", "birecall", "biiou")):
"""
用于统计各种指标的类,主要适用于对单通道灰度图计算二值图像的指标。
"""
if not metric_names:
metric_names = self.suppoted_metrics
assert all(
[m in self.suppoted_metrics for m in metric_names]
), f"Only support: {self.suppoted_metrics}"

self.metric_objs = {}
has_existed = False
for metric_name in metric_names:
if metric_name in INDIVADUAL_METRIC_MAPPING:
self.metric_objs[metric_name] = INDIVADUAL_METRIC_MAPPING[metric_name]()
else: # metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING
if not has_existed: # only init once
self.metric_objs["fmeasurev2"] = py_sod_metrics.FmeasureV2()
has_existed = True
metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING[metric_name]
self.metric_objs["fmeasurev2"].add_handler(
handler_name=metric_name,
metric_handler=metric_handler["handler"](**metric_handler["kwargs"]),
)

def step(self, pre: np.ndarray, gt: np.ndarray):
assert pre.shape == gt.shape, (pre.shape, gt.shape)
assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype)

for m_obj in self.metric_objs.values():
m_obj.step(pre, gt)

def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
numerical_results = {}
for m_name, m_obj in self.metric_objs.items():
info = m_obj.get_results()
if m_name == "fmeasurev2":
for _name, results in info.items():
binary_results = results.get("binary")
if binary_results is not None:
numerical_results[_name] = binary_results
else:
results = info[m_name]
if m_name in ("mae", "sm", "wfm"):
numerical_results[m_name] = results
else:
raise NotImplementedError(m_name)

if num_bits is not None and isinstance(num_bits, int):
numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
if not return_ndarray:
numerical_results = ndarray_to_basetype(numerical_results)
return {"numerical": numerical_results}

def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
return self.get_all_results(num_bits=num_bits, return_ndarray=return_ndarray)["numerical"]


if __name__ == "__main__":
data_loader = ...
model = ...

cal_total_seg_metrics = CalTotalMetricV1()
cal_total_seg_metrics = MetricRecorderV2()
for batch in data_loader:
seg_preds = model(batch)
for seg_pred in seg_preds:
mask_array = ...
cal_total_seg_metrics.step(seg_pred, mask_array)
fixed_seg_results = cal_total_seg_metrics.get_results()
fixed_seg_results = cal_total_seg_metrics.show()
Loading

0 comments on commit b228260

Please sign in to comment.