Skip to content

Revert "[MLIR] Add bufferization state class to OneShotBufferization pass" #141012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025

Conversation

mscuttari
Copy link
Member

@mscuttari mscuttari commented May 22, 2025

Reverts #138143

The PR for the BufferizationState is temporarily reverted due to API incompatibilities that have been initially missed during the update and were not catched by PR checks.

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir-mlprogram
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-shape
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-cf

Author: Michele Scuttari (mscuttari)

Changes

Reverts llvm/llvm-project#138143

The PR for the BufferizationState is temporarily reverted due to API incompatibilities that have been initially missed during the update and were not catched by pre-merge checks.


Patch is 52.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141012.diff

27 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (-14)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-2)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+5-10)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h (-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h (+1-3)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (+2-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (+3-20)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+3-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+3-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+2-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+8-17)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+4-11)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+9-18)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+2-4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+1-4)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-32)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+5-10)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..cb6ef8bc17220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,20 +578,6 @@ class AnalysisState {
       insideMutuallyExclusiveRegionsCache;
 };
 
-/// BufferizationState provides information about the state of the IR during the
-/// bufferization process.
-class BufferizationState {
-public:
-  /// Get a reference to the collection of cached symbol tables.
-  SymbolTableCollection &getSymbolTables();
-
-private:
-  /// The cached symbol tables.
-  /// The user is expected to update / invalidate the cached symbol tables if
-  /// the bufferized operation has the Symbol or SymbolTable traits.
-  SymbolTableCollection symbolTables;
-};
-
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
 /// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..95022d7d665d2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"bufferize",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "::mlir::bufferization::BufferizationState &":$state),
+                      "const ::mlir::bufferization::BufferizationOptions &":$options),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..7a1a701bea6dc 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool resultBufferizesToMemoryWrite(OpResult opResult,
                                        const AnalysisState &state);
@@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool bufferizesToMemoryRead(OpOperand &opOperand,
                                 const AnalysisState &state);
@@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 }
 
@@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     //===------------------------------------------------------------------===//
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state) const {
+                            const BufferizationOptions &options) const {
       // to_tensor/to_buffer pairs fold away after bufferization.
       return success();
     }
@@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index c08bd6c436133..e5f3b6d571f43 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,7 +29,6 @@ class GlobalOp;
 } // namespace memref
 
 namespace bufferization {
-class BufferizationState;
 
 /// A simple analysis that detects allocation operations.
 class BufferPlacementAllocs {
@@ -123,14 +122,9 @@ class BufferPlacementTransformationBase {
 // Globals are created lazily at the top of the enclosing ModuleOp with pretty
 // names. Duplicates are avoided.
 FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
-                                         SymbolTableCollection &symbolTables,
                                          uint64_t alignment,
                                          Attribute memorySpace = {});
 
-void removeSymbol(Operation *op, BufferizationState &state);
-
-void insertSymbol(Operation *op, BufferizationState &state);
-
 } // namespace bufferization
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..d5cb8d8eb673c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,7 +45,6 @@ struct BufferizationStatistics {
 /// additional buffer copies or set "options.copyBeforeWrite = true". The
 /// general bufferization entry point is `runOneShotBufferize`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
-                          BufferizationState &bufferizationState,
                           BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 15189d2c1cb87..673027f76190d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
 /// Run One-Shot Bufferize on the given op: Analysis + Bufferization
 LogicalResult
 runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
-                    BufferizationState &state,
                     BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 2cf801dd1d951..4e5f5e9c730fa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,7 +20,6 @@ namespace bufferization {
 struct BufferizationStatistics;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
-class BufferizationState;
 
 /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
 /// `state`.
@@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 ///   will be inserted only to these FuncOps.
 llvm::LogicalResult
 bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-                  BufferizationState &state,
                   BufferizationStatistics *statistics = nullptr);
 
 /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
 llvm::LogicalResult runOneShotModuleBufferize(
     ModuleOp moduleOp,
     const bufferization::OneShotBufferizationOptions &options,
-    BufferizationState &state, BufferizationStatistics *statistics = nullptr);
+    BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..4f90fc8831bc6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,7 +30,6 @@ namespace mlir {
 namespace bufferization {
 class AllocTensorOp;
 class OneShotAnalysisState;
-class BufferizationState;
 } // namespace bufferization
 
 namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..5e69a98db8f1e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,8 +24,7 @@ struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
     auto type = dyn_cast<RankedTensorType>(constantOp.getType());
 
@@ -47,8 +46,7 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, state.getSymbolTables(),
-                     options.bufferAlignment, memorySpace);
+        getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;
@@ -85,8 +83,7 @@ struct IndexCastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
@@ -134,8 +131,7 @@ struct SelectOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto selectOp = cast<arith::SelectOp>(op);
     Location loc = selectOp.getLoc();
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..1fc34051680f1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
-SymbolTableCollection &BufferizationState::getSymbolTables() {
-  return symbolTables;
-}
-
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 91eccb0ab7430..ecd2ef15546a4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes(
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
-                                       const BufferizationOptions &options,
-                                       BufferizationState &state) {
+                                       const BufferizationOptions &options) {
   OpBuilder::InsertionGuard g(rewriter);
   Location loc = getLoc();
 
@@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
-                                         const BufferizationOptions &options,
-                                         BufferizationState &state) {
+                                         const BufferizationOptions &options) {
   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
   if (failed(buffer))
     return failure();
@@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
 
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
-                                      const BufferizationOptions &options,
-                                      BufferizationState &state) {
+                                      const BufferizationOptions &options) {
   bool tensorDest = isa<TensorType>(getDest().getType());
   Value buffer;
   if (tensorDest) {
@@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
-                                    const BufferizationOptions &options,
-                                    BufferizationState &state) {
+                                    const BufferizationOptions &options) {
   // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
   (void)foldToBufferToTensorPair(rewriter, *this, options);
   // Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..a1d7bb995fc73 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,8 +83,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto payloadOps = state.getPayloadOps(getTarget());
-  BufferizationState bufferizationState;
-
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
@@ -92,12 +90,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
         return emitSilenceableError() << "expected module target";
-      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
-                                                          bufferizationState)))
+      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
         return emitSilenceableError() << "bufferization failed";
     } else {
-      if (failed(bufferization::runOneShotBufferize(target, options,
-                                                    bufferizationState)))
+      if (failed(bufferization::runOneShotBufferize(target, options)))
         return emitSilenceableError() << "bufferization failed";
     }
   }
@@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
-
         >();
   }
 };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index ff2c83d228dbb..c2e90764b1335 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
 //===----------------------------------------------------------------------===//
 
 FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp,
-                            SymbolTableCollection &symbolTables,
-                            uint64_t alignment, Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
+                            Attribute memorySpace) {
   auto type = cast<RankedTensorType>(constantOp.getType());
   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
   if (!moduleOp)
@@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   // Create a builder without an insertion point. We will insert using the
   // symbol table to guarantee unique names.
   OpBuilder globalBuilder(moduleOp.getContext());
-  SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
+  SymbolTable symbolTable(moduleOp);
 
   // Create a pretty name.
   SmallString<64> buf;
@@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   global->moveBefore(&moduleOp.front());
   return global;
 }
-
-namespace mlir::bufferization {
-void removeSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.remove(op);
-}
-
-void insertSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.insert(op);
-}
-} // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 38de525316f7a..824b505517119 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -161,13 +161,10 @@ struct OneShotBufferizePass
       return signalPassFailure();
     }
 
-    BufferizationState state;
-
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      if (failed(
-              runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -178,7 +175,7 @@ struct OneShotBufferizePass
                   "'bufferize-function-boundaries'");
         return signalPassFailure();
       }
-      if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
-                                         BufferizationState &bufferizationState,
                                          BufferizationStatistics *statistics) {
   if (options.copyBeforeWrite) {
     AnalysisState state(options);
@@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                << "//===-------------------------------------------===//\n"
                << "IR after bufferizing: " << nextOp->getName() << "\n");
     rewriter.setInsertionPoint(nextOp);
-    if (failed(
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir-scf

Author: Michele Scuttari (mscuttari)

Changes

Reverts llvm/llvm-project#138143

The PR for the BufferizationState is temporarily reverted due to API incompatibilities that have been initially missed during the update and were not catched by pre-merge checks.


Patch is 52.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141012.diff

27 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (-14)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-2)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+5-10)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h (-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h (+1-3)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (+2-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (+3-20)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+3-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+3-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+2-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+8-17)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+4-11)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+9-18)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+2-4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+1-4)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-32)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+5-10)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..cb6ef8bc17220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,20 +578,6 @@ class AnalysisState {
       insideMutuallyExclusiveRegionsCache;
 };
 
-/// BufferizationState provides information about the state of the IR during the
-/// bufferization process.
-class BufferizationState {
-public:
-  /// Get a reference to the collection of cached symbol tables.
-  SymbolTableCollection &getSymbolTables();
-
-private:
-  /// The cached symbol tables.
-  /// The user is expected to update / invalidate the cached symbol tables if
-  /// the bufferized operation has the Symbol or SymbolTable traits.
-  SymbolTableCollection symbolTables;
-};
-
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
 /// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..95022d7d665d2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"bufferize",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "::mlir::bufferization::BufferizationState &":$state),
+                      "const ::mlir::bufferization::BufferizationOptions &":$options),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..7a1a701bea6dc 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool resultBufferizesToMemoryWrite(OpResult opResult,
                                        const AnalysisState &state);
@@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool bufferizesToMemoryRead(OpOperand &opOperand,
                                 const AnalysisState &state);
@@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 }
 
@@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     //===------------------------------------------------------------------===//
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state) const {
+                            const BufferizationOptions &options) const {
       // to_tensor/to_buffer pairs fold away after bufferization.
       return success();
     }
