@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
75
75
using namespace mlir ::bufferization::func_ext;
76
76
77
77
// / A mapping of FuncOps to their callers.
78
- using FuncCallerMap = DenseMap<FunctionOpInterface , DenseSet<Operation *>>;
78
+ using FuncCallerMap = DenseMap<func::FuncOp , DenseSet<Operation *>>;
79
79
80
80
// / Get or create FuncAnalysisState.
81
81
static FuncAnalysisState &
@@ -88,11 +88,10 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
88
88
89
89
// / Return the unique ReturnOp that terminates `funcOp`.
90
90
// / Return nullptr if there is no such unique ReturnOp.
91
- static Operation *getAssumedUniqueReturnOp (FunctionOpInterface funcOp) {
92
- Operation *returnOp = nullptr ;
93
- for (Block &b : funcOp.getFunctionBody ()) {
94
- auto candidateOp = b.getTerminator ();
95
- if (candidateOp && candidateOp->hasTrait <OpTrait::ReturnLike>()) {
91
+ static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
92
+ func::ReturnOp returnOp;
93
+ for (Block &b : funcOp.getBody ()) {
94
+ if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
96
95
if (returnOp)
97
96
return nullptr ;
98
97
returnOp = candidateOp;
@@ -127,16 +126,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
127
126
// / Store function BlockArguments that are equivalent to/aliasing a returned
128
127
// / value in FuncAnalysisState.
129
128
static LogicalResult
130
- aliasingFuncOpBBArgsAnalysis (FunctionOpInterface funcOp,
131
- OneShotAnalysisState &state,
129
+ aliasingFuncOpBBArgsAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
132
130
FuncAnalysisState &funcState) {
133
- if (funcOp.getFunctionBody ().empty ()) {
131
+ if (funcOp.getBody ().empty ()) {
134
132
// No function body available. Conservatively assume that every tensor
135
133
// return value may alias with any tensor bbArg.
136
- for (const auto &inputIt : llvm::enumerate (funcOp.getArgumentTypes ())) {
134
+ FunctionType type = funcOp.getFunctionType ();
135
+ for (const auto &inputIt : llvm::enumerate (type.getInputs ())) {
137
136
if (!isa<TensorType>(inputIt.value ()))
138
137
continue ;
139
- for (const auto &resultIt : llvm::enumerate (funcOp. getResultTypes ())) {
138
+ for (const auto &resultIt : llvm::enumerate (type. getResults ())) {
140
139
if (!isa<TensorType>(resultIt.value ()))
141
140
continue ;
142
141
int64_t returnIdx = resultIt.index ();
@@ -148,7 +147,7 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
148
147
}
149
148
150
149
// Support only single return-terminated block in the function.
151
- Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
150
+ func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
152
151
assert (returnOp && " expected func with single return op" );
153
152
154
153
for (OpOperand &returnVal : returnOp->getOpOperands ())
@@ -169,8 +168,8 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
169
168
return success ();
170
169
}
171
170
172
- static void annotateFuncArgAccess (FunctionOpInterface funcOp, int64_t idx,
173
- bool isRead, bool isWritten) {
171
+ static void annotateFuncArgAccess (func::FuncOp funcOp, int64_t idx, bool isRead ,
172
+ bool isWritten) {
174
173
OpBuilder b (funcOp.getContext ());
175
174
Attribute accessType;
176
175
if (isRead && isWritten) {
@@ -190,12 +189,12 @@ static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
190
189
// / function with unknown ops, we conservatively assume that such ops bufferize
191
190
// / to a read + write.
192
191
static LogicalResult
193
- funcOpBbArgReadWriteAnalysis (FunctionOpInterface funcOp,
194
- OneShotAnalysisState &state,
192
+ funcOpBbArgReadWriteAnalysis (FuncOp funcOp, OneShotAnalysisState &state,
195
193
FuncAnalysisState &funcState) {
196
- for (int64_t idx = 0 , e = funcOp.getNumArguments (); idx < e; ++idx) {
194
+ for (int64_t idx = 0 , e = funcOp.getFunctionType ().getNumInputs (); idx < e;
195
+ ++idx) {
197
196
// Skip non-tensor arguments.
198
- if (!isa<TensorType>(funcOp.getArgumentTypes ()[ idx] ))
197
+ if (!isa<TensorType>(funcOp.getFunctionType (). getInput ( idx) ))
199
198
continue ;
200
199
bool isRead;
201
200
bool isWritten;
@@ -205,7 +204,7 @@ funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
205
204
StringRef str = accessAttr.getValue ();
206
205
isRead = str == " read" || str == " read-write" ;
207
206
isWritten = str == " write" || str == " read-write" ;
208
- } else if (funcOp.getFunctionBody ().empty ()) {
207
+ } else if (funcOp.getBody ().empty ()) {
209
208
// If the function has no body, conservatively assume that all args are
210
209
// read + written.
211
210
isRead = true ;
@@ -231,32 +230,33 @@ funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
231
230
232
231
// / Remove bufferization attributes on FuncOp arguments.
233
232
static void removeBufferizationAttributes (BlockArgument bbArg) {
234
- auto funcOp = cast<FunctionOpInterface >(bbArg.getOwner ()->getParentOp ());
233
+ auto funcOp = cast<func::FuncOp >(bbArg.getOwner ()->getParentOp ());
235
234
funcOp.removeArgAttr (bbArg.getArgNumber (),
236
235
BufferizationDialect::kBufferLayoutAttrName );
237
236
funcOp.removeArgAttr (bbArg.getArgNumber (),
238
237
BufferizationDialect::kWritableAttrName );
239
238
}
240
239
241
- static FunctionOpInterface getCalledFunction (CallOpInterface callOp) {
240
+ // / Return the func::FuncOp called by `callOp`.
241
+ static func::FuncOp getCalledFunction (func::CallOp callOp) {
242
242
SymbolRefAttr sym =
243
243
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
244
244
if (!sym)
245
245
return nullptr ;
246
- return dyn_cast_or_null<FunctionOpInterface >(
246
+ return dyn_cast_or_null<func::FuncOp >(
247
247
SymbolTable::lookupNearestSymbolFrom (callOp, sym));
248
248
}
249
249
250
250
// / Gather equivalence info of CallOps.
251
251
// / Note: This only adds new equivalence info if the called function was already
252
252
// / analyzed.
253
253
// TODO: This does not handle cyclic function call graphs etc.
254
- static void equivalenceAnalysis (FunctionOpInterface funcOp,
254
+ static void equivalenceAnalysis (func::FuncOp funcOp,
255
255
OneShotAnalysisState &state,
256
256
FuncAnalysisState &funcState) {
257
- funcOp->walk ([&](CallOpInterface callOp) {
258
- FunctionOpInterface calledFunction = getCalledFunction (callOp);
259
- assert (calledFunction && " could not retrieved called FunctionOpInterface " );
257
+ funcOp->walk ([&](func::CallOp callOp) {
258
+ func::FuncOp calledFunction = getCalledFunction (callOp);
259
+ assert (calledFunction && " could not retrieved called func::FuncOp " );
260
260
261
261
// No equivalence info available for the called function.
262
262
if (!funcState.equivalentFuncArgs .count (calledFunction))
@@ -267,7 +267,7 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp,
267
267
int64_t bbargIdx = it.second ;
268
268
if (!state.isInPlace (callOp->getOpOperand (bbargIdx)))
269
269
continue ;
270
- Value returnVal = callOp-> getResult (returnIdx);
270
+ Value returnVal = callOp. getResult (returnIdx);
271
271
Value argVal = callOp->getOperand (bbargIdx);
272
272
state.unionEquivalenceClasses (returnVal, argVal);
273
273
}
@@ -277,9 +277,11 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp,
277
277
}
278
278
279
279
// / Return "true" if the given function signature has tensor semantics.
280
- static bool hasTensorSignature (FunctionOpInterface funcOp) {
281
- return llvm::any_of (funcOp.getArgumentTypes (), llvm::IsaPred<TensorType>) ||
282
- llvm::any_of (funcOp.getResultTypes (), llvm::IsaPred<TensorType>);
280
+ static bool hasTensorSignature (func::FuncOp funcOp) {
281
+ return llvm::any_of (funcOp.getFunctionType ().getInputs (),
282
+ llvm::IsaPred<TensorType>) ||
283
+ llvm::any_of (funcOp.getFunctionType ().getResults (),
284
+ llvm::IsaPred<TensorType>);
283
285
}
284
286
285
287
// / Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
@@ -289,16 +291,16 @@ static bool hasTensorSignature(FunctionOpInterface funcOp) {
289
291
// / retrieve the called FuncOp from any func::CallOp.
290
292
static LogicalResult
291
293
getFuncOpsOrderedByCalls (ModuleOp moduleOp,
292
- SmallVectorImpl<FunctionOpInterface > &orderedFuncOps,
294
+ SmallVectorImpl<func::FuncOp > &orderedFuncOps,
293
295
FuncCallerMap &callerMap) {
294
296
// For each FuncOp, the set of functions called by it (i.e. the union of
295
297
// symbols of all nested func::CallOp).
296
- DenseMap<FunctionOpInterface , DenseSet<FunctionOpInterface >> calledBy;
298
+ DenseMap<func::FuncOp , DenseSet<func::FuncOp >> calledBy;
297
299
// For each FuncOp, the number of func::CallOp it contains.
298
- DenseMap<FunctionOpInterface , unsigned > numberCallOpsContainedInFuncOp;
299
- WalkResult res = moduleOp.walk ([&](FunctionOpInterface funcOp) -> WalkResult {
300
- if (!funcOp.getFunctionBody ().empty ()) {
301
- Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
300
+ DenseMap<func::FuncOp , unsigned > numberCallOpsContainedInFuncOp;
301
+ WalkResult res = moduleOp.walk ([&](func::FuncOp funcOp) -> WalkResult {
302
+ if (!funcOp.getBody ().empty ()) {
303
+ func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
302
304
if (!returnOp)
303
305
return funcOp->emitError ()
304
306
<< " cannot bufferize a FuncOp with tensors and "
@@ -307,10 +309,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
307
309
308
310
// Collect function calls and populate the caller map.
309
311
numberCallOpsContainedInFuncOp[funcOp] = 0 ;
310
- return funcOp.walk ([&](CallOpInterface callOp) -> WalkResult {
311
- FunctionOpInterface calledFunction = getCalledFunction (callOp);
312
- assert (calledFunction &&
313
- " could not retrieved called FunctionOpInterface" );
312
+ return funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
313
+ func::FuncOp calledFunction = getCalledFunction (callOp);
314
+ assert (calledFunction && " could not retrieved called func::FuncOp" );
314
315
// If the called function does not have any tensors in its signature, then
315
316
// it is not necessary to bufferize the callee before the caller.
316
317
if (!hasTensorSignature (calledFunction))
@@ -348,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
348
349
// / most generic layout map as function return types. After bufferizing the
349
350
// / entire function body, a more concise memref type can potentially be used for
350
351
// / the return type of the function.
351
- static void foldMemRefCasts (FunctionOpInterface funcOp) {
352
- if (funcOp.getFunctionBody ().empty ())
352
+ static void foldMemRefCasts (func::FuncOp funcOp) {
353
+ if (funcOp.getBody ().empty ())
353
354
return ;
354
355
355
- Operation * returnOp = getAssumedUniqueReturnOp (funcOp);
356
+ func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
356
357
SmallVector<Type> resultTypes;
357
358
358
359
for (OpOperand &operand : returnOp->getOpOperands ()) {
@@ -364,8 +365,8 @@ static void foldMemRefCasts(FunctionOpInterface funcOp) {
364
365
}
365
366
}
366
367
367
- auto newFuncType = FunctionType::get (funcOp. getContext (),
368
- funcOp.getArgumentTypes (), resultTypes);
368
+ auto newFuncType = FunctionType::get (
369
+ funcOp. getContext (), funcOp.getFunctionType (). getInputs (), resultTypes);
369
370
funcOp.setType (newFuncType);
370
371
}
371
372
@@ -378,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
378
379
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
379
380
380
381
// A list of functions in the order in which they are analyzed + bufferized.
381
- SmallVector<FunctionOpInterface > orderedFuncOps;
382
+ SmallVector<func::FuncOp > orderedFuncOps;
382
383
383
384
// A mapping of FuncOps to their callers.
384
385
FuncCallerMap callerMap;
@@ -387,7 +388,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
387
388
return failure ();
388
389
389
390
// Analyze ops.
390
- for (FunctionOpInterface funcOp : orderedFuncOps) {
391
+ for (func::FuncOp funcOp : orderedFuncOps) {
391
392
if (!state.getOptions ().isOpAllowed (funcOp))
392
393
continue ;
393
394
@@ -415,7 +416,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
415
416
416
417
void mlir::bufferization::removeBufferizationAttributesInModule (
417
418
ModuleOp moduleOp) {
418
- moduleOp.walk ([&](FunctionOpInterface op) {
419
+ moduleOp.walk ([&](func::FuncOp op) {
419
420
for (BlockArgument bbArg : op.getArguments ())
420
421
removeBufferizationAttributes (bbArg);
421
422
});
@@ -429,7 +430,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
429
430
IRRewriter rewriter (moduleOp.getContext ());
430
431
431
432
// A list of functions in the order in which they are analyzed + bufferized.
432
- SmallVector<FunctionOpInterface > orderedFuncOps;
433
+ SmallVector<func::FuncOp > orderedFuncOps;
433
434
434
435
// A mapping of FuncOps to their callers.
435
436
FuncCallerMap callerMap;
@@ -438,11 +439,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
438
439
return failure ();
439
440
440
441
// Bufferize functions.
441
- for (FunctionOpInterface funcOp : orderedFuncOps) {
442
+ for (func::FuncOp funcOp : orderedFuncOps) {
442
443
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
443
444
// would be invalidated.
444
445
445
- if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getName ())) {
446
+ if (llvm::is_contained (options.noAnalysisFuncFilter , funcOp.getSymName ())) {
446
447
// This function was not analyzed and RaW conflicts were not resolved.
447
448
// Buffer copies must be inserted before every write.
448
449
OneShotBufferizationOptions updatedOptions = options;
@@ -462,7 +463,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
462
463
// Bufferize all other ops.
463
464
for (Operation &op : llvm::make_early_inc_range (moduleOp.getOps ())) {
464
465
// Functions were already bufferized.
465
- if (isa<FunctionOpInterface >(&op))
466
+ if (isa<func::FuncOp >(&op))
466
467
continue ;
467
468
if (failed (bufferizeOp (&op, options, statistics)))
468
469
return failure ();
@@ -489,12 +490,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
489
490
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
490
491
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
491
492
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
492
- auto func = dyn_cast<FunctionOpInterface >(op);
493
+ auto func = dyn_cast<func::FuncOp >(op);
493
494
if (!func)
494
- func = op->getParentOfType <FunctionOpInterface >();
495
+ func = op->getParentOfType <func::FuncOp >();
495
496
if (func)
496
497
return llvm::is_contained (options.noAnalysisFuncFilter ,
497
- func.getName ());
498
+ func.getSymName ());
498
499
return false ;
499
500
};
500
501
OneShotBufferizationOptions updatedOptions (options);
0 commit comments