Skip to content

Commit 1ae550a

Browse files
author
Siyuan Feng
committed
address comment
1 parent 225e2c9 commit 1ae550a

File tree

11 files changed

+122
-29
lines changed

11 files changed

+122
-29
lines changed

include/tvm/ir.h

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,6 @@ inline bool IsPragmaKey(const std::string& attr_key) {
13291329
return attr_key.compare(0, 7, "pragma_") == 0;
13301330
}
13311331

1332-
13331332
} // namespace attr
13341333

13351334
/*! \brief namespace of TVM Intrinsic functions */
@@ -1564,11 +1563,52 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
15641563
*/
15651564
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
15661565
/*!
1567-
* \brief tvm intrinsic for tensor core opeartors.
1566+
* \brief tvm intrinsic for tensor core load operators.
1567+
*
1568+
* void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
1569+
* Expr index, Expr buffer_ptr, Expr stride,
1570+
* StringImm layout) {
1571+
* // m, n, k are the shape of wmma fragment.
1572+
* // Determine fragment layout(column-major or row major) by layout.
1573+
* // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope.
1574+
* nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride);
1575+
* }
15681576
*/
15691577
constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
1578+
/*!
1579+
* \brief tvm intrinsic for tensor core mma_sync operators.
1580+
*
1581+
* void tvm_mma_sync(Var fragment_d, Expr index_d,
1582+
* Var fragment_a, Expr index_a,
1583+
* Var fragment_b, Expr index_b,
1584+
* Var fragment_c, Expr index_c) {
1585+
* nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a],
1586+
* fragment_b[index_b], fragment_c[index_c]);
1587+
* }
1588+
*/
15701589
constexpr const char* tvm_mma_sync = "tvm_mma_sync";
1590+
/*!
1591+
* \brief tvm intrinsic for tensor core fill_fragment operators.
1592+
*
1593+
* void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
1594+
* Expr index, Expr value) {
1595+
* // m, n, k are the shape of wmma fragment
1596+
* // fragments must be in 'wmma.accumulator' scope.
1597+
* nvcuda::wmma::fill_fragment(fragment[index], value);
1598+
* }
1599+
*/
15711600
constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
1601+
/*!
1602+
* \brief tvm intrinsic for tensor core store operators.
1603+
*
1604+
* void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
1605+
* Expr index, Expr buffer_ptr, Expr stride,
1606+
* StringImm layout) {
1607+
* // m, n, k are the shape of wmma fragment
1608+
* // fragments must be in 'wmma.accumulator' scope.
1609+
* nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout);
1610+
* }
1611+
*/
15721612
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
15731613

15741614
} // namespace intrinsic

include/tvm/ir_pass.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,15 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
359359
*/
360360
Stmt RewriteUnsafeSelect(Stmt stmt);
361361

362+
/*!
363+
* \brief Lower attached storage access information.
364+
* Do this pass after all storage access analysis finish.
365+
*
366+
* \param stmt The stmt to be transformed
367+
* \return Transformed stmt.
368+
*/
369+
Stmt LowerStorageAccessInfo(Stmt stmt);
370+
362371
/*!
363372
* \brief Decorate the stmt with a device scope, this is helpful for
364373
* hardware accelerator without thread blocks.
@@ -505,13 +514,13 @@ LoweredFunc CombineContextCall(LoweredFunc f);
505514
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
506515

507516
/*!
508-
* \brief Lower attached storage access information.
517+
* \brief Lower attached storage access information on device.
509518
* Do this pass after all storage access analysis finish.
510519
*
511520
* \param func The device function to be lowered.
512521
* \return Transformed function.
513522
*/
514-
LoweredFunc LowerStorageAccessInfo(LoweredFunc func);
523+
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);
515524

