Skip to content

Add IR and codegen support for deactivation symbols. #133536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: users/pcc/spr/main.add-ir-and-codegen-support-for-deactivation-symbols
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class CallLowering {

/// True if this call results in convergent operations.
bool IsConvergent = true;

GlobalValue *DeactivationSymbol = nullptr;
};

/// Argument handling is mostly uniform between the four places that
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct MachineIRBuilderState {
MDNode *PCSections = nullptr;
/// MMRA Metadata to be set on any instruction we create.
MDNode *MMRA = nullptr;
Value *DS = nullptr;

/// \name Fields describing the insertion point.
/// @{
Expand Down Expand Up @@ -368,6 +369,7 @@ class MachineIRBuilder {
State.II = MI.getIterator();
setPCSections(MI.getPCSections());
setMMRAMetadata(MI.getMMRAMetadata());
setDeactivationSymbol(MI.getDeactivationSymbol());
}
/// @}

Expand Down Expand Up @@ -404,6 +406,9 @@ class MachineIRBuilder {
/// Set the PC sections metadata to \p MD for all the next build instructions.
void setMMRAMetadata(MDNode *MMRA) { State.MMRA = MMRA; }

Value *getDeactivationSymbol() { return State.DS; }
void setDeactivationSymbol(Value *DS) { State.DS = DS; }

/// Get the current instruction's MMRA metadata.
MDNode *getMMRAMetadata() { return State.MMRA; }

Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,8 @@ enum NodeType {
// Outputs: Output Chain
CLEAR_CACHE,

DEACTIVATION_SYMBOL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment


/// BUILTIN_OP_END - This must be the last enum value in this list.
/// The target-specific pre-isel opcode values start here.
BUILTIN_OP_END
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/CodeGen/MachineInstr.h
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,15 @@ class MachineInstr
return nullptr;
}

// FIXME: Move to Info.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should fix the fixme, also needs mir print/parse

Value *DeactivationSymbol = nullptr;
Value *getDeactivationSymbol() const {
return DeactivationSymbol;
}
void setDeactivationSymbol(MachineFunction &MF, Value *DeactivationSymbol) {
this->DeactivationSymbol = DeactivationSymbol;
}

/// Helper to extract a CFI type hash if one has been added.
uint32_t getCFIType() const {
if (!Info)
Expand Down
27 changes: 22 additions & 5 deletions llvm/include/llvm/CodeGen/MachineInstrBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,44 @@ enum {
} // end namespace RegState

/// Set of metadata that should be preserved when using BuildMI(). This provides
/// a more convenient way of preserving DebugLoc, PCSections and MMRA.
/// a more convenient way of preserving certain data from the original
/// instruction.
class MIMetadata {
public:
MIMetadata() = default;
MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr)
: DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA) {}
MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr,
Value *DeactivationSymbol = nullptr)
: DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA),
DeactivationSymbol(DeactivationSymbol) {}
MIMetadata(const DILocation *DI, MDNode *PCSections = nullptr,
MDNode *MMRA = nullptr)
: DL(DI), PCSections(PCSections), MMRA(MMRA) {}
explicit MIMetadata(const Instruction &From)
: DL(From.getDebugLoc()),
PCSections(From.getMetadata(LLVMContext::MD_pcsections)) {}
PCSections(From.getMetadata(LLVMContext::MD_pcsections)),
DeactivationSymbol(getDeactivationSymbol(&From)) {}
explicit MIMetadata(const MachineInstr &From)
: DL(From.getDebugLoc()), PCSections(From.getPCSections()) {}
: DL(From.getDebugLoc()), PCSections(From.getPCSections()),
DeactivationSymbol(From.getDeactivationSymbol()) {}

const DebugLoc &getDL() const { return DL; }
MDNode *getPCSections() const { return PCSections; }
MDNode *getMMRAMetadata() const { return MMRA; }
Value *getDeactivationSymbol() const { return DeactivationSymbol; }

