Skip to content

Commit cfecf52

Browse files
author
Valery Chernov
committed
add getInputIndexFromName. lint fix
1 parent e9eb686 commit cfecf52

File tree

2 files changed

+16
-12
lines changed
  • include/tvm/runtime/vm
  • src/runtime/vm

2 files changed

+16
-12
lines changed

include/tvm/runtime/vm/vm.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,13 @@ class VirtualMachine : public runtime::ModuleNode {
296296
virtual void OpStopHook();
297297

298298
private:
299+
int64_t getInputIndexFromName(const std::string& input_name,
300+
const std::string& func_name) const;
299301
const VMFunction& checkAndGetVMFunction(const std::string& func_name) const;
300302
void SetInputTensorWithIndex(std::vector<ObjectRef>& tensors,
301303
const TVMArgValue& tensor,
302304
int index,
303-
Device dev);
305+
Device dev); // NOLINT(*)
304306

305307
protected:
306308
/*! \brief The virtual machine's packed function table. */

src/runtime/vm/vm.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
190190
} else if (name == "get_input_index") {
191191
return TypedPackedFunc<int64_t(std::string, std::string)>(
192192
[this](std::string input_name, std::string func_name) {
193-
auto gvit = exec_->global_map.find(func_name);
194-
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
195-
auto func_index = gvit->second;
196-
const auto& vm_func = exec_->functions[func_index];
197-
const auto& param_names = vm_func.params;
198-
for (uint64_t i = 0; i < param_names.size(); i++) {
199-
if (input_name == param_names[i]) {
200-
return static_cast<int64_t>(i);
201-
}
202-
}
203-
return static_cast<int64_t>(-1);
193+
return getInputIndexFromName(input_name, func_name);
204194
});
205195
} else if (name == "init") {
206196
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
@@ -277,6 +267,18 @@ void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) {
277267
SetInputTensorWithIndex(inputs_[func_name], args[2], inp_index, dev);
278268
}
279269

270+
int64_t VirtualMachine::getInputIndexFromName(const std::string& input_name,
271+
const std::string& func_name) const {
272+
const auto& vm_func = checkAndGetVMFunction(func_name);
273+
const auto& param_names = vm_func.params;
274+
for (uint64_t i = 0; i < param_names.size(); i++) {
275+
if (input_name == param_names[i]) {
276+
return static_cast<int64_t>(i);
277+
}
278+
}
279+
return static_cast<int64_t>(-1);
280+
}
281+
280282
const VMFunction& VirtualMachine::checkAndGetVMFunction(const std::string& func_name) const {
281283
ICHECK(exec_) << "The executable is not created yet.";
282284
auto gvit = exec_->global_map.find(func_name);

0 commit comments

Comments
 (0)