@@ -1318,8 +1318,8 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13181318
13191319// / Rewrite linalg.winograd_filter_transform. The data layout of the filter is
13201320// / FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1321- // / from FHWC first. We need to generate 2 levels of loops to iterate on F and
1322- // / C. After the rewriting, we get
1321+ // / from FHWC first. We generate 2 levels of loops to iterate on F and C. After
1322+ // / the rewriting, we get
13231323// /
13241324// / scf.for %f = lo_f to hi_f step 1
13251325// / scf.for %c = lo_c to hi_c step 1
@@ -1333,30 +1333,42 @@ decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
13331333
13341334// / Rewrite linalg.winograd_input_transform. The data layout of the input is
13351335// / NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1336- // / from NHWC first. We need to generate 2 levels of loops to iterate on N and
1337- // / C. After the rewriting, we get
1338- // /
1339- // / scf.for %n = lo_n to hi_n step 1
1340- // / scf.for %c = lo_c to hi_c step 1
1341- // / %extracted = extract input<h x w> from input<n x h x w x c>
1342- // / %ret = linalg.matmul BT, %extracted
1343- // / %ret = linalg.matmul %ret, B
1344- // / %inserted = insert %ret into input<h x w x n x c>
1336+ // / from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
1337+ // / and tileW. After the rewriting, we get
1338+ // /
1339+ // / scf.for %h = 0 to tileH step 1
1340+ // / scf.for %w = 0 to tileW step 1
1341+ // / scf.for %n = 0 to N step 1
1342+ // / scf.for %c = 0 to C step 1
1343+ // / %extracted = extract %extracted<alphaH x alphaW> from
1344+ // / %input<N x H x W x C>
1345+ // / at [%n, (%h x m), (%w x m), %c]
1346+ // / %ret = linalg.matmul BT, %extracted
1347+ // / %ret = linalg.matmul %ret, B
1348+ // / %inserted = insert %ret<alphaH x alphaW> into
1349+ // / %output<alphaH x alphaW x tileH x tileW x N x C>
1350+ // / at [0, 0, %h, %w, %n, %c]
13451351FailureOr<Operation *>
13461352decomposeWinogradInputTransformOp (RewriterBase &rewriter,
13471353 linalg::WinogradInputTransformOp op);
13481354
13491355// / Rewrite linalg.winograd_output_transform. The data layout of the output is
13501356// / HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1351- // / from HWNF first. We need to generate 2 levels of loops to iterate on N and
1352- // / F. After the transformation, we get
1353- // /
1354- // / scf.for %n = lo_n to hi_n step 1
1355- // / scf.for %f = lo_f to hi_f step 1
1356- // / %extracted = extract input<h x w> from result<h x w x n x f>
1357- // / %ret = linalg.matmul AT, %extracted
1358- // / %ret = linalg.matmul %ret, A
1359- // / %inserted = insert %ret into ret<n x h x w x f>
1357+ // / from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
1358+ // / and tileW. After the transformation, we get
1359+ // /
1360+ // / scf.for %h = 0 to tileH step 1
1361+ // / scf.for %w = 0 to tileW step 1
1362+ // / scf.for %n = 0 to N step 1
1363+ // / scf.for %f = 0 to F step 1
1364+ // / %extracted = extract %extracted<alphaH x alphaW> from
1365+ // / %input<alphaH x alphaW x tileH x tileW x N x F>
1366+ // / at [0, 0, %h, %w, %n, %f]
1367+ // / %ret = linalg.matmul AT, %extracted
1368+ // / %ret = linalg.matmul %ret, A
1369+ // / %inserted = insert %ret<alphaH x alphaW> into
1370+ // / output<N x H x W x F>
1371+ // / at [%n, (%h x m), (%w x m), %f]
13601372FailureOr<Operation *>
13611373decomposeWinogradOutputTransformOp (RewriterBase &rewriter,
13621374 linalg::WinogradOutputTransformOp op);
0 commit comments