Skip to content

Add deactivation symbol operand to ConstantPtrAuth. #133537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: users/pcc/spr/main.add-deactivation-symbol-operand-to-constantptrauth
Choose a base branch
from

Conversation

pcc
Copy link
Contributor

@pcc pcc commented Mar 28, 2025

Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:

  • Fix broken test.
  • Add bitcode and IR writer support.
  • Add tests.

Created using spr 1.3.6-beta.1
@pcc pcc requested a review from nikic as a code owner March 28, 2025 22:34
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AArch64 clang:codegen IR generation bugs: mangling, exceptions, etc. llvm:instcombine llvm:ir llvm:transforms llvm:SandboxIR labels Mar 28, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-llvm-transforms

Author: Peter Collingbourne (pcc)

Changes

Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:

  • Fix broken test.
  • Add bitcode and IR writer support.
  • Add tests.

Patch is 22.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133537.diff

16 Files Affected:

  • (modified) clang/lib/CodeGen/CGPointerAuth.cpp (+3-3)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+1)
  • (modified) llvm/include/llvm/IR/Constants.h (+9-4)
  • (modified) llvm/include/llvm/SandboxIR/Constant.h (+4-1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+21-8)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+17-1)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+3-1)
  • (modified) llvm/lib/IR/Constants.cpp (+8-4)
  • (modified) llvm/lib/IR/ConstantsContext.h (+2-1)
  • (modified) llvm/lib/IR/Core.cpp (+3-1)
  • (modified) llvm/lib/SandboxIR/Constant.cpp (+9-2)
  • (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+31-6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+2-2)
  • (modified) llvm/lib/Transforms/Utils/ValueMapper.cpp (+3-2)
  • (modified) llvm/unittests/SandboxIR/SandboxIRTest.cpp (+1-1)
  • (modified) llvm/unittests/Transforms/Utils/ValueMapperTest.cpp (+9-4)
diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp
index 4b032306ead72..2d72fef470af6 100644
--- a/clang/lib/CodeGen/CGPointerAuth.cpp
+++ b/clang/lib/CodeGen/CGPointerAuth.cpp
@@ -308,9 +308,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key,
     IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0);
   }
 
-  return llvm::ConstantPtrAuth::get(Pointer,
-                                    llvm::ConstantInt::get(Int32Ty, Key),
-                                    IntegerDiscriminator, AddressDiscriminator);
+  return llvm::ConstantPtrAuth::get(
+      Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator,
+      AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy));
 }
 
 /// Does a given PointerAuthScheme require us to sign a value
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index ec2535ac85966..13521ba6cd00f 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -431,6 +431,7 @@ enum ConstantsCodes {
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
   CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
   CST_CODE_PTRAUTH = 33,              // [ptr, key, disc, addrdisc]
+  CST_CODE_PTRAUTH2 = 34,             // [ptr, key, disc, addrdisc, DeactivationSymbol]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index a50217078d0ed..45d5352bf06a6 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1022,10 +1022,10 @@ class ConstantPtrAuth final : public Constant {
   friend struct ConstantPtrAuthKeyType;
   friend class Constant;
 
-  constexpr static IntrusiveOperandsAllocMarker AllocMarker{4};
+  constexpr static IntrusiveOperandsAllocMarker AllocMarker{5};
 
   ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc,
-                  Constant *AddrDisc);
+                  Constant *AddrDisc, Constant *DeactivationSymbol);
 
   void *operator new(size_t s) { return User::operator new(s, AllocMarker); }
 
@@ -1035,7 +1035,8 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer signed with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              ConstantInt *Disc, Constant *AddrDisc);
+                              ConstantInt *Disc, Constant *AddrDisc,
+                              Constant *DeactivationSymbol);
 
   /// Produce a new ptrauth expression signing the given value using
   /// the same schema as is stored in one.
@@ -1067,6 +1068,10 @@ class ConstantPtrAuth final : public Constant {
     return !getAddrDiscriminator()->isNullValue();
   }
 
+  Constant *getDeactivationSymbol() const {
+    return cast<Constant>(Op<4>().get());
+  }
+
   /// A constant value for the address discriminator which has special
   /// significance to ctors/dtors lowering. Regular address discrimination can't
   /// be applied for them since uses of llvm.global_{c|d}tors are disallowed
@@ -1094,7 +1099,7 @@ class ConstantPtrAuth final : public Constant {
 
 template <>
 struct OperandTraits<ConstantPtrAuth>
-    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+    : public FixedNumOperandTraits<ConstantPtrAuth, 5> {};
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
 
diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h
index 17f55e973cd76..5243a9476ac64 100644
--- a/llvm/include/llvm/SandboxIR/Constant.h
+++ b/llvm/include/llvm/SandboxIR/Constant.h
@@ -1096,7 +1096,8 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer signed with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              ConstantInt *Disc, Constant *AddrDisc);
+                              ConstantInt *Disc, Constant *AddrDisc,
+                              Constant *DeactivationSymbol);
   /// The pointer that is signed in this ptrauth signed pointer.
   Constant *getPointer() const;
 
@@ -1111,6 +1112,8 @@ class ConstantPtrAuth final : public Constant {
   /// the only global-initializer user of the ptrauth signed pointer.
   Constant *getAddrDiscriminator() const;
 
+  Constant *getDeactivationSymbol() const;
+
   /// Whether there is any non-null address discriminator.
   bool hasAddressDiscriminator() const {
     return cast<llvm::ConstantPtrAuth>(Val)->hasAddressDiscriminator();
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 960119bab0933..dfa014aa0bd7d 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4226,11 +4226,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   }
   case lltok::kw_ptrauth: {
     // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
-    //                         (',' i64 <disc> (',' ptr addrdisc)? )? ')'
+    //                         (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)? )? )? ')'
     Lex.Lex();
 
     Constant *Ptr, *Key;
-    Constant *Disc = nullptr, *AddrDisc = nullptr;
+    Constant *Disc = nullptr, *AddrDisc = nullptr,
+             *DeactivationSymbol = nullptr;
 
     if (parseToken(lltok::lparen,
                    "expected '(' in constant ptrauth expression") ||
@@ -4239,11 +4240,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
                    "expected comma in constant ptrauth expression") ||
         parseGlobalTypeAndValue(Key))
       return true;
-    // If present, parse the optional disc/addrdisc.
-    if (EatIfPresent(lltok::comma))
-      if (parseGlobalTypeAndValue(Disc) ||
-          (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)))
-        return true;
+    // If present, parse the optional disc/addrdisc/ds.
+    if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc))
+      return true;
+    if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))
+      return true;
+    if (EatIfPresent(lltok::comma) &&
+        parseGlobalTypeAndValue(DeactivationSymbol))
+      return true;
     if (parseToken(lltok::rparen,
                    "expected ')' in constant ptrauth expression"))
       return true;
@@ -4274,7 +4278,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
       AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
     }
 
-    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc);
+    if (DeactivationSymbol) {
+      if (!DeactivationSymbol->getType()->isPointerTy())
+        return error(
+            ID.Loc, "constant ptrauth deactivation symbol must be a pointer");
+    } else {
+      DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0));
+    }
+
+    ID.ConstantVal =
+        ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol);
     ID.Kind = ValID::t_Constant;
     return false;
   }
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 40e755902b724..c09c3b4f7d38c 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1611,7 +1611,13 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
           if (!Disc)
             return error("ptrauth disc operand must be ConstantInt");
 
-          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
+          auto *DeactivationSymbol =
+              ConstOps.size() > 4 ? ConstOps[4]
+                                  : ConstantPointerNull::get(cast<PointerType>(
+                                        ConstOps[3]->getType()));
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3],
+                                   DeactivationSymbol);
           break;
         }
         case BitcodeConstant::NoCFIOpcode: {
@@ -3811,6 +3817,16 @@ Error BitcodeReader::parseConstants() {
                                    (unsigned)Record[2], (unsigned)Record[3]});
       break;
     }
+    case bitc::CST_CODE_PTRAUTH2: {
+      if (Record.size() < 4)
+        return error("Invalid ptrauth record");
+      // Ptr, Key, Disc, AddrDisc, DeactivationSymbol
+      V = BitcodeConstant::create(
+          Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
+          {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2],
+           (unsigned)Record[3], (unsigned)Record[4]});
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 79547b299a903..5efb321967008 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1630,12 +1630,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
   if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
     Out << "ptrauth (";
 
-    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)
+    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?)
     unsigned NumOpsToWrite = 2;
     if (!CPA->getOperand(2)->isNullValue())
       NumOpsToWrite = 3;
     if (!CPA->getOperand(3)->isNullValue())
       NumOpsToWrite = 4;
+    if (!CPA->getOperand(4)->isNullValue())
+      NumOpsToWrite = 5;
 
     ListSeparator LS;
     for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) {
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index fb659450bfeeb..007d36d19f373 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2072,19 +2072,22 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
 //
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      ConstantInt *Disc, Constant *AddrDisc) {
-  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc};
+                                      ConstantInt *Disc, Constant *AddrDisc,
+                                      Constant *DeactivationSymbol) {
+  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol};
   ConstantPtrAuthKeyType MapKey(ArgVec);
   LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
   return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
 }
 
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
-  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator());
+  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(),
+             getDeactivationSymbol());
 }
 
 ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
