-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[InstCombine] Combine or-disjoint (and->mul), (and->mul) to and->mul #136013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Jeffrey Byrnes (jrbyrnes) ChangesThe canonical pattern for bitmasked mul is currently %val = and %x, %bitMask // where %bitMask is some constant In certain cases, where we are combining multiple of these bitmasked muls with common factors, we are able to optimize into and->mul (see #135274 ) This optimization lends itself to further optimizations. This PR addresses one of such optimizations. In cases where we have
we can combine into
provide C1 and C2 are disjoint. Generalized proof: https://alive2.llvm.org/ce/z/MQYMui Full diff: https://github.com/llvm/llvm-project/pull/136013.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6cc241781d112..206131ab4a6a7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3643,6 +3643,27 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
/*NSW=*/true, /*NUW=*/true))
return R;
+
+ Value *LHSOp = nullptr, *RHSOp = nullptr;
+ const APInt *LHSConst = nullptr, *RHSConst = nullptr;
+
+ // ((X & C1) * D) + ((X & C2) * D) -> (X & (C1 + C2) * D)
+ if (match(I.getOperand(0), m_Mul(m_Value(LHSOp), m_APInt(LHSConst))) &&
+ match(I.getOperand(1), m_Mul(m_Value(RHSOp), m_APInt(RHSConst))) &&
+ LHSConst == RHSConst) {
+ Value *LHSBase = nullptr, *RHSBase = nullptr;
+ const APInt *LHSMask = nullptr, *RHSMask = nullptr;
+ if (match(LHSOp, m_And(m_Value(LHSBase), m_APInt(LHSMask))) &&
+ match(RHSOp, m_And(m_Value(RHSBase), m_APInt(RHSMask))) &&
+ LHSBase == RHSBase &&
+ ((*LHSMask & *RHSMask) == APInt::getZero(LHSMask->getBitWidth()))) {
+ auto NewAnd = Builder.CreateAnd(
+ LHSBase, ConstantInt::get(LHSOp->getType(), (*LHSMask + *RHSMask)));
+
+ return BinaryOperator::CreateMul(
+ NewAnd, ConstantInt::get(NewAnd->getType(), *LHSConst));
+ }
+ }
}
Value *X, *Y;
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 95f89e4ce11cd..777387cc662d6 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1281,10 +1281,10 @@ define <16 x i1> @test51(<16 x i1> %arg, <16 x i1> %arg1) {
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <16 x i1> [[ARG:%.*]], <16 x i1> [[ARG1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 20, i32 5, i32 6, i32 23, i32 24, i32 9, i32 10, i32 27, i32 28, i32 29, i32 30, i32 31>
; CHECK-NEXT: ret <16 x i1> [[TMP3]]
;
- %tmp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
- %tmp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
- %tmp3 = or <16 x i1> %tmp, %tmp2
- ret <16 x i1> %tmp3
+ %out = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
+ %out2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
+ %out3 = or <16 x i1> %out, %out2
+ ret <16 x i1> %out3
}
; This would infinite loop because it reaches a transform
@@ -2035,3 +2035,82 @@ define i32 @or_xor_and_commuted3(i32 %x, i32 %y, i32 %z) {
%or1 = or i32 %xor, %yy
ret i32 %or1
}
+
+define i32 @or_combine_mul_and1(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and1(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 6
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 2
+ %out0 = mul i32 %bitop0, 72
+ %bitop1 = and i32 %in, 4
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
+
+define i32 @or_combine_mul_and2(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and2(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 10
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 2
+ %out0 = mul i32 %bitop0, 72
+ %bitop1 = and i32 %in, 8
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
+
+define i32 @or_combine_mul_and_diff_factor(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and_diff_factor(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT: [[TMP0:%.*]] = mul nuw nsw i32 [[BITOP0]], 36
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[BITOP1]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 2
+ %out0 = mul i32 %bitop0, 36
+ %bitop1 = and i32 %in, 4
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
+
+define i32 @or_combine_mul_and_diff_base(i32 %in, i32 %in1) {
+; CHECK-LABEL: @or_combine_mul_and_diff_base(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT: [[TMP0:%.*]] = mul nuw nsw i32 [[BITOP0]], 72
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN1:%.*]], 4
+; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[BITOP1]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 2
+ %out0 = mul i32 %bitop0, 72
+ %bitop1 = and i32 %in1, 4
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
+
+define i32 @or_combine_mul_and_decomposed(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and_decomposed(
+; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[IN:%.*]] to i1
+; CHECK-NEXT: [[OUT0:%.*]] = select i1 [[TMP2]], i32 72, i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[OUT1:%.*]] = or disjoint i32 [[OUT0]], [[OUT]]
+; CHECK-NEXT: ret i32 [[OUT1]]
+;
+ %bitop0 = and i32 %in, 1
+ %out0 = mul i32 %bitop0, 72
+ %bitop1 = and i32 %in, 4
+ %out1 = mul i32 %bitop1, 72
+ %out = or disjoint i32 %out0, %out1
+ ret i32 %out
+}
|
I plan to have 3 or so extensions on #135274 -- the plan is to resolve merge conflicts by extracting this code into a separate function. |
Converted to draft / planned changes: on second thought, it makes more sense to integrate this with #135274 |
d040ee7
to
c74714c
Compare
c74714c
to
e2f011e
Compare
force-push to rebase for #136367 |
Ping |
We may have cases where the We can likely improve |
@jrbyrnes Does this pattern exist in real-world applications? Isn't your motivating case (#133139 (comment)) solved by #135274? |
@dtcxzyw The simplified example in #133139 (comment) is solved by #135274 . However, this is very simplified and does not cover all the cases in my real world application. I need to implement 2 or 3 extensions to that PR (including this one) in order to resolve my issue. |
Change-Id: I1cc2acd3804dde50636518f3ef2c9581848ae9f6
Change-Id: I4b71adfd8bffdda4d2b0d1cba85a3fd73a105a28
Change-Id: I12f77aedbf1a2edfe63e4d03cd1e5c1c601365a7
083eb16
to
5fa229b
Compare
Force-push to rebase for 2112379 |
Ping -- this is the second in a chain of 4 PRs that are needed to solve an urgent issue. Solution has already been stalled for some time |
Could you please remind me what the larger context for these patches was? |
AI workloads are bringing in a new feature called linear layout https://arxiv.org/html/2505.23819v1 The effect of this feature is to rework address calculations s.t. we are using The problem is that The test in #137721 has a reduced example of this, and I've also included some IR in https://discourse.llvm.org/t/rfc-computeknownbits-recursion-depth/85962 . The problem is more general, and there are different common variants of address formulations that aren't included in these examples, but this gives a basic idea. I think that from a solution perspective, it would be best to change the way the recursion depth works s.t. we always do these conversions. However, I realize there may be some concerns with that approach, so I'm also working on an approach that optimizes the address calculation chains s.t. they are compatible with the recursion depth. That is where this stack of instcombine patches comes in: they clean up the intermediate code in the address calculations so we can convert these |
Change-Id: I56a280990a9bae36e59f784a7f48bdbc9f7ca539
The canonical pattern for bitmasked mul is currently
In certain cases, where we are combining multiple of these bitmasked muls with common factors, we are able to optimize into and->mul (see #135274 )
This optimization lends itself to further optimizations. This PR addresses one of such optimizations.
In cases where we have
or-disjoint ( mul(and (X, C1), D) , mul (and (X, C2), D))
we can combine into
mul( and (X, (C1 + C2)), D)
provided C1 and C2 are disjoint.
Generalized proof: https://alive2.llvm.org/ce/z/MQYMui