Skip to content

Commit cad9057

Browse files
committed
Add IR and codegen support for deactivation symbols.
Deactivation symbols are a mechanism for allowing object files to disable specific instructions in other object files at link time. The initial use case is for pointer field protection. For more information, see the RFC: https://discourse.llvm.org/t/rfc-deactivation-symbols/85556 TODO: - Add tests. Pull Request: llvm#133536
1 parent 85a1228 commit cad9057

25 files changed

+197
-24
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ class CallLowering {
161161

162162
/// True if this call results in convergent operations.
163163
bool IsConvergent = true;
164+
165+
GlobalValue *DeactivationSymbol = nullptr;
164166
};
165167

166168
/// Argument handling is mostly uniform between the four places that

llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct MachineIRBuilderState {
5555
MDNode *PCSections = nullptr;
5656
/// MMRA Metadata to be set on any instruction we create.
5757
MDNode *MMRA = nullptr;
58+
Value *DS = nullptr;
5859

5960
/// \name Fields describing the insertion point.
6061
/// @{
@@ -368,6 +369,7 @@ class MachineIRBuilder {
368369
State.II = MI.getIterator();
369370
setPCSections(MI.getPCSections());
370371
setMMRAMetadata(MI.getMMRAMetadata());
372+
setDeactivationSymbol(MI.getDeactivationSymbol());
371373
}
372374
/// @}
373375

@@ -404,6 +406,9 @@ class MachineIRBuilder {
404406
/// Set the PC sections metadata to \p MD for all the next build instructions.
405407
void setMMRAMetadata(MDNode *MMRA) { State.MMRA = MMRA; }
406408

409+
Value *getDeactivationSymbol() { return State.DS; }
410+
void setDeactivationSymbol(Value *DS) { State.DS = DS; }
411+
407412
/// Get the current instruction's MMRA metadata.
408413
MDNode *getMMRAMetadata() { return State.MMRA; }
409414

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,8 @@ enum NodeType {
15241524
// Outputs: Output Chain
15251525
CLEAR_CACHE,
15261526

1527+
DEACTIVATION_SYMBOL,
1528+
15271529
/// BUILTIN_OP_END - This must be the last enum value in this list.
15281530
/// The target-specific pre-isel opcode values start here.
15291531
BUILTIN_OP_END

llvm/include/llvm/CodeGen/MachineInstr.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,15 @@ class MachineInstr
878878
return nullptr;
879879
}
880880

881+
// FIXME: Move to Info.
882+
Value *DeactivationSymbol = nullptr;
883+
Value *getDeactivationSymbol() const {
884+
return DeactivationSymbol;
885+
}
886+
void setDeactivationSymbol(MachineFunction &MF, Value *DeactivationSymbol) {
887+
this->DeactivationSymbol = DeactivationSymbol;
888+
}
889+
881890
/// Helper to extract a CFI type hash if one has been added.
882891
uint32_t getCFIType() const {
883892
if (!Info)

llvm/include/llvm/CodeGen/MachineInstrBuilder.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,29 +69,44 @@ enum {
6969
} // end namespace RegState
7070

7171
/// Set of metadata that should be preserved when using BuildMI(). This provides
72-
/// a more convenient way of preserving DebugLoc, PCSections and MMRA.
72+
/// a more convenient way of preserving certain data from the original
73+
/// instruction.
7374
class MIMetadata {
7475
public:
7576
MIMetadata() = default;
76-
MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr)
77-
: DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA) {}
77+
MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr,
78+
Value *DeactivationSymbol = nullptr)
79+
: DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA),
80+
DeactivationSymbol(DeactivationSymbol) {}
7881
MIMetadata(const DILocation *DI, MDNode *PCSections = nullptr,
7982
MDNode *MMRA = nullptr)
8083
: DL(DI), PCSections(PCSections), MMRA(MMRA) {}
8184
explicit MIMetadata(const Instruction &From)
8285
: DL(From.getDebugLoc()),
83-
PCSections(From.getMetadata(LLVMContext::MD_pcsections)) {}
86+
PCSections(From.getMetadata(LLVMContext::MD_pcsections)),
87+
DeactivationSymbol(getDeactivationSymbol(&From)) {}
8488
explicit MIMetadata(const MachineInstr &From)
85-
: DL(From.getDebugLoc()), PCSections(From.getPCSections()) {}
89+
: DL(From.getDebugLoc()), PCSections(From.getPCSections()),
90+
DeactivationSymbol(From.getDeactivationSymbol()) {}
8691

8792
const DebugLoc &getDL() const { return DL; }
8893
MDNode *getPCSections() const { return PCSections; }
8994
MDNode *getMMRAMetadata() const { return MMRA; }
95+
Value *getDeactivationSymbol() const { return DeactivationSymbol; }
9096

9197
private:
9298
DebugLoc DL;
9399
MDNode *PCSections = nullptr;
94100
MDNode *MMRA = nullptr;
101+
Value *DeactivationSymbol = nullptr;
102+
103+
static inline Value *getDeactivationSymbol(const Instruction *I) {
104+
if (auto *CB = dyn_cast<CallBase>(I))
105+
if (auto Bundle =
106+
CB->getOperandBundle(llvm::LLVMContext::OB_deactivation_symbol))
107+
return Bundle->Inputs[0].get();
108+
return nullptr;
109+
}
95110
};
96111

97112
class MachineInstrBuilder {
@@ -348,6 +363,8 @@ class MachineInstrBuilder {
348363
MI->setPCSections(*MF, MIMD.getPCSections());
349364
if (MIMD.getMMRAMetadata())
350365
MI->setMMRAMetadata(*MF, MIMD.getMMRAMetadata());
366+
if (MIMD.getDeactivationSymbol())
367+
MI->setDeactivationSymbol(*MF, MIMD.getDeactivationSymbol());
351368
return *this;
352369
}
353370

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,7 @@ class SelectionDAG {
752752
int64_t offset = 0, unsigned TargetFlags = 0) {
753753
return getGlobalAddress(GV, DL, VT, offset, true, TargetFlags);
754754
}
755+
SDValue getDeactivationSymbol(const GlobalValue *GV);
755756
SDValue getFrameIndex(int FI, EVT VT, bool isTarget = false);
756757
SDValue getTargetFrameIndex(int FI, EVT VT) {
757758
return getFrameIndex(FI, VT, true);

llvm/include/llvm/CodeGen/SelectionDAGISel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ class SelectionDAGISel {
152152
OPC_RecordChild7,
153153
OPC_RecordMemRef,
154154
OPC_CaptureGlueInput,
155+
OPC_CaptureDeactivationSymbol,
155156
OPC_MoveChild,
156157
OPC_MoveChild0,
157158
OPC_MoveChild1,

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,23 @@ class GlobalAddressSDNode : public SDNode {
19281928
}
19291929
};
19301930

1931+
class DeactivationSymbolSDNode : public SDNode {
1932+
friend class SelectionDAG;
1933+
1934+
const GlobalValue *TheGlobal;
1935+
1936+
DeactivationSymbolSDNode(const GlobalValue *GV, SDVTList VTs)
1937+
: SDNode(ISD::DEACTIVATION_SYMBOL, 0, DebugLoc(), VTs),
1938+
TheGlobal(GV) {}
1939+
1940+
public:
1941+
const GlobalValue *getGlobal() const { return TheGlobal; }
1942+
1943+
static bool classof(const SDNode *N) {
1944+
return N->getOpcode() == ISD::DEACTIVATION_SYMBOL;
1945+
}
1946+
};
1947+
19311948
class FrameIndexSDNode : public SDNode {
19321949
friend class SelectionDAG;
19331950

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4612,6 +4612,7 @@ class TargetLowering : public TargetLoweringBase {
46124612
SmallVector<SDValue, 4> InVals;
46134613
const ConstantInt *CFIType = nullptr;
46144614
SDValue ConvergenceControlToken;
4615+
GlobalValue *DeactivationSymbol = nullptr;
46154616

46164617
std::optional<PtrAuthInfo> PAI;
46174618

@@ -4757,6 +4758,11 @@ class TargetLowering : public TargetLoweringBase {
47574758
return *this;
47584759
}
47594760

4761+
CallLoweringInfo &setDeactivationSymbol(GlobalValue *Sym) {
4762+
DeactivationSymbol = Sym;
4763+
return *this;
4764+
}
4765+
47604766
ArgListTy &getArgs() {
47614767
return Args;
47624768
}

llvm/include/llvm/IR/LLVMContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class LLVMContext {
9696
OB_ptrauth = 7, // "ptrauth"
9797
OB_kcfi = 8, // "kcfi"
9898
OB_convergencectrl = 9, // "convergencectrl"
99+
OB_deactivation_symbol = 10, // "deactivation-symbol"
99100
};
100101

101102
/// getMDKindID - Return a unique non-zero ID for the specified metadata kind.

llvm/include/llvm/Target/Target.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ class Instruction : InstructionEncoding {
682682
// If so, make sure to override
683683
// TargetInstrInfo::getInsertSubregLikeInputs.
684684
bit variadicOpsAreDefs = false; // Are variadic operands definitions?
685+
bit supportsDeactivationSymbol = false;
685686

686687
// Does the instruction have side effects that are not captured by any
687688
// operands of the instruction or other flags?

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
195195
assert(Info.CFIType->getType()->isIntegerTy(32) && "Invalid CFI type");
196196
}
197197

198+
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) {
199+
Info.DeactivationSymbol = cast<GlobalValue>(Bundle->Inputs[0]);
200+
}
201+
198202
Info.CB = &CB;
199203
Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
200204
Info.CallConv = CallConv;

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2861,6 +2861,9 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
28612861
}
28622862
}
28632863

2864+
if (auto Bundle = CI.getOperandBundle(LLVMContext::OB_deactivation_symbol))
2865+
MIB->setDeactivationSymbol(*MF, Bundle->Inputs[0].get());
2866+
28642867
return true;
28652868
}
28662869

llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ void MachineIRBuilder::setMF(MachineFunction &MF) {
3838
//------------------------------------------------------------------------------
3939

4040
MachineInstrBuilder MachineIRBuilder::buildInstrNoInsert(unsigned Opcode) {
41-
return BuildMI(getMF(), {getDL(), getPCSections(), getMMRAMetadata()},
42-
getTII().get(Opcode));
41+
return BuildMI(
42+
getMF(),
43+
{getDL(), getPCSections(), getMMRAMetadata(), getDeactivationSymbol()},
44+
getTII().get(Opcode));
4345
}
4446

4547
MachineInstrBuilder MachineIRBuilder::insertInstr(MachineInstrBuilder MIB) {

llvm/lib/CodeGen/MachineInstr.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ MachineInstr::MachineInstr(MachineFunction &MF, const MCInstrDesc &TID,
120120
MachineInstr::MachineInstr(MachineFunction &MF, const MachineInstr &MI)
121121
: MCID(&MI.getDesc()), NumOperands(0), Flags(0), AsmPrinterFlags(0),
122122
Info(MI.Info), DbgLoc(MI.getDebugLoc()), DebugInstrNum(0),
123-
Opcode(MI.getOpcode()) {
123+
Opcode(MI.getOpcode()), DeactivationSymbol(MI.getDeactivationSymbol()) {
124124
assert(DbgLoc.hasTrivialDestructor() && "Expected trivial destructor");
125125

126126
CapOperands = OperandCapacity::get(MI.getNumOperands());
@@ -728,6 +728,8 @@ bool MachineInstr::isIdenticalTo(const MachineInstr &Other,
728728
// Call instructions with different CFI types are not identical.
729729
if (isCall() && getCFIType() != Other.getCFIType())
730730
return false;
731+
if (getDeactivationSymbol() != Other.getDeactivationSymbol())
732+
return false;
731733

732734
return true;
733735
}
@@ -2009,6 +2011,8 @@ void MachineInstr::print(raw_ostream &OS, ModuleSlotTracker &MST,
20092011
OS << ',';
20102012
OS << " cfi-type " << CFIType;
20112013
}
2014+
if (getDeactivationSymbol())
2015+
OS << ", deactivation-symbol " << getDeactivationSymbol()->getName();
20122016

20132017
if (DebugInstrNum) {
20142018
if (!FirstOp)

llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
#include "InstrEmitter.h"
1616
#include "SDNodeDbgValue.h"
1717
#include "llvm/BinaryFormat/Dwarf.h"
18+
#include "llvm/CodeGen/ISDOpcodes.h"
1819
#include "llvm/CodeGen/MachineConstantPool.h"
1920
#include "llvm/CodeGen/MachineFunction.h"
2021
#include "llvm/CodeGen/MachineInstrBuilder.h"
2122
#include "llvm/CodeGen/MachineRegisterInfo.h"
23+
#include "llvm/CodeGen/SelectionDAGNodes.h"
2224
#include "llvm/CodeGen/StackMaps.h"
2325
#include "llvm/CodeGen/TargetInstrInfo.h"
2426
#include "llvm/CodeGen/TargetLowering.h"
@@ -61,6 +63,8 @@ static unsigned countOperands(SDNode *Node, unsigned NumExpUses,
6163
unsigned N = Node->getNumOperands();
6264
while (N && Node->getOperand(N - 1).getValueType() == MVT::Glue)
6365
--N;
66+
if (N && Node->getOperand(N - 1).getOpcode() == ISD::DEACTIVATION_SYMBOL)
67+
--N; // Ignore deactivation symbol if it exists.
6468
if (N && Node->getOperand(N - 1).getValueType() == MVT::Other)
6569
--N; // Ignore chain if it exists.
6670

@@ -1219,15 +1223,23 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
12191223
}
12201224
}
12211225

1222-
if (SDNode *GluedNode = Node->getGluedNode()) {
1223-
// FIXME: Possibly iterate over multiple glue nodes?
1224-
if (GluedNode->getOpcode() ==
1225-
~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) {
1226-
Register VReg = getVR(GluedNode->getOperand(0), VRBaseMap);
1227-
MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false,
1228-
/*isImp=*/true);
1229-
MIB->addOperand(MO);
1230-
}
1226+
unsigned Op = Node->getNumOperands();
1227+
if (Op != 0 && Node->getOperand(Op - 1)->getOpcode() ==
1228+
~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) {
1229+
Register VReg = getVR(Node->getOperand(Op - 1)->getOperand(0), VRBaseMap);
1230+
MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false,
1231+
/*isImp=*/true);
1232+
MIB->addOperand(MO);
1233+
Op--;
1234+
}
1235+
1236+
if (Op != 0 &&
1237+
Node->getOperand(Op - 1)->getOpcode() == ISD::DEACTIVATION_SYMBOL) {
1238+
MI->setDeactivationSymbol(
1239+
*MF, const_cast<GlobalValue *>(
1240+
cast<DeactivationSymbolSDNode>(Node->getOperand(Op - 1))
1241+
->getGlobal()));
1242+
Op--;
12311243
}
12321244

12331245
// Run post-isel target hook to adjust this instruction if needed.
@@ -1248,7 +1260,8 @@ EmitSpecialNode(SDNode *Node, bool IsClone, bool IsCloned,
12481260
llvm_unreachable("This target-independent node should have been selected!");
12491261
case ISD::EntryToken:
12501262
case ISD::MERGE_VALUES:
1251-
case ISD::TokenFactor: // fall thru
1263+
case ISD::TokenFactor:
1264+
case ISD::DEACTIVATION_SYMBOL:
12521265
break;
12531266
case ISD::CopyToReg: {
12541267
Register DestReg = cast<RegisterSDNode>(Node->getOperand(1))->getReg();

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,21 @@ SDValue SelectionDAG::getGlobalAddress(const GlobalValue *GV, const SDLoc &DL,
19131913
return SDValue(N, 0);
19141914
}
19151915

1916+
SDValue SelectionDAG::getDeactivationSymbol(const GlobalValue *GV) {
1917+
SDVTList VTs = getVTList(MVT::Untyped);
1918+
FoldingSetNodeID ID;
1919+
AddNodeIDNode(ID, ISD::DEACTIVATION_SYMBOL, VTs, {});
1920+
ID.AddPointer(GV);
1921+
void *IP = nullptr;
1922+
if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP))
1923+
return SDValue(E, 0);
1924+
1925+
auto *N = newSDNode<DeactivationSymbolSDNode>(GV, VTs);
1926+
CSEMap.InsertNode(N, IP);
1927+
InsertNode(N);
1928+
return SDValue(N, 0);
1929+
}
1930+
19161931
SDValue SelectionDAG::getFrameIndex(int FI, EVT VT, bool isTarget) {
19171932
unsigned Opc = isTarget ? ISD::TargetFrameIndex : ISD::FrameIndex;
19181933
SDVTList VTs = getVTList(VT);

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "llvm/CodeGen/MachineOperand.h"
4646
#include "llvm/CodeGen/MachineRegisterInfo.h"
4747
#include "llvm/CodeGen/SelectionDAG.h"
48+
#include "llvm/CodeGen/SelectionDAGNodes.h"
4849
#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
4950
#include "llvm/CodeGen/StackMaps.h"
5051
#include "llvm/CodeGen/SwiftErrorValueTracking.h"
@@ -5280,6 +5281,13 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I,
52805281
// Create the node.
52815282
SDValue Result;
52825283

5284+
if (auto Bundle = I.getOperandBundle(LLVMContext::OB_deactivation_symbol)) {
5285+
auto *Sym = Bundle->Inputs[0].get();
5286+
SDValue SDSym = getValue(Sym);
5287+
SDSym = DAG.getDeactivationSymbol(cast<GlobalValue>(Sym));
5288+
Ops.push_back(SDSym);
5289+
}
5290+
52835291
if (auto Bundle = I.getOperandBundle(LLVMContext::OB_convergencectrl)) {
52845292
auto *Token = Bundle->Inputs[0].get();
52855293
SDValue ConvControlToken = getValue(Token);
@@ -8928,6 +8936,11 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89288936
ConvControlToken = getValue(Token);
89298937
}
89308938

8939+
GlobalValue *DeactivationSymbol = nullptr;
8940+
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) {
8941+
DeactivationSymbol = cast<GlobalValue>(Bundle->Inputs[0].get());
8942+
}
8943+
89318944
TargetLowering::CallLoweringInfo CLI(DAG);
89328945
CLI.setDebugLoc(getCurSDLoc())
89338946
.setChain(getRoot())
@@ -8937,7 +8950,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89378950
.setIsPreallocated(
89388951
CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
89398952
.setCFIType(CFIType)
8940-
.setConvergenceControlToken(ConvControlToken);
8953+
.setConvergenceControlToken(ConvControlToken)
8954+
.setDeactivationSymbol(DeactivationSymbol);
89418955

89428956
// Set the pointer authentication info if we have it.
89438957
if (PAI) {
@@ -9554,7 +9568,8 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
95549568
{LLVMContext::OB_deopt, LLVMContext::OB_funclet,
95559569
LLVMContext::OB_cfguardtarget, LLVMContext::OB_preallocated,
95569570
LLVMContext::OB_clang_arc_attachedcall, LLVMContext::OB_kcfi,
9557-
LLVMContext::OB_convergencectrl}) &&
9571+
LLVMContext::OB_convergencectrl,
9572+
LLVMContext::OB_deactivation_symbol}) &&
95589573
"Cannot lower calls with arbitrary operand bundles!");
95599574

95609575
SDValue Callee = getValue(I.getCalledOperand());

0 commit comments

Comments
 (0)