-                                 ConstantInt *Disc, Constant *AddrDisc)
+                                 ConstantInt *Disc, Constant *AddrDisc,
+                                 Constant *DeactivationSymbol)
     : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) {
   assert(Ptr->getType()->isPointerTy());
   assert(Key->getBitWidth() == 32);
@@ -2094,6 +2097,7 @@ ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
   setOperand(1, Key);
   setOperand(2, Disc);
   setOperand(3, AddrDisc);
+  setOperand(4, DeactivationSymbol);
 }
 
 /// Remove the constant from the constant table.
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index e5c9622e09927..bf9d8ab952271 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -545,7 +545,8 @@ struct ConstantPtrAuthKeyType {
 
   ConstantPtrAuth *create(TypeClass *Ty) const {
     return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
-                               cast<ConstantInt>(Operands[2]), Operands[3]);
+                               cast<ConstantInt>(Operands[2]), Operands[3],
+                               Operands[4]);
   }
 };
 
diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index f4b03e8cb8aa3..6190ebdac16d4 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -1687,7 +1687,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key,
                                  LLVMValueRef Disc, LLVMValueRef AddrDisc) {
   return wrap(ConstantPtrAuth::get(
       unwrap<Constant>(Ptr), unwrap<ConstantInt>(Key),
-      unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc)));
+      unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc),
+      ConstantPointerNull::get(
+          cast<PointerType>(unwrap<Constant>(AddrDisc)->getType()))));
 }
 
 /*-- Opcode mapping */
diff --git a/llvm/lib/SandboxIR/Constant.cpp b/llvm/lib/SandboxIR/Constant.cpp
index 3e13c935c4281..0a28cf9feeb4d 100644
--- a/llvm/lib/SandboxIR/Constant.cpp
+++ b/llvm/lib/SandboxIR/Constant.cpp
@@ -421,10 +421,12 @@ PointerType *NoCFIValue::getType() const {
 }
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      ConstantInt *Disc, Constant *AddrDisc) {
+                                      ConstantInt *Disc, Constant *AddrDisc,
+                                      Constant *DeactivationSymbol) {
   auto *LLVMC = llvm::ConstantPtrAuth::get(
       cast<llvm::Constant>(Ptr->Val), cast<llvm::ConstantInt>(Key->Val),
-      cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val));
+      cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val),
+      cast<llvm::Constant>(DeactivationSymbol->Val));
   return cast<ConstantPtrAuth>(Ptr->getContext().getOrCreateConstant(LLVMC));
 }
 
@@ -448,6 +450,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const {
       cast<llvm::ConstantPtrAuth>(Val)->getAddrDiscriminator());
 }
 
+Constant *ConstantPtrAuth::getDeactivationSymbol() const {
+  return Ctx.getOrCreateConstant(
+      cast<llvm::ConstantPtrAuth>(Val)->getDeactivationSymbol());
+}
+
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
   auto *LLVMC = cast<llvm::ConstantPtrAuth>(Val)->getWithSameSchema(
       cast<llvm::Constant>(Pointer->Val));
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 135f6cff0f78b..283493408699e 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -195,7 +195,7 @@ class AArch64AsmPrinter : public AsmPrinter {
 
   const MCExpr *emitPAuthRelocationAsIRelative(
       const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
-      bool HasAddressDiversity, bool IsDSOLocal);
+      bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr);
 
   /// tblgen'erated driver function for lowering simple MI->MC
   /// pseudo instructions.
@@ -2270,15 +2270,17 @@ static void emitAddress(MCStreamer &Streamer, MCRegister Reg,
 }
 
 static bool targetSupportsPAuthRelocation(const Triple &TT,
-                                          const MCExpr *Target) {
+                                          const MCExpr *Target,
+                                          const MCExpr *DSExpr) {
   // No released version of glibc supports PAuth relocations.
   if (TT.isOSGlibc())
     return false;
 
   // We emit PAuth constants as IRELATIVE relocations in cases where the
   // constant cannot be represented as a PAuth relocation:
-  // 1) The signed value is not a symbol.
-  return !isa<MCConstantExpr>(Target);
+  // 1) There is a deactivation symbol.
+  // 2) The signed value is not a symbol.
+  return !DSExpr && !isa<MCConstantExpr>(Target);
 }
 
 static bool targetSupportsIRelativeRelocation(const Triple &TT) {
@@ -2295,7 +2297,7 @@ static bool targetSupportsIRelativeRelocation(const Triple &TT) {
 
 const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
     const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
-    bool HasAddressDiversity, bool IsDSOLocal) {
+    bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) {
   const Triple &TT = TM.getTargetTriple();
 
   // We only emit an IRELATIVE relocation if the target supports IRELATIVE and
@@ -2358,6 +2360,18 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
       MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext());
   OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef),
                                *STI);
