Skip to content

Commit e0ecd1f

Browse files
Improve the handling of SIMD comparisons (#104944)
* Ensure that we can constant fold op_Equality and op_Inequality for SIMD * Optimize comparisons against AllBitsSet on pre-AVX512 hardware
1 parent 39968e7 commit e0ecd1f

File tree

3 files changed

+397
-88
lines changed

3 files changed

+397
-88
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30940,6 +30940,32 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3094030940
}
3094130941
#endif
3094230942

30943+
case NI_Vector128_op_Equality:
30944+
#if defined(TARGET_ARM64)
30945+
case NI_Vector64_op_Equality:
30946+
#elif defined(TARGET_XARCH)
30947+
case NI_Vector256_op_Equality:
30948+
case NI_Vector512_op_Equality:
30949+
#endif // !TARGET_ARM64 && !TARGET_XARCH
30950+
{
30951+
cnsNode->AsVecCon()->EvaluateBinaryInPlace(GT_EQ, isScalar, simdBaseType, otherNode->AsVecCon());
30952+
resultNode = gtNewIconNode(cnsNode->AsVecCon()->IsAllBitsSet() ? 1 : 0, retType);
30953+
break;
30954+
}
30955+
30956+
case NI_Vector128_op_Inequality:
30957+
#if defined(TARGET_ARM64)
30958+
case NI_Vector64_op_Inequality:
30959+
#elif defined(TARGET_XARCH)
30960+
case NI_Vector256_op_Inequality:
30961+
case NI_Vector512_op_Inequality:
30962+
#endif // !TARGET_ARM64 && !TARGET_XARCH
30963+
{
30964+
cnsNode->AsVecCon()->EvaluateBinaryInPlace(GT_NE, isScalar, simdBaseType, otherNode->AsVecCon());
30965+
resultNode = gtNewIconNode(cnsNode->AsVecCon()->IsZero() ? 0 : 1, retType);
30966+
break;
30967+
}
30968+
3094330969
default:
3094430970
{
3094530971
break;
@@ -31380,6 +31406,48 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3138031406
}
3138131407
#endif
3138231408

31409+
case NI_Vector128_op_Equality:
31410+
#if defined(TARGET_ARM64)
31411+
case NI_Vector64_op_Equality:
31412+
#elif defined(TARGET_XARCH)
31413+
case NI_Vector256_op_Equality:
31414+
case NI_Vector512_op_Equality:
31415+
#endif // !TARGET_ARM64 && !TARGET_XARCH
31416+
{
31417+
if (varTypeIsFloating(simdBaseType))
31418+
{
31419+
// Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types
31420+
if (cnsNode->IsVectorNaN(simdBaseType))
31421+
{
31422+
resultNode = gtNewIconNode(0, retType);
31423+
resultNode = gtWrapWithSideEffects(resultNode, otherNode, GTF_ALL_EFFECT);
31424+
break;
31425+
}
31426+
}
31427+
break;
31428+
}
31429+
31430+
case NI_Vector128_op_Inequality:
31431+
#if defined(TARGET_ARM64)
31432+
case NI_Vector64_op_Inequality:
31433+
#elif defined(TARGET_XARCH)
31434+
case NI_Vector256_op_Inequality:
31435+
case NI_Vector512_op_Inequality:
31436+
#endif // !TARGET_ARM64 && !TARGET_XARCH
31437+
{
31438+
if (varTypeIsFloating(simdBaseType))
31439+
{
31440+
// Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types
31441+
if (cnsNode->IsVectorNaN(simdBaseType))
31442+
{
31443+
resultNode = gtNewIconNode(1, retType);
31444+
resultNode = gtWrapWithSideEffects(resultNode, otherNode, GTF_ALL_EFFECT);
31445+
break;
31446+
}
31447+
}
31448+
break;
31449+
}
31450+
3138331451
default:
3138431452
{
3138531453
break;

src/coreclr/jit/lowerxarch.cpp

Lines changed: 97 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2577,7 +2577,7 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm
25772577
CorInfoType maskBaseJitType = simdBaseJitType;
25782578
var_types maskBaseType = simdBaseType;
25792579

2580-
if (op1Msk->OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector))
2580+
if (op1Msk->OperIsConvertMaskToVector())
25812581
{
25822582
GenTreeHWIntrinsic* cvtMaskToVector = op1Msk->AsHWIntrinsic();
25832583

@@ -2588,122 +2588,131 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm
25882588
maskBaseType = cvtMaskToVector->GetSimdBaseType();
25892589
}
25902590

