@@ -688,9 +688,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
688688 UnresolvedMaterializationRewrite (
689689 ConversionPatternRewriterImpl &rewriterImpl,
690690 UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr ,
691- MaterializationKind kind = MaterializationKind::Target)
692- : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
693- converterAndKind (converter, kind) {}
691+ MaterializationKind kind = MaterializationKind::Target);
694692
695693 static bool classof (const IRRewrite *rewrite) {
696694 return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
730728 });
731729}
732730
733- // / Find the single rewrite object of the specified type and block among the
734- // / given rewrites. In debug mode, asserts that there is mo more than one such
735- // / object. Return "nullptr" if no object was found.
736- template <typename RewriteTy, typename R>
737- static RewriteTy *findSingleRewrite (R &&rewrites, Block *block) {
738- RewriteTy *result = nullptr ;
739- for (auto &rewrite : rewrites) {
740- auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get ());
741- if (rewriteTy && rewriteTy->getBlock () == block) {
742- #ifndef NDEBUG
743- assert (!result && " expected single matching rewrite" );
744- result = rewriteTy;
745- #else
746- return rewriteTy;
747- #endif // NDEBUG
748- }
749- }
750- return result;
751- }
752-
753731// ===----------------------------------------------------------------------===//
754732// ConversionPatternRewriterImpl
755733// ===----------------------------------------------------------------------===//
@@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
892870
893871 bool wasErased (void *ptr) const { return erased.contains (ptr); }
894872
895- bool wasErased (OperationRewrite *rewrite) const {
896- return wasErased (rewrite->getOperation ());
897- }
898-
899873 void notifyOperationErased (Operation *op) override { erased.insert (op); }
900874
901875 void notifyBlockErased (Block *block) override { erased.insert (block); }
@@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
935909 // / to modify/access them is invalid rewriter API usage.
936910 SetVector<Operation *> replacedOps;
937911
938- // / A set of all unresolved materializations.
939- DenseSet<Operation *> unresolvedMaterializations;
912+ // / A mapping of all unresolved materializations (UnrealizedConversionCastOp)
913+ // / to the corresponding rewrite objects.
914+ DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
915+ unresolvedMaterializations;
940916
941917 // / The current type converter, or nullptr if no type converter is currently
942918 // / active.
@@ -1058,12 +1034,20 @@ void CreateOperationRewrite::rollback() {
10581034 op->erase ();
10591035}
10601036
1037+ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite (
1038+ ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1039+ const TypeConverter *converter, MaterializationKind kind)
1040+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1041+ converterAndKind(converter, kind) {
1042+ rewriterImpl.unresolvedMaterializations [op] = this ;
1043+ }
1044+
10611045void UnresolvedMaterializationRewrite::rollback () {
10621046 if (getMaterializationKind () == MaterializationKind::Target) {
10631047 for (Value input : op->getOperands ())
10641048 rewriterImpl.mapping .erase (input);
10651049 }
1066- rewriterImpl.unresolvedMaterializations .erase (op );
1050+ rewriterImpl.unresolvedMaterializations .erase (getOperation () );
10671051 op->erase ();
10681052}
10691053
@@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13451329 builder.setInsertionPoint (ip.getBlock (), ip.getPoint ());
13461330 auto convertOp =
13471331 builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1348- unresolvedMaterializations.insert (convertOp);
13491332 appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13501333 return convertOp.getResult (0 );
13511334}
@@ -1382,10 +1365,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13821365 for (auto [newValue, result] : llvm::zip (newValues, op->getResults ())) {
13831366 if (!newValue) {
13841367 // This result was dropped and no replacement value was provided.
1385- if (unresolvedMaterializations.contains (op)) {
1386- // Do not create another materializations if we are erasing a
1387- // materialization.
1388- continue ;
1368+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
1369+ if (unresolvedMaterializations.contains (castOp)) {
1370+ // Do not create another materializations if we are erasing a
1371+ // materialization.
1372+ continue ;
1373+ }
13891374 }
13901375
13911376 // Materialize a replacement value "out of thin air".
@@ -2499,15 +2484,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24992484
25002485 // Gather all unresolved materializations.
25012486 SmallVector<UnrealizedConversionCastOp> allCastOps;
2502- DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
2503- for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites ) {
2504- auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get ());
2505- if (!mat)
2506- continue ;
2507- if (rewriterImpl.eraseRewriter .wasErased (mat))
2487+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2488+ &materializations = rewriterImpl.unresolvedMaterializations ;
2489+ for (auto it : materializations) {
2490+ if (rewriterImpl.eraseRewriter .wasErased (it.first ))
25082491 continue ;
2509- allCastOps.push_back (mat->getOperation ());
2510- rewriteMap[mat->getOperation ()] = mat;
2492+ allCastOps.push_back (it.first );
25112493 }
25122494
25132495 // Reconcile all UnrealizedConversionCastOps that were inserted by the
@@ -2520,8 +2502,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
25202502 if (config.buildMaterializations ) {
25212503 IRRewriter rewriter (rewriterImpl.context , config.listener );
25222504 for (UnrealizedConversionCastOp castOp : remainingCastOps) {
2523- auto it = rewriteMap .find (castOp. getOperation () );
2524- assert (it != rewriteMap .end () && " inconsistent state" );
2505+ auto it = materializations .find (castOp);
2506+ assert (it != materializations .end () && " inconsistent state" );
25252507 if (failed (legalizeUnresolvedMaterialization (rewriter, it->second )))
25262508 return failure ();
25272509 }
0 commit comments