@@ -1341,8 +1341,8 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13411341
13421342// / Rewrite linalg.winograd_filter_transform. The data layout of the filter is
13431343// / FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1344- // / from FHWC first. We need to generate 2 levels of loops to iterate on F and
1345- // / C. After the rewriting, we get
1344+ // / from FHWC first. We generate 2 levels of loops to iterate on F and C. After
1345+ // / the rewriting, we get
13461346// /
13471347// / scf.for %f = lo_f to hi_f step 1
13481348// / scf.for %c = lo_c to hi_c step 1
@@ -1356,30 +1356,42 @@ decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
13561356
13571357// / Rewrite linalg.winograd_input_transform. The data layout of the input is
13581358// / NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1359- // / from NHWC first. We need to generate 2 levels of loops to iterate on N and
1360- // / C. After the rewriting, we get
1361- // /
1362- // / scf.for %n = lo_n to hi_n step 1
1363- // / scf.for %c = lo_c to hi_c step 1
1364- // / %extracted = extract input<h x w> from input<n x h x w x c>
1365- // / %ret = linalg.matmul BT, %extracted
1366- // / %ret = linalg.matmul %ret, B
1367- // / %inserted = insert %ret into input<h x w x n x c>
1359+ // / from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
1360+ // / and tileW. After the rewriting, we get
1361+ // /
1362+ // / scf.for %h = 0 to tileH step 1
1363+ // / scf.for %w = 0 to tileW step 1
1364+ // / scf.for %n = 0 to N step 1
1365+ // / scf.for %c = 0 to C step 1
1366+ // / %extracted = extract %extracted<alphaH x alphaW> from
1367+ // / %input<N x H x W x C>
1368+ // / at [%n, (%h x m), (%w x m), %c]
1369+ // / %ret = linalg.matmul BT, %extracted
1370+ // / %ret = linalg.matmul %ret, B
1371+ // / %inserted = insert %ret<alphaH x alphaW> into
1372+ // / %output<alphaH x alphaW x tileH x tileW x N x C>
1373+ // / at [0, 0, %h, %w, %n, %c]
13681374FailureOr<Operation *>
13691375decomposeWinogradInputTransformOp (RewriterBase &rewriter,
13701376 linalg::WinogradInputTransformOp op);
13711377
13721378// / Rewrite linalg.winograd_output_transform. The data layout of the output is
13731379// / HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1374- // / from HWNF first. We need to generate 2 levels of loops to iterate on N and
1375- // / F. After the transformation, we get
1376- // /
1377- // / scf.for %n = lo_n to hi_n step 1
1378- // / scf.for %f = lo_f to hi_f step 1
1379- // / %extracted = extract input<h x w> from result<h x w x n x f>
1380- // / %ret = linalg.matmul AT, %extracted
1381- // / %ret = linalg.matmul %ret, A
1382- // / %inserted = insert %ret into ret<n x h x w x f>
1380+ // / from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
1381+ // / and tileW. After the transformation, we get
1382+ // /
1383+ // / scf.for %h = 0 to tileH step 1
1384+ // / scf.for %w = 0 to tileW step 1
1385+ // / scf.for %n = 0 to N step 1
1386+ // / scf.for %f = 0 to F step 1
1387+ // / %extracted = extract %extracted<alphaH x alphaW> from
1388+ // / %input<alphaH x alphaW x tileH x tileW x N x F>
1389+ // / at [0, 0, %h, %w, %n, %f]
1390+ // / %ret = linalg.matmul AT, %extracted
1391+ // / %ret = linalg.matmul %ret, A
1392+ // / %inserted = insert %ret<alphaH x alphaW> into
1393+ // / output<N x H x W x F>
1394+ // / at [%n, (%h x m), (%w x m), %f]
13831395FailureOr<Operation *>
13841396decomposeWinogradOutputTransformOp (RewriterBase &rewriter,
13851397 linalg::WinogradOutputTransformOp op);
0 commit comments