Skip to content

[InstCombine] Combine or-disjoint (and->mul), (and->mul) to and->mul #136013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jrbyrnes
Copy link
Contributor

@jrbyrnes jrbyrnes commented Apr 16, 2025

The canonical pattern for bitmasked mul is currently

%val = and %x, %bitMask // where %bitMask is some constant
%cmp = icmp eq %val, 0
%sel = select %cmp, 0, %C // where %C is some constant = C' * %bitMask

In certain cases, where we are combining multiple of these bitmasked muls with common factors, we are able to optimize into and->mul (see #135274 )

This optimization lends itself to further optimizations. This PR addresses one of such optimizations.

In cases where we have

or-disjoint ( mul(and (X, C1), D) , mul (and (X, C2), D))

we can combine into

mul( and (X, (C1 + C2)), D)

provided C1 and C2 are disjoint.

Generalized proof: https://alive2.llvm.org/ce/z/MQYMui

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Jeffrey Byrnes (jrbyrnes)

Changes

The canonical pattern for bitmasked mul is currently

%val = and %x, %bitMask // where %bitMask is some constant
%cmp = icmp eq %val, 0
%sel = select %cmp, 0, %C // where %C is some constant

In certain cases, where we are combining multiple of these bitmasked muls with common factors, we are able to optimize into and->mul (see #135274 )

This optimization lends itself to further optimizations. This PR addresses one of such optimizations.

In cases where we have

or-disjoint ( mul(and (X, C1), D) , mul (and (X, C2), D))

we can combine into

mul( and (X, (C1 + C2)), D)

provide C1 and C2 are disjoint.

Generalized proof: https://alive2.llvm.org/ce/z/MQYMui


Full diff: https://github.com/llvm/llvm-project/pull/136013.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+21)
  • (modified) llvm/test/Transforms/InstCombine/or.ll (+83-4)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6cc241781d112..206131ab4a6a7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3643,6 +3643,27 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
             foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
+
+    Value *LHSOp = nullptr, *RHSOp = nullptr;
+    const APInt *LHSConst = nullptr, *RHSConst = nullptr;
+
+    // ((X & C1) * D) + ((X & C2) * D) -> (X & (C1 + C2) * D)
+    if (match(I.getOperand(0), m_Mul(m_Value(LHSOp), m_APInt(LHSConst))) &&
+        match(I.getOperand(1), m_Mul(m_Value(RHSOp), m_APInt(RHSConst))) &&
+        LHSConst == RHSConst) {
+      Value *LHSBase = nullptr, *RHSBase = nullptr;
+      const APInt *LHSMask = nullptr, *RHSMask = nullptr;
+      if (match(LHSOp, m_And(m_Value(LHSBase), m_APInt(LHSMask))) &&
+          match(RHSOp, m_And(m_Value(RHSBase), m_APInt(RHSMask))) &&
+          LHSBase == RHSBase &&
+          ((*LHSMask & *RHSMask) == APInt::getZero(LHSMask->getBitWidth()))) {
+        auto NewAnd = Builder.CreateAnd(
+            LHSBase, ConstantInt::get(LHSOp->getType(), (*LHSMask + *RHSMask)));
+
+        return BinaryOperator::CreateMul(
+            NewAnd, ConstantInt::get(NewAnd->getType(), *LHSConst));
+      }
+    }
   }
 
   Value *X, *Y;
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 95f89e4ce11cd..777387cc662d6 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1281,10 +1281,10 @@ define <16 x i1> @test51(<16 x i1> %arg, <16 x i1> %arg1) {
 ; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <16 x i1> [[ARG:%.*]], <16 x i1> [[ARG1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 20, i32 5, i32 6, i32 23, i32 24, i32 9, i32 10, i32 27, i32 28, i32 29, i32 30, i32 31>
 ; CHECK-NEXT:    ret <16 x i1> [[TMP3]]
 ;
-  %tmp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
-  %tmp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
-  %tmp3 = or <16 x i1> %tmp, %tmp2
-  ret <16 x i1> %tmp3
+  %out = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
+  %out2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
+  %out3 = or <16 x i1> %out, %out2
+  ret <16 x i1> %out3
 }
 
 ; This would infinite loop because it reaches a transform
@@ -2035,3 +2035,82 @@ define i32 @or_xor_and_commuted3(i32 %x, i32 %y, i32 %z) {
   %or1 = or i32 %xor, %yy
   ret i32 %or1
 }
+
+define i32 @or_combine_mul_and1(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 6
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 2
+  %out0 = mul i32 %bitop0, 72
+  %bitop1 = and i32 %in, 4
+  %out1 = mul i32 %bitop1, 72
+  %out = or disjoint i32 %out0, %out1
+  ret i32 %out
+}
+
+define i32 @or_combine_mul_and2(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 10
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 2
+  %out0 = mul i32 %bitop0, 72
+  %bitop1 = and i32 %in, 8
+  %out1 = mul i32 %bitop1, 72
+  %out = or disjoint i32 %out0, %out1
+  ret i32 %out
+}
+
+define i32 @or_combine_mul_and_diff_factor(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and_diff_factor(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw nsw i32 [[BITOP0]], 36
+; CHECK-NEXT:    [[BITOP1:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT:    [[TMP1:%.*]] = mul nuw nsw i32 [[BITOP1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 2
+  %out0 = mul i32 %bitop0, 36
+  %bitop1 = and i32 %in, 4
+  %out1 = mul i32 %bitop1, 72
+  %out = or disjoint i32 %out0, %out1
+  ret i32 %out
+}
+
+define i32 @or_combine_mul_and_diff_base(i32 %in, i32 %in1) {
+; CHECK-LABEL: @or_combine_mul_and_diff_base(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT:    [[TMP0:%.*]] = mul nuw nsw i32 [[BITOP0]], 72
+; CHECK-NEXT:    [[BITOP1:%.*]] = and i32 [[IN1:%.*]], 4
+; CHECK-NEXT:    [[TMP1:%.*]] = mul nuw nsw i32 [[BITOP1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 2
+  %out0 = mul i32 %bitop0, 72
+  %bitop1 = and i32 %in1, 4
+  %out1 = mul i32 %bitop1, 72
+  %out = or disjoint i32 %out0, %out1
+  ret i32 %out
+}
+
+define i32 @or_combine_mul_and_decomposed(i32 %in) {
+; CHECK-LABEL: @or_combine_mul_and_decomposed(
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[IN:%.*]] to i1
+; CHECK-NEXT:    [[OUT0:%.*]] = select i1 [[TMP2]], i32 72, i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[OUT1:%.*]] = or disjoint i32 [[OUT0]], [[OUT]]
+; CHECK-NEXT:    ret i32 [[OUT1]]
+;
+  %bitop0 = and i32 %in, 1
+  %out0 = mul i32 %bitop0, 72
+  %bitop1 = and i32 %in, 4
+  %out1 = mul i32 %bitop1, 72
+  %out = or disjoint i32 %out0, %out1
+  ret i32 %out
+}

@jrbyrnes
Copy link
Contributor Author

I plan to have 3 or so extensions on #135274 -- the plan is to resolve merge conflicts by extracting this code into a separate function.

@jrbyrnes jrbyrnes marked this pull request as draft April 16, 2025 22:37
@jrbyrnes
Copy link
Contributor Author

Converted to draft / planned changes: on second thought, it makes more sense to integrate this with #135274

@jrbyrnes jrbyrnes force-pushed the ICAndSelExtendCase0 branch from d040ee7 to c74714c Compare April 17, 2025 20:27
@jrbyrnes jrbyrnes marked this pull request as ready for review April 17, 2025 20:27
@jrbyrnes
Copy link
Contributor Author

Extend #135274 to match the new and->mul sequences, make the merging of PRs explicit.

This PR is meant for commits starting at c74714c

@jrbyrnes
Copy link
Contributor Author

force-push to rebase for #136367

@jrbyrnes
Copy link
Contributor Author

Ping

@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented May 27, 2025

We may have cases where the select APInts have different bitwidths than the mask. These cases were unhandled and causing assertion failures.

We can likely improve matchBitmaskMul to handle this type of case, however since this isn't the base case I will leave this as a possible extension.

@dtcxzyw
Copy link
Member

dtcxzyw commented May 30, 2025

@jrbyrnes Does this pattern exist in real-world applications? Isn't your motivating case (#133139 (comment)) solved by #135274?

@jrbyrnes
Copy link
Contributor Author

@dtcxzyw The simplified example in #133139 (comment) is solved by #135274 . However, this is very simplified and does not cover all the cases in my real world application. I need to implement 2 or 3 extensions to that PR (including this one) in order to resolve my issue.

jrbyrnes added 2 commits June 2, 2025 11:05
Change-Id: I1cc2acd3804dde50636518f3ef2c9581848ae9f6
Change-Id: I4b71adfd8bffdda4d2b0d1cba85a3fd73a105a28
Change-Id: I12f77aedbf1a2edfe63e4d03cd1e5c1c601365a7
@jrbyrnes jrbyrnes force-pushed the ICAndSelExtendCase0 branch from 083eb16 to 5fa229b Compare June 2, 2025 18:14
@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Jun 2, 2025

Force-push to rebase for 2112379

@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Jun 4, 2025

Ping -- this is the second in a chain of 4 PRs that are needed to solve an urgent issue. Solution has already been stalled for some time

@nikic
Copy link
Contributor

nikic commented Jun 5, 2025

Could you please remind me what the larger context for these patches was?

@jrbyrnes
Copy link
Contributor Author

jrbyrnes commented Jun 5, 2025

Could you please remind me what the larger context for these patches was?

AI workloads are bringing in a new feature called linear layout https://arxiv.org/html/2505.23819v1

The effect of this feature is to rework address calculations s.t. we are using xor and and in places where we may more typically use add and mul. There are important clients that use this feature which are MLIR based and directly produce IR (instead of going through clang) (e.g. Triton).

The problem is that separateConstOffsetFromGEP doesn't extract constant offsets from xor. Moreover, since many of the xor are equivalent to or disjoint, it would be awesome to convert these xor to or disjoint. Doing so actually provides a very significant performance uplift as it significantly reduces RP and avoids spilling in some cases. The problem is that most of these address computation chains are longer than the knownBits recursion depth limit.

The test in #137721 has a reduced example of this, and I've also included some IR in https://discourse.llvm.org/t/rfc-computeknownbits-recursion-depth/85962 . The problem is more general, and there are different common variants of address formulations that aren't included in these examples, but this gives a basic idea.

I think that from a solution perspective, it would be best to change the way the recursion depth works s.t. we always do these conversions. However, I realize there may be some concerns with that approach, so I'm also working on an approach that optimizes the address calculation chains s.t. they are compatible with the recursion depth.

That is where this stack of instcombine patches comes in: they clean up the intermediate code in the address calculations so we can convert these xor to or-disjoint within the depth limit. This solution is a bit less stable, but it will resolve the current performance issues that occur when adopting this technology.

Change-Id: I56a280990a9bae36e59f784a7f48bdbc9f7ca539
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants