Skip to content

Commit d72b1f1

Browse files
Ivy Zhangzhczhong
andauthored
[Transform] Only use gc runtime allocator for stack-like alloca ops (#287)
* fallback memref transformation when alloc / dealloc not in FILO fashion * align gc bufferization pipeline * only convert stack-style alloca to cgc runtime allocator, with dealloc inserted * add test --------- Co-authored-by: Zhong, Zhicong <zhicong.zhong@intel.com>
1 parent bc7f33b commit d72b1f1

File tree

3 files changed

+147
-93
lines changed

3 files changed

+147
-93
lines changed

lib/gc/Transforms/MemRefToCPURuntime.cpp

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ namespace gc {
2828

2929
namespace {
3030

31+
constexpr uint64_t STACK_ALLOC_THRESHOLD = 128;
32+
3133
bool hasParallelParent(Operation *op) {
3234
// Check if the parent contains a forall / parallel loop
3335
for (Operation *parentOp = op->getParentOp(); parentOp != nullptr;
@@ -38,9 +40,38 @@ bool hasParallelParent(Operation *op) {
3840
}
3941
return false;
4042
}
41-
struct AlignedAllocLowering : public OpRewritePattern<memref::AllocOp> {
42-
using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
43-
LogicalResult matchAndRewrite(memref::AllocOp op,
43+
44+
uint64_t getMemRefSizeInBytes(MemRefType memrefType) {
45+
if (ShapedType::isDynamicShape(memrefType.getShape()))
46+
return UINT64_MAX;
47+
ShapedType shapeType = cast<ShapedType>(memrefType);
48+
int elementSize = shapeType.getElementTypeBitWidth() / 8;
49+
AffineMap layout = memrefType.getLayout().getAffineMap();
50+
ArrayRef<int64_t> shape = memrefType.getShape();
51+
if (!layout.isIdentity()) {
52+
int64_t offset;
53+
SmallVector<int64_t, 4> strides;
54+
if (failed(getStridesAndOffset(memrefType, strides, offset))) {
55+
return UINT64_MAX;
56+
}
57+
58+
int totalSize = elementSize;
59+
for (size_t i = 0; i < shape.size(); ++i) {
60+
totalSize *= (i == shape.size() - 1) ? strides[i] : shape[i];
61+
}
62+
return totalSize;
63+
} else {
64+
int totalSize = elementSize;
65+
for (int64_t dim : shape) {
66+
totalSize *= dim;
67+
}
68+
return totalSize;
69+
}
70+
}
71+
72+
struct AlignedAllocLowering : public OpRewritePattern<memref::AllocaOp> {
73+
using OpRewritePattern<memref::AllocaOp>::OpRewritePattern;
74+
LogicalResult matchAndRewrite(memref::AllocaOp op,
4475
PatternRewriter &rewriter) const final {
4576
auto loc = op->getLoc();
4677
MemRefType type = op.getMemref().getType();
@@ -54,66 +85,66 @@ struct AlignedAllocLowering : public OpRewritePattern<memref::AllocOp> {
5485
return success();
5586
}
5687
};
57-
58-
struct AlignedDeallocLowering : public OpRewritePattern<memref::DeallocOp> {
59-
using OpRewritePattern<memref::DeallocOp>::OpRewritePattern;
60-
LogicalResult matchAndRewrite(memref::DeallocOp op,
61-
PatternRewriter &rewriter) const final {
62-
auto loc = op->getLoc();
63-
Value memref = op.getMemref();
64-
cpuruntime::DeallocOp newDeallocOp =
65-
rewriter.create<cpuruntime::DeallocOp>(loc, memref);
66-
if (hasParallelParent(op))
67-
newDeallocOp.setThreadLocal(true);
68-
rewriter.eraseOp(op);
69-
return success();
70-
}
71-
};
72-
7388
struct ConvertMemRefToCPURuntime
7489
: public impl::ConvertMemRefToCPURuntimeBase<ConvertMemRefToCPURuntime> {
7590

7691
void runOnOperation() final {
7792
auto *ctx = &getContext();
78-
// Create a local set to store operations that should not be transformed.
7993
llvm::SmallSet<Operation *, 16> noTransformOps;
8094

81-
// Walk through the module to find func::FuncOp instances.
95+
// Create deallocOp corresponding to the alloca's location
8296
getOperation()->walk([&](func::FuncOp funcOp) {
83-
BufferViewFlowAnalysis analysis(funcOp);
84-
// Now walk through the operations within the func::FuncOp.
85-
funcOp.walk([&](Operation *op) {
86-
if (op->hasTrait<OpTrait::ReturnLike>()) {
87-
for (Value operand : op->getOperands()) {
88-
if (isa<MemRefType>(operand.getType())) {
89-
auto aliases = analysis.resolveReverse(operand);
90-
// Check if any of the returned memref is allocated within scope.
91-
for (auto &&alias : aliases) {
92-
if (Operation *allocOp =
93-
alias.getDefiningOp<memref::AllocOp>()) {
94-
noTransformOps.insert(allocOp);
95-
}
96-
}
97-
}
98-
}
97+
// Vector to store alloca operations
98+
SmallVector<memref::AllocaOp, 16> allocaOps;
99+
// Collect all alloca operations
100+
funcOp.walk([&](memref::AllocaOp allocaOp) {
101+
uint64_t allocSize =
102+
getMemRefSizeInBytes(allocaOp.getResult().getType());
103+
if (allocSize < STACK_ALLOC_THRESHOLD) {
104+
noTransformOps.insert(allocaOp);
105+
return;
99106
}
107+
allocaOps.push_back(allocaOp);
100108
});
109+
110+
// Create dealloc operations in reverse order of alloca operations
111+
for (auto allocaOp = allocaOps.rbegin(); allocaOp != allocaOps.rend();
112+
++allocaOp) {
113+
Operation *scopeOp =
114+
(*allocaOp)
115+
->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
116+
OpBuilder builder(*allocaOp);
117+
Region &scopeRegion = scopeOp->getRegion(0);
118+
// Set the insertion point to the end of the region before the
119+
// terminator
120+
Block &lastBlock = scopeRegion.back();
121+
builder.setInsertionPointToEnd(&lastBlock);
122+
if (!lastBlock.empty() &&
123+
lastBlock.back().hasTrait<OpTrait::IsTerminator>()) {
124+
builder.setInsertionPoint(&lastBlock.back());
125+
}
126+
127+
// Create the dealloc operation
128+
auto deallocOp = builder.create<cpuruntime::DeallocOp>(
129+
(*allocaOp).getLoc(), (*allocaOp).getResult());
130+
if (hasParallelParent(*allocaOp)) {
131+
deallocOp.setThreadLocal(true);
132+
}
133+
}
101134
});
102135

103136
// add lowering target
104137
ConversionTarget target(getContext());
105138
// Make all operations legal by default.
106139
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
107-
target.addDynamicallyLegalOp<memref::AllocOp, memref::DeallocOp>(
108-
[&](Operation *op) {
109-
// Return true if the operation is in the noTransformOps set, making
110-
// it dynamically legal.
111-
return noTransformOps.find(op) != noTransformOps.end();
112-
});
140+
target.addDynamicallyLegalOp<memref::AllocaOp>([&](Operation *op) {
141+
// Return true if the operation is in the noTransformOps set, making
142+
// it dynamically legal.
143+
return noTransformOps.find(op) != noTransformOps.end();
144+
});
113145
// set pattern
114146
RewritePatternSet patterns(ctx);
115147
patterns.add<AlignedAllocLowering>(ctx);
116-
patterns.add<AlignedDeallocLowering>(ctx);
117148
// perform conversion
118149
if (failed(
119150
applyFullConversion(getOperation(), target, std::move(patterns)))) {

lib/gc/Transforms/Pipeline.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Pass/PassManager.h"
2525
#include "mlir/Support/LogicalResult.h"
2626
#include "mlir/Transforms/Passes.h"
27+
#include <climits>
2728

2829
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
2930
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
@@ -110,9 +111,10 @@ void populateBufferizationPasses(mlir::OpPassManager &pm) {
110111
opt.hoistStaticAllocs = true;
111112
pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt));
112113
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
113-
pm.addNestedPass<func::FuncOp>(
114-
bufferization::createPromoteBuffersToStackPass());
115-
bufferization::BufferDeallocationPipelineOptions deallocOption;
114+
pm.addNestedPass<func::FuncOp>(bufferization::createPromoteBuffersToStackPass(
115+
/*maxAllocSizeInBytes*/ UINT_MAX,
116+
/*maxRankOfAllocatedMemRef*/ 8));
117+
mlir::bufferization::BufferDeallocationPipelineOptions deallocOption;
116118
bufferization::buildBufferDeallocationPipeline(pm, deallocOption);
117119
pm.addPass(createBufferizationToMemRefPass());
118120
populateCleanUpPasses(pm);
Lines changed: 67 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,89 @@
11
// RUN: gc-opt --split-input-file --convert-memref-to-cpuruntime %s -verify-diagnostics | FileCheck %s
2-
func.func @alloc() {
3-
// CHECK-LABEL: func @alloc()
2+
3+
func.func @alloca() {
4+
// CHECK-LABEL: func @alloca()
45
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<1024xf32>
5-
%m0 = memref.alloc() : memref<1024xf32>
6+
%m0 = memref.alloca() : memref<1024xf32>
67
scf.forall (%i) in (32) {
78
}
89
// CHECK: cpuruntime.dealloc %[[m0]] : memref<1024xf32>
9-
cpuruntime.dealloc %m0 : memref<1024xf32>
1010
return
1111
}
1212

13-
func.func @thread_alloc() {
14-
// CHECK-LABEL: func.func @thread_alloc()
15-
// CHECK: %[[m0:.*]] = cpuruntime.alloc thread_local() : memref<1024xf32>
13+
func.func @thread_alloca() {
14+
// CHECK-LABEL: func.func @thread_alloca()
15+
// CHECK-NEXT: scf.forall {{.*}} {
16+
// CHECK-NEXT: %[[m0:.*]] = cpuruntime.alloc thread_local() : memref<1024xf32>
17+
// CHECK-NEXT: cpuruntime.dealloc thread_local %[[m0]] : memref<1024xf32>
18+
// CHECK-NEXT: }
1619
scf.forall (%i) in (32) {
17-
%0 = memref.alloc() : memref<1024xf32>
18-
// CHECK: cpuruntime.dealloc thread_local %[[m0]] : memref<1024xf32>
19-
memref.dealloc %0 : memref<1024xf32>
20+
%0 = memref.alloca() : memref<1024xf32>
2021
}
2122
return
2223
}
2324

24-
func.func @return_alloc() -> memref<32x18xf32> {
25-
// CHECK-LABEL: func @return_alloc() -> memref<32x18xf32>
26-
// CHECK: %[[m0:.*]] = memref.alloc() : memref<32x18xf32>
27-
%0 = memref.alloc() : memref<32x18xf32>
28-
return %0 : memref<32x18xf32>
25+
func.func @dynamic_ranked_alloca(%arg0: memref<*xf32>) {
26+
// CHECK-LABEL: func @dynamic_ranked_alloca(%arg0: memref<*xf32>)
27+
// CHECK: %[[RANK:.*]] = memref.rank %{{.*}} : memref<*xf32>
28+
// CHECK: %[[m0:.*]] = cpuruntime.alloc(%[[RANK]]) : memref<?xindex>
29+
// CHECK: cpuruntime.dealloc %[[m0]] : memref<?xindex>
30+
%0 = memref.rank %arg0 : memref<*xf32>
31+
%alloca = memref.alloca(%0) : memref<?xindex>
32+
return
2933
}
3034

31-
func.func @yield_alloc() -> memref<32x18xf32> {
32-
// CHECK-LABEL: func @yield_alloc() -> memref<32x18xf32>
33-
// CHECK: %[[m0:.*]] = memref.alloc() : memref<32x18xf32>
34-
%c32 = arith.constant 32 : index
35-
%c1 = arith.constant 1 : index
36-
%c0 = arith.constant 0 : index
37-
%lastBuffer = memref.alloc() : memref<32x18xf32>
38-
scf.for %arg3 = %c0 to %c32 step %c1 iter_args(%arg1 = %lastBuffer) -> (memref<32x18xf32>) {
39-
%newBuffer = memref.alloc() : memref<32x18xf32>
40-
memref.dealloc %arg1 : memref<32x18xf32>
41-
scf.yield %newBuffer : memref<32x18xf32>
35+
func.func @loop_nested_if_alloca(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<64xf32>) {
36+
// CHECK-LABEL: func @loop_nested_if_alloca(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<64xf32>)
37+
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<64xf32>
38+
%alloca = memref.alloca() : memref<64xf32>
39+
%0 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %arg3) -> (memref<64xf32>) {
40+
%1 = arith.cmpi eq, %arg5, %arg1 : index
41+
%2 = scf.if %1 -> (memref<64xf32>) {
42+
// CHECK: yield %[[m0]] : memref<64xf32>
43+
scf.yield %alloca : memref<64xf32>
44+
} else {
45+
// CHECK: %[[m1:.*]] = memref.alloca() : memref<2xf32>
46+
%alloca_0 = memref.alloca() : memref<2xf32>
47+
scf.yield %arg6 : memref<64xf32>
48+
}
49+
scf.yield %2 : memref<64xf32>
4250
}
43-
return %lastBuffer : memref<32x18xf32>
51+
// CHECK: cpuruntime.dealloc %[[m0]] : memref<64xf32>
52+
return
4453
}
4554

46-
func.func @return_view_alloc() -> memref<16xf32> {
47-
// CHECK-LABEL: func @return_view_alloc() -> memref<16xf32>
48-
// CHECK: %[[m0:.*]] = memref.alloc() : memref<128xi8>
49-
%c0 = arith.constant 0: index
50-
%f0 = arith.constant 0.0: f32
51-
%alloc = memref.alloc() : memref<128xi8>
52-
%view = memref.view %alloc[%c0][] : memref<128xi8> to memref<32xf32>
53-
%subview = memref.subview %view[0][16][1] : memref<32xf32> to memref<16xf32>
54-
return %subview : memref<16xf32>
55+
func.func @alloca_sequence() {
56+
// CHECK-LABEL: func @alloca_sequence()
57+
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<128xf32>
58+
// CHECK: %[[m1:.*]] = cpuruntime.alloc() : memref<128xf32>
59+
// CHECK: %[[m2:.*]] = cpuruntime.alloc() : memref<128xf32>
60+
%alloc = memref.alloca() : memref<128xf32>
61+
%alloc_0 = memref.alloca() : memref<128xf32>
62+
%alloc_1 = memref.alloca() : memref<128xf32>
63+
// CHECK: cpuruntime.dealloc %[[m2]] : memref<128xf32>
64+
// CHECK: cpuruntime.dealloc %[[m1]] : memref<128xf32>
65+
// CHECK: cpuruntime.dealloc %[[m0]] : memref<128xf32>
66+
return
5567
}
5668

57-
func.func @alloc_dealloc_view() {
58-
// CHECK-LABEL: func @alloc_dealloc_view()
59-
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<128xi8>
60-
%c0 = arith.constant 0: index
61-
%f0 = arith.constant 0.0: f32
62-
%alloc = memref.alloc() : memref<128xi8>
63-
%view = memref.view %alloc[%c0][] : memref<128xi8> to memref<32xf32>
64-
%subview = memref.subview %view[0][16][1] : memref<32xf32> to memref<16xf32>
65-
// CHECK: cpuruntime.dealloc
66-
memref.dealloc %subview : memref<16xf32>
69+
func.func @nested_alloca() {
70+
// CHECK-LABEL: func @nested_alloca()
71+
// CHECK: %[[m0:.*]] = cpuruntime.alloc() : memref<512xf32>
72+
// CHECK-NEXT: scf.forall {{.*}} {
73+
// CHECK-NEXT: %[[m1:.*]] = cpuruntime.alloc thread_local() : memref<32xf32>
74+
// CHECK-NEXT: cpuruntime.dealloc thread_local %[[m1]] : memref<32xf32>
75+
// CHECK-NEXT: }
76+
// CHECK-NEXT: scf.forall {{.*}} {
77+
// CHECK-NEXT: %[[m2:.*]] = cpuruntime.alloc thread_local() : memref<64xf32>
78+
// CHECK-NEXT: cpuruntime.dealloc thread_local %[[m2]] : memref<64xf32>
79+
// CHECK-NEXT: }
80+
// CHECK: cpuruntime.dealloc %[[m0]] : memref<512xf32>
81+
%0 = memref.alloca() : memref<512xf32>
82+
scf.forall (%i) in (32) {
83+
%1 = memref.alloca() : memref<32xf32>
84+
}
85+
scf.forall (%i) in (32) {
86+
%1 = memref.alloca() : memref<64xf32>
87+
}
6788
return
6889
}

0 commit comments

Comments
 (0)