-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Canonicalize/fold 'order preserving' transposes #135841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -5575,6 +5575,34 @@ LogicalResult ShapeCastOp::verify() { | |||
return success(); | ||||
} | ||||
|
||||
namespace { | ||||
|
||||
/// Return true if `transpose` does not permute a pair of non-unit dims. | ||||
/// By `order preserving` we mean that the flattened versions of the input and | ||||
/// output vectors are (numerically) identical. In other words `transpose` is | ||||
/// effectively a shape cast. | ||||
bool isOrderPreserving(TransposeOp transpose) { | ||||
ArrayRef<int64_t> permutation = transpose.getPermutation(); | ||||
VectorType sourceType = transpose.getSourceVectorType(); | ||||
ArrayRef<int64_t> inShape = sourceType.getShape(); | ||||
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims(); | ||||
auto isNonScalableUnitDim = [&](int64_t dim) { | ||||
return inShape[dim] == 1 && !inDimIsScalable[dim]; | ||||
}; | ||||
int64_t current = 0; | ||||
for (auto p : permutation) { | ||||
if (!isNonScalableUnitDim(p)) { | ||||
if (p < current) { | ||||
return false; | ||||
} | ||||
current = p; | ||||
} | ||||
} | ||||
return true; | ||||
} | ||||
|
||||
} // namespace | ||||
|
||||
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { | ||||
|
||||
VectorType resultType = getType(); | ||||
|
@@ -5583,17 +5611,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { | |||
if (getSource().getType() == resultType) | ||||
return getSource(); | ||||
|
||||
// Y = shape_cast(shape_cast(X))) | ||||
// -> X, if X and Y have same type | ||||
// -> shape_cast(X) otherwise. | ||||
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) { | ||||
VectorType srcType = otherOp.getSource().getType(); | ||||
if (resultType == srcType) | ||||
return otherOp.getSource(); | ||||
setOperand(otherOp.getSource()); | ||||
// shape_cast(shape_cast(x)) -> shape_cast(x) | ||||
if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) { | ||||
setOperand(precedingShapeCast.getSource()); | ||||
return getResult(); | ||||
} | ||||
|
||||
// shape_cast(transpose(x)) -> shape_cast(x) | ||||
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) { | ||||
// This folder does | ||||
// shape_cast(transpose) -> shape_cast | ||||
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does | ||||
// shape_cast -> shape_cast(transpose) | ||||
// i.e. the complete opposite. When paired, these 2 patterns can cause | ||||
// infinite cycles in pattern rewriting. | ||||
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable | ||||
// vectors, so by disabling this folder for scalable vectors the | ||||
// cycle is avoided. | ||||
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know why we generate illegal shape cast ops in first place? It sounds like something that shouldn't happen... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They're illegal according to an arm specific lowering target, which I'm not familiar with
I think @banach-space suspects that actually they're not illegal and will investigate the removal of this constraint here (and the pattern ConvertIllegalShapeCastOpsToTransposes). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Illegal" in this context means "Illegal as an SME target - transform before invoking Tidying this up is high on my TODO list (will look into it this week). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎉 thanks @banach-space |
||||
// still needed. If it's not, then we can fold here. | ||||
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) { | ||||
setOperand(transpose.getVector()); | ||||
return getResult(); | ||||
} | ||||
return {}; | ||||
} | ||||
|
||||
// Y = shape_cast(broadcast(X)) | ||||
// -> X, if X and Y have same type | ||||
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) { | ||||
|
@@ -5619,7 +5662,7 @@ namespace { | |||
/// Helper function that computes a new vector type based on the input vector | ||||
/// type by removing the trailing one dims: | ||||
/// | ||||
/// vector<4x1x1xi1> --> vector<4x1> | ||||
/// vector<4x1x1xi1> --> vector<4x1xi1> | ||||
/// | ||||
static VectorType trimTrailingOneDims(VectorType oldType) { | ||||
ArrayRef<int64_t> oldShape = oldType.getShape(); | ||||
|
@@ -6086,6 +6129,32 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { | |||
} | ||||
}; | ||||
|
||||
/// Folds transpose(shape_cast) into a new shape_cast. | ||||
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> { | ||||
public: | ||||
using OpRewritePattern::OpRewritePattern; | ||||
|
||||
LogicalResult matchAndRewrite(TransposeOp transposeOp, | ||||
PatternRewriter &rewriter) const override { | ||||
auto shapeCastOp = | ||||
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>(); | ||||
if (!shapeCastOp) | ||||
return failure(); | ||||
if (!isOrderPreserving(transposeOp)) | ||||
return failure(); | ||||
|
||||
VectorType resultType = transposeOp.getType(); | ||||
|
||||
// We don't need to check isValidShapeCast at this point, because it is | ||||
// guaranteed that merging the transpose into the the shape_cast is a valid | ||||
// shape_cast, because the transpose just inserts/removes ones. | ||||
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType, | ||||
shapeCastOp.getSource()); | ||||
return success(); | ||||
} | ||||
}; | ||||
|
||||
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is | ||||
/// 'order preserving', where 'order preserving' means the flattened | ||||
/// inputs and outputs of the transpose have identical (numerical) values. | ||||
|
@@ -6184,8 +6253,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> { | |||
|
||||
void vector::TransposeOp::getCanonicalizationPatterns( | ||||
RewritePatternSet &results, MLIRContext *context) { | ||||
results.add<FoldTransposeCreateMask, TransposeFolder, FoldTransposeSplat, | ||||
FoldTransposeBroadcast>(context); | ||||
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder, | ||||
FoldTransposeSplat, FoldTransposeBroadcast>(context); | ||||
} | ||||
|
||||
//===----------------------------------------------------------------------===// | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a case for
static
: https://llvm.org/docs/CodingStandards.html#restrict-visibilityThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I intend to change this in a follow-up PR.