Skip to content

Commit

Permalink
更快的Emeasure的计算速度,具体可见:https://www.yuque.com/lart/blog/lwgt38
Browse files Browse the repository at this point in the history
  • Loading branch information
lartpang committed Nov 30, 2020
1 parent 174b45a commit 49ae833
Showing 1 changed file with 75 additions and 76 deletions.
151 changes: 75 additions & 76 deletions sod_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from scipy.ndimage import convolve, distance_transform_edt as bwdist

__version__ = '1.1.1'
__version__ = '1.2.1'

_EPS = 1e-16
_TYPE = np.float64
Expand Down Expand Up @@ -249,103 +249,102 @@ def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float:
adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold)
return adaptive_em

def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> list:
changeable_ems = [
self.cal_em_with_threshold(pred, gt, threshold=threshold)
for threshold in np.linspace(0, 1, 256)
]
def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt)
return changeable_ems

def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
"""
函数内部变量命名规则:
pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
"""
binarized_pred = pred >= threshold
fg_fg_numel = np.count_nonzero(binarized_pred & gt)
fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)

fg___numel = fg_fg_numel + fg_bg_numel
bg___numel = self.gt_size - fg___numel

if self.gt_fg_numel == 0:
binarized_pred_bg_numel = np.count_nonzero(~binarized_pred)
enhanced_matrix_sum = binarized_pred_bg_numel
enhanced_matrix_sum = bg___numel
elif self.gt_fg_numel == self.gt_size:
binarized_pred_fg_numel = np.count_nonzero(binarized_pred)
enhanced_matrix_sum = binarized_pred_fg_numel
enhanced_matrix_sum = fg___numel
else:
enhanced_matrix_sum = self.cal_enhanced_matrix(binarized_pred, gt)
parts_numel, combinations = self.generate_parts_numel_combinations(
fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel,
pred_fg_numel=fg___numel, pred_bg_numel=bg___numel,
)

results_parts = []
for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts.append(enhanced_matrix_value * part_numel)
enhanced_matrix_sum = sum(results_parts)

em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
return em

def cal_enhanced_matrix(self, binarized_pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
# demeaned_pred = pred - pred.mean()
# demeaned_gt = gt - gt.mean()
fg_fg_numel = np.count_nonzero(binarized_pred & gt)
fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
# bg_fg_numel = np.count_nonzero(~binarized_pred & gt)
def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
"""
函数内部变量命名规则:
pred属性(前景fg、背景bg)_gt属性(前景fg、背景bg)_变量含义
如果仅考虑pred或者gt,则另一个对应的属性位置使用`_`替换
"""
pred = (pred * 255).astype(np.uint8)
bins = np.linspace(0, 256, 257)
fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)

fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs

if self.gt_fg_numel == 0:
enhanced_matrix_sum = bg___numel_w_thrs
elif self.gt_fg_numel == self.gt_size:
enhanced_matrix_sum = fg___numel_w_thrs
else:
parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
)

results_parts = np.empty(shape=(4, 256), dtype=np.float64)
for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts[i] = enhanced_matrix_value * part_numel
enhanced_matrix_sum = results_parts.sum(axis=0)

em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
return em

def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel):
bg_fg_numel = self.gt_fg_numel - fg_fg_numel
# bg_bg_numel = np.count_nonzero(~binarized_pred & ~gt)
bg_bg_numel = self.gt_size - (fg_fg_numel + fg_bg_numel + bg_fg_numel)
bg_bg_numel = pred_bg_numel - bg_fg_numel

parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel]

mean_pred_value = (fg_fg_numel + fg_bg_numel) / self.gt_size
mean_pred_value = pred_fg_numel / self.gt_size
mean_gt_value = self.gt_fg_numel / self.gt_size

demeaned_pred_fg_value = 1 - mean_pred_value
demeaned_pred_bg_value = 0 - mean_pred_value
demeaned_gt_fg_value = 1 - mean_gt_value
demeaned_gt_bg_value = 0 - mean_gt_value

combinations = [(demeaned_pred_fg_value, demeaned_gt_fg_value),
(demeaned_pred_fg_value, demeaned_gt_bg_value),
(demeaned_pred_bg_value, demeaned_gt_fg_value),
(demeaned_pred_bg_value, demeaned_gt_bg_value)]

results_parts = []
for part_numel, combination in zip(parts_numel, combinations):
# align_matrix = 2 * (demeaned_gt * demeaned_pred) / (demeaned_gt ** 2 + demeaned_pred ** 2 + _EPS)
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
# enhanced_matrix = (align_matrix + 1) ** 2 / 4
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts.append(enhanced_matrix_value * part_numel)

# enhanced_matrix = enhanced_matrix.sum()
enhanced_matrix = sum(results_parts)
return enhanced_matrix

# def cal_enhanced_matrix_parallel(self, binarized_pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
# """
# 通过构成数组,使用numpy加速计算,消除for循环,实际带来的增益相较于cal_enhanced_matrix反而是负的,可能是因为转成的矩阵太小
# """
# # demeaned_pred = pred - pred.mean()
# # demeaned_gt = gt - gt.mean()
# fg_fg_numel = np.count_nonzero(binarized_pred & gt)
# fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
# pred_fg_numel = fg_fg_numel + fg_bg_numel
#
# # bg_fg_numel = np.count_nonzero(~binarized_pred & gt)
# bg_fg_numel = self.gt_fg_numel - fg_fg_numel
# # bg_bg_numel = np.count_nonzero(~binarized_pred & ~gt)
# bg_bg_numel = self.gt_size - (pred_fg_numel + bg_fg_numel)
#
# parts_numel = np.array([fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel], dtype=_TYPE)
#
# mean_pred_value = pred_fg_numel / self.gt_size
# mean_gt_value = self.gt_fg_numel / self.gt_size
#
# demeaned_pred_fg_value = 1 - mean_pred_value
# demeaned_pred_bg_value = 0 - mean_pred_value
# demeaned_gt_fg_value = 1 - mean_gt_value
# demeaned_gt_bg_value = 0 - mean_gt_value
#
# combinations = np.array([[demeaned_pred_fg_value, demeaned_gt_fg_value],
# [demeaned_pred_fg_value, demeaned_gt_bg_value],
# [demeaned_pred_bg_value, demeaned_gt_fg_value],
# [demeaned_pred_bg_value, demeaned_gt_bg_value]], dtype=_TYPE)
#
# # align_matrix = 2 * (demeaned_gt * demeaned_pred) / (demeaned_gt ** 2 + demeaned_pred ** 2 + _EPS)
# align_matrix_value = 2 * combinations.prod(axis=-1) / ((combinations ** 2).sum(axis=-1, keepdims=True) +
# _EPS)
# # enhanced_matrix = (align_matrix + 1) ** 2 / 4
# enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
# # enhanced_matrix = enhanced_matrix.sum()
# enhanced_matrix_sum = (enhanced_matrix_value * parts_numel).sum()
# return enhanced_matrix_sum
combinations = [
(demeaned_pred_fg_value, demeaned_gt_fg_value),
(demeaned_pred_fg_value, demeaned_gt_bg_value),
(demeaned_pred_bg_value, demeaned_gt_fg_value),
(demeaned_pred_bg_value, demeaned_gt_bg_value)
]
return parts_numel, combinations

def get_results(self) -> dict:
adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE))
Expand Down

0 comments on commit 49ae833

Please sign in to comment.