An End-to-End Guide to Model Explainability

Priyanka Last Updated : 16 Oct, 2024
9 min read

This article was published as a part of the Data Science Blogathon.

In this article, we will learn about model explainability and the different ways to interpret a machine learning model.

What is Model Explainability?

Model explainability refers to the concept of being able to understand the machine learning model. For example – If a healthcare model is predicting whether a patient is suffering from a particular disease or not. The medical practitioners need to know what parameters the model is taking into account or if the model contains any bias. So, it is necessary that once the model is deployed in the real world. Then, the model developers can explain the model. 

Why is Model Explainability required?

  1. Being able to interpret a model increases trust in a machine learning model. This becomes all the more important in scenarios involving life-and-death situations like healthcare, law, credit lending, etc. For example – If a model is predicting cancer, the healthcare providers should be aware of the available variables.

  2. Once we understand a model, we can detect if there is any bias present in the model. For example – If a healthcare model has been trained on the American population, it might not be suitable for Asian people.

  3. Model Explainability becomes important while debugging a model during the development phase.

  4. Model Explainability is critical for getting models to vet by regulatory authorities like Food and Drug Administration (FDA), National Regulatory Authority, etc. It also helps to determine if the models are suitable to be deployed in real life.

How to develop Model Understanding?

Here we have two options at our disposal:

Option 1: Build models that are inherently interpretable – Glass Box Models.

For example – In a linear regression model of the form y = b0 + b1*x, we know that when x increases by 1% then y will increase by b1% keeping other factors constant.

Option 2: Post-hoc explanation of pre-built models – Black Box Models

For example – In a deep learning model, the model developers are not aware of how the input variables have combined to produce a particular output.

 
Glass Box ModelsBlack Box Models 
  
SimpleComplex
InterpretableNot easily Interpretable
Low accuracyHigh accuracy
Examples – Linear Models, Decision TreeExamples – Random Forest, Deep Learning

Ways to interpret a Model

There are two ways to interpret the model – Global vs Local interpretation.
Global InterpretationLocal interpretation
This helps in understanding how a model makes decisions for the overall structureThis helps in understanding how the model makes decisions for a single instance
 Using global interpretation we can explain the complete behavior of the model Using local interpretation we can explain the individual predictions
Global interpretation help in understanding the suitability of the model for deploymentLocal interpretation helps in understanding the behavior of the model in the local neighborhood
Example – Predicting the risk of disease in patientsExample – Understanding why a specific person has a high risk of a disease

Local Interpretation

We will discuss the following methods of local interpretation:

  • LIME (Local Interpretable Model-agnostic Explanations)
  • SHAP (SHapley Additive exPlanations)

 

LIME (Local Interpretable Model-Agnostic Explanations)

LIME provides a local interpretation by modifying feature values of a single data sample and observing its impact on the output. It builds a surrogate model from the input (sample generation) and model predictions. An interpretable model can be used as a surrogate model. Because LIME is a model agnostic technique, therefore it can be used on any model.

Steps involved in LIME:

  1. It creates a permutation (fake) of the given data.
  2. It calculates the distance between permutations and the original observations. Also, we can specify the distance measured.
  3. Then, it makes predictions on the new data using some black-box models.
  4. It picks “m” features that describe the complex model. It is an outcome from the permuted data in the best possible way through the maximum likelihood approach. Here, we can decide the number of features i.e. the value of “m” we want to use.
  5. It picks the “m” features and fits a simple model to the permuted data with the similarity score as weights.
  6. The weights from the simple model are used to provide explanations for the complex model’s local behavior.

SHAP (SHapley Additive exPlanations)

SHAP shows the impact of each feature by interpreting the impact of a certain value compared to a baseline value. The baseline used for prediction is the average of all the predictions. SHAP values allow us to determine any prediction as a sum of the effects of each feature value.

The only disadvantage with SHAP is that the computing time is high. The Shapley values can be combined together and used to perform global interpretations also.

Global Interpretation

We will discuss the following methods of global interpretation:

  • PDP (Partial Dependency Plot)
  • ICE(Individual Conditional Expectation)

 

PDP (Partial Dependency Plot)

PDP explains the global behavior of a model by showing the relationship of the marginal effect of each of the predictors on the response variable.

It shows a relationship between the target variable and a feature variable. Such a relationship could be complex, monotonic, or even a simple linear one. The plot assumes that the feature of interest (whose partial dependence is being computed) is not highly correlated with the other features. If the features of the model are correlated, then PDP does not provide the correct interpretation. We cannot plot PDP for all complex classifiers like Neural Networks.

ICE (Individual Conditional Expectation)

ICE is an extension of PDP(global method) but they are more intuitive to understand as compared to PDP. Using ICE, we can explain heterogeneous relationships. While PDP supports two feature explanations using ICE we can explain only one feature at a time.

Thus, it provides a plot of the average predicted outcomes. These outcomes are for different values of a feature while keeping the values of other feature values are constant.

Hands-on learning model explainability methods

We will explore the different model interpretation methods using the famous “Pima Indians Diabetes Database”  to predict whether a patient has diabetes or not.

Dataset can be downloaded here.

Python Code:

import pandas as pd
import numpy as np
#import matplotlib.pyplot as plt
#import seaborn as sn
#from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

diabetes_df = pd.read_csv( "diabetes.csv", index_col=[0] )

print(diabetes_df.head())
diabetes_df.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 768 entries, 6 to 1
Data columns (total 8 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   Glucose                   768 non-null    int64  
 1   BloodPressure             768 non-null    int64  
 2   SkinThickness             768 non-null    int64  
 3   Insulin                   768 non-null    int64  
 4   BMI                       768 non-null    float64
 5   DiabetesPedigreeFunction  768 non-null    float64
 6   Age                       768 non-null    int64  
 7   Outcome                   768 non-null    int64  
dtypes: float64(2), int64(6)
memory usage: 54.0 KB

Encode the Categorical Variable

X_features = list( diabetes_df.columns )
X_features.remove( "Outcome" )

Split the Dataset

from sklearn.model_selection import train_test_split
X_train, X_test, \
y_train, y_test = train_test_split( diabetes_df[X_features],
                                    diabetes_df.Outcome,
                                    test_size = 0.3,
                                    random_state = 100 )
X_train.shape
(537, 7)
X_test.shape
(231, 7)

Build a RandomForest Model

from sklearn.ensemble import RandomForestClassifier

rf_clf = RandomForestClassifier( n_estimators = 100,
                                 max_features = 0.2,
                                 max_depth = 10,
                                 max_samples = 0.5)
rf_clf.fit(X_train, y_train)
RandomForestClassifier(max_depth=10, max_features=0.2, max_samples=0.5)
y_pred_prob = rf_clf.predict_proba( X_test )[:,1]
y_pred = rf_clf.predict( X_test )

Understanding the Model using ELI5

  • Install ELI5
code
pip install eli5
!pip install eli5
Requirement already satisfied: eli5 in /usr/local/lib/python3.7/dist-packages (0.11.0)
Requirement already satisfied: numpy>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from eli5) (1.19.5)
Requirement already satisfied: tabulate>=0.7.7 in /usr/local/lib/python3.7/dist-packages (from eli5) (0.8.9)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from eli5) (1.4.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from eli5) (2.11.3)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from eli5) (1.15.0)
Requirement already satisfied: scikit-learn>=0.20 in /usr/local/lib/python3.7/dist-packages (from eli5) (1.0.1)
Requirement already satisfied: graphviz in /usr/local/lib/python3.7/dist-packages (from eli5) (0.10.1)
Requirement already satisfied: attrs>16.0.0 in /usr/local/lib/python3.7/dist-packages (from eli5) (21.2.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.20->eli5) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.20->eli5) (3.0.0)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->eli5) (2.0.1)
import eli5.sklearn
eli5.explain_weights( 
    rf_clf, 
    feature_names = X_features )
WeightFeature
0.2326 ± 0.1904Glucose
0.1686 ± 0.1484BMI
0.1480 ± 0.1107Diabetes Pedigree Function
0.1445 ± 0.1298Age
0.1214 ± 0.0904Blood Pressure
0.1004 ± 0.0842Skin Thickness
0.0844 ± 0.0768Insulin

Explaining the weights for a Non-diabetes

X_test.iloc[0]
Glucose                     79.000
BloodPressure               60.000
Skin Thickness              42.000
Insulin                     48.000
BMI                         43.500
Diabetes Pedigree Function  0.678
Age                         23.000
Name: 1, dtype: float64
y_test.iloc[0]
0
eli5.explain_prediction( rf_clf,
                         X_test.iloc[0],
                         target_names = ['Non-diabetes', 'Diabetes'] )

