Skip to content

Commit 2bc36d4

Browse files
committed
[InstCombine] Canonicalize zext+overflow check to overflow check if zext's only purpose is to check overflow
Change processUMulZExtIdiom to also support adds, since the idiom is the same, except with add instead of mul. Alive2: https://alive2.llvm.org/ce/z/SsB4AK
1 parent 0c6c532 commit 2bc36d4

File tree

4 files changed

+101
-96
lines changed

4 files changed

+101
-96
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 88 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6533,72 +6533,76 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
65336533
llvm_unreachable("Unexpected overflow result");
65346534
}
65356535

6536-
/// Recognize and process idiom involving test for multiplication
6536+
/// Recognize and process idiom involving test for unsigned
65376537
/// overflow.
65386538
///
65396539
/// The caller has matched a pattern of the form:
6540+
/// I = cmp u (add(zext A, zext B), V
65406541
/// I = cmp u (mul(zext A, zext B), V
65416542
/// The function checks if this is a test for overflow and if so replaces
6542-
/// multiplication with call to 'mul.with.overflow' intrinsic.
6543+
/// addition/multiplication with call to the umul intrinsic or the canonical
6544+
/// form of uadd overflow.
65436545
///
65446546
/// \param I Compare instruction.
6545-
/// \param MulVal Result of 'mult' instruction. It is one of the arguments of
6546-
/// the compare instruction. Must be of integer type.
6547+
/// \param Val Result of add/mul instruction. It is one of the arguments of
6548+
/// the compare instruction. Must be of integer type.
65476549
/// \param OtherVal The other argument of compare instruction.
65486550
/// \returns Instruction which must replace the compare instruction, NULL if no
65496551
/// replacement required.
6550-
static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
6551-
const APInt *OtherVal,
6552-
InstCombinerImpl &IC) {
6552+
static Instruction *processUZExtIdiom(ICmpInst &I, Value *Val,
6553+
const APInt *OtherVal,
6554+
InstCombinerImpl &IC) {
65536555
// Don't bother doing this transformation for pointers, don't do it for
65546556
// vectors.
6555-
if (!isa<IntegerType>(MulVal->getType()))
6557+
if (!isa<IntegerType>(Val->getType()))
65566558
return nullptr;
65576559

6558-
auto *MulInstr = dyn_cast<Instruction>(MulVal);
6559-
if (!MulInstr)
6560+
auto *Instr = dyn_cast<Instruction>(Val);
6561+
if (!Instr)
65606562
return nullptr;
6561-
assert(MulInstr->getOpcode() == Instruction::Mul);
65626563

6563-
auto *LHS = cast<ZExtInst>(MulInstr->getOperand(0)),
6564-
*RHS = cast<ZExtInst>(MulInstr->getOperand(1));
6564+
unsigned Opcode = Instr->getOpcode();
6565+
assert(Opcode == Instruction::Add || Opcode == Instruction::Mul);
6566+
6567+
auto *LHS = cast<ZExtInst>(Instr->getOperand(0)),
6568+
*RHS = cast<ZExtInst>(Instr->getOperand(1));
65656569
assert(LHS->getOpcode() == Instruction::ZExt);
65666570
assert(RHS->getOpcode() == Instruction::ZExt);
65676571
Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
65686572

6569-
// Calculate type and width of the result produced by mul.with.overflow.
6573+
// Calculate type and width of the result produced by add/mul.with.overflow.
65706574
Type *TyA = A->getType(), *TyB = B->getType();
65716575
unsigned WidthA = TyA->getPrimitiveSizeInBits(),
65726576
WidthB = TyB->getPrimitiveSizeInBits();
6573-
unsigned MulWidth;
6574-
Type *MulType;
6577+
unsigned ResultWidth;
6578+
Type *ResultType;
65756579
if (WidthB > WidthA) {
6576-
MulWidth = WidthB;
6577-
MulType = TyB;
6580+
ResultWidth = WidthB;
6581+
ResultType = TyB;
65786582
} else {
6579-
MulWidth = WidthA;
6580-
MulType = TyA;
6583+
ResultWidth = WidthA;
6584+
ResultType = TyA;
65816585
}
65826586

6583-
// In order to replace the original mul with a narrower mul.with.overflow,
6584-
// all uses must ignore upper bits of the product. The number of used low
6585-
// bits must be not greater than the width of mul.with.overflow.
6586-
if (MulVal->hasNUsesOrMore(2))
6587-
for (User *U : MulVal->users()) {
6587+
// In order to replace the original result with a narrower one, all uses must
6588+
// ignore upper bits of the result. The number of used low bits must be not
6589+
// greater than the width of add or mul.with.overflow.
6590+
if (Val->hasNUsesOrMore(2))
6591+
for (User *U : Val->users()) {
65886592
if (U == &I)
65896593
continue;
65906594
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6591-
// Check if truncation ignores bits above MulWidth.
6595+
// Check if truncation ignores bits above ResultWidth.
65926596
unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits();
6593-
if (TruncWidth > MulWidth)
6597+
if (TruncWidth > ResultWidth)
65946598
return nullptr;
65956599
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
6596-
// Check if AND ignores bits above MulWidth.
6600+
// Check if AND ignores bits above ResultWidth.
65976601
if (BO->getOpcode() != Instruction::And)
65986602
return nullptr;
65996603
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
66006604
const APInt &CVal = CI->getValue();
6601-
if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth)
6605+
if (CVal.getBitWidth() - CVal.countl_zero() > ResultWidth)
66026606
return nullptr;
66036607
} else {
66046608
// In this case we could have the operand of the binary operation
@@ -6616,9 +6620,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66166620
switch (I.getPredicate()) {
66176621
case ICmpInst::ICMP_UGT: {
66186622
// Recognize pattern:
6619-
// mulval = mul(zext A, zext B)
6620-
// cmp ugt mulval, max
6621-
APInt MaxVal = APInt::getMaxValue(MulWidth);
6623+
// val = add/mul(zext A, zext B)
6624+
// cmp ugt val, max
6625+
APInt MaxVal = APInt::getMaxValue(ResultWidth);
66226626
MaxVal = MaxVal.zext(OtherVal->getBitWidth());
66236627
if (MaxVal.eq(*OtherVal))
66246628
break; // Recognized
@@ -6627,9 +6631,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66276631

66286632
case ICmpInst::ICMP_ULT: {
66296633
// Recognize pattern:
6630-
// mulval = mul(zext A, zext B)
6631-
// cmp ule mulval, max + 1
6632-
APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth);
6634+
// val = add/mul(zext A, zext B)
6635+
// cmp ule val, max + 1
6636+
APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), ResultWidth);
66336637
if (MaxVal.eq(*OtherVal))
66346638
break; // Recognized
66356639
return nullptr;
@@ -6640,38 +6644,57 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66406644
}
66416645

66426646
InstCombiner::BuilderTy &Builder = IC.Builder;
6643-
Builder.SetInsertPoint(MulInstr);
6644-
6645-
// Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
6646-
Value *MulA = A, *MulB = B;
6647-
if (WidthA < MulWidth)
6648-
MulA = Builder.CreateZExt(A, MulType);
6649-
if (WidthB < MulWidth)
6650-
MulB = Builder.CreateZExt(B, MulType);
6651-
CallInst *Call =
6652-
Builder.CreateIntrinsic(Intrinsic::umul_with_overflow, MulType,
6653-
{MulA, MulB}, /*FMFSource=*/nullptr, "umul");
6654-
IC.addToWorklist(MulInstr);
6655-
6656-
// If there are uses of mul result other than the comparison, we know that
6657-
// they are truncation or binary AND. Change them to use result of
6658-
// mul.with.overflow and adjust properly mask/size.
6659-
if (MulVal->hasNUsesOrMore(2)) {
6660-
Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value");
6661-
for (User *U : make_early_inc_range(MulVal->users())) {
6647+
Builder.SetInsertPoint(Instr);
6648+
6649+
// Replace: add/mul(zext A, zext B) --> canonical add/mul + overflow check
6650+
Value *ResultA = A, *ResultB = B;
6651+
if (WidthA < ResultWidth)
6652+
ResultA = Builder.CreateZExt(A, ResultType);
6653+
if (WidthB < ResultWidth)
6654+
ResultB = Builder.CreateZExt(B, ResultType);
6655+
6656+
Value *ArithResult;
6657+
Value *OverflowCheck;
6658+
6659+
if (Opcode == Instruction::Add) {
6660+
// Canonical add overflow check: add + compare
6661+
ArithResult = Builder.CreateAdd(ResultA, ResultB, "add");
6662+
// Overflow if result < either operand (for unsigned add)
6663+
if (I.getPredicate() == ICmpInst::ICMP_ULT)
6664+
OverflowCheck =
6665+
Builder.CreateICmpUGE(ArithResult, ResultA, "not.add.overflow");
6666+
else
6667+
OverflowCheck =
6668+
Builder.CreateICmpULT(ArithResult, ResultA, "add.overflow");
6669+
} else {
6670+
// For multiplication, the intrinsic is actually the canonical form
6671+
CallInst *Call = Builder.CreateIntrinsic(Intrinsic::umul_with_overflow,
6672+
ResultType, {ResultA, ResultB},
6673+
/*FMFSource=*/nullptr, "umul");
6674+
ArithResult = Builder.CreateExtractValue(Call, 0, "umul.value");
6675+
OverflowCheck = Builder.CreateExtractValue(Call, 1, "umul.overflow");
6676+
if (I.getPredicate() == ICmpInst::ICMP_ULT)
6677+
OverflowCheck = Builder.CreateNot(OverflowCheck);
6678+
}
6679+
6680+
IC.addToWorklist(Instr);
6681+
6682+
// Replace uses of the original add/mul result with the new arithmetic result
6683+
if (Val->hasNUsesOrMore(2)) {
6684+
for (User *U : make_early_inc_range(Val->users())) {
66626685
if (U == &I)
66636686
continue;
66646687
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
6665-
if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
6666-
IC.replaceInstUsesWith(*TI, Mul);
6688+
if (TI->getType()->getPrimitiveSizeInBits() == ResultWidth)
6689+
IC.replaceInstUsesWith(*TI, ArithResult);
66676690
else
6668-
TI->setOperand(0, Mul);
6691+
TI->setOperand(0, ArithResult);
66696692
} else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) {
66706693
assert(BO->getOpcode() == Instruction::And);
6671-
// Replace (mul & mask) --> zext (mul.with.overflow & short_mask)
6694+
// Replace (ArithResult & mask) --> zext (ArithResult & short_mask)
66726695
ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
6673-
APInt ShortMask = CI->getValue().trunc(MulWidth);
6674-
Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
6696+
APInt ShortMask = CI->getValue().trunc(ResultWidth);
6697+
Value *ShortAnd = Builder.CreateAnd(ArithResult, ShortMask);
66756698
Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
66766699
IC.replaceInstUsesWith(*BO, Zext);
66776700
} else {
@@ -6681,14 +6704,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
66816704
}
66826705
}
66836706

6684-
// The original icmp gets replaced with the overflow value, maybe inverted
6685-
// depending on predicate.
6686-
if (I.getPredicate() == ICmpInst::ICMP_ULT) {
6687-
Value *Res = Builder.CreateExtractValue(Call, 1);
6688-
return BinaryOperator::CreateNot(Res);
6689-
}
6690-
6691-
return ExtractValueInst::Create(Call, 1);
6707+
return IC.replaceInstUsesWith(I, OverflowCheck);
66926708
}
66936709

66946710
/// When performing a comparison against a constant, it is possible that not all
@@ -7832,10 +7848,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
78327848
}
78337849
}
78347850

