Skip to content

Commit

Permalink
Torchserve-Kservev2-Bert-Explanations update (kubeflow#2043)
Browse files Browse the repository at this point in the history
Signed-off-by: Shrinath Suresh <shrinath@ideas2it.com>
  • Loading branch information
shrinath-suresh authored Feb 18, 2022
1 parent 7464e0c commit 286e2a1
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
4 changes: 2 additions & 2 deletions docs/samples/v1beta1/torchserve/v2/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Download the preprocess script from [here](./sequence_classification/Transformer
```bash
torch-model-archiver --model-name BERTSeqClassification --version 1.0 \
--serialized-file Transformer_model/pytorch_model.bin \
--handler ./Transformer_handler_generalized_v2.py \
--handler ./Transformer_kserve_handler.py \
--extra-files "Transformer_model/config.json,./setup_config.json,./Seq_classification_artifacts/index_to_name.json,./Transformer_handler_generalized.py"
```

Expand Down Expand Up @@ -79,7 +79,7 @@ curl -v -H "Host: ${SERVICE_HOSTNAME}" http://${INGRESS_HOST}:${INGRESS_PORT}/v2
Expected output

```bash
{"id": "fa0f4a16-24be-4e82-822b-7ce21cff1016", "model_name": "bert_test", "model_version": "1", "outputs": [{"name": "explain", "shape": [], "datatype": "BYTES", "data": [{"words": ["[CLS]", "[unused65]", "[unused103]", "[unused106]", "[unused106]", "[unused104]", "[unused97]", "[CLS]", "[unused109]", "[MASK]", "[unused31]", "[unused99]", "[unused96]", "[unused110]", "[unused31]", "[unused109]", "[CLS]", "[unused107]", "[unused106]", "[unused109]", "[unused111]", "[CLS]", "[UNK]", "[unused31]", "[unused106]", "[unused105]", "[unused31]", "[unused111]", "[unused99]", "[CLS]", "[unused31]", "[CLS]", "[unused98]", "[unused106]", "[unused105]", "[unused106]", "[unused104]", "[unused116]", "[SEP]"], "importances": [-0.5779647849140105, 0.017149979253482668, 0.02520071691362777, 0.10127131153071542, 0.11157838511306105, 0.10381272285539787, 0.11320268752645515, -0.18749022141160918, 0.09715615163453448, -0.23825046155397892, 0.07830538237901745, 0.052386644292540425, 0.06916019909789417, 0.0489200370513321, 0.06125091233381835, 0.10910945892939933, -0.20546550665577787, 0.03657186541090417, 0.03873832137700618, 0.07419369954398138, 0.03729456936648431, -0.2576498669080684, -0.14288095272100626, 0.04121622648595307, 0.06318685560063542, 0.012703899463284731, 0.03181142622138418, 0.03485410565174061, 0.049515843720263124, -0.18949917348232484, 0.03956454265824759, -0.2113086240763918, 0.028525852720988263, 0.04318882441540453, 0.018988349248547743, 0.07123601660669067, 0.061472429104257806, 0.023899392506903514, 0.49172702017614983], "delta": 0.9374768388549066}]}]}
{"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298", "model_name": "BERTSeqClassification", "model_version": "1.0", "outputs": [{"name": "explain", "shape": [], "datatype": "BYTES", "data": [{"words": ["[CLS]", "bloomberg", "has", "decided", "to", "publish", "a", "new", "report", "on", "the", "global", "economy", ".", "[SEP]"], "importances": [0.0, -0.43571255624310423, -0.11062097534384648, 0.11323803203829622, 0.05438679692935377, -0.11364841625009202, 0.15214504085858935, -0.0013061684457894148, 0.05712844103997178, -0.02296408323390218, 0.1937543236757826, -0.12138265438655091, 0.20713335609474381, -0.8044260616647264, 0.0], "delta": -0.019047775223331675}]}]}
```


Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import ast
import torch
import logging
from Transformer_handler_generalized import TransformersSeqClassifierHandler
from Transformer_handler_generalized import (
TransformersSeqClassifierHandler,
captum_sequence_forward,
construct_input_ref,
summarize_attributions,
get_word_token,
)
import json
from captum.attr import LayerIntegratedGradients

logger = logging.getLogger(__name__)
# TODO Extend the example for token classification, question answering and batch inputs


class TransformersSeqClassifierHandlerV2(TransformersSeqClassifierHandler):
class TransformersKserveHandler(TransformersSeqClassifierHandler):
def __init__(self):
super(TransformersKserveHandler, self).__init__()

def preprocess(self, requests):
"""Basic text preprocessing, based on the user's chocie of application mode.
Args:
Expand Down Expand Up @@ -46,12 +56,7 @@ def preprocess(self, requests):
input_text = data.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
if (
self.setup_config["captum_explanation"]
and not self.setup_config["mode"] == "question_answering"
):
input_text_target = ast.literal_eval(input_text)
input_text = input_text_target["text"]
input_text = json.loads(input_text)["text"]
max_length = self.setup_config["max_length"]
logger.info("Received text: '%s'", input_text)

Expand All @@ -74,3 +79,47 @@ def preprocess(self, requests):
input_ids_batch = torch.cat((input_ids_batch, input_ids), 0)
attention_mask_batch = torch.cat((attention_mask_batch, attention_mask), 0)
return (input_ids_batch, attention_mask_batch)

def get_insights(self, input_batch, text, target):
"""This function initialize and calls the layer integrated gradient to get word importance
of the input text if captum explanation has been selected through setup_config
Args:
input_batch (int): Batches of tokens IDs of text
text (str): The Text specified in the input request
target (int): The Target can be set to any acceptable label under the user's discretion.
Returns:
(list): Returns a list of importances and words.
"""
data = json.loads(text)
text = data["text"]
target = data["target"]

if self.setup_config["captum_explanation"]:
embedding_layer = getattr(self.model, self.setup_config["embedding_name"])
embeddings = embedding_layer.embeddings
self.lig = LayerIntegratedGradients(captum_sequence_forward, embeddings)
else:
logger.warning("Captum Explanation is not chosen and will not be available")

self.target = target

input_ids, ref_input_ids, attention_mask = construct_input_ref(
text, self.tokenizer, self.device, self.setup_config["mode"]
)
all_tokens = get_word_token(input_ids, self.tokenizer)
response = {}
response["words"] = all_tokens

attributions, delta = self.lig.attribute(
inputs=input_ids,
baselines=ref_input_ids,
target=self.target,
additional_forward_args=(attention_mask, 0, self.model),
return_convergence_delta=True,
)

attributions_sum = summarize_attributions(attributions)
response["importances"] = attributions_sum.tolist()
response["delta"] = delta[0].tolist()

return [response]
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"name": "4b7c7d4a-51e4-43c8-af61-04639f6ef4bc",
"shape": -1,
"datatype": "BYTES",
"data": "this year business is good"
}]
"data": "{\"text\":\"Bloomberg has decided to publish a new report on the global economy.\", \"target\":1}"
}
]
}

0 comments on commit 286e2a1

Please sign in to comment.