Skip to content

Commit 6017c7e

Browse files
committed
[AArch64][SME] Disallow SME attributes on direct function calls
This was only used in a handful of tests (mainly to avoid making multiple function declarations). These tests can easily be updated to use indirect calls or attributes on declarations. This allows us to remove checks that looked at both the "callee" and "callsite" attributes, which makes the API of SMECallAttrs a clearer and less error-prone (as you can't accidentally use .callee() when you should have used .calleeOrCallsite()). Note: This currently still allows non-conflicting attributes on direct calls (as clang currently duplicates streaming mode attributes at each callsite).
1 parent 20fc9f1 commit 6017c7e

File tree

7 files changed

+101
-84
lines changed

7 files changed

+101
-84
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8979,9 +8979,9 @@ static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
89798979
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
89808980
CallAttrs.caller().hasStreamingBody())
89818981
return AArch64SME::Always;
8982-
if (CallAttrs.calleeOrCallsite().hasNonStreamingInterface())
8982+
if (CallAttrs.callee().hasNonStreamingInterface())
89838983
return AArch64SME::IfCallerIsStreaming;
8984-
if (CallAttrs.calleeOrCallsite().hasStreamingInterface())
8984+
if (CallAttrs.callee().hasStreamingInterface())
89858985
return AArch64SME::IfCallerIsNonStreaming;
89868986

89878987
llvm_unreachable("Unsupported attributes");
@@ -9158,7 +9158,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91589158
return DescribeCallsite(R) << " sets up a lazy save for ZA";
91599159
});
91609160
} else if (RequiresSaveAllZA) {
9161-
assert(!CallAttrs.calleeOrCallsite().hasSharedZAInterface() &&
9161+
assert(!CallAttrs.callee().hasSharedZAInterface() &&
91629162
"Cannot share state that may not exist");
91639163
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91649164
/*IsSave=*/true);
@@ -9484,9 +9484,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94849484
InGlue = Chain.getValue(1);
94859485
}
94869486

9487-
SDValue NewChain = changeStreamingMode(
9488-
DAG, DL, CallAttrs.calleeOrCallsite().hasStreamingInterface(), Chain,
9489-
InGlue, getSMCondition(CallAttrs), PStateSM);
9487+
SDValue NewChain =
9488+
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9489+
Chain, InGlue, getSMCondition(CallAttrs), PStateSM);
94909490
Chain = NewChain.getValue(0);
94919491
InGlue = NewChain.getValue(1);
94929492
}
@@ -9665,8 +9665,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96659665
if (RequiresSMChange) {
96669666
assert(PStateSM && "Expected a PStateSM to be set");
96679667
Result = changeStreamingMode(
9668-
DAG, DL, !CallAttrs.calleeOrCallsite().hasStreamingInterface(), Result,
9669-
InGlue, getSMCondition(CallAttrs), PStateSM);
9668+
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
9669+
getSMCondition(CallAttrs), PStateSM);
96709670

96719671
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96729672
InGlue = Result.getValue(1);

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
352352
SMEAttrs FAttrs(*F);
353353
SMECallAttrs CallAttrs(Call);
354354

355-
if (SMECallAttrs(FAttrs, CallAttrs.calleeOrCallsite()).requiresSMChange()) {
355+
if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
356356
if (F == Call.getCaller()) // (1)
357357
return CallPenaltyChangeSM * DefaultCallPenalty;
358358
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,22 +94,28 @@ void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
9494
}
9595

9696
bool SMECallAttrs::requiresSMChange() const {
97-
if ((Callsite | Callee).hasStreamingCompatibleInterface())
97+
if (callee().hasStreamingCompatibleInterface())
9898
return false;
9999

100100
// Both non-streaming
101-
if (Caller.hasNonStreamingInterfaceAndBody() &&
102-
(Callsite | Callee).hasNonStreamingInterface())
101+
if (caller().hasNonStreamingInterfaceAndBody() &&
102+
callee().hasNonStreamingInterface())
103103
return false;
104104

105105
// Both streaming
106-
if (Caller.hasStreamingInterfaceOrBody() &&
107-
(Callsite | Callee).hasStreamingInterface())
106+
if (caller().hasStreamingInterfaceOrBody() &&
107+
callee().hasStreamingInterface())
108108
return false;
109109

110110
return true;
111111
}
112112

113113
SMECallAttrs::SMECallAttrs(const CallBase &CB)
114-
: SMECallAttrs(*CB.getFunction(), CB.getCalledFunction(),
115-
CB.getAttributes()) {}
114+
: CallerFn(*CB.getFunction()), CalledFn(CB.getCalledFunction()),
115+
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
116+
// FIXME: We probably should not allow SME attributes on direct calls but
117+
// clang duplicates streaming mode attributes at each callsite.
118+
assert((IsIndirect ||
119+
((Callsite.withoutPerCallsiteFlags() | CalledFn) == CalledFn)) &&
120+
"SME attributes at callsite do not match declaration");
121+
}

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class SMEAttrs {
4444
ZA_Shift = 6,
4545
ZA_Mask = 0b111 << ZA_Shift,
4646
ZT0_Shift = 9,
47-
ZT0_Mask = 0b111 << ZT0_Shift
47+
ZT0_Mask = 0b111 << ZT0_Shift,
48+
Callsite_Flags = ZT0_Undef
4849
};
4950

5051
SMEAttrs() = default;
@@ -135,6 +136,14 @@ class SMEAttrs {
135136
return Merged;
136137
}
137138

139+
SMEAttrs withoutPerCallsiteFlags() const {
140+
return (Bitmask & ~Callsite_Flags);
141+
}
142+
143+
bool operator==(SMEAttrs const &Other) const {
144+
return Bitmask == Other.Bitmask;
145+
}
146+
138147
private:
139148
void addKnownFunctionAttrs(StringRef FuncName);
140149
};
@@ -143,54 +152,57 @@ class SMEAttrs {
143152
/// interfaces to query whether a streaming mode change or lazy-save mechanism
144153
/// is required when going from one function to another (e.g. through a call).
145154
class SMECallAttrs {
146-
SMEAttrs Caller;
147-
SMEAttrs Callee;
155+
SMEAttrs CallerFn;
156+
SMEAttrs CalledFn;
148157
SMEAttrs Callsite;
158+
bool IsIndirect = false;
149159

150160
public:
151161
SMECallAttrs(SMEAttrs Caller, SMEAttrs Callee,
152162
SMEAttrs Callsite = SMEAttrs::Normal)
153-
: Caller(Caller), Callee(Callee), Callsite(Callsite) {}
163+
: CallerFn(Caller), CalledFn(Callee), Callsite(Callsite) {}
154164

155165
SMECallAttrs(const CallBase &CB);
156166

157-
SMEAttrs &caller() { return Caller; }
158-
SMEAttrs &callee() { return Callee; }
167+
SMEAttrs &caller() { return CallerFn; }
168+
SMEAttrs &callee() {
169+
if (IsIndirect)
170+
return Callsite;
171+
return CalledFn;
172+
}
159173
SMEAttrs &callsite() { return Callsite; }
160-
SMEAttrs const &caller() const { return Caller; }
161-
SMEAttrs const &callee() const { return Callee; }
174+
SMEAttrs const &caller() const { return CallerFn; }
175+
SMEAttrs const &callee() const {
176+
return const_cast<SMECallAttrs *>(this)->callee();
177+
}
162178
SMEAttrs const &callsite() const { return Callsite; }
163-
SMEAttrs calleeOrCallsite() const { return Callsite | Callee; }
164179

165180
/// \return true if a call from Caller -> Callee requires a change in
166181
/// streaming mode.
167182
bool requiresSMChange() const;
168183

169184
bool requiresLazySave() const {
170-
return Caller.hasZAState() && (Callsite | Callee).hasPrivateZAInterface() &&
171-
!Callee.isSMEABIRoutine();
185+
return caller().hasZAState() && callee().hasPrivateZAInterface() &&
186+
!callee().isSMEABIRoutine();
172187
}
173188

174189
bool requiresPreservingZT0() const {
175-
return Caller.hasZT0State() && !Callsite.hasUndefZT0() &&
176-
!(Callsite | Callee).sharesZT0() &&
177-
!(Callsite | Callee).hasAgnosticZAInterface();
190+
return caller().hasZT0State() && !callsite().hasUndefZT0() &&
191+
!callee().sharesZT0() && !callee().hasAgnosticZAInterface();
178192
}
179193

180194
bool requiresDisablingZABeforeCall() const {
181-
return Caller.hasZT0State() && !Caller.hasZAState() &&
182-
(Callsite | Callee).hasPrivateZAInterface() &&
183-
!Callee.isSMEABIRoutine();
195+
return caller().hasZT0State() && !caller().hasZAState() &&
196+
callee().hasPrivateZAInterface() && !callee().isSMEABIRoutine();
184197
}
185198

186199
bool requiresEnablingZAAfterCall() const {
187200
return requiresLazySave() || requiresDisablingZABeforeCall();
188201
}
189202

190203
bool requiresPreservingAllZAState() const {
191-
return Caller.hasAgnosticZAInterface() &&
192-
!(Callsite | Callee).hasAgnosticZAInterface() &&
193-
!Callee.isSMEABIRoutine();
204+
return caller().hasAgnosticZAInterface() &&
205+
!callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
194206
}
195207
};
196208

llvm/test/CodeGen/AArch64/sme-peephole-opts.ll

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-streaming-hazard-size=0 -mattr=+sve,+sme2 < %s | FileCheck %s
33

44
declare void @callee()
5+
declare void @callee_sm() "aarch64_pstate_sm_enabled"
56
declare void @callee_farg(float)
67
declare float @callee_farg_fret(float)
78

89
; normal caller -> streaming callees
9-
define void @test0() nounwind {
10+
define void @test0(ptr %callee) nounwind {
1011
; CHECK-LABEL: test0:
1112
; CHECK: // %bb.0:
1213
; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
@@ -16,17 +17,17 @@ define void @test0() nounwind {
1617
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
1718
; CHECK-NEXT: stp x30, x9, [sp, #64] // 16-byte Folded Spill
1819
; CHECK-NEXT: smstart sm
19-
; CHECK-NEXT: bl callee
20-
; CHECK-NEXT: bl callee
20+
; CHECK-NEXT: bl callee_sm
21+
; CHECK-NEXT: bl callee_sm
2122
; CHECK-NEXT: smstop sm
2223
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
2324
; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
2425
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
2526
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
2627
; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
2728
; CHECK-NEXT: ret
28-
call void @callee() "aarch64_pstate_sm_enabled"
29-
call void @callee() "aarch64_pstate_sm_enabled"
29+
call void @callee_sm()
30+
call void @callee_sm()
3031
ret void
3132
}
3233

@@ -118,7 +119,7 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
118119
; CHECK-NEXT: // %bb.1:
119120
; CHECK-NEXT: smstart sm
120121
; CHECK-NEXT: .LBB3_2:
121-
; CHECK-NEXT: bl callee
122+
; CHECK-NEXT: bl callee_sm
122123
; CHECK-NEXT: tbnz w19, #0, .LBB3_4
123124
; CHECK-NEXT: // %bb.3:
124125
; CHECK-NEXT: smstop sm
@@ -140,7 +141,7 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
140141
; CHECK-NEXT: // %bb.9:
141142
; CHECK-NEXT: smstart sm
142143
; CHECK-NEXT: .LBB3_10:
143-
; CHECK-NEXT: bl callee
144+
; CHECK-NEXT: bl callee_sm
144145
; CHECK-NEXT: tbnz w19, #0, .LBB3_12
145146
; CHECK-NEXT: // %bb.11:
146147
; CHECK-NEXT: smstop sm
@@ -152,9 +153,9 @@ define void @test3() nounwind "aarch64_pstate_sm_compatible" {
152153
; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
153154
; CHECK-NEXT: ldp d15, d14, [sp], #96 // 16-byte Folded Reload
154155
; CHECK-NEXT: ret
155-
call void @callee() "aarch64_pstate_sm_enabled"
156+
call void @callee_sm()
156157
call void @callee()
157-
call void @callee() "aarch64_pstate_sm_enabled"
158+
call void @callee_sm()
158159
ret void
159160
}
160161

@@ -342,7 +343,7 @@ define void @test10() "aarch64_pstate_sm_body" {
342343
; CHECK-NEXT: bl callee
343344
; CHECK-NEXT: smstart sm
344345
; CHECK-NEXT: .cfi_restore vg
345-
; CHECK-NEXT: bl callee
346+
; CHECK-NEXT: bl callee_sm
346347
; CHECK-NEXT: .cfi_offset vg, -24
347348
; CHECK-NEXT: smstop sm
348349
; CHECK-NEXT: bl callee
@@ -363,7 +364,7 @@ define void @test10() "aarch64_pstate_sm_body" {
363364
; CHECK-NEXT: .cfi_restore b15
364365
; CHECK-NEXT: ret
365366
call void @callee()
366-
call void @callee() "aarch64_pstate_sm_enabled"
367+
call void @callee_sm()
367368
call void @callee()
368369
ret void
369370
}

llvm/test/CodeGen/AArch64/sme-vg-to-stack.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,11 +1098,11 @@ define void @test_rdsvl_right_after_prologue(i64 %x0) nounwind {
10981098
; NO-SVE-CHECK-NEXT: ret
10991099
%some_alloc = alloca i64, align 8
11001100
%rdsvl = tail call i64 @llvm.aarch64.sme.cntsd()
1101-
call void @bar(i64 %rdsvl, i64 %x0) "aarch64_pstate_sm_enabled"
1101+
call void @bar(i64 %rdsvl, i64 %x0)
11021102
ret void
11031103
}
11041104

1105-
declare void @bar(i64, i64)
1105+
declare void @bar(i64, i64) "aarch64_pstate_sm_enabled"
11061106

11071107
; Ensure we still emit async unwind information with -fno-asynchronous-unwind-tables
11081108
; if the function contains a streaming-mode change.

0 commit comments

Comments
 (0)