-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[UniformAnalysis] Use Immediate postDom as last join #140013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests). This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom.
@llvm/pr-subscribers-llvm-adt @llvm/pr-subscribers-backend-amdgpu Author: Junjie Gu (jgu222) ChangesGiven a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests). This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom. Patch is 22.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140013.diff 8 Files Affected:
diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 6aa3a8b9b6e0b..e99d4b1c6dd45 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -77,6 +77,10 @@ template <typename _FunctionT> class GenericSSAContext {
// a given funciton.
using DominatorTreeT = DominatorTreeBase<BlockT, false>;
+ // A post-dominator tree provides the post-dominance relation between
+ // basic blocks in a given funciton.
+ using PostDominatorTreeT = DominatorTreeBase<BlockT, true>;
+
GenericSSAContext() = default;
GenericSSAContext(const FunctionT *F) : F(F) {}
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index d10355fff1bea..f404577bb7e56 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -263,6 +263,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
public:
using BlockT = typename ContextT::BlockT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using FunctionT = typename ContextT::FunctionT;
using ValueRefT = typename ContextT::ValueRefT;
using InstructionT = typename ContextT::InstructionT;
@@ -296,7 +297,9 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
using DivergencePropagatorT = DivergencePropagator<ContextT>;
GenericSyncDependenceAnalysis(const ContextT &Context,
- const DominatorTreeT &DT, const CycleInfoT &CI);
+ const DominatorTreeT &DT,
+ const PostDominatorTreeT &PDT,
+ const CycleInfoT &CI);
/// \brief Computes divergent join points and cycle exits caused by branch
/// divergence in \p Term.
@@ -315,6 +318,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
ModifiedPO CyclePO;
const DominatorTreeT &DT;
+ const PostDominatorTreeT &PDT;
const CycleInfoT &CI;
DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
@@ -336,6 +340,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
using UseT = typename ContextT::UseT;
using InstructionT = typename ContextT::InstructionT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using CycleInfoT = GenericCycleInfo<ContextT>;
using CycleT = typename CycleInfoT::CycleT;
@@ -348,10 +353,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
using TemporalDivergenceTuple =
std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
- GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
+ GenericUniformityAnalysisImpl(const DominatorTreeT &DT,
+ const PostDominatorTreeT &PDT,
+ const CycleInfoT &CI,
const TargetTransformInfo *TTI)
: Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
- TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
+ TTI(TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT, CI) {}
void initialize();
@@ -435,6 +442,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
private:
const DominatorTreeT &DT;
+ const PostDominatorTreeT &PDT;
// Recognized cycles with divergent exits.
SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
@@ -493,6 +501,7 @@ template <typename ContextT> class DivergencePropagator {
public:
using BlockT = typename ContextT::BlockT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using FunctionT = typename ContextT::FunctionT;
using ValueRefT = typename ContextT::ValueRefT;
@@ -507,6 +516,7 @@ template <typename ContextT> class DivergencePropagator {
const ModifiedPO &CyclePOT;
const DominatorTreeT &DT;
+ const PostDominatorTreeT &PDT;
const CycleInfoT &CI;
const BlockT &DivTermBlock;
const ContextT &Context;
@@ -522,10 +532,11 @@ template <typename ContextT> class DivergencePropagator {
BlockLabelMapT &BlockLabels;
DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
- const CycleInfoT &CI, const BlockT &DivTermBlock)
- : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
- Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
- BlockLabels(DivDesc->BlockLabels) {}
+ const PostDominatorTreeT &PDT, const CycleInfoT &CI,
+ const BlockT &DivTermBlock)
+ : CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI),
+ DivTermBlock(DivTermBlock), Context(CI.getSSAContext()),
+ DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {}
void printDefs(raw_ostream &Out) {
Out << "Propagator::BlockLabels {\n";
@@ -542,6 +553,12 @@ template <typename ContextT> class DivergencePropagator {
Out << "}\n";
}
+ const BlockT *getIPDom(const BlockT *B) {
+ const auto *Node = PDT.getNode(B);
+ const auto *IPDomNode = Node->getIDom();
+ return IPDomNode->getBlock();
+ }
+
// Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
// causes a divergent join.
bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
@@ -610,10 +627,11 @@ template <typename ContextT> class DivergencePropagator {
LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
<< Context.print(&DivTermBlock) << "\n");
- // Early stopping criterion
- int FloorIdx = CyclePOT.size() - 1;
- const BlockT *FloorLabel = nullptr;
- int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
+ // Immediate Post-dominator of DivTermBlock is the last join
+ // to visit.
+ const auto *ImmPDom = getIPDom(&DivTermBlock);
+
+ LLVM_DEBUG(dbgs() << "Last join: " << Context.print(ImmPDom) << "\n");
// Bootstrap with branch targets
auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
@@ -626,34 +644,29 @@ template <typename ContextT> class DivergencePropagator {
LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
<< Context.print(SuccBlock) << "\n");
}
- auto SuccIdx = CyclePOT.getIndex(SuccBlock);
visitEdge(*SuccBlock, *SuccBlock);
- FloorIdx = std::min<int>(FloorIdx, SuccIdx);
}
while (true) {
auto BlockIdx = FreshLabels.find_last();
- if (BlockIdx == -1 || BlockIdx < FloorIdx)
+ if (BlockIdx == -1)
break;
LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
FreshLabels.reset(BlockIdx);
- if (BlockIdx == DivTermIdx) {
- LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
+ const auto *Block = CyclePOT[BlockIdx];
+ if (Block == ImmPDom) {
+ LLVM_DEBUG(dbgs() << "Skipping ImmPDom\n");
continue;
}
- const auto *Block = CyclePOT[BlockIdx];
LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
<< BlockIdx << "\n");
const auto *Label = BlockLabels[Block];
assert(Label);
- bool CausedJoin = false;
- int LoweredFloorIdx = FloorIdx;
-
// If the current block is the header of a reducible cycle that
// contains the divergent branch, then the label should be
// propagated to the cycle exits. Such a header is the "last
@@ -681,28 +694,11 @@ template <typename ContextT> class DivergencePropagator {
if (const auto *BlockCycle = getReducibleParent(Block)) {
SmallVector<BlockT *, 4> BlockCycleExits;
BlockCycle->getExitBlocks(BlockCycleExits);
- for (auto *BlockCycleExit : BlockCycleExits) {
- CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
- LoweredFloorIdx =
- std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
- }
+ for (auto *BlockCycleExit : BlockCycleExits)
+ visitCycleExitEdge(*BlockCycleExit, *Label);
} else {
- for (const auto *SuccBlock : successors(Block)) {
- CausedJoin |= visitEdge(*SuccBlock, *Label);
- LoweredFloorIdx =
- std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
- }
- }
-
- // Floor update
- if (CausedJoin) {
- // 1. Different labels pushed to successors
- FloorIdx = LoweredFloorIdx;
- } else if (FloorLabel != Label) {
- // 2. No join caused BUT we pushed a label that is different than the
- // last pushed label
- FloorIdx = LoweredFloorIdx;
- FloorLabel = Label;
+ for (const auto *SuccBlock : successors(Block))
+ visitEdge(*SuccBlock, *Label);
}
}
@@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
template <typename ContextT>
llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
- const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
- : CyclePO(Context), DT(DT), CI(CI) {
+ const ContextT &Context, const DominatorTreeT &DT,
+ const PostDominatorTreeT &PDT, const CycleInfoT &CI)
+ : CyclePO(Context), DT(DT), PDT(PDT), CI(CI) {
CyclePO.compute(CI);
}
@@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
return *ItCached->second;
// compute all join points
- DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
+ DivergencePropagatorT Propagator(CyclePO, DT, PDT, CI, *DivTermBlock);
auto DivDesc = Propagator.computeJoinPoints();
auto printBlockSet = [&](ConstBlockSet &Blocks) {
@@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
template <typename ContextT>
GenericUniformityInfo<ContextT>::GenericUniformityInfo(
- const DominatorTreeT &DT, const CycleInfoT &CI,
- const TargetTransformInfo *TTI) {
- DA.reset(new ImplT{DT, CI, TTI});
+ const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+ const CycleInfoT &CI, const TargetTransformInfo *TTI) {
+ DA.reset(new ImplT{DT, PDT, CI, TTI});
}
template <typename ContextT>
diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h
index 9376fa6ee0bae..62d35582823dc 100644
--- a/llvm/include/llvm/ADT/GenericUniformityInfo.h
+++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h
@@ -35,6 +35,7 @@ template <typename ContextT> class GenericUniformityInfo {
using UseT = typename ContextT::UseT;
using InstructionT = typename ContextT::InstructionT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using ThisT = GenericUniformityInfo<ContextT>;
using CycleInfoT = GenericCycleInfo<ContextT>;
@@ -43,7 +44,8 @@ template <typename ContextT> class GenericUniformityInfo {
using TemporalDivergenceTuple =
std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
- GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI,
+ GenericUniformityInfo(const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+ const CycleInfoT &CI,
const TargetTransformInfo *TTI = nullptr);
GenericUniformityInfo() = default;
GenericUniformityInfo(GenericUniformityInfo &&) = default;
diff --git a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
index e8c0dc9b43823..03fc9ebfcf442 100644
--- a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
+++ b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
@@ -18,6 +18,7 @@
#include "llvm/CodeGen/MachineCycleAnalysis.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/CodeGen/MachineSSAContext.h"
namespace llvm {
@@ -31,7 +32,8 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
/// everything is uniform.
MachineUniformityInfo computeMachineUniformityInfo(
MachineFunction &F, const MachineCycleInfo &cycleInfo,
- const MachineDominatorTree &domTree, bool HasBranchDivergence);
+ const MachineDominatorTree &domTree,
+ const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence);
/// Legacy analysis pass which computes a \ref MachineUniformityInfo.
class MachineUniformityAnalysisPass : public MachineFunctionPass {
diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index 2101fdfacfc8f..a724a8c26d7db 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -9,6 +9,7 @@
#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/ADT/GenericUniformityImpl.h"
#include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
@@ -114,9 +115,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
FunctionAnalysisManager &FAM) {
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+ auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
auto &CI = FAM.getResult<CycleAnalysis>(F);
- UniformityInfo UI{DT, CI, &TTI};
+ UniformityInfo UI{DT, PDT, CI, &TTI};
// Skip computation if we can assume everything is uniform.
if (TTI.hasBranchDivergence(&F))
UI.compute();
@@ -148,6 +150,7 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
"Uniformity Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
@@ -156,6 +159,7 @@ INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<PostDominatorTreeWrapperPass>();
AU.addRequiredTransitive<CycleInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
}
@@ -163,11 +167,13 @@ void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &pdomTree = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
auto &targetTransformInfo =
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
m_function = &F;
- m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
+ m_uniformityInfo =
+ UniformityInfo{domTree, pdomTree, cycleInfo, &targetTransformInfo};
// Skip computation if we can assume everything is uniform.
if (targetTransformInfo.hasBranchDivergence(m_function))
diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index e4b82ce83fda6..b87f8357ecfa8 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -11,6 +11,7 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/MachineCycleAnalysis.h"
#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/MachineSSAContext.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
@@ -156,9 +157,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
MachineUniformityInfo llvm::computeMachineUniformityInfo(
MachineFunction &F, const MachineCycleInfo &cycleInfo,
- const MachineDominatorTree &domTree, bool HasBranchDivergence) {
+ const MachineDominatorTree &domTree,
+ const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence) {
assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
- MachineUniformityInfo UI(domTree, cycleInfo);
+ MachineUniformityInfo UI(domTree, pdomTree, cycleInfo);
if (HasBranchDivergence)
UI.compute();
return UI;
@@ -184,12 +186,13 @@ MachineUniformityAnalysis::Result
MachineUniformityAnalysis::run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM) {
auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
+ auto &PDomTree = MFAM.getResult<MachinePostDominatorTreeAnalysis>(MF);
auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
.getManager();
auto &F = MF.getFunction();
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
- return computeMachineUniformityInfo(MF, CI, DomTree,
+ return computeMachineUniformityInfo(MF, CI, DomTree, PDomTree,
TTI.hasBranchDivergence(&F));
}
@@ -215,6 +218,7 @@ INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
"Machine Uniformity Info Analysis", false, true)
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
"Machine Uniformity Info Analysis", false, true)
@@ -222,15 +226,18 @@ void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
+ AU.addRequired<MachinePostDominatorTreeWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
+ auto &PDomTree =
+ getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
// FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
// default NoTTI
- UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
+ UI = computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, true);
return false;
}
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
new file mode 100644
index 0000000000000..df949a86635c4
--- /dev/null
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
@@ -0,0 +1,78 @@
+;
+; RUN: opt -mtriple amdgcn-- -passes='print<uniformity>' -disable-output %s 2>&1 | FileCheck %s
+;
+; This is to test an if-then-else case with some unmerged basic blocks
+; (https://github.com/llvm/llvm-project/issues/137277)
+;
+; Entry (div.cond)
+; / \
+; B0 B3
+; | |
+; B1 B4
+; | |
+; B2 B5
+; \ /
+; B6 (phi: divergent)
+;
+
+
+; CHECK-LABEL: 'test_ctrl_divergence':
+; CHECK-LABEL: BLOCK Entry
+; CHECK: DIVERGENT: %div.cond = icmp eq i32 %tid, 0
+; CHECK: DIVERGENT: br i1 %div.cond, label %B3, label %B0
+;
+; CHECK-LABEL: BLOCK B6
+; CHECK: DIVERGENT: %div_a = phi i32 [ %a0, %B2 ], [ %a1, %B5 ]
+; CHECK: DIVERGENT: %div_b = phi i32 [ %b0, %B2 ], [ %b1, %B5 ]
+; CHECK: DIVERGENT: %div_c = phi i32 [ %c0, %B2 ], [ %c1, %B5 ]
+
+
+define amdgpu_kernel void @test_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) {
+Entry:
+ %tid = call i32 @llvm.amdgcn.workitem.id.x()
+ %div.cond = icmp eq i32 %tid, 0
+ br i1 %div.cond, label %B3, label %B0 ; divergent branch
+
+B0:
+ %a0 = add i32 %a, 1
+ br label %B1
+
+B1:
+ %...
[truncated]
|
@llvm/pr-subscribers-llvm-analysis Author: Junjie Gu (jgu222) ChangesGiven a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests). This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom. Patch is 22.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140013.diff 8 Files Affected:
diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 6aa3a8b9b6e0b..e99d4b1c6dd45 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -77,6 +77,10 @@ template <typename _FunctionT> class GenericSSAContext {
// a given funciton.
using DominatorTreeT = DominatorTreeBase<BlockT, false>;
+ // A post-dominator tree provides the post-dominance relation between
+ // basic blocks in a given funciton.
+ using PostDominatorTreeT = DominatorTreeBase<BlockT, true>;
+
GenericSSAContext() = default;
GenericSSAContext(const FunctionT *F) : F(F) {}
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index d10355fff1bea..f404577bb7e56 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -263,6 +263,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
public:
using BlockT = typename ContextT::BlockT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using FunctionT = typename ContextT::FunctionT;
using ValueRefT = typename ContextT::ValueRefT;
using InstructionT = typename ContextT::InstructionT;
@@ -296,7 +297,9 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
using DivergencePropagatorT = DivergencePropagator<ContextT>;
GenericSyncDependenceAnalysis(const ContextT &Context,
- const DominatorTreeT &DT, const CycleInfoT &CI);
+ const DominatorTreeT &DT,
+ const PostDominatorTreeT &PDT,
+ const CycleInfoT &CI);
/// \brief Computes divergent join points and cycle exits caused by branch
/// divergence in \p Term.
@@ -315,6 +318,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
ModifiedPO CyclePO;
const DominatorTreeT &DT;
+ const PostDominatorTreeT &PDT;
const CycleInfoT &CI;
DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
@@ -336,6 +340,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
using UseT = typename ContextT::UseT;
using InstructionT = typename ContextT::InstructionT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using CycleInfoT = GenericCycleInfo<ContextT>;
using CycleT = typename CycleInfoT::CycleT;
@@ -348,10 +353,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
using TemporalDivergenceTuple =
std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
- GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
+ GenericUniformityAnalysisImpl(const DominatorTreeT &DT,
+ const PostDominatorTreeT &PDT,
+ const CycleInfoT &CI,
const TargetTransformInfo *TTI)
: Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
- TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
+ TTI(TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT, CI) {}
void initialize();
@@ -435,6 +442,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
private:
const DominatorTreeT &DT;
+ const PostDominatorTreeT &PDT;
// Recognized cycles with divergent exits.
SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
@@ -493,6 +501,7 @@ template <typename ContextT> class DivergencePropagator {
public:
using BlockT = typename ContextT::BlockT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using FunctionT = typename ContextT::FunctionT;
using ValueRefT = typename ContextT::ValueRefT;
@@ -507,6 +516,7 @@ template <typename ContextT> class DivergencePropagator {
const ModifiedPO &CyclePOT;
const DominatorTreeT &DT;
+ const PostDominatorTreeT &PDT;
const CycleInfoT &CI;
const BlockT &DivTermBlock;
const ContextT &Context;
@@ -522,10 +532,11 @@ template <typename ContextT> class DivergencePropagator {
BlockLabelMapT &BlockLabels;
DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
- const CycleInfoT &CI, const BlockT &DivTermBlock)
- : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
- Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
- BlockLabels(DivDesc->BlockLabels) {}
+ const PostDominatorTreeT &PDT, const CycleInfoT &CI,
+ const BlockT &DivTermBlock)
+ : CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI),
+ DivTermBlock(DivTermBlock), Context(CI.getSSAContext()),
+ DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {}
void printDefs(raw_ostream &Out) {
Out << "Propagator::BlockLabels {\n";
@@ -542,6 +553,12 @@ template <typename ContextT> class DivergencePropagator {
Out << "}\n";
}
+ const BlockT *getIPDom(const BlockT *B) {
+ const auto *Node = PDT.getNode(B);
+ const auto *IPDomNode = Node->getIDom();
+ return IPDomNode->getBlock();
+ }
+
// Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
// causes a divergent join.
bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
@@ -610,10 +627,11 @@ template <typename ContextT> class DivergencePropagator {
LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
<< Context.print(&DivTermBlock) << "\n");
- // Early stopping criterion
- int FloorIdx = CyclePOT.size() - 1;
- const BlockT *FloorLabel = nullptr;
- int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
+ // Immediate Post-dominator of DivTermBlock is the last join
+ // to visit.
+ const auto *ImmPDom = getIPDom(&DivTermBlock);
+
+ LLVM_DEBUG(dbgs() << "Last join: " << Context.print(ImmPDom) << "\n");
// Bootstrap with branch targets
auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
@@ -626,34 +644,29 @@ template <typename ContextT> class DivergencePropagator {
LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
<< Context.print(SuccBlock) << "\n");
}
- auto SuccIdx = CyclePOT.getIndex(SuccBlock);
visitEdge(*SuccBlock, *SuccBlock);
- FloorIdx = std::min<int>(FloorIdx, SuccIdx);
}
while (true) {
auto BlockIdx = FreshLabels.find_last();
- if (BlockIdx == -1 || BlockIdx < FloorIdx)
+ if (BlockIdx == -1)
break;
LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
FreshLabels.reset(BlockIdx);
- if (BlockIdx == DivTermIdx) {
- LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
+ const auto *Block = CyclePOT[BlockIdx];
+ if (Block == ImmPDom) {
+ LLVM_DEBUG(dbgs() << "Skipping ImmPDom\n");
continue;
}
- const auto *Block = CyclePOT[BlockIdx];
LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
<< BlockIdx << "\n");
const auto *Label = BlockLabels[Block];
assert(Label);
- bool CausedJoin = false;
- int LoweredFloorIdx = FloorIdx;
-
// If the current block is the header of a reducible cycle that
// contains the divergent branch, then the label should be
// propagated to the cycle exits. Such a header is the "last
@@ -681,28 +694,11 @@ template <typename ContextT> class DivergencePropagator {
if (const auto *BlockCycle = getReducibleParent(Block)) {
SmallVector<BlockT *, 4> BlockCycleExits;
BlockCycle->getExitBlocks(BlockCycleExits);
- for (auto *BlockCycleExit : BlockCycleExits) {
- CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
- LoweredFloorIdx =
- std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
- }
+ for (auto *BlockCycleExit : BlockCycleExits)
+ visitCycleExitEdge(*BlockCycleExit, *Label);
} else {
- for (const auto *SuccBlock : successors(Block)) {
- CausedJoin |= visitEdge(*SuccBlock, *Label);
- LoweredFloorIdx =
- std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
- }
- }
-
- // Floor update
- if (CausedJoin) {
- // 1. Different labels pushed to successors
- FloorIdx = LoweredFloorIdx;
- } else if (FloorLabel != Label) {
- // 2. No join caused BUT we pushed a label that is different than the
- // last pushed label
- FloorIdx = LoweredFloorIdx;
- FloorLabel = Label;
+ for (const auto *SuccBlock : successors(Block))
+ visitEdge(*SuccBlock, *Label);
}
}
@@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
template <typename ContextT>
llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
- const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
- : CyclePO(Context), DT(DT), CI(CI) {
+ const ContextT &Context, const DominatorTreeT &DT,
+ const PostDominatorTreeT &PDT, const CycleInfoT &CI)
+ : CyclePO(Context), DT(DT), PDT(PDT), CI(CI) {
CyclePO.compute(CI);
}
@@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
return *ItCached->second;
// compute all join points
- DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
+ DivergencePropagatorT Propagator(CyclePO, DT, PDT, CI, *DivTermBlock);
auto DivDesc = Propagator.computeJoinPoints();
auto printBlockSet = [&](ConstBlockSet &Blocks) {
@@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
template <typename ContextT>
GenericUniformityInfo<ContextT>::GenericUniformityInfo(
- const DominatorTreeT &DT, const CycleInfoT &CI,
- const TargetTransformInfo *TTI) {
- DA.reset(new ImplT{DT, CI, TTI});
+ const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+ const CycleInfoT &CI, const TargetTransformInfo *TTI) {
+ DA.reset(new ImplT{DT, PDT, CI, TTI});
}
template <typename ContextT>
diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h
index 9376fa6ee0bae..62d35582823dc 100644
--- a/llvm/include/llvm/ADT/GenericUniformityInfo.h
+++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h
@@ -35,6 +35,7 @@ template <typename ContextT> class GenericUniformityInfo {
using UseT = typename ContextT::UseT;
using InstructionT = typename ContextT::InstructionT;
using DominatorTreeT = typename ContextT::DominatorTreeT;
+ using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
using ThisT = GenericUniformityInfo<ContextT>;
using CycleInfoT = GenericCycleInfo<ContextT>;
@@ -43,7 +44,8 @@ template <typename ContextT> class GenericUniformityInfo {
using TemporalDivergenceTuple =
std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
- GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI,
+ GenericUniformityInfo(const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+ const CycleInfoT &CI,
const TargetTransformInfo *TTI = nullptr);
GenericUniformityInfo() = default;
GenericUniformityInfo(GenericUniformityInfo &&) = default;
diff --git a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
index e8c0dc9b43823..03fc9ebfcf442 100644
--- a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
+++ b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
@@ -18,6 +18,7 @@
#include "llvm/CodeGen/MachineCycleAnalysis.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/CodeGen/MachineSSAContext.h"
namespace llvm {
@@ -31,7 +32,8 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
/// everything is uniform.
MachineUniformityInfo computeMachineUniformityInfo(
MachineFunction &F, const MachineCycleInfo &cycleInfo,
- const MachineDominatorTree &domTree, bool HasBranchDivergence);
+ const MachineDominatorTree &domTree,
+ const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence);
/// Legacy analysis pass which computes a \ref MachineUniformityInfo.
class MachineUniformityAnalysisPass : public MachineFunctionPass {
diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index 2101fdfacfc8f..a724a8c26d7db 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -9,6 +9,7 @@
#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/ADT/GenericUniformityImpl.h"
#include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstIterator.h"
@@ -114,9 +115,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
FunctionAnalysisManager &FAM) {
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+ auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
auto &CI = FAM.getResult<CycleAnalysis>(F);
- UniformityInfo UI{DT, CI, &TTI};
+ UniformityInfo UI{DT, PDT, CI, &TTI};
// Skip computation if we can assume everything is uniform.
if (TTI.hasBranchDivergence(&F))
UI.compute();
@@ -148,6 +150,7 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
"Uniformity Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
@@ -156,6 +159,7 @@ INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<PostDominatorTreeWrapperPass>();
AU.addRequiredTransitive<CycleInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
}
@@ -163,11 +167,13 @@ void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &pdomTree = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
auto &targetTransformInfo =
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
m_function = &F;
- m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
+ m_uniformityInfo =
+ UniformityInfo{domTree, pdomTree, cycleInfo, &targetTransformInfo};
// Skip computation if we can assume everything is uniform.
if (targetTransformInfo.hasBranchDivergence(m_function))
diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index e4b82ce83fda6..b87f8357ecfa8 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -11,6 +11,7 @@
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/MachineCycleAnalysis.h"
#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/MachineSSAContext.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
@@ -156,9 +157,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
MachineUniformityInfo llvm::computeMachineUniformityInfo(
MachineFunction &F, const MachineCycleInfo &cycleInfo,
- const MachineDominatorTree &domTree, bool HasBranchDivergence) {
+ const MachineDominatorTree &domTree,
+ const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence) {
assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
- MachineUniformityInfo UI(domTree, cycleInfo);
+ MachineUniformityInfo UI(domTree, pdomTree, cycleInfo);
if (HasBranchDivergence)
UI.compute();
return UI;
@@ -184,12 +186,13 @@ MachineUniformityAnalysis::Result
MachineUniformityAnalysis::run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM) {
auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
+ auto &PDomTree = MFAM.getResult<MachinePostDominatorTreeAnalysis>(MF);
auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
.getManager();
auto &F = MF.getFunction();
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
- return computeMachineUniformityInfo(MF, CI, DomTree,
+ return computeMachineUniformityInfo(MF, CI, DomTree, PDomTree,
TTI.hasBranchDivergence(&F));
}
@@ -215,6 +218,7 @@ INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
"Machine Uniformity Info Analysis", false, true)
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
"Machine Uniformity Info Analysis", false, true)
@@ -222,15 +226,18 @@ void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
+ AU.addRequired<MachinePostDominatorTreeWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
+ auto &PDomTree =
+ getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
// FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
// default NoTTI
- UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
+ UI = computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, true);
return false;
}
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
new file mode 100644
index 0000000000000..df949a86635c4
--- /dev/null
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
@@ -0,0 +1,78 @@
+;
+; RUN: opt -mtriple amdgcn-- -passes='print<uniformity>' -disable-output %s 2>&1 | FileCheck %s
+;
+; This is to test an if-then-else case with some unmerged basic blocks
+; (https://github.com/llvm/llvm-project/issues/137277)
+;
+; Entry (div.cond)
+; / \
+; B0 B3
+; | |
+; B1 B4
+; | |
+; B2 B5
+; \ /
+; B6 (phi: divergent)
+;
+
+
+; CHECK-LABEL: 'test_ctrl_divergence':
+; CHECK-LABEL: BLOCK Entry
+; CHECK: DIVERGENT: %div.cond = icmp eq i32 %tid, 0
+; CHECK: DIVERGENT: br i1 %div.cond, label %B3, label %B0
+;
+; CHECK-LABEL: BLOCK B6
+; CHECK: DIVERGENT: %div_a = phi i32 [ %a0, %B2 ], [ %a1, %B5 ]
+; CHECK: DIVERGENT: %div_b = phi i32 [ %b0, %B2 ], [ %b1, %B5 ]
+; CHECK: DIVERGENT: %div_c = phi i32 [ %c0, %B2 ], [ %c1, %B5 ]
+
+
+define amdgpu_kernel void @test_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) {
+Entry:
+ %tid = call i32 @llvm.amdgcn.workitem.id.x()
+ %div.cond = icmp eq i32 %tid, 0
+ br i1 %div.cond, label %B3, label %B0 ; divergent branch
+
+B0:
+ %a0 = add i32 %a, 1
+ br label %B1
+
+B1:
+ %...
[truncated]
|
@@ -0,0 +1,78 @@ | |||
; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
; |
@@ -0,0 +1,82 @@ | |||
; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
; |
|
||
declare i32 @llvm.amdgcn.workitem.id.x() #0 | ||
|
||
attributes #0 = {nounwind readnone } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of date attributes for the intrinsic but you can just drop them
Thanks for comments. I will try to come out a modified patch early next week. |
Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests).
This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom.