Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Change tensor.extract/insert to take static/dynamic indices. #104488

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,37 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
```mlir
%4 = tensor.extract %t[%1, %2] : tensor<4x4xi32>
%5 = tensor.extract %rt[%1, %2] : tensor<?x?xi32>
%6 = tensor.extract %rt[3, 4] : tensor<?x?xi32>
%7 = tensor.extract %rt[%1, 4] : tensor<?x?xi32>
```
}];

let arguments = (ins AnyRankedTensor:$tensor, Variadic<Index>:$indices);
let arguments = (ins
AnyRankedTensor:$tensor,
Variadic<Index>:$indices,
DenseI64ArrayAttr:$static_indices
);
let results = (outs AnyType:$result);
let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)";
let assemblyFormat = [{
$tensor ``
custom<DynamicIndexList>($indices, $static_indices)
attr-dict `:` type($tensor)
}];

let builders = [
// Build an ExtractOp with mixed static and dynamic indexes.
OpBuilder<(ins "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an ExtractOp with mixed static, dynamic indexes and inferred result type.
OpBuilder<(ins "Type":$resultType, "Value":$tensor, "ArrayRef<OpFoldResult>":$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an ExtractOp with dynamic indexes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some spelling inconsistences: indexes, indices

OpBuilder<(ins "Value":$source, CArg<"ValueRange", "{}">:$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an ExtractOp with dynamic indexes and inferred result type.
OpBuilder<(ins "Type":$resultType, "Value":$source, CArg<"ValueRange", "{}">:$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];

let hasCanonicalizer = 1;
let hasFolder = 1;
Expand Down Expand Up @@ -808,16 +833,35 @@ def Tensor_InsertOp : Tensor_Op<"insert", [

let arguments = (ins AnyType:$scalar,
AnyRankedTensor:$dest,
Variadic<Index>:$indices);
Variadic<Index>:$indices,
DenseI64ArrayAttr:$static_indices
);
let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
$scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest)
$scalar `into`
$dest `` custom<DynamicIndexList>($indices, $static_indices)
attr-dict `:` type($dest)
}];

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both ops should have a getMixedIndices function, same as getMixedOffsets etc. of InsertSliceOp/ExtractSliceOp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Do you think a new interface like MixedIndicesInterface is needed for this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be useful. (But can also be done without.) I tried something like that in the past (https://reviews.llvm.org/D156899), but I didn't land it for some reason... Don't really remember why. There was also an RFC (https://discourse.llvm.org/t/rfc-more-opfoldresult-and-mixed-indices-in-ops-that-deal-with-shaped-values/72510). I would read through that discussion before adding a new interface.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I think there is something more general to be done here, I think this is a good starting point and can see other pain points and think about generalizing.

}];

let builders = [
// Build an InsertOp with mixed static and dynamic indexes.
OpBuilder<(ins "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an InsertOp with mixed static, dynamic indexes and inferred result type.
OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, "ArrayRef<OpFoldResult>":$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an InsertOp with dynamic indexes.
OpBuilder<(ins "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an InsertOp with dynamic indexes and inferred result type.
OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];

let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,32 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
}
};

struct ExtractFromShapeOfExtentTensor
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractOp op,
PatternRewriter &rewriter) const override {
auto tensorShapeOfOp = op.getTensor().getDefiningOp<shape::ShapeOfOp>();
if (!tensorShapeOfOp)
return rewriter.notifyMatchFailure(op, "producer is not shape.shape_of");

int64_t staticIndice = op.getStaticIndices()[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its sort of weird that static index could be dynamic ... I seem to recall poking about this on a previous review, why not just store only static in one and only dynamic in the other and then using the type to differentiate - that would result in more operations for indexing. Not something to address here as this is keeping the form.

Type indexType = rewriter.getIndexType();
Value indice =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: indices

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather index ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(funnily enough it seems indice is index in Spanish)

staticIndice != ShapedType::kDynamic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer using isDynamic(staticIndex)

? tensorShapeOfOp->getDialect()
->materializeConstant(
rewriter, IntegerAttr::get(indexType, staticIndice),
indexType, op.getLoc())
->getResult(0)
: op.getIndices()[0];
rewriter.replaceOpWithNewOp<tensor::DimOp>(op, tensorShapeOfOp.getArg(),
indice);
return success();
}
};

// Canonicalize
// ```
// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
Expand Down
6 changes: 0 additions & 6 deletions mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,3 @@ def SizeToIndexToSizeCanonicalization : Pat<
def TensorCastConstShape : Pat <
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
[(HasStaticShape $res)]>;

// tensor.extract from shape_of -> tensor.dim. We can take the first index
// because shape_of always returns a 1D tensor.
def ExtractFromShapeOfExtentTensor : Pat<
(Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
(Tensor_DimOp $arg, (TakeFront $indices))>;
114 changes: 105 additions & 9 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ using llvm::divideCeilSigned;
using llvm::divideFloorSigned;
using llvm::mod;

static LogicalResult
checkTensorRankMatchIndices(Value tensor, ValueRange dynamicIndices,
ArrayRef<int64_t> staticIndices) {
auto tensorType = llvm::cast<RankedTensorType>(tensor.getType());
int64_t dynamicDimCount = llvm::count_if(staticIndices, [](int64_t element) {
return element == ShapedType::kDynamic;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

});
if (tensorType.getRank() != staticIndices.size() ||
dynamicDimCount != static_cast<int64_t>(dynamicIndices.size()))
return LogicalResult::failure();
return LogicalResult::success();
}

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
Expand Down Expand Up @@ -1120,10 +1133,49 @@ void ExtractOp::getAsmResultNames(
setNameFn(getResult(), "extracted");
}

// Build an ExtractOp with mixed static and dynamic indexes.
void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
ArrayRef<OpFoldResult> indices,
ArrayRef<NamedAttribute> attrs) {
Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
build(b, result, resultType, tensor, indices, attrs);
}

// Build an ExtractOp with mixed static, dynamic indexes and inferred result
// Type.
void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value tensor, ArrayRef<OpFoldResult> indices,
ArrayRef<NamedAttribute> attrs) {
SmallVector<int64_t> staticIndices;
SmallVector<Value> dynamicIndices;
dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
result.addAttributes(attrs);
build(b, result, resultType, tensor, dynamicIndices,
b.getDenseI64ArrayAttr(staticIndices));
}

// Build an ExtractOp with dynamic indexes and inferred result type.
void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value tensor, ValueRange indices,
ArrayRef<NamedAttribute> attrs) {
SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
build(b, result, resultType, tensor, indicesValues, attrs);
}

// Build an ExtractOp with dynamic indexes.
void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor,
ValueRange indices, ArrayRef<NamedAttribute> attrs) {
Type resultType = llvm::cast<TensorType>(tensor.getType()).getElementType();
SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
build(b, result, resultType, tensor, indicesValues, attrs);
}

LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
if (failed(checkTensorRankMatchIndices(getTensor(), getIndices(),
getStaticIndices())))
return emitOpError("incorrect number of indices for extract_element");
return success();
}
Expand All @@ -1137,12 +1189,18 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {

// Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : adaptor.getIndices()) {
if (!indice || !llvm::isa<IntegerAttr>(indice))
return {};
indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
auto dynamicIndicesIt = adaptor.getIndices().begin();
for (int64_t i : getStaticIndices()) {
if (i != ShapedType::kDynamic) {
indices.push_back(i);
} else {
Attribute indice = *dynamicIndicesIt;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: indices

if (!indice || !llvm::isa<IntegerAttr>(indice))
return {};
indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
dynamicIndicesIt++;
}
}

// Fold extract(from_elements(...)).
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
Expand Down Expand Up @@ -1354,10 +1412,48 @@ void InsertOp::getAsmResultNames(
setNameFn(getResult(), "inserted");
}

// Build an ExtractOp with mixed static and dynamic indexes.
void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
Value dest, ArrayRef<OpFoldResult> indices,
ArrayRef<NamedAttribute> attrs) {
build(b, result, dest.getType(), scalar, dest, indices, attrs);
}

