@@ -185,14 +185,17 @@ struct AnnotateWithFullyQualifiedName : public OpInterfaceRewritePattern<SymbolO
185185
186186struct RenameFunctionsPattern : public RewritePattern {
187187 // / This overload constructs a pattern that matches any operation type.
188- RenameFunctionsPattern (MLIRContext *context, SmallVector<Operation *> *symbolTables)
189- : RewritePattern(MatchAnyOpTypeTag(), 1 , context), _symbolTables(symbolTables)
188+ RenameFunctionsPattern (MLIRContext *context, SmallVector<Operation *> *symbolTables,
189+ llvm::SmallSet<StringRef, 8 > *externalFuncDeclNames)
190+ : RewritePattern(MatchAnyOpTypeTag(), 1 , context), _symbolTables(symbolTables),
191+ _externalFuncDeclNames (externalFuncDeclNames)
190192 {
191193 }
192194
193195 LogicalResult matchAndRewrite (Operation *op, PatternRewriter &rewriter) const override ;
194196
195197 SmallVector<Operation *> *_symbolTables;
198+ llvm::SmallSet<StringRef, 8 > *_externalFuncDeclNames;
196199};
197200
198201static constexpr llvm::StringRef hasBeenRenamedAttrName = " catalyst.unique_names" ;
@@ -236,14 +239,14 @@ LogicalResult RenameFunctionsPattern::matchAndRewrite(Operation *child,
236239
237240 // We should not rename external function declarations, as they can be
238241 // names required by other APIs.
239- // During inlining, on-the-fly they still need to be renamed, otherwise module
240- // verifier complains. So we save the original external API name and rename
241- // them back after inlining is finished .
242+ // We record these external func decls during the rename pattern.
243+ // Then during the actual inlining stage, only the first occurance of the per-module
244+ // func decls of these external decls should be inlined .
242245 if (isa<func::FuncOp>(op)) {
243246 auto f = cast<func::FuncOp>(op);
244247 if (f.isExternal ()) {
245- op. setAttr ( " original_external_API_name " ,
246- rewriter. getStringAttr (f. getName ())) ;
248+ _externalFuncDeclNames-> insert (f. getName ());
249+ continue ;
247250 }
248251 }
249252
@@ -265,7 +268,14 @@ LogicalResult RenameFunctionsPattern::matchAndRewrite(Operation *child,
265268
266269struct InlineNestedModule : public RewritePattern {
267270 // / This overload constructs a pattern that matches any operation type.
268- InlineNestedModule (MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), 1 , context) {}
271+ InlineNestedModule (MLIRContext *context,
272+ const llvm::SmallSet<StringRef, 8 > &externalFuncDeclNames)
273+ : RewritePattern(MatchAnyOpTypeTag(), 1 , context),
274+ _externalFuncDeclNames (externalFuncDeclNames)
275+ {
276+ }
277+
278+ mutable llvm::SmallSet<StringRef, 8 > alreadyInlinedFuncDeclNames;
269279
270280 LogicalResult matchAndRewrite (Operation *op, PatternRewriter &rewriter) const override
271281 {
@@ -277,6 +287,28 @@ struct InlineNestedModule : public RewritePattern {
277287 }
278288
279289 auto parent = op->getParentOp ();
290+ assert (parent->hasTrait <OpTrait::SymbolTable>() &&
291+ " the direct parent of a qnode module must be a module op" );
292+
293+ // Look for the func decls in the current qnode module
294+ // If it is a recorded external func decl, erase it if it already has been inlined.
295+ SmallVector<Operation *> _erasureWorklist;
296+ op->walk ([&](func::FuncOp f) {
297+ StringRef funcName = f.getName ();
298+ if (f.isExternal () && _externalFuncDeclNames.contains (funcName)) {
299+
300+ if (alreadyInlinedFuncDeclNames.contains (funcName)) {
301+ _erasureWorklist.push_back (f);
302+ }
303+ else {
304+ alreadyInlinedFuncDeclNames.insert (funcName);
305+ }
306+ }
307+ });
308+ for (auto op : _erasureWorklist) {
309+ rewriter.eraseOp (op);
310+ }
311+
280312 // Can't generalize getting a region other than the zero-th one.
281313 rewriter.inlineRegionBefore (op->getRegion (0 ), &parent->getRegion (0 ).back ());
282314 Block *inlinedBlock = &parent->getRegion (0 ).front ();
@@ -288,6 +320,8 @@ struct InlineNestedModule : public RewritePattern {
288320
289321 return success ();
290322 }
323+
324+ llvm::SmallSet<StringRef, 8 > _externalFuncDeclNames;
291325};
292326
293327struct SymbolReplacerPattern
@@ -365,60 +399,6 @@ struct NestedToFlatCallPattern : public OpRewritePattern<catalyst::LaunchKernelO
365399 const DenseMap<SymbolRefAttr, SymbolRefAttr> *_map;
366400};
367401
368- struct RestoreExternalFuncDeclNamePattern : public OpRewritePattern <func::FuncOp> {
369- using OpRewritePattern<func::FuncOp>::OpRewritePattern;
370-
371- RestoreExternalFuncDeclNamePattern (MLIRContext *context)
372- : OpRewritePattern<func::FuncOp>::OpRewritePattern(context)
373- {
374- }
375-
376- mutable llvm::SmallSet<StringRef, 8 > CreatedAPIFuncNames;
377-
378- LogicalResult matchAndRewrite (func::FuncOp op, PatternRewriter &rewriter) const override
379- {
380- if (!op->hasAttr (" original_external_API_name" )) {
381- return failure ();
382- }
383-
384- // Create func decl that matches the original name for the external API.
385- // This must happen only once.
386- StringRef APIFuncName =
387- cast<StringAttr>(op->getAttr (" original_external_API_name" )).getValue ();
388- if (!CreatedAPIFuncNames.contains (APIFuncName)) {
389- auto APIFuncDecl = rewriter.create <func::FuncOp>(
390- op->getLoc (), rewriter.getStringAttr (APIFuncName), op.getFunctionType ());
391- APIFuncDecl.setPrivate ();
392- CreatedAPIFuncNames.insert (APIFuncName);
393- return success ();
394- }
395- return failure ();
396- }
397- };
398-
399- struct UpdateCalleeToExternalAPINamesPattern : public OpRewritePattern <func::CallOp> {
400- using OpRewritePattern<func::CallOp>::OpRewritePattern;
401-
402- UpdateCalleeToExternalAPINamesPattern (
403- MLIRContext *context, mlir::DenseMap<StringRef, StringAttr> uniqued_names_to_APIName)
404- : OpRewritePattern<func::CallOp>::OpRewritePattern(context),
405- _uniqued_names_to_APIName (uniqued_names_to_APIName)
406- {
407- }
408-
409- mlir::DenseMap<StringRef, StringAttr> _uniqued_names_to_APIName;
410-
411- LogicalResult matchAndRewrite (func::CallOp op, PatternRewriter &rewriter) const override
412- {
413- if (_uniqued_names_to_APIName.contains (op.getCallee ())) {
414- op.setCallee (_uniqued_names_to_APIName.at (op.getCallee ()));
415- return success ();
416- }
417-
418- return failure ();
419- }
420- };
421-
422402struct CleanupPattern : public RewritePattern {
423403 // / This overload constructs a pattern that matches any operation type.
424404 CleanupPattern (MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), 1 , context) {}
@@ -431,9 +411,6 @@ struct CleanupPattern : public RewritePattern {
431411 }
432412 rewriter.modifyOpInPlace (op, [&] { op->removeAttr (fullyQualifiedNameAttr); });
433413
434- if (op->hasAttr (" original_external_API_name" )) {
435- op->erase ();
436- }
437414 return success ();
438415 }
439416};
@@ -472,34 +449,6 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
472449 int _stopAfterStep;
473450 InlineNestedSymbolTablePass (int stopAfter) : _stopAfterStep(stopAfter) {}
474451
475- void restoreExternalFuncDeclName (Operation *symbolTable, MLIRContext *context,
476- GreedyRewriteConfig config)
477- {
478- mlir::DenseMap<StringRef, StringAttr> uniqued_names_to_APIName;
479- symbolTable->walk ([&](func::FuncOp funcOp) {
480- if (funcOp->hasAttr (" original_external_API_name" )) {
481- uniqued_names_to_APIName[funcOp.getSymName ()] =
482- cast<StringAttr>(funcOp->getAttr (" original_external_API_name" ));
483- }
484- });
485- RewritePatternSet restoreExternalFuncDeclName (context);
486- restoreExternalFuncDeclName.add <RestoreExternalFuncDeclNamePattern>(context);
487- bool run = _stopAfterStep >= 5 || _stopAfterStep == 0 ;
488- if (run && failed (applyPatternsGreedily (symbolTable, std::move (restoreExternalFuncDeclName),
489- config))) {
490- signalPassFailure ();
491- }
492-
493- RewritePatternSet updateCalleeToExternalAPINames (context);
494- updateCalleeToExternalAPINames.add <UpdateCalleeToExternalAPINamesPattern>(
495- context, uniqued_names_to_APIName);
496- run = _stopAfterStep >= 5 || _stopAfterStep == 0 ;
497- if (run && failed (applyPatternsGreedily (
498- symbolTable, std::move (updateCalleeToExternalAPINames), config))) {
499- signalPassFailure ();
500- }
501- }
502-
503452 void runOnOperation () override
504453 {
505454 // Here we are in a root module/symbol table
@@ -528,15 +477,16 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
528477 return WalkResult::skip ();
529478 });
530479
531- renameFunctions.add <RenameFunctionsPattern>(context, &symbolTables);
480+ llvm::SmallSet<StringRef, 8 > externalFuncDeclNames;
481+ renameFunctions.add <RenameFunctionsPattern>(context, &symbolTables, &externalFuncDeclNames);
532482
533483 bool run = _stopAfterStep >= 2 || _stopAfterStep == 0 ;
534484 if (run && failed (applyPatternsGreedily (symbolTable, std::move (renameFunctions), config))) {
535485 signalPassFailure ();
536486 }
537487
538488 RewritePatternSet inlineNested (context);
539- inlineNested.add <InlineNestedModule>(context);
489+ inlineNested.add <InlineNestedModule>(context, externalFuncDeclNames );
540490 run = _stopAfterStep >= 3 || _stopAfterStep == 0 ;
541491 if (run && failed (applyPatternsGreedily (symbolTable, std::move (inlineNested), config))) {
542492 signalPassFailure ();
@@ -568,9 +518,6 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
568518 signalPassFailure ();
569519 }
570520
571- // Restore external API func decl names and update calls.
572- restoreExternalFuncDeclName (symbolTable, context, config);
573-
574521 RewritePatternSet cleanup (context);
575522 cleanup.add <CleanupPattern>(context);
576523 run = _stopAfterStep >= 5 || _stopAfterStep == 0 ;
0 commit comments