@@ -23,6 +23,8 @@ namespace bufferization {
2323using namespace mlir ;
2424using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
2525using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
26+ using AllocDynamicSizesMap =
27+ llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
2628
2729// / Return `true` if the given MemRef type has a fully dynamic layout.
2830static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -43,6 +45,50 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4345 return type.getLayout ().isIdentity ();
4446}
4547
48+ // / Return the dynamic shapes of the `memref` based on the defining op. If the
49+ // / complete dynamic shape fails to be captured, return an empty value.
50+ // / Currently, only function block arguments are supported for capturing.
51+ static SmallVector<Value> getDynamicSize (Value memref, func::FuncOp funcOp) {
52+ Operation *defOp = memref.getDefiningOp ();
53+ if (!defOp)
54+ return {};
55+ auto operands = defOp->getOperands ();
56+ SmallVector<Value> dynamicSizes;
57+ for (Value size : operands) {
58+ if (!isa<IndexType>(size.getType ()))
59+ continue ;
60+
61+ BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
62+ if (!sizeSrc)
63+ return {};
64+ auto arguments = funcOp.getArguments ();
65+ auto iter = llvm::find (arguments, sizeSrc);
66+ if (iter == arguments.end ())
67+ return {};
68+ dynamicSizes.push_back (*iter);
69+ }
70+ return dynamicSizes;
71+ }
72+
73+ // / Returns the dynamic sizes at the callee, through the call relationship
74+ // / between the caller and callee.
75+ static SmallVector<Value> mapDynamicSizeAtCaller (func::CallOp call,
76+ func::FuncOp callee,
77+ ValueRange dynamicSizes) {
78+ SmallVector<Value> mappedDynamicSizes;
79+ for (Value size : dynamicSizes) {
80+ for (auto [src, dst] :
81+ llvm::zip_first (call.getOperands (), callee.getArguments ())) {
82+ if (size != dst)
83+ continue ;
84+ mappedDynamicSizes.push_back (src);
85+ }
86+ }
87+ assert (mappedDynamicSizes.size () == dynamicSizes.size () &&
88+ " could not find all dynamic sizes" );
89+ return mappedDynamicSizes;
90+ }
91+
4692// Updates the func op and entry block.
4793//
4894// Any args appended to the entry block are added to `appendedEntryArgs`.
@@ -109,6 +155,7 @@ updateFuncOp(func::FuncOp func,
109155// the given out-params.
110156static LogicalResult
111157updateReturnOps (func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
158+ AllocDynamicSizesMap &map,
112159 const bufferization::BufferResultsToOutParamsOpts &options) {
113160 auto res = func.walk ([&](func::ReturnOp op) {
114161 SmallVector<Value, 6 > copyIntoOutParams;
@@ -120,12 +167,22 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
120167 keepAsReturnOperands.push_back (operand);
121168 }
122169 OpBuilder builder (op);
170+ SmallVector<SmallVector<Value>> dynamicSizes;
123171 for (auto [orig, arg] : llvm::zip (copyIntoOutParams, appendedEntryArgs)) {
124- if (options.hoistStaticAllocs &&
172+ bool hoistStaticAllocs =
173+ options.hoistStaticAllocs &&
174+ cast<MemRefType>(orig.getType ()).hasStaticShape ();
175+ bool hoistDynamicAllocs =
176+ options.hoistDynamicAllocs &&
177+ !cast<MemRefType>(orig.getType ()).hasStaticShape ();
178+ if ((hoistStaticAllocs || hoistDynamicAllocs) &&
125179 isa_and_nonnull<bufferization::AllocationOpInterface>(
126- orig.getDefiningOp ()) &&
127- mlir::cast<MemRefType>(orig.getType ()).hasStaticShape ()) {
180+ orig.getDefiningOp ())) {
128181 orig.replaceAllUsesWith (arg);
182+ if (hoistDynamicAllocs) {
183+ SmallVector<Value> dynamicSize = getDynamicSize (orig, func);
184+ dynamicSizes.push_back (dynamicSize);
185+ }
129186 orig.getDefiningOp ()->erase ();
130187 } else {
131188 if (failed (options.memCpyFn (builder, op.getLoc (), orig, arg)))
@@ -134,6 +191,10 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
134191 }
135192 func::ReturnOp::create (builder, op.getLoc (), keepAsReturnOperands);
136193 op.erase ();
194+ auto dynamicSizePair =
195+ std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
196+ dynamicSizes);
197+ map.insert (dynamicSizePair);
137198 return WalkResult::advance ();
138199 });
139200 return failure (res.wasInterrupted ());
@@ -142,7 +203,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
142203// Updates all CallOps in the scope of the given ModuleOp by allocating
143204// temporary buffers for newly introduced out params.
144205static LogicalResult
145- updateCalls (ModuleOp module ,
206+ updateCalls (ModuleOp module , const AllocDynamicSizesMap &map,
146207 const bufferization::BufferResultsToOutParamsOpts &options) {
147208 bool didFail = false ;
148209 SymbolTable symtab (module );
@@ -166,8 +227,15 @@ updateCalls(ModuleOp module,
166227 }
167228 SmallVector<Value, 6 > outParams;
168229 OpBuilder builder (op);
230+ SmallVector<SmallVector<Value>> dynamicSizes = map.lookup (callee);
231+ size_t dynamicSizesIndex = 0 ;
169232 for (Value memref : replaceWithOutParams) {
170- if (!cast<MemRefType>(memref.getType ()).hasStaticShape ()) {
233+ SmallVector<Value> dynamicSize = dynamicSizes.size () > dynamicSizesIndex
234+ ? dynamicSizes[dynamicSizesIndex]
235+ : SmallVector<Value>();
236+ bool memrefStaticShape =
237+ cast<MemRefType>(memref.getType ()).hasStaticShape ();
238+ if (!memrefStaticShape && dynamicSize.empty ()) {
171239 op.emitError ()
172240 << " cannot create out param for dynamically shaped result" ;
173241 didFail = true ;
@@ -177,8 +245,15 @@ updateCalls(ModuleOp module,
177245 auto allocType =
178246 MemRefType::get (memrefType.getShape (), memrefType.getElementType (),
179247 AffineMap (), memrefType.getMemorySpace ());
248+
249+ if (memrefStaticShape) {
250+ dynamicSize = {};
251+ } else {
252+ ++dynamicSizesIndex;
253+ dynamicSize = mapDynamicSizeAtCaller (op, callee, dynamicSize);
254+ }
180255 auto maybeOutParam =
181- options.allocationFn (builder, op.getLoc (), allocType);
256+ options.allocationFn (builder, op.getLoc (), allocType, dynamicSize );
182257 if (failed (maybeOutParam)) {
183258 op.emitError () << " failed to create allocation op" ;
184259 didFail = true ;
@@ -213,6 +288,9 @@ updateCalls(ModuleOp module,
213288LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
214289 ModuleOp module ,
215290 const bufferization::BufferResultsToOutParamsOpts &options) {
291+ // It maps the shape source of the dynamic shape memref returned by each
292+ // function.
293+ AllocDynamicSizesMap map;
216294 for (auto func : module .getOps <func::FuncOp>()) {
217295 if (!options.filterFn (&func))
218296 continue ;
@@ -222,11 +300,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
222300 return failure ();
223301 if (func.isExternal ())
224302 continue ;
225- if (failed (updateReturnOps (func, appendedEntryArgs, options))) {
303+ if (failed (updateReturnOps (func, appendedEntryArgs, map, options))) {
226304 return failure ();
227305 }
228306 }
229- if (failed (updateCalls (module , options)))
307+ if (failed (updateCalls (module , map, options)))
230308 return failure ();
231309 return success ();
232310}
@@ -243,6 +321,8 @@ struct BufferResultsToOutParamsPass
243321 options.addResultAttribute = true ;
244322 if (hoistStaticAllocs)
245323 options.hoistStaticAllocs = true ;
324+ if (hoistDynamicAllocs)
325+ options.hoistDynamicAllocs = true ;
246326
247327 if (failed (bufferization::promoteBufferResultsToOutParams (getOperation (),
248328 options)))
0 commit comments