SHapley Additive exPlanations (SHAP):
The ability to correctly interpret a prediction model’s output is extremely important. It engenders appropriate user trust, provides insight into how a model may be improved, and supports understanding of the process being modeled. In some applications, simple models (e.g., linear models) are often preferred for their ease of interpretation, even if they may be less accurate than complex ones. However, the growing availability of big data has increased the benefits of using complex models, so bringing to the forefront the trade-off between accuracy and interpretability of a model’s output. There are a few methods prior to SHAP like LIME, DeepLIFT which address this problem. But an understanding of how these methods relate and when one method is preferable to another is lacking.
SHAP present a novel unified approach for interpreting model predictions by unifying most of the previous methods. SHAP assigns each feature an importance value for a particular prediction. This resulted improved computational performance and/or better consistency with human intuition than previous approaches.
Let us apply SHAP for interpreting Language Models for classification and generation tasks —
Note: All the code used for plots and snippets in this article is hosted in this repo.
SHAP for Classification:
For this example, let us consider multiclass (6) classification ‘emotion’ dataset from HuggingFace(HF) Datasets and explore the predictions of ‘nateraw/bert-base-uncased-emotion’ HF model which is already fine-tuned on ‘emotion’ dataset. This dataset contain ‘text’ input and ‘emotion’ label and here is the distribution of the labels —
Once model and tokenizer were loaded, we have to wrap these with ‘shap.Explainer’ as shown below —
‘explainer’ is used to calculate ‘shap_values’ which in turn is used to plot variety of graphs to assess the predictions. Here is an example where first 50 samples of the dataset is passed to explainer.
shap_values = explainer(data['text'][0:50])
We can get explore dataset level feature impact scores using these shap_values for each class. Here is an example for plotting top features by magnitude for ‘joy’ class in 50 samples that we passed to explainer.
shap.plots.bar(shap_values[:,:,"joy"].mean(0))
Top features that are positively impacting the ‘joy’ class :
shap.plots.bar(shap_values[:,:,"joy"].mean(0), order=shap.Explanation.argsort.flip)
Top features that are negatively impacting the ‘joy’ class :
shap.plots.bar(shap_values[:,:,"joy"].mean(0), order=shap.Explanation.argsort)
The above plots gives the overall idea of what words are considered important for a particular class. But to check the individual sample, there are other kinds of interactive plots.
If we want to check the performance of the model for all the classes visually, here is an example for last two samples —
shap.plots.text(shap_values[-2:])
In the above plot, ‘Input Text’ is self-explanatory and ‘Output Text’ is the space-separated class names and we can hover on any of the class name which highlights(red- positive impact; blue- negative impact) the parts of input text that contributed the most.
If we have a sentence and want to check the impact of phrases for a particular prediction, again ‘shap_values’ comes to rescue as shown below —
Note: I took ‘IMDB’ dataset for this example as ‘input text’ is longer
shap.plots.text(shap_values[:,:,"POSITIVE"])
SHAP for Generation:
For Generation, each token generated is based on the gradients of input tokens and this is visualized accurately with the heatmap that we used earlier.
Here is the example for summarization with ‘distilbart’—
s = dataset['document'][0:1]
explainer = shap.Explainer(model,tokenizer)
shap.plots.text(shap_values)
Here is the example for open-ended text generation with ‘gpt-2’ —
explainer = shap.Explainer(model,tokenizer)
s = ['Two Muslims']
shap_values = explainer(s)
shap.plots.text(shap_values)
If the input prompt is ‘Two Muslims’, see how generated text is related to violence. This is more evident if we hover on ‘killed’ as it is highlighting that because of ‘Muslims’ in input, it generated ‘killed’. This is the negative side of these huge language models. You can find same issue in GPT-3 here.
Language Interpretability Tool(LIT):
The Language Interpretability Tool (LIT) is for researchers and practitioners looking to understand NLP model behavior through a visual, interactive, and extensible tool.
Use LIT to ask and answer questions like:
- What kind of examples does my model perform poorly on?
- Why did my model make this prediction? Can it attribute it to adversarial behavior, or undesirable priors from the training set?
- Does my model behave consistently if I change things like textual style, verb tense, or pronoun gender?
LIT contains many built-in capabilities but is also customizable, with the ability to add custom interpretability techniques, metrics calculations, counterfactual generators, visualizations, and more.
The biggest advantage of LIT is its interactive UI where you can compare multiple models, change data sample on the fly, visualize all the predictions in a 3-d space which gives a very good idea of the model performance.
Here is the general layout of the LIT —
A detailed user guide for the layout is here to get familiarize with this UI.
On a first glance, we can notice the ‘Embeddings’ section which is nothing but the model predictions for all the data samples and is colored based on label and projected in a 3-d space. This quickly explains that my model is performing good as orange and blue are segregated decently.
‘Data Table’ section shows all the data points and their respective labels, predictions etc. With ‘Datapoint Editor’, one can quickly edit a data sample (may be changing the gender to examine the model bias) and compare it with original sample. We can also add new data samples to the dataset.
‘Performance’ tab shows the accuracy, precision, recall, f1 and confusion matrix of the model without we explicitly calculating them.
‘Performance’ tab shows the spread of all data points as shown below —
‘Explanations’ tab gives various types of gradient weights that correspond to a particular prediction. Here is an example for wrong prediction —
As model predicted this sample wrong with an accuracy of 85%, we can explore the gradients which gives the reason why model predicted wrong and which tokens correspond to this prediction. By using ‘LIME’ explanation, it is clear that words such as never, loses, grim, situation have higher -ve scores and these made the overall prediction got towards -ve side because of which the prediction if ‘0’.
‘Counterfactuals’ is used to replace some words in data points and we can also scramble words in an example.
More such demos for classification, regression, summarization, gender bias and using LIT in notebooks can be found here.
But to make LIT work for custom models and datasets, we have to make a few code changes and this notebook explains it all by training a ‘DistilBert’ model on ‘news classification’ dataset, integrates both into LIT and rendering LIT UI in notebook itself !!
SHAP vs LIT:
As we have seen the capabilities of both SHAP and LIT, the immediate question that pops is ‘What should I use?’
This can be answered by considering the bigger picture as follows —
If I want to know the most important tokens of the dataset for the model predictions or if I have to assess the model for all the available classes, then I have to consider SHAP.
But if I want to visualize predictions, gradients, add/change/compare data points on the fly, then I have to consider LIT.
This article was co-authored by Saichandra Pandraju and Sakthi Ganesh