Skip to content

Commit

Permalink
Refactor model wrapper (#1056)
Browse files Browse the repository at this point in the history
* remove model wrapper and simplify class creation

* lint fixes
  • Loading branch information
sakoush authored Mar 24, 2023
1 parent 1dfd221 commit 137d5a6
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 150 deletions.
87 changes: 8 additions & 79 deletions runtimes/alibi-explain/mlserver_alibi_explain/runtime.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import json
import asyncio
import numpy as np
import functools
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional, Dict

from typing import Any, Optional, List, Dict

import numpy as np
import pandas as pd
from alibi.api.interfaces import Explanation, Explainer
from alibi.saving import load_explainer
from concurrent.futures import ThreadPoolExecutor

from mlserver.codecs import (
NumpyRequestCodec,
InputCodecLike,
StringCodec,
RequestCodecLike,
)
from mlserver.errors import ModelParametersMissing
from mlserver.handlers import custom_handler
Expand All @@ -24,14 +21,10 @@
from mlserver.types import (
InferenceRequest,
InferenceResponse,
RequestInput,
MetadataModelResponse,
Parameters,
MetadataTensor,
ResponseOutput,
)
from mlserver.utils import get_model_uri

from mlserver_alibi_explain.alibi_dependency_reference import (
get_mlmodel_class_as_str,
get_alibi_class_as_str,
Expand All @@ -57,6 +50,7 @@ def __init__(
self._executor = ThreadPoolExecutor()
super().__init__(settings)

@custom_handler(rest_path="/explain")
async def explain_v1_output(self, request: InferenceRequest) -> Response:
"""
A custom endpoint to return explanation results in plain json format (no v2
Expand Down Expand Up @@ -148,10 +142,10 @@ def _explain_impl(self, input_data: Any, explain_parameters: Dict) -> Explanatio
raise NotImplementedError


class AlibiExplainRuntime(MLModel):
class AlibiExplainRuntime:
"""Wrapper / Factory class for specific alibi explain runtimes"""

def __init__(self, settings: ModelSettings):
def __new__(cls, settings: ModelSettings):
# TODO: we probably want to validate the enum more sanely here
# we do not want to construct a specific alibi settings here because
# it might be dependent on type
Expand All @@ -165,69 +159,4 @@ def __init__(self, settings: ModelSettings):

alibi_class = import_and_get_class(get_alibi_class_as_str(explainer_type))

self._rt = rt_class(settings, alibi_class)

@property
def name(self) -> str:
return self._rt.name

@property
def version(self) -> Optional[str]:
return self._rt.version

@property
def settings(self) -> ModelSettings:
return self._rt.settings

@property
def inputs(self) -> Optional[List[MetadataTensor]]:
return self._rt.inputs

@inputs.setter
def inputs(self, value: List[MetadataTensor]):
self._rt.inputs = value

@property
def outputs(self) -> Optional[List[MetadataTensor]]:
return self._rt.outputs

@outputs.setter
def outputs(self, value: List[MetadataTensor]):
self._rt.outputs = value

@property # type: ignore
def ready(self) -> bool: # type: ignore
return self._rt.ready

@ready.setter
def ready(self, value: bool):
self._rt.ready = value

def decode(
self,
request_input: RequestInput,
default_codec: Optional[InputCodecLike] = None,
) -> Any:
return self._rt.decode(request_input, default_codec)

def decode_request(
self,
inference_request: InferenceRequest,
default_codec: Optional[RequestCodecLike] = None,
) -> Any:
return self._rt.decode_request(inference_request, default_codec)

async def metadata(self) -> MetadataModelResponse:
return await self._rt.metadata()

async def load(self) -> bool:
return await self._rt.load()

async def predict(self, payload: InferenceRequest) -> InferenceResponse:
return await self._rt.predict(payload)

# we add _explain_v1_output here to enable the registration and routing of custom
# endpoint to `_rt.explain_v1_output`
@custom_handler(rest_path="/explain")
async def _explain_v1_output(self, request: InferenceRequest) -> Response:
return await self._rt.explain_v1_output(request)
return rt_class(settings, alibi_class)
2 changes: 2 additions & 0 deletions runtimes/alibi-explain/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ async def anchor_image_runtime_with_remote_predict_patch(
),
)
)
assert isinstance(rt, MLModel)
await rt.load()

yield rt
Expand All @@ -254,6 +255,7 @@ async def integrated_gradients_runtime() -> AlibiExplainRuntime:
),
)
)
assert isinstance(rt, MLModel)
await rt.load()

return rt
Expand Down
79 changes: 14 additions & 65 deletions runtimes/alibi-explain/tests/test_alibi_runtime_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,27 @@
from alibi.api.interfaces import Explanation
from numpy.testing import assert_array_equal

from mlserver import ModelSettings, MLModel
from mlserver import ModelSettings
from mlserver.codecs import NumpyCodec
from mlserver.types import (
InferenceRequest,
InferenceResponse,
Parameters,
RequestInput,
ResponseOutput,
MetadataTensor,
)
from mlserver_alibi_explain.common import (
convert_from_bytes,
remote_predict,
AlibiExplainSettings,
)
from mlserver_alibi_explain.errors import InvalidExplanationShape
from mlserver_alibi_explain.explainers.black_box_runtime import (
AlibiExplainBlackBoxRuntime,
)
from mlserver_alibi_explain.explainers.integrated_gradients import (
IntegratedGradientsWrapper,
)
from mlserver_alibi_explain.runtime import AlibiExplainRuntime, AlibiExplainRuntimeBase
from .helpers.run_async import run_async_as_sync, run_sync_as_async

Expand Down Expand Up @@ -53,6 +58,7 @@ async def test_integrated_gradients__smoke(
)
],
)
assert isinstance(integrated_gradients_runtime, IntegratedGradientsWrapper)
response = await integrated_gradients_runtime.predict(inference_request)
res = convert_from_bytes(response.outputs[0], ty=str)
res_dict = json.dumps(res)
Expand Down Expand Up @@ -82,6 +88,9 @@ async def test_anchors__smoke(
)
],
)
assert isinstance(
anchor_image_runtime_with_remote_predict_patch, AlibiExplainBlackBoxRuntime
)
response = await anchor_image_runtime_with_remote_predict_patch.predict(
inference_request
)
Expand Down Expand Up @@ -123,67 +132,6 @@ def _sync_request(*args, **kwargs):
assert isinstance(res, InferenceResponse)


async def test_alibi_runtime_wrapper(custom_runtime_tf: MLModel):
"""
Checks that the wrapper returns back the expected valued from the underlying rt
"""

class _MockInit(AlibiExplainRuntime):
def __init__(self, settings: ModelSettings):
self._rt = custom_runtime_tf

data = np.random.randn(10, 28, 28, 1) * 255
inference_request = InferenceRequest(
parameters=Parameters(content_type=NumpyCodec.ContentType),
inputs=[
RequestInput(
name="predict",
shape=data.shape,
data=data.tolist(),
datatype="FP32",
)
],
)

# settings object is dummy and discarded
wrapper = _MockInit(ModelSettings(name="foo", implementation=AlibiExplainRuntime))

assert wrapper.settings == custom_runtime_tf.settings
assert wrapper.name == custom_runtime_tf.name
assert wrapper.version == custom_runtime_tf.version
assert wrapper.inputs == custom_runtime_tf.inputs
assert wrapper.outputs == custom_runtime_tf.outputs
assert wrapper.ready == custom_runtime_tf.ready

assert await wrapper.metadata() == await custom_runtime_tf.metadata()
assert await wrapper.predict(inference_request) == await custom_runtime_tf.predict(
inference_request
)

# check setters
dummy_shape_metadata = [
MetadataTensor(
name="dummy",
datatype="FP32",
shape=[1, 2],
)
]
wrapper.inputs = dummy_shape_metadata
custom_runtime_tf.inputs = dummy_shape_metadata
assert wrapper.inputs == custom_runtime_tf.inputs

wrapper.outputs = dummy_shape_metadata
custom_runtime_tf.outputs = dummy_shape_metadata
assert wrapper.outputs == custom_runtime_tf.outputs

wrapper_public_funcs = list(filter(lambda x: not x.startswith("_"), dir(wrapper)))
expected_public_funcs = list(
filter(lambda x: not x.startswith("_"), dir(custom_runtime_tf))
)

assert wrapper_public_funcs == expected_public_funcs


def fake_predictor(x):
return x

Expand Down Expand Up @@ -289,7 +237,8 @@ async def test_v1_invalid_predict(
async def _mocked_predict(request: InferenceRequest) -> InferenceResponse:
return response

with patch.object(integrated_gradients_runtime._rt, "predict", _mocked_predict):
with patch.object(integrated_gradients_runtime, "predict", _mocked_predict):
request = InferenceRequest(inputs=[])
with pytest.raises(InvalidExplanationShape):
await integrated_gradients_runtime._rt.explain_v1_output(request)
assert isinstance(integrated_gradients_runtime, AlibiExplainRuntimeBase)
await integrated_gradients_runtime.explain_v1_output(request)
19 changes: 13 additions & 6 deletions runtimes/alibi-explain/tests/test_black_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from mlserver_alibi_explain.explainers.black_box_runtime import (
AlibiExplainBlackBoxRuntime,
)
from mlserver_alibi_explain.runtime import AlibiExplainRuntime
from mlserver_alibi_explain.runtime import AlibiExplainRuntime, AlibiExplainRuntimeBase
from .helpers.run_async import run_sync_as_async
from .helpers.tf_model import get_tf_mnist_model_uri

Expand Down Expand Up @@ -74,8 +74,11 @@ async def test_predict_impl(

# [batch, image_x, image_y, channel]
data = np.random.randn(10, 28, 28, 1) * 255
assert isinstance(
anchor_image_runtime_with_remote_predict_patch, AlibiExplainBlackBoxRuntime
)
actual_result = await run_sync_as_async(
anchor_image_runtime_with_remote_predict_patch._rt._infer_impl, data
anchor_image_runtime_with_remote_predict_patch._infer_impl, data
)

# now we go via the inference model and see if we get the same results
Expand Down Expand Up @@ -110,6 +113,9 @@ async def test_end_2_end(
):
# in this test we are getting explanation and making sure that is the same one
# as returned by alibi directly
assert isinstance(
anchor_image_runtime_with_remote_predict_patch, AlibiExplainBlackBoxRuntime
)
runtime_result = await anchor_image_runtime_with_remote_predict_patch.predict(
payload
)
Expand All @@ -136,10 +142,11 @@ async def test_end_2_end_explain_v1_output(
):
# in this test we get raw explanation as opposed to v2

response = (
await anchor_image_runtime_with_remote_predict_patch._rt.explain_v1_output(
payload
)
assert isinstance(
anchor_image_runtime_with_remote_predict_patch, AlibiExplainRuntimeBase
)
response = await anchor_image_runtime_with_remote_predict_patch.explain_v1_output(
payload
)

response_body = json.loads(response.body)
Expand Down

0 comments on commit 137d5a6

Please sign in to comment.