Skip to content

Commit

Permalink
Merge pull request #30 from ramp-kits/FIX_bad_evaluation
Browse files Browse the repository at this point in the history
FIX bad evaluation of the prediction
  • Loading branch information
frcaud authored Dec 14, 2023
2 parents 2f92422 + 987bd04 commit 208a4a3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 29 deletions.
9 changes: 7 additions & 2 deletions stroke/bids_workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import numpy as np
from rampwf.utils.importing import import_module_from_source
from stroke import stroke_config
from stroke.scoring import DiceCoeff
from stroke.bids_loader import BIDSLoader


Expand Down Expand Up @@ -61,13 +63,16 @@ def train_submission(

for idx in range(0, len(train_is), batch_size):
# Get tuples to load
data_to_load = [X_array[i] for i in train_is[idx : idx + batch_size]]
target_to_load = [y_array[i] for i in train_is[idx : idx + batch_size]]
data_to_load = [X_array[i] for i in train_is[idx:idx + batch_size]]
target_to_load = [y_array[i] for i in train_is[idx:idx + batch_size]]
# Load data
data = BIDSLoader.load_image_tuple_list(data_to_load)
target = BIDSLoader.load_image_tuple_list(
target_to_load, dtype=stroke_config.data_types["target"]
)
target = np.array([
DiceCoeff.unpack_data(y, X.shape) for y, X in zip(target, data)
])

# Fit
self.estimator.fit_partial(data, target)
Expand Down
39 changes: 12 additions & 27 deletions stroke/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,37 +43,24 @@ def score_function(self, Y_true: np.array, Y_pred: np.array):
Sørensen–Dice coefficient.
"""
y_true = np.array(Y_true.y_true)
assert len(y_true) == len(Y_pred.y_pred)
if len(Y_pred.y_pred) == 0:
return 0
estimator = Y_pred.y_pred[0].estimator

fscore = 0
# Load example to ensure that the size fits
dat = estimator.predict(BIDSLoader.load_image_tuple(Y_pred.y_pred[0].pred))
# Have to unpack if y_true is bool
# Using proxy of y_true.shape != y_pred.shape to indicate that data needs to be unpacked
must_unpack = y_true[0, ...].shape != dat.shape

for idx, prediction_object in enumerate(Y_pred.y_pred):
# First sample is already loaded; let's not waste the loading.
if idx != 0:
dat = BIDSLoader.load_image_tuple(prediction_object.pred)

# Note: If you want to get the weighted mean, use
# self.calc_score_parts
if must_unpack:
unpacked_y_sample = np.array(
self.unpack_data(y_true[idx, ...], dat.shape), dtype=dat.dtype
for y_true_i, prediction_object in zip(y_true, Y_pred.y_pred):
dat = estimator.predict(BIDSLoader.load_image_tuple(prediction_object.pred))

# Using proxy of y_true.shape != y_pred.shape to indicate that data needs to be unpacked
if y_true_i.shape != dat.shape:
y_true_i = np.array(
self.unpack_data(y_true_i, dat.shape), dtype=dat.dtype
)
# unpacked_y_sample = np.array(np.unpackbits(y_true[idx, ...]), dtype=dat.dtype)
unpacked_y_sample = unpacked_y_sample.reshape(dat.shape)
sd_score = self.calc_score(dat, unpacked_y_sample)
else:
sd_score = self.calc_score(dat, y_true[idx, ...])
fscore += sd_score
fscore += self.calc_score(dat, y_true_i)

# Return the mean score
return fscore / (idx + 1)
return fscore / len(y_true)

@staticmethod
def unpack_data(array_0: np.array, output_shape: np.array):
Expand Down Expand Up @@ -132,10 +119,8 @@ def calc_score_parts(array_0: np.array, array_1: np.array):
tuple
Tuple containing (overlap, sum(array_0), sum(array_1)
"""
array_0_reshape = np.reshape(array_0, (1, np.prod(array_0.shape)))
array_1_reshape = np.reshape(array_1, (np.prod((array_1.shape)), 1))
overlap = 2 * array_0_reshape @ array_1_reshape
return (overlap[0][0], np.sum(array_0), np.sum(array_1))
overlap = 2 * array_0.ravel() @ array_1.ravel()
return (overlap, np.sum(array_0), np.sum(array_1))

@staticmethod
def check_y_pred_dimensions(array_0: np.array, array_1: np.array):
Expand Down

0 comments on commit 208a4a3

Please sign in to comment.