@@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1624
1624
// InsertOp
1625
1625
// ===----------------------------------------------------------------------===//
1626
1626
1627
- namespace {
1628
-
1629
- // / Pattern to fold an insert op of a constant destination and scalar to a new
1630
- // / constant.
1631
- // /
1632
- // / Example:
1633
- // / ```
1634
- // / %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1635
- // / %c0 = arith.constant 0 : index
1636
- // / %c4_f32 = arith.constant 4.0 : f32
1637
- // / %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
1638
- // / ```
1639
- // / is rewritten into:
1640
- // / ```
1641
- // / %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1642
- // / ```
1643
- class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
1644
- public:
1645
- using OpRewritePattern<InsertOp>::OpRewritePattern;
1646
-
1647
- LogicalResult matchAndRewrite (InsertOp insertOp,
1648
- PatternRewriter &rewriter) const override {
1649
- // Requires a ranked tensor type.
1650
- auto destType =
1651
- llvm::dyn_cast<RankedTensorType>(insertOp.getDest ().getType ());
1652
- if (!destType)
1653
- return failure ();
1654
-
1655
- // Pattern requires constant indices
1656
- SmallVector<uint64_t , 8 > indices;
1657
- for (OpFoldResult indice : getAsOpFoldResult (insertOp.getIndices ())) {
1658
- auto indiceAttr = dyn_cast<Attribute>(indice);
1659
- if (!indiceAttr)
1660
- return failure ();
1661
- indices.push_back (llvm::cast<IntegerAttr>(indiceAttr).getInt ());
1662
- }
1663
-
1664
- // Requires a constant scalar to insert
1665
- OpFoldResult scalar = getAsOpFoldResult (insertOp.getScalar ());
1666
- Attribute scalarAttr = dyn_cast<Attribute>(scalar);
1667
- if (!scalarAttr)
1668
- return failure ();
1669
-
1670
- if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
1671
- insertOp.getDest ().getDefiningOp ())) {
1672
- if (auto sourceAttr =
1673
- llvm::dyn_cast<ElementsAttr>(constantOp.getValue ())) {
1674
- // Update the attribute at the inserted index.
1675
- auto sourceValues = sourceAttr.getValues <Attribute>();
1676
- auto flattenedIndex = sourceAttr.getFlattenedIndex (indices);
1677
- std::vector<Attribute> updatedValues;
1678
- updatedValues.reserve (sourceAttr.getNumElements ());
1679
- for (unsigned i = 0 ; i < sourceAttr.getNumElements (); ++i) {
1680
- updatedValues.push_back (i == flattenedIndex ? scalarAttr
1681
- : sourceValues[i]);
1682
- }
1683
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(
1684
- insertOp, sourceAttr.getType (),
1685
- DenseElementsAttr::get (cast<ShapedType>(sourceAttr.getType ()),
1686
- updatedValues));
1687
- return success ();
1688
- }
1689
- }
1690
-
1691
- return failure ();
1692
- }
1693
- };
1694
-
1695
- } // namespace
1696
-
1697
1627
void InsertOp::getAsmResultNames (
1698
1628
function_ref<void (Value, StringRef)> setNameFn) {
1699
1629
setNameFn (getResult (), " inserted" );
@@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1717
1647
return {};
1718
1648
}
1719
1649
1720
- void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
1721
- MLIRContext *context) {
1722
- results.add <InsertOpConstantFold>(context);
1723
- }
1724
-
1725
1650
// ===----------------------------------------------------------------------===//
1726
1651
// GenerateOp
1727
1652
// ===----------------------------------------------------------------------===//
0 commit comments