Skip to content

Commit 4d9cda2

Browse files
dkolsen-pgilanza
authored andcommitted
[CIR] Vector types - part 4 (#490)
This is part 4 of implementing vector types and vector operations in ClangIR, issue #284. This change has three small additions. Implement a "vector splat" conversion, which converts a scalar into vector, initializing all the elements of the vector with the scalar. Implement incomplete initialization of a vector, where the number of explicit initializers is less than the number of elements in the vector. The rest of the elements are implicitly zero initialized. Implement conversions between different vector types. The language rules require that the two types be the same size (in bytes, not necessarily in the number of elements). These conversions are always implemented with a bitcast. The first two changes only required changes to the AST -> ClangIR code gen. There are no changes to the ClangIR dialect, so no changes to the LLVM lowering were needed. The third part only required a change to a validation rule. The code to implement a vector bitcast was already present. The compiler just needed to stop rejecting it as invalid ClangIR.
1 parent e427633 commit 4d9cda2

File tree

5 files changed

+65
-8
lines changed

5 files changed

+65
-8
lines changed

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,8 +1531,19 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
15311531
}
15321532
case CK_MatrixCast:
15331533
llvm_unreachable("NYI");
1534-
case CK_VectorSplat:
1535-
llvm_unreachable("NYI");
1534+
case CK_VectorSplat: {
1535+
// Create a vector object and fill all elements with the same scalar value.
1536+
assert(DestTy->isVectorType() && "CK_VectorSplat to non-vector type");
1537+
mlir::Value Value = Visit(E);
1538+
SmallVector<mlir::Value, 16> Elements;
1539+
auto VecType = CGF.getCIRType(DestTy).dyn_cast<mlir::cir::VectorType>();
1540+
auto NumElements = VecType.getSize();
1541+
for (uint64_t Index = 0; Index < NumElements; ++Index) {
1542+
Elements.push_back(Value);
1543+
}
1544+
return CGF.getBuilder().create<mlir::cir::VecCreateOp>(
1545+
CGF.getLoc(E->getSourceRange()), VecType, Elements);
1546+
}
15361547
case CK_FixedPointCast:
15371548
llvm_unreachable("NYI");
15381549
case CK_FixedPointToBoolean:
@@ -1660,13 +1671,23 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
16601671
assert(!UnimplementedFeature::scalableVectors() &&
16611672
"NYI: scalable vector init");
16621673
assert(!UnimplementedFeature::vectorConstants() && "NYI: vector constants");
1674+
auto VectorType =
1675+
CGF.getCIRType(E->getType()).dyn_cast<mlir::cir::VectorType>();
16631676
SmallVector<mlir::Value, 16> Elements;
16641677
for (Expr *init : E->inits()) {
16651678
Elements.push_back(Visit(init));
16661679
}
1680+
// Zero-initialize any remaining values.
1681+
if (NumInitElements < VectorType.getSize()) {
1682+
mlir::Value ZeroValue = CGF.getBuilder().create<mlir::cir::ConstantOp>(
1683+
CGF.getLoc(E->getSourceRange()), VectorType.getEltType(),
1684+
CGF.getBuilder().getZeroInitAttr(VectorType.getEltType()));
1685+
for (uint64_t i = NumInitElements; i < VectorType.getSize(); ++i) {
1686+
Elements.push_back(ZeroValue);
1687+
}
1688+
}
16671689
return CGF.getBuilder().create<mlir::cir::VecCreateOp>(
1668-
CGF.getLoc(E->getSourceRange()), CGF.getCIRType(E->getType()),
1669-
Elements);
1690+
CGF.getLoc(E->getSourceRange()), VectorType, Elements);
16701691
}
16711692

16721693
if (NumInitElements == 0) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,12 @@ LogicalResult CastOp::verify() {
428428
return success();
429429
}
430430
case cir::CastKind::bitcast: {
431-
if (!srcType.dyn_cast<mlir::cir::PointerType>() ||
432-
!resType.dyn_cast<mlir::cir::PointerType>())
433-
return emitOpError() << "requires !cir.ptr type for source and result";
431+
if ((!srcType.isa<mlir::cir::PointerType>() ||
432+
!resType.isa<mlir::cir::PointerType>()) &&
433+
(!srcType.isa<mlir::cir::VectorType>() ||
434+
!resType.isa<mlir::cir::VectorType>()))
435+
return emitOpError()
436+
<< "requires !cir.ptr or !cir.vector type for source and result";
434437
return success();
435438
}
436439
case cir::CastKind::floating: {

clang/test/CIR/CodeGen/vectype.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@ void vector_int_test(int x) {
1515
vi4 b = { x, 5, 6, x + 1 };
1616
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>
1717

18+
// Incomplete vector initialization.
19+
vi4 bb = { x, x + 1 };
20+
// CHECK: %[[#zero:]] = cir.const(#cir.int<0> : !s32i) : !s32i
21+
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %[[#zero]], %[[#zero]] : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>
22+
23+
// Scalar to vector conversion, a.k.a. vector splat. Only valid as an
24+
// operand of a binary operator, not as a regular conversion.
25+
bb = a + 7;
26+
// CHECK: %[[#seven:]] = cir.const(#cir.int<7> : !s32i) : !s32i
27+
// CHECK: %{{[0-9]+}} = cir.vec.create(%[[#seven]], %[[#seven]], %[[#seven]], %[[#seven]] : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>
28+
29+
// Vector to vector conversion
30+
vd2 bbb = { };
31+
bb = (vi4)bbb;
32+
// CHECK: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.vector<!cir.double x 2>), !cir.vector<!s32i x 4>
33+
1834
// Extract element
1935
int c = a[x];
2036
// CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : <!s32i x 4>
@@ -76,6 +92,17 @@ void vector_double_test(int x, double y) {
7692
vd2 b = { y, y + 1.0 };
7793
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}} : !cir.double, !cir.double) : <!cir.double x 2>
7894

95+
// Incomplete vector initialization
96+
vd2 bb = { y };
97+
// CHECK: [[#dzero:]] = cir.const(#cir.fp<0.000000e+00> : !cir.double) : !cir.double
98+
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %[[#dzero]] : !cir.double, !cir.double) : <!cir.double x 2>
99+
100+
// Scalar to vector conversion, a.k.a. vector splat. Only valid as an
101+
// operand of a binary operator, not as a regular conversion.
102+
bb = a + 2.5;
103+
// CHECK: %[[#twohalf:]] = cir.const(#cir.fp<2.500000e+00> : !cir.double) : !cir.double
104+
// CHECK: %{{[0-9]+}} = cir.vec.create(%[[#twohalf]], %[[#twohalf]] : !cir.double, !cir.double) : <!cir.double x 2>
105+
79106
// Extract element
80107
double c = a[x];
81108
// CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : <!cir.double x 2>

clang/test/CIR/IR/invalid.cir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ cir.func @cast3(%p: !cir.ptr<!u32i>) {
141141

142142
!u32i = !cir.int<u, 32>
143143
cir.func @cast4(%p: !cir.ptr<!u32i>) {
144-
%2 = cir.cast(bitcast, %p : !cir.ptr<!u32i>), !u32i // expected-error {{requires !cir.ptr type for source and result}}
144+
%2 = cir.cast(bitcast, %p : !cir.ptr<!u32i>), !u32i // expected-error {{requires !cir.ptr or !cir.vector type for source and result}}
145145
cir.return
146146
}
147147

clang/test/CIR/Lowering/vectype.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ void vector_int_test(int x) {
4545
// CHECK: %[[#T57:]] = llvm.insertelement %[[#T48]], %[[#T55]][%[[#T56]] : i64] : vector<4xi32>
4646
// CHECK: llvm.store %[[#T57]], %[[#T5:]] : vector<4xi32>, !llvm.ptr
4747

48+
// Vector to vector conversion
49+
vd2 bb = (vd2)b;
50+
// CHECK: %[[#bval:]] = llvm.load %[[#bmem:]] : !llvm.ptr -> vector<4xi32>
51+
// CHECK: %[[#bbval:]] = llvm.bitcast %[[#bval]] : vector<4xi32> to vector<2xf64>
52+
// CHECK: llvm.store %[[#bbval]], %[[#bbmem:]] : vector<2xf64>, !llvm.ptr
53+
4854
// Extract element.
4955
int c = a[x];
5056
// CHECK: %[[#T58:]] = llvm.load %[[#T3]] : !llvm.ptr -> vector<4xi32>

0 commit comments

Comments
 (0)