Skip to content

Commit 4a06ac3

Browse files
authored
Arm64 SVE: Fix conditionalselect with constant arguments (#116852)
* Arm64 SVE: Fix conditionalselect with constant arguments Fixes #116847 When folding, allow arg1 to be a constant mask * Make masked EvaluateBinaryInPlace() Arm64 only * Check significantBit in EvaluateSimdVectorToPattern() * fix set checks in EvaluateSimdVectorToPattern * Use masks in EvalHWIntrinsicFunTernary() for SVE conditionalselect * Check all of a vector lane when converting to mask * Add testing for EvalHWIntrinsicFunTernary changes * whitespace * Revert "Check all of a vector lane when converting to mask" This reverts commit b923b28. * rename significantBit to leastSignificantBit * Use LSB of vector when converting from vector to mask * Add LowerCnsMask * Add testcase * Remove EvaluateSimdMaskToPattern * Revert "Use LSB of vector when converting from vector to mask" This reverts commit c96e38c. * formatting * fix assert check Change-Id: I7951b70aec9aaef5521e100d30737b5a4d332b38 * GenTree for gtNewSimdCvtVectorToMaskNode() * Split NI_Sve_ConditionalSelect into it's own case * Remove mask version of EvaluateBinaryInPlace * remove assert * Check all bits in EvaluateSimdCvtVectorToMask * Add ConstantVectors test * No need for DOTNET_EnableHWIntrinsic in csproj * Use IsMaskZero * Remove EvaluateBinarySimdAndMask * In lowering, default the mask type to byte * In lowering, convert mask using byte basetype
1 parent e4b07f4 commit 4a06ac3

File tree

10 files changed

+392
-31
lines changed

10 files changed

+392
-31
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33525,28 +33525,63 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3352533525
case NI_Vector512_ConditionalSelect:
3352633526
#elif defined(TARGET_ARM64)
3352733527
case NI_AdvSimd_BitwiseSelect:
33528-
case NI_Sve_ConditionalSelect:
3352933528
#endif
3353033529
{
3353133530
assert(!varTypeIsMask(retType));
33531+
assert(!varTypeIsMask(op1));
3353233532

3353333533
if (cnsNode != op1)
3353433534
{
3353533535
break;
3353633536
}
3353733537

33538-
#if defined(TARGET_ARM64)
33539-
if (ni == NI_Sve_ConditionalSelect)
33538+
if (op1->IsVectorAllBitsSet())
3354033539
{
33541-
assert(!op1->IsVectorAllBitsSet() && !op1->IsVectorZero());
33540+
if ((op3->gtFlags & GTF_SIDE_EFFECT) != 0)
33541+
{
33542+
// op3 has side effects, this would require us to append a new statement
33543+
// to ensure that it isn't lost, which isn't safe to do from the general
33544+
// purpose handler here. We'll recognize this and mark it in VN instead
33545+
break;
33546+
}
33547+
33548+
// op3 has no side effects, so we can return op2 directly
33549+
return op2;
3354233550
}
33543-
else
33551+
33552+
if (op1->IsVectorZero())
3354433553
{
33545-
assert(!op1->IsTrueMask(simdBaseType) && !op1->IsMaskZero());
33554+
return gtWrapWithSideEffects(op3, op2, GTF_ALL_EFFECT);
33555+
}
33556+
33557+
if (op2->IsCnsVec() && op3->IsCnsVec())
33558+
{
33559+
// op2 = op2 & op1
33560+
op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon());
33561+
33562+
// op3 = op2 & ~op1
33563+
op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon());
33564+
33565+
// op2 = op2 | op3
33566+
op2->AsVecCon()->EvaluateBinaryInPlace(GT_OR, false, simdBaseType, op3->AsVecCon());
33567+
33568+
resultNode = op2;
33569+
}
33570+
break;
33571+
}
33572+
33573+
#if defined(TARGET_ARM64)
33574+
case NI_Sve_ConditionalSelect:
33575+
{
33576+
assert(!varTypeIsMask(retType));
33577+
assert(varTypeIsMask(op1));
33578+
33579+
if (cnsNode != op1)
33580+
{
33581+
break;
3354633582
}
33547-
#endif
3354833583

33549-
if (op1->IsVectorAllBitsSet() || op1->IsTrueMask(simdBaseType))
33584+
if (op1->IsTrueMask(simdBaseType))
3355033585
{
3355133586
if ((op3->gtFlags & GTF_SIDE_EFFECT) != 0)
3355233587
{
@@ -33560,18 +33595,30 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3356033595
return op2;
3356133596
}
3356233597

33563-
if (op1->IsVectorZero() || op1->IsMaskZero())
33598+
if (op1->IsMaskZero())
3356433599
{
3356533600
return gtWrapWithSideEffects(op3, op2, GTF_ALL_EFFECT);
3356633601
}
3356733602

3356833603
if (op2->IsCnsVec() && op3->IsCnsVec())
3356933604
{
33605+
assert(op2->gtType == TYP_SIMD16);
33606+
assert(op3->gtType == TYP_SIMD16);
33607+
33608+
simd16_t op1SimdVal;
33609+
EvaluateSimdCvtMaskToVector<simd16_t>(simdBaseType, &op1SimdVal, op1->AsMskCon()->gtSimdMaskVal);
33610+
3357033611
// op2 = op2 & op1
33571-
op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon());
33612+
simd16_t result = {};
33613+
EvaluateBinarySimd<simd16_t>(GT_AND, false, simdBaseType, &result, op2->AsVecCon()->gtSimd16Val,
33614+
op1SimdVal);
33615+
op2->AsVecCon()->gtSimd16Val = result;
3357233616

3357333617
// op3 = op2 & ~op1
33574-
op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon());
33618+
result = {};
33619+
EvaluateBinarySimd<simd16_t>(GT_AND_NOT, false, simdBaseType, &result, op3->AsVecCon()->gtSimd16Val,
33620+
op1SimdVal);
33621+
op3->AsVecCon()->gtSimd16Val = result;
3357533622

3357633623
// op2 = op2 | op3
3357733624
op2->AsVecCon()->EvaluateBinaryInPlace(GT_OR, false, simdBaseType, op3->AsVecCon());
@@ -33580,6 +33627,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3358033627
}
3358133628
break;
3358233629
}
33630+
#endif // TARGET_ARM64
3358333631

