Skip to content

Commit 18b68a2

Browse files
author
Valery Chernov
committed
join SetInputWithIndex and SetInputWithName to SetOneInputTensor (set_one_input) to VM API, the joined methods were removed
1 parent 602a192 commit 18b68a2

File tree

2 files changed

+15
-33
lines changed
  • include/tvm/runtime/vm
  • src/runtime/vm

2 files changed

+15
-33
lines changed

include/tvm/runtime/vm/vm.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -274,22 +274,13 @@ class VirtualMachine : public runtime::ModuleNode {
274274
void SetInput(std::string name, TVMArgs args, int offset);
275275

276276
/*!
277-
* \brief Set input tensor with index to a function.
277+
* \brief Set one input tensor with index or name to a function.
278278
* \param name The function name
279-
* \param args args[1:] are two arguments (index, tensor) to the
279+
* \param args args[1:] are two arguments (index or name, tensor) to the
280280
* function. If the tensor is not of the correct device for the function,
281281
* they will be copied to the device.
282282
*/
283-
void SetInputWithIndex(std::string name, TVMArgs args);
284-
285-
/*!
286-
* \brief Set input tensor with name to a function.
287-
* \param name The function name
288-
* \param args args[1:] are two arguments (name, tensor) to the
289-
* function. If the tensor is not of the correct device for the function,
290-
* they will be copied to the device.
291-
*/
292-
void SetInputWithName(std::string name, TVMArgs args);
283+
void SetOneInputTensor(std::string func_name, TVMArgs args);
293284

294285
/*!
295286
* \brief Internal hook for profiling the start of an op.

src/runtime/vm/vm.cc

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
211211
} else if (name == "set_input") {
212212
return PackedFunc(
213213
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); });
214-
} else if (name == "set_input_with_index") {
214+
} else if (name == "set_one_input") {
215215
return PackedFunc(
216-
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInputWithIndex(args[0], args); });
217-
} else if (name == "set_input_with_name") {
218-
return PackedFunc(
219-
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInputWithName(args[0], args); });
216+
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetOneInputTensor(args[0], args); });
220217
} else if (name == "load_late_bound_consts") {
221218
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
222219
CHECK_EQ(args.size(), 1);
@@ -244,25 +241,19 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
244241
inputs_.emplace(func_name, func_args);
245242
}
246243

247-
void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) {
244+
void VirtualMachine::SetOneInputTensor(std::string func_name, TVMArgs args) {
245+
ICHECK_EQ(args.size(), 3) << "The expected number of arguments is 3 (func_name, index or name, tensor)";
248246
const auto& vm_func = checkAndGetVMFunction(func_name);
249247
size_t params_num = vm_func.params.size();
250-
ICHECK_EQ(args.size(), 3) << "The expected number of arguments is 3 (func_name, index, tensor)";
251-
ICHECK_EQ(args[1].type_code(), kTVMArgInt) << "The second argument type doesn't match integer";
252-
int inp_index = args[1];
253-
ICHECK_LT(inp_index, params_num);
254248

255-
createInputsOrCheckSize(func_name, params_num);
256-
Device dev = GetDevice(vm_func.param_device_indexes[inp_index]);
257-
SetInputTensorWithIndex(inputs_[func_name], args[2], inp_index, dev);
258-
}
259-
260-
void VirtualMachine::SetInputWithName(std::string func_name, TVMArgs args) {
261-
const auto& vm_func = checkAndGetVMFunction(func_name);
262-
size_t params_num = vm_func.params.size();
263-
ICHECK_EQ(args.size(), 3) << "The expected number of arguments is 3 (func_name, name, tensor)";
264-
ICHECK_EQ(args[1].type_code(), kTVMStr) << "The second argument type doesn't match string";
265-
int inp_index = int(getInputIndexFromName(vm_func.params, args[1]));
249+
int inp_index;
250+
if (args[1].type_code() == kTVMArgInt) {
251+
inp_index = args[1];
252+
} else if (args[1].type_code() == kTVMStr) {
253+
inp_index = int(getInputIndexFromName(vm_func.params, args[1]));
254+
} else {
255+
LOG(FATAL) << "The second argument type (" << args[1].type_code() << ") doesn't match integer or string";
256+
}
266257
ICHECK_LT(inp_index, params_num);
267258

268259
createInputsOrCheckSize(func_name, params_num);

0 commit comments

Comments
 (0)