Skip to content

Commit

Permalink
Minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Eaaguilart committed Oct 18, 2019
1 parent 893f99b commit f93864c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions keras_wrapper/extra/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self,
each_n_epochs=1,
max_eval_samples=None,
extra_vars=None,
normalize=True,
normalize=False,
normalization_type=None,
output_types=None,
is_text=False,
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(self,

else:
# Convert min_pred_multilabel to list
if isinstance(self.min_pred_multilabel, list):
if not isinstance(self.min_pred_multilabel, list):
self.min_pred_multilabel = [self.min_pred_multilabel for _ in self.gt_pos]

super(EvalPerformance, self).__init__()
Expand Down
4 changes: 2 additions & 2 deletions keras_wrapper/extra/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
accuracy = sklearn_metrics.accuracy_score(y_gt, y_pred)
accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights)
# Compute Precision, Recall and F1 score
avrg = extra_vars.get('average_mode', None)
avrg = extra_vars.get('average_mode', 'macro')
precision, recall, f1, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average=avrg)
# Compute Confusion Matrix
cf = sklearn_metrics.confusion_matrix(np.argmax(y_gt, -1), np.argmax(y_pred, -1))
Expand All @@ -271,7 +271,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
# Compute top 5 fp classes
top5_fps = np.argpartition(cf * neg_identity, -5)[:, -5:][:, ::-1]
# Compute top 5 accuracy
arg_top5_pred = np.argpartition(y_pred, -5)[:, -5:]
arg_top5_pred = np.argpartition(pred_list, -5)[:, -5:]
arg_gt = np.argmax(y_gt, -1)
top5_acc = np.mean(np.max(arg_top5_pred == np.repeat(np.expand_dims(arg_gt, -1), 5, -1), -1))

Expand Down

0 comments on commit f93864c

Please sign in to comment.