Skip to content

Commit c66b72f

Browse files
authored
[mlir][tensor] remove tensor.insert constant folding out of canonicalization (#142671)
Follow ups from #142458 In particular concerns that indiscriminately folding tensor constants can lead to bloating the IR as these can be arbitrarily large. Signed-off-by: Asra Ali <asraa@google.com>
1 parent 49386f4 commit c66b72f

File tree

3 files changed

+0
-92
lines changed

3 files changed

+0
-92
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
827827

828828
let hasFolder = 1;
829829
let hasVerifier = 1;
830-
let hasCanonicalizer = 1;
831830
}
832831

833832
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,76 +1624,6 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
16241624
// InsertOp
16251625
//===----------------------------------------------------------------------===//
16261626

1627-
namespace {
1628-
1629-
/// Pattern to fold an insert op of a constant destination and scalar to a new
1630-
/// constant.
1631-
///
1632-
/// Example:
1633-
/// ```
1634-
/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1635-
/// %c0 = arith.constant 0 : index
1636-
/// %c4_f32 = arith.constant 4.0 : f32
1637-
/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
1638-
/// ```
1639-
/// is rewritten into:
1640-
/// ```
1641-
/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1642-
/// ```
1643-
class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
1644-
public:
1645-
using OpRewritePattern<InsertOp>::OpRewritePattern;
1646-
1647-
LogicalResult matchAndRewrite(InsertOp insertOp,
1648-
PatternRewriter &rewriter) const override {
1649-
// Requires a ranked tensor type.
1650-
auto destType =
1651-
llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
1652-
if (!destType)
1653-
return failure();
1654-
1655-
// Pattern requires constant indices
1656-
SmallVector<uint64_t, 8> indices;
1657-
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
1658-
auto indiceAttr = dyn_cast<Attribute>(indice);
1659-
if (!indiceAttr)
1660-
return failure();
1661-
indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
1662-
}
1663-
1664-
// Requires a constant scalar to insert
1665-
OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
1666-
Attribute scalarAttr = dyn_cast<Attribute>(scalar);
1667-
if (!scalarAttr)
1668-
return failure();
1669-
1670-
if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
1671-
insertOp.getDest().getDefiningOp())) {
1672-
if (auto sourceAttr =
1673-
llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
1674-
// Update the attribute at the inserted index.
1675-
auto sourceValues = sourceAttr.getValues<Attribute>();
1676-
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
1677-
std::vector<Attribute> updatedValues;
1678-
updatedValues.reserve(sourceAttr.getNumElements());
1679-
for (unsigned i = 0; i < sourceAttr.getNumElements(); ++i) {
1680-
updatedValues.push_back(i == flattenedIndex ? scalarAttr
1681-
: sourceValues[i]);
1682-
}
1683-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1684-
insertOp, sourceAttr.getType(),
1685-
DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
1686-
updatedValues));
1687-
return success();
1688-
}
1689-
}
1690-
1691-
return failure();
1692-
}
1693-
};
1694-
1695-
} // namespace
1696-
16971627
void InsertOp::getAsmResultNames(
16981628
function_ref<void(Value, StringRef)> setNameFn) {
16991629
setNameFn(getResult(), "inserted");
@@ -1717,11 +1647,6 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
17171647
return {};
17181648
}
17191649

1720-
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1721-
MLIRContext *context) {
1722-
results.add<InsertOpConstantFold>(context);
1723-
}
1724-
17251650
//===----------------------------------------------------------------------===//
17261651
// GenerateOp
17271652
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
231231
return %ins_1 : tensor<4xf32>
232232
}
233233

234-
235-
// -----
236-
237-
func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
238-
// Fold an insert into a splat.
239-
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
240-
// CHECK-LITERAL:
241-
// CHECK-NEXT: return %[[C4]]
242-
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
243-
%c0 = arith.constant 0 : index
244-
%c1 = arith.constant 1 : index
245-
%c4_i32 = arith.constant 4 : i32
246-
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
247-
return %inserted : tensor<2x2xi32>
248-
}
249-
250234
// -----
251235

252236
// CHECK-LABEL: func @extract_from_tensor.cast

0 commit comments

Comments
 (0)