Skip to content

Commit

Permalink
Fixed balanced accuracy and added top3 and top5 accuracy calculation.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcBS committed Sep 10, 2019
1 parent 9c5a28e commit a4a41f4
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion keras_wrapper/extra/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,24 +242,53 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):

# Compute accuracy
accuracy = sklearn_metrics.accuracy_score(y_gt, y_pred)
accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights)
accuracy_3 = __top_k_accuracy(y_gt, y_pred, 3)
accuracy_5 = __top_k_accuracy(y_gt, y_pred, 5)
#accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights, )

# The following two lines should both provide the same measure (balanced accuracy)
#_, accuracy_balanced, _, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average='macro')
accuracy_balanced = sklearn_metrics.balanced_accuracy_score(y_gt, y_pred)

# Compute Precision, Recall and F1 score
precision, recall, f1, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average='micro')

if verbose > 0:
logging.info('Accuracy: %f' % accuracy)
logging.info('Accuracy top-3: %f' % accuracy_3)
logging.info('Accuracy top-5: %f' % accuracy_5)
logging.info('Balanced Accuracy: %f' % accuracy_balanced)
logging.info('Precision: %f' % precision)
logging.info('Recall: %f' % recall)
logging.info('F1 score: %f' % f1)

return {'accuracy': accuracy,
'accuracy_top_3': accuracy_3,
'accuracy_top_5': accuracy_5,
'accuracy_balanced': accuracy_balanced,
'precision': precision,
'recall': recall,
'f1': f1}


def __top_k_accuracy(truths, preds, k):
"""
Both preds and truths are same shape m by n (m is number of predictions and n is number of classes)
:param preds:
:param truths:
:param k:
:return:
"""
best_k = np.argsort(preds, axis=1)[:, -k:]
ts = np.argmax(truths, axis=1)
successes = 0
for i in range(ts.shape[0]):
if ts[i] in best_k[i,:]:
successes += 1
return float(successes)/ts.shape[0]


def semantic_segmentation_accuracy(pred_list, verbose, extra_vars, split):
"""
Semantic Segmentation Accuracy metric
Expand Down

0 comments on commit a4a41f4

Please sign in to comment.