Skip to content

Commit

Permalink
Merge pull request #1213 from RafalSkolasinski/batch-bytes
Browse files Browse the repository at this point in the history
fix mlserver infer with BYTES
  • Loading branch information
RafalSkolasinski authored and Adrian Gonzalez-Martin committed Jun 16, 2023
1 parent f9f51ed commit c57b67e
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 6 deletions.
12 changes: 8 additions & 4 deletions mlserver/batch_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tritonclient.http.aio as httpclient

import asyncio
import numpy as np
import aiofiles
import logging
import click
Expand Down Expand Up @@ -101,17 +102,21 @@ def from_inference_request(
cls, inference_request: InferenceRequest, binary_data: bool
) -> "TritonRequest":
inputs = []

for request_input in inference_request.inputs or []:
new_input = httpclient.InferInput(
request_input.name, request_input.shape, request_input.datatype
)
request_input_np = NumpyCodec.decode_input(request_input)

# Change datatype if BYTES to satisfy Tritonclient checks
if request_input.datatype == "BYTES":
request_input_np = request_input_np.astype(np.object_)

new_input.set_data_from_numpy(
NumpyCodec.decode_input(request_input),
request_input_np,
binary_data=binary_data,
)
inputs.append(new_input)

outputs = []
for request_output in inference_request.outputs or []:
new_output = httpclient.InferRequestedOutput(
Expand Down Expand Up @@ -208,7 +213,6 @@ def preprocess_items(
)
invalid_inputs.append(_serialize_validation_error(item.index, e))
batched = BatchedRequests(inference_requests)

# Set `id` for batched requests - if only single request use its own id
if len(inference_requests) == 1:
batched.merged_request.id = inference_request.id
Expand Down
14 changes: 14 additions & 0 deletions tests/batch_processing/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import pytest
import os

from mlserver import MLModel, MLServer, ModelSettings

from ..conftest import TESTDATA_PATH
from ..fixtures import EchoModel


@pytest.fixture()
def single_input():
return os.path.join(TESTDATA_PATH, "batch_processing", "single.txt")


@pytest.fixture()
def bytes_input():
return os.path.join(TESTDATA_PATH, "batch_processing", "bytes.txt")


@pytest.fixture()
def invalid_input():
return os.path.join(TESTDATA_PATH, "batch_processing", "invalid.txt")
Expand All @@ -27,3 +35,9 @@ def many_input():
@pytest.fixture()
def single_input_with_id():
return os.path.join(TESTDATA_PATH, "batch_processing", "single_with_id.txt")


@pytest.fixture
async def echo_model(mlserver: MLServer) -> MLModel:
model_settings = ModelSettings(name="echo-model", implementation=EchoModel)
return await mlserver._model_registry.load(model_settings)
46 changes: 45 additions & 1 deletion tests/batch_processing/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mlserver.batch_processing import process_batch
from mlserver.settings import Settings

from mlserver import MLModel

from ..utils import RESTClient

Expand Down Expand Up @@ -53,6 +53,50 @@ async def test_single(
raise RuntimeError(f"Response id is not a valid UUID; got {response['id']}")


async def test_bytes(
tmp_path: str,
echo_model: MLModel,
rest_client: RESTClient,
settings: Settings,
bytes_input: str,
):
await rest_client.wait_until_ready()
model_name = "echo-model"
url = f"{settings.host}:{settings.http_port}"
output_file = os.path.join(tmp_path, "output.txt")

await process_batch(
model_name=model_name,
url=url,
workers=1,
retries=1,
input_data_path=bytes_input,
output_data_path=output_file,
binary_data=False,
batch_size=1,
transport="rest",
request_headers={},
batch_interval=0,
batch_jitter=0,
timeout=60,
use_ssl=False,
insecure=False,
verbose=True,
extra_verbose=True,
)

with open(output_file) as f:
response = json.load(f)

assert response["outputs"][0]["data"] == ["a", "b", "c"]
assert response["id"] is not None and response["id"] != ""
assert response["parameters"]["batch_index"] == 0
try:
_ = UUID(response["id"])
except ValueError:
raise RuntimeError(f"Response id is not a valid UUID; got {response['id']}")


async def test_invalid(
tmp_path: str,
rest_client: RESTClient,
Expand Down
30 changes: 29 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from typing import Dict, List

from mlserver import MLModel
from mlserver.types import InferenceRequest, InferenceResponse, Parameters
from mlserver.types import (
InferenceRequest,
InferenceResponse,
ResponseOutput,
Parameters,
)
from mlserver.codecs import NumpyCodec, decode_args, StringCodec
from mlserver.handlers.custom import custom_handler
from mlserver.errors import MLServerError
Expand Down Expand Up @@ -100,3 +105,26 @@ async def predict(self, inference_request: InferenceRequest) -> InferenceRespons
StringCodec.encode_output("sklearn_version", [self._sklearn_version]),
],
)


class EchoModel(MLModel):
async def load(self) -> bool:
print("Echo Model Initialized")
return await super().load()

async def predict(self, payload: InferenceRequest) -> InferenceResponse:
return InferenceResponse(
id=payload.id,
model_name=self.name,
model_version=self.version,
outputs=[
ResponseOutput(
name=input.name,
shape=input.shape,
datatype=input.datatype,
data=input.data,
parameters=input.parameters,
)
for input in payload.inputs
],
)
1 change: 1 addition & 0 deletions tests/testdata/batch_processing/bytes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"inputs":[{"name":"input-0","shape":[1,3],"datatype":"BYTES","data":["a","b","c"]}]}

0 comments on commit c57b67e

Please sign in to comment.