Skip to content

Commit

Permalink
Merge pull request #79 from cdpierse/feature/multilabel-classificatio…
Browse files Browse the repository at this point in the history
…n-explainer

MultiLabel Classification Explainer
  • Loading branch information
cdpierse authored Mar 3, 2022
2 parents 2f22c58 + 8e66166 commit eb67078
Show file tree
Hide file tree
Showing 24 changed files with 570 additions and 297 deletions.
Binary file added .DS_Store
Binary file not shown.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 21.12b0
hooks:
- id: black
args: [--line-length=120, --target-version=py38]
237 changes: 237 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Check out the streamlit [demo app here](https://share.streamlit.io/cdpierse/tran
- [Sequence Classification Explainer](#sequence-classification-explainer)
- [Visualize Classification attributions](#visualize-classification-attributions)
- [Explaining Attributions for Non Predicted Class](#explaining-attributions-for-non-predicted-class)
- [MultiLabel Classification Explainer](#sequence-classification-explainer)
- [Visualize MultiLabel Classification attributions](#visualize-multilabel-attributions)
- [Zero Shot Classification Explainer](#zero-shot-classification-explainer)
- [Visualize Zero Shot Classification attributions](#visualize-zero-shot-classification-attributions)
- [Question Answering Explainer (Experimental)](#question-answering-explainer-experimental)
Expand Down Expand Up @@ -173,6 +175,241 @@ Getting attributions for different classes is particularly insightful for multic
For a detailed explanation of this example please checkout this [multiclass classification notebook.](notebooks/multiclass_classification_example.ipynb)


</details>

### MultiLabel Classification Explainer

<details><summary>Click to expand</summary>

This explainer is an extension of the `SequenceClassificationExplainer` and is thus compatible with all sequence classification models from the Transformers package. The key change in this explainer is that it caclulates attributions for each label in the model's config and returns a dictionary of word attributions w.r.t to each label. The `visualize()` method also displays a table of attributions with attributions calculated per label.

```python
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import MultiLabelClassificationExplainer

model_name = "j-hartmann/emotion-english-distilroberta-base"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


cls_explainer = MultiLabelClassificationExplainer(model, tokenizer)


word_attributions = cls_explainer("There were many aspects of the film I liked, but it was frightening and gross in parts. My parents hated it.")
```
This produces a dictionary of word attributions mapping labels to a list of tuples for each word and it's attribution score.
<details><summary>Click to see word attribution dictionary</summary>

```python
>>> word_attributions
{'anger': [('<s>', 0.0),
('There', 0.09002208622000409),
('were', -0.025129709879675187),
('many', -0.028852677974079328),
('aspects', -0.06341968013631565),
('of', -0.03587626320752477),
('the', -0.014813095892961287),
('film', -0.14087587475098232),
('I', 0.007367876912617766),
('liked', -0.09816592066307557),
(',', -0.014259517291745674),
('but', -0.08087144668471376),
('it', -0.10185214349220136),
('was', -0.07132244710777856),
('frightening', -0.4125361737439814),
('and', -0.021761663818889918),
('gross', -0.10423745223600908),
('in', -0.02383646952201854),
('parts', -0.027137622525091033),
('.', -0.02960415694062459),
('My', 0.05642774605113695),
('parents', 0.11146648216326158),
('hated', 0.8497975489280364),
('it', 0.05358116678115284),
('.', -0.013566277162080632),
('', 0.09293256725788422),
('</s>', 0.0)],
'disgust': [('<s>', 0.0),
('There', -0.035296263203072),
('were', -0.010224922196739717),
('many', -0.03747571761725605),
('aspects', 0.007696321643436715),
('of', 0.0026740873113235107),
('the', 0.0025752851265661335),
('film', -0.040890035285783645),
('I', -0.014710007408208579),
('liked', 0.025696806663391577),
(',', -0.00739107098314569),
('but', 0.007353791868893654),
('it', -0.00821368234753605),
('was', 0.005439709067819798),
('frightening', -0.8135974168445725),
('and', -0.002334953123414774),
('gross', 0.2366024374426269),
('in', 0.04314772995234148),
('parts', 0.05590472194035334),
('.', -0.04362554293972562),
('My', -0.04252694977895808),
('parents', 0.051580790911406944),
('hated', 0.5067406070057585),
('it', 0.0527491071885104),
('.', -0.008280280618652273),
('', 0.07412384603053103),
('</s>', 0.0)],
'fear': [('<s>', 0.0),
('There', -0.019615758046045408),
('were', 0.008033402634196246),
('many', 0.027772367717635423),
('aspects', 0.01334130725685673),
('of', 0.009186049991879768),
('the', 0.005828877177384549),
('film', 0.09882910753644959),
('I', 0.01753565003544039),
('liked', 0.02062597344466885),
(',', -0.004469530636560965),
('but', -0.019660439408176984),
('it', 0.0488084071292538),
('was', 0.03830859527501167),
('frightening', 0.9526443954511705),
('and', 0.02535156284103706),
('gross', -0.10635301961551227),
('in', -0.019190425328209065),
('parts', -0.01713006453323631),
('.', 0.015043169035757302),
('My', 0.017068079071414916),
('parents', -0.0630781275517486),
('hated', -0.23630028921273583),
('it', -0.056057044429020306),
('.', 0.0015102052077844612),
('', -0.010045048665404609),
('</s>', 0.0)],
'joy': [('<s>', 0.0),
('There', 0.04881772670614576),
('were', -0.0379316152427468),
('many', -0.007955371089444285),
('aspects', 0.04437296429416574),
('of', -0.06407011137335743),
('the', -0.07331568926973099),
('film', 0.21588462483311055),
('I', 0.04885724513463952),
('liked', 0.5309510543276107),
(',', 0.1339765195225006),
('but', 0.09394079060730279),
('it', -0.1462792330432028),
('was', -0.1358591558323458),
('frightening', -0.22184169339341142),
('and', -0.07504142930419291),
('gross', -0.005472075984252812),
('in', -0.0942152657437379),
('parts', -0.19345218754215965),
('.', 0.11096247277185402),
('My', 0.06604512262645984),
('parents', 0.026376541098236207),
('hated', -0.4988319510231699),
('it', -0.17532499366236615),
('.', -0.022609976138939034),
('', -0.43417114685294833),
('</s>', 0.0)],
'neutral': [('<s>', 0.0),
('There', 0.045984598036642205),
('were', 0.017142566357474697),
('many', 0.011419348619472542),
('aspects', 0.02558593440287365),
('of', 0.0186162232003498),
('the', 0.015616416841815963),
('film', -0.021190511300570092),
('I', -0.03572427925026324),
('liked', 0.027062554960050455),
(',', 0.02089914209290366),
('but', 0.025872618597570115),
('it', -0.002980407262316265),
('was', -0.022218157611174086),
('frightening', -0.2982516449116045),
('and', -0.01604643529040792),
('gross', -0.04573829263548096),
('in', -0.006511536166676108),
('parts', -0.011744224307968652),
('.', -0.01817041167875332),
('My', -0.07362312722231429),
('parents', -0.06910711601816408),
('hated', -0.9418903509267312),
('it', 0.022201795222373488),
('.', 0.025694319747309045),
('', 0.04276690822325994),
('</s>', 0.0)],
'sadness': [('<s>', 0.0),
('There', 0.028237893283377526),
('were', -0.04489910545229568),
('many', 0.004996044977269471),
('aspects', -0.1231292680125582),
('of', -0.04552690725956671),
('the', -0.022077819961347042),
('film', -0.14155752357877663),
('I', 0.04135347872193571),
('liked', -0.3097732540526099),
(',', 0.045114660009053134),
('but', 0.0963352125332619),
('it', -0.08120617610094617),
('was', -0.08516150809170213),
('frightening', -0.10386889639962761),
('and', -0.03931986389970189),
('gross', -0.2145059013625132),
('in', -0.03465423285571697),
('parts', -0.08676627134611635),
('.', 0.19025217371906333),
('My', 0.2582092561303794),
('parents', 0.15432351476960307),
('hated', 0.7262186310977987),
('it', -0.029160655114499095),
('.', -0.002758524253450406),
('', -0.33846410359182094),
('</s>', 0.0)],
'surprise': [('<s>', 0.0),
('There', 0.07196110795254315),
('were', 0.1434314520711312),
('many', 0.08812238369489701),
('aspects', 0.013432396769890982),
('of', -0.07127508805657243),
('the', -0.14079766624810955),
('film', -0.16881201614906485),
('I', 0.040595668935112135),
('liked', 0.03239855530171577),
(',', -0.17676382558158257),
('but', -0.03797939330341559),
('it', -0.029191325089641736),
('was', 0.01758013584108571),
('frightening', -0.221738963726823),
('and', -0.05126920277135527),
('gross', -0.33986913466614044),
('in', -0.018180366628697),
('parts', 0.02939418603252064),
('.', 0.018080129971003226),
('My', -0.08060162218059498),
('parents', 0.04351719139081836),
('hated', -0.6919028585285265),
('it', 0.0009574844165327357),
('.', -0.059473118237873344),
('', -0.465690452620123),
('</s>', 0.0)]}
```
</details>


#### Visualize MultiLabel Classification attributions

Sometimes the numeric attributions can be difficult to read particularly in instances where there is a lot of text. To help with that we also provide the `visualize()` method that utilizes Captum's in built viz library to create a HTML file highlighting the attributions. For this explainer attributions will be show w.r.t to each label.

If you are in a notebook, calls to the `visualize()` method will display the visualization in-line. Alternatively you can pass a filepath in as an argument and an HTML file will be created, allowing you to view the explanation HTML in your browser.

```python
cls_explainer.visualize("multilabel_viz.html")
```

<a href="https://github.com/cdpierse/transformers-interpret/blob/master/images/multilabel_example.png">
<img src="https://github.com/cdpierse/transformers-interpret/blob/master/images/multilabel_example.png" width="80%" height="80%" align="center"/>
</a>


</details>

### Zero Shot Classification Explainer
Expand Down
Binary file added images/.DS_Store
Binary file not shown.
Binary file added images/multilabel_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added notebooks/.DS_Store
Binary file not shown.
4 changes: 1 addition & 3 deletions notebooks/multiclass_classification_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"sampathkethineedi/industry-classification\")\n",
"model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"sampathkethineedi/industry-classification\"\n",
")"
"model = AutoModelForSequenceClassification.from_pretrained(\"sampathkethineedi/industry-classification\")"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest==5.4.2
captum==0.3.1
transformers==4.3.2
ipython==7.31.1
captum==0.4.1
transformers==4.15.0
ipython==7.31.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"test",
]
),
version="0.5.2",
version="0.6.0",
license="Apache-2.0",
description="Transformers Interpret is a model explainability tool designed to work exclusively with 🤗 transformers.",
long_description=long_description,
Expand Down
Binary file added test/.DS_Store
Binary file not shown.
28 changes: 7 additions & 21 deletions test/test_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def test_explainer_init_cuda():

def test_explainer_make_input_reference_pair():
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
"this is a test string"
)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
assert isinstance(input_ids, Tensor)
assert isinstance(ref_input_ids, Tensor)
assert isinstance(len_inputs, int)
Expand All @@ -139,9 +137,7 @@ def test_explainer_make_input_reference_pair():

def test_explainer_make_input_reference_pair_gpt2():
explainer = DummyExplainer(GPT2_MODEL, GPT2_TOKENIZER)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
"this is a test string"
)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
assert isinstance(input_ids, Tensor)
assert isinstance(ref_input_ids, Tensor)
assert isinstance(len_inputs, int)
Expand All @@ -151,9 +147,7 @@ def test_explainer_make_input_reference_pair_gpt2():

def test_explainer_make_input_token_type_pair_no_sep_idx():
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
"this is a test string"
)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
(
token_type_ids,
ref_token_type_ids,
Expand All @@ -169,9 +163,7 @@ def test_explainer_make_input_token_type_pair_no_sep_idx():

def test_explainer_make_input_token_type_pair_sep_idx():
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
"this is a test string"
)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
(
token_type_ids,
ref_token_type_ids,
Expand All @@ -187,12 +179,8 @@ def test_explainer_make_input_token_type_pair_sep_idx():

def test_explainer_make_input_reference_position_id_pair():
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
"this is a test string"
)
position_ids, ref_position_ids = explainer._make_input_reference_position_id_pair(
input_ids
)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
position_ids, ref_position_ids = explainer._make_input_reference_position_id_pair(input_ids)

assert ref_position_ids[0][0] == torch.zeros(len(input_ids[0]))[0]
for i, val in enumerate(position_ids[0]):
Expand All @@ -201,9 +189,7 @@ def test_explainer_make_input_reference_position_id_pair():

def test_explainer_make_attention_mask():
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
"this is a test string"
)
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
attention_mask = explainer._make_attention_mask(input_ids)
assert len(attention_mask[0]) == len(input_ids[0])
for i, val in enumerate(attention_mask[0]):
Expand Down
Loading

0 comments on commit eb67078

Please sign in to comment.