-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][tensor] Fix bug in ConcatOpInterface
#168676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR fixes an issue in `ConcatOpInterface` where `tensor.concat` fails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by using `OpFoldResult`, avoiding the need to separately handle dynamic and static dimension values.
|
@llvm/pr-subscribers-mlir-tensor Author: Longsheng Mou (CoTinker) ChangesThis PR fixes an issue in Full diff: https://github.com/llvm/llvm-project/pull/168676.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c607ece418dff..5482cedae71d7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1132,35 +1132,23 @@ struct ConcatOpInterface
// Extract the dimension for the concat op
uint64_t concatDim = concatOp.getDim();
- bool dynamicConcatDim = false;
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(tensorType.getRank(),
rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> sizes;
-
- for (const auto &[dimIdx, dimSize] :
- llvm::enumerate(tensorType.getShape())) {
- if (dimSize == ShapedType::kDynamic) {
- auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
- sizes.push_back(dimOp.getResult());
- if (dimIdx == concatDim)
- dynamicConcatDim = true;
- } else {
- sizes.push_back(rewriter.getIndexAttr(dimSize));
- }
- }
-
- int64_t concatDimOffset = 0;
- std::optional<Value> dynamicOffset;
- std::optional<Value> dynamicSize;
- if (dynamicConcatDim) {
- // One or more operands have dynamic size, so we must accumulate the
- // offset with arith ops.
- dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- }
+ SmallVector<OpFoldResult> sizes =
+ memref::getMixedSizes(rewriter, loc, dstBuffer);
+
+ AffineExpr d0, d1;
+ bindDims(rewriter.getContext(), d0, d1);
+ // Add two integers.
+ auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, d0 + d1,
+ {v1, v2});
+ };
+ OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
for (auto operand : concatOp.getInputs()) {
// Get the buffer for the operand.
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1159,10 @@ struct ConcatOpInterface
// so the offset on that axis must accumulate through the loop, and the
// size must change to the size of the current operand.
auto operandTensorType = cast<RankedTensorType>(operand.getType());
- int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-
- if (dynamicConcatDim) {
- offsets[concatDim] = dynamicOffset.value();
- dynamicSize =
- memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
- .getResult();
- sizes[concatDim] = dynamicSize.value();
- } else {
- sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
- offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
- }
+ offsets[concatDim] = concatDimOffset;
+ OpFoldResult concatDimSize =
+ memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
+ sizes[concatDim] = concatDimSize;
// Create a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1177,7 @@ struct ConcatOpInterface
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
return failure();
- if (dynamicConcatDim) {
- dynamicOffset = arith::AddIOp::create(
- rewriter, loc, dynamicOffset.value(), dynamicSize.value());
- } else {
- concatDimOffset += operandConcatDimSize;
- }
+ concatDimOffset = sum(concatDimOffset, concatDimSize);
}
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5eb2360a29b8f..be8ce20d8f154 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<8x?xf32>
-// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
-// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
-// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<?x?xf32>
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
-// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
-// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// -----
+// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static(
+// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
+// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>)
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
+// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
+// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
+ %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
+ return %0 : tensor<8x10xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
|
|
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR fixes an issue in Full diff: https://github.com/llvm/llvm-project/pull/168676.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c607ece418dff..5482cedae71d7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1132,35 +1132,23 @@ struct ConcatOpInterface
// Extract the dimension for the concat op
uint64_t concatDim = concatOp.getDim();
- bool dynamicConcatDim = false;
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(tensorType.getRank(),
rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> sizes;
-
- for (const auto &[dimIdx, dimSize] :
- llvm::enumerate(tensorType.getShape())) {
- if (dimSize == ShapedType::kDynamic) {
- auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
- sizes.push_back(dimOp.getResult());
- if (dimIdx == concatDim)
- dynamicConcatDim = true;
- } else {
- sizes.push_back(rewriter.getIndexAttr(dimSize));
- }
- }
-
- int64_t concatDimOffset = 0;
- std::optional<Value> dynamicOffset;
- std::optional<Value> dynamicSize;
- if (dynamicConcatDim) {
- // One or more operands have dynamic size, so we must accumulate the
- // offset with arith ops.
- dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- }
+ SmallVector<OpFoldResult> sizes =
+ memref::getMixedSizes(rewriter, loc, dstBuffer);
+
+ AffineExpr d0, d1;
+ bindDims(rewriter.getContext(), d0, d1);
+ // Add two integers.
+ auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, d0 + d1,
+ {v1, v2});
+ };
+ OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
for (auto operand : concatOp.getInputs()) {
// Get the buffer for the operand.
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1159,10 @@ struct ConcatOpInterface
// so the offset on that axis must accumulate through the loop, and the
// size must change to the size of the current operand.
auto operandTensorType = cast<RankedTensorType>(operand.getType());
- int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-
- if (dynamicConcatDim) {
- offsets[concatDim] = dynamicOffset.value();
- dynamicSize =
- memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
- .getResult();
- sizes[concatDim] = dynamicSize.value();
- } else {
- sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
- offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
- }
+ offsets[concatDim] = concatDimOffset;
+ OpFoldResult concatDimSize =
+ memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
+ sizes[concatDim] = concatDimSize;
// Create a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1177,7 @@ struct ConcatOpInterface
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
return failure();
- if (dynamicConcatDim) {
- dynamicOffset = arith::AddIOp::create(
- rewriter, loc, dynamicOffset.value(), dynamicSize.value());
- } else {
- concatDimOffset += operandConcatDimSize;
- }
+ concatDimOffset = sum(concatDimOffset, concatDimSize);
}
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5eb2360a29b8f..be8ce20d8f154 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<8x?xf32>
-// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
-// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
-// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<?x?xf32>
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
-// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
-// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// -----
+// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static(
+// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
+// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>)
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
+// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
+// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
+ %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
+ return %0 : tensor<8x10xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
|
🐧 Linux x64 Test Results
|
|
Friendly ping~ |
| // CHECK: return %[[RET]] | ||
| // CHECK: } | ||
| func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> { | ||
| %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks odd to me: I expected that the result dimension must be dynamic if and only if one of the input dimensions is dynamic. We follow this design in other operations such tensor.collapse_shape. E.g., see CollapseShapeOp::inferCollapsedType for details. Can the verifier be made more strict instead? @qedawkins
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, actually the docs give an example like this:
llvm-project/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Lines 154 to 156 in c33e50b
| // Dynamic + dynamic -> static | |
| %0 = tensor.concat dim(1) %0, %1, %2 : | |
| (tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32> |
But setting aside whether this verifier check is reasonable, the way this PR unifies the computation using
OpFoldResult should still be valuable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I intentionally allowed dynamic + dynamic -> static since it's possible for such situations to arise where forcing the result to be dynamic would require a tensor.cast to introduce the static info. Arguably collapse could do the same, though in the collapse case it's probably a lot less likely someone is choosing to collapse two truly dynamic dimensions into a static one intentionally, so being defensive was probably a win there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a problem with an explicit tensor.cast?
It would be nice to have a consistent op design across the tensor dialect. I believe one reason why we chose input dynamicity == output dynamicity for collapse_shape/expand_shape is that we can print better error messages: if there's only one allowable output type, you can print it during verification errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delayed response. We aren't currently using the functionality for dynamic + dynamic -> static downstream so we wouldn't notice if it was removed (right now), but in general casts hinder optimization by cluttering use-def chains.
As an example, imagine if we added a tensor.split op as the inverse of concat. The folder for it would look for tensor.split(tensor.concat) but the cast would get in the way.
qedawkins
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not to circumvent the discussion, but the change itself here LGTM.
This PR fixes an issue in
ConcatOpInterfacewheretensor.concatfails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by usingOpFoldResult, avoiding the need to separately handle dynamic and static dimension values. Fixes #162776.