diff --git a/pytorch_tabnet/metrics.py b/pytorch_tabnet/metrics.py index e8ad8181..c9a263ec 100644 --- a/pytorch_tabnet/metrics.py +++ b/pytorch_tabnet/metrics.py @@ -9,6 +9,7 @@ log_loss, balanced_accuracy_score, mean_squared_log_error, + classification_report ) import torch @@ -403,6 +404,36 @@ def __call__(self, y_true, y_score): return np.sqrt(mean_squared_log_error(y_true, y_score)) +class ClassificationReport(Metric): + """ + Classification Report: Precision, Recall and F1 scores. + Scikit-implementation: + https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html + """ + + def __init__(self): + self._name = "classification_report" + self._maximize = False + + def __call__(self, y_true, y_score): + """ + Compute precision, recall and F1 scores of predictions for each target class. + + Parameters + ---------- + y_true : np.ndarray + Target matrix or vector + y_score : np.ndarray + Score matrix or vector + + Returns + ------- + str + table of precision, recall, and f1 score as well as supports + """ + return classification_report(y_true, y_score) + + class UnsupervisedMetric(Metric): """ Unsupervised metric