Skip to content

Commit 3562aea

Browse files
committed
accelerate Vector.Dot for all base types
1 parent f41b65f commit 3562aea

File tree

3 files changed

+99
-102
lines changed

3 files changed

+99
-102
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 69 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21423,45 +21423,84 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2142321423
}
2142421424
else if (varTypeIsLong(simdBaseType))
2142521425
{
21426-
assert((simdSize == 16) || (simdSize == 32) || (simdSize == 64));
21426+
if ((simdSize == 32) || compOpportunisticallyDependsOn(InstructionSet_SSE41))
21427+
{
21428+
assert((simdSize == 16) || compIsaSupportedDebugOnly(InstructionSet_AVX2));
2142721429

21428-
assert(((simdSize == 16) && compOpportunisticallyDependsOn(InstructionSet_SSE41)) ||
21429-
((simdSize == 32) && compOpportunisticallyDependsOn(InstructionSet_AVX2)));
21430+
// Make op1 and op2 multi-use:
21431+
GenTree* op1Dup = fgMakeMultiUse(&op1);
21432+
GenTree* op2Dup = fgMakeMultiUse(&op2);
2143021433

21431-
// Make op1 and op2 multi-use:
21432-
GenTree* op1Dup = fgMakeMultiUse(&op1);
21433-
GenTree* op2Dup = fgMakeMultiUse(&op2);
21434+
const bool is256 = simdSize == 32;
21435+
21436+
// Vector256<ulong> tmp0 = Avx2.Multiply(left, right);
21437+
GenTreeHWIntrinsic* tmp0 =
21438+
gtNewSimdHWIntrinsicNode(type, op1, op2, is256 ? NI_AVX2_Multiply : NI_SSE2_Multiply,
21439+
CORINFO_TYPE_ULONG, simdSize);
21440+
21441+
// Vector256<uint> tmp1 = Avx2.Shuffle(right.AsUInt32(), ZWXY);
21442+
GenTree* shuffleMask = gtNewIconNode(SHUFFLE_ZWXY, TYP_INT);
21443+
GenTreeHWIntrinsic* tmp1 =
21444+
gtNewSimdHWIntrinsicNode(type, op2Dup, shuffleMask, is256 ? NI_AVX2_Shuffle : NI_SSE2_Shuffle,
21445+
CORINFO_TYPE_UINT, simdSize);
21446+
21447+
// Vector256<uint> tmp2 = Avx2.MultiplyLow(left.AsUInt32(), tmp1);
21448+
GenTree* tmp2 = gtNewSimdBinOpNode(GT_MUL, type, op1Dup, tmp1, CORINFO_TYPE_UINT, simdSize);
2143421449

21435-
const bool is256 = simdSize == 32;
21450+
// Vector256<int> tmp3 = Avx2.HorizontalAdd(tmp2.AsInt32(), Vector256<int>.Zero);
21451+
GenTreeHWIntrinsic* tmp3 =
21452+
gtNewSimdHWIntrinsicNode(type, tmp2, gtNewZeroConNode(type),
21453+
is256 ? NI_AVX2_HorizontalAdd : NI_SSSE3_HorizontalAdd,
21454+
CORINFO_TYPE_UINT, simdSize);
2143621455

21437-
// Vector256<ulong> tmp0 = Avx2.Multiply(left, right);
21438-
GenTreeHWIntrinsic* tmp0 =
21439-
gtNewSimdHWIntrinsicNode(type, op1, op2, is256 ? NI_AVX2_Multiply : NI_SSE2_Multiply,
21440-
CORINFO_TYPE_ULONG, simdSize);
21456+
// Vector256<int> tmp4 = Avx2.Shuffle(tmp3, YWXW);
21457+
shuffleMask = gtNewIconNode(SHUFFLE_YWXW, TYP_INT);
21458+
GenTreeHWIntrinsic* tmp4 =
21459+
gtNewSimdHWIntrinsicNode(type, tmp3, shuffleMask, is256 ? NI_AVX2_Shuffle : NI_SSE2_Shuffle,
21460+
CORINFO_TYPE_UINT, simdSize);
2144121461

21442-
// Vector256<uint> tmp1 = Avx2.Shuffle(right.AsUInt32(), ZWXY);
21443-
GenTree* shuffleMask = gtNewIconNode(SHUFFLE_ZWXY, TYP_INT);
21444-
GenTreeHWIntrinsic* tmp1 =
21445-
gtNewSimdHWIntrinsicNode(type, op2Dup, shuffleMask, is256 ? NI_AVX2_Shuffle : NI_SSE2_Shuffle,
21446-
CORINFO_TYPE_UINT, simdSize);
21462+
// result = tmp0 + tmp4;
21463+
return gtNewSimdBinOpNode(GT_ADD, type, tmp0, tmp4, simdBaseJitType, simdSize);
21464+
}
21465+
else
21466+
{
21467+
// SSE2 implementation is simple decomposition using pmuludq,
21468+
// which multiplies two uint32s and returns a uint64 result.
21469+
// aLo * bLo + ((aLo * bHi + aHi * bLo) << 32)
21470+
GenTree* op1Dup1 = fgMakeMultiUse(&op1);
21471+
GenTree* op1Dup2 = gtCloneExpr(op1Dup1);
21472+
GenTree* op2Dup1 = fgMakeMultiUse(&op2);
21473+
GenTree* op2Dup2 = gtCloneExpr(op2Dup1);
21474+
21475+
// Vector128<ulong> low = Sse2.Multiply(left.AsUInt32(), right.AsUInt32());
21476+
GenTreeHWIntrinsic* low =
21477+
gtNewSimdHWIntrinsicNode(type, op1, op2, NI_SSE2_Multiply, CORINFO_TYPE_ULONG, simdSize);
2144721478

21448-
// Vector256<uint> tmp2 = Avx2.MultiplyLow(left.AsUInt32(), tmp1);
21449-
GenTree* tmp2 = gtNewSimdBinOpNode(GT_MUL, type, op1Dup, tmp1, CORINFO_TYPE_UINT, simdSize);
21479+
// Vector128<uint> rightHi = (right >>> 32).AsUInt32();
21480+
GenTree* rightHi =
21481+
gtNewSimdBinOpNode(GT_RSZ, type, op2Dup1, gtNewIconNode(32), simdBaseJitType, simdSize);
2145021482

21451-
// Vector256<int> tmp3 = Avx2.HorizontalAdd(tmp2.AsInt32(), Vector256<int>.Zero);
21452-
GenTreeHWIntrinsic* tmp3 =
21453-
gtNewSimdHWIntrinsicNode(type, tmp2, gtNewZeroConNode(type),
21454-
is256 ? NI_AVX2_HorizontalAdd : NI_SSSE3_HorizontalAdd, CORINFO_TYPE_UINT,
21455-
simdSize);
21483+
// Vector128<ulong> tmp0 = Sse2.Multiply(rightHi, left.AsUInt32());
21484+
GenTreeHWIntrinsic* tmp0 = gtNewSimdHWIntrinsicNode(type, rightHi, op1Dup1, NI_SSE2_Multiply,
21485+
CORINFO_TYPE_ULONG, simdSize);
2145621486

21457-
// Vector256<int> tmp4 = Avx2.Shuffle(tmp3, YWXW);
21458-
shuffleMask = gtNewIconNode(SHUFFLE_YWXW, TYP_INT);
21459-
GenTreeHWIntrinsic* tmp4 =
21460-
gtNewSimdHWIntrinsicNode(type, tmp3, shuffleMask, is256 ? NI_AVX2_Shuffle : NI_SSE2_Shuffle,
21461-
CORINFO_TYPE_UINT, simdSize);
21487+
// Vector128<uint> leftHi = (left >>> 32).AsUInt32();
21488+
GenTree* leftHi =
21489+
gtNewSimdBinOpNode(GT_RSZ, type, op1Dup2, gtNewIconNode(32), simdBaseJitType, simdSize);
2146221490

21463-
// result = tmp0 + tmp4;
21464-
return gtNewSimdBinOpNode(GT_ADD, type, tmp0, tmp4, simdBaseJitType, simdSize);
21491+
// Vector128<ulong> tmp1 = Sse2.Multiply(leftHi, right.AsUInt32());
21492+
GenTreeHWIntrinsic* tmp1 =
21493+
gtNewSimdHWIntrinsicNode(type, leftHi, op2Dup2, NI_SSE2_Multiply, CORINFO_TYPE_ULONG, simdSize);
21494+
21495+
// Vector128<ulong> tmp2 = tmp0 + tmp1;
21496+
GenTree* tmp2 = gtNewSimdBinOpNode(GT_ADD, type, tmp0, tmp1, simdBaseJitType, simdSize);
21497+
21498+
// Vector128<ulong> mid = tmp2 << 32;
21499+
GenTree* mid = gtNewSimdBinOpNode(GT_LSH, type, tmp2, gtNewIconNode(32), simdBaseJitType, simdSize);
21500+
21501+
// return low + mid;
21502+
return gtNewSimdBinOpNode(GT_ADD, type, low, mid, simdBaseJitType, simdSize);
21503+
}
2146521504
}
2146621505
#elif defined(TARGET_ARM64)
2146721506
if (varTypeIsLong(simdBaseType))

src/coreclr/jit/hwintrinsiclistxarch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ HARDWARE_INTRINSIC(Vector512, Create,
302302
HARDWARE_INTRINSIC(Vector512, CreateScalar, 64, -1, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen)
303303
HARDWARE_INTRINSIC(Vector512, CreateScalarUnsafe, 64, 1, {INS_movd, INS_movd, INS_movd, INS_movd, INS_movd, INS_movd, INS_movd, INS_movd, INS_movss, INS_movsd_simd}, HW_Category_SIMDScalar, HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen)
304304
HARDWARE_INTRINSIC(Vector512, CreateSequence, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
305+
HARDWARE_INTRINSIC(Vector512, Dot, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen|HW_Flag_BaseTypeFromFirstArg)
305306
HARDWARE_INTRINSIC(Vector512, Equals, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
306307
HARDWARE_INTRINSIC(Vector512, EqualsAny, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
307308
HARDWARE_INTRINSIC(Vector512, ExtractMostSignificantBits, 64, 1, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)

src/coreclr/jit/hwintrinsicxarch.cpp

Lines changed: 29 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2446,40 +2446,40 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
24462446

24472447
case NI_Vector128_Dot:
24482448
case NI_Vector256_Dot:
2449+
case NI_Vector512_Dot:
24492450
{
24502451
assert(sig->numArgs == 2);
24512452
var_types simdType = getSIMDTypeForSize(simdSize);
24522453

2453-
if (varTypeIsByte(simdBaseType) || varTypeIsLong(simdBaseType))
2454+
if ((simdSize == 32) && !varTypeIsFloating(simdBaseType) &&
2455+
!compOpportunisticallyDependsOn(InstructionSet_AVX2))
24542456
{
2455-
// TODO-XARCH-CQ: We could support dot product for 8-bit and
2456-
// 64-bit integers if we support multiplication for the same
2457+
// We can't deal with TYP_SIMD32 for integral types if the compiler doesn't support AVX2
24572458
break;
24582459
}
24592460

2460-
if (simdSize == 32)
2461-
{
2462-
if (!varTypeIsFloating(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_AVX2))
2463-
{
2464-
// We can't deal with TYP_SIMD32 for integral types if the compiler doesn't support AVX2
2465-
break;
2466-
}
2467-
}
2468-
else if ((simdBaseType == TYP_INT) || (simdBaseType == TYP_UINT))
2461+
#if defined(TARGET_X86)
2462+
if (varTypeIsLong(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41))
24692463
{
2470-
if (!compOpportunisticallyDependsOn(InstructionSet_SSE41))
2471-
{
2472-
// TODO-XARCH-CQ: We can support 32-bit integers if we updating multiplication
2473-
// to be lowered rather than imported as the relevant operations.
2474-
break;
2475-
}
2464+
// We need SSE41 to handle long, use software fallback
2465+
break;
24762466
}
2467+
#endif // TARGET_X86
24772468

24782469
op2 = impSIMDPopStack();
24792470
op1 = impSIMDPopStack();
24802471

2472+
if ((simdSize == 64) || varTypeIsByte(simdBaseType) || varTypeIsLong(simdBaseType) ||
2473+
(varTypeIsInt(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41)))
2474+
{
2475+
// The lowering for Dot doesn't handle these cases, so import as Sum(left * right)
2476+
retNode = gtNewSimdBinOpNode(GT_MUL, simdType, op1, op2, simdBaseJitType, simdSize);
2477+
retNode = gtNewSimdSumNode(retType, retNode, simdBaseJitType, simdSize);
2478+
break;
2479+
}
2480+
24812481
retNode = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
2482-
retNode = gtNewSimdGetElementNode(retType, retNode, gtNewIconNode(0), simdBaseJitType, simdSize);
2482+
retNode = gtNewSimdToScalarNode(retType, retNode, simdBaseJitType, simdSize);
24832483
break;
24842484
}
24852485

@@ -3345,30 +3345,14 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
33453345
break;
33463346
}
33473347

3348-
assert(simdSize != 64 || IsBaselineVector512IsaSupportedDebugOnly());
3349-
3348+
#if defined(TARGET_X86)
33503349
if (varTypeIsLong(simdBaseType))
33513350
{
3352-
if (TARGET_POINTER_SIZE == 4)
3353-
{
3354-
// TODO-XARCH-CQ: 32bit support
3355-
break;
3356-
}
3357-
3358-
if ((simdSize == 32) && compOpportunisticallyDependsOn(InstructionSet_AVX2))
3359-
{
3360-
// Emulate NI_AVX512DQ_VL_MultiplyLow with AVX2 for SIMD32
3361-
}
3362-
else if ((simdSize == 16) && compOpportunisticallyDependsOn(InstructionSet_SSE41))
3363-
{
3364-
// Emulate NI_AVX512DQ_VL_MultiplyLow with SSE41 for SIMD16
3365-
}
3366-
else if (simdSize != 64)
3367-
{
3368-
// Software fallback
3369-
break;
3370-
}
3351+
// TODO-XARCH-CQ: We can't handle long here, only because one of the args might
3352+
// be scalar, and gtNewSimdCreateBroadcastNode doesn't handle long on x86.
3353+
break;
33713354
}
3355+
#endif // TARGET_X86
33723356

33733357
CORINFO_ARG_LIST_HANDLE arg1 = sig->args;
33743358
CORINFO_ARG_LIST_HANDLE arg2 = info.compCompHnd->getArgNext(arg1);
@@ -3403,31 +3387,6 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
34033387
break;
34043388
}
34053389

3406-
assert(simdSize != 64 || IsBaselineVector512IsaSupportedDebugOnly());
3407-
3408-
if (varTypeIsLong(simdBaseType))
3409-
{
3410-
if (TARGET_POINTER_SIZE == 4)
3411-
{
3412-
// TODO-XARCH-CQ: 32bit support
3413-
break;
3414-
}
3415-
3416-
if ((simdSize == 32) && compOpportunisticallyDependsOn(InstructionSet_AVX2))
3417-
{
3418-
// Emulate NI_AVX512DQ_VL_MultiplyLow with AVX2 for SIMD32
3419-
}
3420-
else if ((simdSize == 16) && compOpportunisticallyDependsOn(InstructionSet_SSE41))
3421-
{
3422-
// Emulate NI_AVX512DQ_VL_MultiplyLow with SSE41 for SIMD16
3423-
}
3424-
else if (simdSize != 64)
3425-
{
3426-
// Software fallback
3427-
break;
3428-
}
3429-
}
3430-
34313390
op3 = impSIMDPopStack();
34323391
op2 = impSIMDPopStack();
34333392
op1 = impSIMDPopStack();
@@ -3835,17 +3794,15 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
38353794
{
38363795
assert(sig->numArgs == 1);
38373796

3838-
if ((simdSize == 32) && !compOpportunisticallyDependsOn(InstructionSet_AVX2))
3839-
{
3840-
// Vector256 requires AVX2
3841-
break;
3842-
}
3843-
else if ((simdSize == 16) && !compOpportunisticallyDependsOn(InstructionSet_SSE2))
3797+
if ((simdSize == 32) && !varTypeIsFloating(simdBaseType) &&
3798+
!compOpportunisticallyDependsOn(InstructionSet_AVX2))
38443799
{
3800+
// We can't deal with TYP_SIMD32 for integral types if the compiler doesn't support AVX2
38453801
break;
38463802
}
3803+
38473804
#if defined(TARGET_X86)
3848-
else if (varTypeIsLong(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41))
3805+
if (varTypeIsLong(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41))
38493806
{
38503807
// We need SSE41 to handle long, use software fallback
38513808
break;

0 commit comments

Comments
 (0)