-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][MemRef] Add more ops to narrow type support, strided metadata expansion #102228
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -726,6 +726,41 @@ struct ExtractStridedMetadataOpCollapseShapeFolder | |
} | ||
}; | ||
|
||
/// Pattern to replace `extract_strided_metadata(expand_shape)` | ||
/// with the results of computing the sizes and strides on the expanded shape | ||
/// and dividing up dimensions into static and dynamic parts as needed. | ||
struct ExtractStridedMetadataOpExpandShapeFolder | ||
: OpRewritePattern<memref::ExtractStridedMetadataOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, | ||
PatternRewriter &rewriter) const override { | ||
auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>(); | ||
if (!expandShapeOp) | ||
return failure(); | ||
|
||
FailureOr<StridedMetadata> stridedMetadata = | ||
resolveReshapeStridedMetadata<memref::ExpandShapeOp>( | ||
rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides); | ||
if (failed(stridedMetadata)) { | ||
return rewriter.notifyMatchFailure( | ||
op, "failed to resolve metadata in terms of source expand_shape op"); | ||
} | ||
|
||
Location loc = expandShapeOp.getLoc(); | ||
SmallVector<Value> results; | ||
results.push_back(stridedMetadata->basePtr); | ||
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, | ||
stridedMetadata->offset)); | ||
results.append( | ||
getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); | ||
results.append(getValueOrCreateConstantIndexOp(rewriter, loc, | ||
stridedMetadata->strides)); | ||
rewriter.replaceOp(op, results); | ||
return success(); | ||
} | ||
}; | ||
|
||
/// Replace `base, offset, sizes, strides = | ||
/// extract_strided_metadata(allocLikeOp)` | ||
/// | ||
|
@@ -1060,6 +1095,54 @@ class ExtractStridedMetadataOpCastFolder | |
} | ||
}; | ||
|
||
/// Replace `base, offset, sizes, strides = extract_strided_metadata( | ||
/// memory_space_cast(src) to dstTy)` | ||
/// with | ||
/// ``` | ||
/// oldBase, offset, sizes, strides = extract_strided_metadata(src) | ||
/// destBaseTy = type(oldBase) with memory space from destTy | ||
/// base = memory_space_cast(oldBase) to destBaseTy | ||
/// ``` | ||
/// | ||
/// In other words, propagate metadata extraction accross memory space casts. | ||
class ExtractStridedMetadataOpMemorySpaceCastFolder | ||
: public OpRewritePattern<memref::ExtractStridedMetadataOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, | ||
PatternRewriter &rewriter) const override { | ||
Location loc = extractStridedMetadataOp.getLoc(); | ||
Value source = extractStridedMetadataOp.getSource(); | ||
auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>(); | ||
if (!memSpaceCastOp) | ||
return failure(); | ||
auto newExtractStridedMetadata = | ||
rewriter.create<memref::ExtractStridedMetadataOp>( | ||
loc, memSpaceCastOp.getSource()); | ||
SmallVector<Value> results(newExtractStridedMetadata.getResults()); | ||
// As with most other strided metadata rewrite patterns, don't introduce | ||
// a use of the base pointer where non existed. This needs to happen here, | ||
// as opposed to in later dead-code elimination, because these patterns are | ||
// sometimes used during dialect conversion (see EmulateNarrowType, for | ||
// example), so adding spurious usages would cause a pre-legalization value | ||
// to be live that would be dead had this pattern not run. | ||
if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to check for this here? This seems like something that DCE should be able to handle after the fact. So ignore the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As with other rewrite patterns in this file, we do need to check this here. I've added a longer explanatory comment as to why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, that makes sense... Thanks for the explanation. |
||
auto baseBuffer = results[0]; | ||
auto baseBufferType = cast<MemRefType>(baseBuffer.getType()); | ||
MemRefType::Builder newTypeBuilder(baseBufferType); | ||
newTypeBuilder.setMemorySpace( | ||
memSpaceCastOp.getResult().getType().getMemorySpace()); | ||
results[0] = rewriter.create<memref::MemorySpaceCastOp>( | ||
loc, Type{newTypeBuilder}, baseBuffer); | ||
} else { | ||
results[0] = nullptr; | ||
} | ||
rewriter.replaceOp(extractStridedMetadataOp, results); | ||
return success(); | ||
} | ||
}; | ||
|
||
/// Replace `base, offset = | ||
/// extract_strided_metadata(extract_strided_metadata(src)#0)` | ||
/// With | ||
|
@@ -1099,11 +1182,13 @@ void memref::populateExpandStridedMetadataPatterns( | |
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, | ||
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, | ||
ExtractStridedMetadataOpCollapseShapeFolder, | ||
ExtractStridedMetadataOpExpandShapeFolder, | ||
ExtractStridedMetadataOpGetGlobalFolder, | ||
RewriteExtractAlignedPointerAsIndexOfViewLikeOp, | ||
ExtractStridedMetadataOpReinterpretCastFolder, | ||
ExtractStridedMetadataOpSubviewFolder, | ||
ExtractStridedMetadataOpCastFolder, | ||
ExtractStridedMetadataOpMemorySpaceCastFolder, | ||
ExtractStridedMetadataOpExtractStridedMetadataFolder>( | ||
patterns.getContext()); | ||
} | ||
|
@@ -1113,11 +1198,13 @@ void memref::populateResolveExtractStridedMetadataPatterns( | |
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, | ||
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, | ||
ExtractStridedMetadataOpCollapseShapeFolder, | ||
ExtractStridedMetadataOpExpandShapeFolder, | ||
ExtractStridedMetadataOpGetGlobalFolder, | ||
ExtractStridedMetadataOpSubviewFolder, | ||
RewriteExtractAlignedPointerAsIndexOfViewLikeOp, | ||
ExtractStridedMetadataOpReinterpretCastFolder, | ||
ExtractStridedMetadataOpCastFolder, | ||
ExtractStridedMetadataOpMemorySpaceCastFolder, | ||
ExtractStridedMetadataOpExtractStridedMetadataFolder>( | ||
patterns.getContext()); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Please add
{
}
for multi-line statements.