516525
/*!
517526
* \brief Lower intrinsic function calls.

python/tvm/build_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,8 @@ def _build_for_device(flist, target, target_host):
494494
assert not fdevice
495495

496496
target_host = _target.create(target_host)
497-
fdevice = [ir_pass.LowerStorageAccessInfo(x) for x in fdevice]
498-
fhost = [ir_pass.LowerStorageAccessInfo(x) for x in fhost]
497+
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
498+
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
499499
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
500500
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
501501
fhost = [ir_pass.CombineContextCall(x) for x in fhost]

src/api/api_pass.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
118118
});
119119
});
120120

121+
TVM_REGISTER_API("ir_pass.LowerStorageAccess")
122+
.set_body([](TVMArgs args, TVMRetValue *ret) {
123+
LoweredFunc f = args[0];
124+
auto n = make_node<LoweredFuncNode>(*f.operator->());
125+
n->body = LowerStorageAccessInfo(f->body);
126+
*ret = LoweredFunc(n);
127+
});
128+
121129
// make from two arguments
122130
#define REGISTER_PASS(PassName) \
123131
TVM_REGISTER_API("ir_pass."#PassName) \
@@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice);
140148
REGISTER_PASS(StorageRewrite);
141149
REGISTER_PASS(CoProcSync);
142150
REGISTER_PASS(LowerStorageAccessInfo);
151+
REGISTER_PASS(LowerDeviceStorageAccessInfo)
143152
REGISTER_PASS(InjectVirtualThread);
144153
REGISTER_PASS(InjectPrefetch);
145154
REGISTER_PASS(InjectDoubleBuffer);

src/codegen/build_module.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,15 +517,15 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
517517
for (size_t i = 0; i < fhost.size(); ++i) {
518518
auto func = fhost[i];
519519
func = ir::BindDeviceType(func, target->device_type);
520-
func = ir::LowerStorageAccessInfo(func);
520+
func = ir::LowerDeviceStorageAccessInfo(func);
521521
func = ir::LowerTVMBuiltin(func);
522522
fhost.Set(i, func);
523523
}
524524

525525
for (size_t i = 0; i < fhost.size(); ++i) {
526526
auto func = fhost[i];
527527
func = ir::LowerIntrin(func, target_host->target_name);
528-
func = ir::LowerStorageAccessInfo(func);
528+
func = ir::LowerDeviceStorageAccessInfo(func);
529529
func = ir::CombineContextCall(func);
530530
fhost.Set(i, func);
531531
}

src/pass/infer_fragment.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
*/
1919

