-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[CIR] Implement folder for VecTernaryOp #142946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Implement folder for VecTernaryOp #142946
Conversation
@llvm/pr-subscribers-clang Author: Amr Hesham (AmrDeveloper) ChangesThis change adds a folder for the VecTernaryOp Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/142946.diff 4 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 00878f7dd8ed7..eb439f7aa1527 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2228,7 +2228,9 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
`(` $cond `,` $lhs`,` $rhs `)` `:` qualified(type($cond)) `,`
qualified(type($lhs)) attr-dict
}];
+
let hasVerifier = 1;
+ let hasFolder = 1;
}
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index fa7fb592a3cd6..f585254d3340b 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1638,6 +1638,41 @@ LogicalResult cir::VecTernaryOp::verify() {
return success();
}
+OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute cond = adaptor.getCond();
+ mlir::Attribute lhs = adaptor.getLhs();
+ mlir::Attribute rhs = adaptor.getRhs();
+
+ if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) &&
+ mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) &&
+ mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) {
+ auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
+ auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
+ auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
+
+ mlir::ArrayAttr condElts = condVec.getElts();
+
+ SmallVector<mlir::Attribute, 16> elements;
+ elements.reserve(condElts.size());
+
+ for (const auto &[idx, condAttr] :
+ llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
+ if (condAttr.getSInt()) {
+ elements.push_back(lhsVec.getElts()[idx]);
+ continue;
+ }
+
+ elements.push_back(rhsVec.getElts()[idx]);
+ }
+
+ cir::VectorType vecTy = getLhs().getType();
+ return cir::ConstVectorAttr::get(
+ vecTy, mlir::ArrayAttr::get(getContext(), elements));
+ }
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 7d03e374c27e8..aa3e97033cdda 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -138,10 +138,10 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
- // CastOp, UnaryOp, VecExtractOp and VecShuffleDynamicOp are here to perform
- // a manual `fold` in applyOpPatternsGreedily.
+ // CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp and VecTernaryOp are
+ // here to perform a manual `fold` in applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
- VecExtractOp, VecShuffleDynamicOp>(op))
+ VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/Transforms/vector-ternary-fold.cir b/clang/test/CIR/Transforms/vector-ternary-fold.cir
new file mode 100644
index 0000000000000..f2e18576da74b
--- /dev/null
+++ b/clang/test/CIR/Transforms/vector-ternary-fold.cir
@@ -0,0 +1,20 @@
+// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
+ %cond = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ %lhs = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+ %rhs = cir.const #cir.const_vector<[#cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %res = cir.vec.ternary(%cond, %lhs, %rhs) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %res : !cir.vector<4 x !s32i>
+ }
+
+ // [1, 0, 1, 0] ? [1, 2, 3, 4] : [5, 6, 7, 8] Will be fold to [1, 6, 3, 8]
+ // CHECK: cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<6> : !s32i, #cir.int<3> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+
|
@llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds a folder for the VecTernaryOp Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/142946.diff 4 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 00878f7dd8ed7..eb439f7aa1527 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2228,7 +2228,9 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
`(` $cond `,` $lhs`,` $rhs `)` `:` qualified(type($cond)) `,`
qualified(type($lhs)) attr-dict
}];
+
let hasVerifier = 1;
+ let hasFolder = 1;
}
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index fa7fb592a3cd6..f585254d3340b 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1638,6 +1638,41 @@ LogicalResult cir::VecTernaryOp::verify() {
return success();
}
+OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute cond = adaptor.getCond();
+ mlir::Attribute lhs = adaptor.getLhs();
+ mlir::Attribute rhs = adaptor.getRhs();
+
+ if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) &&
+ mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) &&
+ mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) {
+ auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
+ auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
+ auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
+
+ mlir::ArrayAttr condElts = condVec.getElts();
+
+ SmallVector<mlir::Attribute, 16> elements;
+ elements.reserve(condElts.size());
+
+ for (const auto &[idx, condAttr] :
+ llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
+ if (condAttr.getSInt()) {
+ elements.push_back(lhsVec.getElts()[idx]);
+ continue;
+ }
+
+ elements.push_back(rhsVec.getElts()[idx]);
+ }
+
+ cir::VectorType vecTy = getLhs().getType();
+ return cir::ConstVectorAttr::get(
+ vecTy, mlir::ArrayAttr::get(getContext(), elements));
+ }
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 7d03e374c27e8..aa3e97033cdda 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -138,10 +138,10 @@ void CIRCanonicalizePass::runOnOperation() {
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
- // CastOp, UnaryOp, VecExtractOp and VecShuffleDynamicOp are here to perform
- // a manual `fold` in applyOpPatternsGreedily.
+ // CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp and VecTernaryOp are
+ // here to perform a manual `fold` in applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
- VecExtractOp, VecShuffleDynamicOp>(op))
+ VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/Transforms/vector-ternary-fold.cir b/clang/test/CIR/Transforms/vector-ternary-fold.cir
new file mode 100644
index 0000000000000..f2e18576da74b
--- /dev/null
+++ b/clang/test/CIR/Transforms/vector-ternary-fold.cir
@@ -0,0 +1,20 @@
+// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
+ %cond = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ %lhs = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+ %rhs = cir.const #cir.const_vector<[#cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %res = cir.vec.ternary(%cond, %lhs, %rhs) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %res : !cir.vector<4 x !s32i>
+ }
+
+ // [1, 0, 1, 0] ? [1, 2, 3, 4] : [5, 6, 7, 8] Will be fold to [1, 6, 3, 8]
+ // CHECK: cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<6> : !s32i, #cir.int<3> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. I just have a few minor suggestions.
llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) { | ||
if (condAttr.getSInt()) { | ||
elements.push_back(lhsVec.getElts()[idx]); | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be more natural as an if-else
rather than using continue
.
if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) && | ||
mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) && | ||
mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) && | |
mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) && | |
mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) { | |
if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) || | |
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) || | |
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) | |
return {}; |
// CastOp, UnaryOp, VecExtractOp and VecShuffleDynamicOp are here to perform | ||
// a manual `fold` in applyOpPatternsGreedily. | ||
// CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp and VecTernaryOp are | ||
// here to perform a manual `fold` in applyOpPatternsGreedily. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amongst the operations that are here to perform a manual fold are.... (This is Monty Python joke, but we really should just make this a general comment that many of the operations in this list are just here so we can fold them.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can split them into two if statements, one for the ops that are here to perform manual fold and another for the other ops
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, SelectOp>(op))
ops.push_back(op);
// Operations to perform manual `fold` in applyOpPatternsGreedily.
if (isa<CastOp, UnaryOp, VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(
op))
ops.push_back(op);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not really necessary. I think we've passed the point where anyone is going to be checking each one to see why it's here. Also, some operations may be handled explicitly and also have a folder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right, i will convert it to general comment
This change adds a folder for the VecTernaryOp Issue llvm#136487
This change adds a folder for the VecTernaryOp
Issue #136487