Skip to content

Commit

Permalink
Fix dev gluon test (mlflow#4574)
Browse files Browse the repository at this point in the history
* Fix dev gluon test

Signed-off-by: harupy <hkawamura0130@gmail.com>

* clean up imports

Signed-off-by: harupy <hkawamura0130@gmail.com>

* import mxnet

Signed-off-by: harupy <hkawamura0130@gmail.com>

* remove unsued import

Signed-off-by: harupy <hkawamura0130@gmail.com>
  • Loading branch information
harupy authored Jul 19, 2021
1 parent ee6ae1a commit 2b7a0ea
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 21 deletions.
15 changes: 9 additions & 6 deletions mlflow/gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,23 @@ def load_model(model_uri, ctx):
model = mlflow.gluon.load_model("runs:/" + gluon_random_data_run.info.run_id + "/model")
model(nd.array(np.random.rand(1000, 1, 32)))
"""
import mxnet
import mxnet as mx
from mxnet import gluon
from mxnet import sym

local_model_path = _download_artifact_from_uri(artifact_uri=model_uri)

model_arch_path = os.path.join(local_model_path, "data", _MODEL_SAVE_PATH) + "-symbol.json"
model_params_path = os.path.join(local_model_path, "data", _MODEL_SAVE_PATH) + "-0000.params"
symbol = sym.load(model_arch_path)
inputs = sym.var("data", dtype="float32")
net = gluon.SymbolBlock(symbol, inputs)
if Version(mxnet.__version__) >= Version("2.0.0"):
net.load_parameters(model_params_path, ctx)

if Version(mx.__version__) >= Version("2.0.0"):
return gluon.SymbolBlock.imports(
model_arch_path, input_names=["data"], param_file=model_params_path, ctx=ctx
)
else:
symbol = sym.load(model_arch_path)
inputs = sym.var("data", dtype="float32")
net = gluon.SymbolBlock(symbol, inputs)
net.collect_params().load(model_params_path, ctx)
return net

Expand Down
26 changes: 14 additions & 12 deletions tests/gluon/test_gluon_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import yaml

import mxnet as mx
import mxnet.ndarray as nd
import numpy as np
import pandas as pd
import pytest
Expand All @@ -30,9 +29,13 @@

if Version(mx.__version__) >= Version("2.0.0"):
from mxnet.gluon.metric import Accuracy # pylint: disable=import-error

array_module = mx.np
else:
from mxnet.metric import Accuracy # pylint: disable=import-error

array_module = mx.nd


@pytest.fixture
def model_path(tmpdir):
Expand All @@ -49,18 +52,17 @@ def gluon_custom_env(tmpdir):
@pytest.fixture(scope="module")
def model_data():
mnist = mx.test_utils.get_mnist()
train_data = nd.array(mnist["train_data"].reshape(-1, 784))
train_label = nd.array(mnist["train_label"])
test_data = nd.array(mnist["test_data"].reshape(-1, 784))
train_data = array_module.array(mnist["train_data"].reshape(-1, 784))
train_label = array_module.array(mnist["train_label"])
test_data = array_module.array(mnist["test_data"].reshape(-1, 784))
return train_data, train_label, test_data


@pytest.fixture(scope="module")
def gluon_model(model_data):
train_data, train_label, _ = model_data
train_data_loader = DataLoader(
list(zip(train_data, train_label)), batch_size=128, last_batch="discard"
)
dataset = mx.gluon.data.ArrayDataset(train_data, train_label)
train_data_loader = DataLoader(dataset, batch_size=128, last_batch="discard")
model = HybridSequential()
model.add(Dense(128, activation="relu"))
model.add(Dense(64, activation="relu"))
Expand All @@ -85,12 +87,12 @@ def gluon_model(model_data):
@pytest.mark.large
def test_model_save_load(gluon_model, model_data, model_path):
_, _, test_data = model_data
expected = nd.argmax(gluon_model(test_data), axis=1)
expected = array_module.argmax(gluon_model(test_data), axis=1)

mlflow.gluon.save_model(gluon_model, model_path)
# Loading Gluon model
model_loaded = mlflow.gluon.load_model(model_path, ctx.cpu())
actual = nd.argmax(model_loaded(test_data), axis=1)
actual = array_module.argmax(model_loaded(test_data), axis=1)
assert all(expected == actual)
# Loading pyfunc model
pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
Expand Down Expand Up @@ -128,7 +130,7 @@ def test_signature_and_examples_are_saved_correctly(gluon_model, model_data):
def test_model_log_load(gluon_model, model_data, model_path):
# pylint: disable=unused-argument
_, _, test_data = model_data
expected = nd.argmax(gluon_model(test_data), axis=1)
expected = array_module.argmax(gluon_model(test_data), axis=1)

artifact_path = "model"
with mlflow.start_run():
Expand All @@ -139,7 +141,7 @@ def test_model_log_load(gluon_model, model_data, model_path):

# Loading Gluon model
model_loaded = mlflow.gluon.load_model(model_uri, ctx.cpu())
actual = nd.argmax(model_loaded(test_data), axis=1)
actual = array_module.argmax(model_loaded(test_data), axis=1)
assert all(expected == actual)
# Loading pyfunc model
pyfunc_loaded = mlflow.pyfunc.load_model(model_uri)
Expand Down Expand Up @@ -238,7 +240,7 @@ def test_model_log_persists_requirements_in_mlflow_model_directory(gluon_model,
@pytest.mark.large
def test_gluon_model_serving_and_scoring_as_pyfunc(gluon_model, model_data):
_, _, test_data = model_data
expected = nd.argmax(gluon_model(test_data), axis=1)
expected = array_module.argmax(gluon_model(test_data), axis=1)

artifact_path = "model"
with mlflow.start_run():
Expand Down
12 changes: 9 additions & 3 deletions tests/gluon_autolog/test_gluon_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings

import mxnet as mx
import mxnet.ndarray as nd
import numpy as np
import pytest
from mxnet.gluon import Trainer
Expand All @@ -19,16 +18,23 @@

if Version(mx.__version__) >= Version("2.0.0"):
from mxnet.gluon.metric import Accuracy # pylint: disable=import-error

array_module = mx.np
else:
from mxnet.metric import Accuracy # pylint: disable=import-error

array_module = mx.nd


class LogsDataset(Dataset):
def __init__(self):
self.len = 1000

def __getitem__(self, idx):
return nd.array(np.random.rand(1, 32)), nd.full(1, random.randint(0, 10), dtype="float32")
return (
array_module.array(np.random.rand(1, 32)),
array_module.full(1, random.randint(0, 10), dtype="float32"),
)

def __len__(self):
return self.len
Expand Down Expand Up @@ -139,7 +145,7 @@ def test_gluon_autolog_model_can_load_from_artifact(gluon_random_data_run):
assert "model" in artifacts
ctx = mx.cpu()
model = mlflow.gluon.load_model("runs:/" + gluon_random_data_run.info.run_id + "/model", ctx)
model(nd.array(np.random.rand(1000, 1, 32)))
model(array_module.array(np.random.rand(1000, 1, 32)))


@pytest.mark.large
Expand Down

0 comments on commit 2b7a0ea

Please sign in to comment.