Skip to content

[mlir][Transforms] Dialect Conversion: Simplify materialization fn result type #113031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ class TypeConverter {

/// All of the following materializations require function objects that are
/// convertible to the following form:
/// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
/// `Value(OpBuilder &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
/// must return a Value of the type `T` on success, an `std::nullopt` if
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.
/// must return a Value of the type `T` on success and `nullptr` if
/// it failed but other materialization should be attempted. Materialization
/// functions must be provided when a type conversion may persist after the
/// conversion has finished.
///
/// Note: Target materializations may optionally accept an additional Type
/// parameter, which is the original type of the SSA value.
Expand Down Expand Up @@ -335,14 +335,14 @@ class TypeConverter {
/// conversion.
///
/// Arguments: builder, result type, inputs, location
using MaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location)>;
using MaterializationCallbackFn =
std::function<Value(OpBuilder &, Type, ValueRange, Location)>;

/// The signature of the callback used to materialize a target conversion.
///
/// Arguments: builder, result type, inputs, location, original type
using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location, Type)>;
using TargetMaterializationCallbackFn =
std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;

/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
Expand Down Expand Up @@ -396,10 +396,10 @@ class TypeConverter {
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
if (T derivedType = dyn_cast<T>(resultType))
return callback(builder, derivedType, inputs, loc);
return std::nullopt;
return Value();
};
}

Expand All @@ -417,10 +417,10 @@ class TypeConverter {
wrapTargetMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc, Type originalType) -> std::optional<Value> {
Location loc, Type originalType) -> Value {
if (T derivedType = dyn_cast<T>(resultType))
return callback(builder, derivedType, inputs, loc, originalType);
return std::nullopt;
return Value();
};
}
/// With callback of form:
Expand All @@ -433,7 +433,7 @@ class TypeConverter {
return wrapTargetMaterialization<T>(
[callback = std::forward<FnT>(callback)](
OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
Type originalType) -> std::optional<Value> {
Type originalType) -> Value {
return callback(builder, resultType, inputs, loc);
});
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
auto addUnrealizedCast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
ValueRange inputs, Location loc) -> Value {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return std::optional<Value>(cast.getResult(0));
return cast.getResult(0);
};

