Skip to content

Reapply [AMDGPU] Avoid resource propagation for recursion through multiple functions #112251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 15, 2024
7 changes: 7 additions & 0 deletions llvm/include/llvm/MC/MCExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class MCExpr {
bool InParens = false) const;
void dump() const;

/// Returns whether the given symbol is used anywhere in the expression or
/// subexpressions.
bool isSymbolUsedInExpression(const MCSymbol *Sym) const;

/// @}
/// \name Expression Evaluation
/// @{
Expand Down Expand Up @@ -663,6 +667,9 @@ class MCTargetExpr : public MCExpr {
const MCFixup *Fixup) const = 0;
// allow Target Expressions to be checked for equality
virtual bool isEqualTo(const MCExpr *x) const { return false; }
virtual bool isSymbolUsedInExpression(const MCSymbol *Sym) const {
return false;
}
// This should be set when assigned expressions are not valid ".set"
// expressions, e.g. registers, and must be inlined.
virtual bool inlineAssignedExpr() const { return false; }
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/MC/MCExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,35 @@ LLVM_DUMP_METHOD void MCExpr::dump() const {
}
#endif

bool MCExpr::isSymbolUsedInExpression(const MCSymbol *Sym) const {
switch (getKind()) {
case MCExpr::Binary: {
const MCBinaryExpr *BE = static_cast<const MCBinaryExpr *>(this);
return BE->getLHS()->isSymbolUsedInExpression(Sym) ||
BE->getRHS()->isSymbolUsedInExpression(Sym);
}
case MCExpr::Target: {
const MCTargetExpr *TE = static_cast<const MCTargetExpr *>(this);
return TE->isSymbolUsedInExpression(Sym);
}
case MCExpr::Constant:
return false;
case MCExpr::SymbolRef: {
const MCSymbol &S = static_cast<const MCSymbolRefExpr *>(this)->getSymbol();
if (S.isVariable() && !S.isWeakExternal())
return S.getVariableValue()->isSymbolUsedInExpression(Sym);
return &S == Sym;
}
case MCExpr::Unary: {
const MCExpr *SubExpr =
static_cast<const MCUnaryExpr *>(this)->getSubExpr();
return SubExpr->isSymbolUsedInExpression(Sym);
}
}

llvm_unreachable("Unknown expr kind!");
}

/* *** */

const MCBinaryExpr *MCBinaryExpr::create(Opcode Opc, const MCExpr *LHS,
Expand Down
29 changes: 1 addition & 28 deletions llvm/lib/MC/MCParser/AsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6398,33 +6398,6 @@ bool HLASMAsmParser::parseStatement(ParseStatementInfo &Info,
namespace llvm {
namespace MCParserUtils {

/// Returns whether the given symbol is used anywhere in the given expression,
/// or subexpressions.
static bool isSymbolUsedInExpression(const MCSymbol *Sym, const MCExpr *Value) {
switch (Value->getKind()) {
case MCExpr::Binary: {
const MCBinaryExpr *BE = static_cast<const MCBinaryExpr *>(Value);
return isSymbolUsedInExpression(Sym, BE->getLHS()) ||
isSymbolUsedInExpression(Sym, BE->getRHS());
}
case MCExpr::Target:
case MCExpr::Constant:
return false;
case MCExpr::SymbolRef: {
const MCSymbol &S =
static_cast<const MCSymbolRefExpr *>(Value)->getSymbol();
if (S.isVariable() && !S.isWeakExternal())
return isSymbolUsedInExpression(Sym, S.getVariableValue());
return &S == Sym;
}
case MCExpr::Unary:
return isSymbolUsedInExpression(
Sym, static_cast<const MCUnaryExpr *>(Value)->getSubExpr());
}

llvm_unreachable("Unknown expr kind!");
}

bool parseAssignmentExpression(StringRef Name, bool allow_redef,
MCAsmParser &Parser, MCSymbol *&Sym,
const MCExpr *&Value) {
Expand All @@ -6449,7 +6422,7 @@ bool parseAssignmentExpression(StringRef Name, bool allow_redef,
//
// FIXME: Diagnostics. Note the location of the definition as a label.
// FIXME: Diagnose assignment to protected identifier (e.g., register name).
if (isSymbolUsedInExpression(Sym, Value))
if (Value->isSymbolUsedInExpression(Sym))
return Parser.Error(EqualLoc, "Recursive use of '" + Name + "'");
else if (Sym->isUndefined(/*SetUsed*/ false) && !Sym->isUsed() &&
!Sym->isVariable())
Expand Down
52 changes: 42 additions & 10 deletions llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,50 @@ void MCResourceInfo::assignResourceInfoExpr(
const MCConstantExpr *LocalConstExpr =
MCConstantExpr::create(LocalValue, OutContext);
const MCExpr *SymVal = LocalConstExpr;
MCSymbol *Sym = getSymbol(FnSym->getName(), RIK, OutContext);
if (!Callees.empty()) {
SmallVector<const MCExpr *, 8> ArgExprs;
// Avoid recursive symbol assignment.
SmallPtrSet<const Function *, 8> Seen;
ArgExprs.push_back(LocalConstExpr);
const Function &F = MF.getFunction();
Seen.insert(&F);

for (const Function *Callee : Callees) {
if (!Seen.insert(Callee).second)
continue;

MCSymbol *CalleeFnSym = TM.getSymbol(&Callee->getFunction());
MCSymbol *CalleeValSym =
getSymbol(CalleeFnSym->getName(), RIK, OutContext);
ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));

// Avoid constructing recursive definitions by detecting whether `Sym` is
// found transitively within any of its `CalleeValSym`.
if (!CalleeValSym->isVariable() ||
!CalleeValSym->getVariableValue(/*isUsed=*/false)
->isSymbolUsedInExpression(Sym)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this miss adding other functions in the path to callee that we need to still add to the expression?

e.g.

f1->f2->f3->f2

From f1, we see that callee f2 appears in a recursive expression, so it is skipped. but that will mean we do not add f3 to the expression, even though it is needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sadly enough order matters in that it is possible for such a skip to occur. For now I've assumed worst case and taken the module level register maximums for recursion. Preferably I'd like to have cycle/SCC scope register maximums computed and used but I think that's out of scope for this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should try to flatten the expression to remove the recursive edges

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, will do so in a follow up PR

ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
} else {
// In case of recursion: make sure to use conservative register counts
// (i.e., specifically for VGPR/SGPR/AGPR).
switch (RIK) {
default:
break;
case RIK_NumVGPR:
ArgExprs.push_back(MCSymbolRefExpr::create(
getMaxVGPRSymbol(OutContext), OutContext));
break;
case RIK_NumSGPR:
ArgExprs.push_back(MCSymbolRefExpr::create(
getMaxSGPRSymbol(OutContext), OutContext));
break;
case RIK_NumAGPR:
ArgExprs.push_back(MCSymbolRefExpr::create(
getMaxAGPRSymbol(OutContext), OutContext));
break;
}
}
}
SymVal = AMDGPUMCExpr::create(Kind, ArgExprs, OutContext);
if (ArgExprs.size() > 1)
SymVal = AMDGPUMCExpr::create(Kind, ArgExprs, OutContext);
}
MCSymbol *Sym = getSymbol(FnSym->getName(), RIK, OutContext);
Sym->setVariableValue(SymVal);
}

Expand Down Expand Up @@ -163,6 +188,7 @@ void MCResourceInfo::gatherResourceInfo(
// The expression for private segment size should be: FRI.PrivateSegmentSize
// + max(FRI.Callees, FRI.CalleeSegmentSize)
SmallVector<const MCExpr *, 8> ArgExprs;
MCSymbol *Sym = getSymbol(FnSym->getName(), RIK_PrivateSegSize, OutContext);
if (FRI.CalleeSegmentSize)
ArgExprs.push_back(
MCConstantExpr::create(FRI.CalleeSegmentSize, OutContext));
Expand All @@ -174,9 +200,16 @@ void MCResourceInfo::gatherResourceInfo(
continue;
if (!Callee->isDeclaration()) {
MCSymbol *CalleeFnSym = TM.getSymbol(&Callee->getFunction());
MCSymbol *calleeValSym =
MCSymbol *CalleeValSym =
getSymbol(CalleeFnSym->getName(), RIK_PrivateSegSize, OutContext);
ArgExprs.push_back(MCSymbolRefExpr::create(calleeValSym, OutContext));

// Avoid constructing recursive definitions by detecting whether `Sym`
// is found transitively within any of its `CalleeValSym`.
if (!CalleeValSym->isVariable() ||
!CalleeValSym->getVariableValue(/*isUsed=*/false)
->isSymbolUsedInExpression(Sym)) {
ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
}
}
}
const MCExpr *localConstExpr =
Expand All @@ -187,8 +220,7 @@ void MCResourceInfo::gatherResourceInfo(
localConstExpr =
MCBinaryExpr::createAdd(localConstExpr, transitiveExpr, OutContext);
}
getSymbol(FnSym->getName(), RIK_PrivateSegSize, OutContext)
->setVariableValue(localConstExpr);
Sym->setVariableValue(localConstExpr);
}

auto SetToLocal = [&](int64_t LocalValue, ResourceInfoKind RIK) {
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,14 @@ const AMDGPUMCExpr *AMDGPUMCExpr::createOccupancy(unsigned InitOcc,
Ctx);
}

bool AMDGPUMCExpr::isSymbolUsedInExpression(const MCSymbol *Sym) const {
for (const MCExpr *E : getArgs()) {
if (E->isSymbolUsedInExpression(Sym))
return true;
}
return false;
}

static KnownBits fromOptionalToKnownBits(std::optional<bool> CompareResult) {
static constexpr unsigned BitWidth = 64;
const APInt True(BitWidth, 1);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class AMDGPUMCExpr : public MCTargetExpr {
void printImpl(raw_ostream &OS, const MCAsmInfo *MAI) const override;
bool evaluateAsRelocatableImpl(MCValue &Res, const MCAssembler *Asm,
const MCFixup *Fixup) const override;
bool isSymbolUsedInExpression(const MCSymbol *Sym) const override;
void visitUsedExpr(MCStreamer &Streamer) const override;
MCFragment *findAssociatedFragment() const override;
void fixELFSymbolsInTLSFixups(MCAssembler &) const override{};
Expand Down
130 changes: 130 additions & 0 deletions llvm/test/CodeGen/AMDGPU/function-resource-usage.ll
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,136 @@ define amdgpu_kernel void @usage_direct_recursion(i32 %n) #0 {
ret void
}

; GCN-LABEL: {{^}}multi_stage_recurse2:
; GCN: .set multi_stage_recurse2.num_vgpr, max(43, multi_stage_recurse1.num_vgpr)
; GCN: .set multi_stage_recurse2.num_agpr, max(0, multi_stage_recurse1.num_agpr)
; GCN: .set multi_stage_recurse2.numbered_sgpr, max(34, multi_stage_recurse1.numbered_sgpr)
; GCN: .set multi_stage_recurse2.private_seg_size, 16+(max(multi_stage_recurse1.private_seg_size))
; GCN: .set multi_stage_recurse2.uses_vcc, or(1, multi_stage_recurse1.uses_vcc)
; GCN: .set multi_stage_recurse2.uses_flat_scratch, or(0, multi_stage_recurse1.uses_flat_scratch)
; GCN: .set multi_stage_recurse2.has_dyn_sized_stack, or(0, multi_stage_recurse1.has_dyn_sized_stack)
; GCN: .set multi_stage_recurse2.has_recursion, or(1, multi_stage_recurse1.has_recursion)
; GCN: .set multi_stage_recurse2.has_indirect_call, or(0, multi_stage_recurse1.has_indirect_call)
; GCN: TotalNumSgprs: multi_stage_recurse2.numbered_sgpr+(extrasgprs(multi_stage_recurse2.uses_vcc, multi_stage_recurse2.uses_flat_scratch, 1))
; GCN: NumVgprs: max(43, multi_stage_recurse1.num_vgpr)
; GCN: ScratchSize: 16+(max(multi_stage_recurse1.private_seg_size))
; GCN-LABEL: {{^}}multi_stage_recurse1:
; GCN: .set multi_stage_recurse1.num_vgpr, max(48, amdgpu.max_num_vgpr)
; GCN: .set multi_stage_recurse1.num_agpr, max(0, amdgpu.max_num_agpr)
; GCN: .set multi_stage_recurse1.numbered_sgpr, max(34, amdgpu.max_num_sgpr)
; GCN: .set multi_stage_recurse1.private_seg_size, 16
; GCN: .set multi_stage_recurse1.uses_vcc, 1
; GCN: .set multi_stage_recurse1.uses_flat_scratch, 0
; GCN: .set multi_stage_recurse1.has_dyn_sized_stack, 0
; GCN: .set multi_stage_recurse1.has_recursion, 1
; GCN: .set multi_stage_recurse1.has_indirect_call, 0
; GCN: TotalNumSgprs: multi_stage_recurse1.numbered_sgpr+4
; GCN: NumVgprs: max(48, amdgpu.max_num_vgpr)
; GCN: ScratchSize: 16
define void @multi_stage_recurse1(i32 %val) #2 {
call void @multi_stage_recurse2(i32 %val)
call void asm sideeffect "", "~{v47}"() #0
ret void
}
define void @multi_stage_recurse2(i32 %val) #2 {
call void @multi_stage_recurse1(i32 %val)
call void asm sideeffect "", "~{v42}"() #0
ret void
}

; GCN-LABEL: {{^}}usage_multi_stage_recurse:
; GCN: .set usage_multi_stage_recurse.num_vgpr, max(32, multi_stage_recurse1.num_vgpr)
; GCN: .set usage_multi_stage_recurse.num_agpr, max(0, multi_stage_recurse1.num_agpr)
; GCN: .set usage_multi_stage_recurse.numbered_sgpr, max(33, multi_stage_recurse1.numbered_sgpr)
; GCN: .set usage_multi_stage_recurse.private_seg_size, 0+(max(multi_stage_recurse1.private_seg_size))
; GCN: .set usage_multi_stage_recurse.uses_vcc, or(1, multi_stage_recurse1.uses_vcc)
; GCN: .set usage_multi_stage_recurse.uses_flat_scratch, or(1, multi_stage_recurse1.uses_flat_scratch)
; GCN: .set usage_multi_stage_recurse.has_dyn_sized_stack, or(0, multi_stage_recurse1.has_dyn_sized_stack)
; GCN: .set usage_multi_stage_recurse.has_recursion, or(1, multi_stage_recurse1.has_recursion)
; GCN: .set usage_multi_stage_recurse.has_indirect_call, or(0, multi_stage_recurse1.has_indirect_call)
; GCN: TotalNumSgprs: usage_multi_stage_recurse.numbered_sgpr+6
; GCN: NumVgprs: usage_multi_stage_recurse.num_vgpr
; GCN: ScratchSize: 16
define amdgpu_kernel void @usage_multi_stage_recurse(i32 %n) #0 {
call void @multi_stage_recurse1(i32 %n)
ret void
}

; GCN-LABEL: {{^}}multi_stage_recurse_noattr2:
; GCN: .set multi_stage_recurse_noattr2.num_vgpr, max(41, multi_stage_recurse_noattr1.num_vgpr)
; GCN: .set multi_stage_recurse_noattr2.num_agpr, max(0, multi_stage_recurse_noattr1.num_agpr)
; GCN: .set multi_stage_recurse_noattr2.numbered_sgpr, max(54, multi_stage_recurse_noattr1.numbered_sgpr)
; GCN: .set multi_stage_recurse_noattr2.private_seg_size, 16+(max(multi_stage_recurse_noattr1.private_seg_size))
; GCN: .set multi_stage_recurse_noattr2.uses_vcc, or(1, multi_stage_recurse_noattr1.uses_vcc)
; GCN: .set multi_stage_recurse_noattr2.uses_flat_scratch, or(0, multi_stage_recurse_noattr1.uses_flat_scratch)
; GCN: .set multi_stage_recurse_noattr2.has_dyn_sized_stack, or(0, multi_stage_recurse_noattr1.has_dyn_sized_stack)
; GCN: .set multi_stage_recurse_noattr2.has_recursion, or(0, multi_stage_recurse_noattr1.has_recursion)
; GCN: .set multi_stage_recurse_noattr2.has_indirect_call, or(0, multi_stage_recurse_noattr1.has_indirect_call)
; GCN: TotalNumSgprs: multi_stage_recurse_noattr2.numbered_sgpr+(extrasgprs(multi_stage_recurse_noattr2.uses_vcc, multi_stage_recurse_noattr2.uses_flat_scratch, 1))
; GCN: NumVgprs: max(41, multi_stage_recurse_noattr1.num_vgpr)
; GCN: ScratchSize: 16+(max(multi_stage_recurse_noattr1.private_seg_size))
; GCN-LABEL: {{^}}multi_stage_recurse_noattr1:
; GCN: .set multi_stage_recurse_noattr1.num_vgpr, max(41, amdgpu.max_num_vgpr)
; GCN: .set multi_stage_recurse_noattr1.num_agpr, max(0, amdgpu.max_num_agpr)
; GCN: .set multi_stage_recurse_noattr1.numbered_sgpr, max(57, amdgpu.max_num_sgpr)
; GCN: .set multi_stage_recurse_noattr1.private_seg_size, 16
; GCN: .set multi_stage_recurse_noattr1.uses_vcc, 1
; GCN: .set multi_stage_recurse_noattr1.uses_flat_scratch, 0
; GCN: .set multi_stage_recurse_noattr1.has_dyn_sized_stack, 0
; GCN: .set multi_stage_recurse_noattr1.has_recursion, 0
; GCN: .set multi_stage_recurse_noattr1.has_indirect_call, 0
; GCN: TotalNumSgprs: multi_stage_recurse_noattr1.numbered_sgpr+4
; GCN: NumVgprs: max(41, amdgpu.max_num_vgpr)
; GCN: ScratchSize: 16
define void @multi_stage_recurse_noattr1(i32 %val) #0 {
call void @multi_stage_recurse_noattr2(i32 %val)
call void asm sideeffect "", "~{s56}"() #0
ret void
}
define void @multi_stage_recurse_noattr2(i32 %val) #0 {
call void @multi_stage_recurse_noattr1(i32 %val)
call void asm sideeffect "", "~{s53}"() #0
ret void
}

