Skip to content

Commit d1cad22

Browse files
authored
Reland [MLIR] Make resolveCallable customizable in CallOpInterface (#107989)
Relands #100361 with fixed dependencies.
1 parent 4d55f0b commit d1cad22

File tree

11 files changed

+52
-27
lines changed

11 files changed

+52
-27
lines changed

mlir/include/mlir/Interfaces/CallInterfaces.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,21 @@ namespace mlir {
2323
struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
2424
using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
2525
};
26-
} // namespace mlir
2726

28-
/// Include the generated interface declarations.
29-
#include "mlir/Interfaces/CallInterfaces.h.inc"
27+
class CallOpInterface;
28+
29+
namespace call_interface_impl {
30+
31+
/// Resolve the callable operation for given callee to a CallableOpInterface, or
32+
/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
33+
/// parameter that will allow for using a cached symbol table for symbol lookups
34+
/// instead of performing an O(N) scan.
35+
Operation *resolveCallable(CallOpInterface call,
36+
SymbolTableCollection *symbolTable = nullptr);
37+
38+
} // namespace call_interface_impl
39+
40+
} // namespace mlir
3041

3142
namespace llvm {
3243

@@ -41,4 +52,7 @@ struct CastInfo<To, const mlir::CallInterfaceCallable>
4152

4253
} // namespace llvm
4354

55+
/// Include the generated interface declarations.
56+
#include "mlir/Interfaces/CallInterfaces.h.inc"
57+
4458
#endif // MLIR_INTERFACES_CALLINTERFACES_H

mlir/include/mlir/Interfaces/CallInterfaces.td

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,29 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
5959
Returns the operands within this call that are used as arguments to the
6060
callee as a mutable range.
6161
}],
62-
"::mlir::MutableOperandRange", "getArgOperandsMutable">,
62+
"::mlir::MutableOperandRange", "getArgOperandsMutable"
63+
>,
64+
InterfaceMethod<[{
65+
Resolve the callable operation for given callee to a
66+
CallableOpInterface, or nullptr if a valid callable was not resolved.
67+
`symbolTable` parameter allow for using a cached symbol table for symbol
68+
lookups instead of performing an O(N) scan.
69+
}],
70+
"::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection *":$symbolTable),
71+
/*methodBody=*/[{}], /*defaultImplementation=*/[{
72+
return ::mlir::call_interface_impl::resolveCallable($_op, symbolTable);
73+
}]
74+
>,
75+
InterfaceMethod<[{
76+
Resolve the callable operation for given callee to a
77+
CallableOpInterface, or nullptr if a valid callable was not resolved.
78+
}],
79+
"::mlir::Operation *", "resolveCallable", (ins),
80+
/*methodBody=*/[{}], /*defaultImplementation=*/[{
81+
return ::mlir::call_interface_impl::resolveCallable($_op);
82+
}]
83+
>
6384
];
64-
65-
let extraClassDeclaration = [{
66-
/// Resolve the callable operation for given callee to a
67-
/// CallableOpInterface, or nullptr if a valid callable was not resolved.
68-
/// `symbolTable` is an optional parameter that will allow for using a
69-
/// cached symbol table for symbol lookups instead of performing an O(N)
70-
/// scan.
71-
::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr);
72-
}];
7385
}
7486

7587
/// Interface for callable operations.

mlir/lib/Analysis/CallGraph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
146146
CallGraphNode *
147147
CallGraph::resolveCallable(CallOpInterface call,
148148
SymbolTableCollection &symbolTable) const {
149-
Operation *callable = call.resolveCallable(&symbolTable);
149+
Operation *callable = call.resolveCallableInTable(&symbolTable);
150150
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
151151
if (auto *node = lookupNode(callableOp.getCallableRegion()))
152152
return node;

mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
297297
}
298298

299299
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
300-
Operation *callableOp = call.resolveCallable(&symbolTable);
300+
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
301301

302302
// A call to a externally-defined callable has unknown predecessors.
303303
const auto isExternalCallable = [this](Operation *op) {

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
284284
CallOpInterface call, const AbstractDenseLattice &after,
285285
AbstractDenseLattice *before) {
286286
// Find the callee.
287-
Operation *callee = call.resolveCallable(&symbolTable);
287+
Operation *callee = call.resolveCallableInTable(&symbolTable);
288288

289289
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
290290
// No region means the callee is only declared in this module.

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
442442
// For function calls, connect the arguments of the entry blocks to the
443443
// operands of the call op that are forwarded to these arguments.
444444
if (auto call = dyn_cast<CallOpInterface>(op)) {
445-
Operation *callableOp = call.resolveCallable(&symbolTable);
445+
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
446446
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
447447
// Not all operands of a call op forward to arguments. Such operands are
448448
// stored in `unaccounted`.

mlir/lib/Dialect/Async/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRAsyncDialect
88
MLIRAsyncOpsIncGen
99

1010
LINK_LIBS PUBLIC
11+
MLIRCallInterfaces
1112
MLIRControlFlowInterfaces
1213
MLIRFunctionInterfaces
1314
MLIRDialect

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
2727
LINK_LIBS PUBLIC
2828
MLIRArithDialect
2929
MLIRBufferizationDialect
30+
MLIRCallInterfaces
3031
MLIRControlFlowInterfaces
3132
MLIRFuncDialect
3233
MLIRFunctionInterfaces
@@ -42,4 +43,3 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
4243
MLIRViewLikeInterface
4344
MLIRSupport
4445
)
45-

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
824824
// the function is referenced by SSA value instead of a Symbol, it's assumed
825825
// to be public. (And we cannot easily change the type of the SSA value
826826
// anyway.)
827-
Operation *funcOp = op.resolveCallable(state.getSymbolTable());
827+
Operation *funcOp = op.resolveCallableInTable(state.getSymbolTable());
828828
bool isPrivate = false;
829829
if (auto symbol = dyn_cast_or_null<SymbolOpInterface>(funcOp))
830830
isPrivate = symbol.isPrivate() && !symbol.isDeclaration();

mlir/lib/Interfaces/CallInterfaces.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,18 @@ using namespace mlir;
1414
// CallOpInterface
1515
//===----------------------------------------------------------------------===//
1616

17-
/// Resolve the callable operation for given callee to a CallableOpInterface, or
18-
/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
19-
/// parameter that will allow for using a cached symbol table for symbol lookups
20-
/// instead of performing an O(N) scan.
2117
Operation *
22-
CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
23-
CallInterfaceCallable callable = getCallableForCallee();
18+
call_interface_impl::resolveCallable(CallOpInterface call,
19+
SymbolTableCollection *symbolTable) {
20+
CallInterfaceCallable callable = call.getCallableForCallee();
2421
if (auto symbolVal = dyn_cast<Value>(callable))
2522
return symbolVal.getDefiningOp();
2623

2724
// If the callable isn't a value, lookup the symbol reference.
2825
auto symbolRef = callable.get<SymbolRefAttr>();
2926
if (symbolTable)
30-
return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
31-
return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
27+
return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
28+
return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
3229
}
3330

3431
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_library(MLIRTransformUtils
1616

1717
LINK_LIBS PUBLIC
1818
MLIRAnalysis
19+
MLIRCallInterfaces
1920
MLIRControlFlowInterfaces
2021
MLIRFunctionInterfaces
2122
MLIRLoopLikeInterface

0 commit comments

Comments
 (0)