-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[CIR] Upstream TernaryOp for VectorType #142393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-clangir @llvm/pr-subscribers-clang Author: Amr Hesham (AmrDeveloper) ChangesThis change adds support for the Ternary op for VectorType Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/142393.diff 7 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 07851610a2abd..eb02d849b79f6 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2190,4 +2190,40 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// VecTernaryOp
+//===----------------------------------------------------------------------===//
+
+def VecTernaryOp : CIR_Op<"vec.ternary",
+ [Pure, AllTypesMatch<["result", "vec1", "vec2"]>]> {
+ let summary = "The `cond ? a : b` ternary operator for vector types";
+ let description = [{
+ The `cir.vec.ternary` operation represents the C/C++ ternary operator,
+ `?:`, for vector types, which does a `select` on individual elements of the
+ vectors. Unlike a regular `?:` operator, there is no short circuiting. All
+ three arguments are always evaluated. Because there is no short
+ circuiting, there are no regions in this operation, unlike cir.ternary.
+
+ The first argument is a vector of integral type. The second and third
+ arguments are vectors of the same type and have the same number of elements
+ as the first argument.
+
+ The result is a vector of the same type as the second and third arguments.
+ Each element of the result is `(bool)a[n] ? b[n] : c[n]`.
+ }];
+
+ let arguments = (ins
+ IntegerVector:$cond,
+ CIR_VectorType:$vec1,
+ CIR_VectorType:$vec2
+ );
+
+ let results = (outs CIR_VectorType:$result);
+ let assemblyFormat = [{
+ `(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,`
+ qualified(type($vec1)) attr-dict
+ }];
+ let hasVerifier = 1;
+}
+
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 8448c164a5e58..5ae727dff1095 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -193,6 +193,36 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
e->getSourceRange().getBegin());
}
+ mlir::Value
+ VisitAbstractConditionalOperator(const AbstractConditionalOperator *e) {
+ mlir::Location loc = cgf.getLoc(e->getSourceRange());
+ Expr *condExpr = e->getCond();
+ Expr *lhsExpr = e->getTrueExpr();
+ Expr *rhsExpr = e->getFalseExpr();
+
+ // OpenCL: If the condition is a vector, we can treat this condition like
+ // the select function.
+ if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) ||
+ condExpr->getType()->isExtVectorType()) {
+ cgf.getCIRGenModule().errorNYI(loc,
+ "TernaryOp OpenCL VectorType condition");
+ return {};
+ }
+
+ if (condExpr->getType()->isVectorType() ||
+ condExpr->getType()->isSveVLSBuiltinType()) {
+ assert(condExpr->getType()->isVectorType() && "?: op for SVE vector NYI");
+ mlir::Value condValue = Visit(condExpr);
+ mlir::Value lhsValue = Visit(lhsExpr);
+ mlir::Value rhsValue = Visit(rhsExpr);
+ return builder.create<cir::VecTernaryOp>(loc, condValue, lhsValue,
+ rhsValue);
+ }
+
+ cgf.getCIRGenModule().errorNYI(loc, "TernaryOp for non vector types");
+ return {};
+ }
+
mlir::Value VisitMemberExpr(MemberExpr *e);
mlir::Value VisitInitListExpr(InitListExpr *e);
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 36f050de9f8bb..1236c455304a9 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1589,6 +1589,24 @@ LogicalResult cir::VecShuffleDynamicOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// VecTernaryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::VecTernaryOp::verify() {
+ // Verify that the condition operand has the same number of elements as the
+ // other operands. (The automatic verification already checked that all
+ // operands are vector types and that the second and third operands are the
+ // same type.)
+ if (mlir::cast<cir::VectorType>(getCond().getType()).getSize() !=
+ getVec1().getType().getSize()) {
+ return emitOpError() << ": the number of elements in "
+ << getCond().getType() << " and "
+ << getVec1().getType() << " don't match";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index b07e61638c3b4..e5a26260dc8cc 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1708,7 +1708,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecExtractOpLowering,
CIRToLLVMVecInsertOpLowering,
CIRToLLVMVecCmpOpLowering,
- CIRToLLVMVecShuffleDynamicOpLowering
+ CIRToLLVMVecShuffleDynamicOpLowering,
+ CIRToLLVMVecTernaryOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -1916,6 +1917,20 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
+ cir::VecTernaryOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ // Convert `cond` into a vector of i1, then use that in a `select` op.
+ mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
+ op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
+ rewriter.create<mlir::LLVM::ZeroOp>(
+ op.getCond().getLoc(),
+ typeConverter->convertType(op.getCond().getType())));
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
+ op, bitVec, adaptor.getVec1(), adaptor.getVec2());
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 6b8862db2c8be..ed369ff15a710 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -363,6 +363,16 @@ class CIRToLLVMVecShuffleDynamicOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMVecTernaryOpLowering
+ : public mlir::OpConversionPattern<cir::VecTernaryOp> {
+public:
+ using mlir::OpConversionPattern<cir::VecTernaryOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::VecTernaryOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 8a0479fc1d088..53258845c2169 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -1091,3 +1091,18 @@ void foo17() {
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
+
+void foo20() {
+ vi4 a;
+ vi4 b;
+ vi4 c;
+ vi4 r = c ? a : b;
+}
+
+// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+
+// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
+
+// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 4c50f68a56162..49f142d110a81 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -1069,4 +1069,19 @@ void foo17() {
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
-// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
\ No newline at end of file
+// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
+
+void foo20() {
+ vi4 a;
+ vi4 b;
+ vi4 c;
+ vi4 r = c ? a : b;
+}
+
+// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+
+// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
+
+// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
|
|
To simplify these kind of dependencies in PRs I suggest stacked PRs in the future, it would allow you to use LLVM docs has some guidance on this: https://llvm.org/docs/GitHub.html#stacked-pull-requests I am personally using graphite, which has also pretty good interactive demo tutorial. |
120e9a6
to
4aa56e8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, but I would like to wait for #138156 to be merged first.
vi4 b; | ||
vi4 c; | ||
vi4 r = c ? a : b; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to see a test case where the operands are expressions. Something like
vi4 r = (a > b) ? (a - b) : (b - a);
Can you also add test where condition has different vector element type from operands? |
1a5ce80
to
603e1aa
Compare
I addressed all comments in this PR 👍🏻 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (it will be nice to have a folder too in some follow up PR)
Sure, I will submit PR for the folder as a follow-up (#142393 (comment)) 😉 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
This change adds support for the Ternary op for VectorType
Issue #136487