Skip to content

Commit 5ed4fc1

Browse files
committed
support 2Dx4D/5D case
1 parent c897de1 commit 5ed4fc1

File tree

2 files changed

+133
-82
lines changed

2 files changed

+133
-82
lines changed

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 112 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,14 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) {
136136
cfg.KBlock = 64;
137137
cfg.MThreads = 2;
138138
cfg.NThreads = 2;
139-
cfg.KThreads = 2;
139+
cfg.KThreads = 1;
140140
return cfg;
141141
}
142142

143-
static Value tensorViewRankedTensor(RewriterBase &rewriter,
144-
RankedTensorType outTensorType,
145-
Value value) {
143+
static Value
144+
tensorViewRankedTensor(RewriterBase &rewriter, RankedTensorType outTensorType,
145+
Value value,
146+
ArrayRef<int64_t> permutation = SmallVector<int64_t>{}) {
146147
// TODO: add support for plain layout transpose
147148
Value result, currentValue = value;
148149
auto loc = currentValue.getLoc();
@@ -175,33 +176,57 @@ static Value tensorViewRankedTensor(RewriterBase &rewriter,
175176

176177
if (outShape.size() < inShape.size()) {
177178
SmallVector<ReassociationIndices> reassocIndices;
178-
ReassociationIndices firstEntry;
179-
for (auto i = 0UL; i < inShape.size() - outShape.size() + 1; i++) {
180-
firstEntry.push_back(i);
181-
}
182-
reassocIndices.push_back(firstEntry);
183-
for (auto i = inShape.size() - outShape.size() + 1UL; i < inShape.size();
184-
i++) {
185-
reassocIndices.push_back({(int)i});
179+
uint64_t outIdx = 0UL, inIdx = 0UL;
180+
while (inIdx < inShape.size() && outIdx < outShape.size()) {
181+
ReassociationIndices firstEntry;
182+
auto remaining = outShape[outIdx++];
183+
if (remaining == 1) {
184+
firstEntry.push_back(inIdx++);
185+
reassocIndices.push_back(firstEntry);
186+
continue;
187+
}
188+
while (remaining > 1) {
189+
remaining /= inShape[inIdx];
190+
firstEntry.push_back(inIdx++);
191+
}
192+
reassocIndices.push_back(firstEntry);
186193
}
187194
result = rewriter.create<tensor::CollapseShapeOp>(
188195
loc, outTensorType, currentValue, reassocIndices);
189196
} else if (outShape.size() > inShape.size()) {
190197
SmallVector<ReassociationIndices> reassocIndices;
191-
ReassociationIndices firstEntry;
192-
for (auto i = 0UL; i < outShape.size() - inShape.size() + 1; i++) {
193-
firstEntry.push_back((int)i);
194-
}
195-
reassocIndices.push_back(firstEntry);
196-
for (auto i = outShape.size() - inShape.size() + 1UL; i < outShape.size();
197-
i++) {
198-
reassocIndices.push_back({(int)i});
198+
uint64_t outIdx = 0UL, inIdx = 0UL;
199+
while (outIdx < outShape.size() && inIdx < inShape.size()) {
200+
ReassociationIndices firstEntry;
201+
auto remaining = inShape[inIdx++];
202+
if (remaining == 1) {
203+
firstEntry.push_back(outIdx++);
204+
reassocIndices.push_back(firstEntry);
205+
continue;
206+
}
207+
while (remaining > 1) {
208+
remaining /= outShape[outIdx];
209+
firstEntry.push_back(outIdx++);
210+
}
211+
reassocIndices.push_back(firstEntry);
199212
}
200213
result = rewriter.create<tensor::ExpandShapeOp>(
201214
loc, outTensorType, currentValue, reassocIndices);
202215
} else {
203216
result = rewriter.create<tensor::CastOp>(loc, outTensorType, currentValue);
204217
}
218+
219+
if (!permutation.empty()) {
220+
SmallVector<int64_t> transposeShape;
221+
for (auto idx : permutation) {
222+
transposeShape.push_back(outShape[idx]);
223+
}
224+
auto initOp = rewriter.create<tensor::EmptyOp>(loc, transposeShape,
225+
tensorElementType);
226+
auto transposeOp = rewriter.create<linalg::TransposeOp>(
227+
loc, result, initOp->getResult(0), permutation);
228+
result = transposeOp->getResult(0);
229+
}
205230
return result;
206231
}
207232

@@ -345,6 +370,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
345370
return b.notifyMatchFailure(
346371
linalgOp, "currentOp should not has pure buffer semantics");
347372
linalg::LinalgOp currentOp = linalgOp;
373+
348374
for (auto loopTypeIter : llvm::enumerate(loopType)) {
349375
auto [i, loopType] = loopTypeIter;
350376
auto currentDim = loopDim[i];
@@ -486,6 +512,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
486512
bool isExtract,
487513
SmallVector<int64_t> size,
488514
int shrinDimNum = 0) {
515+
OpBuilder::InsertionGuard guard(rewriter);
516+
rewriter.setInsertionPoint(op);
489517
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
490518
SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets();
491519
SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes();
@@ -514,6 +542,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
514542
static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
515543
Operation *op, Value source,
516544
SmallVector<int64_t> size) {
545+
OpBuilder::InsertionGuard guard(rewriter);
546+
rewriter.setInsertionPoint(op);
517547
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
518548
SmallVector<OpFoldResult> mixedOffsets = insertSlice.getMixedOffsets();
519549
SmallVector<OpFoldResult> mixedSizes = insertSlice.getMixedSizes();
@@ -575,35 +605,34 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
575605
linalgOp.getReductionDims(KDimPos);
576606
getMatmulParallelDims(linalgOp, 0, MDimPos);
577607
getMatmulParallelDims(linalgOp, 1, NDimPos);
578-
bool useBlockedLayout = KDimPos.size() > 1;
579608

580609
OuterLoopGenerationOption option;
581610
auto iteratorTypes = linalgOp.getIteratorTypesArray();
582611
auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1);
583612
auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0);
584613
auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1);
585614
auto KParallelBlockSize =
586-
useBlockedLayout
615+
KDimPos.size() > 1
587616
? divAndCeil(KFirstDim, cfg.KThreads)
588617
: divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) *
589618
cfg.KBlock;
590619
auto MParallelBlockSize =
591-
useBlockedLayout
620+
MDimPos.size() > 1
592621
? divAndCeil(MFirstDim, cfg.MThreads)
593622
: divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) *
594623
cfg.MBlock;
595624
auto NParallelBlockSize =
596-
useBlockedLayout
625+
NDimPos.size() > 1
597626
? divAndCeil(NFirstDim, cfg.NThreads)
598627
: divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) *
599628
cfg.NBlock;
600-
auto KOuterBlockSize = useBlockedLayout
629+
auto KOuterBlockSize = KDimPos.size() > 1
601630
? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1
602631
: cfg.KBlock;
603-
auto MOuterBlockSize = useBlockedLayout
632+
auto MOuterBlockSize = MDimPos.size() > 1
604633
? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1
605634
: cfg.MBlock;
606-
auto NOuterBlockSize = useBlockedLayout
635+
auto NOuterBlockSize = NDimPos.size() > 1
607636
? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
608637
: cfg.NBlock;
609638
// Outer
@@ -631,11 +660,23 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
631660
option.loopDim.emplace_back(SmallVector<int>{dim});
632661
}
633662
// Inner
634-
if (!useBlockedLayout) {
663+
if (KDimPos.size() == 1) {
635664
option.nestedTileSizes.emplace_back(SmallVector<int>{cfg.KBlock});
636665
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
637666
option.loopDim.emplace_back(SmallVector<int>{(int)KDimPos.back()});
638667
}
668+
if (MDimPos.size() == 1) {
669+
option.nestedTileSizes.emplace_back(
670+
SmallVector<int>{cfg.innerMostMBlock});
671+
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
672+
option.loopDim.emplace_back(SmallVector<int>{(int)MDimPos.back()});
673+
}
674+
if (NDimPos.size() == 1) {
675+
option.nestedTileSizes.emplace_back(
676+
SmallVector<int>{cfg.innerMostNBlock});
677+
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
678+
option.loopDim.emplace_back(SmallVector<int>{(int)NDimPos.back()});
679+
}
639680
for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) {
640681
if (dim != MDimPos.back() && dim != NDimPos.back() &&
641682
iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
@@ -658,17 +699,24 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
658699
linalg::LinalgOp originOp,
659700
linalg::LinalgOp currentOp,
660701
innerBodyGenerationOption &option) const {
702+
661703
mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()};
662704
auto operandDimTypes = getOprandDimType(originOp);
663705
MatmulConfig cfg = getDefaultMatmulConfig(originOp);
664706
auto AShape = originOp.getShape(originOp.getDpsInputOperand(0));
665707
auto BShape = originOp.getShape(originOp.getDpsInputOperand(1));
666708
auto CShape = originOp.getShape(originOp.getDpsInitOperand(0));
667-
bool useBlockedLayout = BShape.size() > 2;
709+
710+
auto MDimNum = std::count_if((*operandDimTypes)[0].begin(),
711+
(*operandDimTypes)[0].end(),
712+
[](DimType d) { return d == DimType::M; });
713+
auto NDimNum = std::count_if((*operandDimTypes)[1].begin(),
714+
(*operandDimTypes)[1].end(),
715+
[](DimType d) { return d == DimType::N; });
668716
// TODO: support plain in/block out format
669717
SmallVector<int64_t> AInnermostDims, BInnermostDims, CInnermostDims;
670-
if (useBlockedLayout) {
671-
bool firstM = true, firstK = true, firstN = true;
718+
bool firstM = true, firstK = true, firstN = true;
719+
if (MDimNum > 1) {
672720
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[0])) {
673721
if (iter == DimType::M && firstM) {
674722
AInnermostDims.push_back(1);
@@ -682,21 +730,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
682730
AInnermostDims.push_back(AShape[idx]);
683731
}
684732
}
685-
firstN = true;
686-
firstK = true;
687-
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
688-
if (iter == DimType::N && firstN) {
689-
BInnermostDims.push_back(1);
690-
firstN = false;
691-
} else if (iter == DimType::Batch) {
692-
BInnermostDims.push_back(1);
693-
} else if (iter == DimType::K && firstK) {
694-
BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock);
695-
firstK = false;
696-
} else {
697-
BInnermostDims.push_back(BShape[idx]);
698-
}
699-
}
700733
firstM = true;
701734
firstN = true;
702735
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[2])) {
@@ -716,11 +749,29 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
716749
AInnermostDims = SmallVector<int64_t>{cfg.innerMostMBlock,
717750
cfg.KBlock / cfg.innerMostKBlock *
718751
cfg.innerMostKBlock};
752+
CInnermostDims =
753+
SmallVector<int64_t>{cfg.innerMostMBlock, cfg.innerMostNBlock};
754+
}
755+
if (NDimNum > 1) {
756+
firstN = true;
757+
firstK = true;
758+
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
759+
if (iter == DimType::N && firstN) {
760+
BInnermostDims.push_back(1);
761+
firstN = false;
762+
} else if (iter == DimType::Batch) {
763+
BInnermostDims.push_back(1);
764+
} else if (iter == DimType::K && firstK) {
765+
BInnermostDims.push_back(cfg.KBlock / cfg.innerMostKBlock);
766+
firstK = false;
767+
} else {
768+
BInnermostDims.push_back(BShape[idx]);
769+
}
770+
}
771+
} else {
719772
BInnermostDims = SmallVector<int64_t>{cfg.KBlock / cfg.innerMostKBlock *
720773
cfg.innerMostKBlock,
721774
cfg.innerMostNBlock};
722-
CInnermostDims =
723-
SmallVector<int64_t>{cfg.innerMostMBlock, cfg.innerMostNBlock};
724775
}
725776

