Skip to content

Add deactivation symbol operand to ConstantPtrAuth. #133537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: users/pcc/spr/main.add-deactivation-symbol-operand-to-constantptrauth
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions clang/lib/CodeGen/CGPointerAuth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key,
IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0);
}

return llvm::ConstantPtrAuth::get(Pointer,
llvm::ConstantInt::get(Int32Ty, Key),
IntegerDiscriminator, AddressDiscriminator);
return llvm::ConstantPtrAuth::get(
Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator,
AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy));
}

/// Does a given PointerAuthScheme require us to sign a value
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Bitcode/LLVMBitCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ enum ConstantsCodes {
CST_CODE_CE_GEP_WITH_INRANGE = 31, // [opty, flags, range, n x operands]
CST_CODE_CE_GEP = 32, // [opty, flags, n x operands]
CST_CODE_PTRAUTH = 33, // [ptr, key, disc, addrdisc]
CST_CODE_PTRAUTH2 = 34, // [ptr, key, disc, addrdisc, DeactivationSymbol]
};

/// CastOpcodes - These are values used in the bitcode files to encode which
Expand Down
13 changes: 9 additions & 4 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,10 @@ class ConstantPtrAuth final : public Constant {
friend struct ConstantPtrAuthKeyType;
friend class Constant;

constexpr static IntrusiveOperandsAllocMarker AllocMarker{4};
constexpr static IntrusiveOperandsAllocMarker AllocMarker{5};

ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc,
Constant *AddrDisc);
Constant *AddrDisc, Constant *DeactivationSymbol);

void *operator new(size_t s) { return User::operator new(s, AllocMarker); }

Expand All @@ -1044,7 +1044,8 @@ class ConstantPtrAuth final : public Constant {
public:
/// Return a pointer signed with the specified parameters.
static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc);
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol);

/// Produce a new ptrauth expression signing the given value using
/// the same schema as is stored in one.
Expand Down Expand Up @@ -1076,6 +1077,10 @@ class ConstantPtrAuth final : public Constant {
return !getAddrDiscriminator()->isNullValue();
}

Constant *getDeactivationSymbol() const {
return cast<Constant>(Op<4>().get());
}

/// A constant value for the address discriminator which has special
/// significance to ctors/dtors lowering. Regular address discrimination can't
/// be applied for them since uses of llvm.global_{c|d}tors are disallowed
Expand Down Expand Up @@ -1103,7 +1108,7 @@ class ConstantPtrAuth final : public Constant {

template <>
struct OperandTraits<ConstantPtrAuth>
: public FixedNumOperandTraits<ConstantPtrAuth, 4> {};
: public FixedNumOperandTraits<ConstantPtrAuth, 5> {};

DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant)

Expand Down
5 changes: 4 additions & 1 deletion llvm/include/llvm/SandboxIR/Constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,8 @@ class ConstantPtrAuth final : public Constant {
public:
/// Return a pointer signed with the specified parameters.
static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc);
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol);
/// The pointer that is signed in this ptrauth signed pointer.
Constant *getPointer() const;

Expand All @@ -1399,6 +1400,8 @@ class ConstantPtrAuth final : public Constant {
/// the only global-initializer user of the ptrauth signed pointer.
Constant *getAddrDiscriminator() const;

Constant *getDeactivationSymbol() const;

/// Whether there is any non-null address discriminator.
bool hasAddressDiscriminator() const {
return cast<llvm::ConstantPtrAuth>(Val)->hasAddressDiscriminator();
Expand Down
29 changes: 21 additions & 8 deletions llvm/lib/AsmParser/LLParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4218,11 +4218,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
}
case lltok::kw_ptrauth: {
// ValID ::= 'ptrauth' '(' ptr @foo ',' i32 <key>
// (',' i64 <disc> (',' ptr addrdisc)? )? ')'
// (',' i64 <disc> (',' ptr addrdisc (',' ptr ds)? )? )? ')'
Lex.Lex();

Constant *Ptr, *Key;
Constant *Disc = nullptr, *AddrDisc = nullptr;
Constant *Disc = nullptr, *AddrDisc = nullptr,
*DeactivationSymbol = nullptr;

if (parseToken(lltok::lparen,
"expected '(' in constant ptrauth expression") ||
Expand All @@ -4231,11 +4232,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
"expected comma in constant ptrauth expression") ||
parseGlobalTypeAndValue(Key))
return true;
// If present, parse the optional disc/addrdisc.
if (EatIfPresent(lltok::comma))
if (parseGlobalTypeAndValue(Disc) ||
(EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)))
return true;
// If present, parse the optional disc/addrdisc/ds.
if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc))
return true;
if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))
return true;
if (EatIfPresent(lltok::comma) &&
parseGlobalTypeAndValue(DeactivationSymbol))
return true;
if (parseToken(lltok::rparen,
"expected ')' in constant ptrauth expression"))
return true;
Expand Down Expand Up @@ -4266,7 +4270,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0));
}

ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc);
if (DeactivationSymbol) {
if (!DeactivationSymbol->getType()->isPointerTy())
return error(
ID.Loc, "constant ptrauth deactivation symbol must be a pointer");
} else {
DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0));
}

