From 9823d40814d72fc587c18a346e5acbe5981a225f Mon Sep 17 00:00:00 2001 From: Francois Caud Date: Thu, 14 Dec 2023 17:33:49 +0100 Subject: [PATCH] fix assert error --- stroke/scoring.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/stroke/scoring.py b/stroke/scoring.py index a6014a3..60a8b08 100644 --- a/stroke/scoring.py +++ b/stroke/scoring.py @@ -43,14 +43,17 @@ def score_function(self, Y_true: np.array, Y_pred: np.array): Sørensen–Dice coefficient. """ y_true = np.array(Y_true.y_true) - assert len(y_true) == len(Y_pred.y_pred) + if len(Y_pred.y_pred) != 0: + assert len(y_true) == len(Y_pred.y_pred) if len(Y_pred.y_pred) == 0: return 0 estimator = Y_pred.y_pred[0].estimator fscore = 0 for y_true_i, prediction_object in zip(y_true, Y_pred.y_pred): - dat = estimator.predict(BIDSLoader.load_image_tuple(prediction_object.pred)) + dat = estimator.predict( + BIDSLoader.load_image_tuple(prediction_object.pred) + ) # Using proxy of y_true.shape != y_pred.shape to indicate that data needs to be unpacked if y_true_i.shape != dat.shape: