Skip to content

Commit 84b5a86

Browse files
csullivanylc
authored andcommitted
[Hexagon] Launcher modifications to make use of the new device API (apache#9356)
* Use custom lookup_linked_params function in app/hexagon_launcher. * Use new device api in hexagon launcher. * Ensure app/hexagon_launcher uses kDLCPU for external allocations so that CopyDataFromTo knows how to handle the provided memory. * Use loadfile_hexagon in app/hexagon_launcher. * Update hexagon launcher's get_output method to utilize NDArray's for copying in order to exercise the DeviceAPI. * Apply clang formatting
1 parent 21fbca4 commit 84b5a86

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

apps/hexagon_launcher/launcher_core.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,13 @@ const tvm::runtime::PackedFunc get_module_func(tvm::runtime::Module module,
148148
}
149149

150150
void reset_device_api() {
151-
const tvm::runtime::PackedFunc api = get_runtime_func("device_api.cpu");
151+
const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2");
152152
tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api);
153153
}
154154

155155
tvm::runtime::Module load_module(const std::string& file_name) {
156-
static const tvm::runtime::PackedFunc loader = get_runtime_func("runtime.module.loadfile_so");
156+
static const tvm::runtime::PackedFunc loader =
157+
get_runtime_func("runtime.module.loadfile_hexagon");
157158
tvm::runtime::TVMRetValue rv = loader(file_name);
158159
if (rv.type_code() == kTVMModuleHandle) {
159160
return rv.operator tvm::runtime::Module();
@@ -169,7 +170,10 @@ tvm::runtime::Module create_graph_executor(const std::string& graph_json,
169170
uint64_t device_type = device.device_type;
170171
uint64_t device_id = device.device_id;
171172

173+
std::string linked_params = "tvm.runtime.hexagon.lookup_linked_params";
174+
const tvm::runtime::PackedFunc lookup_linked_params = get_runtime_func(linked_params);
172175
// Use default param lookup function (linked into the module).
173-
tvm::runtime::TVMRetValue rv = create_executor(graph_json, graph_module, device_type, device_id);
176+
tvm::runtime::TVMRetValue rv =
177+
create_executor(graph_json, graph_module, lookup_linked_params, device_type, device_id);
174178
return rv.operator tvm::runtime::Module();
175179
}

apps/hexagon_launcher/launcher_core.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ struct Model {
8989

9090
static tvm::Device device() { return tvm::Device{static_cast<DLDeviceType>(kDLHexagon), 0}; }
9191

92+
static tvm::Device external() { return tvm::Device{static_cast<DLDeviceType>(kDLCPU), 0}; }
93+
9294
tvm::runtime::PackedFunc run;
9395
};
9496

apps/hexagon_launcher/launcher_hexagon.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ extern "C" {
2626
#include <qurt_hvx.h>
2727
}
2828

29+
#include <tvm/runtime/object.h>
30+
2931
#include <algorithm>
3032
#include <memory>
3133
#include <string>
@@ -106,7 +108,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_set_input)(remote_handle64 handle, int inpu
106108

107109
DLTensor tensor{
108110
const_cast<unsigned char*>(input_value),
109-
Model::device(),
111+
Model::external(),
110112
meta->ndim,
111113
meta->dtype,
112114
const_cast<int64_t*>(meta->shape),
@@ -153,6 +155,16 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out
153155
tvm::runtime::PackedFunc get_output = get_module_func(TheModel->graph_executor, "get_output");
154156
tvm::runtime::NDArray output = get_output(output_idx);
155157

158+
std::vector<int64_t> shape_vec{output->shape, output->shape + output->ndim};
159+
160+
auto* container = new tvm::runtime::NDArray::Container(
161+
static_cast<void*>(output_value), shape_vec, output->dtype, Model::external());
162+
container->SetDeleter([](tvm::Object* container) {
163+
delete static_cast<tvm::runtime::NDArray::Container*>(container);
164+
});
165+
166+
tvm::runtime::NDArray host_output(GetObjectPtr<tvm::Object>(container));
167+
156168
if (meta_size != 0) {
157169
auto* meta = reinterpret_cast<tensor_meta*>(output_meta);
158170
if (meta_size < meta->meta_size(output->ndim)) {
@@ -170,8 +182,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out
170182
return error_too_small(__func__, "value_size", value_size, data_size);
171183
}
172184

173-
auto data = reinterpret_cast<decltype(output_value)>(output->data);
174-
std::copy(data, data + data_size, output_value);
185+
host_output.CopyFrom(output);
175186
}
176187

177188
return AEE_SUCCESS;

0 commit comments

Comments
 (0)