Skip to content

Commit cadf652

Browse files
authored
[AArch64][SME] Split SMECallAttrs out of SMEAttrs (#137239)
SMECallAttrs is a new helper class that holds all the SMEAttrs for a call. The interfaces to query actions needed for the call (e.g. change streaming mode) have been moved to the SMECallAttrs class. The main motivation for this change is to make the split between the caller, callee, and callsite attributes more apparent. Before this change, we would always merge callsite and callee attributes. The main reason to do this was to handle indirect calls, however, we also occasionally used callsite attributes on direct calls in tests (mainly to avoid creating multiple function declarations). With this patch, we now explicitly handle indirect calls and disallow incompatible attributes on direct calls (so this patch is not entirely an NFC).
1 parent 015093d commit cadf652

File tree

8 files changed

+269
-211
lines changed

8 files changed

+269
-211
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8636,6 +8636,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86368636
}
86378637
}
86388638

8639+
static SMECallAttrs
8640+
getSMECallAttrs(const Function &Function,
8641+
const TargetLowering::CallLoweringInfo &CLI) {
8642+
if (CLI.CB)
8643+
return SMECallAttrs(*CLI.CB);
8644+
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8645+
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8646+
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8647+
}
8648+
86398649
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86408650
const CallLoweringInfo &CLI) const {
86418651
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8654,12 +8664,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86548664

86558665
// SME Streaming functions are not eligible for TCO as they may require
86568666
// the streaming mode or ZA to be restored after returning from the call.
8657-
SMEAttrs CallerAttrs(MF.getFunction());
8658-
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8659-
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8660-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
8661-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8662-
CallerAttrs.hasStreamingBody())
8667+
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8668+
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8669+
CallAttrs.requiresPreservingAllZAState() ||
8670+
CallAttrs.caller().hasStreamingBody())
86638671
return false;
86648672

86658673
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8951,14 +8959,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89518959
return TLI.LowerCallTo(CLI).second;
89528960
}
89538961

8954-
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8955-
const SMEAttrs &CalleeAttrs) {
8956-
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8957-
CallerAttrs.hasStreamingBody())
8962+
static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8963+
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8964+
CallAttrs.caller().hasStreamingBody())
89588965
return AArch64SME::Always;
8959-
if (CalleeAttrs.hasNonStreamingInterface())
8966+
if (CallAttrs.callee().hasNonStreamingInterface())
89608967
return AArch64SME::IfCallerIsStreaming;
8961-
if (CalleeAttrs.hasStreamingInterface())
8968+
if (CallAttrs.callee().hasStreamingInterface())
89628969
return AArch64SME::IfCallerIsNonStreaming;
89638970

89648971
llvm_unreachable("Unsupported attributes");
@@ -9091,11 +9098,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90919098
}
90929099

90939100
// Determine whether we need any streaming mode changes.
9094-
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9095-
if (CLI.CB)
9096-
CalleeAttrs = SMEAttrs(*CLI.CB);
9097-
else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9098-
CalleeAttrs = SMEAttrs(ES->getSymbol());
9101+
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
90999102

91009103
auto DescribeCallsite =
91019104
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9110,9 +9113,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91109113
return R;
91119114
};
91129115

9113-
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9114-
bool RequiresSaveAllZA =
9115-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9116+
bool RequiresLazySave = CallAttrs.requiresLazySave();
9117+
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
91169118
if (RequiresLazySave) {
91179119
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91189120
MachinePointerInfo MPI =
@@ -9140,18 +9142,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91409142
return DescribeCallsite(R) << " sets up a lazy save for ZA";
91419143
});
91429144
} else if (RequiresSaveAllZA) {
9143-
assert(!CalleeAttrs.hasSharedZAInterface() &&
9145+
assert(!CallAttrs.callee().hasSharedZAInterface() &&
91449146
"Cannot share state that may not exist");
91459147
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91469148
/*IsSave=*/true);
91479149
}
91489150

91499151
SDValue PStateSM;
9150-
bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
9152+
bool RequiresSMChange = CallAttrs.requiresSMChange();
91519153
if (RequiresSMChange) {
9152-
if (CallerAttrs.hasStreamingInterfaceOrBody())
9154+
if (CallAttrs.caller().hasStreamingInterfaceOrBody())
91539155
PStateSM = DAG.getConstant(1, DL, MVT::i64);
9154-
else if (CallerAttrs.hasNonStreamingInterface())
9156+
else if (CallAttrs.caller().hasNonStreamingInterface())
91559157
PStateSM = DAG.getConstant(0, DL, MVT::i64);
91569158
else
91579159
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9168,7 +9170,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91689170

91699171
SDValue ZTFrameIdx;
91709172
MachineFrameInfo &MFI = MF.getFrameInfo();
9171-
bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
9173+
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
91729174

91739175
// If the caller has ZT0 state which will not be preserved by the callee,
91749176
// spill ZT0 before the call.
@@ -9184,7 +9186,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91849186

91859187
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
91869188
// PSTATE.ZA before the call if there is no lazy-save active.
9187-
bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
9189+
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
91889190
assert((!DisableZA || !RequiresLazySave) &&
91899191
"Lazy-save should have PSTATE.SM=1 on entry to the function");
91909192

@@ -9466,9 +9468,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94669468
InGlue = Chain.getValue(1);
94679469
}
94689470

9469-
SDValue NewChain = changeStreamingMode(
9470-
DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
9471-
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
9471+
SDValue NewChain =
9472+
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9473+
Chain, InGlue, getSMCondition(CallAttrs), PStateSM);
94729474
Chain = NewChain.getValue(0);
94739475
InGlue = NewChain.getValue(1);
94749476
}
@@ -9647,8 +9649,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96479649
if (RequiresSMChange) {
96489650
assert(PStateSM && "Expected a PStateSM to be set");
96499651
Result = changeStreamingMode(
9650-
DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
9651-
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
9652+
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
9653+
getSMCondition(CallAttrs), PStateSM);
96529654

96539655
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96549656
InGlue = Result.getValue(1);
@@ -9658,7 +9660,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96589660
}
96599661
}
96609662

9661-
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
9663+
if (CallAttrs.requiresEnablingZAAfterCall())
96629664
// Unconditionally resume ZA.
96639665
Result = DAG.getNode(
96649666
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28518,12 +28520,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2851828520

2851928521
// Checks to allow the use of SME instructions
2852028522
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28521-
auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28522-
auto CalleeAttrs = SMEAttrs(*Base);
28523-
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28524-
CallerAttrs.requiresLazySave(CalleeAttrs) ||
28525-
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28526-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28523+
auto CallAttrs = SMECallAttrs(*Base);
28524+
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28525+
CallAttrs.requiresPreservingZT0() ||
28526+
CallAttrs.requiresPreservingAllZAState())
2852728527
return true;
2852828528
}
2852928529
return false;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -268,22 +268,21 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
268268

269269
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
270270
const Function *Callee) const {
271-
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
271+
SMECallAttrs CallAttrs(*Caller, *Callee);
272272

273273
// When inlining, we should consider the body of the function, not the
274274
// interface.
275-
if (CalleeAttrs.hasStreamingBody()) {
276-
CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
277-
CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
275+
if (CallAttrs.callee().hasStreamingBody()) {
276+
CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
277+
CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
278278
}
279279

280-
if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
280+
if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
281281
return false;
282282

283-
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
284-
CallerAttrs.requiresSMChange(CalleeAttrs) ||
285-
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
286-
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
283+
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
284+
CallAttrs.requiresPreservingZT0() ||
285+
CallAttrs.requiresPreservingAllZAState()) {
287286
if (hasPossibleIncompatibleOps(Callee))
288287
return false;
289288
}
@@ -349,12 +348,14 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
349348
// streaming-mode change, and the call to G from F would also require a
350349
// streaming-mode change, then there is benefit to do the streaming-mode
351350
// change only once and avoid inlining of G into F.
351+
352352
SMEAttrs FAttrs(*F);
353-
SMEAttrs CalleeAttrs(Call);
354-
if (FAttrs.requiresSMChange(CalleeAttrs)) {
353+
SMECallAttrs CallAttrs(Call);
354+
355+
if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
355356
if (F == Call.getCaller()) // (1)
356357
return CallPenaltyChangeSM * DefaultCallPenalty;
357-
if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
358+
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
358359
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
359360
}
360361

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ void SMEAttrs::set(unsigned M, bool Enable) {
2727
"ZA_New and SME_ABI_Routine are mutually exclusive");
2828

2929
assert(
30-
(!sharesZA() ||
31-
(isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
30+
(isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
3231
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
3332
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
3433

3534
// ZT0 Attrs
3635
assert(
37-
(!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
38-
isPreservesZT0())) &&
36+
(isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
37+
1 &&
3938
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
4039
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
4140

@@ -44,27 +43,6 @@ void SMEAttrs::set(unsigned M, bool Enable) {
4443
"interface");
4544
}
4645

47-
SMEAttrs::SMEAttrs(const CallBase &CB) {
48-
*this = SMEAttrs(CB.getAttributes());
49-
if (auto *F = CB.getCalledFunction()) {
50-
set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
51-
}
52-
}
53-
54-
SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
55-
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
56-
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
57-
if (FuncName == "__arm_tpidr2_restore")
58-
Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
59-
SMEAttrs::SME_ABI_Routine;
60-
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
61-
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
62-
Bitmask |= SMEAttrs::SM_Compatible;
63-
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
64-
FuncName == "__arm_sme_state_size")
65-
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
66-
}
67-
6846
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
6947
Bitmask = 0;
7048
if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
@@ -99,17 +77,45 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
9977
Bitmask |= encodeZT0State(StateValue::New);
10078
}
10179

102-
bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
103-
if (Callee.hasStreamingCompatibleInterface())
80+
void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
81+
unsigned KnownAttrs = SMEAttrs::Normal;
82+
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
83+
KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
84+
if (FuncName == "__arm_tpidr2_restore")
85+
KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
86+
SMEAttrs::SME_ABI_Routine;
87+
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
88+
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
89+
KnownAttrs |= SMEAttrs::SM_Compatible;
90+
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
91+
FuncName == "__arm_sme_state_size")
92+
KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
93+
set(KnownAttrs, /*Enable=*/true);
94+
}
95+
96+
bool SMECallAttrs::requiresSMChange() const {
97+
if (callee().hasStreamingCompatibleInterface())
10498
return false;
10599

106100
// Both non-streaming
107-
if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
101+
if (caller().hasNonStreamingInterfaceAndBody() &&
102+
callee().hasNonStreamingInterface())
108103
return false;
109104

110105
// Both streaming
111-
if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
106+
if (caller().hasStreamingInterfaceOrBody() &&
107+
callee().hasStreamingInterface())
112108
return false;
113109

114110
return true;
115111
}
112+
113+
SMECallAttrs::SMECallAttrs(const CallBase &CB)
114+
: CallerFn(*CB.getFunction()), CalledFn(CB.getCalledFunction()),
115+
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
116+
// FIXME: We probably should not allow SME attributes on direct calls but
117+
// clang duplicates streaming mode attributes at each callsite.
118+
assert((IsIndirect ||
119+
((Callsite.withoutPerCallsiteFlags() | CalledFn) == CalledFn)) &&
120+
"SME attributes at callsite do not match declaration");
121+
}

0 commit comments

Comments
 (0)