Skip to content

Commit

Permalink
[CIR] Vector types - part 4 (llvm#490)
Browse files Browse the repository at this point in the history
This is part 4 of implementing vector types and vector operations in
ClangIR, issue llvm#284. This change has three small additions.

Implement a "vector splat" conversion, which converts a scalar into
vector, initializing all the elements of the vector with the scalar.

Implement incomplete initialization of a vector, where the number of
explicit initializers is less than the number of elements in the vector.
The rest of the elements are implicitly zero initialized.

Implement conversions between different vector types. The language rules
require that the two types be the same size (in bytes, not necessarily
in the number of elements). These conversions are always implemented
with a bitcast.

The first two changes only required changes to the AST -> ClangIR code
gen. There are no changes to the ClangIR dialect, so no changes to the
LLVM lowering were needed.

The third part only required a change to a validation rule. The code to
implement a vector bitcast was already present. The compiler just needed
to stop rejecting it as invalid ClangIR.
  • Loading branch information
dkolsen-pgi authored and lanza committed Jun 20, 2024
1 parent db71b36 commit 6100462
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 8 deletions.
29 changes: 25 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1531,8 +1531,19 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
}
case CK_MatrixCast:
llvm_unreachable("NYI");
case CK_VectorSplat:
llvm_unreachable("NYI");
case CK_VectorSplat: {
// Create a vector object and fill all elements with the same scalar value.
assert(DestTy->isVectorType() && "CK_VectorSplat to non-vector type");
mlir::Value Value = Visit(E);
SmallVector<mlir::Value, 16> Elements;
auto VecType = CGF.getCIRType(DestTy).dyn_cast<mlir::cir::VectorType>();
auto NumElements = VecType.getSize();
for (uint64_t Index = 0; Index < NumElements; ++Index) {
Elements.push_back(Value);
}
return CGF.getBuilder().create<mlir::cir::VecCreateOp>(
CGF.getLoc(E->getSourceRange()), VecType, Elements);
}
case CK_FixedPointCast:
llvm_unreachable("NYI");
case CK_FixedPointToBoolean:
Expand Down Expand Up @@ -1660,13 +1671,23 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
assert(!UnimplementedFeature::scalableVectors() &&
"NYI: scalable vector init");
assert(!UnimplementedFeature::vectorConstants() && "NYI: vector constants");
auto VectorType =
CGF.getCIRType(E->getType()).dyn_cast<mlir::cir::VectorType>();
SmallVector<mlir::Value, 16> Elements;
for (Expr *init : E->inits()) {
Elements.push_back(Visit(init));
}
// Zero-initialize any remaining values.
if (NumInitElements < VectorType.getSize()) {
mlir::Value ZeroValue = CGF.getBuilder().create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getSourceRange()), VectorType.getEltType(),
CGF.getBuilder().getZeroInitAttr(VectorType.getEltType()));
for (uint64_t i = NumInitElements; i < VectorType.getSize(); ++i) {
Elements.push_back(ZeroValue);
}
}
return CGF.getBuilder().create<mlir::cir::VecCreateOp>(
CGF.getLoc(E->getSourceRange()), CGF.getCIRType(E->getType()),
Elements);
CGF.getLoc(E->getSourceRange()), VectorType, Elements);
}

if (NumInitElements == 0) {
Expand Down
9 changes: 6 additions & 3 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,12 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::bitcast: {
if (!srcType.dyn_cast<mlir::cir::PointerType>() ||
!resType.dyn_cast<mlir::cir::PointerType>())
return emitOpError() << "requires !cir.ptr type for source and result";
if ((!srcType.isa<mlir::cir::PointerType>() ||
!resType.isa<mlir::cir::PointerType>()) &&
(!srcType.isa<mlir::cir::VectorType>() ||
!resType.isa<mlir::cir::VectorType>()))
return emitOpError()
<< "requires !cir.ptr or !cir.vector type for source and result";
return success();
}
case cir::CastKind::floating: {
Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/CodeGen/vectype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ void vector_int_test(int x) {
vi4 b = { x, 5, 6, x + 1 };
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>

// Incomplete vector initialization.
vi4 bb = { x, x + 1 };
// CHECK: %[[#zero:]] = cir.const(#cir.int<0> : !s32i) : !s32i
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %[[#zero]], %[[#zero]] : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>

// Scalar to vector conversion, a.k.a. vector splat. Only valid as an
// operand of a binary operator, not as a regular conversion.
bb = a + 7;
// CHECK: %[[#seven:]] = cir.const(#cir.int<7> : !s32i) : !s32i
// CHECK: %{{[0-9]+}} = cir.vec.create(%[[#seven]], %[[#seven]], %[[#seven]], %[[#seven]] : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>

// Vector to vector conversion
vd2 bbb = { };
bb = (vi4)bbb;
// CHECK: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.vector<!cir.double x 2>), !cir.vector<!s32i x 4>

// Extract element
int c = a[x];
// CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : <!s32i x 4>
Expand Down Expand Up @@ -76,6 +92,17 @@ void vector_double_test(int x, double y) {
vd2 b = { y, y + 1.0 };
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}} : !cir.double, !cir.double) : <!cir.double x 2>

// Incomplete vector initialization
vd2 bb = { y };
// CHECK: [[#dzero:]] = cir.const(#cir.fp<0.000000e+00> : !cir.double) : !cir.double
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %[[#dzero]] : !cir.double, !cir.double) : <!cir.double x 2>

// Scalar to vector conversion, a.k.a. vector splat. Only valid as an
// operand of a binary operator, not as a regular conversion.
bb = a + 2.5;
// CHECK: %[[#twohalf:]] = cir.const(#cir.fp<2.500000e+00> : !cir.double) : !cir.double
// CHECK: %{{[0-9]+}} = cir.vec.create(%[[#twohalf]], %[[#twohalf]] : !cir.double, !cir.double) : <!cir.double x 2>

// Extract element
double c = a[x];
// CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : <!cir.double x 2>
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ cir.func @cast3(%p: !cir.ptr<!u32i>) {

!u32i = !cir.int<u, 32>
cir.func @cast4(%p: !cir.ptr<!u32i>) {
%2 = cir.cast(bitcast, %p : !cir.ptr<!u32i>), !u32i // expected-error {{requires !cir.ptr type for source and result}}
%2 = cir.cast(bitcast, %p : !cir.ptr<!u32i>), !u32i // expected-error {{requires !cir.ptr or !cir.vector type for source and result}}
cir.return
}

Expand Down
6 changes: 6 additions & 0 deletions clang/test/CIR/Lowering/vectype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ void vector_int_test(int x) {
// CHECK: %[[#T57:]] = llvm.insertelement %[[#T48]], %[[#T55]][%[[#T56]] : i64] : vector<4xi32>
// CHECK: llvm.store %[[#T57]], %[[#T5:]] : vector<4xi32>, !llvm.ptr

// Vector to vector conversion
vd2 bb = (vd2)b;
// CHECK: %[[#bval:]] = llvm.load %[[#bmem:]] : !llvm.ptr -> vector<4xi32>
// CHECK: %[[#bbval:]] = llvm.bitcast %[[#bval]] : vector<4xi32> to vector<2xf64>
// CHECK: llvm.store %[[#bbval]], %[[#bbmem:]] : vector<2xf64>, !llvm.ptr

// Extract element.
int c = a[x];
// CHECK: %[[#T58:]] = llvm.load %[[#T3]] : !llvm.ptr -> vector<4xi32>
Expand Down

0 comments on commit 6100462

Please sign in to comment.