-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][Complex] Fix bug in MergeComplexBitcast
#74271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Complex] Fix bug in MergeComplexBitcast
#74271
Conversation
When two `complex.bitcast` ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an `arith.bitcast` should be generated. Otherwise, the generated `complex.bitcast` op is invalid. Also remove a pattern that convertes non-complex -> non-complex `complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are invalid and should not appear in the input. Note: This bug can only be triggered by running with `-debug` (which will should intermediate IR that does not verify) or with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (llvm#74270).
@llvm/pr-subscribers-mlir-complex Author: Matthias Springer (matthias-springer) ChangesWhen two Also remove a pattern that convertes non-complex -> non-complex Note: This bug can only be triggered by running with Full diff: https://github.com/llvm/llvm-project/pull/74271.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
}
if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
- return emitOpError("requires input or output is a complex type");
+ return emitOpError(
+ "requires that either input or output has a complex type");
}
if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
LogicalResult matchAndRewrite(BitcastOp op,
PatternRewriter &rewriter) const override {
if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
- rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
- defining.getOperand());
+ if (isa<ComplexType>(op.getType()) ||
+ isa<ComplexType>(defining.getOperand().getType())) {
+ // complex.bitcast requires that input or output is complex.
+ rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ } else {
+ rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ }
return success();
}
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
}
};
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
- using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BitcastOp op,
- PatternRewriter &rewriter) const override {
- if (isa<ComplexType>(op.getType()) ||
- isa<ComplexType>(op.getOperand().getType()))
- return failure();
-
- rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
- op.getOperand());
- return success();
- }
-};
-
void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+ results.add<MergeComplexBitcast, MergeArithBitcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
// -----
func.func @complex_bitcast_i64(%arg0 : i64) {
- // expected-error @+1 {{op requires input or output is a complex type}}
+ // expected-error @+1 {{op requires that either input or output has a complex type}}
%0 = complex.bitcast %arg0: i64 to f64
return
}
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesWhen two Also remove a pattern that convertes non-complex -> non-complex Note: This bug can only be triggered by running with Full diff: https://github.com/llvm/llvm-project/pull/74271.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
}
if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
- return emitOpError("requires input or output is a complex type");
+ return emitOpError(
+ "requires that either input or output has a complex type");
}
if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
LogicalResult matchAndRewrite(BitcastOp op,
PatternRewriter &rewriter) const override {
if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
- rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
- defining.getOperand());
+ if (isa<ComplexType>(op.getType()) ||
+ isa<ComplexType>(defining.getOperand().getType())) {
+ // complex.bitcast requires that input or output is complex.
+ rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ } else {
+ rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+ defining.getOperand());
+ }
return success();
}
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
}
};
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
- using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BitcastOp op,
- PatternRewriter &rewriter) const override {
- if (isa<ComplexType>(op.getType()) ||
- isa<ComplexType>(op.getOperand().getType()))
- return failure();
-
- rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
- op.getOperand());
- return success();
- }
-};
-
void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+ results.add<MergeComplexBitcast, MergeArithBitcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
// -----
func.func @complex_bitcast_i64(%arg0 : i64) {
- // expected-error @+1 {{op requires input or output is a complex type}}
+ // expected-error @+1 {{op requires that either input or output has a complex type}}
%0 = complex.bitcast %arg0: i64 to f64
return
}
|
Should we fix "bugs" like this one? Is it actually bug? I think there is at the moment no requirement that the IR has to verify after each pattern application. I was looking into this because I had to debug a pass that applies multiple patterns and I wanted to see how an op was getting simplified. So I was running with |
When two
complex.bitcast
ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, anarith.bitcast
should be generated. Otherwise, the generatedcomplex.bitcast
op is invalid.Also remove a pattern that convertes non-complex -> non-complex
complex.bitcast
ops toarith.bitcast
. Suchcomplex.bitcast
ops are invalid and should not appear in the input.Note: This bug can only be triggered by running with
-debug
(which will should intermediate IR that does not verify) or withMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
(#74270).