y=Non-diabetes (probability 0.869) top features

Contribution?Feature
+0.656<BIAS>
+0.162Glucose
+0.106Insulin
+0.081Age
+0.047Blood Pressure
-0.049Diabetes Pedigree Function
-0.066Skin Thickness
-0.067BMI

Explaining the weights for a Non-diabetes

X_test.iloc[4]
Glucose                     123.000
Blood Pressure               70.000
Skin Thickness               44.000
Insulin                      94.000
BMI                          33.100
Diabetes Pedigree Function   0.374
Age                          40.000
Name: 9, dtype: float64
y_test.iloc[4]
0
eli5.explain_prediction( rf_clf,
                         X_test.iloc[4],
                         target_names = ['Non-diabetes', 'Diabetes'] )

y=Non-diabetes (probability 0.524) top features

Contribution?Feature
+0.656<BIAS>
+0.072Insulin
+0.013Diabetes Pedigree Function
+0.002Glucose
-0.011Blood Pressure
-0.017BMI
-0.067Skin Thickness
-0.125Age

Partial Dependence Plots (PDPs)

from sklearn.inspection import PartialDependenceDisplay

Effect of Insulin on Diabetes

fig, ax = plt.subplots(figsize=(12, 6))
ax.set_title("Partial Dependency Plot")
PartialDependenceDisplay.from_estimator(rf_clf,
                                        X_test,
                                        features = ['Insulin'],
                                        feature_names = X_features,
                                        ax = ax);

Effect of Glucose on Diabetes

fig, ax = plt.subplots(figsize=(12, 6))
ax.set_title("Partial Dependency Plot")
PartialDependenceDisplay.from_estimator(rf_clf,
                                        X_test,
                                        features = ['Glucose'],
                                        feature_names = X_features,
                                        ax = ax)
<sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay at 0x7f1e39141e10>

LIME – Local Interpretation of Model Explanation

!pip install lime
Requirement already satisfied: lime in /usr/local/lib/python3.7/dist-packages (0.2.0.1)
Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.7/dist-packages (from lime) (0.18.3)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from lime) (1.19.5)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from lime) (3.2.2)
Requirement already satisfied: scikit-learn>=0.18 in /usr/local/lib/python3.7/dist-packages (from lime) (1.0.1)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from lime) (1.4.1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from lime) (4.62.3)
Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (2.6.3)
Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (2021.11.2)
Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (2.4.1)
Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (1.2.0)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.12->lime) (7.1.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (1.3.2)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (2.8.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (0.11.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->lime) (2.4.7)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->lime) (1.15.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.18->lime) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.18->lime) (3.0.0)
import lime
import lime.lime_tabular
explainer = (lime
             .lime_tabular
             .LimeTabularExplainer(training_data = X_train.to_numpy(), 
                                   training_labels = y_train,                                   
                                   feature_names = X_features, 
                                   class_names = ['Non-diabetes','Diabetes'],
                                   kernel_width=3,
                                   verbose = True ))

Explaining a case of Non-Diabetes

X_test.iloc[0]
Glucose                     79.000
BloodPressure               60.000
SkinThickness               42.000
Insulin                     48.000
BMI                         43.500
DiabetesPedigreeFunction     0.678
Age                         23.000
Name: 1, dtype: float64
exp = explainer.explain_instance( X_test.iloc[0].to_numpy(), 
                                  rf_clf.predict_proba )
Intercept 0.3836565649244127
Prediction_local [0.32668078]
Right: 0.13058823529411764
exp.show_in_notebook(show_table=True, show_all=False)

Explaining a case of Diabetes

exp = explainer.explain_instance( X_test.iloc[4].to_numpy(), 
                                  rf_clf.predict_proba )
Intercept 0.3299430371346654
Prediction_local [0.45386335]
Right: 0.47641411034143033
exp.show_in_notebook(show_table=True, show_all=False)

Using Shapley Values

  • Install SHAP


pip install shap

