diff --git a/keras_wrapper/extra/evaluation.py b/keras_wrapper/extra/evaluation.py index 6bc7181..ee79f30 100644 --- a/keras_wrapper/extra/evaluation.py +++ b/keras_wrapper/extra/evaluation.py @@ -242,24 +242,53 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split): # Compute accuracy accuracy = sklearn_metrics.accuracy_score(y_gt, y_pred) - accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights) + accuracy_3 = __top_k_accuracy(y_gt, y_pred, 3) + accuracy_5 = __top_k_accuracy(y_gt, y_pred, 5) + #accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights, ) + + # The following two lines should both provide the same measure (balanced accuracy) + #_, accuracy_balanced, _, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average='macro') + accuracy_balanced = sklearn_metrics.balanced_accuracy_score(y_gt, y_pred) + # Compute Precision, Recall and F1 score precision, recall, f1, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average='micro') if verbose > 0: logging.info('Accuracy: %f' % accuracy) + logging.info('Accuracy top-3: %f' % accuracy_3) + logging.info('Accuracy top-5: %f' % accuracy_5) logging.info('Balanced Accuracy: %f' % accuracy_balanced) logging.info('Precision: %f' % precision) logging.info('Recall: %f' % recall) logging.info('F1 score: %f' % f1) return {'accuracy': accuracy, + 'accuracy_top_3': accuracy_3, + 'accuracy_top_5': accuracy_5, 'accuracy_balanced': accuracy_balanced, 'precision': precision, 'recall': recall, 'f1': f1} +def __top_k_accuracy(truths, preds, k): + """ + Both preds and truths are same shape m by n (m is number of predictions and n is number of classes) + + :param preds: + :param truths: + :param k: + :return: + """ + best_k = np.argsort(preds, axis=1)[:, -k:] + ts = np.argmax(truths, axis=1) + successes = 0 + for i in range(ts.shape[0]): + if ts[i] in best_k[i,:]: + successes += 1 + return float(successes)/ts.shape[0] + + def semantic_segmentation_accuracy(pred_list, verbose, extra_vars, split): """ Semantic Segmentation Accuracy metric