@@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index c08bd6c436133..e5f3b6d571f43 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,7 +29,6 @@ class GlobalOp;
 } // namespace memref
 
 namespace bufferization {
-class BufferizationState;
 
 /// A simple analysis that detects allocation operations.
 class BufferPlacementAllocs {
@@ -123,14 +122,9 @@ class BufferPlacementTransformationBase {
 // Globals are created lazily at the top of the enclosing ModuleOp with pretty
 // names. Duplicates are avoided.
 FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
-                                         SymbolTableCollection &symbolTables,
                                          uint64_t alignment,
                                          Attribute memorySpace = {});
 
-void removeSymbol(Operation *op, BufferizationState &state);
-
-void insertSymbol(Operation *op, BufferizationState &state);
-
 } // namespace bufferization
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..d5cb8d8eb673c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,7 +45,6 @@ struct BufferizationStatistics {
 /// additional buffer copies or set "options.copyBeforeWrite = true". The
 /// general bufferization entry point is `runOneShotBufferize`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
-                          BufferizationState &bufferizationState,
                           BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 15189d2c1cb87..673027f76190d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
 /// Run One-Shot Bufferize on the given op: Analysis + Bufferization
 LogicalResult
 runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
-                    BufferizationState &state,
                     BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 2cf801dd1d951..4e5f5e9c730fa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,7 +20,6 @@ namespace bufferization {
 struct BufferizationStatistics;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
-class BufferizationState;
 
 /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
 /// `state`.
@@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 ///   will be inserted only to these FuncOps.
 llvm::LogicalResult
 bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-                  BufferizationState &state,
                   BufferizationStatistics *statistics = nullptr);
 
 /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
 llvm::LogicalResult runOneShotModuleBufferize(
     ModuleOp moduleOp,
     const bufferization::OneShotBufferizationOptions &options,
-    BufferizationState &state, BufferizationStatistics *statistics = nullptr);
+    BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..4f90fc8831bc6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,7 +30,6 @@ namespace mlir {
 namespace bufferization {
 class AllocTensorOp;
 class OneShotAnalysisState;
-class BufferizationState;
 } // namespace bufferization
 
 namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..5e69a98db8f1e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,8 +24,7 @@ struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
     auto type = dyn_cast<RankedTensorType>(constantOp.getType());
 
@@ -47,8 +46,7 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, state.getSymbolTables(),
-                     options.bufferAlignment, memorySpace);
+        getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;
@@ -85,8 +83,7 @@ struct IndexCastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
@@ -134,8 +131,7 @@ struct SelectOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto selectOp = cast<arith::SelectOp>(op);
     Location loc = selectOp.getLoc();
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..1fc34051680f1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
-SymbolTableCollection &BufferizationState::getSymbolTables() {
-  return symbolTables;
-}
-
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 91eccb0ab7430..ecd2ef15546a4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes(
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
-                                       const BufferizationOptions &options,
-                                       BufferizationState &state) {
+                                       const BufferizationOptions &options) {
   OpBuilder::InsertionGuard g(rewriter);
   Location loc = getLoc();
 
@@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
-                                         const BufferizationOptions &options,
-                                         BufferizationState &state) {
+                                         const BufferizationOptions &options) {
   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
   if (failed(buffer))
     return failure();
@@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
 
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
-                                      const BufferizationOptions &options,
-                                      BufferizationState &state) {
+                                      const BufferizationOptions &options) {
   bool tensorDest = isa<TensorType>(getDest().getType());
   Value buffer;
   if (tensorDest) {
@@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
-                                    const BufferizationOptions &options,
-                                    BufferizationState &state) {
+                                    const BufferizationOptions &options) {
   // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
   (void)foldToBufferToTensorPair(rewriter, *this, options);
   // Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..a1d7bb995fc73 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,8 +83,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto payloadOps = state.getPayloadOps(getTarget());
-  BufferizationState bufferizationState;
-
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
@@ -92,12 +90,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
         return emitSilenceableError() << "expected module target";
-      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
-                                                          bufferizationState)))
+      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
         return emitSilenceableError() << "bufferization failed";
     } else {
-      if (failed(bufferization::runOneShotBufferize(target, options,
-                                                    bufferizationState)))
+      if (failed(bufferization::runOneShotBufferize(target, options)))
         return emitSilenceableError() << "bufferization failed";
     }
   }
@@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
-
         >();
   }
 };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index ff2c83d228dbb..c2e90764b1335 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
 //===----------------------------------------------------------------------===//
 
 FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp,
-                            SymbolTableCollection &symbolTables,
-                            uint64_t alignment, Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
+                            Attribute memorySpace) {
   auto type = cast<RankedTensorType>(constantOp.getType());
   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
   if (!moduleOp)
@@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   // Create a builder without an insertion point. We will insert using the
   // symbol table to guarantee unique names.
   OpBuilder globalBuilder(moduleOp.getContext());
-  SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
+  SymbolTable symbolTable(moduleOp);
 
   // Create a pretty name.
   SmallString<64> buf;
@@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   global->moveBefore(&moduleOp.front());
   return global;
 }
-
-namespace mlir::bufferization {
-void removeSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.remove(op);
-}
-
-void insertSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.insert(op);
-}
-} // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 38de525316f7a..824b505517119 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -161,13 +161,10 @@ struct OneShotBufferizePass
       return signalPassFailure();
     }
 
-    BufferizationState state;
-
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      if (failed(
-              runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -178,7 +175,7 @@ struct OneShotBufferizePass
                   "'bufferize-function-boundaries'");
         return signalPassFailure();
       }
-      if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
-                                         BufferizationState &bufferizationState,
                                          BufferizationStatistics *statistics) {
   if (options.copyBeforeWrite) {
     AnalysisState state(options);
@@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                << "//===-------------------------------------------===//\n"
                << "IR after bufferizing: " << nextOp->getName() << "\n");
     rewriter.setInsertionPoint(nextOp);
-    if (failed(
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir-tensor

Author: Michele Scuttari (mscuttari)

Changes

Reverts llvm/llvm-project#138143

The PR for the BufferizationState is temporarily reverted due to API incompatibilities that have been initially missed during the update and were not catched by pre-merge checks.


Patch is 52.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141012.diff

27 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (-14)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-2)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+5-10)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h (-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h (+1-3)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (+2-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (+3-20)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+3-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+3-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+2-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+8-17)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+4-11)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+9-18)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+2-4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+1-4)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-32)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+5-10)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..cb6ef8bc17220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,20 +578,6 @@ class AnalysisState {
       insideMutuallyExclusiveRegionsCache;
 };
 
-/// BufferizationState provides information about the state of the IR during the
-/// bufferization process.
-class BufferizationState {
-public:
-  /// Get a reference to the collection of cached symbol tables.
-  SymbolTableCollection &getSymbolTables();
-
-private:
-  /// The cached symbol tables.
-  /// The user is expected to update / invalidate the cached symbol tables if
-  /// the bufferized operation has the Symbol or SymbolTable traits.
-  SymbolTableCollection symbolTables;
-};
-
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
 /// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..95022d7d665d2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"bufferize",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "::mlir::bufferization::BufferizationState &":$state),
+                      "const ::mlir::bufferization::BufferizationOptions &":$options),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..7a1a701bea6dc 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool resultBufferizesToMemoryWrite(OpResult opResult,
                                        const AnalysisState &state);
@@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool bufferizesToMemoryRead(OpOperand &opOperand,
                                 const AnalysisState &state);
@@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 }
 
@@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     //===------------------------------------------------------------------===//
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state) const {
+                            const BufferizationOptions &options) const {
       // to_tensor/to_buffer pairs fold away after bufferization.
       return success();
     }
@@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index c08bd6c436133..e5f3b6d571f43 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,7 +29,6 @@ class GlobalOp;
 } // namespace memref
 
 namespace bufferization {
-class BufferizationState;
 
 /// A simple analysis that detects allocation operations.
 class BufferPlacementAllocs {
@@ -123,14 +122,9 @@ class BufferPlacementTransformationBase {
 // Globals are created lazily at the top of the enclosing ModuleOp with pretty
 // names. Duplicates are avoided.
 FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
-                                         SymbolTableCollection &symbolTables,
                                          uint64_t alignment,
                                          Attribute memorySpace = {});
 
-void removeSymbol(Operation *op, BufferizationState &state);
-
-void insertSymbol(Operation *op, BufferizationState &state);
-
 } // namespace bufferization
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..d5cb8d8eb673c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,7 +45,6 @@ struct BufferizationStatistics {
 /// additional buffer copies or set "options.copyBeforeWrite = true". The
 /// general bufferization entry point is `runOneShotBufferize`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
-                          BufferizationState &bufferizationState,
                           BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 15189d2c1cb87..673027f76190d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
 /// Run One-Shot Bufferize on the given op: Analysis + Bufferization
 LogicalResult
 runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
-                    BufferizationState &state,
                     BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 2cf801dd1d951..4e5f5e9c730fa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,7 +20,6 @@ namespace bufferization {
 struct BufferizationStatistics;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
-class BufferizationState;
 
 /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
 /// `state`.
@@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 ///   will be inserted only to these FuncOps.
 llvm::LogicalResult
 bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-                  BufferizationState &state,
                   BufferizationStatistics *statistics = nullptr);
 
 /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
 llvm::LogicalResult runOneShotModuleBufferize(
     ModuleOp moduleOp,
     const bufferization::OneShotBufferizationOptions &options,
-    BufferizationState &state, BufferizationStatistics *statistics = nullptr);
+    BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..4f90fc8831bc6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,7 +30,6 @@ namespace mlir {
 namespace bufferization {
 class AllocTensorOp;
 class OneShotAnalysisState;
-class BufferizationState;
 } // namespace bufferization
 
 namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..5e69a98db8f1e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,8 +24,7 @@ struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
     auto type = dyn_cast<RankedTensorType>(constantOp.getType());
 
@@ -47,8 +46,7 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, state.getSymbolTables(),
-                     options.bufferAlignment, memorySpace);
+        getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;
@@ -85,8 +83,7 @@ struct IndexCastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
@@ -134,8 +131,7 @@ struct SelectOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto selectOp = cast<arith::SelectOp>(op);
     Location loc = selectOp.getLoc();
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..1fc34051680f1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
-SymbolTableCollection &BufferizationState::getSymbolTables() {
-  return symbolTables;
-}
-
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 91eccb0ab7430..ecd2ef15546a4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes(
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
-                                       const BufferizationOptions &options,
-                                       BufferizationState &state) {
+                                       const BufferizationOptions &options) {
   OpBuilder::InsertionGuard g(rewriter);
   Location loc = getLoc();
 
@@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
-                                         const BufferizationOptions &options,
-                                         BufferizationState &state) {
+                                         const BufferizationOptions &options) {
   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
   if (failed(buffer))
     return failure();
@@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
 
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
-                                      const BufferizationOptions &options,
-                                      BufferizationState &state) {
+                                      const BufferizationOptions &options) {
   bool tensorDest = isa<TensorType>(getDest().getType());
   Value buffer;
   if (tensorDest) {
@@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
-                                    const BufferizationOptions &options,
-                                    BufferizationState &state) {
+                                    const BufferizationOptions &options) {
   // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
   (void)foldToBufferToTensorPair(rewriter, *this, options);
   // Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..a1d7bb995fc73 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,8 +83,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto payloadOps = state.getPayloadOps(getTarget());
-  BufferizationState bufferizationState;
-
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
@@ -92,12 +90,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
         return emitSilenceableError() << "expected module target";
-      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
-                                                          bufferizationState)))
+      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
         return emitSilenceableError() << "bufferization failed";
     } else {
-      if (failed(bufferization::runOneShotBufferize(target, options,
-                                                    bufferizationState)))
+      if (failed(bufferization::runOneShotBufferize(target, options)))
         return emitSilenceableError() << "bufferization failed";
     }
   }
@@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
-
         >();
   }
 };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index ff2c83d228dbb..c2e90764b1335 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
 //===----------------------------------------------------------------------===//
 
 FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp,
-                            SymbolTableCollection &symbolTables,
-                            uint64_t alignment, Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
+                            Attribute memorySpace) {
   auto type = cast<RankedTensorType>(constantOp.getType());
   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
   if (!moduleOp)
@@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   // Create a builder without an insertion point. We will insert using the
   // symbol table to guarantee unique names.
   OpBuilder globalBuilder(moduleOp.getContext());
-  SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
+  SymbolTable symbolTable(moduleOp);
 
   // Create a pretty name.
   SmallString<64> buf;
@@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   global->moveBefore(&moduleOp.front());
   return global;
 }
-
-namespace mlir::bufferization {
-void removeSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.remove(op);
-}
-
-void insertSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.insert(op);
-}
-} // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 38de525316f7a..824b505517119 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -161,13 +161,10 @@ struct OneShotBufferizePass
       return signalPassFailure();
     }
 
-    BufferizationState state;
-
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      if (failed(
-              runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -178,7 +175,7 @@ struct OneShotBufferizePass
                   "'bufferize-function-boundaries'");
         return signalPassFailure();
       }
-      if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
-                                         BufferizationState &bufferizationState,
                                          BufferizationStatistics *statistics) {
   if (options.copyBeforeWrite) {
     AnalysisState state(options);
@@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                << "//===-------------------------------------------===//\n"
                << "IR after bufferizing: " << nextOp->getName() << "\n");
     rewriter.setInsertionPoint(nextOp);
-    if (failed(
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Michele Scuttari (mscuttari)

Changes

Reverts llvm/llvm-project#138143

The PR for the BufferizationState is temporarily reverted due to API incompatibilities that have been initially missed during the update and were not catched by pre-merge checks.


Patch is 52.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141012.diff

27 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (-14)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-2)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+5-10)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h (-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h (+1-3)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (+2-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (+3-20)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+3-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+3-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+2-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+8-17)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+4-11)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+9-18)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+2-4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+1-4)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-32)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+5-10)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..cb6ef8bc17220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,20 +578,6 @@ class AnalysisState {
       insideMutuallyExclusiveRegionsCache;
 };
 
-/// BufferizationState provides information about the state of the IR during the
-/// bufferization process.
-class BufferizationState {
-public:
-  /// Get a reference to the collection of cached symbol tables.
-  SymbolTableCollection &getSymbolTables();
-
-private:
-  /// The cached symbol tables.
-  /// The user is expected to update / invalidate the cached symbol tables if
-  /// the bufferized operation has the Symbol or SymbolTable traits.
-  SymbolTableCollection symbolTables;
-};
-
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
 /// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..95022d7d665d2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"bufferize",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "::mlir::bufferization::BufferizationState &":$state),
+                      "const ::mlir::bufferization::BufferizationOptions &":$options),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..7a1a701bea6dc 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool resultBufferizesToMemoryWrite(OpResult opResult,
                                        const AnalysisState &state);
@@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool bufferizesToMemoryRead(OpOperand &opOperand,
                                 const AnalysisState &state);
@@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 }
 
@@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     //===------------------------------------------------------------------===//
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state) const {
+                            const BufferizationOptions &options) const {
       // to_tensor/to_buffer pairs fold away after bufferization.
       return success();
     }
@@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index c08bd6c436133..e5f3b6d571f43 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,7 +29,6 @@ class GlobalOp;
 } // namespace memref
 
 namespace bufferization {
-class BufferizationState;
 
 /// A simple analysis that detects allocation operations.
 class BufferPlacementAllocs {
@@ -123,14 +122,9 @@ class BufferPlacementTransformationBase {
 // Globals are created lazily at the top of the enclosing ModuleOp with pretty
 // names. Duplicates are avoided.
 FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
-                                         SymbolTableCollection &symbolTables,
                                          uint64_t alignment,
                                          Attribute memorySpace = {});
 
-void removeSymbol(Operation *op, BufferizationState &state);
-
-void insertSymbol(Operation *op, BufferizationState &state);
-
 } // namespace bufferization
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..d5cb8d8eb673c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,7 +45,6 @@ struct BufferizationStatistics {
 /// additional buffer copies or set "options.copyBeforeWrite = true". The
 /// general bufferization entry point is `runOneShotBufferize`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
-                          BufferizationState &bufferizationState,
                           BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 15189d2c1cb87..673027f76190d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
 /// Run One-Shot Bufferize on the given op: Analysis + Bufferization
 LogicalResult
 runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
-                    BufferizationState &state,
                     BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 2cf801dd1d951..4e5f5e9c730fa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,7 +20,6 @@ namespace bufferization {
 struct BufferizationStatistics;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
-class BufferizationState;
 
 /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
 /// `state`.
@@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 ///   will be inserted only to these FuncOps.
 llvm::LogicalResult
 bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-                  BufferizationState &state,
                   BufferizationStatistics *statistics = nullptr);
 
 /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
 llvm::LogicalResult runOneShotModuleBufferize(
     ModuleOp moduleOp,
     const bufferization::OneShotBufferizationOptions &options,
-    BufferizationState &state, BufferizationStatistics *statistics = nullptr);
+    BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..4f90fc8831bc6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,7 +30,6 @@ namespace mlir {
 namespace bufferization {
 class AllocTensorOp;
 class OneShotAnalysisState;
-class BufferizationState;
 } // namespace bufferization
 
 namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..5e69a98db8f1e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,8 +24,7 @@ struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
     auto type = dyn_cast<RankedTensorType>(constantOp.getType());
 
@@ -47,8 +46,7 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, state.getSymbolTables(),
-                     options.bufferAlignment, memorySpace);
+        getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;
@@ -85,8 +83,7 @@ struct IndexCastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
@@ -134,8 +131,7 @@ struct SelectOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto selectOp = cast<arith::SelectOp>(op);
     Location loc = selectOp.getLoc();
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..1fc34051680f1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
-SymbolTableCollection &BufferizationState::getSymbolTables() {
-  return symbolTables;
-}
-
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 91eccb0ab7430..ecd2ef15546a4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes(
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
-                                       const BufferizationOptions &options,
-                                       BufferizationState &state) {
+                                       const BufferizationOptions &options) {
   OpBuilder::InsertionGuard g(rewriter);
   Location loc = getLoc();
 
@@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
-                                         const BufferizationOptions &options,
-                                         BufferizationState &state) {
+                                         const BufferizationOptions &options) {
   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
   if (failed(buffer))
     return failure();
@@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
 
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
-                                      const BufferizationOptions &options,
-                                      BufferizationState &state) {
+                                      const BufferizationOptions &options) {
   bool tensorDest = isa<TensorType>(getDest().getType());
   Value buffer;
   if (tensorDest) {
@@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
-                                    const BufferizationOptions &options,
-                                    BufferizationState &state) {
+                                    const BufferizationOptions &options) {
   // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
   (void)foldToBufferToTensorPair(rewriter, *this, options);
   // Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..a1d7bb995fc73 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,8 +83,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto payloadOps = state.getPayloadOps(getTarget());
-  BufferizationState bufferizationState;
-
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
@@ -92,12 +90,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
         return emitSilenceableError() << "expected module target";
-      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
-                                                          bufferizationState)))
+      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
         return emitSilenceableError() << "bufferization failed";
     } else {
-      if (failed(bufferization::runOneShotBufferize(target, options,
-                                                    bufferizationState)))
+      if (failed(bufferization::runOneShotBufferize(target, options)))
         return emitSilenceableError() << "bufferization failed";
     }
   }
@@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
-
         >();
   }
 };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index ff2c83d228dbb..c2e90764b1335 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
 //===----------------------------------------------------------------------===//
 
 FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp,
-                            SymbolTableCollection &symbolTables,
-                            uint64_t alignment, Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
+                            Attribute memorySpace) {
   auto type = cast<RankedTensorType>(constantOp.getType());
   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
   if (!moduleOp)
@@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   // Create a builder without an insertion point. We will insert using the
   // symbol table to guarantee unique names.
   OpBuilder globalBuilder(moduleOp.getContext());
-  SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
+  SymbolTable symbolTable(moduleOp);
 
   // Create a pretty name.
   SmallString<64> buf;
@@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   global->moveBefore(&moduleOp.front());
   return global;
 }
-
-namespace mlir::bufferization {
-void removeSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.remove(op);
-}
-
-void insertSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.insert(op);
-}
-} // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 38de525316f7a..824b505517119 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -161,13 +161,10 @@ struct OneShotBufferizePass
       return signalPassFailure();
     }
 
-    BufferizationState state;
-
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      if (failed(
-              runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -178,7 +175,7 @@ struct OneShotBufferizePass
                   "'bufferize-function-boundaries'");
         return signalPassFailure();
       }
-      if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
-                                         BufferizationState &bufferizationState,
                                          BufferizationStatistics *statistics) {
   if (options.copyBeforeWrite) {
     AnalysisState state(options);
@@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                << "//===-------------------------------------------===//\n"
                << "IR after bufferizing: " << nextOp->getName() << "\n");
     rewriter.setInsertionPoint(nextOp);
-    if (failed(
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir

Author: Michele Scuttari (mscuttari)

Changes

Reverts llvm/llvm-project#138143

The PR for the BufferizationState is temporarily reverted due to API incompatibilities that have been initially missed during the update and were not catched by pre-merge checks.


Patch is 52.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141012.diff

27 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (-14)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-2)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+5-10)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h (-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h (-1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h (+1-3)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (-1)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+4-8)
  • (modified) mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (+2-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp (+3-20)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+3-8)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+3-6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+6-6)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+2-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+8-17)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+4-11)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+9-18)
  • (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+2-4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp (+1-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+1-4)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-32)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+5-10)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..cb6ef8bc17220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,20 +578,6 @@ class AnalysisState {
       insideMutuallyExclusiveRegionsCache;
 };
 
-/// BufferizationState provides information about the state of the IR during the
-/// bufferization process.
-class BufferizationState {
-public:
-  /// Get a reference to the collection of cached symbol tables.
-  SymbolTableCollection &getSymbolTables();
-
-private:
-  /// The cached symbol tables.
-  /// The user is expected to update / invalidate the cached symbol tables if
-  /// the bufferized operation has the Symbol or SymbolTable traits.
-  SymbolTableCollection symbolTables;
-};
-
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
 /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
 /// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..95022d7d665d2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"::llvm::LogicalResult",
         /*methodName=*/"bufferize",
         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "::mlir::bufferization::BufferizationState &":$state),
+                      "const ::mlir::bufferization::BufferizationOptions &":$options),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..7a1a701bea6dc 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool resultBufferizesToMemoryWrite(OpResult opResult,
                                        const AnalysisState &state);
@@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp
 
   let extraClassDeclaration = [{
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
 
     bool bufferizesToMemoryRead(OpOperand &opOperand,
                                 const AnalysisState &state);
@@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 }
 
@@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     //===------------------------------------------------------------------===//
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state) const {
+                            const BufferizationOptions &options) const {
       // to_tensor/to_buffer pairs fold away after bufferization.
       return success();
     }
@@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     }
 
     LogicalResult bufferize(RewriterBase &rewriter,
-                            const BufferizationOptions &options,
-                            BufferizationState &state);
+                            const BufferizationOptions &options);
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index c08bd6c436133..e5f3b6d571f43 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,7 +29,6 @@ class GlobalOp;
 } // namespace memref
 
 namespace bufferization {
-class BufferizationState;
 
 /// A simple analysis that detects allocation operations.
 class BufferPlacementAllocs {
@@ -123,14 +122,9 @@ class BufferPlacementTransformationBase {
 // Globals are created lazily at the top of the enclosing ModuleOp with pretty
 // names. Duplicates are avoided.
 FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
-                                         SymbolTableCollection &symbolTables,
                                          uint64_t alignment,
                                          Attribute memorySpace = {});
 
-void removeSymbol(Operation *op, BufferizationState &state);
-
-void insertSymbol(Operation *op, BufferizationState &state);
-
 } // namespace bufferization
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..d5cb8d8eb673c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,7 +45,6 @@ struct BufferizationStatistics {
 /// additional buffer copies or set "options.copyBeforeWrite = true". The
 /// general bufferization entry point is `runOneShotBufferize`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
-                          BufferizationState &bufferizationState,
                           BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 15189d2c1cb87..673027f76190d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
 /// Run One-Shot Bufferize on the given op: Analysis + Bufferization
 LogicalResult
 runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
-                    BufferizationState &state,
                     BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 2cf801dd1d951..4e5f5e9c730fa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,7 +20,6 @@ namespace bufferization {
 struct BufferizationStatistics;
 class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
-class BufferizationState;
 
 /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
 /// `state`.
@@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 ///   will be inserted only to these FuncOps.
 llvm::LogicalResult
 bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-                  BufferizationState &state,
                   BufferizationStatistics *statistics = nullptr);
 
 /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
 llvm::LogicalResult runOneShotModuleBufferize(
     ModuleOp moduleOp,
     const bufferization::OneShotBufferizationOptions &options,
-    BufferizationState &state, BufferizationStatistics *statistics = nullptr);
+    BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..4f90fc8831bc6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,7 +30,6 @@ namespace mlir {
 namespace bufferization {
 class AllocTensorOp;
 class OneShotAnalysisState;
-class BufferizationState;
 } // namespace bufferization
 
 namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..5e69a98db8f1e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,8 +24,7 @@ struct ConstantOpInterface
     : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
                                                     arith::ConstantOp> {
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
     auto type = dyn_cast<RankedTensorType>(constantOp.getType());
 
@@ -47,8 +46,7 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, state.getSymbolTables(),
-                     options.bufferAlignment, memorySpace);
+        getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;
@@ -85,8 +83,7 @@ struct IndexCastOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto castOp = cast<arith::IndexCastOp>(op);
     auto resultTensorType = cast<TensorType>(castOp.getType());
 
@@ -134,8 +131,7 @@ struct SelectOpInterface
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                          const BufferizationOptions &options,
-                          BufferizationState &state) const {
+                          const BufferizationOptions &options) const {
     auto selectOp = cast<arith::SelectOp>(op);
     Location loc = selectOp.getLoc();
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..1fc34051680f1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
   insideMutuallyExclusiveRegionsCache.clear();
 }
 
-SymbolTableCollection &BufferizationState::getSymbolTables() {
-  return symbolTables;
-}
-
 Region *bufferization::getNextEnclosingRepetitiveRegion(
     Region *region, const BufferizationOptions &options) {
   assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 91eccb0ab7430..ecd2ef15546a4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes(
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
-                                       const BufferizationOptions &options,
-                                       BufferizationState &state) {
+                                       const BufferizationOptions &options) {
   OpBuilder::InsertionGuard g(rewriter);
   Location loc = getLoc();
 
@@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
-                                         const BufferizationOptions &options,
-                                         BufferizationState &state) {
+                                         const BufferizationOptions &options) {
   FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
   if (failed(buffer))
     return failure();
@@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
 
 LogicalResult
 MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
-                                      const BufferizationOptions &options,
-                                      BufferizationState &state) {
+                                      const BufferizationOptions &options) {
   bool tensorDest = isa<TensorType>(getDest().getType());
   Value buffer;
   if (tensorDest) {
@@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
-                                    const BufferizationOptions &options,
-                                    BufferizationState &state) {
+                                    const BufferizationOptions &options) {
   // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
   (void)foldToBufferToTensorPair(rewriter, *this, options);
   // Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..a1d7bb995fc73 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,8 +83,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto payloadOps = state.getPayloadOps(getTarget());
-  BufferizationState bufferizationState;
-
   for (Operation *target : payloadOps) {
     if (!isa<ModuleOp, FunctionOpInterface>(target))
       return emitSilenceableError() << "expected module or function target";
@@ -92,12 +90,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
         return emitSilenceableError() << "expected module target";
-      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
-                                                          bufferizationState)))
+      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
         return emitSilenceableError() << "bufferization failed";
     } else {
-      if (failed(bufferization::runOneShotBufferize(target, options,
-                                                    bufferizationState)))
+      if (failed(bufferization::runOneShotBufferize(target, options)))
         return emitSilenceableError() << "bufferization failed";
     }
   }
@@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
-
         >();
   }
 };
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index ff2c83d228dbb..c2e90764b1335 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
 //===----------------------------------------------------------------------===//
 
 FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp,
-                            SymbolTableCollection &symbolTables,
-                            uint64_t alignment, Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
+                            Attribute memorySpace) {
   auto type = cast<RankedTensorType>(constantOp.getType());
   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
   if (!moduleOp)
@@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   // Create a builder without an insertion point. We will insert using the
   // symbol table to guarantee unique names.
   OpBuilder globalBuilder(moduleOp.getContext());
-  SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
+  SymbolTable symbolTable(moduleOp);
 
   // Create a pretty name.
   SmallString<64> buf;
@@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
   global->moveBefore(&moduleOp.front());
   return global;
 }
-
-namespace mlir::bufferization {
-void removeSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.remove(op);
-}
-
-void insertSymbol(Operation *op, BufferizationState &state) {
-  SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
-      op->getParentWithTrait<OpTrait::SymbolTable>());
-
-  symbolTable.insert(op);
-}
-} // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 38de525316f7a..824b505517119 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -161,13 +161,10 @@ struct OneShotBufferizePass
       return signalPassFailure();
     }
 
-    BufferizationState state;
-
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      if (failed(
-              runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -178,7 +175,7 @@ struct OneShotBufferizePass
                   "'bufferize-function-boundaries'");
         return signalPassFailure();
       }
-      if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
         signalPassFailure();
         return;
       }
@@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
-                                         BufferizationState &bufferizationState,
                                          BufferizationStatistics *statistics) {
   if (options.copyBeforeWrite) {
     AnalysisState state(options);
@@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                << "//===-------------------------------------------===//\n"
                << "IR after bufferizing: " << nextOp->getName() << "\n");
     rewriter.setInsertionPoint(nextOp);
-    if (failed(
...
[truncated]

Copy link
Contributor

@kazutakahirata kazutakahirata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@mscuttari mscuttari merged commit 72a8893 into main May 22, 2025
23 of 24 checks passed
@mscuttari mscuttari deleted the revert-138143-bufferization-state branch May 22, 2025 07:25
@joker-eph
Copy link
Collaborator

How did this pass the premerge in the first place?

@mscuttari
Copy link
Member Author

How did this pass the premerge in the first place?

That's what I'm also wondering. The build was fine and checks were passing both locally (at least the mlir-check ones) and on the build bots. As I stated in the new PR (#141019), it was just one signature which I forgot to update, but the others within the same compilation unit were addressed. As of now I really don't know how this could be possible.

@mscuttari
Copy link
Member Author

I've found the cause of the problem.
#140355 introduced the implementation of the bufferization interface for tensor::ConcatOp. However my branch, even if rebased to avoid conflicts, was on an older commit. The checks therefore passed before the merge, but not anymore after the commits got squashed and applied on head.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants