Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SelectionDAG] Add CodeGen support for scalar FEAT_CPA #105669

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

rgwott
Copy link
Contributor

@rgwott rgwott commented Aug 22, 2024

CPA stands for Checked Pointer Arithmetic and is part of the 2023 MTE architecture extensions for A-profile.
The new CPA instructions perform regular pointer arithmetic (such as base register + offset) but check for overflow in the most significant bits of the result, enhancing security by detecting address tampering.

In this patch we intend to capture the semantics of pointer arithmetic when it is not folded into loads/stores, then generate the appropriate scalar CPA instructions. In order to preserve pointer arithmetic semantics through the backend, we add the PTRADD SelectionDAG node type.

Use -mcpa-codegen to enable CPA CodeGen (for a target with CPA enabled).

The PTRADD node and respective visitPTRADD() function are adapted from the CHERI/Morello LLVM tree.
Original authors: @davidchisnall, @jrtc27, @arichardson.

Mode details about the CPA extension can be found at:

This PR follows #79569.
It does not address vector FEAT_CPA instructions.

CPA stands for Checked Pointer Arithmetic and is part of the 2023 MTE
architecture extensions for A-profile.
The new CPA instructions perform regular pointer arithmetic (such as
base register + offset) but check for overflow in the most significant
bits of the result, enhancing security by detecting address tampering.

In this patch we intend to capture the semantics of pointer arithmetic
when it is not folded into loads/stores, then generate the appropriate
CPA instructions. In order to preserve pointer arithmetic semantics
through the backend, we add the PTRADD SelectionDAG node type.

The PTRADD node and respective visitPTRADD() function are adapted from
the CHERI/Morello LLVM tree.

Mode details about the CPA extension can be found at:

- https://community.arm.com/arm-community-blogs/b/architectures-and-processors-blog/posts/arm-a-profile-architecture-developments-2023
- https://developer.arm.com/documentation/ddi0602/2023-09/

This PR follows llvm#79569.
@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2024

@llvm/pr-subscribers-llvm-globalisel
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-selectiondag

Author: Rodolfo Wottrich (rgwott)

Changes

CPA stands for Checked Pointer Arithmetic and is part of the 2023 MTE architecture extensions for A-profile.
The new CPA instructions perform regular pointer arithmetic (such as base register + offset) but check for overflow in the most significant bits of the result, enhancing security by detecting address tampering.

In this patch we intend to capture the semantics of pointer arithmetic when it is not folded into loads/stores, then generate the appropriate CPA instructions. In order to preserve pointer arithmetic semantics through the backend, we add the PTRADD SelectionDAG node type.

The PTRADD node and respective visitPTRADD() function are adapted from the CHERI/Morello LLVM tree.

Mode details about the CPA extension can be found at:

This PR follows #79569.


Patch is 53.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105669.diff

14 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+4)
  • (modified) llvm/include/llvm/Target/TargetMachine.h (+5)
  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+2-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+98-3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (+17-4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+8-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+33-11)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+20)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetMachine.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetMachine.h (+4)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+4)
  • (added) llvm/test/CodeGen/AArch64/cpa-globalisel.ll (+455)
  • (added) llvm/test/CodeGen/AArch64/cpa-selectiondag.ll (+449)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 86ff2628975942..305b3349307777 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1452,6 +1452,10 @@ enum NodeType {
   // Outputs: [rv], output chain, glue
   PATCHPOINT,
 
+  // PTRADD represents pointer arithmetic semantics, for those targets which
+  // benefit from that information.
+  PTRADD,
+
 // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(VPSDID, ...) VPSDID,
 #include "llvm/IR/VPIntrinsics.def"
diff --git a/llvm/include/llvm/Target/TargetMachine.h b/llvm/include/llvm/Target/TargetMachine.h
index c3e9d41315f617..26425fced52528 100644
--- a/llvm/include/llvm/Target/TargetMachine.h
+++ b/llvm/include/llvm/Target/TargetMachine.h
@@ -434,6 +434,11 @@ class TargetMachine {
       function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) {
     return false;
   }
+
+  /// True if target has some particular form of dealing with pointer arithmetic
+  /// semantics. False if pointer arithmetic should not be preserved for passes
+  /// such as instruction selection, and can fallback to regular arithmetic.
+  virtual bool shouldPreservePtrArith(const Function &F) const { return false; }
 };
 
 /// This class describes a target machine that is implemented with the LLVM
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 172deffbd31771..aeb27ccf921a4b 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -109,7 +109,7 @@ def SDTOther  : SDTypeProfile<1, 0, [SDTCisVT<0, OtherVT>]>; // for 'vt'.
 def SDTUNDEF  : SDTypeProfile<1, 0, []>;                     // for 'undef'.
 def SDTUnaryOp  : SDTypeProfile<1, 1, []>;                   // for bitconvert.
 
