Skip to content

[UniformAnalysis] Use Immediate postDom as last join #140013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jgu222
Copy link
Contributor

@jgu222 jgu222 commented May 15, 2025

Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests).

This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom.

Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping.
But it is incorrect for some cases (shown in the two new lit tests).

This change uses the immediate post-dominator as the last join to check
for early stopping. It adds post-dominator to genericUniformityImpl in
order to get immediate postDom.
@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-llvm-adt

@llvm/pr-subscribers-backend-amdgpu

Author: Junjie Gu (jgu222)

Changes

Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests).

This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom.


Patch is 22.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140013.diff

8 Files Affected:

  • (modified) llvm/include/llvm/ADT/GenericSSAContext.h (+4)
  • (modified) llvm/include/llvm/ADT/GenericUniformityImpl.h (+44-47)
  • (modified) llvm/include/llvm/ADT/GenericUniformityInfo.h (+3-1)
  • (modified) llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h (+3-1)
  • (modified) llvm/lib/Analysis/UniformityAnalysis.cpp (+8-2)
  • (modified) llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (+11-4)
  • (added) llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll (+78)
  • (added) llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_loop.ll (+82)
diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 6aa3a8b9b6e0b..e99d4b1c6dd45 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -77,6 +77,10 @@ template <typename _FunctionT> class GenericSSAContext {
   // a given funciton.
   using DominatorTreeT = DominatorTreeBase<BlockT, false>;
 
+  // A post-dominator tree provides the post-dominance relation between
+  // basic blocks in a given funciton.
+  using PostDominatorTreeT = DominatorTreeBase<BlockT, true>;
+
   GenericSSAContext() = default;
   GenericSSAContext(const FunctionT *F) : F(F) {}
 
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index d10355fff1bea..f404577bb7e56 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -263,6 +263,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
 public:
   using BlockT = typename ContextT::BlockT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using FunctionT = typename ContextT::FunctionT;
   using ValueRefT = typename ContextT::ValueRefT;
   using InstructionT = typename ContextT::InstructionT;
@@ -296,7 +297,9 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
   using DivergencePropagatorT = DivergencePropagator<ContextT>;
 
   GenericSyncDependenceAnalysis(const ContextT &Context,
-                                const DominatorTreeT &DT, const CycleInfoT &CI);
+                                const DominatorTreeT &DT,
+                                const PostDominatorTreeT &PDT,
+                                const CycleInfoT &CI);
 
   /// \brief Computes divergent join points and cycle exits caused by branch
   /// divergence in \p Term.
@@ -315,6 +318,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
   ModifiedPO CyclePO;
 
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
   const CycleInfoT &CI;
 
   DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
@@ -336,6 +340,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
   using UseT = typename ContextT::UseT;
   using InstructionT = typename ContextT::InstructionT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
 
   using CycleInfoT = GenericCycleInfo<ContextT>;
   using CycleT = typename CycleInfoT::CycleT;
@@ -348,10 +353,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
   using TemporalDivergenceTuple =
       std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
 
-  GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
+  GenericUniformityAnalysisImpl(const DominatorTreeT &DT,
+                                const PostDominatorTreeT &PDT,
+                                const CycleInfoT &CI,
                                 const TargetTransformInfo *TTI)
       : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
-        TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
+        TTI(TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT, CI) {}
 
   void initialize();
 
@@ -435,6 +442,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
 
 private:
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
 
   // Recognized cycles with divergent exits.
   SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
@@ -493,6 +501,7 @@ template <typename ContextT> class DivergencePropagator {
 public:
   using BlockT = typename ContextT::BlockT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using FunctionT = typename ContextT::FunctionT;
   using ValueRefT = typename ContextT::ValueRefT;
 
@@ -507,6 +516,7 @@ template <typename ContextT> class DivergencePropagator {
 
   const ModifiedPO &CyclePOT;
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
   const CycleInfoT &CI;
   const BlockT &DivTermBlock;
   const ContextT &Context;
@@ -522,10 +532,11 @@ template <typename ContextT> class DivergencePropagator {
   BlockLabelMapT &BlockLabels;
 
   DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
-                       const CycleInfoT &CI, const BlockT &DivTermBlock)
-      : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
-        Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
-        BlockLabels(DivDesc->BlockLabels) {}
+                       const PostDominatorTreeT &PDT, const CycleInfoT &CI,
+                       const BlockT &DivTermBlock)
+      : CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI),
+        DivTermBlock(DivTermBlock), Context(CI.getSSAContext()),
+        DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {}
 
   void printDefs(raw_ostream &Out) {
     Out << "Propagator::BlockLabels {\n";
@@ -542,6 +553,12 @@ template <typename ContextT> class DivergencePropagator {
     Out << "}\n";
   }
 
+  const BlockT *getIPDom(const BlockT *B) {
+    const auto *Node = PDT.getNode(B);
+    const auto *IPDomNode = Node->getIDom();
+    return IPDomNode->getBlock();
+  }
+
   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
   // causes a divergent join.
   bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
@@ -610,10 +627,11 @@ template <typename ContextT> class DivergencePropagator {
     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
                       << Context.print(&DivTermBlock) << "\n");
 
-    // Early stopping criterion
-    int FloorIdx = CyclePOT.size() - 1;
-    const BlockT *FloorLabel = nullptr;
-    int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
+    // Immediate Post-dominator of DivTermBlock is the last join
+    // to visit.
+    const auto *ImmPDom = getIPDom(&DivTermBlock);
+
+    LLVM_DEBUG(dbgs() << "Last join: " << Context.print(ImmPDom) << "\n");
 
     // Bootstrap with branch targets
     auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
@@ -626,34 +644,29 @@ template <typename ContextT> class DivergencePropagator {
         LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
                           << Context.print(SuccBlock) << "\n");
       }
