Skip to content

[mlir][vector] Allow integer indices in vector.extract/insert ops #115808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

dcaballe
Copy link
Contributor

vector.extract and vector.insert can currently take an i64 constant or an index type value as indices. The index type will usually lower to an i32 or i64 type. However, we are often indexing really small vector dimensions where smaller integers could be used. This PR extends both ops to accept any integer value as indices. For example:

  %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32>
  %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32>
  %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32>

This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed.

`vector.extract` and `vector.insert` can currently take an `i64` constant
or an `index` type value as indices. The `index` type will usually lower to
an `i32` or `i64` type. However, we are often indexing really small vector
dimensions where smaller integers could be used. This PR extends both
ops to accept any integer value as indices. For example:

```
  %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32>
  %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32>
  %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32>
```

This led to some changes to the ops' parser and printer. When a value index is provided,
the index type is printed as part of the index list. All the value indices provided must
match that type. When no value index is provided, no index type is printed.
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-sme

Author: Diego Caballero (dcaballe)

Changes

vector.extract and vector.insert can currently take an i64 constant or an index type value as indices. The index type will usually lower to an i32 or i64 type. However, we are often indexing really small vector dimensions where smaller integers could be used. This PR extends both ops to accept any integer value as indices. For example:

  %0 = vector.extract %arg0[%i32_idx : i32] : vector&lt;8x16xf32&gt; from vector&lt;4x8x16xf32&gt;
  %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector&lt;16xf32&gt; from vector&lt;4x8x16xf32&gt;
  %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector&lt;4x8x16xf32&gt;

This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed.


Patch is 84.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115808.diff

22 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+13-10)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+17-4)
  • (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+25-4)
  • (modified) mlir/lib/AsmParser/AsmParserImpl.h (+19-4)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+7-4)
  • (modified) mlir/lib/AsmParser/Parser.h (+8-1)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+6-4)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-4)
  • (modified) mlir/lib/Interfaces/ViewLikeInterface.cpp (+40-6)
  • (modified) mlir/test/Conversion/VectorToArmSME/unsupported.mlir (+5-5)
  • (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (+49-49)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+8-8)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+4-4)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+6-6)
  • (modified) mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (+2-2)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+13-13)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+2-2)
  • (modified) mlir/test/Dialect/Linalg/transform-ops-invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+65)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+39-12)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+12-12)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5b08d6aa022b1..dad08305b2a645 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -695,14 +695,14 @@ def Vector_ExtractOp :
     %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
     %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
     %3 = vector.extract %1[]: vector<f32> from vector<f32>
-    %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
-    %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
+    %4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32>
+    %5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32>
     ```
   }];
 
   let arguments = (ins
     AnyVectorOfAnyRank:$vector,
-    Variadic<Index>:$dynamic_position,
+    Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
   let results = (outs AnyType:$result);
@@ -737,7 +737,8 @@ def Vector_ExtractOp :
 
   let assemblyFormat = [{
     $vector ``
-    custom<DynamicIndexList>($dynamic_position, $static_position)
+    custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+                                     type($dynamic_position))
     attr-dict `:` type($result) `from` type($vector)
   }];
 
@@ -883,15 +884,15 @@ def Vector_InsertOp :
     %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
     %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
     %8 = vector.insert %6, %7[] : f32 into vector<f32>
-    %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
-    %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
+    %11 = vector.insert %9, %10[%a, %b, %c : index] : vector<f32> into vector<4x8x16xf32>
+    %12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32>
     ```
   }];
 
   let arguments = (ins
     AnyType:$source,
     AnyVectorOfAnyRank:$dest,
-    Variadic<Index>:$dynamic_position,
+    Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
   let results = (outs AnyVectorOfAnyRank:$result);
@@ -926,7 +927,9 @@ def Vector_InsertOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+    $source `,` $dest
+    custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+                                     type($dynamic_position))
     attr-dict `:` type($source) `into` type($dest)
   }];
 
@@ -1344,7 +1347,7 @@ def Vector_TransferReadOp :
           %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
           // Update the temporary gathered slice with the individual element
           %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
-          %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
+          %updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32>
           memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
     }}}
     // At this point we gathered the elements from the original
@@ -1367,7 +1370,7 @@ def Vector_TransferReadOp :
         %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
         %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
         // Here we only store to the first element in dimension one
-        %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
+        %updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32>
         memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
     }}
     // At this point we gathered the elements from the original
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a7222794f320b2..699dd1da863b6f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -794,16 +794,26 @@ class AsmParser {
   };
 
   /// Parse a list of comma-separated items with an optional delimiter.  If a