-def SDTPtrAddOp : SDTypeProfile<1, 2, [     // ptradd
+def SDTPtrAddOp : SDTypeProfile<1, 2, [  // ptradd
   SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisPtrTy<1>
 ]>;
 def SDTIntBinOp : SDTypeProfile<1, 2, [     // add, and, or, xor, udiv, etc.
@@ -390,7 +390,7 @@ def tblockaddress: SDNode<"ISD::TargetBlockAddress",  SDTPtrLeaf, [],
 
 def add        : SDNode<"ISD::ADD"       , SDTIntBinOp   ,
                         [SDNPCommutative, SDNPAssociative]>;
-def ptradd     : SDNode<"ISD::ADD"       , SDTPtrAddOp, []>;
+def ptradd     : SDNode<"ISD::PTRADD"    , SDTPtrAddOp, []>;
 def sub        : SDNode<"ISD::SUB"       , SDTIntBinOp>;
 def mul        : SDNode<"ISD::MUL"       , SDTIntBinOp,
                         [SDNPCommutative, SDNPAssociative]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 11935cbc309f01..16a16e1d702c29 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -416,7 +416,9 @@ namespace {
     SDValue visitMERGE_VALUES(SDNode *N);
     SDValue visitADD(SDNode *N);
     SDValue visitADDLike(SDNode *N);
-    SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
+    SDValue visitADDLikeCommutative(SDValue N0, SDValue N1,
+                                    SDNode *LocReference);
+    SDValue visitPTRADD(SDNode *N);
     SDValue visitSUB(SDNode *N);
     SDValue visitADDSAT(SDNode *N);
     SDValue visitSUBSAT(SDNode *N);
@@ -1082,7 +1084,7 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
   // (load/store (add, (add, x, y), offset2)) ->
   // (load/store (add, (add, x, offset2), y)).
 
-  if (N0.getOpcode() != ISD::ADD)
+  if (N0.getOpcode() != ISD::ADD && N0.getOpcode() != ISD::PTRADD)
     return false;
 
   // Check for vscale addressing modes.
@@ -1833,6 +1835,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::TokenFactor:        return visitTokenFactor(N);
   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
   case ISD::ADD:                return visitADD(N);
+  case ISD::PTRADD:             return visitPTRADD(N);
   case ISD::SUB:                return visitSUB(N);
   case ISD::SADDSAT:
   case ISD::UADDSAT:            return visitADDSAT(N);
@@ -2349,7 +2352,7 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
   }
 
   TargetLowering::AddrMode AM;
-  if (N->getOpcode() == ISD::ADD) {
+  if (N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::PTRADD) {
     AM.HasBaseReg = true;
     ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
     if (Offset)
@@ -2578,6 +2581,98 @@ SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
   return SDValue();
 }
 
+/// Try to fold a pointer arithmetic node.
+/// This needs to be done separately from normal addition, because pointer
+/// addition is not commutative.
+/// This function was adapted from DAGCombiner::visitPTRADD() from the Morello
+/// project, which is based on CHERI.
+SDValue DAGCombiner::visitPTRADD(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  EVT PtrVT = N0.getValueType();
+  EVT IntVT = N1.getValueType();
+  SDLoc DL(N);
+
+  // fold (ptradd undef, y) -> undef
+  if (N0.isUndef())
+    return N0;
+
+  // fold (ptradd x, undef) -> undef
+  if (N1.isUndef())
+    return DAG.getUNDEF(PtrVT);
+
+  // fold (ptradd x, 0) -> x
+  if (isNullConstant(N1))
+    return N0;
+
+  if (N0.getOpcode() == ISD::PTRADD &&
+      !reassociationCanBreakAddressingModePattern(ISD::PTRADD, DL, N, N0, N1)) {
+    SDValue X = N0.getOperand(0);
+    SDValue Y = N0.getOperand(1);
+    SDValue Z = N1;
+    bool N0OneUse = N0.hasOneUse();
+    bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
+    bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
+
+    // (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y) if:
+    //   * (ptradd x, y) has one use; and
+    //   * y is a constant; and
+    //   * z is not a constant.
+    // Serves to expose constant y for subsequent folding.
+    if (N0OneUse && YIsConstant && !ZIsConstant) {
+      SDValue Add = DAG.getNode(ISD::PTRADD, DL, IntVT, {X, Z});
+
+      // Calling visit() can replace the Add node with ISD::DELETED_NODE if
+      // there aren't any users, so keep a handle around whilst we visit it.
+      HandleSDNode ADDHandle(Add);
+
+      SDValue VisitedAdd = visit(Add.getNode());
+      if (VisitedAdd) {
+        // If visit() returns the same node, it means the SDNode was RAUW'd, and
+        // therefore we have to load the new value to perform the checks whether
+        // the reassociation fold is profitable.
+        if (VisitedAdd.getNode() == Add.getNode())
+          Add = ADDHandle.getValue();
+        else
+          Add = VisitedAdd;
+      }
+
+      return DAG.getMemBasePlusOffset(Add, Y, DL, SDNodeFlags());
+    }
+
+    bool ZOneUse = Z.hasOneUse();
+
+    // (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
+    //   * x is a null pointer; or
+    //   * y is a constant and z has one use; or
+    //   * y is a constant and (ptradd x, y) has one use; or
+    //   * (ptradd x, y) and z have one use and z is not a constant.
+    if (isNullConstant(X) || (YIsConstant && ZOneUse) ||
+        (YIsConstant && N0OneUse) || (N0OneUse && ZOneUse && !ZIsConstant)) {
+      SDValue Add = DAG.getNode(ISD::ADD, DL, IntVT, {Y, Z});
+
+      // Calling visit() can replace the Add node with ISD::DELETED_NODE if
+      // there aren't any users, so keep a handle around whilst we visit it.
+      HandleSDNode ADDHandle(Add);
+
+      SDValue VisitedAdd = visit(Add.getNode());
+      if (VisitedAdd) {
+        // If visit() returns the same node, it means the SDNode was RAUW'd, and
+        // therefore we have to load the new value to perform the checks whether
+        // the reassociation fold is profitable.
+        if (VisitedAdd.getNode() == Add.getNode())
+          Add = ADDHandle.getValue();
+        else
+          Add = VisitedAdd;
+      }
+
+      return DAG.getMemBasePlusOffset(X, Add, DL, SDNodeFlags());
+    }
+  }
+
+  return SDValue();
+}
+
 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
 /// a shift and add with a different constant.
 static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 74e3a898569bea..28e0bdbb549c66 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -4069,8 +4069,14 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     else
       Index = DAG.getNode(ISD::MUL, dl, Index.getValueType(), Index,
                           DAG.getConstant(EntrySize, dl, Index.getValueType()));
-    SDValue Addr = DAG.getNode(ISD::ADD, dl, Index.getValueType(),
-                               Index, Table);
+    SDValue Addr;
+    if (!DAG.getTarget().shouldPreservePtrArith(
+            DAG.getMachineFunction().getFunction())) {
+      Addr = DAG.getNode(ISD::ADD, dl, Index.getValueType(), Index, Table);
+    } else {
+      // PTRADD always takes the pointer first, so the operands are commuted
+      Addr = DAG.getNode(ISD::PTRADD, dl, Index.getValueType(), Table, Index);
+    }
 
     EVT MemVT = EVT::getIntegerVT(*DAG.getContext(), EntrySize * 8);
     SDValue LD = DAG.getExtLoad(
@@ -4081,8 +4087,15 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
       // For PIC, the sequence is:
       // BRIND(load(Jumptable + index) + RelocBase)
       // RelocBase can be JumpTable, GOT or some sort of global base.
-      Addr = DAG.getNode(ISD::ADD, dl, PTy, Addr,
-                          TLI.getPICJumpTableRelocBase(Table, DAG));
+      if (!DAG.getTarget().shouldPreservePtrArith(
+              DAG.getMachineFunction().getFunction())) {
+        Addr = DAG.getNode(ISD::ADD, dl, PTy, Addr,
+                           TLI.getPICJumpTableRelocBase(Table, DAG));
+      } else {
+        // PTRADD always takes the pointer first, so the operands are commuted
+        Addr = DAG.getNode(ISD::PTRADD, dl, PTy,
+                           TLI.getPICJumpTableRelocBase(Table, DAG), Addr);
+      }
     }
 
     Tmp1 = TLI.expandIndirectJTBranch(dl, LD.getValue(1), Addr, JTI, DAG);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 27675dce70c260..dd746234e6ad83 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -5387,7 +5387,8 @@ bool SelectionDAG::isADDLike(SDValue Op, bool NoWrap) const {
 
 bool SelectionDAG::isBaseWithConstantOffset(SDValue Op) const {
   return Op.getNumOperands() == 2 && isa<ConstantSDNode>(Op.getOperand(1)) &&
-         (Op.getOpcode() == ISD::ADD || isADDLike(Op));
+         (Op.getOpcode() == ISD::ADD || Op.getOpcode() == ISD::PTRADD ||
+          isADDLike(Op));
 }
 
 bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const {
@@ -7785,7 +7786,12 @@ SDValue SelectionDAG::getMemBasePlusOffset(SDValue Ptr, SDValue Offset,
                                            const SDNodeFlags Flags) {
   assert(Offset.getValueType().isInteger());
   EVT BasePtrVT = Ptr.getValueType();
-  return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, Flags);
+  if (!this->getTarget().shouldPreservePtrArith(
+          this->getMachineFunction().getFunction())) {
+    return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, Flags);
+  } else {
+    return getNode(ISD::PTRADD, DL, BasePtrVT, Ptr, Offset, Flags);
+  }
 }
 
 /// Returns true if memcpy source is constant data.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 60dcb118542785..f6e797dee395a4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4293,6 +4293,12 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
   SDLoc dl = getCurSDLoc();
   auto &TLI = DAG.getTargetLoweringInfo();
   GEPNoWrapFlags NW = cast<GEPOperator>(I).getNoWrapFlags();
+  unsigned int AddOpcode = ISD::PTRADD;
+
+  if (!DAG.getTarget().shouldPreservePtrArith(
+          DAG.getMachineFunction().getFunction())) {
+    AddOpcode = ISD::ADD;
+  }
 
   // Normalize Vector GEP - all scalar operands should be converted to the
   // splat vector.
@@ -4324,7 +4330,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
             (int64_t(Offset) >= 0 && NW.hasNoUnsignedSignedWrap()))
           Flags.setNoUnsignedWrap(true);
 
-        N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N,
+        N = DAG.getNode(AddOpcode, dl, N.getValueType(), N,
                         DAG.getConstant(Offset, dl, N.getValueType()), Flags);
       }
     } else {
@@ -4368,7 +4374,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
 
         OffsVal = DAG.getSExtOrTrunc(OffsVal, dl, N.getValueType());
 
-        N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N, OffsVal, Flags);
+        N = DAG.getNode(AddOpcode, dl, N.getValueType(), N, OffsVal, Flags);
         continue;
       }
 
