Skip to content

Revert "[MLIR] Add bufferization state class to OneShotBufferization pass" #141012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -578,20 +578,6 @@ class AnalysisState {
insideMutuallyExclusiveRegionsCache;
};

/// BufferizationState provides information about the state of the IR during the
/// bufferization process.
class BufferizationState {
public:
/// Get a reference to the collection of cached symbol tables.
SymbolTableCollection &getSymbolTables();

private:
/// The cached symbol tables.
/// The user is expected to update / invalidate the cached symbol tables if
/// the bufferized operation has the Symbol or SymbolTable traits.
SymbolTableCollection symbolTables;
};

/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"const ::mlir::bufferization::BufferizationOptions &":$options,
"::mlir::bufferization::BufferizationState &":$state),
"const ::mlir::bufferization::BufferizationOptions &":$options),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
Expand Down
15 changes: 5 additions & 10 deletions mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",

let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state);
const BufferizationOptions &options);

bool resultBufferizesToMemoryWrite(OpResult opResult,
const AnalysisState &state);
Expand Down Expand Up @@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp

let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state);
const BufferizationOptions &options);

bool bufferizesToMemoryRead(OpOperand &opOperand,
const AnalysisState &state);
Expand Down Expand Up @@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
}

LogicalResult bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state);
const BufferizationOptions &options);
}];
}

Expand Down Expand Up @@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
//===------------------------------------------------------------------===//

LogicalResult bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
const BufferizationOptions &options) const {
// to_tensor/to_buffer pairs fold away after bufferization.
return success();
}
Expand Down Expand Up @@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}

LogicalResult bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state);
const BufferizationOptions &options);
}];

let assemblyFormat = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class GlobalOp;
} // namespace memref

namespace bufferization {
class BufferizationState;

/// A simple analysis that detects allocation operations.
class BufferPlacementAllocs {
Expand Down Expand Up @@ -123,14 +122,9 @@ class BufferPlacementTransformationBase {
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
// names. Duplicates are avoided.
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
SymbolTableCollection &symbolTables,
uint64_t alignment,
Attribute memorySpace = {});

void removeSymbol(Operation *op, BufferizationState &state);

void insertSymbol(Operation *op, BufferizationState &state);

} // namespace bufferization
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ struct BufferizationStatistics {
/// additional buffer copies or set "options.copyBeforeWrite = true". The
/// general bufferization entry point is `runOneShotBufferize`.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);

/// Bufferize the signature of `block` and its callers (i.e., ops that have the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
LogicalResult
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
BufferizationState &state,
BufferizationStatistics *statistics = nullptr);

} // namespace bufferization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ namespace bufferization {
struct BufferizationStatistics;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;
class BufferizationState;

/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
Expand All @@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// will be inserted only to these FuncOps.
llvm::LogicalResult
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state,
BufferizationStatistics *statistics = nullptr);

/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
Expand All @@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
llvm::LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
const bufferization::OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
BufferizationStatistics *statistics = nullptr);

} // namespace bufferization
} // namespace mlir
Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ namespace mlir {
namespace bufferization {
class AllocTensorOp;
class OneShotAnalysisState;
class BufferizationState;
} // namespace bufferization

namespace linalg {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
auto type = dyn_cast<RankedTensorType>(constantOp.getType());

Expand All @@ -47,8 +46,7 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
getGlobalFor(constantOp, state.getSymbolTables(),
options.bufferAlignment, memorySpace);
getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = *globalOp;
Expand Down Expand Up @@ -85,8 +83,7 @@ struct IndexCastOpInterface
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());

Expand Down Expand Up @@ -134,8 +131,7 @@ struct SelectOpInterface
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto selectOp = cast<arith::SelectOp>(op);
Location loc = selectOp.getLoc();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
insideMutuallyExclusiveRegionsCache.clear();
}

SymbolTableCollection &BufferizationState::getSymbolTables() {
return symbolTables;
}

