Skip to content

Commit 85660f5

Browse files
committed
Merge branch 'main' of https://github.com/microsoft/TileLang into main
2 parents 4e4f317 + 9356eb6 commit 85660f5

File tree

16 files changed

+1716
-69
lines changed

16 files changed

+1716
-69
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
9999
src/transform/*.cc
100100
src/op/*.cc
101101
src/target/utils.cc
102+
src/target/codegen_cpp.cc
103+
src/target/rt_mod_cpp.cc
102104
)
103105

104106
# Include CUDA source files if CUDA is enabled

src/ir.cc

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,33 @@
77
*
88
*/
99

10+
#include <tvm/arith/analyzer.h>
1011
#include <tvm/script/ir_builder/tir/ir.h>
1112

1213
namespace tvm {
1314
namespace tl {
1415

16+
constexpr const char *tilelang_is_cpu_kernel_frame =
17+
"tilelang.is_cpu_kernel_frame";
18+
1519
using namespace script::ir_builder::tir;
1620

21+
static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) {
22+
using namespace tvm::tir;
23+
Var var = Var(name);
24+
// Create a frame that represents a loop over the given domain.
25+
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
26+
n->vars.push_back(var);
27+
n->doms.push_back(Range(0, dom));
28+
n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms,
29+
Stmt body) -> Stmt {
30+
ICHECK_EQ(vars.size(), 1);
31+
ICHECK_EQ(doms.size(), 1);
32+
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
33+
};
34+
return ForFrame(n);
35+
}
36+
1737
ForFrame ParallelFor(Array<PrimExpr> extents,
1838
Map<String, ObjectRef> annotations) {
1939
using namespace tvm::tir;
@@ -121,24 +141,43 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
121141
Array<PrimExpr> block_size,
122142
Map<String, ObjectRef> attrs) {
123143
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
124-
ICHECK(grid_size.size() <= 3);
125-
if (grid_size.size() > 0)
126-
n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0]));
127-
if (grid_size.size() > 1)
128-
n->frames.push_back(LaunchThread("blockIdx.y", grid_size[1]));
129-
if (grid_size.size() > 2)
130-
n->frames.push_back(LaunchThread("blockIdx.z", grid_size[2]));
131-
if (block_size.defined()) {
132-
ICHECK(block_size.size() <= 3);
133-
if (block_size.size() > 0)
134-
n->frames.push_back(LaunchThread("threadIdx.x", block_size[0]));
135-
if (block_size.size() > 1)
136-
n->frames.push_back(LaunchThread("threadIdx.y", block_size[1]));
137-
if (block_size.size() > 2)
138-
n->frames.push_back(LaunchThread("threadIdx.z", block_size[2]));
144+
145+
// If the kernel is a CPU kernel, we don't need to launch any threads.
146+
bool is_cpu_kernel_frame =
147+
attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame);
148+
149+
if (is_cpu_kernel_frame) {
150+
ICHECK(grid_size.size() >= 0);
151+
ICHECK(block_size.size() == 0) << "CPU kernel cannot have block size";
152+
ICHECK(attrs.defined());
153+
// create grid loop var
154+
for (int i = 0; i < grid_size.size(); i++) {
155+
n->frames.push_back(
156+
MakeIterVarFrame("block_var_" + std::to_string(i), grid_size[i]));
157+
}
158+
// Launch CPU Kernel
139159
} else {
140-
n->frames.push_back(Block(""));
160+
// Launch GPU Kernel
161+
ICHECK(grid_size.size() <= 3);
162+
if (grid_size.size() > 0)
163+
n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0]));
164+
if (grid_size.size() > 1)
165+
n->frames.push_back(LaunchThread("blockIdx.y", grid_size[1]));
166+
if (grid_size.size() > 2)
167+
n->frames.push_back(LaunchThread("blockIdx.z", grid_size[2]));
168+
if (block_size.defined()) {
169+
ICHECK(block_size.size() <= 3);
170+
if (block_size.size() > 0)
171+
n->frames.push_back(LaunchThread("threadIdx.x", block_size[0]));
172+
if (block_size.size() > 1)
173+
n->frames.push_back(LaunchThread("threadIdx.y", block_size[1]));
174+
if (block_size.size() > 2)
175+
n->frames.push_back(LaunchThread("threadIdx.z", block_size[2]));
176+
} else {
177+
n->frames.push_back(Block(""));
178+
}
141179
}
180+
142181
if (attrs.defined()) {
143182
auto empty_block = Block("");
144183
empty_block->annotations = attrs;

src/op/elem.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
138138
}
139139

140140
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
141+
Target target = T.target;
142+
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
141143
Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
142144
if (ldsm_stmt.defined())
143145
return ldsm_stmt;
@@ -148,12 +150,19 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
148150
auto simt_loop = MakeSIMTLoop(analyzer);
149151
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
150152

153+
For vectorized_thread_loop;
151154
auto par_op = std::make_unique<ParallelOp>(fused_loop);
152-
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
153-
InferLevel::kFree);
154-
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
155-
par_op->GetLoopLayout());
156-
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
155+
156+
if (is_cpu_target) {
157+
vectorized_thread_loop = VectorizeLoop(fused_loop);
158+
} else {
159+
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
160+
InferLevel::kFree);
161+
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
162+
par_op->GetLoopLayout());
163+
vectorized_thread_loop = VectorizeLoop(thread_loop);
164+
}
165+
157166
if (par_op->GetPredicate(T.thread_var).defined()) {
158167
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
159168
vectorized_thread_loop);

0 commit comments

Comments
 (0)