@@ -211,12 +211,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
211211 } else if (name == " set_input" ) {
212212 return PackedFunc (
213213 [sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) { SetInput (args[0 ], args, 1 ); });
214- } else if (name == " set_input_with_index " ) {
214+ } else if (name == " set_one_input " ) {
215215 return PackedFunc (
216- [sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) { SetInputWithIndex (args[0 ], args); });
217- } else if (name == " set_input_with_name" ) {
218- return PackedFunc (
219- [sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) { SetInputWithName (args[0 ], args); });
216+ [sptr_to_self, this ](TVMArgs args, TVMRetValue* rv) { SetOneInputTensor (args[0 ], args); });
220217 } else if (name == " load_late_bound_consts" ) {
221218 return PackedFunc ([this ](TVMArgs args, TVMRetValue* rv) {
222219 CHECK_EQ (args.size (), 1 );
@@ -244,25 +241,19 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
244241 inputs_.emplace (func_name, func_args);
245242}
246243
247- void VirtualMachine::SetInputWithIndex (std::string func_name, TVMArgs args) {
244+ void VirtualMachine::SetOneInputTensor (std::string func_name, TVMArgs args) {
245+ ICHECK_EQ (args.size (), 3 ) << " The expected number of arguments is 3 (func_name, index or name, tensor)" ;
248246 const auto & vm_func = checkAndGetVMFunction (func_name);
249247 size_t params_num = vm_func.params .size ();
250- ICHECK_EQ (args.size (), 3 ) << " The expected number of arguments is 3 (func_name, index, tensor)" ;
251- ICHECK_EQ (args[1 ].type_code (), kTVMArgInt ) << " The second argument type doesn't match integer" ;
252- int inp_index = args[1 ];
253- ICHECK_LT (inp_index, params_num);
254248
255- createInputsOrCheckSize (func_name, params_num);
256- Device dev = GetDevice (vm_func.param_device_indexes [inp_index]);
257- SetInputTensorWithIndex (inputs_[func_name], args[2 ], inp_index, dev);
258- }
259-
260- void VirtualMachine::SetInputWithName (std::string func_name, TVMArgs args) {
261- const auto & vm_func = checkAndGetVMFunction (func_name);
262- size_t params_num = vm_func.params .size ();
263- ICHECK_EQ (args.size (), 3 ) << " The expected number of arguments is 3 (func_name, name, tensor)" ;
264- ICHECK_EQ (args[1 ].type_code (), kTVMStr ) << " The second argument type doesn't match string" ;
265- int inp_index = int (getInputIndexFromName (vm_func.params , args[1 ]));
249+ int inp_index;
250+ if (args[1 ].type_code () == kTVMArgInt ) {
251+ inp_index = args[1 ];
252+ } else if (args[1 ].type_code () == kTVMStr ) {
253+ inp_index = int (getInputIndexFromName (vm_func.params , args[1 ]));
254+ } else {
255+ LOG (FATAL) << " The second argument type (" << args[1 ].type_code () << " ) doesn't match integer or string" ;
256+ }
266257 ICHECK_LT (inp_index, params_num);
267258
268259 createInputsOrCheckSize (func_name, params_num);
0 commit comments