Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for white-box explainers to alibi-explain runtime #1279

Merged
merged 20 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class ExplainerDependencyReference:
_ANCHOR_TABULAR_TAG = "anchor_tabular"
_KERNEL_SHAP_TAG = "kernel_shap"
_INTEGRATED_GRADIENTS_TAG = "integrated_gradients"
_TREE_SHAP_TAG = "tree_shap"
_TREE_PARTIAL_DEPENDENCE_TAG = "tree_partial_dependence"
_TREE_PARTIAL_DEPENDENCE_VARIANCE_TAG = "tree_partial_dependence_variance"


# NOTE: to add new explainers populate the below dict with a new
Expand All @@ -30,6 +33,7 @@ class ExplainerDependencyReference:

_BLACKBOX_MODULE = "mlserver_alibi_explain.explainers.black_box_runtime"
_INTEGRATED_GRADIENTS_MODULE = "mlserver_alibi_explain.explainers.integrated_gradients"
_WHITEBOX_SKLEARN_MODULE = "mlserver_alibi_explain.explainers.sklearn_api_runtime"

_TAG_TO_RT_IMPL: Dict[str, ExplainerDependencyReference] = {
_ANCHOR_IMAGE_TAG: ExplainerDependencyReference(
Expand Down Expand Up @@ -57,6 +61,21 @@ class ExplainerDependencyReference:
runtime_class=f"{_INTEGRATED_GRADIENTS_MODULE}.IntegratedGradientsWrapper",
alibi_class="alibi.explainers.IntegratedGradients",
),
_TREE_SHAP_TAG: ExplainerDependencyReference(
explainer_name=_TREE_SHAP_TAG,
runtime_class=f"{_WHITEBOX_SKLEARN_MODULE}.AlibiExplainSKLearnAPIRuntime",
alibi_class="alibi.explainers.TreeShap",
),
_TREE_PARTIAL_DEPENDENCE_TAG: ExplainerDependencyReference(
explainer_name=_TREE_PARTIAL_DEPENDENCE_TAG,
runtime_class=f"{_WHITEBOX_SKLEARN_MODULE}.AlibiExplainSKLearnAPIRuntime",
alibi_class="alibi.explainers.TreePartialDependence",
),
_TREE_PARTIAL_DEPENDENCE_VARIANCE_TAG: ExplainerDependencyReference(
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
explainer_name=_TREE_PARTIAL_DEPENDENCE_VARIANCE_TAG,
runtime_class=f"{_WHITEBOX_SKLEARN_MODULE}.AlibiExplainSKLearnAPIRuntime",
alibi_class="alibi.explainers.PartialDependenceVariance",
),
}


Expand All @@ -66,6 +85,9 @@ class ExplainerEnum(str, Enum):
anchor_tabular = _ANCHOR_TABULAR_TAG
kernel_shap = _KERNEL_SHAP_TAG
integrated_gradients = _INTEGRATED_GRADIENTS_TAG
tree_shap = _TREE_SHAP_TAG
tree_partial_dependence = _TREE_PARTIAL_DEPENDENCE_TAG
tree_partial_dependence_variance = _TREE_PARTIAL_DEPENDENCE_VARIANCE_TAG


def get_mlmodel_class_as_str(tag: Union[ExplainerEnum, str]) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ async def load(self) -> bool:
# TODO: use init explainer field instead?
if self.alibi_explain_settings.init_parameters is not None:
init_parameters = self.alibi_explain_settings.init_parameters
init_parameters["predictor"] = self._infer_impl
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
self._model = self._explainer_class(**init_parameters) # type: ignore
self._model = self._explainer_class(self._infer_impl, **init_parameters) # type: ignore
else:
self._model = await self._load_from_uri(self._infer_impl)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any

import joblib
import lightgbm as lgb
from xgboost.core import XGBoostError
from lightgbm.basic import LightGBMError

from mlserver_xgboost.xgboost import _load_sklearn_interface as load_xgb_model
from mlserver.errors import InvalidModelURI
from mlserver_alibi_explain.explainers.white_box_runtime import AlibiExplainWhiteBoxRuntime


class AlibiExplainSKLearnAPIRuntime(AlibiExplainWhiteBoxRuntime):
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
"""
Runtime for white-box explainers that require access to a tree-based model matching the SKLearn API, such as
a sklearn, XGBoost, or LightGBM model. Example explainers include TreeShap and TreePartialDependence.
"""
async def _get_inference_model(self) -> Any:
inference_model_path = self.alibi_explain_settings.infer_uri
# Attempt to load model in order: XGBoost, LightGBM, sklearn
# TODO - add support for CatBoost (would require model_type = 'classifier' or 'regressor' in settings)
try:
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
# Try to load as sklearn model first
model = joblib.load(inference_model_path)
except (IndexError, KeyError, IOError):
try:
# Try to load as XGBoost model
model = load_xgb_model(inference_model_path)
except XGBoostError:
try:
# Try to load as LightGBM model (do this last as raises warning if not successful)
model = lgb.Booster(model_file=inference_model_path)
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
except LightGBMError:
raise InvalidModelURI(self.name, inference_model_path)

