Skip to content

[mlir][linalg] Add pattern to bubble-up pack through expand shape op #93529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 18, 2024

Conversation

adam-smnk
Copy link
Contributor

Extends bubble-up pack through reshape pattern to handle pack propagation through expand shape ops.

Extends bubble-up pack through reshape pattern to handle pack propagation
through expand shape ops.
@llvmbot
Copy link
Member

llvmbot commented May 28, 2024

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Extends bubble-up pack through reshape pattern to handle pack propagation through expand shape ops.


Patch is 20.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93529.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+104)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+204)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 2bea083ac2d78..73a86caa2fbcb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -17,6 +17,8 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -694,6 +696,105 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   return success();
 }
 
+/// Project dimsPos to their collapsed positions in the reassocIndices.
+///
+/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
+/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
+/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
+/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
+static SmallVector<int64_t>
+projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
+                             ArrayRef<ReassociationIndices> reassocIndices) {
+  SmallVector<int64_t> projectedPos;
+
+  // Map each dimension to the position of corresponding reassociation index.
+  for (auto pos : dimsPos) {
+    for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
+      // If the dimension is present in the current indices group, the group
+      // position within the reassociation map is the desired projected
+      // dimension position.
+      if (llvm::any_of(indices,
+                       [&](int64_t expandDim) { return expandDim == pos; })) {
+        projectedPos.push_back(idx);
+        break;
+      }
+    }
+  }
+  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
+
+  return projectedPos;
+}
+
+/// Bubble up pack op through expand shape op.
+static LogicalResult
+bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
+                                 tensor::PackOp packOp,
+                                 PatternRewriter &rewriter) {
+  // Cannot propagate shape expansion if there is outer dimensions permutation.
+  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
+    return rewriter.notifyMatchFailure(
+        packOp, "expects outer_dims_perm is empty or an identity permutation");
+  }
+
+  // Validate dimensions' relations between shape expansion and packing.
+  SmallVector<ReassociationIndices, 4> reassoc =
+      expandOp.getReassociationIndices();
+  ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
+  llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
+                                       packInnerDims.end());
+
+  for (auto [idx, indices] : llvm::enumerate(reassoc)) {
+    llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
+    llvm::SetVector<int64_t> packedDims =
+        llvm::set_intersection(packDimsPos, expandDimPos);
+
+    // The expanded dimension is not packed - simply continue.
+    if (packedDims.empty())
+      continue;
+    // Shape expansion cannot be propagated when multiple expanded dimension are
+    // packed.
+    if (packedDims.size() > 1)
+      return rewriter.notifyMatchFailure(
+          packOp, "only one of the expanded dimensions can be packed");
+    // Only the inner-most dim should be packed. Otherwise, elements order will
+    // be affected after operation reordering.
+    if (packedDims[0] != indices.back())
+      return rewriter.notifyMatchFailure(
+          packOp, "can only pack the inner-most expanded dimension");
+  }
+
+  // Project pack.inner_dims_pos to positions before shape expansion.
+  SmallVector<int64_t> projectedInnerDimsPos =
+      projectDimsPosIntoReassocPos(packInnerDims, reassoc);
+
+  // Project the shape expansion to new packed shape.
+  // The pack.outer_dims_perm is restricted to identity so, the permutation can
+  // be omitted for simplicity.
+  RankedTensorType newPackType = tensor::PackOp::inferPackedType(
+      expandOp.getSrcType(), packOp.getStaticInnerTiles(),
+      projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+  auto reassocExpand =
+      getReassociationIndicesForReshape(newPackType, packOp.getDestType());
+  if (!reassocExpand)
+    return rewriter.notifyMatchFailure(
+        packOp, "could not reassociate dims after bubbling up");
+
+  Value destTensor = tensor::PackOp::createDestinationTensor(
+      rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
+      projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+  Value packedVal = rewriter.create<tensor::PackOp>(
+      packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
+      packOp.getMixedTiles(), packOp.getPaddingValue(),
+      /*outerDimsPerm=*/SmallVector<int64_t>{});
+
+  Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+      packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
+  rewriter.replaceOp(packOp, newExpandOp);
+
+  return success();
+}
+
 class BubbleUpPackOpThroughReshapeOp final
     : public OpRewritePattern<tensor::PackOp> {
 public:
@@ -723,6 +824,9 @@ class BubbleUpPackOpThroughReshapeOp final
         .Case([&](tensor::CollapseShapeOp op) {
           return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
         })
+        .Case([&](tensor::ExpandShapeOp op) {
+          return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
+        })
         .Default([](Operation *) { return failure(); });
   }
 
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 9140904620acd..43f9799357df5 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -988,6 +988,210 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4
 
 // -----
 
+func.func @bubble_up_pack_outer_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x64x4xf32> {
+  %empty = tensor.empty() : tensor<4x2x64x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [1] inner_tiles = [4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x64x4xf32>
+  return %pack : tensor<4x2x64x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_outer_expanded_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x64x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_inner_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x4x4xf32> {
+  %empty = tensor.empty() : tensor<32x4x4x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [4] into %empty : tensor<32x4x16xf32> -> tensor<32x4x4x4xf32>
+  return %pack : tensor<32x4x4x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_inner_expanded_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<32x16x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<32x4x4x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x32x16x4xf32> {
+  %empty = tensor.empty() : tensor<8x2x32x16x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [4] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x32x16x4xf32>
+  return %pack : tensor<8x2x32x16x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]] output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<8x2x32x16x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_expand_dynamic(%arg0: tensor<?x64xf32>) -> tensor<?x4x2x8xf32> {
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?x64xf32>
+  %empty = tensor.empty(%dim) : tensor<?x4x2x8xf32>
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%dim, 4, 16] : tensor<?x64xf32> into tensor<?x4x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [8] into %empty : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
+  return %pack : tensor<?x4x2x8xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_expand_dynamic(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[DIM_INPUT:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x64xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM_INPUT]]) : tensor<?x8x8xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] : tensor<?x64xf32> -> tensor<?x8x8xf32>
+// CHECK:         %[[DIM_PACK:.+]] = tensor.dim %[[PACK]], %[[C0]] : tensor<?x8x8xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<?x4x2x8xf32>
+
+// -----
+
+func.func @bubble_up_pack_non_expanded_padding_through_expand(%arg0: tensor<32x60xf32>) -> tensor<4x2x8x4x8xf32> {
+  %cst = arith.constant 3.000000e+00 : f32
+  %empty = tensor.empty() : tensor<4x2x8x4x8xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x60xf32> into tensor<4x8x60xf32>
+  %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1, 2] inner_tiles = [4, 8] into %empty : tensor<4x8x60xf32> -> tensor<4x2x8x4x8xf32>
+  return %pack : tensor<4x2x8x4x8xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_padding_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[CST:.+]] = arith.constant 3.000000e+00 : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32) inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]] : tensor<32x60xf32> -> tensor<8x8x4x8xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x8x4x8xf32>
+
+// -----
+
+func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x32x4x2xf32> {
+  %empty = tensor.empty() : tensor<4x2x32x4x2xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<4x2x32x4x2xf32>
+  return %pack : tensor<4x2x32x4x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x32x4x2xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x32x4x2xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x32x4x2xf32>
+
+// -----
+
+func.func @bubble_up_pack_multiple_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x4x8x4x8x2xf32> {
+  %empty = tensor.empty() : tensor<8x2x4x8x4x8x2xf32>
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [0, 2, 3] inner_tiles = [4, 8, 2] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x4x8x4x8x2xf32>
+  return %pack : tensor<8x2x4x8x4x8x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_multiple_dims_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x8x4x8x2xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]] output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<8x2x4x8x4x8x2xf32>
+
+// -----
+
+func.func @bubble_up_pack_inner_dims_reorder_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x4x16x4xf32> {
+  %empty = tensor.empty() : tensor<4x2x4x16x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [2, 1] inner_tiles = [16, 4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x4x16x4xf32>
+  return %pack : tensor<4x2x4x16x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_inner_dims_reorder_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x4x16x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x4x16x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<4x2x2x8x16x4x4xf32> {
+  %empty = tensor.empty() : tensor<4x2x2x8x16x4x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [4, 8, 2, 32, 16] : tensor<32x64x16xf32> into tensor<4x8x2x32x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [1, 3] inner_tiles = [4, 4] into %empty : tensor<4x8x2x32x16xf32> -> tensor<4x2x2x8x16x4x4xf32>
+  return %pack : tensor<4x2x2x8x16x4x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x16x16x4x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]] output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x2x8x16x4x4xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x2x4x2xf32> {
+  %empty = tensor.empty() : tensor<32x4x2x4x2xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+  return %pack : tensor<32x4x2x4x2xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x4x2x4x2xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]] outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+// CHECK:         return %[[PACK]] : tensor<32x4x2x4x2xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x2x64x2x4xf32> {
+  %empty = tensor.empty() : tensor<2x2x64x2x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+  return %pack : tensor<2x2x64x2x4xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x64x2x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]] inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+// CHECK:         return %[[PACK]] : tensor<2x2x64x2x4xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x8x64x2xf32> {
+  %empty = tensor.empty() : tensor<2x8x64x2xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [2] into %empty : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+  return %pack : tensor<2x8x64x2xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x2xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]] inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+// CHECK:         return %[[PACK]] : tensor<2x8x64x2xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(%arg0: tensor<30x60xf32>) -> tensor<3x2x60x8xf32> {
+  %cst = arith.constant 3.000000e+00 : f32
+  %empty = tensor.empty() : tensor<3x2x60x8xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
+  %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [8] into %empty : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
+  return %pack : tensor<3x2x60x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
+// CHECK-SAME:      %[[ARG0:[a-...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 28, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Adam Siemieniuk (adam-smnk)

Changes

Extends bubble-up pack through reshape pattern to handle pack propagation through expand shape ops.


Patch is 20.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93529.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+104)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+204)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 2bea083ac2d78..73a86caa2fbcb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -17,6 +17,8 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -694,6 +696,105 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   return success();
 }
 
+/// Project dimsPos to their collapsed positions in the reassocIndices.
+///
+/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
+/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
+/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
+/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
+static SmallVector<int64_t>
+projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
+                             ArrayRef<ReassociationIndices> reassocIndices) {
+  SmallVector<int64_t> projectedPos;
+
+  // Map each dimension to the position of corresponding reassociation index.
+  for (auto pos : dimsPos) {
+    for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
+      // If the dimension is present in the current indices group, the group
+      // position within the reassociation map is the desired projected
+      // dimension position.
+      if (llvm::any_of(indices,
+                       [&](int64_t expandDim) { return expandDim == pos; })) {
+        projectedPos.push_back(idx);
+        break;
+      }
+    }
+  }
+  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
+
+  return projectedPos;
+}
+
+/// Bubble up pack op through expand shape op.
+static LogicalResult
+bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
+                                 tensor::PackOp packOp,
+                                 PatternRewriter &rewriter) {
+  // Cannot propagate shape expansion if there is outer dimensions permutation.
+  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
+    return rewriter.notifyMatchFailure(
+        packOp, "expects outer_dims_perm is empty or an identity permutation");
+  }
+
+  // Validate dimensions' relations between shape expansion and packing.
+  SmallVector<ReassociationIndices, 4> reassoc =
+      expandOp.getReassociationIndices();
+  ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
+  llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
+                                       packInnerDims.end());
+
+  for (auto [idx, indices] : llvm::enumerate(reassoc)) {
+    llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
+    llvm::SetVector<int64_t> packedDims =
+        llvm::set_intersection(packDimsPos, expandDimPos);
+
+    // The expanded dimension is not packed - simply continue.
+    if (packedDims.empty())
+      continue;
+    // Shape expansion cannot be propagated when multiple expanded dimension are
+    // packed.
+    if (packedDims.size() > 1)
+      return rewriter.notifyMatchFailure(
+          packOp, "only one of the expanded dimensions can be packed");
+    // Only the inner-most dim should be packed. Otherwise, elements order will
+    // be affected after operation reordering.
+    if (packedDims[0] != indices.back())
+      return rewriter.notifyMatchFailure(
+          packOp, "can only pack the inner-most expanded dimension");
+  }
+
+  // Project pack.inner_dims_pos to positions before shape expansion.
+  SmallVector<int64_t> projectedInnerDimsPos =
+      projectDimsPosIntoReassocPos(packInnerDims, reassoc);
+
+  // Project the shape expansion to new packed shape.
+  // The pack.outer_dims_perm is restricted to identity so, the permutation can
+  // be omitted for simplicity.
+  RankedTensorType newPackType = tensor::PackOp::inferPackedType(
+      expandOp.getSrcType(), packOp.getStaticInnerTiles(),
+      projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+  auto reassocExpand =
+      getReassociationIndicesForReshape(newPackType, packOp.getDestType());
+  if (!reassocExpand)
+    return rewriter.notifyMatchFailure(
+        packOp, "could not reassociate dims after bubbling up");
+
+  Value destTensor = tensor::PackOp::createDestinationTensor(
+      rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
+      projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+  Value packedVal = rewriter.create<tensor::PackOp>(
+      packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
+      packOp.getMixedTiles(), packOp.getPaddingValue(),
+      /*outerDimsPerm=*/SmallVector<int64_t>{});
+
+  Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+      packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
+  rewriter.replaceOp(packOp, newExpandOp);
+
+  return success();
+}
+
 class BubbleUpPackOpThroughReshapeOp final
     : public OpRewritePattern<tensor::PackOp> {
 public:
@@ -723,6 +824,9 @@ class BubbleUpPackOpThroughReshapeOp final
         .Case([&](tensor::CollapseShapeOp op) {
           return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
         })
+        .Case([&](tensor::ExpandShapeOp op) {
+          return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
+        })
         .Default([](Operation *) { return failure(); });
   }
 
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 9140904620acd..43f9799357df5 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -988,6 +988,210 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4
 
 // -----
 
+func.func @bubble_up_pack_outer_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x64x4xf32> {
+  %empty = tensor.empty() : tensor<4x2x64x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [1] inner_tiles = [4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x64x4xf32>
+  return %pack : tensor<4x2x64x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_outer_expanded_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x64x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_inner_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x4x4xf32> {
+  %empty = tensor.empty() : tensor<32x4x4x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [4] into %empty : tensor<32x4x16xf32> -> tensor<32x4x4x4xf32>
+  return %pack : tensor<32x4x4x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_inner_expanded_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<32x16x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<32x4x4x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x32x16x4xf32> {
+  %empty = tensor.empty() : tensor<8x2x32x16x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [4] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x32x16x4xf32>
+  return %pack : tensor<8x2x32x16x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]] output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<8x2x32x16x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_expand_dynamic(%arg0: tensor<?x64xf32>) -> tensor<?x4x2x8xf32> {
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?x64xf32>
+  %empty = tensor.empty(%dim) : tensor<?x4x2x8xf32>
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%dim, 4, 16] : tensor<?x64xf32> into tensor<?x4x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [8] into %empty : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
+  return %pack : tensor<?x4x2x8xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_expand_dynamic(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[DIM_INPUT:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x64xf32>
+// CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM_INPUT]]) : tensor<?x8x8xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] : tensor<?x64xf32> -> tensor<?x8x8xf32>
+// CHECK:         %[[DIM_PACK:.+]] = tensor.dim %[[PACK]], %[[C0]] : tensor<?x8x8xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<?x4x2x8xf32>
+
+// -----
+
+func.func @bubble_up_pack_non_expanded_padding_through_expand(%arg0: tensor<32x60xf32>) -> tensor<4x2x8x4x8xf32> {
+  %cst = arith.constant 3.000000e+00 : f32
+  %empty = tensor.empty() : tensor<4x2x8x4x8xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x60xf32> into tensor<4x8x60xf32>
+  %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1, 2] inner_tiles = [4, 8] into %empty : tensor<4x8x60xf32> -> tensor<4x2x8x4x8xf32>
+  return %pack : tensor<4x2x8x4x8xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_padding_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[CST:.+]] = arith.constant 3.000000e+00 : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32) inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]] : tensor<32x60xf32> -> tensor<8x8x4x8xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x8x4x8xf32>
+
+// -----
+
+func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x32x4x2xf32> {
+  %empty = tensor.empty() : tensor<4x2x32x4x2xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<4x2x32x4x2xf32>
+  return %pack : tensor<4x2x32x4x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x32x4x2xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x32x4x2xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x32x4x2xf32>
+
+// -----
+
+func.func @bubble_up_pack_multiple_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x4x8x4x8x2xf32> {
+  %empty = tensor.empty() : tensor<8x2x4x8x4x8x2xf32>
+  %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [0, 2, 3] inner_tiles = [4, 8, 2] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x4x8x4x8x2xf32>
+  return %pack : tensor<8x2x4x8x4x8x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_multiple_dims_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x8x4x8x2xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]] output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<8x2x4x8x4x8x2xf32>
+
+// -----
+
+func.func @bubble_up_pack_inner_dims_reorder_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x4x16x4xf32> {
+  %empty = tensor.empty() : tensor<4x2x4x16x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [2, 1] inner_tiles = [16, 4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x4x16x4xf32>
+  return %pack : tensor<4x2x4x16x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_inner_dims_reorder_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x4x16x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x4x16x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<4x2x2x8x16x4x4xf32> {
+  %empty = tensor.empty() : tensor<4x2x2x8x16x4x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [4, 8, 2, 32, 16] : tensor<32x64x16xf32> into tensor<4x8x2x32x16xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [1, 3] inner_tiles = [4, 4] into %empty : tensor<4x8x2x32x16xf32> -> tensor<4x2x2x8x16x4x4xf32>
+  return %pack : tensor<4x2x2x8x16x4x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<8x16x16x4x4xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]] : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]] output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32>
+// CHECK:         return %[[EXPANDED]] : tensor<4x2x2x8x16x4x4xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x2x4x2xf32> {
+  %empty = tensor.empty() : tensor<32x4x2x4x2xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+  return %pack : tensor<32x4x2x4x2xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<32x4x2x4x2xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]] outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+// CHECK:         return %[[PACK]] : tensor<32x4x2x4x2xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x2x64x2x4xf32> {
+  %empty = tensor.empty() : tensor<2x2x64x2x4xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+  return %pack : tensor<2x2x64x2x4xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x64x2x4xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]] inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+// CHECK:         return %[[PACK]] : tensor<2x2x64x2x4xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x8x64x2xf32> {
+  %empty = tensor.empty() : tensor<2x8x64x2xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+  %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [2] into %empty : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+  return %pack : tensor<2x8x64x2xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(
+// CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x2xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[EXPANDED]] inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]] : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+// CHECK:         return %[[PACK]] : tensor<2x8x64x2xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(%arg0: tensor<30x60xf32>) -> tensor<3x2x60x8xf32> {
+  %cst = arith.constant 3.000000e+00 : f32
+  %empty = tensor.empty() : tensor<3x2x60x8xf32>
+  %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
+  %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [8] into %empty : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
+  return %pack : tensor<3x2x60x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
+// CHECK-SAME:      %[[ARG0:[a-...
[truncated]

@hanhanW hanhanW requested a review from pashu123 May 28, 2024 18:26
Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial comment.

@MaheshRavishankar MaheshRavishankar requested a review from Max191 May 30, 2024 15:28
@adam-smnk
Copy link
Contributor Author

@pashu123 @Max191 ping

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for pushing on this

@adam-smnk adam-smnk merged commit a945f55 into llvm:main Jun 18, 2024
7 checks passed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something in this new code is triggering a crash (compiler assert) in the downstream IREE project: iree-org/iree#17734. If I revert this PR locally, the crash goes away.

I don't have a reduced test case yet and the input program is large (12MB) + specific to our downstream project.

  • Assert + stack trace:

    Assertion failed: input.size() == permutation.size() && "expected input rank to equal permutation rank", file D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir/Dialect/Utils/IndexingUtils.h, line 204
    Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
    Stack dump:
    0.	Program arguments: D:\\dev\\projects\\iree-build\\tools\\iree-compile.exe D:/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.mlir --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host -o D:/tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb --mlir-print-ir-before-all --mlir-elide-elementsattrs-if-larger=8 --mlir-elide-resource-strings-if-larger=8 --mlir-disable-threading
    Exception Code: 0x80000003
     #0 0x00007ff64dcf8e95 HandleAbort D:\dev\projects\iree\third_party\llvm-project\llvm\lib\Support\Windows\Signals.inc:425:0
     #1 0x00007ffe7d561881 (C:\WINDOWS\System32\ucrtbase.dll+0x71881)
     #2 0x00007ffe7d562851 (C:\WINDOWS\System32\ucrtbase.dll+0x72851)
     #3 0x00007ffe7d5641b5 (C:\WINDOWS\System32\ucrtbase.dll+0x741b5)
     #4 0x00007ffe7d5644f1 (C:\WINDOWS\System32\ucrtbase.dll+0x744f1)
     #5 0x00007ff651c4f70f mlir::applyPermutation<class llvm::SmallVector<__int64, 2>>(class llvm::ArrayRef<class llvm::SmallVector<__int64, 2>>, class llvm::ArrayRef<__int64>) D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir\Dialect\Utils\IndexingUtils.h:205:0
     #6 0x00007ff651c49c11 mlir::applyPermutation<class llvm::SmallVector<__int64, 2>>(class llvm::SmallVectorImpl<class llvm::SmallVector<__int64, 2>> const &, class llvm::ArrayRef<__int64>) D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir\Dialect\Utils\IndexingUtils.h:214:0
     #7 0x00007ff651c41c2b mlir::applyPermutationToVector<class llvm::SmallVector<__int64, 2>, 1>(class llvm::SmallVector<class llvm::SmallVector<__int64, 2>, 1> &, class llvm::ArrayRef<__int64>) D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir\Dialect\Utils\IndexingUtils.h:225:0
     #8 0x00007ff6531d1eab `anonymous namespace'::applyPermutationAndReindexReassoc D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Dialect\Linalg\Transforms\DataLayoutPropagation.cpp:610:0
     #9 0x00007ff6531d268d `anonymous namespace'::bubbleUpPackOpThroughCollapseShape D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Dialect\Linalg\Transforms\DataLayoutPropagation.cpp:687:0
    #10 0x00007ff6531d3cd6 `anonymous namespace'::BubbleUpPackOpThroughReshapeOp::matchAndRewrite D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Dialect\Linalg\Transforms\DataLayoutPropagation.cpp:849:0
    #11 0x00007ff650b2bbe4 mlir::detail::OpOrInterfaceRewritePatternBase<class mlir::tensor::PackOp>::matchAndRewrite(class mlir::Operation *, class mlir::PatternRewriter &) const D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir\IR\PatternMatch.h:332:0
    #12 0x00007ff65209e8eb <lambda_033eed04a8a10a7b33015298d48d216a>::operator() D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Rewrite\PatternApplicator.cpp:212:0
    #13 0x00007ff65209c275 mlir::PatternApplicator::matchAndRewrite(class mlir::Operation *, class mlir::PatternRewriter &, class llvm::function_ref<(class mlir::Pattern const &)>, class llvm::function_ref<(class mlir::Pattern const &)>, class llvm::function_ref<(class mlir::Pattern const &)>) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Rewrite\PatternApplicator.cpp:233:0
    #14 0x00007ff650f1f91e `anonymous namespace'::GreedyPatternRewriteDriver::processWorklist D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Transforms\Utils\GreedyPatternRewriteDriver.cpp:617:0
    #15 0x00007ff650f220e2 llvm::function_ref<void __cdecl(void)>::callback_fn<<lambda_56efa1fe2231a48e07ce9bd5369059af> > D:\dev\projects\iree\third_party\llvm-project\llvm\include\llvm\ADT\STLFunctionalExtras.h:45:0
    #16 0x00007ff650f214ae `anonymous namespace'::RegionPatternRewriteDriver::simplify D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Transforms\Utils\GreedyPatternRewriteDriver.cpp:872:0
    #17 0x00007ff650f1d38e mlir::applyPatternsAndFoldGreedily(class mlir::Region &, class mlir::FrozenRewritePatternSet const &, class mlir::GreedyRewriteConfig, bool *) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Transforms\Utils\GreedyPatternRewriteDriver.cpp:920:0
    #18 0x00007ff651c78d1d mlir::iree_compiler::GlobalOptimization::`anonymous namespace'::DataLayoutPropagationPass::runOnOperation D:\dev\projects\iree\compiler\src\iree\compiler\GlobalOptimization\DataLayoutPropagation.cpp:31:0
    #19 0x00007ff64e0cead0 llvm::function_ref<void __cdecl(void)>::callback_fn<<lambda_e8f8990a45bf3495636c03506b9db479> > D:\dev\projects\iree\third_party\llvm-project\llvm\include\llvm\ADT\STLFunctionalExtras.h:45:0
    #20 0x00007ff64e0c8637 mlir::detail::OpToOpPassAdaptor::run(class mlir::Pass *, class mlir::Operation *, class mlir::AnalysisManager, bool, unsigned int) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:533:0
    #21 0x00007ff64e0c883d mlir::detail::OpToOpPassAdaptor::runPipeline(class mlir::OpPassManager &, class mlir::Operation *, class mlir::AnalysisManager, bool, unsigned int, class mlir::PassInstrumentor *, struct mlir::PassInstrumentation::PipelineParentInfo const *) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:593:0
    #22 0x00007ff64e0c77bb mlir::detail::OpToOpPassAdaptor::runOnOperationImpl(bool) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:734:0
    #23 0x00007ff64e0ceb23 llvm::function_ref<void __cdecl(void)>::callback_fn<<lambda_e8f8990a45bf3495636c03506b9db479> > D:\dev\projects\iree\third_party\llvm-project\llvm\include\llvm\ADT\STLFunctionalExtras.h:45:0
    #24 0x00007ff64e0c8637 mlir::detail::OpToOpPassAdaptor::run(class mlir::Pass *, class mlir::Operation *, class mlir::AnalysisManager, bool, unsigned int) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:533:0
    #25 0x00007ff64e0c883d mlir::detail::OpToOpPassAdaptor::runPipeline(class mlir::OpPassManager &, class mlir::Operation *, class mlir::AnalysisManager, bool, unsigned int, class mlir::PassInstrumentor *, struct mlir::PassInstrumentation::PipelineParentInfo const *) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:593:0
    #26 0x00007ff64e0c6d7b mlir::PassManager::runPasses(class mlir::Operation *, class mlir::AnalysisManager) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:904:0
    #27 0x00007ff64e0c6b3e mlir::PassManager::run(class mlir::Operation *) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:883:0
    #28 0x00007ff64dca71c7 mlir::iree_compiler::embed::`anonymous namespace'::Invocation::runPipeline D:\dev\projects\iree\compiler\src\iree\compiler\API\Internal\CompilerDriver.cpp:995:0
    #29 0x00007ff64dc657ac <lambda_139d4d9eb9ed714e768e1c22e93f7b10>::operator() D:\dev\projects\iree\compiler\src\iree\compiler\Tools\iree_compile_lib.cc:254:0
    #30 0x00007ff64dc5ba18 mlir::iree_compiler::runIreecMain(int, char **) D:\dev\projects\iree\compiler\src\iree\compiler\Tools\iree_compile_lib.cc:355:0
    #31 0x00007ff658023d34 __scrt_common_main_seh d:\a01\_work\43\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl:288:0
    #32 0x00007ffe7e987344 (C:\WINDOWS\System32\KERNEL32.DLL+0x17344)
    #33 0x00007ffe7f9bcc91 (C:\WINDOWS\SYSTEM32\ntdll.dll+0x4cc91)
    
  • Here's a bit of printf debugging:

    // 100s of these, which are fine
    // Calling bubbleUpPackOpThroughCollapseShape with tensor::CollapseShapeOp:
    %collapsed_116 = tensor.collapse_shape %114 [[0, 1], [2], [3]] : tensor<4x32x100x?xf16> into tensor<128x100x?xf16>
    // ...and tensor::PackOp:
    %pack_119 = tensor.pack %collapsed_116 padding_value(%cst_27 : f16) outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %117 : tensor<128x100x?xf16> -> tensor<128x?x100x8x1xf16>
    
    // crash right after this
    // Calling bubbleUpPackOpThroughCollapseShape with tensor::CollapseShapeOp:
    %collapsed_2681 = tensor.collapse_shape %expanded_2680 [[0], [1, 2], [3]] : tensor<4x32x1x100xf16> into tensor<4x32x100xf16>
    // ...and tensor::PackOp:
    %pack_2682 = tensor.pack %collapsed_2681 inner_dims_pos = [0, 2] inner_tiles = [1, 1] into %2209 : tensor<4x32x100xf16> -> tensor<4x32x100x1x1xf16>
  • Here's the IR before we call into this code and crash (13000 lines, can try reducing): https://gist.github.com/ScottTodd/d5f9721307e78cada067a81e60a471c0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I revert this PR locally, the crash goes away.

Interesting, looking at the stack dump it's calling bubbleUpPackOpThroughCollapseShape which is unrelated and should be untouched by this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, looking at the stack dump it's calling bubbleUpPackOpThroughCollapseShape which is unrelated and should be untouched by this PR.

There could be few issues. The propagation through expand_shape op changes the graph and trigger the failure. The issue could be either in expand_shape patterns or collapse_shape patterns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First guess is that this PR bubbles pack through expand_shape ops that acted as a barrier before and now it allows more bubbling to occur.
So, it either bubbled pack through some expand that it shouldn't have or exposed an edge case in the collapse_shape part of code.

Copy link

@AmosLewis AmosLewis Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got same issue for mit-b0 and 3 more models. The issue could be fixed by revert this commit. Here is the stacktrace

(gdb) bt
#0  __pthread_kill_implementation (no_tid=0, signo=6, threadid=140737352719616) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (signo=6, threadid=140737352719616) at ./nptl/pthread_kill.c:78
#2  __GI___pthread_kill (threadid=140737352719616, signo=signo@entry=6) at ./nptl/pthread_kill.c:89
#3  0x00007fffdd442476 in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffdd4287f3 in __GI_abort () at ./stdlib/abort.c:79
#5  0x00007fffdd42871b in __assert_fail_base (fmt=0x7fffdd5dd130 "%s%s%s:%u: %s%sAssertion `%s' failed.\n%n", 
    assertion=0x7fffe22dc65a "input.size() == permutation.size() && \"expected input rank to equal permutation rank\"", 
    file=0x7fffe12f935b "iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h", line=204, function=<optimized out>) at ./assert/assert.c:92
#6  0x00007fffdd439e96 in __GI___assert_fail (assertion=0x7fffe22dc65a "input.size() == permutation.size() && \"expected input rank to equal permutation rank\"", 
    file=0x7fffe12f935b "iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h", line=204, 
    function=0x7fffe23b6a01 "SmallVector<T> mlir::applyPermutation(ArrayRef<T>, ArrayRef<int64_t>) [T = llvm::SmallVector<long, 2>]") at ./assert/assert.c:101
#7  0x00007fffec368e98 in mlir::applyPermutation<llvm::SmallVector<long, 2u> > (input=..., permutation=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h:203

#8  0x00007fffec368dd9 in mlir::applyPermutation<llvm::SmallVector<long, 2u> > (input=..., permutation=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h:214
#9  0x00007fffec368979 in mlir::applyPermutationToVector<llvm::SmallVector<long, 2u>, 1u> (inVec=..., permutation=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h:225
#10 0x00007fffef5599c5 in (anonymous namespace)::applyPermutationAndReindexReassoc (reassocIndices=..., permutation=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp:608
#11 0x00007fffef5594e6 in (anonymous namespace)::bubbleUpPackOpThroughCollapseShape (collapseOp=..., packOp=..., rewriter=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp:685
#12 0x00007fffef559057 in (anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}::operat--Type <RET> for more, q to quit, c to continue without paging--
or()(mlir::tensor::CollapseShapeOp) const (this=0x7fffffff8fc0, op=...) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp:851
#13 0x00007fffef558fef in llvm::TypeSwitch<mlir::Operation*, mlir::LogicalResult>::Case<mlir::tensor::CollapseShapeOp, (anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}>((anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}&&) (this=0x7fffffff8fd0, caseFn=...) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/TypeSwitch.h:102
#14 0x00007fffef558b45 in llvm::detail::TypeSwitchBase<llvm::TypeSwitch<mlir::Operation*, mlir::LogicalResult>, mlir::Operation*>::Case<(anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}>((anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}&&) (this=0x7fffffff8fd0, caseFn=...) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/TypeSwitch.h:60
#15 0x00007fffef558a91 in (anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite (this=0x5555557239d0, packOp=..., rewriter=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp:850
#16 0x00007fffec3b87eb in mlir::detail::OpOrInterfaceRewritePatternBase<mlir::tensor::PackOp>::matchAndRewrite (this=0x5555557239d0, op=0x555556649190, rewriter=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/PatternMatch.h:331
#17 0x00007ffff16cf52e in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_0::operator()() const (this=0x7fffffff92b0)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:212
#18 0x00007ffff16cf385 in llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_0>(long) (callable=140737488327344)
    at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
#19 0x00007fffe93074a9 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffff91f0) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#20 0x00007ffff16d0d55 in mlir::MLIRContext::executeAction<mlir::ApplyPatternAction, mlir::Pattern const&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pattern const&) (
    this=0x5555555ec710, actionFn=..., irUnits=..., args=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275
#21 0x00007ffff16cde27 in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (this=0x7fffffff9ed0, op=0x555556649190, rewriter=..., canApply=..., onFailure=..., onSuccess=...)
--Type <RET> for more, q to quit, c to continue without paging--
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:195
#22 0x00007ffff16898db in (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist (this=0x7fffffff9dd0)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:615
#23 0x00007ffff1688b61 in (anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) &&::$_2::operator()() const (this=0x7fffffff9c80)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:874
#24 0x00007ffff1688b35 in llvm::function_ref<void ()>::callback_fn<(anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) &&::$_2>(long) (callable=140737488329856)
    at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
#25 0x00007fffe93074a9 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffff9c20) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#26 0x00007ffff1688285 in mlir::MLIRContext::executeAction<(anonymous namespace)::GreedyPatternRewriteIteration, long&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, long&) (
    this=0x5555555ec710, actionFn=..., irUnits=..., args=@0x7fffffff9d88: 2) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275
#27 0x00007ffff168670e in (anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) && (this=0x7fffffff9dd0, changed=0x7fffffff9fd7)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:872
#28 0x00007ffff16863f7 in mlir::applyPatternsAndFoldGreedily (region=..., patterns=..., config=..., changed=0x7fffffff9fd7)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:919
#29 0x00007fffe924a105 in mlir::applyPatternsAndFoldGreedily (op=0x5555557d1cf0, patterns=..., config=..., changed=0x0)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h:159
#30 0x00007fffec31f23b in mlir::iree_compiler::GlobalOptimization::(anonymous namespace)::DataLayoutPropagationPass::runOnOperation (this=0x55555663f1d0)
    at /home/chi/src/iree/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp:31
#31 0x00007fffe980335b in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const (this=0x7fffffffa428)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:527
#32 0x00007fffe98032f5 in llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) (
    callable=140737488331816) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
--Type <RET> for more, q to quit, c to continue without paging--
#33 0x00007fffe93074a9 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffffa3b0) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#34 0x00007fffe9806175 in mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) (this=0x5555555ec710, 
    actionFn=..., irUnits=..., args=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275
#35 0x00007fffe97feab3 in mlir::detail::OpToOpPassAdaptor::run (pass=0x55555663f1d0, op=0x5555557d1cf0, am=..., verifyPasses=true, parentInitGeneration=1)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:521
#36 0x00007fffe97ff034 in mlir::detail::OpToOpPassAdaptor::runPipeline (pm=..., op=0x5555557d1cf0, am=..., verifyPasses=true, parentInitGeneration=1, instrumentor=0x5555557f5ee0, 
    parentInfo=0x7fffffffaae0) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:593
#37 0x00007fffe98045e5 in mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0::operator()(mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo&) const (
    this=0x7fffffffaa78, opInfo=...) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:813
#38 0x00007fffe9804269 in mlir::failableParallelForEach<__gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0&>(mlir::MLIRContext*, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0&) (context=0x5555555ec710, begin={passManagerIdx = 0, op = 0x5555557d1cf0, am = {impl = 0x5555557d2ce0}}, 
    end={passManagerIdx = 129, op = 0x55555578f6a0, am = {impl = 0x55555578ef00}}, func=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/Threading.h:46
#39 0x00007fffe98002eb in mlir::failableParallelForEach<std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> >&, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0&>(mlir::MLIRContext*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> >&, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0&) (
    context=0x5555555ec710, range=std::vector of length 1, capacity 1 = {...}, func=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/Threading.h:92
#40 0x00007fffe97ffbfa in mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl (this=0x555555804b30, verifyPasses=true) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:823
#41 0x00007fffe97ff727 in mlir::detail::OpToOpPassAdaptor::runOnOperation (this=0x555555804b30, verifyPasses=true) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:714
--Type <RET> for more, q to quit, c to continue without paging--
#42 0x00007fffe9803346 in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const (this=0x7fffffffade8)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:525
#43 0x00007fffe98032f5 in llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) (
    callable=140737488334312) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
#44 0x00007fffe93074a9 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffffad70) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#45 0x00007fffe9806175 in mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) (this=0x5555555ec710, 
    actionFn=..., irUnits=..., args=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275
#46 0x00007fffe97feab3 in mlir::detail::OpToOpPassAdaptor::run (pass=0x555555804b30, op=0x5555557f5690, am=..., verifyPasses=true, parentInitGeneration=1)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:521
#47 0x00007fffe97ff034 in mlir::detail::OpToOpPassAdaptor::runPipeline (pm=..., op=0x5555557f5690, am=..., verifyPasses=true, parentInitGeneration=1, instrumentor=0x0, parentInfo=0x0)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:593
#48 0x00007fffe9800a78 in mlir::PassManager::runPasses (this=0x555555739cb0, op=0x5555557f5690, am=...) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:904
#49 0x00007fffe98009a2 in mlir::PassManager::run (this=0x555555739cb0, op=0x5555557f5690) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:884
#50 0x00007fffe9251cba in mlir::iree_compiler::embed::(anonymous namespace)::Invocation::runPipeline (this=0x55555565add0, pipeline=IREE_COMPILER_PIPELINE_STD)
    at /home/chi/src/iree/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp:995
