From 8bb9d508e3dd9132648db223e19c80609787e822 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 9 Aug 2019 07:33:34 -0700 Subject: [PATCH] External library name mangling support for linalg. This CL introduces the ability to generate the external library name for Linalg operations. The problem is that neither mlir or C support overloading and we want a simplified form of name mangling that is still reasonable to read. This CL creates the name of the external call that Linalg expects from the operation name and the type of its arguments. The interface library names are updated and use new cases are added for FillOp. PiperOrigin-RevId: 262556833 --- .../mlir/Linalg/IR/LinalgLibraryOps.td | 19 +++++------ .../mlir/include/mlir/Linalg/IR/LinalgOps.h | 22 ++++++++++++ third_party/mlir/lib/Linalg/IR/LinalgOps.cpp | 33 ++++++++++++++++++ .../Linalg/Transforms/LowerToLLVMDialect.cpp | 34 +++++++++---------- 4 files changed, 81 insertions(+), 27 deletions(-) diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td b/third_party/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td index 547a2c4cf81..998d68ba806 100644 --- a/third_party/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td +++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td @@ -80,10 +80,9 @@ class LinalgLibraryBase_Op props> class LinalgLibrary_Op props> : LinalgLibraryBase_Op { - - code classDeclaration = [{ - StringRef getLibraryCallName() { - return "linalg_}] # mnemonic # [{"; + code libraryCallName = [{ + std::string getLibraryCallName() { + return generateLibraryCallName(getOperation()); } }]; } @@ -138,7 +137,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { return build( builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; - let extraClassDeclaration = classDeclaration # [{ + let extraClassDeclaration = libraryCallName # [{ unsigned getNumParallelLoops() { auto *view = *(getOperands().begin()); return view->getType().cast().getRank(); @@ -151,7 +150,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> { let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>); - let extraClassDeclaration = classDeclaration # [{ + let extraClassDeclaration = libraryCallName # [{ unsigned getNumParallelLoops() { auto *view = *(getOperands().begin()); return view->getType().cast().getRank(); @@ -170,7 +169,7 @@ def DotOp : LinalgLibrary_Op<"dot", NLoopTypes<0, 1, 0>, ViewRanks<[1, 1, 0]>]> { let arguments = (ins View, View, View); - let extraClassDeclaration = classDeclaration; + let extraClassDeclaration = libraryCallName; } def MatvecOp : LinalgLibrary_Op<"matvec", @@ -178,7 +177,7 @@ def MatvecOp : LinalgLibrary_Op<"matvec", NLoopTypes<1, 1, 0>, ViewRanks<[2, 1, 1]>]> { let arguments = (ins View, View, View); - let extraClassDeclaration = classDeclaration; + let extraClassDeclaration = libraryCallName; } def MatmulOp : LinalgLibrary_Op<"matmul", @@ -186,7 +185,7 @@ def MatmulOp : LinalgLibrary_Op<"matmul", NLoopTypes<2, 1, 0>, ViewRanks<[2, 2, 2]>]> { let arguments = (ins View, View, View); - let extraClassDeclaration = classDeclaration; + let extraClassDeclaration = libraryCallName; } def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> { @@ -211,7 +210,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> { let arguments = (ins View:$filter, View:$input, View:$output, OptionalAttr:$strides, OptionalAttr:$dilations); - let extraClassDeclaration = classDeclaration # [{ + let extraClassDeclaration = libraryCallName # [{ // TODO(ntv) extend to support more than 1 dimensions and potentially // grouping too. unsigned getNumBatchDimensions() { return 1; } diff --git a/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h b/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h index 4085d066324..3187f4f80ef 100644 --- a/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -186,6 +186,28 @@ class StoreOp } }; +/// Returns the name mangled library call name to disambiguate between different +/// overloads at the C level. The name mangling scheme is basic and uses MLIR +/// type names: +/// 1. form a string which is the concatenation of the linalg op name with all +/// the operand type names, separate by underscores; +/// 2. drop the `linalg.` prefix, and the `<`, `>`, `?` symbols from the type. +/// Assumes `op` is a LinalgOp. +/// +/// Examples: +/// +/// 1. linalg.fill(%A, %f) : !linalg.view, f32 +/// name mangles into `linalg_fill_viewf32_f32_impl` +/// +/// 2. linalg.dot(%A, %B, %C) : +/// !linalg.view, !linalg.view, !linalg.view +/// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl` +/// +/// 3. linalg.matmul(...) : +/// !linalg.view, !linalg.view, !linalg.view +/// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl` +std::string generateLibraryCallName(Operation *op); + #define GET_OP_CLASSES #include "mlir/Linalg/IR/LinalgOps.h.inc" diff --git a/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp b/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp index 6549508fc09..bce2b32be77 100644 --- a/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/third_party/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -37,6 +37,7 @@ #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::edsc; @@ -1085,3 +1086,35 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { } llvm_unreachable("Missing loopToOperandRangesMaps for op"); } + +static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { + if (auto view = t.dyn_cast()) { + ss << "view"; + for (unsigned i = 0, e = view.getRank(); i < e; ++i) + ss << "x"; + appendMangledType(ss, view.getElementType()); + } else if (auto vec = t.dyn_cast()) { + ss << "vector"; + interleave( + vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); + appendMangledType(ss, vec.getElementType()); + } else if (t.isIntOrIndexOrFloat()) { + ss << t; + } else { + llvm_unreachable("Invalid type for linalg library name mangling"); + } +} + +std::string mlir::linalg::generateLibraryCallName(Operation *op) { + assert(isa(op)); + std::string name(op->getName().getStringRef().str()); + name.reserve(128); + std::replace(name.begin(), name.end(), '.', '_'); + llvm::raw_string_ostream ss(name); + ss << "_"; + auto types = op->getOperandTypes(); + interleave( + types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, + [&]() { ss << "_"; }); + return ss.str(); +} diff --git a/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 6967a9dd331..a45f943bea8 100644 --- a/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -39,6 +39,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/LowerAffine.h" #include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/SetVector.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" @@ -545,7 +547,7 @@ static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) { } SmallVector fnArgTypes; for (auto t : libFn.getType().getInputs()) { - assert(t.isa() && + assert(t && t.isa() && "Expected LLVM Type for argument while generating library Call " "Implementation Definition"); fnArgTypes.push_back(t.cast().getPointerTo()); @@ -577,12 +579,8 @@ getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, // Get the Function type consistent with LLVM Lowering. SmallVector inputTypes; - for (auto operand : op->getOperands()) { - // TODO(ravishankarm): convertLinalgType handles only a subset of Linalg - // types. Handle other types (as well as non-Linalg types) either here or in - // convertLinalgType. - inputTypes.push_back(convertLinalgType(operand->getType(), lowering)); - } + for (auto operand : op->getOperands()) + inputTypes.push_back(lowering.convertType(operand->getType())); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); @@ -632,15 +630,15 @@ class LinalgTypeConverter : public LLVMTypeConverter { return convertLinalgType(t, *this); } - void addLibraryFnDeclaration(FuncOp fn) { - libraryFnDeclarations.push_back(fn); - } + void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); } - ArrayRef getLibraryFnDeclarations() { return libraryFnDeclarations; } + ArrayRef getLibraryFnDeclarations() { + return libraryFnDeclarations.getArrayRef(); + } private: /// List of library functions declarations needed during dialect conversion - SmallVector libraryFnDeclarations; + llvm::SetVector libraryFnDeclarations; }; } // end anonymous namespace @@ -676,11 +674,13 @@ static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert, LinalgOpConversion, - LoadOpConversion, RangeOpConversion, SliceOpConversion, - StoreOpConversion, ViewOpConversion>(ctx, converter); + patterns + .insert, LinalgOpConversion, + LinalgOpConversion, LoadOpConversion, RangeOpConversion, + SliceOpConversion, StoreOpConversion, ViewOpConversion>( + ctx, converter); } namespace {