-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir] Add encoding attribute to VectorType. #99029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
RankedTensorType already has encoding attribute. Adding it to VectorType as well.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Alexander Belyaev (pifon2a) ChangesRankedTensorType already has encoding attribute. Adding it to VectorType as well. Full diff: https://github.com/llvm/llvm-project/pull/99029.diff 6 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 5579b138668d2..564a27e01240c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,13 @@ class VectorType::Builder {
/// Build from another VectorType.
explicit Builder(VectorType other)
: elementType(other.getElementType()), shape(other.getShape()),
- scalableDims(other.getScalableDims()) {}
+ scalableDims(other.getScalableDims()), encoding(other.getEncoding()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims = {})
- : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+ ArrayRef<bool> scalableDims = {}, Attribute encoding = nullptr)
+ : elementType(elementType), shape(shape), scalableDims(scalableDims),
+ encoding(encoding) {}
Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,14 +343,20 @@ class VectorType::Builder {
return *this;
}
+ Builder &setEncoding(Attribute newEncoding) {
+ encoding = newEncoding;
+ return *this;
+ }
+
operator VectorType() {
- return VectorType::get(shape, elementType, scalableDims);
+ return VectorType::get(shape, elementType, scalableDims, encoding);
}
private:
Type elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
+ Attribute encoding;
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..7edc8d228b340 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1060,6 +1060,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
vector-dim-list := (static-dim-list `x`)?
static-dim-list ::= static-dim (`x` static-dim)*
static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+ encoding ::= attribute-value
```
The vector type represents a SIMD style vector used by target-specific
@@ -1072,6 +1073,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
Vector shapes must be positive decimal integers. 0D vectors are allowed by
omitting the dimension: `vector<f32>`.
+ The `encoding` attribute provides additional information on the vector.
+ An empty attribute denotes a straightforward vector without any specific
+ structure.
+
Note: hexadecimal integer literals are not allowed in vector type
declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
2D vector with shape `(0, 42)` and zero shapes are not allowed.
@@ -1094,17 +1099,22 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
// A 3D mixed fixed/scalable vector in which only the inner dimension is
// scalable.
vector<2x[4]x8xf32>
+
+ // Vector with an encoding attribute (where #ENCODING is a named alias).
+ vector<4x2xf64, #ENCODING>
```
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
- ArrayRefParameter<"bool">:$scalableDims
+ ArrayRefParameter<"bool">:$scalableDims,
+ "Attribute":$encoding
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- CArg<"ArrayRef<bool>", "{}">:$scalableDims
+ CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+ CArg<"Attribute", "{}">:$encoding
), [{
// While `scalableDims` is optional, its default value should be
// `false` for every dim in `shape`.
@@ -1113,7 +1123,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
- return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ auto ctx = elementType.getContext();
+ return $_get(ctx, shape, elementType, scalableDims, encoding);
}]>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 0b46c96bbc04d..8d38d5d1c4dea 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -448,6 +448,7 @@ Type Parser::parseTupleType() {
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
+/// encoding ::= attribute-value
///
VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector);
@@ -467,14 +468,29 @@ VectorType Parser::parseVectorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
+
+ // Parse an optional encoding attribute.
+ Attribute encoding;
+ if (consumeIf(Token::comma)) {
+ auto parseResult = parseOptionalAttribute(encoding);
+ if (parseResult.has_value()) {
+ if (failed(parseResult.value()))
+ return nullptr;
+ if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
+ if (failed(v.verifyEncoding(dimensions, elementType,
+ [&] { return emitError(); })))
+ return nullptr;
+ }
+ }
+ }
+
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
-
if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
- return VectorType::get(dimensions, elementType, scalableDims);
+ return VectorType::get(dimensions, elementType, scalableDims, encoding);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 13eb18036eeec..f336ca061a55b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2622,6 +2622,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << 'x';
}
printType(vectorTy.getElementType());
+ // Only print the encoding attribute value if set.
+ if (vectorTy.getEncoding()) {
+ os << ", ";
+ printAttribute(vectorTy.getEncoding());
+ }
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 179797cb943a1..b15a35d1e4126 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims) {
+ ArrayRef<bool> scalableDims,
+ Attribute encoding) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
@@ -242,6 +243,9 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();
+ if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
+ if (failed(v.verifyEncoding(shape, elementType, emitError)))
+ return failure();
return success();
}
@@ -260,7 +264,7 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return VectorType::get(shape.value_or(getShape()), elementType,
- getScalableDims());
+ getScalableDims(), getEncoding());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index cace1fefa43d6..57ccbd6c02da5 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -73,6 +73,9 @@ func.func private @float_types(f80, f128)
// CHECK: func private @vectors(vector<f32>, vector<1xf32>, vector<2x4xf32>)
func.func private @vectors(vector<f32>, vector<1 x f32>, vector<2x4xf32>)
+// CHECK: func private @vector_encoding(vector<16x32xf64, "indexed">)
+func.func private @vector_encoding(vector<16x32xf64, "indexed">)
+
// CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
func.func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
tensor<1x?x4x?x?xi32>, tensor<i8>)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am for this change, but there was a large discussion on whether this should happen on or not on discourse. I think it would be worthwhile signalling that discourse thread.
RankedTensorType already has encoding attribute. Adding it to VectorType as well.