ID.ConstantVal =
ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol);
ID.Kind = ValID::t_Constant;
return false;
}
Expand Down
18 changes: 17 additions & 1 deletion llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,13 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
if (!Disc)
return error("ptrauth disc operand must be ConstantInt");

C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]);
auto *DeactivationSymbol =
ConstOps.size() > 4 ? ConstOps[4]
: ConstantPointerNull::get(cast<PointerType>(
ConstOps[3]->getType()));

C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3],
DeactivationSymbol);
break;
}
case BitcodeConstant::NoCFIOpcode: {
Expand Down Expand Up @@ -3801,6 +3807,16 @@ Error BitcodeReader::parseConstants() {
(unsigned)Record[2], (unsigned)Record[3]});
break;
}
case bitc::CST_CODE_PTRAUTH2: {
if (Record.size() < 4)
return error("Invalid ptrauth record");
// Ptr, Key, Disc, AddrDisc, DeactivationSymbol
V = BitcodeConstant::create(
Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode,
{(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2],
(unsigned)Record[3], (unsigned)Record[4]});
break;
}
}

assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID");
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/IR/AsmWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1658,12 +1658,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
if (const ConstantPtrAuth *CPA = dyn_cast<ConstantPtrAuth>(CV)) {
Out << "ptrauth (";

// ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?)
// ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?)
unsigned NumOpsToWrite = 2;
if (!CPA->getOperand(2)->isNullValue())
NumOpsToWrite = 3;
if (!CPA->getOperand(3)->isNullValue())
NumOpsToWrite = 4;
if (!CPA->getOperand(4)->isNullValue())
NumOpsToWrite = 5;

ListSeparator LS;
for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) {
Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2056,19 +2056,22 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) {
//

ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc) {
Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc};
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol) {
Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol};
ConstantPtrAuthKeyType MapKey(ArgVec);
LLVMContextImpl *pImpl = Ptr->getContext().pImpl;
return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey);
}

ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator());
return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(),
getDeactivationSymbol());
}

ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc)
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol)
: Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) {
assert(Ptr->getType()->isPointerTy());
assert(Key->getBitWidth() == 32);
Expand All @@ -2078,6 +2081,7 @@ ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key,
setOperand(1, Key);
setOperand(2, Disc);
setOperand(3, AddrDisc);
setOperand(4, DeactivationSymbol);
}

/// Remove the constant from the constant table.
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/IR/ConstantsContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ struct ConstantPtrAuthKeyType {

ConstantPtrAuth *create(TypeClass *Ty) const {
return new ConstantPtrAuth(Operands[0], cast<ConstantInt>(Operands[1]),
cast<ConstantInt>(Operands[2]), Operands[3]);
cast<ConstantInt>(Operands[2]), Operands[3],
Operands[4]);
}
};

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/IR/Core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key,
LLVMValueRef Disc, LLVMValueRef AddrDisc) {
return wrap(ConstantPtrAuth::get(
unwrap<Constant>(Ptr), unwrap<ConstantInt>(Key),
unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc)));
unwrap<ConstantInt>(Disc), unwrap<Constant>(AddrDisc),
ConstantPointerNull::get(
cast<PointerType>(unwrap<Constant>(AddrDisc)->getType()))));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to extend the C API to give access to this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reckon that could be done in a followup if anyone needs it.

}

/*-- Opcode mapping */
Expand Down
11 changes: 9 additions & 2 deletions llvm/lib/SandboxIR/Constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,12 @@ PointerType *NoCFIValue::getType() const {
}

ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key,
ConstantInt *Disc, Constant *AddrDisc) {
ConstantInt *Disc, Constant *AddrDisc,
Constant *DeactivationSymbol) {
auto *LLVMC = llvm::ConstantPtrAuth::get(
cast<llvm::Constant>(Ptr->Val), cast<llvm::ConstantInt>(Key->Val),
cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val));
cast<llvm::ConstantInt>(Disc->Val), cast<llvm::Constant>(AddrDisc->Val),
cast<llvm::Constant>(DeactivationSymbol->Val));
return cast<ConstantPtrAuth>(Ptr->getContext().getOrCreateConstant(LLVMC));
}

Expand All @@ -470,6 +472,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const {
cast<llvm::ConstantPtrAuth>(Val)->getAddrDiscriminator());
}

