Skip to content

[CIR] Upstream TernaryOp for VectorType #142393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 5, 2025

Conversation

AmrDeveloper
Copy link
Member

This change adds support for the Ternary op for VectorType

Issue #136487

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels Jun 2, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 2, 2025

@llvm/pr-subscribers-clangir

@llvm/pr-subscribers-clang

Author: Amr Hesham (AmrDeveloper)

Changes

This change adds support for the Ternary op for VectorType

Issue #136487


Full diff: https://github.com/llvm/llvm-project/pull/142393.diff

7 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+36)
  • (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+30)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+18)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+16-1)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h (+10)
  • (modified) clang/test/CIR/CodeGen/vector-ext.cpp (+15)
  • (modified) clang/test/CIR/CodeGen/vector.cpp (+16-1)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 07851610a2abd..eb02d849b79f6 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2190,4 +2190,40 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// VecTernaryOp
+//===----------------------------------------------------------------------===//
+
+def VecTernaryOp : CIR_Op<"vec.ternary",
+                   [Pure, AllTypesMatch<["result", "vec1", "vec2"]>]> {
+  let summary = "The `cond ? a : b` ternary operator for vector types";
+  let description = [{
+    The `cir.vec.ternary` operation represents the C/C++ ternary operator,
+    `?:`, for vector types, which does a `select` on individual elements of the
+    vectors. Unlike a regular `?:` operator, there is no short circuiting. All
+    three arguments are always evaluated.  Because there is no short
+    circuiting, there are no regions in this operation, unlike cir.ternary.
+
+    The first argument is a vector of integral type. The second and third
+    arguments are vectors of the same type and have the same number of elements
+    as the first argument.
+
+    The result is a vector of the same type as the second and third arguments.
+    Each element of the result is `(bool)a[n] ? b[n] : c[n]`.
+  }];
+
+  let arguments = (ins
+    IntegerVector:$cond,
+    CIR_VectorType:$vec1,
+    CIR_VectorType:$vec2
+  );
+
+  let results = (outs CIR_VectorType:$result);
+  let assemblyFormat = [{
+    `(` $cond `,` $vec1 `,` $vec2 `)` `:` qualified(type($cond)) `,`
+    qualified(type($vec1)) attr-dict
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // CLANG_CIR_DIALECT_IR_CIROPS_TD
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 8448c164a5e58..5ae727dff1095 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -193,6 +193,36 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
                                 e->getSourceRange().getBegin());
   }
 
+  mlir::Value
+  VisitAbstractConditionalOperator(const AbstractConditionalOperator *e) {
+    mlir::Location loc = cgf.getLoc(e->getSourceRange());
+    Expr *condExpr = e->getCond();
+    Expr *lhsExpr = e->getTrueExpr();
+    Expr *rhsExpr = e->getFalseExpr();
+
+    // OpenCL: If the condition is a vector, we can treat this condition like
+    // the select function.
+    if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) ||
+        condExpr->getType()->isExtVectorType()) {
+      cgf.getCIRGenModule().errorNYI(loc,
+                                     "TernaryOp OpenCL VectorType condition");
+      return {};
+    }
+
+    if (condExpr->getType()->isVectorType() ||
+        condExpr->getType()->isSveVLSBuiltinType()) {
+      assert(condExpr->getType()->isVectorType() && "?: op for SVE vector NYI");
+      mlir::Value condValue = Visit(condExpr);
+      mlir::Value lhsValue = Visit(lhsExpr);
+      mlir::Value rhsValue = Visit(rhsExpr);
+      return builder.create<cir::VecTernaryOp>(loc, condValue, lhsValue,
+                                               rhsValue);
+    }
+
+    cgf.getCIRGenModule().errorNYI(loc, "TernaryOp for non vector types");
+    return {};
+  }
+
   mlir::Value VisitMemberExpr(MemberExpr *e);
 
   mlir::Value VisitInitListExpr(InitListExpr *e);
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 36f050de9f8bb..1236c455304a9 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1589,6 +1589,24 @@ LogicalResult cir::VecShuffleDynamicOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// VecTernaryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::VecTernaryOp::verify() {
+  // Verify that the condition operand has the same number of elements as the
+  // other operands.  (The automatic verification already checked that all
+  // operands are vector types and that the second and third operands are the
+  // same type.)
+  if (mlir::cast<cir::VectorType>(getCond().getType()).getSize() !=
+      getVec1().getType().getSize()) {
+    return emitOpError() << ": the number of elements in "
+                         << getCond().getType() << " and "
+                         << getVec1().getType() << " don't match";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index b07e61638c3b4..e5a26260dc8cc 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1708,7 +1708,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
                CIRToLLVMVecExtractOpLowering,
                CIRToLLVMVecInsertOpLowering,
                CIRToLLVMVecCmpOpLowering,
-               CIRToLLVMVecShuffleDynamicOpLowering
+               CIRToLLVMVecShuffleDynamicOpLowering,
+               CIRToLLVMVecTernaryOpLowering
       // clang-format on
       >(converter, patterns.getContext());
 
@@ -1916,6 +1917,20 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
+    cir::VecTernaryOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  // Convert `cond` into a vector of i1, then use that in a `select` op.
+  mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
+      op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
+      rewriter.create<mlir::LLVM::ZeroOp>(
+          op.getCond().getLoc(),
+          typeConverter->convertType(op.getCond().getType())));
+  rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
+      op, bitVec, adaptor.getVec1(), adaptor.getVec2());
+  return mlir::success();
+}
+
 std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
   return std::make_unique<ConvertCIRToLLVMPass>();
 }
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 6b8862db2c8be..ed369ff15a710 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -363,6 +363,16 @@ class CIRToLLVMVecShuffleDynamicOpLowering
                   mlir::ConversionPatternRewriter &) const override;
 };
 