return model
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
from typing import Any, Type
from typing import Any, Type, Dict

from alibi.api.interfaces import Explainer
from alibi.api.interfaces import Explainer, Explanation

from mlserver import ModelSettings
from mlserver_alibi_explain.common import AlibiExplainSettings
Expand Down Expand Up @@ -29,17 +29,23 @@ def __init__(self, settings: ModelSettings, explainer_class: Type[Explainer]):
super().__init__(settings, explainer_settings)

async def load(self) -> bool:
# white box explainers requires access to the full inference model
self._inference_model = await self._get_inference_model()

if self.alibi_explain_settings.init_parameters is not None:
# Instantiate explainer with init parameters (and give it full inference model)
init_parameters = self.alibi_explain_settings.init_parameters
# white box explainers requires access to the inference model
init_parameters["model"] = self._inference_model
self._model = self._explainer_class(**init_parameters) # type: ignore
self._model = self._explainer_class(self._inference_model, **init_parameters) # type: ignore
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
else:
# Load explainer from URI (and give it full inference model)
self._model = await self._load_from_uri(self._inference_model)

return True

def _explain_impl(self, input_data: Any, explain_parameters: Dict) -> Explanation:
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
# TODO: how are we going to deal with that?
assert self._inference_model is not None, "Inference model is not set"
return self._model.explain(input_data, **explain_parameters)

async def _get_inference_model(self) -> Any:
raise NotImplementedError
4 changes: 1 addition & 3 deletions runtimes/alibi-explain/mlserver_alibi_explain/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ async def _async_explain_impl(
)

async def _load_from_uri(self, predictor: Any) -> Explainer:
# load the model from disk
# full model is passed as `predictor`
# load the model from disk
"""Load the explainer from disk, and pass the predictor"""
model_parameters: Optional[ModelParameters] = self.settings.parameters
if model_parameters is None:
raise ModelParametersMissing(self.name)
Expand Down
3 changes: 3 additions & 0 deletions runtimes/alibi-explain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ packages = [{include = "mlserver_alibi_explain"}]
[tool.poetry.dependencies]
python = "^3.8.1,<3.12"
mlserver = "*"
mlserver_sklearn = "*"
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
mlserver_xgboost = "*"
mlserver_lightgbm = "*"
orjson = "*"
alibi = {extras = ["shap", "tensorflow"], version = "*"}

Expand Down
15 changes: 15 additions & 0 deletions runtimes/alibi-explain/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import AsyncIterable, Dict, Any, Iterable
from unittest.mock import patch
from typing import Type
import joblib

from httpx import AsyncClient
from fastapi import FastAPI
Expand All @@ -17,6 +18,8 @@
from alibi.api.interfaces import Explanation, Explainer
from alibi.explainers import AnchorImage

from sklearn.base import BaseEstimator

from mlserver import MLModel
from mlserver.handlers import DataPlane, ModelRepositoryHandlers
from mlserver.parallel import InferencePoolRegistry
Expand All @@ -33,6 +36,7 @@
from mlserver_alibi_explain.runtime import AlibiExplainRuntime, AlibiExplainRuntimeBase

from .helpers.tf_model import get_tf_mnist_model_uri, TFMNISTModel
from .helpers.sk_model import get_sk_income_model_uri, get_income_data
from .helpers.run_async import run_async_as_sync
from .helpers.metrics import unregister_metrics

Expand Down Expand Up @@ -323,3 +327,14 @@ def _train_anchor_image_explainer() -> None:

_ANCHOR_IMAGE_DIR.mkdir(parents=True)
anchor_image.save(_ANCHOR_IMAGE_DIR)


@pytest.fixture(scope="module")
def sk_income_model() -> BaseEstimator:
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
model = joblib.load(get_sk_income_model_uri())
return model


@pytest.fixture(scope="module")
def income_data() -> dict:
return get_income_data()
66 changes: 66 additions & 0 deletions runtimes/alibi-explain/tests/helpers/sk_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
from pathlib import Path
import joblib

from sklearn.ensemble import GradientBoostingClassifier
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
from alibi.datasets import fetch_adult

from mlserver import MLModel
from mlserver.codecs import NumpyCodec
from mlserver.types import InferenceRequest, InferenceResponse


_MODEL_PATH = Path(os.path.dirname(__file__)).parent / ".data" / "sk_income" / "model.joblib"
adriangonz marked this conversation as resolved.
Show resolved Hide resolved


def get_sk_income_model_uri() -> Path:
if not _MODEL_PATH.exists():
_train_sk_income()
return _MODEL_PATH


