Skip to content

Commit 39c5553

Browse files
committed
[mlir] Make resolveCallable customizable in CallOpInterface
Allow customization of the `resolveCallable` method in the `CallOpInterface`. This change allows for operations implementing this interface to provide their own logic for resolving callables. - Introduce the `resolveCallable` method, which does not include the optional symbol table parameter. This method replaces the previously existing extra class declaration `resolveCallable`. - Introduce the `resolveCallableInTable` method, which incorporates the symbol table parameter. This method replaces the previous extra class declaration `resolveCallable` that used the optional symbol table parameter.
1 parent 7543d09 commit 39c5553

File tree

8 files changed

+47
-26
lines changed

8 files changed

+47
-26
lines changed

mlir/include/mlir/Interfaces/CallInterfaces.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,20 @@ namespace mlir {
2323
struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
2424
using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
2525
};
26-
} // namespace mlir
2726

28-
/// Include the generated interface declarations.
29-
#include "mlir/Interfaces/CallInterfaces.h.inc"
27+
class CallOpInterface;
28+
29+
namespace call_interface_impl {
30+
31+
/// Resolve the callable operation for given callee to a CallableOpInterface, or
32+
/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
33+
/// parameter that will allow for using a cached symbol table for symbol lookups
34+
/// instead of performing an O(N) scan.
35+
Operation *resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable = nullptr);
36+
37+
} // namespace call_interface_impl
38+
39+
} // namespace mlir
3040

3141
namespace llvm {
3242

@@ -41,4 +51,7 @@ struct CastInfo<To, const mlir::CallInterfaceCallable>
4151

4252
} // namespace llvm
4353

54+
/// Include the generated interface declarations.
55+
#include "mlir/Interfaces/CallInterfaces.h.inc"
56+
4457
#endif // MLIR_INTERFACES_CALLINTERFACES_H

mlir/include/mlir/Interfaces/CallInterfaces.td

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,29 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
5959
Returns the operands within this call that are used as arguments to the
6060
callee as a mutable range.
6161
}],
62-
"::mlir::MutableOperandRange", "getArgOperandsMutable">,
62+
"::mlir::MutableOperandRange", "getArgOperandsMutable"
63+
>,
64+
InterfaceMethod<[{
65+
Resolve the callable operation for given callee to a
66+
CallableOpInterface, or nullptr if a valid callable was not resolved.
67+
`symbolTable` parameter allow for using a cached symbol table for symbol
68+
lookups instead of performing an O(N) scan.
69+
}],
70+
"::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection *":$symbolTable),
71+
/*methodBody=*/[{}], /*defaultImplementation=*/[{
72+
return ::mlir::call_interface_impl::resolveCallable($_op, symbolTable);
73+
}]
74+
>,
75+
InterfaceMethod<[{
76+
Resolve the callable operation for given callee to a
77+
CallableOpInterface, or nullptr if a valid callable was not resolved.
78+
}],
79+
"::mlir::Operation *", "resolveCallable", (ins),
80+
/*methodBody=*/[{}], /*defaultImplementation=*/[{
81+
return ::mlir::call_interface_impl::resolveCallable($_op);
82+
}]
83+
>
6384
];
64-
65-
let extraClassDeclaration = [{
66-
/// Resolve the callable operation for given callee to a
67-
/// CallableOpInterface, or nullptr if a valid callable was not resolved.
68-
/// `symbolTable` is an optional parameter that will allow for using a
69-
/// cached symbol table for symbol lookups instead of performing an O(N)
70-
/// scan.
71-
::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr);
72-
}];
7385
}
7486

7587
/// Interface for callable operations.

mlir/lib/Analysis/CallGraph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
146146
CallGraphNode *
147147
CallGraph::resolveCallable(CallOpInterface call,
148148
SymbolTableCollection &symbolTable) const {
149-
Operation *callable = call.resolveCallable(&symbolTable);
149+
Operation *callable = call.resolveCallableInTable(&symbolTable);
150150
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
151151
if (auto *node = lookupNode(callableOp.getCallableRegion()))
152152
return node;

mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
297297
}
298298

299299
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
300-
Operation *callableOp = call.resolveCallable(&symbolTable);
300+
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
301301

302302
// A call to a externally-defined callable has unknown predecessors.
303303
const auto isExternalCallable = [this](Operation *op) {

mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
284284
CallOpInterface call, const AbstractDenseLattice &after,
285285
AbstractDenseLattice *before) {
286286
// Find the callee.
287-
Operation *callee = call.resolveCallable(&symbolTable);
287+
Operation *callee = call.resolveCallableInTable(&symbolTable);
288288

289289
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
290290
// No region means the callee is only declared in this module.

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
442442
// For function calls, connect the arguments of the entry blocks to the
443443
// operands of the call op that are forwarded to these arguments.
444444
if (auto call = dyn_cast<CallOpInterface>(op)) {
445-
Operation *callableOp = call.resolveCallable(&symbolTable);
445+
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
446446
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
447447
// Not all operands of a call op forward to arguments. Such operands are
448448
// stored in `unaccounted`.

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
824824
// the function is referenced by SSA value instead of a Symbol, it's assumed
825825
// to be public. (And we cannot easily change the type of the SSA value
826826
// anyway.)
827-
Operation *funcOp = op.resolveCallable(state.getSymbolTable());
827+
Operation *funcOp = op.resolveCallableInTable(state.getSymbolTable());
828828
bool isPrivate = false;
829829
if (auto symbol = dyn_cast_or_null<SymbolOpInterface>(funcOp))
830830
isPrivate = symbol.isPrivate() && !symbol.isDeclaration();

mlir/lib/Interfaces/CallInterfaces.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,17 @@ using namespace mlir;
1414
// CallOpInterface
1515
//===----------------------------------------------------------------------===//
1616

17-
/// Resolve the callable operation for given callee to a CallableOpInterface, or
18-
/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
19-
/// parameter that will allow for using a cached symbol table for symbol lookups
20-
/// instead of performing an O(N) scan.
2117
Operation *
22-
CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
23-
CallInterfaceCallable callable = getCallableForCallee();
18+
call_interface_impl::resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable) {
19+
CallInterfaceCallable callable = call.getCallableForCallee();
2420
if (auto symbolVal = dyn_cast<Value>(callable))
2521
return symbolVal.getDefiningOp();
2622

2723
// If the callable isn't a value, lookup the symbol reference.
2824
auto symbolRef = callable.get<SymbolRefAttr>();
2925
if (symbolTable)
30-
return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
31-
return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
26+
return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
27+
return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
3228
}
3329

3430
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)