3358433632
default:
3358533633
{

src/coreclr/jit/lower.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,11 @@ GenTree* Lowering::LowerNode(GenTree* node)
789789
LowerReturnSuspend(node);
790790
break;
791791

792+
#if defined(FEATURE_HW_INTRINSICS) && defined(TARGET_ARM64)
793+
case GT_CNS_MSK:
794+
return LowerCnsMask(node->AsMskCon());
795+
#endif // FEATURE_HW_INTRINSICS && TARGET_ARM64
796+
792797
default:
793798
break;
794799
}

src/coreclr/jit/lower.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -451,11 +451,12 @@ class Lowering final : public Phase
451451
GenTree* TryLowerXorOpToGetMaskUpToLowestSetBit(GenTreeOp* xorNode);
452452
void LowerBswapOp(GenTreeOp* node);
453453
#elif defined(TARGET_ARM64)
454-
bool IsValidConstForMovImm(GenTreeHWIntrinsic* node);
455-
void LowerHWIntrinsicFusedMultiplyAddScalar(GenTreeHWIntrinsic* node);
456-
void LowerModPow2(GenTree* node);
457-
bool TryLowerAddForPossibleContainment(GenTreeOp* node, GenTree** next);
458-
void StoreFFRValue(GenTreeHWIntrinsic* node);
454+
bool IsValidConstForMovImm(GenTreeHWIntrinsic* node);
455+
void LowerHWIntrinsicFusedMultiplyAddScalar(GenTreeHWIntrinsic* node);
456+
void LowerModPow2(GenTree* node);
457+
GenTree* LowerCnsMask(GenTreeMskCon* mask);
458+
bool TryLowerAddForPossibleContainment(GenTreeOp* node, GenTree** next);
459+
void StoreFFRValue(GenTreeHWIntrinsic* node);
459460
#endif // !TARGET_XARCH && !TARGET_ARM64
460461
GenTree* InsertNewSimdCreateScalarUnsafeNode(var_types type,
461462
GenTree* op1,

src/coreclr/jit/lowerarmarch.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,77 @@ void Lowering::LowerModPow2(GenTree* node)
11341134
ContainCheckNode(mod);
11351135
}
11361136

1137+
//------------------------------------------------------------------------
1138+
// LowerCnsMask: Lower GT_CNS_MSK. Ensure the mask matches a known pattern.
1139+
// If not then lower to a constant vector.
1140+
//
1141+
// Arguments:
1142+
// mask - the node to lower
1143+
//
1144+
GenTree* Lowering::LowerCnsMask(GenTreeMskCon* mask)
1145+
{
1146+
// Try every type until a match is found
1147+
1148+
if (mask->IsZero())
1149+
{
1150+
return mask->gtNext;
1151+
}
1152+
1153+
if (EvaluateSimdMaskToPattern<simd16_t>(TYP_BYTE, mask->gtSimdMaskVal) != SveMaskPatternNone)
1154+
{
1155+
return mask->gtNext;
1156+
}
1157+
1158+
if (EvaluateSimdMaskToPattern<simd16_t>(TYP_SHORT, mask->gtSimdMaskVal) != SveMaskPatternNone)
1159+
{
1160+
return mask->gtNext;
1161+
}
1162+
1163+
if (EvaluateSimdMaskToPattern<simd16_t>(TYP_INT, mask->gtSimdMaskVal) != SveMaskPatternNone)
1164+
{
1165+
return mask->gtNext;
1166+
}
1167+
1168+
if (EvaluateSimdMaskToPattern<simd16_t>(TYP_LONG, mask->gtSimdMaskVal) != SveMaskPatternNone)
1169+
{
1170+
return mask->gtNext;
1171+
}
1172+
1173+
// Not a valid pattern, so cannot be created using ptrue/pfalse. Instead the mask will require
1174+
// loading from memory. There is no way to load to a predicate from memory using a PC relative
1175+
// address, so instead use a constant vector plus conversion to mask. Using basetype byte will
1176+
// ensure every entry in the mask is converted.
1177+
1178+
LABELEDDISPTREERANGE("lowering cns mask to cns vector (before)", BlockRange(), mask);
1179+
1180+
// Create a vector constant
1181+
GenTreeVecCon* vecCon = comp->gtNewVconNode(TYP_SIMD16);
1182+
EvaluateSimdCvtMaskToVector<simd16_t>(TYP_BYTE, &vecCon->gtSimdVal, mask->gtSimdMaskVal);
1183+
BlockRange().InsertBefore(mask, vecCon);
1184+
1185+
// Convert the vector constant to a mask
1186+
GenTree* convertedVec = comp->gtNewSimdCvtVectorToMaskNode(TYP_MASK, vecCon, CORINFO_TYPE_BYTE, 16);
1187+
BlockRange().InsertBefore(mask, convertedVec->AsHWIntrinsic()->Op(1));
1188+
BlockRange().InsertBefore(mask, convertedVec);
1189+
1190+
// Update use
1191+
LIR::Use use;
1192+
if (BlockRange().TryGetUse(mask, &use))
1193+
{
1194+
use.ReplaceWith(convertedVec);
1195+
}
1196+
else
1197+
{
1198+
convertedVec->SetUnusedValue();
1199+
}
1200+
1201+
BlockRange().Remove(mask);
1202+
1203+
LABELEDDISPTREERANGE("lowering cns mask to cns vector (after)", BlockRange(), vecCon);
1204+
1205+
return vecCon->gtNext;
1206+
}
1207+
11371208
const int POST_INDEXED_ADDRESSING_MAX_DISTANCE = 16;
11381209

11391210
//------------------------------------------------------------------------

