Skip to content

[Transform] Hoist thread-local allocator within the nested parallel loops #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
45a02a4
[mlir][Memref] Add memref-merge optimization
May 8, 2024
1615878
[tests] Add example MLIR-unittest in lit
May 9, 2024
d1225c0
format
May 9, 2024
d107580
Merge branch 'yijie/unittest' into yijie/mem-merge
May 9, 2024
eaf2667
add test
May 9, 2024
f868629
update doc
May 9, 2024
504c785
doc
May 9, 2024
ef3d150
handle i1
May 10, 2024
21fe5fa
Merge remote-tracking branch 'origin/main' into yijie/mem-merge
May 15, 2024
5db4be4
trigger
May 15, 2024
84a17ac
fix
May 15, 2024
87a7fb6
remove cprt
May 15, 2024
42d612b
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
May 20, 2024
115fd66
update
May 20, 2024
26adb18
fix lint
May 20, 2024
f93f1d2
Merge remote-tracking branch 'origin/main' into yijie/mem-merge
Jun 3, 2024
36354ea
rename
Jun 3, 2024
82ab370
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Aug 14, 2024
1d3b887
fix
Aug 14, 2024
e285e99
make checker happy
Aug 14, 2024
5990627
fix tidy
Aug 14, 2024
eecc19f
make you happy
Aug 15, 2024
b3541f0
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Aug 15, 2024
57ddeab
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Aug 20, 2024
e5a2f83
fix
Aug 20, 2024
66309fc
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
Aug 21, 2024
5953b13
port memref-hoist
ciyongch Aug 21, 2024
ddf69b0
update test cases
ciyongch Aug 21, 2024
8401ca0
update mlir test
ciyongch Aug 26, 2024
43f27a7
refactor and add new cases
ciyongch Aug 26, 2024
75a88aa
fix tidy
ciyongch Aug 26, 2024
a81249d
Merge branch 'main' into ciyong/memref_hoist_v2
ciyongch Sep 9, 2024
8f5325d
Merge branch 'main' into ciyong/memref_hoist_v2
ciyongch Sep 12, 2024
5c8fdf4
restore unchanged code
ciyongch Sep 12, 2024
d2c7cdc
Merge remote-tracking branch 'origin/main' into ciyong/memref_hoist_v2
ciyongch Sep 12, 2024
fdede5b
address comments
ciyongch Sep 13, 2024
4708126
Merge remote-tracking branch 'origin/main' into ciyong/memref_hoist_v2
ciyongch Sep 18, 2024
cc25380
Merge branch 'main' into ciyong/memref_hoist_v2
ciyongch Sep 20, 2024
514664e
address comment
ciyongch Sep 20, 2024
ac8452f
address comment
ciyongch Sep 20, 2024
99a2c34
fix lint
ciyongch Sep 20, 2024
1e80344
support multiple vars within a forall loop
ciyongch Sep 23, 2024
8daedf9
Merge remote-tracking branch 'origin/main' into ciyong/memref_hoist_v2
ciyongch Sep 23, 2024
991e1fc
fix lint
ciyongch Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 208 additions & 15 deletions lib/gc/Transforms/MergeAllocTickBased.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,65 @@ using namespace special_ticks;
/// and default memory space.
static bool isMemRefTypeOk(MemRefType type) { return type.hasStaticShape(); }

static inline int64_t getSizeInBytes(MemRefType &memType) {
// treat bool (i1) as 1 byte. It may not be true for all targets, but we at
// least have a large enough size for i1
int64_t size = memType.getElementTypeBitWidth() / 8;
size = (size > 0) ? size : 1;
for (auto v : memType.getShape()) {
size *= v;
}
return size;
}

static bool needsHoistOutOfParallelLoop(Operation *op) {
Operation *parent =
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
if (isa_and_nonnull<scf::ForallOp>(parent)) {
// check if the current allocation is between the nested pfor, and use
// inside the inner parallel loop
SmallVector<Operation *, 4> parallelOpInCurBlock;
Block *curBlock = op->getBlock();
for (auto &curOp : curBlock->getOperations()) {
if (isa<scf::ForallOp>(curOp)) {
parallelOpInCurBlock.push_back(&curOp);
}
}

if (parallelOpInCurBlock.empty())
return false;

for (auto *use : op->getUsers()) {
for (auto *parallelOp : parallelOpInCurBlock) {
if (parallelOp->isAncestor(use)) {
return true;
}
}
}
}

return false;
}

static bool isForallLoopBoundStatic(Operation *op) {
auto forallOp = dyn_cast<scf::ForallOp>(op);
if (!forallOp)
return false;

auto lbs = forallOp.getMixedLowerBound();
auto ubs = forallOp.getMixedUpperBound();
auto steps = forallOp.getMixedStep();
auto allConstantValue = [](SmallVector<OpFoldResult> vals) -> bool {
return llvm::all_of(vals, [](OpFoldResult val) {
std::optional<int64_t> const_val = getConstantIntValue(val);
return const_val.has_value();
});
};

return allConstantValue(lbs) && allConstantValue(ubs) &&
allConstantValue(steps);
}

void Tick::update(int64_t tick) {
if (tick == UNTRACEABLE_ACCESS) {
firstAccess = UNTRACEABLE_ACCESS;
Expand Down Expand Up @@ -180,28 +239,60 @@ bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op,
// trait, and is not scf.for
Operation *TickCollecter::getAllocScope(TickCollecterStates *s,
Operation *op) const {
auto parent = op;
Operation *parent = op;
bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop(op);

for (;;) {
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
if (!parent) {
return nullptr;
}
if (!isa<scf::ForOp>(parent)) {
return parent;
}

if (isa<scf::ForOp>(parent))
continue;

if (isa<scf::ForallOp>(parent) &&
(moveToUpperParellelLoop && isForallLoopBoundStatic(parent)))
continue;

return parent;
}
}

FailureOr<size_t> TickCollecter::getAllocSize(TickCollecterStates *s,
Operation *op) const {
auto refType = cast<MemRefType>(op->getResultTypes().front());
int64_t size = refType.getElementTypeBitWidth() / 8;
// treat bool (i1) as 1 byte. It may not be true for all targets, but we at
// least have a large enough size for i1
size = (size != 0) ? size : 1;
for (auto v : refType.getShape()) {
size *= v;

// Get the total number of threads from the outermost to the current level of
// the parallel loop that the allocation located in.
int64_t numThreads = 1;
if (needsHoistOutOfParallelLoop(op)) {
Operation *parent =
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
while (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
if (!isForallLoopBoundStatic(forallOp))
break;

OpBuilder builder{forallOp->getContext()};
std::optional<int64_t> numIterations;
for (auto [lb, ub, step] : llvm::zip(forallOp.getLowerBound(builder),
forallOp.getUpperBound(builder),
forallOp.getStep(builder))) {
numIterations = constantTripCount(lb, ub, step);
if (numIterations.has_value()) {
numThreads *= numIterations.value();
} else {
return op->emitError("Expecting static loop range!");
}
}

parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
}
}
assert(numThreads > 0);

int64_t size = getSizeInBytes(refType);
size *= numThreads;
if (size > 0) {
return static_cast<size_t>(size);
}
Expand Down Expand Up @@ -391,11 +482,113 @@ Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Block *scope,
Value mergedAlloc,
int64_t byteOffset) const {
builder.setInsertionPoint(origAllocOp);
auto byteShift =
builder.create<arith::ConstantIndexOp>(origAllocOp->getLoc(), byteOffset);
return builder.create<memref::ViewOp>(origAllocOp->getLoc(),
origAllocOp->getResultTypes().front(),
mergedAlloc, byteShift, ValueRange{});
auto loc = origAllocOp->getLoc();
auto byteShift = builder.create<arith::ConstantIndexOp>(loc, byteOffset);

bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop(origAllocOp);
Operation *parent =
origAllocOp->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
if (!moveToUpperParellelLoop || !parent || !isa<scf::ForallOp>(parent))
return builder.create<memref::ViewOp>(loc,
origAllocOp->getResultTypes().front(),
mergedAlloc, byteShift, ValueRange{});

// get the aggregated inductorVar
Value inductVar;
bool isOuterMostLoop = true;
int64_t innerLoopUpperBound = 1;
while (parent) {
if (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
if (isForallLoopBoundStatic(forallOp)) {
SmallVector<Value> ubs = forallOp.getUpperBound(builder);
SmallVector<Value> lbs = forallOp.getLowerBound(builder);
SmallVector<Value> steps = forallOp.getStep(builder);
SmallVector<Value> inductionVars = forallOp.getInductionVars();

auto getCurrentVar = [&loc, &builder](Value var, Value lb,
Value step) -> Value {
if (!isConstantIntValue(lb, 0))
var = builder.create<arith::SubIOp>(loc, var, lb);

if (!isConstantIntValue(step, 1))
var = builder.create<arith::DivSIOp>(loc, var, step);
return var;
};

auto getAggregatedVar =
[&loc, &builder, &getCurrentVar](
const SmallVector<Value> &_lbs, const SmallVector<Value> &_ubs,
const SmallVector<Value> &_steps,
const SmallVector<Value> &_inductVars) -> Value {
Value var;
if (_ubs.size() == 1) {
var = getCurrentVar(_inductVars[0], _lbs[0], _steps[0]);
return var;
} else {
bool isFirstLoop = true;
for (auto [lb, ub, step, inductVar] :
llvm::zip(_lbs, _ubs, _steps, _inductVars)) {
if (isFirstLoop) {
var = getCurrentVar(inductVar, lb, step);
isFirstLoop = false;
} else {
Value cur_var = getCurrentVar(inductVar, lb, step);
std::optional<int64_t> bound = constantTripCount(lb, ub, step);
assert(bound.has_value());
Value boundVal =
builder.create<arith::ConstantIndexOp>(loc, bound.value());
Value tmpVal =
builder.create<arith::MulIOp>(loc, var, boundVal);
var = builder.create<arith::AddIOp>(loc, tmpVal, cur_var);
}
}
return var;
}
};

if (isOuterMostLoop) {
inductVar = getAggregatedVar(lbs, ubs, steps, inductionVars);
isOuterMostLoop = false;
} else {
Value currentVar = getAggregatedVar(lbs, ubs, steps, inductionVars);

Value innerLoopBoundVal =
builder.create<arith::ConstantIndexOp>(loc, innerLoopUpperBound);
Value intermediateVal =
builder.create<arith::MulIOp>(loc, currentVar, innerLoopBoundVal);
inductVar =
builder.create<arith::AddIOp>(loc, inductVar, intermediateVal);
}
// get aggregated loop bound
for (auto [lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
std::optional<int64_t> cur_bound = constantTripCount(lb, ub, step);
assert(cur_bound.has_value());
innerLoopUpperBound *= cur_bound.value();
}
}
}

parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
}

if (!isOuterMostLoop) {
// get original shape size
auto memType = cast<MemRefType>(origAllocOp->getResultTypes().front());
int64_t size = getSizeInBytes(memType);
Value origSize = builder.create<arith::ConstantIndexOp>(loc, size);
Value offsetPerThread =
builder.create<arith::MulIOp>(loc, inductVar, origSize);
Value byteShiftPerThread =
builder.create<arith::AddIOp>(loc, byteShift, offsetPerThread);

return builder.create<memref::ViewOp>(
loc, origAllocOp->getResultTypes().front(), mergedAlloc,
byteShiftPerThread, ValueRange{});
} else {
return builder.create<memref::ViewOp>(loc,
origAllocOp->getResultTypes().front(),
mergedAlloc, byteShift, ValueRange{});
}
}

LogicalResult
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/test/gc/Transforms/buffer-merge-lifetime.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func.func @alias_ref(%pred : i1) {
// CHECK-DAG: func.func @escape_from_if() attributes {__mergealloc_scope = [[TOPSCOPE5:[0-9]+]]
func.func @escape_from_if() {
%ctrue = arith.constant 1 : i1
// check that f lives at the whole range of the following scf.if
// check that f lives at the whole range of the following scf.if
// CHECK-DAG: %[[F:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE5]], 4, 13>}
%f = memref.alloc() : memref<8x64xf32>
// tick of the scf.if starts from 4 and ends at 14
Expand Down
Loading