Skip to content

Commit 72ffaa9

Browse files
authored
[IR][TRE] Support associative intrinsics (llvm#74226)
There is support for intrinsics in Instruction::isCommunative, but there is no equivalent implementation for isAssociative. This patch builds support for associative intrinsics with TRE as an application. TRE can now have associative intrinsics as an accumulator. For example: ``` struct Node { Node *next; unsigned val; } unsigned maxval(struct Node *n) { if (!n) return 0; return std::max(n->val, maxval(n->next)); } ``` Can be transformed into: ``` unsigned maxval(struct Node *n) { struct Node *head = n; unsigned max = 0; // Identity of unsigned std::max while (true) { if (!head) return max; max = std::max(max, head->val); head = head->next; } return max; } ``` This example results in about 5x speedup in local runs. We conservatively only consider min/max and as associative for this patch to limit testing scope. There are probably other intrinsics that could be considered associative. There are a few consumers of isAssociative() that could be impacted. Testing has only required to Reassociate pass be updated.
1 parent 1334030 commit 72ffaa9

File tree

8 files changed

+321
-66
lines changed

8 files changed

+321
-66
lines changed

llvm/include/llvm/IR/Constants.h

+15-8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llvm/ADT/StringRef.h"
2828
#include "llvm/IR/Constant.h"
2929
#include "llvm/IR/DerivedTypes.h"
30+
#include "llvm/IR/Intrinsics.h"
3031
#include "llvm/IR/OperandTraits.h"
3132
#include "llvm/IR/User.h"
3233
#include "llvm/IR/Value.h"
@@ -1095,18 +1096,24 @@ class ConstantExpr : public Constant {
10951096
static Constant *getExactLogBase2(Constant *C);
10961097

10971098
/// Return the identity constant for a binary opcode.
1098-
/// The identity constant C is defined as X op C = X and C op X = X for every
1099-
/// X when the binary operation is commutative. If the binop is not
1100-
/// commutative, callers can acquire the operand 1 identity constant by
1101-
/// setting AllowRHSConstant to true. For example, any shift has a zero
1102-
/// identity constant for operand 1: X shift 0 = X.
1103-
/// If this is a fadd/fsub operation and we don't care about signed zeros,
1104-
/// then setting NSZ to true returns the identity +0.0 instead of -0.0.
1105-
/// Return nullptr if the operator does not have an identity constant.
1099+
/// If the binop is not commutative, callers can acquire the operand 1
1100+
/// identity constant by setting AllowRHSConstant to true. For example, any
1101+
/// shift has a zero identity constant for operand 1: X shift 0 = X. If this
1102+
/// is a fadd/fsub operation and we don't care about signed zeros, then
1103+
/// setting NSZ to true returns the identity +0.0 instead of -0.0. Return
1104+
/// nullptr if the operator does not have an identity constant.
11061105
static Constant *getBinOpIdentity(unsigned Opcode, Type *Ty,
11071106
bool AllowRHSConstant = false,
11081107
bool NSZ = false);
11091108

1109+
static Constant *getIntrinsicIdentity(Intrinsic::ID, Type *Ty);
1110+
1111+
/// Return the identity constant for a binary or intrinsic Instruction.
1112+
/// The identity constant C is defined as X op C = X and C op X = X where C
1113+
/// and X are the first two operands, and the operation is commutative.
1114+
static Constant *getIdentity(Instruction *I, Type *Ty,
1115+
bool AllowRHSConstant = false, bool NSZ = false);
1116+
11101117
/// Return the absorbing element for the given binary
11111118
/// operation, i.e. a constant C such that X op C = C and C op X = C for
11121119
/// every X. For example, this returns zero for integer multiplication.

llvm/include/llvm/IR/IntrinsicInst.h

+12
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ class IntrinsicInst : public CallInst {
5555
return getCalledFunction()->getIntrinsicID();
5656
}
5757

58+
bool isAssociative() const {
59+
switch (getIntrinsicID()) {
60+
case Intrinsic::smax:
61+
case Intrinsic::smin:
62+
case Intrinsic::umax:
63+
case Intrinsic::umin:
64+
return true;
65+
default:
66+
return false;
67+
}
68+
}
69+
5870
/// Return true if swapping the first two arguments to the intrinsic produces
5971
/// the same result.
6072
bool isCommutative() const {

llvm/lib/IR/Constants.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -2556,6 +2556,32 @@ Constant *ConstantExpr::getBinOpIdentity(unsigned Opcode, Type *Ty,
25562556
}
25572557
}
25582558

2559+
Constant *ConstantExpr::getIntrinsicIdentity(Intrinsic::ID ID, Type *Ty) {
2560+
switch (ID) {
2561+
case Intrinsic::umax:
2562+
return Constant::getNullValue(Ty);
2563+
case Intrinsic::umin:
2564+
return Constant::getAllOnesValue(Ty);
2565+
case Intrinsic::smax:
2566+
return Constant::getIntegerValue(
2567+
Ty, APInt::getSignedMinValue(Ty->getIntegerBitWidth()));
2568+
case Intrinsic::smin:
2569+
return Constant::getIntegerValue(
2570+
Ty, APInt::getSignedMaxValue(Ty->getIntegerBitWidth()));
2571+
default:
2572+
return nullptr;
2573+
}
2574+
}
2575+
2576+
Constant *ConstantExpr::getIdentity(Instruction *I, Type *Ty,
2577+
bool AllowRHSConstant, bool NSZ) {
2578+
if (I->isBinaryOp())
2579+
return getBinOpIdentity(I->getOpcode(), Ty, AllowRHSConstant, NSZ);
2580+
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
2581+
return getIntrinsicIdentity(II->getIntrinsicID(), Ty);
2582+
return nullptr;
2583+
}
2584+
25592585
Constant *ConstantExpr::getBinOpAbsorber(unsigned Opcode, Type *Ty) {
25602586
switch (Opcode) {
25612587
default:

llvm/lib/IR/Instruction.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,8 @@ const DebugLoc &Instruction::getStableDebugLoc() const {
10911091
}
10921092

10931093
bool Instruction::isAssociative() const {
1094+
if (auto *II = dyn_cast<IntrinsicInst>(this))
1095+
return II->isAssociative();
10941096
unsigned Opcode = getOpcode();
10951097
if (isAssociative(Opcode))
10961098
return true;

llvm/lib/Transforms/Scalar/Reassociate.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2554,7 +2554,7 @@ ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) {
25542554
// Make a "pairmap" of how often each operand pair occurs.
25552555
for (BasicBlock *BI : RPOT) {
25562556
for (Instruction &I : *BI) {
2557-
if (!I.isAssociative())
2557+
if (!I.isAssociative() || !I.isBinaryOp())
25582558
continue;
25592559

25602560
// Ignore nodes that aren't at the root of trees.

llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,14 @@ static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) {
369369
if (!I->isAssociative() || !I->isCommutative())
370370
return false;
371371

372-
assert(I->getNumOperands() == 2 &&
373-
"Associative/commutative operations should have 2 args!");
372+
assert(I->getNumOperands() >= 2 &&
373+
"Associative/commutative operations should have at least 2 args!");
374+
375+
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
376+
// Accumulators must have an identity.
377+
if (!ConstantExpr::getIntrinsicIdentity(II->getIntrinsicID(), I->getType()))
378+
return false;
379+
}
374380

375381
// Exactly one operand should be the result of the call instruction.
376382
if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
@@ -569,8 +575,8 @@ void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) {
569575
for (pred_iterator PI = PB; PI != PE; ++PI) {
570576
BasicBlock *P = *PI;
571577
if (P == &F.getEntryBlock()) {
572-
Constant *Identity = ConstantExpr::getBinOpIdentity(
573-
AccRecInstr->getOpcode(), AccRecInstr->getType());
578+
Constant *Identity =
579+
ConstantExpr::getIdentity(AccRecInstr, AccRecInstr->getType());
574580
AccPN->addIncoming(Identity, P);
575581
} else {
576582
AccPN->addIncoming(AccPN, P);

llvm/test/Transforms/TailCallElim/accum_recursion.ll

+39-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ define i64 @test3_fib(i64 %n) nounwind readnone {
7878
; CHECK-NEXT: ]
7979
; CHECK: bb1:
8080
; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[N_TR]], -1
81-
; CHECK-NEXT: [[RECURSE1:%.*]] = tail call i64 @test3_fib(i64 [[TMP0]]) #[[ATTR1:[0-9]+]]
81+
; CHECK-NEXT: [[RECURSE1:%.*]] = tail call i64 @test3_fib(i64 [[TMP0]]) #[[ATTR2:[0-9]+]]
8282
; CHECK-NEXT: [[TMP1]] = add i64 [[N_TR]], -2
8383
; CHECK-NEXT: [[ACCUMULATE]] = add nsw i64 [[ACCUMULATOR_TR]], [[RECURSE1]]
8484
; CHECK-NEXT: br label [[TAILRECURSE]]
@@ -290,3 +290,41 @@ return:
290290
%retval.0 = phi i32 [ %accumulate1, %if.then2 ], [ %accumulate2, %if.end3 ], [ 0, %entry ]
291291
ret i32 %retval.0
292292
}
293+
294+
%struct.ListNode = type { i32, ptr }
295+
296+
; We cannot TRE commutative, non-associative intrinsics
297+
define i32 @test_non_associative_sadd_sat(ptr %a) local_unnamed_addr {
298+
; CHECK-LABEL: define i32 @test_non_associative_sadd_sat(
299+
; CHECK-SAME: ptr [[A:%.*]]) local_unnamed_addr {
300+
; CHECK-NEXT: entry:
301+
; CHECK-NEXT: [[TOBOOL_NOT:%.*]] = icmp eq ptr [[A]], null
302+
; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[COMMON_RET6:%.*]], label [[IF_END:%.*]]
303+
; CHECK: common.ret6:
304+
; CHECK-NEXT: ret i32 -1
305+
; CHECK: if.end:
306+
; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
307+
; CHECK-NEXT: [[NEXT:%.*]] = getelementptr inbounds [[STRUCT_LISTNODE:%.*]], ptr [[A]], i64 0, i32 1
308+
; CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[NEXT]], align 8
309+
; CHECK-NEXT: [[CALL:%.*]] = tail call i32 @test_non_associative_sadd_sat(ptr [[TMP1]])
310+
; CHECK-NEXT: [[DOTSROA_SPECULATED:%.*]] = tail call i32 @llvm.sadd.sat.i32(i32 [[TMP0]], i32 [[CALL]])
311+
; CHECK-NEXT: ret i32 [[DOTSROA_SPECULATED]]
312+
;
313+
entry:
314+
%tobool.not = icmp eq ptr %a, null
315+
br i1 %tobool.not, label %common.ret6, label %if.end
316+
317+
common.ret6: ; preds = %entry, %if.end
318+
%common.ret6.op = phi i32 [ %.sroa.speculated, %if.end ], [ -1, %entry ]
319+
ret i32 %common.ret6.op
320+
321+
if.end: ; preds = %entry
322+
%0 = load i32, ptr %a
323+
%next = getelementptr inbounds %struct.ListNode, ptr %a, i64 0, i32 1
324+
%1 = load ptr, ptr %next
325+
%call = tail call i32 @test_non_associative_sadd_sat(ptr %1)
326+
%.sroa.speculated = tail call i32 @llvm.sadd.sat.i32(i32 %0, i32 %call)
327+
br label %common.ret6
328+
}
329+
330+
declare i32 @llvm.sadd.sat.i32(i32, i32)

0 commit comments

Comments
 (0)