2020
/*!
21-
* Copyright (c) 2019 by Contributors
21+
* Copyright (c) 2019 by Contributors
22+
* \brief Infer TensorCore metadata from tensor intrinsic.
2223
* \file tensorcore_fragment.cc
2324
*/
2425
#include <tvm/ir.h>
@@ -34,10 +35,14 @@
3435
namespace tvm {
3536
namespace ir {
3637

38+
// Get fragment information from tensor intrinsics
3739
class FragmentGetter : public IRVisitor {
3840
public:
41+
// fragment metadata
3942
struct FragmentInfo {
43+
// fragment shape
4044
int m, n, k;
45+
// fragment layout (row-major or column-major)
4146
std::string layout;
4247
FragmentInfo() = default;
4348
FragmentInfo(int _m, int _n, int _k, const std::string& _layout)
@@ -49,9 +54,11 @@ class FragmentGetter : public IRVisitor {
4954

5055
if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) ||
5156
op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
57+
// Get shape and layout information from load and store intrinsic
5258
CHECK_EQ(op->args.size(), 8U);
5359
const Variable* buffer_var = op->args[0].as<Variable>();
5460
CHECK(buffer_var);
61+
// Get shape
5562
const IntImm* m = op->args[1].as<IntImm>();
5663
const IntImm* n = op->args[2].as<IntImm>();
5764
const IntImm* k = op->args[3].as<IntImm>();
@@ -63,6 +70,7 @@ class FragmentGetter : public IRVisitor {
6370

6471
std::string scope = scopes[buffer_var];
6572
if (fragments.count(buffer_var)) {
73+
// check if the fragment has met before
6674
FragmentInfo info = fragments[buffer_var];
6775
CHECK_EQ(m->value, info.m);
6876
CHECK_EQ(n->value, info.n);
@@ -71,6 +79,7 @@ class FragmentGetter : public IRVisitor {
7179
CHECK_EQ(layout->value, info.layout);
7280
}
7381
} else {
82+
// store metadata
7483
FragmentInfo info;
7584
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
7685
info = FragmentInfo(m->value, n->value, k->value, layout->value);
@@ -80,9 +89,11 @@ class FragmentGetter : public IRVisitor {
8089
fragments[buffer_var] = info;
8190
}
8291
} else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
92+
// Get shape information from fill intrinsic
8393
CHECK_EQ(op->args.size(), 6U);
8494
const Variable* buffer_var = op->args[0].as<Variable>();
8595
CHECK(buffer_var);
96+
// Get shape
8697
const IntImm* m = op->args[1].as<IntImm>();
8798
const IntImm* n = op->args[2].as<IntImm>();
8899
const IntImm* k = op->args[3].as<IntImm>();
@@ -91,6 +102,7 @@ class FragmentGetter : public IRVisitor {
91102
CHECK(k);
92103

93104
std::string scope = scopes[buffer_var];
105+
// Only wmma.accumulator can use tvm_fill_fragment
94106
CHECK_EQ(scope, "wmma.accumulator");
95107
if (fragments.count(buffer_var)) {
96108
FragmentInfo info = fragments[buffer_var];
@@ -104,6 +116,7 @@ class FragmentGetter : public IRVisitor {
104116
}
105117
}
106118

119+
// Get memory scope
107120
void Visit_(const AttrStmt* op) final {
108121
if (op->attr_key == attr::storage_scope) {
109122
const Variable* buffer = op->node.as<Variable>();
@@ -113,15 +126,19 @@ class FragmentGetter : public IRVisitor {
113126
IRVisitor::Visit_(op);
114127
}
115128

129+
// Memory scope for allocations
116130
std::unordered_map<const Variable*, std::string> scopes;
131+
// Fragment metadata for all fragments
117132
std::unordered_map<const Variable*, FragmentInfo> fragments;
118133
};
119134

135+
// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
120136
class FragmentChecker : public IRVisitor {
121137
public:
122138
explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {}
123139

124140
void Visit_(const Call* op) final {
141+
// Check shape when calling tvm_mma_sync
125142
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
126143
CHECK_EQ(op->args.size(), 8U);
127144
const Variable* buffer_var_d = op->args[0].as<Variable>();
@@ -132,24 +149,28 @@ class FragmentChecker : public IRVisitor {
132149
CHECK(buffer_var_a);
133150
CHECK(buffer_var_b);
134151
CHECK(buffer_var_c);
152+
153+
// Check all fragment A, B, C and D have the same shape
135154
CHECK(CheckShape(buffer_var_d, buffer_var_a));
136155
CHECK(CheckShape(buffer_var_d, buffer_var_b));
137156
CHECK(CheckShape(buffer_var_d, buffer_var_c));
138157
}
139158
}
140159

141160
private:
161+
// A tool for checking shapes of two fragments
142162
bool CheckShape(const Variable* buffer1, const Variable* buffer2) {
143163
CHECK(fragment_getter.fragments.count(buffer1));
144164
CHECK(fragment_getter.fragments.count(buffer2));
145165
FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
146166
FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
147167
return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
148168
}
149-
169+
// Fragment infomation
150170
const FragmentGetter &fragment_getter;
151171
};
152172

173+
// Store the metadata into attributes
153174
class InferFragmenter : public IRMutator {
154175
public:
155176
explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {}
@@ -158,13 +179,17 @@ class InferFragmenter : public IRMutator {
158179
Stmt stmt = IRMutator::Mutate_(op, s);
159180
const Variable* buffer = op->buffer_var.get();
160181
if (fragment_getter.fragments.count(buffer)) {
182+
// Add attribute to fragments allocation
161183
FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer);
184+
185+
// Add shape attribute to all fragments
162186
std::string shape = std::to_string(info.n) + ", " +
163187
std::to_string(info.m) + ", " +
164188
std::to_string(info.k);
165189
Expr shape_expr = StringImm::make(shape);
166190
Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
167191
if (info.layout != "") {
192+
// Add shape attribute to matrix_a and matrix_b
168193
Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout,
169194
StringImm::make(info.layout), shape_attr);
170195
return layout_attr;
@@ -176,6 +201,7 @@ class InferFragmenter : public IRMutator {
176201
}
177202

178203
private:
204+
// Fragment infomation
179205
const FragmentGetter &fragment_getter;
180206
};
181207

src/pass/storage_access.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
341341
return StorageAccessInfoLower().Mutate(stmt);
342342
}
343343

344-
LoweredFunc LowerStorageAccessInfo(LoweredFunc f) {
344+
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
345345
auto n = make_node<LoweredFuncNode>(*f.operator->());
346346
n->body = LowerStorageAccessInfo(f->body);
347347
return LoweredFunc(n);

tests/python/unittest/test_schedule_tensor_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def test_tensor_core_batch_matmal():
191191
s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b'))
192192
s[C].tensorize(kernel_i, intrin_wmma_store_matrix())
193193
s[CF].tensorize(_i, intrin_wmma_gemm())
194+
194195
func = tvm.build(s, [A, B, C], 'cuda')
195196

196197
ctx = tvm.gpu(0)

tests/scripts/task_lint.sh

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,19 @@ trap cleanup 0
3030
echo "Check file types..."
3131
python3 tests/lint/check_file_type.py
3232

33-
#echo "Check ASF license header..."
34-
#java -jar /bin/apache-rat.jar -E tests/lint/rat-excludes -d . | (grep "== File" > /tmp/$$.apache-rat.txt || true)
35-
#if grep --quiet -E "File" /tmp/$$.apache-rat.txt; then
36-
# echo "Need to add ASF header to the following files."
37-
# echo "----------------File List----------------"
38-
# cat /tmp/$$.apache-rat.txt
39-
# echo "-----------------------------------------"
40-
# echo "Use the following steps to add the headers:"
41-
# echo "- Create file_list.txt in your text editor"
42-
# echo "- Copy paste the above content in file-list into file_list.txt"
43-
# echo "- python3 tests/lint/add_asf_header.py file_list.txt"
44-
# exit 1
45-
#fi
33+
echo "Check ASF license header..."
34+
java -jar /bin/apache-rat.jar -E tests/lint/rat-excludes -d . | (grep "== File" > /tmp/$$.apache-rat.txt || true)
35+
if grep --quiet -E "File" /tmp/$$.apache-rat.txt; then
36+
echo "Need to add ASF header to the following files."
37+
echo "----------------File List----------------"
38+
cat /tmp/$$.apache-rat.txt
39+
echo "-----------------------------------------"
40+
echo "Use the following steps to add the headers:"
41+
echo "- Create file_list.txt in your text editor"
42+
echo "- Copy paste the above content in file-list into file_list.txt"
43+
echo "- python3 tests/lint/add_asf_header.py file_list.txt"
44+
exit 1
45+
fi
4646

4747
echo "Check codestyle of c++ code..."
4848
make cpplint

tutorials/optimize/opt_conv_tensorcore.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@
5555
import numpy as np
5656
from tvm.contrib import nvcc
5757

58+
# Use nvcc compiler for better perf
59+
@tvm.register_func
60+
def tvm_callback_cuda_compile(code):
61+
ptx = nvcc.compile_cuda(code, target="ptx")
62+
return ptx
63+
5864
# The sizes of inputs and filters
5965
batch_size = 256
6066
height = 14
@@ -250,10 +256,10 @@ def intrin_func(ins, outs):
250256

251257

252258
# Define tiling sizes
253-
block_row_warps = 2
254-
block_col_warps = 4
255-
warp_row_tiles = 4
256-
warp_col_tiles = 2
259+
block_row_warps = 4
260+
block_col_warps = 2
261+
warp_row_tiles = 2
262+
warp_col_tiles = 4
257263
warp_size = 32
258264
chunk = 2
259265

@@ -333,7 +339,8 @@ def intrin_func(ins, outs):
333339

334340
ctx = tvm.gpu(0)
335341
if nvcc.have_tensorcore(ctx.compute_version):
336-
func = tvm.build(s, [A, W, Conv], 'cuda')
342+
with tvm.build_config(auto_unroll_max_step=16):
343+
func = tvm.build(s, [A, W, Conv], 'cuda')
337344
a_np = np.random.uniform(size=data_shape).astype(A.dtype)
338345
w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
339346
a = tvm.nd.array(a_np, ctx)

0 commit comments

Comments
 (0)