addSourceMaterialization(addUnrealizedCast);
Expand Down
49 changes: 23 additions & 26 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,36 +158,35 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// original block argument type. The dialect conversion framework will then
// insert a target materialization from the original block argument type to
// a legal type.
addArgumentMaterialization(
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
return std::nullopt;
}
Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
resultType, inputs);
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
addArgumentMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc) {
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
return Value();
}
Value desc =
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) {
Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
if (!barePtr)
return std::nullopt;
return Value();
Block *block = barePtr.getOwner();
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
return std::nullopt;
return Value();
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
} else {
Expand All @@ -202,19 +201,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
Expand Down
8 changes: 3 additions & 5 deletions mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ using namespace mlir;

namespace {

std::optional<Value> materializeAsUnrealizedCast(OpBuilder &builder,
Type resultType,
ValueRange inputs,
Location loc) {
Value materializeAsUnrealizedCast(OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
/// This function is meant to handle the **compute** side; so it does not
/// involve storage classes in its logic. The storage side is expected to be
/// handled by MemRef conversion logic.
static std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
static Value castToSourceType(const spirv::TargetEnv &targetEnv,
OpBuilder &builder, Type type, ValueRange inputs,
Location loc) {
// We can only cast one value in SPIR-V.
if (inputs.size() != 1) {
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
Expand Down Expand Up @@ -1459,7 +1459,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return std::optional<Value>(cast.getResult(0));
return cast.getResult(0);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,7 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
addConversion(convertIterSpaceType);

addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) -> Value {
return builder
.create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
.getResult(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {

// Required by scf.for 1:N type conversion.
addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) -> Value {
if (!getSparseTensorEncoding(tp))
// Not a sparse tensor.
return std::nullopt;
return Value();
// Sparsifier knows how to cancel out these casts.
return genTuple(builder, loc, tp, inputs);
});
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,29 @@ void transform::TypeConversionCastShapeDynamicDimsOp::
converter.addSourceMaterialization([ignoreDynamicInfo](
OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
if (inputs.size() != 1) {
return std::nullopt;
return Value();
}
Value input = inputs[0];
if (!ignoreDynamicInfo &&
!tensor::preservesStaticInformation(resultType, input.getType())) {
return std::nullopt;
return Value();
}
if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
return std::nullopt;
return Value();
}
return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
});
converter.addTargetMaterialization([](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
if (inputs.size() != 1) {
return std::nullopt;
return Value();
}
Value input = inputs[0];
if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
return std::nullopt;
return Value();
}
return builder.create<tensor::CastOp>(loc, resultType, input).getResult();
});
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Tosa/Transforms/TosaTypeConverters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@ void mlir::tosa::populateTosaTypeConversion(TypeConverter &converter) {
});
converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
Expand Down
13 changes: 6 additions & 7 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2812,8 +2812,8 @@ Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
ValueRange inputs) const {
for (const MaterializationCallbackFn &fn :
llvm::reverse(argumentMaterializations))
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
return *result;
if (Value result = fn(builder, resultType, inputs, loc))
return result;
return nullptr;
}

Expand All @@ -2822,8 +2822,8 @@ Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
ValueRange inputs) const {
for (const MaterializationCallbackFn &fn :
llvm::reverse(sourceMaterializations))
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
return *result;
if (Value result = fn(builder, resultType, inputs, loc))
return result;
return nullptr;
}

Expand All @@ -2833,9 +2833,8 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
Type originalType) const {
for (const TargetMaterializationCallbackFn &fn :
llvm::reverse(targetMaterializations))
if (std::optional<Value> result =
fn(builder, resultType, inputs, loc, originalType))
return *result;
if (Value result = fn(builder, resultType, inputs, loc, originalType))
return result;
return nullptr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,8 @@ buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
///
/// This function has been copied (with small adaptions) from
/// TestDecomposeCallGraphTypes.cpp.
static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
TupleType resultType,
ValueRange inputs, Location loc) {
static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
ValueRange inputs, Location loc) {
// Build one value for each element at this nesting level.
SmallVector<Value> elements;
elements.reserve(resultType.getTypes().size());
Expand All @@ -201,13 +200,13 @@ static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
inputIt += numNestedFlattenedTypes;

// Recurse on the values for the nested TupleType.
std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
nestedFlattenedelements, loc);
if (!res.has_value())
return {};
Value res = buildMakeTupleOp(builder, nestedTupleType,
nestedFlattenedelements, loc);
if (!res)
return Value();

// The tuple constructed by the conversion is the element value.
elements.push_back(res.value());
elements.push_back(res);
} else {
// Base case: take one input as is.
elements.push_back(*inputIt++);
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct TestEmulateWideIntPass
// TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
// casts (and vice versa) and using it insted of `llvm.bitcast`.
auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> std::optional<Value> {
Location loc) -> Value {
auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
return cast->getResult(0);
};
Expand Down
15 changes: 7 additions & 8 deletions mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
/// type `resultType`. If that type is nested, each nested tuple is built
/// recursively with another `test.make_tuple` op.
static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
TupleType resultType,
ValueRange inputs, Location loc) {
static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
ValueRange inputs, Location loc) {
// Build one value for each element at this nesting level.
SmallVector<Value> elements;
elements.reserve(resultType.getTypes().size());
Expand All @@ -64,13 +63,13 @@ static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
inputIt += numNestedFlattenedTypes;

// Recurse on the values for the nested TupleType.
std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
nestedFlattenedelements, loc);
if (!res.has_value())
return {};
Value res = buildMakeTupleOp(builder, nestedTupleType,
nestedFlattenedelements, loc);
if (!res)
return Value();

// The tuple constructed by the conversion is the element value.
elements.push_back(res.value());
elements.push_back(res);
} else {
// Base case: take one input as is.
elements.push_back(*inputIt++);
Expand Down
Loading
Loading