+class CIRToLLVMVecTernaryOpLowering
+    : public mlir::OpConversionPattern<cir::VecTernaryOp> {
+public:
+  using mlir::OpConversionPattern<cir::VecTernaryOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::VecTernaryOp op, OpAdaptor,
+                  mlir::ConversionPatternRewriter &) const override;
+};
+
 } // namespace direct
 } // namespace cir
 
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 8a0479fc1d088..53258845c2169 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -1091,3 +1091,18 @@ void foo17() {
 // OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
 // OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
 // OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
+
+void foo20() {
+  vi4 a;
+  vi4 b;
+  vi4 c;
+  vi4 r = c ? a : b;
+}
+
+// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+
+// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
+
+// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 4c50f68a56162..49f142d110a81 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -1069,4 +1069,19 @@ void foo17() {
 
 // OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
 // OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
-// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
\ No newline at end of file
+// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
+
+void foo20() {
+  vi4 a;
+  vi4 b;
+  vi4 c;
+  vi4 r = c ? a : b;
+}
+
+// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+
+// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
+
+// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
+// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}

@AmrDeveloper
Copy link
Member Author

@xlauko
Copy link
Contributor

xlauko commented Jun 2, 2025

To simplify these kind of dependencies in PRs I suggest stacked PRs in the future, it would allow you to use CIR_VectorOfIntType here and mark that this PR needs [CIR][NFC] Upstream VectorType support in helper function #142222

LLVM docs has some guidance on this: https://llvm.org/docs/GitHub.html#stacked-pull-requests

I am personally using graphite, which has also pretty good interactive demo tutorial.

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_vec_ternary branch from 120e9a6 to 4aa56e8 Compare June 2, 2025 20:48
Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, but I would like to wait for #138156 to be merged first.

vi4 b;
vi4 c;
vi4 r = c ? a : b;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to see a test case where the operands are expressions. Something like

vi4 r = (a > b) ? (a - b) : (b - a);

@xlauko
Copy link
Contributor

xlauko commented Jun 3, 2025

Can you also add test where condition has different vector element type from operands?

@AmrDeveloper AmrDeveloper force-pushed the cir_upstream_vec_ternary branch from 1a5ce80 to 603e1aa Compare June 3, 2025 12:02
@AmrDeveloper
Copy link
Member Author

I addressed all comments in this PR 👍🏻

Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (it will be nice to have a folder too in some follow up PR)

@AmrDeveloper
Copy link
Member Author

LGTM (it will be nice to have a folder too in some follow up PR)

Sure, I will submit PR for the folder as a follow-up (#142393 (comment)) 😉

Copy link
Contributor

@xlauko xlauko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@AmrDeveloper AmrDeveloper merged commit 5e21f2b into llvm:main Jun 5, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants