Skip to content

[Analysis]: Allow inlining recursive call IF recursion depth is 1. #119677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 6, 2025
Merged
81 changes: 81 additions & 0 deletions llvm/lib/Analysis/InlineCost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/DomConditionCache.h"
#include "llvm/Analysis/EphemeralValuesCache.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LoopInfo.h"
Expand Down Expand Up @@ -262,6 +263,8 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
// Cache the DataLayout since we use it a lot.
const DataLayout &DL;

DominatorTree DT;

/// The OptimizationRemarkEmitter available for this compilation.
OptimizationRemarkEmitter *ORE;

Expand Down Expand Up @@ -444,6 +447,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
bool canFoldInboundsGEP(GetElementPtrInst &I);
bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset);
bool simplifyCallSite(Function *F, CallBase &Call);
bool simplifyCmpInstForRecCall(CmpInst &Cmp);
bool simplifyInstruction(Instruction &I);
bool simplifyIntrinsicCallIsConstant(CallBase &CB);
bool simplifyIntrinsicCallObjectSize(CallBase &CB);
Expand Down Expand Up @@ -1676,6 +1680,79 @@ bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) {
return isGEPFree(I);
}

// Simplify \p Cmp if RHS is const and we can ValueTrack LHS.
// This handles the case only when the Cmp instruction is guarding a recursive
// call that will cause the Cmp to fail/succeed for the recursive call.
bool CallAnalyzer::simplifyCmpInstForRecCall(CmpInst &Cmp) {
// Bail out if LHS is not a function argument or RHS is NOT const:
if (!isa<Argument>(Cmp.getOperand(0)) || !isa<Constant>(Cmp.getOperand(1)))
return false;
auto *CmpOp = Cmp.getOperand(0);
Function *F = Cmp.getFunction();
// Iterate over the users of the function to check if it's a recursive
// function:
for (auto *U : F->users()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this iterating over the uses of the function? Shouldn't this be inspecting just the CandidateCall in particular?

It looks like this checks if there is any recursive call of the right form, even if it's not the call-site being analyzed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get the recursive call that is guarded by the cmp instr that I'm analyzing.

I can do that by either top-down approach by finding the branch of the icmp and then getting the successors and iterate over all the instructions in the successors to find the recursive call, or the other way is bottom-up approach by finding any recursive call of the function and check its predecessor and so on until I get the icmp.
so, I think the bottom-up approach is better here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern here is that there may be multiple calls of the function, only one of which is the recursive call you are interested in. We will separate compute the cost for each of these call-sites -- but we will treat each of them as if they were the recursive call, and thus incorrectly assign them a lower cost. Instead, we should only do this when trying to inline the actual recursive call, as given by CandidateCall.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I got your point. Yes, I agree with you.
I think there should be a patch for ValueTracking changes, and another following patch for the 2 points you mentioned about the Inliner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today, I will create a patch for the ValueTracking, and another patch for resolving your comments on InlineCost.

CallInst *Call = dyn_cast<CallInst>(U);
if (!Call || Call->getFunction() != F || Call->getCalledFunction() != F)
continue;
auto *CallBB = Call->getParent();
auto *Predecessor = CallBB->getSinglePredecessor();
// Only handle the case when the callsite has a single predecessor:
if (!Predecessor)
continue;

auto *Br = dyn_cast<BranchInst>(Predecessor->getTerminator());
if (!Br || Br->isUnconditional())
continue;
// Check if the Br condition is the same Cmp instr we are investigating:
if (Br->getCondition() != &Cmp)
continue;
// Check if there are any arg of the recursive callsite is affecting the cmp
// instr:
bool ArgFound = false;
Value *FuncArg = nullptr, *CallArg = nullptr;
for (unsigned ArgNum = 0;
ArgNum < F->arg_size() && ArgNum < Call->arg_size(); ArgNum++) {
FuncArg = F->getArg(ArgNum);
CallArg = Call->getArgOperand(ArgNum);
if (FuncArg == CmpOp && CallArg != CmpOp) {
ArgFound = true;
break;
}
}
if (!ArgFound)
continue;
// Now we have a recursive call that is guarded by a cmp instruction.
// Check if this cmp can be simplified:
SimplifyQuery SQ(DL, dyn_cast<Instruction>(CallArg));
DomConditionCache DC;
DC.registerBranch(Br);
SQ.DC = &DC;
if (DT.root_size() == 0) {
// Dominator tree was never constructed for any function yet.
DT.recalculate(*F);
} else if (DT.getRoot()->getParent() != F) {
// Dominator tree was constructed for a different function, recalculate
// it for the current function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fishy. Isn't the CallAnalyzer instantiated per call-site? How can we end up with different functions here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. The else-if statement should be removed.

DT.recalculate(*F);
}
SQ.DT = &DT;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to inject the condition via CondContext instead?

Copy link
Member Author

@hassnaaHamdi hassnaaHamdi May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the current logic at ValueTracking.cpp, then injecting the condition will not be useful.
But if the logic of ValueTracking.cpp::computeKnownFPClassFromContext got changed to check the CondContext and use directly computeKnownFPClassFromCond, similarly to the logic at computeKnownBitsFromContext, then injection will be useful.
Maybe I create a patch for the ValueTracking, and then a follow-up patch for applying your suggestion.

Value *SimplifiedInstruction = llvm::simplifyInstructionWithOperands(
cast<CmpInst>(&Cmp), {CallArg, Cmp.getOperand(1)}, SQ);
if (auto *ConstVal = dyn_cast_or_null<ConstantInt>(SimplifiedInstruction)) {
bool IsTrueSuccessor = CallBB == Br->getSuccessor(0);
// Make sure that the BB of the recursive call is NOT the next successor
// of the icmp. In other words, make sure that the recursion depth is 1.
if ((ConstVal->isOne() && !IsTrueSuccessor) ||
(ConstVal->isZero() && IsTrueSuccessor)) {
SimplifiedValues[&Cmp] = ConstVal;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were talking about this - I think it would be valid for this to always happen if we know that the condition simplifies, but the recursion would always be called, leading to a infinite recursion. It would just mean that we mark the function as IsRecursive=true and not inline, so the end results should be the same and this is just being more careful (IIUC). Either way seems OK to me.

return true;
}
}
}
return false;
}

