From 110775809ad114e190132290657a86b2c292a878 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 11 Jan 2021 20:42:10 +0000 Subject: [PATCH] Revert "[mlir][linalg] Support parsing attributes in named op spec" This reverts commit df86f15f0c53c395dac5a14aba08745bc12b9b9b. The gcc-5 build was broken by this change: mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp:1275:77: required from here /usr/include/c++/5/ext/new_allocator.h:120:4: error: no matching function for call to 'std::pair, {anonymous}::TCParser::RegisteredAttr>::pair(llvm::StringRef&, {anonymous}::TCParser::RegisteredAttr' --- .../test-linalg-ods-gen.tc | 22 --- .../mlir-linalg-ods-gen.cpp | 180 +----------------- 2 files changed, 4 insertions(+), 198 deletions(-) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc index 1ef12876063778..f81380f02bb382 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -72,25 +72,3 @@ ods_def : def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); } - -// Test attribute definitions -// ODS-LABEL: def Test4Op -// ODS: F32ArrayAttr:$array_attr, -// ODS: F32:$f32_attr, -// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr, -// ODS: I32:$i32_attr, -// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr, -// ODS: OptionalAttr:$optional_attr -// -ods_def : -def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) -attr( - f32_attr: f32, - i32_attr: i32, - fvec_attr: 4xf32, - ivec_attr: 5x6xi32, - array_attr : f32[], - optional_attr? : f32 -) { - C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); -} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp index e7ab5edc1118b6..592e6cb774fbf6 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -20,17 +20,11 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSwitch.h" -#include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" -#include - #define DEBUG_TYPE "linalg-ods-gen" static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen"); @@ -85,14 +79,11 @@ class Token { gt, l_brace, l_paren, - l_square, lt, minus, plus, - question, r_brace, r_paren, - r_square, semicolon, star, @@ -100,7 +91,6 @@ class Token { kw_def, FIRST_KEYWORD = kw_def, kw_ods_def, - kw_attr_def, kw_floordiv, kw_ceildiv, kw_mod, @@ -161,10 +151,6 @@ class Lexer { Token emitError(llvm::SMLoc loc, const Twine &msg); Token emitError(const char *loc, const Twine &msg); - /// Change the position of the lexer cursor. The next token we lex will start - /// at the designated point in the input. - void resetPointer(const char *newPtr) { curPtr = newPtr; } - private: Token formToken(Token::Kind kind, const char *tokStart) { return Token(kind, StringRef(tokStart, curPtr - tokStart)); @@ -261,14 +247,10 @@ Token Lexer::lexToken() { return formToken(Token::Kind::l_brace, tokStart); case '(': return formToken(Token::Kind::l_paren, tokStart); - case '[': - return formToken(Token::Kind::l_square, tokStart); case '}': return formToken(Token::Kind::r_brace, tokStart); case ')': return formToken(Token::Kind::r_paren, tokStart); - case ']': - return formToken(Token::Kind::r_square, tokStart); case '<': return formToken(Token::Kind::lt, tokStart); case '>': @@ -281,8 +263,6 @@ Token Lexer::lexToken() { return formToken(Token::Kind::semicolon, tokStart); case '*': return formToken(Token::Kind::star, tokStart); - case '?': - return formToken(Token::Kind::question, tokStart); case '/': if (*curPtr == '/') { skipComment(); @@ -309,7 +289,6 @@ Token Lexer::lexIdentifier(const char *tokStart) { // Check to see if this identifier is a keyword. StringRef str(tokStart, curPtr - tokStart); Token::Kind kind = StringSwitch(str) - .Case("attr", Token::Kind::kw_attr_def) .Case("def", Token::Kind::kw_def) .Case("ods_def", Token::Kind::kw_ods_def) .Case("floordiv", Token::Kind::kw_floordiv) @@ -373,40 +352,29 @@ class Parser { "shouldn't advance past EOF or errors"); curToken = lexer.lexToken(); } - void consumeToken(Token::Kind kind) { assert(curToken.getKind() == kind && "unexpected token"); curToken = lexer.lexToken(); } - LogicalResult parseToken(Token::Kind kind, const Twine &msg) { if (curToken.getKind() != kind) return emitError(curToken.getLoc(), msg); consumeToken(); return success(); } - - /// Parses an optional token and returns failure if failed to parse. - LogicalResult parseOptionalToken(Token::Kind kind) { - return success(consumeIf(kind)); - } - LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { lexer.emitError(loc, msg); return failure(); } - LogicalResult emitError(const Twine &msg) { return emitError(curToken.getLoc(), msg); } - bool consumeIf(Token::Kind kind) { if (curToken.isNot(kind)) return false; consumeToken(kind); return true; } - LogicalResult parseCommaSeparatedList(llvm::function_ref parseElement) { // Non-empty case starts with an element. @@ -420,7 +388,6 @@ class Parser { } return success(); } - LogicalResult parseCommaSeparatedListUntil(Token::Kind rightToken, llvm::function_ref parseElement, @@ -994,8 +961,6 @@ class TCParser { LogicalResult parseTensorUse(TensorUse &result, ComprehensionParsingState &state); - LogicalResult parseAttrDef(); - /// Parses a tensor expression. LogicalResult parseExpression(TensorUse currentDefinition, std::unique_ptr &result, @@ -1045,29 +1010,15 @@ class TCParser { unsigned index; }; - //===--------------------------------------------------------------------===// - // Internal bookkeeping of attributes. - //===--------------------------------------------------------------------===// - struct RegisteredAttr { - StringRef elementType; - SmallVector vectorDims; - bool isArray; - bool isOptional; - }; - //===--------------------------------------------------------------------===// // Per-TC def state. //===--------------------------------------------------------------------===// /// Symbols are per TC def. AffineSymbolList symbols; - /// Tensors are per TC def. llvm::StringMap registeredTensors; unsigned nextRegisteredTensorIndex; - /// Attributes are per TC def. - std::map registeredAttrs; - Parser &parser; }; } // namespace @@ -1219,72 +1170,6 @@ LogicalResult TCParser::parseTensorUse(TensorUse &result, return success(); } -/// Parse the information for an attribute def of the form: -/// -/// affine-expr-list ::= affine-expr (`,` affine-expr )* -/// attr-id ::= bare-id (`?`)? -/// dim-list ::= (integer-literal 'x')+ -/// attr-typedef ::= dim-list? type (`[` `]`)? -/// attr-def ::= attr-id `:` attr-typedef -LogicalResult TCParser::parseAttrDef() { - auto attrLoc = parser.curToken.getLoc(); - StringRef attrName = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) - return failure(); - bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question)); - if (failed(parser.parseToken(Token::Kind::colon, "expected colon"))) - return failure(); - - // Parse the attribute's type. We don't expect the type to be arbitrary - // complex, so just use this ad-hoc handling here. - - // Parse potential dimension list - SmallVector vectorDims; - while (parser.curToken.is(Token::Kind::integer)) { - vectorDims.push_back(parser.curToken.getUInt64IntegerValue().getValue()); - parser.consumeToken(); - - StringRef spelling = parser.curToken.getSpelling(); - if (spelling[0] != 'x') - return parser.emitError(parser.curToken.getLoc(), - "expected 'x' in dimension list"); - - // If we had a prefix of 'x', lex the next token immediately after the 'x'. - if (spelling.size() != 1) - parser.lexer.resetPointer(spelling.data() + 1); - - parser.consumeToken(); - } - - StringRef elementType = parser.curToken.getSpelling(); - if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) - return failure(); - - bool isArray = false; - auto arrayLoc = parser.curToken.getLoc(); - if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) { - isArray = true; - if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'"))) - return failure(); - } - - if (!vectorDims.empty() && isArray) - return parser.emitError(arrayLoc, "unsupported vector array attribute"); - - auto iterBoolPair = registeredAttrs.emplace( - attrName, RegisteredAttr{elementType, vectorDims, isArray, isOptional}); - if (!iterBoolPair.second) - return parser.emitError(attrLoc, - "Failed to register attribute '" + attrName + "'"); - - LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "") - << " " << attrName << " " - << "with type: " << elementType - << (isArray ? "[]" : "") << "\n"); - - return success(); -} - /// Parses a tensor expression of the form: /// /// op-spec ::= bare-id `<` reduction-dims-list `>` @@ -1456,13 +1341,10 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName, /// Parse and print the information for a ODS def. /// /// tensor-def-list ::= tensor-def (`,` tensor-def )* -/// attr-def-list ::= attr-def (`,` attr-def )* /// /// comprehension-list ::= comprehension comprehension* /// -/// tc-attr-def ::= `attr` `(` attr-def-list `)` /// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)` -/// (tc-attr-def)? /// `{` comprehension-list `}` /// /// ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def @@ -1471,7 +1353,6 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName, /// contain only expressions involving symbols and constants), but can /// otherwise contain arbitrary affine expressions. LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { - // Parse def header (including C++ op name) if (failed(parser.parseToken(Token::Kind::kw_ods_def, "expected 'ods_def' to define a TC ODS")) || failed(parser.parseToken(Token::Kind::lt, "expected '<'"))) @@ -1483,15 +1364,12 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { failed(parser.parseToken(Token::Kind::gt, "expected '>'")) || failed(parser.parseToken(Token::Kind::colon, "expected ':'"))) return failure(); - if (failed(parser.parseToken(Token::Kind::kw_def, "expected 'def' to define a TC"))) return failure(); StringRef tcName = parser.curToken.getSpelling(); LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n"); - - // Parse input/output tensor definitions if (failed(parser.parseToken(Token::Kind::id, "expected id")) || failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) return failure(); @@ -1514,16 +1392,6 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false))) return failure(); - // Parse optional attribute definitions - if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) { - if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) - return failure(); - if (failed(parser.parseCommaSeparatedListUntil( - Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this), - /*allowEmptyList=*/false))) - return failure(); - } - // Since we don't declare symbols separately, we discover them eagerly: each // newly encountered id in a tensor shape expression is treated as a new // symbolic. At this point, all tensors have been parsed and all the symbols @@ -1582,52 +1450,12 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, StringRef linalgOpName, ComprehensionParsingState &state) { - SmallVector attributes; - for (const auto &attr : registeredAttrs) { - llvm::StringRef name = attr.first; - - llvm::StringRef elementType = attr.second.elementType; - std::string odsType = llvm::StringSwitch(elementType) - .Case("f32", "F32") - .Case("i32", "I32") - .Default(""); - if (odsType.empty()) { - parser.emitError("unimplemented support for attribute element type: " + - elementType); - return; - } - - const auto &dims = attr.second.vectorDims; - if (!dims.empty()) { - SmallVector dimStrs; - for (uint64_t dim : dims) - dimStrs.push_back(std::to_string(dim)); - odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType, - llvm::join(dimStrs, ", ")); - } - - assert(dims.empty() || !attr.second.isArray); - if (attr.second.isArray) - odsType = llvm::formatv("{0}ArrayAttr", odsType); - - if (attr.second.isOptional) - odsType = llvm::formatv("OptionalAttr<{0}>", odsType); - - attributes.push_back(llvm::formatv("{0}:${1}", odsType, name)); - } - - std::string attrList = llvm::join(attributes, ",\n"); - if (!attrList.empty()) - attrList = ",\n" + attrList; - const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">]> { - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs{4} - ); + let arguments = (ins Variadic:$inputs, + Variadic:$outputs); let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); @@ -1687,7 +1515,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, static std::function getRegionBuilder() {{ return regionBuilder; } // Generic methods. - static unsigned getNumRegionArgs() {{ return {5}; } + static unsigned getNumRegionArgs() {{ return {4}; } std::string getLibraryCallName() {{ return generateLibraryCallName(getOperation()); } @@ -1703,7 +1531,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, } os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs, - attrList, state.orderedTensorArgs.size()); + state.orderedTensorArgs.size()); } /// Print the C++ StructuredOpsInterface impl of `iterator_types`.