diff --git a/imagenet_utils.py b/imagenet_utils.py index 8a61fcc..5cf3d17 100644 --- a/imagenet_utils.py +++ b/imagenet_utils.py @@ -42,7 +42,7 @@ def decode_predictions(preds, top=5): CLASS_INDEX = json.load(open(fpath)) results = [] for pred in preds: - top_indices = np.argpartition(pred, -top)[-top:][::-1] + top_indices = pred.argsort()[-top:][::-1] result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices] results.append(result) return results