Skip to content

[AMDGPU] Avoid resource propagation for recursion through multiple functions #111004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 83 additions & 7 deletions llvm/lib/Target/AMDGPU/AMDGPUMCResourceInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,77 @@ MCSymbol *MCResourceInfo::getMaxSGPRSymbol(MCContext &OutContext) {
return OutContext.getOrCreateSymbol("amdgpu.max_num_sgpr");
}

// The (partially complete) expression should have no recursion in it. After
// all, we're trying to avoid recursion using this codepath. Returns true if
// Sym is found within Expr without recursing on Expr, false otherwise.
static bool findSymbolInExpr(MCSymbol *Sym, const MCExpr *Expr,
SmallVectorImpl<const MCExpr *> &Exprs,
SmallPtrSetImpl<const MCExpr *> &Visited) {
// Assert if any of the expressions is already visited (i.e., there is
// existing recursion).
if (!Visited.insert(Expr).second)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the advantage of this over assert? Is it because llvm_unreachable can be enabled separately?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't require the annoying variable only used in asserts builds problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the compiler will optimize out the unused variable if llvm_unreachable is a no-op. Guess it's

if (!Visited.insert(Expr).second)
  llvm_unreachable("already visited expression");

vs

[[maybe_unused]] bool AlreadyVisited = Visited.insert(Expr).second;
assert(!AlreadyVisited && "already visited expression");

Roughly equivalent so I guess it's up to taste.

llvm_unreachable("already visited expression");

switch (Expr->getKind()) {
default:
return false;
case MCExpr::ExprKind::SymbolRef: {
const MCSymbolRefExpr *SymRefExpr = cast<MCSymbolRefExpr>(Expr);
const MCSymbol &SymRef = SymRefExpr->getSymbol();
if (Sym == &SymRef)
return true;
if (SymRef.isVariable())
Exprs.push_back(SymRef.getVariableValue(/*isUsed=*/false));
return false;
}
case MCExpr::ExprKind::Binary: {
const MCBinaryExpr *BExpr = cast<MCBinaryExpr>(Expr);
Exprs.push_back(BExpr->getLHS());
Exprs.push_back(BExpr->getRHS());
return false;
}
case MCExpr::ExprKind::Unary: {
const MCUnaryExpr *UExpr = cast<MCUnaryExpr>(Expr);
Exprs.push_back(UExpr->getSubExpr());
return false;
}
case MCExpr::ExprKind::Target: {
const AMDGPUMCExpr *AGVK = cast<AMDGPUMCExpr>(Expr);
for (const MCExpr *E : AGVK->getArgs())
Exprs.push_back(E);
return false;
}
}
}

// Symbols whose values eventually are used through their defines (i.e.,
// recursive) must be avoided. Do a walk over Expr to see if Sym will occur in
// it. The Expr is an MCExpr given through a callee's equivalent MCSymbol so if
// no recursion is found Sym can be safely assigned to a (sub-)expr which
// contains the symbol Expr is associated with. Returns true if Sym exists
// in Expr or its sub-expressions, false otherwise.
static bool foundRecursiveSymbolDef(MCSymbol *Sym, const MCExpr *Expr) {
SmallVector<const MCExpr *, 8> WorkList;
SmallPtrSet<const MCExpr *, 8> Visited;
WorkList.push_back(Expr);

while (!WorkList.empty()) {
const MCExpr *CurExpr = WorkList.pop_back_val();
if (findSymbolInExpr(Sym, CurExpr, WorkList, Visited))
return true;
}

return false;
}

void MCResourceInfo::assignResourceInfoExpr(
int64_t LocalValue, ResourceInfoKind RIK, AMDGPUMCExpr::VariantKind Kind,
const MachineFunction &MF, const SmallVectorImpl<const Function *> &Callees,
MCContext &OutContext) {
const MCConstantExpr *LocalConstExpr =
MCConstantExpr::create(LocalValue, OutContext);
const MCExpr *SymVal = LocalConstExpr;
MCSymbol *Sym = getSymbol(MF.getName(), RIK, OutContext);
if (!Callees.empty()) {
SmallVector<const MCExpr *, 8> ArgExprs;
// Avoid recursive symbol assignment.
Expand All @@ -110,11 +174,17 @@ void MCResourceInfo::assignResourceInfoExpr(
if (!Seen.insert(Callee).second)
continue;
MCSymbol *CalleeValSym = getSymbol(Callee->getName(), RIK, OutContext);
ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
bool CalleeIsVar = CalleeValSym->isVariable();
if (!CalleeIsVar ||
(CalleeIsVar &&
!foundRecursiveSymbolDef(
Sym, CalleeValSym->getVariableValue(/*IsUsed=*/false)))) {
ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 2 paths to this push_back, so merge the push_back into one condition

}
}
SymVal = AMDGPUMCExpr::create(Kind, ArgExprs, OutContext);
if (ArgExprs.size() > 1)
SymVal = AMDGPUMCExpr::create(Kind, ArgExprs, OutContext);
}
MCSymbol *Sym = getSymbol(MF.getName(), RIK, OutContext);
Sym->setVariableValue(SymVal);
}

Expand Down Expand Up @@ -155,6 +225,7 @@ void MCResourceInfo::gatherResourceInfo(
// The expression for private segment size should be: FRI.PrivateSegmentSize
// + max(FRI.Callees, FRI.CalleeSegmentSize)
SmallVector<const MCExpr *, 8> ArgExprs;
MCSymbol *Sym = getSymbol(MF.getName(), RIK_PrivateSegSize, OutContext);
if (FRI.CalleeSegmentSize)
ArgExprs.push_back(
MCConstantExpr::create(FRI.CalleeSegmentSize, OutContext));
Expand All @@ -165,9 +236,15 @@ void MCResourceInfo::gatherResourceInfo(
if (!Seen.insert(Callee).second)
continue;
if (!Callee->isDeclaration()) {
MCSymbol *calleeValSym =
MCSymbol *CalleeValSym =
getSymbol(Callee->getName(), RIK_PrivateSegSize, OutContext);
ArgExprs.push_back(MCSymbolRefExpr::create(calleeValSym, OutContext));
bool CalleeIsVar = CalleeValSym->isVariable();
if (!CalleeIsVar ||
(CalleeIsVar &&
!foundRecursiveSymbolDef(
Sym, CalleeValSym->getVariableValue(/*IsUsed=*/false)))) {
ArgExprs.push_back(MCSymbolRefExpr::create(CalleeValSym, OutContext));
}
}
}
const MCExpr *localConstExpr =
Expand All @@ -178,8 +255,7 @@ void MCResourceInfo::gatherResourceInfo(
localConstExpr =
MCBinaryExpr::createAdd(localConstExpr, transitiveExpr, OutContext);
}
getSymbol(MF.getName(), RIK_PrivateSegSize, OutContext)
->setVariableValue(localConstExpr);
Sym->setVariableValue(localConstExpr);
}

auto SetToLocal = [&](int64_t LocalValue, ResourceInfoKind RIK) {
Expand Down
126 changes: 126 additions & 0 deletions llvm/test/CodeGen/AMDGPU/function-resource-usage.ll
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,132 @@ define amdgpu_kernel void @usage_direct_recursion(i32 %n) #0 {
ret void
}

; GCN-LABEL: {{^}}multi_stage_recurse2:
; GCN: .set multi_stage_recurse2.num_vgpr, max(41, multi_stage_recurse1.num_vgpr)
; GCN: .set multi_stage_recurse2.num_agpr, max(0, multi_stage_recurse1.num_agpr)
; GCN: .set multi_stage_recurse2.numbered_sgpr, max(34, multi_stage_recurse1.numbered_sgpr)
; GCN: .set multi_stage_recurse2.private_seg_size, 16+(max(multi_stage_recurse1.private_seg_size))
; GCN: .set multi_stage_recurse2.uses_vcc, or(1, multi_stage_recurse1.uses_vcc)
; GCN: .set multi_stage_recurse2.uses_flat_scratch, or(0, multi_stage_recurse1.uses_flat_scratch)
; GCN: .set multi_stage_recurse2.has_dyn_sized_stack, or(0, multi_stage_recurse1.has_dyn_sized_stack)
; GCN: .set multi_stage_recurse2.has_recursion, or(1, multi_stage_recurse1.has_recursion)
; GCN: .set multi_stage_recurse2.has_indirect_call, or(0, multi_stage_recurse1.has_indirect_call)
; GCN: TotalNumSgprs: multi_stage_recurse2.numbered_sgpr+(extrasgprs(multi_stage_recurse2.uses_vcc, multi_stage_recurse2.uses_flat_scratch, 1))
; GCN: NumVgprs: max(41, multi_stage_recurse1.num_vgpr)
; GCN: ScratchSize: 16+(max(multi_stage_recurse1.private_seg_size))
; GCN-LABEL: {{^}}multi_stage_recurse1:
; GCN: .set multi_stage_recurse1.num_vgpr, 41
; GCN: .set multi_stage_recurse1.num_agpr, 0
; GCN: .set multi_stage_recurse1.numbered_sgpr, 34
; GCN: .set multi_stage_recurse1.private_seg_size, 16
; GCN: .set multi_stage_recurse1.uses_vcc, 1
; GCN: .set multi_stage_recurse1.uses_flat_scratch, 0
; GCN: .set multi_stage_recurse1.has_dyn_sized_stack, 0
; GCN: .set multi_stage_recurse1.has_recursion, 1
; GCN: .set multi_stage_recurse1.has_indirect_call, 0
; GCN: TotalNumSgprs: 38
; GCN: NumVgprs: 41
; GCN: ScratchSize: 16
define void @multi_stage_recurse1(i32 %val) #2 {
call void @multi_stage_recurse2(i32 %val)
ret void
}
define void @multi_stage_recurse2(i32 %val) #2 {
call void @multi_stage_recurse1(i32 %val)
ret void
}

; GCN-LABEL: {{^}}usage_multi_stage_recurse:
; GCN: .set usage_multi_stage_recurse.num_vgpr, max(32, multi_stage_recurse1.num_vgpr)
; GCN: .set usage_multi_stage_recurse.num_agpr, max(0, multi_stage_recurse1.num_agpr)
; GCN: .set usage_multi_stage_recurse.numbered_sgpr, max(33, multi_stage_recurse1.numbered_sgpr)
; GCN: .set usage_multi_stage_recurse.private_seg_size, 0+(max(multi_stage_recurse1.private_seg_size))
; GCN: .set usage_multi_stage_recurse.uses_vcc, or(1, multi_stage_recurse1.uses_vcc)
; GCN: .set usage_multi_stage_recurse.uses_flat_scratch, or(1, multi_stage_recurse1.uses_flat_scratch)
; GCN: .set usage_multi_stage_recurse.has_dyn_sized_stack, or(0, multi_stage_recurse1.has_dyn_sized_stack)
; GCN: .set usage_multi_stage_recurse.has_recursion, or(1, multi_stage_recurse1.has_recursion)
; GCN: .set usage_multi_stage_recurse.has_indirect_call, or(0, multi_stage_recurse1.has_indirect_call)
; GCN: TotalNumSgprs: 40
; GCN: NumVgprs: 41
; GCN: ScratchSize: 16
define amdgpu_kernel void @usage_multi_stage_recurse(i32 %n) #0 {
call void @multi_stage_recurse1(i32 %n)
ret void
}

; GCN-LABEL: {{^}}multi_stage_recurse_noattr2:
; GCN: .set multi_stage_recurse_noattr2.num_vgpr, max(41, multi_stage_recurse_noattr1.num_vgpr)
; GCN: .set multi_stage_recurse_noattr2.num_agpr, max(0, multi_stage_recurse_noattr1.num_agpr)
; GCN: .set multi_stage_recurse_noattr2.numbered_sgpr, max(34, multi_stage_recurse_noattr1.numbered_sgpr)
; GCN: .set multi_stage_recurse_noattr2.private_seg_size, 16+(max(multi_stage_recurse_noattr1.private_seg_size))
; GCN: .set multi_stage_recurse_noattr2.uses_vcc, or(1, multi_stage_recurse_noattr1.uses_vcc)
; GCN: .set multi_stage_recurse_noattr2.uses_flat_scratch, or(0, multi_stage_recurse_noattr1.uses_flat_scratch)
; GCN: .set multi_stage_recurse_noattr2.has_dyn_sized_stack, or(0, multi_stage_recurse_noattr1.has_dyn_sized_stack)
; GCN: .set multi_stage_recurse_noattr2.has_recursion, or(0, multi_stage_recurse_noattr1.has_recursion)
; GCN: .set multi_stage_recurse_noattr2.has_indirect_call, or(0, multi_stage_recurse_noattr1.has_indirect_call)
; GCN: TotalNumSgprs: multi_stage_recurse_noattr2.numbered_sgpr+(extrasgprs(multi_stage_recurse_noattr2.uses_vcc, multi_stage_recurse_noattr2.uses_flat_scratch, 1))
; GCN: NumVgprs: max(41, multi_stage_recurse_noattr1.num_vgpr)
; GCN: ScratchSize: 16+(max(multi_stage_recurse_noattr1.private_seg_size))
; GCN-LABEL: {{^}}multi_stage_recurse_noattr1:
; GCN: .set multi_stage_recurse_noattr1.num_vgpr, 41
; GCN: .set multi_stage_recurse_noattr1.num_agpr, 0
; GCN: .set multi_stage_recurse_noattr1.numbered_sgpr, 34
; GCN: .set multi_stage_recurse_noattr1.private_seg_size, 16
; GCN: .set multi_stage_recurse_noattr1.uses_vcc, 1
; GCN: .set multi_stage_recurse_noattr1.uses_flat_scratch, 0
; GCN: .set multi_stage_recurse_noattr1.has_dyn_sized_stack, 0
; GCN: .set multi_stage_recurse_noattr1.has_recursion, 0
; GCN: .set multi_stage_recurse_noattr1.has_indirect_call, 0
; GCN: TotalNumSgprs: 38
; GCN: NumVgprs: 41
; GCN: ScratchSize: 16
define void @multi_stage_recurse_noattr1(i32 %val) #0 {
call void @multi_stage_recurse_noattr2(i32 %val)
ret void
}
define void @multi_stage_recurse_noattr2(i32 %val) #0 {
call void @multi_stage_recurse_noattr1(i32 %val)
ret void
}

; GCN-LABEL: {{^}}usage_multi_stage_recurse_noattrs:
; GCN: .set usage_multi_stage_recurse_noattrs.num_vgpr, max(32, multi_stage_recurse_noattr1.num_vgpr)
; GCN: .set usage_multi_stage_recurse_noattrs.num_agpr, max(0, multi_stage_recurse_noattr1.num_agpr)
; GCN: .set usage_multi_stage_recurse_noattrs.numbered_sgpr, max(33, multi_stage_recurse_noattr1.numbered_sgpr)
; GCN: .set usage_multi_stage_recurse_noattrs.private_seg_size, 0+(max(multi_stage_recurse_noattr1.private_seg_size))
; GCN: .set usage_multi_stage_recurse_noattrs.uses_vcc, or(1, multi_stage_recurse_noattr1.uses_vcc)
; GCN: .set usage_multi_stage_recurse_noattrs.uses_flat_scratch, or(1, multi_stage_recurse_noattr1.uses_flat_scratch)
; GCN: .set usage_multi_stage_recurse_noattrs.has_dyn_sized_stack, or(0, multi_stage_recurse_noattr1.has_dyn_sized_stack)
; GCN: .set usage_multi_stage_recurse_noattrs.has_recursion, or(0, multi_stage_recurse_noattr1.has_recursion)
; GCN: .set usage_multi_stage_recurse_noattrs.has_indirect_call, or(0, multi_stage_recurse_noattr1.has_indirect_call)
; GCN: TotalNumSgprs: 40
; GCN: NumVgprs: 41
; GCN: ScratchSize: 16
define amdgpu_kernel void @usage_multi_stage_recurse_noattrs(i32 %n) #0 {
call void @multi_stage_recurse_noattr1(i32 %n)
ret void
}

; GCN-LABEL: {{^}}multi_call_with_multi_stage_recurse:
; GCN: .set multi_call_with_multi_stage_recurse.num_vgpr, max(41, use_stack0.num_vgpr, use_stack1.num_vgpr, multi_stage_recurse1.num_vgpr)
; GCN: .set multi_call_with_multi_stage_recurse.num_agpr, max(0, use_stack0.num_agpr, use_stack1.num_agpr, multi_stage_recurse1.num_agpr)
; GCN: .set multi_call_with_multi_stage_recurse.numbered_sgpr, max(43, use_stack0.numbered_sgpr, use_stack1.numbered_sgpr, multi_stage_recurse1.numbered_sgpr)
; GCN: .set multi_call_with_multi_stage_recurse.private_seg_size, 0+(max(use_stack0.private_seg_size, use_stack1.private_seg_size, multi_stage_recurse1.private_seg_size))
; GCN: .set multi_call_with_multi_stage_recurse.uses_vcc, or(1, use_stack0.uses_vcc, use_stack1.uses_vcc, multi_stage_recurse1.uses_vcc)
; GCN: .set multi_call_with_multi_stage_recurse.uses_flat_scratch, or(1, use_stack0.uses_flat_scratch, use_stack1.uses_flat_scratch, multi_stage_recurse1.uses_flat_scratch)
; GCN: .set multi_call_with_multi_stage_recurse.has_dyn_sized_stack, or(0, use_stack0.has_dyn_sized_stack, use_stack1.has_dyn_sized_stack, multi_stage_recurse1.has_dyn_sized_stack)
; GCN: .set multi_call_with_multi_stage_recurse.has_recursion, or(1, use_stack0.has_recursion, use_stack1.has_recursion, multi_stage_recurse1.has_recursion)
; GCN: .set multi_call_with_multi_stage_recurse.has_indirect_call, or(0, use_stack0.has_indirect_call, use_stack1.has_indirect_call, multi_stage_recurse1.has_indirect_call)
; GCN: TotalNumSgprs: 49
; GCN: NumVgprs: 41
; GCN: ScratchSize: 2052
define amdgpu_kernel void @multi_call_with_multi_stage_recurse(i32 %n) #0 {
call void @use_stack0()
call void @use_stack1()
call void @multi_stage_recurse1(i32 %n)
ret void
}

; Make sure there's no assert when a sgpr96 is used.
; GCN-LABEL: {{^}}count_use_sgpr96_external_call
; GCN: .set count_use_sgpr96_external_call.num_vgpr, max(32, amdgpu.max_num_vgpr)
Expand Down
85 changes: 85 additions & 0 deletions llvm/test/CodeGen/AMDGPU/recursive-resource-usage-mcexpr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx90a < %s | FileCheck %s