@@ -4411,8 +4417,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
         }
       }
 
-      N = DAG.getNode(ISD::ADD, dl,
-                      N.getValueType(), N, IdxN);
+      N = DAG.getNode(AddOpcode, dl, N.getValueType(), N, IdxN);
     }
   }
 
@@ -4473,8 +4478,15 @@ void SelectionDAGBuilder::visitAlloca(const AllocaInst &I) {
   // an address inside an alloca.
   SDNodeFlags Flags;
   Flags.setNoUnsignedWrap(true);
-  AllocSize = DAG.getNode(ISD::ADD, dl, AllocSize.getValueType(), AllocSize,
-                          DAG.getConstant(StackAlignMask, dl, IntPtr), Flags);
+  if (DAG.getTarget().shouldPreservePtrArith(
+          DAG.getMachineFunction().getFunction())) {
+    AllocSize = DAG.getNode(ISD::PTRADD, dl, AllocSize.getValueType(),
+                            DAG.getConstant(StackAlignMask, dl, IntPtr),
+                            AllocSize, Flags);
+  } else {
+    AllocSize = DAG.getNode(ISD::ADD, dl, AllocSize.getValueType(), AllocSize,
+                            DAG.getConstant(StackAlignMask, dl, IntPtr), Flags);
+  }
 
   // Mask out the low bits for alignment purposes.
   AllocSize = DAG.getNode(ISD::AND, dl, AllocSize.getValueType(), AllocSize,
@@ -9071,8 +9083,13 @@ bool SelectionDAGBuilder::visitMemPCpyCall(const CallInst &I) {
   Size = DAG.getSExtOrTrunc(Size, sdl, Dst.getValueType());
 
   // Adjust return pointer to point just past the last dst byte.
-  SDValue DstPlusSize = DAG.getNode(ISD::ADD, sdl, Dst.getValueType(),
-                                    Dst, Size);
+  unsigned int AddOpcode = ISD::PTRADD;
+  if (!DAG.getTarget().shouldPreservePtrArith(
+          DAG.getMachineFunction().getFunction())) {
+    AddOpcode = ISD::ADD;
+  }
+  SDValue DstPlusSize =
+      DAG.getNode(AddOpcode, sdl, Dst.getValueType(), Dst, Size);
   setValue(&I, DstPlusSize);
   return true;
 }
@@ -11169,9 +11186,14 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
     MachineFunction &MF = CLI.DAG.getMachineFunction();
     Align HiddenSRetAlign = MF.getFrameInfo().getObjectAlign(DemoteStackIdx);
     for (unsigned i = 0; i < NumValues; ++i) {
-      SDValue Add = CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
-                                    CLI.DAG.getConstant(Offsets[i], CLI.DL,
-                                                        PtrVT), Flags);
+      unsigned int AddOpcode = ISD::PTRADD;
+      if (!CLI.DAG.getTarget().shouldPreservePtrArith(
+              CLI.DAG.getMachineFunction().getFunction())) {
+        AddOpcode = ISD::ADD;
+      }
+      SDValue Add = CLI.DAG.getNode(
+          AddOpcode, CLI.DL, PtrVT, DemoteStackSlot,
+          CLI.DAG.getConstant(Offsets[i], CLI.DL, PtrVT), Flags);
       SDValue L = CLI.DAG.getLoad(
           RetTys[i], CLI.DL, CLI.Chain, Add,
           MachinePointerInfo::getFixedStack(CLI.DAG.getMachineFunction(),
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 001f782f209fdb..5e727df0fcab48 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -256,6 +256,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
 
   // Binary operators
   case ISD::ADD:                        return "add";
+  case ISD::PTRADD:                     return "ptradd";
   case ISD::SUB:                        return "sub";
   case ISD::MUL:                        return "mul";
   case ISD::MULHU:                      return "mulhu";
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index ec225a5b234a26..0fc31cf9120838 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -10223,6 +10223,26 @@ let Predicates = [HasCPA] in {
   // Scalar multiply-add/subtract
   def MADDPT : MulAccumCPA<0, "maddpt">;
   def MSUBPT : MulAccumCPA<1, "msubpt">;
+
+  // Rules to use CPA instructions in pointer arithmetic patterns which are not
+  // folded into loads/stores. The AddedComplexity serves to help supersede
+  // other simpler (non-CPA) patterns and make sure CPA is used instead.
+  let AddedComplexity = 20 in {
+    def : Pat<(ptradd GPR64sp:$Rn, GPR64sp:$Rm),
+              (ADDPT_shift GPR64sp:$Rn, GPR64sp:$Rm, (i32 0))>;
+    def : Pat<(ptradd GPR64sp:$Rn, (shl GPR64sp:$Rm, (i64 imm0_7:$imm))),
+              (ADDPT_shift GPR64sp:$Rn, GPR64sp:$Rm,
+                           (i32 (trunc_imm imm0_7:$imm)))>;
+    def : Pat<(ptradd GPR64sp:$Rn, (ineg GPR64sp:$Rm)),
+              (SUBPT_shift GPR64sp:$Rn, GPR64sp:$Rm, (i32 0))>;
+    def : Pat<(ptradd GPR64sp:$Rn, (ineg (shl GPR64sp:$Rm, (i64 imm0_7:$imm)))),
+              (SUBPT_shift GPR64sp:$Rn, GPR64sp:$Rm,
+                           (i32 (trunc_imm imm0_7:$imm)))>;
+    def : Pat<(ptradd GPR64:$Ra, (mul GPR64:$Rn, GPR64:$Rm)),
+              (MADDPT GPR64:$Rn, GPR64:$Rm, GPR64:$Ra)>;
+    def : Pat<(ptradd GPR64:$Ra, (mul GPR64:$Rn, (ineg GPR64:$Rm))),
+              (MSUBPT GPR64:$Rn, GPR64:$Rm, GPR64:$Ra)>;
+  }
 }
 
 def round_v4fp32_to_v4bf16 :
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index bd5684a287381a..3dfc90380ccb86 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -920,3 +920,7 @@ bool AArch64TargetMachine::parseMachineFunctionInfo(
   MF.getInfo<AArch64FunctionInfo>()->initializeBaseYamlFields(YamlMFI);
   return false;
 }
+
+bool AArch64TargetMachine::shouldPreservePtrArith(const Function &F) const {
+  return getSubtargetImpl(F)->hasCPA();
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.h b/llvm/lib/Target/AArch64/AArch64TargetMachine.h
index 1a470ca87127ce..c161223fe7fc10 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.h
@@ -69,6 +69,10 @@ class AArch64TargetMachine : public LLVMTargetMachine {
     return true;
   }
 
+  /// In AArch64, true if FEAT_CPA is present. Allows pointer arithmetic
+  /// semantics to be preserved for instruction selection.
+  bool shouldPreservePtrArith(const Function &F) const override;
+
 private:
   bool isLittle;
 };
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index e9e6b6cb68d0d1..158a4b2d0d1577 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -2091,6 +2091,10 @@ bool AArch64InstructionSelector::preISelLower(MachineInstr &I) {
     return Changed;
   }
   case TargetOpcode::G_PTR_ADD:
+    // If Checked Pointer Arithmetic (FEAT_CPA) is present, preserve the pointer
+    // arithmetic semantics instead of falling back to regular arithmetic.
+    if (TM.shouldPreservePtrArith(MF.getFunction()))
+      return false;
     return convertPtrAddToAdd(I, MRI);
   case TargetOpcode::G_LOAD: {
     // For scalar loads of pointers, we try to convert the dest type from p0
diff --git a/llvm/test/CodeGen/AArch64/cpa-globalisel.ll b/llvm/test/CodeGen/AArch64/cpa-globalisel.ll
ne...
[truncated]

/// True if target has some particular form of dealing with pointer arithmetic
/// semantics. False if pointer arithmetic should not be preserved for passes
/// such as instruction selection, and can fallback to regular arithmetic.
virtual bool shouldPreservePtrArith(const Function &F) const { return false; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect this to be a property of individual instructions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For us it's a property of the type, so having the instruction would suffice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though getting the instruction in getMemBasePlusOffset isn't doable, and we really do need that interface to do the right thing whatever you throw at it. We add new c<N> types for capabilities that make it easy to know what you want to do, but CPA doesn't really fit in with that as they're still integers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FEAT_CPA this is a property of the target. Suggestions for improvement are welcome. =)

Copy link
Member

@arichardson arichardson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for upstreaming this - I think we should use getMemBasePlusOffset in more places though instead of all the added checks.

I also think the commit message should include Co-authored-by: for @davidchisnall and @jrtc27 (I don't think I substantively touched any of the imported lines so no need to include me. EDIT: just checked the blame, and my only contribution here is an infinite combine fix so I don't mind either way).


// (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
// * x is a null pointer; or
// * y is a constant and z has one use; or
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the reasoning behind dropping a couple of the conditions present in CHERI LLVM? I spent quite a while at the time tweaking them to catch everything I could reasonably throw at it, but there could be redundancy in the final version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate the work put into figuring out all those conditions, and some of them I am not even totally knowledgeable about. However we did find test cases that broke some. There was the possibility that new nodes were introduced, then visit() was called, then the reassociation would ultimately be cancelled after visit() modified the DAG. This left the DAG in an inconsistent state. My call was to remove the offending conditions as I found no better solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share these test cases?

Copy link
Contributor Author

@rgwott rgwott Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Our fuzzer found several cases that are essentially the same, so I only have record of this:

#include <stdint.h>
#include <stdio.h>

uint8_t a[2][1][2] = {1, 1, 1, 1};
uint16_t b = 0;

void main() {
if (a[1][b][b + 1])
printf("hello\n");
}

The corresponding IR is in the test cases included in this PR (multidim).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course I may be wrong and it could be a problem with my adaptation.

@arichardson
Copy link
Member

ping? Is this still being worked on?

@rgwott
Copy link
Contributor Author

rgwott commented Nov 14, 2024

ping? Is this still being worked on?

Hi. This fell down the priority list for me, but I should be able to reply to the comments and resolve the outstanding matters in the next weeks.

Rodolfo.Wottrich@arm.com and others added 4 commits January 20, 2025 14:28
Reduce complexity of PTRADD usage logic by harnessing
getMemBasePlusOffset().
Not all uses are applicable, as ADD's order of operands might be inverted
(PTRADD should always have pointer first, offset second).
By allowing getMemBasePlusOffset() to know whether its use for regular
ADDs does not take pointer first and offset second, the generation of
PTRADD on enabled targets can be done correctly, with the arguments
inverted.

This modification is to avoid changing the generation of some ADDs,
thus requiring the rewrite of several tests for several architectures.
@rgwott rgwott changed the title [AArch64] Add CodeGen support for FEAT_CPA [AArch64][SelectionDAG] Add CodeGen support for FEAT_CPA Jan 23, 2025

This comment was marked as off-topic.

@rgwott rgwott changed the title [AArch64][SelectionDAG] Add CodeGen support for FEAT_CPA [AArch64][SelectionDAG] Add CodeGen support for scalar FEAT_CPA Jan 28, 2025
Copy link
Member

@arichardson arichardson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks mostly good now - I will work on submitting some changes shortly so that the "Inverted flag" is not needed anymore.

@arichardson
Copy link
Member

It should be possible to drop the inverted flag once #125279 lands.

One more nit: maybe the commit message should link to https://developer.arm.com/documentation/ddi0602/2024-12/Base-Instructions/ADDPT--Add-checked-pointer- instead of the top-level instruction reference?

arichardson added a commit that referenced this pull request Jan 31, 2025
This is needed for architectures that actually use strict pointer
arithmetic instead of integers such as AArch64 with FEAT_CPA (see
#105669) or CHERI. Using an
index as the first operand of pointer arithmetic may result in an
invalid output.

While there are quite a few codegen changes here, these only change the
order of registers in add instructions. One MIPS combine had to be
updated to handle the new node order.

Reviewed By: topperc

Pull Request: #125279
Copy link
Member

@arichardson arichardson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that #125279 has landed this LGTM once rebased and the "inverted flag" has been dropped and the ISD::ADD for alignment has been restored.

github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 31, 2025
…ump tables

This is needed for architectures that actually use strict pointer
arithmetic instead of integers such as AArch64 with FEAT_CPA (see
llvm/llvm-project#105669) or CHERI. Using an
index as the first operand of pointer arithmetic may result in an
invalid output.

While there are quite a few codegen changes here, these only change the
order of registers in add instructions. One MIPS combine had to be
updated to handle the new node order.

Reviewed By: topperc

Pull Request: llvm/llvm-project#125279
@@ -0,0 +1,451 @@
; RUN: llc -mtriple=aarch64 -verify-machineinstrs --mattr=+cpa -O0 -global-isel=0 -fast-isel=0 %s -o - 2>&1 | FileCheck %s --check-prefixes=CHECK-CPA-O0
; RUN: llc -mtriple=aarch64 -verify-machineinstrs --mattr=+cpa -O3 -global-isel=0 -fast-isel=0 %s -o - 2>&1 | FileCheck %s --check-prefixes=CHECK-CPA-O3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need -O3 and -O0 run lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is interesting to have because sometimes the instruction selected is different, allowing us to exercise both. This did however get a little bit less useful with the removal of a SelectionDAG fold opportunity during the course of this PR. But I am in favour of keeping it.

@array2 = external dso_local global [10 x %struct.my_type2], align 8

define void @addpt1(i64 %index, i64 %arg) {
; CHECK-CPA-O0-LABEL: addpt1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to auto-generate these test checks with update_llc_test_checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why, could you explain your rationale?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It helps keep the tests regular and easier to maintain. (If you don't do it now, someone will likely regenerate them in the future if they change). You might be able to use --check-prefixes=CHECK-CPA,CHECK-CPA-O0 if we expect the O0 and O3 codegen to be the same in some tests.

@@ -401,7 +401,7 @@ def tblockaddress: SDNode<"ISD::TargetBlockAddress", SDTPtrLeaf, [],

def add : SDNode<"ISD::ADD" , SDTIntBinOp ,
[SDNPCommutative, SDNPAssociative]>;
def ptradd : SDNode<"ISD::ADD" , SDTPtrAddOp, []>;
def ptradd : SDNode<"ISD::PTRADD" , SDTPtrAddOp, []>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are redefining the meaning of this node then we should update the existing uses of it. I see only 1 in the AMDGPU backend, which is probably covered by add, but we should check it can be removed as the new meaning is different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I investigated a bit about this AMDGPU occurrence. When building the compiler, the tablegen will indeed compile to include a ptradd pattern to select a certain instruction. However: (1) there are two consecutive entries, one for add and one for ptradd, that do the same thing, so they are equivalent, and (2) in compile time, the PTRADD node will never exist for that backend so the tablegen entry will never be selected (the one for ADD will instead). Thankfully this is not a problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to update the existing uses. If the AMD ptradd line is unnecessary, then removing it should not be a problem. It looks like it might be used for gisel at the moment though?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AMDGPU ptradd line has an effect for global ISel, it shouldn't be removed. This ptradd SDNode was introduced as an equivalent to global ISel's G_PTR_ADD (as declared in SelectionDAGCompat.td), to specify SDAG patterns that are auto-translated to global ISel patterns.
As far as I'm aware, it doesn't matter for that if ptradd uses ISD::ADD or ISD::PTRADD, so changing it as the PR currently does is fine.

@rgwott
Copy link
Contributor Author

rgwott commented Feb 4, 2025

maybe the commit message should link to https://developer.arm.com/documentation/ddi0602/2024-12/Base-Instructions/ADDPT--Add-checked-pointer- instead of the top-level instruction reference?

I mean, yes and no. This is just one of the CPA instructions (although the most important one). I will make reference to it alongside the link.

@@ -5025,6 +5025,11 @@ def msve_vector_bits_EQ : Joined<["-"], "msve-vector-bits=">, Group<m_aarch64_Fe
Visibility<[ClangOption, FlangOption]>,
HelpText<"Specify the size in bits of an SVE vector register. Defaults to the"
" vector length agnostic value of \"scalable\". (AArch64 only)">;

def mcpa_codegen : Flag<["-"], "mcpa-codegen">,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it just an internal option for the time being? Otherwise we need to agree to the interface with GCC, and I don't think that has been done yet. It might be something to add in a future patch when we are sure of the details.

@@ -401,7 +401,7 @@ def tblockaddress: SDNode<"ISD::TargetBlockAddress", SDTPtrLeaf, [],

def add : SDNode<"ISD::ADD" , SDTIntBinOp ,
[SDNPCommutative, SDNPAssociative]>;
def ptradd : SDNode<"ISD::ADD" , SDTPtrAddOp, []>;
def ptradd : SDNode<"ISD::PTRADD" , SDTPtrAddOp, []>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to update the existing uses. If the AMD ptradd line is unnecessary, then removing it should not be a problem. It looks like it might be used for gisel at the moment though?

@array2 = external dso_local global [10 x %struct.my_type2], align 8

define void @addpt1(i64 %index, i64 %arg) {
; CHECK-CPA-O0-LABEL: addpt1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It helps keep the tests regular and easier to maintain. (If you don't do it now, someone will likely regenerate them in the future if they change). You might be able to use --check-prefixes=CHECK-CPA,CHECK-CPA-O0 if we expect the O0 and O3 codegen to be the same in some tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants