Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "external/llvm-project"]
path = external/llvm-project
url = https://github.com/llvm/llvm-project.git
6 changes: 1 addition & 5 deletions build_tools/install_mlir.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ set -e
td="$(realpath $(dirname $0)/..)"

# Find LLVM source (assumes it is adjacent to this directory).
if [ -z "$LLVM_SRC_DIR" ]; then
LLVM_SRC_DIR="$td/../llvm-project"
fi
LLVM_SRC_DIR="$(realpath "$LLVM_SRC_DIR")"
LLVM_SRC_DIR="$(realpath "${LLVM_SRC_DIR:-$td/external/llvm-project}")"

if ! [ -f "$LLVM_SRC_DIR/llvm/CMakeLists.txt" ]; then
echo "Expected LLVM_SRC_DIR variable to be set correctly (got '$LLVM_SRC_DIR')"
Expand Down Expand Up @@ -42,4 +39,3 @@ cmake -GNinja \
-DLLVM_ENABLE_RTTI=On

cmake --build "$build_mlir" --target install
#cmake --build "$build_mlir" --target install
1 change: 1 addition & 0 deletions external/llvm-project
Submodule llvm-project added at 0b161d
36 changes: 4 additions & 32 deletions include/npcomp/Dialect/TCP/IR/TCPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Broadcasts `operand` to the shape `shape`.

It is undefined behavior if such a broadcast is not legal.
}];
let arguments = (ins AnyRankedTensor:$operand, Shape_ShapeType:$shape);
let arguments = (ins AnyRankedTensor:$operand, Shape_ExtentTensorType:$shape);
let results = (outs AnyRankedTensor:$result);
}

Expand All @@ -54,7 +54,7 @@ def TCP_AllocMemRefOp : TCP_Op<"alloc_memref", []> {
let description = [{
Allocates a memref of the given shape.
}];
let arguments = (ins Shape_ShapeType:$shape);
let arguments = (ins Shape_ExtentTensorType:$shape);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$shape attr-dict `:` type($memref)";
}
Expand Down Expand Up @@ -91,8 +91,7 @@ def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
// produce the error in practice. The ops like shape.broadcast itself, when
// lowered, immediately produce errors.
// TODO: This should eventually be moved to a shape dialect.
def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error",
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error", []> {
let summary = "Observes the fact that a shape might be an error.";
let description = [{
This op is a structural placeholder that captures a shape such that it
Expand All @@ -103,37 +102,10 @@ def TCP_ShapeObserveErrorOp : TCP_Op<"shape_observe_error",
effecting ops, is not very well-defined, and needs to be worked
on/redesigned.
}];
let arguments = (ins Shape_ShapeType:$shape);
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
// TODO: ODS seems to create redeclared class members if we remove this,
// resulting in C++ compilation errors.
let results = (outs NoneType:$dummy);
}

// TODO: This probably belongs in the shape dialect.
def TCP_GetExtentOp : TCP_Op<"get_extent",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Gets the specified extent from a shape.";
let description = [{
Gets the specified extent from a shape.

This op has undefined behavior if the shape is an error.
}];
let arguments = (ins Shape_ShapeType:$shape, I64Attr:$dim);
let results = (outs Index:$extent);
let assemblyFormat = "$shape `,` $dim attr-dict";

let builders = [
// Helper to pass a simple integer instead of an integer attr.
OpBuilder<
[{
OpBuilder &builder, OperationState &result,
Value shape, int64_t dim
}],
[{
build(builder, result, shape, builder.getI64IntegerAttr(dim));
}]
>
];
}

#endif // TCP_OPS
13 changes: 11 additions & 2 deletions lib/Conversion/TCFToTCP/TCFToTCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ using namespace mlir;
using namespace mlir::NPCOMP;

namespace {

RankedTensorType getExtentTensorType(Builder &builder) {
return RankedTensorType::get({ShapedType::kDynamicSize},
builder.getIndexType());
}

class ConvertAdd : public OpRewritePattern<tcf::AddOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand All @@ -34,6 +40,9 @@ class ConvertAdd : public OpRewritePattern<tcf::AddOp> {
Value broadcastedShape = rewriter.create<shape::BroadcastOp>(
op.getLoc(), lhsShape, rhsShape, /*error=*/nullptr);
rewriter.create<tcp::ShapeObserveErrorOp>(op.getLoc(), broadcastedShape);
Value broadcastedExtents = rewriter.create<shape::ToExtentTensorOp>(
op.getLoc(), getExtentTensorType(rewriter), broadcastedShape);

// TODO: It's annoying to do the dynamic broadcast above then
// do the static transfer function here. Would be nice if they could
// somehow be unified.
Expand All @@ -43,9 +52,9 @@ class ConvertAdd : public OpRewritePattern<tcf::AddOp> {
auto resultType =
RankedTensorType::get(broadcastedStaticShape, lhsType.getElementType());
Value lhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
op.getLoc(), resultType, op.lhs(), broadcastedShape);
op.getLoc(), resultType, op.lhs(), broadcastedExtents);
Value rhsBroadcasted = rewriter.create<tcp::BroadcastToOp>(
op.getLoc(), resultType, op.rhs(), broadcastedShape);
op.getLoc(), resultType, op.rhs(), broadcastedExtents);
Value add = rewriter.create<tcp::AddOp>(op.getLoc(), op.getType(),
lhsBroadcasted, rhsBroadcasted);
rewriter.replaceOp(op, add);
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/TCPToLinalg/TCPToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class ConvertAdd : public OpRewritePattern<tcp::AddOp> {
auto genericOp = rewriter.create<linalg::GenericOp>(
op.getLoc(), llvm::makeArrayRef({op.getType()}),
ValueRange({op.lhs(), op.rhs()}),
/*args_in=*/rewriter.getI64IntegerAttr(2),
/*args_out=*/rewriter.getI64IntegerAttr(1),
/*indexing_maps=*/rewriter.getAffineMapArrayAttr(accesses),
/*iterator_types=*/rewriter.getStrArrayAttr(iterators), /*doc=*/nullptr,
/*library_call=*/nullptr);
/*args_in=*/2,
/*args_out=*/1,
/*indexing_maps=*/accesses,
/*iterator_types=*/iterators,
/*function_ref=*/nullptr);

Region &region = genericOp.region();
Block *block = rewriter.createBlock(&region, region.begin());
Expand Down
25 changes: 0 additions & 25 deletions lib/Dialect/TCP/IR/TCPOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,6 @@ static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// ShapeObserveErrorOp
//===----------------------------------------------------------------------===//

LogicalResult ShapeObserveErrorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(NoneType::get(context));
return success();
}

//===----------------------------------------------------------------------===//
// GetExtentOp
//===----------------------------------------------------------------------===//

LogicalResult
GetExtentOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(IndexType::get(context));
return success();
}

namespace mlir {
namespace NPCOMP {
namespace tcp {
Expand Down
22 changes: 11 additions & 11 deletions lib/E2E/E2E.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,11 @@ class LowerLinalgLoopDimOp : public OpRewritePattern<DimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp op,
PatternRewriter &rewriter) const override {
// TODO: Remove this const pattern when lowering to shape.get_extent.
auto constIndex = op.getConstantIndex();
if (!constIndex)
return failure();

auto allocMemRef = op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
if (!allocMemRef)
return rewriter.notifyMatchFailure(op, "could not find alloc_memref");
rewriter.replaceOpWithNewOp<tcp::GetExtentOp>(op, allocMemRef.shape(),
*constIndex);
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(
Copy link
Contributor

@silvasean silvasean Aug 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this CL delete all uses of tcp::GetExtentOp if so, we can remove tcp::GetExtentOp from the dialect. (that was on my TODO list; sorry you had to bear the brunt of this update).

D'oh, I see it's in your next patch. Sorry for the noise.

op, rewriter.getIndexType(), allocMemRef.shape(), op.index());
return success();
}
};
Expand All @@ -231,7 +226,7 @@ class LowerLinalgLoopDimOps
// remove this.
return !op.memrefOrTensor().getDefiningOp<tcp::AllocMemRefOp>();
});
target.addLegalOp<tcp::GetExtentOp>();
target.addLegalOp<shape::GetExtentOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
Expand Down Expand Up @@ -261,7 +256,8 @@ class LowerAllocMemRefOp : public OpRewritePattern<tcp::AllocMemRefOp> {
SmallVector<Value, 6> dynamicExtents;
for (int i = 0, e = memrefType.getRank(); i < e; i++) {
if (memrefType.isDynamicDim(i)) {
auto extent = rewriter.create<tcp::GetExtentOp>(op.getLoc(), shape, i);
auto extent =
rewriter.create<shape::GetExtentOp>(op.getLoc(), shape, i);
dynamicExtents.push_back(extent);
}
}
Expand All @@ -281,8 +277,9 @@ class LowerAllocMemRefOps
patterns.insert<LowerAllocMemRefOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<tcp::AllocMemRefOp>();
target.addLegalOp<tcp::GetExtentOp>();
target.addLegalOp<shape::GetExtentOp>();
target.addLegalOp<AllocOp>();
target.addLegalOp<ConstantOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
return signalPassFailure();
}
Expand Down Expand Up @@ -433,8 +430,11 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
pm.addPass(createLowerRankedShapesPass());

// Run a some cleanups.
// TODO: Some folding and DCE of dangling ops is still needed here. Once the
// invariants above are tightened up, the canonicalize should be moved into
// the optimize block.
pm.addPass(createCanonicalizerPass());
if (options.optimize) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}

Expand Down
Loading