Skip to content

Conversation

AlexMaclean
Copy link
Member

Add custom lowering for BR_JT DAG nodes to the brx.idx PTX instruction (PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx).
Depending on the heuristics in DAG selection, switch statements may now be lowered using brx.idx.

Note: this fixes the previous issue in #102400 by adding the isBarrier attribute to BRX_END

Add custom lowering for `BR_JT` DAG nodes to the `brx.idx` PTX
instruction ([PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx]
(https://docs.nvidia.com/cuda/parallel-thread-execution/#control-flow-instructions-brx-idx)).
Depending on the heuristics in DAG selection, `switch` statements may
now be lowered using `brx.idx`
@AlexMaclean AlexMaclean requested a review from Artem-B August 8, 2024 23:20
@AlexMaclean AlexMaclean self-assigned this Aug 8, 2024
@llvmbot llvmbot added backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well labels Aug 8, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: Alex MacLean (AlexMaclean)

Changes

Add custom lowering for BR_JT DAG nodes to the brx.idx PTX instruction (PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx).
Depending on the heuristics in DAG selection, switch statements may now be lowered using brx.idx.

Note: this fixes the previous issue in #102400 by adding the isBarrier attribute to BRX_END


Full diff: https://github.com/llvm/llvm-project/pull/102550.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+6-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+42-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+10)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+40)
  • (added) llvm/test/CodeGen/NVPTX/jump-table.ll (+69)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9ccdbab008aec8..5b2214fa66c40b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
   /// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
   virtual unsigned getJumpTableEncoding() const;
 
+  virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
+    return getPointerTy(DL);
+  }
+
   virtual const MCExpr *
   LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
                             const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1f4436fb3a4966..37ba62911ec70b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
   // Emit the code for the jump table
   assert(JT.SL && "Should set SDLoc for SelectionDAG!");
   assert(JT.Reg != -1U && "Should lower JT Header first!");
-  EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
+  EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
   SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
   SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
   SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
@@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
   // This value may be smaller or larger than the target's pointer type, and
   // therefore require extension or truncating.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
+  SwitchOp =
+      DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));
 
   unsigned JumpTableReg =
-      FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
-  SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
-                                    JumpTableReg, SwitchOp);
+      FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
+  SDValue CopyTo =
+      DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
   JT.Reg = JumpTableReg;
 
   if (!JTH.FallthroughUnreachable) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 516fc7339a4bf3..bf647c88f00e28 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -25,6 +25,7 @@
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::ROTR, MVT::i8, Expand);
   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
 
-  // Indirect branch is not supported.
-  // This also disables Jump Table creation.
-  setOperationAction(ISD::BR_JT, MVT::Other, Expand);
+  setOperationAction(ISD::BR_JT, MVT::Other, Custom);
   setOperationAction(ISD::BRIND, MVT::Other, Expand);
 
   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
@@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::Dummy)
     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
+    MAKE_CASE(NVPTXISD::BrxEnd)
+    MAKE_CASE(NVPTXISD::BrxItem)
+    MAKE_CASE(NVPTXISD::BrxStart)
     MAKE_CASE(NVPTXISD::Tex1DFloatS32)
     MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
     MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
@@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerFP_ROUND(Op, DAG);
   case ISD::FP_EXTEND:
     return LowerFP_EXTEND(Op, DAG);
+  case ISD::BR_JT:
+    return LowerBR_JT(Op, DAG);
   case ISD::VAARG:
     return LowerVAARG(Op, DAG);
   case ISD::VASTART:
@@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   }
 }
 
+SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Chain = Op.getOperand(0);
+  const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
+  SDValue Index = Op.getOperand(2);
+
+  unsigned JId = JT->getIndex();
+  MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
+  ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
+
+  SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
+
+  // Generate BrxStart node
+  SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
+  Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
+
+  // Generate BrxItem nodes
+  assert(!MBBs.empty());
+  for (MachineBasicBlock *MBB : MBBs.drop_back())
+    Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
+                        DAG.getBasicBlock(MBB), Chain.getValue(1));
+
+  // Generate BrxEnd nodes
+  SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
+                      IdV, Chain.getValue(1)};
+  SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
+
+  return BrxEnd;
+}
+
+// This will prevent AsmPrinter from trying to print the jump tables itself.
+unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
+  return MachineJumpTableInfo::EK_Inline;
+}
+
 // This function is almost a copy of SelectionDAG::expandVAArg().
 // The only diff is that this one produces loads from local address space.
 SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 63262961b363ed..32e6b044b0de1f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -62,6 +62,9 @@ enum NodeType : unsigned {
   BFI,
   PRMT,
   DYNAMIC_STACKALLOC,
+  BrxStart,
+  BrxItem,
+  BrxEnd,
   Dummy,
 
   LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
     return true;
   }
 
+  // The default is the same as pointer type, but brx.idx only accepts i32
+  MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
+
+  unsigned getJumpTableEncoding() const override;
+
   bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
 
   // The default is to transform llvm.ctlz(x, false) (where false indicates that
@@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 6a096fa5acea7c..d75dc8781f7802 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3880,6 +3880,46 @@ def DYNAMIC_STACKALLOC64 :
             [(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
             Requires<[hasPTX<73>, hasSM<52>]>;
 
+
+//
+// BRX
+//
+
+def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
+def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;
+
+def brx_start :
+  SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
+def brx_item :
+  SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def brx_end :
+  SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
+         [SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;
+
+let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in {
+
+  def BRX_START :
+    NVPTXInst<(outs), (ins i32imm:$id),
+              "$$L_brx_$id: .branchtargets",
+              [(brx_start (i32 imm:$id))]>;
+
+  def BRX_ITEM :
+    NVPTXInst<(outs), (ins brtarget:$target),
+              "\t$target,",
+              [(brx_item bb:$target)]>;
+
+  def BRX_END :
+    NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
+              "\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
+              [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]> {
+      let isBarrier = 1;
+    }
+}
+
+
 include "NVPTXIntrinsics.td"
 
 //-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll
new file mode 100644
index 00000000000000..867e171a5840ae
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/jump-table.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+@out = addrspace(1) global i32 0, align 4
+
+define void @foo(i32 %i) {
+; CHECK-LABEL: foo(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.u32 %r2, [foo_param_0];
+; CHECK-NEXT:    setp.gt.u32 %p1, %r2, 3;
+; CHECK-NEXT:    @%p1 bra $L__BB0_6;
+; CHECK-NEXT:  // %bb.1: // %entry
+; CHECK-NEXT:    $L_brx_0: .branchtargets
+; CHECK-NEXT:     $L__BB0_2,
+; CHECK-NEXT:     $L__BB0_3,
+; CHECK-NEXT:     $L__BB0_4,
+; CHECK-NEXT:     $L__BB0_5;
+; CHECK-NEXT:    brx.idx %r2, $L_brx_0;
+; CHECK-NEXT:  $L__BB0_2: // %case0
+; CHECK-NEXT:    mov.b32 %r6, 0;
+; CHECK-NEXT:    st.global.u32 [out], %r6;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_4: // %case2
+; CHECK-NEXT:    mov.b32 %r4, 2;
+; CHECK-NEXT:    st.global.u32 [out], %r4;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_5: // %case3
+; CHECK-NEXT:    mov.b32 %r3, 3;
+; CHECK-NEXT:    st.global.u32 [out], %r3;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_3: // %case1
+; CHECK-NEXT:    mov.b32 %r5, 1;
+; CHECK-NEXT:    st.global.u32 [out], %r5;
+; CHECK-NEXT:  $L__BB0_6: // %end
+; CHECK-NEXT:    ret;
+entry:
+  switch i32 %i, label %end [
+    i32 0, label %case0
+    i32 1, label %case1
+    i32 2, label %case2
+    i32 3, label %case3
+  ]
+
+case0:
+  store i32 0, ptr addrspace(1) @out, align 4
+  br label %end
+
+case1:
+  store i32 1, ptr addrspace(1) @out, align 4
+  br label %end
+
+case2:
+  store i32 2, ptr addrspace(1) @out, align 4
+  br label %end
+
+case3:
+  store i32 3, ptr addrspace(1) @out, align 4
+  br label %end
+
+end:
+  ret void
+}

@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Add custom lowering for BR_JT DAG nodes to the brx.idx PTX instruction (PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx).
Depending on the heuristics in DAG selection, switch statements may now be lowered using brx.idx.

Note: this fixes the previous issue in #102400 by adding the isBarrier attribute to BRX_END


Full diff: https://github.com/llvm/llvm-project/pull/102550.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+6-5)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+42-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+10)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+40)
  • (added) llvm/test/CodeGen/NVPTX/jump-table.ll (+69)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9ccdbab008aec8..5b2214fa66c40b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
   /// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
   virtual unsigned getJumpTableEncoding() const;
 
+  virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
+    return getPointerTy(DL);
+  }
+
   virtual const MCExpr *
   LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
                             const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1f4436fb3a4966..37ba62911ec70b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
   // Emit the code for the jump table
   assert(JT.SL && "Should set SDLoc for SelectionDAG!");
   assert(JT.Reg != -1U && "Should lower JT Header first!");
-  EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
+  EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
   SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
   SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
   SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
@@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
   // This value may be smaller or larger than the target's pointer type, and
   // therefore require extension or truncating.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
+  SwitchOp =
+      DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));
 
   unsigned JumpTableReg =
-      FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
-  SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
-                                    JumpTableReg, SwitchOp);
+      FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
+  SDValue CopyTo =
+      DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
   JT.Reg = JumpTableReg;
 
   if (!JTH.FallthroughUnreachable) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 516fc7339a4bf3..bf647c88f00e28 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -25,6 +25,7 @@
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::ROTR, MVT::i8, Expand);
   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
 
-  // Indirect branch is not supported.
-  // This also disables Jump Table creation.
-  setOperationAction(ISD::BR_JT, MVT::Other, Expand);
+  setOperationAction(ISD::BR_JT, MVT::Other, Custom);
   setOperationAction(ISD::BRIND, MVT::Other, Expand);
 
   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
@@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::Dummy)
     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
+    MAKE_CASE(NVPTXISD::BrxEnd)
+    MAKE_CASE(NVPTXISD::BrxItem)
+    MAKE_CASE(NVPTXISD::BrxStart)
     MAKE_CASE(NVPTXISD::Tex1DFloatS32)
     MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
     MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
@@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerFP_ROUND(Op, DAG);
   case ISD::FP_EXTEND:
     return LowerFP_EXTEND(Op, DAG);
+  case ISD::BR_JT:
+    return LowerBR_JT(Op, DAG);
   case ISD::VAARG:
     return LowerVAARG(Op, DAG);
   case ISD::VASTART:
@@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   }
 }
 
+SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Chain = Op.getOperand(0);
+  const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
+  SDValue Index = Op.getOperand(2);
+
+  unsigned JId = JT->getIndex();
+  MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
+  ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
+
+  SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
+
+  // Generate BrxStart node
+  SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
+  Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
+
+  // Generate BrxItem nodes
+  assert(!MBBs.empty());
+  for (MachineBasicBlock *MBB : MBBs.drop_back())
+    Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
+                        DAG.getBasicBlock(MBB), Chain.getValue(1));
+
+  // Generate BrxEnd nodes
+  SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
+                      IdV, Chain.getValue(1)};
+  SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
+
+  return BrxEnd;
+}
+
+// This will prevent AsmPrinter from trying to print the jump tables itself.
+unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
+  return MachineJumpTableInfo::EK_Inline;
+}
+
 // This function is almost a copy of SelectionDAG::expandVAArg().
 // The only diff is that this one produces loads from local address space.
 SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 63262961b363ed..32e6b044b0de1f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -62,6 +62,9 @@ enum NodeType : unsigned {
   BFI,
   PRMT,
   DYNAMIC_STACKALLOC,
+  BrxStart,
+  BrxItem,
+  BrxEnd,
   Dummy,
 
   LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
     return true;
   }
 
+  // The default is the same as pointer type, but brx.idx only accepts i32
+  MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
+
+  unsigned getJumpTableEncoding() const override;
+
   bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
 
   // The default is to transform llvm.ctlz(x, false) (where false indicates that
@@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 6a096fa5acea7c..d75dc8781f7802 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3880,6 +3880,46 @@ def DYNAMIC_STACKALLOC64 :
             [(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
             Requires<[hasPTX<73>, hasSM<52>]>;
 
+
+//
+// BRX
+//
+
+def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
+def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;
+
+def brx_start :
+  SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
+def brx_item :
+  SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
+         [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def brx_end :
+  SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
+         [SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;
+
+let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in {
+
+  def BRX_START :
+    NVPTXInst<(outs), (ins i32imm:$id),
+              "$$L_brx_$id: .branchtargets",
+              [(brx_start (i32 imm:$id))]>;
+
+  def BRX_ITEM :
+    NVPTXInst<(outs), (ins brtarget:$target),
+              "\t$target,",
+              [(brx_item bb:$target)]>;
+
+  def BRX_END :
+    NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
+              "\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
+              [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]> {
+      let isBarrier = 1;
+    }
+}
+
+
 include "NVPTXIntrinsics.td"
 
 //-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll
new file mode 100644
index 00000000000000..867e171a5840ae
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/jump-table.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+@out = addrspace(1) global i32 0, align 4
+
+define void @foo(i32 %i) {
+; CHECK-LABEL: foo(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.u32 %r2, [foo_param_0];
+; CHECK-NEXT:    setp.gt.u32 %p1, %r2, 3;
+; CHECK-NEXT:    @%p1 bra $L__BB0_6;
+; CHECK-NEXT:  // %bb.1: // %entry
+; CHECK-NEXT:    $L_brx_0: .branchtargets
+; CHECK-NEXT:     $L__BB0_2,
+; CHECK-NEXT:     $L__BB0_3,
+; CHECK-NEXT:     $L__BB0_4,
+; CHECK-NEXT:     $L__BB0_5;
+; CHECK-NEXT:    brx.idx %r2, $L_brx_0;
+; CHECK-NEXT:  $L__BB0_2: // %case0
+; CHECK-NEXT:    mov.b32 %r6, 0;
+; CHECK-NEXT:    st.global.u32 [out], %r6;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_4: // %case2
+; CHECK-NEXT:    mov.b32 %r4, 2;
+; CHECK-NEXT:    st.global.u32 [out], %r4;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_5: // %case3
+; CHECK-NEXT:    mov.b32 %r3, 3;
+; CHECK-NEXT:    st.global.u32 [out], %r3;
+; CHECK-NEXT:    bra.uni $L__BB0_6;
+; CHECK-NEXT:  $L__BB0_3: // %case1
+; CHECK-NEXT:    mov.b32 %r5, 1;
+; CHECK-NEXT:    st.global.u32 [out], %r5;
+; CHECK-NEXT:  $L__BB0_6: // %end
+; CHECK-NEXT:    ret;
+entry:
+  switch i32 %i, label %end [
+    i32 0, label %case0
+    i32 1, label %case1
+    i32 2, label %case2
+    i32 3, label %case3
+  ]
+
+case0:
+  store i32 0, ptr addrspace(1) @out, align 4
+  br label %end
+
+case1:
+  store i32 1, ptr addrspace(1) @out, align 4
+  br label %end
+
+case2:
+  store i32 2, ptr addrspace(1) @out, align 4
+  br label %end
+
+case3:
+  store i32 3, ptr addrspace(1) @out, align 4
+  br label %end
+
+end:
+  ret void
+}

@AlexMaclean AlexMaclean merged commit ccc3127 into llvm:main Aug 9, 2024
8 checks passed
kutemeikito added a commit to kutemeikito/llvm-project that referenced this pull request Aug 10, 2024
* 'main' of https://github.com/llvm/llvm-project: (700 commits)
  [SandboxIR][NFC] SingleLLVMInstructionImpl class (llvm#102687)
  [ThinLTO]Clean up 'import-assume-unique-local' flag. (llvm#102424)
  [nsan] Make #include more conventional
  [SandboxIR][NFC] Use Tracker.emplaceIfTracking()
  [libc]  Moved range_reduction_double ifdef statement (llvm#102659)
  [libc] Fix CFP long double and add tests (llvm#102660)
  [TargetLowering] Handle vector types in expandFixedPointMul (llvm#102635)
  [compiler-rt][NFC] Replace environment variable with %t (llvm#102197)
  [UnitTests] Convert a test to use opaque pointers (llvm#102668)
  [CodeGen][NFCI] Don't re-implement parts of ASTContext::getIntWidth (llvm#101765)
  [SandboxIR] Clean up tracking code with the help of emplaceIfTracking() (llvm#102406)
  [mlir][bazel] remove extra blanks in mlir-tblgen test
  [NVPTX][NFC] Update tests to use bfloat type (llvm#101493)
  [mlir] Add support for parsing nested PassPipelineOptions (llvm#101118)
  [mlir][bazel] add missing td dependency in mlir-tblgen test
  [flang][cuda] Fix lib dependency
  [libc] Clean up remaining use of *_WIDTH macros in printf (llvm#102679)
  [flang][cuda] Convert cuf.alloc for box to fir.alloca in device context (llvm#102662)
  [SandboxIR] Implement the InsertElementInst class (llvm#102404)
  [libc] Fix use of cpp::numeric_limits<...>::digits (llvm#102674)
  [mlir][ODS] Verify type constraints in Types and Attributes (llvm#102326)
  [LTO] enable `ObjCARCContractPass` only on optimized build  (llvm#101114)
  [mlir][ODS] Consistent `cppType` / `cppClassName` usage (llvm#102657)
  [lldb] Move definition of SBSaveCoreOptions dtor out of header (llvm#102539)
  [libc] Use cpp::numeric_limits in preference to C23 <limits.h> macros (llvm#102665)
  [clang] Implement -fptrauth-auth-traps. (llvm#102417)
  [LLVM][rtsan] rtsan transform to preserve CFGAnalyses (llvm#102651)
  Revert "[AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)"
  [RISCV][GISel] Add missing tests for G_CTLZ/CTTZ instruction selection. NFC
  Return available function types for BindingDecls. (llvm#102196)
  [clang] Wire -fptrauth-returns to "ptrauth-returns" fn attribute. (llvm#102416)
  [RISCV] Remove riscv-experimental-rv64-legal-i32. (llvm#102509)
  [RISCV] Move PseudoVSET(I)VLI expansion to use PseudoInstExpansion. (llvm#102496)
  [NVPTX] support switch statement with brx.idx (reland) (llvm#102550)
  [libc][newhdrgen]sorted function names in yaml (llvm#102544)
  [GlobalIsel] Combine G_ADD and G_SUB with constants (llvm#97771)
  Suppress spurious warnings due to R_RISCV_SET_ULEB128
  [scudo] Separated committed and decommitted entries. (llvm#101409)
  [MIPS] Fix missing ANDI optimization (llvm#97689)
  [Clang] Add env var for nvptx-arch/amdgpu-arch timeout (llvm#102521)
  [asan] Switch allocator to dynamic base address (llvm#98511)
  [AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)
  [libc][math][c23] Add fadd{l,f128} C23 math functions (llvm#102531)
  [mlir][bazel] revert bazel rule change for DLTITransformOps
  [msan] Support vst{2,3,4}_lane instructions (llvm#101215)
  Revert "[MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)"
  [X86] pr57673.ll - generate MIR test checks
  [mlir][vector][test] Split tests from vector-transfer-flatten.mlir (llvm#102584)
  [mlir][bazel] add bazel rule for DLTITransformOps
  OpenMPOpt: Remove dead include
  [IR] Add method to GlobalVariable to change type of initializer. (llvm#102553)
  [flang][cuda] Force default allocator in device code (llvm#102238)
  [llvm] Construct SmallVector<SDValue> with ArrayRef (NFC) (llvm#102578)
  [MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)
  [AMDGPU][AsmParser][NFC] Remove a misleading comment. (llvm#102604)
  [Arm][AArch64][Clang] Respect function's branch protection attributes. (llvm#101978)
  [mlir] Verifier: steal bit to track seen instead of set. (llvm#102626)
  [Clang] Fix Handling of Init Capture with Parameter Packs in LambdaScopeForCallOperatorInstantiationRAII (llvm#100766)
  [X86] Convert truncsat clamping patterns to use SDPatternMatch. NFC.
  [gn] Give two scripts argparse.RawDescriptionHelpFormatter
  [bazel] Add missing dep for the SPIRVToLLVM target
  [Clang] Simplify specifying passes via -Xoffload-linker (llvm#102483)
  [bazel] Port for d45de80
  [SelectionDAG] Use unaligned store/load to move AVX registers onto stack for `insertelement` (llvm#82130)
  [Clang][OMPX] Add the code generation for multi-dim `num_teams` (llvm#101407)
  [ARM] Regenerate big-endian-vmov.ll. NFC
  [AMDGPU][AsmParser][NFCI] All NamedIntOperands to be of the i32 type. (llvm#102616)
  [libc][math][c23] Add totalorderl function. (llvm#102564)
  [mlir][spirv] Support `memref` in `convert-to-spirv` pass (llvm#102534)
  [MLIR][GPU-LLVM] Convert `gpu.func` to `llvm.func` (llvm#101664)
  Fix a unit test input file (llvm#102567)
  [llvm-readobj][COFF] Dump hybrid objects for ARM64X files. (llvm#102245)
  AMDGPU/NewPM: Port SIFixSGPRCopies to new pass manager (llvm#102614)
  [MemoryBuiltins] Simplify getCalledFunction() helper (NFC)
  [AArch64] Add invalid 1 x vscale costs for reductions and reduction-operations. (llvm#102105)
  [MemoryBuiltins] Handle allocator attributes on call-site
  LSV/test/AArch64: add missing lit.local.cfg; fix build (llvm#102607)
  Revert "Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)"
  [RISCV] Add Syntacore SCR5 RV32/64 processors definition (llvm#102285)
  [InstCombine] Remove unnecessary RUN line from test (NFC)
  [flang][OpenMP] Handle multiple ranges in `num_teams` clause (llvm#102535)
  [mlir][vector] Add tests for scalable vectors in one-shot-bufferize.mlir (llvm#102361)
  [mlir][vector] Disable `vector.matrix_multiply` for scalable vectors (llvm#102573)
  [clang] Implement CWG2627 Bit-fields and narrowing conversions (llvm#78112)
  [NFC] Use references to avoid copying (llvm#99863)
  Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (llvm#100731)" (llvm#102457)
  [IRBuilder] Generate nuw GEPs for struct member accesses (llvm#99538)
  [bazel] Port for 9b06e25
  [CodeGen][NewPM] Improve start/stop pass error message CodeGenPassBuilder (llvm#102591)
  [AArch64] Implement TRBMPAM_EL1 system register (llvm#102485)
  [InstCombine] Fixing wrong select folding in vectors with undef elements (llvm#102244)
  [AArch64] Sink operands to fmuladd. (llvm#102297)
  LSV: document hang reported in llvm#37865 (llvm#102479)
  Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)
  [RISCV][clang] Remove bfloat base type in non-zvfbfmin vcreate (llvm#102146)
  [RISCV][clang] Add missing `zvfbfmin` to `vget_v` intrinsic (llvm#102149)
  [mlir][vector] Add mask elimination transform (llvm#99314)
  [Clang][Interp] Fix display of syntactically-invalid note for member function calls (llvm#102170)
  [bazel] Port for 3fffa6d
  [DebugInfo][RemoveDIs] Use iterator-inserters in clang (llvm#102006)
  ...

Signed-off-by: Edwiin Kusuma Jaya <kutemeikito0905@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants