-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][LLVM] Delete getFixedVectorType
and getScalableVectorType
#135051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][LLVM] Delete getFixedVectorType
and getScalableVectorType
#135051
Conversation
8b0377d
to
120778f
Compare
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-nvgpu Author: Matthias Springer (matthias-springer) ChangesThe LLVM dialect no longer has its own vector types. It uses Depends on #134981. Full diff: https://github.com/llvm/llvm-project/pull/135051.diff 7 Files Affected:
diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index 468f69c419071..4b5d518ca4eab 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -336,10 +336,6 @@ compatible with the LLVM dialect:
vector type compatible with the LLVM dialect;
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
of elements in any vector type compatible with the LLVM dialect;
-- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
- with the given element type and size; the resulting type is either a
- built-in or an LLVM dialect vector type depending on which one supports the
- given element type.
#### Examples of Compatible Vector Types
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index a2a76c49a2bda..17561f79d135a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements,
/// and length.
Type getVectorType(Type elementType, const llvm::ElementCount &numElements);
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getFixedVectorType(Type elementType, unsigned numElements);
-
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getScalableVectorType(Type elementType, unsigned numElements);
-
/// Returns the size of the given primitive LLVM dialect-compatible type
/// (including vectors) in bits, for example, the size of i16 is 16 and
/// the size of vector<4xi16> is 64. Returns 0 for non-primitive
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 51507c6507b69..69fa62c8196e4 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
static Type inferIntrinsicResultType(Type vectorResultType) {
MLIRContext *ctx = vectorResultType.getContext();
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
- auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+ auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
auto i32Ty = IntegerType::get(ctx, 32);
- auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+ auto i32x2Ty = VectorType::get(2, i32Ty);
Type f64Ty = Float64Type::get(ctx);
- Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+ Type f64x2Ty = VectorType::get(2, f64Ty);
Type f32Ty = Float32Type::get(ctx);
- Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
+ Type f32x2Ty = VectorType::get(2, f32Ty);
if (a.getElementType() == f16x2Ty) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
@@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
ctx,
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
}
- if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
+ if (a.getElementType() == VectorType::get(1, f32Ty)) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
}
@@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();
Type f64Ty = rewriter.getF64Type();
- Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
- Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
- Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
- Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
- Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+ Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
+ Type i32x2Ty = VectorType::get(2, i32Ty);
+ Type f64x2Ty = VectorType::get(2, f64Ty);
+ Type f32x2Ty = VectorType::get(2, f32Ty);
+ Type f32x1Ty = VectorType::get(1, f32Ty);
auto makeConst = [&](int32_t index) -> Value {
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
@@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
Type f64Ty = b.getF64Type();
Type f32Ty = b.getF32Type();
Type i64Ty = b.getI64Type();
- Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
- Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
- Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+ Type i8x4Ty = VectorType::get(4, b.getI8Type());
+ Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
+ Type f32x1Ty = VectorType::get(1, f32Ty);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
@@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
if (!vectorResultType) {
return failure();
}
- Type innerVectorType = LLVM::getFixedVectorType(
- vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+ Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
+ vectorResultType.getElementType());
int64_t num32BitRegs = vectorResultType.getDimSize(0);
@@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering
// Bitcast the sparse metadata from vector<2xf16> to an i32.
Value sparseMetadata = adaptor.getSparseMetadata();
- if (sparseMetadata.getType() !=
- LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
+ if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index b3c2a29309528..29cf38c1fefea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType,
/*isScalable=*/false);
}
-Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
- assert(VectorType::isValidElementType(elementType) &&
- "incompatible element type");
- return VectorType::get(numElements, elementType);
-}
-
-Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
- // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
- // scalable/non-scalable.
- return VectorType::get(numElements, elementType, /*scalableDims=*/true);
-}
-
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
assert(isCompatibleType(type) &&
"expected a type compatible with the LLVM dialect");
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 09bff6101edd3..b9d6952f67671 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() {
std::optional<mlir::NVVM::MMATypes>
MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
auto half2Type =
- LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+ VectorType::get(2, Float16Type::get(operandElType.getContext()));
if (operandElType.isF64())
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
@@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
- p << " : " << "(";
+ p << " : "
+ << "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto i32Ty = IntegerType::get(context, 32);
- auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
@@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() {
expectedA.emplace_back(1, f64Ty);
expectedB.emplace_back(1, f64Ty);
expectedC.emplace_back(2, f64Ty);
- // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+ // expectedC.emplace_back(1, VectorType::get(2, f64Ty));
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
context, SmallVector<Type>(2, f64Ty)));
allowedShapes.push_back({8, 8, 4});
@@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+ ss << " $" << (regCnt) << ","
+ << " $" << (regCnt + 1) << ","
+ << " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
@@ -1219,7 +1222,7 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
- [&]() -> auto { \
+ [&]() -> auto{ \
switch (dims) { \
case 1: \
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
@@ -1234,7 +1237,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
default: \
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
} \
- }()
+ } \
+ ()
llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
@@ -1364,13 +1368,14 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
- [&]() -> auto { \
+ [&]() -> auto{ \
if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
- }()
+ } \
+ ()
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 39cca7d363e0d..e80360aa08ed5 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
Type elType = type.vectorType.getElementType();
if (elType.isF16()) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
+ inferNumRegistersPerMatrixFragment(type)};
}
// f64 operand
Type f64Ty = Float64Type::get(ctx);
if (elType.isF64()) {
return isAccum
- ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
+ ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f64Ty, 1, 64,
inferNumRegistersPerMatrixFragment(type)};
@@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
// int8 operand
if (elType.isInteger(8)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
+ 32, inferNumRegistersPerMatrixFragment(type)};
}
// int4 operand
if (elType.isInteger(4)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
+ 32, inferNumRegistersPerMatrixFragment(type)};
}
// Integer 32bit acc operands
if (elType.isInteger(32)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
+ 64, inferNumRegistersPerMatrixFragment(type)};
}
// Floating point 32bit operands
if (elType.isF32()) {
Type f32Ty = Float32Type::get(ctx);
return isAccum
- ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
+ ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f32Ty, 1, 32,
inferNumRegistersPerMatrixFragment(type)};
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index bc9765fff2953..c46aa3e80d51a 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -124,14 +124,15 @@ class TypeFromLLVMIRTranslatorImpl {
/// Translates the given fixed-vector type.
Type translate(llvm::FixedVectorType *type) {
- return LLVM::getFixedVectorType(translateType(type->getElementType()),
- type->getNumElements());
+ return VectorType::get(type->getNumElements(),
+ translateType(type->getElementType()));
}
/// Translates the given scalable-vector type.
Type translate(llvm::ScalableVectorType *type) {
- return LLVM::getScalableVectorType(translateType(type->getElementType()),
- type->getMinNumElements());
+ return VectorType::get(type->getMinNumElements(),
+ translateType(type->getElementType()),
+ /*scalable=*/true);
}
/// Translates the given target extension type.
|
@llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesThe LLVM dialect no longer has its own vector types. It uses Depends on #134981. Full diff: https://github.com/llvm/llvm-project/pull/135051.diff 7 Files Affected:
diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index 468f69c419071..4b5d518ca4eab 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -336,10 +336,6 @@ compatible with the LLVM dialect:
vector type compatible with the LLVM dialect;
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
of elements in any vector type compatible with the LLVM dialect;
-- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
- with the given element type and size; the resulting type is either a
- built-in or an LLVM dialect vector type depending on which one supports the
- given element type.
#### Examples of Compatible Vector Types
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index a2a76c49a2bda..17561f79d135a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements,
/// and length.
Type getVectorType(Type elementType, const llvm::ElementCount &numElements);
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getFixedVectorType(Type elementType, unsigned numElements);
-
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getScalableVectorType(Type elementType, unsigned numElements);
-
/// Returns the size of the given primitive LLVM dialect-compatible type
/// (including vectors) in bits, for example, the size of i16 is 16 and
/// the size of vector<4xi16> is 64. Returns 0 for non-primitive
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 51507c6507b69..69fa62c8196e4 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
static Type inferIntrinsicResultType(Type vectorResultType) {
MLIRContext *ctx = vectorResultType.getContext();
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
- auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+ auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
auto i32Ty = IntegerType::get(ctx, 32);
- auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+ auto i32x2Ty = VectorType::get(2, i32Ty);
Type f64Ty = Float64Type::get(ctx);
- Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+ Type f64x2Ty = VectorType::get(2, f64Ty);
Type f32Ty = Float32Type::get(ctx);
- Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
+ Type f32x2Ty = VectorType::get(2, f32Ty);
if (a.getElementType() == f16x2Ty) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
@@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
ctx,
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
}
- if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
+ if (a.getElementType() == VectorType::get(1, f32Ty)) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
}
@@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();
Type f64Ty = rewriter.getF64Type();
- Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
- Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
- Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
- Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
- Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+ Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
+ Type i32x2Ty = VectorType::get(2, i32Ty);
+ Type f64x2Ty = VectorType::get(2, f64Ty);
+ Type f32x2Ty = VectorType::get(2, f32Ty);
+ Type f32x1Ty = VectorType::get(1, f32Ty);
auto makeConst = [&](int32_t index) -> Value {
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
@@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
Type f64Ty = b.getF64Type();
Type f32Ty = b.getF32Type();
Type i64Ty = b.getI64Type();
- Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
- Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
- Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+ Type i8x4Ty = VectorType::get(4, b.getI8Type());
+ Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
+ Type f32x1Ty = VectorType::get(1, f32Ty);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
@@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
if (!vectorResultType) {
return failure();
}
- Type innerVectorType = LLVM::getFixedVectorType(
- vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+ Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
+ vectorResultType.getElementType());
int64_t num32BitRegs = vectorResultType.getDimSize(0);
@@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering
// Bitcast the sparse metadata from vector<2xf16> to an i32.
Value sparseMetadata = adaptor.getSparseMetadata();
- if (sparseMetadata.getType() !=
- LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
+ if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index b3c2a29309528..29cf38c1fefea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType,
/*isScalable=*/false);
}
-Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
- assert(VectorType::isValidElementType(elementType) &&
- "incompatible element type");
- return VectorType::get(numElements, elementType);
-}
-
-Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
- // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
- // scalable/non-scalable.
- return VectorType::get(numElements, elementType, /*scalableDims=*/true);
-}
-
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
assert(isCompatibleType(type) &&
"expected a type compatible with the LLVM dialect");
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 09bff6101edd3..b9d6952f67671 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() {
std::optional<mlir::NVVM::MMATypes>
MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
auto half2Type =
- LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+ VectorType::get(2, Float16Type::get(operandElType.getContext()));
if (operandElType.isF64())
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
@@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
- p << " : " << "(";
+ p << " : "
+ << "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto i32Ty = IntegerType::get(context, 32);
- auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
@@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() {
expectedA.emplace_back(1, f64Ty);
expectedB.emplace_back(1, f64Ty);
expectedC.emplace_back(2, f64Ty);
- // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+ // expectedC.emplace_back(1, VectorType::get(2, f64Ty));
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
context, SmallVector<Type>(2, f64Ty)));
allowedShapes.push_back({8, 8, 4});
@@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+ ss << " $" << (regCnt) << ","
+ << " $" << (regCnt + 1) << ","
+ << " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
@@ -1219,7 +1222,7 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
- [&]() -> auto { \
+ [&]() -> auto{ \
switch (dims) { \
case 1: \
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
@@ -1234,7 +1237,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
default: \
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
} \
- }()
+ } \
+ ()
llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
@@ -1364,13 +1368,14 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
- [&]() -> auto { \
+ [&]() -> auto{ \
if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
- }()
+ } \
+ ()
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 39cca7d363e0d..e80360aa08ed5 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
Type elType = type.vectorType.getElementType();
if (elType.isF16()) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
+ inferNumRegistersPerMatrixFragment(type)};
}
// f64 operand
Type f64Ty = Float64Type::get(ctx);
if (elType.isF64()) {
return isAccum
- ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
+ ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f64Ty, 1, 64,
inferNumRegistersPerMatrixFragment(type)};
@@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
// int8 operand
if (elType.isInteger(8)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
+ 32, inferNumRegistersPerMatrixFragment(type)};
}
// int4 operand
if (elType.isInteger(4)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
+ 32, inferNumRegistersPerMatrixFragment(type)};
}
// Integer 32bit acc operands
if (elType.isInteger(32)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
+ 64, inferNumRegistersPerMatrixFragment(type)};
}
// Floating point 32bit operands
if (elType.isF32()) {
Type f32Ty = Float32Type::get(ctx);
return isAccum
- ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
+ ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f32Ty, 1, 32,
inferNumRegistersPerMatrixFragment(type)};
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index bc9765fff2953..c46aa3e80d51a 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -124,14 +124,15 @@ class TypeFromLLVMIRTranslatorImpl {
/// Translates the given fixed-vector type.
Type translate(llvm::FixedVectorType *type) {
- return LLVM::getFixedVectorType(translateType(type->getElementType()),
- type->getNumElements());
+ return VectorType::get(type->getNumElements(),
+ translateType(type->getElementType()));
}
/// Translates the given scalable-vector type.
Type translate(llvm::ScalableVectorType *type) {
- return LLVM::getScalableVectorType(translateType(type->getElementType()),
- type->getMinNumElements());
+ return VectorType::get(type->getMinNumElements(),
+ translateType(type->getElementType()),
+ /*scalable=*/true);
}
/// Translates the given target extension type.
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
120778f
to
0982387
Compare
0982387
to
ce792f0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the cleanup.
This addresses a comment on #135051.
I have tracked down a crash in the slp vectorizer to this PR. Unfortunately, it is deep inside some highly complex private code when optimizing with thinlto. I'm working on a reduced testcase, but until then, here is the stack trace of the failure.
|
Progress is slow going, but the failing intrinsic is llvm.umax, and it fails because it has no operands at all, which is why getOperandEntry asserts. |
@Sterling-Augustine while the title talks about "Vector"-things, this is a MLIR only change and I can't see how it can crash clang. I suspect your test is flaky or your bisection didn't go as expected. |
…lvm#135051) The LLVM dialect no longer has its own vector types. It uses `mlir::VectorType` everywhere. Remove `LLVM::getFixedVectorType/getScalableVectorType` and use `VectorType::get` instead. This commit addresses a [comment](llvm#133286 (comment)) on the PR that deleted the LLVM vector types.
This addresses a comment on llvm#135051.
…et (#1585) Use VectorType::get instead of the `getFixedVectorType` function because it's already removed PR llvm/llvm-project#135051
The LLVM dialect no longer has its own vector types. It uses
mlir::VectorType
everywhere. RemoveLLVM::getFixedVectorType/getScalableVectorType
and useVectorType::get
instead. This commit addresses a comment on the PR that deleted the LLVM vector types.