Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Change tensor.extract/insert to take static/dynamic indices. #104488

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

cathyzhyi
Copy link
Contributor

This changes the ODS of tensor.extract/insert op. Some new builder methods are added and the verifiers/canonicalizers are updated. One of the canonicalization pattern of shape.shape_of is also updated.

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 15, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-shape

Author: Yi Zhang (cathyzhyi)

Changes

This changes the ODS of tensor.extract/insert op. Some new builder methods are added and the verifiers/canonicalizers are updated. One of the canonicalization pattern of shape.shape_of is also updated.


Full diff: https://github.com/llvm/llvm-project/pull/104488.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+48-4)
  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+26)
  • (modified) mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td (-6)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+108-9)
  • (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+13)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+8-4)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+36-2)
  • (modified) mlir/test/Dialect/Tensor/ops.mlir (+6)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..997d0ccb28d769 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -332,12 +332,37 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
     ```mlir
     %4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
     %5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
+    %6 = tensor.extract %rt[3, 4] : tensor<?x?xi32>
+    %7 = tensor.extract %rt[%1, 4] : tensor<?x?xi32>
     ```
   }];
 
-  let arguments = (ins AnyRankedTensor:$tensor, Variadic<Index>:$indices);
+  let arguments = (ins
+    AnyRankedTensor:$tensor,
+    Variadic<Index>:$indices,
+    DenseI64ArrayAttr:$static_indices
+  );
   let results = (outs AnyType:$result);
-  let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
+  let assemblyFormat = [{
+    $tensor ``
+    custom<DynamicIndexList>($indices, $static_indices)
+    attr-dict `:` type($tensor)
+  }];
+
+  let builders = [
+    // Build an ExtractOp with mixed static and dynamic indexes.
+    OpBuilder<(ins "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an ExtractOp with mixed static, dynamic indexes and inferred result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an ExtractOp with dynamic indexes.
+    OpBuilder<(ins "Value":$source, CArg<"ValueRange", "{}">:$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an ExtractOp with dynamic indexes and inferred result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$source, CArg<"ValueRange", "{}">:$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+  ];
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
@@ -808,16 +833,35 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
 
   let arguments = (ins AnyType:$scalar,
                        AnyRankedTensor:$dest,
-                       Variadic<Index>:$indices);
+                       Variadic<Index>:$indices,
+                       DenseI64ArrayAttr:$static_indices
+  );
   let results = (outs AnyRankedTensor:$result);
   let assemblyFormat = [{
-    $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
+    $scalar `into`
+    $dest `` custom<DynamicIndexList>($indices, $static_indices)
+    attr-dict `:` type($dest)
   }];
 
   let extraClassDeclaration = [{
     MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
   }];
 
+  let builders = [
+    // Build an InsertOp with mixed static and dynamic indexes.
+    OpBuilder<(ins "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an InsertOp with mixed static, dynamic indexes and inferred result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an InsertOp with dynamic indexes.
+    OpBuilder<(ins "Value":$scalar, "Value":$dest,  CArg<"ValueRange", "{}">:$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build an InsertOp with dynamic indexes and inferred result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+  ];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 8eb8e579954faa..89184f2162c2c4 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1736,6 +1736,32 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
   }
 };
 
+struct ExtractFromShapeOfExtentTensor
+    : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    auto tensorShapeOfOp = op.getTensor().getDefiningOp<shape::ShapeOfOp>();
+    if (!tensorShapeOfOp)
+      return rewriter.notifyMatchFailure(op, "producer is not shape.shape_of");
+
+    int64_t staticIndice = op.getStaticIndices()[0];
+    Type indexType = rewriter.getIndexType();
+    Value indice =
+        staticIndice != ShapedType::kDynamic
+            ? tensorShapeOfOp->getDialect()
+                  ->materializeConstant(
+                      rewriter, IntegerAttr::get(indexType, staticIndice),
+                      indexType, op.getLoc())
+                  ->getResult(0)
+            : op.getIndices()[0];
+    rewriter.replaceOpWithNewOp<tensor::DimOp>(op, tensorShapeOfOp.getArg(),
+                                               indice);
+    return success();
+  }
+};
+
 // Canonicalize
 // ```
 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index cb294ae2978fce..e135105d6980b6 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -44,9 +44,3 @@ def SizeToIndexToSizeCanonicalization : Pat<
 def TensorCastConstShape : Pat <
   (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
   [(HasStaticShape $res)]>;
-
-// tensor.extract from shape_of -> tensor.dim. We can take the first index
-// because shape_of always returns a 1D tensor.
-def ExtractFromShapeOfExtentTensor : Pat<
-  (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
-  (Tensor_DimOp $arg, (TakeFront $indices))>;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e11c6aaccf74dd..bb4d3eccc7377f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -27,7 +28,9 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/MathExtras.h"
 #include <algorithm>
 #include <optional>
@@ -39,6 +42,19 @@ using llvm::divideCeilSigned;
 using llvm::divideFloorSigned;
 using llvm::mod;
 
+static LogicalResult
+checkTensorRankMatchIndices(Value tensor, ValueRange dynamicIndices,
+                            ArrayRef<int64_t> staticIndices) {
+  auto tensorType = llvm::cast<RankedTensorType>(tensor.getType());
+  int64_t dynamicDimCount = llvm::count_if(staticIndices, [](int64_t element) {
+    return element == ShapedType::kDynamic;
+  });
+  if (tensorType.getRank() != staticIndices.size() ||
+      dynamicDimCount != static_cast<int64_t>(dynamicIndices.size()))
+    return LogicalResult::failure();
+  return LogicalResult::success();
+}
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *TensorDialect::materializeConstant(OpBuilder &builder,
@@ -1120,10 +1136,49 @@ void ExtractOp::getAsmResultNames(
   setNameFn(getResult(), "extracted");
 }
 
+// Build an ExtractOp with mixed static and dynamic indexes.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
+                      ArrayRef<OpFoldResult> indices,
+                      ArrayRef<NamedAttribute> attrs) {
+  Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
+  build(b, result, resultType, tensor, indices, attrs);
+}
+
+// Build an ExtractOp with mixed static, dynamic indexes and inferred result
+// Type.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                      Value tensor, ArrayRef<OpFoldResult> indices,
+                      ArrayRef<NamedAttribute> attrs) {
+  SmallVector<int64_t> staticIndices;
+  SmallVector<Value> dynamicIndices;
+  dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
+  result.addAttributes(attrs);
+  build(b, result, resultType, tensor, dynamicIndices,
+        b.getDenseI64ArrayAttr(staticIndices));
+}
+
+// Build an ExtractOp with dynamic indexes and inferred result type.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                      Value tensor, ValueRange indices,
+                      ArrayRef<NamedAttribute> attrs) {
+  SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, resultType, tensor, indicesValues, attrs);
+}
+
+// Build an ExtractOp with dynamic indexes.
+void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
+                      ValueRange indices, ArrayRef<NamedAttribute> attrs) {
+  Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
+  SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, resultType, tensor, indicesValues, attrs);
+}
+
 LogicalResult ExtractOp::verify() {
   // Verify the # indices match if we have a ranked type.
-  auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
-  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
+  if (failed(checkTensorRankMatchIndices(getTensor(), getIndices(),
+                                         getStaticIndices())))
     return emitOpError("incorrect number of indices for extract_element");
   return success();
 }
@@ -1137,12 +1192,18 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
 
   // Collect the constant indices into the tensor.
   SmallVector<uint64_t, 8> indices;
