Skip to content

[mlir][Transforms] Dialect Conversion: Simplify block conversion API #94866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,23 @@ class TypeConverter {
From the perspective of type conversion, the types of block arguments are a bit
special. Throughout the conversion process, blocks may move between regions of
different operations. Given this, the conversion of the types for blocks must be
done explicitly via a conversion pattern. To convert the types of block
arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
be invoked; `convertRegionTypes`. This hook uses a provided type converter to
apply type conversions to all blocks within a given region, and all blocks that
move into that region. As noted above, the conversions performed by this method
use the argument materialization hook on the `TypeConverter`. This hook also
takes an optional `TypeConverter::SignatureConversion` parameter that applies a
custom conversion to the entry block of the region. The types of the entry block
arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
AffineForOp, etc. To convert the signature of just the region entry block, and
not any other blocks within the region, the `applySignatureConversion` hook may
be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
can be built programmatically:
done explicitly via a conversion pattern.

To convert the types of block arguments within a Region, a custom hook on the
`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
uses a provided type converter to apply type conversions to all blocks of a
given region. As noted above, the conversions performed by this method use the
argument materialization hook on the `TypeConverter`. This hook also takes an
optional `TypeConverter::SignatureConversion` parameter that applies a custom
conversion to the entry block of the region. The types of the entry block
arguments are often tied semantically to the operation, e.g.,
`func::FuncOp`, `AffineForOp`, etc.

To convert the signature of just one given block, the
`applySignatureConversion` hook can be used.

A signature conversion, `TypeConverter::SignatureConversion`, can be built
programmatically:

```c++
class SignatureConversion {
Expand Down
49 changes: 25 additions & 24 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ class TypeConverter {
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
template <typename TargetType> TargetType convertType(Type t) const {
template <typename TargetType>
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}

Expand Down Expand Up @@ -661,42 +662,42 @@ class ConversionPatternRewriter final : public PatternRewriter {
public:
~ConversionPatternRewriter() override;

/// Apply a signature conversion to the entry block of the given region. This
/// replaces the entry block with a new block containing the updated
/// signature. The new entry block to the region is returned for convenience.
/// If no block argument types are changing, the entry original block will be
/// Apply a signature conversion to given block. This replaces the block with
/// a new block containing the updated signature. The operations of the given
/// block are inlined into the newly-created block, which is returned.
///
/// If no block argument types are changing, the original block will be
/// left in place and returned.
///
/// If provided, `converter` will be used for any materializations.
/// A signature converison must be provided. (Type converters can construct
/// a signature conversion with `convertBlockSignature`.)
///
/// Optionally, a type converter can be provided to build materializations.
/// Note: If no type converter was provided or the type converter does not
/// specify any suitable argument/target materialization rules, the dialect
/// conversion may fail to legalize unresolved materializations.
Block *
applySignatureConversion(Region *region,
applySignatureConversion(Block *block,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter = nullptr);

/// Convert the types of block arguments within the given region. This
/// Apply a signature conversion to each block in the given region. This
/// replaces each block with a new block containing the updated signature. If
/// an updated signature would match the current signature, the respective
/// block is left in place as is.
/// block is left in place as is. (See `applySignatureConversion` for
/// details.) The new entry block of the region is returned.
///
/// SignatureConversions are computed with the specified type converter.
/// This function returns "failure" if the type converter failed to compute
/// a SignatureConversion for at least one block.
///
/// The entry block may have a special conversion if `entryConversion` is
/// provided. On success, the new entry block to the region is returned for
/// convenience. Otherwise, failure is returned.
/// Optionally, a special SignatureConversion can be specified for the entry
/// block. This is because the types of the entry block arguments are often
/// tied semantically to the operation.
FailureOr<Block *> convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);

/// Convert the types of block arguments within the given region except for
/// the entry region. This replaces each non-entry block with a new block
/// containing the updated signature. If an updated signature would match the
/// current signature, the respective block is left in place as is.
///
/// If special conversion behavior is needed for the non-entry blocks (for
/// example, we need to convert only a subset of a BB arguments), such
/// behavior can be specified in blockConversions.
LogicalResult convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions);

/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
signatureConverter.remapInput(0, newIndVar);
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
signatureConverter.remapInput(i, header->getArgument(i));
body = rewriter.applySignatureConversion(&forOp.getRegion(),
body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
signatureConverter);

// Move the blocks from the forOp into the loopOp. This is the body of the
Expand Down
20 changes: 8 additions & 12 deletions mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
ConversionPatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
Region &region = op.getFunctionBody();
SmallVector<TypeConverter::SignatureConversion, 2> conversions;

for (Block &block : llvm::drop_begin(region, 1)) {
conversions.emplace_back(block.getNumArguments());
TypeConverter::SignatureConversion &back = conversions.back();
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
TypeConverter::SignatureConversion conversion(
/*numOrigInputs=*/block.getNumArguments());

for (BlockArgument blockArgument : block.getArguments()) {
int idx = blockArgument.getArgNumber();

if (blockArgsToDetensor.count(blockArgument))
back.addInputs(idx, {getTypeConverter()->convertType(
block.getArgumentTypes()[idx])});
conversion.addInputs(idx, {getTypeConverter()->convertType(
block.getArgumentTypes()[idx])});
else
back.addInputs(idx, {block.getArgumentTypes()[idx]});
conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
}
}

if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
conversions))) {
rewriter.cancelOpModification(op);
return failure();
rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
}

rewriter.finalizeOpModification(op);
Expand Down
123 changes: 27 additions & 96 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// Type Conversion
//===--------------------------------------------------------------------===//

/// Attempt to convert the signature of the given block, if successful a new
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
FailureOr<Block *> convertBlockSignature(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);

/// Convert the types of non-entry block arguments within the given region.
LogicalResult convertNonEntryRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});

/// Apply a signature conversion on the given region, using `converter` for
/// materializations if not null.
Block *
applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter);

/// Convert the types of block arguments within the given region.
FailureOr<Block *>
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
Expand Down Expand Up @@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
// Type Conversion

FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
if (conversion)
return applySignatureConversion(rewriter, block, converter, *conversion);

// If a converter wasn't provided, and the block wasn't already converted,
// there is nothing we can do.
if (!converter)
return failure();

// Try to convert the signature for the block with the provided converter.
if (auto conversion = converter->convertBlockSignature(block))
return applySignatureConversion(rewriter, block, converter, *conversion);
return failure();
}

Block *ConversionPatternRewriterImpl::applySignatureConversion(
ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
if (!region->empty())
return *convertBlockSignature(rewriter, &region->front(), converter,
&conversion);
return nullptr;
}

FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
Expand All @@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (region->empty())
return nullptr;

if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
return failure();

FailureOr<Block *> newEntry = convertBlockSignature(
rewriter, &region->front(), &converter, entryConversion);
return newEntry;
}

LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
regionToConverter[region] = &converter;
if (region->empty())
return success();

// Convert the arguments of each block within the region.
int blockIdx = 0;
assert((blockConversions.empty() ||
blockConversions.size() == region->getBlocks().size() - 1) &&
"expected either to provide no SignatureConversions at all or to "
"provide a SignatureConversion for each non-entry block");

// Convert the arguments of each non-entry block within the region.
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
TypeConverter::SignatureConversion *blockConversion =
blockConversions.empty()
? nullptr
: const_cast<TypeConverter::SignatureConversion *>(
&blockConversions[blockIdx++]);

if (failed(convertBlockSignature(rewriter, &block, &converter,
blockConversion)))
// Compute the signature for the block with the provided converter.
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&block);
if (!conversion)
return failure();
}
return success();
// Convert the block with the computed signature.
applySignatureConversion(rewriter, &block, &converter, *conversion);
}

// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
return applySignatureConversion(rewriter, &region->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&region->front());
if (!conversion)
return failure();
return applySignatureConversion(rewriter, &region->front(), &converter,
*conversion);
}

Block *ConversionPatternRewriterImpl::applySignatureConversion(
Expand Down Expand Up @@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
}

Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
Block *block, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->applySignatureConversion(*this, region, conversion, converter);
return impl->applySignatureConversion(*this, block, converter, conversion);
}

FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
Expand All @@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
return impl->convertRegionTypes(*this, region, converter, entryConversion);
}

LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->convertNonEntryRegionTypes(*this, region, converter,
blockConversions);
}

void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
LLVM_DEBUG({
Expand Down Expand Up @@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(block);
if (!conversion) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
}
impl.applySignatureConversion(rewriter, block, converter, *conversion);
continue;
}

Expand Down
5 changes: 3 additions & 2 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1516,8 +1516,9 @@ struct TestTestSignatureConversionNoConverter
if (failed(
converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
return failure();
rewriter.modifyOpInPlace(
op, [&] { rewriter.applySignatureConversion(&region, result); });
rewriter.modifyOpInPlace(op, [&] {
rewriter.applySignatureConversion(&region.front(), result);
});
return success();
}

Expand Down
Loading