726777
OpBuilder::InsertionGuard guard(rewriter);
@@ -747,35 +798,35 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
747798
AInnermostDims, useShrinkedLayout))) {
748799
return failure();
749800
}
750-
751801
// View the tensor to brgemm required format
752802
Value dataOprand = tensorViewRankedTensor(
753803
rewriter,
754804
mlir::RankedTensorType::get(
755-
useBlockedLayout
756-
? SmallVector<int64_t>(AInnermostDims.begin() + 1,
757-
AInnermostDims.end())
758-
: SmallVector<int64_t>{1, AInnermostDims[0], AInnermostDims[1]},
805+
MDimNum > 1 ? SmallVector<int64_t>(AInnermostDims.begin() + 1,
806+
AInnermostDims.end())
807+
: SmallVector<int64_t>{cfg.innerMostMBlock,
808+
cfg.KBlock / cfg.innerMostKBlock,
809+
cfg.innerMostKBlock},
759810
dataType.getElementType()),
760-
currentOp.getDpsInputs()[0]);
811+
currentOp.getDpsInputs()[0],
812+
MDimNum == 1 ? SmallVector<int64_t>{1, 0, 2} : SmallVector<int64_t>{});
761813
Value weightOprand = tensorViewRankedTensor(
762814
rewriter,
763815
mlir::RankedTensorType::get(
764-
useBlockedLayout
765-
? SmallVector<int64_t>(BInnermostDims.begin() + 1,
766-
BInnermostDims.end())
767-
: SmallVector<int64_t>{1, BInnermostDims[0], BInnermostDims[1]},
816+
NDimNum > 1 ? SmallVector<int64_t>(BInnermostDims.begin() + 1,
817+
BInnermostDims.end())
818+
: SmallVector<int64_t>{cfg.KBlock / cfg.innerMostKBlock,
819+
cfg.innerMostKBlock,
820+
cfg.innerMostNBlock},
768821
weightType.getElementType()),
769822
currentOp.getDpsInputs()[1]);
770823
Value resultOprand = tensorViewRankedTensor(
771824
rewriter,
772825
mlir::RankedTensorType::get(
773-
SmallVector<int64_t>(CInnermostDims.begin() +
774-
(useBlockedLayout ? 2 : 0),
826+
SmallVector<int64_t>(CInnermostDims.begin() + (MDimNum > 1 ? 2 : 0),
775827
CInnermostDims.end()),
776828
resultType.getElementType()),
777829
currentOp.getDpsInits()[0]);
778-
779830
// Create the brgemm op and replace the origin linalg op
780831
linalg::LinalgOp matmul;
781832
if (BInnermostDims.size() == 4 || BInnermostDims.size() == 2) {

test/gc/Transform/deepTileContractionNamedOp.mlir

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
// RUN: gc-opt --split-input-file --deep-tile-contraction-named-op %s
22

3-
// // -----
3+
// -----
44

5-
// /// CHECK-LABEL: @matmul_4Dx4D_f32
6-
// func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> {
7-
// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32>
8-
// %cst_0 = arith.constant 0.000000e+00 : f32
9-
// %0 = tensor.empty() : tensor<128x128x32x32xf32>
10-
// %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
11-
// %2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
12-
// return %2 : tensor<128x128x32x32xf32>
13-
// }
5+
/// CHECK-LABEL: @matmul_4Dx4D_f32
6+
func.func @matmul_4Dx4D_f32(%arg0: tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32> {
7+
%cst = arith.constant dense<1.000000e+00> : tensor<128x128x32x32x1xf32>
8+
%cst_0 = arith.constant 0.000000e+00 : f32
9+
%0 = tensor.empty() : tensor<128x128x32x32xf32>
10+
%1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
11+
%2 = linalgx.mm4d_vnni ins(%arg0, %cst : tensor<128x128x32x32xf32>, tensor<128x128x32x32x1xf32>) outs(%1 : tensor<128x128x32x32xf32>) -> tensor<128x128x32x32xf32>
12+
return %2 : tensor<128x128x32x32xf32>
13+
}
1414

1515
// -----
1616

@@ -24,7 +24,7 @@ func.func @matmul_2Dx2D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf3
2424
return %2 : tensor<4096x4096xf32>
2525
}
2626

27-
// // -----
27+
// -----
2828

2929
// /// CHECK-LABEL: @matmul_2Dx4D_f32
3030
// func.func @matmul_4Dx4D_f32(%arg0: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
@@ -48,15 +48,15 @@ func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>) -> tensor<128x12
4848
return %2 : tensor<128x128x32x32xbf16>
4949
}
5050

51-
// // -----
51+
// -----
5252

53-
// /// CHECK-LABEL: @matmul_2Dx4D_bf16
54-
// func.func @matmul_4Dx4D_bf16(%arg0: tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> {
55-
// %cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16>
56-
// %cst_0 = arith.constant 0.000000e+00 : bf16
57-
// %0 = tensor.empty() : tensor<4096x4096xbf16>
58-
// %1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
59-
// %2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
60-
// return %2 : tensor<4096x4096xbf16>
61-
// }
53+
/// CHECK-LABEL: @matmul_2Dx4D_bf16
54+
func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16> {
55+
%cst = arith.constant dense<1.000000e+00> : tensor<128x128x16x32x2xbf16>
56+
%cst_0 = arith.constant 0.000000e+00 : bf16
57+
%0 = tensor.empty() : tensor<4096x4096xbf16>
58+
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
59+
%2 = linalgx.mm2d_vnni ins(%arg0, %cst : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
60+
return %2 : tensor<4096x4096xbf16>
61+
}
6262

0 commit comments

Comments
 (0)