@@ -54214,22 +54214,31 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
54214
54214
// cases.
54215
54215
static SDValue combinei64TruncSrlConstant(SDValue N, EVT VT, SelectionDAG &DAG,
54216
54216
const SDLoc &DL) {
54217
+ assert(N.getOpcode() == ISD::SRL && "Unknown shift opcode");
54218
+ std::optional<uint64_t> ValidSrlConst = DAG.getValidShiftAmount(N);
54219
+ if (!ValidSrlConst)
54220
+ return SDValue();
54221
+ uint64_t SrlConstVal = *ValidSrlConst;
54217
54222
54218
54223
SDValue Op = N.getOperand(0);
54219
- APInt OpConst = Op.getConstantOperandAPInt(1);
54220
- APInt SrlConst = N.getConstantOperandAPInt(1);
54221
- uint64_t SrlConstVal = SrlConst.getZExtValue();
54222
54224
unsigned Opcode = Op.getOpcode();
54225
+ assert(VT == MVT::i32 && Op.getValueType() == MVT::i64 &&
54226
+ "Illegal truncation types");
54227
+
54228
+ if ((Opcode != ISD::ADD && Opcode != ISD::OR && Opcode != ISD::XOR) ||
54229
+ !isa<ConstantSDNode>(Op.getOperand(1)))
54230
+ return SDValue();
54231
+ const APInt &OpConst = Op.getConstantOperandAPInt(1);
54223
54232
54224
- if (SrlConst.ule(32) ||
54233
+ if (SrlConstVal <= 32 ||
54225
54234
(Opcode == ISD::ADD && OpConst.countr_zero() < SrlConstVal))
54226
54235
return SDValue();
54227
54236
54228
54237
SDValue OpLhsSrl =
54229
54238
DAG.getNode(ISD::SRL, DL, MVT::i64, Op.getOperand(0), N.getOperand(1));
54230
54239
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, VT, OpLhsSrl);
54231
54240
54232
- APInt NewOpConstVal = OpConst.lshr(SrlConst ).trunc(VT.getSizeInBits());
54241
+ APInt NewOpConstVal = OpConst.lshr(SrlConstVal ).trunc(VT.getSizeInBits());
54233
54242
SDValue NewOpConst = DAG.getConstant(NewOpConstVal, DL, VT);
54234
54243
SDValue NewOpNode = DAG.getNode(Opcode, DL, VT, Trunc, NewOpConst);
54235
54244
@@ -54285,20 +54294,8 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
54285
54294
if (!Src.hasOneUse())
54286
54295
return SDValue();
54287
54296
54288
- if (VT == MVT::i32 && SrcVT == MVT::i64 && SrcOpcode == ISD::SRL &&
54289
- isa<ConstantSDNode>(Src.getOperand(1))) {
54290
-
54291
- unsigned SrcOpOpcode = Src.getOperand(0).getOpcode();
54292
- if ((SrcOpOpcode != ISD::ADD && SrcOpOpcode != ISD::OR &&
54293
- SrcOpOpcode != ISD::XOR) ||
54294
- !isa<ConstantSDNode>(Src.getOperand(0).getOperand(1)))
54295
- return SDValue();
54296
-
54297
- if (SDValue R = combinei64TruncSrlConstant(Src, VT, DAG, DL))
54298
- return R;
54299
-
54300
- return SDValue();
54301
- }
54297
+ if (VT == MVT::i32 && SrcVT == MVT::i64 && SrcOpcode == ISD::SRL)
54298
+ return combinei64TruncSrlConstant(Src, VT, DAG, DL);
54302
54299
54303
54300
if (!VT.isVector())
54304
54301
return SDValue();
0 commit comments