@@ -6533,72 +6533,76 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
65336533 llvm_unreachable (" Unexpected overflow result" );
65346534}
65356535
6536- // / Recognize and process idiom involving test for multiplication
6536+ // / Recognize and process idiom involving test for unsigned
65376537// / overflow.
65386538// /
65396539// / The caller has matched a pattern of the form:
6540+ // / I = cmp u (add(zext A, zext B), V
65406541// / I = cmp u (mul(zext A, zext B), V
65416542// / The function checks if this is a test for overflow and if so replaces
6542- // / multiplication with call to 'mul.with.overflow' intrinsic.
6543+ // / addition/multiplication with call to the umul intrinsic or the canonical
6544+ // / form of uadd overflow.
65436545// /
65446546// / \param I Compare instruction.
6545- // / \param MulVal Result of 'mult' instruction. It is one of the arguments of
6546- // / the compare instruction. Must be of integer type.
6547+ // / \param Val Result of add/mul instruction. It is one of the arguments of
6548+ // / the compare instruction. Must be of integer type.
65476549// / \param OtherVal The other argument of compare instruction.
65486550// / \returns Instruction which must replace the compare instruction, NULL if no
65496551// / replacement required.
6550- static Instruction *processUMulZExtIdiom (ICmpInst &I, Value *MulVal ,
6551- const APInt *OtherVal,
6552- InstCombinerImpl &IC) {
6552+ static Instruction *processUZExtIdiom (ICmpInst &I, Value *Val ,
6553+ const APInt *OtherVal,
6554+ InstCombinerImpl &IC) {
65536555 // Don't bother doing this transformation for pointers, don't do it for
65546556 // vectors.
6555- if (!isa<IntegerType>(MulVal ->getType ()))
6557+ if (!isa<IntegerType>(Val ->getType ()))
65566558 return nullptr ;
65576559
6558- auto *MulInstr = dyn_cast<Instruction>(MulVal );
6559- if (!MulInstr )
6560+ auto *Instr = dyn_cast<Instruction>(Val );
6561+ if (!Instr )
65606562 return nullptr ;
6561- assert (MulInstr->getOpcode () == Instruction::Mul);
65626563
6563- auto *LHS = cast<ZExtInst>(MulInstr->getOperand (0 )),
6564- *RHS = cast<ZExtInst>(MulInstr->getOperand (1 ));
6564+ unsigned Opcode = Instr->getOpcode ();
6565+ assert (Opcode == Instruction::Add || Opcode == Instruction::Mul);
6566+
6567+ auto *LHS = cast<ZExtInst>(Instr->getOperand (0 )),
6568+ *RHS = cast<ZExtInst>(Instr->getOperand (1 ));
65656569 assert (LHS->getOpcode () == Instruction::ZExt);
65666570 assert (RHS->getOpcode () == Instruction::ZExt);
65676571 Value *A = LHS->getOperand (0 ), *B = RHS->getOperand (0 );
65686572
6569- // Calculate type and width of the result produced by mul.with.overflow.
6573+ // Calculate type and width of the result produced by add/ mul.with.overflow.
65706574 Type *TyA = A->getType (), *TyB = B->getType ();
65716575 unsigned WidthA = TyA->getPrimitiveSizeInBits (),
65726576 WidthB = TyB->getPrimitiveSizeInBits ();
6573- unsigned MulWidth ;
6574- Type *MulType ;
6577+ unsigned ResultWidth ;
6578+ Type *ResultType ;
65756579 if (WidthB > WidthA) {
6576- MulWidth = WidthB;
6577- MulType = TyB;
6580+ ResultWidth = WidthB;
6581+ ResultType = TyB;
65786582 } else {
6579- MulWidth = WidthA;
6580- MulType = TyA;
6583+ ResultWidth = WidthA;
6584+ ResultType = TyA;
65816585 }
65826586
6583- // In order to replace the original mul with a narrower mul.with.overflow,
6584- // all uses must ignore upper bits of the product. The number of used low
6585- // bits must be not greater than the width of mul.with.overflow.
6586- if (MulVal ->hasNUsesOrMore (2 ))
6587- for (User *U : MulVal ->users ()) {
6587+ // In order to replace the original result with a narrower one, all uses must
6588+ // ignore upper bits of the result. The number of used low bits must be not
6589+ // greater than the width of add or mul.with.overflow.
6590+ if (Val ->hasNUsesOrMore (2 ))
6591+ for (User *U : Val ->users ()) {
65886592 if (U == &I)
65896593 continue ;
65906594 if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6591- // Check if truncation ignores bits above MulWidth .
6595+ // Check if truncation ignores bits above ResultWidth .
65926596 unsigned TruncWidth = TI->getType ()->getPrimitiveSizeInBits ();
6593- if (TruncWidth > MulWidth )
6597+ if (TruncWidth > ResultWidth )
65946598 return nullptr ;
65956599 } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
6596- // Check if AND ignores bits above MulWidth .
6600+ // Check if AND ignores bits above ResultWidth .
65976601 if (BO->getOpcode () != Instruction::And)
65986602 return nullptr ;
65996603 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand (1 ))) {
66006604 const APInt &CVal = CI->getValue ();
6601- if (CVal.getBitWidth () - CVal.countl_zero () > MulWidth )
6605+ if (CVal.getBitWidth () - CVal.countl_zero () > ResultWidth )
66026606 return nullptr ;
66036607 } else {
66046608 // In this case we could have the operand of the binary operation
@@ -6616,9 +6620,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66166620 switch (I.getPredicate ()) {
66176621 case ICmpInst::ICMP_UGT: {
66186622 // Recognize pattern:
6619- // mulval = mul(zext A, zext B)
6620- // cmp ugt mulval , max
6621- APInt MaxVal = APInt::getMaxValue (MulWidth );
6623+ // val = add/ mul(zext A, zext B)
6624+ // cmp ugt val , max
6625+ APInt MaxVal = APInt::getMaxValue (ResultWidth );
66226626 MaxVal = MaxVal.zext (OtherVal->getBitWidth ());
66236627 if (MaxVal.eq (*OtherVal))
66246628 break ; // Recognized
@@ -6627,9 +6631,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66276631
66286632 case ICmpInst::ICMP_ULT: {
66296633 // Recognize pattern:
6630- // mulval = mul(zext A, zext B)
6631- // cmp ule mulval , max + 1
6632- APInt MaxVal = APInt::getOneBitSet (OtherVal->getBitWidth (), MulWidth );
6634+ // val = add/ mul(zext A, zext B)
6635+ // cmp ule val , max + 1
6636+ APInt MaxVal = APInt::getOneBitSet (OtherVal->getBitWidth (), ResultWidth );
66336637 if (MaxVal.eq (*OtherVal))
66346638 break ; // Recognized
66356639 return nullptr ;
@@ -6640,38 +6644,57 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66406644 }
66416645
66426646 InstCombiner::BuilderTy &Builder = IC.Builder ;
6643- Builder.SetInsertPoint (MulInstr);
6644-
6645- // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
6646- Value *MulA = A, *MulB = B;
6647- if (WidthA < MulWidth)
6648- MulA = Builder.CreateZExt (A, MulType);
6649- if (WidthB < MulWidth)
6650- MulB = Builder.CreateZExt (B, MulType);
6651- CallInst *Call =
6652- Builder.CreateIntrinsic (Intrinsic::umul_with_overflow, MulType,
6653- {MulA, MulB}, /* FMFSource=*/ nullptr , " umul" );
6654- IC.addToWorklist (MulInstr);
6655-
6656- // If there are uses of mul result other than the comparison, we know that
6657- // they are truncation or binary AND. Change them to use result of
6658- // mul.with.overflow and adjust properly mask/size.
6659- if (MulVal->hasNUsesOrMore (2 )) {
6660- Value *Mul = Builder.CreateExtractValue (Call, 0 , " umul.value" );
6661- for (User *U : make_early_inc_range (MulVal->users ())) {
6647+ Builder.SetInsertPoint (Instr);
6648+
6649+ // Replace: add/mul(zext A, zext B) --> canonical add/mul + overflow check
6650+ Value *ResultA = A, *ResultB = B;
6651+ if (WidthA < ResultWidth)
6652+ ResultA = Builder.CreateZExt (A, ResultType);
6653+ if (WidthB < ResultWidth)
6654+ ResultB = Builder.CreateZExt (B, ResultType);
6655+
6656+ Value *ArithResult;
6657+ Value *OverflowCheck;
6658+
6659+ if (Opcode == Instruction::Add) {
6660+ // Canonical add overflow check: add + compare
6661+ ArithResult = Builder.CreateAdd (ResultA, ResultB, " add" );
6662+ // Overflow if result < either operand (for unsigned add)
6663+ if (I.getPredicate () == ICmpInst::ICMP_ULT)
6664+ OverflowCheck =
6665+ Builder.CreateICmpUGE (ArithResult, ResultA, " not.add.overflow" );
6666+ else
6667+ OverflowCheck =
6668+ Builder.CreateICmpULT (ArithResult, ResultA, " add.overflow" );
6669+ } else {
6670+ // For multiplication, the intrinsic is actually the canonical form
6671+ CallInst *Call = Builder.CreateIntrinsic (Intrinsic::umul_with_overflow,
6672+ ResultType, {ResultA, ResultB},
6673+ /* FMFSource=*/ nullptr , " umul" );
6674+ ArithResult = Builder.CreateExtractValue (Call, 0 , " umul.value" );
6675+ OverflowCheck = Builder.CreateExtractValue (Call, 1 , " umul.overflow" );
6676+ if (I.getPredicate () == ICmpInst::ICMP_ULT)
6677+ OverflowCheck = Builder.CreateNot (OverflowCheck);
6678+ }
6679+
6680+ IC.addToWorklist (Instr);
6681+
6682+ // Replace uses of the original add/mul result with the new arithmetic result
6683+ if (Val->hasNUsesOrMore (2 )) {
6684+ for (User *U : make_early_inc_range (Val->users ())) {
66626685 if (U == &I)
66636686 continue ;
66646687 if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6665- if (TI->getType ()->getPrimitiveSizeInBits () == MulWidth )
6666- IC.replaceInstUsesWith (*TI, Mul );
6688+ if (TI->getType ()->getPrimitiveSizeInBits () == ResultWidth )
6689+ IC.replaceInstUsesWith (*TI, ArithResult );
66676690 else
6668- TI->setOperand (0 , Mul );
6691+ TI->setOperand (0 , ArithResult );
66696692 } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
66706693 assert (BO->getOpcode () == Instruction::And);
6671- // Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
6694+ // Replace (ArithResult & mask) --> zext (ArithResult & short_mask)
66726695 ConstantInt *CI = cast<ConstantInt>(BO->getOperand (1 ));
6673- APInt ShortMask = CI->getValue ().trunc (MulWidth );
6674- Value *ShortAnd = Builder.CreateAnd (Mul , ShortMask);
6696+ APInt ShortMask = CI->getValue ().trunc (ResultWidth );
6697+ Value *ShortAnd = Builder.CreateAnd (ArithResult , ShortMask);
66756698 Value *Zext = Builder.CreateZExt (ShortAnd, BO->getType ());
66766699 IC.replaceInstUsesWith (*BO, Zext);
66776700 } else {
@@ -6681,14 +6704,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66816704 }
66826705 }
66836706
6684- // The original icmp gets replaced with the overflow value, maybe inverted
6685- // depending on predicate.
6686- if (I.getPredicate () == ICmpInst::ICMP_ULT) {
6687- Value *Res = Builder.CreateExtractValue (Call, 1 );
6688- return BinaryOperator::CreateNot (Res);
6689- }
6690-
6691- return ExtractValueInst::Create (Call, 1 );
6707+ return IC.replaceInstUsesWith (I, OverflowCheck);
66926708}
66936709
66946710// / When performing a comparison against a constant, it is possible that not all
@@ -7832,10 +7848,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
78327848 }
78337849 }
78347850
7851+ // (zext X) + (zext Y) --> add + overflow check.
78357852 // (zext X) * (zext Y) --> llvm.umul.with.overflow.
7836- if (match (Op0, m_NUWMul (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y)))) &&
7853+ if ((match (Op0, m_NUWAdd (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y)))) ||
7854+ match (Op0, m_NUWMul (m_ZExt (m_Value (X)), m_ZExt (m_Value (Y))))) &&
78377855 match (Op1, m_APInt (C))) {
7838- if (Instruction *R = processUMulZExtIdiom (I, Op0, C, *this ))
7856+ if (Instruction *R = processUZExtIdiom (I, Op0, C, *this ))
78397857 return R;
78407858 }
78417859
0 commit comments