-  for (Attribute indice : adaptor.getIndices()) {
-    if (!indice || !llvm::isa<IntegerAttr>(indice))
-      return {};
-    indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
+  auto dynamicIndicesIt = adaptor.getIndices().begin();
+  for (int64_t i : getStaticIndices()) {
+    if (i != ShapedType::kDynamic) {
+      indices.push_back(i);
+    } else {
+      Attribute indice = *dynamicIndicesIt;
+      if (!indice || !llvm::isa<IntegerAttr>(indice))
+        return {};
+      indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
+      dynamicIndicesIt++;
+    }
   }
-
   // Fold extract(from_elements(...)).
   if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
     auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
@@ -1354,10 +1415,48 @@ void InsertOp::getAsmResultNames(
   setNameFn(getResult(), "inserted");
 }
 
+// Build an ExtractOp with mixed static and dynamic indexes.
+void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
+                     Value dest, ArrayRef<OpFoldResult> indices,
+                     ArrayRef<NamedAttribute> attrs) {
+  build(b, result, dest.getType(), scalar, dest, indices, attrs);
+}
+
+// Build an InsertOp with mixed static, dynamic indexes and inferred result
+// Type.
+void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                     Value scalar, Value dest, ArrayRef<OpFoldResult> indices,
+                     ArrayRef<NamedAttribute> attrs) {
+  SmallVector<int64_t> staticIndices;
+  SmallVector<Value> dynamicIndices;
+  dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
+  result.addAttributes(attrs);
+  build(b, result, resultType, scalar, dest, dynamicIndices,
+        b.getDenseI64ArrayAttr(staticIndices));
+}
+
+// Build an ExtractOp with dynamic indexes and inferred result type.
+void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                     Value scalar, Value dest, ValueRange indices,
+                     ArrayRef<NamedAttribute> attrs) {
+  SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, resultType, scalar, dest, indicesValues, attrs);
+}
+
+// Build an InsertOp with dynamic indexes.
+void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
+                     Value dest, ValueRange indices,
+                     ArrayRef<NamedAttribute> attrs) {
+  SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
+      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, dest.getType(), scalar, dest, indicesValues, attrs);
+}
+
 LogicalResult InsertOp::verify() {
   // Verify the # indices match if we have a ranked type.
-  auto destType = llvm::cast<RankedTensorType>(getDest().getType());
-  if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
+  if (failed(checkTensorRankMatchIndices(getDest(), getIndices(),
+                                         getStaticIndices())))
     return emitOpError("incorrect number of indices");
   return success();
 }
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 5b98a7790debf2..8c04e574dbc518 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1519,6 +1519,19 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
  return %result : index
 }
 
+// -----
+
+// CHECK-LABEL: func @extract_shapeof_static_indice
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<?x?xf64>
+func.func @extract_shapeof_static_indice(%arg0 : tensor<?x?xf64>) -> index {
+// CHECK:        %[[C1:.*]] = arith.constant 1
+ %shape = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
+// CHECK:        %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
+ %result = tensor.extract %shape[1] : tensor<2xindex>
+// CHECK:        return %[[DIM]]
+ return %result : index
+}
+
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4b8efde78cc23c..8f7c7478669b4f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -137,11 +137,12 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
 // -----
 
 // CHECK-LABEL: func @fold_extract
-func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
+func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, i32, complex<f32>) {
   %const_0 = arith.constant 0 : index
   %const_1 = arith.constant 1 : index
   %const_3 = arith.constant 3 : index
   // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
+  // CHECK-DAG: [[CNEG1:%.+]] = arith.constant -1 : i32
   // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
   // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
 
@@ -162,13 +163,16 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
   %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
 
+  // Fold an extract into a dense tensor with mixed dynamic and static indexes.
+  %ext_5 = tensor.extract %3[%const_1, 0, 2] : tensor<2x1x4xi32>
+
   // Fold an extract into a complex constant.
   // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
   %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
-  %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
+  %ext_6 = tensor.extract %4[] : tensor<complex<f32>>
 
-  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
-  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
+  // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[CNEG1]], [[C5]]
+  return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6: f32, f16, f16, i32, i32, complex<f32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 41b6529f64afa3..8c594ddacb8d33 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -64,7 +64,7 @@ func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
 
 // -----
 
-func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
+func.func @extract_too_few_indices(%arg0: tensor<?xf32>) {
   // expected-error@+1 {{incorrect number of indices for extract_element}}
   %0 = tensor.extract %arg0[] : tensor<?xf32>
   return
@@ -72,7 +72,24 @@ func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
 
 // -----
 
-func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+func.func @extract_too_many_static_indices(%arg0: tensor<?xf32>) {
+  // expected-error@+1 {{incorrect number of indices for extract_element}}
+  %0 = tensor.extract %arg0[2, 3] : tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @extract_too_many_mixed_indices(%arg0: tensor<?xf32>) {
+  %c1 = arith.constant 1 : index
+  // expected-error@+1 {{incorrect number of indices for extract_element}}
+  %0 = tensor.extract %arg0[%c1, 2, 3] : tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @insert_too_few_indices(%arg0: f32, %arg1: tensor<?xf32>) {
   // expected-error@+1 {{incorrect number of indices}}
   %0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
   return
@@ -80,6 +97,23 @@ func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
 
 // -----
 
+func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+  // expected-error@+1 {{incorrect number of indices}}
+  %0 = tensor.insert %arg0 into %arg1[2, 3] : tensor<?xf32>
+  return
+}
+
+// -----
+
+func.func @insert_too_many_mixed_indices(%arg0: f32, %arg1: tensor<?xf32>) {
+  %c1 = arith.constant 1 : index
+  // expected-error@+1 {{incorrect number of indices}}
+  %0 = tensor.insert %arg0 into %arg1[%c1, 2, 3] : tensor<?xf32>
+  return
+}
+
+// -----
+
 func.func @tensor.from_elements_wrong_result_type() {
   // expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}}
   %c0 = arith.constant 0 : i32
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 378137a14b59ff..0a4cd08239c5b4 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -58,6 +58,9 @@ func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> {
 func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
   // CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
   %0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+
+  // CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], 2, 3] : tensor<?x?x?xf32>
+  %1 = tensor.extract %arg0[%arg1, 2, 3] : tensor<?x?x?xf32>
   return
 }
 
@@ -70,6 +73,9 @@ func.func @extract(%arg0: tensor<?x?x?xf32>, %arg1: index) {
 func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor<?x?x?xf32>) {
   // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<?x?x?xf32>
   %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor<?x?x?xf32>
+
+  // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], 2, 3] : tensor<?x?x?xf32>
+  %1 = tensor.insert %arg0 into %arg2[%arg1, 2, 3] : tensor<?x?x?xf32>
   return
 }
 

This changes the ODS of `tensor.extract/insert` op. Some new builder methods are
added and the verifiers/canonicalizers are updated. One of the canonicalization
pattern of `shape.shape_of` is also updated.
@jpienaar jpienaar self-requested a review August 15, 2024 21:37
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few more places that have to be updated. Basically everything that calls ExtractOp::getIndices or InsertOp::getIndices. E.g., Tensor/Transforms/BufferizableOpInterfaceImpl.cpp.

$scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
$scalar `into`
$dest `` custom<DynamicIndexList>($indices, $static_indices)
attr-dict `:` type($dest)
}];

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both ops should have a getMixedIndices function, same as getMixedOffsets etc. of InsertSliceOp/ExtractSliceOp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Do you think a new interface like MixedIndicesInterface is needed for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be useful. (But can also be done without.) I tried something like that in the past (https://reviews.llvm.org/D156899), but I didn't land it for some reason... Don't really remember why. There was also an RFC (https://discourse.llvm.org/t/rfc-more-opfoldresult-and-mixed-indices-in-ops-that-deal-with-shaped-values/72510). I would read through that discussion before adding a new interface.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I think there is something more general to be done here, I think this is a good starting point and can see other pain points and think about generalizing.

// Build an ExtractOp with mixed static, dynamic indexes and inferred result type.
OpBuilder<(ins "Type":$resultType, "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an ExtractOp with dynamic indexes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some spelling inconsistences: indexes, indices


int64_t staticIndice = op.getStaticIndices()[0];
Type indexType = rewriter.getIndexType();
Value indice =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: indices

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather index ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(funnily enough it seems indice is index in Spanish)

if (i != ShapedType::kDynamic) {
indices.push_back(i);
} else {
Attribute indice = *dynamicIndicesIt;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: indices

$scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
$scalar `into`
$dest `` custom<DynamicIndexList>($indices, $static_indices)
attr-dict `:` type($dest)
}];

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I think there is something more general to be done here, I think this is a good starting point and can see other pain points and think about generalizing.


int64_t staticIndice = op.getStaticIndices()[0];
Type indexType = rewriter.getIndexType();
Value indice =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather index ?

int64_t staticIndice = op.getStaticIndices()[0];
Type indexType = rewriter.getIndexType();
Value indice =
staticIndice != ShapedType::kDynamic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer using isDynamic(staticIndex)

if (!tensorShapeOfOp)
return rewriter.notifyMatchFailure(op, "producer is not shape.shape_of");

int64_t staticIndice = op.getStaticIndices()[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its sort of weird that static index could be dynamic ... I seem to recall poking about this on a previous review, why not just store only static in one and only dynamic in the other and then using the type to differentiate - that would result in more operations for indexing. Not something to address here as this is keeping the form.

ArrayRef<int64_t> staticIndices) {
auto tensorType = llvm::cast<RankedTensorType>(tensor.getType());
int64_t dynamicDimCount = llvm::count_if(staticIndices, [](int64_t element) {
return element == ShapedType::kDynamic;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same


int64_t staticIndice = op.getStaticIndices()[0];
Type indexType = rewriter.getIndexType();
Value indice =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(funnily enough it seems indice is index in Spanish)

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation for this?

In general I see this as an anti-pattern to have operation accept both an attribute and an SSA value.
I thought the "experiment" with the few ops that adopted this form in the past was deemed like an accepted "bad idea" in the end?

@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Aug 19, 2024

What is the motivation for this?

In general I see this as an anti-pattern to have operation accept both an attribute and an SSA value.
I thought the "experiment" with the few ops that adopted this form in the past was deemed like an accepted "bad idea" in the end?

I think there is a long history of discussion here, but I for one prefer the mixed static/dynamic list. It is pretty readable and constant values are folded into the operation itself instead having to lookup use-def chains. So I would very much prefer having a general utility of the form being added in this PR. So +1 from me!

@joker-eph
Copy link
Collaborator

joker-eph commented Aug 19, 2024

To be clear: it has nothing to do with textual readability but instead more with the uniformity of the APIs (or lack thereof) and the complexity imposed on every user.

@MaheshRavishankar
Copy link
Contributor

To be clear: it has nothing to do with textual readability but instead more with the uniformity of the APIs (or lack thereof) and the complexity imposed on every user.

Could you explain more. It's opt-in based on the op. So what is the complexity you are referring to?

@joker-eph
Copy link
Collaborator

It's opt-in based on the op.

I don't quite get what you mean by this?

So what is the complexity you are referring to?

The C++ API provides accessors for either Value or Attribute which describe the same thing and are both optional (one of them must be present, in an exclusionary way).
MLIR core infra is just not setup to support this use case very well (we could have improved it to do this, but it isn't a pattern that has been widely accepted and so it hasn't happened). For example you can use getOperands() to get a ValueRange, and ODS can generate accessors for ValueRange for named operands, but this "idiom" is not correctly abstracted anywhere in the same way and the complexity is pushed on every API users.

@MaheshRavishankar
Copy link
Contributor

It's opt-in based on the op.

I don't quite get what you mean by this?

I mean it only matters to the op that wants to use this paradigm. So in that sense it is opt-in.

So what is the complexity you are referring to?

The C++ API provides accessors for either Value or Attribute which describe the same thing and are both optional (one of them must be present, in an exclusionary way).

Not sure that is true. As I understand it the attribute is always required. The Values are needed only if the attribute has a sentinel value saying the dimension is dynamic.

MLIR core infra is just not setup to support this use case very well (we could have improved it to do this, but it isn't a pattern that has been widely accepted and so it hasn't happened). For example you can use getOperands() to get a ValueRange, and ODS can generate accessors for ValueRange for named operands, but this "idiom" is not correctly abstracted anywhere in the same way and the complexity is pushed on every API users.

That is true. Most ops that use this use custom methods to get mixed static/dynamic values. This PR IMO is starting to build support for this paradigm and make it's usage more uniform

@joker-eph
Copy link
Collaborator

Most ops that use this use custom methods to get mixed static/dynamic values. This PR IMO is starting to build support for this paradigm and make it's usage more uniform

So far you motivated this by “it is nicer to read”: I would strongly object to doing anything this scale for this reason. I am sure there are others though but I’d want to see this well motivated (why this op and not just all operations?), and carefully considered, because this isn’t trivial.

in the past this pattern started with other motivations (@nicolasvasilache or @ftynse ), but the tradeoffs didn’t pan out as far as I can remember.

@matthias-springer
Copy link
Member

Another point that wasn't mentioned yet: Static indices can be verified. Dynamic indices should not be verified as per verifier guidelines.

@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Aug 20, 2024

Most ops that use this use custom methods to get mixed static/dynamic values. This PR IMO is starting to build support for this paradigm and make it's usage more uniform

So far you motivated this by “it is nicer to read”: I would strongly object to doing anything this scale for this reason. I am sure there are others though but I’d want to see this well motivated (why this op and not just all operations?), and carefully considered, because this isn’t trivial.

in the past this pattern started with other motivations (@nicolasvasilache or @ftynse ), but the tradeoffs didn’t pan out as far as I can remember.

Agreed that "nicer to read" isnt a good enough reason (but is a good side-effect). For me its more about things that are constant, like constant values in list of offsets, sizes and strides of slices, (or indices in this PR) are "embedded" into the operation itself. I dont need to look at producer of operands to see if it is a constant (and as Matthias mentioned, verify these constants values). So thats why +1 from me on this.

basically

%0 = tensor.extract %b[10, 20]

is more self-contained than

%c10 = arith.constant 10 : index
%c20 = arith.constant 20 : index
%0 = tensor.extract %b[%c10, %c20]

and also allows you to verify that [10, 20] is in-bounds of %b without having to look at producers of operands.

@joker-eph
Copy link
Collaborator

For me its more about things that are constant,

Either they are constant and always attributes, or they are "maybe" constant and that's an SSA value that you can match. It's the same with every single operation though as far as I can tell. I don't see what's special here.
(hence my comment about if we were intended for this, we'd have built this in the core infra and generalized it as a first class thing).

@MaheshRavishankar
Copy link
Contributor

For me its more about things that are constant,

Either they are constant and always attributes, or they are "maybe" constant and that's an SSA value that you can match. It's the same with every single operation though as far as I can tell. I don't see what's special here.
(hence my comment about if we were intended for this, we'd have built this in the core infra and generalized it as a first class thing).

Sure... What would first class support for this look like. Could we build it in now?

@joker-eph
Copy link
Collaborator

Sure... What would first class support for this look like. Could we build it in now?

I don't know what it would look like (something like: replace Value with OpFoldResult for all API on the Operation class to begin with? Changes to ODS?).
Right now MLIR is just designed as "constant are defined with separate operations", this not a simple change and it would need a very strong motivation.
(attributes are not means to fold SSA value, they are meant for information required by lowering)

@jpienaar
Copy link
Member

I don't see this as proposing a change to MLIR core. This is a change to the Tensor dialect, making the insert and extract ops more uniform within the dialect itself.

I currently have inputs where I have O(40k) scalar constant operations whose only reason for existing is indexing. Its inefficient, its more difficult to test and cumbersome to write (using a little lazy cache per isolated from above context to avoid ending at O(400k) scalar constant operations that get CSE'd later). Its not special to indexing ops that one could capture constant values, but it is very wasteful for them (I forget, its 10x higher storage than just parameter).

We could also just introduce new ops: tensor.insert_static and tensor.extract_static that is all static. That is an option. It introduces different set of conditional "expansion" (e.g., where insert is expected today, one would need to consider the static insert too).

Now wrt API, indeed this change doesn't change those. An option, is to add method like std::optional<int64_t> getConstantDim(i) which would then either look at constant property or match on operand to try and extract constant value, and std::optional<Value> getDynamicDim(i) which would just return operand else nullopt. Still results in conditionals, but for static that's not new (match is already a conditional) and unseen by user, while for dynamic I actually don't know how its being used except for codegen later and there it is most likely post checking if constant already :) This also doesn't do any automatic promotion to static args, so its only if created such.

@joker-eph
Copy link
Collaborator

joker-eph commented Aug 22, 2024

I don't see this as proposing a change to MLIR core

Nobody said otherwise I believe.

We could also just introduce new ops: tensor.insert_static and tensor.extract_static that is all static

I don’t have an objection of principles on this if there is a use-case for static indexing.

Thanks for providing a motivation by the way, we can discuss tradeoffs with concrete info :)
(In this case seems like the usual missing map for pooling constants, likely something we should invest in as well!)

@matthias-springer
Copy link
Member

What do you think about OpBuilder::createOrFold? It currently returns Value, but I think it should return OpFoldResult. I tried to change that at some point (but it was too much work). As a follow-up, I wanted to update more tensors ops to store "mixed" values. (Similar to what this PR is doing.) Then the API would work nicely with createOrFold. (In a sense that the op builder can take the result of a createOrFold and does not have to build an op for a constant value.) But also, that never happened...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants