Skip to content

Commit

Permalink
Revert "[mlir][linalg] Support parsing attributes in named op spec"
Browse files Browse the repository at this point in the history
This reverts commit df86f15.

The gcc-5 build was broken by this change:

  mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp:1275:77:   required from here
  /usr/include/c++/5/ext/new_allocator.h:120:4: error: no matching function for call to 'std::pair<const std::__cxx11::basic_string<char>, {anonymous}::TCParser::RegisteredAttr>::pair(llvm::StringRef&, {anonymous}::TCParser::RegisteredAttr'
  • Loading branch information
joker-eph committed Jan 11, 2021
1 parent cceb1bf commit 1107758
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 198 deletions.
22 changes: 0 additions & 22 deletions mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,3 @@ ods_def<Test3Op> :
def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
}

// Test attribute definitions
// ODS-LABEL: def Test4Op
// ODS: F32ArrayAttr:$array_attr,
// ODS: F32:$f32_attr,
// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
// ODS: I32:$i32_attr,
// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
// ODS: OptionalAttr<F32>:$optional_attr
//
ods_def<Test4Op> :
def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N))
attr(
f32_attr: f32,
i32_attr: i32,
fvec_attr: 4xf32,
ivec_attr: 5x6xi32,
array_attr : f32[],
optional_attr? : f32
) {
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
}
180 changes: 4 additions & 176 deletions mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,11 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"

#include <map>

#define DEBUG_TYPE "linalg-ods-gen"

static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
Expand Down Expand Up @@ -85,22 +79,18 @@ class Token {
gt,
l_brace,
l_paren,
l_square,
lt,
minus,
plus,
question,
r_brace,
r_paren,
r_square,
semicolon,
star,

// Keywords.
kw_def,
FIRST_KEYWORD = kw_def,
kw_ods_def,
kw_attr_def,
kw_floordiv,
kw_ceildiv,
kw_mod,
Expand Down Expand Up @@ -161,10 +151,6 @@ class Lexer {
Token emitError(llvm::SMLoc loc, const Twine &msg);
Token emitError(const char *loc, const Twine &msg);

/// Change the position of the lexer cursor. The next token we lex will start
/// at the designated point in the input.
void resetPointer(const char *newPtr) { curPtr = newPtr; }

private:
Token formToken(Token::Kind kind, const char *tokStart) {
return Token(kind, StringRef(tokStart, curPtr - tokStart));
Expand Down Expand Up @@ -261,14 +247,10 @@ Token Lexer::lexToken() {
return formToken(Token::Kind::l_brace, tokStart);
case '(':
return formToken(Token::Kind::l_paren, tokStart);
case '[':
return formToken(Token::Kind::l_square, tokStart);
case '}':
return formToken(Token::Kind::r_brace, tokStart);
case ')':
return formToken(Token::Kind::r_paren, tokStart);
case ']':
return formToken(Token::Kind::r_square, tokStart);
case '<':
return formToken(Token::Kind::lt, tokStart);
case '>':
Expand All @@ -281,8 +263,6 @@ Token Lexer::lexToken() {
return formToken(Token::Kind::semicolon, tokStart);
case '*':
return formToken(Token::Kind::star, tokStart);
case '?':
return formToken(Token::Kind::question, tokStart);
case '/':
if (*curPtr == '/') {
skipComment();
Expand All @@ -309,7 +289,6 @@ Token Lexer::lexIdentifier(const char *tokStart) {
// Check to see if this identifier is a keyword.
StringRef str(tokStart, curPtr - tokStart);
Token::Kind kind = StringSwitch<Token::Kind>(str)
.Case("attr", Token::Kind::kw_attr_def)
.Case("def", Token::Kind::kw_def)
.Case("ods_def", Token::Kind::kw_ods_def)
.Case("floordiv", Token::Kind::kw_floordiv)
Expand Down Expand Up @@ -373,40 +352,29 @@ class Parser {
"shouldn't advance past EOF or errors");
curToken = lexer.lexToken();
}

void consumeToken(Token::Kind kind) {
assert(curToken.getKind() == kind && "unexpected token");
curToken = lexer.lexToken();
}

LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
if (curToken.getKind() != kind)
return emitError(curToken.getLoc(), msg);
consumeToken();
return success();
}

/// Parses an optional token and returns failure if failed to parse.
LogicalResult parseOptionalToken(Token::Kind kind) {
return success(consumeIf(kind));
}

LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
lexer.emitError(loc, msg);
return failure();
}

LogicalResult emitError(const Twine &msg) {
return emitError(curToken.getLoc(), msg);
}

bool consumeIf(Token::Kind kind) {
if (curToken.isNot(kind))
return false;
consumeToken(kind);
return true;
}

LogicalResult
parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
// Non-empty case starts with an element.
Expand All @@ -420,7 +388,6 @@ class Parser {
}
return success();
}

LogicalResult
parseCommaSeparatedListUntil(Token::Kind rightToken,
llvm::function_ref<ParseResult()> parseElement,
Expand Down Expand Up @@ -994,8 +961,6 @@ class TCParser {
LogicalResult parseTensorUse(TensorUse &result,
ComprehensionParsingState &state);

LogicalResult parseAttrDef();

/// Parses a tensor expression.
LogicalResult parseExpression(TensorUse currentDefinition,
std::unique_ptr<Expression> &result,
Expand Down Expand Up @@ -1045,29 +1010,15 @@ class TCParser {
unsigned index;
};

//===--------------------------------------------------------------------===//
// Internal bookkeeping of attributes.
//===--------------------------------------------------------------------===//
struct RegisteredAttr {
StringRef elementType;
SmallVector<uint64_t, 4> vectorDims;
bool isArray;
bool isOptional;
};

//===--------------------------------------------------------------------===//
// Per-TC def state.
//===--------------------------------------------------------------------===//
/// Symbols are per TC def.
AffineSymbolList symbols;

/// Tensors are per TC def.
llvm::StringMap<RegisteredTensor> registeredTensors;
unsigned nextRegisteredTensorIndex;

/// Attributes are per TC def.
std::map<std::string, RegisteredAttr> registeredAttrs;

Parser &parser;
};
} // namespace
Expand Down Expand Up @@ -1219,72 +1170,6 @@ LogicalResult TCParser::parseTensorUse(TensorUse &result,
return success();
}

/// Parse the information for an attribute def of the form:
///
/// affine-expr-list ::= affine-expr (`,` affine-expr )*
/// attr-id ::= bare-id (`?`)?
/// dim-list ::= (integer-literal 'x')+
/// attr-typedef ::= dim-list? type (`[` `]`)?
/// attr-def ::= attr-id `:` attr-typedef
LogicalResult TCParser::parseAttrDef() {
auto attrLoc = parser.curToken.getLoc();
StringRef attrName = parser.curToken.getSpelling();
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
return failure();
bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question));
if (failed(parser.parseToken(Token::Kind::colon, "expected colon")))
return failure();

// Parse the attribute's type. We don't expect the type to be arbitrary
// complex, so just use this ad-hoc handling here.

// Parse potential dimension list
SmallVector<uint64_t, 4> vectorDims;
while (parser.curToken.is(Token::Kind::integer)) {
vectorDims.push_back(parser.curToken.getUInt64IntegerValue().getValue());
parser.consumeToken();

StringRef spelling = parser.curToken.getSpelling();
if (spelling[0] != 'x')
return parser.emitError(parser.curToken.getLoc(),
"expected 'x' in dimension list");

// If we had a prefix of 'x', lex the next token immediately after the 'x'.
if (spelling.size() != 1)
parser.lexer.resetPointer(spelling.data() + 1);

parser.consumeToken();
}

StringRef elementType = parser.curToken.getSpelling();
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
return failure();

bool isArray = false;
auto arrayLoc = parser.curToken.getLoc();
if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) {
isArray = true;
if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'")))
return failure();
}

if (!vectorDims.empty() && isArray)
return parser.emitError(arrayLoc, "unsupported vector array attribute");

auto iterBoolPair = registeredAttrs.emplace(
attrName, RegisteredAttr{elementType, vectorDims, isArray, isOptional});
if (!iterBoolPair.second)
return parser.emitError(attrLoc,
"Failed to register attribute '" + attrName + "'");

LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "")
<< " " << attrName << " "
<< "with type: " << elementType
<< (isArray ? "[]" : "") << "\n");

