Skip to content

Commit 89d63f9

Browse files
ARM64-SVE: Add ShiftRightArithmeticForDivide (#104279)
1 parent 4f96b8f commit 89d63f9

File tree

8 files changed

+142
-13
lines changed

8 files changed

+142
-13
lines changed

src/coreclr/jit/hwintrinsiccodegenarm64.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,8 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
588588
{
589589
assert(instrIsRMW);
590590

591-
insScalableOpts sopt;
591+
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;
592+
bool hasShift = false;
592593

593594
switch (intrinEmbMask.id)
594595
{
@@ -601,17 +602,34 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
601602
{
602603
assert(emitter::optGetSveInsOpt(op2Size) == INS_OPTS_SCALABLE_D);
603604
sopt = INS_SCALABLE_OPTS_WIDE;
604-
break;
605605
}
606-
607-
FALLTHROUGH;
606+
break;
608607
}
609608

609+
case NI_Sve_ShiftRightArithmeticForDivide:
610+
hasShift = true;
611+
break;
612+
610613
default:
611-
sopt = INS_SCALABLE_OPTS_NONE;
612614
break;
613615
}
614616

617+
auto emitInsHelper = [&](regNumber reg1, regNumber reg2, regNumber reg3) {
618+
if (hasShift)
619+
{
620+
HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op2, op2->AsHWIntrinsic());
621+
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
622+
{
623+
GetEmitter()->emitInsSve_R_R_I(insEmbMask, emitSize, reg1, reg2, helper.ImmValue(), opt,
624+
sopt);
625+
}
626+
}
627+
else
628+
{
629+
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, reg1, reg2, reg3, opt, sopt);
630+
}
631+
};
632+
615633
if (intrin.op3->IsVectorZero())
616634
{
617635
// If `falseReg` is zero, then move the first operand of `intrinEmbMask` in the
@@ -622,7 +640,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
622640

623641
// Finally, perform the actual "predicated" operation so that `targetReg` is the first operand
624642
// and `embMaskOp2Reg` is the second operand.
625-
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg, opt, sopt);
643+
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
626644
}
627645
else if (targetReg != falseReg)
628646
{
@@ -636,8 +654,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
636654
{
637655
// If the embedded instruction supports optional mask operation, use the "unpredicated"
638656
// version of the instruction, followed by "sel" to select the active lanes.
639-
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, embMaskOp1Reg,
640-
embMaskOp2Reg, opt, sopt);
657+
emitInsHelper(targetReg, embMaskOp1Reg, embMaskOp2Reg);
641658
}
642659
else
643660
{
@@ -651,8 +668,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
651668

652669
GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, embMaskOp1Reg);
653670

654-
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg,
655-
opt, sopt);
671+
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
656672
}
657673

658674
GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg,
@@ -669,13 +685,13 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
669685

670686
// Finally, perform the actual "predicated" operation so that `targetReg` is the first operand
671687
// and `embMaskOp2Reg` is the second operand.
672-
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg, opt, sopt);
688+
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
673689
}
674690
else
675691
{
676692
// Just perform the actual "predicated" operation so that `targetReg` is the first operand
677693
// and `embMaskOp2Reg` is the second operand.
678-
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg, opt, sopt);
694+
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
679695
}
680696

681697
break;