; GCN-LABEL: {{^}}usage_multi_stage_recurse_noattrs:
; GCN: .set usage_multi_stage_recurse_noattrs.num_vgpr, max(32, multi_stage_recurse_noattr1.num_vgpr)
; GCN: .set usage_multi_stage_recurse_noattrs.num_agpr, max(0, multi_stage_recurse_noattr1.num_agpr)
; GCN: .set usage_multi_stage_recurse_noattrs.numbered_sgpr, max(33, multi_stage_recurse_noattr1.numbered_sgpr)
; GCN: .set usage_multi_stage_recurse_noattrs.private_seg_size, 0+(max(multi_stage_recurse_noattr1.private_seg_size))
; GCN: .set usage_multi_stage_recurse_noattrs.uses_vcc, or(1, multi_stage_recurse_noattr1.uses_vcc)
; GCN: .set usage_multi_stage_recurse_noattrs.uses_flat_scratch, or(1, multi_stage_recurse_noattr1.uses_flat_scratch)
; GCN: .set usage_multi_stage_recurse_noattrs.has_dyn_sized_stack, or(0, multi_stage_recurse_noattr1.has_dyn_sized_stack)
; GCN: .set usage_multi_stage_recurse_noattrs.has_recursion, or(0, multi_stage_recurse_noattr1.has_recursion)
; GCN: .set usage_multi_stage_recurse_noattrs.has_indirect_call, or(0, multi_stage_recurse_noattr1.has_indirect_call)
; GCN: TotalNumSgprs: usage_multi_stage_recurse_noattrs.numbered_sgpr+6
; GCN: NumVgprs: usage_multi_stage_recurse_noattrs.num_vgpr
; GCN: ScratchSize: 16
define amdgpu_kernel void @usage_multi_stage_recurse_noattrs(i32 %n) #0 {
call void @multi_stage_recurse_noattr1(i32 %n)
ret void
}

