Skip to content

v0.4.0 - ZeroShotClassificationExplainer, Custom Labels for SequenceClassificationExplainer

Compare
Choose a tag to compare
@cdpierse cdpierse released this 25 May 17:07
· 98 commits to master since this release

Zero-Shot Classification Explainer (#19, #40)

This release introduces a new explainer ZeroShotClassificationExplainer that allows for word attributions to be calculated in a zero shot manner with the appropriate NLI type models. To achieve this we implement zero shot exactly how the hugging face team did based off the following paper. A list of compatible models with this explainer can be found at the transformers model hub.

Below is an example of instantiating and calling the ZeroShotClassificationExplainer:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import ZeroShotClassificationExplainer

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")

model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")


zero_shot_explainer = ZeroShotClassificationExplainer(model, tokenizer)


word_attributions = zero_shot_explainer(
    "Today apple released the new Macbook showing off a range of new features found in the proprietary silicon chip computer. ",
    labels = ["finance", "technology", "sports"], # any labels you want can be passed, even a list with one label
)

Which will return the following list of tuples:

>>> word_attributions
[('<s>', 0.0),
 ('Today', 0.0),
 ('apple', 0.22505152647747717),
 ('released', -0.16164146624851905),
 ('the', 0.5026975657258089),
 ('new', 0.052589263167955536),
 ('Mac', 0.2528325960993759),
 ('book', -0.06445090203729663),
 ('showing', -0.21204922293777534),
 ('off', 0.06319714817612732),
 ('a', 0.032048012090796815),
 ('range', 0.08553079346908955),
 ('of', 0.1409201107994034),
 ('new', 0.0515261917112576),
 ('features', -0.09656406466213506),
 ('found', 0.02336613296843605),
 ('in', -0.0011649894272190678),
 ('the', 0.14229640664777807),
 ('proprietary', -0.23169065661847646),
 ('silicon', 0.5963924257008087),
 ('chip', -0.19908474233975806),
 ('computer', 0.030620295844734646),
 ('.', 0.1995076958535378)]

We can find out which label was predicted with:

>>> zero_shot_explainer.predicted_label   
'technology (entailment)'

For the ZeroShotClassificationExplainer the visualize() method returns a table similar to the SequenceClassificationExplainer.

zero_shot_explainer.visualize("zero_shot.html")

Custom Labels For Sequence Classification - @lalitpagaria (#25, #41 )

This contribution by @lalitpagaria adds the ability to add custom class labels that replace the default labels originally found in the models's config.

This is a very useful addition as it is quite common for popular trained models to have not set the label names resulting in labels that look like "LABEL_0", "LABEL_1",.... This can make the sequence classification explainer visualization particularly hard to understand and not very readable.

Custom labels are passed to the SequenceClassificationExplainers's constructor and the number of labels passed must equal the number of already existing labels:

seq_explainer = SequenceClassificationExplainer(
        DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["sad", "happy"]
    )

Now the class at the 0th index corresponds to the label "sad" and the class at the 1st index corresponds to "happy".

This is a really nice addition that makes Transformers Interpret more usable with a whole range sequence classification models that don't have labels set. Thanks @lalitpagaria .

General Cleanup and Housekeeping

  • Cleaned up a number of flake8 reported linting errors
  • Improved on the docstring for QA explainer (more to come however) but the QA explainer still needs some finalization
  • Added some increased testing coverage
  • Added a contribution guideline