Skip to content

Commit 7c8cfbb

Browse files
committed
Init C buffer with easy builder
1 parent 6a7a6ed commit 7c8cfbb

File tree

3 files changed

+48
-20
lines changed

3 files changed

+48
-20
lines changed

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
77
add_mlir_library(GCPasses
88
OneDNNGraphToLinalg.cpp
99
DeepTileContractionNamedOp.cpp
10+
Tiling.cpp
1011

1112
ADDITIONAL_HEADER_DIRS
1213
${PROJECT_SOURCE_DIR}/include

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "./Tiling.hpp"
14+
#include "gc/Dialect/Arith/Utils/EasyBuild.h"
15+
#include "gc/IR/EasyBuild.h"
16+
#include "gc/IR/EasyBuildSCF.h"
1417
#include "mlir/AsmParser/AsmParser.h"
1518
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1619
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -179,6 +182,7 @@ struct OuterLoopGenerationResult {
179182
SmallVector<Operation *> tiledOps;
180183
/// The `scf.for` operations that iterate over the tiles.
181184
SmallVector<LoopLikeOpInterface> loops;
185+
SmallVector<LoopLikeOpInterface> reductionLoops;
182186
/// Values to use as replacements for the untiled op. Is the same size as the
183187
/// number of results of the untiled op.
184188
SmallVector<Value> replacements;
@@ -192,6 +196,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
192196
auto nestedTileSizes = option.nestedTileSizes;
193197
auto loopType = option.loopType;
194198
auto loopDim = option.loopDim;
199+
SmallVector<mlir::utils::IteratorType> iteratorTypes =
200+
linalgOp.getIteratorTypesArray();
195201

196202
if (loopType.size() != loopDim.size() ||
197203
loopDim.size() != nestedTileSizes.size()) {
@@ -228,6 +234,13 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
228234
return failure();
229235
b.replaceOp(currentOp, tilingResult->replacements);
230236
currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOps.back());
237+
238+
for (auto [dim, loop] : llvm::zip(currentDim, tilingResult->loops)) {
239+
if (iteratorTypes[dim] == mlir::utils::IteratorType::reduction) {
240+
result.reductionLoops.push_back(loop);
241+
}
242+
result.loops.push_back(loop);
243+
}
231244
} else if (type == OuterLoopGenerationOption::LoopType::ForallOp) {
232245
SmallVector<OpFoldResult> tileSizes(
233246
currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0));
@@ -395,16 +408,16 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
395408
}
396409

397410
/*
398-
forall([PM, PN]: [MThreads, NThreads) {
399-
for(PK : KThreads) {
411+
matmul(A, B) -> C
412+
---------------->
413+
forall([PM, PN, PK]: [MThreads, NThreads, KThreads]) {
400414
CSlice = [KThreads, PM * MOuterBlock: (PM + 1) * MOuterBlock,
401415
PN * NOuterBlock: (PN + 1) * NOuterBlock]
402416
ASlice = A[PM * MOuterBlock: (PM + 1) * MOuterBlock, PK * KOuterBlock * (PK
403417
+ 1) * KOuterBlock]
404418
BSlice = B[PK * KOuterBlock * (PK + 1) * KOuterBlock, PN *
405419
NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM
406420
+ 1) * MOuterBlock, PN * NOuterBlock: (PN + 1) * NOuterBlock]
407-
408421
MNumBlock = MOuterBlock / MBlock
409422
NNumBlock = NOuterBlock / NBlock
410423
KNumBlock = KOuterBlock / KBlovk
@@ -426,9 +439,8 @@ iin_block_: (in + 1) * iin_block_] (init with 0 when ok == 0)
426439
A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0));
427440
}
428441
}
429-
}
430-
C = final_reduce(CSlice)
431442
}
443+
C = final_reduce(CSlice)
432444
*/
433445
struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
434446
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
@@ -508,12 +520,14 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
508520
struct innerBodyGenerationOption {
509521
bool hasFillOp = false;
510522
Value fillValue;
523+
SmallVector<LoopLikeOpInterface> KLoopHandles;
511524
};
512525

513526
LogicalResult
514527
innerBodyGeneration(RewriterBase &rewriter, linalg::LinalgOp originOp,
515528
linalg::LinalgOp currentOp,
516529
const innerBodyGenerationOption &option) const {
530+
mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()};
517531
auto operandDimTypes = getOprandDimType(originOp);
518532
MatmulConfig cfg = getDefaultMatmulConfig(originOp);
519533
auto AShape = originOp.getShape(originOp.getDpsInputOperand(0));
@@ -656,18 +670,31 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
656670
currentOp = matmul;
657671