src/coreclr/jit/hwintrinsiclistarm64sve.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ HARDWARE_INTRINSIC(Sve, SaturatingIncrementByActiveElementCount,
209209
HARDWARE_INTRINSIC(Sve, Scale, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fscale, INS_sve_fscale}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
210210
HARDWARE_INTRINSIC(Sve, ShiftLeftLogical, -1, -1, false, {INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
211211
HARDWARE_INTRINSIC(Sve, ShiftRightArithmetic, -1, -1, false, {INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
212+
HARDWARE_INTRINSIC(Sve, ShiftRightArithmeticForDivide, -1, -1, false, {INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_invalid, INS_invalid}, HW_Category_ShiftRightByImmediate, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_HasImmediateOperand)
212213
HARDWARE_INTRINSIC(Sve, ShiftRightLogical, -1, -1, false, {INS_invalid, INS_sve_lsr, INS_invalid, INS_sve_lsr, INS_invalid, INS_sve_lsr, INS_invalid, INS_sve_lsr, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
213214
HARDWARE_INTRINSIC(Sve, SignExtend16, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sxth, INS_invalid, INS_sve_sxth, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
214215
HARDWARE_INTRINSIC(Sve, SignExtend32, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sxtw, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)

src/coreclr/jit/lowerarmarch.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3371,6 +3371,8 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
33713371
// Handle op2
33723372
if (op2->OperIsHWIntrinsic())
33733373
{
3374+
const GenTreeHWIntrinsic* embOp = op2->AsHWIntrinsic();
3375+
33743376
if (IsInvariantInRange(op2, node) && op2->isEmbeddedMaskingCompatibleHWIntrinsic())
33753377
{
33763378
uint32_t maskSize = genTypeSize(node->GetSimdBaseType());
@@ -3386,7 +3388,6 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
33863388
{
33873389
// Else check if this operation has an auxiliary type that matches the
33883390
// mask size.
3389-
GenTreeHWIntrinsic* embOp = op2->AsHWIntrinsic();
33903391

33913392
// For now, make sure that we get here only for intrinsics that we are
33923393
// sure about to rely on auxiliary type's size.
@@ -3403,6 +3404,17 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
34033404
}
34043405
}
34053406
}
3407+
3408+
// Handle intrinsics with embedded masks and immediate operands
3409+
// (For now, just handle ShiftRightArithmeticForDivide specifically)
3410+
if (embOp->GetHWIntrinsicId() == NI_Sve_ShiftRightArithmeticForDivide)
3411+
{
3412+
assert(embOp->GetOperandCount() == 2);
3413+
if (embOp->Op(2)->IsCnsIntOrI())
3414+
{
3415+
MakeSrcContained(op2, embOp->Op(2));
3416+
}
3417+
}
34063418
}
34073419

34083420
// Handle op3

src/coreclr/jit/lsraarm64.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,19 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
19201920
else
19211921
{
19221922
assert((numArgs == 1) || (numArgs == 2) || (numArgs == 3));
1923+
1924+
// Special handling for ShiftRightArithmeticForDivide:
1925+
// We might need an additional register to hold branch targets into the switch table
1926+
// that encodes the immediate
1927+
if (intrinEmb.id == NI_Sve_ShiftRightArithmeticForDivide)
1928+
{
1929+
assert(embOp2Node->GetOperandCount() == 2);
1930+
if (!embOp2Node->Op(2)->isContainedIntOrIImmed())
1931+
{
1932+
buildInternalIntRegisterDefForNode(embOp2Node);
1933+
}
1934+
}
1935+
19231936
tgtPrefUse = BuildUse(embOp2Node->Op(1));
19241937
srcCount += 1;
19251938

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Arm/Sve.PlatformNotSupported.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6351,6 +6351,45 @@ internal Arm64() { }
63516351
public static unsafe Vector<sbyte> ShiftRightArithmetic(Vector<sbyte> left, Vector<ulong> right) { throw new PlatformNotSupportedException(); }
63526352

63536353

6354+
/// Arithmetic shift right for divide by immediate
6355+
6356+
/// <summary>
6357+
/// svint16_t svasrd[_n_s16]_m(svbool_t pg, svint16_t op1, uint64_t imm2)
6358+
/// ASRD Ztied1.H, Pg/M, Ztied1.H, #imm2
6359+
/// svint16_t svasrd[_n_s16]_x(svbool_t pg, svint16_t op1, uint64_t imm2)
6360+
/// ASRD Ztied1.H, Pg/M, Ztied1.H, #imm2
6361+
/// svint16_t svasrd[_n_s16]_z(svbool_t pg, svint16_t op1, uint64_t imm2)
6362+
/// </summary>
6363+
public static unsafe Vector<short> ShiftRightArithmeticForDivide(Vector<short> value, [ConstantExpected(Min = 1, Max = (byte)(16))] byte control) { throw new PlatformNotSupportedException(); }
6364+
6365+
/// <summary>
6366+
/// svint32_t svasrd[_n_s32]_m(svbool_t pg, svint32_t op1, uint64_t imm2)
6367+
/// ASRD Ztied1.S, Pg/M, Ztied1.S, #imm2
6368+
/// svint32_t svasrd[_n_s32]_x(svbool_t pg, svint32_t op1, uint64_t imm2)
6369+
/// ASRD Ztied1.S, Pg/M, Ztied1.S, #imm2
6370+
/// svint32_t svasrd[_n_s32]_z(svbool_t pg, svint32_t op1, uint64_t imm2)
6371+
/// </summary>
6372+
public static unsafe Vector<int> ShiftRightArithmeticForDivide(Vector<int> value, [ConstantExpected(Min = 1, Max = (byte)(32))] byte control) { throw new PlatformNotSupportedException(); }
6373+
6374+
/// <summary>
6375+
/// svint64_t svasrd[_n_s64]_m(svbool_t pg, svint64_t op1, uint64_t imm2)
6376+
/// ASRD Ztied1.D, Pg/M, Ztied1.D, #imm2
6377+
/// svint64_t svasrd[_n_s64]_x(svbool_t pg, svint64_t op1, uint64_t imm2)
6378+
/// ASRD Ztied1.D, Pg/M, Ztied1.D, #imm2
6379+
/// svint64_t svasrd[_n_s64]_z(svbool_t pg, svint64_t op1, uint64_t imm2)
6380+
/// </summary>
6381+
public static unsafe Vector<long> ShiftRightArithmeticForDivide(Vector<long> value, [ConstantExpected(Min = 1, Max = (byte)(64))] byte control) { throw new PlatformNotSupportedException(); }
6382+
6383+
/// <summary>
6384+
/// svint8_t svasrd[_n_s8]_m(svbool_t pg, svint8_t op1, uint64_t imm2)
6385+
/// ASRD Ztied1.B, Pg/M, Ztied1.B, #imm2
6386+
/// svint8_t svasrd[_n_s8]_x(svbool_t pg, svint8_t op1, uint64_t imm2)
6387+
/// ASRD Ztied1.B, Pg/M, Ztied1.B, #imm2
6388+
/// svint8_t svasrd[_n_s8]_z(svbool_t pg, svint8_t op1, uint64_t imm2)
6389+
/// </summary>
6390+
public static unsafe Vector<sbyte> ShiftRightArithmeticForDivide(Vector<sbyte> value, [ConstantExpected(Min = 1, Max = (byte)(8))] byte control) { throw new PlatformNotSupportedException(); }
6391+
6392+
63546393
/// Logical shift right
63556394

63566395
/// <summary>

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Arm/Sve.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6395,6 +6395,45 @@ internal Arm64() { }
63956395
public static unsafe Vector<sbyte> ShiftRightArithmetic(Vector<sbyte> left, Vector<ulong> right) => ShiftRightArithmetic(left, right);
63966396

63976397

6398+
/// Arithmetic shift right for divide by immediate
6399+
6400+
/// <summary>
6401+
/// svint16_t svasrd[_n_s16]_m(svbool_t pg, svint16_t op1, uint64_t imm2)
6402+
/// ASRD Ztied1.H, Pg/M, Ztied1.H, #imm2
6403+
/// svint16_t svasrd[_n_s16]_x(svbool_t pg, svint16_t op1, uint64_t imm2)
6404+
/// ASRD Ztied1.H, Pg/M, Ztied1.H, #imm2
6405+
/// svint16_t svasrd[_n_s16]_z(svbool_t pg, svint16_t op1, uint64_t imm2)
6406+
/// </summary>
6407+
public static unsafe Vector<short> ShiftRightArithmeticForDivide(Vector<short> value, [ConstantExpected(Min = 1, Max = (byte)(16))] byte control) => ShiftRightArithmeticForDivide(value, control);
6408+
6409+
/// <summary>
6410+
/// svint32_t svasrd[_n_s32]_m(svbool_t pg, svint32_t op1, uint64_t imm2)
6411+
/// ASRD Ztied1.S, Pg/M, Ztied1.S, #imm2
6412+
/// svint32_t svasrd[_n_s32]_x(svbool_t pg, svint32_t op1, uint64_t imm2)
6413+
/// ASRD Ztied1.S, Pg/M, Ztied1.S, #imm2
6414+
/// svint32_t svasrd[_n_s32]_z(svbool_t pg, svint32_t op1, uint64_t imm2)
6415+
/// </summary>
6416+
public static unsafe Vector<int> ShiftRightArithmeticForDivide(Vector<int> value, [ConstantExpected(Min = 1, Max = (byte)(32))] byte control) => ShiftRightArithmeticForDivide(value, control);
6417+
6418+
/// <summary>
6419+
/// svint64_t svasrd[_n_s64]_m(svbool_t pg, svint64_t op1, uint64_t imm2)
6420+
/// ASRD Ztied1.D, Pg/M, Ztied1.D, #imm2
6421+
/// svint64_t svasrd[_n_s64]_x(svbool_t pg, svint64_t op1, uint64_t imm2)
6422+
/// ASRD Ztied1.D, Pg/M, Ztied1.D, #imm2
6423+
/// svint64_t svasrd[_n_s64]_z(svbool_t pg, svint64_t op1, uint64_t imm2)
6424+
/// </summary>
6425+
public static unsafe Vector<long> ShiftRightArithmeticForDivide(Vector<long> value, [ConstantExpected(Min = 1, Max = (byte)(64))] byte control) => ShiftRightArithmeticForDivide(value, control);
6426+
6427+
/// <summary>
6428+
/// svint8_t svasrd[_n_s8]_m(svbool_t pg, svint8_t op1, uint64_t imm2)
6429+
/// ASRD Ztied1.B, Pg/M, Ztied1.B, #imm2
6430+
/// svint8_t svasrd[_n_s8]_x(svbool_t pg, svint8_t op1, uint64_t imm2)
6431+
/// ASRD Ztied1.B, Pg/M, Ztied1.B, #imm2
6432+
/// svint8_t svasrd[_n_s8]_z(svbool_t pg, svint8_t op1, uint64_t imm2)
6433+
/// </summary>
6434+
public static unsafe Vector<sbyte> ShiftRightArithmeticForDivide(Vector<sbyte> value, [ConstantExpected(Min = 1, Max = (byte)(8))] byte control) => ShiftRightArithmeticForDivide(value, control);
6435+
6436+
63986437
/// Logical shift right
63996438

64006439
/// <summary>

src/libraries/System.Runtime.Intrinsics/ref/System.Runtime.Intrinsics.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5118,6 +5118,10 @@ internal Arm64() { }
51185118
public static System.Numerics.Vector<long> ShiftRightArithmetic(System.Numerics.Vector<long> left, System.Numerics.Vector<ulong> right) { throw null; }
51195119
public static System.Numerics.Vector<sbyte> ShiftRightArithmetic(System.Numerics.Vector<sbyte> left, System.Numerics.Vector<byte> right) { throw null; }
51205120
public static System.Numerics.Vector<sbyte> ShiftRightArithmetic(System.Numerics.Vector<sbyte> left, System.Numerics.Vector<ulong> right) { throw null; }
5121+
public static System.Numerics.Vector<short> ShiftRightArithmeticForDivide(System.Numerics.Vector<short> value, [ConstantExpected(Min = 1, Max = (byte)(16))] byte control) { throw null; }
5122+
public static System.Numerics.Vector<int> ShiftRightArithmeticForDivide(System.Numerics.Vector<int> value, [ConstantExpected(Min = 1, Max = (byte)(32))] byte control) { throw null; }
5123+
public static System.Numerics.Vector<long> ShiftRightArithmeticForDivide(System.Numerics.Vector<long> value, [ConstantExpected(Min = 1, Max = (byte)(64))] byte control) { throw null; }
5124+
public static System.Numerics.Vector<sbyte> ShiftRightArithmeticForDivide(System.Numerics.Vector<sbyte> value, [ConstantExpected(Min = 1, Max = (byte)(8))] byte control) { throw null; }
51215125
public static System.Numerics.Vector<byte> ShiftRightLogical(System.Numerics.Vector<byte> left, System.Numerics.Vector<byte> right) { throw null; }
51225126
public static System.Numerics.Vector<byte> ShiftRightLogical(System.Numerics.Vector<byte> left, System.Numerics.Vector<ulong> right) { throw null; }
51235127
public static System.Numerics.Vector<ushort> ShiftRightLogical(System.Numerics.Vector<ushort> left, System.Numerics.Vector<ushort> right) { throw null; }

0 commit comments

Comments
 (0)