Skip to content

Commit

Permalink
feat: final version of shap explanations
Browse files Browse the repository at this point in the history
  • Loading branch information
ireneisdoomed committed Feb 13, 2025
1 parent 37b83ac commit 62f45b4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 109 deletions.
12 changes: 6 additions & 6 deletions src/gentropy/assets/schemas/l2g_predictions.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@
"name": "shapValue",
"nullable": true,
"type": "float"
},
{
"metadata": {},
"name": "scaledProbability",
"nullable": true,
"type": "float"
}
],
"type": "struct"
},
"type": "array"
}
},
{
"name": "shapBaseValue",
"type": "float",
"nullable": true,
"metadata": {}
}
]
}
91 changes: 16 additions & 75 deletions src/gentropy/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import shap
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType
from scipy.special import expit

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.session import Session
Expand Down Expand Up @@ -199,13 +198,10 @@ def explain(
for i, feature in enumerate(features_list):
pdf[f"shap_{feature}"] = [row[i] for row in shap_values]

# Normalise feature contributions so they sum to final probability
scaled_pdf = L2GPrediction._normalise_feature_contributions(pdf, base_value)

spark_session = self.df.sparkSession
return L2GPrediction(
_df=(
spark_session.createDataFrame(scaled_pdf.to_dict(orient="records"))
spark_session.createDataFrame(pdf.to_dict(orient="records"))
.withColumn(
"features",
f.array(
Expand All @@ -216,14 +212,12 @@ def explain(
f.col(f"shap_{feature}")
.cast("float")
.alias("shapValue"),
f.col(f"scaled_prob_shap_{feature}")
.cast("float")
.alias("scaledProbability"),
)
for feature in features_list
)
),
)
.withColumn("shapBaseValue", f.lit(base_value).cast("float"))
.select(*L2GPrediction.get_schema().names)
),
_schema=self.get_schema(),
Expand All @@ -234,7 +228,7 @@ def explain(
def _explain(
model: LocusToGeneModel, pdf: pd_dataframe
) -> tuple[float, list[list[float]]]:
"""Calculate SHAP values. Output is log odds ratio (raw mode).
"""Calculate SHAP values. Output is in probability form (approximated from the log odds ratios).
Args:
model (LocusToGeneModel): L2G model
Expand All @@ -244,10 +238,21 @@ def _explain(
tuple[float, list[list[float]]]: A tuple containing:
- base_value (float): Base value of the model
- shap_values (list[list[float]]): SHAP values for prediction
Raises:
AttributeError: If model.training_data is not set, seed dataset to get shapley values cannot be created.
"""
if not model.training_data:
raise AttributeError(
"`model.training_data` is missing, seed dataset to get shapley values cannot be created."
)
background_data = model.training_data._df.select(
*model.features_list
).toPandas()
explainer = shap.TreeExplainer(
model.model,
feature_perturbation="tree_path_dependent",
data=background_data,
model_output="probability",
)
if pdf.shape[0] >= 10_000:
logging.warning(
Expand All @@ -257,73 +262,9 @@ def _explain(
pdf.to_numpy(),
check_additivity=False,
)
base_value = expit(explainer.expected_value[0])
base_value = explainer.expected_value
return (base_value, shap_values)

@staticmethod
def _normalise_feature_contributions(
pdf: pd_dataframe, base_log_odds: float
) -> pd_dataframe:
"""Normalize SHAP contributions to probability space while preserving directionality.
Args:
pdf (pd_dataframe): Input dataframe with SHAP values and scores
base_log_odds (float): Base log-odds from the SHAP explainer
Returns:
pd_dataframe: Output dataframe with normalized probability contributions
"""
# Calculate base probability and sigmoid derivative
prob_base = expit(base_log_odds)
sigmoid_slope = prob_base * (1 - prob_base) # Derivative at base log-odds

# ----------------------------------
# 1. Linear Approximation Phase
# ----------------------------------
# Convert SHAP values to directional probability deltas
shap_cols = [col for col in pdf.columns if col.startswith("shap_")]
linear_deltas = pdf[shap_cols] * sigmoid_slope

# ----------------------------------
# 2. Base Probability Distribution
# ----------------------------------
# Calculate total absolute SHAP magnitude per row
total_abs_shap = (
pdf[shap_cols].abs().sum(axis=1).replace(0, 1) # Avoid division by zero
)

# Distribute base probability proportionally to SHAP magnitudes
base_distribution = (
pdf[shap_cols].abs().div(total_abs_shap, axis=0).mul(prob_base, axis=0)
)

# ----------------------------------
# 3. Contribution Scaling Phase
# ----------------------------------
# Calculate required probability adjustment
target_diff = pdf["score"] - prob_base

# Calculate scaling factor for linear deltas
raw_delta_sum = linear_deltas.sum(axis=1).replace(
0, 1
) # Avoid division by zero
scaling_factor = target_diff / raw_delta_sum

# Scale deltas to match target probability difference
scaled_deltas = linear_deltas.mul(scaling_factor, axis=0)

# ----------------------------------
# 4. Final Contribution Calculation
# ----------------------------------
# Combine base distribution and scaled deltas
final_contributions = base_distribution + scaled_deltas

# Assign results to new columns
for col in shap_cols:
feature_name = col.replace("shap_", "")
pdf[f"scaled_prob_shap_{feature_name}"] = final_contributions[col]
return pdf

def add_features(
self: L2GPrediction,
feature_matrix: L2GFeatureMatrix,
Expand Down
28 changes: 0 additions & 28 deletions tests/gentropy/dataset/test_l2g_prediction.py

This file was deleted.

0 comments on commit 62f45b4

Please sign in to comment.