+
+  if (DSExpr) {
+    auto *PrePACInstExpr =
+        MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext());
+    OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_INST32", DSExpr,
+                                    SMLoc(), *STI);
+  }
+
+  // We need a RET despite the above tail call because the deactivation symbol
+  // may replace it with a NOP.
+  OutStreamer->emitInstruction(MCInstBuilder(AArch64::RET).addReg(AArch64::LR),
+                               *STI);
   OutStreamer->popSection();
 
   return MCSymbolRefExpr::create(IFuncSym, OutStreamer->getContext());
@@ -2388,6 +2402,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
     Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx);
   }
 
+  const MCExpr *DSExpr = nullptr;
+  if (auto *DS = dyn_cast<GlobalValue>(CPA.getDeactivationSymbol())) {
+    if (isa<GlobalAlias>(DS))
+      return Sym;
+    DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx);
+  }
+
   uint64_t KeyID = CPA.getKey()->getZExtValue();
   // We later rely on valid KeyID value in AArch64PACKeyIDToString call from
   // AArch64AuthMCExpr::printImpl, so fail fast.
@@ -2404,9 +2425,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
   // Check if we need to represent this with an IRELATIVE and emit it if so.
   if (auto *IFuncSym = emitPAuthRelocationAsIRelative(
           Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(),
-          BaseGVB && BaseGVB->isDSOLocal()))
+          BaseGVB && BaseGVB->isDSOLocal(), DSExpr))
     return IFuncSym;
 
+  if (DSExpr)
+    report_fatal_error("deactivation symbols unsupported in constant "
+                       "expressions on this target");
+
   // Finally build the complete @AUTH expr.
   return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID),
                                    CPA.hasAddressDiscriminator(), Ctx);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 12dd4cec85f59..58b98d8d93464 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2946,9 +2946,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
         auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
         auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
-        auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
+        auto *Null = ConstantPointerNull::get(Builder.getPtrTy());
         auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
-                                            SignDisc, SignAddrDisc);
+                                            SignDisc, Null, Null);
         replaceInstUsesWith(
             *II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
         return eraseInstFromFunction(*II);
diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp
index 5e50536a99206..320bef6c8f240 100644
--- a/llvm/lib/Transforms/Utils/ValueMapper.cpp
+++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp
@@ -526,8 +526,9 @@ Value *Mapper::mapValue(const Value *V) {
   if (isa<ConstantVector>(C))
     return getVM()[V] = ConstantVector::get(Ops);
   if (isa<ConstantPtrAuth>(C))
-    return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
-                                             cast<ConstantInt>(Ops[2]), Ops[3]);
+    return getVM()[V] =
+               ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
+                                    cast<ConstantInt>(Ops[2]), Ops[3], Ops[4]);
   // If this is a no-operand constant, it must be because the type was remapped.
   if (isa<PoisonValue>(C))
     return getVM()[V] = PoisonValue::get(NewTy);
diff --git a/llvm/unittests/SandboxIR/SandboxIR...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-llvm-ir

Author: Peter Collingbourne (pcc)

Changes

Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:

  • Fix broken test.
  • Add bitcode and IR writer support.
  • Add tests.

Patch is 22.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133537.diff

16 Files Affected:

  • (modified) clang/lib/CodeGen/CGPointerAuth.cpp (+3-3)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+1)
  • (modified) llvm/include/llvm/IR/Constants.h (+9-4)
  • (modified) llvm/include/llvm/SandboxIR/Constant.h (+4-1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+21-8)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+17-1)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+3-1)
  • (modified) llvm/lib/IR/Constants.cpp (+8-4)
  • (modified) llvm/lib/IR/ConstantsContext.h (+2-1)
  • (modified) llvm/lib/IR/Core.cpp (+3-1)
  • (modified) llvm/lib/SandboxIR/Constant.cpp (+9-2)
  • (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+31-6)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+2-2)
  • (modified) llvm/lib/Transforms/Utils/ValueMapper.cpp (+3-2)
  • (modified) llvm/unittests/SandboxIR/SandboxIRTest.cpp (+1-1)
  • (modified) llvm/unittests/Transforms/Utils/ValueMapperTest.cpp (+9-4)
diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp
index 4b032306ead72..2d72fef470af6 100644
--- a/clang/lib/CodeGen/CGPointerAuth.cpp
+++ b/clang/lib/CodeGen/CGPointerAuth.cpp
@@ -308,9 +308,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key,
     IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0);
   }
 
-  return llvm::ConstantPtrAuth::get(Pointer,
-                                    llvm::ConstantInt::get(Int32Ty, Key),
-                                    IntegerDiscriminator, AddressDiscriminator);
+  return llvm::ConstantPtrAuth::get(
+      Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator,
+      AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy));
 }
 
 /// Does a given PointerAuthScheme require us to sign a value
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index ec2535ac85966..13521ba6cd00f 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -431,6 +431,7 @@ enum ConstantsCodes {
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
   CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
   CST_CODE_PTRAUTH = 33,              // [ptr, key, disc, addrdisc]
+  CST_CODE_PTRAUTH2 = 34,             // [ptr, key, disc, addrdisc, DeactivationSymbol]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index a50217078d0ed..45d5352bf06a6 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1022,10 +1022,10 @@ class ConstantPtrAuth final : public Constant {
   friend struct ConstantPtrAuthKeyType;
   friend class Constant;
 
-  constexpr static IntrusiveOperandsAllocMarker AllocMarker{4};
+  constexpr static IntrusiveOperandsAllocMarker AllocMarker{5};
 
   ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc,
-                  Constant *AddrDisc);
+                  Constant *AddrDisc, Constant *DeactivationSymbol);
 
   void *operator new(size_t s) { return User::operator new(s, AllocMarker); }
 
@@ -1035,7 +1035,8 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer signed with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              ConstantInt *Disc, Constant *AddrDisc);
+                              ConstantInt *Disc, Constant *AddrDisc,
+                              Constant *DeactivationSymbol);
 
   /// Produce a new ptrauth expression signing the given value using
   /// the same schema as is stored in one.
@@ -1067,6 +1068,10 @@ class ConstantPtrAuth final : public Constant {
     return !getAddrDiscriminator()->isNullValue();
   }
 
+  Constant *getDeactivationSymbol() const {
+    return cast<Constant>(Op<4>().get());
+  }
+
   /// A constant value for the address discriminator which has special
   /// significance to ctors/dtors lowering. Regular address discrimination can't
   /// be applied for them since uses of llvm.global_{c|d}tors are disallowed
@@ -1094,7 +1099,7 @@ class ConstantPtrAuth final : public Constant {
 
 template <>
 struct OperandTraits<ConstantPtrAuth>
-    : public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
+    : public FixedNumOperandTraits<ConstantPtrAuth, 5> {};
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)
 
diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h
index 17f55e973cd76..5243a9476ac64 100644
--- a/llvm/include/llvm/SandboxIR/Constant.h
+++ b/llvm/include/llvm/SandboxIR/Constant.h
@@ -1096,7 +1096,8 @@ class ConstantPtrAuth final : public Constant {
 public:
   /// Return a pointer signed with the specified parameters.
   static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
-                              ConstantInt *Disc, Constant *AddrDisc);
+                              ConstantInt *Disc, Constant *AddrDisc,
+                              Constant *DeactivationSymbol);
   /// The pointer that is signed in this ptrauth signed pointer.
   Constant *getPointer() const;
 
@@ -1111,6 +1112,8 @@ class ConstantPtrAuth final : public Constant {
   /// the only global-initializer user of the ptrauth signed pointer.
   Constant *getAddrDiscriminator() const;
 
+  Constant *getDeactivationSymbol() const;
+
   /// Whether there is any non-null address discriminator.
   bool hasAddressDiscriminator() const {
     return cast<llvm::ConstantPtrAuth>(Val)->hasAddressDiscriminator();
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 960119bab0933..dfa014aa0bd7d 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4226,11 +4226,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   }
   case lltok::kw_ptrauth: {
     // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
-    //                         (',' i64 <disc> (',' ptr addrdisc)? )? ')'
+    //                         (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)? )? )? ')'
     Lex.Lex();
 
     Constant *Ptr, *Key;
-    Constant *Disc = nullptr, *AddrDisc = nullptr;
+    Constant *Disc = nullptr, *AddrDisc = nullptr,
+             *DeactivationSymbol = nullptr;
 
     if (parseToken(lltok::lparen,
                    "expected '(' in constant ptrauth expression") ||
@@ -4239,11 +4240,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
                    "expected comma in constant ptrauth expression") ||
         parseGlobalTypeAndValue(Key))
       return true;
-    // If present, parse the optional disc/addrdisc.
-    if (EatIfPresent(lltok::comma))
-      if (parseGlobalTypeAndValue(Disc) ||
-          (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)))
-        return true;
+    // If present, parse the optional disc/addrdisc/ds.
+    if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc))
+      return true;
+    if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))
+      return true;
+    if (EatIfPresent(lltok::comma) &&
+        parseGlobalTypeAndValue(DeactivationSymbol))
+      return true;
     if (parseToken(lltok::rparen,
                    "expected ')' in constant ptrauth expression"))
       return true;
@@ -4274,7 +4278,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
       AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
     }
 
