Skip to content

Commit 4b70c32

Browse files
committed
fixup! fixup! [mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect
Address comment from Adam
1 parent 4d86070 commit 4b70c32

File tree

3 files changed

+34
-29
lines changed

3 files changed

+34
-29
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,27 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
303303
void populateVectorToFromElementsToShuffleTreePatterns(
304304
RewritePatternSet &patterns, PatternBenefit benefit = 1);
305305

306-
/// TODO
307-
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns);
306+
/// Populate the pattern set with the following patterns:
307+
///
308+
/// [ContractionOpToMatmulOpLowering]
309+
/// Lowers `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
310+
///
311+
/// Given the high benefit, this will be prioriotised over other
312+
/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
313+
/// only run this registration conditionally.
314+
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns,
315+
PatternBenefit benefit = 100);
308316

309-
/// TODO
310-
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns);
317+
/// Populate the pattern set with the following patterns:
318+
///
319+
/// [TransposeOpLowering]
320+
/// Lowers `vector.contract` to `llvm.intr.matrix.multiply`.
321+
///
322+
/// Given the high benefit, this will be prioriotised over other
323+
/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
324+
/// only run this registration conditionally.
325+
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns,
326+
PatternBenefit benefit = 100);
311327

312328
} // namespace vector
313329
} // namespace mlir

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2139,16 +2139,9 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
21392139
return res;
21402140
}
21412141

2142-
/// Progressive lowering of TransposeOp.
2143-
/// One:
2144-
/// %x = vector.transpose %y, [1, 0]
2145-
/// is replaced by:
2146-
/// %z = arith.constant dense<0.000000e+00>
2147-
/// %0 = vector.extract %y[0, 0]
2148-
/// %1 = vector.insert %0, %z [0, 0]
2149-
/// ..
2150-
/// %x = vector.insert .., .. [.., ..]
2151-
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
2142+
/// Lowers vector.transpose to llvm.intr.matrix.transpose
2143+
class TransposeOpToMatrixTransposeOpLowering
2144+
: public OpRewritePattern<vector::TransposeOp> {
21522145
public:
21532146
using OpRewritePattern<TransposeOp>::OpRewritePattern;
21542147

@@ -2191,24 +2184,15 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
21912184
patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
21922185
}
21932186

2194-
/// Pattern to lower `vector.contract` to `llvm.intr.matrix.multiply`.
2195-
///
2196-
/// Given the high benefit, this will be prioriotised over other
2197-
/// contract-lowering patterns. As such, the convert-vector-to-llvm pass will
2198-
/// only run this registration conditionally.
21992187
void mlir::vector::populateVectorContractToMatrixMultiply(
2200-
RewritePatternSet &patterns) {
2201-
patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), 100);
2188+
RewritePatternSet &patterns, PatternBenefit benefit) {
2189+
patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), benefit);
22022190
}
22032191

2204-
/// Pattern to lower `vector.transpose` to `llvm.intr.matrix.flat_transpose`.
2205-
///
2206-
/// Given the high benefit, this will be prioriotised over other
2207-
/// transpose-lowering patterns. As such, the convert-vector-to-llvm pass will
2208-
/// only run this registration conditionally.
22092192
void mlir::vector::populateVectorTransposeToFlatTranspose(
2210-
RewritePatternSet &patterns) {
2211-
patterns.add<TransposeOpLowering>(patterns.getContext(), 100);
2193+
RewritePatternSet &patterns, PatternBenefit benefit) {
2194+
patterns.add<TransposeOpToMatrixTransposeOpLowering>(patterns.getContext(),
2195+
benefit);
22122196
}
22132197

22142198
/// Populate the given list with patterns that convert from Vector to LLVM.

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,21 @@ void ConvertVectorToLLVMPass::runOnOperation() {
7171
populateVectorBroadcastLoweringPatterns(patterns);
7272
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
7373
if (vectorContractLowering == vector::VectorContractLowering::Matmul) {
74+
// This pattern creates a dependency on the LLVM dialect, hence we don't
75+
// include it in `populateVectorContractLoweringPatterns` that is part of
76+
// the Vector dialect (and should not depend on LLVM).
7477
populateVectorContractToMatrixMultiply(patterns);
7578
}
7679
populateVectorMaskOpLoweringPatterns(patterns);
7780
populateVectorShapeCastLoweringPatterns(patterns);
7881
populateVectorInterleaveLoweringPatterns(patterns);
7982
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
8083
if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) {
84+
// This pattern creates a dependency on the LLVM dialect, hence we don't
85+
// include it in `populateVectorTransposeLoweringPatterns` that is part of
86+
// the Vector dialect (and should not depend on LLVM).
8187
populateVectorTransposeToFlatTranspose(patterns);
8288
}
83-
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
8489
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
8590
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
8691
populateVectorMaskMaterializationPatterns(patterns,

0 commit comments

Comments
 (0)