!pip install shap
Requirement already satisfied: shap in /usr/local/lib/python3.7/dist-packages (0.40.0)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (1.0.1)
Requirement already satisfied: slicer==0.0.7 in /usr/local/lib/python3.7/dist-packages (from shap) (0.0.7)
Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1)
Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap) (21.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0)
Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2)
Requirement already satisfied: pyparsing<3,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>20.9->shap) (2.4.7)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0)
Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (3.0.0)
import shap

Explaining a case of Non-diabetes

row_to_show = 1
data_for_prediction = X_test.iloc[row_to_show]
data_for_prediction_array = data_for_prediction.values.reshape(1, -1)
rf_clf.predict_proba(data_for_prediction_array)
array([[0.90805083, 0.09194917]])
explainer = shap.TreeExplainer(rf_clf)

shap_values = explainer.shap_values(data_for_prediction_array)
shap.initjs()
shap.force_plot( explainer.expected_value[1], 
                 shap_values[1], 
                 data_for_prediction,
                 figsize=(20, 2) )

Visualization omitted, Javascript library not loaded! 

Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). The Javascript has been stripped for security, if you notice the notebook on GitHub. When you use JupyterLab, this error will occur because it is a JupyterLab extension that has not yet been written.

Explaining a case of Diabetes

row_to_show = 4
data_for_prediction = X_test.iloc[row_to_show]
data_for_prediction_array = data_for_prediction.values.reshape(1, -1)
rf_clf.predict_proba(data_for_prediction_array)

shap_values = explainer.shap_values(data_for_prediction_array)

shap.initjs()
shap.force_plot( explainer.expected_value[1], 
                 shap_values[1], 
                 data_for_prediction,
                 figsize=(20, 2) )

Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). The GitHub notebook shows that the Javascript has been stripped for security. If you notice this error in JupyterLab. It is there because a JupyterLab extension has not yet been written.

Global Explanation of SHAP

explainer = shap.TreeExplainer( rf_clf )

shap_values = explainer.shap_values( X_train )
shap.summary_plot( shap_values[1], X_train, plot_type = 'dot' )
explainer.expected_value
array([0.65552239, 0.34447761])
shap.initjs()
shap.force_plot( explainer.expected_value[1], 
                 shap_values[1], 
                 X_train )

Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). The notebook is available on GitHub where the Javascript has been stripped for security. If you happen to use JupyterLab. Then, this error is because a JupyterLab extension has not yet been written.

Individual Conditional Expectations (ICE)

import matplotlib.pyplot as plt

from sklearn.inspection import PartialDependenceDisplay

fig, ax = plt.subplots(figsize=(12, 8))
ax.set_title("Individual Conditional Expectations")

display = PartialDependenceDisplay.from_estimator(
    rf_clf,
    X_train,
    features=["Age"],
    kind="individual",
    subsample=100,
    n_jobs=3,
    grid_resolution=20,
    random_state=0,
    ice_lines_kw={"color": "tab:blue", "alpha": 0.5, "linewidth": 0.5},
    ax = ax
)
fig, ax = plt.subplots(figsize=(12, 8))
ax.set_title("Individual Conditional Expectations")

display = PartialDependenceDisplay.from_estimator(
    rf_clf,
    X_train,
    features=["Age"],
    kind="both",
    subsample=100,
    n_jobs=3,
    grid_resolution=20,
    random_state=0,
    ice_lines_kw={"color": "tab:blue", "alpha": 0.5, "linewidth": 0.5},
    pd_line_kw={"color": "tab:orange", "linestyle": "--"},
    ax = ax
)

Model Explainability

For full code visit Github

Conclusion

Machine learning models are often seen as black-box models. However, in this article, we have seen how we can explain such models and why it is important to do so. Further, we have discussed ways to interpret and explain a model. Explainable AI (XAI) is emerging and we would possibly be able to automate the interpretation of ML models in the near future.

The media shown in this article is not owned by Analytics Vidhya and are used at the Author’s discretion

Priyanka Dalmia is pursuing MBA from IIM Bangalore. She likes to write about Product Management, Machine Learning & Quantitative finance. Follow her on Twitter @Quant_Dalmia

Responses From Readers

Clear

Congratulations, You Did It!
Well Done on Completing Your Learning Journey. Stay curious and keep exploring!

We use cookies essential for this site to function well. Please click to help us improve its usefulness with additional cookies. Learn about our use of cookies in our Privacy Policy & Cookies Policy.

Show details