// Build an InsertOp with mixed static, dynamic indexes and inferred result
// Type.
void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value scalar, Value dest, ArrayRef<OpFoldResult> indices,
ArrayRef<NamedAttribute> attrs) {
SmallVector<int64_t> staticIndices;
SmallVector<Value> dynamicIndices;
dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices);
result.addAttributes(attrs);
build(b, result, resultType, scalar, dest, dynamicIndices,
b.getDenseI64ArrayAttr(staticIndices));
}

// Build an ExtractOp with dynamic indexes and inferred result type.
void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value scalar, Value dest, ValueRange indices,
ArrayRef<NamedAttribute> attrs) {
SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
build(b, result, resultType, scalar, dest, indicesValues, attrs);
}

// Build an InsertOp with dynamic indexes.
void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar,
Value dest, ValueRange indices,
ArrayRef<NamedAttribute> attrs) {
SmallVector<OpFoldResult> indicesValues = llvm::to_vector<4>(
llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }));
build(b, result, dest.getType(), scalar, dest, indicesValues, attrs);
}

LogicalResult InsertOp::verify() {
// Verify the # indices match if we have a ranked type.
auto destType = llvm::cast<RankedTensorType>(getDest().getType());
if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
if (failed(checkTensorRankMatchIndices(getDest(), getIndices(),
getStaticIndices())))
return emitOpError("incorrect number of indices");
return success();
}
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,19 @@ func.func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
return %result : index
}

// -----

// CHECK-LABEL: func @extract_shapeof_static_indice
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf64>
func.func @extract_shapeof_static_indice(%arg0 : tensor<?x?xf64>) -> index {
// CHECK: %[[C1:.*]] = arith.constant 1
%shape = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
%result = tensor.extract %shape[1] : tensor<2xindex>
// CHECK: return %[[DIM]]
return %result : index
}


// -----

Expand Down
12 changes: 8 additions & 4 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,12 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
// -----

// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
%const_1 = arith.constant 1 : index
%const_3 = arith.constant 3 : index
// CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
// CHECK-DAG: [[CNEG1:%.+]] = arith.constant -1 : i32
// CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
// CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16

Expand All @@ -162,13 +163,16 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
%ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>

// Fold an extract into a dense tensor with mixed dynamic and static indexes.
%ext_5 = tensor.extract %3[%const_1, 0, 2] : tensor<2x1x4xi32>

// Fold an extract into a complex constant.
// CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
%ext_5 = tensor.extract %4[] : tensor<complex<f32>>
%ext_6 = tensor.extract %4[] : tensor<complex<f32>>

// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[CNEG1]], [[C5]]
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6: f32, f16, f16, i32, i32, complex<f32>
}

// -----
Expand Down
38 changes: 36 additions & 2 deletions mlir/test/Dialect/Tensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,56 @@ func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {

// -----

func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
func.func @extract_too_few_indices(%arg0: tensor<?xf32>) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = tensor.extract %arg0[] : tensor<?xf32>
return
}

// -----

func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
func.func @extract_too_many_static_indices(%arg0: tensor<?xf32>) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = tensor.extract %arg0[2, 3] : tensor<?xf32>
return
}

// -----

func.func @extract_too_many_mixed_indices(%arg0: tensor<?xf32>) {
%c1 = arith.constant 1 : index
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = tensor.extract %arg0[%c1, 2, 3] : tensor<?xf32>
return
}

// -----

func.func @insert_too_few_indices(%arg0: f32, %arg1: tensor<?xf32>) {
// expected-error@+1 {{incorrect number of indices}}
%0 = tensor.insert %arg0 into %arg1[] : tensor<?xf32>
return
}

// -----

func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
// expected-error@+1 {{incorrect number of indices}}
%0 = tensor.insert %arg0 into %arg1[2, 3] : tensor<?xf32>
return
}

// -----

func.func @insert_too_many_mixed_indices(%arg0: f32, %arg1: tensor<?xf32>) {
%c1 = arith.constant 1 : index
// expected-error@+1 {{incorrect number of indices}}
%0 = tensor.insert %arg0 into %arg1[%c1, 2, 3] : tensor<?xf32>
return
}

// -----

func.func @tensor.from_elements_wrong_result_type() {
// expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}}
%c0 = arith.constant 0 : i32
Expand Down
Loading
Loading