1818 */
1919
2020/* !
21- * Copyright (c) 2019 by Contributors
21+ * Copyright (c) 2019 by Contributors
22+ * \brief Infer TensorCore metadata from tensor intrinsic.
2223 * \file tensorcore_fragment.cc
2324 */
2425#include < tvm/ir.h>
3435namespace tvm {
3536namespace ir {
3637
38+ // Get fragment information from tensor intrinsics
3739class FragmentGetter : public IRVisitor {
3840 public:
41+ // fragment metadata
3942 struct FragmentInfo {
43+ // fragment shape
4044 int m, n, k;
45+ // fragment layout (row-major or column-major)
4146 std::string layout;
4247 FragmentInfo () = default ;
4348 FragmentInfo (int _m, int _n, int _k, const std::string& _layout)
@@ -49,9 +54,11 @@ class FragmentGetter : public IRVisitor {
4954
5055 if (op->is_intrinsic (intrinsic::tvm_load_matrix_sync) ||
5156 op->is_intrinsic (intrinsic::tvm_store_matrix_sync)) {
57+ // Get shape and layout information from load and store intrinsic
5258 CHECK_EQ (op->args .size (), 8U );
5359 const Variable* buffer_var = op->args [0 ].as <Variable>();
5460 CHECK (buffer_var);
61+ // Get shape
5562 const IntImm* m = op->args [1 ].as <IntImm>();
5663 const IntImm* n = op->args [2 ].as <IntImm>();
5764 const IntImm* k = op->args [3 ].as <IntImm>();
@@ -63,6 +70,7 @@ class FragmentGetter : public IRVisitor {
6370
6471 std::string scope = scopes[buffer_var];
6572 if (fragments.count (buffer_var)) {
73+ // check if the fragment has met before
6674 FragmentInfo info = fragments[buffer_var];
6775 CHECK_EQ (m->value , info.m );
6876 CHECK_EQ (n->value , info.n );
@@ -71,6 +79,7 @@ class FragmentGetter : public IRVisitor {
7179 CHECK_EQ (layout->value , info.layout );
7280 }
7381 } else {
82+ // store metadata
7483 FragmentInfo info;
7584 if (scope == " wmma.matrix_a" || scope == " wmma.matrix_b" ) {
7685 info = FragmentInfo (m->value , n->value , k->value , layout->value );
@@ -80,9 +89,11 @@ class FragmentGetter : public IRVisitor {
8089 fragments[buffer_var] = info;
8190 }
8291 } else if (op->is_intrinsic (intrinsic::tvm_fill_fragment)) {
92+ // Get shape information from fill intrinsic
8393 CHECK_EQ (op->args .size (), 6U );
8494 const Variable* buffer_var = op->args [0 ].as <Variable>();
8595 CHECK (buffer_var);
96+ // Get shape
8697 const IntImm* m = op->args [1 ].as <IntImm>();
8798 const IntImm* n = op->args [2 ].as <IntImm>();
8899 const IntImm* k = op->args [3 ].as <IntImm>();
@@ -91,6 +102,7 @@ class FragmentGetter : public IRVisitor {
91102 CHECK (k);
92103
93104 std::string scope = scopes[buffer_var];
105+ // Only wmma.accumulator can use tvm_fill_fragment
94106 CHECK_EQ (scope, " wmma.accumulator" );
95107 if (fragments.count (buffer_var)) {
96108 FragmentInfo info = fragments[buffer_var];
@@ -104,6 +116,7 @@ class FragmentGetter : public IRVisitor {
104116 }
105117 }
106118
119+ // Get memory scope
107120 void Visit_ (const AttrStmt* op) final {
108121 if (op->attr_key == attr::storage_scope) {
109122 const Variable* buffer = op->node .as <Variable>();
@@ -113,15 +126,19 @@ class FragmentGetter : public IRVisitor {
113126 IRVisitor::Visit_ (op);
114127 }
115128
129+ // Memory scope for allocations
116130 std::unordered_map<const Variable*, std::string> scopes;
131+ // Fragment metadata for all fragments
117132 std::unordered_map<const Variable*, FragmentInfo> fragments;
118133};
119134
135+ // Check shape of fragment making sure it is a valid shape for tvm_mma_sync
120136class FragmentChecker : public IRVisitor {
121137 public:
122138 explicit FragmentChecker (const FragmentGetter &getter) : fragment_getter(getter) {}
123139
124140 void Visit_ (const Call* op) final {
141+ // Check shape when calling tvm_mma_sync
125142 if (op->is_intrinsic (intrinsic::tvm_mma_sync)) {
126143 CHECK_EQ (op->args .size (), 8U );
127144 const Variable* buffer_var_d = op->args [0 ].as <Variable>();
@@ -132,24 +149,28 @@ class FragmentChecker : public IRVisitor {
132149 CHECK (buffer_var_a);
133150 CHECK (buffer_var_b);
134151 CHECK (buffer_var_c);
152+
153+ // Check all fragment A, B, C and D have the same shape
135154 CHECK (CheckShape (buffer_var_d, buffer_var_a));
136155 CHECK (CheckShape (buffer_var_d, buffer_var_b));
137156 CHECK (CheckShape (buffer_var_d, buffer_var_c));
138157 }
139158 }
140159
141160 private:
161+ // A tool for checking shapes of two fragments
142162 bool CheckShape (const Variable* buffer1, const Variable* buffer2) {
143163 CHECK (fragment_getter.fragments .count (buffer1));
144164 CHECK (fragment_getter.fragments .count (buffer2));
145165 FragmentGetter::FragmentInfo info1 = fragment_getter.fragments .at (buffer1);
146166 FragmentGetter::FragmentInfo info2 = fragment_getter.fragments .at (buffer2);
147167 return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k ;
148168 }
149-
169+ // Fragment infomation
150170 const FragmentGetter &fragment_getter;
151171};
152172
173+ // Store the metadata into attributes
153174class InferFragmenter : public IRMutator {
154175 public:
155176 explicit InferFragmenter (const FragmentGetter &getter) : fragment_getter(getter) {}
@@ -158,13 +179,17 @@ class InferFragmenter : public IRMutator {
158179 Stmt stmt = IRMutator::Mutate_ (op, s);
159180 const Variable* buffer = op->buffer_var .get ();
160181 if (fragment_getter.fragments .count (buffer)) {
182+ // Add attribute to fragments allocation
161183 FragmentGetter::FragmentInfo info = fragment_getter.fragments .at (buffer);
184+
185+ // Add shape attribute to all fragments
162186 std::string shape = std::to_string (info.n ) + " , " +
163187 std::to_string (info.m ) + " , " +
164188 std::to_string (info.k );
165189 Expr shape_expr = StringImm::make (shape);
166190 Stmt shape_attr = AttrStmt::make (op->buffer_var , attr::fragment_shape, shape_expr, stmt);
167191 if (info.layout != " " ) {
192+ // Add shape attribute to matrix_a and matrix_b
168193 Stmt layout_attr = AttrStmt::make (op->buffer_var , attr::fragment_layout,
169194 StringImm::make (info.layout ), shape_attr);
170195 return layout_attr;
@@ -176,6 +201,7 @@ class InferFragmenter : public IRMutator {
176201 }
177202
178203 private:
204+ // Fragment infomation
179205 const FragmentGetter &fragment_getter;
180206};
181207
0 commit comments