|
7 | 7 | *
|
8 | 8 | */
|
9 | 9 |
|
| 10 | +#include <tvm/arith/analyzer.h> |
10 | 11 | #include <tvm/script/ir_builder/tir/ir.h>
|
11 | 12 |
|
12 | 13 | namespace tvm {
|
13 | 14 | namespace tl {
|
14 | 15 |
|
| 16 | +constexpr const char *tilelang_is_cpu_kernel_frame = |
| 17 | + "tilelang.is_cpu_kernel_frame"; |
| 18 | + |
15 | 19 | using namespace script::ir_builder::tir;
|
16 | 20 |
|
| 21 | +static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) { |
| 22 | + using namespace tvm::tir; |
| 23 | + Var var = Var(name); |
| 24 | + // Create a frame that represents a loop over the given domain. |
| 25 | + ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); |
| 26 | + n->vars.push_back(var); |
| 27 | + n->doms.push_back(Range(0, dom)); |
| 28 | + n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms, |
| 29 | + Stmt body) -> Stmt { |
| 30 | + ICHECK_EQ(vars.size(), 1); |
| 31 | + ICHECK_EQ(doms.size(), 1); |
| 32 | + return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body); |
| 33 | + }; |
| 34 | + return ForFrame(n); |
| 35 | +} |
| 36 | + |
17 | 37 | ForFrame ParallelFor(Array<PrimExpr> extents,
|
18 | 38 | Map<String, ObjectRef> annotations) {
|
19 | 39 | using namespace tvm::tir;
|
@@ -121,24 +141,43 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
|
121 | 141 | Array<PrimExpr> block_size,
|
122 | 142 | Map<String, ObjectRef> attrs) {
|
123 | 143 | ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
|
124 |
| - ICHECK(grid_size.size() <= 3); |
125 |
| - if (grid_size.size() > 0) |
126 |
| - n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0])); |
127 |
| - if (grid_size.size() > 1) |
128 |
| - n->frames.push_back(LaunchThread("blockIdx.y", grid_size[1])); |
129 |
| - if (grid_size.size() > 2) |
130 |
| - n->frames.push_back(LaunchThread("blockIdx.z", grid_size[2])); |
131 |
| - if (block_size.defined()) { |
132 |
| - ICHECK(block_size.size() <= 3); |
133 |
| - if (block_size.size() > 0) |
134 |
| - n->frames.push_back(LaunchThread("threadIdx.x", block_size[0])); |
135 |
| - if (block_size.size() > 1) |
136 |
| - n->frames.push_back(LaunchThread("threadIdx.y", block_size[1])); |
137 |
| - if (block_size.size() > 2) |
138 |
| - n->frames.push_back(LaunchThread("threadIdx.z", block_size[2])); |
| 144 | + |
| 145 | + // If the kernel is a CPU kernel, we don't need to launch any threads. |
| 146 | + bool is_cpu_kernel_frame = |
| 147 | + attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame); |
| 148 | + |
| 149 | + if (is_cpu_kernel_frame) { |
| 150 | + ICHECK(grid_size.size() >= 0); |
| 151 | + ICHECK(block_size.size() == 0) << "CPU kernel cannot have block size"; |
| 152 | + ICHECK(attrs.defined()); |
| 153 | + // create grid loop var |
| 154 | + for (int i = 0; i < grid_size.size(); i++) { |
| 155 | + n->frames.push_back( |
| 156 | + MakeIterVarFrame("block_var_" + std::to_string(i), grid_size[i])); |
| 157 | + } |
| 158 | + // Launch CPU Kernel |
139 | 159 | } else {
|
140 |
| - n->frames.push_back(Block("")); |
| 160 | + // Launch GPU Kernel |
| 161 | + ICHECK(grid_size.size() <= 3); |
| 162 | + if (grid_size.size() > 0) |
| 163 | + n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0])); |
| 164 | + if (grid_size.size() > 1) |
| 165 | + n->frames.push_back(LaunchThread("blockIdx.y", grid_size[1])); |
| 166 | + if (grid_size.size() > 2) |
| 167 | + n->frames.push_back(LaunchThread("blockIdx.z", grid_size[2])); |
| 168 | + if (block_size.defined()) { |
| 169 | + ICHECK(block_size.size() <= 3); |
| 170 | + if (block_size.size() > 0) |
| 171 | + n->frames.push_back(LaunchThread("threadIdx.x", block_size[0])); |
| 172 | + if (block_size.size() > 1) |
| 173 | + n->frames.push_back(LaunchThread("threadIdx.y", block_size[1])); |
| 174 | + if (block_size.size() > 2) |
| 175 | + n->frames.push_back(LaunchThread("threadIdx.z", block_size[2])); |
| 176 | + } else { |
| 177 | + n->frames.push_back(Block("")); |
| 178 | + } |
141 | 179 | }
|
| 180 | + |
142 | 181 | if (attrs.defined()) {
|
143 | 182 | auto empty_block = Block("");
|
144 | 183 | empty_block->annotations = attrs;
|
|
0 commit comments