; CHECK-LABEL: {{^}}qux
; CHECK: .set qux.num_vgpr, max(41, foo.num_vgpr)
; CHECK: .set qux.num_agpr, max(0, foo.num_agpr)
; CHECK: .set qux.numbered_sgpr, max(34, foo.numbered_sgpr)
; CHECK: .set qux.private_seg_size, 16
; CHECK: .set qux.uses_vcc, or(1, foo.uses_vcc)
; CHECK: .set qux.uses_flat_scratch, or(0, foo.uses_flat_scratch)
; CHECK: .set qux.has_dyn_sized_stack, or(0, foo.has_dyn_sized_stack)
; CHECK: .set qux.has_recursion, or(1, foo.has_recursion)
; CHECK: .set qux.has_indirect_call, or(0, foo.has_indirect_call)

; CHECK-LABEL: {{^}}baz
; CHECK: .set baz.num_vgpr, max(42, qux.num_vgpr)
; CHECK: .set baz.num_agpr, max(0, qux.num_agpr)
; CHECK: .set baz.numbered_sgpr, max(34, qux.numbered_sgpr)
; CHECK: .set baz.private_seg_size, 16+(max(qux.private_seg_size))
; CHECK: .set baz.uses_vcc, or(1, qux.uses_vcc)
; CHECK: .set baz.uses_flat_scratch, or(0, qux.uses_flat_scratch)
; CHECK: .set baz.has_dyn_sized_stack, or(0, qux.has_dyn_sized_stack)
; CHECK: .set baz.has_recursion, or(1, qux.has_recursion)
; CHECK: .set baz.has_indirect_call, or(0, qux.has_indirect_call)

; CHECK-LABEL: {{^}}bar
; CHECK: .set bar.num_vgpr, max(42, baz.num_vgpr)
; CHECK: .set bar.num_agpr, max(0, baz.num_agpr)
; CHECK: .set bar.numbered_sgpr, max(34, baz.numbered_sgpr)
; CHECK: .set bar.private_seg_size, 16+(max(baz.private_seg_size))
; CHECK: .set bar.uses_vcc, or(1, baz.uses_vcc)
; CHECK: .set bar.uses_flat_scratch, or(0, baz.uses_flat_scratch)
; CHECK: .set bar.has_dyn_sized_stack, or(0, baz.has_dyn_sized_stack)
; CHECK: .set bar.has_recursion, or(1, baz.has_recursion)
; CHECK: .set bar.has_indirect_call, or(0, baz.has_indirect_call)

; CHECK-LABEL: {{^}}foo
; CHECK: .set foo.num_vgpr, 42
; CHECK: .set foo.num_agpr, 0
; CHECK: .set foo.numbered_sgpr, 34
; CHECK: .set foo.private_seg_size, 16
; CHECK: .set foo.uses_vcc, 1
; CHECK: .set foo.uses_flat_scratch, 0
; CHECK: .set foo.has_dyn_sized_stack, 0
; CHECK: .set foo.has_recursion, 1
; CHECK: .set foo.has_indirect_call, 0

define void @foo() {
entry:
call void @bar()
ret void
}

define void @bar() {
entry:
call void @baz()
ret void
}

define void @baz() {
entry:
call void @qux()
ret void
}

define void @qux() {
entry:
call void @foo()
ret void
}

; CHECK-LABEL: {{^}}usefoo
; CHECK: .set usefoo.num_vgpr, max(32, foo.num_vgpr)
; CHECK: .set usefoo.num_agpr, max(0, foo.num_agpr)
; CHECK: .set usefoo.numbered_sgpr, max(33, foo.numbered_sgpr)
; CHECK: .set usefoo.private_seg_size, 0+(max(foo.private_seg_size))
; CHECK: .set usefoo.uses_vcc, or(1, foo.uses_vcc)
; CHECK: .set usefoo.uses_flat_scratch, or(1, foo.uses_flat_scratch)
; CHECK: .set usefoo.has_dyn_sized_stack, or(0, foo.has_dyn_sized_stack)
; CHECK: .set usefoo.has_recursion, or(1, foo.has_recursion)
; CHECK: .set usefoo.has_indirect_call, or(0, foo.has_indirect_call)
define amdgpu_kernel void @usefoo() {
call void @foo()
ret void
}

Loading