Skip to content

Commit

Permalink
Alibi-explain runtime kernel_shap fixes (#1017)
Browse files Browse the repository at this point in the history
* remove  constant not used

* add batch and infer output in Alibi Settings

* add batch condition

* add comment on the explainer types for non-batch

* define `output` to be used for inference payload

* refactor variable

* add tests for specified output

* wire up output

* add test for explainer batch

* add note in test

* add missing await in load from uri

* lint fix
  • Loading branch information
sakoush authored Feb 28, 2023
1 parent 2bac2d5 commit 71da660
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 27 deletions.
22 changes: 18 additions & 4 deletions runtimes/alibi-explain/mlserver_alibi_explain/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
InferenceRequest,
Parameters,
MetadataModelResponse,
RequestOutput,
)
from mlserver.utils import generate_uuid

Expand All @@ -22,8 +23,6 @@

EXPLAINER_TYPE_TAG = "explainer_type"

_MAX_RETRY_ATTEMPT = 3

ENV_PREFIX_ALIBI_EXPLAIN_SETTINGS = "MLSERVER_MODEL_ALIBI_EXPLAIN_"
EXPLAIN_PARAMETERS_TAG = "explain_parameters"
SELDON_SKIP_LOGGING_HEADER = "Seldon-Skip-Logging"
Expand Down Expand Up @@ -92,6 +91,10 @@ class Config:

infer_uri: str
explainer_type: str
explainer_batch: Optional[bool] = False
# In cases where the inference model can output multiple fields and
# we are interested in a specific field for explanation
infer_output: Optional[str]
init_parameters: Optional[dict]
ssl_verify_path: Optional[str]

Expand All @@ -105,10 +108,14 @@ def import_and_get_class(class_path: str) -> type:
def to_v2_inference_request(
input_data: Union[np.ndarray, List[str]],
metadata: Optional[MetadataModelResponse],
output: Optional[str],
) -> InferenceRequest:
"""
Encode numpy payload to v2 protocol.
If `output` is set, it takes precedence over any outputs that are set in the
`metadata` given that the user then is explicitly defining an output.
Note: We only fetch the first-input name and the list of outputs from the metadata
endpoint currently. We should consider wider reconciliation with data types etc.
Expand All @@ -118,19 +125,26 @@ def to_v2_inference_request(
Numpy ndarray to encode
metadata
Extra metadata that can help encode the payload.
output
The output we care about to explain from the inference model.
"""

# MLServer does not really care about a correct input name!
input_name = _DEFAULT_INPUT_NAME
id_name = generate_uuid()
default_outputs = []
outputs = []

if output:
outputs = [RequestOutput(name=output)]

if metadata is not None:
if metadata.inputs:
# we only support a big single input numpy
input_name = metadata.inputs[0].name
if metadata.outputs:
outputs = metadata.outputs
if not output:
default_outputs = [metadata.outputs[0]]

# For List[str] (e.g. AnchorText), we use StringCodec for input
input_payload_codec = StringCodec if type(input_data) == list else NumpyCodec
Expand All @@ -146,6 +160,6 @@ def to_v2_inference_request(
use_bytes=False,
)
],
outputs=outputs,
outputs=outputs if outputs else default_outputs,
)
return v2_request
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ async def load(self) -> bool:
return self.ready

def _explain_impl(self, input_data: Any, explain_parameters: Dict) -> Explanation:
# if we get a list of strings, we can only explain the first elem and there
# is no way of just sending a plain string in v2, it has to be in a list
# as the encoding is List[str] with content_type "BYTES"
# we also assume that the explain data will contain a batch dimension, and in
# current implementation we will only explain the first data element.
input_data = input_data[0]
if not self.alibi_explain_settings.explainer_batch:
# if we get a list of strings, we can only explain the first elem and there
# is no way of just sending a plain string in v2, it has to be in a list
# as the encoding is List[str] with content_type "BYTES"
# we also assume that the explain data will contain a batch dimension,
# and in current implementation we will only explain the first data element.
# this is for explainers that do not support batch, e.g. anchors
input_data = input_data[0]

return self._model.explain(input_data, **explain_parameters)

Expand All @@ -74,7 +76,11 @@ def _infer_impl(self, input_data: Union[np.ndarray, List[str]]) -> np.ndarray:
meta_url, ssl_verify_path=self.ssl_verify_path
)

v2_request = to_v2_inference_request(input_data, self.infer_metadata)
v2_request = to_v2_inference_request(
input_data=input_data,
metadata=self.infer_metadata,
output=self.alibi_explain_settings.infer_output,
)
v2_response = remote_predict(
v2_payload=v2_request,
predictor_url=self.infer_uri,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def load(self) -> bool:
init_parameters["model"] = self._inference_model
self._model = self._explainer_class(**init_parameters) # type: ignore
else:
self._model = self._load_from_uri(self._inference_model)
self._model = await self._load_from_uri(self._inference_model)

self.ready = True
return self.ready
Expand Down
10 changes: 4 additions & 6 deletions runtimes/alibi-explain/tests/test_alibi_runtime_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from typing import Any, Dict
from unittest.mock import patch

import alibi.explainers.anchors.anchor_tabular
import pytest
import numpy as np

from typing import Any, Dict
from unittest.mock import patch
import pytest
from alibi.api.interfaces import Explanation
from numpy.testing import assert_array_equal

Expand All @@ -24,9 +23,8 @@
remote_predict,
AlibiExplainSettings,
)
from mlserver_alibi_explain.runtime import AlibiExplainRuntime, AlibiExplainRuntimeBase
from mlserver_alibi_explain.errors import InvalidExplanationShape

from mlserver_alibi_explain.runtime import AlibiExplainRuntime, AlibiExplainRuntimeBase
from .helpers.run_async import run_async_as_sync, run_sync_as_async