#51 0x00007fffe9251593 in ireeCompilerInvocationPipeline (inv=0x55555565add0, pipeline=IREE_COMPILER_PIPELINE_STD)
    at /home/chi/src/iree/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp:1430
#52 0x00007fffe978a88e in mlir::iree_compiler::runIreecMain(int, char**)::$_2::operator()(iree_compiler_source_t*) const (this=0x7fffffffc0e8, source=0x55555565aba0)
    at /home/chi/src/iree/compiler/src/iree/compiler/Tools/iree_compile_lib.cc:254
#53 0x00007fffe9789d1e in mlir::iree_compiler::runIreecMain (argc=4, argv=0x7fffffffdbd8) at /home/chi/src/iree/compiler/src/iree/compiler/Tools/iree_compile_lib.cc:355
#54 0x00007fffe929baab in ireeCompilerRunMain (argc=4, argv=0x7fffffffdbd8) at /home/chi/src/iree/compiler/src/iree/compiler/API/Internal/IREECompileToolEntryPoint.cpp:12
#55 0x00005555555557a2 in main (argc=4, argv=0x7fffffffdbd8) at /home/chi/src/iree/tools/iree-compile-main.cc:9

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adam-smnk I encountered the same bug on my end.
The root cause for the bug is that outerDimsPerm is an optional attribute that could be empty. However, when calling applyPermutationAndReindexReassoc, it assumes outerDimsPerm to be non-empty. One possible solution is to fill outerDimsPerm with default values ([0, 1, 2, ...]).

adam-smnk pushed a commit that referenced this pull request Jun 27, 2024
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants