Skip to content

[mlir] Add encoding attribute to VectorType. #99029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pifon2a
Copy link
Contributor

@pifon2a pifon2a commented Jul 16, 2024

RankedTensorType already has encoding attribute. Adding it to VectorType as well.

RankedTensorType already has encoding attribute. Adding it to VectorType as
well.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Jul 16, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 16, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir-core

Author: Alexander Belyaev (pifon2a)

Changes

RankedTensorType already has encoding attribute. Adding it to VectorType as well.


Full diff: https://github.com/llvm/llvm-project/pull/99029.diff

6 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+11-4)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+14-3)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+18-2)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+5)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+6-2)
  • (modified) mlir/test/IR/parser.mlir (+3)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 5579b138668d2..564a27e01240c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,13 @@ class VectorType::Builder {
   /// Build from another VectorType.
   explicit Builder(VectorType other)
       : elementType(other.getElementType()), shape(other.getShape()),
-        scalableDims(other.getScalableDims()) {}
+        scalableDims(other.getScalableDims()), encoding(other.getEncoding()) {}
 
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
-          ArrayRef<bool> scalableDims = {})
-      : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+          ArrayRef<bool> scalableDims = {}, Attribute encoding = nullptr)
+      : elementType(elementType), shape(shape), scalableDims(scalableDims),
+        encoding(encoding) {}
 
   Builder &setShape(ArrayRef<int64_t> newShape,
                     ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,14 +343,20 @@ class VectorType::Builder {
     return *this;
   }
 
+  Builder &setEncoding(Attribute newEncoding) {
+    encoding = newEncoding;
+    return *this;
+  }
+
   operator VectorType() {
-    return VectorType::get(shape, elementType, scalableDims);
+    return VectorType::get(shape, elementType, scalableDims, encoding);
   }
 
 private:
   Type elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
+  Attribute encoding;
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..7edc8d228b340 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1060,6 +1060,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     vector-dim-list := (static-dim-list `x`)?
     static-dim-list ::= static-dim (`x` static-dim)*
     static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+    encoding ::= attribute-value
     ```
 
     The vector type represents a SIMD style vector used by target-specific
@@ -1072,6 +1073,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     Vector shapes must be positive decimal integers. 0D vectors are allowed by
     omitting the dimension: `vector<f32>`.
 
+    The `encoding` attribute provides additional information on the vector.
+    An empty attribute denotes a straightforward vector without any specific
+    structure.
+
     Note: hexadecimal integer literals are not allowed in vector type
     declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
     2D vector with shape `(0, 42)` and zero shapes are not allowed.
@@ -1094,17 +1099,22 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     // A 3D mixed fixed/scalable vector in which only the inner dimension is
     // scalable.
     vector<2x[4]x8xf32>
+
+    // Vector with an encoding attribute (where #ENCODING is a named alias).
+    vector<4x2xf64, #ENCODING>
     ```
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
     "Type":$elementType,
-    ArrayRefParameter<"bool">:$scalableDims
+    ArrayRefParameter<"bool">:$scalableDims,
+    "Attribute":$encoding
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      CArg<"ArrayRef<bool>", "{}">:$scalableDims
+      CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+      CArg<"Attribute", "{}">:$encoding
     ), [{
       // While `scalableDims` is optional, its default value should be
       // `false` for every dim in `shape`.
@@ -1113,7 +1123,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
         isScalableVec.resize(shape.size(), false);
         scalableDims = isScalableVec;
       }
-      return $_get(elementType.getContext(), shape, elementType, scalableDims);
+      auto ctx = elementType.getContext();
+      return $_get(ctx, shape, elementType, scalableDims, encoding);
     }]>
   ];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 0b46c96bbc04d..8d38d5d1c4dea 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -448,6 +448,7 @@ Type Parser::parseTupleType() {
 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
+/// encoding ::= attribute-value
 ///
 VectorType Parser::parseVectorType() {
   consumeToken(Token::kw_vector);
@@ -467,14 +468,29 @@ VectorType Parser::parseVectorType() {
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
   auto elementType = parseType();
+
+  // Parse an optional encoding attribute.
+  Attribute encoding;
+  if (consumeIf(Token::comma)) {
+    auto parseResult = parseOptionalAttribute(encoding);
+    if (parseResult.has_value()) {
+      if (failed(parseResult.value()))
+        return nullptr;
+      if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
+        if (failed(v.verifyEncoding(dimensions, elementType,
+                                    [&] { return emitError(); })))
+          return nullptr;
+      }
+    }
+  }
+
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
-
   if (!VectorType::isValidElementType(elementType))
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
 
-  return VectorType::get(dimensions, elementType, scalableDims);
+  return VectorType::get(dimensions, elementType, scalableDims, encoding);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 13eb18036eeec..f336ca061a55b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2622,6 +2622,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
           os << 'x';
         }
         printType(vectorTy.getElementType());
+        // Only print the encoding attribute value if set.
+        if (vectorTy.getEncoding()) {
+          os << ", ";
+          printAttribute(vectorTy.getEncoding());
+        }
         os << '>';
       })
       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 179797cb943a1..b15a35d1e4126 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
-                                 ArrayRef<bool> scalableDims) {
+                                 ArrayRef<bool> scalableDims,
+                                 Attribute encoding) {
   if (!isValidElementType(elementType))
     return emitError()
            << "vector elements must be int/index/float type but got "
@@ -242,6 +243,9 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
     return emitError() << "number of dims must match, got "
                        << scalableDims.size() << " and " << shape.size();
 
+  if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
+    if (failed(v.verifyEncoding(shape, elementType, emitError)))
+      return failure();
   return success();
 }
 
@@ -260,7 +264,7 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                  Type elementType) const {
   return VectorType::get(shape.value_or(getShape()), elementType,
-                         getScalableDims());
+                         getScalableDims(), getEncoding());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index cace1fefa43d6..57ccbd6c02da5 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -73,6 +73,9 @@ func.func private @float_types(f80, f128)
 // CHECK: func private @vectors(vector<f32>, vector<1xf32>, vector<2x4xf32>)
 func.func private @vectors(vector<f32>, vector<1 x f32>, vector<2x4xf32>)
 
+// CHECK: func private @vector_encoding(vector<16x32xf64, "indexed">)
+func.func private @vector_encoding(vector<16x32xf64, "indexed">)
+
 // CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
 func.func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
                  tensor<1x?x4x?x?xi32>, tensor<i8>)

@pifon2a pifon2a requested a review from ftynse July 16, 2024 12:54
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am for this change, but there was a large discussion on whether this should happen on or not on discourse. I think it would be worthwhile signalling that discourse thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants