Skip to content

Commit 3a294ed

Browse files
authored
JIT: Added SVE APIs CreateMaskForFirstActiveElement and CreateMaskForNextActiveElement (#104002)
* Initial work * Added tests. Fixed parameter names. * Use delay free for op1 if the target preference is op2. Use sve_mov instead of mov. * Feedback * Feedback * Update Helpers.cs * Handle RMW for non-explicit masked operation * Remove handling as its already handled it looks like * Feedback * Feedback * Feedback
1 parent 0746dd3 commit 3a294ed

File tree

10 files changed

+454
-4
lines changed

10 files changed

+454
-4
lines changed

src/coreclr/jit/hwintrinsic.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1922,6 +1922,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
19221922
break;
19231923
}
19241924

1925+
case NI_Sve_CreateMaskForFirstActiveElement:
1926+
case NI_Sve_CreateMaskForNextActiveElement:
19251927
case NI_Sve_GetActiveElementCount:
19261928
case NI_Sve_TestAnyTrue:
19271929
case NI_Sve_TestFirstTrue:

src/coreclr/jit/hwintrinsiccodegenarm64.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,23 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
867867
assert(!node->IsEmbMaskOp());
868868
if (HWIntrinsicInfo::IsExplicitMaskedOperation(intrin.id))
869869
{
870-
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
870+
if (isRMW)
871+
{
872+
if (targetReg != op2Reg)
873+
{
874+
assert(targetReg != op1Reg);
875+
876+
GetEmitter()->emitIns_Mov(ins_Move_Extend(intrin.op2->TypeGet(), false),
877+
emitTypeSize(node), targetReg, op2Reg,
878+
/* canSkip */ true);
879+
}
880+
881+
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
882+
}
883+
else
884+
{
885+
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
886+
}
871887
}
872888
else
873889
{
@@ -2211,6 +2227,21 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
22112227
break;
22122228
}
22132229

2230+
case NI_Sve_CreateMaskForFirstActiveElement:
2231+
{
2232+
assert(isRMW);
2233+
assert(HWIntrinsicInfo::IsExplicitMaskedOperation(intrin.id));
2234+
2235+
if (targetReg != op2Reg)
2236+
{
2237+
assert(targetReg != op1Reg);
2238+
GetEmitter()->emitIns_Mov(INS_sve_mov, emitTypeSize(node), targetReg, op2Reg, /* canSkip */ true);
2239+
}
2240+
2241+
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, INS_OPTS_SCALABLE_B);
2242+
break;
2243+
}
2244+
22142245
default:
22152246
unreached();
22162247
}

src/coreclr/jit/hwintrinsiclistarm64sve.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ HARDWARE_INTRINSIC(Sve, CreateFalseMaskSingle,
4747
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt16, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
4848
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt32, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
4949
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt64, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
50+
HARDWARE_INTRINSIC(Sve, CreateMaskForFirstActiveElement, -1, 2, true, {INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
51+
HARDWARE_INTRINSIC(Sve, CreateMaskForNextActiveElement, -1, 2, true, {INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_HasRMWSemantics)
5052
HARDWARE_INTRINSIC(Sve, CreateTrueMaskByte, -1, 1, false, {INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)
5153
HARDWARE_INTRINSIC(Sve, CreateTrueMaskDouble, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)
5254
HARDWARE_INTRINSIC(Sve, CreateTrueMaskInt16, -1, 1, false, {INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)

src/coreclr/jit/instr.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,7 +1721,6 @@ instruction CodeGen::ins_Move_Extend(var_types srcType, bool srcInReg)
17211721
#if defined(TARGET_XARCH)
17221722
return INS_kmovq_msk;
17231723
#elif defined(TARGET_ARM64)
1724-
unreached(); // TODO-SVE: This needs testing
17251724
return INS_sve_mov;
17261725
#endif
17271726
}
@@ -2085,7 +2084,6 @@ instruction CodeGen::ins_Copy(regNumber srcReg, var_types dstType)
20852084
#if defined(TARGET_XARCH)
20862085
return INS_kmovq_gpr;
20872086
#elif defined(TARGET_ARM64)
2088-
unreached(); // TODO-SVE: This needs testing
20892087
return INS_sve_mov;
20902088
#endif
20912089
}

src/coreclr/jit/lsraarm64.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1627,7 +1627,14 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
16271627
predMask = RBM_LOWMASK.GetPredicateRegSet();
16281628
}
16291629

1630-
srcCount += BuildOperandUses(intrin.op1, predMask);
1630+
if (tgtPrefOp2)
1631+
{
1632+
srcCount += BuildDelayFreeUses(intrin.op1, intrin.op2, predMask);
1633+
}
1634+
else
1635+
{
1636+
srcCount += BuildOperandUses(intrin.op1, predMask);
1637+
}
16311638
}
16321639
}
16331640
else if (intrinsicTree->OperIsMemoryLoadOrStore())

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Arm/Sve.PlatformNotSupported.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,79 @@ internal Arm64() { }
10161016
public static unsafe Vector<ulong> CreateFalseMaskUInt64() { throw new PlatformNotSupportedException(); }
10171017

10181018

1019+
/// <summary>
1020+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1021+
/// PFIRST Ptied.B, Pg, Ptied.B
1022+
/// </summary>
1023+
public static unsafe Vector<byte> CreateMaskForFirstActiveElement(Vector<byte> mask, Vector<byte> srcMask) { throw new PlatformNotSupportedException(); }
1024+
1025+
/// <summary>
1026+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1027+
/// PFIRST Ptied.B, Pg, Ptied.B
1028+
/// </summary>
1029+
public static unsafe Vector<short> CreateMaskForFirstActiveElement(Vector<short> mask, Vector<short> srcMask) { throw new PlatformNotSupportedException(); }
1030+
1031+
/// <summary>
1032+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1033+
/// PFIRST Ptied.B, Pg, Ptied.B
1034+
/// </summary>
1035+
public static unsafe Vector<int> CreateMaskForFirstActiveElement(Vector<int> mask, Vector<int> srcMask) { throw new PlatformNotSupportedException(); }
1036+
1037+
/// <summary>
1038+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1039+
/// PFIRST Ptied.B, Pg, Ptied.B
1040+
/// </summary>
1041+
public static unsafe Vector<long> CreateMaskForFirstActiveElement(Vector<long> mask, Vector<long> srcMask) { throw new PlatformNotSupportedException(); }
1042+
1043+
/// <summary>
1044+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1045+
/// PFIRST Ptied.B, Pg, Ptied.B
1046+
/// </summary>
1047+
public static unsafe Vector<sbyte> CreateMaskForFirstActiveElement(Vector<sbyte> mask, Vector<sbyte> srcMask) { throw new PlatformNotSupportedException(); }
1048+
1049+
/// <summary>
1050+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1051+
/// PFIRST Ptied.B, Pg, Ptied.B
1052+
/// </summary>
1053+
public static unsafe Vector<ushort> CreateMaskForFirstActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) { throw new PlatformNotSupportedException(); }
1054+
1055+
/// <summary>
1056+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1057+
/// PFIRST Ptied.B, Pg, Ptied.B
1058+
/// </summary>
1059+
public static unsafe Vector<uint> CreateMaskForFirstActiveElement(Vector<uint> mask, Vector<uint> srcMask) { throw new PlatformNotSupportedException(); }
1060+
1061+
/// <summary>
1062+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1063+
/// PFIRST Ptied.B, Pg, Ptied.B
1064+
/// </summary>
1065+
public static unsafe Vector<ulong> CreateMaskForFirstActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) { throw new PlatformNotSupportedException(); }
1066+
1067+
/// <summary>
1068+
/// svbool_t svpnext_b8(svbool_t pg, svbool_t op)
1069+
/// PNEXT Ptied.B, Pg, Ptied.B
1070+
/// </summary>
1071+
public static unsafe Vector<byte> CreateMaskForNextActiveElement(Vector<byte> mask, Vector<byte> srcMask) { throw new PlatformNotSupportedException(); }
1072+
1073+
/// <summary>
1074+
/// svbool_t svpnext_b16(svbool_t pg, svbool_t op)
1075+
/// PNEXT Ptied.H, Pg, Ptied.H
1076+
/// </summary>
1077+
public static unsafe Vector<ushort> CreateMaskForNextActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) { throw new PlatformNotSupportedException(); }
1078+
1079+
/// <summary>
1080+
/// svbool_t svpnext_b32(svbool_t pg, svbool_t op)
1081+
/// PNEXT Ptied.S, Pg, Ptied.S
1082+
/// </summary>
1083+
public static unsafe Vector<uint> CreateMaskForNextActiveElement(Vector<uint> mask, Vector<uint> srcMask) { throw new PlatformNotSupportedException(); }
1084+
1085+
/// <summary>
1086+
/// svbool_t svpnext_b64(svbool_t pg, svbool_t op)
1087+
/// PNEXT Ptied.D, Pg, Ptied.D
1088+
/// </summary>
1089+
public static unsafe Vector<ulong> CreateMaskForNextActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) { throw new PlatformNotSupportedException(); }
1090+
1091+
10191092
/// CreateTrueMaskByte : Set predicate elements to true
10201093

