Skip to content

Commit f565711

Browse files
authored
JIT: Accelerate Vector.Dot for all base types (#111853)
1 parent 5ed086c commit f565711

File tree

3 files changed

+57
-102
lines changed

3 files changed

+57
-102
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21401,45 +21401,45 @@ GenTree* Compiler::gtNewSimdBinOpNode(
2140121401
}
2140221402
else if (varTypeIsLong(simdBaseType))
2140321403
{
21404-
assert((simdSize == 16) || (simdSize == 32) || (simdSize == 64));
21404+
// This fallback path will be used only if the vpmullq instruction is not available.
21405+
// The implementation is a simple decomposition using pmuludq, which multiplies
21406+
// two uint32s and returns a uint64 result.
21407+
//
21408+
// aLo * bLo + ((aLo * bHi + aHi * bLo) << 32)
2140521409

21406-
assert(((simdSize == 16) && compOpportunisticallyDependsOn(InstructionSet_SSE41)) ||
21407-
((simdSize == 32) && compOpportunisticallyDependsOn(InstructionSet_AVX2)));
21410+
assert(!canUseEvexEncodingDebugOnly());
21411+
assert((simdSize == 16) || compIsaSupportedDebugOnly(InstructionSet_AVX2));
2140821412

21409-
// Make op1 and op2 multi-use:
21410-
GenTree* op1Dup = fgMakeMultiUse(&op1);
21411-
GenTree* op2Dup = fgMakeMultiUse(&op2);
21413+
NamedIntrinsic muludq = (simdSize == 16) ? NI_SSE2_Multiply : NI_AVX2_Multiply;
21414+
21415+
GenTree* op1Dup1 = fgMakeMultiUse(&op1);
21416+
GenTree* op1Dup2 = gtCloneExpr(op1Dup1);
21417+
GenTree* op2Dup1 = fgMakeMultiUse(&op2);
21418+
GenTree* op2Dup2 = gtCloneExpr(op2Dup1);
2141221419

21413-
const bool is256 = simdSize == 32;
21420+
// Vector128<ulong> low = Sse2.Multiply(a.AsUInt32(), b.AsUInt32());
21421+
GenTree* low = gtNewSimdHWIntrinsicNode(type, op1, op2, muludq, CORINFO_TYPE_ULONG, simdSize);
2141421422

21415-
// Vector256<ulong> tmp0 = Avx2.Multiply(left, right);
21416-
GenTreeHWIntrinsic* tmp0 =
21417-
gtNewSimdHWIntrinsicNode(type, op1, op2, is256 ? NI_AVX2_Multiply : NI_SSE2_Multiply,
21418-
CORINFO_TYPE_ULONG, simdSize);
21423+
// Vector128<ulong> mid = (b >>> 32).AsUInt64();
21424+
GenTree* mid = gtNewSimdBinOpNode(GT_RSZ, type, op2Dup1, gtNewIconNode(32), simdBaseJitType, simdSize);
2141921425

21420-
// Vector256<uint> tmp1 = Avx2.Shuffle(right.AsUInt32(), ZWXY);
21421-
GenTree* shuffleMask = gtNewIconNode(SHUFFLE_ZWXY, TYP_INT);
21422-
GenTreeHWIntrinsic* tmp1 =
21423-
gtNewSimdHWIntrinsicNode(type, op2Dup, shuffleMask, is256 ? NI_AVX2_Shuffle : NI_SSE2_Shuffle,
21424-
CORINFO_TYPE_UINT, simdSize);
21426+
// mid = Sse2.Multiply(mid.AsUInt32(), a.AsUInt32());
21427+
mid = gtNewSimdHWIntrinsicNode(type, mid, op1Dup1, muludq, CORINFO_TYPE_ULONG, simdSize);
2142521428

21426-
// Vector256<uint> tmp2 = Avx2.MultiplyLow(left.AsUInt32(), tmp1);
21427-
GenTree* tmp2 = gtNewSimdBinOpNode(GT_MUL, type, op1Dup, tmp1, CORINFO_TYPE_UINT, simdSize);
21429+
// Vector128<ulong> tmp = (a >>> 32).AsUInt64();
21430+
GenTree* tmp = gtNewSimdBinOpNode(GT_RSZ, type, op1Dup2, gtNewIconNode(32), simdBaseJitType, simdSize);
2142821431

21429-
// Vector256<int> tmp3 = Avx2.HorizontalAdd(tmp2.AsInt32(), Vector256<int>.Zero);
21430-
GenTreeHWIntrinsic* tmp3 =
21431-
gtNewSimdHWIntrinsicNode(type, tmp2, gtNewZeroConNode(type),
21432-
is256 ? NI_AVX2_HorizontalAdd : NI_SSSE3_HorizontalAdd, CORINFO_TYPE_UINT,
21433-
simdSize);
21432+
// tmp = Sse2.Multiply(tmp.AsUInt32(), b.AsUInt32());
21433+
tmp = gtNewSimdHWIntrinsicNode(type, tmp, op2Dup2, muludq, CORINFO_TYPE_ULONG, simdSize);
2143421434

21435-
// Vector256<int> tmp4 = Avx2.Shuffle(tmp3, YWXW);
21436-
shuffleMask = gtNewIconNode(SHUFFLE_YWXW, TYP_INT);
21437-
GenTreeHWIntrinsic* tmp4 =
21438-
gtNewSimdHWIntrinsicNode(type, tmp3, shuffleMask, is256 ? NI_AVX2_Shuffle : NI_SSE2_Shuffle,
21439-
CORINFO_TYPE_UINT, simdSize);
21435+
// mid += tmp;
21436+
mid = gtNewSimdBinOpNode(GT_ADD, type, mid, tmp, simdBaseJitType, simdSize);
2144021437

21441-
// result = tmp0 + tmp4;
21442-
return gtNewSimdBinOpNode(GT_ADD, type, tmp0, tmp4, simdBaseJitType, simdSize);
21438+
// mid <<= 32;
21439+
mid = gtNewSimdBinOpNode(GT_LSH, type, mid, gtNewIconNode(32), simdBaseJitType, simdSize);
21440+
21441+
// return low + mid;
21442+
return gtNewSimdBinOpNode(GT_ADD, type, low, mid, simdBaseJitType, simdSize);
2144321443
}
2144421444
#elif defined(TARGET_ARM64)
2144521445
if (varTypeIsLong(simdBaseType))
@@ -26070,7 +26070,7 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
2607026070

2607126071
if (simdSize == 32)
2607226072
{
26073-
assert(compIsaSupportedDebugOnly(InstructionSet_AVX2));
26073+
assert(IsBaselineVector256IsaSupportedDebugOnly());
2607426074
GenTree* op1Dup = fgMakeMultiUse(&op1);
2607526075

2607626076
op1 = gtNewSimdGetLowerNode(TYP_SIMD16, op1, simdBaseJitType, simdSize);

src/coreclr/jit/hwintrinsiclistxarch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ HARDWARE_INTRINSIC(Vector512, Create,
302302
HARDWARE_INTRINSIC(Vector512, CreateScalar, 64, -1, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen)
303303
HARDWARE_INTRINSIC(Vector512, CreateScalarUnsafe, 64, 1, {INS_movd, INS_movd, INS_movd, INS_movd, INS_movd, INS_movd, INS_movd, INS_movd, INS_movss, INS_movsd_simd}, HW_Category_SIMDScalar, HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen)
304304
HARDWARE_INTRINSIC(Vector512, CreateSequence, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
305+
HARDWARE_INTRINSIC(Vector512, Dot, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen|HW_Flag_BaseTypeFromFirstArg)
305306
HARDWARE_INTRINSIC(Vector512, Equals, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
306307
HARDWARE_INTRINSIC(Vector512, EqualsAny, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
307308
HARDWARE_INTRINSIC(Vector512, ExtractMostSignificantBits, 64, 1, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)

src/coreclr/jit/hwintrinsicxarch.cpp

Lines changed: 25 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,40 +2450,40 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
24502450

24512451
case NI_Vector128_Dot:
24522452
case NI_Vector256_Dot:
2453+
case NI_Vector512_Dot:
24532454
{
24542455
assert(sig->numArgs == 2);
24552456
var_types simdType = getSIMDTypeForSize(simdSize);
24562457

2457-
if (varTypeIsByte(simdBaseType) || varTypeIsLong(simdBaseType))
2458+
if ((simdSize == 32) && !varTypeIsFloating(simdBaseType) &&
2459+
!compOpportunisticallyDependsOn(InstructionSet_AVX2))
24582460
{
2459-
// TODO-XARCH-CQ: We could support dot product for 8-bit and
2460-
// 64-bit integers if we support multiplication for the same
2461+
// We can't deal with TYP_SIMD32 for integral types if the compiler doesn't support AVX2
24612462
break;
24622463
}
24632464

2464-
if (simdSize == 32)
2465-
{
2466-
if (!varTypeIsFloating(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_AVX2))
2467-
{
2468-
// We can't deal with TYP_SIMD32 for integral types if the compiler doesn't support AVX2
2469-
break;
2470-
}
2471-
}
2472-
else if ((simdBaseType == TYP_INT) || (simdBaseType == TYP_UINT))
2465+
#if defined(TARGET_X86)
2466+
if (varTypeIsLong(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41))
24732467
{
2474-
if (!compOpportunisticallyDependsOn(InstructionSet_SSE41))
2475-
{
2476-
// TODO-XARCH-CQ: We can support 32-bit integers if we updating multiplication
2477-
// to be lowered rather than imported as the relevant operations.
2478-
break;
2479-
}
2468+
// We need SSE41 to handle long, use software fallback
2469+
break;
24802470
}
2471+
#endif // TARGET_X86
24812472

24822473
op2 = impSIMDPopStack();
24832474
op1 = impSIMDPopStack();
24842475

2476+
if ((simdSize == 64) || varTypeIsByte(simdBaseType) || varTypeIsLong(simdBaseType) ||
2477+
(varTypeIsInt(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41)))
2478+
{
2479+
// The lowering for Dot doesn't handle these cases, so import as Sum(left * right)
2480+
retNode = gtNewSimdBinOpNode(GT_MUL, simdType, op1, op2, simdBaseJitType, simdSize);
2481+
retNode = gtNewSimdSumNode(retType, retNode, simdBaseJitType, simdSize);
2482+
break;
2483+
}
2484+
24852485
retNode = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
2486-
retNode = gtNewSimdGetElementNode(retType, retNode, gtNewIconNode(0), simdBaseJitType, simdSize);
2486+
retNode = gtNewSimdToScalarNode(retType, retNode, simdBaseJitType, simdSize);
24872487
break;
24882488
}
24892489

@@ -3349,28 +3349,14 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
33493349
break;
33503350
}
33513351

3352+
#if defined(TARGET_X86)
33523353
if (varTypeIsLong(simdBaseType))
33533354
{
3354-
if (TARGET_POINTER_SIZE == 4)
3355-
{
3356-
// TODO-XARCH-CQ: 32bit support
3357-
break;
3358-
}
3359-
3360-
if ((simdSize == 32) && compOpportunisticallyDependsOn(InstructionSet_AVX2))
3361-
{
3362-
// Emulate NI_AVX512DQ_VL_MultiplyLow with AVX2 for SIMD32
3363-
}
3364-
else if ((simdSize == 16) && compOpportunisticallyDependsOn(InstructionSet_SSE41))
3365-
{
3366-
// Emulate NI_AVX512DQ_VL_MultiplyLow with SSE41 for SIMD16
3367-
}
3368-
else if (simdSize != 64)
3369-
{
3370-
// Software fallback
3371-
break;
3372-
}
3355+
// TODO-XARCH-CQ: We can't handle long here, only because one of the args might
3356+
// be scalar, and gtNewSimdCreateBroadcastNode doesn't handle long on x86.
3357+
break;
33733358
}
3359+
#endif // TARGET_X86
33743360

33753361
CORINFO_ARG_LIST_HANDLE arg1 = sig->args;
33763362
CORINFO_ARG_LIST_HANDLE arg2 = info.compCompHnd->getArgNext(arg1);
@@ -3405,29 +3391,6 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
34053391
break;
34063392
}
34073393

3408-
if (varTypeIsLong(simdBaseType))
3409-
{
3410-
if (TARGET_POINTER_SIZE == 4)
3411-
{
3412-
// TODO-XARCH-CQ: 32bit support
3413-
break;
3414-
}
3415-
3416-
if ((simdSize == 32) && compOpportunisticallyDependsOn(InstructionSet_AVX2))
3417-
{
3418-
// Emulate NI_AVX512DQ_VL_MultiplyLow with AVX2 for SIMD32
3419-
}
3420-
else if ((simdSize == 16) && compOpportunisticallyDependsOn(InstructionSet_SSE41))
3421-
{
3422-
// Emulate NI_AVX512DQ_VL_MultiplyLow with SSE41 for SIMD16
3423-
}
3424-
else if (simdSize != 64)
3425-
{
3426-
// Software fallback
3427-
break;
3428-
}
3429-
}
3430-
34313394
op3 = impSIMDPopStack();
34323395
op2 = impSIMDPopStack();
34333396
op1 = impSIMDPopStack();
@@ -3818,17 +3781,8 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
38183781
{
38193782
assert(sig->numArgs == 1);
38203783

3821-
if ((simdSize == 32) && !compOpportunisticallyDependsOn(InstructionSet_AVX2))
3822-
{
3823-
// Vector256 requires AVX2
3824-
break;
3825-
}
3826-
else if ((simdSize == 16) && !compOpportunisticallyDependsOn(InstructionSet_SSE2))
3827-
{
3828-
break;
3829-
}
38303784
#if defined(TARGET_X86)
3831-
else if (varTypeIsLong(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41))
3785+
if (varTypeIsLong(simdBaseType) && !compOpportunisticallyDependsOn(InstructionSet_SSE41))
38323786
{
38333787
// We need SSE41 to handle long, use software fallback
38343788
break;

0 commit comments

Comments
 (0)