Skip to content

[mlir][LLVM] Delete getFixedVectorType and getScalableVectorType #135051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mlir/docs/Dialects/LLVM.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,6 @@ compatible with the LLVM dialect:
vector type compatible with the LLVM dialect;
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
of elements in any vector type compatible with the LLVM dialect;
- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
with the given element type and size; the resulting type is either a
built-in or an LLVM dialect vector type depending on which one supports the
given element type.

#### Examples of Compatible Vector Types

Expand Down
8 changes: 0 additions & 8 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements,
/// and length.
Type getVectorType(Type elementType, const llvm::ElementCount &numElements);

/// Creates an LLVM dialect-compatible type with the given element type and
/// length.
Type getFixedVectorType(Type elementType, unsigned numElements);

/// Creates an LLVM dialect-compatible type with the given element type and
/// length.
Type getScalableVectorType(Type elementType, unsigned numElements);

/// Returns the size of the given primitive LLVM dialect-compatible type
/// (including vectors) in bits, for example, the size of i16 is 16 and
/// the size of vector<4xi16> is 64. Returns 0 for non-primitive
Expand Down
33 changes: 16 additions & 17 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
static Type inferIntrinsicResultType(Type vectorResultType) {
MLIRContext *ctx = vectorResultType.getContext();
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
auto i32Ty = IntegerType::get(ctx, 32);
auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
auto i32x2Ty = VectorType::get(2, i32Ty);
Type f64Ty = Float64Type::get(ctx);
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
Type f64x2Ty = VectorType::get(2, f64Ty);
Type f32Ty = Float32Type::get(ctx);
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
Type f32x2Ty = VectorType::get(2, f32Ty);
if (a.getElementType() == f16x2Ty) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
Expand All @@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
ctx,
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
}
if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
if (a.getElementType() == VectorType::get(1, f32Ty)) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
}
Expand All @@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();
Type f64Ty = rewriter.getF64Type();
Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
Type i32x2Ty = VectorType::get(2, i32Ty);
Type f64x2Ty = VectorType::get(2, f64Ty);
Type f32x2Ty = VectorType::get(2, f32Ty);
Type f32x1Ty = VectorType::get(1, f32Ty);

auto makeConst = [&](int32_t index) -> Value {
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
Expand Down Expand Up @@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
Type f64Ty = b.getF64Type();
Type f32Ty = b.getF32Type();
Type i64Ty = b.getI64Type();
Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
Type i8x4Ty = VectorType::get(4, b.getI8Type());
Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
Type f32x1Ty = VectorType::get(1, f32Ty);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());

for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
Expand Down Expand Up @@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
if (!vectorResultType) {
return failure();
}
Type innerVectorType = LLVM::getFixedVectorType(
vectorResultType.getElementType(), vectorResultType.getDimSize(1));
Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
vectorResultType.getElementType());

int64_t num32BitRegs = vectorResultType.getDimSize(0);

Expand Down Expand Up @@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering

// Bitcast the sparse metadata from vector<2xf16> to an i32.
Value sparseMetadata = adaptor.getSparseMetadata();
if (sparseMetadata.getType() !=
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata =
Expand Down
12 changes: 0 additions & 12 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType,
/*isScalable=*/false);
}

Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
assert(VectorType::isValidElementType(elementType) &&
"incompatible element type");
return VectorType::get(numElements, elementType);
}

Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
// LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
// scalable/non-scalable.
return VectorType::get(numElements, elementType, /*scalableDims=*/true);
}

llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
assert(isCompatibleType(type) &&
"expected a type compatible with the LLVM dialect");
Expand Down
13 changes: 8 additions & 5 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() {
std::optional<mlir::NVVM::MMATypes>
MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
auto half2Type =
LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
VectorType::get(2, Float16Type::get(operandElType.getContext()));
if (operandElType.isF64())
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
Expand Down Expand Up @@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);

// Print the types of the operands and result.
p << " : " << "(";
p << " : "
<< "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
Expand Down Expand Up @@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto i32Ty = IntegerType::get(context, 32);
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
auto f16x2Ty = VectorType::get(2, f16Ty);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
Expand Down Expand Up @@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() {
expectedA.emplace_back(1, f64Ty);
expectedB.emplace_back(1, f64Ty);
expectedC.emplace_back(2, f64Ty);
// expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
// expectedC.emplace_back(1, VectorType::get(2, f64Ty));
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
context, SmallVector<Type>(2, f64Ty)));
allowedShapes.push_back({8, 8, 4});
Expand Down Expand Up @@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
ss << " $" << (regCnt) << ","
<< " $" << (regCnt + 1) << ","
<< " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
Expand Down
24 changes: 10 additions & 14 deletions mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,47 +103,43 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {

Type elType = type.vectorType.getElementType();
if (elType.isF16()) {
return FragmentElementInfo{
LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
inferNumRegistersPerMatrixFragment(type)};
return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
inferNumRegistersPerMatrixFragment(type)};
}

// f64 operand
Type f64Ty = Float64Type::get(ctx);
if (elType.isF64()) {
return isAccum
? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f64Ty, 1, 64,
inferNumRegistersPerMatrixFragment(type)};
}

// int8 operand
if (elType.isInteger(8)) {
return FragmentElementInfo{
LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
inferNumRegistersPerMatrixFragment(type)};
return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
32, inferNumRegistersPerMatrixFragment(type)};
}

// int4 operand
if (elType.isInteger(4)) {
return FragmentElementInfo{
LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
inferNumRegistersPerMatrixFragment(type)};
return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
32, inferNumRegistersPerMatrixFragment(type)};
}

// Integer 32bit acc operands
if (elType.isInteger(32)) {
return FragmentElementInfo{
LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
inferNumRegistersPerMatrixFragment(type)};
return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
64, inferNumRegistersPerMatrixFragment(type)};
}

// Floating point 32bit operands
if (elType.isF32()) {
Type f32Ty = Float32Type::get(ctx);
return isAccum
? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f32Ty, 1, 32,
inferNumRegistersPerMatrixFragment(type)};
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,15 @@ class TypeFromLLVMIRTranslatorImpl {

/// Translates the given fixed-vector type.
Type translate(llvm::FixedVectorType *type) {
return LLVM::getFixedVectorType(translateType(type->getElementType()),
type->getNumElements());
return VectorType::get(type->getNumElements(),
translateType(type->getElementType()));
}

/// Translates the given scalable-vector type.
Type translate(llvm::ScalableVectorType *type) {
return LLVM::getScalableVectorType(translateType(type->getElementType()),
type->getMinNumElements());
return VectorType::get(type->getMinNumElements(),
translateType(type->getElementType()),
/*scalable=*/true);
}

/// Translates the given target extension type.
Expand Down