-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[AMD][ROCDL] Add packed conversions fp8/bf8->bf16 and fp8/bf8->fp32 in ROCDL dialect #131850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…/bf8->fp32 Add packed conversions fp8/bf8->bf16 in gfx950 and fp8/bf8->fp32 in gfx942 Update amdgpu.ext_packed_fp8 lowering to use ROCDL CvtPkF32Fp8Op
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-backend-amdgpu Author: Yi Qian (yiqian1) Changes
Patch is 48.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131850.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 3acc383923ca8..3ed6e84d19044 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -85,11 +85,12 @@ def AMDGPU_ExtPackedFp8Op :
AMDGPU_Op<"ext_packed_fp8", [Pure]>,
Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN,
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source,
- ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
- Results<(outs F32:$res)> {
- let summary = "Extend one of a vector of packed fp8 values to a float";
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex)>,
+ Results<(outs FixedVectorOfLengthAndType<[2], [F32]>:$res)> {
+ let summary = "Extend a vector of packed fp8 values to two floats";
+
let description = [{
- Extend the value `source[index]` to a 32-bit float and return it.
+ Extend the two 8-bit floats in `source[wordrIndex]` to two 32-bit floats and return them.
This rather unusual signature arises from the fact that AMD GPUs cannot
easily work with sub 32-bit quantities, so the compiler intrinsics for
@@ -97,11 +98,11 @@ def AMDGPU_ExtPackedFp8Op :
this operation) take packed vectors of 4 such floats.
If the passed-in vector has fewer than four elements, or the input is scalar,
- the remaining values in the <4 x i8> will be filled with with
+ the remaining values in the <4 x i8> will be filled with
undefined values as needed.
}];
let assemblyFormat = [{
- attr-dict $source `[` $index `]` `:` type($source) `to` type($res)
+ attr-dict $source `[` $wordIndex `]` `:` type($source) `to` type($res)
}];
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index f194e70ee275b..9a433202e3149 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -681,26 +681,26 @@ def ROCDL_CvtPkRtz:
}];
}
-def ROCDL_CvtScaleF32PkFp8F16 :
+def ROCDL_CvtScaleF32PkFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed fp8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF32PkFp8Bf16 :
+def ROCDL_CvtScaleF32PkFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf16 to packed fp8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
@@ -708,13 +708,13 @@ def ROCDL_CvtScaleF32PkFp8Bf16 :
}
-def ROCDL_CvtScaleF32PkBf8F16 :
+def ROCDL_CvtScaleF32PkBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed bf8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
@@ -722,26 +722,26 @@ def ROCDL_CvtScaleF32PkBf8F16 :
}
-def ROCDL_CvtScaleF32PkBf8Bf16 :
+def ROCDL_CvtScaleF32PkBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert bf16 to packed bf8";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8.
+ Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF32SrFp8F16 :
+def ROCDL_CvtScaleF32SrFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -749,13 +749,13 @@ def ROCDL_CvtScaleF32SrFp8F16 :
}];
}
-def ROCDL_CvtScaleF32SrBf8F16 :
+def ROCDL_CvtScaleF32SrBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -763,13 +763,13 @@ def ROCDL_CvtScaleF32SrBf8F16 :
}];
}
-def ROCDL_CvtScaleF32SrFp8Bf16 :
+def ROCDL_CvtScaleF32SrFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -777,13 +777,13 @@ def ROCDL_CvtScaleF32SrFp8Bf16 :
}];
}
-def ROCDL_CvtScaleF32SrBf8Bf16:
+def ROCDL_CvtScaleF32SrBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
- using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
+ Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
+ using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
@@ -791,48 +791,74 @@ def ROCDL_CvtScaleF32SrBf8Bf16:
}];
}
-def ROCDL_CvtScaleF32PkF16Fp8 :
+def ROCDL_CvtScaleF32PkF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert fp8 to packed f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to packed f16.
+ let summary = "Convert fp8 to packed f16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed f16, then scale
+ the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScaleF32PkF16Bf8 :
+def ROCDL_CvtScaleF32PkF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
- let summary = "Scale and convert bf8 to packed f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to packed f16.
+ let summary = "convert bf8 to packed f16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed f16, then scale
+ the packed values by exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScaleF16Fp8 :
+def ROCDL_CvtScaleF32PkBf16Fp8Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+ let summary = "Convert fp8 to packed bf16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF32PkBf16Bf8Op :
+ ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
+ let summary = "Convert bf8 to packed bf16 and scale";
+ let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
+ the packed values by the exponent in `scale`.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
+ }];
+}
+
+def ROCDL_CvtScaleF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+ let description = [{ Convert `src` based on $wordSel to f16, then scale the value
+ by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
+ preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
-def ROCDL_CvtScaleF16Bf8 :
+def ROCDL_CvtScaleF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
- let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
- then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
+ let description = [{ Convert `src` based on $wordSel to f16, then scale the value
+ by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
+ preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
@@ -842,25 +868,25 @@ def ROCDL_CvtScaleF16Bf8 :
//===---------------------------------------------------------------------===//
// 32-bit float intrinsics
//===---------------------------------------------------------------------===//
-def ROCDL_CvtScale32PkF32Fp8 :
+def ROCDL_CvtScaleF32PkF32Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed fp8 to packed f32";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp32.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Convert `src` based on $wordSel to packed fp32, then scale the packed values by
+ the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
-def ROCDL_CvtScale32PkF32Bf8 :
+def ROCDL_CvtScaleF32PkF32Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf8 to packed f32";
let description = [{
- Scale `src` by the exponent in `scale` then convert to packed fp32.
- Store the result in low/high word based on $wordSel, preserving the other word.
+ Convert `src` based on $wordSel to packed fp32, then scale the packed values by
+ the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
@@ -869,7 +895,7 @@ def ROCDL_CvtScale32PkF32Bf8 :
//===---------------------------------------------------------------------===//
// 8-bit float scale intrinsics
//===---------------------------------------------------------------------===//
-def ROCDL_CvtScaleF32PkFp8F32:
+def ROCDL_CvtScaleF32PkFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed fp8";
@@ -882,7 +908,7 @@ def ROCDL_CvtScaleF32PkFp8F32:
}];
}
-def ROCDL_CvtScaleF32PkBf8F32:
+def ROCDL_CvtScaleF32PkBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed bf8";
@@ -895,7 +921,7 @@ def ROCDL_CvtScaleF32PkBf8F32:
}];
}
-def ROCDL_CvtScaleF32SrFp8F32:
+def ROCDL_CvtScaleF32SrFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to fp8 using stochastic rounding";
@@ -909,7 +935,7 @@ def ROCDL_CvtScaleF32SrFp8F32:
}
-def ROCDL_CvtScaleF32SrBf8F32:
+def ROCDL_CvtScaleF32SrBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to bf8 using stochastic rounding";
@@ -978,6 +1004,29 @@ def ROCDL_CvtScaleF32Fp8Op :
}];
}
+def ROCDL_CvtPkF32Fp8Op :
+ ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, I1:$wordSel)> {
+ let summary = "Convert packed fp8 to packed f32";
+ let description = [{
+ Convert `src` based on $wordSel to packed fp32.
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `:` type($res)
+ }];
+}
+
+def ROCDL_CvtPkF32Bf8Op :
+ ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
+ Arguments<(ins I32:$src, I1:$wordSel)> {
+ let summary = "Convert packed bf8 to packed f32";
+ let description = [{
+ Convert `src` based on $wordSel to packed fp32,
+ }];
+ let assemblyFormat = [{
+ attr-dict $src `[` $wordSel `]` `:` type($res)
+ }];
+}
def ROCDL_CvtPkBf8F32Op :
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..768d21384412d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -977,13 +977,13 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
source = longVec;
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
- Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
+ Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
- wordSel);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
+ wordSel);
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
- rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
- wordSel);
+ rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
+ wordSel);
}
return success();
}
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..f9b685d1e90f6 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -83,14 +83,15 @@ static bool isSupportedF8(Type elementType, Chipset chipset) {
return false;
}
-static Value castF32To(Type elementType, Value f32, Location loc,
+static Value castF32To(Type desType, Value f32, Location loc,
PatternRewriter &rewriter) {
+ Type elementType = getElementTypeOrSelf(desType);
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
- return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
+ return rewriter.create<arith::TruncFOp>(loc, desType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
- return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
+ return rewriter.create<arith::ExtFOp>(loc, desType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}
@@ -110,10 +111,12 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+ VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), in, 0);
- Value result = castF32To(outElemType, asFloat, loc, rewriter);
+ Value asFloats =
+ rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType, in, 0);
+ Value resFloat = rewriter.create<vector::ExtractOp>(loc, asFloats, 0);
+ Value result = castF32To(outElemType, resFloat, loc, rewriter);
rewriter.replaceOp(op, result);
return success();
}
@@ -150,11 +153,18 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, in, i, elemsThisOp, 1);
- for (int64_t j = 0; j < elemsThisOp; ++j) {
- Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
- loc, rewriter.getF32Type(), inSlice, j);
- Value asType = castF32To(outElemType, asFloat, loc, rewriter);
- result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
+ for (int64_t j = 0; j < elemsThisOp; j += 2) {
+ Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(loc, extResType,
+ inSlice, j / 2);
+ Type desType = VectorType::get(2, outElemType);
+ Value asType = castF32To(desType, asFloats, loc, rewriter);
+ if (i + j + 1 < numElements)
+ ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'm just going to check so it's out there - what's the portability on these?
Second, I object to unconditionally using the packed instructions - it's a waste of a register.
This code should use the packed instructions where there is actually a vector, and then fall back to the scalar ones for the odd final element / for the scalar case.
To rephrase the portability question - what prevents us from generating the ->bf16 instructions on gfx942? |
There are no ->bf16 instructions on gfx942. They're only available on gfx950. |
So where in the code do we prevent those instructions from being emitted on gfx942? |
Hmm, I don't think we prevent it in the code. |
You'll just get an invalid ISA instruction for that particular arch error |
@yiqian1 Please add guards to prevent using the bf16 instructions on gfx942 That is,
on gfx942 needs to go via |
An arith.extf to any non-f32 types will go via f32, as shown in these test cases
and
Note that currently fp8/bf8->bf16 conversions are not used in lowering amdgpu.ext_packed_fp8. We can add them for gfx950 in the future if we want. In this PR, I just updated amdgpu.ext_packed_fp8 with packed fp8/bf8->fp32 conversions, which are available on both gfx942 and gfx950. |
Ah Then the PR title confused me - I thought you were plumbing those through the lowering So ... yeah, I'll take another look and then lift my hold on the assumption that the bf16 conversions'll be followup work (You might want to fiddle with the PR title though to make it clearer that you're not supporting packed ext to bf16 with this change) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Pulls in llvm/llvm-project#131850 for type conversion ops.
Pulls in llvm/llvm-project#131850 for type conversion ops.
Pulls in these changes: - Fp8/BF8 type conversion ops (llvm/llvm-project#131850) - Accept Triples in createTargetMachine() (llvm/llvm-project#130940) Co-authored-by: Lei Zhang <antiagainst@gmail.com>
Pulls in llvm/llvm-project#131850 for type conversion ops.
- Add calls to `populateVectorGatherLoweringPatterns` and `populateVectorGatherToConditionalLoadPatterns` as needed to handle upstream llvm/llvm-project#132206 - Update ROCDL tests to account for llvm/llvm-project#131850 - Disable `tosa.reduce_prod` and `tosa.clamp` tests on integers temporarily since they're failing Tosa verification - see #20422 Explicitly set `amdhsa_code_object_version` when constructing ROCm LLVM modules, explicitly pass `-mcode-object-version=N` to microkernel compiles, and explicitly set the `__oclc_ABI_version` global, so that we don't get hit with errors about being unable to run our binaries from the compiler default code object version getting bumped further than what our users or CI support. The default code object version in LLVM as of this commit is 6, we pin to 5. The followup work to update to v6 is #20423 .
- Add calls to `populateVectorGatherLoweringPatterns` and `populateVectorGatherToConditionalLoadPatterns` as needed to handle upstream llvm/llvm-project#132206 - Update ROCDL tests to account for llvm/llvm-project#131850 - Disable `tosa.reduce_prod` and `tosa.clamp` tests on integers temporarily since they're failing Tosa verification - see iree-org#20422 Explicitly set `amdhsa_code_object_version` when constructing ROCm LLVM modules, explicitly pass `-mcode-object-version=N` to microkernel compiles, and explicitly set the `__oclc_ABI_version` global, so that we don't get hit with errors about being unable to run our binaries from the compiler default code object version getting bumped further than what our users or CI support. The default code object version in LLVM as of this commit is 6, we pin to 5. The followup work to update to v6 is iree-org#20423 .
Uh oh!
There was an error while loading. Please reload this page.