-
Notifications
You must be signed in to change notification settings - Fork 13.8k
[mlir] Add sub-byte type emulation support for memref.collapse_shape
#89962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR add support for `memref.collapse_shape` to sub-byte type emulation. The `memref.collapse_shape` becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR adds support for Full diff: https://github.com/llvm/llvm-project/pull/89962.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 4449733f0daf06..77c108aab48070 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -13,7 +13,6 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <type_traits>
@@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCollapseShape
+//===----------------------------------------------------------------------===//
+
+/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
+/// that we flatten memrefs to a single dimension as part of the emulation and
+/// there is no dimension to collapse any further.
+struct ConvertMemRefCollapseShape final
+ : OpConversionPattern<memref::CollapseShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value srcVal = adaptor.getSrc();
+ auto newTy = dyn_cast<MemRefType>(srcVal.getType());
+ if (!newTy)
+ return failure();
+
+ if (newTy.getRank() != 1)
+ return failure();
+
+ rewriter.replaceOp(collapseShapeOp, srcVal);
+ return success();
+ }
+};
+
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
- ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
+ ConvertMemRefAllocation<memref::AllocaOp>,
+ ConvertMemRefCollapseShape, ConvertMemRefLoad,
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index fd37b7ff0a2713..435dcc944778db 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
// CHECK32: return
+
+// -----
+
+func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
+ %arr = memref.alloc() : memref<32x8x128xi4>
+ %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
+ %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4>
+ return %1 : i4
+}
+
+// CHECK-LABEL: func.func @memref_collapse_shape_i4(
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
+// CHECK-NOT: memref.collapse_shape
+// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
+
+// CHECK32-LABEL: func.func @memref_collapse_shape_i4(
+// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
+// CHECK32-NOT: memref.collapse_shape
+// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thank you!
This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.
This PR adds support for
memref.collapse_shape
to sub-byte type emulation. Thememref.collapse_shape
becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions).