diff --git a/keras_wrapper/extra/evaluation.py b/keras_wrapper/extra/evaluation.py index 5112788..f845236 100644 --- a/keras_wrapper/extra/evaluation.py +++ b/keras_wrapper/extra/evaluation.py @@ -249,7 +249,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split): accuracy = sklearn_metrics.accuracy_score(y_gt, y_pred) acc_top_n = {} for topn in top_n_accuracies: - acc_top_n[topn] = __top_k_accuracy(y_gt, y_pred, topn) + acc_top_n[topn] = __top_k_accuracy(y_gt, pred_list, topn) # accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights, ) # The following two lines should both provide the same measure (balanced accuracy) @@ -287,7 +287,7 @@ def __top_k_accuracy(truths, preds, k): :param k: :return: """ - best_k = np.argsort(preds, axis=1)[:, -k:] + best_k = np.argsort(preds, axis=1)[:, -k:][:, ::-1] ts = np.argmax(truths, axis=1) successes = 0 for i in range(ts.shape[0]):