-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][vector] Allow integer indices in vector.extract/insert ops #115808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
`vector.extract` and `vector.insert` can currently take an `i64` constant or an `index` type value as indices. The `index` type will usually lower to an `i32` or `i64` type. However, we are often indexing really small vector dimensions where smaller integers could be used. This PR extends both ops to accept any integer value as indices. For example: ``` %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32> %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32> %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32> ``` This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-sme Author: Diego Caballero (dcaballe) Changes
This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed. Patch is 84.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115808.diff 22 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5b08d6aa022b1..dad08305b2a645 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -695,14 +695,14 @@ def Vector_ExtractOp :
%1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
%2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
%3 = vector.extract %1[]: vector<f32> from vector<f32>
- %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
- %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
+ %4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32>
+ %5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32>
```
}];
let arguments = (ins
AnyVectorOfAnyRank:$vector,
- Variadic<Index>:$dynamic_position,
+ Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
let results = (outs AnyType:$result);
@@ -737,7 +737,8 @@ def Vector_ExtractOp :
let assemblyFormat = [{
$vector ``
- custom<DynamicIndexList>($dynamic_position, $static_position)
+ custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+ type($dynamic_position))
attr-dict `:` type($result) `from` type($vector)
}];
@@ -883,15 +884,15 @@ def Vector_InsertOp :
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
%8 = vector.insert %6, %7[] : f32 into vector<f32>
- %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
- %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
+ %11 = vector.insert %9, %10[%a, %b, %c : index] : vector<f32> into vector<4x8x16xf32>
+ %12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32>
```
}];
let arguments = (ins
AnyType:$source,
AnyVectorOfAnyRank:$dest,
- Variadic<Index>:$dynamic_position,
+ Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
let results = (outs AnyVectorOfAnyRank:$result);
@@ -926,7 +927,9 @@ def Vector_InsertOp :
}];
let assemblyFormat = [{
- $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+ $source `,` $dest
+ custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+ type($dynamic_position))
attr-dict `:` type($source) `into` type($dest)
}];
@@ -1344,7 +1347,7 @@ def Vector_TransferReadOp :
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
// Update the temporary gathered slice with the individual element
%slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
- %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
+ %updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32>
memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
}}}
// At this point we gathered the elements from the original
@@ -1367,7 +1370,7 @@ def Vector_TransferReadOp :
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
%slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
// Here we only store to the first element in dimension one
- %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
+ %updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32>
memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
}}
// At this point we gathered the elements from the original
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a7222794f320b2..699dd1da863b6f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -794,16 +794,26 @@ class AsmParser {
};
/// Parse a list of comma-separated items with an optional delimiter. If a
- /// delimiter is provided, then an empty list is allowed. If not, then at
+ /// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
///
+ /// `parseSuffixFn` is an optional function to parse any suffix that can be
+ /// appended to the comma separated list within the delimiter.
+ ///
/// contextMessage is an optional message appended to "expected '('" sorts of
/// diagnostics when parsing the delimeters.
- virtual ParseResult
+ virtual ParseResult parseCommaSeparatedList(
+ Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+ std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+ StringRef contextMessage = StringRef()) = 0;
+ ParseResult
parseCommaSeparatedList(Delimiter delimiter,
function_ref<ParseResult()> parseElementFn,
- StringRef contextMessage = StringRef()) = 0;
-
+ StringRef contextMessage) {
+ return parseCommaSeparatedList(delimiter, parseElementFn,
+ /*parseSuffixFn=*/std::nullopt,
+ contextMessage);
+ }
/// Parse a comma separated list of elements that must have at least one entry
/// in it.
ParseResult
@@ -1319,6 +1329,9 @@ class AsmParser {
virtual ParseResult
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
+ /// Parse an optional colon followed by a type.
+ virtual ParseResult parseOptionalColonType(Type &result) = 0;
+
/// Parse a keyword followed by a type.
ParseResult parseKeywordType(const char *keyword, Type &result) {
return failure(parseKeyword(keyword) || parseType(result));
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 3dcbd2f1af1936..1971c25a8f20b1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
/// is non-empty, it is expected to contain as many elements as `values`
/// indicating their types. This allows idiomatic printing of mixed value and
-/// integer attributes in a list. E.g.
-/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the
+/// same and only one type is printed at the end of the list. E.g.,
+/// `[0, %arg2, 3, %arg42, 2 : i8]`.
///
/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
/// This notation is similar to how scalable dims are marked when defining
@@ -108,7 +110,8 @@ void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
TypeRange valueTypes = TypeRange(),
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+ bool hasSameTypeDynamicValues = false);
inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
@@ -123,6 +126,13 @@ inline void printDynamicIndexList(
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
delimiter);
}
+inline void printSameTypeDynamicIndexList(
+ OpAsmPrinter &printer, Operation *op, OperandRange values,
+ ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+ return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
+ delimiter, /*hasSameTypeDynamicValues=*/true);
+}
/// Parser hook for custom directive in assemblyFormat.
///
@@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
SmallVectorImpl<Type> *valueTypes = nullptr,
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+ bool hasSameTypeDynamicValues = false);
inline ParseResult
parseDynamicIndexList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
@@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList(
return parseDynamicIndexList(parser, values, integers, scalableVals,
&valueTypes, delimiter);
}
+inline ParseResult parseSameTypeDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+ DenseBoolArrayAttr scalableVals = {};
+ return parseDynamicIndexList(parser, values, integers, scalableVals,
+ &valueTypes, delimiter,
+ /*hasSameTypeDynamicValues=*/true);
+}
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 04250f63dcd253..4d5b93ec09d175 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT {
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
- ParseResult parseCommaSeparatedList(Delimiter delimiter,
- function_ref<ParseResult()> parseElt,
- StringRef contextMessage) override {
- return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+ ParseResult parseCommaSeparatedList(
+ Delimiter delimiter, function_ref<ParseResult()> parseElt,
+ std::optional<function_ref<ParseResult()>> parseSuffix,
+ StringRef contextMessage) override {
+ return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix,
+ contextMessage);
}
+ using BaseT::parseCommaSeparatedList;
+
//===--------------------------------------------------------------------===//
// Keyword Parsing
//===--------------------------------------------------------------------===//
@@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT {
return parser.parseTypeListNoParens(result);
}
+ /// Parse an optional colon followed by a type.
+ ParseResult parseOptionalColonType(Type &result) override {
+ SmallVector<Type, 1> types;
+ ParseResult parseResult = parseOptionalColonTypeList(types);
+ if (llvm::succeeded(parseResult) && types.size() > 1)
+ return emitError(getCurrentLocation(), "expected single type");
+ if (!types.empty())
+ result = types[0];
+ return parseResult;
+ }
+
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic,
bool withTrailingX) override {
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8f19487d80fa39..6476910f71eb7f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
-ParseResult
-Parser::parseCommaSeparatedList(Delimiter delimiter,
- function_ref<ParseResult()> parseElementFn,
- StringRef contextMessage) {
+ParseResult Parser::parseCommaSeparatedList(
+ Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+ std::optional<function_ref<ParseResult()>> parseSuffixFn,
+ StringRef contextMessage) {
switch (delimiter) {
case Delimiter::None:
break;
@@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter,
return failure();
}
+ if (parseSuffixFn && (*parseSuffixFn)())
+ return failure();
+
switch (delimiter) {
case Delimiter::None:
return success();
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index bf91831798056b..1ebca05bbcb2ef 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -46,10 +46,17 @@ class Parser {
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
+ ParseResult parseCommaSeparatedList(
+ Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+ std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+ StringRef contextMessage = StringRef());
ParseResult
parseCommaSeparatedList(Delimiter delimiter,
function_ref<ParseResult()> parseElementFn,
- StringRef contextMessage = StringRef());
+ StringRef contextMessage) {
+ return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt,
+ contextMessage);
+ }
/// Parse a comma separated list of elements that must have at least one entry
/// in it.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 55965d9c2a531d..c5c3353bf0477f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering
///
/// Example:
/// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// %el = vector.extract %tile[%row, %col : index] : i32 from
+/// vector<[4]x[4]xi32>
/// ```
/// Becomes:
/// ```
/// %slice = arm_sme.extract_tile_slice %tile[%row]
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32>
/// ```
struct VectorExtractToArmSMELowering
: public OpRewritePattern<vector::ExtractOp> {
@@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering
/// ```
/// %slice = arm_sme.extract_tile_slice %tile[%row]
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
+/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into
+/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice,
+/// %tile[%row]
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
/// ```
struct VectorInsertToArmSMELowering
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..b623a86c53ee71 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
/// %vscale = vector.vscale
/// %c4_vscale = arith.muli %vscale, %c4 : index
/// scf.for %idx = %c0 to %c4_vscale step %c1 {
-/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
-/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
-/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
-/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+/// %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32>
+/// %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32>
+/// %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32>
+/// %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32>
/// %slice_i = affine.apply #map(%idx)[%i]
/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index ca33636336bf0c..8e44ff60eec874 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
ArrayRef<bool> scalables, TypeRange valueTypes,
- AsmParser::Delimiter delimiter) {
+ AsmParser::Delimiter delimiter,
+ bool hasSameTypeDynamicValues) {
char leftDelimiter = getLeftDelimiter(delimiter);
char rightDelimiter = getRightDelimiter(delimiter);
printer << leftDelimiter;
@@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
printer << "[";
if (ShapedType::isDynamic(integer)) {
printer << values[dynamicValIdx];
- if (!valueTypes.empty())
+ if (!hasSameTypeDynamicValues && !valueTypes.empty())
printer << " : " << valueTypes[dynamicValIdx];
++dynamicValIdx;
} else {
@@ -142,6 +143,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
scalableIndexIdx++;
});
+ if (hasSameTypeDynamicValues && !valueTypes.empty()) {
+ assert(std::all_of(valueTypes.begin(), valueTypes.end(),
+ [&](Type type) { return type == valueTypes[0]; }) &&
+ "Expected the same value types");
+ printer << " : " << valueTypes[0];
+ }
+
printer << rightDelimiter;
}
@@ -149,7 +157,8 @@ ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
- SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
+ SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter,
+ bool hasSameTypeDynamicValues) {
SmallVector<int64_t, 4> integerVals;
SmallVector<bool, 4> scalableVals;
@@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList(
if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
integerVals.push_back(ShapedType::kDynamic);
- if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+ if (!hasSameTypeDynamicValues && valueTypes &&
+ parser.parseColonType(valueTypes->emplace_back()))
return failure();
} else {
int64_t integer;
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList(
return failure();
return success();
};
+ auto parseColonType = [&]() -> ParseResult {
+ if (hasSameTypeDynamicValues) {
+ assert(valueTypes && "Expected non-null value types");
+ assert(valueTypes->empty() && "Expected no parsed value types");
+
+ Type dynValType;
+ if (parser.parseOptionalColonType(dynValType))
+ return failure();
+
+ if (!dynValType && !values.empty())
+ return parser.emitError(parser.getNameLoc())
+ << "expected a type for dynamic indices";
+ if (dynValType) {
+ if (values.empty())
+ return parser.emitError(parser.getNameLoc())
+ << "expected no type for constant indices";
+
+ // Broadcast the single type to all the dynamic values.
+ valueTypes->append(values.size(), dynValType);
+ ...
[truncated]
|
@llvm/pr-subscribers-mlir-spirv Author: Diego Caballero (dcaballe) Changes
This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed. Patch is 84.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115808.diff 22 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index c5b08d6aa022b1..dad08305b2a645 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -695,14 +695,14 @@ def Vector_ExtractOp :
%1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
%2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
%3 = vector.extract %1[]: vector<f32> from vector<f32>
- %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
- %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
+ %4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32>
+ %5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32>
```
}];
let arguments = (ins
AnyVectorOfAnyRank:$vector,
- Variadic<Index>:$dynamic_position,
+ Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
let results = (outs AnyType:$result);
@@ -737,7 +737,8 @@ def Vector_ExtractOp :
let assemblyFormat = [{
$vector ``
- custom<DynamicIndexList>($dynamic_position, $static_position)
+ custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+ type($dynamic_position))
attr-dict `:` type($result) `from` type($vector)
}];
@@ -883,15 +884,15 @@ def Vector_InsertOp :
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
%8 = vector.insert %6, %7[] : f32 into vector<f32>
- %11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
- %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
+ %11 = vector.insert %9, %10[%a, %b, %c : index] : vector<f32> into vector<4x8x16xf32>
+ %12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32>
```
}];
let arguments = (ins
AnyType:$source,
AnyVectorOfAnyRank:$dest,
- Variadic<Index>:$dynamic_position,
+ Variadic<AnySignlessIntegerOrIndex>:$dynamic_position,
DenseI64ArrayAttr:$static_position
);
let results = (outs AnyVectorOfAnyRank:$result);
@@ -926,7 +927,9 @@ def Vector_InsertOp :
}];
let assemblyFormat = [{
- $source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
+ $source `,` $dest
+ custom<SameTypeDynamicIndexList>($dynamic_position, $static_position,
+ type($dynamic_position))
attr-dict `:` type($source) `into` type($dest)
}];
@@ -1344,7 +1347,7 @@ def Vector_TransferReadOp :
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
// Update the temporary gathered slice with the individual element
%slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
- %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
+ %updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32>
memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
}}}
// At this point we gathered the elements from the original
@@ -1367,7 +1370,7 @@ def Vector_TransferReadOp :
%a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
%slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
// Here we only store to the first element in dimension one
- %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
+ %updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32>
memref.store %updated, %tmp : memref<vector<3x4x5xf32>>
}}
// At this point we gathered the elements from the original
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a7222794f320b2..699dd1da863b6f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -794,16 +794,26 @@ class AsmParser {
};
/// Parse a list of comma-separated items with an optional delimiter. If a
- /// delimiter is provided, then an empty list is allowed. If not, then at
+ /// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
///
+ /// `parseSuffixFn` is an optional function to parse any suffix that can be
+ /// appended to the comma separated list within the delimiter.
+ ///
/// contextMessage is an optional message appended to "expected '('" sorts of
/// diagnostics when parsing the delimeters.
- virtual ParseResult
+ virtual ParseResult parseCommaSeparatedList(
+ Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+ std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+ StringRef contextMessage = StringRef()) = 0;
+ ParseResult
parseCommaSeparatedList(Delimiter delimiter,
function_ref<ParseResult()> parseElementFn,
- StringRef contextMessage = StringRef()) = 0;
-
+ StringRef contextMessage) {
+ return parseCommaSeparatedList(delimiter, parseElementFn,
+ /*parseSuffixFn=*/std::nullopt,
+ contextMessage);
+ }
/// Parse a comma separated list of elements that must have at least one entry
/// in it.
ParseResult
@@ -1319,6 +1329,9 @@ class AsmParser {
virtual ParseResult
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
+ /// Parse an optional colon followed by a type.
+ virtual ParseResult parseOptionalColonType(Type &result) = 0;
+
/// Parse a keyword followed by a type.
ParseResult parseKeywordType(const char *keyword, Type &result) {
return failure(parseKeyword(keyword) || parseType(result));
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 3dcbd2f1af1936..1971c25a8f20b1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes`
/// is non-empty, it is expected to contain as many elements as `values`
/// indicating their types. This allows idiomatic printing of mixed value and
-/// integer attributes in a list. E.g.
-/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`.
+/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the
+/// same and only one type is printed at the end of the list. E.g.,
+/// `[0, %arg2, 3, %arg42, 2 : i8]`.
///
/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
/// This notation is similar to how scalable dims are marked when defining
@@ -108,7 +110,8 @@ void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
TypeRange valueTypes = TypeRange(),
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+ bool hasSameTypeDynamicValues = false);
inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
@@ -123,6 +126,13 @@ inline void printDynamicIndexList(
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
delimiter);
}
+inline void printSameTypeDynamicIndexList(
+ OpAsmPrinter &printer, Operation *op, OperandRange values,
+ ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+ return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
+ delimiter, /*hasSameTypeDynamicValues=*/true);
+}
/// Parser hook for custom directive in assemblyFormat.
///
@@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
SmallVectorImpl<Type> *valueTypes = nullptr,
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+ bool hasSameTypeDynamicValues = false);
inline ParseResult
parseDynamicIndexList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
@@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList(
return parseDynamicIndexList(parser, values, integers, scalableVals,
&valueTypes, delimiter);
}
+inline ParseResult parseSameTypeDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+ DenseBoolArrayAttr scalableVals = {};
+ return parseDynamicIndexList(parser, values, integers, scalableVals,
+ &valueTypes, delimiter,
+ /*hasSameTypeDynamicValues=*/true);
+}
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 04250f63dcd253..4d5b93ec09d175 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT {
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
- ParseResult parseCommaSeparatedList(Delimiter delimiter,
- function_ref<ParseResult()> parseElt,
- StringRef contextMessage) override {
- return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+ ParseResult parseCommaSeparatedList(
+ Delimiter delimiter, function_ref<ParseResult()> parseElt,
+ std::optional<function_ref<ParseResult()>> parseSuffix,
+ StringRef contextMessage) override {
+ return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix,
+ contextMessage);
}
+ using BaseT::parseCommaSeparatedList;
+
//===--------------------------------------------------------------------===//
// Keyword Parsing
//===--------------------------------------------------------------------===//
@@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT {
return parser.parseTypeListNoParens(result);
}
+ /// Parse an optional colon followed by a type.
+ ParseResult parseOptionalColonType(Type &result) override {
+ SmallVector<Type, 1> types;
+ ParseResult parseResult = parseOptionalColonTypeList(types);
+ if (llvm::succeeded(parseResult) && types.size() > 1)
+ return emitError(getCurrentLocation(), "expected single type");
+ if (!types.empty())
+ result = types[0];
+ return parseResult;
+ }
+
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic,
bool withTrailingX) override {
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8f19487d80fa39..6476910f71eb7f 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default;
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
-ParseResult
-Parser::parseCommaSeparatedList(Delimiter delimiter,
- function_ref<ParseResult()> parseElementFn,
- StringRef contextMessage) {
+ParseResult Parser::parseCommaSeparatedList(
+ Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+ std::optional<function_ref<ParseResult()>> parseSuffixFn,
+ StringRef contextMessage) {
switch (delimiter) {
case Delimiter::None:
break;
@@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter,
return failure();
}
+ if (parseSuffixFn && (*parseSuffixFn)())
+ return failure();
+
switch (delimiter) {
case Delimiter::None:
return success();
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index bf91831798056b..1ebca05bbcb2ef 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -46,10 +46,17 @@ class Parser {
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
+ ParseResult parseCommaSeparatedList(
+ Delimiter delimiter, function_ref<ParseResult()> parseElementFn,
+ std::optional<function_ref<ParseResult()>> parseSuffixFn = std::nullopt,
+ StringRef contextMessage = StringRef());
ParseResult
parseCommaSeparatedList(Delimiter delimiter,
function_ref<ParseResult()> parseElementFn,
- StringRef contextMessage = StringRef());
+ StringRef contextMessage) {
+ return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt,
+ contextMessage);
+ }
/// Parse a comma separated list of elements that must have at least one entry
/// in it.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 55965d9c2a531d..c5c3353bf0477f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering
///
/// Example:
/// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// %el = vector.extract %tile[%row, %col : index] : i32 from
+/// vector<[4]x[4]xi32>
/// ```
/// Becomes:
/// ```
/// %slice = arm_sme.extract_tile_slice %tile[%row]
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32>
/// ```
struct VectorExtractToArmSMELowering
: public OpRewritePattern<vector::ExtractOp> {
@@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering
/// ```
/// %slice = arm_sme.extract_tile_slice %tile[%row]
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
+/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into
+/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice,
+/// %tile[%row]
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
/// ```
struct VectorInsertToArmSMELowering
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..b623a86c53ee71 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
/// %vscale = vector.vscale
/// %c4_vscale = arith.muli %vscale, %c4 : index
/// scf.for %idx = %c0 to %c4_vscale step %c1 {
-/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
-/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
-/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
-/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
+/// %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32>
+/// %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32>
+/// %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32>
+/// %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32>
/// %slice_i = affine.apply #map(%idx)[%i]
/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index ca33636336bf0c..8e44ff60eec874 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
ArrayRef<bool> scalables, TypeRange valueTypes,
- AsmParser::Delimiter delimiter) {
+ AsmParser::Delimiter delimiter,
+ bool hasSameTypeDynamicValues) {
char leftDelimiter = getLeftDelimiter(delimiter);
char rightDelimiter = getRightDelimiter(delimiter);
printer << leftDelimiter;
@@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
printer << "[";
if (ShapedType::isDynamic(integer)) {
printer << values[dynamicValIdx];
- if (!valueTypes.empty())
+ if (!hasSameTypeDynamicValues && !valueTypes.empty())
printer << " : " << valueTypes[dynamicValIdx];
++dynamicValIdx;
} else {
@@ -142,6 +143,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
scalableIndexIdx++;
});
+ if (hasSameTypeDynamicValues && !valueTypes.empty()) {
+ assert(std::all_of(valueTypes.begin(), valueTypes.end(),
+ [&](Type type) { return type == valueTypes[0]; }) &&
+ "Expected the same value types");
+ printer << " : " << valueTypes[0];
+ }
+
printer << rightDelimiter;
}
@@ -149,7 +157,8 @@ ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
- SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
+ SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter,
+ bool hasSameTypeDynamicValues) {
SmallVector<int64_t, 4> integerVals;
SmallVector<bool, 4> scalableVals;
@@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList(
if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
integerVals.push_back(ShapedType::kDynamic);
- if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+ if (!hasSameTypeDynamicValues && valueTypes &&
+ parser.parseColonType(valueTypes->emplace_back()))
return failure();
} else {
int64_t integer;
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList(
return failure();
return success();
};
+ auto parseColonType = [&]() -> ParseResult {
+ if (hasSameTypeDynamicValues) {
+ assert(valueTypes && "Expected non-null value types");
+ assert(valueTypes->empty() && "Expected no parsed value types");
+
+ Type dynValType;
+ if (parser.parseOptionalColonType(dynValType))
+ return failure();
+
+ if (!dynValType && !values.empty())
+ return parser.emitError(parser.getNameLoc())
+ << "expected a type for dynamic indices";
+ if (dynValType) {
+ if (values.empty())
+ return parser.emitError(parser.getNameLoc())
+ << "expected no type for constant indices";
+
+ // Broadcast the single type to all the dynamic values.
+ valueTypes->append(values.size(), dynValType);
+ ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this - mostly makes sense, but I have a few questions :)
What's the end-goal here - this would be quite desirable for vector.gather
. Importantly, how should we decide what type to use for the index variables? This change creates a tricky decision point for code-gen 🤔
Btw, it would be good to add new tests to vector-to-llvm.mlir to demonstrate the impact of this on the actual "end result" (i.e. LLVM IR). I guess this makes little sense if things don't change at the LLVM level?
@@ -271,6 +304,38 @@ func.func @insert_0d(%a: f32, %b: vector<f32>) { | |||
%1 = vector.insert %a, %b[0] : f32 into vector<f32> | |||
} | |||
|
|||
// ----- | |||
func.func @extract_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
func.func @extract_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>, | |
func.func @insert_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>, |
Same comment for the tests below. Note that you are inserting a single value rather than a vector (@extract_vector
-> @insert_vector
-> @insert_value
?)
mlir/test/Dialect/Vector/ops.mlir
Outdated
|
||
// CHECK-LABEL: @extract_val_int | ||
// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8 | ||
func.func @extract_val_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] This is clear today, but I know that my future self will be grateful for a bit descriptive name :)
func.func @extract_val_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32, | |
func.func @extract_val_idx_as_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32, |
@@ -274,7 +274,7 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect | |||
// CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> | |||
func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { | |||
%idx = arith.constant 2 : index | |||
%0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32> | |||
%0 = vector.insert %val, %arg0[%idx : index] : f32 into vector<4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add tests for i8 and maybe i1 as the insert/extract index types? These require type conversion in the general case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Jakub, I added i32 and i8 tests but they need to modify the spirv conversion, as you mentioned. Would you mind helping with that? I have no idea about that pass. How would you like to proceed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you file an issue for this and link to it in the PR description and perhaps in some TODO in the code?
The end goal is to be able to default to the widest type but use narrower types when needed and it's safe. That should be simple enough. This shouldn't add more complexity to codegen than we already have. We already have to deal with I would go even further:
Exactly! The same principle applies. However, note that
AFAIK, LLVM's |
So I've been trying to figure out the right mechanism to select the right index size 😅 Suggestions are much appreciated :) At a very coarse grain level we could use the architecture pointer size, but this way we'd be mostly switching between 32 and 64 bits. #not-good-enough :)
Are you thinking that the default for |
Yes and no. What I mean here is that gather indices are limited to the memory in the system. Extracts/inserts... not necessarily... For example, could we create a |
Is this something we can land already? Any other comments? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
Please wait for @kuhar to also approve.
New year's ping :) Hopefully we can land it before it gets too stale. |
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final | |||
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grammar in this sentence is broken: is kDynamic
. Can you fix that also?
/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`. | ||
/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the | ||
/// same and only one type is printed at the end of the list. E.g., | ||
/// `[0, %arg2, 3, %arg42, 2 : i8]`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify that the number of types in valueTypes
must match the number of dynamic elements, even if hasSameTypeDynamicValues
is set.
Btw, have you considered changing the API such that valueTypes
contains only a single value in case of hasSameTypeDynamicValues
? That would seem more natural to me.
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final | |||
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this function have valueTypes
parameter? The type can be taken from the SSA values in values
. Is it possible to remove valueTypes
? I think a bool printTypes
should be sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
valueTypes
is needed for integers
. We are representing integers
with int64_t
but their actual type comes from valueTypes
. Perhaps we should rename this to itemTypes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
values
can also be empty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, the types are not used for integers
(!). I think valueTypes
is needed to match the function signature expected by custom($values, $integers, type($values)). We could verify that both $values
and $type($values)
match and then use just one of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into this a bit. I think valueTypes
in needed in the parser, so that the user-specified types can be checked against the actual types in resolveOperands
. printDynamicIndexList
just has it for consistency.
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, | |||
OperandRange values, | |||
ArrayRef<int64_t> integers, | |||
ArrayRef<bool> scalables, TypeRange valueTypes, | |||
AsmParser::Delimiter delimiter) { | |||
AsmParser::Delimiter delimiter, | |||
bool hasSameTypeDynamicValues) { | |||
char leftDelimiter = getLeftDelimiter(delimiter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case we keep the valueTypes
parameter, I think there should be an assert
that checks the number of elements in valueTypes
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can we assert that TypeRange(values) == valueTypes
?
@@ -142,14 +143,22 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, | |||
scalableIndexIdx++; | |||
}); | |||
|
|||
if (hasSameTypeDynamicValues && !valueTypes.empty()) { | |||
assert(std::all_of(valueTypes.begin(), valueTypes.end(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llvm::all_equal
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList( | |||
return failure(); | |||
return success(); | |||
}; | |||
auto parseColonType = [&]() -> ParseResult { | |||
if (hasSameTypeDynamicValues) { | |||
assert(valueTypes && "Expected non-null value types"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is valueTypes
requires when hasSameTypeDynamicValues
is "true", but it is not required when hasSameTypeDynamicValues
is "false"?
} | ||
|
||
using BaseT::parseCommaSeparatedList; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there two overloads, one calling the super implementation and the other one calling parser.parseCommaSeparatedList
?
return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); | ||
ParseResult parseCommaSeparatedList( | ||
Delimiter delimiter, function_ref<ParseResult()> parseElt, | ||
std::optional<function_ref<ParseResult()>> parseSuffix, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move parseSuffix
to the end and make it function_ref<ParseResult()> parseSuffix = nullptr
? function_ref
is rarely used with std::optional
. Also, maybe this function would not be needed at all then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is needed to parse the suffix, which is optional. I'm using a function_ref
similar to what we do for parseElt
.
StringRef contextMessage) { | ||
ParseResult Parser::parseCommaSeparatedList( | ||
Delimiter delimiter, function_ref<ParseResult()> parseElementFn, | ||
std::optional<function_ref<ParseResult()>> parseSuffixFn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't function_ref nullable on its own? This is the case for plain std::function
. I'd like to avoid double nullability where possible.
Thanks for the cleanup, @matthias-springer! I'm revisiting some implementation details and making it more generic. Hang tight! :) |
…ter handlers. This PR addresses part of the feedback provided in llvm#115808.
…ter handlers. This PR addresses part of the feedback provided in llvm#115808.
…ter handlers. This PR addresses part of the feedback provided in llvm#115808.
…oads (llvm#122436) llvm#115808 adds additional `custom<>` parser/printer variants. The overall list of overloads/variants is getting larger. This commit removes overloads that are not needed, to keep the parser/printer simple.
…ter handlers. This PR addresses part of the feedback provided in llvm#115808.
hey @dcaballe , do you have the cycles to progress this? It would be great to see it in-tree :) |
Yes, actually, I put quite some time on this internally and was discussing with @matthias-springer about it. Unfortunately, we hit a dead-end and I have to backtrack some of the generalization changes so the current PR is pretty close to how the final state would look like, at least in terms of functionality. |
vector.extract
andvector.insert
can currently take ani64
constant or anindex
type value as indices. Theindex
type will usually lower to ani32
ori64
type. However, we are often indexing really small vector dimensions where smaller integers could be used. This PR extends both ops to accept any integer value as indices. For example:This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed.