Skip to content

Commit

Permalink
Sherif akoush/quickfix/create artifacts on fly (#374)
Browse files Browse the repository at this point in the history
* train models on the fly

* delete artifacts as we create them on the fly

* lint, format
  • Loading branch information
sakoush authored Nov 9, 2021
1 parent f06f1b1 commit 42eac47
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 8 deletions.
26 changes: 23 additions & 3 deletions runtimes/alibi-explain/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from typing import AsyncIterable
from unittest.mock import patch

import tensorflow as tf

import nest_asyncio
import pytest
from alibi.explainers import AnchorImage
from fastapi import FastAPI
from fastapi.testclient import TestClient

Expand All @@ -18,14 +21,15 @@
from mlserver.settings import ModelSettings, ModelParameters, Settings
from mlserver_alibi_explain.common import AlibiExplainSettings
from mlserver_alibi_explain.runtime import AlibiExplainRuntime
from helpers.tf_model import TFMNISTModel
from helpers.tf_model import TFMNISTModel, get_tf_mnist_model_uri

# allow nesting loop
# in our case this allows multiple runtimes to execute
# in the same thread for testing reasons
nest_asyncio.apply()

TESTS_PATH = Path(os.path.dirname(__file__))
_ANCHOR_IMAGE_DIR = TESTS_PATH / "data" / "mnist_anchor_image"


# TODO: how to make this in utils?
Expand Down Expand Up @@ -120,8 +124,16 @@ def rest_client(rest_app: FastAPI) -> TestClient:
return TestClient(rest_app)


@pytest.fixture
def anchor_image_directory() -> Path:
if not _ANCHOR_IMAGE_DIR.exists():
_train_anchor_image_explainer()
return _ANCHOR_IMAGE_DIR


@pytest.fixture
async def anchor_image_runtime_with_remote_predict_patch(
anchor_image_directory,
custom_runtime_tf: MLModel,
remote_predict_mock_path: str = "mlserver_alibi_explain.common.remote_predict",
) -> AlibiExplainRuntime:
Expand Down Expand Up @@ -153,7 +165,7 @@ def mock_predict(*args, **kwargs):
ModelSettings(
parallel_workers=0,
parameters=ModelParameters(
uri=str(TESTS_PATH / "data" / "mnist_anchor_image"),
uri=str(anchor_image_directory),
extra=AlibiExplainSettings(
explainer_type="anchor_image", infer_uri="dummy_call"
),
Expand All @@ -174,11 +186,19 @@ async def integrated_gradients_runtime() -> AlibiExplainRuntime:
extra=AlibiExplainSettings(
init_parameters={"n_steps": 50, "method": "gausslegendre"},
explainer_type="integrated_gradients",
infer_uri=str(TESTS_PATH / "data" / "tf_mnist" / "model.h5"),
infer_uri=str(get_tf_mnist_model_uri()),
)
),
)
)
await rt.load()

return rt


def _train_anchor_image_explainer() -> None:
model = tf.keras.models.load_model(get_tf_mnist_model_uri())
anchor_image = AnchorImage(model.predict, (28, 28, 1))

_ANCHOR_IMAGE_DIR.mkdir(parents=True)
anchor_image.save(_ANCHOR_IMAGE_DIR)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed runtimes/alibi-explain/tests/data/tf_mnist/model.h5
Binary file not shown.
55 changes: 54 additions & 1 deletion runtimes/alibi-explain/tests/helpers/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,23 @@
from pathlib import Path

import tensorflow as tf
from tensorflow.keras.layers import Activation, Conv2D, Dense, Dropout
from tensorflow.keras.layers import Flatten, Input, MaxPooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical

from mlserver import MLModel
from mlserver.codecs import NumpyCodec
from mlserver.types import InferenceRequest, InferenceResponse


_MODEL_PATH = Path(os.path.dirname(__file__)).parent / "data" / "tf_mnist" / "model.h5"


def get_tf_mnist_model_uri() -> Path:
return Path(os.path.dirname(__file__)).parent / "data" / "tf_mnist" / "model.h5"
if not _MODEL_PATH.exists():
_train_tf_mnist()
return _MODEL_PATH


class TFMNISTModel(MLModel):
Expand All @@ -27,3 +36,47 @@ async def load(self) -> bool:
self._model = tf.keras.models.load_model(get_tf_mnist_model_uri())
self.ready = True
return self.ready


def _train_tf_mnist() -> None:
train, test = tf.keras.datasets.mnist.load_data()
X_train, y_train = train
X_test, y_test = test

X_train = X_train.reshape(-1, 28, 28, 1).astype("float64") / 255
X_test = X_test.reshape(-1, 28, 28, 1).astype("float64") / 255
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

inputs = Input(shape=(X_train.shape[1:]), dtype=tf.float64)
x = Conv2D(64, 2, padding="same", activation="relu")(inputs)
x = MaxPooling2D(pool_size=2)(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 2, padding="same", activation="relu")(x)
x = MaxPooling2D(pool_size=2)(x)
x = Dropout(0.3)(x)

x = Flatten()(x)
x = Dense(256, activation="relu")(x)
x = Dropout(0.5)(x)
logits = Dense(10, name="logits")(x)
outputs = Activation("softmax", name="softmax")(logits)

model = Model(inputs=inputs, outputs=outputs)
model.compile(
loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)

# train model
model.fit(
X_train,
y_train,
epochs=6,
batch_size=256,
verbose=1,
validation_data=(X_test, y_test),
)

_MODEL_PATH.parent.mkdir(parents=True)
model.save(_MODEL_PATH, save_format="h5")
6 changes: 2 additions & 4 deletions runtimes/alibi-explain/tests/test_black_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,9 @@ async def test_predict_impl(


@pytest.fixture()
def alibi_anchor_image_model():
def alibi_anchor_image_model(anchor_image_directory):
inference_model = tf.keras.models.load_model(get_tf_mnist_model_uri())
model = load_explainer(
TESTS_PATH / "data" / "mnist_anchor_image", inference_model.__call__
)
model = load_explainer(anchor_image_directory, inference_model.predict)
return model


Expand Down

0 comments on commit 42eac47

Please sign in to comment.