658672
if (option.hasFillOp) {
659-
// TODO: support partial K in sinsngle threads, control flow may need
660-
// easy builder support
661673
rewriter.setInsertionPointAfter(currentOp);
662-
auto fillOp = rewriter.create<linalg::FillOp>(
663-
currentOp->getLoc(), option.fillValue, currentOp.getDpsInits()[0]);
664-
IRMapping mapping;
665-
mapping.map(currentOp.getDpsInits()[0], fillOp.getResult(0));
666-
auto res = rewriter.clone(*(currentOp.getOperation()), mapping);
667-
rewriter.replaceOp(currentOp, res);
668-
currentOp = dyn_cast<linalg::LinalgOp>(res);
674+
675+
auto cond = eb(true);
676+
for (auto loop : option.KLoopHandles) {
677+
auto induceVar = eb.wrap<mlir::easybuild::EBUnsigned>(
678+
loop.getLoopRegions().front()->front().getArgument(0));
679+
auto currentCond = induceVar == eb.toIndex(0);
680+
cond = cond & currentCond;
681+
}
682+
EB_scf_if(cond, {currentOp.getDpsInits()[0].getType()}) {
683+
auto fillOp = rewriter.create<linalg::FillOp>(
684+
currentOp->getLoc(), option.fillValue, currentOp.getDpsInits()[0]);
685+
IRMapping mapping;
686+
mapping.map(currentOp.getDpsInits()[0], fillOp.getResult(0));
687+
auto res = rewriter.clone(*(currentOp.getOperation()), mapping);
688+
eb.yield(res->getResult(0));
689+
}
690+
EB_else {
691+
auto res = rewriter.clone(*(currentOp.getOperation()));
692+
eb.yield(res->getResult(0));
693+
}
694+
auto ifOp = eb.getLastOperaion();
695+
rewriter.replaceOp(currentOp, ifOp);
696+
ifOp->getParentOfType<func::FuncOp>().dump();
669697
}
670-
currentOp.getOperation()->getParentOfType<func::FuncOp>().dump();
671698
return success();
672699
}
673700

@@ -685,7 +712,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
685712
// if-else block)
686713
bool hasFillOp = false;
687714
Value fillValue;
688-
SmallVector<LoopLikeOpInterface> KLoopHandle;
689715
if (auto op = dyn_cast<linalg::FillOp>(
690716
linalgOp.getDpsInits()[0].getDefiningOp())) {
691717
hasFillOp = true;
@@ -707,7 +733,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
707733
// Step 3 inner loop generation, convert the linalg.generic to brgemm
708734
if (failed(innerBodyGeneration(
709735
rewriter, matmulOp, linalgOp,
710-
innerBodyGenerationOption{hasFillOp, fillValue}))) {
736+
innerBodyGenerationOption{hasFillOp, fillValue,
737+
outerLoopResult->reductionLoops}))) {
711738
return failure();
712739
}
713740
return success();

lib/gc/Transforms/Tiling.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
844844
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
845845
"many elements as number of threads");
846846

847-
if (redDims.front() >= numThreads.size())
847+
if ((unsigned)redDims.front() >= numThreads.size())
848848
return b.notifyMatchFailure(
849849
op, "reduction dimension must be mapped to threads");
850850

@@ -914,7 +914,7 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
914914
}
915915

916916
auto nonZeroDimIdx = 0;
917-
for (auto dim = 0; dim < numThreads.size(); dim++) {
917+
for (auto dim = 0UL; dim < numThreads.size(); dim++) {
918918
if (!isConstantIntValue(numThreads[dim], 0)) {
919919
if (llvm::find(redDims, dim) != redDims.end())
920920
outOffsets[dim] = forallOp.getInductionVars()[nonZeroDimIdx];
@@ -973,7 +973,7 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
973973
int64_t offIdx = 0;
974974
int64_t sizeIdx = 0;
975975
int64_t nonZeroDimIdx = 0;
976-
for (int64_t i = 0; i < numThreads.size(); ++i) {
976+
for (auto i = 0UL; i < numThreads.size(); ++i) {
977977
if (llvm::find(redDims, i) != redDims.end()) {
978978
if (hasReductionThreads) {
979979
resultOffsetsRank.push_back(

0 commit comments

Comments
 (0)