@@ -6515,72 +6515,75 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
6515
6515
llvm_unreachable (" Unexpected overflow result" );
6516
6516
}
6517
6517
6518
- // / Recognize and process idiom involving test for multiplication
6518
+ // / Recognize and process idiom involving test for unsigned
6519
6519
// / overflow.
6520
6520
// /
6521
6521
// / The caller has matched a pattern of the form:
6522
+ // / I = cmp u (add(zext A, zext B), V
6522
6523
// / I = cmp u (mul(zext A, zext B), V
6523
6524
// / The function checks if this is a test for overflow and if so replaces
6524
- // / multiplication with call to 'mul.with.overflow' intrinsic.
6525
+ // / addition with call to the right intrinsic.
6525
6526
// /
6526
6527
// / \param I Compare instruction.
6527
- // / \param MulVal Result of 'mult' instruction. It is one of the arguments of
6528
+ // / \param Val Result of instruction. It is one of the arguments of
6528
6529
// / the compare instruction. Must be of integer type.
6529
6530
// / \param OtherVal The other argument of compare instruction.
6530
6531
// / \returns Instruction which must replace the compare instruction, NULL if no
6531
6532
// / replacement required.
6532
- static Instruction *processUMulZExtIdiom (ICmpInst &I, Value *MulVal ,
6533
- const APInt *OtherVal,
6534
- InstCombinerImpl &IC) {
6533
+ static Instruction *processUZExtIdiom (ICmpInst &I, Value *Val ,
6534
+ const APInt *OtherVal,
6535
+ InstCombinerImpl &IC) {
6535
6536
// Don't bother doing this transformation for pointers, don't do it for
6536
6537
// vectors.
6537
- if (!isa<IntegerType>(MulVal ->getType ()))
6538
+ if (!isa<IntegerType>(Val ->getType ()))
6538
6539
return nullptr ;
6539
6540
6540
- auto *MulInstr = dyn_cast<Instruction>(MulVal );
6541
- if (!MulInstr )
6541
+ auto *Instr = dyn_cast<Instruction>(Val );
6542
+ if (!Instr )
6542
6543
return nullptr ;
6543
- assert (MulInstr->getOpcode () == Instruction::Mul);
6544
6544
6545
- auto *LHS = cast<ZExtInst>(MulInstr->getOperand (0 )),
6546
- *RHS = cast<ZExtInst>(MulInstr->getOperand (1 ));
6545
+ unsigned Opcode = Instr->getOpcode ();
6546
+ assert (Opcode == Instruction::Add || Opcode == Instruction::Mul);
6547
+
6548
+ auto *LHS = cast<ZExtInst>(Instr->getOperand (0 )),
6549
+ *RHS = cast<ZExtInst>(Instr->getOperand (1 ));
6547
6550
assert (LHS->getOpcode () == Instruction::ZExt);
6548
6551
assert (RHS->getOpcode () == Instruction::ZExt);
6549
6552
Value *A = LHS->getOperand (0 ), *B = RHS->getOperand (0 );
6550
6553
6551
- // Calculate type and width of the result produced by mul.with.overflow.
6554
+ // Calculate type and width of the result produced by add/ mul.with.overflow.
6552
6555
Type *TyA = A->getType (), *TyB = B->getType ();
6553
6556
unsigned WidthA = TyA->getPrimitiveSizeInBits (),
6554
6557
WidthB = TyB->getPrimitiveSizeInBits ();
6555
- unsigned MulWidth ;
6556
- Type *MulType ;
6558
+ unsigned ResultWidth ;
6559
+ Type *ResultType ;
6557
6560
if (WidthB > WidthA) {
6558
- MulWidth = WidthB;
6559
- MulType = TyB;
6561
+ ResultWidth = WidthB;
6562
+ ResultType = TyB;
6560
6563
} else {
6561
- MulWidth = WidthA;
6562
- MulType = TyA;
6564
+ ResultWidth = WidthA;
6565
+ ResultType = TyA;
6563
6566
}
6564
6567
6565
- // In order to replace the original mul with a narrower mul.with.overflow,
6566
- // all uses must ignore upper bits of the product . The number of used low
6567
- // bits must be not greater than the width of mul.with.overflow.
6568
- if (MulVal ->hasNUsesOrMore (2 ))
6569
- for (User *U : MulVal ->users ()) {
6568
+ // In order to replace the original result with an add/ mul.with.overflow
6569
+ // intrinsic, all uses must ignore upper bits of the result . The number of
6570
+ // used low bits must be not greater than the width of add/ mul.with.overflow.
6571
+ if (Val ->hasNUsesOrMore (2 ))
6572
+ for (User *U : Val ->users ()) {
6570
6573
if (U == &I)
6571
6574
continue ;
6572
6575
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6573
- // Check if truncation ignores bits above MulWidth .
6576
+ // Check if truncation ignores bits above ResultWidth .
6574
6577
unsigned TruncWidth = TI->getType ()->getPrimitiveSizeInBits ();
6575
- if (TruncWidth > MulWidth )
6578
+ if (TruncWidth > ResultWidth )
6576
6579
return nullptr ;
6577
6580
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
6578
- // Check if AND ignores bits above MulWidth .
6581
+ // Check if AND ignores bits above ResultWidth .
6579
6582
if (BO->getOpcode () != Instruction::And)
6580
6583
return nullptr ;
6581
6584
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand (1 ))) {
6582
6585
const APInt &CVal = CI->getValue ();
6583
- if (CVal.getBitWidth () - CVal.countl_zero () > MulWidth )
6586
+ if (CVal.getBitWidth () - CVal.countl_zero () > ResultWidth )
6584
6587
return nullptr ;
6585
6588
} else {
6586
6589
// In this case we could have the operand of the binary operation
@@ -6598,9 +6601,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
6598
6601
switch (I.getPredicate ()) {
6599
6602
case ICmpInst::ICMP_UGT: {
6600
6603
// Recognize pattern:
6601
- // mulval = mul(zext A, zext B)
6602
- // cmp ugt mulval , max
6603
- APInt MaxVal = APInt::getMaxValue (MulWidth );
6604
+ // val = add/ mul(zext A, zext B)
6605
+ // cmp ugt val , max
6606
+ APInt MaxVal = APInt::getMaxValue (ResultWidth );
6604
6607
MaxVal = MaxVal.zext (OtherVal->getBitWidth ());
6605
6608
if (MaxVal.eq (*OtherVal))
6606
6609
break ; // Recognized
@@ -6609,9 +6612,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
6609
6612
6610
6613
case ICmpInst::ICMP_ULT: {
6611
6614
// Recognize pattern:
6612
- // mulval = mul(zext A, zext B)
6613
- // cmp ule mulval , max + 1
6614
- APInt MaxVal = APInt::getOneBitSet (OtherVal->getBitWidth (), MulWidth );
6615
+ // val = add/ mul(zext A, zext B)
6616
+ // cmp ule val , max + 1
6617
+ APInt MaxVal = APInt::getOneBitSet (OtherVal->getBitWidth (), ResultWidth );
6615
6618
if (MaxVal.eq (*OtherVal))
6616
6619
break ; // Recognized
6617
6620
return nullptr ;
@@ -6622,38 +6625,42 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
6622
6625
}
6623
6626
6624
6627
InstCombiner::BuilderTy &Builder = IC.Builder ;
6625
- Builder.SetInsertPoint (MulInstr);
6626
-
6627
- // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
6628
- Value *MulA = A, *MulB = B;
6629
- if (WidthA < MulWidth)
6630
- MulA = Builder.CreateZExt (A, MulType);
6631
- if (WidthB < MulWidth)
6632
- MulB = Builder.CreateZExt (B, MulType);
6633
- CallInst *Call =
6634
- Builder.CreateIntrinsic (Intrinsic::umul_with_overflow, MulType,
6635
- {MulA, MulB}, /* FMFSource=*/ nullptr , " umul" );
6636
- IC.addToWorklist (MulInstr);
6637
-
6638
- // If there are uses of mul result other than the comparison, we know that
6628
+ Builder.SetInsertPoint (Instr);
6629
+
6630
+ // Replace: add/mul(zext A, zext B) --> add/mul.with.overflow(A, B)
6631
+ Value *ResultA = A, *ResultB = B;
6632
+ if (WidthA < ResultWidth)
6633
+ ResultA = Builder.CreateZExt (A, ResultType);
6634
+ if (WidthB < ResultWidth)
6635
+ ResultB = Builder.CreateZExt (B, ResultType);
6636
+ CallInst *Call = Builder.CreateIntrinsic (
6637
+ Opcode == Instruction::Add ? Intrinsic::uadd_with_overflow
6638
+ : Intrinsic::umul_with_overflow,
6639
+ ResultType, {ResultA, ResultB}, /* FMFSource=*/ nullptr ,
6640
+ Intrinsic::uadd_with_overflow ? " uadd" : " umul" );
6641
+ IC.addToWorklist (Instr);
6642
+
6643
+ // If there are uses of add result other than the comparison, we know that
6639
6644
// they are truncation or binary AND. Change them to use result of
6640
- // mul.with.overflow and adjust properly mask/size.
6641
- if (MulVal->hasNUsesOrMore (2 )) {
6642
- Value *Mul = Builder.CreateExtractValue (Call, 0 , " umul.value" );
6643
- for (User *U : make_early_inc_range (MulVal->users ())) {
6645
+ // add/mul.with.overflow and adjust properly mask/size.
6646
+ if (Val->hasNUsesOrMore (2 )) {
6647
+ Value *Extract = Builder.CreateExtractValue (
6648
+ Call, 0 , Instruction::Add ? " uadd.value" : " umul.value" );
6649
+ for (User *U : make_early_inc_range (Val->users ())) {
6644
6650
if (U == &I)
6645
6651
continue ;
6646
6652
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6647
- if (TI->getType ()->getPrimitiveSizeInBits () == MulWidth )
6648
- IC.replaceInstUsesWith (*TI, Mul );
6653
+ if (TI->getType ()->getPrimitiveSizeInBits () == ResultWidth )
6654
+ IC.replaceInstUsesWith (*TI, Extract );
6649
6655
else
6650
- TI->setOperand (0 , Mul );
6656
+ TI->setOperand (0 , Extract );
6651
6657
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
6652
6658
assert (BO->getOpcode () == Instruction::And);
6653
- // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
6659
+ // Replace (Extract & mask) --> zext (add/mul.with.overflow &
6660
+ // short_mask)
6654
6661
ConstantInt *CI = cast<ConstantInt>(BO->getOperand (1 ));
6655
- APInt ShortMask = CI->getValue ().trunc (MulWidth );
6656
- Value *ShortAnd = Builder.CreateAnd (Mul , ShortMask);
6662
+ APInt ShortMask = CI->getValue ().trunc (ResultWidth );
6663
+ Value *ShortAnd = Builder.CreateAnd (Extract , ShortMask);
6657
6664
Value *Zext = Builder.CreateZExt (ShortAnd, BO->getType ());
6658
6665
IC.replaceInstUsesWith (*BO, Zext);
6659
6666
} else {
@@ -7078,7 +7085,7 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
7078
7085
// icmp eq X, (zext (icmp ne X, 0)) --> X == 0 || X == 1
7079
7086
// icmp ne X, (zext (icmp ne X, 0)) --> X != 0 && X != 1
7080
7087
// icmp eq X, (sext (icmp ne X, 0)) --> X == 0 || X == -1
7081
- // icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X = = -1
7088
+ // icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X ! = -1
7082
7089
return CreateRangeCheck ();
7083
7090
}
7084
7091
} else if (IsSExt ? C->isAllOnes () : C->isOne ()) {
@@ -7791,10 +7798,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
7791
7798
}
7792
7799
}
7793
7800
7801
+ // (zext X) + (zext Y) --> llvm.uadd.with.overflow.
7794
7802
// (zext X) * (zext Y) --> llvm.umul.with.overflow.
7795
- if (match (Op0, m_NUWMul (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y)))) &&
7803
+ if ((match (Op0, m_NUWAdd (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y)))) ||
7804
+ match (Op0, m_NUWMul (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y))))) &&
7796
7805
match (Op1, m_APInt (C))) {
7797
- if (Instruction *R = processUMulZExtIdiom (I, Op0, C, *this ))
7806
+ if (Instruction *R = processUZExtIdiom (I, Op0, C, *this ))
7798
7807
return R;
7799
7808
}
7800
7809
0 commit comments