Skip to content

Commit 26b71a3

Browse files
[LLVM][CodeGen] Add lowering for scalable vector bfloat operations.
Specifically: fabs, fadd, fceil, fdiv, ffloor, fma, fmax, fmaxnm, fmin, fminnm, fmul, fnearbyint, fneg, frint, fround, froundeven, fsub, fsqrt & ftrunc
1 parent 3e3780e commit 26b71a3

File tree

12 files changed

+1970
-851
lines changed

12 files changed

+1970
-851
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class VectorLegalizer {
141141
SDValue ExpandSELECT(SDNode *Node);
142142
std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
143143
SDValue ExpandStore(SDNode *N);
144+
SDValue ExpandBF16Arith(SDNode *Node);
144145
SDValue ExpandFNEG(SDNode *Node);
145146
SDValue ExpandFABS(SDNode *Node);
146147
SDValue ExpandFCOPYSIGN(SDNode *Node);
@@ -1070,13 +1071,21 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
10701071
break;
10711072
case ISD::FMINNUM:
10721073
case ISD::FMAXNUM:
1074+
if (SDValue Expanded = ExpandBF16Arith(Node)) {
1075+
Results.push_back(Expanded);
1076+
return;
1077+
}
10731078
if (SDValue Expanded = TLI.expandFMINNUM_FMAXNUM(Node, DAG)) {
10741079
Results.push_back(Expanded);
10751080
return;
10761081
}
10771082
break;
10781083
case ISD::FMINIMUM:
10791084
case ISD::FMAXIMUM:
1085+
if (SDValue Expanded = ExpandBF16Arith(Node)) {
1086+
Results.push_back(Expanded);
1087+
return;
1088+
}
10801089
Results.push_back(TLI.expandFMINIMUM_FMAXIMUM(Node, DAG));
10811090
return;
10821091
case ISD::SMIN:
@@ -1197,6 +1206,24 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
11971206
case ISD::UCMP:
11981207
Results.push_back(TLI.expandCMP(Node, DAG));
11991208
return;
1209+
1210+
case ISD::FADD:
1211+
case ISD::FMUL:
1212+
case ISD::FMA:
1213+
case ISD::FDIV:
1214+
case ISD::FCEIL:
1215+
case ISD::FFLOOR:
1216+
case ISD::FNEARBYINT:
1217+
case ISD::FRINT:
1218+
case ISD::FROUND:
1219+
case ISD::FROUNDEVEN:
1220+
case ISD::FTRUNC:
1221+
case ISD::FSQRT:
1222+
if (SDValue Expanded = ExpandBF16Arith(Node)) {
1223+
Results.push_back(Expanded);
1224+
return;
1225+
}
1226+
break;
12001227
}
12011228

12021229
SDValue Unrolled = DAG.UnrollVectorOp(Node);
@@ -1874,6 +1901,11 @@ void VectorLegalizer::ExpandFSUB(SDNode *Node,
18741901
TLI.isOperationLegalOrCustom(ISD::FADD, VT))
18751902
return; // Defer to LegalizeDAG
18761903

1904+
if (SDValue Expanded = ExpandBF16Arith(Node)) {
1905+
Results.push_back(Expanded);
1906+
return;
1907+
}
1908+
18771909
SDValue Tmp = DAG.UnrollVectorOp(Node);
18781910
Results.push_back(Tmp);
18791911
}
@@ -2134,6 +2166,67 @@ bool VectorLegalizer::tryExpandVecMathCall(
21342166
return tryExpandVecMathCall(Node, LC, Results);
21352167
}
21362168

