-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Fix lower_unpack when dynamic dimensions are involved
#68423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When lowering `tensor.unpack`, we need to use the sizes of the destination tensor in the final `tensor.extract_slice` operation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of the `tensor.unpack` operation instead of its destination argument. This would produce invalid IR because the `tensor.dim` operations would need to appear before the `tensor.extract_slice` operation, but the input of the `tensor.dim` operations would consume the final result of the lowering of `tensor.unpack`, which happens after the `tensor.extract_slice` operation. In other words, the definition wouldn't dominate its uses. I.e., we were generating: ``` %dynDim = tensor.dim %defLater, ... <-- %defLater defined below %res = tensor.extract_slice ..., %dynDim, ... %defLater = linalg.copy (ins %res) ``` Note: I checked the implementation of `lower_pack` and the code is correct as far as I can tell.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg ChangesWhen lowering This would produce invalid IR because the I.e., we were generating: Note: I checked the implementation of Full diff: https://github.com/llvm/llvm-project/pull/68423.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8183b40ad7346f4..bca343cf8777149 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -467,7 +467,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, destTensorType, reshapeOp->getResult(0),
SmallVector<OpFoldResult>(destRank, zero),
- tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
+ tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
SmallVector<OpFoldResult>(destRank, one));
// 7. Inject a copy to preserve DPS.
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index c71feddcc1c8486..ad6c6a6f6199cc6 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -133,7 +133,7 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
// CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
// CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
- // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
+ // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
// CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
: tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
@@ -397,3 +397,40 @@ transform.sequence failures(propagate) {
transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
}
+
+// -----
+
+// Check that we can lower unpack with dynamic dimensions in the destination.
+// CHECK-LABEL: func.func @unpack_with_dynamic_dest(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<32x2x49x16x16xf32>, %[[ARG1:.*]]: tensor<32x?x?xf32>)
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32x2x16x49x16xf32>
+// CHECK: %[[TRAN:.*]] = linalg.transpose
+// CHECK-SAME: ins(%[[ARG0]] : tensor<32x2x49x16x16xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<32x2x16x49x16xf32>)
+// CHECK-SAME: permutation = [0, 1, 3, 2, 4]
+// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0], [1, 2], [3, 4]]
+// CHECK-SAME: : tensor<32x2x16x49x16xf32> into tensor<32x32x784xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<32x?x?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<32x?x?xf32>
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0] [32, %[[DIM1]], %[[DIM2]]] [1, 1, 1]
+// CHECK-SAME: : tensor<32x32x784xf32> to tensor<32x?x?xf32>
+// CHECK: linalg.copy ins(%[[SLICE]] : tensor<32x?x?xf32>)
+// CHECK-SAME: outs(%[[ARG1]] : tensor<32x?x?xf32>)
+func.func @unpack_with_dynamic_dest(%arg0: tensor<32x2x49x16x16xf32>, %arg1: tensor<32x?x?xf32>) -> tensor<32x?x?xf32> {
+ %pack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %arg1
+ : tensor<32x2x49x16x16xf32> -> tensor<32x?x?xf32>
+ return %pack : tensor<32x?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
+ : (!transform.any_op) -> !transform.op<"tensor.unpack">
+ transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
+ -> (!transform.op<"tensor.empty">,
+ !transform.op<"linalg.transpose">,
+ !transform.op<"tensor.collapse_shape">,
+ !transform.op<"tensor.extract_slice">)
+}
|
When lowering
tensor.unpack, we need to use the sizes of the destination tensor in the finaltensor.extract_sliceoperation. Prior to this patch, when the destination tensor had dynamic dimensions, we would compute them from the result of thetensor.unpackoperation instead of its destination argument.This would produce invalid IR because the
tensor.dimoperations would need to appear before thetensor.extract_sliceoperation, but the input of thetensor.dimoperations would consume the final result of the lowering oftensor.unpack, which happens after thetensor.extract_sliceoperation. In other words, the definition wouldn't dominate its uses.I.e., we were generating:
Note: I checked the implementation of
lower_packand the code is correct as far as I can tell.