10211094
/// <summary>

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Arm/Sve.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,79 @@ internal Arm64() { }
10731073
public static unsafe Vector<ulong> CreateFalseMaskUInt64() => CreateFalseMaskUInt64();
10741074

10751075

1076+
/// <summary>
1077+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1078+
/// PFIRST Ptied.B, Pg, Ptied.B
1079+
/// </summary>
1080+
public static unsafe Vector<byte> CreateMaskForFirstActiveElement(Vector<byte> mask, Vector<byte> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);
1081+
1082+
/// <summary>
1083+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1084+
/// PFIRST Ptied.B, Pg, Ptied.B
1085+
/// </summary>
1086+
public static unsafe Vector<short> CreateMaskForFirstActiveElement(Vector<short> mask, Vector<short> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);
1087+
1088+
/// <summary>
1089+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1090+
/// PFIRST Ptied.B, Pg, Ptied.B
1091+
/// </summary>
1092+
public static unsafe Vector<int> CreateMaskForFirstActiveElement(Vector<int> mask, Vector<int> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);
1093+
1094+
/// <summary>
1095+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1096+
/// PFIRST Ptied.B, Pg, Ptied.B
1097+
/// </summary>
1098+
public static unsafe Vector<long> CreateMaskForFirstActiveElement(Vector<long> mask, Vector<long> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);
1099+
1100+
/// <summary>
1101+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1102+
/// PFIRST Ptied.B, Pg, Ptied.B
1103+
/// </summary>
1104+
public static unsafe Vector<sbyte> CreateMaskForFirstActiveElement(Vector<sbyte> mask, Vector<sbyte> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);
1105+
1106+
/// <summary>
1107+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1108+
/// PFIRST Ptied.B, Pg, Ptied.B
1109+
/// </summary>
1110+
public static unsafe Vector<ushort> CreateMaskForFirstActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);
1111+
1112+
/// <summary>
1113+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1114+
/// PFIRST Ptied.B, Pg, Ptied.B
1115+
/// </summary>
1116+
public static unsafe Vector<uint> CreateMaskForFirstActiveElement(Vector<uint> mask, Vector<uint> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);
1117+
1118+
/// <summary>
1119+
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
1120+
/// PFIRST Ptied.B, Pg, Ptied.B
1121+
/// </summary>
1122+
public static unsafe Vector<ulong> CreateMaskForFirstActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);
1123+
1124+
/// <summary>
1125+
/// svbool_t svpnext_b8(svbool_t pg, svbool_t op)
1126+
/// PNEXT Ptied.B, Pg, Ptied.B
1127+
/// </summary>
1128+
public static unsafe Vector<byte> CreateMaskForNextActiveElement(Vector<byte> mask, Vector<byte> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);
1129+
1130+
/// <summary>
1131+
/// svbool_t svpnext_b16(svbool_t pg, svbool_t op)
1132+
/// PNEXT Ptied.H, Pg, Ptied.H
1133+
/// </summary>
1134+
public static unsafe Vector<ushort> CreateMaskForNextActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);
1135+
1136+
/// <summary>
1137+
/// svbool_t svpnext_b32(svbool_t pg, svbool_t op)
1138+
/// PNEXT Ptied.S, Pg, Ptied.S
1139+
/// </summary>
1140+
public static unsafe Vector<uint> CreateMaskForNextActiveElement(Vector<uint> mask, Vector<uint> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);
1141+
1142+
/// <summary>
1143+
/// svbool_t svpnext_b64(svbool_t pg, svbool_t op)
1144+
/// PNEXT Ptied.D, Pg, Ptied.D
1145+
/// </summary>
1146+
public static unsafe Vector<ulong> CreateMaskForNextActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);
1147+
1148+
10761149
/// CreateTrueMaskByte : Set predicate elements to true
10771150

10781151
/// <summary>

src/libraries/System.Runtime.Intrinsics/ref/System.Runtime.Intrinsics.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4335,6 +4335,20 @@ internal Arm64() { }
43354335
public static System.Numerics.Vector<ushort> CreateFalseMaskUInt16() { throw null; }
43364336
public static System.Numerics.Vector<uint> CreateFalseMaskUInt32() { throw null; }
43374337
public static System.Numerics.Vector<ulong> CreateFalseMaskUInt64() { throw null; }
4338+
4339+
public static unsafe System.Numerics.Vector<byte> CreateMaskForFirstActiveElement(System.Numerics.Vector<byte> mask, System.Numerics.Vector<byte> srcMask) { throw null; }
4340+
public static unsafe System.Numerics.Vector<short> CreateMaskForFirstActiveElement(System.Numerics.Vector<short> mask, System.Numerics.Vector<short> srcMask) { throw null; }
4341+
public static unsafe System.Numerics.Vector<int> CreateMaskForFirstActiveElement(System.Numerics.Vector<int> mask, System.Numerics.Vector<int> srcMask) { throw null; }
4342+
public static unsafe System.Numerics.Vector<long> CreateMaskForFirstActiveElement(System.Numerics.Vector<long> mask, System.Numerics.Vector<long> srcMask) { throw null; }
4343+
public static unsafe System.Numerics.Vector<sbyte> CreateMaskForFirstActiveElement(System.Numerics.Vector<sbyte> mask, System.Numerics.Vector<sbyte> srcMask) { throw null; }
4344+
public static unsafe System.Numerics.Vector<ushort> CreateMaskForFirstActiveElement(System.Numerics.Vector<ushort> mask, System.Numerics.Vector<ushort> srcMask) { throw null; }
4345+
public static unsafe System.Numerics.Vector<uint> CreateMaskForFirstActiveElement(System.Numerics.Vector<uint> mask, System.Numerics.Vector<uint> srcMask) { throw null; }
4346+
public static unsafe System.Numerics.Vector<ulong> CreateMaskForFirstActiveElement(System.Numerics.Vector<ulong> mask, System.Numerics.Vector<ulong> srcMask) { throw null; }
4347+
public static unsafe System.Numerics.Vector<byte> CreateMaskForNextActiveElement(System.Numerics.Vector<byte> mask, System.Numerics.Vector<byte> srcMask) { throw null; }
4348+
public static unsafe System.Numerics.Vector<ushort> CreateMaskForNextActiveElement(System.Numerics.Vector<ushort> mask, System.Numerics.Vector<ushort> srcMask) { throw null; }
4349+
public static unsafe System.Numerics.Vector<uint> CreateMaskForNextActiveElement(System.Numerics.Vector<uint> mask, System.Numerics.Vector<uint> srcMask) { throw null; }
4350+
public static unsafe System.Numerics.Vector<ulong> CreateMaskForNextActiveElement(System.Numerics.Vector<ulong> mask, System.Numerics.Vector<ulong> srcMask) { throw null; }
4351+
43384352
public static System.Numerics.Vector<byte> CreateTrueMaskByte([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
43394353
public static System.Numerics.Vector<double> CreateTrueMaskDouble([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
43404354
public static System.Numerics.Vector<short> CreateTrueMaskInt16([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }

0 commit comments

Comments
 (0)