@@ -221,6 +221,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
221221 } else if (name == " set_input" ) {
222222 return PackedFunc (
223223 [sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) { SetInput (args[0 ], args, 1 ); });
224+ } else if (name == " set_input_with_index" ) {
225+ return PackedFunc (
226+ [sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) { SetInputWithIndex (args[0 ], args); });
224227 } else if (name == " load_late_bound_consts" ) {
225228 return PackedFunc ([this ](TVMArgs args, TVMRetValue* rv) {
226229 CHECK_EQ (args.size (), 1 );
@@ -267,6 +270,46 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
267270 inputs_.emplace (func_name, func_args);
268271}
269272
273+ void VirtualMachine::SetInputWithIndex (std::string func_name, TVMArgs args) {
274+ ICHECK (exec_) << " The executable is not created yet." ;
275+ auto gvit = exec_->global_map .find (func_name);
276+ ICHECK (gvit != exec_->global_map .end ()) << " Cannot find function " << func_name;
277+ auto func_index = gvit->second ;
278+ const auto & vm_func = exec_->functions [func_index];
279+ const auto & param_names = vm_func.params ;
280+ ICHECK_EQ (args.size (), 3 ) << " The expected number of arguments is 3 (func_name, index, tensor)" ;
281+ // TODO(vvchernov): Looks like it should be checked earlier and in other place
282+ ICHECK_EQ (param_names.size (), vm_func.param_device_indexes .size ())
283+ << " The number of provided parameters doesn't match the number of assigned devices" ;
284+ if (inputs_.count (func_name)) {
285+ ICHECK_EQ (inputs_[func_name].size (), param_names.size ())
286+ << " The size of function" << func_name
287+ << " doesn't match the number of provided parameters" ;
288+ } else {
289+ std::vector<ObjectRef> func_args (param_names.size ());
290+ inputs_.emplace (func_name, func_args);
291+ }
292+ ICHECK_EQ (args[1 ].type_code (), kTVMArgInt ) << " The second argument doesn't match integer index" ;
293+ int inp_index = args[1 ];
294+ auto & input_tensors = inputs_[func_name];
295+ Device dev = GetDevice (vm_func.param_device_indexes [inp_index]);
296+
297+ if (args[2 ].type_code () == kTVMDLTensorHandle ) {
298+ // Automatically convert input DLTensors to NDArray
299+ DLTensor* tensor = args[2 ];
300+ std::vector<int64_t > shape;
301+ for (int64_t i = 0 ; i < tensor->ndim ; i++) {
302+ shape.push_back (tensor->shape [i]);
303+ }
304+ NDArray ary = NDArray::Empty (shape, tensor->dtype , dev);
305+ ary.CopyFrom (tensor);
306+ input_tensors[inp_index] = ary;
307+ } else {
308+ ObjectRef obj = CopyTo (args[2 ], dev);
309+ input_tensors[inp_index] = obj;
310+ }
311+ }
312+
270313inline Device VirtualMachine::GetDevice (Index device_index) const {
271314 ICHECK_GE (devices_.size (), device_index) << " invalid device index: " << device_index;
272315 return devices_[device_index];
0 commit comments