7851+
// (zext X) + (zext Y) --> add + overflow check.
78357852
// (zext X) * (zext Y) --> llvm.umul.with.overflow.
7836-
if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
7853+
if ((match(Op0, m_NUWAdd(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) ||
7854+
match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y))))) &&
78377855
match(Op1, m_APInt(C))) {
7838-
if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this))
7856+
if (Instruction *R = processUZExtIdiom(I, Op0, C, *this))
78397857
return R;
78407858
}
78417859

llvm/test/Transforms/InstCombine/overflow-mul.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ define i32 @extra_and_use(i32 %x, i32 %y) {
286286
; CHECK-LABEL: @extra_and_use(
287287
; CHECK-NEXT: [[UMUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
288288
; CHECK-NEXT: [[UMUL_VALUE:%.*]] = extractvalue { i32, i1 } [[UMUL]], 0
289-
; CHECK-NEXT: [[AND:%.*]] = zext i32 [[UMUL_VALUE]] to i64
290289
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1
290+
; CHECK-NEXT: [[AND:%.*]] = zext i32 [[UMUL_VALUE]] to i64
291291
; CHECK-NEXT: call void @use.i64(i64 [[AND]])
292292
; CHECK-NEXT: [[RETVAL:%.*]] = zext i1 [[OVERFLOW]] to i32
293293
; CHECK-NEXT: ret i32 [[RETVAL]]
@@ -306,9 +306,9 @@ define i32 @extra_and_use_small_mask(i32 %x, i32 %y) {
306306
; CHECK-LABEL: @extra_and_use_small_mask(
307307
; CHECK-NEXT: [[UMUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
308308
; CHECK-NEXT: [[UMUL_VALUE:%.*]] = extractvalue { i32, i1 } [[UMUL]], 0
309+
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1
309310
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[UMUL_VALUE]], 268435455
310311
; CHECK-NEXT: [[AND:%.*]] = zext nneg i32 [[TMP1]] to i64
311-
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1
312312
; CHECK-NEXT: call void @use.i64(i64 [[AND]])
313313
; CHECK-NEXT: [[RETVAL:%.*]] = zext i1 [[OVERFLOW]] to i32
314314
; CHECK-NEXT: ret i32 [[RETVAL]]

llvm/test/Transforms/InstCombine/saturating-add-sub.ll

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,11 +2352,7 @@ define i8 @fold_add_umax_to_usub_multiuse(i8 %a) {
23522352

23532353
define i32 @uadd_with_zext(i32 %x, i32 %y) {
23542354
; CHECK-LABEL: @uadd_with_zext(
2355-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
2356-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
2357-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
2358-
; CHECK-NEXT: [[COND1:%.*]] = call i64 @llvm.umin.i64(i64 [[ADD]], i64 4294967295)
2359-
; CHECK-NEXT: [[COND:%.*]] = trunc nuw i64 [[COND1]] to i32
2355+
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 [[Y:%.*]])
23602356
; CHECK-NEXT: ret i32 [[COND]]
23612357
;
23622358
%conv = zext i32 %x to i64
@@ -2370,13 +2366,9 @@ define i32 @uadd_with_zext(i32 %x, i32 %y) {
23702366

23712367
define i32 @uadd_with_zext_multi_use(i32 %x, i32 %y) {
23722368
; CHECK-LABEL: @uadd_with_zext_multi_use(
2373-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
2374-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
2375-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
2376-
; CHECK-NEXT: [[TRUNCADD:%.*]] = trunc i64 [[ADD]] to i32
2369+
; CHECK-NEXT: [[TRUNCADD:%.*]] = add i32 [[X:%.*]], [[Y:%.*]]
23772370
; CHECK-NEXT: call void @usei32(i32 [[TRUNCADD]])
2378-
; CHECK-NEXT: [[COND1:%.*]] = call i64 @llvm.umin.i64(i64 [[ADD]], i64 4294967295)
2379-
; CHECK-NEXT: [[COND:%.*]] = trunc nuw i64 [[COND1]] to i32
2371+
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Y]])
23802372
; CHECK-NEXT: ret i32 [[COND]]
23812373
;
23822374
%conv = zext i32 %x to i64

llvm/test/Transforms/InstCombine/uadd-with-overflow.ll

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,8 @@ define { <2 x i32>, <2 x i1> } @fold_simple_splat_constant_with_or_fail(<2 x i32
150150

151151
define i32 @uadd_with_zext(i32 %x, i32 %y) {
152152
; CHECK-LABEL: @uadd_with_zext(
153-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
154-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
155-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
156-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ugt i64 [[ADD]], 4294967295
153+
; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[X:%.*]], -1
154+
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[Y:%.*]], [[TMP1]]
157155
; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32
158156
; CHECK-NEXT: ret i32 [[COND]]
159157
;
@@ -167,12 +165,11 @@ define i32 @uadd_with_zext(i32 %x, i32 %y) {
167165

168166
define i32 @uadd_with_zext_use_and(i32 %x, i32 %y) {
169167
; CHECK-LABEL: @uadd_with_zext_use_and(
170-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
171-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
172-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
173-
; CHECK-NEXT: [[AND:%.*]] = and i64 [[ADD]], 65535
168+
; CHECK-NEXT: [[UADD_VALUE:%.*]] = add i32 [[X:%.*]], [[Y:%.*]]
169+
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[UADD_VALUE]], [[X]]
170+
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[UADD_VALUE]], 65535
171+
; CHECK-NEXT: [[AND:%.*]] = zext nneg i32 [[TMP1]] to i64
174172
; CHECK-NEXT: call void @usei64(i64 [[AND]])
175-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ugt i64 [[ADD]], 4294967295
176173
; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32
177174
; CHECK-NEXT: ret i32 [[COND]]
178175
;
@@ -188,10 +185,8 @@ define i32 @uadd_with_zext_use_and(i32 %x, i32 %y) {
188185

189186
define i32 @uadd_with_zext_inverse(i32 %x, i32 %y) {
190187
; CHECK-LABEL: @uadd_with_zext_inverse(
191-
; CHECK-NEXT: [[CONV:%.*]] = zext i32 [[X:%.*]] to i64
192-
; CHECK-NEXT: [[CONV1:%.*]] = zext i32 [[Y:%.*]] to i64
193-
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[CONV]], [[CONV1]]
194-
; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ult i64 [[ADD]], 4294967296
188+
; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[X:%.*]], -1
189+
; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[Y:%.*]], [[TMP1]]
195190
; CHECK-NEXT: [[COND:%.*]] = zext i1 [[CMP]] to i32
196191
; CHECK-NEXT: ret i32 [[COND]]
197192
;

0 commit comments

Comments
 (0)