Skip to content

[Transform] Only use gc runtime allocator for stack-like alloca ops #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 75 additions & 44 deletions lib/gc/Transforms/MemRefToCPURuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace gc {

namespace {

constexpr uint64_t STACK_ALLOC_THRESHOLD = 128;

bool hasParallelParent(Operation *op) {
// Check if the parent contains a forall / parallel loop
for (Operation *parentOp = op->getParentOp(); parentOp != nullptr;
Expand All @@ -38,9 +40,38 @@ bool hasParallelParent(Operation *op) {
}
return false;
}
struct AlignedAllocLowering : public OpRewritePattern<memref::AllocOp> {
using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::AllocOp op,

uint64_t getMemRefSizeInBytes(MemRefType memrefType) {
if (ShapedType::isDynamicShape(memrefType.getShape()))
return UINT64_MAX;
ShapedType shapeType = cast<ShapedType>(memrefType);
int elementSize = shapeType.getElementTypeBitWidth() / 8;
AffineMap layout = memrefType.getLayout().getAffineMap();
ArrayRef<int64_t> shape = memrefType.getShape();
if (!layout.isIdentity()) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memrefType, strides, offset))) {
return UINT64_MAX;
}

int totalSize = elementSize;
for (size_t i = 0; i < shape.size(); ++i) {
totalSize *= (i == shape.size() - 1) ? strides[i] : shape[i];
}
return totalSize;
} else {
int totalSize = elementSize;
for (int64_t dim : shape) {
totalSize *= dim;
}
return totalSize;
}
}

struct AlignedAllocLowering : public OpRewritePattern<memref::AllocaOp> {
using OpRewritePattern<memref::AllocaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::AllocaOp op,
PatternRewriter &rewriter) const final {
auto loc = op->getLoc();
MemRefType type = op.getMemref().getType();
Expand All @@ -54,66 +85,66 @@ struct AlignedAllocLowering : public OpRewritePattern<memref::AllocOp> {
return success();
}
};

struct AlignedDeallocLowering : public OpRewritePattern<memref::DeallocOp> {
using OpRewritePattern<memref::DeallocOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::DeallocOp op,
PatternRewriter &rewriter) const final {
auto loc = op->getLoc();
Value memref = op.getMemref();
cpuruntime::DeallocOp newDeallocOp =
rewriter.create<cpuruntime::DeallocOp>(loc, memref);
if (hasParallelParent(op))
newDeallocOp.setThreadLocal(true);
rewriter.eraseOp(op);
return success();
}
};

struct ConvertMemRefToCPURuntime
: public impl::ConvertMemRefToCPURuntimeBase<ConvertMemRefToCPURuntime> {

void runOnOperation() final {
auto *ctx = &getContext();
// Create a local set to store operations that should not be transformed.
llvm::SmallSet<Operation *, 16> noTransformOps;

// Walk through the module to find func::FuncOp instances.
// Create deallocOp corresponding to the alloca's location
getOperation()->walk([&](func::FuncOp funcOp) {
BufferViewFlowAnalysis analysis(funcOp);
// Now walk through the operations within the func::FuncOp.
funcOp.walk([&](Operation *op) {
if (op->hasTrait<OpTrait::ReturnLike>()) {
for (Value operand : op->getOperands()) {
if (isa<MemRefType>(operand.getType())) {
auto aliases = analysis.resolveReverse(operand);
// Check if any of the returned memref is allocated within scope.
for (auto &&alias : aliases) {
if (Operation *allocOp =
alias.getDefiningOp<memref::AllocOp>()) {
noTransformOps.insert(allocOp);
}
}
}
}
// Vector to store alloca operations
SmallVector<memref::AllocaOp, 16> allocaOps;
// Collect all alloca operations
funcOp.walk([&](memref::AllocaOp allocaOp) {
uint64_t allocSize =
getMemRefSizeInBytes(allocaOp.getResult().getType());
if (allocSize < STACK_ALLOC_THRESHOLD) {
noTransformOps.insert(allocaOp);
return;
}
allocaOps.push_back(allocaOp);
});

// Create dealloc operations in reverse order of alloca operations
for (auto allocaOp = allocaOps.rbegin(); allocaOp != allocaOps.rend();
++allocaOp) {
Operation *scopeOp =
(*allocaOp)
->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
OpBuilder builder(*allocaOp);
Region &scopeRegion = scopeOp->getRegion(0);
// Set the insertion point to the end of the region before the
// terminator
Block &lastBlock = scopeRegion.back();
builder.setInsertionPointToEnd(&lastBlock);
if (!lastBlock.empty() &&
lastBlock.back().hasTrait<OpTrait::IsTerminator>()) {
builder.setInsertionPoint(&lastBlock.back());
}

// Create the dealloc operation
auto deallocOp = builder.create<cpuruntime::DeallocOp>(
(*allocaOp).getLoc(), (*allocaOp).getResult());
if (hasParallelParent(*allocaOp)) {
deallocOp.setThreadLocal(true);
}
}
});

// add lowering target
ConversionTarget target(getContext());
// Make all operations legal by default.
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
target.addDynamicallyLegalOp<memref::AllocOp, memref::DeallocOp>(
[&](Operation *op) {
// Return true if the operation is in the noTransformOps set, making
// it dynamically legal.
return noTransformOps.find(op) != noTransformOps.end();
});
target.addDynamicallyLegalOp<memref::AllocaOp>([&](Operation *op) {
// Return true if the operation is in the noTransformOps set, making
// it dynamically legal.
return noTransformOps.find(op) != noTransformOps.end();
});
// set pattern
RewritePatternSet patterns(ctx);
patterns.add<AlignedAllocLowering>(ctx);
patterns.add<AlignedDeallocLowering>(ctx);
// perform conversion
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
Expand Down
8 changes: 5 additions & 3 deletions lib/gc/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include <climits>

#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
Expand Down Expand Up @@ -110,9 +111,10 @@ void populateBufferizationPasses(mlir::OpPassManager &pm) {
opt.hoistStaticAllocs = true;
pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt));
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
pm.addNestedPass<func::FuncOp>(
bufferization::createPromoteBuffersToStackPass());
bufferization::BufferDeallocationPipelineOptions deallocOption;
pm.addNestedPass<func::FuncOp>(bufferization::createPromoteBuffersToStackPass(
/*maxAllocSizeInBytes*/ UINT_MAX,
/*maxRankOfAllocatedMemRef*/ 8));
mlir::bufferization::BufferDeallocationPipelineOptions deallocOption;
bufferization::buildBufferDeallocationPipeline(pm, deallocOption);
pm.addPass(createBufferizationToMemRefPass());
populateCleanUpPasses(pm);
Expand Down
113 changes: 67 additions & 46 deletions test/mlir/test/gc/Dialect/CPURuntime/memref-to-cpuruntime.mlir
Original file line number Diff line number Diff line change
@@ -1,68 +1,89 @@
// RUN: gc-opt --split-input-file --convert-memref-to-cpuruntime %s -verify-diagnostics | FileCheck %s
func.func @alloc() {
// CHECK-LABEL: func @alloc()

func.func @alloca() {
// CHECK-LABEL: func @alloca()
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<1024xf32>
%m0 = memref.alloc() : memref<1024xf32>
%m0 = memref.alloca() : memref<1024xf32>
scf.forall (%i) in (32) {
}
// CHECK: cpuruntime.dealloc %[[m0]] : memref<1024xf32>
cpuruntime.dealloc %m0 : memref<1024xf32>
return
}

func.func @thread_alloc() {
// CHECK-LABEL: func.func @thread_alloc()
// CHECK: %[[m0:.*]] = cpuruntime.alloc thread_local() : memref<1024xf32>
func.func @thread_alloca() {
// CHECK-LABEL: func.func @thread_alloca()
// CHECK-NEXT: scf.forall {{.*}} {
// CHECK-NEXT: %[[m0:.*]] = cpuruntime.alloc thread_local() : memref<1024xf32>
// CHECK-NEXT: cpuruntime.dealloc thread_local %[[m0]] : memref<1024xf32>
// CHECK-NEXT: }
scf.forall (%i) in (32) {
%0 = memref.alloc() : memref<1024xf32>
// CHECK: cpuruntime.dealloc thread_local %[[m0]] : memref<1024xf32>
memref.dealloc %0 : memref<1024xf32>
%0 = memref.alloca() : memref<1024xf32>
}
return
}

func.func @return_alloc() -> memref<32x18xf32> {
// CHECK-LABEL: func @return_alloc() -> memref<32x18xf32>
// CHECK: %[[m0:.*]] = memref.alloc() : memref<32x18xf32>
%0 = memref.alloc() : memref<32x18xf32>
return %0 : memref<32x18xf32>
func.func @dynamic_ranked_alloca(%arg0: memref<*xf32>) {
// CHECK-LABEL: func @dynamic_ranked_alloca(%arg0: memref<*xf32>)
// CHECK: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32>
// CHECK: %[[m0:.*]] = cpuruntime.alloc(%[[RANK]]) : memref<?xindex>
// CHECK: cpuruntime.dealloc %[[m0]] : memref<?xindex>
%0 = memref.rank %arg0 : memref<*xf32>
%alloca = memref.alloca(%0) : memref<?xindex>
return
}

func.func @yield_alloc() -> memref<32x18xf32> {
// CHECK-LABEL: func @yield_alloc() -> memref<32x18xf32>
// CHECK: %[[m0:.*]] = memref.alloc() : memref<32x18xf32>
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%lastBuffer = memref.alloc() : memref<32x18xf32>
scf.for %arg3 = %c0 to %c32 step %c1 iter_args(%arg1 = %lastBuffer) -> (memref<32x18xf32>) {
%newBuffer = memref.alloc() : memref<32x18xf32>
memref.dealloc %arg1 : memref<32x18xf32>
scf.yield %newBuffer : memref<32x18xf32>
func.func @loop_nested_if_alloca(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<64xf32>) {
// CHECK-LABEL: func @loop_nested_if_alloca(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<64xf32>)
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<64xf32>
%alloca = memref.alloca() : memref<64xf32>
%0 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %arg3) -> (memref<64xf32>) {
%1 = arith.cmpi eq, %arg5, %arg1 : index
%2 = scf.if %1 -> (memref<64xf32>) {
// CHECK: yield %[[m0]] : memref<64xf32>
scf.yield %alloca : memref<64xf32>
} else {
// CHECK: %[[m1:.*]] = memref.alloca() : memref<2xf32>
%alloca_0 = memref.alloca() : memref<2xf32>
scf.yield %arg6 : memref<64xf32>
}
scf.yield %2 : memref<64xf32>
}
return %lastBuffer : memref<32x18xf32>
// CHECK: cpuruntime.dealloc %[[m0]] : memref<64xf32>
return
}

func.func @return_view_alloc() -> memref<16xf32> {
// CHECK-LABEL: func @return_view_alloc() -> memref<16xf32>
// CHECK: %[[m0:.*]] = memref.alloc() : memref<128xi8>
%c0 = arith.constant 0: index
%f0 = arith.constant 0.0: f32
%alloc = memref.alloc() : memref<128xi8>
%view = memref.view %alloc[%c0][] : memref<128xi8> to memref<32xf32>
%subview = memref.subview %view[0][16][1] : memref<32xf32> to memref<16xf32>
return %subview : memref<16xf32>
func.func @alloca_sequence() {
// CHECK-LABEL: func @alloca_sequence()
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<128xf32>
// CHECK: %[[m1:.*]] = cpuruntime.alloc() : memref<128xf32>
// CHECK: %[[m2:.*]] = cpuruntime.alloc() : memref<128xf32>
%alloc = memref.alloca() : memref<128xf32>
%alloc_0 = memref.alloca() : memref<128xf32>
%alloc_1 = memref.alloca() : memref<128xf32>
// CHECK: cpuruntime.dealloc %[[m2]] : memref<128xf32>
// CHECK: cpuruntime.dealloc %[[m1]] : memref<128xf32>
// CHECK: cpuruntime.dealloc %[[m0]] : memref<128xf32>
return
}

func.func @alloc_dealloc_view() {
// CHECK-LABEL: func @alloc_dealloc_view()
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<128xi8>
%c0 = arith.constant 0: index
%f0 = arith.constant 0.0: f32
%alloc = memref.alloc() : memref<128xi8>
%view = memref.view %alloc[%c0][] : memref<128xi8> to memref<32xf32>
%subview = memref.subview %view[0][16][1] : memref<32xf32> to memref<16xf32>
// CHECK: cpuruntime.dealloc
memref.dealloc %subview : memref<16xf32>
func.func @nested_alloca() {
// CHECK-LABEL: func @nested_alloca()
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<512xf32>
// CHECK-NEXT: scf.forall {{.*}} {
// CHECK-NEXT: %[[m1:.*]] = cpuruntime.alloc thread_local() : memref<32xf32>
// CHECK-NEXT: cpuruntime.dealloc thread_local %[[m1]] : memref<32xf32>
// CHECK-NEXT: }
// CHECK-NEXT: scf.forall {{.*}} {
// CHECK-NEXT: %[[m2:.*]] = cpuruntime.alloc thread_local() : memref<64xf32>
// CHECK-NEXT: cpuruntime.dealloc thread_local %[[m2]] : memref<64xf32>
// CHECK-NEXT: }
// CHECK: cpuruntime.dealloc %[[m0]] : memref<512xf32>
%0 = memref.alloca() : memref<512xf32>
scf.forall (%i) in (32) {
%1 = memref.alloca() : memref<32xf32>
}
scf.forall (%i) in (32) {
%1 = memref.alloca() : memref<64xf32>
}
return
}