Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][MemRef] Add more ops to narrow type support, strided metadata expansion #102228

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,46 @@ struct ConvertMemRefAssumeAlignment final
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefCopy
//===----------------------------------------------------------------------===//

struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
if (maybeRankedSource && maybeRankedDest &&
maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Please add { } for multi-line statements.

return rewriter.notifyMatchFailure(
op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
"and {1}) is currently unimplemented",
maybeRankedSource.getLayout(),
maybeRankedDest.getLayout()));
rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
adaptor.getTarget());
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefDealloc
//===----------------------------------------------------------------------===//

struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefLoad
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -300,6 +340,30 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefMemorySpaceCast
//===----------------------------------------------------------------------===//

struct ConvertMemRefMemorySpaceCast final
: OpConversionPattern<memref::MemorySpaceCastOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newTy = getTypeConverter()->convertType(op.getDest().getType());
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
op.getDest().getType()));
}

rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
adaptor.getSource());
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefReinterpretCast
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -490,6 +554,28 @@ struct ConvertMemRefCollapseShape final
}
};

/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
/// that we flatten memrefs to a single dimension as part of the emulation and
/// the expansion would just have been undone.
struct ConvertMemRefExpandShape final
: OpConversionPattern<memref::ExpandShapeOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value srcVal = adaptor.getSrc();
auto newTy = dyn_cast<MemRefType>(srcVal.getType());
if (!newTy)
return failure();

if (newTy.getRank() != 1)
return failure();

rewriter.replaceOp(expandShapeOp, srcVal);
return success();
}
};
} // end anonymous namespace

//===----------------------------------------------------------------------===//
Expand All @@ -502,9 +588,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(

// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
ConvertMemRefAllocation<memref::AllocaOp>,
ConvertMemRefCollapseShape, ConvertMemRefLoad,
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
ConvertMemRefDealloc, ConvertMemRefCollapseShape,
ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
Expand Down
87 changes: 87 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,41 @@ struct ExtractStridedMetadataOpCollapseShapeFolder
}
};

/// Pattern to replace `extract_strided_metadata(expand_shape)`
/// with the results of computing the sizes and strides on the expanded shape
/// and dividing up dimensions into static and dynamic parts as needed.
struct ExtractStridedMetadataOpExpandShapeFolder
: OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
if (!expandShapeOp)
return failure();

FailureOr<StridedMetadata> stridedMetadata =
resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
if (failed(stridedMetadata)) {
return rewriter.notifyMatchFailure(
op, "failed to resolve metadata in terms of source expand_shape op");
}

Location loc = expandShapeOp.getLoc();
SmallVector<Value> results;
results.push_back(stridedMetadata->basePtr);
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
stridedMetadata->offset));
results.append(
getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
stridedMetadata->strides));
rewriter.replaceOp(op, results);
return success();
}
};

/// Replace `base, offset, sizes, strides =
/// extract_strided_metadata(allocLikeOp)`
///
Expand Down Expand Up @@ -1060,6 +1095,54 @@ class ExtractStridedMetadataOpCastFolder
}
};

/// Replace `base, offset, sizes, strides = extract_strided_metadata(
/// memory_space_cast(src) to dstTy)`
/// with
/// ```
/// oldBase, offset, sizes, strides = extract_strided_metadata(src)
/// destBaseTy = type(oldBase) with memory space from destTy
/// base = memory_space_cast(oldBase) to destBaseTy
/// ```
///
/// In other words, propagate metadata extraction accross memory space casts.
class ExtractStridedMetadataOpMemorySpaceCastFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
PatternRewriter &rewriter) const override {
Location loc = extractStridedMetadataOp.getLoc();
Value source = extractStridedMetadataOp.getSource();
auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
if (!memSpaceCastOp)
return failure();
auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(
loc, memSpaceCastOp.getSource());
SmallVector<Value> results(newExtractStridedMetadata.getResults());
// As with most other strided metadata rewrite patterns, don't introduce
// a use of the base pointer where non existed. This needs to happen here,
// as opposed to in later dead-code elimination, because these patterns are
// sometimes used during dialect conversion (see EmulateNarrowType, for
// example), so adding spurious usages would cause a pre-legalization value
// to be live that would be dead had this pattern not run.
if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check for this here? This seems like something that DCE should be able to handle after the fact. So ignore the use_empty() case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with other rewrite patterns in this file, we do need to check this here. I've added a longer explanatory comment as to why

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that makes sense... Thanks for the explanation.

auto baseBuffer = results[0];
auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
MemRefType::Builder newTypeBuilder(baseBufferType);
newTypeBuilder.setMemorySpace(
memSpaceCastOp.getResult().getType().getMemorySpace());
results[0] = rewriter.create<memref::MemorySpaceCastOp>(
loc, Type{newTypeBuilder}, baseBuffer);
} else {
results[0] = nullptr;
}
rewriter.replaceOp(extractStridedMetadataOp, results);
return success();
}
};

/// Replace `base, offset =
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
/// With
Expand Down Expand Up @@ -1099,11 +1182,13 @@ void memref::populateExpandStridedMetadataPatterns(
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpExpandShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand All @@ -1113,11 +1198,13 @@ void memref::populateResolveExtractStridedMetadataPatterns(
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpExpandShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ func.func @memref_i8() -> i8 {
%c3 = arith.constant 3 : index
%m = memref.alloc() : memref<4xi8, 1>
%v = memref.load %m[%c3] : memref<4xi8, 1>
memref.dealloc %m : memref<4xi8, 1>
return %v : i8
}
// CHECK-LABEL: func @memref_i8()
// CHECK: %[[M:.+]] = memref.alloc() : memref<4xi8, 1>
// CHECK-NEXT: %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1>
// CHECK-NEXT: memref.dealloc %[[M]]
// CHECK-NEXT: return %[[V]]

// CHECK32-LABEL: func @memref_i8()
Expand All @@ -21,6 +23,7 @@ func.func @memref_i8() -> i8 {
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]]
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8
// CHECK32-NEXT: memref.dealloc %[[M]]
// CHECK32-NEXT: return %[[TRUNC]]

// -----
Expand Down Expand Up @@ -485,3 +488,68 @@ func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
// CHECK32-NOT: memref.collapse_shape
// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>

// -----

func.func @memref_expand_shape_i4(%idx0 : index, %idx1 : index, %idx2 : index) -> i4 {
%arr = memref.alloc() : memref<256x128xi4>
%expand = memref.expand_shape %arr[[0, 1], [2]] output_shape [32, 8, 128] : memref<256x128xi4> into memref<32x8x128xi4>
%1 = memref.load %expand[%idx0, %idx1, %idx2] : memref<32x8x128xi4>
return %1 : i4
}

// CHECK-LABEL: func.func @memref_expand_shape_i4(
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
// CHECK-NOT: memref.expand_shape
// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>

// CHECK32-LABEL: func.func @memref_expand_shape_i4(
// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
// CHECK32-NOT: memref.expand_shape
// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>

// -----

func.func @memref_memory_space_cast_i4(%arg0: memref<32x128xi4, 1>) -> memref<32x128xi4> {
%cast = memref.memory_space_cast %arg0 : memref<32x128xi4, 1> to memref<32x128xi4>
return %cast : memref<32x128xi4>
}

// CHECK-LABEL: func.func @memref_memory_space_cast_i4(
// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>
// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<2048xi8, 1> to memref<2048xi8>
// CHECK: return %[[CAST]]

// CHECK32-LABEL: func.func @memref_memory_space_cast_i4(
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>
// CHECK32: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<512xi32, 1> to memref<512xi32>
// CHECK32: return %[[CAST]]

// -----

func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>) {
memref.copy %arg0, %arg1 : memref<32x128xi4, 1> to memref<32x128xi4>
return
}

// CHECK-LABEL: func.func @memref_copy_i4(
// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>, %[[ARG1:.*]]: memref<2048xi8>
// CHECK: memref.copy %[[ARG0]], %[[ARG1]]
// CHECK: return

// CHECK32-LABEL: func.func @memref_copy_i4(
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
// CHECK32: memref.copy %[[ARG0]], %[[ARG1]]
// CHECK32: return

// -----

!colMajor = memref<8x8xi4, strided<[1, 8]>>
func.func @copy_distinct_layouts(%idx : index) -> i4 {
%c0 = arith.constant 0 : index
%arr = memref.alloc() : memref<8x8xi4>
%arr2 = memref.alloc() : !colMajor
// expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
%ld = memref.load %arr2[%c0, %c0] : !colMajor
return %ld : i4
}
38 changes: 38 additions & 0 deletions mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1553,3 +1553,41 @@ func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index

// -----

func.func @extract_strided_metadata_of_memory_space_cast(%base: memref<20xf32>)
-> (memref<f32, 1>, index, index, index) {

%memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>

%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
memref<20xf32, 1> -> memref<f32, 1>, index, index, index

return %base_buffer, %offset, %size, %stride :
memref<f32, 1>, index, index, index
}

// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[BASE]]
// CHECK: return %[[CAST]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32, 1>, index, index, index

// -----

func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<20xf32>)
-> (index, index, index) {

%memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>

%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
memref<20xf32, 1> -> memref<f32, 1>, index, index, index

return %offset, %size, %stride : index, index, index
}

// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
// CHECK-NOT: memref.memory_space_cast
Loading