diff --git a/qsprpred/models/metrics.py b/qsprpred/models/metrics.py index f06e485f..2e6d0e88 100644 --- a/qsprpred/models/metrics.py +++ b/qsprpred/models/metrics.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod import numpy as np +import sklearn from sklearn.metrics import get_scorer from sklearn.metrics._scorer import _BaseScorer @@ -68,7 +69,17 @@ class probability predictions. """ # Convert predictions to correct shape for sklearn scorer if isinstance(y_pred, list): - if self.scorer.__class__.__name__ == "_PredictScorer": + + convert_to_discrete = False + if sklearn.__version__ < "1.4.0": + if self.scorer.__class__.__name__ == "_PredictScorer": + convert_to_discrete = True + elif self.scorer._response_method == "predict": + convert_to_discrete = True + elif "predict_proba" not in self.scorer._response_method: + convert_to_discrete = True + + if convert_to_discrete: # convert to discrete values y_pred = [np.argmax(yp, axis=1) for yp in y_pred] else: