Skip to content

[MLIR] Make resolveCallable customizable in CallOpInterface #100361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions mlir/include/mlir/Interfaces/CallInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,20 @@ namespace mlir {
struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
};
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Interfaces/CallInterfaces.h.inc"
class CallOpInterface;

namespace call_interface_impl {

/// Resolve the callable operation for given callee to a CallableOpInterface, or
/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
/// parameter that will allow for using a cached symbol table for symbol lookups
/// instead of performing an O(N) scan.
Operation *resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable = nullptr);

} // namespace call_interface_impl

} // namespace mlir

namespace llvm {

Expand All @@ -41,4 +51,7 @@ struct CastInfo<To, const mlir::CallInterfaceCallable>

} // namespace llvm

/// Include the generated interface declarations.
#include "mlir/Interfaces/CallInterfaces.h.inc"

#endif // MLIR_INTERFACES_CALLINTERFACES_H
32 changes: 22 additions & 10 deletions mlir/include/mlir/Interfaces/CallInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,29 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
Returns the operands within this call that are used as arguments to the
callee as a mutable range.
}],
"::mlir::MutableOperandRange", "getArgOperandsMutable">,
"::mlir::MutableOperandRange", "getArgOperandsMutable"
>,
InterfaceMethod<[{
Resolve the callable operation for given callee to a
CallableOpInterface, or nullptr if a valid callable was not resolved.
`symbolTable` parameter allow for using a cached symbol table for symbol
lookups instead of performing an O(N) scan.
}],
"::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection *":$symbolTable),
/*methodBody=*/[{}], /*defaultImplementation=*/[{
return ::mlir::call_interface_impl::resolveCallable($_op, symbolTable);
}]
>,
InterfaceMethod<[{
Resolve the callable operation for given callee to a
CallableOpInterface, or nullptr if a valid callable was not resolved.
}],
"::mlir::Operation *", "resolveCallable", (ins),
/*methodBody=*/[{}], /*defaultImplementation=*/[{
return ::mlir::call_interface_impl::resolveCallable($_op);
}]
>
];

let extraClassDeclaration = [{
/// Resolve the callable operation for given callee to a
/// CallableOpInterface, or nullptr if a valid callable was not resolved.
/// `symbolTable` is an optional parameter that will allow for using a
/// cached symbol table for symbol lookups instead of performing an O(N)
/// scan.
::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr);
}];
}

/// Interface for callable operations.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Analysis/CallGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
CallGraphNode *
CallGraph::resolveCallable(CallOpInterface call,
SymbolTableCollection &symbolTable) const {
Operation *callable = call.resolveCallable(&symbolTable);
Operation *callable = call.resolveCallableInTable(&symbolTable);
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
}

void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
Operation *callableOp = call.resolveCallable(&symbolTable);
Operation *callableOp = call.resolveCallableInTable(&symbolTable);

// A call to a externally-defined callable has unknown predecessors.
const auto isExternalCallable = [this](Operation *op) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
// Find the callee.
Operation *callee = call.resolveCallable(&symbolTable);
Operation *callee = call.resolveCallableInTable(&symbolTable);

auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
// No region means the callee is only declared in this module.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
Operation *callableOp = call.resolveCallable(&symbolTable);
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
// Not all operands of a call op forward to arguments. Such operands are
// stored in `unaccounted`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
// the function is referenced by SSA value instead of a Symbol, it's assumed
// to be public. (And we cannot easily change the type of the SSA value
// anyway.)
Operation *funcOp = op.resolveCallable(state.getSymbolTable());
Operation *funcOp = op.resolveCallableInTable(state.getSymbolTable());
bool isPrivate = false;
if (auto symbol = dyn_cast_or_null<SymbolOpInterface>(funcOp))
isPrivate = symbol.isPrivate() && !symbol.isDeclaration();
Expand Down
12 changes: 4 additions & 8 deletions mlir/lib/Interfaces/CallInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,17 @@ using namespace mlir;
// CallOpInterface
//===----------------------------------------------------------------------===//

/// Resolve the callable operation for given callee to a CallableOpInterface, or
/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
/// parameter that will allow for using a cached symbol table for symbol lookups
/// instead of performing an O(N) scan.
Operation *
CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
CallInterfaceCallable callable = getCallableForCallee();
call_interface_impl::resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable) {
CallInterfaceCallable callable = call.getCallableForCallee();
if (auto symbolVal = dyn_cast<Value>(callable))
return symbolVal.getDefiningOp();

// If the callable isn't a value, lookup the symbol reference.
auto symbolRef = callable.get<SymbolRefAttr>();
if (symbolTable)
return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
}

//===----------------------------------------------------------------------===//
Expand Down
Loading