Skip to content

[mlir][IR] Remove isF...() type API for low-precision FP types #123326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -329,31 +329,31 @@ def F64 : F<64>;
def F80 : F<80>;
def F128 : F<128>;

def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
def BF16 : Type<CPred<"::llvm::isa<BFloat16Type>($_self)">, "bfloat16 type">,
BuildableType<"$_builder.getType<BFloat16Type>()">;
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
def TF32 : Type<CPred<"::llvm::isa<FloatTF32Type>($_self)">, "tf32 type">,
BuildableType<"$_builder.getType<FloatTF32Type>()">;
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
def F8E4M3FN : Type<CPred<"::llvm::isa<Float8E4M3FNType>($_self)">, "f8E4M3FN type">,
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
def F8E5M2 : Type<CPred<"::llvm::isa<Float8E5M2Type>($_self)">, "f8E5M2 type">,
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
def F8E4M3 : Type<CPred<"::llvm::isa<Float8E4M3Type>($_self)">, "f8E4M3 type">,
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
def F8E4M3FNUZ : Type<CPred<"::llvm::isa<Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">,
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
def F8E5M2FNUZ : Type<CPred<"::llvm::isa<Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
def F8E3M4 : Type<CPred<"::llvm::isa<Float8E3M4Type>($_self)">, "f8E3M4 type">,
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
def F4E2M1FN : Type<CPred<"::llvm::isa<Float4E2M1FNType>($_self)">, "f4E2M1FN type">,
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
def F6E2M3FN : Type<CPred<"::llvm::isa<Float6E2M3FNType>($_self)">, "f6E2M3FN type">,
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
def F6E3M2FN : Type<CPred<"::llvm::isa<Float6E3M2FNType($_self)">, "f6E3M2FN type">,
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
def F8E8M0FNU : Type<CPred<"::llvm::isa<Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">,
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;

def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
Expand Down
11 changes: 0 additions & 11 deletions mlir/include/mlir/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,6 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
bool isFloat4E2M1FN() const;
bool isFloat6E2M3FN() const;
bool isFloat6E3M2FN() const;
bool isFloat8E5M2() const;
bool isFloat8E4M3() const;
bool isFloat8E4M3FN() const;
bool isFloat8E5M2FNUZ() const;
bool isFloat8E4M3FNUZ() const;
bool isFloat8E4M3B11FNUZ() const;
bool isFloat8E3M4() const;
bool isFloat8E8M0FNU() const;
bool isBF16() const;
bool isF16() const;
bool isTF32() const;
Expand Down
42 changes: 26 additions & 16 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
}

bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
return unwrap(type).isFloat4E2M1FN();
return llvm::isa<Float4E2M1FNType>(unwrap(type));
}

MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
Expand All @@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
}

bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
return unwrap(type).isFloat6E2M3FN();
return llvm::isa<Float6E2M3FNType>(unwrap(type));
}

MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
Expand All @@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
}

bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
return unwrap(type).isFloat6E3M2FN();
return llvm::isa<Float6E3M2FNType>(unwrap(type));
}

MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
Expand All @@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
}

bool mlirTypeIsAFloat8E5M2(MlirType type) {
return unwrap(type).isFloat8E5M2();
return llvm::isa<Float8E5M2Type>(unwrap(type));
}

MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
Expand All @@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
}

bool mlirTypeIsAFloat8E4M3(MlirType type) {
return unwrap(type).isFloat8E4M3();
return llvm::isa<Float8E4M3Type>(unwrap(type));
}

MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
Expand All @@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
}

bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
return unwrap(type).isFloat8E4M3FN();
return llvm::isa<Float8E4M3FNType>(unwrap(type));
}

MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
Expand All @@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
}

bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
return unwrap(type).isFloat8E5M2FNUZ();
return llvm::isa<Float8E5M2FNUZType>(unwrap(type));
}

MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
Expand All @@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
}

bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3FNUZ();
return llvm::isa<Float8E4M3FNUZType>(unwrap(type));
}

MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
Expand All @@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
}

bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3B11FNUZ();
return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type));
}

MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
Expand All @@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
}

bool mlirTypeIsAFloat8E3M4(MlirType type) {
return unwrap(type).isFloat8E3M4();
return llvm::isa<Float8E3M4Type>(unwrap(type));
}

MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
Expand All @@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
}

bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
return unwrap(type).isFloat8E8M0FNU();
return llvm::isa<Float8E8M0FNUType>(unwrap(type));
}

MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
Expand All @@ -221,15 +221,19 @@ MlirTypeID mlirBFloat16TypeGetTypeID() {
return wrap(BFloat16Type::getTypeID());
}

bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
bool mlirTypeIsABF16(MlirType type) {
return llvm::isa<BFloat16Type>(unwrap(type));
}

MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(BFloat16Type::get(unwrap(ctx)));
}

MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }

bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
bool mlirTypeIsAF16(MlirType type) {
return llvm::isa<Float16Type>(unwrap(type));
}

MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(Float16Type::get(unwrap(ctx)));
Expand All @@ -239,23 +243,29 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() {
return wrap(FloatTF32Type::getTypeID());
}

bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
bool mlirTypeIsATF32(MlirType type) {
return llvm::isa<FloatTF32Type>(unwrap(type));
}

MlirType mlirTF32TypeGet(MlirContext ctx) {
return wrap(FloatTF32Type::get(unwrap(ctx)));
}

MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }

bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
bool mlirTypeIsAF32(MlirType type) {
return llvm::isa<Float32Type>(unwrap(type));
}

MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(Float32Type::get(unwrap(ctx)));
}

MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }

bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
bool mlirTypeIsAF64(MlirType type) {
return llvm::isa<Float64Type>(unwrap(type));
}

MlirType mlirF64TypeGet(MlirContext ctx) {
return wrap(Float64Type::get(unwrap(ctx)));
Expand Down
38 changes: 20 additions & 18 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,38 +564,40 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
}

if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
chipset >= kGfx940) {
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
}
}

if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
chipset >= kGfx940) {
Type sourceBElem =
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
}
if (m == 32 && n == 32 && k == 16 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
if (sourceBElem.isFloat8E4M3FNUZ())
if (isa<Float8E4M3FNUZType>(sourceBElem))
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
}
}
Expand Down Expand Up @@ -623,9 +625,9 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
return std::nullopt;
}
Expand Down Expand Up @@ -803,10 +805,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
if (sourceElemType.isFloat8E5M2FNUZ()) {
if (isa<Float8E5M2FNUZType>(sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
wordSel);
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
} else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
wordSel);
}
Expand Down Expand Up @@ -838,10 +840,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());

Value result;
if (resultElemType.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
else if (resultElemType.isFloat8E4M3FNUZ())
else if (isa<Float8E4M3FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);

Expand Down Expand Up @@ -873,10 +875,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());

Value result;
if (resultElemType.isFloat8E5M2FNUZ())
if (isa<Float8E5M2FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
else if (resultElemType.isFloat8E4M3FNUZ())
else if (isa<Float8E4M3FNUZType>(resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
return failure();
inType = inVecType.getElementType();
}
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
}

void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Expand Down Expand Up @@ -216,7 +216,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (inType && inType.getWidth() <= 8 && saturateFP8)
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
}

void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Expand Down
9 changes: 4 additions & 5 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,10 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
return type;

// F4, F6, F8 types are converted to integer types with the same bit width.
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
type.isFloat8E8M0FNU())
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
Float8E8M0FNUType>(type))
return IntegerType::get(&getContext(), type.getWidth());

// Other floating-point types: A custom type conversion rule must be
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,8 +1254,8 @@ struct NVGPUWarpgroupMmaOpLowering
wgmmaK = 8;
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
wgmmaK = 16;
} else if (inputElemType.isFloat8E4M3FN() ||
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
} else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
inputElemType.isInteger(16)) {
wgmmaK = 32;
} else if (inputElemType.isInteger(1)) {
wgmmaK = 256;
Expand All @@ -1276,9 +1276,9 @@ struct NVGPUWarpgroupMmaOpLowering
return NVVM::WGMMATypes::f16;
if (elemType.isBF16())
return NVVM::WGMMATypes::bf16;
if (elemType.isFloat8E4M3FN())
if (isa<Float8E4M3FNType>(elemType))
return NVVM::WGMMATypes::e4m3;
if (elemType.isFloat8E5M2())
if (isa<Float8E5M2Type>(elemType))
return NVVM::WGMMATypes::e5m2;
if (elemType.isInteger(1))
return NVVM::WGMMATypes::b1;
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
}

Type sourceBType = getSourceB().getType();
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
return emitOpError("expected both source operands to have f8 elements");
if (sourceLen != sourceBLen)
return emitOpError(
Expand Down
Loading
Loading