Constant *ConstantPtrAuth::getDeactivationSymbol() const {
return Ctx.getOrCreateConstant(
cast<llvm::ConstantPtrAuth>(Val)->getDeactivationSymbol());
}

ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const {
auto *LLVMC = cast<llvm::ConstantPtrAuth>(Val)->getWithSameSchema(
cast<llvm::Constant>(Pointer->Val));
Expand Down
37 changes: 31 additions & 6 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class AArch64AsmPrinter : public AsmPrinter {

const MCExpr *emitPAuthRelocationAsIRelative(
const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
bool HasAddressDiversity, bool IsDSOLocal);
bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr);

/// tblgen'erated driver function for lowering simple MI->MC
/// pseudo instructions.
Expand Down Expand Up @@ -2301,15 +2301,17 @@ static void emitAddress(MCStreamer &Streamer, MCRegister Reg,
}

static bool targetSupportsPAuthRelocation(const Triple &TT,
const MCExpr *Target) {
const MCExpr *Target,
const MCExpr *DSExpr) {
// No released version of glibc supports PAuth relocations.
if (TT.isOSGlibc())
return false;

// We emit PAuth constants as IRELATIVE relocations in cases where the
// constant cannot be represented as a PAuth relocation:
// 1) The signed value is not a symbol.
return !isa<MCConstantExpr>(Target);
// 1) There is a deactivation symbol.
// 2) The signed value is not a symbol.
return !DSExpr && !isa<MCConstantExpr>(Target);
}

static bool targetSupportsIRelativeRelocation(const Triple &TT) {
Expand All @@ -2326,7 +2328,7 @@ static bool targetSupportsIRelativeRelocation(const Triple &TT) {

const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID,
bool HasAddressDiversity, bool IsDSOLocal) {
bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) {
const Triple &TT = TM.getTargetTriple();

// We only emit an IRELATIVE relocation if the target supports IRELATIVE and
Expand Down Expand Up @@ -2388,6 +2390,18 @@ const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative(
MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext());
OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef),
*STI);

if (DSExpr) {
auto *PrePACInstExpr =
MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext());
OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_INST32", DSExpr,
SMLoc(), *STI);
}

// We need a RET despite the above tail call because the deactivation symbol
// may replace it with a NOP.
OutStreamer->emitInstruction(MCInstBuilder(AArch64::RET).addReg(AArch64::LR),
*STI);
OutStreamer->popSection();

return MCSymbolRefExpr::create(IRelativeSym, AArch64MCExpr::VK_FUNCINIT,
Expand Down Expand Up @@ -2419,6 +2433,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx);
}

const MCExpr *DSExpr = nullptr;
if (auto *DS = dyn_cast<GlobalValue>(CPA.getDeactivationSymbol())) {
if (isa<GlobalAlias>(DS))
return Sym;
DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx);
}

uint64_t KeyID = CPA.getKey()->getZExtValue();
// We later rely on valid KeyID value in AArch64PACKeyIDToString call from
// AArch64AuthMCExpr::printImpl, so fail fast.
Expand All @@ -2435,9 +2456,13 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
// Check if we need to represent this with an IRELATIVE and emit it if so.
if (auto *IFuncSym = emitPAuthRelocationAsIRelative(
Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(),
BaseGVB && BaseGVB->isDSOLocal()))
BaseGVB && BaseGVB->isDSOLocal(), DSExpr))
return IFuncSym;

if (DSExpr)
report_fatal_error("deactivation symbols unsupported in constant "
"expressions on this target");

// Finally build the complete @AUTH expr.
return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID),
CPA.hasAddressDiscriminator(), Ctx);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2979,9 +2979,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
auto *Null = ConstantPointerNull::get(Builder.getPtrTy());
auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
SignDisc, SignAddrDisc);
SignDisc, Null, Null);
replaceInstUsesWith(
*II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
return eraseInstFromFunction(*II);
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Utils/ValueMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,9 @@ Value *Mapper::mapValue(const Value *V) {
if (isa<ConstantVector>(C))
return getVM()[V] = ConstantVector::get(Ops);
if (isa<ConstantPtrAuth>(C))
return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
cast<ConstantInt>(Ops[2]), Ops[3]);
return getVM()[V] =
ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]),
cast<ConstantInt>(Ops[2]), Ops[3], Ops[4]);
// If this is a no-operand constant, it must be because the type was remapped.
if (isa<PoisonValue>(C))
return getVM()[V] = PoisonValue::get(NewTy);
Expand Down
2 changes: 1 addition & 1 deletion llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,7 @@ define ptr @foo() {
// Check get(), getKey(), getDiscriminator(), getAddrDiscriminator().
auto *NewPtrAuth = sandboxir::ConstantPtrAuth::get(
&F, PtrAuth->getKey(), PtrAuth->getDiscriminator(),
PtrAuth->getAddrDiscriminator());
PtrAuth->getAddrDiscriminator(), PtrAuth->getDeactivationSymbol());
EXPECT_EQ(NewPtrAuth, PtrAuth);
// Check hasAddressDiscriminator().
EXPECT_EQ(PtrAuth->hasAddressDiscriminator(),
Expand Down
Loading
Loading