diff --git a/src/gentropy/assets/schemas/l2g_predictions.json b/src/gentropy/assets/schemas/l2g_predictions.json index 2043b3bee..1d100bf94 100644 --- a/src/gentropy/assets/schemas/l2g_predictions.json +++ b/src/gentropy/assets/schemas/l2g_predictions.json @@ -44,18 +44,18 @@ "name": "shapValue", "nullable": true, "type": "float" - }, - { - "metadata": {}, - "name": "scaledProbability", - "nullable": true, - "type": "float" } ], "type": "struct" }, "type": "array" } + }, + { + "name": "shapBaseValue", + "type": "float", + "nullable": true, + "metadata": {} } ] } diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 973baf382..a04a4507a 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -10,7 +10,6 @@ import shap from pyspark.sql import DataFrame from pyspark.sql.types import StructType -from scipy.special import expit from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session @@ -199,13 +198,10 @@ def explain( for i, feature in enumerate(features_list): pdf[f"shap_{feature}"] = [row[i] for row in shap_values] - # Normalise feature contributions so they sum to final probability - scaled_pdf = L2GPrediction._normalise_feature_contributions(pdf, base_value) - spark_session = self.df.sparkSession return L2GPrediction( _df=( - spark_session.createDataFrame(scaled_pdf.to_dict(orient="records")) + spark_session.createDataFrame(pdf.to_dict(orient="records")) .withColumn( "features", f.array( @@ -216,14 +212,12 @@ def explain( f.col(f"shap_{feature}") .cast("float") .alias("shapValue"), - f.col(f"scaled_prob_shap_{feature}") - .cast("float") - .alias("scaledProbability"), ) for feature in features_list ) ), ) + .withColumn("shapBaseValue", f.lit(base_value).cast("float")) .select(*L2GPrediction.get_schema().names) ), _schema=self.get_schema(), @@ -234,7 +228,7 @@ def explain( def _explain( model: LocusToGeneModel, pdf: pd_dataframe ) -> tuple[float, list[list[float]]]: - """Calculate SHAP values. Output is log odds ratio (raw mode). + """Calculate SHAP values. Output is in probability form (approximated from the log odds ratios). Args: model (LocusToGeneModel): L2G model @@ -244,10 +238,21 @@ def _explain( tuple[float, list[list[float]]]: A tuple containing: - base_value (float): Base value of the model - shap_values (list[list[float]]): SHAP values for prediction + + Raises: + AttributeError: If model.training_data is not set, seed dataset to get shapley values cannot be created. """ + if not model.training_data: + raise AttributeError( + "`model.training_data` is missing, seed dataset to get shapley values cannot be created." + ) + background_data = model.training_data._df.select( + *model.features_list + ).toPandas() explainer = shap.TreeExplainer( model.model, - feature_perturbation="tree_path_dependent", + data=background_data, + model_output="probability", ) if pdf.shape[0] >= 10_000: logging.warning( @@ -257,73 +262,9 @@ def _explain( pdf.to_numpy(), check_additivity=False, ) - base_value = expit(explainer.expected_value[0]) + base_value = explainer.expected_value return (base_value, shap_values) - @staticmethod - def _normalise_feature_contributions( - pdf: pd_dataframe, base_log_odds: float - ) -> pd_dataframe: - """Normalize SHAP contributions to probability space while preserving directionality. - - Args: - pdf (pd_dataframe): Input dataframe with SHAP values and scores - base_log_odds (float): Base log-odds from the SHAP explainer - - Returns: - pd_dataframe: Output dataframe with normalized probability contributions - """ - # Calculate base probability and sigmoid derivative - prob_base = expit(base_log_odds) - sigmoid_slope = prob_base * (1 - prob_base) # Derivative at base log-odds - - # ---------------------------------- - # 1. Linear Approximation Phase - # ---------------------------------- - # Convert SHAP values to directional probability deltas - shap_cols = [col for col in pdf.columns if col.startswith("shap_")] - linear_deltas = pdf[shap_cols] * sigmoid_slope - - # ---------------------------------- - # 2. Base Probability Distribution - # ---------------------------------- - # Calculate total absolute SHAP magnitude per row - total_abs_shap = ( - pdf[shap_cols].abs().sum(axis=1).replace(0, 1) # Avoid division by zero - ) - - # Distribute base probability proportionally to SHAP magnitudes - base_distribution = ( - pdf[shap_cols].abs().div(total_abs_shap, axis=0).mul(prob_base, axis=0) - ) - - # ---------------------------------- - # 3. Contribution Scaling Phase - # ---------------------------------- - # Calculate required probability adjustment - target_diff = pdf["score"] - prob_base - - # Calculate scaling factor for linear deltas - raw_delta_sum = linear_deltas.sum(axis=1).replace( - 0, 1 - ) # Avoid division by zero - scaling_factor = target_diff / raw_delta_sum - - # Scale deltas to match target probability difference - scaled_deltas = linear_deltas.mul(scaling_factor, axis=0) - - # ---------------------------------- - # 4. Final Contribution Calculation - # ---------------------------------- - # Combine base distribution and scaled deltas - final_contributions = base_distribution + scaled_deltas - - # Assign results to new columns - for col in shap_cols: - feature_name = col.replace("shap_", "") - pdf[f"scaled_prob_shap_{feature_name}"] = final_contributions[col] - return pdf - def add_features( self: L2GPrediction, feature_matrix: L2GFeatureMatrix, diff --git a/tests/gentropy/dataset/test_l2g_prediction.py b/tests/gentropy/dataset/test_l2g_prediction.py deleted file mode 100644 index 7a150f519..000000000 --- a/tests/gentropy/dataset/test_l2g_prediction.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Test L2G Prediction methods.""" - -import numpy as np -import pandas as pd - -from gentropy.dataset.l2g_prediction import L2GPrediction - - -def test_normalise_feature_contributions() -> None: - """Tests that scaled probabilities per feature add up to the probability inferred by the model.""" - df = pd.DataFrame( - { - "score": [0.45], # Final probability - "shap_feature1": [-3.85], - "shap_feature2": [3.015], - "shap_feature3": [0.063], - } - ) - base_log_odds = 0.56 - scaled_df = L2GPrediction._normalise_feature_contributions(df, base_log_odds) - reconstructed_prob = ( - scaled_df["scaled_prob_shap_feature1"].sum() - + scaled_df["scaled_prob_shap_feature2"].sum() - + scaled_df["scaled_prob_shap_feature3"].sum() - ) - assert np.allclose( - reconstructed_prob, df["score"], atol=1e-6 - ), "SHAP probability contributions do not sum to the expected probability."