Skip to content

Commit acf3ae8

Browse files
committed
Support cpmplex topo
1 parent 4363915 commit acf3ae8

File tree

1 file changed

+98
-63
lines changed

1 file changed

+98
-63
lines changed

lib/gc/Transforms/CST.cpp

Lines changed: 98 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3131
#include "llvm/Support/Debug.h"
3232

33-
#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
33+
// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
3434

3535
namespace mlir {
3636
namespace gc {
@@ -300,12 +300,12 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
300300
// void *allocator(size_t size) { return std::aligned_alloc(64, size); }
301301
// void deallocator(void *ptr) { std::free(ptr); }
302302

303-
std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
304-
// simply allocate buffer and return
305-
std::shared_ptr<void> base = std::shared_ptr<void>{
306-
std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
307-
return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
308-
}
303+
// std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
304+
// // simply allocate buffer and return
305+
// std::shared_ptr<void> base = std::shared_ptr<void>{
306+
// std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
307+
// return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
308+
// }
309309

310310
size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; }
311311

@@ -329,12 +329,12 @@ struct constGraphTensorCacheManager {
329329
totalSize += divideAndCeil(buffersSize[i], 64) * 64;
330330
}
331331
llvm::dbgs() << "Alloc total size: " << totalSize << '\n';
332-
auto base = createConstCacheProxy(totalSize);
332+
// auto base = createConstCacheProxy(totalSize);
333333
std::vector<uint64_t> globalIds(buffersSize.size());
334334
size_t offset = 0;
335335
for (size_t i = 0; i < buffersSize.size(); i++) {
336336
llvm::dbgs() << "Alloc offset: " << offset << '\n';
337-
regCachedTensor(cachedTensorGlobalId, base, offset);
337+
// regCachedTensor(cachedTensorGlobalId, base, offset);
338338
globalIds[i] = cachedTensorGlobalId;
339339
++cachedTensorGlobalId;
340340
offset += divideAndCeil(buffersSize[i], 64) * 64;
@@ -431,11 +431,11 @@ void CST::runOnOperation() {
431431
// values of folded constant weights in original block
432432
SmallVector<Value> outputValues;
433433
Value v;
434-
// TODO: solve complicated topology. Currently we only handle simple topology
435-
// where one constant weight input will and only will produce one constant
436-
// output and each constant weight only contributes to one constant output.
434+
// Support complicated topology.
437435
for (size_t id = 0; id < block.getNumArguments(); ++id) {
438436
if (constArgsIndexes.count(id) == 1) {
437+
// The constant ops are all single-input single-output.
438+
bool simpleTopo = true;
439439
auto arg = block.getArgument(id);
440440
if (!isa<TensorType>(arg.getType())) {
441441
continue;
@@ -444,54 +444,72 @@ void CST::runOnOperation() {
444444
v = dyn_cast<Value>(arg);
445445
inputValues.push_back(v);
446446
SmallVector<Value> valuesOnTheWay = {v}; // the constant tensors
447+
std::deque<Value> dq;
448+
dq.push_back(v);
447449
// For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2
448-
while (!v.getUsers().empty()) {
449-
// v.getUsers().size() should be 1
450-
Operation *user = *(v.getUsers().begin());
451-
// If user is not const or user has multiple operand, we reach the end
452-
if (!isInConstantSubgraph(user) || !singleOperand(user)) {
453-
outputTypes.push_back(v.getType());
454-
outputValues.push_back(v);
455-
break;
450+
while (!dq.empty()) {
451+
v = dq.front();
452+
dq.pop_front();
453+
// if the children ops of v are not all constant, we end at v
454+
if (std::any_of(v.getUsers().begin(), v.getUsers().end(),
455+
[](Operation *child) {
456+
return !isInConstantSubgraph(child);
457+
})) {
458+
if (std::find(outputValues.begin(), outputValues.end(), v) ==
459+
outputValues.end()) {
460+
outputTypes.push_back(v.getType());
461+
outputValues.push_back(v);
462+
}
463+
continue;
464+
}
465+
if (!v.hasOneUse()) {
466+
simpleTopo = false;
467+
}
468+
// the children ops of v are all constant, we push their results to
469+
// queue
470+
for (Operation *child : v.getUsers()) {
471+
if (!singleOperand(child) || child->getResults().size() > 1) {
472+
simpleTopo = false;
473+
}
474+
for (OpResult result : child->getResults()) {
475+
auto r = dyn_cast<Value>(result);
476+
dq.push_back(r);
477+
valuesOnTheWay.push_back(r);
478+
}
456479
}
457-
// user should has only 1 output value
458-
OpResult result = *(user->result_begin());
459-
v = dyn_cast<Value>(result);
460-
valuesOnTheWay.push_back(v);
461480
}
462481

463482
// If data size of outputValue is too greater than size of inputValue, do
464483
// not fold it. Compare data size changes during traverse to find the last
465484
// op that satisfies this condition.
466-
int64_t initSize =
467-
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
468-
if (!isa<TensorType>(outputTypes.back()) ||
469-
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
470-
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
471-
size_t lastIdx = 0;
472-
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
473-
int64_t size =
474-
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
475-
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
476-
lastIdx = i;
485+
if (simpleTopo) {
486+
int64_t initSize =
487+
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
488+
if (!isa<TensorType>(outputTypes.back()) ||
489+
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
490+
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
491+
size_t lastIdx = 0;
492+
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
493+
int64_t size = getTensorSize(
494+
dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
495+
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
496+
lastIdx = i;
497+
}
498+
}
499+
if (lastIdx == 0) { // no suitable value found
500+
inputTypes.pop_back();
501+
outputTypes.pop_back();
502+
inputValues.pop_back();
503+
outputValues.pop_back();
504+
constArgsIndexes.erase(id);
505+
} else {
506+
outputTypes.back() = valuesOnTheWay[lastIdx].getType();
507+
outputValues.back() = valuesOnTheWay[lastIdx];
477508
}
478-
}
479-
if (lastIdx == 0) { // no suitable value found
480-
inputTypes.pop_back();
481-
outputTypes.pop_back();
482-
inputValues.pop_back();
483-
outputValues.pop_back();
484-
constArgsIndexes.erase(id);
485-
} else {
486-
outputTypes.back() = valuesOnTheWay[lastIdx].getType();
487-
outputValues.back() = valuesOnTheWay[lastIdx];
488509
}
489510
}
490511
}
491512
}
492-
if (inputTypes.size() != outputTypes.size()) {
493-
return;
494-
}
495513

496514
FunctionType foldFuncType =
497515
FunctionType::get(context, inputTypes, outputTypes);
@@ -548,30 +566,34 @@ void CST::runOnOperation() {
548566
moduleOp.push_back(foldFunc);
549567
symbolTable.insert(foldFunc);
550568

569+
// the indexes of args to the folding func.
551570
SmallVector<int32_t> foldArgs;
571+
// the indexes of folded args.
552572
SmallVector<int32_t> foldIds;
573+
// the indexes of args to the computing func.
553574
SmallVector<int32_t> computeArgs;
554575

555576
// modify the BlockArguments of block
556577
size_t oriNumArgs = block.getNumArguments();
557-
size_t argIdx = 0;
578+
// Add the folded args to the end of BlockArguments list
579+
for (size_t id = 0; id < outputValues.size(); ++id) {
580+
auto loc = block.getArgument(id).getLoc();
581+
BlockArgument foldArg =
582+
block.insertArgument(oriNumArgs + id, outputTypes[id], loc);
583+
outputValues[id].replaceUsesWithIf(foldArg, [&](OpOperand &val) {
584+
Operation *op = val.getOwner();
585+
return op->getBlock() == &block;
586+
});
587+
foldIds.push_back(id + oriNumArgs);
588+
}
589+
// Erase the operations on constant args
558590
for (size_t id = 0; id < oriNumArgs; ++id) {
559591
if (constArgsIndexes.count(id) == 1) {
560592
foldArgs.push_back(id);
561-
foldIds.push_back(argIdx + oriNumArgs);
562-
computeArgs.push_back(argIdx + oriNumArgs);
563-
auto loc = block.getArgument(id).getLoc();
564-
BlockArgument foldArg =
565-
block.insertArgument(id, outputTypes[argIdx], loc);
566-
outputValues[argIdx].replaceUsesWithIf(foldArg, [&](OpOperand &val) {
567-
Operation *op = val.getOwner();
568-
return op->getBlock() == &block;
569-
});
570-
571593
std::deque<Value> dq;
572594
SmallVector<Operation *> opsToErase;
573595
std::unordered_set<Operation *> opsToEraseSet;
574-
dq.push_back(block.getArgument(id + 1));
596+
dq.push_back(block.getArgument(id));
575597
while (!dq.empty()) {
576598
Value v = dq.front();
577599
dq.pop_front();
@@ -586,16 +608,26 @@ void CST::runOnOperation() {
586608
opsToEraseSet.insert(op);
587609
}
588610
}
589-
590611
for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) {
591612
(*it)->erase();
592613
}
593-
block.eraseArgument(id + 1);
594-
++argIdx;
595614
} else {
596615
computeArgs.push_back(id);
597616
}
598617
}
618+
// Erase the constant args in BlockArguments list
619+
llvm::BitVector argsToErase;
620+
for (size_t id = 0; id < oriNumArgs; ++id) {
621+
if (constArgsIndexes.count(id) == 1) {
622+
argsToErase.push_back(true);
623+
} else {
624+
argsToErase.push_back(false);
625+
}
626+
}
627+
for (size_t id = 0; id < outputValues.size(); ++id) {
628+
argsToErase.push_back(false);
629+
}
630+
block.eraseArguments(argsToErase);
599631

600632
for (auto id : foldIds) {
601633
foldArgs.insert(foldArgs.end(), id);
@@ -604,6 +636,9 @@ void CST::runOnOperation() {
604636
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args",
605637
foldArgs);
606638

639+
for (auto id : foldIds) {
640+
computeArgs.insert(computeArgs.end(), id);
641+
}
607642
computeArgs.insert(computeArgs.begin(), computeArgs.size());
608643
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args",
609644
computeArgs);

0 commit comments

Comments
 (0)