2591-
if (!varTypeIsFloating(simdBaseType) && (simdSize != 64) && op2->IsVectorZero() &&
2592-
comp->compOpportunisticallyDependsOn(InstructionSet_SSE41) && !varTypeIsMask(op1Msk))
2591+
if (!varTypeIsFloating(simdBaseType) && (simdSize != 64) && !varTypeIsMask(op1Msk))
25932592
{
2594-
// On SSE4.1 or higher we can optimize comparisons against zero to
2595-
// just use PTEST. We can't support it for floating-point, however,
2596-
// as it has both +0.0 and -0.0 where +0.0 == -0.0
2593+
bool isOp2VectorZero = op2->IsVectorZero();
25972594

2598-
bool skipReplaceOperands = false;
2599-
2600-
if (op1->OperIsHWIntrinsic())
2595+
if ((isOp2VectorZero || op2->IsVectorAllBitsSet()) &&
2596+
comp->compOpportunisticallyDependsOn(InstructionSet_SSE41))
26012597
{
2602-
GenTreeHWIntrinsic* op1Intrinsic = op1->AsHWIntrinsic();
2603-
NamedIntrinsic op1IntrinsicId = op1Intrinsic->GetHWIntrinsicId();
2598+
// On SSE4.1 or higher we can optimize comparisons against Zero or AllBitsSet to
2599+
// just use PTEST. We can't support it for floating-point, however, as it has
2600+
// both +0.0 and -0.0 where +0.0 == -0.0
26042601

2605-
GenTree* nestedOp1 = nullptr;
2606-
GenTree* nestedOp2 = nullptr;
2607-
bool isEmbeddedBroadcast = false;
2602+
bool skipReplaceOperands = false;
26082603

2609-
if (op1Intrinsic->GetOperandCount() == 2)
2604+
if (!isOp2VectorZero)
26102605
{
2611-
nestedOp1 = op1Intrinsic->Op(1);
2612-
nestedOp2 = op1Intrinsic->Op(2);
2606+
// We can optimize to TestC(op1, allbitsset)
2607+
//
2608+
// This works out because TestC sets CF if (~x & y) == 0, so:
2609+
// ~00 & 11 = 11; 11 & 11 = 11; NC
2610+
// ~01 & 11 = 01; 10 & 11 = 10; NC
2611+
// ~10 & 11 = 10; 01 & 11 = 01; NC
2612+
// ~11 & 11 = 11; 00 & 11 = 00; C
26132613

2614-
assert(!nestedOp1->isContained());
2615-
isEmbeddedBroadcast = nestedOp2->isContained() && nestedOp2->OperIsHWIntrinsic();
2616-
}
2614+
assert(op2->IsVectorAllBitsSet());
2615+
cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;
26172616

2618-
switch (op1IntrinsicId)
2617+
skipReplaceOperands = true;
2618+
}
2619+
else if (op1->OperIsHWIntrinsic())
26192620
{
2620-
case NI_SSE_And:
2621-
case NI_SSE2_And:
2622-
case NI_AVX_And:
2623-
case NI_AVX2_And:
2621+
assert(op2->IsVectorZero());
2622+
2623+
GenTreeHWIntrinsic* op1Intrinsic = op1->AsHWIntrinsic();
2624+
2625+
if (op1Intrinsic->GetOperandCount() == 2)
26242626
{
2625-
// We can optimize to TestZ(op1.op1, op1.op2)
2627+
GenTree* nestedOp1 = op1Intrinsic->Op(1);
2628+
GenTree* nestedOp2 = op1Intrinsic->Op(2);
2629+
2630+
assert(!nestedOp1->isContained());
2631+
bool isEmbeddedBroadcast = nestedOp2->isContained() && nestedOp2->OperIsHWIntrinsic();
26262632

2627-
if (isEmbeddedBroadcast)
2633+
bool isScalar = false;
2634+
genTreeOps oper = op1Intrinsic->GetOperForHWIntrinsicId(&isScalar);
2635+
2636+
switch (oper)
26282637
{
2629-
// PTEST doesn't support embedded broadcast
2630-
break;
2631-
}
2638+
case GT_AND:
2639+
{
2640+
// We can optimize to TestZ(op1.op1, op1.op2)
26322641

2633-
node->Op(1) = nestedOp1;
2634-
node->Op(2) = nestedOp2;
2642+
if (isEmbeddedBroadcast)
2643+
{
2644+
// PTEST doesn't support embedded broadcast
2645+
break;
2646+
}
26352647

2636-
BlockRange().Remove(op1);
2637-
BlockRange().Remove(op2);
2648+
node->Op(1) = nestedOp1;
2649+
node->Op(2) = nestedOp2;
26382650

2639-
skipReplaceOperands = true;
2640-
break;
2641-
}
2651+
BlockRange().Remove(op1);
2652+
BlockRange().Remove(op2);
26422653

2643-
case NI_SSE_AndNot:
2644-
case NI_SSE2_AndNot:
2645-
case NI_AVX_AndNot:
2646-
case NI_AVX2_AndNot:
2647-
{
2648-
// We can optimize to TestC(op1.op1, op1.op2)
2654+
skipReplaceOperands = true;
2655+
break;
2656+
}
26492657

2650-
if (isEmbeddedBroadcast)
2651-
{
2652-
// PTEST doesn't support embedded broadcast
2653-
break;
2654-
}
2658+
case GT_AND_NOT:
2659+
{
2660+
// We can optimize to TestC(op1.op1, op1.op2)
2661+
2662+
if (isEmbeddedBroadcast)
2663+
{
2664+
// PTEST doesn't support embedded broadcast
2665+
break;
2666+
}
26552667

2656-
cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;
2668+
cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;
26572669

2658-
node->Op(1) = nestedOp1;
2659-
node->Op(2) = nestedOp2;
2670+
node->Op(1) = nestedOp1;
2671+
node->Op(2) = nestedOp2;
26602672

2661-
BlockRange().Remove(op1);
2662-
BlockRange().Remove(op2);
2673+
BlockRange().Remove(op1);
2674+
BlockRange().Remove(op2);
26632675

2664-
skipReplaceOperands = true;
2665-
break;
2666-
}
2676+
skipReplaceOperands = true;
2677+
break;
2678+
}
26672679

2668-
default:
2669-
{
2670-
break;
2680+
default:
2681+
{
2682+
break;
2683+
}
2684+
}
26712685
}
26722686
}
2673-
}
2674-
2675-
if (!skipReplaceOperands)
2676-
{
2677-
// Default handler, emit a TestZ(op1, op1)
26782687

2679-
node->Op(1) = op1;
2680-
BlockRange().Remove(op2);
2688+
if (!skipReplaceOperands)
2689+
{
2690+
// Default handler, emit a TestZ(op1, op1)
2691+
assert(op2->IsVectorZero());
26812692

2682-
LIR::Use op1Use(BlockRange(), &node->Op(1), node);
2683-
ReplaceWithLclVar(op1Use);
2684-
op1 = node->Op(1);
2693+
node->Op(1) = op1;
2694+
BlockRange().Remove(op2);
26852695

2686-
op2 = comp->gtClone(op1);
2687-
BlockRange().InsertAfter(op1, op2);
2688-
node->Op(2) = op2;
2689-
}
2696+
LIR::Use op1Use(BlockRange(), &node->Op(1), node);
2697+
ReplaceWithLclVar(op1Use);
2698+
op1 = node->Op(1);
26902699

2691-
if (simdSize == 32)
2692-
{
2693-
// TODO-Review: LowerHWIntrinsicCC resets the id again, so why is this needed?
2694-
node->ChangeHWIntrinsicId(NI_AVX_TestZ);
2695-
LowerHWIntrinsicCC(node, NI_AVX_PTEST, cmpCnd);
2696-
}
2697-
else
2698-
{
2699-
assert(simdSize == 16);
2700+
op2 = comp->gtClone(op1);
2701+
BlockRange().InsertAfter(op1, op2);
2702+
node->Op(2) = op2;
2703+
}
27002704

2701-
// TODO-Review: LowerHWIntrinsicCC resets the id again, so why is this needed?
2702-
node->ChangeHWIntrinsicId(NI_SSE41_TestZ);
2703-
LowerHWIntrinsicCC(node, NI_SSE41_PTEST, cmpCnd);
2705+
if (simdSize == 32)
2706+
{
2707+
LowerHWIntrinsicCC(node, NI_AVX_PTEST, cmpCnd);
2708+
}
2709+
else
2710+
{
2711+
assert(simdSize == 16);
2712+
LowerHWIntrinsicCC(node, NI_SSE41_PTEST, cmpCnd);
2713+
}
2714+
return LowerNode(node);
27042715
}
2705-
2706-
return LowerNode(node);
27072716
}
27082717

27092718
// TODO-XARCH-AVX512: We should handle TYP_SIMD12 here under the EVEX path, but doing
@@ -3579,7 +3588,7 @@ GenTree* Lowering::LowerHWIntrinsicTernaryLogic(GenTreeHWIntrinsic* node)
35793588
}
35803589
}
35813590

3582-
if (condition->OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector))
3591+
if (condition->OperIsConvertMaskToVector())
35833592
{
35843593
GenTree* tmp = condition->AsHWIntrinsic()->Op(1);
35853594
BlockRange().Remove(condition);

0 commit comments

Comments
 (0)