diff --git a/stroke/scoring.py b/stroke/scoring.py index 139a481..a7f476e 100644 --- a/stroke/scoring.py +++ b/stroke/scoring.py @@ -43,34 +43,21 @@ def score_function(self, Y_true: np.array, Y_pred: np.array): Sørensen–Dice coefficient. """ y_true = np.array(Y_true.y_true) + assert len(y_true) == len(Y_pred.y_pred) if len(Y_pred.y_pred) == 0: return 0 estimator = Y_pred.y_pred[0].estimator fscore = 0 - # Load example to ensure that the size fits - dat = estimator.predict(BIDSLoader.load_image_tuple(Y_pred.y_pred[0].pred)) - # Have to unpack if y_true is bool - # Using proxy of y_true.shape != y_pred.shape to indicate that data needs to be unpacked - must_unpack = y_true[0, ...].shape != dat.shape - - for idx, prediction_object in enumerate(Y_pred.y_pred): - # First sample is already loaded; let's not waste the loading. - if idx != 0: - dat = BIDSLoader.load_image_tuple(prediction_object.pred) - - # Note: If you want to get the weighted mean, use - # self.calc_score_parts - if must_unpack: - unpacked_y_sample = np.array( - self.unpack_data(y_true[idx, ...], dat.shape), dtype=dat.dtype + for y_true_i, prediction_object in zip(y_true, Y_pred.y_pred): + dat = estimator.predict(BIDSLoader.load_image_tuple(prediction_object.pred)) + + # Using proxy of y_true.shape != y_pred.shape to indicate that data needs to be unpacked + if y_true_i.shape != dat.shape: + y_true_i = np.array( + self.unpack_data(y_true_i, dat.shape), dtype=dat.dtype ) - # unpacked_y_sample = np.array(np.unpackbits(y_true[idx, ...]), dtype=dat.dtype) - unpacked_y_sample = unpacked_y_sample.reshape(dat.shape) - sd_score = self.calc_score(dat, unpacked_y_sample) - else: - sd_score = self.calc_score(dat, y_true[idx, ...]) - fscore += sd_score + fscore += self.calc_score(dat, y_true_i) # Return the mean score return fscore / (idx + 1) @@ -132,10 +119,8 @@ def calc_score_parts(array_0: np.array, array_1: np.array): tuple Tuple containing (overlap, sum(array_0), sum(array_1) """ - array_0_reshape = np.reshape(array_0, (1, np.prod(array_0.shape))) - array_1_reshape = np.reshape(array_1, (np.prod((array_1.shape)), 1)) - overlap = 2 * array_0_reshape @ array_1_reshape - return (overlap[0][0], np.sum(array_0), np.sum(array_1)) + overlap = 2 * array_0.ravel() @ array_1.ravel() + return (overlap, np.sum(array_0), np.sum(array_1)) @staticmethod def check_y_pred_dimensions(array_0: np.array, array_1: np.array):