-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][vector] Allow integer indices in vector.extract/insert ops #115808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final | |
/// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this function have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, the types are not used for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked into this a bit. I think |
||
/// is non-empty, it is expected to contain as many elements as `values` | ||
/// indicating their types. This allows idiomatic printing of mixed value and | ||
/// integer attributes in a list. E.g. | ||
/// `[%arg0 : index, 7, 42, %arg42 : i32]`. | ||
/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`. | ||
/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the | ||
/// same and only one type is printed at the end of the list. E.g., | ||
/// `[0, %arg2, 3, %arg42, 2 : i8]`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarify that the number of types in Btw, have you considered changing the API such that |
||
/// | ||
/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. | ||
/// This notation is similar to how scalable dims are marked when defining | ||
|
@@ -108,7 +110,8 @@ void printDynamicIndexList( | |
OpAsmPrinter &printer, Operation *op, OperandRange values, | ||
ArrayRef<int64_t> integers, ArrayRef<bool> scalables, | ||
TypeRange valueTypes = TypeRange(), | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square, | ||
bool hasSameTypeDynamicValues = false); | ||
inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, | ||
OperandRange values, | ||
ArrayRef<int64_t> integers, | ||
|
@@ -123,6 +126,13 @@ inline void printDynamicIndexList( | |
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes, | ||
delimiter); | ||
} | ||
inline void printSameTypeDynamicIndexList( | ||
OpAsmPrinter &printer, Operation *op, OperandRange values, | ||
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(), | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { | ||
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes, | ||
delimiter, /*hasSameTypeDynamicValues=*/true); | ||
} | ||
|
||
/// Parser hook for custom directive in assemblyFormat. | ||
/// | ||
|
@@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList( | |
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, | ||
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, | ||
SmallVectorImpl<Type> *valueTypes = nullptr, | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square, | ||
bool hasSameTypeDynamicValues = false); | ||
inline ParseResult | ||
parseDynamicIndexList(OpAsmParser &parser, | ||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, | ||
|
@@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList( | |
return parseDynamicIndexList(parser, values, integers, scalableVals, | ||
&valueTypes, delimiter); | ||
} | ||
inline ParseResult parseSameTypeDynamicIndexList( | ||
OpAsmParser &parser, | ||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, | ||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes, | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { | ||
DenseBoolArrayAttr scalableVals = {}; | ||
return parseDynamicIndexList(parser, values, integers, scalableVals, | ||
&valueTypes, delimiter, | ||
/*hasSameTypeDynamicValues=*/true); | ||
} | ||
|
||
/// Verify that a the `values` has as many elements as the number of entries in | ||
/// `attr` for which `isDynamic` evaluates to true. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT { | |
/// Parse a list of comma-separated items with an optional delimiter. If a | ||
/// delimiter is provided, then an empty list is allowed. If not, then at | ||
/// least one element will be parsed. | ||
ParseResult parseCommaSeparatedList(Delimiter delimiter, | ||
function_ref<ParseResult()> parseElt, | ||
StringRef contextMessage) override { | ||
return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); | ||
ParseResult parseCommaSeparatedList( | ||
Delimiter delimiter, function_ref<ParseResult()> parseElt, | ||
std::optional<function_ref<ParseResult()>> parseSuffix, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function is needed to parse the suffix, which is optional. I'm using a |
||
StringRef contextMessage) override { | ||
return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix, | ||
contextMessage); | ||
} | ||
|
||
using BaseT::parseCommaSeparatedList; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are there two overloads, one calling the super implementation and the other one calling |
||
|
||
//===--------------------------------------------------------------------===// | ||
// Keyword Parsing | ||
//===--------------------------------------------------------------------===// | ||
|
@@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT { | |
return parser.parseTypeListNoParens(result); | ||
} | ||
|
||
/// Parse an optional colon followed by a type. | ||
ParseResult parseOptionalColonType(Type &result) override { | ||
SmallVector<Type, 1> types; | ||
ParseResult parseResult = parseOptionalColonTypeList(types); | ||
if (llvm::succeeded(parseResult) && types.size() > 1) | ||
return emitError(getCurrentLocation(), "expected single type"); | ||
if (!types.empty()) | ||
result = types[0]; | ||
return parseResult; | ||
} | ||
|
||
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, | ||
bool allowDynamic, | ||
bool withTrailingX) override { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default; | |
/// Parse a list of comma-separated items with an optional delimiter. If a | ||
/// delimiter is provided, then an empty list is allowed. If not, then at | ||
/// least one element will be parsed. | ||
ParseResult | ||
Parser::parseCommaSeparatedList(Delimiter delimiter, | ||
function_ref<ParseResult()> parseElementFn, | ||
StringRef contextMessage) { | ||
ParseResult Parser::parseCommaSeparatedList( | ||
Delimiter delimiter, function_ref<ParseResult()> parseElementFn, | ||
std::optional<function_ref<ParseResult()>> parseSuffixFn, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't function_ref nullable on its own? This is the case for plain |
||
StringRef contextMessage) { | ||
switch (delimiter) { | ||
case Delimiter::None: | ||
break; | ||
|
@@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter, | |
return failure(); | ||
} | ||
|
||
if (parseSuffixFn && (*parseSuffixFn)()) | ||
return failure(); | ||
|
||
switch (delimiter) { | ||
case Delimiter::None: | ||
return success(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, | |
OperandRange values, | ||
ArrayRef<int64_t> integers, | ||
ArrayRef<bool> scalables, TypeRange valueTypes, | ||
AsmParser::Delimiter delimiter) { | ||
AsmParser::Delimiter delimiter, | ||
bool hasSameTypeDynamicValues) { | ||
char leftDelimiter = getLeftDelimiter(delimiter); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case we keep the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, can we assert that |
||
char rightDelimiter = getRightDelimiter(delimiter); | ||
printer << leftDelimiter; | ||
|
@@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, | |
printer << "["; | ||
if (ShapedType::isDynamic(integer)) { | ||
printer << values[dynamicValIdx]; | ||
if (!valueTypes.empty()) | ||
if (!hasSameTypeDynamicValues && !valueTypes.empty()) | ||
printer << " : " << valueTypes[dynamicValIdx]; | ||
++dynamicValIdx; | ||
} else { | ||
|
@@ -142,14 +143,22 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, | |
scalableIndexIdx++; | ||
}); | ||
|
||
if (hasSameTypeDynamicValues && !valueTypes.empty()) { | ||
assert(std::all_of(valueTypes.begin(), valueTypes.end(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
[&](Type type) { return type == valueTypes[0]; }) && | ||
"Expected the same value types"); | ||
printer << " : " << valueTypes[0]; | ||
} | ||
|
||
printer << rightDelimiter; | ||
} | ||
|
||
ParseResult mlir::parseDynamicIndexList( | ||
OpAsmParser &parser, | ||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, | ||
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, | ||
SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) { | ||
SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter, | ||
bool hasSameTypeDynamicValues) { | ||
|
||
SmallVector<int64_t, 4> integerVals; | ||
SmallVector<bool, 4> scalableVals; | ||
|
@@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList( | |
if (res.has_value() && succeeded(res.value())) { | ||
values.push_back(operand); | ||
integerVals.push_back(ShapedType::kDynamic); | ||
if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) | ||
if (!hasSameTypeDynamicValues && valueTypes && | ||
parser.parseColonType(valueTypes->emplace_back())) | ||
return failure(); | ||
} else { | ||
int64_t integer; | ||
|
@@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList( | |
return failure(); | ||
return success(); | ||
}; | ||
auto parseColonType = [&]() -> ParseResult { | ||
if (hasSameTypeDynamicValues) { | ||
assert(valueTypes && "Expected non-null value types"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is |
||
assert(valueTypes->empty() && "Expected no parsed value types"); | ||
|
||
Type dynValType; | ||
if (parser.parseOptionalColonType(dynValType)) | ||
return failure(); | ||
|
||
if (!dynValType && !values.empty()) | ||
return parser.emitError(parser.getNameLoc()) | ||
<< "expected a type for dynamic indices"; | ||
if (dynValType) { | ||
if (values.empty()) | ||
return parser.emitError(parser.getNameLoc()) | ||
<< "expected no type for constant indices"; | ||
|
||
// Broadcast the single type to all the dynamic values. | ||
valueTypes->append(values.size(), dynValType); | ||
} | ||
} | ||
return success(); | ||
}; | ||
if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue, | ||
" in dynamic index list")) | ||
parseColonType, " in dynamic index list")) | ||
return parser.emitError(parser.getNameLoc()) | ||
<< "expected SSA value or integer"; | ||
<< "expected a valid list of SSA values or integers"; | ||
|
||
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); | ||
scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); | ||
return success(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grammar in this sentence is broken:
is kDynamic
. Can you fix that also?