Skip to content

Commit

Permalink
Allow linalg.view to change the underlying elemental type.
Browse files Browse the repository at this point in the history
This CL adds the ability for linalg.view to act as a bitcast operation.
This will be used when promoting views into faster memory and casting to vector types.

In the process, linalg.view is moved to ODS.

PiperOrigin-RevId: 262556246
  • Loading branch information
Nicolas Vasilache authored and tensorflower-gardener committed Aug 9, 2019
1 parent de9771e commit a475e3e
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 133 deletions.
41 changes: 0 additions & 41 deletions third_party/mlir/include/mlir/Linalg/IR/LinalgOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,47 +186,6 @@ class StoreOp
}
};

/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
/// range abstraction on top of an underlying linalg.buffer. This gives an
/// indexing structure to an otherwise non-indexable linalg.buffer.
///
/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
/// a `view` of the same elemental type as the buffer and of rank the number of
/// ranges:
///
/// ```{.mlir}
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
/// ```
class ViewOp : public Op<ViewOp, OpTrait::VariadicOperands, OpTrait::OneResult,
OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };

public:
using Op::Op;

// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.view"; }
static void build(Builder *b, OperationState *result, Value *buffer,
llvm::ArrayRef<Value *> indexings);
LogicalResult verify();
static ParseResult parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);

// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
Value *getSupportingBuffer() { return getOperand(0); }
// Get the underlying indexing at a given rank.
Value *getIndexing(unsigned rank) { return *(getIndexings().begin() + rank); }
// Get all the indexings in this view.
Operation::operand_range getIndexings() {
return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
}
};

#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.h.inc"

Expand Down
45 changes: 45 additions & 0 deletions third_party/mlir/include/mlir/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,51 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
}];
}

def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
Results<(outs View)> {
let summary = "view operation";
let description = [{
The "linalg.view" op produces a linalg.view which is a multi-dimensional
range abstraction on top of an underlying linalg.buffer. This gives an
indexing structure to an otherwise non-indexable linalg.buffer.

A "linalg.view" takes a buffer and a variadic number of ranges and produces
a `view` of rank the number of ranges. The elemental type may not match the
buffer element type:

Examples:
```
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.view %1[%2, %2] : !linalg.view<?x?xvector<4xf32>>
```
}];

let builders = [OpBuilder<
"Builder *b, OperationState *result, Value *buffer, "
"ArrayRef<Value *> ranges, Type resultType = Type(), "
"ArrayRef<NamedAttribute> attrs = {}">];

let verifier = [{
if (getViewType().getRank() != llvm::size(ranges()))
return emitOpError("the view rank must be the number of its ranges");
return success();
}];

let extraClassDeclaration = [{
enum { FirstIndexingOperand = 1 };
unsigned getRank() { return getViewType().getRank(); }
Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
/// Get the underlying indexing at a given rank.
Value *getRange(unsigned rank) {
assert(rank < getRank() && "rank overflow");
return *(ranges().begin() + rank);
}
}];
}

def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Value *Aliases::find(Value *v) {
return it.first->second;
}
if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
auto it = aliases.insert(std::make_pair(v, view.getSupportingBuffer()));
auto it = aliases.insert(std::make_pair(v, view.buffer()));
return it.first->second;
}
if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
Expand Down
166 changes: 78 additions & 88 deletions third_party/mlir/lib/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
Value *min, *max, *step;
if (view) {
// Cannot traverse block arguments, fail.
if (isa<BlockArgument>(view.getIndexing(dim)))
if (isa<BlockArgument>(view.getRange(dim)))
return matchFailure();
// Record min, max, step for further processing.
auto range = cast<RangeOp>(view.getIndexing(dim)->getDefiningOp());
auto range = cast<RangeOp>(view.getRange(dim)->getDefiningOp());
std::tie(min, max, step) =
std::make_tuple(range.min(), range.max(), range.step());
} else if (subView) {
Expand Down Expand Up @@ -414,97 +414,15 @@ LogicalResult mlir::linalg::StoreOp::verify() {
return success();
}

//////////////////////////////////////////////////////////////////////////////
// ViewOp
//////////////////////////////////////////////////////////////////////////////
void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
Value *buffer, ArrayRef<Value *> indexings) {
BufferType bufferType = buffer->getType().cast<BufferType>();
result->addOperands({buffer});
result->addOperands(indexings);
assert(
std::none_of(indexings.begin(), indexings.end(),
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
"linalg.view takes only arguments of type linalg.range");

Type elementType = bufferType.getElementType();
result->addTypes(
{ViewType::get(b->getContext(), elementType, indexings.size())});
}

LogicalResult mlir::linalg::ViewOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a buffer operand followed by indexings");
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
if (!bufferType)
return emitOpError("first operand must be of BufferType");
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>()) {
return emitOpError() << index << "^th index must be of range type";
}
++index;
}
if (getViewType().getRank() != index)
return emitOpError()
<< "the rank of the view must be the number of its indexings";
return success();
}

ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type bType, type;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColon() || parser->parseType(bType) ||
parser->parseArrow() || parser->parseType(type)) {
return failure();
}

