Skip to content

Commit b630e05

Browse files
authored
Add extra params to MLModel (#783)
1 parent f7b35b8 commit b630e05

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

src/viam/services/mlmodel/client.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Mapping, Optional
22

33
from grpclib.client import Channel
44
from numpy.typing import NDArray
55

66
from viam.proto.service.mlmodel import InferRequest, InferResponse, MetadataRequest, MetadataResponse, MLModelServiceStub
77
from viam.resource.rpc_client_base import ReconfigurableResourceRPCClientBase
88
from viam.services.mlmodel.utils import flat_tensors_to_ndarrays, ndarrays_to_flat_tensors
9+
from viam.utils import ValueTypes, dict_to_struct
910

1011
from .mlmodel import Metadata, MLModel
1112

@@ -16,14 +17,21 @@ def __init__(self, name: str, channel: Channel):
1617
self.client = MLModelServiceStub(channel)
1718
super().__init__(name)
1819

19-
async def infer(self, input_tensors: Dict[str, NDArray], *, timeout: Optional[float] = None, **kwargs) -> Dict[str, NDArray]:
20+
async def infer(
21+
self,
22+
input_tensors: Dict[str, NDArray],
23+
*,
24+
extra: Optional[Mapping[str, ValueTypes]] = None,
25+
timeout: Optional[float] = None,
26+
**kwargs,
27+
) -> Dict[str, NDArray]:
2028
md = kwargs.get("metadata", self.Metadata()).proto
21-
request = InferRequest(name=self.name, input_tensors=ndarrays_to_flat_tensors(input_tensors))
29+
request = InferRequest(name=self.name, input_tensors=ndarrays_to_flat_tensors(input_tensors), extra=dict_to_struct(extra))
2230
response: InferResponse = await self.client.Infer(request, timeout=timeout, metadata=md)
2331
return flat_tensors_to_ndarrays(response.output_tensors)
2432

25-
async def metadata(self, *, timeout: Optional[float] = None, **kwargs) -> Metadata:
33+
async def metadata(self, *, extra: Optional[Mapping[str, ValueTypes]] = None, timeout: Optional[float] = None, **kwargs) -> Metadata:
2634
md = kwargs.get("metadata", self.Metadata()).proto
27-
request = MetadataRequest(name=self.name)
35+
request = MetadataRequest(name=self.name, extra=dict_to_struct(extra))
2836
response: MetadataResponse = await self.client.Metadata(request, timeout=timeout, metadata=md)
2937
return response.metadata

src/viam/services/mlmodel/mlmodel.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import abc
2-
from typing import Dict, Final, Optional
2+
from typing import Dict, Final, Optional, Mapping
33

44
from numpy.typing import NDArray
55

66
from viam.proto.service.mlmodel import Metadata
77
from viam.resource.types import RESOURCE_NAMESPACE_RDK, RESOURCE_TYPE_SERVICE, Subtype
8+
from viam.utils import ValueTypes
89

910
from ..service_base import ServiceBase
1011

@@ -25,7 +26,13 @@ class MLModel(ServiceBase):
2526
)
2627

2728
@abc.abstractmethod
28-
async def infer(self, input_tensors: Dict[str, NDArray], *, timeout: Optional[float]) -> Dict[str, NDArray]:
29+
async def infer(
30+
self,
31+
input_tensors: Dict[str, NDArray],
32+
*,
33+
extra: Optional[Mapping[str, ValueTypes]] = None,
34+
timeout: Optional[float] = None,
35+
) -> Dict[str, NDArray]:
2936
"""Take an already ordered input tensor as an array, make an inference on the model, and return an output tensor map.
3037
3138
::
@@ -50,7 +57,7 @@ async def infer(self, input_tensors: Dict[str, NDArray], *, timeout: Optional[fl
5057
...
5158

5259
@abc.abstractmethod
53-
async def metadata(self, *, timeout: Optional[float]) -> Metadata:
60+
async def metadata(self, *, extra: Optional[Mapping[str, ValueTypes]] = None, timeout: Optional[float] = None) -> Metadata:
5461
"""Get the metadata (such as name, type, expected tensor/array shape, inputs, and outputs) associated with the ML model.
5562
5663
::

src/viam/services/mlmodel/service.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from viam.proto.service.mlmodel import InferRequest, InferResponse, MetadataRequest, MetadataResponse, MLModelServiceBase
44
from viam.resource.rpc_service_base import ResourceRPCServiceBase
55
from viam.services.mlmodel.utils import flat_tensors_to_ndarrays, ndarrays_to_flat_tensors
6+
from viam.utils import struct_to_dict
67

78
from .mlmodel import MLModel
89

@@ -19,8 +20,9 @@ async def Infer(self, stream: Stream[InferRequest, InferResponse]) -> None:
1920
assert request is not None
2021
name = request.name
2122
mlmodel = self.get_resource(name)
23+
extra = struct_to_dict(request.extra)
2224
timeout = stream.deadline.time_remaining() if stream.deadline else None
23-
output_tensors = await mlmodel.infer(input_tensors=flat_tensors_to_ndarrays(request.input_tensors), timeout=timeout)
25+
output_tensors = await mlmodel.infer(input_tensors=flat_tensors_to_ndarrays(request.input_tensors), extra=extra, timeout=timeout)
2426
response = InferResponse(output_tensors=ndarrays_to_flat_tensors(output_tensors))
2527
await stream.send_message(response)
2628

@@ -29,7 +31,8 @@ async def Metadata(self, stream: Stream[MetadataRequest, MetadataResponse]) -> N
2931
assert request is not None
3032
name = request.name
3133
mlmodel = self.get_resource(name)
34+
extra = struct_to_dict(request.extra)
3235
timeout = stream.deadline.time_remaining() if stream.deadline else None
33-
metadata = await mlmodel.metadata(timeout=timeout)
36+
metadata = await mlmodel.metadata(extra=extra, timeout=timeout)
3437
response = MetadataResponse(metadata=metadata)
3538
await stream.send_message(response)

tests/mocks/services.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -515,13 +515,19 @@ def __init__(self, name: str):
515515

516516
super().__init__(name)
517517

518-
async def infer(self, input_tensors: Dict[str, NDArray], *, timeout: Optional[float] = None) -> Dict[str, NDArray]:
518+
async def infer(
519+
self,
520+
input_tensors: Dict[str, NDArray],
521+
*,
522+
extra: Optional[Mapping[str, ValueTypes]] = None,
523+
timeout: Optional[float] = None,
524+
) -> Dict[str, NDArray]:
519525
self.timeout = timeout
520526
request_data = ndarrays_to_flat_tensors(input_tensors)
521527
response_data = flat_tensors_to_ndarrays(request_data)
522528
return response_data
523529

524-
async def metadata(self, *, timeout: Optional[float] = None) -> Metadata:
530+
async def metadata(self, *, extra: Optional[Mapping[str, ValueTypes]] = None, timeout: Optional[float] = None) -> Metadata:
525531
self.timeout = timeout
526532
return self.META
527533

0 commit comments

Comments
 (0)