Skip to content

Commit

Permalink
FIX bad evaluation of the prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
tomMoral committed Dec 13, 2023
1 parent 2f92422 commit c27678d
Showing 1 changed file with 11 additions and 26 deletions.
37 changes: 11 additions & 26 deletions stroke/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,34 +43,21 @@ def score_function(self, Y_true: np.array, Y_pred: np.array):
Sørensen–Dice coefficient.
"""
y_true = np.array(Y_true.y_true)
assert len(y_true) == len(Y_pred.y_pred)
if len(Y_pred.y_pred) == 0:
return 0
estimator = Y_pred.y_pred[0].estimator

fscore = 0
# Load example to ensure that the size fits
dat = estimator.predict(BIDSLoader.load_image_tuple(Y_pred.y_pred[0].pred))
# Have to unpack if y_true is bool
# Using proxy of y_true.shape != y_pred.shape to indicate that data needs to be unpacked
must_unpack = y_true[0, ...].shape != dat.shape

for idx, prediction_object in enumerate(Y_pred.y_pred):
# First sample is already loaded; let's not waste the loading.
if idx != 0:
dat = BIDSLoader.load_image_tuple(prediction_object.pred)

# Note: If you want to get the weighted mean, use
# self.calc_score_parts
if must_unpack:
unpacked_y_sample = np.array(
self.unpack_data(y_true[idx, ...], dat.shape), dtype=dat.dtype
for y_true_i, prediction_object in zip(y_true, Y_pred.y_pred):
dat = estimator.predict(BIDSLoader.load_image_tuple(prediction_object.pred))

# Using proxy of y_true.shape != y_pred.shape to indicate that data needs to be unpacked
if y_true_i.shape != dat.shape:
y_true_i = np.array(
self.unpack_data(y_true_i, dat.shape), dtype=dat.dtype
)
# unpacked_y_sample = np.array(np.unpackbits(y_true[idx, ...]), dtype=dat.dtype)
unpacked_y_sample = unpacked_y_sample.reshape(dat.shape)
sd_score = self.calc_score(dat, unpacked_y_sample)
else:
sd_score = self.calc_score(dat, y_true[idx, ...])
fscore += sd_score
fscore += self.calc_score(dat, y_true_i)

# Return the mean score
return fscore / (idx + 1)
Expand Down Expand Up @@ -132,10 +119,8 @@ def calc_score_parts(array_0: np.array, array_1: np.array):
tuple
Tuple containing (overlap, sum(array_0), sum(array_1)
"""
array_0_reshape = np.reshape(array_0, (1, np.prod(array_0.shape)))
array_1_reshape = np.reshape(array_1, (np.prod((array_1.shape)), 1))
overlap = 2 * array_0_reshape @ array_1_reshape
return (overlap[0][0], np.sum(array_0), np.sum(array_1))
overlap = 2 * array_0.ravel() @ array_1.ravel()
return (overlap, np.sum(array_0), np.sum(array_1))

@staticmethod
def check_y_pred_dimensions(array_0: np.array, array_1: np.array):
Expand Down

0 comments on commit c27678d

Please sign in to comment.