diff --git a/keras_wrapper/extra/evaluation.py b/keras_wrapper/extra/evaluation.py index d2e5bcf..5112788 100644 --- a/keras_wrapper/extra/evaluation.py +++ b/keras_wrapper/extra/evaluation.py @@ -247,7 +247,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split): # Compute accuracy top_n_accuracies = [3, 5] accuracy = sklearn_metrics.accuracy_score(y_gt, y_pred) - acc_top_n = [] + acc_top_n = {} for topn in top_n_accuracies: acc_top_n[topn] = __top_k_accuracy(y_gt, y_pred, topn) # accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights, )