Skip to content

Commit bcd3847

Browse files
author
Valery Chernov
committed
add set_one_input method to python API of VirtualMachine
1 parent 5056a7d commit bcd3847

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

python/tvm/runtime/vm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def __init__(self, exe, device, memory_cfg=None):
380380
self._get_num_outputs = self.module["get_num_outputs"]
381381
self._get_input_index = self.module["get_input_index"]
382382
self._set_input = self.module["set_input"]
383+
self._set_one_input = self.module["set_one_input"]
383384
self._setup_device(device, memory_cfg)
384385

385386
def _setup_device(self, dev, memory_cfg):
@@ -450,6 +451,24 @@ def set_input(self, func_name, *args, **kwargs):
450451
cargs = convert(args)
451452
self._set_input(func_name, *cargs)
452453

454+
def set_one_input(self, func_name, **kwargs):
455+
"""Set the one input tensor with tag to a function.
456+
457+
Parameters
458+
----------
459+
func_name : str
460+
The name of the function.
461+
462+
kwargs: dict of str or int to tvm.runtime.NDArray or np.ndarray
463+
Named arguments to the function.
464+
"""
465+
assert len(kwargs) == 1
466+
tag = kwargs.keys()[0]
467+
if isinstance(tag, str):
468+
func_params = self._exec.get_function_params(func_name)
469+
assert tag in func_params
470+
self._set_one_input(func_name, tag, kwargs[tag])
471+
453472
def invoke(self, func_name, *args, **kwargs):
454473
"""Invoke a function.
455474

0 commit comments

Comments
 (0)