Skip to content

Commit

Permalink
feat: Add model-based explainability
Browse files Browse the repository at this point in the history
  • Loading branch information
T0217 committed Sep 20, 2024
1 parent 139afc7 commit bbad788
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/sdqc_check/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .causality import CausalAnalysis
from .classification import ClassificationModel
from .explainability import (
ShapFeatureImportance, PFIFeatureImportance
ShapFeatureImportance, PFIFeatureImportance, ModelBasedFeatureImportance
)
from .statistical_test import (
data_corr, identify_data_types, CategoricalTest, NumericalTest
Expand All @@ -13,6 +13,7 @@
'ClassificationModel',
'ShapFeatureImportance',
'PFIFeatureImportance',
'ModelBasedFeatureImportance',
'data_corr',
'identify_data_types',
'CategoricalTest',
Expand Down
4 changes: 2 additions & 2 deletions src/sdqc_check/explainability/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .shap import ShapFeatureImportance
from .pfi import PFIFeatureImportance
from .model_based import ModelBasedFeatureImportance

__all__ = [
'ShapFeatureImportance',
'PFIFeatureImportance',
'ModelBasedFeatureImportance',
]


45 changes: 45 additions & 0 deletions src/sdqc_check/explainability/model_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from .base import BaseFeatureImportance
import pandas as pd
import numpy as np


class ModelBasedFeatureImportance(BaseFeatureImportance):
"""
Model-Based Feature Importance calculator.
This class uses model-specific attributes to compute feature importance scores.
Parameters
----------
See base class BaseFeatureImportance for parameter details.
"""

def compute_feature_importance(self) -> pd.DataFrame:
"""
Compute Model-Based Feature Importance.
This method uses model-specific attributes to calculate feature importance.
Returns
-------
pd.DataFrame
A DataFrame with features and their model-based importance scores, sorted in descending order.
Raises
------
AttributeError:
If the model does not have the required attribute for feature importance.
"""
if hasattr(self.model, "feature_importances_"):
importances = self.model.feature_importances_
elif hasattr(self.model, "coef_"):
importances = np.abs(self.model.coef_[0])
else:
raise AttributeError(
"No attribute 'feature_importances_' or 'coef_' found for the model."
)
importance = pd.DataFrame({
'feature': self.X_train.columns,
'importance': importances
})
return importance.sort_values('importance', ascending=False)
8 changes: 5 additions & 3 deletions src/sdqc_integration/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sdqc_check import (
CausalAnalysis,
ClassificationModel,
ModelBasedFeatureImportance,
ShapFeatureImportance,
PFIFeatureImportance,
data_corr,
Expand Down Expand Up @@ -272,7 +273,7 @@ def explainability_step(
Perform explainability analysis on the best model.
This method applies the specified explainability algorithm
(SHAP or PFI) to interpret the predictions of the best
(Model-Based, SHAP or PFI) to interpret the predictions of the best
performing model.
Parameters
Expand Down Expand Up @@ -308,21 +309,22 @@ def _get_explainability_class(self) -> Type:
Returns
-------
Type
The explainability class (SHAP or PFI) to be used.
The explainability class (Model-Based, SHAP or PFI) to be used.
Raises
------
ValueError
If the specified explainability algorithm is not supported.
"""
explainability_classes = {
'model_based': ModelBasedFeatureImportance,
'shap': ShapFeatureImportance,
'pfi': PFIFeatureImportance
}
if self.explainability_algorithm not in explainability_classes:
raise ValueError(
f"Algorithm {self.explainability_algorithm} not supported. "
"Please choose from 'shap' or 'pfi'."
"Please choose from 'model_based', 'shap' or 'pfi'."
)
return explainability_classes[self.explainability_algorithm]

Expand Down
12 changes: 9 additions & 3 deletions tests/test_sdqc_check/test_explainability/test_explainability.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sdqc_check import ShapFeatureImportance, PFIFeatureImportance
from sdqc_check import (
ShapFeatureImportance, PFIFeatureImportance, ModelBasedFeatureImportance
)


@pytest.fixture
Expand All @@ -22,7 +24,9 @@ def sample_data():
return model, X_train, X_test, y_train, y_test


@pytest.mark.parametrize("FeatureImportance", [ShapFeatureImportance, PFIFeatureImportance])
@pytest.mark.parametrize("FeatureImportance", [
ShapFeatureImportance, PFIFeatureImportance, ModelBasedFeatureImportance
])
def test_feature_importance(sample_data, FeatureImportance):
model, X_train, X_test, y_train, y_test = sample_data
importance = FeatureImportance(model, X_train, X_test, y_test)
Expand All @@ -38,7 +42,9 @@ def test_feature_importance(sample_data, FeatureImportance):
assert result['importance'].is_monotonic_decreasing


@pytest.mark.parametrize("FeatureImportance", [ShapFeatureImportance, PFIFeatureImportance])
@pytest.mark.parametrize("FeatureImportance", [
ShapFeatureImportance, PFIFeatureImportance, ModelBasedFeatureImportance
])
def test_feature_importance_random_seed(sample_data, FeatureImportance):
model, X_train, X_test, y_train, y_test = sample_data
importance1 = FeatureImportance(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sdqc_integration/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_causal_analysis_step(sample_data):
assert isinstance(synthetic_matrix, np.ndarray)


@pytest.mark.parametrize('explainability_algorithm', ['shap', 'pfi'])
@pytest.mark.parametrize('explainability_algorithm', ['model_based', 'shap', 'pfi'])
def test_explainability_step(sample_data, explainability_algorithm):
raw_data, synthetic_data = sample_data
analysis = SequentialAnalysis(
Expand Down

0 comments on commit bbad788

Please sign in to comment.