Skip to content

Commit 66aa9a2

Browse files
authored
[mlir][bufferization] Implement BufferDeallocationopInterface for scf.forall.in_parallel (#66351)
The scf.forall.in_parallel terminator operation has a nested graph region with the NoTerminator trait. Such regions are not supported by the default implementations. Therefore, this commit adds a specialized implementation for this operation which only covers the case where the nested region is empty. This is because after bufferization, ops like tensor.parallel_insert_slice were already converted to memref operations residing int the scf.forall only and the nested region of scf.forall.in_parallel ends up empty.
1 parent 9e739fd commit 66aa9a2

File tree

5 files changed

+137
-0
lines changed

5 files changed

+137
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace scf {
17+
void registerBufferDeallocationOpInterfaceExternalModels(
18+
DialectRegistry &registry);
19+
} // namespace scf
20+
} // namespace mlir
21+
22+
#endif // MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
#include "mlir/Dialect/Quant/QuantOps.h"
6161
#include "mlir/Dialect/SCF/IR/SCF.h"
6262
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
63+
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
64+
#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
6365
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
6466
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
6567
#include "mlir/Dialect/Shape/IR/Shape.h"
@@ -149,6 +151,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
149151
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
150152
memref::registerValueBoundsOpInterfaceExternalModels(registry);
151153
memref::registerMemorySlotExternalModels(registry);
154+
scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
152155
scf::registerBufferizableOpInterfaceExternalModels(registry);
153156
scf::registerValueBoundsOpInterfaceExternalModels(registry);
154157
shape::registerBufferizableOpInterfaceExternalModels(registry);
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
10+
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
11+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12+
#include "mlir/Dialect/SCF/IR/SCF.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::bufferization;
16+
17+
namespace {
18+
/// The `scf.forall.in_parallel` terminator is special in a few ways:
19+
/// * It does not implement the BranchOpInterface or
20+
/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
21+
/// which is not supported by BufferDeallocation.
22+
/// * It has a graph-like region which only allows one specific tensor op
23+
/// * After bufferization the nested region is always empty
24+
/// For these reasons we provide custom deallocation logic via this external
25+
/// model.
26+
///
27+
/// Example:
28+
/// ```mlir
29+
/// scf.forall (%arg1) in (%arg0) {
30+
/// %alloc = memref.alloc() : memref<2xf32>
31+
/// ...
32+
/// <implicit in_parallel terminator here>
33+
/// }
34+
/// ```
35+
/// gets transformed to
36+
/// ```mlir
37+
/// scf.forall (%arg1) in (%arg0) {
38+
/// %alloc = memref.alloc() : memref<2xf32>
39+
/// ...
40+
/// bufferization.dealloc (%alloc : memref<2xf32>) if (%true)
41+
/// <implicit in_parallel terminator here>
42+
/// }
43+
/// ```
44+
struct InParallelOpInterface
45+
: public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
46+
scf::InParallelOp> {
47+
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
48+
const DeallocationOptions &options) const {
49+
auto inParallelOp = cast<scf::InParallelOp>(op);
50+
OpBuilder builder(op);
51+
if (!inParallelOp.getBody()->empty())
52+
return op->emitError("only supported when nested region is empty");
53+
54+
// Collect the values to deallocate and retain and use them to create the
55+
// dealloc operation.
56+
Block *block = op->getBlock();
57+
SmallVector<Value> memrefs, conditions, toRetain;
58+
if (failed(state.getMemrefsAndConditionsToDeallocate(
59+
builder, op->getLoc(), block, memrefs, conditions)))
60+
return failure();
61+
62+
state.getMemrefsToRetain(block, /*toBlock=*/nullptr, {}, toRetain);
63+
if (memrefs.empty() && toRetain.empty())
64+
return op;
65+
66+
auto deallocOp = builder.create<bufferization::DeallocOp>(
67+
op->getLoc(), memrefs, conditions, toRetain);
68+
69+
// We want to replace the current ownership of the retained values with the
70+
// result values of the dealloc operation as they are always unique.
71+
state.resetOwnerships(deallocOp.getRetained(), block);
72+
for (auto [retained, ownership] :
73+
llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions()))
74+
state.updateOwnership(retained, ownership, block);
75+
76+
return op;
77+
}
78+
};
79+
80+
} // namespace
81+
82+
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
83+
DialectRegistry &registry) {
84+
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
85+
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
86+
});
87+
}

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRSCFTransforms
2+
BufferDeallocationOpInterfaceImpl.cpp
23
BufferizableOpInterfaceImpl.cpp
34
Bufferize.cpp
45
ForToWhile.cpp
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
2+
// RUN: -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
3+
4+
func.func @parallel_insert_slice(%arg0: index) {
5+
%c0 = arith.constant 0 : index
6+
%alloc = memref.alloc() : memref<2xf32>
7+
scf.forall (%arg1) in (%arg0) {
8+
%alloc0 = memref.alloc() : memref<2xf32>
9+
%0 = memref.load %alloc[%c0] : memref<2xf32>
10+
linalg.fill ins(%0 : f32) outs(%alloc0 : memref<2xf32>)
11+
}
12+
return
13+
}
14+
15+
// CHECK-LABEL: func @parallel_insert_slice
16+
// CHECK-SAME: (%arg0: index)
17+
// CHECK: [[ALLOC0:%.+]] = memref.alloc(
18+
// CHECK: scf.forall
19+
// CHECK: [[ALLOC1:%.+]] = memref.alloc(
20+
// CHECK: bufferization.dealloc ([[ALLOC1]] : memref<2xf32>) if (%true
21+
// CHECK-NOT: retain
22+
// CHECK: }
23+
// CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
24+
// CHECK-NOT: retain

0 commit comments

Comments
 (0)