-  /// delimiter is provided, then an empty list is allowed.  If not, then at
+  /// delimiter is provided, then an empty list is allowed. If not, then at
   /// least one element will be parsed.
   ///
+  /// `parseSuffixFn` is an optional function to parse any suffix that can be
+  /// appended to the comma separated list within the delimiter.
+  ///
   /// contextMessage is an optional message appended to "expected '('" sorts of
   /// diagnostics when parsing the delimeters.
-  virtual ParseResult
+  virtual ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+      std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+      StringRef contextMessage = StringRef()) = 0;
+  ParseResult
   parseCommaSeparatedList(Delimiter delimiter,
                           function_ref<ParseResult()> parseElementFn,
-                          StringRef contextMessage = StringRef()) = 0;
-
+                          StringRef contextMessage) {
+    return parseCommaSeparatedList(delimiter, parseElementFn,
+                                   /*parseSuffixFn=*/std::nullopt,
+                                   contextMessage);
+  }
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
   ParseResult
@@ -1319,6 +1329,9 @@ class AsmParser {
   virtual ParseResult
   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
 
+  /// Parse an optional colon followed by a type.
+  virtual ParseResult parseOptionalColonType(Type &result) = 0;
+
   /// Parse a keyword followed by a type.
   ParseResult parseKeywordType(const char *keyword, Type &result) {
     return failure(parseKeyword(keyword) || parseType(result));
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 3dcbd2f1af1936..1971c25a8f20b1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
 /// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
 /// is non-empty, it is expected to contain as many elements as `values`
 /// indicating their types. This allows idiomatic printing of mixed value and
-/// integer attributes in a list. E.g.
-/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the
+/// same and only one type is printed at the end of the list. E.g.,
+/// `[0, %arg2, 3, %arg42, 2 : i8]`.
 ///
 /// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
 /// This notation is similar to how scalable dims are marked when defining
@@ -108,7 +110,8 @@ void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool hasSameTypeDynamicValues = false);
 inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                   OperandRange values,
                                   ArrayRef<int64_t> integers,
@@ -123,6 +126,13 @@ inline void printDynamicIndexList(
   return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
                                delimiter);
 }
+inline void printSameTypeDynamicIndexList(
+    OpAsmPrinter &printer, Operation *op, OperandRange values,
+    ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
+                               delimiter, /*hasSameTypeDynamicValues=*/true);
+}
 
 /// Parser hook for custom directive in assemblyFormat.
 ///
@@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool hasSameTypeDynamicValues = false);
 inline ParseResult
 parseDynamicIndexList(OpAsmParser &parser,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
@@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList(
   return parseDynamicIndexList(parser, values, integers, scalableVals,
                                &valueTypes, delimiter);
 }
+inline ParseResult parseSameTypeDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals,
+                               &valueTypes, delimiter,
+                               /*hasSameTypeDynamicValues=*/true);
+}
 
 /// Verify that a the `values` has as many elements as the number of entries in
 /// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 04250f63dcd253..4d5b93ec09d175 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT {
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
-  ParseResult parseCommaSeparatedList(Delimiter delimiter,
-                                      function_ref<ParseResult()> parseElt,
-                                      StringRef contextMessage) override {
-    return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+  ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElt,
+      std::optional<function_ref<ParseResult()>> parseSuffix,
+      StringRef contextMessage) override {
+    return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix,
+                                          contextMessage);
   }
 
+  using BaseT::parseCommaSeparatedList;
+
   //===--------------------------------------------------------------------===//
   // Keyword Parsing
   //===--------------------------------------------------------------------===//
@@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT {
     return parser.parseTypeListNoParens(result);
   }
 
