Skip to content

[AArch64][SME] Merge back-to-back SME call regions #142111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
576 changes: 368 additions & 208 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Large diffs are not rendered by default.

20 changes: 15 additions & 5 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ enum NodeType : unsigned {
RESTORE_ZT,
SAVE_ZT,

SME_CALL_START,
SME_CALL_SM_CHANGE,
SME_CALL_END,

// A call with the callee in x16, i.e. "blr x16".
CALL_ARM64EC_TO_X64,

Expand Down Expand Up @@ -823,6 +827,9 @@ class AArch64TargetLowering : public TargetLowering {
TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const override;

TargetLoweringBase::LegalizeAction
getCustomOperationAction(SDNode &) const override;

/// If the target has a standard location for the stack protector cookie,
/// returns the address of that location. Otherwise, returns nullptr.
Value *getIRStackGuard(IRBuilderBase &IRB) const override;
Expand Down Expand Up @@ -1028,6 +1035,11 @@ class AArch64TargetLowering : public TargetLowering {
/// True if stack clash protection is enabled for this functions.
bool hasInlineStackProbe(const MachineFunction &MF) const override;

// Returns the runtime value for PSTATE.SM by generating a call to
// __arm_sme_state.
SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
EVT VT) const;

private:
/// Keep a pointer to the AArch64Subtarget around so that we can
/// make the right decision when generating code for different targets.
Expand Down Expand Up @@ -1211,6 +1223,9 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSME_CALL_START(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSME_CALL_SM_CHANGE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSME_CALL_END(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;

Expand Down Expand Up @@ -1347,11 +1362,6 @@ class AArch64TargetLowering : public TargetLowering {
// This function does not handle predicate bitcasts.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;

// Returns the runtime value for PSTATE.SM by generating a call to
// __arm_sme_state.
SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
EVT VT) const;

bool preferScalarizeSplat(SDNode *N) const override;

unsigned getMinimumJumpTableEntries() const override;
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t VGIdx = std::numeric_limits<int>::max();
int64_t StreamingVGIdx = std::numeric_limits<int>::max();

// The stack slot where ZT0 is stored.
int64_t ZT0Idx = std::numeric_limits<int>::max();

public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);

Expand Down Expand Up @@ -275,6 +278,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t getStreamingVGIdx() const { return StreamingVGIdx; };
void setStreamingVGIdx(unsigned FrameIdx) { StreamingVGIdx = FrameIdx; };

int64_t getZT0Idx() const { return ZT0Idx; };
void setZT0Idx(unsigned FrameIdx) { ZT0Idx = FrameIdx; };

bool isSVECC() const { return IsSVECC; };
void setIsSVECC(bool s) { IsSVECC = s; };

Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ let usesCustomInserter = 1, Defs = [SP], Uses = [SP] in {
def : Pat<(i64 (AArch64AllocateZABuffer GPR64:$size)),
(AllocateZABuffer $size)>;

def AArch64InitTPIDR2Obj : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 1,
[SDTCisInt<0>]>, [SDNPHasChain, SDNPMayStore]>;
def AArch64InitTPIDR2Obj : SDNode<"AArch64ISD::INIT_TPIDR2OBJ", SDTypeProfile<0, 2,
[SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain, SDNPMayStore]>;
let usesCustomInserter = 1 in {
def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer, GPR64:$save_slices), [(AArch64InitTPIDR2Obj GPR64:$buffer, GPR64:$save_slices)]>, Sched<[WriteI]> {}
}

// Nodes to allocate a save buffer for SME.
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class SMEAttrs {
return Bitmask == Other.Bitmask;
}

explicit operator unsigned() const { return Bitmask; }

private:
void addKnownFunctionAttrs(StringRef FuncName);
};
Expand Down Expand Up @@ -201,6 +203,13 @@ class SMECallAttrs {
return caller().hasAgnosticZAInterface() &&
!callee().hasAgnosticZAInterface() && !callee().isSMEABIRoutine();
}

bool operator==(SMECallAttrs const &Other) const {
return caller() == Other.caller() && callee() == Other.callee() &&
callsite() == Other.callsite();
}

bool operator!=(SMECallAttrs const &Other) const { return !(*this == Other); }
};

} // namespace llvm
Expand Down
23 changes: 7 additions & 16 deletions llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,7 @@ define double @za_shared_caller_to_za_none_callee(double %x) nounwind noinline
; CHECK-COMMON-NEXT: mov x9, sp
; CHECK-COMMON-NEXT: msub x9, x8, x8, x9
; CHECK-COMMON-NEXT: mov sp, x9
; CHECK-COMMON-NEXT: stur x9, [x29, #-16]
; CHECK-COMMON-NEXT: sturh wzr, [x29, #-6]
; CHECK-COMMON-NEXT: stur wzr, [x29, #-4]
; CHECK-COMMON-NEXT: sturh w8, [x29, #-8]
; CHECK-COMMON-NEXT: stp x9, x8, [x29, #-16]
; CHECK-COMMON-NEXT: sub x8, x29, #16
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x8
; CHECK-COMMON-NEXT: bl normal_callee
Expand Down Expand Up @@ -310,12 +307,9 @@ define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_inout_za" nounwind {
; CHECK-COMMON-NEXT: mov x9, sp
; CHECK-COMMON-NEXT: msub x9, x8, x8, x9
; CHECK-COMMON-NEXT: mov sp, x9
; CHECK-COMMON-NEXT: stur x9, [x29, #-16]
; CHECK-COMMON-NEXT: sub x9, x29, #16
; CHECK-COMMON-NEXT: sturh wzr, [x29, #-6]
; CHECK-COMMON-NEXT: stur wzr, [x29, #-4]
; CHECK-COMMON-NEXT: sturh w8, [x29, #-8]
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x9
; CHECK-COMMON-NEXT: sub x10, x29, #16
; CHECK-COMMON-NEXT: stp x9, x8, [x29, #-16]
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x10
; CHECK-COMMON-NEXT: bl __addtf3
; CHECK-COMMON-NEXT: smstart za
; CHECK-COMMON-NEXT: mrs x8, TPIDR2_EL0
Expand Down Expand Up @@ -375,12 +369,9 @@ define double @frem_call_za(double %a, double %b) "aarch64_inout_za" nounwind {
; CHECK-COMMON-NEXT: mov x9, sp
; CHECK-COMMON-NEXT: msub x9, x8, x8, x9
; CHECK-COMMON-NEXT: mov sp, x9
; CHECK-COMMON-NEXT: stur x9, [x29, #-16]
; CHECK-COMMON-NEXT: sub x9, x29, #16
; CHECK-COMMON-NEXT: sturh wzr, [x29, #-6]
; CHECK-COMMON-NEXT: stur wzr, [x29, #-4]
; CHECK-COMMON-NEXT: sturh w8, [x29, #-8]
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x9
; CHECK-COMMON-NEXT: sub x10, x29, #16
; CHECK-COMMON-NEXT: stp x9, x8, [x29, #-16]
; CHECK-COMMON-NEXT: msr TPIDR2_EL0, x10
; CHECK-COMMON-NEXT: bl fmod
; CHECK-COMMON-NEXT: smstart za
; CHECK-COMMON-NEXT: mrs x8, TPIDR2_EL0
Expand Down
66 changes: 21 additions & 45 deletions llvm/test/CodeGen/AArch64/sme-lazy-save-call.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ define void @test_lazy_save_1_callee() nounwind "aarch64_inout_za" {
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: stur x9, [x29, #-16]
; CHECK-NEXT: sub x9, x29, #16
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh w8, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x9
; CHECK-NEXT: sub x10, x29, #16
; CHECK-NEXT: stp x9, x8, [x29, #-16]
; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
Expand All @@ -43,21 +40,18 @@ define void @test_lazy_save_1_callee() nounwind "aarch64_inout_za" {
define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
; CHECK-LABEL: test_lazy_save_2_callees:
; CHECK: // %bb.0:
; CHECK-NEXT: stp x29, x30, [sp, #-48]! // 16-byte Folded Spill
; CHECK-NEXT: str x21, [sp, #16] // 8-byte Folded Spill
; CHECK-NEXT: stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
; CHECK-NEXT: str x19, [sp, #16] // 8-byte Folded Spill
; CHECK-NEXT: mov x29, sp
; CHECK-NEXT: stp x20, x19, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: sub sp, sp, #16
; CHECK-NEXT: rdsvl x20, #1
; CHECK-NEXT: mov x8, sp
; CHECK-NEXT: msub x8, x20, x20, x8
; CHECK-NEXT: mov sp, x8
; CHECK-NEXT: sub x21, x29, #16
; CHECK-NEXT: stur x8, [x29, #-16]
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh w20, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x21
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: sub x10, x29, #16
; CHECK-NEXT: stp x9, x8, [x29, #-16]
; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
Expand All @@ -67,21 +61,9 @@ define void @test_lazy_save_2_callees() nounwind "aarch64_inout_za" {
; CHECK-NEXT: bl __arm_tpidr2_restore
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: sturh w20, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x21
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
; CHECK-NEXT: sub x0, x29, #16
; CHECK-NEXT: cbnz x8, .LBB1_4
; CHECK-NEXT: // %bb.3:
; CHECK-NEXT: bl __arm_tpidr2_restore
; CHECK-NEXT: .LBB1_4:
; CHECK-NEXT: msr TPIDR2_EL0, xzr
; CHECK-NEXT: mov sp, x29
; CHECK-NEXT: ldp x20, x19, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: ldr x21, [sp, #16] // 8-byte Folded Reload
; CHECK-NEXT: ldp x29, x30, [sp], #48 // 16-byte Folded Reload
; CHECK-NEXT: ldr x19, [sp, #16] // 8-byte Folded Reload
; CHECK-NEXT: ldp x29, x30, [sp], #32 // 16-byte Folded Reload
; CHECK-NEXT: ret
call void @private_za_callee()
call void @private_za_callee()
Expand All @@ -100,12 +82,9 @@ define float @test_lazy_save_expanded_intrinsic(float %a) nounwind "aarch64_inou
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: stur x9, [x29, #-16]
; CHECK-NEXT: sub x9, x29, #16
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh w8, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x9
; CHECK-NEXT: sub x10, x29, #16
; CHECK-NEXT: stp x9, x8, [x29, #-16]
; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl cosf
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
Expand Down Expand Up @@ -141,12 +120,9 @@ define void @test_lazy_save_and_conditional_smstart() nounwind "aarch64_inout_za
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: stur x9, [x29, #-80]
; CHECK-NEXT: sub x9, x29, #80
; CHECK-NEXT: sturh wzr, [x29, #-70]
; CHECK-NEXT: stur wzr, [x29, #-68]
; CHECK-NEXT: sturh w8, [x29, #-72]
; CHECK-NEXT: msr TPIDR2_EL0, x9
; CHECK-NEXT: sub x10, x29, #80
; CHECK-NEXT: stp x9, x8, [x29, #-80]
; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x20, x0, #0x1
; CHECK-NEXT: tbz w20, #0, .LBB3_2
Expand Down
27 changes: 6 additions & 21 deletions llvm/test/CodeGen/AArch64/sme-peephole-opts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,11 @@ define void @test2() nounwind "aarch64_pstate_sm_compatible" {
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB2_2:
; CHECK-NEXT: bl callee
; CHECK-NEXT: bl callee
; CHECK-NEXT: tbz w19, #0, .LBB2_4
; CHECK-NEXT: // %bb.3:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB2_4:
; CHECK-NEXT: bl __arm_sme_state
; CHECK-NEXT: and x19, x0, #0x1
; CHECK-NEXT: tbz w19, #0, .LBB2_6
; CHECK-NEXT: // %bb.5:
; CHECK-NEXT: smstop sm
; CHECK-NEXT: .LBB2_6:
; CHECK-NEXT: bl callee
; CHECK-NEXT: tbz w19, #0, .LBB2_8
; CHECK-NEXT: // %bb.7:
; CHECK-NEXT: smstart sm
; CHECK-NEXT: .LBB2_8:
; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
; CHECK-NEXT: ldr x19, [sp, #80] // 8-byte Folded Reload
; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
Expand Down Expand Up @@ -252,22 +242,17 @@ define float @test6(float %f) nounwind "aarch64_pstate_sm_enabled" {
define void @test7() nounwind "aarch64_inout_zt0" {
; CHECK-LABEL: test7:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #144
; CHECK-NEXT: stp x30, x19, [sp, #128] // 16-byte Folded Spill
; CHECK-NEXT: add x19, sp, #64
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: sub sp, sp, #80
; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: mov x19, sp
; CHECK-NEXT: str zt0, [x19]
; CHECK-NEXT: smstop za
; CHECK-NEXT: bl callee
; CHECK-NEXT: bl callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: ldr zt0, [x19]
; CHECK-NEXT: ldp x30, x19, [sp, #128] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #144
; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
; CHECK-NEXT: add sp, sp, #80
; CHECK-NEXT: ret
call void @callee()
call void @callee()
Expand Down
18 changes: 6 additions & 12 deletions llvm/test/CodeGen/AArch64/sme-shared-za-interface.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@ define void @disable_tailcallopt() "aarch64_inout_za" nounwind {
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: stur x9, [x29, #-16]
; CHECK-NEXT: sub x9, x29, #16
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh w8, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x9
; CHECK-NEXT: sub x10, x29, #16
; CHECK-NEXT: stp x9, x8, [x29, #-16]
; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl private_za_callee
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
Expand Down Expand Up @@ -47,12 +44,9 @@ define fp128 @f128_call_za(fp128 %a, fp128 %b) "aarch64_inout_za" nounwind {
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: msub x9, x8, x8, x9
; CHECK-NEXT: mov sp, x9
; CHECK-NEXT: stur x9, [x29, #-16]
; CHECK-NEXT: sub x9, x29, #16
; CHECK-NEXT: sturh wzr, [x29, #-6]
; CHECK-NEXT: stur wzr, [x29, #-4]
; CHECK-NEXT: sturh w8, [x29, #-8]
; CHECK-NEXT: msr TPIDR2_EL0, x9
; CHECK-NEXT: sub x10, x29, #16
; CHECK-NEXT: stp x9, x8, [x29, #-16]
; CHECK-NEXT: msr TPIDR2_EL0, x10
; CHECK-NEXT: bl __addtf3
; CHECK-NEXT: smstart za
; CHECK-NEXT: mrs x8, TPIDR2_EL0
Expand Down
Loading