-      auto SuccIdx = CyclePOT.getIndex(SuccBlock);
       visitEdge(*SuccBlock, *SuccBlock);
-      FloorIdx = std::min<int>(FloorIdx, SuccIdx);
     }
 
     while (true) {
       auto BlockIdx = FreshLabels.find_last();
-      if (BlockIdx == -1 || BlockIdx < FloorIdx)
+      if (BlockIdx == -1)
         break;
 
       LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
 
       FreshLabels.reset(BlockIdx);
-      if (BlockIdx == DivTermIdx) {
-        LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
+      const auto *Block = CyclePOT[BlockIdx];
+      if (Block == ImmPDom) {
+        LLVM_DEBUG(dbgs() << "Skipping ImmPDom\n");
         continue;
       }
 
-      const auto *Block = CyclePOT[BlockIdx];
       LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
                         << BlockIdx << "\n");
 
       const auto *Label = BlockLabels[Block];
       assert(Label);
 
-      bool CausedJoin = false;
-      int LoweredFloorIdx = FloorIdx;
-
       // If the current block is the header of a reducible cycle that
       // contains the divergent branch, then the label should be
       // propagated to the cycle exits. Such a header is the "last
@@ -681,28 +694,11 @@ template <typename ContextT> class DivergencePropagator {
       if (const auto *BlockCycle = getReducibleParent(Block)) {
         SmallVector<BlockT *, 4> BlockCycleExits;
         BlockCycle->getExitBlocks(BlockCycleExits);
-        for (auto *BlockCycleExit : BlockCycleExits) {
-          CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
-          LoweredFloorIdx =
-              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
-        }
+        for (auto *BlockCycleExit : BlockCycleExits)
+          visitCycleExitEdge(*BlockCycleExit, *Label);
       } else {
-        for (const auto *SuccBlock : successors(Block)) {
-          CausedJoin |= visitEdge(*SuccBlock, *Label);
-          LoweredFloorIdx =
-              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
-        }
-      }
-
-      // Floor update
-      if (CausedJoin) {
-        // 1. Different labels pushed to successors
-        FloorIdx = LoweredFloorIdx;
-      } else if (FloorLabel != Label) {
-        // 2. No join caused BUT we pushed a label that is different than the
-        // last pushed label
-        FloorIdx = LoweredFloorIdx;
-        FloorLabel = Label;
+        for (const auto *SuccBlock : successors(Block))
+          visitEdge(*SuccBlock, *Label);
       }
     }
 
@@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
 
 template <typename ContextT>
 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
-    const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
-    : CyclePO(Context), DT(DT), CI(CI) {
+    const ContextT &Context, const DominatorTreeT &DT,
+    const PostDominatorTreeT &PDT, const CycleInfoT &CI)
+    : CyclePO(Context), DT(DT), PDT(PDT), CI(CI) {
   CyclePO.compute(CI);
 }
 
@@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
     return *ItCached->second;
 
   // compute all join points
-  DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
+  DivergencePropagatorT Propagator(CyclePO, DT, PDT, CI, *DivTermBlock);
   auto DivDesc = Propagator.computeJoinPoints();
 
   auto printBlockSet = [&](ConstBlockSet &Blocks) {
@@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
 
 template <typename ContextT>
 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
-    const DominatorTreeT &DT, const CycleInfoT &CI,
-    const TargetTransformInfo *TTI) {
-  DA.reset(new ImplT{DT, CI, TTI});
+    const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+    const CycleInfoT &CI, const TargetTransformInfo *TTI) {
+  DA.reset(new ImplT{DT, PDT, CI, TTI});
 }
 
 template <typename ContextT>
diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h
index 9376fa6ee0bae..62d35582823dc 100644
--- a/llvm/include/llvm/ADT/GenericUniformityInfo.h
+++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h
@@ -35,6 +35,7 @@ template <typename ContextT> class GenericUniformityInfo {
   using UseT = typename ContextT::UseT;
   using InstructionT = typename ContextT::InstructionT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using ThisT = GenericUniformityInfo<ContextT>;
 
   using CycleInfoT = GenericCycleInfo<ContextT>;
@@ -43,7 +44,8 @@ template <typename ContextT> class GenericUniformityInfo {
   using TemporalDivergenceTuple =
       std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
 
-  GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI,
+  GenericUniformityInfo(const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+                        const CycleInfoT &CI,
                         const TargetTransformInfo *TTI = nullptr);
   GenericUniformityInfo() = default;
   GenericUniformityInfo(GenericUniformityInfo &&) = default;
diff --git a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
index e8c0dc9b43823..03fc9ebfcf442 100644
--- a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
+++ b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
@@ -18,6 +18,7 @@
 #include "llvm/CodeGen/MachineCycleAnalysis.h"
 #include "llvm/CodeGen/MachineDominators.h"
 #include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
 #include "llvm/CodeGen/MachineSSAContext.h"
 
 namespace llvm {
@@ -31,7 +32,8 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
 /// everything is uniform.
 MachineUniformityInfo computeMachineUniformityInfo(
     MachineFunction &F, const MachineCycleInfo &cycleInfo,
-    const MachineDominatorTree &domTree, bool HasBranchDivergence);
+    const MachineDominatorTree &domTree,
+    const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence);
 
 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
 class MachineUniformityAnalysisPass : public MachineFunctionPass {
diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index 2101fdfacfc8f..a724a8c26d7db 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -9,6 +9,7 @@
 #include "llvm/Analysis/UniformityAnalysis.h"
 #include "llvm/ADT/GenericUniformityImpl.h"
 #include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
@@ -114,9 +115,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
                                                  FunctionAnalysisManager &FAM) {
   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+  auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
   auto &CI = FAM.getResult<CycleAnalysis>(F);
-  UniformityInfo UI{DT, CI, &TTI};
+  UniformityInfo UI{DT, PDT, CI, &TTI};
   // Skip computation if we can assume everything is uniform.
   if (TTI.hasBranchDivergence(&F))
     UI.compute();
@@ -148,6 +150,7 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
                       "Uniformity Analysis", true, true)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
@@ -156,6 +159,7 @@ INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
   AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addRequired<PostDominatorTreeWrapperPass>();
   AU.addRequiredTransitive<CycleInfoWrapperPass>();
   AU.addRequired<TargetTransformInfoWrapperPass>();
 }
@@ -163,11 +167,13 @@ void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
   auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
   auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  auto &pdomTree = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
   auto &targetTransformInfo =
       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
 
   m_function = &F;
-  m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
+  m_uniformityInfo =
+      UniformityInfo{domTree, pdomTree, cycleInfo, &targetTransformInfo};
 
   // Skip computation if we can assume everything is uniform.
   if (targetTransformInfo.hasBranchDivergence(m_function))
diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index e4b82ce83fda6..b87f8357ecfa8 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -11,6 +11,7 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/MachineCycleAnalysis.h"
 #include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/MachineSSAContext.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
@@ -156,9 +157,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
 
 MachineUniformityInfo llvm::computeMachineUniformityInfo(
     MachineFunction &F, const MachineCycleInfo &cycleInfo,
-    const MachineDominatorTree &domTree, bool HasBranchDivergence) {
+    const MachineDominatorTree &domTree,
+    const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence) {
   assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
-  MachineUniformityInfo UI(domTree, cycleInfo);
+  MachineUniformityInfo UI(domTree, pdomTree, cycleInfo);
   if (HasBranchDivergence)
     UI.compute();
   return UI;
@@ -184,12 +186,13 @@ MachineUniformityAnalysis::Result
 MachineUniformityAnalysis::run(MachineFunction &MF,
                                MachineFunctionAnalysisManager &MFAM) {
   auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
+  auto &PDomTree = MFAM.getResult<MachinePostDominatorTreeAnalysis>(MF);
   auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
   auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
                   .getManager();
   auto &F = MF.getFunction();
   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
-  return computeMachineUniformityInfo(MF, CI, DomTree,
+  return computeMachineUniformityInfo(MF, CI, DomTree, PDomTree,
                                       TTI.hasBranchDivergence(&F));
 }
 
@@ -215,6 +218,7 @@ INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
                       "Machine Uniformity Info Analysis", false, true)
 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
                     "Machine Uniformity Info Analysis", false, true)
 
@@ -222,15 +226,18 @@ void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
   AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
   AU.addRequired<MachineDominatorTreeWrapperPass>();
+  AU.addRequired<MachinePostDominatorTreeWrapperPass>();
   MachineFunctionPass::getAnalysisUsage(AU);
 }
 
 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
   auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
+  auto &PDomTree =
+      getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
   auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
   // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
   // default NoTTI
-  UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
+  UI = computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, true);
   return false;
 }
 
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
new file mode 100644
index 0000000000000..df949a86635c4
--- /dev/null
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
@@ -0,0 +1,78 @@
+;
+; RUN: opt -mtriple amdgcn-- -passes='print<uniformity>' -disable-output %s 2>&1 | FileCheck %s
+;
+; This is to test an if-then-else case with some unmerged basic blocks
+; (https://github.com/llvm/llvm-project/issues/137277)
+;
+;      Entry (div.cond)
+;      /   \
+;     B0   B3
+;     |    |
+;     B1   B4
+;     |    |
+;     B2   B5
+;      \  /
+;       B6 (phi: divergent)
+;
+
+
+; CHECK-LABEL:  'test_ctrl_divergence':
+; CHECK-LABEL:  BLOCK Entry
+; CHECK:  DIVERGENT:   %div.cond = icmp eq i32 %tid, 0
+; CHECK:  DIVERGENT:   br i1 %div.cond, label %B3, label %B0
+;
+; CHECK-LABEL:  BLOCK B6
+; CHECK:  DIVERGENT:   %div_a = phi i32 [ %a0, %B2 ], [ %a1, %B5 ]
+; CHECK:  DIVERGENT:   %div_b = phi i32 [ %b0, %B2 ], [ %b1, %B5 ]
+; CHECK:  DIVERGENT:   %div_c = phi i32 [ %c0, %B2 ], [ %c1, %B5 ]
+
+
+define amdgpu_kernel void @test_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) {
+Entry:
+  %tid = call i32 @llvm.amdgcn.workitem.id.x()
+  %div.cond = icmp eq i32 %tid, 0
+  br i1 %div.cond, label %B3, label %B0 ; divergent branch
+
+B0:
+  %a0 = add i32 %a, 1
+  br label %B1
+
+B1:
+  %...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-llvm-analysis

Author: Junjie Gu (jgu222)

Changes

Given a divergent block, computeJoinPoints uses FloorIdx to do early stopping. But it is incorrect for some cases (shown in the two new lit tests).

This change uses the immediate post-dominator as the last join to check for early stopping. It adds post-dominator to genericUniformityImpl in order to get immediate postDom.


Patch is 22.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140013.diff

8 Files Affected:

  • (modified) llvm/include/llvm/ADT/GenericSSAContext.h (+4)
  • (modified) llvm/include/llvm/ADT/GenericUniformityImpl.h (+44-47)
  • (modified) llvm/include/llvm/ADT/GenericUniformityInfo.h (+3-1)
  • (modified) llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h (+3-1)
  • (modified) llvm/lib/Analysis/UniformityAnalysis.cpp (+8-2)
  • (modified) llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (+11-4)
  • (added) llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll (+78)
  • (added) llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_loop.ll (+82)
diff --git a/llvm/include/llvm/ADT/GenericSSAContext.h b/llvm/include/llvm/ADT/GenericSSAContext.h
index 6aa3a8b9b6e0b..e99d4b1c6dd45 100644
--- a/llvm/include/llvm/ADT/GenericSSAContext.h
+++ b/llvm/include/llvm/ADT/GenericSSAContext.h
@@ -77,6 +77,10 @@ template <typename _FunctionT> class GenericSSAContext {
   // a given funciton.
   using DominatorTreeT = DominatorTreeBase<BlockT, false>;
 
+  // A post-dominator tree provides the post-dominance relation between
+  // basic blocks in a given funciton.
+  using PostDominatorTreeT = DominatorTreeBase<BlockT, true>;
+
   GenericSSAContext() = default;
   GenericSSAContext(const FunctionT *F) : F(F) {}
 
diff --git a/llvm/include/llvm/ADT/GenericUniformityImpl.h b/llvm/include/llvm/ADT/GenericUniformityImpl.h
index d10355fff1bea..f404577bb7e56 100644
--- a/llvm/include/llvm/ADT/GenericUniformityImpl.h
+++ b/llvm/include/llvm/ADT/GenericUniformityImpl.h
@@ -263,6 +263,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
 public:
   using BlockT = typename ContextT::BlockT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using FunctionT = typename ContextT::FunctionT;
   using ValueRefT = typename ContextT::ValueRefT;
   using InstructionT = typename ContextT::InstructionT;
@@ -296,7 +297,9 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
   using DivergencePropagatorT = DivergencePropagator<ContextT>;
 
   GenericSyncDependenceAnalysis(const ContextT &Context,
-                                const DominatorTreeT &DT, const CycleInfoT &CI);
+                                const DominatorTreeT &DT,
+                                const PostDominatorTreeT &PDT,
+                                const CycleInfoT &CI);
 
   /// \brief Computes divergent join points and cycle exits caused by branch
   /// divergence in \p Term.
@@ -315,6 +318,7 @@ template <typename ContextT> class GenericSyncDependenceAnalysis {
   ModifiedPO CyclePO;
 
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
   const CycleInfoT &CI;
 
   DenseMap<const BlockT *, std::unique_ptr<DivergenceDescriptor>>
@@ -336,6 +340,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
   using UseT = typename ContextT::UseT;
   using InstructionT = typename ContextT::InstructionT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
 
   using CycleInfoT = GenericCycleInfo<ContextT>;
   using CycleT = typename CycleInfoT::CycleT;
@@ -348,10 +353,12 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
   using TemporalDivergenceTuple =
       std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
 
-  GenericUniformityAnalysisImpl(const DominatorTreeT &DT, const CycleInfoT &CI,
+  GenericUniformityAnalysisImpl(const DominatorTreeT &DT,
+                                const PostDominatorTreeT &PDT,
+                                const CycleInfoT &CI,
                                 const TargetTransformInfo *TTI)
       : Context(CI.getSSAContext()), F(*Context.getFunction()), CI(CI),
-        TTI(TTI), DT(DT), SDA(Context, DT, CI) {}
+        TTI(TTI), DT(DT), PDT(PDT), SDA(Context, DT, PDT, CI) {}
 
   void initialize();
 
@@ -435,6 +442,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
 
 private:
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
 
   // Recognized cycles with divergent exits.
   SmallPtrSet<const CycleT *, 16> DivergentExitCycles;
@@ -493,6 +501,7 @@ template <typename ContextT> class DivergencePropagator {
 public:
   using BlockT = typename ContextT::BlockT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using FunctionT = typename ContextT::FunctionT;
   using ValueRefT = typename ContextT::ValueRefT;
 
@@ -507,6 +516,7 @@ template <typename ContextT> class DivergencePropagator {
 
   const ModifiedPO &CyclePOT;
   const DominatorTreeT &DT;
+  const PostDominatorTreeT &PDT;
   const CycleInfoT &CI;
   const BlockT &DivTermBlock;
   const ContextT &Context;
@@ -522,10 +532,11 @@ template <typename ContextT> class DivergencePropagator {
   BlockLabelMapT &BlockLabels;
 
   DivergencePropagator(const ModifiedPO &CyclePOT, const DominatorTreeT &DT,
-                       const CycleInfoT &CI, const BlockT &DivTermBlock)
-      : CyclePOT(CyclePOT), DT(DT), CI(CI), DivTermBlock(DivTermBlock),
-        Context(CI.getSSAContext()), DivDesc(new DivergenceDescriptorT),
-        BlockLabels(DivDesc->BlockLabels) {}
+                       const PostDominatorTreeT &PDT, const CycleInfoT &CI,
+                       const BlockT &DivTermBlock)
+      : CyclePOT(CyclePOT), DT(DT), PDT(PDT), CI(CI),
+        DivTermBlock(DivTermBlock), Context(CI.getSSAContext()),
+        DivDesc(new DivergenceDescriptorT), BlockLabels(DivDesc->BlockLabels) {}
 
   void printDefs(raw_ostream &Out) {
     Out << "Propagator::BlockLabels {\n";
@@ -542,6 +553,12 @@ template <typename ContextT> class DivergencePropagator {
     Out << "}\n";
   }
 
+  const BlockT *getIPDom(const BlockT *B) {
+    const auto *Node = PDT.getNode(B);
+    const auto *IPDomNode = Node->getIDom();
+    return IPDomNode->getBlock();
+  }
+
   // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this
   // causes a divergent join.
   bool computeJoin(const BlockT &SuccBlock, const BlockT &PushedLabel) {
@@ -610,10 +627,11 @@ template <typename ContextT> class DivergencePropagator {
     LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: "
                       << Context.print(&DivTermBlock) << "\n");
 
-    // Early stopping criterion
-    int FloorIdx = CyclePOT.size() - 1;
-    const BlockT *FloorLabel = nullptr;
-    int DivTermIdx = CyclePOT.getIndex(&DivTermBlock);
+    // Immediate Post-dominator of DivTermBlock is the last join
+    // to visit.
+    const auto *ImmPDom = getIPDom(&DivTermBlock);
+
+    LLVM_DEBUG(dbgs() << "Last join: " << Context.print(ImmPDom) << "\n");
 
     // Bootstrap with branch targets
     auto const *DivTermCycle = CI.getCycle(&DivTermBlock);
@@ -626,34 +644,29 @@ template <typename ContextT> class DivergencePropagator {
         LLVM_DEBUG(dbgs() << "\tImmediate divergent cycle exit: "
                           << Context.print(SuccBlock) << "\n");
       }
-      auto SuccIdx = CyclePOT.getIndex(SuccBlock);
       visitEdge(*SuccBlock, *SuccBlock);
-      FloorIdx = std::min<int>(FloorIdx, SuccIdx);
     }
 
     while (true) {
       auto BlockIdx = FreshLabels.find_last();
-      if (BlockIdx == -1 || BlockIdx < FloorIdx)
+      if (BlockIdx == -1)
         break;
 
       LLVM_DEBUG(dbgs() << "Current labels:\n"; printDefs(dbgs()));
 
       FreshLabels.reset(BlockIdx);
-      if (BlockIdx == DivTermIdx) {
-        LLVM_DEBUG(dbgs() << "Skipping DivTermBlock\n");
+      const auto *Block = CyclePOT[BlockIdx];
+      if (Block == ImmPDom) {
+        LLVM_DEBUG(dbgs() << "Skipping ImmPDom\n");
         continue;
       }
 
-      const auto *Block = CyclePOT[BlockIdx];
       LLVM_DEBUG(dbgs() << "visiting " << Context.print(Block) << " at index "
                         << BlockIdx << "\n");
 
       const auto *Label = BlockLabels[Block];
       assert(Label);
 
-      bool CausedJoin = false;
-      int LoweredFloorIdx = FloorIdx;
-
       // If the current block is the header of a reducible cycle that
       // contains the divergent branch, then the label should be
       // propagated to the cycle exits. Such a header is the "last
@@ -681,28 +694,11 @@ template <typename ContextT> class DivergencePropagator {
       if (const auto *BlockCycle = getReducibleParent(Block)) {
         SmallVector<BlockT *, 4> BlockCycleExits;
         BlockCycle->getExitBlocks(BlockCycleExits);
-        for (auto *BlockCycleExit : BlockCycleExits) {
-          CausedJoin |= visitCycleExitEdge(*BlockCycleExit, *Label);
-          LoweredFloorIdx =
-              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(BlockCycleExit));
-        }
+        for (auto *BlockCycleExit : BlockCycleExits)
+          visitCycleExitEdge(*BlockCycleExit, *Label);
       } else {
-        for (const auto *SuccBlock : successors(Block)) {
-          CausedJoin |= visitEdge(*SuccBlock, *Label);
-          LoweredFloorIdx =
-              std::min<int>(LoweredFloorIdx, CyclePOT.getIndex(SuccBlock));
-        }
-      }
-
-      // Floor update
-      if (CausedJoin) {
-        // 1. Different labels pushed to successors
-        FloorIdx = LoweredFloorIdx;
-      } else if (FloorLabel != Label) {
-        // 2. No join caused BUT we pushed a label that is different than the
-        // last pushed label
-        FloorIdx = LoweredFloorIdx;
-        FloorLabel = Label;
+        for (const auto *SuccBlock : successors(Block))
+          visitEdge(*SuccBlock, *Label);
       }
     }
 
@@ -742,8 +738,9 @@ typename llvm::GenericSyncDependenceAnalysis<ContextT>::DivergenceDescriptor
 
 template <typename ContextT>
 llvm::GenericSyncDependenceAnalysis<ContextT>::GenericSyncDependenceAnalysis(
-    const ContextT &Context, const DominatorTreeT &DT, const CycleInfoT &CI)
-    : CyclePO(Context), DT(DT), CI(CI) {
+    const ContextT &Context, const DominatorTreeT &DT,
+    const PostDominatorTreeT &PDT, const CycleInfoT &CI)
+    : CyclePO(Context), DT(DT), PDT(PDT), CI(CI) {
   CyclePO.compute(CI);
 }
 
@@ -761,7 +758,7 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
     return *ItCached->second;
 
   // compute all join points
-  DivergencePropagatorT Propagator(CyclePO, DT, CI, *DivTermBlock);
+  DivergencePropagatorT Propagator(CyclePO, DT, PDT, CI, *DivTermBlock);
   auto DivDesc = Propagator.computeJoinPoints();
 
   auto printBlockSet = [&](ConstBlockSet &Blocks) {
@@ -1155,9 +1152,9 @@ bool GenericUniformityAnalysisImpl<ContextT>::isAlwaysUniform(
 
 template <typename ContextT>
 GenericUniformityInfo<ContextT>::GenericUniformityInfo(
-    const DominatorTreeT &DT, const CycleInfoT &CI,
-    const TargetTransformInfo *TTI) {
-  DA.reset(new ImplT{DT, CI, TTI});
+    const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+    const CycleInfoT &CI, const TargetTransformInfo *TTI) {
+  DA.reset(new ImplT{DT, PDT, CI, TTI});
 }
 
 template <typename ContextT>
diff --git a/llvm/include/llvm/ADT/GenericUniformityInfo.h b/llvm/include/llvm/ADT/GenericUniformityInfo.h
index 9376fa6ee0bae..62d35582823dc 100644
--- a/llvm/include/llvm/ADT/GenericUniformityInfo.h
+++ b/llvm/include/llvm/ADT/GenericUniformityInfo.h
@@ -35,6 +35,7 @@ template <typename ContextT> class GenericUniformityInfo {
   using UseT = typename ContextT::UseT;
   using InstructionT = typename ContextT::InstructionT;
   using DominatorTreeT = typename ContextT::DominatorTreeT;
+  using PostDominatorTreeT = typename ContextT::PostDominatorTreeT;
   using ThisT = GenericUniformityInfo<ContextT>;
 
   using CycleInfoT = GenericCycleInfo<ContextT>;
@@ -43,7 +44,8 @@ template <typename ContextT> class GenericUniformityInfo {
   using TemporalDivergenceTuple =
       std::tuple<ConstValueRefT, InstructionT *, const CycleT *>;
 
-  GenericUniformityInfo(const DominatorTreeT &DT, const CycleInfoT &CI,
+  GenericUniformityInfo(const DominatorTreeT &DT, const PostDominatorTreeT &PDT,
+                        const CycleInfoT &CI,
                         const TargetTransformInfo *TTI = nullptr);
   GenericUniformityInfo() = default;
   GenericUniformityInfo(GenericUniformityInfo &&) = default;
diff --git a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
index e8c0dc9b43823..03fc9ebfcf442 100644
--- a/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
+++ b/llvm/include/llvm/CodeGen/MachineUniformityAnalysis.h
@@ -18,6 +18,7 @@
 #include "llvm/CodeGen/MachineCycleAnalysis.h"
 #include "llvm/CodeGen/MachineDominators.h"
 #include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
 #include "llvm/CodeGen/MachineSSAContext.h"
 
 namespace llvm {
@@ -31,7 +32,8 @@ using MachineUniformityInfo = GenericUniformityInfo<MachineSSAContext>;
 /// everything is uniform.
 MachineUniformityInfo computeMachineUniformityInfo(
     MachineFunction &F, const MachineCycleInfo &cycleInfo,
-    const MachineDominatorTree &domTree, bool HasBranchDivergence);
+    const MachineDominatorTree &domTree,
+    const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence);
 
 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
 class MachineUniformityAnalysisPass : public MachineFunctionPass {
diff --git a/llvm/lib/Analysis/UniformityAnalysis.cpp b/llvm/lib/Analysis/UniformityAnalysis.cpp
index 2101fdfacfc8f..a724a8c26d7db 100644
--- a/llvm/lib/Analysis/UniformityAnalysis.cpp
+++ b/llvm/lib/Analysis/UniformityAnalysis.cpp
@@ -9,6 +9,7 @@
 #include "llvm/Analysis/UniformityAnalysis.h"
 #include "llvm/ADT/GenericUniformityImpl.h"
 #include "llvm/Analysis/CycleAnalysis.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
@@ -114,9 +115,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
                                                  FunctionAnalysisManager &FAM) {
   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+  auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F);
   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
   auto &CI = FAM.getResult<CycleAnalysis>(F);
-  UniformityInfo UI{DT, CI, &TTI};
+  UniformityInfo UI{DT, PDT, CI, &TTI};
   // Skip computation if we can assume everything is uniform.
   if (TTI.hasBranchDivergence(&F))
     UI.compute();
@@ -148,6 +150,7 @@ UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
                       "Uniformity Analysis", true, true)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
@@ -156,6 +159,7 @@ INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
   AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addRequired<PostDominatorTreeWrapperPass>();
   AU.addRequiredTransitive<CycleInfoWrapperPass>();
   AU.addRequired<TargetTransformInfoWrapperPass>();
 }
@@ -163,11 +167,13 @@ void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
   auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
   auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  auto &pdomTree = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree();
   auto &targetTransformInfo =
       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
 
   m_function = &F;
-  m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
+  m_uniformityInfo =
+      UniformityInfo{domTree, pdomTree, cycleInfo, &targetTransformInfo};
 
   // Skip computation if we can assume everything is uniform.
   if (targetTransformInfo.hasBranchDivergence(m_function))
diff --git a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
index e4b82ce83fda6..b87f8357ecfa8 100644
--- a/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
+++ b/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
@@ -11,6 +11,7 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/MachineCycleAnalysis.h"
 #include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachinePostDominators.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/MachineSSAContext.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
@@ -156,9 +157,10 @@ template struct llvm::GenericUniformityAnalysisImplDeleter<
 
 MachineUniformityInfo llvm::computeMachineUniformityInfo(
     MachineFunction &F, const MachineCycleInfo &cycleInfo,
-    const MachineDominatorTree &domTree, bool HasBranchDivergence) {
+    const MachineDominatorTree &domTree,
+    const MachinePostDominatorTree &pdomTree, bool HasBranchDivergence) {
   assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
-  MachineUniformityInfo UI(domTree, cycleInfo);
+  MachineUniformityInfo UI(domTree, pdomTree, cycleInfo);
   if (HasBranchDivergence)
     UI.compute();
   return UI;
@@ -184,12 +186,13 @@ MachineUniformityAnalysis::Result
 MachineUniformityAnalysis::run(MachineFunction &MF,
                                MachineFunctionAnalysisManager &MFAM) {
   auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
+  auto &PDomTree = MFAM.getResult<MachinePostDominatorTreeAnalysis>(MF);
   auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
   auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
                   .getManager();
   auto &F = MF.getFunction();
   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
-  return computeMachineUniformityInfo(MF, CI, DomTree,
+  return computeMachineUniformityInfo(MF, CI, DomTree, PDomTree,
                                       TTI.hasBranchDivergence(&F));
 }
 
@@ -215,6 +218,7 @@ INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
                       "Machine Uniformity Info Analysis", false, true)
 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
                     "Machine Uniformity Info Analysis", false, true)
 
@@ -222,15 +226,18 @@ void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
   AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
   AU.addRequired<MachineDominatorTreeWrapperPass>();
+  AU.addRequired<MachinePostDominatorTreeWrapperPass>();
   MachineFunctionPass::getAnalysisUsage(AU);
 }
 
 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
   auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
+  auto &PDomTree =
+      getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
   auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
   // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
   // default NoTTI
-  UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
+  UI = computeMachineUniformityInfo(MF, CI, DomTree, PDomTree, true);
   return false;
 }
 
diff --git a/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
new file mode 100644
index 0000000000000..df949a86635c4
--- /dev/null
+++ b/llvm/test/Analysis/UniformityAnalysis/AMDGPU/phi_div_branch.ll
@@ -0,0 +1,78 @@
+;
+; RUN: opt -mtriple amdgcn-- -passes='print<uniformity>' -disable-output %s 2>&1 | FileCheck %s
+;
+; This is to test an if-then-else case with some unmerged basic blocks
+; (https://github.com/llvm/llvm-project/issues/137277)
+;
+;      Entry (div.cond)
+;      /   \
+;     B0   B3
+;     |    |
+;     B1   B4
+;     |    |
+;     B2   B5
+;      \  /
+;       B6 (phi: divergent)
+;
+
+
+; CHECK-LABEL:  'test_ctrl_divergence':
+; CHECK-LABEL:  BLOCK Entry
+; CHECK:  DIVERGENT:   %div.cond = icmp eq i32 %tid, 0
+; CHECK:  DIVERGENT:   br i1 %div.cond, label %B3, label %B0
+;
+; CHECK-LABEL:  BLOCK B6
+; CHECK:  DIVERGENT:   %div_a = phi i32 [ %a0, %B2 ], [ %a1, %B5 ]
+; CHECK:  DIVERGENT:   %div_b = phi i32 [ %b0, %B2 ], [ %b1, %B5 ]
+; CHECK:  DIVERGENT:   %div_c = phi i32 [ %c0, %B2 ], [ %c1, %B5 ]
+
+
+define amdgpu_kernel void @test_ctrl_divergence(i32 %a, i32 %b, i32 %c, i32 %d) {
+Entry:
+  %tid = call i32 @llvm.amdgcn.workitem.id.x()
+  %div.cond = icmp eq i32 %tid, 0
+  br i1 %div.cond, label %B3, label %B0 ; divergent branch
+
+B0:
+  %a0 = add i32 %a, 1
+  br label %B1
+
+B1:
+  %...
[truncated]

@@ -0,0 +1,78 @@
;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
;

@@ -0,0 +1,82 @@
;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
;


declare i32 @llvm.amdgcn.workitem.id.x() #0

attributes #0 = {nounwind readnone }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of date attributes for the intrinsic but you can just drop them

@ssahasra
Copy link
Collaborator

@arsenm the potential solution is currently being discussed at #139667 . There's a good chance that we will discard this PR.

@jgu222
Copy link
Contributor Author

jgu222 commented May 22, 2025

Thanks for comments. I will try to come out a modified patch early next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants