Skip to content

Conversation

@CoTinker
Copy link
Contributor

This PR fixes an issue in ConcatOpInterface where tensor.concat fails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by using OpFoldResult, avoiding the need to separately handle dynamic and static dimension values. Fixes #162776.

This PR fixes an issue in `ConcatOpInterface` where `tensor.concat` fails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by using `OpFoldResult`, avoiding the need to separately handle dynamic and static dimension values.
@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2025

@llvm/pr-subscribers-mlir-tensor

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes an issue in ConcatOpInterface where tensor.concat fails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by using OpFoldResult, avoiding the need to separately handle dynamic and static dimension values. Fixes #162776.


Full diff: https://github.com/llvm/llvm-project/pull/168676.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-41)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+33-7)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c607ece418dff..5482cedae71d7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1132,35 +1132,23 @@ struct ConcatOpInterface
 
     // Extract the dimension for the concat op
     uint64_t concatDim = concatOp.getDim();
-    bool dynamicConcatDim = false;
 
     SmallVector<OpFoldResult> offsets(tensorType.getRank(),
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(tensorType.getRank(),
                                       rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> sizes;
-
-    for (const auto &[dimIdx, dimSize] :
-         llvm::enumerate(tensorType.getShape())) {
-      if (dimSize == ShapedType::kDynamic) {
-        auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
-        sizes.push_back(dimOp.getResult());
-        if (dimIdx == concatDim)
-          dynamicConcatDim = true;
-      } else {
-        sizes.push_back(rewriter.getIndexAttr(dimSize));
-      }
-    }
-
-    int64_t concatDimOffset = 0;
-    std::optional<Value> dynamicOffset;
-    std::optional<Value> dynamicSize;
-    if (dynamicConcatDim) {
-      // One or more operands have dynamic size, so we must accumulate the
-      // offset with arith ops.
-      dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
-    }
+    SmallVector<OpFoldResult> sizes =
+        memref::getMixedSizes(rewriter, loc, dstBuffer);
+
+    AffineExpr d0, d1;
+    bindDims(rewriter.getContext(), d0, d1);
+    // Add two integers.
+    auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
+      return affine::makeComposedFoldedAffineApply(rewriter, loc, d0 + d1,
+                                                   {v1, v2});
+    };
 
+    OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
     for (auto operand : concatOp.getInputs()) {
       // Get the buffer for the operand.
       FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1159,10 @@ struct ConcatOpInterface
       // so the offset on that axis must accumulate through the loop, and the
       // size must change to the size of the current operand.
       auto operandTensorType = cast<RankedTensorType>(operand.getType());
-      int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-
-      if (dynamicConcatDim) {
-        offsets[concatDim] = dynamicOffset.value();
-        dynamicSize =
-            memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
-                .getResult();
-        sizes[concatDim] = dynamicSize.value();
-      } else {
-        sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
-        offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
-      }
+      offsets[concatDim] = concatDimOffset;
+      OpFoldResult concatDimSize =
+          memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
+      sizes[concatDim] = concatDimSize;
 
       // Create a subview of the destination buffer.
       auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1177,7 @@ struct ConcatOpInterface
       if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
         return failure();
 
-      if (dynamicConcatDim) {
-        dynamicOffset = arith::AddIOp::create(
-            rewriter, loc, dynamicOffset.value(), dynamicSize.value());
-      } else {
-        concatDimOffset += operandConcatDimSize;
-      }
+      concatDimOffset = sum(concatDimOffset, concatDimSize);
     }
 
     replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5eb2360a29b8f..be8ce20d8f154 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
 // CHECK-DAG:       %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
 // CHECK:           %[[ALLOC:.*]] = memref.alloc
 // CHECK-SAME:                                    memref<8x?xf32>
-// CHECK-DAG:       %[[OFFSET:.*]] = arith.constant 0 : index
-// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
-// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:           return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
 // CHECK:           %[[ALLOC:.*]] = memref.alloc
 // CHECK-SAME:                                    memref<?x?xf32>
 // CHECK-DAG:       %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
-// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
-// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:           return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
 
 // -----
 
+// CHECK:  #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL:   func @tensor.concat_mixed_dynamic_static(
+// CHECK-SAME:        %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
+// CHECK-SAME:        %[[H:.*]]: tensor<8x2xf32>)
+// CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG:       %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
+// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
+// CHECK-DAG:       %[[c1:.*]] = arith.constant 1 : index
+// CHECK:           %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK:           %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK:           %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
+// CHECK:           %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
+// CHECK:           memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
+  %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
+  return %0 : tensor<8x10xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @tensor.splat_dynamic(
 // CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
 // CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2025

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes an issue in ConcatOpInterface where tensor.concat fails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by using OpFoldResult, avoiding the need to separately handle dynamic and static dimension values. Fixes #162776.


Full diff: https://github.com/llvm/llvm-project/pull/168676.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-41)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+33-7)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c607ece418dff..5482cedae71d7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1132,35 +1132,23 @@ struct ConcatOpInterface
 
     // Extract the dimension for the concat op
     uint64_t concatDim = concatOp.getDim();
-    bool dynamicConcatDim = false;
 
     SmallVector<OpFoldResult> offsets(tensorType.getRank(),
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(tensorType.getRank(),
                                       rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> sizes;
-
-    for (const auto &[dimIdx, dimSize] :
-         llvm::enumerate(tensorType.getShape())) {
-      if (dimSize == ShapedType::kDynamic) {
-        auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
-        sizes.push_back(dimOp.getResult());
-        if (dimIdx == concatDim)
-          dynamicConcatDim = true;
-      } else {
-        sizes.push_back(rewriter.getIndexAttr(dimSize));
-      }
-    }
-
-    int64_t concatDimOffset = 0;
-    std::optional<Value> dynamicOffset;
-    std::optional<Value> dynamicSize;
-    if (dynamicConcatDim) {
-      // One or more operands have dynamic size, so we must accumulate the
-      // offset with arith ops.
-      dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
-    }
+    SmallVector<OpFoldResult> sizes =
+        memref::getMixedSizes(rewriter, loc, dstBuffer);
+
+    AffineExpr d0, d1;
+    bindDims(rewriter.getContext(), d0, d1);
+    // Add two integers.
+    auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
+      return affine::makeComposedFoldedAffineApply(rewriter, loc, d0 + d1,
+                                                   {v1, v2});
+    };
 
+    OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
     for (auto operand : concatOp.getInputs()) {
       // Get the buffer for the operand.
       FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1159,10 @@ struct ConcatOpInterface
       // so the offset on that axis must accumulate through the loop, and the
       // size must change to the size of the current operand.
       auto operandTensorType = cast<RankedTensorType>(operand.getType());
-      int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-
-      if (dynamicConcatDim) {
-        offsets[concatDim] = dynamicOffset.value();
-        dynamicSize =
-            memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
-                .getResult();
-        sizes[concatDim] = dynamicSize.value();
-      } else {
-        sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
-        offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
-      }
+      offsets[concatDim] = concatDimOffset;
+      OpFoldResult concatDimSize =
+          memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
+      sizes[concatDim] = concatDimSize;
 
       // Create a subview of the destination buffer.
       auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1177,7 @@ struct ConcatOpInterface
       if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
         return failure();
 
-      if (dynamicConcatDim) {
-        dynamicOffset = arith::AddIOp::create(
-            rewriter, loc, dynamicOffset.value(), dynamicSize.value());
-      } else {
-        concatDimOffset += operandConcatDimSize;
-      }
+      concatDimOffset = sum(concatDimOffset, concatDimSize);
     }
 
     replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5eb2360a29b8f..be8ce20d8f154 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
 // CHECK-DAG:       %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
 // CHECK:           %[[ALLOC:.*]] = memref.alloc
 // CHECK-SAME:                                    memref<8x?xf32>
-// CHECK-DAG:       %[[OFFSET:.*]] = arith.constant 0 : index
-// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
-// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:           return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
 // CHECK:           %[[ALLOC:.*]] = memref.alloc
 // CHECK-SAME:                                    memref<?x?xf32>
 // CHECK-DAG:       %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
-// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
-// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:           return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
 
 // -----
 
+// CHECK:  #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL:   func @tensor.concat_mixed_dynamic_static(
+// CHECK-SAME:        %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
+// CHECK-SAME:        %[[H:.*]]: tensor<8x2xf32>)
+// CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG:       %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
+// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
+// CHECK-DAG:       %[[c1:.*]] = arith.constant 1 : index
+// CHECK:           %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK:           %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK:           %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
+// CHECK:           %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
+// CHECK:           memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
+  %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
+  return %0 : tensor<8x10xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @tensor.splat_dynamic(
 // CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
 // CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index

@github-actions
Copy link

github-actions bot commented Nov 19, 2025

🐧 Linux x64 Test Results

  • 7151 tests passed
  • 594 tests skipped

@CoTinker
Copy link
Contributor Author

Friendly ping~

// CHECK: return %[[RET]]
// CHECK: }
func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
%0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks odd to me: I expected that the result dimension must be dynamic if and only if one of the input dimensions is dynamic. We follow this design in other operations such tensor.collapse_shape. E.g., see CollapseShapeOp::inferCollapsedType for details. Can the verifier be made more strict instead? @qedawkins

Copy link
Contributor Author

@CoTinker CoTinker Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, actually the docs give an example like this:

// Dynamic + dynamic -> static
%0 = tensor.concat dim(1) %0, %1, %2 :
(tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>

But setting aside whether this verifier check is reasonable, the way this PR unifies the computation using OpFoldResult should still be valuable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally allowed dynamic + dynamic -> static since it's possible for such situations to arise where forcing the result to be dynamic would require a tensor.cast to introduce the static info. Arguably collapse could do the same, though in the collapse case it's probably a lot less likely someone is choosing to collapse two truly dynamic dimensions into a static one intentionally, so being defensive was probably a win there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a problem with an explicit tensor.cast?

It would be nice to have a consistent op design across the tensor dialect. I believe one reason why we chose input dynamicity == output dynamicity for collapse_shape/expand_shape is that we can print better error messages: if there's only one allowable output type, you can print it during verification errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed response. We aren't currently using the functionality for dynamic + dynamic -> static downstream so we wouldn't notice if it was removed (right now), but in general casts hinder optimization by cluttering use-def chains.

As an example, imagine if we added a tensor.split op as the inverse of concat. The folder for it would look for tensor.split(tensor.concat) but the cast would get in the way.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not to circumvent the discussion, but the change itself here LGTM.

@CoTinker CoTinker merged commit 039f883 into llvm:main Dec 2, 2025
10 checks passed
@CoTinker CoTinker deleted the concat branch December 2, 2025 03:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] integer overflow in BufferizableOpInterfaceImpl.cpp

4 participants