Region *bufferization::getNextEnclosingRepetitiveRegion(
Region *region, const BufferizationOptions &options) {
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
Expand Down
12 changes: 4 additions & 8 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes(
//===----------------------------------------------------------------------===//

LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) {
const BufferizationOptions &options) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = getLoc();

Expand Down Expand Up @@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//

LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) {
const BufferizationOptions &options) {
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
if (failed(buffer))
return failure();
Expand Down Expand Up @@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,

LogicalResult
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) {
const BufferizationOptions &options) {
bool tensorDest = isa<TensorType>(getDest().getType());
Value buffer;
if (tensorDest) {
Expand Down Expand Up @@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
}

LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) {
const BufferizationOptions &options) {
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
(void)foldToBufferToTensorPair(rewriter, *this, options);
// Note: The return value of `bufferize` indicates whether there was an error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,17 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
}

auto payloadOps = state.getPayloadOps(getTarget());
BufferizationState bufferizationState;

for (Operation *target : payloadOps) {
if (!isa<ModuleOp, FunctionOpInterface>(target))
return emitSilenceableError() << "expected module or function target";
auto moduleOp = dyn_cast<ModuleOp>(target);
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
bufferizationState)))
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
return emitSilenceableError() << "bufferization failed";
} else {
if (failed(bufferization::runOneShotBufferize(target, options,
bufferizationState)))
if (failed(bufferization::runOneShotBufferize(target, options)))
return emitSilenceableError() << "bufferization failed";
}
}
Expand Down Expand Up @@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"

>();
}
};
Expand Down
23 changes: 3 additions & 20 deletions mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
//===----------------------------------------------------------------------===//

FailureOr<memref::GlobalOp>
bufferization::getGlobalFor(arith::ConstantOp constantOp,
SymbolTableCollection &symbolTables,
uint64_t alignment, Attribute memorySpace) {
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
Attribute memorySpace) {
auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
Expand All @@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(moduleOp.getContext());
SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
SymbolTable symbolTable(moduleOp);

// Create a pretty name.
SmallString<64> buf;
Expand Down Expand Up @@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
global->moveBefore(&moduleOp.front());
return global;
}

namespace mlir::bufferization {
void removeSymbol(Operation *op, BufferizationState &state) {
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
op->getParentWithTrait<OpTrait::SymbolTable>());

symbolTable.remove(op);
}

void insertSymbol(Operation *op, BufferizationState &state) {
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
op->getParentWithTrait<OpTrait::SymbolTable>());

symbolTable.insert(op);
}
} // namespace mlir::bufferization
11 changes: 3 additions & 8 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,10 @@ struct OneShotBufferizePass
return signalPassFailure();
}

BufferizationState state;

BufferizationStatistics statistics;
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
if (failed(
runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
signalPassFailure();
return;
}
Expand All @@ -178,7 +175,7 @@ struct OneShotBufferizePass
"'bufferize-function-boundaries'");
return signalPassFailure();
}
if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
signalPassFailure();
return;
}
Expand Down Expand Up @@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {

LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options,
BufferizationState &bufferizationState,
BufferizationStatistics *statistics) {
if (options.copyBeforeWrite) {
AnalysisState state(options);
Expand Down Expand Up @@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
<< "//===-------------------------------------------===//\n"
<< "IR after bufferizing: " << nextOp->getName() << "\n");
rewriter.setInsertionPoint(nextOp);
if (failed(
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
if (failed(bufferizableOp.bufferize(rewriter, options))) {
LLVM_DEBUG(llvm::dbgs()
<< "failed to bufferize\n"
<< "//===-------------------------------------------===//\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ struct CallOpInterface
/// All function arguments are writable. It is the responsibility of the
/// CallOp to insert buffer copies where necessary.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
const BufferizationOptions &options) const {
func::CallOp callOp = cast<func::CallOp>(op);

// 1. Compute the result types of the new CallOp.
Expand Down Expand Up @@ -350,8 +349,7 @@ struct ReturnOpInterface
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
const BufferizationOptions &options) const {
#ifndef NDEBUG
auto returnOp = cast<func::ReturnOp>(op);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
Expand Down Expand Up @@ -420,8 +418,7 @@ struct FuncOpInterface
/// All function bbArgs are writable unless they are explicitly marked as
/// read-only. Callers must insert copies when needed.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();

Expand Down
Loading
Loading