Skip to content

Commit

Permalink
Merge branch 'fix/metrics_sklearn_1.4' into 'dev'
Browse files Browse the repository at this point in the history
fix metrics for sklearn version 1.4.0

See merge request cdd/QSPRpred!167
  • Loading branch information
HellevdM committed Feb 7, 2024
2 parents 4c6e58a + f6c008b commit ec966c0
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion qsprpred/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod

import numpy as np
import sklearn
from sklearn.metrics import get_scorer
from sklearn.metrics._scorer import _BaseScorer

Expand Down Expand Up @@ -68,7 +69,17 @@ class probability predictions.
"""
# Convert predictions to correct shape for sklearn scorer
if isinstance(y_pred, list):
if self.scorer.__class__.__name__ == "_PredictScorer":

convert_to_discrete = False
if sklearn.__version__ < "1.4.0":
if self.scorer.__class__.__name__ == "_PredictScorer":
convert_to_discrete = True
elif self.scorer._response_method == "predict":
convert_to_discrete = True
elif "predict_proba" not in self.scorer._response_method:
convert_to_discrete = True

if convert_to_discrete:
# convert to discrete values
y_pred = [np.argmax(yp, axis=1) for yp in y_pred]
else:
Expand Down

0 comments on commit ec966c0

Please sign in to comment.