private:
DebugLoc DL;
MDNode *PCSections = nullptr;
MDNode *MMRA = nullptr;
Value *DeactivationSymbol = nullptr;

static inline Value *getDeactivationSymbol(const Instruction *I) {
if (auto *CB = dyn_cast<CallBase>(I))
if (auto Bundle =
CB->getOperandBundle(llvm::LLVMContext::OB_deactivation_symbol))
return Bundle->Inputs[0].get();
return nullptr;
}
};

class MachineInstrBuilder {
Expand Down Expand Up @@ -347,6 +362,8 @@ class MachineInstrBuilder {
MI->setPCSections(*MF, MIMD.getPCSections());
if (MIMD.getMMRAMetadata())
MI->setMMRAMetadata(*MF, MIMD.getMMRAMetadata());
if (MIMD.getDeactivationSymbol())
MI->setDeactivationSymbol(*MF, MIMD.getDeactivationSymbol());
return *this;
}

Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ class SelectionDAG {
int64_t offset = 0, unsigned TargetFlags = 0) {
return getGlobalAddress(GV, DL, VT, offset, true, TargetFlags);
}
SDValue getDeactivationSymbol(const GlobalValue *GV);
SDValue getFrameIndex(int FI, EVT VT, bool isTarget = false);
SDValue getTargetFrameIndex(int FI, EVT VT) {
return getFrameIndex(FI, VT, true);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGISel.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class SelectionDAGISel {
OPC_RecordChild7,
OPC_RecordMemRef,
OPC_CaptureGlueInput,
OPC_CaptureDeactivationSymbol,
OPC_MoveChild,
OPC_MoveChild0,
OPC_MoveChild1,
Expand Down
17 changes: 17 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,23 @@ class GlobalAddressSDNode : public SDNode {
}
};

class DeactivationSymbolSDNode : public SDNode {
friend class SelectionDAG;

const GlobalValue *TheGlobal;

DeactivationSymbolSDNode(const GlobalValue *GV, SDVTList VTs)
: SDNode(ISD::DEACTIVATION_SYMBOL, 0, DebugLoc(), VTs),
TheGlobal(GV) {}

public:
const GlobalValue *getGlobal() const { return TheGlobal; }

static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::DEACTIVATION_SYMBOL;
}
};

class FrameIndexSDNode : public SDNode {
friend class SelectionDAG;

Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -4690,6 +4690,7 @@ class TargetLowering : public TargetLoweringBase {
SmallVector<SDValue, 4> InVals;
const ConstantInt *CFIType = nullptr;
SDValue ConvergenceControlToken;
GlobalValue *DeactivationSymbol = nullptr;

std::optional<PtrAuthInfo> PAI;

Expand Down Expand Up @@ -4835,6 +4836,11 @@ class TargetLowering : public TargetLoweringBase {
return *this;
}

CallLoweringInfo &setDeactivationSymbol(GlobalValue *Sym) {
DeactivationSymbol = Sym;
return *this;
}

ArgListTy &getArgs() {
return Args;
}
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/LLVMContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class LLVMContext {
OB_ptrauth = 7, // "ptrauth"
OB_kcfi = 8, // "kcfi"
OB_convergencectrl = 9, // "convergencectrl"
OB_deactivation_symbol = 10, // "deactivation-symbol"
};

/// getMDKindID - Return a unique non-zero ID for the specified metadata kind.
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Target/Target.td
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ class Instruction : InstructionEncoding {
// If so, make sure to override
// TargetInstrInfo::getInsertSubregLikeInputs.
bit variadicOpsAreDefs = false; // Are variadic operands definitions?
bit supportsDeactivationSymbol = false;

// Does the instruction have side effects that are not captured by any
// operands of the instruction or other flags?
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
assert(Info.CFIType->getType()->isIntegerTy(32) && "Invalid CFI type");
}

if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) {
Info.DeactivationSymbol = cast<GlobalValue>(Bundle->Inputs[0]);
}

Info.CB = &CB;
Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
Info.CallConv = CallConv;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2860,6 +2860,9 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
}
}

if (auto Bundle = CI.getOperandBundle(LLVMContext::OB_deactivation_symbol))
MIB->setDeactivationSymbol(*MF, Bundle->Inputs[0].get());

return true;
}

Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ void MachineIRBuilder::setMF(MachineFunction &MF) {
//------------------------------------------------------------------------------

MachineInstrBuilder MachineIRBuilder::buildInstrNoInsert(unsigned Opcode) {
return BuildMI(getMF(), {getDL(), getPCSections(), getMMRAMetadata()},
getTII().get(Opcode));
return BuildMI(
getMF(),
{getDL(), getPCSections(), getMMRAMetadata(), getDeactivationSymbol()},
getTII().get(Opcode));
}

MachineInstrBuilder MachineIRBuilder::insertInstr(MachineInstrBuilder MIB) {
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/CodeGen/MachineInstr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ MachineInstr::MachineInstr(MachineFunction &MF, const MCInstrDesc &TID,
MachineInstr::MachineInstr(MachineFunction &MF, const MachineInstr &MI)
: MCID(&MI.getDesc()), NumOperands(0), Flags(0), AsmPrinterFlags(0),
Info(MI.Info), DbgLoc(MI.getDebugLoc()), DebugInstrNum(0),
Opcode(MI.getOpcode()) {
Opcode(MI.getOpcode()), DeactivationSymbol(MI.getDeactivationSymbol()) {
assert(DbgLoc.hasTrivialDestructor() && "Expected trivial destructor");

CapOperands = OperandCapacity::get(MI.getNumOperands());
Expand Down Expand Up @@ -728,6 +728,8 @@ bool MachineInstr::isIdenticalTo(const MachineInstr &Other,
// Call instructions with different CFI types are not identical.
if (isCall() && getCFIType() != Other.getCFIType())
return false;
if (getDeactivationSymbol() != Other.getDeactivationSymbol())
return false;

return true;
}
Expand Down Expand Up @@ -2029,6 +2031,8 @@ void MachineInstr::print(raw_ostream &OS, ModuleSlotTracker &MST,
OS << ',';
OS << " cfi-type " << CFIType;
}
if (getDeactivationSymbol())
OS << ", deactivation-symbol " << getDeactivationSymbol()->getName();

if (DebugInstrNum) {
if (!FirstOp)
Expand Down
33 changes: 23 additions & 10 deletions llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
#include "InstrEmitter.h"
#include "SDNodeDbgValue.h"
#include "llvm/BinaryFormat/Dwarf.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/StackMaps.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
Expand Down Expand Up @@ -61,6 +63,8 @@ static unsigned countOperands(SDNode *Node, unsigned NumExpUses,
unsigned N = Node->getNumOperands();
while (N && Node->getOperand(N - 1).getValueType() == MVT::Glue)
--N;
if (N && Node->getOperand(N - 1).getOpcode() == ISD::DEACTIVATION_SYMBOL)
--N; // Ignore deactivation symbol if it exists.
if (N && Node->getOperand(N - 1).getValueType() == MVT::Other)
--N; // Ignore chain if it exists.

Expand Down Expand Up @@ -1218,15 +1222,23 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
}
}

if (SDNode *GluedNode = Node->getGluedNode()) {
// FIXME: Possibly iterate over multiple glue nodes?
if (GluedNode->getOpcode() ==
~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) {
Register VReg = getVR(GluedNode->getOperand(0), VRBaseMap);
MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false,
/*isImp=*/true);
MIB->addOperand(MO);
}
unsigned Op = Node->getNumOperands();
if (Op != 0 && Node->getOperand(Op - 1)->getOpcode() ==
~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) {
Register VReg = getVR(Node->getOperand(Op - 1)->getOperand(0), VRBaseMap);
MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false,
/*isImp=*/true);
MIB->addOperand(MO);
Op--;
}

if (Op != 0 &&
Node->getOperand(Op - 1)->getOpcode() == ISD::DEACTIVATION_SYMBOL) {
MI->setDeactivationSymbol(
*MF, const_cast<GlobalValue *>(
cast<DeactivationSymbolSDNode>(Node->getOperand(Op - 1))
->getGlobal()));
Op--;
}

// Run post-isel target hook to adjust this instruction if needed.
Expand All @@ -1247,7 +1259,8 @@ EmitSpecialNode(SDNode *Node, bool IsClone, bool IsCloned,
llvm_unreachable("This target-independent node should have been selected!");
case ISD::EntryToken:
case ISD::MERGE_VALUES:
case ISD::TokenFactor: // fall thru
case ISD::TokenFactor:
case ISD::DEACTIVATION_SYMBOL:
break;
case ISD::CopyToReg: {
Register DestReg = cast<RegisterSDNode>(Node->getOperand(1))->getReg();
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1916,6 +1916,21 @@ SDValue SelectionDAG::getGlobalAddress(const GlobalValue *GV, const SDLoc &DL,
return SDValue(N, 0);
}

SDValue SelectionDAG::getDeactivationSymbol(const GlobalValue *GV) {
SDVTList VTs = getVTList(MVT::Untyped);
FoldingSetNodeID ID;
AddNodeIDNode(ID, ISD::DEACTIVATION_SYMBOL, VTs, {});
ID.AddPointer(GV);
void *IP = nullptr;
if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP))
return SDValue(E, 0);

auto *N = newSDNode<DeactivationSymbolSDNode>(GV, VTs);
CSEMap.InsertNode(N, IP);
InsertNode(N);
return SDValue(N, 0);
}

SDValue SelectionDAG::getFrameIndex(int FI, EVT VT, bool isTarget) {
unsigned Opc = isTarget ? ISD::TargetFrameIndex : ISD::FrameIndex;
SDVTList VTs = getVTList(VT);
Expand Down
19 changes: 17 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/SelectionDAGTargetInfo.h"
#include "llvm/CodeGen/StackMaps.h"
#include "llvm/CodeGen/SwiftErrorValueTracking.h"
Expand Down Expand Up @@ -5283,6 +5284,13 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I,
// Create the node.
SDValue Result;

if (auto Bundle = I.getOperandBundle(LLVMContext::OB_deactivation_symbol)) {
auto *Sym = Bundle->Inputs[0].get();
SDValue SDSym = getValue(Sym);
SDSym = DAG.getDeactivationSymbol(cast<GlobalValue>(Sym));
Ops.push_back(SDSym);
}

if (auto Bundle = I.getOperandBundle(LLVMContext::OB_convergencectrl)) {
auto *Token = Bundle->Inputs[0].get();
SDValue ConvControlToken = getValue(Token);
Expand Down Expand Up @@ -8916,6 +8924,11 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
ConvControlToken = getValue(Token);
}

GlobalValue *DeactivationSymbol = nullptr;
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) {
DeactivationSymbol = cast<GlobalValue>(Bundle->Inputs[0].get());
}

TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(getCurSDLoc())
.setChain(getRoot())
Expand All @@ -8925,7 +8938,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
.setIsPreallocated(
CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
.setCFIType(CFIType)
.setConvergenceControlToken(ConvControlToken);
.setConvergenceControlToken(ConvControlToken)
.setDeactivationSymbol(DeactivationSymbol);

// Set the pointer authentication info if we have it.
if (PAI) {
Expand Down Expand Up @@ -9542,7 +9556,8 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
{LLVMContext::OB_deopt, LLVMContext::OB_funclet,
LLVMContext::OB_cfguardtarget, LLVMContext::OB_preallocated,
LLVMContext::OB_clang_arc_attachedcall, LLVMContext::OB_kcfi,
LLVMContext::OB_convergencectrl}) &&
LLVMContext::OB_convergencectrl,
LLVMContext::OB_deactivation_symbol}) &&
"Cannot lower calls with arbitrary operand bundles!");

SDValue Callee = getValue(I.getCalledOperand());
Expand Down
Loading
Loading