Skip to content

Commit ba6533d

Browse files
author
Valery Chernov
committed
add set_one_input method to python API of VirtualMachine
1 parent 5056a7d commit ba6533d

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

python/tvm/runtime/vm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def __init__(self, exe, device, memory_cfg=None):
380380
self._get_num_outputs = self.module["get_num_outputs"]
381381
self._get_input_index = self.module["get_input_index"]
382382
self._set_input = self.module["set_input"]
383+
self._set_one_input = self.module["set_one_input"]
383384
self._setup_device(device, memory_cfg)
384385

385386
def _setup_device(self, dev, memory_cfg):
@@ -450,6 +451,24 @@ def set_input(self, func_name, *args, **kwargs):
450451
cargs = convert(args)
451452
self._set_input(func_name, *cargs)
452453

454+
def set_one_input(self, func_name, **kwargs):
455+
"""Set the one input tensor with tag to a function.
456+
457+
Parameters
458+
----------
459+
func_name : str
460+
The name of the function.
461+
462+
kwargs: dict of str or int to tvm.runtime.NDArray or np.ndarray
463+
Named arguments to the function.
464+
"""
465+
assert len(kwargs) == 1
466+
tag = kwargs.keys()[0]
467+
if isinstance(tag, str):
468+
func_params = self._exec.get_function_params(func_name)
469+
assert tag in func_params
470+
self._set_one_input(func_name, tag, kwargs[tag])
471+
453472
def invoke(self, func_name, *args, **kwargs):
454473
"""Invoke a function.
455474

src/runtime/vm/vm.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
244244
inputs_.emplace(func_name, func_args);
245245
}
246246

247-
void VirtualMachine::SetOneInput(std::string func_name, const TVMArgValue& tag, const TVMArgValue& tensor) {
247+
void VirtualMachine::SetOneInput(std::string func_name, const TVMArgValue& tag,
248+
const TVMArgValue& tensor) {
248249
const auto& vm_func = CheckAndGetVMFunction(func_name);
249250
size_t params_num = vm_func.params.size();
250251

0 commit comments

Comments
 (0)