Skip to content

Revert "[MLIR] Make resolveCallable customizable in CallOpInterface" #107984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2024

Conversation

matthias-springer
Copy link
Member

Reverts #100361

This commit caused some linker errors. (Missing MLIRCallInterfaces dependency.)

@matthias-springer matthias-springer merged commit 7574042 into main Sep 10, 2024
5 of 6 checks passed
@matthias-springer matthias-springer deleted the revert-100361-main branch September 10, 2024 08:24
@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Sep 10, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2024

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Reverts llvm/llvm-project#100361

This commit caused some linker errors. (Missing MLIRCallInterfaces dependency.)


Full diff: https://github.com/llvm/llvm-project/pull/107984.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/CallInterfaces.h (+3-16)
  • (modified) mlir/include/mlir/Interfaces/CallInterfaces.td (+10-22)
  • (modified) mlir/lib/Analysis/CallGraph.cpp (+1-1)
  • (modified) mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp (+1-1)
  • (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+1-1)
  • (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+1-1)
  • (modified) mlir/lib/Interfaces/CallInterfaces.cpp (+8-4)
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h
index 58c37f01caef09..7dbcddb01b241e 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.h
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.h
@@ -23,21 +23,11 @@ namespace mlir {
 struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {
   using PointerUnion<SymbolRefAttr, Value>::PointerUnion;
 };
-
-class CallOpInterface;
-
-namespace call_interface_impl {
-
-/// Resolve the callable operation for given callee to a CallableOpInterface, or
-/// nullptr if a valid callable was not resolved.  `symbolTable` is an optional
-/// parameter that will allow for using a cached symbol table for symbol lookups
-/// instead of performing an O(N) scan.
-Operation *resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable = nullptr);
-
-} // namespace call_interface_impl
-
 } // namespace mlir
 
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/CallInterfaces.h.inc"
+
 namespace llvm {
 
 // Allow llvm::cast style functions.
@@ -51,7 +41,4 @@ struct CastInfo<To, const mlir::CallInterfaceCallable>
 
 } // namespace llvm
 
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/CallInterfaces.h.inc"
-
 #endif // MLIR_INTERFACES_CALLINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index c6002da0d491ce..752de74e6e4d7e 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -59,29 +59,17 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
         Returns the operands within this call that are used as arguments to the
         callee as a mutable range.
       }],
-      "::mlir::MutableOperandRange", "getArgOperandsMutable"
-    >,
-    InterfaceMethod<[{
-        Resolve the callable operation for given callee to a
-        CallableOpInterface, or nullptr if a valid callable was not resolved.
-        `symbolTable` parameter allow for using a cached symbol table for symbol
-        lookups instead of performing an O(N) scan.
-      }],
-      "::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection *":$symbolTable),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{
-        return ::mlir::call_interface_impl::resolveCallable($_op, symbolTable);
-      }]
-    >,
-    InterfaceMethod<[{
-        Resolve the callable operation for given callee to a
-        CallableOpInterface, or nullptr if a valid callable was not resolved.
-    }],
-      "::mlir::Operation *", "resolveCallable", (ins),
-      /*methodBody=*/[{}], /*defaultImplementation=*/[{
-        return ::mlir::call_interface_impl::resolveCallable($_op);
-      }]
-    >
+      "::mlir::MutableOperandRange", "getArgOperandsMutable">,
   ];
+
+  let extraClassDeclaration = [{
+    /// Resolve the callable operation for given callee to a
+    /// CallableOpInterface, or nullptr if a valid callable was not resolved.
+    /// `symbolTable` is an optional parameter that will allow for using a
+    /// cached symbol table for symbol lookups instead of performing an O(N)
+    /// scan.
+    ::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr);
+  }];
 }
 
 /// Interface for callable operations.
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 780c7caee767c1..ccd4676632136b 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
 CallGraphNode *
 CallGraph::resolveCallable(CallOpInterface call,
                            SymbolTableCollection &symbolTable) const {
-  Operation *callable = call.resolveCallableInTable(&symbolTable);
+  Operation *callable = call.resolveCallable(&symbolTable);
   if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
     if (auto *node = lookupNode(callableOp.getCallableRegion()))
       return node;
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index beb68018a3b16e..532480b6fad57d 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -297,7 +297,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
 }
 
 void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
-  Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+  Operation *callableOp = call.resolveCallable(&symbolTable);
 
   // A call to a externally-defined callable has unknown predecessors.
   const auto isExternalCallable = [this](Operation *op) {
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 300c6e5f9b8919..37f4ceaaa56cee 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -284,7 +284,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
     CallOpInterface call, const AbstractDenseLattice &after,
     AbstractDenseLattice *before) {
   // Find the callee.
-  Operation *callee = call.resolveCallableInTable(&symbolTable);
+  Operation *callee = call.resolveCallable(&symbolTable);
 
   auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
   // No region means the callee is only declared in this module.
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 1bd6defef90be0..4a73f21a18aae7 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -442,7 +442,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // For function calls, connect the arguments of the entry blocks to the
   // operands of the call op that are forwarded to these arguments.
   if (auto call = dyn_cast<CallOpInterface>(op)) {
-    Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+    Operation *callableOp = call.resolveCallable(&symbolTable);
     if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
       // Not all operands of a call op forward to arguments. Such operands are
       // stored in `unaccounted`.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index b973618004497b..ca5d0688b5b594 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -824,7 +824,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
   // the function is referenced by SSA value instead of a Symbol, it's assumed
   // to be public. (And we cannot easily change the type of the SSA value
   // anyway.)
-  Operation *funcOp = op.resolveCallableInTable(state.getSymbolTable());
+  Operation *funcOp = op.resolveCallable(state.getSymbolTable());
   bool isPrivate = false;
   if (auto symbol = dyn_cast_or_null<SymbolOpInterface>(funcOp))
     isPrivate = symbol.isPrivate() && !symbol.isDeclaration();
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 47f8021f50cd28..455684d8e2ea7c 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -14,17 +14,21 @@ using namespace mlir;
 // CallOpInterface
 //===----------------------------------------------------------------------===//
 
+/// Resolve the callable operation for given callee to a CallableOpInterface, or
+/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
+/// parameter that will allow for using a cached symbol table for symbol lookups
+/// instead of performing an O(N) scan.
 Operation *
-call_interface_impl::resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable) {
-  CallInterfaceCallable callable = call.getCallableForCallee();
+CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
+  CallInterfaceCallable callable = getCallableForCallee();
   if (auto symbolVal = dyn_cast<Value>(callable))
     return symbolVal.getDefiningOp();
 
   // If the callable isn't a value, lookup the symbol reference.
   auto symbolRef = callable.get<SymbolRefAttr>();
   if (symbolTable)
-    return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
-  return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
+    return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
+  return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
 }
 
 //===----------------------------------------------------------------------===//

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants