@@ -3787,7 +3787,7 @@ static SDValue insert128BitVector(SDValue Result, SDValue Vec, unsigned IdxVal,
3787
3787
static SDValue widenSubVector(MVT VT, SDValue Vec, bool ZeroNewElements,
3788
3788
const X86Subtarget &Subtarget, SelectionDAG &DAG,
3789
3789
const SDLoc &dl) {
3790
- assert(Vec.getValueSizeInBits().getFixedValue() < VT.getFixedSizeInBits() &&
3790
+ assert(Vec.getValueSizeInBits().getFixedValue() <= VT.getFixedSizeInBits() &&
3791
3791
Vec.getValueType().getScalarType() == VT.getScalarType() &&
3792
3792
"Unsupported vector widening type");
3793
3793
SDValue Res = ZeroNewElements ? getZeroVector(VT, Subtarget, DAG, dl)
@@ -3801,7 +3801,7 @@ static SDValue widenSubVector(MVT VT, SDValue Vec, bool ZeroNewElements,
3801
3801
static SDValue widenSubVector(SDValue Vec, bool ZeroNewElements,
3802
3802
const X86Subtarget &Subtarget, SelectionDAG &DAG,
3803
3803
const SDLoc &dl, unsigned WideSizeInBits) {
3804
- assert(Vec.getValueSizeInBits() < WideSizeInBits &&
3804
+ assert(Vec.getValueSizeInBits() <= WideSizeInBits &&
3805
3805
(WideSizeInBits % Vec.getScalarValueSizeInBits()) == 0 &&
3806
3806
"Unsupported vector widening type");
3807
3807
unsigned WideNumElts = WideSizeInBits / Vec.getScalarValueSizeInBits();
@@ -19982,22 +19982,18 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In,
19982
19982
if (SrcVT == DstVT)
19983
19983
return In;
19984
19984
19985
- // We only support vector truncation to 64bits or greater from a
19986
- // 128bits or greater source.
19987
- unsigned DstSizeInBits = DstVT.getSizeInBits();
19988
- unsigned SrcSizeInBits = SrcVT.getSizeInBits();
19989
- if ((DstSizeInBits % 64) != 0 || (SrcSizeInBits % 128) != 0)
19990
- return SDValue();
19991
-
19992
19985
unsigned NumElems = SrcVT.getVectorNumElements();
19993
19986
if (!isPowerOf2_32(NumElems))
19994
19987
return SDValue();
19995
19988
19996
- LLVMContext &Ctx = *DAG.getContext();
19989
+ unsigned DstSizeInBits = DstVT.getSizeInBits();
19990
+ unsigned SrcSizeInBits = SrcVT.getSizeInBits();
19997
19991
assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
19998
19992
assert(SrcSizeInBits > DstSizeInBits && "Illegal truncation");
19999
19993
19994
+ LLVMContext &Ctx = *DAG.getContext();
20000
19995
EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2);
19996
+ EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
20001
19997
20002
19998
// Pack to the largest type possible:
20003
19999
// vXi64/vXi32 -> PACK*SDW and vXi16 -> PACK*SWB.
@@ -20008,14 +20004,16 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In,
20008
20004
OutVT = MVT::i16;
20009
20005
}
20010
20006
20011
- // 128bit -> 64bit truncate - PACK 128-bit src in the lower subvector .
20012
- if (SrcVT.is128BitVector() ) {
20007
+ // Sub-128-bit truncation - widen to 128-bit src and pack in the lower half .
20008
+ if (SrcSizeInBits <= 128 ) {
20013
20009
InVT = EVT::getVectorVT(Ctx, InVT, 128 / InVT.getSizeInBits());
20014
20010
OutVT = EVT::getVectorVT(Ctx, OutVT, 128 / OutVT.getSizeInBits());
20011
+ In = widenSubVector(In, false, Subtarget, DAG, DL, 128);
20015
20012
In = DAG.getBitcast(InVT, In);
20016
20013
SDValue Res = DAG.getNode(Opcode, DL, OutVT, In, DAG.getUNDEF(InVT));
20017
- Res = extractSubVector(Res, 0, DAG, DL, 64);
20018
- return DAG.getBitcast(DstVT, Res);
20014
+ Res = extractSubVector(Res, 0, DAG, DL, SrcSizeInBits / 2);
20015
+ Res = DAG.getBitcast(PackedVT, Res);
20016
+ return truncateVectorWithPACK(Opcode, DstVT, Res, DL, DAG, Subtarget);
20019
20017
}
20020
20018
20021
20019
// Split lower/upper subvectors.
@@ -20061,15 +20059,13 @@ static SDValue truncateVectorWithPACK(unsigned Opcode, EVT DstVT, SDValue In,
20061
20059
return DAG.getBitcast(DstVT, Res);
20062
20060
20063
20061
// If 512bit -> 128bit truncate another stage.
20064
- EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
20065
20062
Res = DAG.getBitcast(PackedVT, Res);
20066
20063
return truncateVectorWithPACK(Opcode, DstVT, Res, DL, DAG, Subtarget);
20067
20064
}
20068
20065
20069
20066
// Recursively pack lower/upper subvectors, concat result and pack again.
20070
20067
assert(SrcSizeInBits >= 256 && "Expected 256-bit vector or greater");
20071
20068
20072
- EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
20073
20069
if (PackedVT.is128BitVector()) {
20074
20070
// Avoid CONCAT_VECTORS on sub-128bit nodes as these can fail after
20075
20071
// type legalization.
@@ -50833,6 +50829,10 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL,
50833
50829
if (SVT == MVT::i32 && VT.getSizeInBits() < 128)
50834
50830
return SDValue();
50835
50831
50832
+ // Truncation from sub-128bit to vXi8 can be better handled with PSHUFB.
50833
+ if (SVT == MVT::i8 && InVT.getSizeInBits() <= 128 && Subtarget.hasSSSE3())
50834
+ return SDValue();
50835
+
50836
50836
// AVX512 has fast truncate, but if the input is already going to be split,
50837
50837
// there's no harm in trying pack.
50838
50838
if (Subtarget.hasAVX512() &&
0 commit comments