Releases: cdpierse/transformers-interpret
v0.10.0 Fix Multi Label Activations
ImageClassificationExplainer 🖼️
This is a hugely exciting release for us as it is our first foray into the domain of computer vision. With this update, we are adding support for image classification models inside the Huggingface Transformers ecosystem. We are very excited to bring a simple API for calculating and visualizing attributions for vision transformers and their numerous variants in just 3 lines of code.
ImageClassificationExplainer (#105)
The ImageClassificationExplainer
is designed to work with all models from the Transformers library that are trained for image classification (Swin, ViT etc). It provides attributions for every pixel in that image that can be easily visualized using the explainer's built-in visualize
method.
Initialising an image classification is very simple, all you need is an image classification model finetuned or trained to work with Huggingface and its feature extractor.
For this example we are using google/vit-base-patch16-224
, a Vision Transformer (ViT) model pre-trained on ImageNet-21k that predicts from 1000 possible classes.
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from transformers_interpret import ImageClassificationExplainer
from PIL import Image
import requests
model_name = "google/vit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# With both the model and feature extractor initialized we are now able to get explanations on an image, we will use a simple image of a golden retriever.
image_link = "https://imagesvc.meredithcorp.io/v3/mm/image?url=https%3A%2F%2Fstatic.onecms.io%2Fwp-content%2Fuploads%2Fsites%2F47%2F2020%2F08%2F16%2Fgolden-retriever-177213599-2000.jpg"
image = Image.open(requests.get(image_link, stream=True).raw)
image_classification_explainer = ImageClassificationExplainer(model=model, feature_extractor=feature_extractor)
image_attributions = image_classification_explainer(
image
)
print(image_attributions.shape)
Which will return the following list of tuples:
>>> torch.Size([1, 3, 224, 224])
Visualizing Image Attributions
Because we are dealing with images visualization is even more straightforward than in text models.
Attributions can be easily visualized using the visualize
method of the explainer. There are currently 4 supported visualization methods.
heatmap
- a heatmap of positive and negative attributions is drawn in using the dimensions of the image.overlay
- the heatmap is overlayed over a grayscaled version of the original imagemasked_image
- the absolute value of attributions is used to create a mask over the original imagealpha_scaling
- Sets the alpha channel (transparency) of each pixel to be equal to the normalized attribution value.
Heatmap
image_classification_explainer.visualize(
method="heatmap",
side_by_side=True,
outlier_threshold=0.03
)
Overlay
image_classification_explainer.visualize(
method="overlay",
side_by_side=True,
outlier_threshold=0.03
)
Masked Image
image_classification_explainer.visualize(
method="masked_image",
side_by_side=True,
outlier_threshold=0.03
)
Alpha Scaling
image_classification_explainer.visualize(
method="alpha_scaling",
side_by_side=True,
outlier_threshold=0.03
)
PairwiseSequenceClassificationExplainer, RoBERTa bug fixes, GH Actions migration
Release version 0.8.1
Lots of changes big and small with this release:
PairwiseSequenceClassificationExplainer (#87, #82, #58)
This has been a fairly requested feature and one that I am very happy to release, especially as I have had the desire to explain the outputs of CrossEncoder models as of late.
The PairwiseSequenceClassificationExplainer
is a variant of the SequenceClassificationExplainer
that is designed to work with classification models that expect the input sequence to be two inputs separated by a models' separator token. Common examples of this are NLI models and Cross-Encoders which are commonly used to score two inputs similarity to one another.
This explainer calculates pairwise attributions for two passed inputs text1 and text2 using the model and tokenizer given in the constructor.
Also, since a common use case for pairwise sequence classification is to compare two inputs similarity - models of this nature typically only have a single output node rather than multiple for each class. The pairwise sequence classification has some useful utility functions to make interpreting single node outputs clearer.
By default for models that output a single node the attributions are with respect to the inputs pushing the scores closer to 1.0, however if you want to see the attributions with respect to scores closer to 0.0 you can pass flip_sign=True
when calling the explainer. For similarity-based models, this is useful, as the model might predict a score closer to 0.0 for the two inputs and in that case, we would flip the attributions sign to explain why the two inputs are dissimilar.
Example Usage
For this example we are using "cross-encoder/ms-marco-MiniLM-L-6-v2"
, a high quality cross-encoder trained on the MSMarco dataset a passage ranking dataset for question answering and machine reading comprehension.
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret.explainers.sequence_classification import PairwiseSequenceClassificationExplainer
model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-6-v2")
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-6-v2")
pairwise_explainer = PairwiseSequenceClassificationExplainer(model, tokenizer)
# the pairwise explainer requires two string inputs to be passed, in this case given the nature of the model
# we pass a query string and a context string. The question we are asking of our model is "does this context contain a valid answer to our question"
# the higher the score the better the fit.
query = "How many people live in Berlin?"
context = "Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."
pairwise_attr = explainer(query, context)
Which returns the following attributions:
>>> pairwise_attr
[('[CLS]', 0.0),
('how', -0.037558652124213034),
('many', -0.40348581975409786),
('people', -0.29756140282349425),
('live', -0.48979015417391764),
('in', -0.17844527885888117),
('berlin', 0.3737346097442739),
('?', -0.2281428913480142),
('[SEP]', 0.0),
('berlin', 0.18282430604641564),
('has', 0.039114659489254834),
('a', 0.0820056652212297),
('population', 0.35712150914643026),
('of', 0.09680870840224687),
('3', 0.04791760029513795),
(',', 0.040330986539774266),
('520', 0.16307677913176166),
(',', -0.005919693904602767),
('03', 0.019431649515841844),
('##1', -0.0243808667024702),
('registered', 0.07748341753369632),
('inhabitants', 0.23904087299731255),
('in', 0.07553221327346359),
('an', 0.033112821611999875),
('area', -0.025378852244447532),
('of', 0.026526373859562906),
('89', 0.0030700151809002147),
('##1', -0.000410387092186983),
('.', -0.0193147139126114),
('82', 0.0073800833347678774),
('square', 0.028988305990861576),
('kilometers', 0.02071182933829008),
('.', -0.025901070914318036),
('[SEP]', 0.0)]
Visualize Pairwise Classification attributions
Visualizing the pairwise attributions is no different to the sequence classification explaine. We can see that in both the query
and context
there is a lot of positive attribution for the word berlin
as well the words population
and inhabitants
in the context
, good signs that our model understands the textual context of the question asked.
pairwise_explainer.visualize("cross_encoder_attr.html")
If we were more interested in highlighting the input attributions that pushed the model away from the positive class of this single node output we could pass:
pairwise_attr = explainer(query, context, flip_sign=True)
This simply inverts the sign of the attributions ensuring that they are with respect to the model outputting 0 rather than 1.
RoBERTa Consitency Improvements (#65)
Thanks to some great detective work by @dvsrepo, @jogonba2, @databill86, and @VDuchauffour on this issue over the last year we've been able to identify what looks to be the main culprit responsible for the misalignment of scores given for RoBERTa based model inside the package when compared with their actual outputs in the transformers package.
Because this package has to create reference id's for each input type (input_ids, position_ids, token_type_ids) to create a baseline we try and emulate the outputs of the model's tokenizers in an automated fashion, for most BERT-based models this works great but as I have learned from reading this thread (#65) there were significant issues with RoBERTa.
It seems that the main reason for this is that RoBERTa implements position_ids
in a very different manner to BERT (read this and this for extra context). Since we were passing completely incorrect values for position_ids it appears to have thrown the model's predictions off. This release does not fully fix the issue but it does bypass the passing of incorrect position_ids
by simply not passing them to the forward function. We've done this by creating a flag that recognises certain model architectures as being incompatible with how we create position_ids
according the Transformers docs when position_ids
are not passed:
They are an optional parameter. If no position_ids are passed to the model, the IDs are automatically created as absolute positional embeddings.
So this solution should be good for most situations, however, ideally in the future, we will look into creating RoBERTa compatible position_ids
within the package itself.
Move to GH actions
This release also moves our testing suite from CircleCI to GH Actions, GH Actions has proven to be easier to integrate with and much more convenient.
Other
- Set the minimum python version to 3.7. As of December 2021 Python 3.6 is no longer officially supported by the python team therefore we have also removed support for it.
- Housekeeping and cleanup around the codebase
v0.7.2 - TokenClassificationExplainer (NER) 🧍🏻♀️🌎 🏢
TokenClassificationExplainer(#91)
This incredible release is all thanks to a fantastic community contribution from @pabvald, he implemented the entire TokenClassificationExplainer
class, as well as all its tests and associated docs. A huge thank you again to Pablo for this amazing work, it has been on my to-do list for over a year and I greatly appreciate this contribution and I know the community will too.
This new explainer is designed to work with any and all models in the HuggingFaceTransformers package that are of the kind {Model}ForTokenClassification
, which are models commonly used for tasks such Named Entity Recognition (NER) and Part-of-speech (POS) tagging.
The TokenClassificationExplainer
returns a dictionary mapping each word in a given sequence to a label in the model's trained labels configuration. Token classification models work on a word by word basis so the structure of this explainers output is that each word maps to another dictionary which contains two keys label
and attribution_scores
, where label
is a string indicating the predicted label and attribution_scores
is another dict mapping words to scores for the given root word key.
How to use
from transformers import AutoModelForTokenClassification, AutoTokenizer
from transformers_interpret import TokenClassificationExplainer
MODEL_PATH = 'dslim/bert-base-NER'
model = AutoModelForTokenClassification.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
ner_explainer = TokenClassificationExplainer(model=model, tokenizer=tokenizer)
sample_text = "Tim Cook is CEO of Apple."
attributions = ner_explainer(sample_text)
print(attributions)
Expand to see word attribution dictionary
{'[CLS]': {'label': 'O',
'attribution_scores': [('[CLS]', 0.0),
('Tim', 0.346423320984119),
('Cook', 0.5334609978768102),
('is', -0.40334870049983335),
('CEO', -0.3101234375976895),
('of', 0.512072192130804),
('Apple', -0.17249370683345489),
('.', 0.21111967418861474),
('[SEP]', 0.0)]},
'Tim': {'label': 'B-PER',
'attribution_scores': [('[CLS]', 0.0),
('Tim', 0.6097200124017794),
('Cook', 0.7418433507979225),
('is', 0.2277328676307869),
('CEO', 0.12913824237676577),
('of', 0.0658425121482477),
('Apple', 0.06830320263790929),
('.', -0.01924683905463743),
('[SEP]', 0.0)]},
'Cook': {'label': 'I-PER',
'attribution_scores': [('[CLS]', 0.0),
('Tim', 0.5523936725613293),
('Cook', 0.8009957951991128),
('is', 0.1804967026709793),
('CEO', 0.12327788007775593),
('of', 0.042470529981614845),
('Apple', 0.057217721910403266),
('.', -0.020318897077615642),
('[SEP]', 0.0)]},
'is': {'label': 'O',
'attribution_scores': [('[CLS]', 0.0),
('Tim', 0.24614651317657982),
('Cook', -0.009088703281476993),
('is', 0.9216954069405697),
('CEO', 0.026992140219729874),
('of', 0.2520559406534854),
('Apple', -0.09920548911190433),
('.', 0.12531705560714215),
('[SEP]', 0.0)]},
'CEO': {'label': 'O',
'attribution_scores': [('[CLS]', 0.0),
('Tim', 0.3124910273039106),
('Cook', 0.3625517589427658),
('is', 0.3507524148134499),
('CEO', 0.37196988201878567),
('of', 0.645668212957734),
('Apple', -0.27458958091134866),
('.', 0.13126252757894524),
('[SEP]', 0.0)]},
'of': {'label': 'O',
'attribution_scores': [('[CLS]', 0.0),
('Tim', 0.021065140560775575),
('Cook', 0.05638048932919909),
('is', 0.16774739397504396),
('CEO', 0.043009122581603866),
('of', 0.9340829137500298),
('Apple', -0.11144488868920191),
('.', 0.2854079089492836),
('[SEP]', 0.0)]},
'Apple': {'label': 'B-ORG',
'attribution_scores': [('[CLS]', 0.0),
('Tim', -0.017330599088927878),
('Cook', -0.04074196463435918),
('is', -0.08738080703156076),
('CEO', 0.23234519803002726),
('of', 0.12270125701886334),
('Apple', 0.9561624229708163),
('.', -0.08436746169241069),
('[SEP]', 0.0)]},
'.': {'label': 'O',
'attribution_scores': [('[CLS]', 0.0),
('Tim', 0.052863660537099254),
('Cook', -0.0694824371223385),
('is', -0.18074653059003534),
('CEO', 0.021118463602210605),
('of', 0.06322422431822372),
('Apple', -0.6286955666244136),
('.', 0.748336093254276),
('[SEP]', 0.0)]},
'[SEP]': {'label': 'O',
'attribution_scores': [('[CLS]', 0.0),
('Tim', 0.29980967625881066),
('Cook', -0.22297477338851293),
('is', -0.050889312336460345),
('CEO', 0.11157068443843984),
('of', 0.25200059104116196),
('Apple', -0.8839047143031845),
('.', -0.023808126035021283),
('[SEP]', 0.0)]}}
Visualizing explanations
With a single call to the visualize() method we get a nice inline display of what inputs are causing the activations to fire that led to classifying each of the tokens into a particular class.
Ignore indexes
To save computation time, we can indicate a list of token indexes that we want to ignore. The explainer will not compute explanations for these tokens, although attributions of these tokens will be calculated to explain the predictions over other tokens.
attributions_2 = ner_explainer(sample_text, ignored_indexes=[0, 3, 4, 5])
When we visualize these attributions it will be much more concise:
Ignore labels
In a similar way, we can also tell the explainer to ignore certain labels, e.g. we might not be interested in seeing the explanations of those tokens that are classified as 'O'.
attributions_3 = ner_explainer(sample_text, ignored_labels=['O'])
Which result in:
v0.6.0 - MultiLabelClassificationExplainer 🏷️🏷️🏷️🏷️
MultiLabelClassificationExplainer (#79)
Extends the existing sequence classification explainer into a new explainer that independently produces attributions for each label in the model regardless of what the predicted class is. This allows users to better inspect and interpret model predictions across all classes, particularly in situations where classifiers might be used in a multilabel fashion.
The MultiLabelClassificationExplainer returns a dictionary mapping labels/classes to a list of word attributions, additionally the visualize() method will display the entire table of attributions for each label.
This has been a very requested feature for a number of months so we're very happy to get it released (finally)
CC: @MichalMalyska @rhettdsouza13 @fraserprice @JensVN98 @dheerajiiitv
How to use
This explainer is an extension of the SequenceClassificationExplainer
and is thus compatible with all sequence classification models from the Transformers package. The key change in this explainer is that it caclulates attributions for each label in the model's config and returns a dictionary of word attributions w.r.t to each label. The visualize()
method also displays a table of attributions with attributions calculated per label.
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import MultiLabelClassificationExplainer
model_name = "j-hartmann/emotion-english-distilroberta-base"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
cls_explainer = MultiLabelClassificationExplainer(model, tokenizer)
word_attributions = cls_explainer("There were many aspects of the film I liked, but it was frightening and gross in parts. My parents hated it.")
This produces a dictionary of word attributions mapping labels to a list of tuples for each word and it's attribution score.
Click to see word attribution dictionary
>>> word_attributions
{'anger': [('<s>', 0.0),
('There', 0.09002208622000409),
('were', -0.025129709879675187),
('many', -0.028852677974079328),
('aspects', -0.06341968013631565),
('of', -0.03587626320752477),
('the', -0.014813095892961287),
('film', -0.14087587475098232),
('I', 0.007367876912617766),
('liked', -0.09816592066307557),
(',', -0.014259517291745674),
('but', -0.08087144668471376),
('it', -0.10185214349220136),
('was', -0.07132244710777856),
('frightening', -0.4125361737439814),
('and', -0.021761663818889918),
('gross', -0.10423745223600908),
('in', -0.02383646952201854),
('parts', -0.027137622525091033),
('.', -0.02960415694062459),
('My', 0.05642774605113695),
('parents', 0.11146648216326158),
('hated', 0.8497975489280364),
('it', 0.05358116678115284),
('.', -0.013566277162080632),
('', 0.09293256725788422),
('</s>', 0.0)],
'disgust': [('<s>', 0.0),
('There', -0.035296263203072),
('were', -0.010224922196739717),
('many', -0.03747571761725605),
('aspects', 0.007696321643436715),
('of', 0.0026740873113235107),
('the', 0.0025752851265661335),
('film', -0.040890035285783645),
('I', -0.014710007408208579),
('liked', 0.025696806663391577),
(',', -0.00739107098314569),
('but', 0.007353791868893654),
('it', -0.00821368234753605),
('was', 0.005439709067819798),
('frightening', -0.8135974168445725),
('and', -0.002334953123414774),
('gross', 0.2366024374426269),
('in', 0.04314772995234148),
('parts', 0.05590472194035334),
('.', -0.04362554293972562),
('My', -0.04252694977895808),
('parents', 0.051580790911406944),
('hated', 0.5067406070057585),
('it', 0.0527491071885104),
('.', -0.008280280618652273),
('', 0.07412384603053103),
('</s>', 0.0)],
'fear': [('<s>', 0.0),
('There', -0.019615758046045408),
('were', 0.008033402634196246),
('many', 0.027772367717635423),
('aspects', 0.01334130725685673),
('of', 0.009186049991879768),
('the', 0.005828877177384549),
('film', 0.09882910753644959),
('I', 0.01753565003544039),
('liked', 0.02062597344466885),
(',', -0.004469530636560965),
('but', -0.019660439408176984),
('it', 0.0488084071292538),
('was', 0.03830859527501167),
('frightening', 0.9526443954511705),
('and', 0.02535156284103706),
('gross', -0.10635301961551227),
('in', -0.019190425328209065),
('parts', -0.01713006453323631),
('.', 0.015043169035757302),
('My', 0.017068079071414916),
('parents', -0.0630781275517486),
('hated', -0.23630028921273583),
('it', -0.056057044429020306),
('.', 0.0015102052077844612),
('', -0.010045048665404609),
('</s>', 0.0)],
'joy': [('<s>', 0.0),
('There', 0.04881772670614576),
('were', -0.0379316152427468),
('many', -0.007955371089444285),
('aspects', 0.04437296429416574),
('of', -0.06407011137335743),
('the', -0.07331568926973099),
('film', 0.21588462483311055),
('I', 0.04885724513463952),
('liked', 0.5309510543276107),
(',', 0.1339765195225006),
('but', 0.09394079060730279),
('it', -0.1462792330432028),
('was', -0.1358591558323458),
('frightening', -0.22184169339341142),
('and', -0.07504142930419291),
('gross', -0.005472075984252812),
('in', -0.0942152657437379),
('parts', -0.19345218754215965),
('.', 0.11096247277185402),
('My', 0.06604512262645984),
('parents', 0.026376541098236207),
('hated', -0.4988319510231699),
('it', -0.17532499366236615),
('.', -0.022609976138939034),
('', -0.43417114685294833),
('</s>', 0.0)],
'neutral': [('<s>', 0.0),
('There', 0.045984598036642205),
('were', 0.017142566357474697),
('many', 0.011419348619472542),
('aspects', 0.02558593440287365),
('of', 0.0186162232003498),
('the', 0.015616416841815963),
('film', -0.021190511300570092),
('I', -0.03572427925026324),
('liked', 0.027062554960050455),
(',', 0.02089914209290366),
('but', 0.025872618597570115),
('it', -0.002980407262316265),
('was', -0.022218157611174086),
('frightening', -0.2982516449116045),
('and', -0.01604643529040792),
('gross', -0.04573829263548096),
('in', -0.006511536166676108),
('parts', -0.011744224307968652),
('.', -0.01817041167875332),
('My', -0.07362312722231429),
('parents', -0.06910711601816408),
('hated', -0.9418903509267312),
('it', 0.022201795222373488),
('.', 0.025694319747309045),
('', 0.04276690822325994),
('</s>', 0.0)],
'sadness': [('<s>', 0.0),
('There', 0.028237893283377526),
('were', -0.04489910545229568),
('many', 0.004996044977269471),
('aspects', -0.1231292680125582),
('of', -0.04552690725956671),
('the', -0.022077819961347042),
('film', -0.14155752357877663),
('I', 0.04135347872193571),
('liked', -0.3097732540526099),
(',', 0.045114660009053134),
('but', 0.0963352125332619),
('it', -0.08120617610094617),
('was', -0.08516150809170213),
('frightening', -0.10386889639962761),
('and', -0.03931986389970189),
('gross', -0.2145059013625132),
('in', -0.03465423285571697),
('parts', -0.08676627134611635),
('.', 0.19025217371906333),
('My', 0.2582092561303794),
('parents', 0.15432351476960307),
('hated', 0.7262186310977987),
('it', -0.029160655114499095),
('.', -0.002758524253450406),
('', -0.33846410359182094),
('</s>', 0.0)],
'surprise': [('<s>', 0.0),
('There', 0.07196110795254315),
('were', 0.1434314520711312),
('many', 0.08812238369489701),
('aspects', 0.013432396769890982),
('of', -0.07127508805657243),
('the', -0.14079766624810955),
('film', -0.16881201614906485),
('I', 0.040595668935112135),
('liked', 0.03239855530171577),
(',', -0.17676382558158257),
('but', -0.03797939330341559),
('it', -0.029191325089641736),
('was', 0.01758013584108571),
('frightening', -0.221738963726823),
('and', -0.05126920277135527),
('gross', -0.33986913466614044),
('in', -0.018180366628697),
('parts', 0.02939418603252064),
('.', 0.018080129971003226),
('My', -0.08060162218059498),
('parents', 0.04351719139081836),
...
v0.5.1 - ZeroShotClassificationExplainer Improvements , Memory optimizations, and custom steps
Zero Shot Classification Explainer Improvements (#49)
- Changes the default behavior of how the zero shot classification explainer works by calculating attributions for each
label by default and displaying the attributions for every label in the visualization. This required some major reorganization of the former implementation.
Memory Optimizations (#54)
- Every explainer instance can now take an optional parameter
internal_batch_size
. - This helps prevent issues where the explainer would cause OOM errors because all the steps (50 by default) used to calculate the attributions are batched together.
- For large models like Longformer etc it is recommended to select very low values (1 or 2) for
internal_batch_size
(#51). - This addition has been extremely helpful in stabilizing the performance of the streamlit demo app which prior to this update was crashing frequently. Lowering internal_batch_size should greatly reduce memory overhead in situations where more than a single batch of gradients would cause OOM.
Internal Batch Size Example
cls_explainer('A very short 100 character text here!', internal_batch_size=1)
Custom Steps For Attribution (#54)
- Explainer instances can now also accept another optional parameter
n_steps
. Default value forn_steps
in Captum is 50. n_steps
controls the number of steps used to calculate the approximate attributions from the baseline inputs to the true inputs.- Higher values for
n_steps
should result in less noisy approximations than lower values but longer calculation times. - If
n_steps
is set to a particularly high value it is highly recommended to setinternal_batch_size
to a low value to prevent OOM issues.
N Steps Example
cls_explainer('A very short 100 character text here!', n_steps=100)
v0.4.0 - ZeroShotClassificationExplainer, Custom Labels for SequenceClassificationExplainer
Zero-Shot Classification Explainer (#19, #40)
This release introduces a new explainer ZeroShotClassificationExplainer
that allows for word attributions to be calculated in a zero shot manner with the appropriate NLI type models. To achieve this we implement zero shot exactly how the hugging face team did based off the following paper. A list of compatible models with this explainer can be found at the transformers model hub.
Below is an example of instantiating and calling the ZeroShotClassificationExplainer
:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import ZeroShotClassificationExplainer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
zero_shot_explainer = ZeroShotClassificationExplainer(model, tokenizer)
word_attributions = zero_shot_explainer(
"Today apple released the new Macbook showing off a range of new features found in the proprietary silicon chip computer. ",
labels = ["finance", "technology", "sports"], # any labels you want can be passed, even a list with one label
)
Which will return the following list of tuples:
>>> word_attributions
[('<s>', 0.0),
('Today', 0.0),
('apple', 0.22505152647747717),
('released', -0.16164146624851905),
('the', 0.5026975657258089),
('new', 0.052589263167955536),
('Mac', 0.2528325960993759),
('book', -0.06445090203729663),
('showing', -0.21204922293777534),
('off', 0.06319714817612732),
('a', 0.032048012090796815),
('range', 0.08553079346908955),
('of', 0.1409201107994034),
('new', 0.0515261917112576),
('features', -0.09656406466213506),
('found', 0.02336613296843605),
('in', -0.0011649894272190678),
('the', 0.14229640664777807),
('proprietary', -0.23169065661847646),
('silicon', 0.5963924257008087),
('chip', -0.19908474233975806),
('computer', 0.030620295844734646),
('.', 0.1995076958535378)]
We can find out which label was predicted with:
>>> zero_shot_explainer.predicted_label
'technology (entailment)'
For the ZeroShotClassificationExplainer
the visualize() method returns a table similar to the SequenceClassificationExplainer
.
zero_shot_explainer.visualize("zero_shot.html")
Custom Labels For Sequence Classification - @lalitpagaria (#25, #41 )
This contribution by @lalitpagaria adds the ability to add custom class labels that replace the default labels originally found in the models's config.
This is a very useful addition as it is quite common for popular trained models to have not set the label names resulting in labels that look like "LABEL_0", "LABEL_1",...
. This can make the sequence classification explainer visualization particularly hard to understand and not very readable.
Custom labels are passed to the SequenceClassificationExplainers
's constructor and the number of labels passed must equal the number of already existing labels:
seq_explainer = SequenceClassificationExplainer(
DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["sad", "happy"]
)
Now the class at the 0th index corresponds to the label "sad" and the class at the 1st index corresponds to "happy".
This is a really nice addition that makes Transformers Interpret more usable with a whole range sequence classification models that don't have labels set. Thanks @lalitpagaria .
General Cleanup and Housekeeping
- Cleaned up a number of flake8 reported linting errors
- Improved on the docstring for QA explainer (more to come however) but the QA explainer still needs some finalization
- Added some increased testing coverage
- Added a contribution guideline
v0.3.1 Constructor changes, Question Answering Explainer (experimental)
Changes to explainer constructor
In previous versions the constructor for explainers allowed for text to passed along with the model and tokenizer, text could also be passed to the explainer instance and would replace the text passed in the constructor. This behavior was confusing and also didn't work with integrating into future explainers. This version changes that behavior so that only the model and tokenizer are passed to the constructor and text is always passed to the instance.
Old
cls_explainer = SequenceClassificationExplainer(
"I love you, I like you",
model,
tokenizer)
cls_explainer("I hate you, I loathe you") # overwrites intial text
New
cls_explainer = SequenceClassificationExplainer(
model,
tokenizer)
cls_explainer("I love you, I like you")
Question Answering Explainer (Experimental release)
This release adds an initial implementation of an explainer for question-answering models from the Huggingface library called the QuestionAnsweringExplainer
. This explainer is still somewhat experimental and has not been tested with a wide range of models and architectures.
I'm still figuring some things out such as the default embeddings to explain attributions for. In some cases I can get the average attributions for three embedding types combined (word, position, token type) while in others this is not possible so I currently default to word attributions. While this was fine for the classification explainer, position ids and token type ids play a big role in calculating attributions for question-answering models.
How to create a QA explainer
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from transformers_interpret import QuestionAnsweringExplainer
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
qa_explainer = QuestionAnsweringExplainer(
model,
tokenizer,
)
Getting word attributions
context = """
In Artificial Intelligence and machine learning, Natural Language Processing relates to the usage of machines to process and understand human language.
Many researchers currently work in this space.
"""
word_attributions = qa_explainer(
"What is natural language processing ?",
context,
)
word_attributions
are a dict with two keys start
and end
the values for each key are a list of tuples. The values in start
are the word attributions for the predicted start position from the context
and the values in end
are for the predicted end position.
We can get the text span for the predicted answer with
>>> qa_explainer.predicted_answer
'usage of machines to process and understand human language'
Like the classification explainer attributions can be visualized with
qa_explainer.visualize("bert_qa_viz.html")
This will create a table of two rows, the first for start position attributions and the second for end position attributions.
Please report any bugs or quirks you find with this new explainer. Thanks
v0.2.0
Support Attributions for Multiple Embedding Types
SequenceClassificationExplainer
now has support for word attributions for bothword_embeddings
andposition_embeddings
for model's whereposition_ids
are part of a model's forward method. Embeddings for attribution can be set with the class's call method.
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("aychang/roberta-base-imdb")
model = AutoModelForSequenceClassification.from_pretrained("aychang/roberta-base-imdb")
Word Embedding Attributions
from transformers_interpret import SequenceClassificationExplainer
cls_explainer = SequenceClassificationExplainer(
"This was a really good film I enjoyed it a lot",
model,
tokenizer)
attributions = cls_explainer(embedding_type=0) # 0 = word
>>> attributions.word_attributions
[('<s>', 0.0),
('This', -0.3240508614377356),
('was', 0.1438011922867732),
('a', 0.2243325698743557),
('really', 0.2303368793560317),
('good', -0.0600901206724276),
('film', 0.01613507050261139),
('I', 0.002752767414682212),
('enjoyed', 0.36666383287176274),
('it', 0.46981294407030466),
('a', 0.15187907852049023),
('lot', 0.6235539369814076),
('</s>', 0.0)]
Position Embedding Attributions
from transformers_interpret import SequenceClassificationExplainer
cls_explainer = SequenceClassificationExplainer(
"This was a really good film I enjoyed it a lot",
model,
tokenizer)
attributions = cls_explainer(embedding_type=1) # 1 = position
>>> attributions.word_attributions
[('<s>', 0.0),
('This', -0.011571866816239364),
('was', 0.9746020664206717),
('a', 0.06633740353266766),
('really', 0.007891184021722232),
('good', 0.11340512797772889),
('film', -0.1035443669783489),
('I', -0.030966387400513003),
('enjoyed', -0.07312861129345115),
('it', -0.062475007741951326),
('a', 0.05681161636240444),
('lot', 0.04342110477675596),
('</s>', 0.08154160609887448)]
Additional Functionality Added To Base Explainer
To support multiple embedding types for the classification explainer a number of handlers were added to the BaseExplainer
to allow this functionality to be added easily to future explainers.
BaseExplainer
inspects signature of a model's forward function and determines whether it receivesposition_ids
andtoken_type_ids
. For example Bert models take both as optional parameters whereas distilbert does not.- From this inspection the available embedding types are set in the
BaseExplainer
rather than in explainers that inherit from it.
Misc
- Updated tests, many of the tests in the suite now test out 3 different architectures Bert, Distilbert and GPT2. This helps iron out any issues with slight variations that these model's have.
v0.1.11
- Classification explainer model's now use
model.get_input_embeddings()
by default. This is equivalent to calculating attributions exclusively w.r.t to word embeddings. This makes attribution calculation much more streamlined for many more models and will greatly improve compatibility with different architectures. (#21) SequenceClassificationExplainer
'svisualize
method now returns theHTML
object by default. 99c6db7