-
Notifications
You must be signed in to change notification settings - Fork 49
Expand file tree
/
Copy pathAirHostModule.cpp
More file actions
152 lines (120 loc) · 4.27 KB
/
Copy pathAirHostModule.cpp
File metadata and controls
152 lines (120 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
//===- AirHostModule.cpp ----------------------------------------*- C++ -*-===//
//
// Copyright (C) 2021-2022, Xilinx Inc.
// Copyright (C) 2022, Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT
//
//===----------------------------------------------------------------------===//
#include <iostream>
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include "air.hpp"
#include "hsa/hsa.h"
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
namespace nb = nanobind;
namespace {
void defineAIRHostModule(nb::module_ &m) {
m.def(
"init_libxaie", []() -> uint64_t { return (uint64_t)air_init_libxaie(); },
nb::rv_policy::reference);
m.def("deinit_libxaie", [](uint64_t ctx) -> void {
air_deinit_libxaie((air_libxaie_ctx_t)ctx);
});
m.def("init", []() -> uint64_t { return (uint64_t)air_init(); });
m.def("shut_down", []() -> uint64_t { return (uint64_t)air_shut_down(); });
nb::class_<air_module_desc_t>(m, "ModuleDescriptor")
.def(
"getSegments",
[](const air_module_desc_t &d) -> std::vector<air_segment_desc_t *> {
std::vector<air_segment_desc_t *> segments;
for (uint64_t i = 0; i < d.segment_length; i++)
segments.push_back(d.segment_descs[i]);
return segments;
},
nb::rv_policy::reference);
nb::class_<air_segment_desc_t>(m, "SegmentDescriptor")
.def(
"getHerds",
[](const air_segment_desc_t &d) -> std::vector<air_herd_desc_t *> {
std::vector<air_herd_desc_t *> herds;
for (uint64_t i = 0; i < d.herd_length; i++)
herds.push_back(d.herd_descs[i]);
return herds;
},
nb::rv_policy::reference)
.def("getName", [](const air_segment_desc_t &d) -> std::string {
return {d.name, static_cast<size_t>(d.name_length)};
});
nb::class_<air_herd_desc_t>(m, "HerdDescriptor")
.def("getName", [](const air_herd_desc_t &d) -> std::string {
return {d.name, static_cast<size_t>(d.name_length)};
});
m.def("module_load_from_file",
[](const std::string &filename, hsa_agent_t *agent,
hsa_queue_t *q) -> air_module_handle_t {
return air_module_load_from_file(filename.c_str(), agent, q);
});
m.def("module_unload", &air_module_unload);
m.def("get_module_descriptor", &air_module_get_desc,
nb::rv_policy::reference);
nb::class_<hsa_agent_t> Agent(m, "Agent");
m.def(
"get_agents",
[]() -> std::vector<hsa_agent_t> {
std::vector<hsa_agent_t> agents;
air_get_agents(agents);
return agents;
},
nb::rv_policy::reference);
nb::class_<hsa_queue_t> Queue(m, "Queue");
m.def(
"queue_create",
[](const hsa_agent_t &a) -> hsa_queue_t * {
hsa_queue_t *q = nullptr;
uint32_t aie_max_queue_size(0);
// Query the queue size the agent supports
auto queue_size_ret = hsa_agent_get_info(
a, HSA_AGENT_INFO_QUEUE_MAX_SIZE, &aie_max_queue_size);
if (queue_size_ret != HSA_STATUS_SUCCESS)
return nullptr;
// Creating the queue
auto queue_create_ret =
hsa_queue_create(a, aie_max_queue_size, HSA_QUEUE_TYPE_SINGLE,
nullptr, nullptr, 0, 0, &q);
if (queue_create_ret != 0)
return nullptr;
return q;
},
nb::rv_policy::reference);
m.def(
"read32", [](uint64_t addr) -> uint32_t { return air_read32(addr); },
nb::rv_policy::copy);
m.def("write32", [](uint64_t addr, uint32_t val) -> void {
return air_write32(addr, val);
});
m.def(
"get_tile_addr",
[](uint32_t col, uint32_t row) -> uint64_t {
return air_get_tile_addr(col, row);
},
nb::rv_policy::copy);
}
} // namespace
NB_MODULE(_airRt, m) {
m.doc() = R"pbdoc(
AIR Runtime Python bindings
--------------------------
.. currentmodule:: _airRt
.. autosummary::
:toctree: _generate
)pbdoc";
#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
#else
m.attr("__version__") = "dev";
#endif
auto airhost = m.def_submodule("host", "libairhost bindings");
defineAIRHostModule(airhost);
}