Skip to content

Commit 6c654b5

Browse files
[mlir][linalg][bufferize] Support std.select bufferization
This op is an example for how to deal with ops who's OpResult may aliasing with one of multiple OpOperands. Differential Revision: https://reviews.llvm.org/D116868
1 parent 5642ce5 commit 6c654b5

File tree

17 files changed

+333
-60
lines changed

17 files changed

+333
-60
lines changed

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
1+
//===- AffineInterfaceImpl.h - Affine Impl. of BufferizableOpInterface ----===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,18 +346,18 @@ class BufferizationState {
346346
/// In the above example, Values with a star satisfy the condition. When
347347
/// starting the traversal from Value 1, the resulting SetVector is:
348348
/// { 2, 7, 8, 5 }
349-
llvm::SetVector<Value> findValueInReverseUseDefChain(
349+
SetVector<Value> findValueInReverseUseDefChain(
350350
Value value, llvm::function_ref<bool(Value)> condition) const;
351351

352-
/// Find the Value of the last preceding write of a given Value.
352+
/// Find the Values of the last preceding write of a given Value.
353353
///
354354
/// Note: Unknown ops are handled conservatively and assumed to be writes.
355355
/// Furthermore, BlockArguments are also assumed to be writes. There is no
356356
/// analysis across block boundaries.
357357
///
358358
/// Note: When reaching an end of the reverse SSA use-def chain, that value
359359
/// is returned regardless of whether it is a memory write or not.
360-
Value findLastPrecedingWrite(Value value) const;
360+
SetVector<Value> findLastPrecedingWrite(Value value) const;
361361

362362
/// Creates a memref allocation.
363363
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ runComprehensiveBufferize(ModuleOp moduleOp,
3131

3232
namespace std_ext {
3333

34-
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
34+
void registerModuleBufferizationExternalModels(DialectRegistry &registry);
3535

3636
} // namespace std_ext
3737
} // namespace comprehensive_bufferize
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
10+
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace linalg {
17+
namespace comprehensive_bufferize {
18+
namespace std_ext {
19+
20+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
21+
22+
} // namespace std_ext
23+
} // namespace comprehensive_bufferize
24+
} // namespace linalg
25+
} // namespace mlir
26+
27+
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -305,26 +305,18 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
305305
return result;
306306
}
307307

308-
// Find the Value of the last preceding write of a given Value.
309-
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
310-
findLastPrecedingWrite(Value value) const {
311-
SetVector<Value> result =
312-
findValueInReverseUseDefChain(value, [&](Value value) {
313-
Operation *op = value.getDefiningOp();
314-
if (!op)
315-
return true;
316-
auto bufferizableOp = options.dynCastBufferizableOp(op);
317-
if (!bufferizableOp)
318-
return true;
319-
return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
320-
});
321-
322-
// To simplify the analysis, `scf.if` ops are considered memory writes. There
323-
// are currently no other ops where one OpResult may alias with multiple
324-
// OpOperands. Therefore, this function should return exactly one result at
325-
// the moment.
326-
assert(result.size() == 1 && "expected exactly one result");
327-
return result.front();
308+
// Find the Values of the last preceding write of a given Value.
309+
llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
310+
BufferizationState::findLastPrecedingWrite(Value value) const {
311+
return findValueInReverseUseDefChain(value, [&](Value value) {
312+
Operation *op = value.getDefiningOp();
313+
if (!op)
314+
return true;
315+
auto bufferizableOp = options.dynCastBufferizableOp(op);
316+
if (!bufferizableOp)
317+
return true;
318+
return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
319+
});
328320
}
329321

330322
mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
@@ -404,15 +396,19 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
404396
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
405397
if (failed(resultBuffer))
406398
return failure();
407-
// Do not copy if the last preceding write of `operand` is an op that does
399+
// Do not copy if the last preceding writes of `operand` are ops that do
408400
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
409401
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
410402
// use-def chain, it returns that value, regardless of whether it is a
411403
// memory write or not.
412-
Value lastWrite = findLastPrecedingWrite(operand);
413-
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
414-
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
415-
return resultBuffer;
404+
SetVector<Value> lastWrites = findLastPrecedingWrite(operand);
405+
if (llvm::none_of(lastWrites, [&](Value lastWrite) {
406+
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
407+
return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
408+
*this);
409+
return true;
410+
}))
411+
return resultBuffer;
416412
// Do not copy if the copied data is never read.
417413
OpResult aliasingOpResult = getAliasingOpResult(opOperand);
418414
if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ set(LLVM_OPTIONAL_SOURCES
77
LinalgInterfaceImpl.cpp
88
ModuleBufferization.cpp
99
SCFInterfaceImpl.cpp
10+
StdInterfaceImpl.cpp
1011
TensorInterfaceImpl.cpp
1112
VectorInterfaceImpl.cpp
1213
)
@@ -61,6 +62,14 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
6162
MLIRSCF
6263
)
6364

65+
add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
66+
StdInterfaceImpl.cpp
67+
68+
LINK_LIBS PUBLIC
69+
MLIRBufferizableOpInterface
70+
MLIRStandard
71+
)
72+
6473
add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
6574
TensorInterfaceImpl.cpp
6675

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ static bool hasReadAfterWriteInterference(
219219
for (OpOperand *uRead : usesRead) {
220220
Operation *readingOp = uRead->getOwner();
221221

222-
// Find most recent write of uRead by following the SSA use-def chain. E.g.:
222+
// Find most recent writes of uRead by following the SSA use-def chain.
223+
// E.g.:
223224
//
224225
// %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
225226
// %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
@@ -228,7 +229,7 @@ static bool hasReadAfterWriteInterference(
228229
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
229230
// is %0. Note that operations that create an alias but do not write (such
230231
// as ExtractSliceOp) are skipped.
231-
Value lastWrite = state.findLastPrecedingWrite(uRead->get());
232+
SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());
232233

233234
// Look for conflicting memory writes. Potential conflicts are writes to an
234235
// alias that have been decided to bufferize inplace.
@@ -265,35 +266,38 @@ static bool hasReadAfterWriteInterference(
265266
if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
266267
continue;
267268

268-
// No conflict if the conflicting write happens before the last
269-
// write.
270-
if (Operation *writingOp = lastWrite.getDefiningOp()) {
271-
if (happensBefore(conflictingWritingOp, writingOp, domInfo))
272-
// conflictingWritingOp happens before writingOp. No conflict.
273-
continue;
274-
// No conflict if conflictingWritingOp is contained in writingOp.
275-
if (writingOp->isProperAncestor(conflictingWritingOp))
276-
continue;
277-
} else {
278-
auto bbArg = lastWrite.cast<BlockArgument>();
279-
Block *block = bbArg.getOwner();
280-
if (!block->findAncestorOpInBlock(*conflictingWritingOp))
281-
// conflictingWritingOp happens outside of the block. No
282-
// conflict.
283-
continue;
284-
}
269+
// Check all possible last writes.
270+
for (Value lastWrite : lastWrites) {
271+
// No conflict if the conflicting write happens before the last
272+
// write.
273+
if (Operation *writingOp = lastWrite.getDefiningOp()) {
274+
if (happensBefore(conflictingWritingOp, writingOp, domInfo))
275+
// conflictingWritingOp happens before writingOp. No conflict.
276+
continue;
277+
// No conflict if conflictingWritingOp is contained in writingOp.
278+
if (writingOp->isProperAncestor(conflictingWritingOp))
279+
continue;
280+
} else {
281+
auto bbArg = lastWrite.cast<BlockArgument>();
282+
Block *block = bbArg.getOwner();
283+
if (!block->findAncestorOpInBlock(*conflictingWritingOp))
284+
// conflictingWritingOp happens outside of the block. No
285+
// conflict.
286+
continue;
287+
}
285288

286-
// No conflict if the conflicting write and the last write are the same
287-
// use.
288-
if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
289-
continue;
289+
// No conflict if the conflicting write and the last write are the same
290+
// use.
291+
if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
292+
continue;
290293

291-
// All requirements are met. Conflict found!
294+
// All requirements are met. Conflict found!
292295

293-
if (options.printConflicts)
294-
annotateConflict(uRead, uConflictingWrite, lastWrite);
296+
if (options.printConflicts)
297+
annotateConflict(uRead, uConflictingWrite, lastWrite);
295298

296-
return true;
299+
return true;
300+
}
297301
}
298302
}
299303

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ struct FuncOpInterface
938938
} // namespace mlir
939939

940940
void mlir::linalg::comprehensive_bufferize::std_ext::
941-
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
941+
registerModuleBufferizationExternalModels(DialectRegistry &registry) {
942942
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
943943
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
944944
registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
12+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
13+
#include "mlir/IR/Dialect.h"
14+
#include "mlir/IR/Operation.h"
15+
16+
namespace mlir {
17+
namespace linalg {
18+
namespace comprehensive_bufferize {
19+
namespace std_ext {
20+
21+
/// Bufferization of std.select. Just replace the operands.
22+
struct SelectOpInterface
23+
: public BufferizableOpInterface::ExternalModel<SelectOpInterface,
24+
SelectOp> {
25+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
26+
const BufferizationState &state) const {
27+
return false;
28+
}
29+
30+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
31+
const BufferizationState &state) const {
32+
return false;
33+
}
34+
35+
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
36+
const BufferizationState &state) const {
37+
return op->getOpResult(0) /*result*/;
38+
}
39+
40+
SmallVector<OpOperand *>
41+
getAliasingOpOperand(Operation *op, OpResult opResult,
42+
const BufferizationState &state) const {
43+
return {&op->getOpOperand(1) /*true_value*/,
44+
&op->getOpOperand(2) /*false_value*/};
45+
}
46+
47+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
48+
const BufferizationState &state) const {
49+
auto selectOp = cast<SelectOp>(op);
50+
// `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
51+
// TODO: It would be more efficient to copy the result of the `select` op
52+
// instead of its OpOperands. In the worst case, 2 copies are inserted at
53+
// the moment (one for each tensor). When copying the op result, only one
54+
// copy would be needed.
55+
Value trueBuffer =
56+
*state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
57+
Value falseBuffer =
58+
*state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
59+
replaceOpWithNewBufferizedOp<SelectOp>(
60+
rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
61+
return success();
62+
}
63+
64+
BufferRelation bufferRelation(Operation *op, OpResult opResult,
65+
const BufferizationAliasInfo &aliasInfo,
66+
const BufferizationState &state) const {
67+
return BufferRelation::None;
68+
}
69+
};
70+
71+
} // namespace std_ext
72+
} // namespace comprehensive_bufferize
73+
} // namespace linalg
74+
} // namespace mlir
75+
76+
void mlir::linalg::comprehensive_bufferize::std_ext::
77+
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
78+
registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>();
79+
}

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
4949
MLIRSCF
5050
MLIRSCFBufferizableOpInterfaceImpl
5151
MLIRSCFTransforms
52+
MLIRStdBufferizableOpInterfaceImpl
5253
MLIRPass
5354
MLIRStandard
5455
MLIRStandardOpsTransforms

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
1818
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
1919
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
20+
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
2021
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
2122
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
2223
#include "mlir/Dialect/Linalg/Passes.h"
@@ -51,6 +52,7 @@ struct LinalgComprehensiveModuleBufferize
5152
bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
5253
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
5354
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
55+
std_ext::registerModuleBufferizationExternalModels(registry);
5456
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
5557
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
5658
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);

0 commit comments

Comments
 (0)