@@ -8652,6 +8652,16 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
8652
8652
}
8653
8653
}
8654
8654
8655
+ static SMECallAttrs
8656
+ getSMECallAttrs(const Function &Function,
8657
+ const TargetLowering::CallLoweringInfo &CLI) {
8658
+ if (CLI.CB)
8659
+ return SMECallAttrs(*CLI.CB);
8660
+ if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8661
+ return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8662
+ return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8663
+ }
8664
+
8655
8665
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
8656
8666
const CallLoweringInfo &CLI) const {
8657
8667
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8670,12 +8680,10 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
8670
8680
8671
8681
// SME Streaming functions are not eligible for TCO as they may require
8672
8682
// the streaming mode or ZA to be restored after returning from the call.
8673
- SMEAttrs CallerAttrs(MF.getFunction());
8674
- auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8675
- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8676
- CallerAttrs.requiresLazySave(CalleeAttrs) ||
8677
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8678
- CallerAttrs.hasStreamingBody())
8683
+ SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8684
+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8685
+ CallAttrs.requiresPreservingAllZAState() ||
8686
+ CallAttrs.caller().hasStreamingBody())
8679
8687
return false;
8680
8688
8681
8689
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8967,14 +8975,13 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8967
8975
return TLI.LowerCallTo(CLI).second;
8968
8976
}
8969
8977
8970
- static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8971
- const SMEAttrs &CalleeAttrs) {
8972
- if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8973
- CallerAttrs.hasStreamingBody())
8978
+ static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8979
+ if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8980
+ CallAttrs.caller().hasStreamingBody())
8974
8981
return AArch64SME::Always;
8975
- if (CalleeAttrs .hasNonStreamingInterface())
8982
+ if (CallAttrs.calleeOrCallsite() .hasNonStreamingInterface())
8976
8983
return AArch64SME::IfCallerIsStreaming;
8977
- if (CalleeAttrs .hasStreamingInterface())
8984
+ if (CallAttrs.calleeOrCallsite() .hasStreamingInterface())
8978
8985
return AArch64SME::IfCallerIsNonStreaming;
8979
8986
8980
8987
llvm_unreachable("Unsupported attributes");
@@ -9107,11 +9114,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9107
9114
}
9108
9115
9109
9116
// Determine whether we need any streaming mode changes.
9110
- SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9111
- if (CLI.CB)
9112
- CalleeAttrs = SMEAttrs(*CLI.CB);
9113
- else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9114
- CalleeAttrs = SMEAttrs(ES->getSymbol());
9117
+ SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9115
9118
9116
9119
auto DescribeCallsite =
9117
9120
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9126,9 +9129,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9126
9129
return R;
9127
9130
};
9128
9131
9129
- bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9130
- bool RequiresSaveAllZA =
9131
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9132
+ bool RequiresLazySave = CallAttrs.requiresLazySave();
9133
+ bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9132
9134
if (RequiresLazySave) {
9133
9135
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
9134
9136
MachinePointerInfo MPI =
@@ -9156,18 +9158,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9156
9158
return DescribeCallsite(R) << " sets up a lazy save for ZA";
9157
9159
});
9158
9160
} else if (RequiresSaveAllZA) {
9159
- assert(!CalleeAttrs .hasSharedZAInterface() &&
9161
+ assert(!CallAttrs.calleeOrCallsite() .hasSharedZAInterface() &&
9160
9162
"Cannot share state that may not exist");
9161
9163
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9162
9164
/*IsSave=*/true);
9163
9165
}
9164
9166
9165
9167
SDValue PStateSM;
9166
- bool RequiresSMChange = CallerAttrs .requiresSMChange(CalleeAttrs );
9168
+ bool RequiresSMChange = CallAttrs .requiresSMChange();
9167
9169
if (RequiresSMChange) {
9168
- if (CallerAttrs .hasStreamingInterfaceOrBody())
9170
+ if (CallAttrs.caller() .hasStreamingInterfaceOrBody())
9169
9171
PStateSM = DAG.getConstant(1, DL, MVT::i64);
9170
- else if (CallerAttrs .hasNonStreamingInterface())
9172
+ else if (CallAttrs.caller() .hasNonStreamingInterface())
9171
9173
PStateSM = DAG.getConstant(0, DL, MVT::i64);
9172
9174
else
9173
9175
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9184,7 +9186,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9184
9186
9185
9187
SDValue ZTFrameIdx;
9186
9188
MachineFrameInfo &MFI = MF.getFrameInfo();
9187
- bool ShouldPreserveZT0 = CallerAttrs .requiresPreservingZT0(CalleeAttrs );
9189
+ bool ShouldPreserveZT0 = CallAttrs .requiresPreservingZT0();
9188
9190
9189
9191
// If the caller has ZT0 state which will not be preserved by the callee,
9190
9192
// spill ZT0 before the call.
@@ -9200,7 +9202,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9200
9202
9201
9203
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
9202
9204
// PSTATE.ZA before the call if there is no lazy-save active.
9203
- bool DisableZA = CallerAttrs .requiresDisablingZABeforeCall(CalleeAttrs );
9205
+ bool DisableZA = CallAttrs .requiresDisablingZABeforeCall();
9204
9206
assert((!DisableZA || !RequiresLazySave) &&
9205
9207
"Lazy-save should have PSTATE.SM=1 on entry to the function");
9206
9208
@@ -9483,8 +9485,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9483
9485
}
9484
9486
9485
9487
SDValue NewChain = changeStreamingMode(
9486
- DAG, DL, CalleeAttrs. hasStreamingInterface(), Chain, InGlue ,
9487
- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9488
+ DAG, DL, CallAttrs.calleeOrCallsite(). hasStreamingInterface(), Chain,
9489
+ InGlue, getSMCondition(CallAttrs ), PStateSM);
9488
9490
Chain = NewChain.getValue(0);
9489
9491
InGlue = NewChain.getValue(1);
9490
9492
}
@@ -9663,8 +9665,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9663
9665
if (RequiresSMChange) {
9664
9666
assert(PStateSM && "Expected a PStateSM to be set");
9665
9667
Result = changeStreamingMode(
9666
- DAG, DL, !CalleeAttrs. hasStreamingInterface(), Result, InGlue ,
9667
- getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9668
+ DAG, DL, !CallAttrs.calleeOrCallsite(). hasStreamingInterface(), Result,
9669
+ InGlue, getSMCondition(CallAttrs ), PStateSM);
9668
9670
9669
9671
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
9670
9672
InGlue = Result.getValue(1);
@@ -9674,7 +9676,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9674
9676
}
9675
9677
}
9676
9678
9677
- if (CallerAttrs .requiresEnablingZAAfterCall(CalleeAttrs ))
9679
+ if (CallAttrs .requiresEnablingZAAfterCall())
9678
9680
// Unconditionally resume ZA.
9679
9681
Result = DAG.getNode(
9680
9682
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28573,12 +28575,10 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
28573
28575
28574
28576
// Checks to allow the use of SME instructions
28575
28577
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28576
- auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28577
- auto CalleeAttrs = SMEAttrs(*Base);
28578
- if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28579
- CallerAttrs.requiresLazySave(CalleeAttrs) ||
28580
- CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28581
- CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28578
+ auto CallAttrs = SMECallAttrs(*Base);
28579
+ if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28580
+ CallAttrs.requiresPreservingZT0() ||
28581
+ CallAttrs.requiresPreservingAllZAState())
28582
28582
return true;
28583
28583
}
28584
28584
return false;
0 commit comments