Skip to content

Commit 0c89cbb

Browse files
authored
[X86][FP16] Widen 128/256-bit CVTTP2xI to 512-bit when VLX not enabled (#142763)
1 parent 0487db1 commit 0c89cbb

File tree

3 files changed

+448
-15
lines changed

3 files changed

+448
-15
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,6 +2354,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
23542354
setOperationAction(ISD::LLRINT, MVT::v8f16, Legal);
23552355
}
23562356

2357+
setOperationAction(ISD::FP_TO_SINT, MVT::v8i16, Custom);
2358+
setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::v8i16, Custom);
2359+
setOperationAction(ISD::FP_TO_UINT, MVT::v8i16, Custom);
2360+
setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::v8i16, Custom);
2361+
23572362
if (Subtarget.hasVLX()) {
23582363
setGroup(MVT::v8f16);
23592364
setGroup(MVT::v16f16);
@@ -2369,10 +2374,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
23692374
setOperationAction(ISD::UINT_TO_FP, MVT::v8i16, Legal);
23702375
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::v8i16, Legal);
23712376

2372-
setOperationAction(ISD::FP_TO_SINT, MVT::v8i16, Custom);
2373-
setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::v8i16, Custom);
2374-
setOperationAction(ISD::FP_TO_UINT, MVT::v8i16, Custom);
2375-
setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::v8i16, Custom);
23762377
setOperationAction(ISD::FP_ROUND, MVT::v8f16, Legal);
23772378
setOperationAction(ISD::STRICT_FP_ROUND, MVT::v8f16, Legal);
23782379
setOperationAction(ISD::FP_EXTEND, MVT::v8f32, Custom);
@@ -19999,10 +20000,12 @@ static SDValue promoteXINT_TO_FP(SDValue Op, const SDLoc &dl,
1999920000

2000020001
static bool isLegalConversion(MVT VT, MVT FloatVT, bool IsSigned,
2000120002
const X86Subtarget &Subtarget) {
20002-
if (VT == MVT::v4i32 && Subtarget.hasSSE2() && IsSigned)
20003-
return true;
20004-
if (VT == MVT::v8i32 && Subtarget.hasAVX() && IsSigned)
20005-
return true;
20003+
if (FloatVT.getScalarType() != MVT::f16 || Subtarget.hasVLX()) {
20004+
if (VT == MVT::v4i32 && Subtarget.hasSSE2() && IsSigned)
20005+
return true;
20006+
if (VT == MVT::v8i32 && Subtarget.hasAVX() && IsSigned)
20007+
return true;
20008+
}
2000620009
if (Subtarget.hasVLX() && (VT == MVT::v4i32 || VT == MVT::v8i32))
2000720010
return true;
2000820011
if (Subtarget.useAVX512Regs()) {
@@ -21541,6 +21544,7 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
2154121544
bool IsStrict = Op->isStrictFPOpcode();
2154221545
bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT ||
2154321546
Op.getOpcode() == ISD::STRICT_FP_TO_SINT;
21547+
bool HasVLX = Subtarget.hasVLX();
2154421548
MVT VT = Op->getSimpleValueType(0);
2154521549
SDValue Src = Op.getOperand(IsStrict ? 1 : 0);
2154621550
SDValue Chain = IsStrict ? Op->getOperand(0) : SDValue();
@@ -21571,7 +21575,7 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
2157121575
else
2157221576
Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
2157321577

21574-
if (!IsSigned && !Subtarget.hasVLX()) {
21578+
if (!IsSigned && !HasVLX) {
2157521579
assert(Subtarget.useAVX512Regs() && "Unexpected features!");
2157621580
// Widen to 512-bits.
2157721581
ResVT = MVT::v8i32;
@@ -21601,22 +21605,33 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
2160121605
}
2160221606

2160321607
if (Subtarget.hasFP16() && SrcVT.getVectorElementType() == MVT::f16) {
21604-
if (VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16)
21608+
if ((HasVLX && (VT == MVT::v8i16 || VT == MVT::v16i16)) ||
21609+
VT == MVT::v32i16)
2160521610
return Op;
2160621611

2160721612
MVT ResVT = VT;
2160821613
MVT EleVT = VT.getVectorElementType();
2160921614
if (EleVT != MVT::i64)
2161021615
ResVT = EleVT == MVT::i32 ? MVT::v4i32 : MVT::v8i16;
2161121616

21612-
if (SrcVT != MVT::v8f16) {
21617+
if (SrcVT == MVT::v2f16 || SrcVT == MVT::v4f16) {
2161321618
SDValue Tmp =
2161421619
IsStrict ? DAG.getConstantFP(0.0, dl, SrcVT) : DAG.getUNDEF(SrcVT);
2161521620
SmallVector<SDValue, 4> Ops(SrcVT == MVT::v2f16 ? 4 : 2, Tmp);
2161621621
Ops[0] = Src;
2161721622
Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8f16, Ops);
2161821623
}
2161921624

21625+
if (!HasVLX) {
21626+
assert(Subtarget.useAVX512Regs() && "Unexpected features!");
21627+
// Widen to 512-bits.
21628+
unsigned IntSize = EleVT.getSizeInBits();
21629+
unsigned Num = IntSize > 16 ? 512 / IntSize : 32;
21630+
ResVT = MVT::getVectorVT(EleVT, Num);
21631+
Src = widenSubVector(MVT::getVectorVT(MVT::f16, Num), Src, IsStrict,
21632+
Subtarget, DAG, dl);
21633+
}
21634+
2162021635
if (IsStrict) {
2162121636
Res = DAG.getNode(IsSigned ? X86ISD::STRICT_CVTTP2SI
2162221637
: X86ISD::STRICT_CVTTP2UI,
@@ -21629,7 +21644,8 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
2162921644

2163021645
// TODO: Need to add exception check code for strict FP.
2163121646
if (EleVT.getSizeInBits() < 16) {
21632-
ResVT = MVT::getVectorVT(EleVT, 8);
21647+
if (HasVLX)
21648+
ResVT = MVT::getVectorVT(EleVT, 8);
2163321649
Res = DAG.getNode(ISD::TRUNCATE, dl, ResVT, Res);
2163421650
}
2163521651

@@ -34139,12 +34155,10 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
3413934155
}
3414034156

3414134157
if (IsStrict) {
34142-
Opc = IsSigned ? X86ISD::STRICT_CVTTP2SI : X86ISD::STRICT_CVTTP2UI;
3414334158
Res =
3414434159
DAG.getNode(Opc, dl, {ResVT, MVT::Other}, {N->getOperand(0), Src});
3414534160
Chain = Res.getValue(1);
3414634161
} else {
34147-
Opc = IsSigned ? X86ISD::CVTTP2SI : X86ISD::CVTTP2UI;
3414834162
Res = DAG.getNode(Opc, dl, ResVT, Src);
3414934163
}
3415034164

@@ -44161,7 +44175,12 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
4416144175
// Conversions.
4416244176
// TODO: Add more CVT opcodes when we have test coverage.
4416344177
case X86ISD::CVTTP2SI:
44164-
case X86ISD::CVTTP2UI:
44178+
case X86ISD::CVTTP2UI: {
44179+
if (Op.getOperand(0).getValueType().getVectorElementType() == MVT::f16 &&
44180+
!Subtarget.hasVLX())
44181+
break;
44182+
[[fallthrough]];
44183+
}
4416544184
case X86ISD::CVTPH2PS: {
4416644185
SDLoc DL(Op);
4416744186
unsigned Scale = SizeInBits / ExtSizeInBits;

0 commit comments

Comments
 (0)