Skip to content

Commit c52fd37

Browse files
Add support for Sve.Scatter() (#104555)
* Add support for Sve.Scatter() * Fix XML formatting error * Address review comments
1 parent e3d2dd1 commit c52fd37

File tree

10 files changed

+1332
-2
lines changed

10 files changed

+1332
-2
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26845,6 +26845,10 @@ bool GenTreeHWIntrinsic::OperIsMemoryStore(GenTree** pAddr) const
2684526845
addr = Op(2);
2684626846
break;
2684726847

26848+
case NI_Sve_Scatter:
26849+
addr = Op(2);
26850+
break;
26851+
2684826852
#endif // TARGET_ARM64
2684926853

2685026854
default:
@@ -26886,7 +26890,11 @@ bool GenTreeHWIntrinsic::OperIsMemoryStore(GenTree** pAddr) const
2688626890

2688726891
if (addr != nullptr)
2688826892
{
26893+
#ifdef TARGET_ARM64
26894+
assert(varTypeIsI(addr) || (varTypeIsSIMD(addr) && ((intrinsicId >= NI_Sve_Scatter))));
26895+
#else
2688926896
assert(varTypeIsI(addr));
26897+
#endif
2689026898
return true;
2689126899
}
2689226900

src/coreclr/jit/hwintrinsic.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,11 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
18691869
case NI_Sve_GatherVectorUInt32ZeroExtend:
18701870
case NI_Sve_GatherVectorWithByteOffsets:
18711871
assert(varTypeIsSIMD(op3->TypeGet()));
1872-
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
1872+
if (numArgs == 3)
1873+
{
1874+
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(
1875+
getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
1876+
}
18731877
break;
18741878
#endif
18751879

@@ -1885,6 +1889,23 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
18851889
assert(!isScalar);
18861890
retNode =
18871891
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
1892+
1893+
switch (intrinsic)
1894+
{
1895+
#if defined(TARGET_ARM64)
1896+
case NI_Sve_Scatter:
1897+
assert(varTypeIsSIMD(op3->TypeGet()));
1898+
if (numArgs == 4)
1899+
{
1900+
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(
1901+
getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
1902+
}
1903+
break;
1904+
#endif
1905+
1906+
default:
1907+
break;
1908+
}
18881909
break;
18891910
}
18901911

src/coreclr/jit/hwintrinsiccodegenarm64.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,39 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
20522052
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
20532053
break;
20542054

2055+
case NI_Sve_Scatter:
2056+
{
2057+
if (!varTypeIsSIMD(intrin.op2->gtType))
2058+
{
2059+
// Scatter(Vector<T1> mask, T1* address, Vector<T2> indicies, Vector<T> data)
2060+
assert(intrin.numOperands == 4);
2061+
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
2062+
2063+
if (baseSize == EA_8BYTE)
2064+
{
2065+
// Index is multiplied by 8
2066+
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, op2Reg, op3Reg, opt,
2067+
INS_SCALABLE_OPTS_LSL_N);
2068+
}
2069+
else
2070+
{
2071+
// Index is sign or zero extended to 64bits, then multiplied by 4
2072+
assert(baseSize == EA_4BYTE);
2073+
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
2074+
: INS_OPTS_SCALABLE_S_SXTW;
2075+
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, op2Reg, op3Reg, opt,
2076+
INS_SCALABLE_OPTS_MOD_N);
2077+
}
2078+
}
2079+
else
2080+
{
2081+
// Scatter(Vector<T> mask, Vector<T> addresses, Vector<T> data)
2082+
assert(intrin.numOperands == 3);
2083+
GetEmitter()->emitIns_R_R_R_I(ins, emitSize, op3Reg, op1Reg, op2Reg, 0, opt);
2084+
}
2085+
break;
2086+
}
2087+
20552088
case NI_Sve_StoreNarrowing:
20562089
opt = emitter::optGetSveInsOpt(emitTypeSize(intrin.baseType));
20572090
GetEmitter()->emitIns_R_R_R_I(ins, emitSize, op3Reg, op1Reg, op2Reg, 0, opt);

src/coreclr/jit/hwintrinsiclistarm64sve.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ HARDWARE_INTRINSIC(Sve, SaturatingIncrementBy64BitElementCount,
222222
HARDWARE_INTRINSIC(Sve, SaturatingIncrementBy8BitElementCount, 0, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sqincb, INS_sve_uqincb, INS_sve_sqincb, INS_sve_uqincb, INS_invalid, INS_invalid}, HW_Category_Scalar, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand|HW_Flag_SpecialCodeGen|HW_Flag_SpecialImport|HW_Flag_HasRMWSemantics)
223223
HARDWARE_INTRINSIC(Sve, SaturatingIncrementByActiveElementCount, -1, 2, true, {INS_invalid, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_SpecialImport|HW_Flag_BaseTypeFromSecondArg|HW_Flag_HasRMWSemantics)
224224
HARDWARE_INTRINSIC(Sve, Scale, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fscale, INS_sve_fscale}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
225+
HARDWARE_INTRINSIC(Sve, Scatter, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_st1w, INS_sve_st1w, INS_sve_st1d, INS_sve_st1d, INS_sve_st1w, INS_sve_st1d}, HW_Category_MemoryStore, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
225226
HARDWARE_INTRINSIC(Sve, ShiftLeftLogical, -1, -1, false, {INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
226227
HARDWARE_INTRINSIC(Sve, ShiftRightArithmetic, -1, -1, false, {INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
227228
HARDWARE_INTRINSIC(Sve, ShiftRightArithmeticForDivide, -1, -1, false, {INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_invalid, INS_invalid}, HW_Category_ShiftRightByImmediate, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_HasImmediateOperand)

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Arm/Sve.PlatformNotSupported.cs

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4175,7 +4175,7 @@ internal Arm64() { }
41754175
/// svuint8_t svcls[_s8]_z(svbool_t pg, svint8_t op)
41764176
/// CLS Ztied.B, Pg/M, Zop.B
41774177
/// </summary>
4178-
public static unsafe Vector<byte> LeadingSignCount(Vector<sbyte> value){ throw new PlatformNotSupportedException(); }
4178+
public static unsafe Vector<byte> LeadingSignCount(Vector<sbyte> value) { throw new PlatformNotSupportedException(); }
41794179

41804180
/// <summary>
41814181
/// svuint16_t svcls[_s16]_m(svuint16_t inactive, svbool_t pg, svint16_t op)
@@ -7144,6 +7144,120 @@ internal Arm64() { }
71447144
public static unsafe Vector<float> Scale(Vector<float> left, Vector<int> right) { throw new PlatformNotSupportedException(); }
71457145

71467146

7147+
// Non-truncating store
7148+
7149+
// <summary>
7150+
// void svst1_scatter_[s64]offset[_f64](svbool_t pg, float64_t *base, svint64_t offsets, svfloat64_t data)
7151+
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
7152+
// </summary>
7153+
public static unsafe void Scatter(Vector<double> mask, double* address, Vector<long> indicies, Vector<double> data) { throw new PlatformNotSupportedException(); }
7154+
7155+
// <summary>
7156+
// void svst1_scatter[_u64base_f64](svbool_t pg, svuint64_t bases, svfloat64_t data)
7157+
// ST1D Zdata.D, Pg, [Zbases.D, #0]
7158+
// </summary>
7159+
public static unsafe void Scatter(Vector<double> mask, Vector<ulong> addresses, Vector<double> data) { throw new PlatformNotSupportedException(); }
7160+
7161+
// <summary>
7162+
// void svst1_scatter_[u64]offset[_f64](svbool_t pg, float64_t *base, svuint64_t offsets, svfloat64_t data)
7163+
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
7164+
// </summary>
7165+
public static unsafe void Scatter(Vector<double> mask, double* address, Vector<ulong> indicies, Vector<double> data) { throw new PlatformNotSupportedException(); }
7166+
7167+
// <summary>
7168+
// void svst1_scatter_[s32]offset[_s32](svbool_t pg, int32_t *base, svint32_t offsets, svint32_t data)
7169+
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
7170+
// </summary>
7171+
public static unsafe void Scatter(Vector<int> mask, int* address, Vector<int> indicies, Vector<int> data) { throw new PlatformNotSupportedException(); }
7172+
7173+
// <summary>
7174+
// void svst1_scatter[_u32base_s32](svbool_t pg, svuint32_t bases, svint32_t data)
7175+
// ST1W Zdata.S, Pg, [Zbases.S, #0]
7176+
// </summary>
7177+
// Removed as per #103297
7178+
// public static unsafe void Scatter(Vector<int> mask, Vector<uint> addresses, Vector<int> data) { throw new PlatformNotSupportedException(); }
7179+
7180+
// <summary>
7181+
// void svst1_scatter_[u32]offset[_s32](svbool_t pg, int32_t *base, svuint32_t offsets, svint32_t data)
7182+
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
7183+
// </summary>
7184+
public static unsafe void Scatter(Vector<int> mask, int* address, Vector<uint> indicies, Vector<int> data) { throw new PlatformNotSupportedException(); }
7185+
7186+
// <summary>
7187+
// void svst1_scatter_[s64]offset[_s64](svbool_t pg, int64_t *base, svint64_t offsets, svint64_t data)
7188+
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
7189+
// </summary>
7190+
public static unsafe void Scatter(Vector<long> mask, long* address, Vector<long> indicies, Vector<long> data) { throw new PlatformNotSupportedException(); }
7191+
7192+
// <summary>
7193+
// void svst1_scatter[_u64base_s64](svbool_t pg, svuint64_t bases, svint64_t data)
7194+
// ST1D Zdata.D, Pg, [Zbases.D, #0]
7195+
// </summary>
7196+
public static unsafe void Scatter(Vector<long> mask, Vector<ulong> addresses, Vector<long> data) { throw new PlatformNotSupportedException(); }
7197+
7198+
// <summary>
7199+
// void svst1_scatter_[u64]offset[_s64](svbool_t pg, int64_t *base, svuint64_t offsets, svint64_t data)
7200+
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
7201+
// </summary>
7202+
public static unsafe void Scatter(Vector<long> mask, long* address, Vector<ulong> indicies, Vector<long> data) { throw new PlatformNotSupportedException(); }
7203+
7204+
// <summary>
7205+
// void svst1_scatter_[s32]offset[_f32](svbool_t pg, float32_t *base, svint32_t offsets, svfloat32_t data)
7206+
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
7207+
// </summary>
7208+
public static unsafe void Scatter(Vector<float> mask, float* address, Vector<int> indicies, Vector<float> data) { throw new PlatformNotSupportedException(); }
7209+
7210+
// <summary>
7211+
// void svst1_scatter[_u32base_f32](svbool_t pg, svuint32_t bases, svfloat32_t data)
7212+
// ST1W Zdata.S, Pg, [Zbases.S, #0]
7213+
// </summary>
7214+
// Removed as per #103297
7215+
// public static unsafe void Scatter(Vector<float> mask, Vector<uint> addresses, Vector<float> data) { throw new PlatformNotSupportedException(); }
7216+
7217+
// <summary>
7218+
// void svst1_scatter_[u32]offset[_f32](svbool_t pg, float32_t *base, svuint32_t offsets, svfloat32_t data)
7219+
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
7220+
// </summary>
7221+
public static unsafe void Scatter(Vector<float> mask, float* address, Vector<uint> indicies, Vector<float> data) { throw new PlatformNotSupportedException(); }
7222+
7223+
// <summary>
7224+
// void svst1_scatter_[s32]offset[_u32](svbool_t pg, uint32_t *base, svint32_t offsets, svuint32_t data)
7225+
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
7226+
// </summary>
7227+
public static unsafe void Scatter(Vector<uint> mask, uint* address, Vector<int> indicies, Vector<uint> data) { throw new PlatformNotSupportedException(); }
7228+
7229+
// <summary>
7230+
// void svst1_scatter[_u32base_u32](svbool_t pg, svuint32_t bases, svuint32_t data)
7231+
// ST1W Zdata.S, Pg, [Zbases.S, #0]
7232+
// </summary>
7233+
// Removed as per #103297
7234+
// public static unsafe void Scatter(Vector<uint> mask, Vector<uint> addresses, Vector<uint> data) { throw new PlatformNotSupportedException(); }
7235+
7236+
// <summary>
7237+
// void svst1_scatter_[u32]offset[_u32](svbool_t pg, uint32_t *base, svuint32_t offsets, svuint32_t data)
7238+
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
7239+
// </summary>
7240+
public static unsafe void Scatter(Vector<uint> mask, uint* address, Vector<uint> indicies, Vector<uint> data) { throw new PlatformNotSupportedException(); }
7241+
7242+
// <summary>
7243+
// void svst1_scatter_[s64]offset[_u64](svbool_t pg, uint64_t *base, svint64_t offsets, svuint64_t data)
7244+
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
7245+
// </summary>
7246+
public static unsafe void Scatter(Vector<ulong> mask, ulong* address, Vector<long> indicies, Vector<ulong> data) { throw new PlatformNotSupportedException(); }
7247+
7248+
// <summary>
7249+
// void svst1_scatter[_u64base_u64](svbool_t pg, svuint64_t bases, svuint64_t data)
7250+
// ST1D Zdata.D, Pg, [Zbases.D, #0]
7251+
// </summary>
7252+
public static unsafe void Scatter(Vector<ulong> mask, Vector<ulong> addresses, Vector<ulong> data) { throw new PlatformNotSupportedException(); }
7253+
7254+
// <summary>
7255+
// void svst1_scatter_[u64]offset[_u64](svbool_t pg, uint64_t *base, svuint64_t offsets, svuint64_t data)
7256+
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
7257+
// </summary>
7258+
public static unsafe void Scatter(Vector<ulong> mask, ulong* address, Vector<ulong> indicies, Vector<ulong> data) { throw new PlatformNotSupportedException(); }
7259+
7260+
71477261
/// Logical shift left
71487262

71497263
/// <summary>

0 commit comments

Comments
 (0)