Skip to content

Commit 508eea4

Browse files
committed
Update comments
1 parent 73bbca1 commit 508eea4

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,8 +1341,8 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13411341

13421342
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
13431343
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1344-
/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
1345-
/// C. After the rewriting, we get
1344+
/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
1345+
/// the rewriting, we get
13461346
///
13471347
/// scf.for %f = lo_f to hi_f step 1
13481348
/// scf.for %c = lo_c to hi_c step 1
@@ -1356,30 +1356,42 @@ decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
13561356

13571357
/// Rewrite linalg.winograd_input_transform. The data layout of the input is
13581358
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1359-
/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
1360-
/// C. After the rewriting, we get
1361-
///
1362-
/// scf.for %n = lo_n to hi_n step 1
1363-
/// scf.for %c = lo_c to hi_c step 1
1364-
/// %extracted = extract input<h x w> from input<n x h x w x c>
1365-
/// %ret = linalg.matmul BT, %extracted
1366-
/// %ret = linalg.matmul %ret, B
1367-
/// %inserted = insert %ret into input<h x w x n x c>
1359+
/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
1360+
/// and tileW. After the rewriting, we get
1361+
///
1362+
/// scf.for %h = 0 to tileH step 1
1363+
/// scf.for %w = 0 to tileW step 1
1364+
/// scf.for %n = 0 to N step 1
1365+
/// scf.for %c = 0 to C step 1
1366+
/// %extracted = extract %extracted<alphaH x alphaW> from
1367+
/// %input<N x H x W x C>
1368+
/// at [%n, (%h x m), (%w x m), %c]
1369+
/// %ret = linalg.matmul BT, %extracted
1370+
/// %ret = linalg.matmul %ret, B
1371+
/// %inserted = insert %ret<alphaH x alphaW> into
1372+
/// %output<alphaH x alphaW x tileH x tileW x N x C>
1373+
/// at [0, 0, %h, %w, %n, %c]
13681374
FailureOr<Operation *>
13691375
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
13701376
linalg::WinogradInputTransformOp op);
13711377

13721378
/// Rewrite linalg.winograd_output_transform. The data layout of the output is
13731379
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1374-
/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
1375-
/// F. After the transformation, we get
1376-
///
1377-
/// scf.for %n = lo_n to hi_n step 1
1378-
/// scf.for %f = lo_f to hi_f step 1
1379-
/// %extracted = extract input<h x w> from result<h x w x n x f>
1380-
/// %ret = linalg.matmul AT, %extracted
1381-
/// %ret = linalg.matmul %ret, A
1382-
/// %inserted = insert %ret into ret<n x h x w x f>
1380+
/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
1381+
/// and tileW. After the transformation, we get
1382+
///
1383+
/// scf.for %h = 0 to tileH step 1
1384+
/// scf.for %w = 0 to tileW step 1
1385+
/// scf.for %n = 0 to N step 1
1386+
/// scf.for %f = 0 to F step 1
1387+
/// %extracted = extract %extracted<alphaH x alphaW> from
1388+
/// %input<alphaH x alphaW x tileH x tileW x N x F>
1389+
/// at [0, 0, %h, %w, %n, %f]
1390+
/// %ret = linalg.matmul AT, %extracted
1391+
/// %ret = linalg.matmul %ret, A
1392+
/// %inserted = insert %ret<alphaH x alphaW> into
1393+
/// output<N x H x W x F>
1394+
/// at [%n, (%h x m), (%w x m), %f]
13831395
FailureOr<Operation *>
13841396
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
13851397
linalg::WinogradOutputTransformOp op);

0 commit comments

Comments
 (0)