"""
Expand Down
164 changes: 155 additions & 9 deletions runtimes/alibi-explain/tests/test_black_box.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
import json
import os
from pathlib import Path
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock
from unittest.mock import patch

import alibi.explainers.anchors.anchor_tabular
import numpy as np
import pytest
import tensorflow as tf
from alibi.api.interfaces import Explanation
from alibi.saving import load_explainer
from numpy.testing import assert_allclose

from mlserver import MLModel
from mlserver.codecs import NumpyCodec, StringCodec
from mlserver import ModelSettings, MLModel
from mlserver.codecs import NumpyCodec
from mlserver.codecs import StringCodec
from mlserver.types import (
InferenceRequest,
InferenceResponse,
Parameters,
RequestInput,
MetadataModelResponse,
MetadataTensor,
)
from mlserver.types import (
MetadataModelResponse,
RequestOutput,
)
from mlserver_alibi_explain import AlibiExplainRuntime
from mlserver_alibi_explain.common import (
convert_from_bytes,
)
from mlserver_alibi_explain.common import (
to_v2_inference_request,
_DEFAULT_INPUT_NAME,
)

from mlserver_alibi_explain.explainers.black_box_runtime import (
AlibiExplainBlackBoxRuntime,
)
from mlserver_alibi_explain.runtime import AlibiExplainRuntime
from .helpers.run_async import run_sync_as_async
from .helpers.tf_model import get_tf_mnist_model_uri

Expand Down Expand Up @@ -137,12 +148,13 @@ async def test_end_2_end_explain_v1_output(


@pytest.mark.parametrize(
"payload, metadata, expected_v2_request",
"payload, metadata, output, expected_v2_request",
[
# numpy payload
(
np.zeros([2, 4]),
None,
None,
InferenceRequest(
id=_DEFAULT_ID_NAME,
parameters=Parameters(content_type=NumpyCodec.ContentType),
Expand All @@ -169,6 +181,7 @@ async def test_end_2_end_explain_v1_output(
MetadataTensor(name="output_name", datatype="dummy", shape=[])
],
),
None,
InferenceRequest(
id=_DEFAULT_ID_NAME,
parameters=Parameters(content_type=NumpyCodec.ContentType),
Expand All @@ -186,10 +199,93 @@ async def test_end_2_end_explain_v1_output(
], # inserted from metadata above
),
),
# multiple outputs in the metadata, return only the first output
(
np.zeros([2, 4]),
MetadataModelResponse(
name="dummy",
platform="dummy",
inputs=[MetadataTensor(name="input_name", datatype="dummy", shape=[])],
outputs=[
MetadataTensor(name="output_name", datatype="dummy", shape=[]),
MetadataTensor(name="output_name_2", datatype="dummy", shape=[]),
],
),
None,
InferenceRequest(
id=_DEFAULT_ID_NAME,
parameters=Parameters(content_type=NumpyCodec.ContentType),
inputs=[
RequestInput(
parameters=Parameters(content_type=NumpyCodec.ContentType),
name="input_name", # inserted from metadata above
data=np.zeros([2, 4]).flatten().tolist(),
shape=[2, 4],
datatype="FP64", # default for np.zeros
)
],
outputs=[
RequestOutput(name="output_name"),
], # inserted from metadata above
),
),
# Specified output
(
np.zeros([2, 4]),
None,
"output_name",
InferenceRequest(
id=_DEFAULT_ID_NAME,
parameters=Parameters(content_type=NumpyCodec.ContentType),
inputs=[
RequestInput(
parameters=Parameters(content_type=NumpyCodec.ContentType),
name=_DEFAULT_INPUT_NAME,
data=np.zeros([2, 4]).flatten().tolist(),
shape=[2, 4],
datatype="FP64", # default for np.zeros
)
],
outputs=[
RequestOutput(name="output_name"),
], # inserted from output
),
),
# Specified output and metadata
(
np.zeros([2, 4]),
MetadataModelResponse(
name="dummy",
platform="dummy",
inputs=[MetadataTensor(name="input_name", datatype="dummy", shape=[])],
outputs=[
MetadataTensor(name="output_name", datatype="dummy", shape=[]),
MetadataTensor(name="output_name_2", datatype="dummy", shape=[]),
],
),
"output_name_2",
InferenceRequest(
id=_DEFAULT_ID_NAME,
parameters=Parameters(content_type=NumpyCodec.ContentType),
inputs=[
RequestInput(
parameters=Parameters(content_type=NumpyCodec.ContentType),
name="input_name", # from metadata
data=np.zeros([2, 4]).flatten().tolist(),
shape=[2, 4],
datatype="FP64", # default for np.zeros
)
],
outputs=[
RequestOutput(name="output_name_2"),
], # from output
),
),
# List[str] payload
(
["dummy", "dummy text"],
None,
None,
InferenceRequest(
id=_DEFAULT_ID_NAME,
parameters=Parameters(content_type=StringCodec.ContentType),
Expand All @@ -211,6 +307,56 @@ async def test_end_2_end_explain_v1_output(
"mlserver_alibi_explain.common.generate_uuid",
MagicMock(return_value=_DEFAULT_ID_NAME),
)
def test_encode_inference_request__as_expected(payload, metadata, expected_v2_request):
encoded_request = to_v2_inference_request(payload, metadata)
def test_encode_inference_request__as_expected(
payload, metadata, output, expected_v2_request
):
encoded_request = to_v2_inference_request(payload, metadata, output)
assert encoded_request == expected_v2_request


@pytest.mark.parametrize(
"batch",
[True, False, None],
)
async def test_backbox_explain_with_batch(batch):
data_np = np.array([[1.0, 2.0], [3.0, 4.0]])

def _explain_impl(input_data: np.ndarray) -> Explanation:
if batch:
assert input_data.shape == (2, 2)
else:
assert input_data.shape == (2,) # we have returned the first element
return Explanation(meta={}, data={})

rt = AlibiExplainBlackBoxRuntime(
settings=ModelSettings(
name="foo",
implementation=AlibiExplainRuntime,
parameters={
"extra": {
"infer_uri": "dum",
"explainer_type": "dum",
"explainer_batch": batch,
}
},
),
explainer_class=alibi.explainers.anchors.anchor_tabular.AnchorTabular,
)
rt._model = alibi.explainers.anchors.anchor_tabular.AnchorTabular(
lambda x: x, ["a"]
)
rt._model.explain = _explain_impl

inference_request = InferenceRequest(
inputs=[
RequestInput(
name="predict",
shape=data_np.shape,
data=data_np.tolist(),
datatype="FP32",
)
],
)

res = await rt.predict(inference_request)
assert isinstance(res, InferenceResponse)

0 comments on commit 71da660

Please sign in to comment.