Skip to content

[mlir][gpu] Add pass for emulating unsupported types. #138087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
// Map strings to float types.
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);

// Map strings to Int types.
std::optional<IntegerType> parseIntType(MLIRContext *ctx, StringRef name);

// Map strings to int or float types.
std::optional<Type> parseIntOrFloatType(MLIRContext *ctx, StringRef name);

} // namespace arith
} // namespace mlir

Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include <optional>
Expand Down Expand Up @@ -87,6 +89,24 @@ void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
PatternBenefit benefit = 1);

/// Set up a type converter to convert unsupported source types to
/// supported target types.
void populateImitateUnsupportedTypesTypeConverter(TypeConverter &typeConverter,
ArrayRef<Type> sourceTypes,
ArrayRef<Type> targetTypes);

/// Collect a set of pattern needed to imitate unsupported source types
/// using supported target types.
void populateImitateUnsupportedTypesConversionPatterns(
RewritePatternSet &patterns, TypeConverter &typeConverter,
ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
DenseMap<StringAttr, FunctionType> &convertedFuncTypes);

/// Set up a dialect conversion to reject operations on unsupported
/// float types.
void configureImitateUnsupportedTypesLegality(ConversionTarget &target,
TypeConverter &typeConverter);

/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
Expand Down
53 changes: 53 additions & 0 deletions mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,57 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
];
}

def GpuImitateUnsupportedTypes : Pass<"imitate-unsupported-types", "::mlir::ModuleOp"> {
let summary = "Imitate unsupported types with supported types of same bitwidth.";
let description = [{
This pass imitates (bitcast/reinterpret_cast) unsupported types
with supported types of same bitwidth. The imitation is done
by bitcasting the unspported types to the supported types of same bitwidth.
Therefore, the source type and destination type must have the same bitwidth.
The imitation is done by using the following operations: arith.bitcast.

The imitation is often needed when the GPU target (dialect/IR) does not
support a certain type but the underlying architecture does. Take SPIR-V for
example, it does not support bf16, but an underlying architecture (e.g.,
intel pvc gpu) that uses SPIR-V for code-generation does.
Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
SPIR-V kernel can then use the imitated type (i16) in the computation.
However, i16 is not the same as bf16 (integer vs float), so the computation
can not readily use the imitated type (i16).

Therefore, this transformation pass is intended to be used in conjuction
with other transformation passes such as `EmulateUnsupportedFloats` and
`ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
vice-versa.

Finally, usually, there are instructions available in the target
(dialect/IR) that can take advantage of these generated patterns
(bf16->i16->f32, f32->bf16->i16), and convert them to the supported
types.
For example, Intel provides SPIR-V extension ops that can
take imitated bf16 (i16) and convert them to f32 and vice-versa.
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op

}];
Copy link
Collaborator

@joker-eph joker-eph May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation only touches specific ops, but from the pass description it's absolutely not clear to me what is the scope here, especially considering we have also EmulateUnsupportedFloats

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the description to provide a overall view. Please let me know if it's enough.


let options = [
ListOption<"sourceTypeStrs", "source-types", "std::string",
"MLIR types without type support on a given target">,
ListOption<"targetTypeStrs", "target-types", "std::string",
"MLIR types to convert the unsupported source types to">,
];

let dependentDialects = [
"::mlir::gpu::GPUDialect",
"::mlir::arith::ArithDialect",
"::mlir::memref::MemRefDialect"
];
}


#endif // MLIR_DIALECT_GPU_PASSES
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,29 @@ std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
.Default(std::nullopt);
}

/// Map strings to Int types.
std::optional<IntegerType> parseIntType(MLIRContext *ctx, StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<IntegerType>>(name)
.Case("i1", b.getIntegerType(1))
.Case("i2", b.getIntegerType(2))
.Case("i4", b.getIntegerType(4))
.Case("i6", b.getIntegerType(6))
.Case("i8", b.getIntegerType(8))
.Case("i16", b.getIntegerType(16))
.Case("i32", b.getIntegerType(32))
.Case("i64", b.getIntegerType(64))
.Case("i80", b.getIntegerType(80))
.Case("i128", b.getIntegerType(128))
.Default(std::nullopt);
}
/// Map strings to Int or Float types.
std::optional<Type> parseIntOrFloatType(MLIRContext *ctx, StringRef name) {
if (auto floatTy = parseFloatType(ctx, name))
return *floatTy;
if (auto intTy = parseIntType(ctx, name))
return *intTy;
return std::nullopt;
}

} // namespace mlir::arith
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ add_mlir_dialect_library(MLIRGPUDialect
MLIRMemRefDialect
MLIRSideEffectInterfaces
MLIRSupport
)
)

add_mlir_dialect_library(MLIRGPUTransforms
Transforms/AllReduceLowering.cpp
Expand All @@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupIdRewriter.cpp
Transforms/SubgroupReduceLowering.cpp
Transforms/ImitateUnsupportedTypes.cpp

OBJECT

Expand Down Expand Up @@ -76,7 +77,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
MLIRROCDLTarget
MLIRTransformUtils
MLIRVectorDialect
)
)

add_subdirectory(TransformOps)
add_subdirectory(Pipelines)
Expand Down
Loading