Skip to content

Commit c03efd8

Browse files
committed
[mlir][bufferization][NFC] Introduce BufferDeallocationOpInterface
This new interface allows operations to implement custom handling of ownership values and insertion of dealloc operations which is useful when an op cannot implement the interfaces supported by default by the buffer deallocation pass (e.g., because they are not exactly compatible or because there are some additional semantics to it that would render the default implementations in buffer deallocation invalid, or because no interfaces exist for this kind of behavior and it's not worth introducing one plus a default implementation in buffer deallocation). Additionally, it can also be used to provide more efficient handling for a specific op than the interface based default implementations can.
1 parent 3ae76e4 commit c03efd8

File tree

15 files changed

+798
-522
lines changed

15 files changed

+798
-522
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
//===- BufferDeallocationOpInterface.h --------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_
10+
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_
11+
12+
#include "mlir/Analysis/Liveness.h"
13+
#include "mlir/IR/Operation.h"
14+
#include "mlir/IR/SymbolTable.h"
15+
#include "mlir/Support/LLVM.h"
16+
17+
namespace mlir {
18+
namespace bufferization {
19+
20+
/// Compare two SSA values in a deterministic manner. Two block arguments are
21+
/// ordered by argument number, block arguments are always less than operation
22+
/// results, and operation results are ordered by the `isBeforeInBlock` order of
23+
/// their defining operation.
24+
struct ValueComparator {
25+
bool operator()(const Value &lhs, const Value &rhs) const;
26+
};
27+
28+
/// This class is used to track the ownership of values. The ownership can
29+
/// either be not initialized yet ('Uninitialized' state), set to a unique SSA
30+
/// value which indicates the ownership at runtime (or statically if it is a
31+
/// constant value) ('Unique' state), or it cannot be represented in a single
32+
/// SSA value ('Unknown' state). An artificial example of a case where ownership
33+
/// cannot be represented in a single i1 SSA value could be the following:
34+
/// `%0 = test.non_deterministic_select %arg0, %arg1 : i32`
35+
/// Since the operation does not provide us a separate boolean indicator on
36+
/// which of the two operands was selected, we would need to either insert an
37+
/// alias check at runtime to determine if `%0` aliases with `%arg0` or `%arg1`,
38+
/// or insert a `bufferization.clone` operation to get a fresh buffer which we
39+
/// could assign ownership to.
40+
///
41+
/// The three states this class can represent form a lattice on a partial order:
42+
/// forall X in SSA values. uninitialized < unique(X) < unknown
43+
/// forall X, Y in SSA values.
44+
/// unique(X) == unique(Y) iff X and Y always evaluate to the same value
45+
/// unique(X) != unique(Y) otherwise
46+
class Ownership {
47+
public:
48+
/// Constructor that creates an 'Uninitialized' ownership. This is needed for
49+
/// default-construction when used in DenseMap.
50+
Ownership() = default;
51+
52+
/// Constructor that creates an 'Unique' ownership. This is a non-explicit
53+
/// constructor to allow implicit conversion from 'Value'.
54+
Ownership(Value indicator);
55+
56+
/// Get an ownership value in 'Unknown' state.
57+
static Ownership getUnknown();
58+
/// Get an ownership value in 'Unique' state with 'indicator' as parameter.
59+
static Ownership getUnique(Value indicator);
60+
/// Get an ownership value in 'Uninitialized' state.
61+
static Ownership getUninitialized();
62+
63+
/// Check if this ownership value is in the 'Uninitialized' state.
64+
bool isUninitialized() const;
65+
/// Check if this ownership value is in the 'Unique' state.
66+
bool isUnique() const;
67+
/// Check if this ownership value is in the 'Unknown' state.
68+
bool isUnknown() const;
69+
70+
/// If this ownership value is in 'Unique' state, this function can be used to
71+
/// get the indicator parameter. Using this function in any other state is UB.
72+
Value getIndicator() const;
73+
74+
/// Get the join of the two-element subset {this,other}. Does not modify
75+
/// 'this'.
76+
Ownership getCombined(Ownership other) const;
77+
78+
/// Modify 'this' ownership to be the join of the current 'this' and 'other'.
79+
void combine(Ownership other);
80+
81+
private:
82+
enum class State {
83+
Uninitialized,
84+
Unique,
85+
Unknown,
86+
};
87+
88+
// The indicator value is only relevant in the 'Unique' state.
89+
Value indicator;
90+
State state = State::Uninitialized;
91+
};
92+
93+
/// Options for BufferDeallocationOpInterface-based buffer deallocation.
94+
struct DeallocationOptions {
95+
// A pass option indicating whether private functions should be modified to
96+
// pass the ownership of MemRef values instead of adhering to the function
97+
// boundary ABI.
98+
bool privateFuncDynamicOwnership = false;
99+
};
100+
101+
/// This class collects all the state that we need to perform the buffer
102+
/// deallocation pass with associated helper functions such that we have easy
103+
/// access to it in the BufferDeallocationOpInterface implementations and the
104+
/// BufferDeallocation pass.
105+
class DeallocationState {
106+
public:
107+
DeallocationState(Operation *op);
108+
109+
// The state should always be passed by reference.
110+
DeallocationState(const DeallocationState &) = delete;
111+
112+
/// Small helper function to update the ownership map by taking the current
113+
/// ownership ('Uninitialized' state if not yet present), computing the join
114+
/// with the passed ownership and storing this new value in the map. By
115+
/// default, it will be performed for the block where 'owned' is defined. If
116+
/// the ownership of the given value should be updated for another block, the
117+
/// 'block' argument can be explicitly passed.
118+
void updateOwnership(Value memref, Ownership ownership,
119+
Block *block = nullptr);
120+
121+
/// Removes ownerships associated with all values in the passed range for
122+
/// 'block'.
123+
void resetOwnerships(ValueRange memrefs, Block *block);
124+
125+
/// Returns the ownership of 'memref' for the given basic block.
126+
Ownership getOwnership(Value memref, Block *block) const;
127+
128+
/// Remember the given 'memref' to deallocate it at the end of the 'block'.
129+
void addMemrefToDeallocate(Value memref, Block *block);
130+
131+
/// Forget about a MemRef that we originally wanted to deallocate at the end
132+
/// of 'block', possibly because it already gets deallocated before the end of
133+
/// the block.
134+
void dropMemrefToDeallocate(Value memref, Block *block);
135+
136+
/// Return a sorted list of MemRef values which are live at the start of the
137+
/// given block.
138+
void getLiveMemrefsIn(Block *block, SmallVectorImpl<Value> &memrefs);
139+
140+
/// Given an SSA value of MemRef type, this function queries the ownership and
141+
/// if it is not already in the 'Unique' state, potentially inserts IR to get
142+
/// a new SSA value, returned as the first element of the pair, which has
143+
/// 'Unique' ownership and can be used instead of the passed Value with the
144+
/// the ownership indicator returned as the second element of the pair.
145+
std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
146+
Value memref);
147+
148+
/// Given two basic blocks and the values passed via block arguments to the
149+
/// destination block, compute the list of MemRefs that have to be retained in
150+
/// the 'fromBlock' to not run into a use-after-free situation.
151+
/// This list consists of the MemRefs in the successor operand list of the
152+
/// terminator and the MemRefs in the 'out' set of the liveness analysis
153+
/// intersected with the 'in' set of the destination block.
154+
///
155+
/// toRetain = filter(successorOperands + (liveOut(fromBlock) insersect
156+
/// liveIn(toBlock)), isMemRef)
157+
void getMemrefsToRetain(Block *fromBlock, Block *toBlock,
158+
ValueRange destOperands,
159+
SmallVectorImpl<Value> &toRetain) const;
160+
161+
/// For a given block, computes the list of MemRefs that potentially need to
162+
/// be deallocated at the end of that block. This list also contains values
163+
/// that have to be retained (and are thus part of the list returned by
164+
/// `getMemrefsToRetain`) and is computed by taking the MemRefs in the 'in'
165+
/// set of the liveness analysis of 'block' appended by the set of MemRefs
166+
/// allocated in 'block' itself and subtracted by the set of MemRefs
167+
/// deallocated in 'block'.
168+
/// Note that we don't have to take the intersection of the liveness 'in' set
169+
/// with the 'out' set of the predecessor block because a value that is in the
170+
/// 'in' set must be defined in an ancestor block that dominates all direct
171+
/// predecessors and thus the 'in' set of this block is a subset of the 'out'
172+
/// sets of each predecessor.
173+
///
174+
/// memrefs = filter((liveIn(block) U
175+
/// allocated(block) U arguments(block)) \ deallocated(block), isMemRef)
176+
///
177+
/// The list of conditions is then populated by querying the internal
178+
/// datastructures for the ownership value of that MemRef.
179+
LogicalResult
180+
getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc,
181+
Block *block,
182+
SmallVectorImpl<Value> &memrefs,
183+
SmallVectorImpl<Value> &conditions) const;
184+
185+
/// Returns the symbol cache to lookup functions from call operations to check
186+
/// attributes on the function operation.
187+
SymbolTableCollection *getSymbolTable() { return &symbolTable; }
188+
189+
private:
190+
// Symbol cache to lookup functions from call operations to check attributes
191+
// on the function operation.
192+
SymbolTableCollection symbolTable;
193+
194+
// Mapping from each SSA value with MemRef type to the associated ownership in
195+
// each block.
196+
DenseMap<std::pair<Value, Block *>, Ownership> ownershipMap;
197+
198+
// Collects the list of MemRef values that potentially need to be deallocated
199+
// per block. It is also fine (albeit not efficient) to add MemRef values that
200+
// don't have to be deallocated, but only when the ownership is not 'Unknown'.
201+
DenseMap<Block *, SmallVector<Value>> memrefsToDeallocatePerBlock;
202+
203+
// The underlying liveness analysis to compute fine grained information about
204+
// alloc and dealloc positions.
205+
Liveness liveness;
206+
};
207+
208+
} // namespace bufferization
209+
} // namespace mlir
210+
211+
//===----------------------------------------------------------------------===//
212+
// Buffer Deallocation Interface
213+
//===----------------------------------------------------------------------===//
214+
215+
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h.inc"
216+
217+
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//===-- BufferDeallocationOpInterface.td -------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef BUFFER_DEALLOCATION_OP_INTERFACE
10+
#define BUFFER_DEALLOCATION_OP_INTERFACE
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def BufferDeallocationOpInterface :
15+
OpInterface<"BufferDeallocationOpInterface"> {
16+
let description = [{
17+
An op interface for Buffer Deallocation. Ops that implement this interface
18+
can provide custom logic for computing the ownership of OpResults, modify
19+
the operation to properly pass the ownership values around, and insert
20+
`bufferization.dealloc` operations when necessary.
21+
}];
22+
let cppNamespace = "::mlir::bufferization";
23+
let methods = [
24+
InterfaceMethod<
25+
/*desc=*/[{
26+
This method takes the current deallocation state and transformation
27+
options and updates the deallocation state as necessary for the
28+
operation implementing this interface. It may also insert
29+
`bufferization.dealloc` operations and rebuild itself with different
30+
result types. For operations implementing this interface all other
31+
interface handlers (e.g., default handlers for interfaces like
32+
RegionBranchOpInterface, CallOpInterface, etc.) are skipped by the
33+
deallocation pass. On success, either the current operation or one of
34+
the newly inserted operations is returned from which on the driver
35+
should continue the processing. On failure, the deallocation pass
36+
will terminate. It is recommended to emit a useful error message in
37+
that case.
38+
}],
39+
/*retType=*/"FailureOr<Operation *>",
40+
/*methodName=*/"process",
41+
/*args=*/(ins "DeallocationState &":$state,
42+
"const DeallocationOptions &":$options)>
43+
];
44+
}
45+
46+
#endif // BUFFER_DEALLOCATION_OP_INTERFACE

mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect(BufferizationOps bufferization)
22
add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
33
add_mlir_interface(AllocationOpInterface)
4+
add_mlir_interface(BufferDeallocationOpInterface)
45
add_mlir_interface(BufferizableOpInterface)
56
add_mlir_interface(SubsetInsertionOpInterface)
67

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,6 @@ class BufferPlacementTransformationBase {
121121
Liveness liveness;
122122
};
123123

124-
/// Compare two SSA values in a deterministic manner. Two block arguments are
125-
/// ordered by argument number, block arguments are always less than operation
126-
/// results, and operation results are ordered by the `isBeforeInBlock` order of
127-
/// their defining operation.
128-
struct ValueComparator {
129-
bool operator()(const Value &lhs, const Value &rhs) const;
130-
};
131-
132124
// Create a global op for the given tensor-valued constant in the program.
133125
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
134126
// names. Duplicates are avoided.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace cf {
17+
void registerBufferDeallocationOpInterfaceExternalModels(
18+
DialectRegistry &registry);
19+
} // namespace cf
20+
} // namespace mlir
21+
22+
#endif // MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
3030
#include "mlir/Dialect/Complex/IR/Complex.h"
3131
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
32+
#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
3233
#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
3334
#include "mlir/Dialect/DLTI/DLTI.h"
3435
#include "mlir/Dialect/EmitC/IR/EmitC.h"
@@ -138,6 +139,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
138139
registry);
139140
builtin::registerCastOpInterfaceExternalModels(registry);
140141
cf::registerBufferizableOpInterfaceExternalModels(registry);
142+
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
141143
linalg::registerBufferizableOpInterfaceExternalModels(registry);
142144
linalg::registerTilingInterfaceExternalModels(registry);
143145
linalg::registerValueBoundsOpInterfaceExternalModels(registry);

0 commit comments

Comments
 (0)