Skip to content

Commit 473a983

Browse files
Updating Sum() implementation for Vector128 and Vector256 + adding lowering for Vector512 (#95568)
* Updating Sum() implementation. * Fixing codegen * Addressing review comments. * Fix Formatting * Enabling for long on x86. * Cleaning up ToScalar implementation
1 parent f5a97b2 commit 473a983

File tree

4 files changed

+152
-60
lines changed

4 files changed

+152
-60
lines changed

src/coreclr/jit/compiler.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3271,6 +3271,10 @@ class Compiler
32713271
GenTree* op4,
32723272
CorInfoType simdBaseJitType,
32733273
unsigned simdSize);
3274+
GenTree* gtNewSimdToScalarNode(var_types type,
3275+
GenTree* op1,
3276+
CorInfoType simdBaseJitType,
3277+
unsigned simdSize);
32743278
#endif // TARGET_XARCH
32753279

32763280
GenTree* gtNewSimdUnOpNode(genTreeOps op,

src/coreclr/jit/gentree.cpp

Lines changed: 137 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24491,58 +24491,113 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
2449124491
GenTree* tmp = nullptr;
2449224492

2449324493
#if defined(TARGET_XARCH)
24494-
assert(!varTypeIsByte(simdBaseType) && !varTypeIsLong(simdBaseType));
24495-
assert(simdSize != 64);
2449624494

24497-
// HorizontalAdd combines pairs so we need log2(vectorLength) passes to sum all elements together.
24498-
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
24499-
int haddCount = genLog2(vectorLength);
24495+
if (simdSize == 64)
24496+
{
24497+
assert(IsBaselineVector512IsaSupportedDebugOnly());
24498+
GenTree* op1Dup = fgMakeMultiUse(&op1);
24499+
op1 = gtNewSimdGetUpperNode(TYP_SIMD32, op1, simdBaseJitType, simdSize);
24500+
op1Dup = gtNewSimdGetLowerNode(TYP_SIMD32, op1Dup, simdBaseJitType, simdSize);
24501+
simdSize = simdSize / 2;
24502+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1, op1Dup, simdBaseJitType, simdSize);
24503+
}
2450024504

2450124505
if (simdSize == 32)
2450224506
{
24503-
// Minus 1 because for the last pass we split the vector to low / high and add them together.
24504-
haddCount -= 1;
24507+
assert(compIsaSupportedDebugOnly(InstructionSet_AVX2));
24508+
GenTree* op1Dup = fgMakeMultiUse(&op1);
24509+
op1 = gtNewSimdGetUpperNode(TYP_SIMD16, op1, simdBaseJitType, simdSize);
24510+
op1Dup = gtNewSimdGetLowerNode(TYP_SIMD16, op1Dup, simdBaseJitType, simdSize);
24511+
simdSize = simdSize / 2;
24512+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Dup, simdBaseJitType, simdSize);
24513+
}
2450524514

24506-
if (varTypeIsFloating(simdBaseType))
24515+
assert(simdSize == 16);
24516+
24517+
if (varTypeIsFloating(simdBaseType))
24518+
{
24519+
if (simdBaseType == TYP_FLOAT)
2450724520
{
24508-
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));
24509-
intrinsic = NI_AVX_HorizontalAdd;
24521+
assert(compIsaSupportedDebugOnly(InstructionSet_SSE2));
24522+
GenTree* op1Shuffled = fgMakeMultiUse(&op1);
24523+
if (compOpportunisticallyDependsOn(InstructionSet_AVX))
24524+
{
24525+
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));
24526+
// The permute below gives us [0, 1, 2, 3] -> [1, 0, 3, 2]
24527+
op1 = gtNewSimdHWIntrinsicNode(type, op1, gtNewIconNode((int)0b10110001, TYP_INT), NI_AVX_Permute,
24528+
simdBaseJitType, simdSize);
24529+
// The add below now results in [0 + 1, 1 + 0, 2 + 3, 3 + 2]
24530+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Shuffled, simdBaseJitType, simdSize);
24531+
op1Shuffled = fgMakeMultiUse(&op1);
24532+
// The permute below gives us [0 + 1, 1 + 0, 2 + 3, 3 + 2] -> [2 + 3, 3 + 2, 0 + 1, 1 + 0]
24533+
op1 = gtNewSimdHWIntrinsicNode(type, op1, gtNewIconNode((int)0b01001110, TYP_INT), NI_AVX_Permute,
24534+
simdBaseJitType, simdSize);
24535+
}
24536+
else
24537+
{
24538+
assert(compIsaSupportedDebugOnly(InstructionSet_SSE));
24539+
// The shuffle below gives us [0, 1, 2, 3] -> [1, 0, 3, 2]
24540+
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, op1Shuffled, gtNewIconNode((int)0b10110001, TYP_INT),
24541+
NI_SSE_Shuffle, simdBaseJitType, simdSize);
24542+
op1Shuffled = fgMakeMultiUse(&op1Shuffled);
24543+
// The add below now results in [0 + 1, 1 + 0, 2 + 3, 3 + 2]
24544+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Shuffled, simdBaseJitType, simdSize);
24545+
op1Shuffled = fgMakeMultiUse(&op1);
24546+
// The shuffle below gives us [0 + 1, 1 + 0, 2 + 3, 3 + 2] -> [2 + 3, 3 + 2, 0 + 1, 1 + 0]
24547+
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, op1Shuffled, gtNewIconNode((int)0b01001110, TYP_INT),
24548+
NI_SSE_Shuffle, simdBaseJitType, simdSize);
24549+
op1Shuffled = fgMakeMultiUse(&op1Shuffled);
24550+
}
24551+
// Finally adding the results gets us [(0 + 1) + (2 + 3), (1 + 0) + (3 + 2), (2 + 3) + (0 + 1), (3 + 2) + (1
24552+
// + 0)]
24553+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Shuffled, simdBaseJitType, simdSize);
24554+
return gtNewSimdToScalarNode(type, op1, simdBaseJitType, simdSize);
2451024555
}
2451124556
else
2451224557
{
24513-
assert(compIsaSupportedDebugOnly(InstructionSet_AVX2));
24514-
intrinsic = NI_AVX2_HorizontalAdd;
24558+
assert(compIsaSupportedDebugOnly(InstructionSet_SSE2));
24559+
GenTree* op1Shuffled = fgMakeMultiUse(&op1);
24560+
if (compOpportunisticallyDependsOn(InstructionSet_AVX))
24561+
{
24562+
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));
24563+
// The permute below gives us [0, 1] -> [1, 0]
24564+
op1 = gtNewSimdHWIntrinsicNode(type, op1, gtNewIconNode((int)0b0001, TYP_INT), NI_AVX_Permute,
24565+
simdBaseJitType, simdSize);
24566+
}
24567+
else
24568+
{
24569+
// The shuffle below gives us [0, 1] -> [1, 0]
24570+
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, op1Shuffled, gtNewIconNode((int)0b0001, TYP_INT),
24571+
NI_SSE2_Shuffle, simdBaseJitType, simdSize);
24572+
op1Shuffled = fgMakeMultiUse(&op1Shuffled);
24573+
}
24574+
// Finally adding the results gets us [0 + 1, 1 + 0]
24575+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1, op1Shuffled, simdBaseJitType, simdSize);
24576+
return gtNewSimdToScalarNode(type, op1, simdBaseJitType, simdSize);
2451524577
}
2451624578
}
24517-
else if (varTypeIsFloating(simdBaseType))
24518-
{
24519-
assert(compIsaSupportedDebugOnly(InstructionSet_SSE3));
24520-
intrinsic = NI_SSE3_HorizontalAdd;
24521-
}
24522-
else
24523-
{
24524-
assert(compIsaSupportedDebugOnly(InstructionSet_SSSE3));
24525-
intrinsic = NI_SSSE3_HorizontalAdd;
24526-
}
2452724579

24528-
for (int i = 0; i < haddCount; i++)
24529-
{
24530-
tmp = fgMakeMultiUse(&op1);
24531-
op1 = gtNewSimdHWIntrinsicNode(simdType, op1, tmp, intrinsic, simdBaseJitType, simdSize);
24532-
}
24580+
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
24581+
int shiftCount = genLog2(vectorLength);
24582+
int typeSize = genTypeSize(simdBaseType);
24583+
int shiftVal = (typeSize * vectorLength) / 2;
2453324584

24534-
if (simdSize == 32)
24585+
// The reduced sum is calculated for integer values using a combination of shift + add
24586+
// For e.g. consider a 32 bit integer. This means we have 4 values in a XMM register
24587+
// After the first shift + add -> [(0 + 2), (1 + 3), ...]
24588+
// After the second shift + add -> [(0 + 2 + 1 + 3), ...]
24589+
GenTree* opShifted = nullptr;
24590+
while (shiftVal >= typeSize)
2453524591
{
24536-
intrinsic = (simdBaseType == TYP_FLOAT) ? NI_SSE_Add : NI_SSE2_Add;
24537-
24538-
tmp = fgMakeMultiUse(&op1);
24539-
op1 = gtNewSimdGetUpperNode(TYP_SIMD16, op1, simdBaseJitType, simdSize);
24540-
24541-
tmp = gtNewSimdGetLowerNode(TYP_SIMD16, tmp, simdBaseJitType, simdSize);
24542-
op1 = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, tmp, intrinsic, simdBaseJitType, 16);
24592+
tmp = fgMakeMultiUse(&op1);
24593+
opShifted = gtNewSimdHWIntrinsicNode(TYP_SIMD16, op1, gtNewIconNode(shiftVal, TYP_INT),
24594+
NI_SSE2_ShiftRightLogical128BitLane, simdBaseJitType, simdSize);
24595+
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, opShifted, tmp, simdBaseJitType, simdSize);
24596+
shiftVal = shiftVal / 2;
2454324597
}
2454424598

24545-
return gtNewSimdHWIntrinsicNode(type, op1, NI_Vector128_ToScalar, simdBaseJitType, 16);
24599+
return gtNewSimdToScalarNode(type, op1, simdBaseJitType, simdSize);
24600+
2454624601
#elif defined(TARGET_ARM64)
2454724602
switch (simdBaseType)
2454824603
{
@@ -24673,6 +24728,52 @@ GenTree* Compiler::gtNewSimdTernaryLogicNode(var_types type,
2467324728
}
2467424729
#endif // TARGET_XARCH
2467524730

24731+
#if defined(TARGET_XARCH)
24732+
//----------------------------------------------------------------------------------------------
24733+
// Compiler::gtNewSimdToScalarNode: Creates a new simd ToScalar node.
24734+
//
24735+
// Arguments:
24736+
// type - The return type of SIMD node being created.
24737+
// op1 - The SIMD operand.
24738+
// simdBaseJitType - The base JIT type of SIMD type of the intrinsic.
24739+
// simdSize - The size of the SIMD type of the intrinsic.
24740+
//
24741+
// Returns:
24742+
// The created node that has the ToScalar implementation.
24743+
//
24744+
GenTree* Compiler::gtNewSimdToScalarNode(var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize)
24745+
{
24746+
24747+
#if defined(TARGET_X86)
24748+
var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
24749+
if (varTypeIsLong(simdBaseType))
24750+
{
24751+
// We need SSE41 to handle long, use software fallback
24752+
assert(compIsaSupportedDebugOnly(InstructionSet_SSE41));
24753+
24754+
// Create a GetElement node which handles decomposition
24755+
GenTree* op2 = gtNewIconNode(0);
24756+
return gtNewSimdGetElementNode(type, op1, op2, simdBaseJitType, simdSize);
24757+
}
24758+
#endif // TARGET_X86
24759+
// Ensure MOVD/MOVQ support exists
24760+
assert(compIsaSupportedDebugOnly(InstructionSet_SSE2));
24761+
NamedIntrinsic intrinsic = NI_Vector128_ToScalar;
24762+
24763+
if (simdSize == 32)
24764+
{
24765+
assert(compIsaSupportedDebugOnly(InstructionSet_AVX));
24766+
intrinsic = NI_Vector256_ToScalar;
24767+
}
24768+
else if (simdSize == 64)
24769+
{
24770+
assert(IsBaselineVector512IsaSupportedDebugOnly());
24771+
intrinsic = NI_Vector512_ToScalar;
24772+
}
24773+
return gtNewSimdHWIntrinsicNode(type, op1, intrinsic, simdBaseJitType, simdSize);
24774+
}
24775+
#endif // TARGET_XARCH
24776+
2467624777
GenTree* Compiler::gtNewSimdUnOpNode(
2467724778
genTreeOps op, var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize)
2467824779
{

src/coreclr/jit/hwintrinsiclistxarch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ HARDWARE_INTRINSIC(Vector512, StoreAligned,
330330
HARDWARE_INTRINSIC(Vector512, StoreAlignedNonTemporal, 64, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)
331331
HARDWARE_INTRINSIC(Vector512, StoreUnsafe, 64, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)
332332
HARDWARE_INTRINSIC(Vector512, Subtract, 64, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen)
333+
HARDWARE_INTRINSIC(Vector512, Sum, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)
333334
HARDWARE_INTRINSIC(Vector512, ToScalar, 64, 1, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_movss, INS_movsd_simd}, HW_Category_SIMDScalar, HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_BaseTypeFromFirstArg)
334335
HARDWARE_INTRINSIC(Vector512, WidenLower, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen|HW_Flag_BaseTypeFromFirstArg)
335336
HARDWARE_INTRINSIC(Vector512, WidenUpper, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen|HW_Flag_BaseTypeFromFirstArg)

src/coreclr/jit/hwintrinsicxarch.cpp

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2877,33 +2877,27 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
28772877

28782878
case NI_Vector128_Sum:
28792879
case NI_Vector256_Sum:
2880+
case NI_Vector512_Sum:
28802881
{
28812882
assert(sig->numArgs == 1);
28822883
var_types simdType = getSIMDTypeForSize(simdSize);
28832884

28842885
if ((simdSize == 32) && !compOpportunisticallyDependsOn(InstructionSet_AVX2))
28852886
{
2886-
// Vector256 for integer types requires AVX2
2887+
// Vector256 requires AVX2
28872888
break;
28882889
}
2889-
else if (varTypeIsFloating(simdBaseType))
2890+
else if ((simdSize == 16) && !compOpportunisticallyDependsOn(InstructionSet_SSE2))
28902891
{
2891-
if (!compOpportunisticallyDependsOn(InstructionSet_SSE3))
2892-
{
2893-
// Floating-point types require SSE3.HorizontalAdd
2894-
break;
2895-
}
2896-
}
2897-
else if (!compOpportunisticallyDependsOn(InstructionSet_SSSE3))
2898-
{
2899-
// Integral types require SSSE3.HorizontalAdd
29002892
break;
29012893
}
2902-
else if (varTypeIsByte(simdBaseType) || varTypeIsLong(simdBaseType))
2894+
#if defined(TARGET_X86)
2895+
else if (varTypeIsLong(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41))
29032896
{
2904-
// byte, sbyte, long, and ulong all would require more work to support
2897+
// We need SSE41 to handle long, use software fallback
29052898
break;
29062899
}
2900+
#endif // TARGET_X86
29072901

29082902
op1 = impSIMDPopStack();
29092903
retNode = gtNewSimdSumNode(retType, op1, simdBaseJitType, simdSize);
@@ -2917,23 +2911,15 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
29172911
assert(sig->numArgs == 1);
29182912

29192913
#if defined(TARGET_X86)
2920-
if (varTypeIsLong(simdBaseType))
2914+
if (varTypeIsLong(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41))
29212915
{
2922-
if (!compOpportunisticallyDependsOn(InstructionSet_SSE41))
2923-
{
2924-
// We need SSE41 to handle long, use software fallback
2925-
break;
2926-
}
2927-
// Create a GetElement node which handles decomposition
2928-
op1 = impSIMDPopStack();
2929-
op2 = gtNewIconNode(0);
2930-
retNode = gtNewSimdGetElementNode(retType, op1, op2, simdBaseJitType, simdSize);
2916+
// We need SSE41 to handle long, use software fallback
29312917
break;
29322918
}
29332919
#endif // TARGET_X86
29342920

29352921
op1 = impSIMDPopStack();
2936-
retNode = gtNewSimdHWIntrinsicNode(retType, op1, intrinsic, simdBaseJitType, simdSize);
2922+
retNode = gtNewSimdToScalarNode(retType, op1, simdBaseJitType, simdSize);
29372923
break;
29382924
}
29392925

0 commit comments

Comments
 (0)