Skip to content

Commit e972232

Browse files
committed
Update comments
1 parent b4e5c7c commit e972232

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,8 +1318,8 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13181318

13191319
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
13201320
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1321-
/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
1322-
/// C. After the rewriting, we get
1321+
/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
1322+
/// the rewriting, we get
13231323
///
13241324
/// scf.for %f = lo_f to hi_f step 1
13251325
/// scf.for %c = lo_c to hi_c step 1
@@ -1333,30 +1333,42 @@ decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
13331333

13341334
/// Rewrite linalg.winograd_input_transform. The data layout of the input is
13351335
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1336-
/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
1337-
/// C. After the rewriting, we get
1338-
///
1339-
/// scf.for %n = lo_n to hi_n step 1
1340-
/// scf.for %c = lo_c to hi_c step 1
1341-
/// %extracted = extract input<h x w> from input<n x h x w x c>
1342-
/// %ret = linalg.matmul BT, %extracted
1343-
/// %ret = linalg.matmul %ret, B
1344-
/// %inserted = insert %ret into input<h x w x n x c>
1336+
/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
1337+
/// and tileW. After the rewriting, we get
1338+
///
1339+
/// scf.for %h = 0 to tileH step 1
1340+
/// scf.for %w = 0 to tileW step 1
1341+
/// scf.for %n = 0 to N step 1
1342+
/// scf.for %c = 0 to C step 1
1343+
/// %extracted = extract %extracted<alphaH x alphaW> from
1344+
/// %input<N x H x W x C>
1345+
/// at [%n, (%h x m), (%w x m), %c]
1346+
/// %ret = linalg.matmul BT, %extracted
1347+
/// %ret = linalg.matmul %ret, B
1348+
/// %inserted = insert %ret<alphaH x alphaW> into
1349+
/// %output<alphaH x alphaW x tileH x tileW x N x C>
1350+
/// at [0, 0, %h, %w, %n, %c]
13451351
FailureOr<Operation *>
13461352
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
13471353
linalg::WinogradInputTransformOp op);
13481354

13491355
/// Rewrite linalg.winograd_output_transform. The data layout of the output is
13501356
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1351-
/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
1352-
/// F. After the transformation, we get
1353-
///
1354-
/// scf.for %n = lo_n to hi_n step 1
1355-
/// scf.for %f = lo_f to hi_f step 1
1356-
/// %extracted = extract input<h x w> from result<h x w x n x f>
1357-
/// %ret = linalg.matmul AT, %extracted
1358-
/// %ret = linalg.matmul %ret, A
1359-
/// %inserted = insert %ret into ret<n x h x w x f>
1357+
/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
1358+
/// and tileW. After the transformation, we get
1359+
///
1360+
/// scf.for %h = 0 to tileH step 1
1361+
/// scf.for %w = 0 to tileW step 1
1362+
/// scf.for %n = 0 to N step 1
1363+
/// scf.for %f = 0 to F step 1
1364+
/// %extracted = extract %extracted<alphaH x alphaW> from
1365+
/// %input<alphaH x alphaW x tileH x tileW x N x F>
1366+
/// at [0, 0, %h, %w, %n, %f]
1367+
/// %ret = linalg.matmul AT, %extracted
1368+
/// %ret = linalg.matmul %ret, A
1369+
/// %inserted = insert %ret<alphaH x alphaW> into
1370+
/// output<N x H x W x F>
1371+
/// at [%n, (%h x m), (%w x m), %f]
13601372
FailureOr<Operation *>
13611373
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
13621374
linalg::WinogradOutputTransformOp op);

0 commit comments

Comments
 (0)