@@ -380,6 +380,7 @@ def __init__(self, exe, device, memory_cfg=None):
380380 self ._get_num_outputs = self .module ["get_num_outputs" ]
381381 self ._get_input_index = self .module ["get_input_index" ]
382382 self ._set_input = self .module ["set_input" ]
383+ self ._set_one_input = self .module ["set_one_input" ]
383384 self ._setup_device (device , memory_cfg )
384385
385386 def _setup_device (self , dev , memory_cfg ):
@@ -450,6 +451,24 @@ def set_input(self, func_name, *args, **kwargs):
450451 cargs = convert (args )
451452 self ._set_input (func_name , * cargs )
452453
454+ def set_one_input (self , func_name , ** kwargs ):
455+ """Set the one input tensor with tag to a function.
456+
457+ Parameters
458+ ----------
459+ func_name : str
460+ The name of the function.
461+
462+ kwargs: dict of str or int to tvm.runtime.NDArray or np.ndarray
463+ Named arguments to the function.
464+ """
465+ assert len (kwargs ) == 1
466+ tag = kwargs .keys ()[0 ]
467+ if isinstance (tag , str ):
468+ func_params = self ._exec .get_function_params (func_name )
469+ assert tag in func_params
470+ self ._set_one_input (func_name , tag , kwargs [tag ])
471+
453472 def invoke (self , func_name , * args , ** kwargs ):
454473 """Invoke a function.
455474
0 commit comments