-    ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc);
+    if (DeactivationSymbol) {
+      if (!DeactivationSymbol->getType()->isPointerTy())
+        return error(
+            ID.Loc, "constant ptrauth deactivation symbol must be a pointer");
+    } else {
+      DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0));
+    }
+
+    ID.ConstantVal =
+        ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol);
     ID.Kind = ValID::t_Constant;
     return false;
   }
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 40e755902b724..c09c3b4f7d38c 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1611,7 +1611,13 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
           if (!Disc)
             return error("ptrauth disc operand must be ConstantInt");
 
-          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
+          auto *DeactivationSymbol =
+              ConstOps.size() > 4 ? ConstOps[4]
+                                  : ConstantPointerNull::get(cast<PointerType>(
+                                        ConstOps[3]->getType()));
+
+          C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3],
+                                   DeactivationSymbol);
           break;
         }
         case BitcodeConstant::NoCFIOpcode: {
@@ -3811,6 +3817,16 @@ Error BitcodeReader::parseConstants() {
                                    (unsigned)Record[2], (unsigned)Record[3]});
       break;
     }
+    case bitc::CST_CODE_PTRAUTH2: {
+      if (Record.size() < 4)
+        return error("Invalid ptrauth record");
+      // Ptr, Key, Disc, AddrDisc, DeactivationSymbol
+      V = BitcodeConstant::create(
+          Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
+          {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2],
+           (unsigned)Record[3], (unsigned)Record[4]});
+      break;
+    }
     }
 
     assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 79547b299a903..5efb321967008 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1630,12 +1630,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
   if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
     Out << "ptrauth (";
 
-    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)
+    // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?)
     unsigned NumOpsToWrite = 2;
     if (!CPA->getOperand(2)->isNullValue())
       NumOpsToWrite = 3;
     if (!CPA->getOperand(3)->isNullValue())
       NumOpsToWrite = 4;
+    if (!CPA->getOperand(4)->isNullValue())
+      NumOpsToWrite = 5;
 
     ListSeparator LS;
     for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) {
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index fb659450bfeeb..007d36d19f373 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2072,19 +2072,22 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
 //
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      ConstantInt *Disc, Constant *AddrDisc) {
-  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc};
+                                      ConstantInt *Disc, Constant *AddrDisc,
+                                      Constant *DeactivationSymbol) {
+  Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol};
   ConstantPtrAuthKeyType MapKey(ArgVec);
   LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
   return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
 }
 
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
-  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator());
+  return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(),
+             getDeactivationSymbol());
 }
 
 ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
-                                 ConstantInt *Disc, Constant *AddrDisc)
+                                 ConstantInt *Disc, Constant *AddrDisc,
+                                 Constant *DeactivationSymbol)
     : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) {
   assert(Ptr->getType()->isPointerTy());
   assert(Key->getBitWidth() == 32);
@@ -2094,6 +2097,7 @@ ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
   setOperand(1, Key);
   setOperand(2, Disc);
   setOperand(3, AddrDisc);
+  setOperand(4, DeactivationSymbol);
 }
 
 /// Remove the constant from the constant table.
diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h
index e5c9622e09927..bf9d8ab952271 100644
--- a/llvm/lib/IR/ConstantsContext.h
+++ b/llvm/lib/IR/ConstantsContext.h
@@ -545,7 +545,8 @@ struct ConstantPtrAuthKeyType {
 
   ConstantPtrAuth *create(TypeClass *Ty) const {
     return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
-                               cast<ConstantInt>(Operands[2]), Operands[3]);
+                               cast<ConstantInt>(Operands[2]), Operands[3],
+                               Operands[4]);
   }
 };
 
diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp
index f4b03e8cb8aa3..6190ebdac16d4 100644
--- a/llvm/lib/IR/Core.cpp
+++ b/llvm/lib/IR/Core.cpp
@@ -1687,7 +1687,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key,
                                  LLVMValueRef Disc, LLVMValueRef AddrDisc) {
   return wrap(ConstantPtrAuth::get(
       unwrap<Constant>(Ptr), unwrap<ConstantInt>(Key),
-      unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc)));
+      unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc),
+      ConstantPointerNull::get(
+          cast<PointerType>(unwrap<Constant>(AddrDisc)->getType()))));
 }
 
 /*-- Opcode mapping */
diff --git a/llvm/lib/SandboxIR/Constant.cpp b/llvm/lib/SandboxIR/Constant.cpp
index 3e13c935c4281..0a28cf9feeb4d 100644
--- a/llvm/lib/SandboxIR/Constant.cpp
+++ b/llvm/lib/SandboxIR/Constant.cpp
@@ -421,10 +421,12 @@ PointerType *NoCFIValue::getType() const {
 }
 
 ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
-                                      ConstantInt *Disc, Constant *AddrDisc) {
+                                      ConstantInt *Disc, Constant *AddrDisc,
+                                      Constant *DeactivationSymbol) {
   auto *LLVMC = llvm::ConstantPtrAuth::get(
       cast<llvm::Constant>(Ptr->Val), cast<llvm::ConstantInt>(Key->Val),
-      cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val));
+      cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val),
+      cast<llvm::Constant>(DeactivationSymbol->Val));
   return cast<ConstantPtrAuth>(Ptr->getContext().getOrCreateConstant(LLVMC));
 }
 
@@ -448,6 +450,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const {
       cast<llvm::ConstantPtrAuth>(Val)->getAddrDiscriminator());
 }
 
+Constant *ConstantPtrAuth::getDeactivationSymbol() const {
+  return Ctx.getOrCreateConstant(
+      cast<llvm::ConstantPtrAuth>(Val)->getDeactivationSymbol());
+}
+
 ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
   auto *LLVMC = cast<llvm::ConstantPtrAuth>(Val)->getWithSameSchema(
       cast<llvm::Constant>(Pointer->Val));
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 135f6cff0f78b..283493408699e 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -195,7 +195,7 @@ class AArch64AsmPrinter : public AsmPrinter {
 
   const MCExpr *emitPAuthRelocationAsIRelative(
       const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
-      bool HasAddressDiversity, bool IsDSOLocal);
+      bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr);
 
   /// tblgen'erated driver function for lowering simple MI->MC
   /// pseudo instructions.
@@ -2270,15 +2270,17 @@ static void emitAddress(MCStreamer &Streamer, MCRegister Reg,
 }
 
 static bool targetSupportsPAuthRelocation(const Triple &TT,
-                                          const MCExpr *Target) {
+                                          const MCExpr *Target,
+                                          const MCExpr *DSExpr) {
   // No released version of glibc supports PAuth relocations.
   if (TT.isOSGlibc())
     return false;
 
   // We emit PAuth constants as IRELATIVE relocations in cases where the
   // constant cannot be represented as a PAuth relocation:
-  // 1) The signed value is not a symbol.
-  return !isa<MCConstantExpr>(Target);
+  // 1) There is a deactivation symbol.
+  // 2) The signed value is not a symbol.
+  return !DSExpr && !isa<MCConstantExpr>(Target);
 }
 
 static bool targetSupportsIRelativeRelocation(const Triple &TT) {
@@ -2295,7 +2297,7 @@ static bool targetSupportsIRelativeRelocation(const Triple &TT) {
 
 const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
     const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
-    bool HasAddressDiversity, bool IsDSOLocal) {
+    bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) {
   const Triple &TT = TM.getTargetTriple();
 
   // We only emit an IRELATIVE relocation if the target supports IRELATIVE and
@@ -2358,6 +2360,18 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
       MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext());
   OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef),
                                *STI);
+
+  if (DSExpr) {
+    auto *PrePACInstExpr =
+        MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext());
+    OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_INST32", DSExpr,
+                                    SMLoc(), *STI);
+  }
+
+  // We need a RET despite the above tail call because the deactivation symbol
+  // may replace it with a NOP.
+  OutStreamer->emitInstruction(MCInstBuilder(AArch64::RET).addReg(AArch64::LR),
+                               *STI);
   OutStreamer->popSection();
 
   return MCSymbolRefExpr::create(IFuncSym, OutStreamer->getContext());
@@ -2388,6 +2402,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
     Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx);
   }
 
+  const MCExpr *DSExpr = nullptr;
+  if (auto *DS = dyn_cast<GlobalValue>(CPA.getDeactivationSymbol())) {
+    if (isa<GlobalAlias>(DS))
+      return Sym;
+    DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx);
+  }
+
   uint64_t KeyID = CPA.getKey()->getZExtValue();
   // We later rely on valid KeyID value in AArch64PACKeyIDToString call from
   // AArch64AuthMCExpr::printImpl, so fail fast.
@@ -2404,9 +2425,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
   // Check if we need to represent this with an IRELATIVE and emit it if so.
   if (auto *IFuncSym = emitPAuthRelocationAsIRelative(
           Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(),
-          BaseGVB && BaseGVB->isDSOLocal()))
+          BaseGVB && BaseGVB->isDSOLocal(), DSExpr))
     return IFuncSym;
 
+  if (DSExpr)
+    report_fatal_error("deactivation symbols unsupported in constant "
+                       "expressions on this target");
+
   // Finally build the complete @AUTH expr.
   return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID),
                                    CPA.hasAddressDiscriminator(), Ctx);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 12dd4cec85f59..58b98d8d93464 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2946,9 +2946,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
         auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
         auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
-        auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
+        auto *Null = ConstantPointerNull::get(Builder.getPtrTy());
         auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
-                                            SignDisc, SignAddrDisc);
+                                            SignDisc, Null, Null);
         replaceInstUsesWith(
             *II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
         return eraseInstFromFunction(*II);
diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp
index 5e50536a99206..320bef6c8f240 100644
--- a/llvm/lib/Transforms/Utils/ValueMapper.cpp
+++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp
@@ -526,8 +526,9 @@ Value *Mapper::mapValue(const Value *V) {
   if (isa<ConstantVector>(C))
     return getVM()[V] = ConstantVector::get(Ops);
   if (isa<ConstantPtrAuth>(C))
-    return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
-                                             cast<ConstantInt>(Ops[2]), Ops[3]);
+    return getVM()[V] =
+               ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
+                                    cast<ConstantInt>(Ops[2]), Ops[3], Ops[4]);
   // If this is a no-operand constant, it must be because the type was remapped.
   if (isa<PoisonValue>(C))
     return getVM()[V] = PoisonValue::get(NewTy);
diff --git a/llvm/unittests/SandboxIR/SandboxIR...
[truncated]

Copy link

github-actions bot commented Mar 28, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- clang/lib/CodeGen/CGPointerAuth.cpp llvm/include/llvm/Bitcode/LLVMBitCodes.h llvm/include/llvm/IR/Constants.h llvm/include/llvm/SandboxIR/Constant.h llvm/lib/AsmParser/LLParser.cpp llvm/lib/Bitcode/Reader/BitcodeReader.cpp llvm/lib/IR/AsmWriter.cpp llvm/lib/IR/Constants.cpp llvm/lib/IR/ConstantsContext.h llvm/lib/IR/Core.cpp llvm/lib/SandboxIR/Constant.cpp llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp llvm/lib/Transforms/Utils/ValueMapper.cpp llvm/unittests/SandboxIR/SandboxIRTest.cpp llvm/unittests/Transforms/Utils/ValueMapperTest.cpp
View the diff from clang-format here.
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index c414ed4e3..f6e1976b5 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -437,7 +437,7 @@ enum ConstantsCodes {
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
   CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
   CST_CODE_PTRAUTH = 33,              // [ptr, key, disc, addrdisc]
-  CST_CODE_PTRAUTH2 = 34,             // [ptr, key, disc, addrdisc, DeactivationSymbol]
+  CST_CODE_PTRAUTH2 = 34, // [ptr, key, disc, addrdisc, DeactivationSymbol]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 55b89ae9f..0891efef8 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4218,7 +4218,8 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   }
   case lltok::kw_ptrauth: {
     // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
-    //                         (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)? )? )? ')'
+    //                         (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)?
+    //                         )? )? ')'
     Lex.Lex();
 
     Constant *Ptr, *Key;
@@ -4272,10 +4273,11 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
 
     if (DeactivationSymbol) {
       if (!DeactivationSymbol->getType()->isPointerTy())
-        return error(
-            ID.Loc, "constant ptrauth deactivation symbol must be a pointer");
+        return error(ID.Loc,
+                     "constant ptrauth deactivation symbol must be a pointer");
     } else {
-      DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0));
+      DeactivationSymbol =
+          ConstantPointerNull::get(PointerType::get(Context, 0));
     }
 
     ID.ConstantVal =

pcc added a commit to pcc/llvm-project that referenced this pull request Apr 3, 2025
Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:
- Fix broken test.
- Add bitcode and IR writer support.
- Add tests.

Pull Request: llvm#133537
pcc added a commit to pcc/llvm-project that referenced this pull request Apr 4, 2025
Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:
- Fix broken test.
- Add bitcode and IR writer support.
- Add tests.

Pull Request: llvm#133537
pcc added 2 commits May 12, 2025 21:38
Created using spr 1.3.6-beta.1
Created using spr 1.3.6-beta.1
pcc added a commit to pcc/llvm-project that referenced this pull request May 24, 2025
Deactivation symbol operands are supported in the code generator by
building on the previously added support for IRELATIVE relocations.

TODO:
- Fix broken test.
- Add bitcode and IR writer support.
- Add tests.

Pull Request: llvm#133537
Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing verifier checks?

unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc)));
unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc),
ConstantPointerNull::get(
cast<PointerType>(unwrap<Constant>(AddrDisc)->getType()))));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to extend the C API to give access to this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reckon that could be done in a followup if anyone needs it.

@pcc
Copy link
Contributor Author

pcc commented May 28, 2025

Missing verifier checks?

Right, I guess the new operand can either be null (no deactivation symbol) or a globalvariable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 clang:codegen IR generation bugs: mangling, exceptions, etc. clang Clang issues not falling into any other category llvm:instcombine llvm:ir llvm:SandboxIR llvm:transforms
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

3 participants