Skip to content

Commit 20fc9f1

Browse files
committed
[AArch64][SME] Split SMECallAttrs out of SMEAttrs (NFC)
SMECallAttrs is a new helper class that holds all the SMEAttrs for a call. The interfaces to query actions needed for the call (e.g. change streaming mode) have been moved to the SMECallAttrs class. The main motivation for this change is to make the split between caller, callee, and callsite attributes more apparent. Places that previously implicitly checked callsite attributes have been updated to make these checks explicit. Similarly, places known to only check callee or callsite attributes have also been updated to make this clear.
1 parent 2590140 commit 20fc9f1

File tree

5 files changed

+206
-161
lines changed

5 files changed

+206
-161
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8652,6 +8652,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86528652
}
86538653
}
86548654

8655+
static SMECallAttrs
8656+
getSMECallAttrs(const Function &Function,
8657+
const TargetLowering::CallLoweringInfo &CLI) {
8658+
if (CLI.CB)
8659+
return SMECallAttrs(*CLI.CB);
8660+
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8661+
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8662+
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8663+
}
8664+
86558665
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86568666
const CallLoweringInfo &CLI) const {
86578667
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8670,12 +8680,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86708680

86718681
// SME Streaming functions are not eligible for TCO as they may require
86728682
// the streaming mode or ZA to be restored after returning from the call.
8673-
SMEAttrs CallerAttrs(MF.getFunction());
8674-
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8675-
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8676-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
8677-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8678-
CallerAttrs.hasStreamingBody())
8683+
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8684+
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8685+
CallAttrs.requiresPreservingAllZAState() ||
8686+
CallAttrs.caller().hasStreamingBody())
86798687
return false;
86808688

86818689
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8967,14 +8975,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89678975
return TLI.LowerCallTo(CLI).second;
89688976
}
89698977

8970-
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8971-
const SMEAttrs &CalleeAttrs) {
8972-
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8973-
CallerAttrs.hasStreamingBody())
8978+
static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8979+
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8980+
CallAttrs.caller().hasStreamingBody())
89748981
return AArch64SME::Always;
8975-
if (CalleeAttrs.hasNonStreamingInterface())
8982+
if (CallAttrs.calleeOrCallsite().hasNonStreamingInterface())
89768983
return AArch64SME::IfCallerIsStreaming;
8977-
if (CalleeAttrs.hasStreamingInterface())
8984+
if (CallAttrs.calleeOrCallsite().hasStreamingInterface())
89788985
return AArch64SME::IfCallerIsNonStreaming;
89798986

89808987
llvm_unreachable("Unsupported attributes");
@@ -9107,11 +9114,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91079114
}
91089115

91099116
// Determine whether we need any streaming mode changes.
9110-
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9111-
if (CLI.CB)
9112-
CalleeAttrs = SMEAttrs(*CLI.CB);
9113-
else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9114-
CalleeAttrs = SMEAttrs(ES->getSymbol());
9117+
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
91159118

91169119
auto DescribeCallsite =
91179120
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9126,9 +9129,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91269129
return R;
91279130
};
91289131

9129-
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9130-
bool RequiresSaveAllZA =
9131-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9132+
bool RequiresLazySave = CallAttrs.requiresLazySave();
9133+
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91329134
if (RequiresLazySave) {
91339135
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91349136
MachinePointerInfo MPI =
@@ -9156,18 +9158,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91569158
return DescribeCallsite(R) << " sets up a lazy save for ZA";
91579159
});
91589160
} else if (RequiresSaveAllZA) {
9159-
assert(!CalleeAttrs.hasSharedZAInterface() &&
9161+
assert(!CallAttrs.calleeOrCallsite().hasSharedZAInterface() &&
91609162
"Cannot share state that may not exist");
91619163
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91629164
/*IsSave=*/true);
91639165
}
91649166

91659167
SDValue PStateSM;
9166-
bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
9168+
bool RequiresSMChange = CallAttrs.requiresSMChange();
91679169
if (RequiresSMChange) {
9168-
if (CallerAttrs.hasStreamingInterfaceOrBody())
9170+
if (CallAttrs.caller().hasStreamingInterfaceOrBody())
91699171
PStateSM = DAG.getConstant(1, DL, MVT::i64);
9170-
else if (CallerAttrs.hasNonStreamingInterface())
9172+
else if (CallAttrs.caller().hasNonStreamingInterface())
91719173
PStateSM = DAG.getConstant(0, DL, MVT::i64);
91729174
else
91739175
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9184,7 +9186,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91849186

91859187
SDValue ZTFrameIdx;
91869188
MachineFrameInfo &MFI = MF.getFrameInfo();
9187-
bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
9189+
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
91889190

91899191
// If the caller has ZT0 state which will not be preserved by the callee,
91909192
// spill ZT0 before the call.
@@ -9200,7 +9202,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
92009202

92019203
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
92029204
// PSTATE.ZA before the call if there is no lazy-save active.
9203-
bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
9205+
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
92049206
assert((!DisableZA || !RequiresLazySave) &&
92059207
"Lazy-save should have PSTATE.SM=1 on entry to the function");
92069208

@@ -9483,8 +9485,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94839485
}
94849486

94859487
SDValue NewChain = changeStreamingMode(
9486-
DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
9487-
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
9488+
DAG, DL, CallAttrs.calleeOrCallsite().hasStreamingInterface(), Chain,
9489+
InGlue, getSMCondition(CallAttrs), PStateSM);
94889490
Chain = NewChain.getValue(0);
94899491
InGlue = NewChain.getValue(1);
94909492
}
@@ -9663,8 +9665,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96639665
if (RequiresSMChange) {
96649666
assert(PStateSM && "Expected a PStateSM to be set");
96659667
Result = changeStreamingMode(
9666-
DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
9667-
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
9668+
DAG, DL, !CallAttrs.calleeOrCallsite().hasStreamingInterface(), Result,
9669+
InGlue, getSMCondition(CallAttrs), PStateSM);
96689670

96699671
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96709672
InGlue = Result.getValue(1);
@@ -9674,7 +9676,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96749676
}
96759677
}
96769678

9677-
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
9679+
if (CallAttrs.requiresEnablingZAAfterCall())
96789680
// Unconditionally resume ZA.
96799681
Result = DAG.getNode(
96809682
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28573,12 +28575,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2857328575

2857428576
// Checks to allow the use of SME instructions
2857528577
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28576-
auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28577-
auto CalleeAttrs = SMEAttrs(*Base);
28578-
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28579-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
28580-
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28581-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28578+
auto CallAttrs = SMECallAttrs(*Base);
28579+
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28580+
CallAttrs.requiresPreservingZT0() ||
28581+
CallAttrs.requiresPreservingAllZAState())
2858228582
return true;
2858328583
}
2858428584
return false;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,22 +268,21 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
268268

269269
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
270270
const Function *Callee) const {
271-
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
271+
SMECallAttrs CallAttrs(*Caller, *Callee);
272272

273273
// When inlining, we should consider the body of the function, not the
274274
// interface.
275-
if (CalleeAttrs.hasStreamingBody()) {
276-
CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
277-
CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
275+
if (CallAttrs.callee().hasStreamingBody()) {
276+
CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
277+
CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
278278
}
279279

280-
if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
280+
if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
281281
return false;
282282

283-
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
284-
CallerAttrs.requiresSMChange(CalleeAttrs) ||
285-
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
286-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
283+
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
284+
CallAttrs.requiresPreservingZT0() ||
285+
CallAttrs.requiresPreservingAllZAState()) {
287286
if (hasPossibleIncompatibleOps(Callee))
288287
return false;
289288
}
@@ -349,12 +348,14 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
349348
// streaming-mode change, and the call to G from F would also require a
350349
// streaming-mode change, then there is benefit to do the streaming-mode
351350
// change only once and avoid inlining of G into F.
351+
352352
SMEAttrs FAttrs(*F);
353-
SMEAttrs CalleeAttrs(Call);
354-
if (FAttrs.requiresSMChange(CalleeAttrs)) {
353+
SMECallAttrs CallAttrs(Call);
354+
355+
if (SMECallAttrs(FAttrs, CallAttrs.calleeOrCallsite()).requiresSMChange()) {
355356
if (F == Call.getCaller()) // (1)
356357
return CallPenaltyChangeSM * DefaultCallPenalty;
357-
if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
358+
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
358359
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
359360
}
360361

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ void SMEAttrs::set(unsigned M, bool Enable) {
2727
"ZA_New and SME_ABI_Routine are mutually exclusive");
2828

2929
assert(
30-
(!sharesZA() ||
31-
(isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
30+
(isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
3231
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
3332
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
3433

3534
// ZT0 Attrs
3635
assert(
37-
(!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
38-
isPreservesZT0())) &&
36+
(isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
37+
1 &&
3938
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
4039
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
4140

@@ -44,27 +43,6 @@ void SMEAttrs::set(unsigned M, bool Enable) {
4443
"interface");
4544
}
4645

47-
SMEAttrs::SMEAttrs(const CallBase &CB) {
48-
*this = SMEAttrs(CB.getAttributes());
49-
if (auto *F = CB.getCalledFunction()) {
50-
set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
51-
}
52-
}
53-
54-
SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
55-
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
56-
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
57-
if (FuncName == "__arm_tpidr2_restore")
58-
Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
59-
SMEAttrs::SME_ABI_Routine;
60-
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
61-
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
62-
Bitmask |= SMEAttrs::SM_Compatible;
63-
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
64-
FuncName == "__arm_sme_state_size")
65-
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
66-
}
67-
6846
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
6947
Bitmask = 0;
7048
if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
@@ -99,17 +77,39 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
9977
Bitmask |= encodeZT0State(StateValue::New);
10078
}
10179

102-
bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
103-
if (Callee.hasStreamingCompatibleInterface())
80+
void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
81+
unsigned KnownAttrs = SMEAttrs::Normal;
82+
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
83+
KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
84+
if (FuncName == "__arm_tpidr2_restore")
85+
KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
86+
SMEAttrs::SME_ABI_Routine;
87+
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
88+
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
89+
KnownAttrs |= SMEAttrs::SM_Compatible;
90+
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
91+
FuncName == "__arm_sme_state_size")
92+
KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
93+
set(KnownAttrs, /*Enable=*/true);
94+
}
95+
96+
bool SMECallAttrs::requiresSMChange() const {
97+
if ((Callsite | Callee).hasStreamingCompatibleInterface())
10498
return false;
10599

106100
// Both non-streaming
107-
if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
101+
if (Caller.hasNonStreamingInterfaceAndBody() &&
102+
(Callsite | Callee).hasNonStreamingInterface())
108103
return false;
109104

110105
// Both streaming
111-
if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
106+
if (Caller.hasStreamingInterfaceOrBody() &&
107+
(Callsite | Callee).hasStreamingInterface())
112108
return false;
113109

114110
return true;
115111
}
112+
113+
SMECallAttrs::SMECallAttrs(const CallBase &CB)
114+
: SMECallAttrs(*CB.getFunction(), CB.getCalledFunction(),
115+
CB.getAttributes()) {}

0 commit comments

Comments
 (0)