BufferType bufferType = bType.dyn_cast<BufferType>();
if (!bufferType) {
return parser->emitError(parser->getNameLoc(), "buffer type expected");
}

ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != indexingsInfo.size())
return parser->emitError(parser->getNameLoc(), "expected")
<< viewType.getRank() << " range indexings";
return failure(
parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}

// A ViewOp prints as:
//
// ```{.mlir}
// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
void mlir::linalg::ViewOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
interleave(
getIndexings().begin(), getIndexings().end(), [&](Value *v) { *p << *v; },
[&]() { *p << ", "; });
*p << "] : " << getSupportingBuffer()->getType() << " -> " << getType();
}

///////////////////// Operations defined with Tablegen /////////////////////////
// For such operations that do not correspond to library calls (i.e. defined in
// LinalgOps.td), we define an overloaded `print` function and a
// parse`className` function.

//===----------------------------------------------------------------------===//
// BufferAllocOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter *p, BufferAllocOp op) {
*p << op.getOperationName() << " ";
if (!llvm::empty(op.size()))
Expand Down Expand Up @@ -544,6 +462,10 @@ static LogicalResult verify(BufferAllocOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// BufferDeallocOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter *p, BufferDeallocOp op) {
*p << op.getOperationName() << " " << *op.buffer();
p->printOptionalAttrDict(op.getAttrs());
Expand All @@ -565,6 +487,10 @@ static void print(OpAsmPrinter *p, BufferSizeOp op) {
*p << " : " << op.getOperand()->getType();
}

//===----------------------------------------------------------------------===//
// BufferSizeOp
//===----------------------------------------------------------------------===//

static ParseResult parseBufferSizeOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType op;
Expand Down Expand Up @@ -747,6 +673,66 @@ static LogicalResult verify(GenericOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//
void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
Value *buffer, ArrayRef<Value *> ranges,
Type resultType,
ArrayRef<NamedAttribute> attrs) {
if (!resultType) {
Type elementType = buffer->getType().cast<BufferType>().getElementType();
resultType = ViewType::get(b->getContext(), elementType, ranges.size());
}
build(b, result, resultType, buffer, ranges);
result->addAttributes(attrs);
}

static ParseResult parseViewOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
Type bType, vType;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColon() || parser->parseType(bType) ||
parser->parseArrow() || parser->parseType(vType)) {
return failure();
}

BufferType bufferType = bType.dyn_cast<BufferType>();
if (!bufferType) {
return parser->emitError(parser->getNameLoc(), "buffer type expected");
}

ViewType viewType = vType.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != rangesInfo.size())
return parser->emitError(parser->getNameLoc(), "expected")
<< viewType.getRank() << " range ranges";
return failure(
parser->resolveOperand(bufferInfo, bufferType, result->operands) ||
(!rangesInfo.empty() &&
parser->resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types));
}

// A ViewOp prints as:
//
// ```{.mlir}
// linalg.view %0[%1, %2] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
static void print(OpAsmPrinter *p, ViewOp op) {
*p << op.getOperationName() << " " << *op.buffer() << "[";
interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; });
*p << "] : " << op.buffer()->getType() << " -> " << op.getType();
}

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -808,6 +794,10 @@ static void print(OpAsmPrinter *p, SubViewOp op) {
*p << " : " << op.getViewType();
}

//===----------------------------------------------------------------------===//
// SubViewOp
//===----------------------------------------------------------------------===//

static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType inputView, resultView;
Type viewType;
Expand Down
2 changes: 1 addition & 1 deletion third_party/mlir/lib/Linalg/IR/LinalgTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ using namespace mlir::linalg;
mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<BufferType, RangeType, ViewType>();
addOperations<LoadOp, RangeOp, StoreOp, SliceOp, ViewOp>();
addOperations<LoadOp, RangeOp, StoreOp, SliceOp>();
addOperations<
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
Expand Down
4 changes: 2 additions & 2 deletions third_party/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,9 @@ class ViewOpConversion : public LLVMOpLowering {
desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));

// Compute and insert view sizes (max - min along the range).
int numIndexings = llvm::size(viewOp.getIndexings());
int numRanges = llvm::size(viewOp.ranges());
Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
for (int i = numIndexings - 1; i >= 0; --i) {
for (int i = numRanges - 1; i >= 0; --i) {
// Update stride.
Value *rangeDescriptor = operands[1 + i];
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
Expand Down

0 comments on commit a475e3e

Please sign in to comment.