Skip to content

Commit d5b7c32

Browse files
committed
No need to rename and undo rename! Just inline the first occurance.
1 parent ef7f52a commit d5b7c32

File tree

1 file changed

+45
-98
lines changed

1 file changed

+45
-98
lines changed

mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp

Lines changed: 45 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,17 @@ struct AnnotateWithFullyQualifiedName : public OpInterfaceRewritePattern<SymbolO
185185

186186
struct RenameFunctionsPattern : public RewritePattern {
187187
/// This overload constructs a pattern that matches any operation type.
188-
RenameFunctionsPattern(MLIRContext *context, SmallVector<Operation *> *symbolTables)
189-
: RewritePattern(MatchAnyOpTypeTag(), 1, context), _symbolTables(symbolTables)
188+
RenameFunctionsPattern(MLIRContext *context, SmallVector<Operation *> *symbolTables,
189+
llvm::SmallSet<StringRef, 8> *externalFuncDeclNames)
190+
: RewritePattern(MatchAnyOpTypeTag(), 1, context), _symbolTables(symbolTables),
191+
_externalFuncDeclNames(externalFuncDeclNames)
190192
{
191193
}
192194

193195
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override;
194196

195197
SmallVector<Operation *> *_symbolTables;
198+
llvm::SmallSet<StringRef, 8> *_externalFuncDeclNames;
196199
};
197200

198201
static constexpr llvm::StringRef hasBeenRenamedAttrName = "catalyst.unique_names";
@@ -236,14 +239,14 @@ LogicalResult RenameFunctionsPattern::matchAndRewrite(Operation *child,
236239

237240
// We should not rename external function declarations, as they can be
238241
// names required by other APIs.
239-
// During inlining, on-the-fly they still need to be renamed, otherwise module
240-
// verifier complains. So we save the original external API name and rename
241-
// them back after inlining is finished.
242+
// We record these external func decls during the rename pattern.
243+
// Then during the actual inlining stage, only the first occurance of the per-module
244+
// func decls of these external decls should be inlined.
242245
if (isa<func::FuncOp>(op)) {
243246
auto f = cast<func::FuncOp>(op);
244247
if (f.isExternal()) {
245-
op.setAttr("original_external_API_name",
246-
rewriter.getStringAttr(f.getName()));
248+
_externalFuncDeclNames->insert(f.getName());
249+
continue;
247250
}
248251
}
249252

@@ -265,7 +268,14 @@ LogicalResult RenameFunctionsPattern::matchAndRewrite(Operation *child,
265268

266269
struct InlineNestedModule : public RewritePattern {
267270
/// This overload constructs a pattern that matches any operation type.
268-
InlineNestedModule(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
271+
InlineNestedModule(MLIRContext *context,
272+
const llvm::SmallSet<StringRef, 8> &externalFuncDeclNames)
273+
: RewritePattern(MatchAnyOpTypeTag(), 1, context),
274+
_externalFuncDeclNames(externalFuncDeclNames)
275+
{
276+
}
277+
278+
mutable llvm::SmallSet<StringRef, 8> alreadyInlinedFuncDeclNames;
269279

270280
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
271281
{
@@ -277,6 +287,28 @@ struct InlineNestedModule : public RewritePattern {
277287
}
278288

279289
auto parent = op->getParentOp();
290+
assert(parent->hasTrait<OpTrait::SymbolTable>() &&
291+
"the direct parent of a qnode module must be a module op");
292+
293+
// Look for the func decls in the current qnode module
294+
// If it is a recorded external func decl, erase it if it already has been inlined.
295+
SmallVector<Operation *> _erasureWorklist;
296+
op->walk([&](func::FuncOp f) {
297+
StringRef funcName = f.getName();
298+
if (f.isExternal() && _externalFuncDeclNames.contains(funcName)) {
299+
300+
if (alreadyInlinedFuncDeclNames.contains(funcName)) {
301+
_erasureWorklist.push_back(f);
302+
}
303+
else {
304+
alreadyInlinedFuncDeclNames.insert(funcName);
305+
}
306+
}
307+
});
308+
for (auto op : _erasureWorklist) {
309+
rewriter.eraseOp(op);
310+
}
311+
280312
// Can't generalize getting a region other than the zero-th one.
281313
rewriter.inlineRegionBefore(op->getRegion(0), &parent->getRegion(0).back());
282314
Block *inlinedBlock = &parent->getRegion(0).front();
@@ -288,6 +320,8 @@ struct InlineNestedModule : public RewritePattern {
288320

289321
return success();
290322
}
323+
324+
llvm::SmallSet<StringRef, 8> _externalFuncDeclNames;
291325
};
292326

293327
struct SymbolReplacerPattern
@@ -365,60 +399,6 @@ struct NestedToFlatCallPattern : public OpRewritePattern<catalyst::LaunchKernelO
365399
const DenseMap<SymbolRefAttr, SymbolRefAttr> *_map;
366400
};
367401

368-
struct RestoreExternalFuncDeclNamePattern : public OpRewritePattern<func::FuncOp> {
369-
using OpRewritePattern<func::FuncOp>::OpRewritePattern;
370-
371-
RestoreExternalFuncDeclNamePattern(MLIRContext *context)
372-
: OpRewritePattern<func::FuncOp>::OpRewritePattern(context)
373-
{
374-
}
375-
376-
mutable llvm::SmallSet<StringRef, 8> CreatedAPIFuncNames;
377-
378-
LogicalResult matchAndRewrite(func::FuncOp op, PatternRewriter &rewriter) const override
379-
{
380-
if (!op->hasAttr("original_external_API_name")) {
381-
return failure();
382-
}
383-
384-
// Create func decl that matches the original name for the external API.
385-
// This must happen only once.
386-
StringRef APIFuncName =
387-
cast<StringAttr>(op->getAttr("original_external_API_name")).getValue();
388-
if (!CreatedAPIFuncNames.contains(APIFuncName)) {
389-
auto APIFuncDecl = rewriter.create<func::FuncOp>(
390-
op->getLoc(), rewriter.getStringAttr(APIFuncName), op.getFunctionType());
391-
APIFuncDecl.setPrivate();
392-
CreatedAPIFuncNames.insert(APIFuncName);
393-
return success();
394-
}
395-
return failure();
396-
}
397-
};
398-
399-
struct UpdateCalleeToExternalAPINamesPattern : public OpRewritePattern<func::CallOp> {
400-
using OpRewritePattern<func::CallOp>::OpRewritePattern;
401-
402-
UpdateCalleeToExternalAPINamesPattern(
403-
MLIRContext *context, mlir::DenseMap<StringRef, StringAttr> uniqued_names_to_APIName)
404-
: OpRewritePattern<func::CallOp>::OpRewritePattern(context),
405-
_uniqued_names_to_APIName(uniqued_names_to_APIName)
406-
{
407-
}
408-
409-
mlir::DenseMap<StringRef, StringAttr> _uniqued_names_to_APIName;
410-
411-
LogicalResult matchAndRewrite(func::CallOp op, PatternRewriter &rewriter) const override
412-
{
413-
if (_uniqued_names_to_APIName.contains(op.getCallee())) {
414-
op.setCallee(_uniqued_names_to_APIName.at(op.getCallee()));
415-
return success();
416-
}
417-
418-
return failure();
419-
}
420-
};
421-
422402
struct CleanupPattern : public RewritePattern {
423403
/// This overload constructs a pattern that matches any operation type.
424404
CleanupPattern(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
@@ -431,9 +411,6 @@ struct CleanupPattern : public RewritePattern {
431411
}
432412
rewriter.modifyOpInPlace(op, [&] { op->removeAttr(fullyQualifiedNameAttr); });
433413

434-
if (op->hasAttr("original_external_API_name")) {
435-
op->erase();
436-
}
437414
return success();
438415
}
439416
};
@@ -472,34 +449,6 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
472449
int _stopAfterStep;
473450
InlineNestedSymbolTablePass(int stopAfter) : _stopAfterStep(stopAfter) {}
474451

475-
void restoreExternalFuncDeclName(Operation *symbolTable, MLIRContext *context,
476-
GreedyRewriteConfig config)
477-
{
478-
mlir::DenseMap<StringRef, StringAttr> uniqued_names_to_APIName;
479-
symbolTable->walk([&](func::FuncOp funcOp) {
480-
if (funcOp->hasAttr("original_external_API_name")) {
481-
uniqued_names_to_APIName[funcOp.getSymName()] =
482-
cast<StringAttr>(funcOp->getAttr("original_external_API_name"));
483-
}
484-
});
485-
RewritePatternSet restoreExternalFuncDeclName(context);
486-
restoreExternalFuncDeclName.add<RestoreExternalFuncDeclNamePattern>(context);
487-
bool run = _stopAfterStep >= 5 || _stopAfterStep == 0;
488-
if (run && failed(applyPatternsGreedily(symbolTable, std::move(restoreExternalFuncDeclName),
489-
config))) {
490-
signalPassFailure();
491-
}
492-
493-
RewritePatternSet updateCalleeToExternalAPINames(context);
494-
updateCalleeToExternalAPINames.add<UpdateCalleeToExternalAPINamesPattern>(
495-
context, uniqued_names_to_APIName);
496-
run = _stopAfterStep >= 5 || _stopAfterStep == 0;
497-
if (run && failed(applyPatternsGreedily(
498-
symbolTable, std::move(updateCalleeToExternalAPINames), config))) {
499-
signalPassFailure();
500-
}
501-
}
502-
503452
void runOnOperation() override
504453
{
505454
// Here we are in a root module/symbol table
@@ -528,15 +477,16 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
528477
return WalkResult::skip();
529478
});
530479

531-
renameFunctions.add<RenameFunctionsPattern>(context, &symbolTables);
480+
llvm::SmallSet<StringRef, 8> externalFuncDeclNames;
481+
renameFunctions.add<RenameFunctionsPattern>(context, &symbolTables, &externalFuncDeclNames);
532482

533483
bool run = _stopAfterStep >= 2 || _stopAfterStep == 0;
534484
if (run && failed(applyPatternsGreedily(symbolTable, std::move(renameFunctions), config))) {
535485
signalPassFailure();
536486
}
537487

538488
RewritePatternSet inlineNested(context);
539-
inlineNested.add<InlineNestedModule>(context);
489+
inlineNested.add<InlineNestedModule>(context, externalFuncDeclNames);
540490
run = _stopAfterStep >= 3 || _stopAfterStep == 0;
541491
if (run && failed(applyPatternsGreedily(symbolTable, std::move(inlineNested), config))) {
542492
signalPassFailure();
@@ -568,9 +518,6 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
568518
signalPassFailure();
569519
}
570520

571-
// Restore external API func decl names and update calls.
572-
restoreExternalFuncDeclName(symbolTable, context, config);
573-
574521
RewritePatternSet cleanup(context);
575522
cleanup.add<CleanupPattern>(context);
576523
run = _stopAfterStep >= 5 || _stopAfterStep == 0;

0 commit comments

Comments
 (0)