-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[RISCV] Treat bf16->f32 as separate ExtKind in combineOp_VLToVWOp_VL. #144653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…p_VL. This allows us to better track the narrow type we need and to fix miscompiles if f16->f32 and bf16->f32 extends are mixed. Fixes llvm#144651. Still need to add tests, but it's late and I need sleep.
|
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesThis allows us to better track the narrow type we need and to fix miscompiles if f16->f32 and bf16->f32 extends are mixed. Fixes #144651. Still need to add tests, but it's late and I need sleep. Full diff: https://github.com/llvm/llvm-project/pull/144653.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e670567bd1844..f7d447e03af94 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16309,7 +16309,12 @@ namespace {
// apply a combine.
struct CombineResult;
-enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
+enum ExtKind : uint8_t {
+ ZExt = 1 << 0,
+ SExt = 1 << 1,
+ FPExt = 1 << 2,
+ BF16Ext = 1 << 3
+};
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
@@ -16344,8 +16349,10 @@ struct NodeExtensionHelper {
/// instance, a splat constant (e.g., 3), would support being both sign and
/// zero extended.
bool SupportsSExt;
- /// Records if this operand is like being floating-Point extended.
+ /// Records if this operand is like being floating point extended.
bool SupportsFPExt;
+ /// Records if this operand is extended from bf16.
+ bool SupportsBF16Ext;
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
@@ -16381,6 +16388,7 @@ struct NodeExtensionHelper {
case ExtKind::ZExt:
return RISCVISD::VZEXT_VL;
case ExtKind::FPExt:
+ case ExtKind::BF16Ext:
return RISCVISD::FP_EXTEND_VL;
}
llvm_unreachable("Unknown ExtKind enum");
@@ -16402,13 +16410,6 @@ struct NodeExtensionHelper {
if (Source.getValueType() == NarrowVT)
return Source;
- // vfmadd_vl -> vfwmadd_vl can take bf16 operands
- if (Source.getValueType().getVectorElementType() == MVT::bf16) {
- assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
- Root->getOpcode() == RISCVISD::VFMADD_VL);
- return Source;
- }
-
unsigned ExtOpc = getExtOpc(*SupportsExt);
// If we need an extension, we should be changing the type.
@@ -16451,7 +16452,8 @@ struct NodeExtensionHelper {
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- MVT EltVT = SupportsExt == ExtKind::FPExt
+ MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
+ : SupportsExt == ExtKind::FPExt
? MVT::getFloatingPointVT(NarrowSize)
: MVT::getIntegerVT(NarrowSize);
@@ -16628,17 +16630,17 @@ struct NodeExtensionHelper {
EnforceOneUse = false;
}
- bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
- const RISCVSubtarget &Subtarget) {
+ bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
+ if (NarrowEltVT == MVT::f32)
+ return true;
// Any f16 extension will need zvfh
- if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
- return false;
- // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
- // zvfbfwma
- if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
- Root->getOpcode() != RISCVISD::VFMADD_VL))
- return false;
- return true;
+ if (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16())
+ return true;
+ return false;
+ }
+
+ bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
+ return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
}
/// Helper method to set the various fields of this struct based on the
@@ -16648,6 +16650,7 @@ struct NodeExtensionHelper {
SupportsZExt = false;
SupportsSExt = false;
SupportsFPExt = false;
+ SupportsBF16Ext = false;
EnforceOneUse = true;
unsigned Opc = OrigOperand.getOpcode();
// For the nodes we handle below, we end up using their inputs directly: see
@@ -16679,9 +16682,11 @@ struct NodeExtensionHelper {
case RISCVISD::FP_EXTEND_VL: {
MVT NarrowEltVT =
OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
- if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
- break;
- SupportsFPExt = true;
+ if (isSupportedFPExtend(NarrowEltVT, Subtarget))
+ SupportsFPExt = true;
+ if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
+ SupportsBF16Ext = true;
+
break;
}
case ISD::SPLAT_VECTOR:
@@ -16698,16 +16703,16 @@ struct NodeExtensionHelper {
if (Op.getOpcode() != ISD::FP_EXTEND)
break;
- if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
- Subtarget))
- break;
-
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
if (NarrowSize != ScalarBits)
break;
- SupportsFPExt = true;
+ if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
+ SupportsFPExt = true;
+ if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
+ Subtarget))
+ SupportsBF16Ext = true;
break;
}
default:
@@ -16940,6 +16945,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
/*RHSExt=*/{ExtKind::FPExt});
+ if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
+ RHS.SupportsBF16Ext && Root->getOpcode() == RISCVISD::VFMADD_VL)
+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
+ Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
+ /*RHSExt=*/{ExtKind::BF16Ext});
return std::nullopt;
}
@@ -16953,9 +16963,10 @@ static std::optional<CombineResult>
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(
- Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
- Subtarget);
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS,
+ ExtKind::ZExt | ExtKind::SExt |
+ ExtKind::FPExt | ExtKind::BF16Ext,
+ DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
index 1639f21f243d8..aec970adff51e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
-; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
-; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
+; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
+; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
+; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
define <1 x float> @vfwmaccbf16_vv_v1f32(<1 x float> %a, <1 x bfloat> %b, <1 x bfloat> %c) {
; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v1f32:
@@ -295,3 +295,53 @@ define <32 x float> @vfwmaccbf32_vf_v32f32(<32 x float> %a, bfloat %b, <32 x bfl
%res = call <32 x float> @llvm.fma.v32f32(<32 x float> %b.ext, <32 x float> %c.ext, <32 x float> %a)
ret <32 x float> %res
}
+
+define <4 x float> @vfwmaccbf16_vf_v4f32_scalar_extend(<4 x float> %rd, bfloat %a, <4 x bfloat> %b) local_unnamed_addr #0 {
+; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFWMA-NEXT: vfwmaccbf16.vf v8, fa0, v9
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: fmv.x.w a0, fa0
+; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v10, v9
+; ZVFBFMIN-NEXT: slli a0, a0, 16
+; ZVFBFMIN-NEXT: fmv.w.x fa5, a0
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vf v8, fa5, v10
+; ZVFBFMIN-NEXT: ret
+ %b_ext = fpext <4 x bfloat> %b to <4 x float>
+ %a_extend = fpext bfloat %a to float
+ %a_insert = insertelement <4 x float> poison, float %a_extend, i64 0
+ %a_shuffle = shufflevector <4 x float> %a_insert, <4 x float> poison, <4 x i32> zeroinitializer
+ %fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_shuffle, <4 x float> %b_ext, <4 x float> %rd)
+ ret <4 x float> %fma
+}
+
+; Negative test with a mix of bfloat and half fpext.
+define <4 x float> @mix(<4 x float> %rd, <4 x half> %a, <4 x bfloat> %b) {
+; ZVFBFWMA-LABEL: mix:
+; ZVFBFWMA: # %bb.0:
+; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFWMA-NEXT: vfwcvt.f.f.v v11, v9
+; ZVFBFWMA-NEXT: vfwcvtbf16.f.f.v v9, v10
+; ZVFBFWMA-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; ZVFBFWMA-NEXT: vfmacc.vv v8, v11, v9
+; ZVFBFWMA-NEXT: ret
+;
+; ZVFBFMIN-LABEL: mix:
+; ZVFBFMIN: # %bb.0:
+; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; ZVFBFMIN-NEXT: vfwcvt.f.f.v v11, v9
+; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v9, v10
+; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v9
+; ZVFBFMIN-NEXT: ret
+ %a_ext = fpext <4 x half> %a to <4 x float>
+ %b_ext = fpext <4 x bfloat> %b to <4 x float>
+ %fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_ext, <4 x float> %b_ext, <4 x float> %rd)
+ ret <4 x float> %fma
+}
|
This avoids a root opcode check elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This allows us to better track the narrow type we need and to fix miscompiles if f16->f32 and bf16->f32 extends are mixed.
Fixes #144651.
Still need to add tests, but it's late and I need sleep.