return success();
}

/// Parses a tensor expression of the form:
///
/// op-spec ::= bare-id `<` reduction-dims-list `>`
Expand Down Expand Up @@ -1456,13 +1341,10 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
/// Parse and print the information for a ODS def.
///
/// tensor-def-list ::= tensor-def (`,` tensor-def )*
/// attr-def-list ::= attr-def (`,` attr-def )*
///
/// comprehension-list ::= comprehension comprehension*
///
/// tc-attr-def ::= `attr` `(` attr-def-list `)`
/// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
/// (tc-attr-def)?
/// `{` comprehension-list `}`
///
/// ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
Expand All @@ -1471,7 +1353,6 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
/// contain only expressions involving symbols and constants), but can
/// otherwise contain arbitrary affine expressions.
LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
// Parse def header (including C++ op name)
if (failed(parser.parseToken(Token::Kind::kw_ods_def,
"expected 'ods_def' to define a TC ODS")) ||
failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
Expand All @@ -1483,15 +1364,12 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
return failure();

if (failed(parser.parseToken(Token::Kind::kw_def,
"expected 'def' to define a TC")))
return failure();

StringRef tcName = parser.curToken.getSpelling();
LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");

// Parse input/output tensor definitions
if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
return failure();
Expand All @@ -1514,16 +1392,6 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
return failure();

// Parse optional attribute definitions
if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) {
if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
return failure();
if (failed(parser.parseCommaSeparatedListUntil(
Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this),
/*allowEmptyList=*/false)))
return failure();
}

// Since we don't declare symbols separately, we discover them eagerly: each
// newly encountered id in a tensor shape expression is treated as a new
// symbolic. At this point, all tensors have been parsed and all the symbols
Expand Down Expand Up @@ -1582,52 +1450,12 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
StringRef linalgOpName,
ComprehensionParsingState &state) {
SmallVector<std::string, 4> attributes;
for (const auto &attr : registeredAttrs) {
llvm::StringRef name = attr.first;

llvm::StringRef elementType = attr.second.elementType;
std::string odsType = llvm::StringSwitch<std::string>(elementType)
.Case("f32", "F32")
.Case("i32", "I32")
.Default("");
if (odsType.empty()) {
parser.emitError("unimplemented support for attribute element type: " +
elementType);
return;
}

const auto &dims = attr.second.vectorDims;
if (!dims.empty()) {
SmallVector<std::string, 4> dimStrs;
for (uint64_t dim : dims)
dimStrs.push_back(std::to_string(dim));
odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
llvm::join(dimStrs, ", "));
}

assert(dims.empty() || !attr.second.isArray);
if (attr.second.isArray)
odsType = llvm::formatv("{0}ArrayAttr", odsType);

if (attr.second.isOptional)
odsType = llvm::formatv("OptionalAttr<{0}>", odsType);

attributes.push_back(llvm::formatv("{0}:${1}", odsType, name));
}

std::string attrList = llvm::join(attributes, ",\n");
if (!attrList.empty())
attrList = ",\n" + attrList;

const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">]> {
let arguments = (ins
Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs{4}
);
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
Expand Down Expand Up @@ -1687,7 +1515,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
// Generic methods.
static unsigned getNumRegionArgs() {{ return {5}; }
static unsigned getNumRegionArgs() {{ return {4}; }
std::string getLibraryCallName() {{
return generateLibraryCallName(getOperation());
}
Expand All @@ -1703,7 +1531,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
}

os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
attrList, state.orderedTensorArgs.size());
state.orderedTensorArgs.size());
}

/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
Expand Down

0 comments on commit 1107758

Please sign in to comment.