Skip to content

Commit 53eb213

Browse files
author
Valery Chernov
committed
set_input_with_index was implemented for VM
1 parent f583a70 commit 53eb213

File tree

2 files changed

+53
-0
lines changed
  • include/tvm/runtime/vm
  • src/runtime/vm

2 files changed

+53
-0
lines changed

include/tvm/runtime/vm/vm.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,15 @@ class VirtualMachine : public runtime::ModuleNode {
270270
*/
271271
void SetInput(std::string name, TVMArgs args, int offset);
272272

273+
/*!
274+
* \brief Set input tensor with index to a function.
275+
* \param name The function name
276+
* \param args args[1:] are two arguments (index, tensor) to the
277+
* function. If the tensor is not of the correct device for the function,
278+
* they will be copied to the device.
279+
*/
280+
void SetInputWithIndex(std::string name, TVMArgs args);
281+
273282
/*!
274283
* \brief Internal hook for profiling the start of an op.
275284
*

src/runtime/vm/vm.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
221221
} else if (name == "set_input") {
222222
return PackedFunc(
223223
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); });
224+
} else if (name == "set_input_with_index") {
225+
return PackedFunc(
226+
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInputWithIndex(args[0], args); });
224227
} else if (name == "load_late_bound_consts") {
225228
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
226229
CHECK_EQ(args.size(), 1);
@@ -267,6 +270,47 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
267270
inputs_.emplace(func_name, func_args);
268271
}
269272

273+
void VirtualMachine::SetInputWithIndex(std::string func_name, TVMArgs args) {
274+
ICHECK(exec_) << "The executable is not created yet.";
275+
auto gvit = exec_->global_map.find(func_name);
276+
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
277+
auto func_index = gvit->second;
278+
const auto& vm_func = exec_->functions[func_index];
279+
const auto& param_names = vm_func.params;
280+
ICHECK_EQ(args.size(), 3)
281+
<< "The expected number of arguments is 3 (func_name, index, tensor)";
282+
// TODO(vvchernov): Looks like it should be checked earlier and in other place
283+
ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size())
284+
<< "The number of provided parameters doesn't match the number of assigned devices";
285+
if (inputs_.count(func_name)) {
286+
ICHECK_EQ(inputs_[func_name].size(), param_names.size())
287+
<< "The size of function" << func_name << " doesn't match the number of provided parameters";
288+
} else {
289+
std::vector<ObjectRef> func_args(param_names.size());
290+
inputs_.emplace(func_name, func_args);
291+
}
292+
ICHECK_EQ(args[1].type_code(), kTVMArgInt)
293+
<< "The second argument doesn't match integer index";
294+
int inp_index = args[1];
295+
auto& input_tensors = inputs_[func_name];
296+
Device dev = GetDevice(vm_func.param_device_indexes[inp_index]);
297+
298+
if (args[2].type_code() == kTVMDLTensorHandle) {
299+
// Automatically convert input DLTensors to NDArray
300+
DLTensor* tensor = args[2];
301+
std::vector<int64_t> shape;
302+
for (int64_t i = 0; i < tensor->ndim; i++) {
303+
shape.push_back(tensor->shape[i]);
304+
}
305+
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
306+
ary.CopyFrom(tensor);
307+
input_tensors[inp_index] = ary;
308+
} else {
309+
ObjectRef obj = CopyTo(args[2], dev);
310+
input_tensors[inp_index] = obj;
311+
}
312+
}
313+
270314
inline Device VirtualMachine::GetDevice(Index device_index) const {
271315
ICHECK_GE(devices_.size(), device_index) << "invalid device index: " << device_index;
272316
return devices_[device_index];

0 commit comments

Comments
 (0)