/// Simplify \p I if its operands are constants and update SimplifiedValues.
bool CallAnalyzer::simplifyInstruction(Instruction &I) {
SmallVector<Constant *> COps;
Expand Down Expand Up @@ -2060,6 +2137,10 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) {
if (simplifyInstruction(I))
return true;

// Try to handle comparison that can be simplified using ValueTracking.
if (simplifyCmpInstForRecCall(I))
return true;

if (I.getOpcode() == Instruction::FCmp)
return false;

Expand Down
193 changes: 193 additions & 0 deletions llvm/test/Transforms/Inline/inline-recursive-fn.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='inline,instcombine' < %s | FileCheck %s

define float @inline_rec_true_successor(float %x, float %scale) {
; CHECK-LABEL: define float @inline_rec_true_successor(
; CHECK-SAME: float [[X:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp olt float [[X]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[COMMON_RET18:.*]]:
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[COMMON_RET18_OP_I:%.*]], %[[INLINE_REC_TRUE_SUCCESSOR_EXIT:.*]] ], [ [[MUL:%.*]], %[[IF_END]] ]
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: br i1 false, label %[[IF_THEN_I:.*]], label %[[IF_END_I:.*]]
; CHECK: [[IF_THEN_I]]:
; CHECK-NEXT: br label %[[INLINE_REC_TRUE_SUCCESSOR_EXIT]]
; CHECK: [[IF_END_I]]:
; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[X]]
; CHECK-NEXT: [[MUL_I:%.*]] = fmul float [[SCALE]], [[FNEG]]
; CHECK-NEXT: br label %[[INLINE_REC_TRUE_SUCCESSOR_EXIT]]
; CHECK: [[INLINE_REC_TRUE_SUCCESSOR_EXIT]]:
; CHECK-NEXT: [[COMMON_RET18_OP_I]] = phi float [ poison, %[[IF_THEN_I]] ], [ [[MUL_I]], %[[IF_END_I]] ]
; CHECK-NEXT: br label %[[COMMON_RET18]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: [[MUL]] = fmul float [[X]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
;
entry:
%cmp = fcmp olt float %x, 0.000000e+00
br i1 %cmp, label %if.then, label %if.end

common.ret18: ; preds = %if.then, %if.end
%common.ret18.op = phi float [ %call, %if.then ], [ %mul, %if.end ]
ret float %common.ret18.op

if.then: ; preds = %entry
%fneg = fneg float %x
%call = tail call float @inline_rec_true_successor(float %fneg, float %scale)
br label %common.ret18

if.end: ; preds = %entry
%mul = fmul float %x, %scale
br label %common.ret18
}

; Same as previous test except that the recursive callsite is in the false successor
define float @inline_rec_false_successor(float %x, float %scale) {
; CHECK-LABEL: define float @inline_rec_false_successor(
; CHECK-SAME: float [[Y:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp uge float [[Y]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[COMMON_RET18:.*]]:
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[MUL:%.*]], %[[IF_THEN]] ], [ [[COMMON_RET18_OP_I:%.*]], %[[INLINE_REC_FALSE_SUCCESSOR_EXIT:.*]] ]
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: [[MUL]] = fmul float [[Y]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: br i1 true, label %[[IF_THEN_I:.*]], label %[[IF_END_I:.*]]
; CHECK: [[IF_THEN_I]]:
; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[Y]]
; CHECK-NEXT: [[MUL_I:%.*]] = fmul float [[SCALE]], [[FNEG]]
; CHECK-NEXT: br label %[[INLINE_REC_FALSE_SUCCESSOR_EXIT]]
; CHECK: [[IF_END_I]]:
; CHECK-NEXT: br label %[[INLINE_REC_FALSE_SUCCESSOR_EXIT]]
; CHECK: [[INLINE_REC_FALSE_SUCCESSOR_EXIT]]:
; CHECK-NEXT: [[COMMON_RET18_OP_I]] = phi float [ [[MUL_I]], %[[IF_THEN_I]] ], [ poison, %[[IF_END_I]] ]
; CHECK-NEXT: br label %[[COMMON_RET18]]
;
entry:
%cmp = fcmp uge float %x, 0.000000e+00
br i1 %cmp, label %if.then, label %if.end

common.ret18: ; preds = %if.then, %if.end
%common.ret18.op = phi float [ %mul, %if.then ], [ %call, %if.end ]
ret float %common.ret18.op

if.then: ; preds = %entry
%mul = fmul float %x, %scale
br label %common.ret18

if.end: ; preds = %entry
%fneg = fneg float %x
%call = tail call float @inline_rec_false_successor(float %fneg, float %scale)
br label %common.ret18
}

; Test when the BR has Value not cmp instruction
define float @inline_rec_no_cmp(i1 %flag, float %scale) {
; CHECK-LABEL: define float @inline_rec_no_cmp(
; CHECK-SAME: i1 [[FLAG:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: br i1 [[FLAG]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: [[SUM:%.*]] = fadd float [[SCALE]], 5.000000e+00
; CHECK-NEXT: [[SUM1:%.*]] = fadd float [[SUM]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET:.*]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: [[SUM2:%.*]] = fadd float [[SCALE]], 5.000000e+00
; CHECK-NEXT: br label %[[COMMON_RET]]
; CHECK: [[COMMON_RET]]:
; CHECK-NEXT: [[COMMON_RET_RES:%.*]] = phi float [ [[SUM1]], %[[IF_THEN]] ], [ [[SUM2]], %[[IF_END]] ]
; CHECK-NEXT: ret float [[COMMON_RET_RES]]
;
entry:
br i1 %flag, label %if.then, label %if.end
if.then:
%res = tail call float @inline_rec_no_cmp(i1 false, float %scale)
%sum1 = fadd float %res, %scale
br label %common.ret
if.end:
%sum2 = fadd float %scale, 5.000000e+00
br label %common.ret
common.ret:
%common.ret.res = phi float [ %sum1, %if.then ], [ %sum2, %if.end ]
ret float %common.ret.res
}

define float @no_inline_rec(float %x, float %scale) {
; CHECK-LABEL: define float @no_inline_rec(
; CHECK-SAME: float [[Z:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp olt float [[Z]], 5.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[COMMON_RET18:.*]]:
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[FNEG1:%.*]], %[[IF_THEN]] ], [ [[MUL:%.*]], %[[IF_END]] ]
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: [[FADD:%.*]] = fadd float [[Z]], 5.000000e+00
; CHECK-NEXT: [[CALL:%.*]] = tail call float @no_inline_rec(float [[FADD]], float [[SCALE]])
; CHECK-NEXT: [[FNEG1]] = fneg float [[CALL]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: [[MUL]] = fmul float [[Z]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
;
entry:
%cmp = fcmp olt float %x, 5.000000e+00
br i1 %cmp, label %if.then, label %if.end

common.ret18: ; preds = %if.then, %if.end
%common.ret18.op = phi float [ %fneg1, %if.then ], [ %mul, %if.end ]
ret float %common.ret18.op

if.then: ; preds = %entry
%fadd = fadd float %x, 5.000000e+00
%call = tail call float @no_inline_rec(float %fadd, float %scale)
%fneg1 = fneg float %call
br label %common.ret18

if.end: ; preds = %entry
%mul = fmul float %x, %scale
br label %common.ret18
}

; Test when the icmp can be simplified but the recurison depth is NOT 1,
; so the recursive call will not be inlined.
define float @no_inline_rec_depth_not_1(float %x, float %scale) {
; CHECK-LABEL: define float @no_inline_rec_depth_not_1(
; CHECK-SAME: float [[X:%.*]], float [[SCALE:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[CMP:%.*]] = fcmp olt float [[X]], 0.000000e+00
; CHECK-NEXT: br i1 [[CMP]], label %[[IF_THEN:.*]], label %[[IF_END:.*]]
; CHECK: [[COMMON_RET18:.*]]:
; CHECK-NEXT: [[COMMON_RET18_OP:%.*]] = phi float [ [[CALL:%.*]], %[[IF_THEN]] ], [ [[MUL:%.*]], %[[IF_END]] ]
; CHECK-NEXT: ret float [[COMMON_RET18_OP]]
; CHECK: [[IF_THEN]]:
; CHECK-NEXT: [[CALL]] = tail call float @no_inline_rec_depth_not_1(float [[X]], float [[SCALE]])
; CHECK-NEXT: br label %[[COMMON_RET18]]
; CHECK: [[IF_END]]:
; CHECK-NEXT: [[MUL]] = fmul float [[X]], [[SCALE]]
; CHECK-NEXT: br label %[[COMMON_RET18]]
;
entry:
%cmp = fcmp olt float %x, 0.000000e+00
br i1 %cmp, label %if.then, label %if.end

common.ret18: ; preds = %if.then, %if.end
%common.ret18.op = phi float [ %call, %if.then ], [ %mul, %if.end ]
ret float %common.ret18.op

if.then: ; preds = %entry
%fneg1 = fneg float %x
%fneg = fneg float %fneg1
%call = tail call float @no_inline_rec_depth_not_1(float %fneg, float %scale)
br label %common.ret18

if.end: ; preds = %entry
%mul = fmul float %x, %scale
br label %common.ret18
}