; GCN-LABEL: {{^}}multi_call_with_multi_stage_recurse:
; GCN: .set multi_call_with_multi_stage_recurse.num_vgpr, max(41, use_stack0.num_vgpr, use_stack1.num_vgpr, multi_stage_recurse1.num_vgpr)
; GCN: .set multi_call_with_multi_stage_recurse.num_agpr, max(0, use_stack0.num_agpr, use_stack1.num_agpr, multi_stage_recurse1.num_agpr)
; GCN: .set multi_call_with_multi_stage_recurse.numbered_sgpr, max(43, use_stack0.numbered_sgpr, use_stack1.numbered_sgpr, multi_stage_recurse1.numbered_sgpr)
; GCN: .set multi_call_with_multi_stage_recurse.private_seg_size, 0+(max(use_stack0.private_seg_size, use_stack1.private_seg_size, multi_stage_recurse1.private_seg_size))
; GCN: .set multi_call_with_multi_stage_recurse.uses_vcc, or(1, use_stack0.uses_vcc, use_stack1.uses_vcc, multi_stage_recurse1.uses_vcc)
; GCN: .set multi_call_with_multi_stage_recurse.uses_flat_scratch, or(1, use_stack0.uses_flat_scratch, use_stack1.uses_flat_scratch, multi_stage_recurse1.uses_flat_scratch)
; GCN: .set multi_call_with_multi_stage_recurse.has_dyn_sized_stack, or(0, use_stack0.has_dyn_sized_stack, use_stack1.has_dyn_sized_stack, multi_stage_recurse1.has_dyn_sized_stack)
; GCN: .set multi_call_with_multi_stage_recurse.has_recursion, or(1, use_stack0.has_recursion, use_stack1.has_recursion, multi_stage_recurse1.has_recursion)
; GCN: .set multi_call_with_multi_stage_recurse.has_indirect_call, or(0, use_stack0.has_indirect_call, use_stack1.has_indirect_call, multi_stage_recurse1.has_indirect_call)
; GCN: TotalNumSgprs: multi_call_with_multi_stage_recurse.numbered_sgpr+6
; GCN: NumVgprs: multi_call_with_multi_stage_recurse.num_vgpr
; GCN: ScratchSize: 2052
define amdgpu_kernel void @multi_call_with_multi_stage_recurse(i32 %n) #0 {
call void @use_stack0()
call void @use_stack1()
call void @multi_stage_recurse1(i32 %n)
ret void
}

; Make sure there's no assert when a sgpr96 is used.
; GCN-LABEL: {{^}}count_use_sgpr96_external_call
; GCN: .set count_use_sgpr96_external_call.num_vgpr, max(32, amdgpu.max_num_vgpr)
Expand Down
Loading
Loading