Skip to content

[CIR] Vector types - part 4 #490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1531,8 +1531,19 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
}
case CK_MatrixCast:
llvm_unreachable("NYI");
case CK_VectorSplat:
llvm_unreachable("NYI");
case CK_VectorSplat: {
// Create a vector object and fill all elements with the same scalar value.
assert(DestTy->isVectorType() && "CK_VectorSplat to non-vector type");
mlir::Value Value = Visit(E);
SmallVector<mlir::Value, 16> Elements;
auto VecType = CGF.getCIRType(DestTy).dyn_cast<mlir::cir::VectorType>();
auto NumElements = VecType.getSize();
for (int Index = 0; Index < NumElements; ++Index) {
Elements.push_back(Value);
}
return CGF.getBuilder().create<mlir::cir::VecCreateOp>(
CGF.getLoc(E->getSourceRange()), VecType, Elements);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to get a specific vec splat op, so analysis wouldn't need to check every element to find out if something is a splat or not.

}
case CK_FixedPointCast:
llvm_unreachable("NYI");
case CK_FixedPointToBoolean:
Expand Down Expand Up @@ -1660,13 +1671,23 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
assert(!UnimplementedFeature::scalableVectors() &&
"NYI: scalable vector init");
assert(!UnimplementedFeature::vectorConstants() && "NYI: vector constants");
auto VectorType =
CGF.getCIRType(E->getType()).dyn_cast<mlir::cir::VectorType>();
SmallVector<mlir::Value, 16> Elements;
for (Expr *init : E->inits()) {
Elements.push_back(Visit(init));
}
// Zero-initialize any remaining values.
if (NumInitElements < VectorType.getSize()) {
mlir::Value ZeroValue = CGF.getBuilder().create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getSourceRange()), VectorType.getEltType(),
CGF.getBuilder().getZeroInitAttr(VectorType.getEltType()));
for (int i = NumInitElements; i < VectorType.getSize(); ++i) {
Elements.push_back(ZeroValue);
}
}
return CGF.getBuilder().create<mlir::cir::VecCreateOp>(
CGF.getLoc(E->getSourceRange()), CGF.getCIRType(E->getType()),
Elements);
CGF.getLoc(E->getSourceRange()), VectorType, Elements);
}

if (NumInitElements == 0) {
Expand Down
9 changes: 6 additions & 3 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,12 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::bitcast: {
if (!srcType.dyn_cast<mlir::cir::PointerType>() ||
!resType.dyn_cast<mlir::cir::PointerType>())
return emitOpError() << "requires !cir.ptr type for source and result";
if ((!srcType.isa<mlir::cir::PointerType>() ||
!resType.isa<mlir::cir::PointerType>()) &&
(!srcType.isa<mlir::cir::VectorType>() ||
!resType.isa<mlir::cir::VectorType>()))
return emitOpError()
<< "requires !cir.ptr or !cir.vector type for source and result";
return success();
}
case cir::CastKind::floating: {
Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/CodeGen/vectype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@ void vector_int_test(int x) {
vi4 b = { x, 5, 6, x + 1 };
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>

// Incomplete vector initialization.
vi4 bb = { x, x + 1 };
// CHECK: %[[#zero:]] = cir.const(#cir.int<0> : !s32i) : !s32i
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %[[#zero]], %[[#zero]] : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random comment: If this was a constant vector (which isn't the case), this reminds me of something we discussed in previous patches it would be good to have #cir.const_vec of sorts, and similarly to ConstArrayAttr, we could have a trailing zero indicator. Adding a comment here so I can populate an issue.


// Scalar to vector conversion, a.k.a. vector splat. Only valid as an
// operand of a binary operator, not as a regular conversion.
bb = a + 7;
// CHECK: %[[#seven:]] = cir.const(#cir.int<7> : !s32i) : !s32i
// CHECK: %{{[0-9]+}} = cir.vec.create(%[[#seven]], %[[#seven]], %[[#seven]], %[[#seven]] : !s32i, !s32i, !s32i, !s32i) : <!s32i x 4>

// Vector to vector conversion
vd2 bbb = { };
bb = (vi4)bbb;
// CHECK: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.vector<!cir.double x 2>), !cir.vector<!s32i x 4>

// Extract element
int c = a[x];
// CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : <!s32i x 4>
Expand Down Expand Up @@ -76,6 +92,17 @@ void vector_double_test(int x, double y) {
vd2 b = { y, y + 1.0 };
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}} : !cir.double, !cir.double) : <!cir.double x 2>

// Incomplete vector initialization
vd2 bb = { y };
// CHECK: [[#dzero:]] = cir.const(#cir.fp<0.000000e+00> : !cir.double) : !cir.double
// CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %[[#dzero]] : !cir.double, !cir.double) : <!cir.double x 2>

// Scalar to vector conversion, a.k.a. vector splat. Only valid as an
// operand of a binary operator, not as a regular conversion.
bb = a + 2.5;
// CHECK: %[[#twohalf:]] = cir.const(#cir.fp<2.500000e+00> : !cir.double) : !cir.double
// CHECK: %{{[0-9]+}} = cir.vec.create(%[[#twohalf]], %[[#twohalf]] : !cir.double, !cir.double) : <!cir.double x 2>

// Extract element
double c = a[x];
// CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : <!cir.double x 2>
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ cir.func @cast3(%p: !cir.ptr<!u32i>) {

!u32i = !cir.int<u, 32>
cir.func @cast4(%p: !cir.ptr<!u32i>) {
%2 = cir.cast(bitcast, %p : !cir.ptr<!u32i>), !u32i // expected-error {{requires !cir.ptr type for source and result}}
%2 = cir.cast(bitcast, %p : !cir.ptr<!u32i>), !u32i // expected-error {{requires !cir.ptr or !cir.vector type for source and result}}
cir.return
}

Expand Down
6 changes: 6 additions & 0 deletions clang/test/CIR/Lowering/vectype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ void vector_int_test(int x) {
// CHECK: %[[#T57:]] = llvm.insertelement %[[#T48]], %[[#T55]][%[[#T56]] : i64] : vector<4xi32>
// CHECK: llvm.store %[[#T57]], %[[#T5:]] : vector<4xi32>, !llvm.ptr

// Vector to vector conversion
vd2 bb = (vd2)b;
// CHECK: %[[#bval:]] = llvm.load %[[#bmem:]] : !llvm.ptr -> vector<4xi32>
// CHECK: %[[#bbval:]] = llvm.bitcast %[[#bval]] : vector<4xi32> to vector<2xf64>
// CHECK: llvm.store %[[#bbval]], %[[#bbmem:]] : vector<2xf64>, !llvm.ptr

// Extract element.
int c = a[x];
// CHECK: %[[#T58:]] = llvm.load %[[#T3]] : !llvm.ptr -> vector<4xi32>
Expand Down