@@ -221,6 +221,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
221221 } else if (name == " set_input" ) {
222222 return PackedFunc (
223223 [sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) { SetInput (args[0 ], args, 1 ); });
224+ } else if (name == " set_input_with_index" ) {
225+ return PackedFunc (
226+ [sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) { SetInputWithIndex (args[0 ], args); });
224227 } else if (name == " load_late_bound_consts" ) {
225228 return PackedFunc ([this ](TVMArgs args, TVMRetValue* rv) {
226229 CHECK_EQ (args.size (), 1 );
@@ -267,6 +270,47 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
267270 inputs_.emplace (func_name, func_args);
268271}
269272
273+ void VirtualMachine::SetInputWithIndex (std::string func_name, TVMArgs args) {
274+ ICHECK (exec_) << " The executable is not created yet." ;
275+ auto gvit = exec_->global_map .find (func_name);
276+ ICHECK (gvit != exec_->global_map .end ()) << " Cannot find function " << func_name;
277+ auto func_index = gvit->second ;
278+ const auto & vm_func = exec_->functions [func_index];
279+ const auto & param_names = vm_func.params ;
280+ ICHECK_EQ (args.size (), 3 )
281+ << " The expected number of arguments is 3 (func_name, index, tensor)" ;
282+ // TODO(vvchernov): Looks like it should be checked earlier and in other place
283+ ICHECK_EQ (param_names.size (), vm_func.param_device_indexes .size ())
284+ << " The number of provided parameters doesn't match the number of assigned devices" ;
285+ if (inputs_.count (func_name)) {
286+ ICHECK_EQ (inputs_[func_name].size (), param_names.size ())
287+ << " The size of function" << func_name << " doesn't match the number of provided parameters" ;
288+ } else {
289+ std::vector<ObjectRef> func_args (param_names.size ());
290+ inputs_.emplace (func_name, func_args);
291+ }
292+ ICHECK_EQ (args[1 ].type_code (), kTVMArgInt )
293+ << " The second argument doesn't match integer index" ;
294+ int inp_index = args[1 ];
295+ auto & input_tensors = inputs_[func_name];
296+ Device dev = GetDevice (vm_func.param_device_indexes [inp_index]);
297+
298+ if (args[2 ].type_code () == kTVMDLTensorHandle ) {
299+ // Automatically convert input DLTensors to NDArray
300+ DLTensor* tensor = args[2 ];
301+ std::vector<int64_t > shape;
302+ for (int64_t i = 0 ; i < tensor->ndim ; i++) {
303+ shape.push_back (tensor->shape [i]);
304+ }
305+ NDArray ary = NDArray::Empty (shape, tensor->dtype , dev);
306+ ary.CopyFrom (tensor);
307+ input_tensors[inp_index] = ary;
308+ } else {
309+ ObjectRef obj = CopyTo (args[2 ], dev);
310+ input_tensors[inp_index] = obj;
311+ }
312+ }
313+
270314inline Device VirtualMachine::GetDevice (Index device_index) const {
271315 ICHECK_GE (devices_.size (), device_index) << " invalid device index: " << device_index;
272316 return devices_[device_index];
0 commit comments