Skip to content

Commit 0c6b378

Browse files
author
Siyuan Feng
committed
support config fragment shape and layout using intrinsic
1 parent 4bb9b54 commit 0c6b378

File tree

8 files changed

+264
-25
lines changed

8 files changed

+264
-25
lines changed

include/tvm/ir.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope";
13101310
*/
13111311
constexpr const char* device_scope = "device_scope";
13121312

1313+
/*!
1314+
* \brief Mark that the shape of TensorCore fragment
1315+
*/
1316+
constexpr const char* fragment_shape = "fragment_shape";
1317+
1318+
/*!
1319+
* \brief Mark that the layout of TensorCore fragment
1320+
*/
1321+
constexpr const char* fragment_layout = "fragment_layout";
1322+
13131323
/*!
13141324
* \brief Check if attr_key is a pragma key extension
13151325
* \param attr_key The attr key to be compared
@@ -1319,6 +1329,7 @@ inline bool IsPragmaKey(const std::string& attr_key) {
13191329
return attr_key.compare(0, 7, "pragma_") == 0;
13201330
}
13211331

1332+
13221333
} // namespace attr
13231334

13241335
/*! \brief namespace of TVM Intrinsic functions */
@@ -1559,7 +1570,6 @@ constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
15591570
constexpr const char* tvm_mma_sync = "tvm_mma_sync";
15601571
constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
15611572
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
1562-
constexpr const char* tvm_access_fragement = "tvm_access_fragement";
15631573

15641574
} // namespace intrinsic
15651575

include/tvm/ir_pass.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
525525
*/
526526
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
527527

528+
/*!
529+
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
530+
*
531+
* \param stmt The stmt to be transformed
532+
* \return Transformed stmt.
533+
*/
534+
LoweredFunc InferFragment(LoweredFunc f);
535+
528536
/*!
529537
* \brief Verify if memory accesses are legal for a specific target device type.
530538
*

python/tvm/build_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def _build_for_device(flist, target, target_host):
464464
func = ir_pass.ThreadSync(func, "global")
465465
func = ir_pass.ThreadSync(func, "shared")
466466
func = ir_pass.ThreadSync(func, "warp")
467+
func = ir_pass.InferFragment(func)
467468
warp_size = target.thread_warp_size
468469
func = ir_pass.LowerThreadAllreduce(func, warp_size)
469470
fsplits = [s for s in ir_pass.SplitHostDevice(func)]

src/api/api_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,6 @@ REGISTER_PASS(VerifyGPUCode);
160160
REGISTER_PASS(DecorateDeviceScope);
161161
REGISTER_PASS(InstrumentBoundCheckers);
162162
REGISTER_PASS(VerifyCompactBuffer);
163+
REGISTER_PASS(InferFragment)
163164
} // namespace ir
164165
} // namespace tvm

src/codegen/codegen_cuda.cc

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -305,39 +305,39 @@ void CodeGenCUDA::PrintStorageScope(
305305
void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
306306
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
307307
need_mma_h_ = true;
308-
CHECK_EQ(op->args.size(), 3U);
308+
CHECK_EQ(op->args.size(), 6U);
309309
os << "nvcuda::wmma::fill_fragment(";
310310
this->PrintExpr(op->args[0], os);
311311
os << "[";
312-
this->PrintExpr(op->args[1], os);
312+
this->PrintExpr(op->args[4], os);
313313
os << "], ";
314-
this->PrintExpr(op->args[2], os);
314+
this->PrintExpr(op->args[5], os);
315315
os << ")";
316316
} else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) {
317317
need_mma_h_ = true;
318-
CHECK_EQ(op->args.size(), 4U);
318+
CHECK_EQ(op->args.size(), 8U);
319319
os << "nvcuda::wmma::load_matrix_sync(";
320320
this->PrintExpr(op->args[0], os);
321321
os << "[";
322-
this->PrintExpr(op->args[1], os);
322+
this->PrintExpr(op->args[4], os);
323323
os << "], ";
324-
this->PrintExpr(op->args[2], os);
324+
this->PrintExpr(op->args[5], os);
325325
os << ", ";
326-
this->PrintExpr(op->args[3], os);
326+
this->PrintExpr(op->args[6], os);
327327
os << ")";
328328
} else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
329329
need_mma_h_ = true;
330-
CHECK_EQ(op->args.size(), 5U);
330+
CHECK_EQ(op->args.size(), 8U);
331331
os << "nvcuda::wmma::store_matrix_sync(";
332-
this->PrintExpr(op->args[2], os);
332+
this->PrintExpr(op->args[5], os);
333333
os << ", ";
334334
this->PrintExpr(op->args[0], os);
335335
os << "[";
336-
this->PrintExpr(op->args[1], os);
336+
this->PrintExpr(op->args[4], os);
337337
os << "], ";
338-
this->PrintExpr(op->args[3], os);
339-
if (const StringImm *str = op->args[4].as<StringImm>()) {
340-
os << ", nvcuda::wmma::" << str->value;
338+
this->PrintExpr(op->args[6], os);
339+
if (const StringImm *str = op->args[7].as<StringImm>()) {
340+
os << ", nvcuda::wmma::mem_" << str->value;
341341
} else {
342342
LOG(FATAL) << "Invalid parameters";
343343
}
@@ -357,6 +357,19 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
357357
}
358358
}
359359

360+
void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
361+
if (op->attr_key == attr::fragment_shape) {
362+
const Variable* buffer = op->node.as<Variable>();
363+
const StringImm* shape_str = op->value.as<StringImm>();
364+
fragment_shapes[buffer] = shape_str->value;
365+
} else if (op->attr_key == attr::fragment_layout) {
366+
const Variable* buffer = op->node.as<Variable>();
367+
const StringImm* layout_str = op->value.as<StringImm>();
368+
fragment_layouts[buffer] = layout_str->value;
369+
}
370+
CodeGenC::VisitStmt_(op);
371+
}
372+
360373
void CodeGenCUDA::VisitStmt_(const Allocate* op) {
361374
CHECK(!is_zero(op->condition));
362375
std::string vid = AllocVarID(op->buffer_var.get());
@@ -383,7 +396,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) {
383396
<< "Accumulator only support half and float type for now";
384397
}
385398
constant_size /= 256;
386-
PrintWmmaScope(scope, op->type, stream);
399+
PrintWmmaScope(scope, op->type, buffer, stream);
387400
} else {
388401
PrintStorageScope(scope, stream);
389402
stream << ' ';
@@ -498,18 +511,23 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
498511
PrintConst(op, os, this);
499512
}
500513

501-
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, std::ostream &os) {
514+
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, const Variable* variable, std::ostream &os) {
502515
std::stringstream type;
503516
PrintType(t, type);
517+
std::string shape_str = fragment_shapes[variable];
504518
if (scope == "wmma.matrix_a") {
505519
need_mma_h_ = true;
506-
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, " << type.str() << ", nvcuda::wmma::row_major>";
520+
std::string layout_str = fragment_layouts[variable];
521+
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
522+
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
507523
} else if (scope == "wmma.matrix_b") {
508524
need_mma_h_ = true;
509-
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, "<< type.str() << ", nvcuda::wmma::row_major>";
525+
std::string layout_str = fragment_layouts[variable];
526+
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
527+
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
510528
} else if (scope == "wmma.accumulator") {
511529
need_mma_h_ = true;
512-
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, "<< type.str() << ">";
530+
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", "<< type.str() << ">";
513531
}
514532
}
515533

src/codegen/codegen_cuda.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class CodeGenCUDA final : public CodeGenC {
6363
void VisitExpr_(const Call *op, std::ostream& os) final;
6464
void VisitStmt_(const Evaluate *op) final;
6565
void VisitStmt_(const Allocate *op) final;
66+
void VisitStmt_(const AttrStmt *op) final;
6667

6768
private:
6869
// Whether global barrier is needed.
@@ -79,8 +80,11 @@ class CodeGenCUDA final : public CodeGenC {
7980
bool need_math_constants_h_{false};
8081
// whether need mma.h
8182
bool need_mma_h_{false};
83+
84+
std::unordered_map<const Variable*, std::string> fragment_shapes;
85+
std::unordered_map<const Variable*, std::string> fragment_layouts;
8286
friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p);
83-
void PrintWmmaScope(const std::string& scope, Type t, std::ostream& os);
87+
void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os);
8488
};
8589

8690
} // namespace codegen

src/pass/infer_fragment.cc

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file tensorcore_fragment.cc
23+
*/
24+
#include <tvm/ir.h>
25+
#include <tvm/ir_pass.h>
26+
#include <tvm/ir_mutator.h>
27+
#include <tvm/ir_visitor.h>
28+
#include <unordered_map>
29+
#include <unordered_set>
30+
#include "ir_util.h"
31+
#include "storage_access.h"
32+
#include "../runtime/thread_storage_scope.h"
33+
34+
namespace tvm {
35+
namespace ir {
36+
37+
class FragmentGetter : public IRVisitor {
38+
public:
39+
struct FragmentInfo {
40+
int m, n, k;
41+
std::string layout;
42+
FragmentInfo() = default;
43+
FragmentInfo(int _m, int _n, int _k, const std::string& _layout)
44+
: m(_m), n(_n), k(_k), layout(_layout) {}
45+
};
46+
47+
void Visit_(const Call* op) final {
48+
IRVisitor::Visit_(op);
49+
50+
if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
51+
op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
52+
CHECK_EQ(op->args.size(), 8U);
53+
const Variable* buffer_var = op->args[0].as<Variable>();
54+
CHECK(buffer_var);
55+
const IntImm* m = op->args[1].as<IntImm>();
56+
const IntImm* n = op->args[2].as<IntImm>();
57+
const IntImm* k = op->args[3].as<IntImm>();
58+
const StringImm* layout = op->args[7].as<StringImm>();
59+
CHECK(m);
60+
CHECK(n);
61+
CHECK(k);
62+
CHECK(layout);
63+
64+
std::string scope = scopes[buffer_var];
65+
if (fragments.count(buffer_var)) {
66+
FragmentInfo info = fragments[buffer_var];
67+
CHECK_EQ(m->value, info.m);
68+
CHECK_EQ(n->value, info.n);
69+
CHECK_EQ(k->value, info.k);
70+
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
71+
CHECK_EQ(layout->value, info.layout);
72+
}
73+
} else {
74+
FragmentInfo info;
75+
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
76+
info = FragmentInfo(m->value, n->value, k->value, layout->value);
77+
} else if (scope == "wmma.accumulator") {
78+
info = FragmentInfo(m->value, n->value, k->value, "");
79+
}
80+
fragments[buffer_var] = info;
81+
}
82+
} else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
83+
CHECK_EQ(op->args.size(), 6U);
84+
const Variable* buffer_var = op->args[0].as<Variable>();
85+
CHECK(buffer_var);
86+
const IntImm* m = op->args[1].as<IntImm>();
87+
const IntImm* n = op->args[2].as<IntImm>();
88+
const IntImm* k = op->args[3].as<IntImm>();
89+
CHECK(m);
90+
CHECK(n);
91+
CHECK(k);
92+
93+
std::string scope = scopes[buffer_var];
94+
CHECK_EQ(scope, "wmma.accumulator");
95+
if (fragments.count(buffer_var)) {
96+
FragmentInfo info = fragments[buffer_var];
97+
CHECK_EQ(m->value, info.m);
98+
CHECK_EQ(n->value, info.n);
99+
CHECK_EQ(k->value, info.k);
100+
} else {
101+
FragmentInfo info(m->value, n->value, k->value, "");
102+
fragments[buffer_var] = info;
103+
}
104+
}
105+
}
106+
107+
void Visit_(const AttrStmt* op) final {
108+
if (op->attr_key == attr::storage_scope) {
109+
const Variable* buffer = op->node.as<Variable>();
110+
CHECK(buffer);
111+
scopes[buffer] = op->value.as<StringImm>()->value;
112+
}
113+
IRVisitor::Visit_(op);
114+
}
115+
116+
std::unordered_map<const Variable*, std::string> scopes;
117+
std::unordered_map<const Variable*, FragmentInfo> fragments;
118+
};
119+
120+
class FragmentChecker : public IRVisitor {
121+
public:
122+
FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
123+
124+
void Visit_(const Call* op) final {
125+
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
126+
CHECK_EQ(op->args.size(), 8U);
127+
const Variable* buffer_var_d = op->args[0].as<Variable>();
128+
const Variable* buffer_var_a = op->args[2].as<Variable>();
129+
const Variable* buffer_var_b = op->args[4].as<Variable>();
130+
const Variable* buffer_var_c = op->args[6].as<Variable>();
131+
CHECK(buffer_var_d);
132+
CHECK(buffer_var_a);
133+
CHECK(buffer_var_b);
134+
CHECK(buffer_var_c);
135+
CHECK(CheckShape(buffer_var_d, buffer_var_a));
136+
CHECK(CheckShape(buffer_var_d, buffer_var_b));
137+
CHECK(CheckShape(buffer_var_d, buffer_var_c));
138+
}
139+
}
140+
private:
141+
bool CheckShape(const Variable* buffer1, const Variable* buffer2) {
142+
CHECK(fragment_getter.fragments.count(buffer1));
143+
CHECK(fragment_getter.fragments.count(buffer2));
144+
FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
145+
FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
146+
return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
147+
148+
}
149+
const FragmentGetter &fragment_getter;
150+
151+
};
152+
153+
class InferFragmenter : public IRMutator {
154+
public:
155+
InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
156+
157+
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
158+
Stmt stmt = IRMutator::Mutate_(op, s);
159+
const Variable* buffer = op->buffer_var.get();
160+
if (fragment_getter.fragments.count(buffer)) {
161+
FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
162+
std::string shape = std::to_string(info.n) + ", " +
163+
std::to_string(info.m) + ", " +
164+
std::to_string(info.k);
165+
Expr shape_expr = StringImm::make(shape);
166+
Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
167+
if (info.layout != "") {
168+
Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout,
169+
StringImm::make(info.layout), shape_attr);
170+
return layout_attr;
171+
} else {
172+
return shape_attr;
173+
}
174+
}
175+
return stmt;
176+
}
177+
private:
178+
const FragmentGetter &fragment_getter;
179+
};
180+
181+
Stmt InferFragment(Stmt stmt) {
182+
FragmentGetter getter;
183+
getter.Visit(stmt);
184+
FragmentChecker(getter).Visit(stmt);
185+
stmt = InferFragmenter(getter).Mutate(stmt);
186+
return stmt;
187+
}
188+
189+
LoweredFunc InferFragment(LoweredFunc f) {
190+
CHECK_NE(f->func_type, kHostFunc);
191+
auto n = make_node<LoweredFuncNode>(*f.operator->());
192+
n->body = InferFragment(f->body);
193+
return LoweredFunc(n);
194+
}
195+
196+
} // namespace ir
197+
} // namespace tvm

0 commit comments

Comments
 (0)