+  /// Parse an optional colon followed by a type.
+  ParseResult parseOptionalColonType(Type &result) override {
+    SmallVector<Type, 1> types;
+    ParseResult parseResult = parseOptionalColonTypeList(types);
+    if (llvm::succeeded(parseResult) && types.size() > 1)
+      return emitError(getCurrentLocation(), "expected single type");
+    if (!types.empty())
+      result = types[0];
+    return parseResult;
+  }
+
   ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
                                  bool allowDynamic,
                                  bool withTrailingX) override {
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8f19487d80fa39..6476910f71eb7f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
 /// Parse a list of comma-separated items with an optional delimiter.  If a
 /// delimiter is provided, then an empty list is allowed.  If not, then at
 /// least one element will be parsed.
-ParseResult
-Parser::parseCommaSeparatedList(Delimiter delimiter,
-                                function_ref<ParseResult()> parseElementFn,
-                                StringRef contextMessage) {
+ParseResult Parser::parseCommaSeparatedList(
+    Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+    std::optional<function_ref<ParseResult()>> parseSuffixFn,
+    StringRef contextMessage) {
   switch (delimiter) {
   case Delimiter::None:
     break;
@@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter,
       return failure();
   }
 
+  if (parseSuffixFn && (*parseSuffixFn)())
+    return failure();
+
   switch (delimiter) {
   case Delimiter::None:
     return success();
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index bf91831798056b..1ebca05bbcb2ef 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -46,10 +46,17 @@ class Parser {
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
+  ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+      std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+      StringRef contextMessage = StringRef());
   ParseResult
   parseCommaSeparatedList(Delimiter delimiter,
                           function_ref<ParseResult()> parseElementFn,
-                          StringRef contextMessage = StringRef());
+                          StringRef contextMessage) {
+    return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt,
+                                   contextMessage);
+  }
 
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 55965d9c2a531d..c5c3353bf0477f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering
 ///
 /// Example:
 /// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// %el = vector.extract %tile[%row, %col : index] : i32 from
+/// vector<[4]x[4]xi32>
 /// ```
 /// Becomes:
 /// ```
 /// %slice = arm_sme.extract_tile_slice %tile[%row]
 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32>
 /// ```
 struct VectorExtractToArmSMELowering
     : public OpRewritePattern<vector::ExtractOp> {
@@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering
 /// ```
 /// %slice = arm_sme.extract_tile_slice %tile[%row]
 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
+/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into
+/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice,
+/// %tile[%row]
 ///               : vector<[4]xi32> into vector<[4]x[4]xi32>
 /// ```
 struct VectorInsertToArmSMELowering
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..b623a86c53ee71 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
 /// %vscale = vector.vscale
 /// %c4_vscale = arith.muli %vscale, %c4 : index
 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
-///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
-///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
-///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
-///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+///   %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32>
+///   %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32>
+///   %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32>
+///   %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32>
 ///   %slice_i = affine.apply #map(%idx)[%i]
 ///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
 ///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index ca33636336bf0c..8e44ff60eec874 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                  OperandRange values,
                                  ArrayRef<int64_t> integers,
                                  ArrayRef<bool> scalables, TypeRange valueTypes,
-                                 AsmParser::Delimiter delimiter) {
+                                 AsmParser::Delimiter delimiter,
+                                 bool hasSameTypeDynamicValues) {
   char leftDelimiter = getLeftDelimiter(delimiter);
   char rightDelimiter = getRightDelimiter(delimiter);
   printer << leftDelimiter;
@@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
       printer << "[";
     if (ShapedType::isDynamic(integer)) {
       printer << values[dynamicValIdx];
-      if (!valueTypes.empty())
+      if (!hasSameTypeDynamicValues && !valueTypes.empty())
         printer << " : " << valueTypes[dynamicValIdx];
       ++dynamicValIdx;
     } else {
@@ -142,6 +143,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
     scalableIndexIdx++;
   });
 
+  if (hasSameTypeDynamicValues && !valueTypes.empty()) {
+    assert(std::all_of(valueTypes.begin(), valueTypes.end(),
+                       [&](Type type) { return type == valueTypes[0]; }) &&
+           "Expected the same value types");
+    printer << " : " << valueTypes[0];
+  }
+
   printer << rightDelimiter;
 }
 
@@ -149,7 +157,8 @@ ParseResult mlir::parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
-    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
+    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter,
+    bool hasSameTypeDynamicValues) {
 
   SmallVector<int64_t, 4> integerVals;
   SmallVector<bool, 4> scalableVals;
@@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList(
     if (res.has_value() && succeeded(res.value())) {
       values.push_back(operand);
       integerVals.push_back(ShapedType::kDynamic);
-      if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+      if (!hasSameTypeDynamicValues && valueTypes &&
+          parser.parseColonType(valueTypes->emplace_back()))
         return failure();
     } else {
       int64_t integer;
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList(
       return failure();
     return success();
   };
+  auto parseColonType = [&]() -> ParseResult {
+    if (hasSameTypeDynamicValues) {
+      assert(valueTypes && "Expected non-null value types");
+      assert(valueTypes->empty() && "Expected no parsed value types");
+
+      Type dynValType;
+      if (parser.parseOptionalColonType(dynValType))
+        return failure();
+
+      if (!dynValType && !values.empty())
+        return parser.emitError(parser.getNameLoc())
+               << "expected a type for dynamic indices";
+      if (dynValType) {
+        if (values.empty())
+          return parser.emitError(parser.getNameLoc())
+                 << "expected no type for constant indices";
+
+        // Broadcast the single type to all the dynamic values.
+        valueTypes->append(values.size(), dynValType);
+   ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Diego Caballero (dcaballe)

Changes

vector.extract and vector.insert can currently take an i64 constant or an index type value as indices. The index type will usually lower to an i32 or i64 type. However, we are often indexing really small vector dimensions where smaller integers could be used. This PR extends both ops to accept any integer value as indices. For example:

  %0 = vector.extract %arg0[%i32_idx : i32] : vector&lt;8x16xf32&gt; from vector&lt;4x8x16xf32&gt;
  %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector&lt;16xf32&gt; from vector&lt;4x8x16xf32&gt;
  %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector&lt;4x8x16xf32&gt;

This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed.


Patch is 84.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115808.diff

22 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+13-10)
  • (modified) mlir/include/mlir/IR/OpImplementation.h (+17-4)
  • (modified) mlir/include/mlir/Interfaces/ViewLikeInterface.h (+25-4)
  • (modified) mlir/lib/AsmParser/AsmParserImpl.h (+19-4)
  • (modified) mlir/lib/AsmParser/Parser.cpp (+7-4)
  • (modified) mlir/lib/AsmParser/Parser.h (+8-1)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+6-4)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-4)
  • (modified) mlir/lib/Interfaces/ViewLikeInterface.cpp (+40-6)
  • (modified) mlir/test/Conversion/VectorToArmSME/unsupported.mlir (+5-5)
  • (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (+49-49)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+8-8)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+4-4)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+6-6)
  • (modified) mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (+2-2)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+13-13)
  • (modified) mlir/test/Dialect/Linalg/hoisting.mlir (+2-2)
  • (modified) mlir/test/Dialect/Linalg/transform-ops-invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+65)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+39-12)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+12-12)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5b08d6aa022b1..dad08305b2a645 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -695,14 +695,14 @@ def Vector_ExtractOp :
     %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
     %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
     %3 = vector.extract %1[]: vector<f32> from vector<f32>
-    %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
-    %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
+    %4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32>
+    %5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32>
     ```
   }];
 
   let arguments = (ins
     AnyVectorOfAnyRank:$vector,
-    Variadic<Index>:$dynamic_position,
+    Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
   let results = (outs AnyType:$result);
@@ -737,7 +737,8 @@ def Vector_ExtractOp :
 
   let assemblyFormat = [{
     $vector ``
-    custom<DynamicIndexList>($dynamic_position, $static_position)
+    custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+                                     type($dynamic_position))
     attr-dict `:` type($result) `from` type($vector)
   }];
 
@@ -883,15 +884,15 @@ def Vector_InsertOp :
     %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
     %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
     %8 = vector.insert %6, %7[] : f32 into vector<f32>
-    %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
-    %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
+    %11 = vector.insert %9, %10[%a, %b, %c : index] : vector<f32> into vector<4x8x16xf32>
+    %12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32>
     ```
   }];
 
   let arguments = (ins
     AnyType:$source,
     AnyVectorOfAnyRank:$dest,
-    Variadic<Index>:$dynamic_position,
+    Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
     DenseI64ArrayAttr:$static_position
   );
   let results = (outs AnyVectorOfAnyRank:$result);
@@ -926,7 +927,9 @@ def Vector_InsertOp :
   }];
 
   let assemblyFormat = [{
-    $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+    $source `,` $dest
+    custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+                                     type($dynamic_position))
     attr-dict `:` type($source) `into` type($dest)
   }];
 
@@ -1344,7 +1347,7 @@ def Vector_TransferReadOp :
           %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
           // Update the temporary gathered slice with the individual element
           %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
-          %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
+          %updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32>
           memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
     }}}
     // At this point we gathered the elements from the original
@@ -1367,7 +1370,7 @@ def Vector_TransferReadOp :
         %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
         %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
         // Here we only store to the first element in dimension one
-        %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
+        %updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32>
         memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
     }}
     // At this point we gathered the elements from the original
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a7222794f320b2..699dd1da863b6f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -794,16 +794,26 @@ class AsmParser {
   };
 
   /// Parse a list of comma-separated items with an optional delimiter.  If a
-  /// delimiter is provided, then an empty list is allowed.  If not, then at
+  /// delimiter is provided, then an empty list is allowed. If not, then at
   /// least one element will be parsed.
   ///
+  /// `parseSuffixFn` is an optional function to parse any suffix that can be
+  /// appended to the comma separated list within the delimiter.
+  ///
   /// contextMessage is an optional message appended to "expected '('" sorts of
   /// diagnostics when parsing the delimeters.
-  virtual ParseResult
+  virtual ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+      std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+      StringRef contextMessage = StringRef()) = 0;
+  ParseResult
   parseCommaSeparatedList(Delimiter delimiter,
                           function_ref<ParseResult()> parseElementFn,
-                          StringRef contextMessage = StringRef()) = 0;
-
+                          StringRef contextMessage) {
+    return parseCommaSeparatedList(delimiter, parseElementFn,
+                                   /*parseSuffixFn=*/std::nullopt,
+                                   contextMessage);
+  }
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
   ParseResult
@@ -1319,6 +1329,9 @@ class AsmParser {
   virtual ParseResult
   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
 
+  /// Parse an optional colon followed by a type.
+  virtual ParseResult parseOptionalColonType(Type &result) = 0;
+
   /// Parse a keyword followed by a type.
   ParseResult parseKeywordType(const char *keyword, Type &result) {
     return failure(parseKeyword(keyword) || parseType(result));
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 3dcbd2f1af1936..1971c25a8f20b1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
 /// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
 /// is non-empty, it is expected to contain as many elements as `values`
 /// indicating their types. This allows idiomatic printing of mixed value and
-/// integer attributes in a list. E.g.
-/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the
+/// same and only one type is printed at the end of the list. E.g.,
+/// `[0, %arg2, 3, %arg42, 2 : i8]`.
 ///
 /// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
 /// This notation is similar to how scalable dims are marked when defining
@@ -108,7 +110,8 @@ void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool hasSameTypeDynamicValues = false);
 inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                   OperandRange values,
                                   ArrayRef<int64_t> integers,
@@ -123,6 +126,13 @@ inline void printDynamicIndexList(
   return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
                                delimiter);
 }
+inline void printSameTypeDynamicIndexList(
+    OpAsmPrinter &printer, Operation *op, OperandRange values,
+    ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
+                               delimiter, /*hasSameTypeDynamicValues=*/true);
+}
 
 /// Parser hook for custom directive in assemblyFormat.
 ///
@@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool hasSameTypeDynamicValues = false);
 inline ParseResult
 parseDynamicIndexList(OpAsmParser &parser,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
@@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList(
   return parseDynamicIndexList(parser, values, integers, scalableVals,
                                &valueTypes, delimiter);
 }
+inline ParseResult parseSameTypeDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals,
+                               &valueTypes, delimiter,
+                               /*hasSameTypeDynamicValues=*/true);
+}
 
 /// Verify that a the `values` has as many elements as the number of entries in
 /// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 04250f63dcd253..4d5b93ec09d175 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT {
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
-  ParseResult parseCommaSeparatedList(Delimiter delimiter,
-                                      function_ref<ParseResult()> parseElt,
-                                      StringRef contextMessage) override {
-    return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+  ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElt,
+      std::optional<function_ref<ParseResult()>> parseSuffix,
+      StringRef contextMessage) override {
+    return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix,
+                                          contextMessage);
   }
 
+  using BaseT::parseCommaSeparatedList;
+
   //===--------------------------------------------------------------------===//
   // Keyword Parsing
   //===--------------------------------------------------------------------===//
@@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT {
     return parser.parseTypeListNoParens(result);
   }
 
+  /// Parse an optional colon followed by a type.
+  ParseResult parseOptionalColonType(Type &result) override {
+    SmallVector<Type, 1> types;
+    ParseResult parseResult = parseOptionalColonTypeList(types);
+    if (llvm::succeeded(parseResult) && types.size() > 1)
+      return emitError(getCurrentLocation(), "expected single type");
+    if (!types.empty())
+      result = types[0];
+    return parseResult;
+  }
+
   ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
                                  bool allowDynamic,
                                  bool withTrailingX) override {
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8f19487d80fa39..6476910f71eb7f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
 /// Parse a list of comma-separated items with an optional delimiter.  If a
 /// delimiter is provided, then an empty list is allowed.  If not, then at
 /// least one element will be parsed.
-ParseResult
-Parser::parseCommaSeparatedList(Delimiter delimiter,
-                                function_ref<ParseResult()> parseElementFn,
-                                StringRef contextMessage) {
+ParseResult Parser::parseCommaSeparatedList(
+    Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+    std::optional<function_ref<ParseResult()>> parseSuffixFn,
+    StringRef contextMessage) {
   switch (delimiter) {
   case Delimiter::None:
     break;
@@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter,
       return failure();
   }
 
+  if (parseSuffixFn && (*parseSuffixFn)())
+    return failure();
+
   switch (delimiter) {
   case Delimiter::None:
     return success();
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index bf91831798056b..1ebca05bbcb2ef 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -46,10 +46,17 @@ class Parser {
   /// Parse a list of comma-separated items with an optional delimiter.  If a
   /// delimiter is provided, then an empty list is allowed.  If not, then at
   /// least one element will be parsed.
+  ParseResult parseCommaSeparatedList(
+      Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+      std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+      StringRef contextMessage = StringRef());
   ParseResult
   parseCommaSeparatedList(Delimiter delimiter,
                           function_ref<ParseResult()> parseElementFn,
-                          StringRef contextMessage = StringRef());
+                          StringRef contextMessage) {
+    return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt,
+                                   contextMessage);
+  }
 
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 55965d9c2a531d..c5c3353bf0477f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering
 ///
 /// Example:
 /// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// %el = vector.extract %tile[%row, %col : index] : i32 from
+/// vector<[4]x[4]xi32>
 /// ```
 /// Becomes:
 /// ```
 /// %slice = arm_sme.extract_tile_slice %tile[%row]
 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32>
 /// ```
 struct VectorExtractToArmSMELowering
     : public OpRewritePattern<vector::ExtractOp> {
@@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering
 /// ```
 /// %slice = arm_sme.extract_tile_slice %tile[%row]
 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
+/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into
+/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice,
+/// %tile[%row]
 ///               : vector<[4]xi32> into vector<[4]x[4]xi32>
 /// ```
 struct VectorInsertToArmSMELowering
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..b623a86c53ee71 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
 /// %vscale = vector.vscale
 /// %c4_vscale = arith.muli %vscale, %c4 : index
 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
-///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
-///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
-///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
-///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+///   %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32>
+///   %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32>
+///   %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32>
+///   %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32>
 ///   %slice_i = affine.apply #map(%idx)[%i]
 ///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
 ///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index ca33636336bf0c..8e44ff60eec874 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                  OperandRange values,
                                  ArrayRef<int64_t> integers,
                                  ArrayRef<bool> scalables, TypeRange valueTypes,
-                                 AsmParser::Delimiter delimiter) {
+                                 AsmParser::Delimiter delimiter,
+                                 bool hasSameTypeDynamicValues) {
   char leftDelimiter = getLeftDelimiter(delimiter);
   char rightDelimiter = getRightDelimiter(delimiter);
   printer << leftDelimiter;
@@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
       printer << "[";
     if (ShapedType::isDynamic(integer)) {
       printer << values[dynamicValIdx];
-      if (!valueTypes.empty())
+      if (!hasSameTypeDynamicValues && !valueTypes.empty())
         printer << " : " << valueTypes[dynamicValIdx];
       ++dynamicValIdx;
     } else {
@@ -142,6 +143,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
     scalableIndexIdx++;
   });
 
+  if (hasSameTypeDynamicValues && !valueTypes.empty()) {
+    assert(std::all_of(valueTypes.begin(), valueTypes.end(),
+                       [&](Type type) { return type == valueTypes[0]; }) &&
+           "Expected the same value types");
+    printer << " : " << valueTypes[0];
+  }
+
   printer << rightDelimiter;
 }
 
@@ -149,7 +157,8 @@ ParseResult mlir::parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
-    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
+    SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter,
+    bool hasSameTypeDynamicValues) {
 
   SmallVector<int64_t, 4> integerVals;
   SmallVector<bool, 4> scalableVals;
@@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList(
     if (res.has_value() && succeeded(res.value())) {
       values.push_back(operand);
       integerVals.push_back(ShapedType::kDynamic);
-      if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+      if (!hasSameTypeDynamicValues && valueTypes &&
+          parser.parseColonType(valueTypes->emplace_back()))
         return failure();
     } else {
       int64_t integer;
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList(
       return failure();
     return success();
   };
+  auto parseColonType = [&]() -> ParseResult {
+    if (hasSameTypeDynamicValues) {
+      assert(valueTypes && "Expected non-null value types");
+      assert(valueTypes->empty() && "Expected no parsed value types");
+
+      Type dynValType;
+      if (parser.parseOptionalColonType(dynValType))
+        return failure();
+
+      if (!dynValType && !values.empty())
+        return parser.emitError(parser.getNameLoc())
+               << "expected a type for dynamic indices";
+      if (dynValType) {
+        if (values.empty())
+          return parser.emitError(parser.getNameLoc())
+                 << "expected no type for constant indices";
+
+        // Broadcast the single type to all the dynamic values.
+        valueTypes->append(values.size(), dynValType);
+   ...
[truncated]

@dcaballe dcaballe requested a review from c-rhodes November 12, 2024 05:20
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this - mostly makes sense, but I have a few questions :)

What's the end-goal here - this would be quite desirable for vector.gather. Importantly, how should we decide what type to use for the index variables? This change creates a tricky decision point for code-gen 🤔

Btw, it would be good to add new tests to vector-to-llvm.mlir to demonstrate the impact of this on the actual "end result" (i.e. LLVM IR). I guess this makes little sense if things don't change at the LLVM level?

@@ -271,6 +304,38 @@ func.func @insert_0d(%a: f32, %b: vector<f32>) {
%1 = vector.insert %a, %b[0] : f32 into vector<f32>
}

// -----
func.func @extract_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func.func @extract_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>,
func.func @insert_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>,

Same comment for the tests below. Note that you are inserting a single value rather than a vector (@extract_vector -> @insert_vector -> @insert_value?)


// CHECK-LABEL: @extract_val_int
// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8
func.func @extract_val_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] This is clear today, but I know that my future self will be grateful for a bit descriptive name :)

Suggested change
func.func @extract_val_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32,
func.func @extract_val_idx_as_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32,

@@ -274,7 +274,7 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect
// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
%idx = arith.constant 2 : index
%0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
%0 = vector.insert %val, %arg0[%idx : index] : f32 into vector<4xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add tests for i8 and maybe i1 as the insert/extract index types? These require type conversion in the general case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Jakub, I added i32 and i8 tests but they need to modify the spirv conversion, as you mentioned. Would you mind helping with that? I have no idea about that pass. How would you like to proceed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you file an issue for this and link to it in the PR description and perhaps in some TODO in the code?

@dcaballe
Copy link
Contributor Author

What's the end-goal here - Importantly, how should we decide what type to use for the index variables? This change creates a tricky decision point for code-gen

The end goal is to be able to default to the widest type but use narrower types when needed and it's safe. That should be simple enough. This shouldn't add more complexity to codegen than we already have. We already have to deal with i64 constant and index variable indices.

I would go even further: vector.insert/extract are not memory related operations so I'm not sure using the index type is even safe here. It's a bit of a stretch but for a 32-bit machine, I'm pretty sure I can come up with an example where a 32-bit vector.extract/insert index overflows :). IMO, using integers would make more sense. Maybe I'm missing something.

this would be quite desirable for vector.gather

Exactly! The same principle applies. However, note that vector.gather is a memory related op, vector.extract/insert are not.

I guess this makes little sense if things don't change at the LLVM level?

AFAIK, LLVM's extract/insertelement instructions accept any integer type.

@banach-space
Copy link
Contributor

The end goal is to be able to default to the widest type but use narrower types when needed and it's safe.

So I've been trying to figure out the right mechanism to select the right index size 😅 Suggestions are much appreciated :) At a very coarse grain level we could use the architecture pointer size, but this way we'd be mostly switching between 32 and 64 bits. #not-good-enough :)

However, note that vector.gather is a memory related op, vector.extract/insert are not.

Are you thinking that the default for vector.extract/insert should be narrower than for vector.gather? For example, i32 and i64, respectively, on a 64 bit machine?

@dcaballe
Copy link
Contributor Author

Are you thinking that the default for vector.extract/insert should be narrower than for vector.gather? For example, i32 and i64, respectively, on a 64 bit machine?

Yes and no. What I mean here is that gather indices are limited to the memory in the system. Extracts/inserts... not necessarily... For example, could we create a i1 vector with more elements than memory bytes in the system and then extract one element from it? 😄 Perhaps not very realistic and nothing we should worry about but that's the difference I see between gather and extract/insert indices.

@dcaballe
Copy link
Contributor Author

Is this something we can land already? Any other comments?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

Please wait for @kuhar to also approve.

@dcaballe
Copy link
Contributor Author

dcaballe commented Jan 8, 2025

New year's ping :)

Hopefully we can land it before it gets too stale.

@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grammar in this sentence is broken: is kDynamic. Can you fix that also?

/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`.
/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the
/// same and only one type is printed at the end of the list. E.g.,
/// `[0, %arg2, 3, %arg42, 2 : i8]`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarify that the number of types in valueTypes must match the number of dynamic elements, even if hasSameTypeDynamicValues is set.

Btw, have you considered changing the API such that valueTypes contains only a single value in case of hasSameTypeDynamicValues? That would seem more natural to me.

@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this function have valueTypes parameter? The type can be taken from the SSA values in values. Is it possible to remove valueTypes? I think a bool printTypes should be sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valueTypes is needed for integers. We are representing integers with int64_t but their actual type comes from valueTypes. Perhaps we should rename this to itemTypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

values can also be empty

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, the types are not used for integers (!). I think valueTypes is needed to match the function signature expected by custom($values, $integers, type($values)). We could verify that both $values and $type($values) match and then use just one of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this a bit. I think valueTypes in needed in the parser, so that the user-specified types can be checked against the actual types in resolveOperands. printDynamicIndexList just has it for consistency.

@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
ArrayRef<bool> scalables, TypeRange valueTypes,
AsmParser::Delimiter delimiter) {
AsmParser::Delimiter delimiter,
bool hasSameTypeDynamicValues) {
char leftDelimiter = getLeftDelimiter(delimiter);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case we keep the valueTypes parameter, I think there should be an assert that checks the number of elements in valueTypes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can we assert that TypeRange(values) == valueTypes?

@@ -142,14 +143,22 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
scalableIndexIdx++;
});

if (hasSameTypeDynamicValues && !valueTypes.empty()) {
assert(std::all_of(valueTypes.begin(), valueTypes.end(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::all_equal

@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList(
return failure();
return success();
};
auto parseColonType = [&]() -> ParseResult {
if (hasSameTypeDynamicValues) {
assert(valueTypes && "Expected non-null value types");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is valueTypes requires when hasSameTypeDynamicValues is "true", but it is not required when hasSameTypeDynamicValues is "false"?

}

using BaseT::parseCommaSeparatedList;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there two overloads, one calling the super implementation and the other one calling parser.parseCommaSeparatedList?

return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
ParseResult parseCommaSeparatedList(
Delimiter delimiter, function_ref<ParseResult()> parseElt,
std::optional<function_ref<ParseResult()>> parseSuffix,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move parseSuffix to the end and make it function_ref<ParseResult()> parseSuffix = nullptr? function_ref is rarely used with std::optional. Also, maybe this function would not be needed at all then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is needed to parse the suffix, which is optional. I'm using a function_ref similar to what we do for parseElt.

StringRef contextMessage) {
ParseResult Parser::parseCommaSeparatedList(
Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
std::optional<function_ref<ParseResult()>> parseSuffixFn,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't function_ref nullable on its own? This is the case for plain std::function. I'd like to avoid double nullability where possible.

matthias-springer added a commit that referenced this pull request Jan 10, 2025
…oads (#122436)

#115808 adds additional `custom<>` parser/printer variants. The overall
list of overloads/variants is getting larger.

This commit removes overloads that are not needed, to keep the
parser/printer simple.
@dcaballe
Copy link
Contributor Author

Thanks for the cleanup, @matthias-springer! I'm revisiting some implementation details and making it more generic. Hang tight! :)

dcaballe added a commit to dcaballe/llvm-project that referenced this pull request Jan 11, 2025
…ter handlers.

This PR addresses part of the feedback provided in llvm#115808.
dcaballe added a commit to dcaballe/llvm-project that referenced this pull request Jan 11, 2025
…ter handlers.

This PR addresses part of the feedback provided in llvm#115808.
dcaballe added a commit to dcaballe/llvm-project that referenced this pull request Jan 11, 2025
…ter handlers.

This PR addresses part of the feedback provided in llvm#115808.
BaiXilin pushed a commit to BaiXilin/llvm-fix-vnni-instr-types that referenced this pull request Jan 12, 2025
…oads (llvm#122436)

llvm#115808 adds additional `custom<>` parser/printer variants. The overall
list of overloads/variants is getting larger.

This commit removes overloads that are not needed, to keep the
parser/printer simple.
dcaballe added a commit to dcaballe/llvm-project that referenced this pull request Jan 13, 2025
…ter handlers.

This PR addresses part of the feedback provided in llvm#115808.
dcaballe added a commit that referenced this pull request Jan 15, 2025
…ter handlers (#122555)

This PR addresses part of the feedback provided in #115808.
@banach-space
Copy link
Contributor

hey @dcaballe , do you have the cycles to progress this? It would be great to see it in-tree :)

@dcaballe
Copy link
Contributor Author

Yes, actually, I put quite some time on this internally and was discussing with @matthias-springer about it. Unfortunately, we hit a dead-end and I have to backtrack some of the generalization changes so the current PR is pretty close to how the final state would look like, at least in terms of functionality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants