Skip to content

Commit 1820102

Browse files
authored
Wasm fmuladd relaxed (#163177)
Reland #161355, after fixing up the cross-projects-tests for the wasm simd intrinsics. Original commit message: Lower v4f32 and v2f64 fmuladd calls to relaxed_madd instructions. If we have FP16, then lower v8f16 fmuladds to FMA. I've introduced an ISD node for fmuladd to maintain the rounding ambiguity through legalization / combine / isel.
1 parent 095cad6 commit 1820102

File tree

15 files changed

+1447
-53
lines changed

15 files changed

+1447
-53
lines changed

cross-project-tests/intrinsic-header-tests/wasm_simd128.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,13 +1511,13 @@ v128_t test_f16x8_convert_u16x8(v128_t a) {
15111511
}
15121512

15131513
// CHECK-LABEL: test_f16x8_relaxed_madd:
1514-
// CHECK: f16x8.relaxed_madd{{$}}
1514+
// CHECK: f16x8.madd{{$}}
15151515
v128_t test_f16x8_relaxed_madd(v128_t a, v128_t b, v128_t c) {
15161516
return wasm_f16x8_relaxed_madd(a, b, c);
15171517
}
15181518

15191519
// CHECK-LABEL: test_f16x8_relaxed_nmadd:
1520-
// CHECK: f16x8.relaxed_nmadd{{$}}
1520+
// CHECK: f16x8.nmadd{{$}}
15211521
v128_t test_f16x8_relaxed_nmadd(v128_t a, v128_t b, v128_t c) {
15221522
return wasm_f16x8_relaxed_nmadd(a, b, c);
15231523
}

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,12 @@ enum NodeType {
514514
/// separately rounded operations.
515515
FMAD,
516516

517+
/// FMULADD - Performs a * b + c, with, or without, intermediate rounding.
518+
/// It is expected that this will be illegal for most targets, as it usually
519+
/// makes sense to split this or use an FMA. But some targets, such as
520+
/// WebAssembly, can directly support these semantics.
521+
FMULADD,
522+
517523
/// FCOPYSIGN(X, Y) - Return the value of X with the sign of Y. NOTE: This
518524
/// DAG node does not require that X and Y have the same type, just that
519525
/// they are both floating point. X and the result must have the same type.

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def fdiv : SDNode<"ISD::FDIV" , SDTFPBinOp>;
535535
def frem : SDNode<"ISD::FREM" , SDTFPBinOp>;
536536
def fma : SDNode<"ISD::FMA" , SDTFPTernaryOp, [SDNPCommutative]>;
537537
def fmad : SDNode<"ISD::FMAD" , SDTFPTernaryOp, [SDNPCommutative]>;
538+
def fmuladd : SDNode<"ISD::FMULADD" , SDTFPTernaryOp, [SDNPCommutative]>;
538539
def fabs : SDNode<"ISD::FABS" , SDTFPUnaryOp>;
539540
def fminnum : SDNode<"ISD::FMINNUM" , SDTFPBinOp,
540541
[SDNPCommutative, SDNPAssociative]>;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ namespace {
509509
SDValue visitFMUL(SDNode *N);
510510
template <class MatchContextClass> SDValue visitFMA(SDNode *N);
511511
SDValue visitFMAD(SDNode *N);
512+
SDValue visitFMULADD(SDNode *N);
512513
SDValue visitFDIV(SDNode *N);
513514
SDValue visitFREM(SDNode *N);
514515
SDValue visitFSQRT(SDNode *N);
@@ -1991,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
19911992
case ISD::FMUL: return visitFMUL(N);
19921993
case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
19931994
case ISD::FMAD: return visitFMAD(N);
1995+
case ISD::FMULADD: return visitFMULADD(N);
19941996
case ISD::FDIV: return visitFDIV(N);
19951997
case ISD::FREM: return visitFREM(N);
19961998
case ISD::FSQRT: return visitFSQRT(N);
@@ -18444,6 +18446,21 @@ SDValue DAGCombiner::visitFMAD(SDNode *N) {
1844418446
return SDValue();
1844518447
}
1844618448

18449+
SDValue DAGCombiner::visitFMULADD(SDNode *N) {
18450+
SDValue N0 = N->getOperand(0);
18451+
SDValue N1 = N->getOperand(1);
18452+
SDValue N2 = N->getOperand(2);
18453+
EVT VT = N->getValueType(0);
18454+
SDLoc DL(N);
18455+
18456+
// Constant fold FMULADD.
18457+
if (SDValue C =
18458+
DAG.FoldConstantArithmetic(ISD::FMULADD, DL, VT, {N0, N1, N2}))
18459+
return C;
18460+
18461+
return SDValue();
18462+
}
18463+
1844718464
// Combine multiple FDIVs with the same divisor into multiple FMULs by the
1844818465
// reciprocal.
1844918466
// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5786,6 +5786,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
57865786
case ISD::FCOPYSIGN:
57875787
case ISD::FMA:
57885788
case ISD::FMAD:
5789+
case ISD::FMULADD:
57895790
case ISD::FP_EXTEND:
57905791
case ISD::FP_TO_SINT_SAT:
57915792
case ISD::FP_TO_UINT_SAT:
@@ -5904,6 +5905,7 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, const APInt &DemandedElts,
59045905
case ISD::FCOSH:
59055906
case ISD::FTANH:
59065907
case ISD::FMA:
5908+
case ISD::FMULADD:
59075909
case ISD::FMAD: {
59085910
if (SNaN)
59095911
return true;
@@ -7231,7 +7233,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
72317233
}
72327234

72337235
// Handle fma/fmad special cases.
7234-
if (Opcode == ISD::FMA || Opcode == ISD::FMAD) {
7236+
if (Opcode == ISD::FMA || Opcode == ISD::FMAD || Opcode == ISD::FMULADD) {
72357237
assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
72367238
assert(Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
72377239
Ops[2].getValueType() == VT && "FMA types must match!");
@@ -7242,7 +7244,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
72427244
APFloat V1 = C1->getValueAPF();
72437245
const APFloat &V2 = C2->getValueAPF();
72447246
const APFloat &V3 = C3->getValueAPF();
7245-
if (Opcode == ISD::FMAD) {
7247+
if (Opcode == ISD::FMAD || Opcode == ISD::FMULADD) {
72467248
V1.multiply(V2, APFloat::rmNearestTiesToEven);
72477249
V1.add(V3, APFloat::rmNearestTiesToEven);
72487250
} else

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6996,6 +6996,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
69966996
getValue(I.getArgOperand(0)),
69976997
getValue(I.getArgOperand(1)),
69986998
getValue(I.getArgOperand(2)), Flags));
6999+
} else if (TLI.isOperationLegalOrCustom(ISD::FMULADD, VT)) {
7000+
// TODO: Support splitting the vector.
7001+
setValue(&I, DAG.getNode(ISD::FMULADD, sdl,
7002+
getValue(I.getArgOperand(0)).getValueType(),
7003+
getValue(I.getArgOperand(0)),
7004+
getValue(I.getArgOperand(1)),
7005+
getValue(I.getArgOperand(2)), Flags));
69997006
} else {
70007007
// TODO: Intrinsic calls should have fast-math-flags.
70017008
SDValue Mul = DAG.getNode(

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
310310
case ISD::FMA: return "fma";
311311
case ISD::STRICT_FMA: return "strict_fma";
312312
case ISD::FMAD: return "fmad";
313+
case ISD::FMULADD: return "fmuladd";
313314
case ISD::FREM: return "frem";
314315
case ISD::STRICT_FREM: return "strict_frem";
315316
case ISD::FCOPYSIGN: return "fcopysign";

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7676,6 +7676,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
76767676
break;
76777677
}
76787678
case ISD::FMA:
7679+
case ISD::FMULADD:
76797680
case ISD::FMAD: {
76807681
if (!Flags.hasNoSignedZeros())
76817682
break;

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,8 @@ void TargetLoweringBase::initActions() {
815815
ISD::FTAN, ISD::FACOS,
816816
ISD::FASIN, ISD::FATAN,
817817
ISD::FCOSH, ISD::FSINH,
818-
ISD::FTANH, ISD::FATAN2},
818+
ISD::FTANH, ISD::FATAN2,
819+
ISD::FMULADD},
819820
VT, Expand);
820821

821822
// Overflow operations default to expand

llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,15 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
317317
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
318318
}
319319

320+
if (Subtarget->hasFP16()) {
321+
setOperationAction(ISD::FMA, MVT::v8f16, Legal);
322+
}
323+
324+
if (Subtarget->hasRelaxedSIMD()) {
325+
setOperationAction(ISD::FMULADD, MVT::v4f32, Legal);
326+
setOperationAction(ISD::FMULADD, MVT::v2f64, Legal);
327+
}
328+
320329
// Partial MLA reductions.
321330
for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
322331
setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal);
@@ -1120,6 +1129,18 @@ WebAssemblyTargetLowering::getPreferredVectorAction(MVT VT) const {
11201129
return TargetLoweringBase::getPreferredVectorAction(VT);
11211130
}
11221131

1132+
bool WebAssemblyTargetLowering::isFMAFasterThanFMulAndFAdd(
1133+
const MachineFunction &MF, EVT VT) const {
1134+
if (!Subtarget->hasFP16() || !VT.isVector())
1135+
return false;
1136+
1137+
EVT ScalarVT = VT.getScalarType();
1138+
if (!ScalarVT.isSimple())
1139+
return false;
1140+
1141+
return ScalarVT.getSimpleVT().SimpleTy == MVT::f16;
1142+
}
1143+
11231144
bool WebAssemblyTargetLowering::shouldSimplifyDemandedVectorElts(
11241145
SDValue Op, const TargetLoweringOpt &TLO) const {
11251146
// ISel process runs DAGCombiner after legalization; this step is called

0 commit comments

Comments
 (0)