Skip to content

Commit 459de73

Browse files
authored
[CIR] Upstream converting vector types (#142012)
This change adds support for ConvertVectorExpr to convert between vector types with the same size Issue #136487
1 parent d721d4e commit 459de73

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
185185
return {};
186186
}
187187

188+
mlir::Value VisitConvertVectorExpr(ConvertVectorExpr *e) {
189+
// __builtin_convertvector is an element-wise cast, and is implemented as a
190+
// regular cast. The back end handles casts of vectors correctly.
191+
return emitScalarConversion(Visit(e->getSrcExpr()),
192+
e->getSrcExpr()->getType(), e->getType(),
193+
e->getSourceRange().getBegin());
194+
}
195+
188196
mlir::Value VisitMemberExpr(MemberExpr *e);
189197

190198
mlir::Value VisitInitListExpr(InitListExpr *e);
@@ -277,7 +285,12 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
277285
"Obsolete code. Don't use mlir::IntegerType with CIR.");
278286

279287
mlir::Type fullDstTy = dstTy;
280-
assert(!cir::MissingFeatures::vectorType());
288+
if (mlir::isa<cir::VectorType>(srcTy) &&
289+
mlir::isa<cir::VectorType>(dstTy)) {
290+
// Use the element types of the vectors to figure out the CastKind.
291+
srcTy = mlir::dyn_cast<cir::VectorType>(srcTy).getElementType();
292+
dstTy = mlir::dyn_cast<cir::VectorType>(dstTy).getElementType();
293+
}
281294

282295
std::optional<cir::CastKind> castKind;
283296

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,16 @@ LogicalResult cir::ContinueOp::verify() {
286286
//===----------------------------------------------------------------------===//
287287

288288
LogicalResult cir::CastOp::verify() {
289-
const mlir::Type resType = getResult().getType();
290-
const mlir::Type srcType = getSrc().getType();
289+
mlir::Type resType = getResult().getType();
290+
mlir::Type srcType = getSrc().getType();
291+
292+
if (mlir::isa<cir::VectorType>(srcType) &&
293+
mlir::isa<cir::VectorType>(resType)) {
294+
// Use the element type of the vector to verify the cast kind. (Except for
295+
// bitcast, see below.)
296+
srcType = mlir::dyn_cast<cir::VectorType>(srcType).getElementType();
297+
resType = mlir::dyn_cast<cir::VectorType>(resType).getElementType();
298+
}
291299

292300
switch (getKind()) {
293301
case cir::CastKind::int_to_bool: {

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
66
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
77

8+
typedef unsigned short vus2 __attribute__((ext_vector_type(2)));
89
typedef int vi4 __attribute__((ext_vector_type(4)));
910
typedef int vi6 __attribute__((ext_vector_type(6)));
1011
typedef unsigned int uvi4 __attribute__((ext_vector_type(4)));
@@ -1073,3 +1074,20 @@ void foo16() {
10731074
// OGCG: %[[SHUF_IDX_3:.*]] = extractelement <6 x i32> %[[MASK]], i64 3
10741075
// OGCG: %[[SHUF_ELE_3:.*]] = extractelement <6 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_3]]
10751076
// OGCG: %[[SHUF_INS_3:.*]] = insertelement <6 x i32> %[[SHUF_INS_2]], i32 %[[SHUF_ELE_3]], i64 3
1077+
1078+
void foo17() {
1079+
vd2 a;
1080+
vus2 W = __builtin_convertvector(a, vus2);
1081+
}
1082+
1083+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["a"]
1084+
// CIR: %[[TMP:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<2 x !cir.double>>, !cir.vector<2 x !cir.double>
1085+
// CIR: %[[RES:.*]] = cir.cast(float_to_int, %[[TMP]] : !cir.vector<2 x !cir.double>), !cir.vector<2 x !u16i>
1086+
1087+
// LLVM: %[[VEC_A:.*]] = alloca <2 x double>, i64 1, align 16
1088+
// LLVM: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
1089+
// LLVM: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1090+
1091+
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
1092+
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
1093+
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -emit-llvm %s -o %t.ll
66
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
77

8+
typedef unsigned short vus2 __attribute__((vector_size(4)));
89
typedef int vi4 __attribute__((vector_size(16)));
910
typedef int vi6 __attribute__((vector_size(24)));
1011
typedef unsigned int uvi4 __attribute__((vector_size(16)));
@@ -1052,3 +1053,20 @@ void foo16() {
10521053
// OGCG: %[[SHUF_IDX_3:.*]] = extractelement <6 x i32> %[[MASK]], i64 3
10531054
// OGCG: %[[SHUF_ELE_3:.*]] = extractelement <6 x i32> %[[TMP_A]], i32 %[[SHUF_IDX_3]]
10541055
// OGCG: %[[SHUF_INS_3:.*]] = insertelement <6 x i32> %[[SHUF_INS_2]], i32 %[[SHUF_ELE_3]], i64 3
1056+
1057+
void foo17() {
1058+
vd2 a;
1059+
vus2 W = __builtin_convertvector(a, vus2);
1060+
}
1061+
1062+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<2 x !cir.double>, !cir.ptr<!cir.vector<2 x !cir.double>>, ["a"]
1063+
// CIR: %[[TMP:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<2 x !cir.double>>, !cir.vector<2 x !cir.double>
1064+
// CIR: %[[RES:.*]] = cir.cast(float_to_int, %[[TMP]] : !cir.vector<2 x !cir.double>), !cir.vector<2 x !u16i>
1065+
1066+
// LLVM: %[[VEC_A:.*]] = alloca <2 x double>, i64 1, align 16
1067+
// LLVM: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
1068+
// LLVM: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1069+
1070+
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
1071+
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
1072+
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>

0 commit comments

Comments
 (0)