Skip to content

Commit bcffebf

Browse files
authored
Arm64/SVE: Add support to handle predicate registers as callee-trash (#104065)
* Add MSK_CALLEE_TRASH and include it in CALLEE_TRASH * Assign correct registerType for predicate registers * Handle the save/restore of predicate registers * misc changes * jit format * Remove handling of Temps and use the same as locals * use GetPredicateRegSet() * Disable mask registers if on non-sve * small change in DbgEnc * jit format * Revert "jit format" This reverts commit 5535c69. * Revert "small change in DbgEnc" This reverts commit bb97d80. * Revert "Disable mask registers if on non-sve" This reverts commit bcfd8a8. * minor review feedback
1 parent e24ea1d commit bcffebf

File tree

6 files changed

+55
-16
lines changed

6 files changed

+55
-16
lines changed

src/coreclr/jit/emitarm64.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7884,7 +7884,22 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va
78847884
isSimple = false;
78857885
size = EA_SCALABLE;
78867886
attr = size;
7887-
fmt = isVectorRegister(reg1) ? IF_SVE_IE_2A : IF_SVE_ID_2A;
7887+
if (isPredicateRegister(reg1))
7888+
{
7889+
assert(offs == 0);
7890+
// For predicate, generate based off rsGetRsvdReg()
7891+
regNumber rsvdReg = codeGen->rsGetRsvdReg();
7892+
7893+
// add rsvd, fp, #imm
7894+
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, imm);
7895+
// str p0, [rsvd, #0, mul vl]
7896+
emitIns_R_R_I(ins, attr, reg1, rsvdReg, 0);
7897+
7898+
return;
7899+
}
7900+
7901+
assert(isVectorRegister(reg1));
7902+
fmt = IF_SVE_IE_2A;
78887903

78897904
// TODO-SVE: Don't assume 128bit vectors
78907905
// Predicate size is vector length / 8
@@ -8138,7 +8153,24 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va
81388153
isSimple = false;
81398154
size = EA_SCALABLE;
81408155
attr = size;
8141-
fmt = isVectorRegister(reg1) ? IF_SVE_JH_2A : IF_SVE_JG_2A;
8156+
8157+
if (isPredicateRegister(reg1))
8158+
{
8159+
assert(offs == 0);
8160+
8161+
// For predicate, generate based off rsGetRsvdReg()
8162+
regNumber rsvdReg = codeGen->rsGetRsvdReg();
8163+
8164+
// add rsvd, fp, #imm
8165+
emitIns_R_R_I(INS_add, EA_8BYTE, rsvdReg, reg2, imm);
8166+
// str p0, [rsvd, #0, mul vl]
8167+
emitIns_R_R_I(ins, attr, reg1, rsvdReg, 0);
8168+
8169+
return;
8170+
}
8171+
8172+
assert(isVectorRegister(reg1));
8173+
fmt = IF_SVE_JH_2A;
81428174

81438175
// TODO-SVE: Don't assume 128bit vectors
81448176
// Predicate size is vector length / 8

src/coreclr/jit/emitarm64.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,11 @@ inline static bool isHighPredicateRegister(regNumber reg)
12051205
return (reg >= REG_PREDICATE_HIGH_FIRST) && (reg <= REG_PREDICATE_HIGH_LAST);
12061206
}
12071207

1208+
inline static bool isMaskReg(regNumber reg)
1209+
{
1210+
return isPredicateRegister(reg);
1211+
}
1212+
12081213
inline static bool isEvenRegister(regNumber reg)
12091214
{
12101215
if (isGeneralRegister(reg))

src/coreclr/jit/lclvars.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5593,7 +5593,7 @@ unsigned Compiler::lvaGetMaxSpillTempSize()
55935593
* Doing this all in one pass is 'hard'. So instead we do it in 2 basic passes:
55945594
* 1. Assign all the offsets relative to the Virtual '0'. Offsets above (the
55955595
* incoming arguments) are positive. Offsets below (everything else) are
5596-
* negative. This pass also calcuates the total frame size (between Caller's
5596+
* negative. This pass also calculates the total frame size (between Caller's
55975597
* SP/return address and the Ambient SP).
55985598
* 2. Figure out where to place the frame pointer, and then adjust the offsets
55995599
* as needed for the final stack size and whether the offset is frame pointer

src/coreclr/jit/lsra.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,13 @@ class RegRecord : public Referenceable
508508
{
509509
registerType = FloatRegisterType;
510510
}
511-
#if defined(TARGET_XARCH) && defined(FEATURE_SIMD)
511+
#if defined(FEATURE_MASKED_HW_INTRINSICS)
512512
else
513513
{
514514
assert(emitter::isMaskReg(reg));
515515
registerType = MaskRegisterType;
516516
}
517-
#endif
517+
#endif // FEATURE_MASKED_HW_INTRINSICS
518518
regNum = reg;
519519
isCalleeSave = ((RBM_CALLEE_SAVED & genRegMask(reg)) != 0);
520520
}

src/coreclr/jit/lsrabuild.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,9 @@ regMaskTP LinearScan::getKillSetForCall(GenTreeCall* call)
855855

856856
#else
857857
killMask.RemoveRegsetForType(RBM_FLT_CALLEE_TRASH.GetFloatRegSet(), FloatRegisterType);
858+
#if defined(TARGET_ARM64)
859+
killMask.RemoveRegsetForType(RBM_MSK_CALLEE_TRASH.GetPredicateRegSet(), MaskRegisterType);
860+
#endif // TARGET_ARM64
858861
#endif // TARGET_XARCH
859862
}
860863
#ifdef TARGET_ARM
@@ -1148,8 +1151,8 @@ bool LinearScan::buildKillPositionsForNode(GenTree* tree, LsraLocation currentLo
11481151
{
11491152
continue;
11501153
}
1151-
Interval* interval = getIntervalForLocalVar(varIndex);
1152-
const bool isCallKill = ((killMask == RBM_INT_CALLEE_TRASH) || (killMask == RBM_CALLEE_TRASH));
1154+
Interval* interval = getIntervalForLocalVar(varIndex);
1155+
const bool isCallKill = ((killMask.getLow() == RBM_INT_CALLEE_TRASH) || (killMask == RBM_CALLEE_TRASH));
11531156
SingleTypeRegSet regsKillMask = killMask.GetRegSetForType(interval->registerType);
11541157

11551158
if (isCallKill)

src/coreclr/jit/targetarm64.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,15 @@
7575
#define RBM_FLT_CALLEE_SAVED (RBM_V8|RBM_V9|RBM_V10|RBM_V11|RBM_V12|RBM_V13|RBM_V14|RBM_V15)
7676
#define RBM_FLT_CALLEE_TRASH (RBM_V0|RBM_V1|RBM_V2|RBM_V3|RBM_V4|RBM_V5|RBM_V6|RBM_V7|RBM_V16|RBM_V17|RBM_V18|RBM_V19|RBM_V20|RBM_V21|RBM_V22|RBM_V23|RBM_V24|RBM_V25|RBM_V26|RBM_V27|RBM_V28|RBM_V29|RBM_V30|RBM_V31)
7777

78+
#define RBM_LOWMASK (RBM_P0|RBM_P1|RBM_P2|RBM_P3|RBM_P4|RBM_P5|RBM_P6|RBM_P7)
79+
#define RBM_HIGHMASK (RBM_P8|RBM_P9|RBM_P10| RBM_P11|RBM_P12|RBM_P13|RBM_P14|RBM_P15)
80+
#define RBM_ALLMASK (RBM_LOWMASK|RBM_HIGHMASK)
81+
82+
#define RBM_MSK_CALLEE_SAVED (0)
83+
#define RBM_MSK_CALLEE_TRASH RBM_ALLMASK
84+
7885
#define RBM_CALLEE_SAVED (RBM_INT_CALLEE_SAVED | RBM_FLT_CALLEE_SAVED)
79-
#define RBM_CALLEE_TRASH (RBM_INT_CALLEE_TRASH | RBM_FLT_CALLEE_TRASH)
86+
#define RBM_CALLEE_TRASH (RBM_INT_CALLEE_TRASH | RBM_FLT_CALLEE_TRASH | RBM_MSK_CALLEE_TRASH)
8087

8188
#define REG_DEFAULT_HELPER_CALL_TARGET REG_R12
8289
#define RBM_DEFAULT_HELPER_CALL_TARGET RBM_R12
@@ -146,14 +153,6 @@
146153
#define REG_JUMP_THUNK_PARAM REG_R12
147154
#define RBM_JUMP_THUNK_PARAM RBM_R12
148155

149-
#define RBM_LOWMASK (RBM_P0 | RBM_P1 | RBM_P2 | RBM_P3 | RBM_P4 | RBM_P5 | RBM_P6 | RBM_P7)
150-
#define RBM_HIGHMASK (RBM_P8 | RBM_P9 | RBM_P10 | RBM_P11 | RBM_P12 | RBM_P13 | RBM_P14 | RBM_P15)
151-
#define RBM_ALLMASK (RBM_LOWMASK | RBM_HIGHMASK)
152-
153-
// TODO-SVE: Fix when adding predicate register allocation
154-
#define RBM_MSK_CALLEE_SAVED (0)
155-
#define RBM_MSK_CALLEE_TRASH (0)
156-
157156
// ARM64 write barrier ABI (see vm\arm64\asmhelpers.asm, vm\arm64\asmhelpers.S):
158157
// CORINFO_HELP_ASSIGN_REF (JIT_WriteBarrier), CORINFO_HELP_CHECKED_ASSIGN_REF (JIT_CheckedWriteBarrier):
159158
// On entry:

0 commit comments

Comments
 (0)