Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mlserver infer with BYTES #1213

Merged
merged 3 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mlserver/batch_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tritonclient.http.aio as httpclient

import asyncio
import numpy as np
import aiofiles
import logging
import click
Expand Down Expand Up @@ -101,17 +102,21 @@ def from_inference_request(
cls, inference_request: InferenceRequest, binary_data: bool
) -> "TritonRequest":
inputs = []

for request_input in inference_request.inputs or []:
new_input = httpclient.InferInput(
request_input.name, request_input.shape, request_input.datatype
)
request_input_np = NumpyCodec.decode_input(request_input)

# Change datatype if BYTES to satisfy Tritonclient checks
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the feeling this should be something handled in the NumpyCodec but it may be too tritonclient specific...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - do you know what's returned by NumpyCodec.decode_input at the moment?

Copy link
Contributor Author

@RafalSkolasinski RafalSkolasinski Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tritonclient expects

    elif np_dtype == np.object_ or np_dtype.type == np.bytes_:
        return "BYTES"
    return None

as otherwise it is None that was causing an issue.

As we do not set anything there explicitly (and I also seen that by adding extra logs earlier during debugging) we just have '<U1'

In [2]: import numpy as np

In [3]: x = np.array(["x"])

In [4]: x.dtype == np.bytes_
Out[4]: False

In [5]: x.dtype
Out[5]: dtype('<U1')

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what's going is that the NumpyCodec encodes that list as a tensor of strings (i.e. using np.dtype(str)). However, Triton doesn't know how to handle that case.

if is_list_of(data, str):
# Handle special case of strings being treated as Numpy arrays
return np.dtype(str)

I can't remember why we had to add that np.dtype(str) special case, but I would be keen to keep it as is to avoid any potential side effects. Having said that, it's unclear whether changing the dtype here may cause any issues downstream (although I'd expect that strings in np arrays are quite an edge case).

if request_input.datatype == "BYTES":
request_input_np = request_input_np.astype(np.object_)

new_input.set_data_from_numpy(
NumpyCodec.decode_input(request_input),
request_input_np,
binary_data=binary_data,
)
inputs.append(new_input)

outputs = []
for request_output in inference_request.outputs or []:
new_output = httpclient.InferRequestedOutput(
Expand Down Expand Up @@ -208,7 +213,6 @@ def preprocess_items(
)
invalid_inputs.append(_serialize_validation_error(item.index, e))
batched = BatchedRequests(inference_requests)

# Set `id` for batched requests - if only single request use its own id
if len(inference_requests) == 1:
batched.merged_request.id = inference_request.id
Expand Down
15 changes: 15 additions & 0 deletions tests/batch_processing/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import pytest
import os

from mlserver import MLModel, MLServer, ModelSettings

from ..conftest import TESTDATA_PATH
from ..fixtures import EchoModel


@pytest.fixture()
def single_input():
return os.path.join(TESTDATA_PATH, "batch_processing", "single.txt")


@pytest.fixture()
def bytes_input():
return os.path.join(TESTDATA_PATH, "batch_processing", "bytes.txt")


@pytest.fixture()
def invalid_input():
return os.path.join(TESTDATA_PATH, "batch_processing", "invalid.txt")
Expand All @@ -27,3 +35,10 @@ def many_input():
@pytest.fixture()
def single_input_with_id():
return os.path.join(TESTDATA_PATH, "batch_processing", "single_with_id.txt")

@pytest.fixture
async def echo_model(mlserver: MLServer) -> MLModel:
model_settings = ModelSettings(
name="echo-model", implementation=EchoModel
)
return await mlserver._model_registry.load(model_settings)
46 changes: 45 additions & 1 deletion tests/batch_processing/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mlserver.batch_processing import process_batch
from mlserver.settings import Settings

from mlserver import MLModel

from ..utils import RESTClient

Expand Down Expand Up @@ -53,6 +53,50 @@ async def test_single(
raise RuntimeError(f"Response id is not a valid UUID; got {response['id']}")


async def test_bytes(
tmp_path: str,
echo_model: MLModel,
rest_client: RESTClient,
settings: Settings,
bytes_input: str,
):
await rest_client.wait_until_ready()
model_name = "echo-model"
url = f"{settings.host}:{settings.http_port}"
output_file = os.path.join(tmp_path, "output.txt")

await process_batch(
model_name=model_name,
url=url,
workers=1,
retries=1,
input_data_path=bytes_input,
output_data_path=output_file,
binary_data=False,
batch_size=1,
transport="rest",
request_headers={},
batch_interval=0,
batch_jitter=0,
timeout=60,
use_ssl=False,
insecure=False,
verbose=True,
extra_verbose=True,
)

with open(output_file) as f:
response = json.load(f)

assert response["outputs"][0]["data"] == ["a", "b", "c"]
assert response["id"] is not None and response["id"] != ""
assert response["parameters"]["batch_index"] == 0
try:
_ = UUID(response["id"])
except ValueError:
raise RuntimeError(f"Response id is not a valid UUID; got {response['id']}")


async def test_invalid(
tmp_path: str,
rest_client: RESTClient,
Expand Down
30 changes: 29 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from typing import Dict, List

from mlserver import MLModel
from mlserver.types import InferenceRequest, InferenceResponse, Parameters
from mlserver.types import (
InferenceRequest,
InferenceResponse,
ResponseOutput,
Parameters,
)
from mlserver.codecs import NumpyCodec, decode_args, StringCodec
from mlserver.handlers.custom import custom_handler
from mlserver.errors import MLServerError
Expand Down Expand Up @@ -100,3 +105,26 @@ async def predict(self, inference_request: InferenceRequest) -> InferenceRespons
StringCodec.encode_output("sklearn_version", [self._sklearn_version]),
],
)


class EchoModel(MLModel):
async def load(self) -> bool:
print("Echo Model Initialized")
return await super().load()

async def predict(self, payload: InferenceRequest) -> InferenceResponse:
return InferenceResponse(
id=payload.id,
model_name=self.name,
model_version=self.version,
outputs=[
ResponseOutput(
name=input.name,
shape=input.shape,
datatype=input.datatype,
data=input.data,
parameters=input.parameters,
)
for input in payload.inputs
],
)
1 change: 1 addition & 0 deletions tests/testdata/batch_processing/bytes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"inputs":[{"name":"input-0","shape":[1,3],"datatype":"BYTES","data":["a","b","c"]}]}