Skip to content

[mlir][LLVM] Delete LLVMFixedVectorType and LLVMScalableVectorType #133286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions mlir/docs/Dialects/LLVM.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,20 +327,7 @@ multiple of some fixed size in case of _scalable_ vectors, and the element type.
Vectors cannot be nested and only 1D vectors are supported. Scalable vectors are
still considered 1D.

LLVM dialect uses built-in vector types for _fixed_-size vectors of built-in
types, and provides additional types for fixed-sized vectors of LLVM dialect
types (`LLVMFixedVectorType`) and scalable vectors of any types
(`LLVMScalableVectorType`). These two additional types share the following
syntax:

```
llvm-vec-type ::= `!llvm.vec<` (`?` `x`)? integer-literal `x` type `>`
```

Note that the sets of element types supported by built-in and LLVM dialect
vector types are mutually exclusive, e.g., the built-in vector type does not
accept `!llvm.ptr` and the LLVM dialect fixed-width vector type does not
accept `i32`.
The LLVM dialect uses built-in vector type.

The following functions are provided to operate on any kind of the vector types
compatible with the LLVM dialect:
Expand All @@ -360,8 +347,8 @@ compatible with the LLVM dialect:

```mlir
vector<42 x i32> // Vector of 42 32-bit integers.
!llvm.vec<42 x ptr> // Vector of 42 pointers.
!llvm.vec<? x 4 x i32> // Scalable vector of 32-bit integers with
vector<42 x !llvm.ptr> // Vector of 42 pointers.
vector<[4] x i32> // Scalable vector of 32-bit integers with
// size divisible by 4.
!llvm.array<2 x vector<2 x i32>> // Array of 2 vectors of 2 32-bit integers.
!llvm.array<2 x vec<2 x ptr>> // Array of 2 vectors of 2 pointers.
Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ namespace LLVM {
}

DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata");
Expand Down
77 changes: 13 additions & 64 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -288,70 +288,6 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
];
}

//===----------------------------------------------------------------------===//
// LLVMFixedVectorType
//===----------------------------------------------------------------------===//

def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
let summary = "LLVM fixed vector type";
let description = [{
LLVM dialect vector type that supports all element types that are supported
in LLVM vectors but that are not supported by the builtin MLIR vector type.
E.g., LLVMFixedVectorType supports LLVM pointers as element type.
}];

let typeName = "llvm.fixed_vec";

let parameters = (ins "Type":$elementType, "unsigned":$numElements);
let assemblyFormat = [{
`<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
}];

let genVerifyDecl = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType,
"unsigned":$numElements)>
];

let extraClassDeclaration = [{
/// Checks if the given type can be used in a vector type.
static bool isValidElementType(Type type);
}];
}

//===----------------------------------------------------------------------===//
// LLVMScalableVectorType
//===----------------------------------------------------------------------===//

def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
let summary = "LLVM scalable vector type";
let description = [{
LLVM dialect scalable vector type, represents a sequence of elements of
unknown length that is known to be divisible by some constant. These
elements can be processed as one in SIMD context.
}];

let typeName = "llvm.scalable_vec";

let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
let assemblyFormat = [{
`<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
}];

let genVerifyDecl = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType,
"unsigned":$minNumElements)>
];

let extraClassDeclaration = [{
/// Checks if the given type can be used in a vector type.
static bool isValidElementType(Type type);
}];
}

//===----------------------------------------------------------------------===//
// LLVMTargetExtType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -400,4 +336,17 @@ def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
}];
}

//===----------------------------------------------------------------------===//
// LLVMPPCFP128Type
//===----------------------------------------------------------------------===//

def LLVMPPCFP128Type : LLVMType<"LLVMPPCFP128", "ppc_fp128",
[DeclareTypeInterfaceMethods<FloatTypeInterface, ["getFloatSemantics"]>]> {
let summary = "128 bit FP type with IBM double-double semantics";
let description = [{
A 128 bit floating-point type with IBM double-double semantics.
See S_PPCDoubleDouble in APFloat.h for details.
}];
}

#endif // LLVMTYPES_TD
64 changes: 25 additions & 39 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,6 @@ GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
static Type extractVectorElementType(Type type) {
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return vectorType.getElementType();
if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type))
return scalableVectorType.getElementType();
if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type))
return fixedVectorType.getElementType();
return type;
}

Expand Down Expand Up @@ -725,20 +721,18 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
if (rawConstantIndices.size() == 1 || !currType)
continue;

currType =
TypeSwitch<Type, Type>(currType)
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
LLVMArrayType>([](auto containerType) {
return containerType.getElementType();
})
.Case([&](LLVMStructType structType) -> Type {
int64_t memberIndex = rawConstantIndices.back();
if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
structType.getBody().size())
return structType.getBody()[memberIndex];
return nullptr;
})
.Default(Type(nullptr));
currType = TypeSwitch<Type, Type>(currType)
.Case<VectorType, LLVMArrayType>([](auto containerType) {
return containerType.getElementType();
})
.Case([&](LLVMStructType structType) -> Type {
int64_t memberIndex = rawConstantIndices.back();
if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
structType.getBody().size())
return structType.getBody()[memberIndex];
return nullptr;
})
.Default(Type(nullptr));
}
}

Expand Down Expand Up @@ -839,11 +833,11 @@ verifyStructIndices(Type baseGEPType, unsigned indexPos,
return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
indices, emitOpError);
})
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
LLVMArrayType>([&](auto containerType) -> LogicalResult {
return verifyStructIndices(containerType.getElementType(), indexPos + 1,
indices, emitOpError);
})
.Case<VectorType, LLVMArrayType>(
[&](auto containerType) -> LogicalResult {
return verifyStructIndices(containerType.getElementType(),
indexPos + 1, indices, emitOpError);
})
.Default([&](auto otherType) -> LogicalResult {
return emitOpError()
<< "type " << otherType << " cannot be indexed (index #"
Expand Down Expand Up @@ -3157,35 +3151,30 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
//===----------------------------------------------------------------------===//

/// Compute the total number of elements in the given type, also taking into
/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
/// `LLVMFixedVectorType`. Everything else is treated as a scalar.
/// account nested types. Supported types are `VectorType` and `LLVMArrayType`.
/// Everything else is treated as a scalar.
static int64_t getNumElements(Type t) {
if (auto vecType = dyn_cast<VectorType>(t))
if (auto vecType = dyn_cast<VectorType>(t)) {
assert(!vecType.isScalable() &&
"number of elements of a scalable vector type is unknown");
return vecType.getNumElements() * getNumElements(vecType.getElementType());
}
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
return arrayType.getNumElements() *
getNumElements(arrayType.getElementType());
if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
return vecType.getNumElements() * getNumElements(vecType.getElementType());
assert(!isa<LLVM::LLVMScalableVectorType>(t) &&
"number of elements of a scalable vector type is unknown");
return 1;
}

/// Check if the given type is a scalable vector type or a vector/array type
/// that contains a nested scalable vector type.
static bool hasScalableVectorType(Type t) {
if (isa<LLVM::LLVMScalableVectorType>(t))
return true;
if (auto vecType = dyn_cast<VectorType>(t)) {
if (vecType.isScalable())
return true;
return hasScalableVectorType(vecType.getElementType());
}
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
return hasScalableVectorType(arrayType.getElementType());
if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
return hasScalableVectorType(vecType.getElementType());
return false;
}

Expand Down Expand Up @@ -3265,8 +3254,7 @@ LogicalResult LLVM::ConstantOp::verify() {
<< "scalable vector type requires a splat attribute";
return success();
}
if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
getType()))
if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
return emitOpError() << "expected vector or array type";
// The number of elements of the attribute and the type must match.
int64_t attrNumElements;
Expand Down Expand Up @@ -3515,8 +3503,7 @@ LogicalResult LLVM::BitcastOp::verify() {
if (!resultType)
return success();

auto isVector =
llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>;
auto isVector = llvm::IsaPred<VectorType>;

// Due to bitcast requiring both operands to be of the same size, it is not
// possible for only one of the two to be a pointer of vectors.
Expand Down Expand Up @@ -3982,7 +3969,6 @@ void LLVMDialect::initialize() {

// clang-format off
addTypes<LLVMVoidType,
LLVMPPCFP128Type,
LLVMTokenType,
LLVMLabelType,
LLVMMetadataType>();
Expand Down
6 changes: 0 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,6 @@ static bool isSupportedTypeForConversion(Type type) {
if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
return false;

// LLVM vector types are only used for either pointers or target specific
// types. These types cannot be casted in the general case, thus the memory
// optimizations do not support them.
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
return false;

if (auto vectorType = dyn_cast<VectorType>(type)) {
// Vectors of pointers cannot be casted.
if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))
Expand Down
44 changes: 1 addition & 43 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ static StringRef getTypeKeyword(Type type) {
.Case<LLVMMetadataType>([&](Type) { return "metadata"; })
.Case<LLVMFunctionType>([&](Type) { return "func"; })
.Case<LLVMPointerType>([&](Type) { return "ptr"; })
.Case<LLVMFixedVectorType, LLVMScalableVectorType>(
[&](Type) { return "vec"; })
.Case<LLVMArrayType>([&](Type) { return "array"; })
.Case<LLVMStructType>([&](Type) { return "struct"; })
.Case<LLVMTargetExtType>([&](Type) { return "target"; })
Expand Down Expand Up @@ -104,8 +102,7 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
printer << getTypeKeyword(type);

llvm::TypeSwitch<Type>(type)
.Case<LLVMPointerType, LLVMArrayType, LLVMFixedVectorType,
LLVMScalableVectorType, LLVMFunctionType, LLVMTargetExtType,
.Case<LLVMPointerType, LLVMArrayType, LLVMFunctionType, LLVMTargetExtType,
LLVMStructType>([&](auto type) { type.print(printer); });
}

Expand All @@ -115,44 +112,6 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {

static ParseResult dispatchParse(AsmParser &parser, Type &type);

/// Parses an LLVM dialect vector type.
/// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
/// Supports both fixed and scalable vectors.
static Type parseVectorType(AsmParser &parser) {
SmallVector<int64_t, 2> dims;
SMLoc dimPos, typePos;
Type elementType;
SMLoc loc = parser.getCurrentLocation();
if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
parser.getCurrentLocation(&typePos) ||
dispatchParse(parser, elementType) || parser.parseGreater())
return Type();

// We parsed a generic dimension list, but vectors only support two forms:
// - single non-dynamic entry in the list (fixed vector);
// - two elements, the first dynamic (indicated by ShapedType::kDynamic)
// and the second
// non-dynamic (scalable vector).
if (dims.empty() || dims.size() > 2 ||
((dims.size() == 2) ^ (ShapedType::isDynamic(dims[0]))) ||
(dims.size() == 2 && ShapedType::isDynamic(dims[1]))) {
parser.emitError(dimPos)
<< "expected '? x <integer> x <type>' or '<integer> x <type>'";
return Type();
}

bool isScalable = dims.size() == 2;
if (isScalable)
return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
if (elementType.isSignlessIntOrFloat()) {
parser.emitError(typePos)
<< "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
return Type();
}
return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
}

/// Attempts to set the body of an identified structure type. Reports a parsing
/// error at `subtypesLoc` in case of failure.
static LLVMStructType trySetStructBody(LLVMStructType type,
Expand Down Expand Up @@ -311,7 +270,6 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
.Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
.Case("func", [&] { return LLVMFunctionType::parse(parser); })
.Case("ptr", [&] { return LLVMPointerType::parse(parser); })
.Case("vec", [&] { return parseVectorType(parser); })
.Case("array", [&] { return LLVMArrayType::parse(parser); })
.Case("struct", [&] { return LLVMStructType::parse(parser); })
.Case("target", [&] { return LLVMTargetExtType::parse(parser); })
Expand Down
Loading
Loading