Skip to content

Commit 5e21f2b

Browse files
authored
[CIR] Upstream TernaryOp for VectorType (#142393)
This change adds support for the Ternary op for VectorType Issue #136487
1 parent 8b167db commit 5e21f2b

File tree

6 files changed

+151
-9
lines changed

6 files changed

+151
-9
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,4 +2194,40 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
21942194
let hasVerifier = 1;
21952195
}
21962196

2197+
//===----------------------------------------------------------------------===//
2198+
// VecTernaryOp
2199+
//===----------------------------------------------------------------------===//
2200+
2201+
def VecTernaryOp : CIR_Op<"vec.ternary",
2202+
[Pure, AllTypesMatch<["result", "lhs", "rhs"]>]> {
2203+
let summary = "The `cond ? a : b` ternary operator for vector types";
2204+
let description = [{
2205+
The `cir.vec.ternary` operation represents the C/C++ ternary operator,
2206+
`?:`, for vector types, which does a `select` on individual elements of the
2207+
vectors. Unlike a regular `?:` operator, there is no short circuiting. All
2208+
three arguments are always evaluated. Because there is no short
2209+
circuiting, there are no regions in this operation, unlike cir.ternary.
2210+
2211+
The first argument is a vector of integral type. The second and third
2212+
arguments are vectors of the same type and have the same number of elements
2213+
as the first argument.
2214+
2215+
The result is a vector of the same type as the second and third arguments.
2216+
Each element of the result is `(bool)a[n] ? b[n] : c[n]`.
2217+
}];
2218+
2219+
let arguments = (ins
2220+
CIR_VectorOfIntType:$cond,
2221+
CIR_VectorType:$lhs,
2222+
CIR_VectorType:$rhs
2223+
);
2224+
2225+
let results = (outs CIR_VectorType:$result);
2226+
let assemblyFormat = [{
2227+
`(` $cond `,` $lhs`,` $rhs `)` `:` qualified(type($cond)) `,`
2228+
qualified(type($lhs)) attr-dict
2229+
}];
2230+
let hasVerifier = 1;
2231+
}
2232+
21972233
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,19 +1954,28 @@ mlir::Value ScalarExprEmitter::VisitAbstractConditionalOperator(
19541954
}
19551955
}
19561956

1957+
QualType condType = condExpr->getType();
1958+
19571959
// OpenCL: If the condition is a vector, we can treat this condition like
19581960
// the select function.
1959-
if ((cgf.getLangOpts().OpenCL && condExpr->getType()->isVectorType()) ||
1960-
condExpr->getType()->isExtVectorType()) {
1961+
if ((cgf.getLangOpts().OpenCL && condType->isVectorType()) ||
1962+
condType->isExtVectorType()) {
19611963
assert(!cir::MissingFeatures::vectorType());
19621964
cgf.cgm.errorNYI(e->getSourceRange(), "vector ternary op");
19631965
}
19641966

1965-
if (condExpr->getType()->isVectorType() ||
1966-
condExpr->getType()->isSveVLSBuiltinType()) {
1967-
assert(!cir::MissingFeatures::vecTernaryOp());
1968-
cgf.cgm.errorNYI(e->getSourceRange(), "vector ternary op");
1969-
return {};
1967+
if (condType->isVectorType() || condType->isSveVLSBuiltinType()) {
1968+
if (!condType->isVectorType()) {
1969+
assert(!cir::MissingFeatures::vecTernaryOp());
1970+
cgf.cgm.errorNYI(loc, "TernaryOp for SVE vector");
1971+
return {};
1972+
}
1973+
1974+
mlir::Value condValue = Visit(condExpr);
1975+
mlir::Value lhsValue = Visit(lhsExpr);
1976+
mlir::Value rhsValue = Visit(rhsExpr);
1977+
return builder.create<cir::VecTernaryOp>(loc, condValue, lhsValue,
1978+
rhsValue);
19701979
}
19711980

19721981
// If this is a really simple expression (like x ? 4 : 5), emit this as a

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,23 @@ LogicalResult cir::VecShuffleDynamicOp::verify() {
15891589
return success();
15901590
}
15911591

1592+
//===----------------------------------------------------------------------===//
1593+
// VecTernaryOp
1594+
//===----------------------------------------------------------------------===//
1595+
1596+
LogicalResult cir::VecTernaryOp::verify() {
1597+
// Verify that the condition operand has the same number of elements as the
1598+
// other operands. (The automatic verification already checked that all
1599+
// operands are vector types and that the second and third operands are the
1600+
// same type.)
1601+
if (getCond().getType().getSize() != getLhs().getType().getSize()) {
1602+
return emitOpError() << ": the number of elements in "
1603+
<< getCond().getType() << " and " << getLhs().getType()
1604+
<< " don't match";
1605+
}
1606+
return success();
1607+
}
1608+
15921609
//===----------------------------------------------------------------------===//
15931610
// TableGen'd op method definitions
15941611
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
17311731
CIRToLLVMVecExtractOpLowering,
17321732
CIRToLLVMVecInsertOpLowering,
17331733
CIRToLLVMVecCmpOpLowering,
1734-
CIRToLLVMVecShuffleDynamicOpLowering
1734+
CIRToLLVMVecShuffleDynamicOpLowering,
1735+
CIRToLLVMVecTernaryOpLowering
17351736
// clang-format on
17361737
>(converter, patterns.getContext());
17371738

@@ -1936,6 +1937,20 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
19361937
return mlir::success();
19371938
}
19381939

1940+
mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
1941+
cir::VecTernaryOp op, OpAdaptor adaptor,
1942+
mlir::ConversionPatternRewriter &rewriter) const {
1943+
// Convert `cond` into a vector of i1, then use that in a `select` op.
1944+
mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
1945+
op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
1946+
rewriter.create<mlir::LLVM::ZeroOp>(
1947+
op.getCond().getLoc(),
1948+
typeConverter->convertType(op.getCond().getType())));
1949+
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1950+
op, bitVec, adaptor.getLhs(), adaptor.getRhs());
1951+
return mlir::success();
1952+
}
1953+
19391954
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
19401955
return std::make_unique<ConvertCIRToLLVMPass>();
19411956
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,16 @@ class CIRToLLVMVecShuffleDynamicOpLowering
368368
mlir::ConversionPatternRewriter &) const override;
369369
};
370370

