diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 95844e1bb686..1af5628208da 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -1531,8 +1531,19 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { } case CK_MatrixCast: llvm_unreachable("NYI"); - case CK_VectorSplat: - llvm_unreachable("NYI"); + case CK_VectorSplat: { + // Create a vector object and fill all elements with the same scalar value. + assert(DestTy->isVectorType() && "CK_VectorSplat to non-vector type"); + mlir::Value Value = Visit(E); + SmallVector Elements; + auto VecType = CGF.getCIRType(DestTy).dyn_cast(); + auto NumElements = VecType.getSize(); + for (uint64_t Index = 0; Index < NumElements; ++Index) { + Elements.push_back(Value); + } + return CGF.getBuilder().create( + CGF.getLoc(E->getSourceRange()), VecType, Elements); + } case CK_FixedPointCast: llvm_unreachable("NYI"); case CK_FixedPointToBoolean: @@ -1660,13 +1671,23 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) { assert(!UnimplementedFeature::scalableVectors() && "NYI: scalable vector init"); assert(!UnimplementedFeature::vectorConstants() && "NYI: vector constants"); + auto VectorType = + CGF.getCIRType(E->getType()).dyn_cast(); SmallVector Elements; for (Expr *init : E->inits()) { Elements.push_back(Visit(init)); } + // Zero-initialize any remaining values. + if (NumInitElements < VectorType.getSize()) { + mlir::Value ZeroValue = CGF.getBuilder().create( + CGF.getLoc(E->getSourceRange()), VectorType.getEltType(), + CGF.getBuilder().getZeroInitAttr(VectorType.getEltType())); + for (uint64_t i = NumInitElements; i < VectorType.getSize(); ++i) { + Elements.push_back(ZeroValue); + } + } return CGF.getBuilder().create( - CGF.getLoc(E->getSourceRange()), CGF.getCIRType(E->getType()), - Elements); + CGF.getLoc(E->getSourceRange()), VectorType, Elements); } if (NumInitElements == 0) { diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 1f51087b472e..8abf57dd942e 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -428,9 +428,12 @@ LogicalResult CastOp::verify() { return success(); } case cir::CastKind::bitcast: { - if (!srcType.dyn_cast() || - !resType.dyn_cast()) - return emitOpError() << "requires !cir.ptr type for source and result"; + if ((!srcType.isa() || + !resType.isa()) && + (!srcType.isa() || + !resType.isa())) + return emitOpError() + << "requires !cir.ptr or !cir.vector type for source and result"; return success(); } case cir::CastKind::floating: { diff --git a/clang/test/CIR/CodeGen/vectype.cpp b/clang/test/CIR/CodeGen/vectype.cpp index dc86b96abd1f..3be34dc9c0b4 100644 --- a/clang/test/CIR/CodeGen/vectype.cpp +++ b/clang/test/CIR/CodeGen/vectype.cpp @@ -15,6 +15,22 @@ void vector_int_test(int x) { vi4 b = { x, 5, 6, x + 1 }; // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : !s32i, !s32i, !s32i, !s32i) : + // Incomplete vector initialization. + vi4 bb = { x, x + 1 }; + // CHECK: %[[#zero:]] = cir.const(#cir.int<0> : !s32i) : !s32i + // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %[[#zero]], %[[#zero]] : !s32i, !s32i, !s32i, !s32i) : + + // Scalar to vector conversion, a.k.a. vector splat. Only valid as an + // operand of a binary operator, not as a regular conversion. + bb = a + 7; + // CHECK: %[[#seven:]] = cir.const(#cir.int<7> : !s32i) : !s32i + // CHECK: %{{[0-9]+}} = cir.vec.create(%[[#seven]], %[[#seven]], %[[#seven]], %[[#seven]] : !s32i, !s32i, !s32i, !s32i) : + + // Vector to vector conversion + vd2 bbb = { }; + bb = (vi4)bbb; + // CHECK: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.vector), !cir.vector + // Extract element int c = a[x]; // CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : @@ -76,6 +92,17 @@ void vector_double_test(int x, double y) { vd2 b = { y, y + 1.0 }; // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}} : !cir.double, !cir.double) : + // Incomplete vector initialization + vd2 bb = { y }; + // CHECK: [[#dzero:]] = cir.const(#cir.fp<0.000000e+00> : !cir.double) : !cir.double + // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %[[#dzero]] : !cir.double, !cir.double) : + + // Scalar to vector conversion, a.k.a. vector splat. Only valid as an + // operand of a binary operator, not as a regular conversion. + bb = a + 2.5; + // CHECK: %[[#twohalf:]] = cir.const(#cir.fp<2.500000e+00> : !cir.double) : !cir.double + // CHECK: %{{[0-9]+}} = cir.vec.create(%[[#twohalf]], %[[#twohalf]] : !cir.double, !cir.double) : + // Extract element double c = a[x]; // CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index 2ff17558d866..8e5545adf1f9 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -141,7 +141,7 @@ cir.func @cast3(%p: !cir.ptr) { !u32i = !cir.int cir.func @cast4(%p: !cir.ptr) { - %2 = cir.cast(bitcast, %p : !cir.ptr), !u32i // expected-error {{requires !cir.ptr type for source and result}} + %2 = cir.cast(bitcast, %p : !cir.ptr), !u32i // expected-error {{requires !cir.ptr or !cir.vector type for source and result}} cir.return } diff --git a/clang/test/CIR/Lowering/vectype.cpp b/clang/test/CIR/Lowering/vectype.cpp index 34458427acb3..1139a815476b 100644 --- a/clang/test/CIR/Lowering/vectype.cpp +++ b/clang/test/CIR/Lowering/vectype.cpp @@ -45,6 +45,12 @@ void vector_int_test(int x) { // CHECK: %[[#T57:]] = llvm.insertelement %[[#T48]], %[[#T55]][%[[#T56]] : i64] : vector<4xi32> // CHECK: llvm.store %[[#T57]], %[[#T5:]] : vector<4xi32>, !llvm.ptr + // Vector to vector conversion + vd2 bb = (vd2)b; + // CHECK: %[[#bval:]] = llvm.load %[[#bmem:]] : !llvm.ptr -> vector<4xi32> + // CHECK: %[[#bbval:]] = llvm.bitcast %[[#bval]] : vector<4xi32> to vector<2xf64> + // CHECK: llvm.store %[[#bbval]], %[[#bbmem:]] : vector<2xf64>, !llvm.ptr + // Extract element. int c = a[x]; // CHECK: %[[#T58:]] = llvm.load %[[#T3]] : !llvm.ptr -> vector<4xi32>