2169+
// Try to lower BFloat arithmetic by performing the same operation on operands
2170+
// that have been promoted to Float32, the result of which is then truncated.
2171+
// If promotion requires non-legal types the operation is split with the
2172+
// promotion occuring during a successive call to this function.
2173+
SDValue VectorLegalizer::ExpandBF16Arith(SDNode *Node) {
2174+
EVT VT = Node->getValueType(0);
2175+
if (VT.getVectorElementType() != MVT::bf16)
2176+
return SDValue();
2177+
2178+
SDLoc DL(Node);
2179+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2180+
unsigned Opcode = Node->getOpcode();
2181+
2182+
// Can we promote to float and try again?
2183+
2184+
EVT PromoteVT = VT.changeVectorElementType(MVT::f32);
2185+
if (TLI.isTypeLegal(PromoteVT)) {
2186+
// Don't expand if the result is likely to be unrolled anyway.
2187+
if (!TLI.isOperationLegalOrCustom(Opcode, PromoteVT))
2188+
return SDValue();
2189+
2190+
SmallVector<SDValue, 4> Ops;
2191+
for (const SDValue &V : Node->op_values())
2192+
Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V));
2193+
2194+
SDValue PromotedOp = DAG.getNode(Opcode, DL, PromoteVT, Ops);
2195+
return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp,
2196+
DAG.getIntPtrConstant(0, DL, true));
2197+
}
2198+
2199+
// Can we split the vector and try again?
2200+
2201+
if (VT.getVectorMinNumElements() == 1)
2202+
return SDValue();
2203+
2204+
EVT LoVT, HiVT;
2205+
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
2206+
2207+
// Restrict expansion to cases where both parts can be concatenated.
2208+
if (LoVT != HiVT || !TLI.isTypeLegal(LoVT))
2209+
return SDValue();
2210+
2211+
// Don't expand if the result is likely to be unrolled anyway.
2212+
if (!TLI.isOperationLegalOrCustom(Opcode, LoVT) &&
2213+
!TLI.isOperationLegalOrCustom(Opcode,
2214+
LoVT.changeVectorElementType(MVT::f32)))
2215+
return SDValue();
2216+
2217+
SmallVector<SDValue, 4> LoOps, HiOps;
2218+
for (const SDValue &V : Node->op_values()) {
2219+
SDValue Lo, Hi;
2220+
std::tie(Lo, Hi) = DAG.SplitVector(V, DL, LoVT, HiVT);
2221+
LoOps.push_back(Lo);
2222+
HiOps.push_back(Hi);
2223+
}
2224+
2225+
SDValue SplitOpLo = DAG.getNode(Opcode, DL, LoVT, LoOps);
2226+
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
2227+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
2228+
}
2229+
21372230
void VectorLegalizer::UnrollStrictFPOp(SDNode *Node,
21382231
SmallVectorImpl<SDValue> &Results) {
21392232
EVT VT = Node->getValueType(0);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1663,12 +1663,44 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
16631663
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
16641664
setOperationAction(ISD::BITCAST, VT, Custom);
16651665
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1666+
setOperationAction(ISD::FABS, VT, Legal);
1667+
setOperationAction(ISD::FCEIL, VT, Expand);
1668+
setOperationAction(ISD::FDIV, VT, Expand);
1669+
setOperationAction(ISD::FFLOOR, VT, Expand);
1670+
setOperationAction(ISD::FNEARBYINT, VT, Expand);
1671+
setOperationAction(ISD::FNEG, VT, Legal);
16661672
setOperationAction(ISD::FP_EXTEND, VT, Custom);
16671673
setOperationAction(ISD::FP_ROUND, VT, Custom);
1668-
setOperationAction(ISD::MLOAD, VT, Custom);
1674+
setOperationAction(ISD::FRINT, VT, Expand);
1675+
setOperationAction(ISD::FROUND, VT, Expand);
1676+
setOperationAction(ISD::FROUNDEVEN, VT, Expand);
1677+
setOperationAction(ISD::FSQRT, VT, Expand);
1678+
setOperationAction(ISD::FTRUNC, VT, Expand);
16691679
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1680+
setOperationAction(ISD::MLOAD, VT, Custom);
16701681
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
16711682
setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1683+
1684+
if (!Subtarget->hasSVEB16B16()) {
1685+
setOperationAction(ISD::FADD, VT, Expand);
1686+
setOperationAction(ISD::FMA, VT, Expand);
1687+
setOperationAction(ISD::FMAXIMUM, VT, Expand);
1688+
setOperationAction(ISD::FMAXNUM, VT, Expand);
1689+
setOperationAction(ISD::FMINIMUM, VT, Expand);
1690+
setOperationAction(ISD::FMINNUM, VT, Expand);
1691+
setOperationAction(ISD::FMUL, VT, Expand);
1692+
setOperationAction(ISD::FSUB, VT, Expand);
1693+
1694+
} else {
1695+
setOperationAction(ISD::FADD, VT, Legal);
1696+
setOperationAction(ISD::FMA, VT, Custom);
1697+
setOperationAction(ISD::FMAXIMUM, VT, Custom);
1698+
setOperationAction(ISD::FMAXNUM, VT, Custom);
1699+
setOperationAction(ISD::FMINIMUM, VT, Custom);
1700+
setOperationAction(ISD::FMINNUM, VT, Custom);
1701+
setOperationAction(ISD::FMUL, VT, Legal);
1702+
setOperationAction(ISD::FSUB, VT, Legal);
1703+
}
16721704
}
16731705

16741706
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,13 @@ let Predicates = [HasSVEorSME] in {
663663
defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", AArch64fabs_mt>;
664664
defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>;
665665

666+
foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
667+
def : Pat<(VT (fabs VT:$op)),
668+
(AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>;
669+
def : Pat<(VT (fneg VT:$op)),
670+
(EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>;
671+
}
672+
666673
// zext(cmpeq(x, splat(0))) -> cnot(x)
667674
def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))),
668675
(CNOT_ZPmZ_B $Op2, $Pg, $Op2)>;

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,6 +2299,8 @@ multiclass sve_fp_3op_u_zd_bfloat<bits<3> opc, string asm, SDPatternOperator op>
22992299
def NAME : sve_fp_3op_u_zd<0b00, opc, asm, ZPR16>;
23002300

23012301
def : SVE_2_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
2302+
def : SVE_2_Op_Pat<nxv4bf16, op, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
2303+
def : SVE_2_Op_Pat<nxv2bf16, op, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
23022304
}
23032305

23042306
multiclass sve_fp_3op_u_zd_ftsmul<bits<3> opc, string asm, SDPatternOperator op> {
@@ -9078,6 +9080,8 @@ multiclass sve_fp_bin_pred_bfloat<SDPatternOperator op> {
90789080
def _UNDEF : PredTwoOpPseudo<NAME, ZPR16, FalseLanesUndef>;
90799081

90809082
def : SVE_3_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, !cast<Pseudo>(NAME # _UNDEF)>;
9083+
def : SVE_3_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, !cast<Pseudo>(NAME # _UNDEF)>;
9084+
def : SVE_3_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, !cast<Pseudo>(NAME # _UNDEF)>;
90819085
}
90829086

90839087
// Predicated pseudo floating point three operand instructions.
@@ -9099,6 +9103,8 @@ multiclass sve_fp_3op_pred_bfloat<SDPatternOperator op> {
90999103
def _UNDEF : PredThreeOpPseudo<NAME, ZPR16, FalseLanesUndef>;
91009104

91019105
def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME # _UNDEF)>;
9106+
def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME # _UNDEF)>;
9107+
def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME # _UNDEF)>;
91029108
}
91039109

91049110
// Predicated pseudo integer two operand instructions.

llvm/test/CodeGen/AArch64/atomicrmw-fmax.ll

Lines changed: 30 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -652,90 +652,46 @@ define <2 x half> @test_atomicrmw_fmax_v2f16_seq_cst_align4(ptr %ptr, <2 x half>
652652
define <2 x bfloat> @test_atomicrmw_fmax_v2bf16_seq_cst_align4(ptr %ptr, <2 x bfloat> %value) #0 {
653653
; NOLSE-LABEL: test_atomicrmw_fmax_v2bf16_seq_cst_align4:
654654
; NOLSE: // %bb.0:
655-
; NOLSE-NEXT: // kill: def $d0 killed $d0 def $q0
656-
; NOLSE-NEXT: mov h1, v0.h[1]
657-
; NOLSE-NEXT: fmov w10, s0
658-
; NOLSE-NEXT: mov w8, #32767 // =0x7fff
659-
; NOLSE-NEXT: lsl w10, w10, #16
660-
; NOLSE-NEXT: fmov w9, s1
661-
; NOLSE-NEXT: fmov s1, w10
662-
; NOLSE-NEXT: lsl w9, w9, #16
663-
; NOLSE-NEXT: fmov s0, w9
655+
; NOLSE-NEXT: movi v1.4s, #1
656+
; NOLSE-NEXT: movi v2.4s, #127, msl #8
657+
; NOLSE-NEXT: shll v0.4s, v0.4h, #16
664658
; NOLSE-NEXT: .LBB7_1: // %atomicrmw.start
665659
; NOLSE-NEXT: // =>This Inner Loop Header: Depth=1
666-
; NOLSE-NEXT: ldaxr w9, [x0]
667-
; NOLSE-NEXT: fmov s2, w9
668-
; NOLSE-NEXT: mov h3, v2.h[1]
669-
; NOLSE-NEXT: fmov w11, s2
670-
; NOLSE-NEXT: lsl w11, w11, #16
671-
; NOLSE-NEXT: fmov w10, s3
672-
; NOLSE-NEXT: fmov s3, w11
673-
; NOLSE-NEXT: lsl w10, w10, #16
674-
; NOLSE-NEXT: fmaxnm s3, s3, s1
675-
; NOLSE-NEXT: fmov s2, w10
676-
; NOLSE-NEXT: fmaxnm s2, s2, s0
677-
; NOLSE-NEXT: fmov w11, s3
678-
; NOLSE-NEXT: ubfx w13, w11, #16, #1
679-
; NOLSE-NEXT: add w11, w11, w8
680-
; NOLSE-NEXT: fmov w10, s2
681-
; NOLSE-NEXT: add w11, w13, w11
682-
; NOLSE-NEXT: lsr w11, w11, #16
683-
; NOLSE-NEXT: ubfx w12, w10, #16, #1
684-
; NOLSE-NEXT: add w10, w10, w8
685-
; NOLSE-NEXT: fmov s3, w11
686-
; NOLSE-NEXT: add w10, w12, w10
687-
; NOLSE-NEXT: lsr w10, w10, #16
688-
; NOLSE-NEXT: fmov s2, w10
689-
; NOLSE-NEXT: mov v3.h[1], v2.h[0]
690-
; NOLSE-NEXT: fmov w10, s3
691-
; NOLSE-NEXT: stlxr w11, w10, [x0]
692-
; NOLSE-NEXT: cbnz w11, .LBB7_1
660+
; NOLSE-NEXT: ldaxr w8, [x0]
661+
; NOLSE-NEXT: fmov s3, w8
662+
; NOLSE-NEXT: shll v3.4s, v3.4h, #16
663+
; NOLSE-NEXT: fmaxnm v3.4s, v3.4s, v0.4s
664+
; NOLSE-NEXT: ushr v4.4s, v3.4s, #16
665+
; NOLSE-NEXT: and v4.16b, v4.16b, v1.16b
666+
; NOLSE-NEXT: add v3.4s, v4.4s, v3.4s
667+
; NOLSE-NEXT: addhn v3.4h, v3.4s, v2.4s
668+
; NOLSE-NEXT: fmov w9, s3
669+
; NOLSE-NEXT: stlxr w10, w9, [x0]
670+
; NOLSE-NEXT: cbnz w10, .LBB7_1
693671
; NOLSE-NEXT: // %bb.2: // %atomicrmw.end
694-
; NOLSE-NEXT: fmov d0, x9
672+
; NOLSE-NEXT: fmov d0, x8
695673
; NOLSE-NEXT: ret
696674
;
697675
; LSE-LABEL: test_atomicrmw_fmax_v2bf16_seq_cst_align4:
698676
; LSE: // %bb.0:
699-
; LSE-NEXT: // kill: def $d0 killed $d0 def $q0
700-
; LSE-NEXT: mov h1, v0.h[1]
701-
; LSE-NEXT: fmov w10, s0
702-
; LSE-NEXT: mov w8, #32767 // =0x7fff
677+
; LSE-NEXT: movi v1.4s, #1
678+
; LSE-NEXT: movi v2.4s, #127, msl #8
679+
; LSE-NEXT: shll v3.4s, v0.4h, #16
703680
; LSE-NEXT: ldr s0, [x0]
704-
; LSE-NEXT: lsl w10, w10, #16
705-
; LSE-NEXT: fmov w9, s1
706-
; LSE-NEXT: fmov s2, w10
707-
; LSE-NEXT: lsl w9, w9, #16
708-
; LSE-NEXT: fmov s1, w9
709681
; LSE-NEXT: .LBB7_1: // %atomicrmw.start
710682
; LSE-NEXT: // =>This Inner Loop Header: Depth=1
711-
; LSE-NEXT: mov h3, v0.h[1]
712-
; LSE-NEXT: fmov w10, s0
713-
; LSE-NEXT: lsl w10, w10, #16
714-
; LSE-NEXT: fmov w9, s3
715-
; LSE-NEXT: fmov s4, w10
716-
; LSE-NEXT: lsl w9, w9, #16
717-
; LSE-NEXT: fmaxnm s4, s4, s2
718-
; LSE-NEXT: fmov s3, w9
719-
; LSE-NEXT: fmaxnm s3, s3, s1
720-
; LSE-NEXT: fmov w10, s4
721-
; LSE-NEXT: ubfx w12, w10, #16, #1
722-
; LSE-NEXT: add w10, w10, w8
723-
; LSE-NEXT: fmov w9, s3
724-
; LSE-NEXT: add w10, w12, w10
725-
; LSE-NEXT: lsr w10, w10, #16
726-
; LSE-NEXT: ubfx w11, w9, #16, #1
727-
; LSE-NEXT: add w9, w9, w8
728-
; LSE-NEXT: fmov s4, w10
729-
; LSE-NEXT: add w9, w11, w9
730-
; LSE-NEXT: lsr w9, w9, #16
731-
; LSE-NEXT: fmov s3, w9
732-
; LSE-NEXT: fmov w9, s0
733-
; LSE-NEXT: mov v4.h[1], v3.h[0]
734-
; LSE-NEXT: mov w11, w9
735-
; LSE-NEXT: fmov w10, s4
736-
; LSE-NEXT: casal w11, w10, [x0]
737-
; LSE-NEXT: fmov s0, w11
738-
; LSE-NEXT: cmp w11, w9
683+
; LSE-NEXT: shll v4.4s, v0.4h, #16
684+
; LSE-NEXT: fmov w8, s0
685+
; LSE-NEXT: fmaxnm v4.4s, v4.4s, v3.4s
686+
; LSE-NEXT: mov w10, w8
687+
; LSE-NEXT: ushr v5.4s, v4.4s, #16
688+
; LSE-NEXT: and v5.16b, v5.16b, v1.16b
689+
; LSE-NEXT: add v4.4s, v5.4s, v4.4s
690+
; LSE-NEXT: addhn v4.4h, v4.4s, v2.4s
691+
; LSE-NEXT: fmov w9, s4
692+
; LSE-NEXT: casal w10, w9, [x0]
693+
; LSE-NEXT: fmov s0, w10
694+
; LSE-NEXT: cmp w10, w8
739695
; LSE-NEXT: b.ne .LBB7_1
740696
; LSE-NEXT: // %bb.2: // %atomicrmw.end
741697
; LSE-NEXT: // kill: def $d0 killed $d0 killed $q0

0 commit comments

Comments
 (0)