From 610046226127c4338bba1da1caf4a26cde20a334 Mon Sep 17 00:00:00 2001 From: David Olsen Date: Tue, 5 Mar 2024 11:31:16 -0800 Subject: [PATCH] [CIR] Vector types - part 4 (#490) This is part 4 of implementing vector types and vector operations in ClangIR, issue #284. This change has three small additions. Implement a "vector splat" conversion, which converts a scalar into vector, initializing all the elements of the vector with the scalar. Implement incomplete initialization of a vector, where the number of explicit initializers is less than the number of elements in the vector. The rest of the elements are implicitly zero initialized. Implement conversions between different vector types. The language rules require that the two types be the same size (in bytes, not necessarily in the number of elements). These conversions are always implemented with a bitcast. The first two changes only required changes to the AST -> ClangIR code gen. There are no changes to the ClangIR dialect, so no changes to the LLVM lowering were needed. The third part only required a change to a validation rule. The code to implement a vector bitcast was already present. The compiler just needed to stop rejecting it as invalid ClangIR. --- clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 29 +++++++++++++++++++--- clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 9 ++++--- clang/test/CIR/CodeGen/vectype.cpp | 27 ++++++++++++++++++++ clang/test/CIR/IR/invalid.cir | 2 +- clang/test/CIR/Lowering/vectype.cpp | 6 +++++ 5 files changed, 65 insertions(+), 8 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 95844e1bb686..1af5628208da 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -1531,8 +1531,19 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { } case CK_MatrixCast: llvm_unreachable("NYI"); - case CK_VectorSplat: - llvm_unreachable("NYI"); + case CK_VectorSplat: { + // Create a vector object and fill all elements with the same scalar value. + assert(DestTy->isVectorType() && "CK_VectorSplat to non-vector type"); + mlir::Value Value = Visit(E); + SmallVector Elements; + auto VecType = CGF.getCIRType(DestTy).dyn_cast(); + auto NumElements = VecType.getSize(); + for (uint64_t Index = 0; Index < NumElements; ++Index) { + Elements.push_back(Value); + } + return CGF.getBuilder().create( + CGF.getLoc(E->getSourceRange()), VecType, Elements); + } case CK_FixedPointCast: llvm_unreachable("NYI"); case CK_FixedPointToBoolean: @@ -1660,13 +1671,23 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) { assert(!UnimplementedFeature::scalableVectors() && "NYI: scalable vector init"); assert(!UnimplementedFeature::vectorConstants() && "NYI: vector constants"); + auto VectorType = + CGF.getCIRType(E->getType()).dyn_cast(); SmallVector Elements; for (Expr *init : E->inits()) { Elements.push_back(Visit(init)); } + // Zero-initialize any remaining values. + if (NumInitElements < VectorType.getSize()) { + mlir::Value ZeroValue = CGF.getBuilder().create( + CGF.getLoc(E->getSourceRange()), VectorType.getEltType(), + CGF.getBuilder().getZeroInitAttr(VectorType.getEltType())); + for (uint64_t i = NumInitElements; i < VectorType.getSize(); ++i) { + Elements.push_back(ZeroValue); + } + } return CGF.getBuilder().create( - CGF.getLoc(E->getSourceRange()), CGF.getCIRType(E->getType()), - Elements); + CGF.getLoc(E->getSourceRange()), VectorType, Elements); } if (NumInitElements == 0) { diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 1f51087b472e..8abf57dd942e 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -428,9 +428,12 @@ LogicalResult CastOp::verify() { return success(); } case cir::CastKind::bitcast: { - if (!srcType.dyn_cast() || - !resType.dyn_cast()) - return emitOpError() << "requires !cir.ptr type for source and result"; + if ((!srcType.isa() || + !resType.isa()) && + (!srcType.isa() || + !resType.isa())) + return emitOpError() + << "requires !cir.ptr or !cir.vector type for source and result"; return success(); } case cir::CastKind::floating: { diff --git a/clang/test/CIR/CodeGen/vectype.cpp b/clang/test/CIR/CodeGen/vectype.cpp index dc86b96abd1f..3be34dc9c0b4 100644 --- a/clang/test/CIR/CodeGen/vectype.cpp +++ b/clang/test/CIR/CodeGen/vectype.cpp @@ -15,6 +15,22 @@ void vector_int_test(int x) { vi4 b = { x, 5, 6, x + 1 }; // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : !s32i, !s32i, !s32i, !s32i) : + // Incomplete vector initialization. + vi4 bb = { x, x + 1 }; + // CHECK: %[[#zero:]] = cir.const(#cir.int<0> : !s32i) : !s32i + // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}}, %[[#zero]], %[[#zero]] : !s32i, !s32i, !s32i, !s32i) : + + // Scalar to vector conversion, a.k.a. vector splat. Only valid as an + // operand of a binary operator, not as a regular conversion. + bb = a + 7; + // CHECK: %[[#seven:]] = cir.const(#cir.int<7> : !s32i) : !s32i + // CHECK: %{{[0-9]+}} = cir.vec.create(%[[#seven]], %[[#seven]], %[[#seven]], %[[#seven]] : !s32i, !s32i, !s32i, !s32i) : + + // Vector to vector conversion + vd2 bbb = { }; + bb = (vi4)bbb; + // CHECK: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.vector), !cir.vector + // Extract element int c = a[x]; // CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : @@ -76,6 +92,17 @@ void vector_double_test(int x, double y) { vd2 b = { y, y + 1.0 }; // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %{{[0-9]+}} : !cir.double, !cir.double) : + // Incomplete vector initialization + vd2 bb = { y }; + // CHECK: [[#dzero:]] = cir.const(#cir.fp<0.000000e+00> : !cir.double) : !cir.double + // CHECK: %{{[0-9]+}} = cir.vec.create(%{{[0-9]+}}, %[[#dzero]] : !cir.double, !cir.double) : + + // Scalar to vector conversion, a.k.a. vector splat. Only valid as an + // operand of a binary operator, not as a regular conversion. + bb = a + 2.5; + // CHECK: %[[#twohalf:]] = cir.const(#cir.fp<2.500000e+00> : !cir.double) : !cir.double + // CHECK: %{{[0-9]+}} = cir.vec.create(%[[#twohalf]], %[[#twohalf]] : !cir.double, !cir.double) : + // Extract element double c = a[x]; // CHECK: %{{[0-9]+}} = cir.vec.extract %{{[0-9]+}}[%{{[0-9]+}} : !s32i] : diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index 2ff17558d866..8e5545adf1f9 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -141,7 +141,7 @@ cir.func @cast3(%p: !cir.ptr) { !u32i = !cir.int cir.func @cast4(%p: !cir.ptr) { - %2 = cir.cast(bitcast, %p : !cir.ptr), !u32i // expected-error {{requires !cir.ptr type for source and result}} + %2 = cir.cast(bitcast, %p : !cir.ptr), !u32i // expected-error {{requires !cir.ptr or !cir.vector type for source and result}} cir.return } diff --git a/clang/test/CIR/Lowering/vectype.cpp b/clang/test/CIR/Lowering/vectype.cpp index 34458427acb3..1139a815476b 100644 --- a/clang/test/CIR/Lowering/vectype.cpp +++ b/clang/test/CIR/Lowering/vectype.cpp @@ -45,6 +45,12 @@ void vector_int_test(int x) { // CHECK: %[[#T57:]] = llvm.insertelement %[[#T48]], %[[#T55]][%[[#T56]] : i64] : vector<4xi32> // CHECK: llvm.store %[[#T57]], %[[#T5:]] : vector<4xi32>, !llvm.ptr + // Vector to vector conversion + vd2 bb = (vd2)b; + // CHECK: %[[#bval:]] = llvm.load %[[#bmem:]] : !llvm.ptr -> vector<4xi32> + // CHECK: %[[#bbval:]] = llvm.bitcast %[[#bval]] : vector<4xi32> to vector<2xf64> + // CHECK: llvm.store %[[#bbval]], %[[#bbmem:]] : vector<2xf64>, !llvm.ptr + // Extract element. int c = a[x]; // CHECK: %[[#T58:]] = llvm.load %[[#T3]] : !llvm.ptr -> vector<4xi32>