src/coreclr/jit/simd.h

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,35 +1598,32 @@ void EvaluateSimdCvtVectorToMask(simdmask_t* result, TSimd arg0)
15981598
uint32_t count = sizeof(TSimd) / sizeof(TBase);
15991599
uint64_t mask = 0;
16001600

1601-
TBase significantBit = 1;
16021601
#if defined(TARGET_XARCH)
1603-
significantBit = static_cast<TBase>(1) << ((sizeof(TBase) * 8) - 1);
1602+
TBase MostSignificantBit = static_cast<TBase>(1) << ((sizeof(TBase) * 8) - 1);
16041603
#endif
16051604

16061605
for (uint32_t i = 0; i < count; i++)
16071606
{
16081607
TBase input0;
16091608
memcpy(&input0, &arg0.u8[i * sizeof(TBase)], sizeof(TBase));
16101609

1611-
if ((input0 & significantBit) != 0)
1612-
{
16131610
#if defined(TARGET_XARCH)
1614-
// For xarch we have count sequential bits to write
1615-
// depending on if the corresponding the input element
1616-
// has its most significant bit set
1617-
1611+
// For xarch we have count sequential bits to write depending on if the
1612+
// corresponding the input element has its most significant bit set
1613+
if ((input0 & MostSignificantBit) != 0)
1614+
{
16181615
mask |= static_cast<uint64_t>(1) << i;
1616+
}
16191617
#elif defined(TARGET_ARM64)
1620-
// For Arm64 we have count total bits to write, but
1621-
// they are sizeof(TBase) bits apart. We set
1622-
// depending on if the corresponding input element
1623-
// has its least significant bit set
1624-
1618+
// For Arm64 we have count total bits to write, but they are sizeof(TBase) bits
1619+
// apart. We set depending on if the corresponding input element is non zero
1620+
if (input0 != 0)
1621+
{
16251622
mask |= static_cast<uint64_t>(1) << (i * sizeof(TBase));
1623+
}
16261624
#else
1627-
unreached();
1625+
unreached();
16281626
#endif
1629-
}
16301627
}
16311628

16321629
memcpy(&result->u8[0], &mask, sizeof(uint64_t));

src/coreclr/jit/valuenum.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9145,6 +9145,30 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunTernary(
91459145
{
91469146
// (y & x) | (z & ~x)
91479147

9148+
#if defined(TARGET_ARM64)
9149+
if (ni == NI_Sve_ConditionalSelect)
9150+
{
9151+
assert(TypeOfVN(arg0VN) == TYP_MASK);
9152+
assert(type == TYP_SIMD16);
9153+
9154+
ValueNum maskVNSimd = EvaluateSimdCvtMaskToVector(this, type, baseType, arg0VN);
9155+
simd16_t maskVal = ::GetConstantSimd16(this, baseType, maskVNSimd);
9156+
9157+
simd16_t arg1 = ::GetConstantSimd16(this, baseType, arg1VN);
9158+
simd16_t arg2 = ::GetConstantSimd16(this, baseType, arg2VN);
9159+
9160+
simd16_t result = {};
9161+
EvaluateBinarySimd<simd16_t>(GT_AND, false, baseType, &result, arg1, maskVal);
9162+
ValueNum trueVN = VNForSimd16Con(result);
9163+
9164+
result = {};
9165+
EvaluateBinarySimd<simd16_t>(GT_AND_NOT, false, baseType, &result, arg2, maskVal);
9166+
ValueNum falseVN = VNForSimd16Con(result);
9167+
9168+
return EvaluateBinarySimd(this, GT_OR, false, type, baseType, trueVN, falseVN);
9169+
}
9170+
#endif // TARGET_ARM64
9171+
91489172
ValueNum trueVN = EvaluateBinarySimd(this, GT_AND, false, type, baseType, arg1VN, arg0VN);
91499173
ValueNum falseVN = EvaluateBinarySimd(this, GT_AND_NOT, false, type, baseType, arg2VN, arg0VN);
91509174

0 commit comments

Comments
 (0)