Skip to content

Commit

Permalink
External library name mangling support for linalg.
Browse files Browse the repository at this point in the history
This CL introduces the ability to generate the external library name for Linalg operations.
The problem is that neither mlir or C support overloading and we want a simplified form of name mangling that is still reasonable to read.
This CL creates the name of the external call that Linalg expects from the operation name and the type of its arguments.

The interface library names are updated and use new cases are added for FillOp.

PiperOrigin-RevId: 262556833
  • Loading branch information
Nicolas Vasilache authored and tensorflower-gardener committed Aug 9, 2019
1 parent a475e3e commit 8bb9d50
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 27 deletions.
19 changes: 9 additions & 10 deletions third_party/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ class LinalgLibraryBase_Op<string mnemonic, list<OpTrait> props>

class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
: LinalgLibraryBase_Op<mnemonic, props> {

code classDeclaration = [{
StringRef getLibraryCallName() {
return "linalg_}] # mnemonic # [{";
code libraryCallName = [{
std::string getLibraryCallName() {
return generateLibraryCallName(getOperation());
}
}];
}
Expand Down Expand Up @@ -138,7 +137,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
return build(
builder, result, input, output, AffineMapAttr(), AffineMapAttr());
}]>];
let extraClassDeclaration = classDeclaration # [{
let extraClassDeclaration = libraryCallName # [{
unsigned getNumParallelLoops() {
auto *view = *(getOperands().begin());
return view->getType().cast<ViewType>().getRank();
Expand All @@ -151,7 +150,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {

def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
let extraClassDeclaration = classDeclaration # [{
let extraClassDeclaration = libraryCallName # [{
unsigned getNumParallelLoops() {
auto *view = *(getOperands().begin());
return view->getType().cast<ViewType>().getRank();
Expand All @@ -170,23 +169,23 @@ def DotOp : LinalgLibrary_Op<"dot",
NLoopTypes<0, 1, 0>,
ViewRanks<[1, 1, 0]>]> {
let arguments = (ins View, View, View);
let extraClassDeclaration = classDeclaration;
let extraClassDeclaration = libraryCallName;
}

def MatvecOp : LinalgLibrary_Op<"matvec",
[NInputsAndOutputs<2, 1>,
NLoopTypes<1, 1, 0>,
ViewRanks<[2, 1, 1]>]> {
let arguments = (ins View, View, View);
let extraClassDeclaration = classDeclaration;
let extraClassDeclaration = libraryCallName;
}

def MatmulOp : LinalgLibrary_Op<"matmul",
[NInputsAndOutputs<2, 1>,
NLoopTypes<2, 1, 0>,
ViewRanks<[2, 2, 2]>]> {
let arguments = (ins View, View, View);
let extraClassDeclaration = classDeclaration;
let extraClassDeclaration = libraryCallName;
}

def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
Expand All @@ -211,7 +210,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
let arguments = (ins View:$filter, View:$input, View:$output,
OptionalAttr<I64ArrayAttr>:$strides,
OptionalAttr<I64ArrayAttr>:$dilations);
let extraClassDeclaration = classDeclaration # [{
let extraClassDeclaration = libraryCallName # [{
// TODO(ntv) extend to support more than 1 dimensions and potentially
// grouping too.
unsigned getNumBatchDimensions() { return 1; }
Expand Down
22 changes: 22 additions & 0 deletions third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,28 @@ class StoreOp
}
};

/// Returns the name mangled library call name to disambiguate between different
/// overloads at the C level. The name mangling scheme is basic and uses MLIR
/// type names:
/// 1. form a string which is the concatenation of the linalg op name with all
/// the operand type names, separate by underscores;
/// 2. drop the `linalg.` prefix, and the `<`, `>`, `?` symbols from the type.
/// Assumes `op` is a LinalgOp.
///
/// Examples:
///
/// 1. linalg.fill(%A, %f) : !linalg.view<f32>, f32
/// name mangles into `linalg_fill_viewf32_f32_impl`
///
/// 2. linalg.dot(%A, %B, %C) :
/// !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
/// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl`
///
/// 3. linalg.matmul(...) :
/// !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
/// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl`
std::string generateLibraryCallName(Operation *op);

#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.h.inc"

Expand Down
33 changes: 33 additions & 0 deletions third_party/mlir/lib/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "mlir/Transforms/FoldUtils.h"

#include "llvm/ADT/StringSet.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::edsc;
Expand Down Expand Up @@ -1085,3 +1086,35 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
}
llvm_unreachable("Missing loopToOperandRangesMaps for op");
}

static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
if (auto view = t.dyn_cast<ViewType>()) {
ss << "view";
for (unsigned i = 0, e = view.getRank(); i < e; ++i)
ss << "x";
appendMangledType(ss, view.getElementType());
} else if (auto vec = t.dyn_cast<VectorType>()) {
ss << "vector";
interleave(
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
appendMangledType(ss, vec.getElementType());
} else if (t.isIntOrIndexOrFloat()) {
ss << t;
} else {
llvm_unreachable("Invalid type for linalg library name mangling");
}
}

std::string mlir::linalg::generateLibraryCallName(Operation *op) {
assert(isa<LinalgOp>(op));
std::string name(op->getName().getStringRef().str());
name.reserve(128);
std::replace(name.begin(), name.end(), '.', '_');
llvm::raw_string_ostream ss(name);
ss << "_";
auto types = op->getOperandTypes();
interleave(
types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
[&]() { ss << "_"; });
return ss.str();
}
34 changes: 17 additions & 17 deletions third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/LowerAffine.h"
#include "mlir/Transforms/Passes.h"

#include "llvm/ADT/SetVector.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
Expand Down Expand Up @@ -545,7 +547,7 @@ static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) {
}
SmallVector<Type, 4> fnArgTypes;
for (auto t : libFn.getType().getInputs()) {
assert(t.isa<LLVMType>() &&
assert(t && t.isa<LLVMType>() &&
"Expected LLVM Type for argument while generating library Call "
"Implementation Definition");
fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
Expand Down Expand Up @@ -577,12 +579,8 @@ getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,

// Get the Function type consistent with LLVM Lowering.
SmallVector<Type, 4> inputTypes;
for (auto operand : op->getOperands()) {
// TODO(ravishankarm): convertLinalgType handles only a subset of Linalg
// types. Handle other types (as well as non-Linalg types) either here or in
// convertLinalgType.
inputTypes.push_back(convertLinalgType(operand->getType(), lowering));
}
for (auto operand : op->getOperands())
inputTypes.push_back(lowering.convertType(operand->getType()));
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
Expand Down Expand Up @@ -632,15 +630,15 @@ class LinalgTypeConverter : public LLVMTypeConverter {
return convertLinalgType(t, *this);
}

void addLibraryFnDeclaration(FuncOp fn) {
libraryFnDeclarations.push_back(fn);
}
void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); }

ArrayRef<FuncOp> getLibraryFnDeclarations() { return libraryFnDeclarations; }
ArrayRef<FuncOp> getLibraryFnDeclarations() {
return libraryFnDeclarations.getArrayRef();
}

private:
/// List of library functions declarations needed during dialect conversion
SmallVector<FuncOp, 2> libraryFnDeclarations;
llvm::SetVector<FuncOp> libraryFnDeclarations;
};
} // end anonymous namespace

Expand Down Expand Up @@ -676,11 +674,13 @@ static void
populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
OwningRewritePatternList &patterns,
MLIRContext *ctx) {
patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion,
LinalgOpConversion<DotOp>, LinalgOpConversion<MatmulOp>,
LoadOpConversion, RangeOpConversion, SliceOpConversion,
StoreOpConversion, ViewOpConversion>(ctx, converter);
patterns
.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
BufferSizeOpConversion, DimOpConversion,
LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
LinalgOpConversion<MatmulOp>, LoadOpConversion, RangeOpConversion,
SliceOpConversion, StoreOpConversion, ViewOpConversion>(
ctx, converter);
}

namespace {
Expand Down

0 comments on commit 8bb9d50

Please sign in to comment.