class SKIncomeModel(MLModel):
async def predict(self, payload: InferenceRequest) -> InferenceResponse:
np_codec = NumpyCodec
model_input = payload.inputs[0]
input_data = np_codec.decode_input(model_input)
output_data = self._model(input_data)
return InferenceResponse(
model_name=self.name,
outputs=[np_codec.encode_output("predict", output_data)]
)

async def load(self) -> bool:
self._model = joblib.load(get_sk_income_model_uri())
return True
adriangonz marked this conversation as resolved.
Show resolved Hide resolved


def _train_sk_income() -> None:
data = get_income_data()
X_train, Y_train = data['X'], data['Y']

model = GradientBoostingClassifier(n_estimators=50)
model.fit(X_train, Y_train)

_MODEL_PATH.parent.mkdir(parents=True)
joblib.dump(model, _MODEL_PATH)


def get_income_data() -> dict:
print('Generating adult dataset...')
adult = fetch_adult()
X = adult.data
Y = adult.target

feature_names = adult.feature_names
category_map = adult.category_map

# Package into dictionary
data_dict = {
'X': X,
'Y': Y,
'feature_names': feature_names,
'category_map': category_map,
'target_names': adult.target_names,
}
return data_dict
116 changes: 116 additions & 0 deletions runtimes/alibi-explain/tests/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Union, Literal, Optional, Tuple
from pathlib import Path
import numpy as np

from alibi.api.interfaces import Explainer

from mlserver.types import InferenceRequest, Parameters, RequestInput
from mlserver.codecs import NumpyCodec
from mlserver.settings import ModelSettings, ModelParameters
from mlserver_alibi_explain import AlibiExplainRuntime
from mlserver_alibi_explain.common import AlibiExplainSettings, import_and_get_class
from mlserver_alibi_explain.alibi_dependency_reference import get_alibi_class_as_str
from .sk_model import get_sk_income_model_uri


def train_explainer(
explainer_tag: str,
save_dir: Optional[Path],
fit: Union[np.ndarray, Literal[False, 'no-data']] = False,
*args,
**kwargs
) -> Explainer:
"""
Train and save an explainer.
"""
# Instantiate explainer
klass = import_and_get_class(get_alibi_class_as_str(explainer_tag))
explainer = klass(*args, **kwargs)

# Fit explainer
if fit:
explainer.fit() if fit == 'no-data' else explainer.fit(fit)

# Save explainer
if save_dir:
explainer.save(save_dir)

return explainer


def build_request(data: np.ndarray, **explain_kwargs) -> InferenceRequest:
"""
Build an inference request from a numpy array.
"""
inference_request = InferenceRequest(
parameters=Parameters(
content_type=NumpyCodec.ContentType,
explain_parameters=explain_kwargs,
),
inputs=[
RequestInput(
name="predict",
shape=data.shape,
data=data.tolist(),
datatype="FP32",
)
],
)
return inference_request


def build_test_case(explainer_type: str, init_kwargs: dict, explain_kwargs: dict,
fit: Union[np.ndarray, Literal[False, 'no-data']], save_dir: Optional[Path], payload: np.ndarray) \
-> Tuple[ModelSettings, Explainer, InferenceRequest, dict]:
"""
Function to build a test case for a given explainer type. The function returns a model settings object, an
explainer object, an inference request object and a dictionary of explain parameters.

Parameters
----------
explainer_type
The type of explainer to build.
init_kwargs
Instantiation kwargs for the explainer.
explain_kwargs
Explain kwargs for the explainer.
fit
Data to fit the explainer on, `False` if no fit is required, or `'no-data'` to fit the explainer without
data e.g. for `TreeShap` with path-dependent algorithm.
save_dir
Directory to save the explainer to, and then pass to `uri` in `ModelParameters`. If `None`, the explainer
will not be saved to disk, and `init_parameters` will specified in `ModelSettings` instead.
payload
The payload to send as request to the explainer.
"""
# Build explainer
explainer = train_explainer(
explainer_type,
save_dir,
fit=fit,
**init_kwargs
)

# Explainer model settings
model_params = {}
alibi_explain_settings = {
'explainer_type': explainer_type,
'infer_uri': str(get_sk_income_model_uri()),
}
if save_dir:
model_params['uri'] = str(save_dir)
else:
init_params = init_kwargs.copy()
init_params.pop('predictor') # TODO: Will need to add `model`, `predict_fn` here eventually
alibi_explain_settings['init_parameters'] = init_params
model_params['extra'] = AlibiExplainSettings(**alibi_explain_settings)

model_settings = ModelSettings(
name="foo",
implementation=AlibiExplainRuntime,
parameters=ModelParameters(**model_params),
)

# Inference request
inference_request = build_request(payload, **explain_kwargs)
return model_settings, explainer, inference_request, explain_kwargs
Loading