Skip to content

[mlir] Add sub-byte type emulation support for memref.collapse_shape #89962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 26, 2024

Conversation

dcaballe
Copy link
Contributor

This PR adds support for memref.collapse_shape to sub-byte type emulation. The memref.collapse_shape becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).

This PR add support for `memref.collapse_shape` to sub-byte type
emulation. The `memref.collapse_shape` becomes a no-opt given that we
are flattening the memref as part of the emulation (i.e., we are
collapsing all the dimensions).
@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR adds support for memref.collapse_shape to sub-byte type emulation. The memref.collapse_shape becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).


Full diff: https://github.com/llvm/llvm-project/pull/89962.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+29-3)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+20)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 4449733f0daf06..77c108aab48070 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
 #include <cassert>
 #include <type_traits>
 
@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCollapseShape
+//===----------------------------------------------------------------------===//
+
+/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
+/// that we flatten memrefs to a single dimension as part of the emulation and
+/// there is no dimension to collapse any further.
+struct ConvertMemRefCollapseShape final
+    : OpConversionPattern<memref::CollapseShapeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value srcVal = adaptor.getSrc();
+    auto newTy = dyn_cast<MemRefType>(srcVal.getType());
+    if (!newTy)
+      return failure();
+
+    if (newTy.getRank() != 1)
+      return failure();
+
+    rewriter.replaceOp(collapseShapeOp, srcVal);
+    return success();
+  }
+};
+
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
 
   // Populate `memref.*` conversion patterns.
   patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
-               ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
+               ConvertMemRefAllocation<memref::AllocaOp>,
+               ConvertMemRefCollapseShape, ConvertMemRefLoad,
                ConvertMemrefStore, ConvertMemRefAssumeAlignment,
                ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
       typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index fd37b7ff0a2713..435dcc944778db 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
 //       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
 //       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
 //       CHECK32:   return
+
+// -----
+
+func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
+  %arr = memref.alloc() : memref<32x8x128xi4>
+  %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
+  %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4>
+  return %1 : i4
+}
+
+// CHECK-LABEL:   func.func @memref_collapse_shape_i4(
+//       CHECK:     %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
+//   CHECK-NOT:     memref.collapse_shape
+//       CHECK:     memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
+
+// CHECK32-LABEL:   func.func @memref_collapse_shape_i4(
+//       CHECK32:     %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
+//   CHECK32-NOT:     memref.collapse_shape
+//       CHECK32:     memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thank you!

This PR adds a new pattern to the set of patterns used to resolve the
offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`,
the new pattern resolves strided_metadata(collapse_shape) directly,
without introduce a reshape_cast op.
@dcaballe dcaballe merged commit 571831a into llvm:main Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants