[mlir] Add encoding attribute to VectorType.#99029
Open
Conversation
RankedTensorType already has encoding attribute. Adding it to VectorType as well.
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Alexander Belyaev (pifon2a) ChangesRankedTensorType already has encoding attribute. Adding it to VectorType as well. Full diff: https://github.com/llvm/llvm-project/pull/99029.diff 6 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 5579b138668d2..564a27e01240c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,13 @@ class VectorType::Builder {
/// Build from another VectorType.
explicit Builder(VectorType other)
: elementType(other.getElementType()), shape(other.getShape()),
- scalableDims(other.getScalableDims()) {}
+ scalableDims(other.getScalableDims()), encoding(other.getEncoding()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims = {})
- : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+ ArrayRef<bool> scalableDims = {}, Attribute encoding = nullptr)
+ : elementType(elementType), shape(shape), scalableDims(scalableDims),
+ encoding(encoding) {}
Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,14 +343,20 @@ class VectorType::Builder {
return *this;
}
+ Builder &setEncoding(Attribute newEncoding) {
+ encoding = newEncoding;
+ return *this;
+ }
+
operator VectorType() {
- return VectorType::get(shape, elementType, scalableDims);
+ return VectorType::get(shape, elementType, scalableDims, encoding);
}
private:
Type elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
+ Attribute encoding;
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..7edc8d228b340 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1060,6 +1060,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
vector-dim-list := (static-dim-list `x`)?
static-dim-list ::= static-dim (`x` static-dim)*
static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+ encoding ::= attribute-value
```
The vector type represents a SIMD style vector used by target-specific
@@ -1072,6 +1073,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
Vector shapes must be positive decimal integers. 0D vectors are allowed by
omitting the dimension: `vector<f32>`.
+ The `encoding` attribute provides additional information on the vector.
+ An empty attribute denotes a straightforward vector without any specific
+ structure.
+
Note: hexadecimal integer literals are not allowed in vector type
declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
2D vector with shape `(0, 42)` and zero shapes are not allowed.
@@ -1094,17 +1099,22 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
// A 3D mixed fixed/scalable vector in which only the inner dimension is
// scalable.
vector<2x[4]x8xf32>
+
+ // Vector with an encoding attribute (where #ENCODING is a named alias).
+ vector<4x2xf64, #ENCODING>
```
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
- ArrayRefParameter<"bool">:$scalableDims
+ ArrayRefParameter<"bool">:$scalableDims,
+ "Attribute":$encoding
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- CArg<"ArrayRef<bool>", "{}">:$scalableDims
+ CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+ CArg<"Attribute", "{}">:$encoding
), [{
// While `scalableDims` is optional, its default value should be
// `false` for every dim in `shape`.
@@ -1113,7 +1123,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
- return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ auto ctx = elementType.getContext();
+ return $_get(ctx, shape, elementType, scalableDims, encoding);
}]>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 0b46c96bbc04d..8d38d5d1c4dea 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -448,6 +448,7 @@ Type Parser::parseTupleType() {
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
+/// encoding ::= attribute-value
///
VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector);
@@ -467,14 +468,29 @@ VectorType Parser::parseVectorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
+
+ // Parse an optional encoding attribute.
+ Attribute encoding;
+ if (consumeIf(Token::comma)) {
+ auto parseResult = parseOptionalAttribute(encoding);
+ if (parseResult.has_value()) {
+ if (failed(parseResult.value()))
+ return nullptr;
+ if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
+ if (failed(v.verifyEncoding(dimensions, elementType,
+ [&] { return emitError(); })))
+ return nullptr;
+ }
+ }
+ }
+
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
-
if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
- return VectorType::get(dimensions, elementType, scalableDims);
+ return VectorType::get(dimensions, elementType, scalableDims, encoding);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 13eb18036eeec..f336ca061a55b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2622,6 +2622,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << 'x';
}
printType(vectorTy.getElementType());
+ // Only print the encoding attribute value if set.
+ if (vectorTy.getEncoding()) {
+ os << ", ";
+ printAttribute(vectorTy.getEncoding());
+ }
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 179797cb943a1..b15a35d1e4126 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims) {
+ ArrayRef<bool> scalableDims,
+ Attribute encoding) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
@@ -242,6 +243,9 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();
+ if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
+ if (failed(v.verifyEncoding(shape, elementType, emitError)))
+ return failure();
return success();
}
@@ -260,7 +264,7 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return VectorType::get(shape.value_or(getShape()), elementType,
- getScalableDims());
+ getScalableDims(), getEncoding());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index cace1fefa43d6..57ccbd6c02da5 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -73,6 +73,9 @@ func.func private @float_types(f80, f128)
// CHECK: func private @vectors(vector<f32>, vector<1xf32>, vector<2x4xf32>)
func.func private @vectors(vector<f32>, vector<1 x f32>, vector<2x4xf32>)
+// CHECK: func private @vector_encoding(vector<16x32xf64, "indexed">)
+func.func private @vector_encoding(vector<16x32xf64, "indexed">)
+
// CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
func.func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
tensor<1x?x4x?x?xi32>, tensor<i8>)
|
MaheshRavishankar
requested changes
Jul 16, 2024
Contributor
MaheshRavishankar
left a comment
There was a problem hiding this comment.
I am for this change, but there was a large discussion on whether this should happen on or not on discourse. I think it would be worthwhile signalling that discourse thread.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
RankedTensorType already has encoding attribute. Adding it to VectorType as well.