11
11
// ===----------------------------------------------------------------------===//
12
12
13
13
#include " ./Tiling.hpp"
14
+ #include " gc/Dialect/Arith/Utils/EasyBuild.h"
15
+ #include " gc/IR/EasyBuild.h"
16
+ #include " gc/IR/EasyBuildSCF.h"
14
17
#include " mlir/AsmParser/AsmParser.h"
15
18
#include " mlir/Dialect/Affine/IR/AffineOps.h"
16
19
#include " mlir/Dialect/Func/IR/FuncOps.h"
@@ -179,6 +182,7 @@ struct OuterLoopGenerationResult {
179
182
SmallVector<Operation *> tiledOps;
180
183
// / The `scf.for` operations that iterate over the tiles.
181
184
SmallVector<LoopLikeOpInterface> loops;
185
+ SmallVector<LoopLikeOpInterface> reductionLoops;
182
186
// / Values to use as replacements for the untiled op. Is the same size as the
183
187
// / number of results of the untiled op.
184
188
SmallVector<Value> replacements;
@@ -192,6 +196,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
192
196
auto nestedTileSizes = option.nestedTileSizes ;
193
197
auto loopType = option.loopType ;
194
198
auto loopDim = option.loopDim ;
199
+ SmallVector<mlir::utils::IteratorType> iteratorTypes =
200
+ linalgOp.getIteratorTypesArray ();
195
201
196
202
if (loopType.size () != loopDim.size () ||
197
203
loopDim.size () != nestedTileSizes.size ()) {
@@ -228,6 +234,13 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
228
234
return failure ();
229
235
b.replaceOp (currentOp, tilingResult->replacements );
230
236
currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOps .back ());
237
+
238
+ for (auto [dim, loop] : llvm::zip (currentDim, tilingResult->loops )) {
239
+ if (iteratorTypes[dim] == mlir::utils::IteratorType::reduction) {
240
+ result.reductionLoops .push_back (loop);
241
+ }
242
+ result.loops .push_back (loop);
243
+ }
231
244
} else if (type == OuterLoopGenerationOption::LoopType::ForallOp) {
232
245
SmallVector<OpFoldResult> tileSizes (
233
246
currentOp.getNumLoops (), getAsIndexOpFoldResult (b.getContext (), 0 ));
@@ -395,16 +408,16 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
395
408
}
396
409
397
410
/*
398
- forall([PM, PN]: [MThreads, NThreads) {
399
- for(PK : KThreads) {
411
+ matmul(A, B) -> C
412
+ ---------------->
413
+ forall([PM, PN, PK]: [MThreads, NThreads, KThreads]) {
400
414
CSlice = [KThreads, PM * MOuterBlock: (PM + 1) * MOuterBlock,
401
415
PN * NOuterBlock: (PN + 1) * NOuterBlock]
402
416
ASlice = A[PM * MOuterBlock: (PM + 1) * MOuterBlock, PK * KOuterBlock * (PK
403
417
+ 1) * KOuterBlock]
404
418
BSlice = B[PK * KOuterBlock * (PK + 1) * KOuterBlock, PN *
405
419
NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM
406
420
+ 1) * MOuterBlock, PN * NOuterBlock: (PN + 1) * NOuterBlock]
407
-
408
421
MNumBlock = MOuterBlock / MBlock
409
422
NNumBlock = NOuterBlock / NBlock
410
423
KNumBlock = KOuterBlock / KBlovk
@@ -426,9 +439,8 @@ iin_block_: (in + 1) * iin_block_] (init with 0 when ok == 0)
426
439
A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0));
427
440
}
428
441
}
429
- }
430
- C = final_reduce(CSlice)
431
442
}
443
+ C = final_reduce(CSlice)
432
444
*/
433
445
struct deepTileMatmul : public OpInterfaceRewritePattern <linalg::LinalgOp> {
434
446
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
@@ -508,12 +520,14 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
508
520
struct innerBodyGenerationOption {
509
521
bool hasFillOp = false ;
510
522
Value fillValue;
523
+ SmallVector<LoopLikeOpInterface> KLoopHandles;
511
524
};
512
525
513
526
LogicalResult
514
527
innerBodyGeneration (RewriterBase &rewriter, linalg::LinalgOp originOp,
515
528
linalg::LinalgOp currentOp,
516
529
const innerBodyGenerationOption &option) const {
530
+ mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc ()};
517
531
auto operandDimTypes = getOprandDimType (originOp);
518
532
MatmulConfig cfg = getDefaultMatmulConfig (originOp);
519
533
auto AShape = originOp.getShape (originOp.getDpsInputOperand (0 ));
@@ -656,18 +670,31 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
656
670
currentOp = matmul;
657
671
658
672
if (option.hasFillOp ) {
659
- // TODO: support partial K in sinsngle threads, control flow may need
660
- // easy builder support
661
673
rewriter.setInsertionPointAfter (currentOp);
662
- auto fillOp = rewriter.create <linalg::FillOp>(
663
- currentOp->getLoc (), option.fillValue , currentOp.getDpsInits ()[0 ]);
664
- IRMapping mapping;
665
- mapping.map (currentOp.getDpsInits ()[0 ], fillOp.getResult (0 ));
666
- auto res = rewriter.clone (*(currentOp.getOperation ()), mapping);
667
- rewriter.replaceOp (currentOp, res);
668
- currentOp = dyn_cast<linalg::LinalgOp>(res);
674
+
675
+ auto cond = eb (true );
676
+ for (auto loop : option.KLoopHandles ) {
677
+ auto induceVar = eb.wrap <mlir::easybuild::EBUnsigned>(
678
+ loop.getLoopRegions ().front ()->front ().getArgument (0 ));
679
+ auto currentCond = induceVar == eb.toIndex (0 );
680
+ cond = cond & currentCond;
681
+ }
682
+ EB_scf_if (cond, {currentOp.getDpsInits ()[0 ].getType ()}) {
683
+ auto fillOp = rewriter.create <linalg::FillOp>(
684
+ currentOp->getLoc (), option.fillValue , currentOp.getDpsInits ()[0 ]);
685
+ IRMapping mapping;
686
+ mapping.map (currentOp.getDpsInits ()[0 ], fillOp.getResult (0 ));
687
+ auto res = rewriter.clone (*(currentOp.getOperation ()), mapping);
688
+ eb.yield (res->getResult (0 ));
689
+ }
690
+ EB_else {
691
+ auto res = rewriter.clone (*(currentOp.getOperation ()));
692
+ eb.yield (res->getResult (0 ));
693
+ }
694
+ auto ifOp = eb.getLastOperaion ();
695
+ rewriter.replaceOp (currentOp, ifOp);
696
+ ifOp->getParentOfType <func::FuncOp>().dump ();
669
697
}
670
- currentOp.getOperation ()->getParentOfType <func::FuncOp>().dump ();
671
698
return success ();
672
699
}
673
700
@@ -685,7 +712,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
685
712
// if-else block)
686
713
bool hasFillOp = false ;
687
714
Value fillValue;
688
- SmallVector<LoopLikeOpInterface> KLoopHandle;
689
715
if (auto op = dyn_cast<linalg::FillOp>(
690
716
linalgOp.getDpsInits ()[0 ].getDefiningOp ())) {
691
717
hasFillOp = true ;
@@ -707,7 +733,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
707
733
// Step 3 inner loop generation, convert the linalg.generic to brgemm
708
734
if (failed (innerBodyGeneration (
709
735
rewriter, matmulOp, linalgOp,
710
- innerBodyGenerationOption{hasFillOp, fillValue}))) {
736
+ innerBodyGenerationOption{hasFillOp, fillValue,
737
+ outerLoopResult->reductionLoops }))) {
711
738
return failure ();
712
739
}
713
740
return success ();
0 commit comments