-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[AArch64][SME] Store SME attributes in AArch64FunctionInfo (NFC) #142362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The SMEAttrs class is tiny (simply a wrapper around a bitmask). Constructing SMEAttrs from a llvm::Function is relatively expensive (as we have to redo the checks for every SME attribute). So let's just construct the SMEAttrs as part of the AArch64FunctionInfo and reuse the parsed attributes where possible.
@llvm/pr-subscribers-backend-aarch64 Author: Benjamin Maxwell (MacDue) ChangesThe SMEAttrs class is tiny (simply a wrapper around a bitmask). Constructing SMEAttrs from a llvm::Function is relatively expensive (as we have to redo the checks for every SME attribute). So let's just construct the SMEAttrs as part of the AArch64FunctionInfo and reuse the parsed attributes where possible. Full diff: https://github.com/llvm/llvm-project/pull/142362.diff 8 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 5ddf83f45ac69..bb7e6b662f80e 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5198,7 +5198,8 @@ bool AArch64FastISel::fastSelectInstruction(const Instruction *I) {
FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
const TargetLibraryInfo *LibInfo) {
- SMEAttrs CallerAttrs(*FuncInfo.Fn);
+ SMEAttrs CallerAttrs =
+ FuncInfo.MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
CallerAttrs.hasStreamingCompatibleInterface() ||
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index 0f33e77d4eecc..c22dbb9bf0067 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -595,7 +595,7 @@ void AArch64FrameLowering::emitCalleeSavedGPRLocations(
MachineFunction &MF = *MBB.getParent();
MachineFrameInfo &MFI = MF.getFrameInfo();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
bool LocallyStreaming =
Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface();
@@ -2887,7 +2887,7 @@ bool enableMultiVectorSpillFill(const AArch64Subtarget &Subtarget,
if (DisableMultiVectorSpillFill)
return false;
- SMEAttrs FuncAttrs(MF.getFunction());
+ SMEAttrs FuncAttrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
bool IsLocallyStreaming =
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
@@ -3210,7 +3210,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
// Find an available register to store value of VG to.
Reg1 = findScratchNonCalleeSaveRegister(&MBB);
assert(Reg1 != AArch64::NoRegister);
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface() &&
AFI->getStreamingVGIdx() == std::numeric_limits<int>::max()) {
@@ -3539,12 +3539,13 @@ static std::optional<int> getLdStFrameID(const MachineInstr &MI,
void AArch64FrameLowering::determineStackHazardSlot(
MachineFunction &MF, BitVector &SavedRegs) const {
unsigned StackHazardSize = getStackHazardSize(MF);
+ auto *AFI = MF.getInfo<AArch64FunctionInfo>();
if (StackHazardSize == 0 || StackHazardSize % 16 != 0 ||
- MF.getInfo<AArch64FunctionInfo>()->hasStackHazardSlotIndex())
+ AFI->hasStackHazardSlotIndex())
return;
// Stack hazards are only needed in streaming functions.
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (!StackHazardInNonStreaming && Attrs.hasNonStreamingInterfaceAndBody())
return;
@@ -3581,7 +3582,7 @@ void AArch64FrameLowering::determineStackHazardSlot(
int ID = MFI.CreateStackObject(StackHazardSize, Align(16), false);
LLVM_DEBUG(dbgs() << "Created Hazard slot at " << ID << " size "
<< StackHazardSize << "\n");
- MF.getInfo<AArch64FunctionInfo>()->setStackHazardSlotIndex(ID);
+ AFI->setStackHazardSlotIndex(ID);
}
}
@@ -3734,8 +3735,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
// changes, as we will need to spill the value of the VG register.
// For locally streaming functions, we spill both the streaming and
// non-streaming VG value.
- const Function &F = MF.getFunction();
- SMEAttrs Attrs(F);
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (requiresSaveVG(MF)) {
if (Attrs.hasStreamingBody() && !Attrs.hasStreamingInterface())
CSStackSize += 16;
@@ -3892,7 +3892,7 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
// Insert VG into the list of CSRs, immediately before LR if saved.
if (requiresSaveVG(MF)) {
std::vector<CalleeSavedInfo> VGSaves;
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
auto VGInfo = CalleeSavedInfo(AArch64::VG);
VGInfo.setRestored(false);
@@ -4909,10 +4909,10 @@ static void emitVGSaveRestore(MachineBasicBlock::iterator II,
MI.getOpcode() != AArch64::VGRestorePseudo)
return;
- SMEAttrs FuncAttrs(MF->getFunction());
+ auto *AFI = MF->getInfo<AArch64FunctionInfo>();
+ SMEAttrs FuncAttrs = AFI->getSMEFnAttrs();
bool LocallyStreaming =
FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface();
- const AArch64FunctionInfo *AFI = MF->getInfo<AArch64FunctionInfo>();
int64_t VGFrameIdx =
LocallyStreaming ? AFI->getStreamingVGIdx() : AFI->getVGIdx();
@@ -5402,8 +5402,8 @@ static inline raw_ostream &operator<<(raw_ostream &OS, const StackAccess &SA) {
void AArch64FrameLowering::emitRemarks(
const MachineFunction &MF, MachineOptimizationRemarkEmitter *ORE) const {
- SMEAttrs Attrs(MF.getFunction());
- if (Attrs.hasNonStreamingInterfaceAndBody())
+ auto *AFI = MF.getInfo<AArch64FunctionInfo>();
+ if (AFI->getSMEFnAttrs().hasNonStreamingInterfaceAndBody())
return;
unsigned StackHazardSize = getStackHazardSize(MF);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ae34e6b7dcc3c..4dd9c513120bb 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -7751,7 +7751,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
(void)Res;
}
- SMEAttrs Attrs(MF.getFunction());
+ SMEAttrs Attrs = FuncInfo->getSMEFnAttrs();
bool IsLocallyStreaming =
!Attrs.hasStreamingInterface() && Attrs.hasStreamingBody();
assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
@@ -8105,7 +8105,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
// Create a 16 Byte TPIDR2 object. The dynamic buffer
// will be expanded and stored in the static object later using a pseudonode.
- if (SMEAttrs(MF.getFunction()).hasZAState()) {
+ if (Attrs.hasZAState()) {
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
@@ -8125,7 +8125,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
- } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
+ } else if (Attrs.hasAgnosticZAInterface()) {
// Call __arm_sme_state_size().
SDValue BufferSize =
DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
@@ -9610,7 +9610,7 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
// Emit SMSTOP before returning from a locally streaming function
- SMEAttrs FuncAttrs(MF.getFunction());
+ SMEAttrs FuncAttrs = FuncInfo->getSMEFnAttrs();
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
if (FuncAttrs.hasStreamingCompatibleInterface()) {
Register Reg = FuncInfo->getPStateSMReg();
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
index 5bcff61cef4b1..4b04b80121ffa 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
@@ -100,6 +100,9 @@ AArch64FunctionInfo::AArch64FunctionInfo(const Function &F,
BranchTargetEnforcement = F.hasFnAttribute("branch-target-enforcement");
BranchProtectionPAuthLR = F.hasFnAttribute("branch-protection-pauth-lr");
+ // Parse the SME function attributes.
+ SMEFnAttrs = SMEAttrs(F);
+
// The default stack probe size is 4096 if the function has no
// stack-probe-size attribute. This is a safe default because it is the
// smallest possible guard page size.
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index d3026ca45c349..361d5ec3f2b22 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -14,6 +14,7 @@
#define LLVM_LIB_TARGET_AARCH64_AARCH64MACHINEFUNCTIONINFO_H
#include "AArch64Subtarget.h"
+#include "Utils/AArch64SMEAttributes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -245,6 +246,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t VGIdx = std::numeric_limits<int>::max();
int64_t StreamingVGIdx = std::numeric_limits<int>::max();
+ // Holds the SME function attributes (streaming mode, ZA/ZT0 state).
+ SMEAttrs SMEFnAttrs;
+
public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
@@ -449,6 +453,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
StackHazardCSRSlotIndex = Index;
}
+ SMEAttrs getSMEFnAttrs() const { return SMEFnAttrs; }
+
unsigned getSRetReturnReg() const { return SRetReturnReg; }
void setSRetReturnReg(unsigned Reg) { SRetReturnReg = Reg; }
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index 1afe23e637e8d..2310c19356968 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -648,9 +648,8 @@ bool AArch64RegisterInfo::hasBasePointer(const MachineFunction &MF) const {
// Since hasBasePointer() is called before we know if we have hazard padding
// or an emergency spill slot we need to enable the basepointer
// conservatively.
- if (AFI->hasStackHazardSlotIndex() ||
- (ST.getStreamingHazardSize() &&
- !SMEAttrs(MF.getFunction()).hasNonStreamingInterfaceAndBody())) {
+ if (ST.getStreamingHazardSize() &&
+ !AFI->getSMEFnAttrs().hasNonStreamingInterfaceAndBody()) {
return true;
}
diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
index 0d368b7c280c8..90f6fc2ea664b 100644
--- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp
@@ -11,8 +11,8 @@
//===----------------------------------------------------------------------===//
#include "AArch64SelectionDAGInfo.h"
+#include "AArch64MachineFunctionInfo.h"
#include "AArch64TargetMachine.h"
-#include "Utils/AArch64SMEAttributes.h"
#define GET_SDNODE_DESC
#include "AArch64GenSDNodeInfo.inc"
@@ -227,7 +227,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
return EmitMOPS(AArch64::MOPSMemoryCopyPseudo, DAG, DL, Chain, Dst, Src,
Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
- SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
RTLIB::MEMCPY);
@@ -246,7 +247,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
Size, Alignment, isVolatile, DstPtrInfo,
MachinePointerInfo{});
- SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
RTLIB::MEMSET);
@@ -264,7 +266,8 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
return EmitMOPS(AArch64::MOPSMemoryMovePseudo, DAG, dl, Chain, Dst, Src,
Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
- SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+ auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
+ SMEAttrs Attrs = AFI->getSMEFnAttrs();
if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
RTLIB::MEMMOVE);
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 9bef102e8abf1..fd77571fe1c52 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -539,7 +539,7 @@ bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
return true;
}
- SMEAttrs Attrs(F);
+ SMEAttrs Attrs = MF.getInfo<AArch64FunctionInfo>()->getSMEFnAttrs();
if (Attrs.hasZAState() || Attrs.hasZT0State() ||
Attrs.hasStreamingInterfaceOrBody() ||
Attrs.hasStreamingCompatibleInterface())
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…m#142362) The SMEAttrs class is tiny (simply a wrapper around a bitmask). Constructing SMEAttrs from a llvm::Function is relatively expensive (as we have to redo the checks for every SME attribute). So let's just construct the SMEAttrs as part of the AArch64FunctionInfo and reuse the parsed attributes where possible.
The SMEAttrs class is tiny (simply a wrapper around a bitmask). Constructing SMEAttrs from a llvm::Function is relatively expensive (as we have to redo the checks for every SME attribute). So let's just construct the SMEAttrs as part of the AArch64FunctionInfo and reuse the parsed attributes where possible.