Skip to content

Commit 1b705f8

Browse files
author
Valery Chernov
committed
set_input_with_index was implemented for VM
1 parent f583a70 commit 1b705f8

File tree

2 files changed

+52
-0
lines changed
  • include/tvm/runtime/vm
  • src/runtime/vm

2 files changed

+52
-0
lines changed

include/tvm/runtime/vm/vm.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,15 @@ class VirtualMachine : public runtime::ModuleNode {
270270
*/
271271
void SetInput(std::string name, TVMArgs args, int offset);
272272

273+
/*!
274+
* \brief Set input tensor with index to a function.
275+
* \param name The function name
276+
* \param args args[1:] are two arguments (index, tensor) to the
277+
* function. If the tensor is not of the correct device for the function,
278+
* they will be copied to the device.
279+
*/
280+
void SetInputWithIndex(std::string name, TVMArgs args);
281+
273282
/*!
274283
* \brief Internal hook for profiling the start of an op.
275284
*

src/runtime/vm/vm.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
221221
} else if (name == "set_input") {
222222
return PackedFunc(
223223
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); });
224+
} else if (name == "set_input_with_index") {
225+
return PackedFunc(
226+
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInputWithIndex(args[0], args); });
224227
} else if (name == "load_late_bound_consts") {
225228
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
226229
CHECK_EQ(args.size(), 1);
@@ -267,6 +270,46 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
267270
inputs_.emplace(func_name, func_args);
268271
}
269272

273+
void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) {
274+
ICHECK(exec_) << "The executable is not created yet.";
275+
auto gvit = exec_->global_map.find(func_name);
276+
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
277+
auto func_index = gvit->second;
278+
const auto& vm_func = exec_->functions[func_index];
279+
const auto& param_names = vm_func.params;
280+
ICHECK_EQ(args.size(), 3) << "The expected number of arguments is 3 (func_name, index, tensor)";
281+
// TODO(vvchernov): Looks like it should be checked earlier and in other place
282+
ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size())
283+
<< "The number of provided parameters doesn't match the number of assigned devices";
284+
if (inputs_.count(func_name)) {
285+
ICHECK_EQ(inputs_[func_name].size(), param_names.size())
286+
<< "The size of function" << func_name
287+
<< " doesn't match the number of provided parameters";
288+
} else {
289+
std::vector<ObjectRef> func_args(param_names.size());
290+
inputs_.emplace(func_name, func_args);
291+
}
292+
ICHECK_EQ(args[1].type_code(), kTVMArgInt) << "The second argument doesn't match integer index";
293+
int inp_index = args[1];
294+
auto& input_tensors = inputs_[func_name];
295+
Device dev = GetDevice(vm_func.param_device_indexes[inp_index]);
296+
297+
if (args[2].type_code() == kTVMDLTensorHandle) {
298+
// Automatically convert input DLTensors to NDArray
299+
DLTensor* tensor = args[2];
300+
std::vector<int64_t> shape;
301+
for (int64_t i = 0; i < tensor->ndim; i++) {
302+
shape.push_back(tensor->shape[i]);
303+
}
304+
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
305+
ary.CopyFrom(tensor);
306+
input_tensors[inp_index] = ary;
307+
} else {
308+
ObjectRef obj = CopyTo(args[2], dev);
309+
input_tensors[inp_index] = obj;
310+
}
311+
}
312+
270313
inline Device VirtualMachine::GetDevice(Index device_index) const {
271314
ICHECK_GE(devices_.size(), device_index) << "invalid device index: " << device_index;
272315
return devices_[device_index];

0 commit comments

Comments
 (0)