From bbad7880f4e129e2fce6cd68828187d34869ea8c Mon Sep 17 00:00:00 2001 From: T Date: Fri, 20 Sep 2024 14:45:43 +0800 Subject: [PATCH] feat: Add model-based explainability --- src/sdqc_check/__init__.py | 3 +- src/sdqc_check/explainability/__init__.py | 4 +- src/sdqc_check/explainability/model_based.py | 45 +++++++++++++++++++ src/sdqc_integration/sequential.py | 8 ++-- .../test_explainability.py | 12 +++-- .../test_sdqc_integration/test_sequential.py | 2 +- 6 files changed, 64 insertions(+), 10 deletions(-) create mode 100644 src/sdqc_check/explainability/model_based.py diff --git a/src/sdqc_check/__init__.py b/src/sdqc_check/__init__.py index c7216d6..b815bf7 100644 --- a/src/sdqc_check/__init__.py +++ b/src/sdqc_check/__init__.py @@ -1,7 +1,7 @@ from .causality import CausalAnalysis from .classification import ClassificationModel from .explainability import ( - ShapFeatureImportance, PFIFeatureImportance + ShapFeatureImportance, PFIFeatureImportance, ModelBasedFeatureImportance ) from .statistical_test import ( data_corr, identify_data_types, CategoricalTest, NumericalTest @@ -13,6 +13,7 @@ 'ClassificationModel', 'ShapFeatureImportance', 'PFIFeatureImportance', + 'ModelBasedFeatureImportance', 'data_corr', 'identify_data_types', 'CategoricalTest', diff --git a/src/sdqc_check/explainability/__init__.py b/src/sdqc_check/explainability/__init__.py index 2c8de96..e243b50 100644 --- a/src/sdqc_check/explainability/__init__.py +++ b/src/sdqc_check/explainability/__init__.py @@ -1,9 +1,9 @@ from .shap import ShapFeatureImportance from .pfi import PFIFeatureImportance +from .model_based import ModelBasedFeatureImportance __all__ = [ 'ShapFeatureImportance', 'PFIFeatureImportance', + 'ModelBasedFeatureImportance', ] - - diff --git a/src/sdqc_check/explainability/model_based.py b/src/sdqc_check/explainability/model_based.py new file mode 100644 index 0000000..eb08195 --- /dev/null +++ b/src/sdqc_check/explainability/model_based.py @@ -0,0 +1,45 @@ +from .base import BaseFeatureImportance +import pandas as pd +import numpy as np + + +class ModelBasedFeatureImportance(BaseFeatureImportance): + """ + Model-Based Feature Importance calculator. + + This class uses model-specific attributes to compute feature importance scores. + + Parameters + ---------- + See base class BaseFeatureImportance for parameter details. + """ + + def compute_feature_importance(self) -> pd.DataFrame: + """ + Compute Model-Based Feature Importance. + + This method uses model-specific attributes to calculate feature importance. + + Returns + ------- + pd.DataFrame + A DataFrame with features and their model-based importance scores, sorted in descending order. + + Raises + ------ + AttributeError: + If the model does not have the required attribute for feature importance. + """ + if hasattr(self.model, "feature_importances_"): + importances = self.model.feature_importances_ + elif hasattr(self.model, "coef_"): + importances = np.abs(self.model.coef_[0]) + else: + raise AttributeError( + "No attribute 'feature_importances_' or 'coef_' found for the model." + ) + importance = pd.DataFrame({ + 'feature': self.X_train.columns, + 'importance': importances + }) + return importance.sort_values('importance', ascending=False) diff --git a/src/sdqc_integration/sequential.py b/src/sdqc_integration/sequential.py index 4dda499..11bae39 100644 --- a/src/sdqc_integration/sequential.py +++ b/src/sdqc_integration/sequential.py @@ -10,6 +10,7 @@ from sdqc_check import ( CausalAnalysis, ClassificationModel, + ModelBasedFeatureImportance, ShapFeatureImportance, PFIFeatureImportance, data_corr, @@ -272,7 +273,7 @@ def explainability_step( Perform explainability analysis on the best model. This method applies the specified explainability algorithm - (SHAP or PFI) to interpret the predictions of the best + (Model-Based, SHAP or PFI) to interpret the predictions of the best performing model. Parameters @@ -308,7 +309,7 @@ def _get_explainability_class(self) -> Type: Returns ------- Type - The explainability class (SHAP or PFI) to be used. + The explainability class (Model-Based, SHAP or PFI) to be used. Raises ------ @@ -316,13 +317,14 @@ def _get_explainability_class(self) -> Type: If the specified explainability algorithm is not supported. """ explainability_classes = { + 'model_based': ModelBasedFeatureImportance, 'shap': ShapFeatureImportance, 'pfi': PFIFeatureImportance } if self.explainability_algorithm not in explainability_classes: raise ValueError( f"Algorithm {self.explainability_algorithm} not supported. " - "Please choose from 'shap' or 'pfi'." + "Please choose from 'model_based', 'shap' or 'pfi'." ) return explainability_classes[self.explainability_algorithm] diff --git a/tests/test_sdqc_check/test_explainability/test_explainability.py b/tests/test_sdqc_check/test_explainability/test_explainability.py index fcd9a1e..2d753c3 100644 --- a/tests/test_sdqc_check/test_explainability/test_explainability.py +++ b/tests/test_sdqc_check/test_explainability/test_explainability.py @@ -4,7 +4,9 @@ from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier from sklearn.svm import SVC -from sdqc_check import ShapFeatureImportance, PFIFeatureImportance +from sdqc_check import ( + ShapFeatureImportance, PFIFeatureImportance, ModelBasedFeatureImportance +) @pytest.fixture @@ -22,7 +24,9 @@ def sample_data(): return model, X_train, X_test, y_train, y_test -@pytest.mark.parametrize("FeatureImportance", [ShapFeatureImportance, PFIFeatureImportance]) +@pytest.mark.parametrize("FeatureImportance", [ + ShapFeatureImportance, PFIFeatureImportance, ModelBasedFeatureImportance +]) def test_feature_importance(sample_data, FeatureImportance): model, X_train, X_test, y_train, y_test = sample_data importance = FeatureImportance(model, X_train, X_test, y_test) @@ -38,7 +42,9 @@ def test_feature_importance(sample_data, FeatureImportance): assert result['importance'].is_monotonic_decreasing -@pytest.mark.parametrize("FeatureImportance", [ShapFeatureImportance, PFIFeatureImportance]) +@pytest.mark.parametrize("FeatureImportance", [ + ShapFeatureImportance, PFIFeatureImportance, ModelBasedFeatureImportance +]) def test_feature_importance_random_seed(sample_data, FeatureImportance): model, X_train, X_test, y_train, y_test = sample_data importance1 = FeatureImportance( diff --git a/tests/test_sdqc_integration/test_sequential.py b/tests/test_sdqc_integration/test_sequential.py index 6ad2731..e15c7c7 100644 --- a/tests/test_sdqc_integration/test_sequential.py +++ b/tests/test_sdqc_integration/test_sequential.py @@ -50,7 +50,7 @@ def test_causal_analysis_step(sample_data): assert isinstance(synthetic_matrix, np.ndarray) -@pytest.mark.parametrize('explainability_algorithm', ['shap', 'pfi']) +@pytest.mark.parametrize('explainability_algorithm', ['model_based', 'shap', 'pfi']) def test_explainability_step(sample_data, explainability_algorithm): raw_data, synthetic_data = sample_data analysis = SequentialAnalysis(