371+
class CIRToLLVMVecTernaryOpLowering
372+
: public mlir::OpConversionPattern<cir::VecTernaryOp> {
373+
public:
374+
using mlir::OpConversionPattern<cir::VecTernaryOp>::OpConversionPattern;
375+
376+
mlir::LogicalResult
377+
matchAndRewrite(cir::VecTernaryOp op, OpAdaptor,
378+
mlir::ConversionPatternRewriter &) const override;
379+
};
380+
371381
} // namespace direct
372382
} // namespace cir
373383

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,4 +1069,59 @@ void foo17() {
10691069

10701070
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
10711071
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
1072-
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1072+
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1073+
1074+
void foo20() {
1075+
vi4 a;
1076+
vi4 b;
1077+
vi4 c;
1078+
vi4 r = c ? a : b;
1079+
}
1080+
1081+
// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
1082+
1083+
// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1084+
// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
1085+
1086+
// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1087+
// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> {{.*}}, <4 x i32> {{.*}}
1088+
1089+
void foo21() {
1090+
vi4 a;
1091+
vi4 b;
1092+
vi4 r = (a > b) ? (a - b) : (b - a);
1093+
}
1094+
1095+
// CIR: %[[VEC_COND:.*]] = cir.vec.cmp(gt, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
1096+
// CIR: %[[LHS:.*]] = cir.binop(sub, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>
1097+
// CIR: %[[RHS:.*]] = cir.binop(sub, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>
1098+
// CIR: %[[RES:.*]] = cir.vec.ternary(%[[VEC_COND]], %[[LHS]], %[[RHS]]) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
1099+
1100+
// LLVM: %[[CMP:.*]] = icmp sgt <4 x i32> {{.*}}, {{.*}}
1101+
// LLVM: %[[SEXT:.*]] = sext <4 x i1> %[[CMP]] to <4 x i32>
1102+
// LLVM: %[[LHS:.*]] = sub <4 x i32> {{.*}}, {{.*}}
1103+
// LLVM: %[[RHS:.*]] = sub <4 x i32> {{.*}}, {{.*}}
1104+
// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> %[[SEXT]], zeroinitializer
1105+
// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> %[[LHS]], <4 x i32> %[[RHS]]
1106+
1107+
// OGCG: %[[CMP:.*]] = icmp sgt <4 x i32> {{.*}}, {{.*}}
1108+
// OGCG: %[[SEXT:.*]] = sext <4 x i1> %[[CMP]] to <4 x i32>
1109+
// OGCG: %[[LHS:.*]] = sub <4 x i32> {{.*}}, {{.*}}
1110+
// OGCG: %[[RHS:.*]] = sub <4 x i32> {{.*}}, {{.*}}
1111+
// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> %[[SEXT]], zeroinitializer
1112+
// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x i32> %[[LHS]], <4 x i32> %[[RHS]]
1113+
1114+
void foo22() {
1115+
vf4 a;
1116+
vf4 b;
1117+
vi4 c;
1118+
vf4 r = c ? a : b;
1119+
}
1120+
1121+
// CIR: %[[RES:.*]] = cir.vec.ternary({{.*}}, {{.*}}, {{.*}}) : !cir.vector<4 x !s32i>, !cir.vector<4 x !cir.float>
1122+
1123+
// LLVM: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1124+
// LLVM: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x float> {{.*}}, <4 x float> {{.*}}
1125+
1126+
// OGCG: %[[VEC_COND:.*]] = icmp ne <4 x i32> {{.*}}, zeroinitializer
1127+
// OGCG: %[[RES:.*]] = select <4 x i1> %[[VEC_COND]], <4 x float> {{.*}}, <4 x float> {{.*}}

0 commit comments

Comments
 (0)