Skip to content

[MLIR] Cache symbol tables during OneShotBufferization analyses #138125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
/// analyzed.
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;

/// A collection of cached SymbolTables used for faster function lookup.
mutable SymbolTableCollection symbolTables;

/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
void startFunctionAnalysis(FuncOp funcOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,29 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
}

/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
static FuncOp getCalledFunction(CallOpInterface callOp,
SymbolTableCollection &symbolTables) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
symbolTables.lookupNearestSymbolFrom(callOp, sym));
}

/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp,
const AnalysisState &state) {
auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);

if (auto *funcAnalysisState =
oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
// Use the cached symbol tables.
return getCalledFunction(callOp, funcAnalysisState->symbolTables);
}

SymbolTableCollection symbolTables;
return getCalledFunction(callOp, symbolTables);
}

/// Get FuncAnalysisState.
Expand Down Expand Up @@ -135,7 +151,7 @@ struct CallOpInterface
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
FuncOp funcOp = getCalledFunction(callOp, state);
assert(funcOp && "expected CallOp to a FuncOp");

if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
Expand All @@ -150,7 +166,7 @@ struct CallOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
FuncOp funcOp = getCalledFunction(callOp, state);
assert(funcOp && "expected CallOp to a FuncOp");

if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
Expand All @@ -165,7 +181,7 @@ struct CallOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
FuncOp funcOp = getCalledFunction(callOp, state);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Any OpResult may be aliasing.
Expand Down Expand Up @@ -199,7 +215,11 @@ struct CallOpInterface
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);

// TODO Avoid recomputing the symbol tables every time.
SymbolTableCollection symbolTable;

FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");

// If the callee was already bufferized, we can directly take the type from
Expand Down Expand Up @@ -243,7 +263,11 @@ struct CallOpInterface
// 2. Rewrite tensor operands as memrefs based on type of the already
// bufferized callee.
SmallVector<Value> newOperands;
FuncOp funcOp = getCalledFunction(callOp);

// TODO Avoid recomputing the symbol tables every time.
SymbolTableCollection symbolTable;

FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
FunctionType funcType = funcOp.getFunctionType();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
}

/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(func::CallOp callOp) {
static func::FuncOp
getCalledFunction(func::CallOp callOp,
mlir::SymbolTableCollection &symbolTable) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
symbolTable.lookupNearestSymbolFrom(callOp, sym));
}

/// Return "true" if the given function signature has tensor semantics.
Expand Down Expand Up @@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;

// TODO Avoid recomputing the symbol tables every time.
mlir::SymbolTableCollection symbolTable;

for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
Expand Down