Skip to content

Commit 05c583d

Browse files
vvchernovValery Chernov
authored andcommitted
[VirtualMachine] new method allowing to set one input tensor by its index or name (apache#10293)
* set_input_with_index was implemented for VM * clean code * add getInputIndexFromName. add function descriptions. lint fix * fix lint * transfer comparison of parameter names number and assigned devices number to VMFunction constructor * add GetVMFunctionWithName to Executable API * clean code * add SetInputWithName (set_input_with_name) to VM API * join SetInputWithIndex and SetInputWithName to SetOneInputTensor (set_one_input) to VM API, the joined methods were removed * fix lint * some fixes after review * add set_one_input method to python API of VirtualMachine * pytests for set_input and set_one_input methods of VirtualMachine were implemented and checked * CI restart * construct simple model for pytests by relay instead of onnx tools (need for correct CI) Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
1 parent 653b6e0 commit 05c583d

File tree

6 files changed

+322
-54
lines changed

6 files changed

+322
-54
lines changed

include/tvm/runtime/vm/executable.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,13 @@ class Executable : public ModuleNode {
218218
*/
219219
void SetLib(const runtime::Module& lib);
220220

221+
/*!
222+
* \brief Get VMFunction.
223+
* \param func_name The function's name.
224+
* \return VMFunction.
225+
*/
226+
const VMFunction& GetVMFunctionWithName(const std::string& func_name) const;
227+
221228
/*!
222229
* \brief Get the arity of the VMFunction.
223230
* \param func Function name.

include/tvm/runtime/vm/vm.h

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ struct VMFunction {
9393
params(std::move(params)),
9494
instructions(std::move(instructions)),
9595
register_file_size(register_file_size),
96-
param_device_indexes(std::move(param_device_indexes)) {}
96+
param_device_indexes(std::move(param_device_indexes)) {
97+
ICHECK_EQ(params.size(), param_device_indexes.size());
98+
}
9799

98100
VMFunction() = default;
99101

@@ -270,6 +272,15 @@ class VirtualMachine : public runtime::ModuleNode {
270272
*/
271273
void SetInput(std::string name, TVMArgs args, int offset);
272274

275+
/*!
276+
* \brief Set one input tensor with index or name to a function.
277+
* \param name The function name.
278+
* \param tag index or name of the input tensor .
279+
* \param tensor the input tensor. If the tensor is not of the correct device for the function,
280+
* they will be copied to the device.
281+
*/
282+
void SetOneInput(std::string name, const TVMArgValue& tag, const TVMArgValue& tensor);
283+
273284
/*!
274285
* \brief Internal hook for profiling the start of an op.
275286
*
@@ -286,6 +297,48 @@ class VirtualMachine : public runtime::ModuleNode {
286297
*/
287298
virtual void OpStopHook();
288299

300+
private:
301+
/*!
302+
* \brief Get index of input tensor from its name.
303+
* \param func_name The function's name.
304+
* \param input_name The input tensor name.
305+
* \return The input tensor index.
306+
*/
307+
int64_t GetInputIndexFromVMFunction(const std::string& func_name,
308+
const std::string& input_name) const;
309+
310+
/*!
311+
* \brief Get index of input tensor from its name.
312+
* \param params parameter names.
313+
* \param input_name The input tensor name.
314+
* \return The input tensor index.
315+
*/
316+
int64_t GetInputIndexFromName(const std::vector<std::string>& params,
317+
const std::string& input_name) const;
318+
319+
/*!
320+
* \brief Check executable exists and get VM function from it.
321+
* \param func_name The function's name.
322+
* \return VM function.
323+
*/
324+
const VMFunction& CheckAndGetVMFunction(const std::string& func_name) const;
325+
326+
/*!
327+
* \brief Creats inputs_ field, if it exists check its size.
328+
* \param func_name The function's name.
329+
* \param size inputs_ field size.
330+
* \return VM function.
331+
*/
332+
void CreateInputsOrCheckSize(const std::string& func_name, size_t size);
333+
334+
/*!
335+
* \brief Set one input tensor with given index to set of input tensors if need copy to given
336+
* device. \param tensors the input tensors set (destination) \param tensor some tensor (not
337+
* neccessary DLTensor). \param index The input tensor index. \param dev device to copy if need.
338+
*/
339+
void SetInputTensorWithIndex(std::vector<ObjectRef>& tensors, // NOLINT(*)
340+
const TVMArgValue& tensor, int index, Device dev);
341+
289342
protected:
290343
/*! \brief The virtual machine's packed function table. */
291344
std::vector<PackedFunc> packed_funcs_;

python/tvm/runtime/vm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def __init__(self, exe, device, memory_cfg=None):
380380
self._get_num_outputs = self.module["get_num_outputs"]
381381
self._get_input_index = self.module["get_input_index"]
382382
self._set_input = self.module["set_input"]
383+
self._set_one_input = self.module["set_one_input"]
383384
self._setup_device(device, memory_cfg)
384385

385386
def _setup_device(self, dev, memory_cfg):
@@ -450,6 +451,30 @@ def set_input(self, func_name, *args, **kwargs):
450451
cargs = convert(args)
451452
self._set_input(func_name, *cargs)
452453

454+
def set_one_input(self, func_name, *args, **kwargs):
455+
"""Set the one input tensor with tag to a function.
456+
457+
Parameters
458+
----------
459+
func_name : str
460+
The name of the function.
461+
args : [str or int, tvm.runtime.NDArray]
462+
name or index of tensor and input tensor, optional
463+
kwargs: dict of str or int to tvm.runtime.NDArray, optional
464+
taged arguments to the function.
465+
Only args or kwargs should exist
466+
"""
467+
if kwargs:
468+
assert len(kwargs) == 1
469+
tag = next(iter(kwargs))
470+
if isinstance(tag, str):
471+
func_params = self._exec.get_function_params(func_name)
472+
assert tag in func_params
473+
self._set_one_input(func_name, tag, kwargs[tag])
474+
else:
475+
assert len(args) == 2
476+
self._set_one_input(func_name, args[0], args[1])
477+
453478
def invoke(self, func_name, *args, **kwargs):
454479
"""Invoke a function.
455480

src/runtime/vm/executable.cc

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,27 +109,20 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr<Obje
109109
}
110110
}
111111

112-
int Executable::GetFunctionArity(std::string func_name) const {
112+
const VMFunction& Executable::GetVMFunctionWithName(const std::string& func_name) const {
113113
auto it = global_map.find(func_name);
114-
if (it == global_map.end()) {
115-
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
116-
return -1;
117-
}
118-
const auto& func = functions[it->second];
114+
ICHECK(it != global_map.end()) << "Cannot find function " << func_name << " in executable";
115+
return functions[it->second];
116+
}
117+
118+
int Executable::GetFunctionArity(std::string func_name) const {
119+
const auto& func = GetVMFunctionWithName(func_name);
119120
return func.params.size();
120121
}
121122

122123
std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const {
123-
auto it = global_map.find(func_name);
124-
if (it == global_map.end()) {
125-
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
126-
return "";
127-
}
128-
const auto& func = functions[it->second];
129-
if (index > func.params.size()) {
130-
LOG(ERROR) << "Invalid parameter index";
131-
return "";
132-
}
124+
const auto& func = GetVMFunctionWithName(func_name);
125+
ICHECK_LT(index, func.params.size()) << "Invalid parameter index";
133126
return func.params[index];
134127
}
135128

src/runtime/vm/vm.cc

Lines changed: 85 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
190190
} else if (name == "get_input_index") {
191191
return TypedPackedFunc<int64_t(std::string, std::string)>(
192192
[this](std::string input_name, std::string func_name) {
193-
auto gvit = exec_->global_map.find(func_name);
194-
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
195-
auto func_index = gvit->second;
196-
const auto& vm_func = exec_->functions[func_index];
197-
const auto& param_names = vm_func.params;
198-
for (uint64_t i = 0; i < param_names.size(); i++) {
199-
if (input_name == param_names[i]) {
200-
return static_cast<int64_t>(i);
201-
}
202-
}
203-
return static_cast<int64_t>(-1);
193+
return GetInputIndexFromVMFunction(func_name, input_name);
204194
});
205195
} else if (name == "init") {
206196
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
@@ -221,6 +211,12 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
221211
} else if (name == "set_input") {
222212
return PackedFunc(
223213
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { SetInput(args[0], args, 1); });
214+
} else if (name == "set_one_input") {
215+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
216+
ICHECK_EQ(args.size(), 3) << "The expected number of arguments is 3 "
217+
<< "(func_name, index or name, tensor)";
218+
SetOneInput(args[0], args[1], args[2]);
219+
});
224220
} else if (name == "load_late_bound_consts") {
225221
return PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
226222
CHECK_EQ(args.size(), 1);
@@ -234,39 +230,91 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
234230
}
235231

236232
void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
237-
ICHECK(exec_) << "The executable is not created yet.";
238-
auto gvit = exec_->global_map.find(func_name);
239-
ICHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
240-
auto func_index = gvit->second;
241-
const auto& vm_func = exec_->functions[func_index];
242-
const auto& param_names = vm_func.params;
243-
ICHECK_EQ(args.size() - offset, param_names.size())
233+
const auto& vm_func = CheckAndGetVMFunction(func_name);
234+
size_t params_num = vm_func.params.size();
235+
ICHECK_EQ(args.size() - offset, params_num)
244236
<< "The number of provided parameters doesn't match the number of arguments";
245-
ICHECK_EQ(param_names.size(), vm_func.param_device_indexes.size())
246-
<< "The number of provided parameters doesn't match the number of assigned devices";
247-
std::vector<ObjectRef> func_args(param_names.size());
237+
std::vector<ObjectRef> func_args(params_num);
248238
for (int i = offset; i < args.size(); ++i) {
249-
Device dev = GetDevice(vm_func.param_device_indexes[i - offset]);
250-
251-
if (args[i].type_code() == kTVMDLTensorHandle) {
252-
// Automatically convert input DLTensors to NDArray
253-
DLTensor* tensor = args[i];
254-
std::vector<int64_t> shape;
255-
for (int64_t i = 0; i < tensor->ndim; i++) {
256-
shape.push_back(tensor->shape[i]);
257-
}
258-
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
259-
ary.CopyFrom(tensor);
260-
func_args[i - offset] = ary;
261-
} else {
262-
ObjectRef obj = CopyTo(args[i], dev);
263-
func_args[i - offset] = obj;
264-
}
239+
int index = i - offset;
240+
Device dev = GetDevice(vm_func.param_device_indexes[index]);
241+
SetInputTensorWithIndex(func_args, args[i], index, dev);
265242
}
266243
inputs_.erase(func_name);
267244
inputs_.emplace(func_name, func_args);
268245
}
269246

247+
void VirtualMachine::SetOneInput(std::string func_name, const TVMArgValue& tag,
248+
const TVMArgValue& tensor) {
249+
const auto& vm_func = CheckAndGetVMFunction(func_name);
250+
size_t params_num = vm_func.params.size();
251+
252+
int inp_index;
253+
if (tag.type_code() == kTVMArgInt) {
254+
inp_index = tag;
255+
} else if (tag.type_code() == kTVMStr) {
256+
inp_index = static_cast<int>(GetInputIndexFromName(vm_func.params, tag));
257+
} else {
258+
LOG(FATAL) << "The type of input tensor tag (" << tag.type_code()
259+
<< ") doesn't match integer or string";
260+
}
261+
ICHECK_LT(inp_index, params_num);
262+
263+
CreateInputsOrCheckSize(func_name, params_num);
264+
Device dev = GetDevice(vm_func.param_device_indexes[inp_index]);
265+
SetInputTensorWithIndex(inputs_[func_name], tensor, inp_index, dev);
266+
}
267+
268+
int64_t VirtualMachine::GetInputIndexFromVMFunction(const std::string& func_name,
269+
const std::string& input_name) const {
270+
const auto& vm_func = CheckAndGetVMFunction(func_name);
271+
return GetInputIndexFromName(vm_func.params, input_name);
272+
}
273+
274+
int64_t VirtualMachine::GetInputIndexFromName(const std::vector<std::string>& params,
275+
const std::string& input_name) const {
276+
// TODO(vvchernov): excess integer type?
277+
for (uint64_t i = 0; i < params.size(); i++) {
278+
if (input_name == params[i]) {
279+
return static_cast<int64_t>(i);
280+
}
281+
}
282+
return static_cast<int64_t>(-1);
283+
}
284+
285+
const VMFunction& VirtualMachine::CheckAndGetVMFunction(const std::string& func_name) const {
286+
ICHECK(exec_) << "The executable is not created yet.";
287+
return exec_->GetVMFunctionWithName(func_name);
288+
}
289+
290+
void VirtualMachine::CreateInputsOrCheckSize(const std::string& func_name, size_t size) {
291+
if (inputs_.count(func_name)) {
292+
ICHECK_EQ(inputs_[func_name].size(), size)
293+
<< "The size of function" << func_name
294+
<< " doesn't match the number of provided parameters";
295+
} else {
296+
std::vector<ObjectRef> func_args(size);
297+
inputs_.emplace(func_name, func_args);
298+
}
299+
}
300+
301+
void VirtualMachine::SetInputTensorWithIndex(std::vector<ObjectRef>& tensors,
302+
const TVMArgValue& inp_tensor, int index, Device dev) {
303+
if (inp_tensor.type_code() == kTVMDLTensorHandle) {
304+
// Automatically convert input DLTensors to NDArray
305+
DLTensor* tensor = inp_tensor;
306+
std::vector<int64_t> shape;
307+
for (int64_t i = 0; i < tensor->ndim; i++) {
308+
shape.push_back(tensor->shape[i]);
309+
}
310+
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
311+
ary.CopyFrom(tensor);
312+
tensors[index] = ary;
313+
} else {
314+
tensors[index] = CopyTo(inp_tensor, dev);
315+
}
316+
}
317+
270318
inline Device VirtualMachine::GetDevice(Index device_index) const {
271319
ICHECK_GE(devices_.size(), device_index) << "invalid device index: " << device_index;
272320
return devices_[device_index];

0 commit comments

Comments
 (0)