Skip to content

Constant folding for SIMD comparisons #85584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/coreclr/jit/simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ TBase EvaluateBinaryScalarRSZ(TBase arg0, TBase arg1)
return arg0 >> (arg1 & ((sizeof(TBase) * 8) - 1));
}

template <typename TBase>
TBase GetAllBitsSetScalar()
{
return ~static_cast<TBase>(0);
}

template <>
inline int8_t EvaluateBinaryScalarRSZ<int8_t>(int8_t arg0, int8_t arg1)
{
Expand Down Expand Up @@ -520,6 +526,26 @@ TBase EvaluateBinaryScalarSpecialized(genTreeOps oper, TBase arg0, TBase arg1)
return arg0 ^ arg1;
}

case GT_EQ:
{
#ifdef _MSC_VER
// Floating point is not supported
assert(&typeid(TBase) != &typeid(float));
assert(&typeid(TBase) != &typeid(double));
#endif // _MSC_VER
return arg0 == arg1 ? GetAllBitsSetScalar<TBase>() : 0;
}

case GT_NE:
{
#ifdef _MSC_VER
// Floating point is not supported
assert(&typeid(TBase) != &typeid(float));
assert(&typeid(TBase) != &typeid(double));
#endif // _MSC_VER
return arg0 != arg1 ? GetAllBitsSetScalar<TBase>() : 0;
}

default:
{
unreached();
Expand Down
111 changes: 105 additions & 6 deletions src/coreclr/jit/valuenum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6812,6 +6812,47 @@ ValueNum EvaluateBinarySimd(ValueNumStore* vns,
}
#endif // TARGET_XARCH

case TYP_BOOL:
{
assert((oper == GT_EQ) || (oper == GT_NE));

var_types vn1Type = vns->TypeOfVN(arg0VN);
var_types vn2Type = vns->TypeOfVN(arg1VN);

assert((vn1Type == vn2Type) && varTypeIsSIMD(vn1Type));
assert(!varTypeIsFloating(baseType));

ValueNum packed = EvaluateBinarySimd(vns, GT_EQ, scalar, vn1Type, baseType, arg0VN, arg1VN);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be EvaluateBinarySimd(vns, oper, ...)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be simpler to implement this as EvaluateVector and then simplify check result IsAllBitsSet or !IsZero

Which would also simplify the other relational comparisons

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be EvaluateBinarySimd(vns, oper, ...)?

It makes it harder to check output, doesn't it?
Currently I just pass EQ and then return AllBitsSet or !AllBitsSet depending on source oper

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it? We have a simple property for any simd_T and it makes it easier to cover the other cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, if you just evaluate per element you get a result simd_T result and can then just check IsAllBitsSet or IsZero

bool allBitsSet = false;
if (vn1Type == TYP_SIMD8)
{
allBitsSet = GetConstantSimd8(vns, baseType, packed).IsAllBitsSet();
}
else if (vn1Type == TYP_SIMD12)
{
allBitsSet = GetConstantSimd12(vns, baseType, packed).IsAllBitsSet();
}
else if (vn1Type == TYP_SIMD16)
{
allBitsSet = GetConstantSimd16(vns, baseType, packed).IsAllBitsSet();
}
#if defined(TARGET_XARCH)
else if (vn1Type == TYP_SIMD32)
{
allBitsSet = GetConstantSimd32(vns, baseType, packed).IsAllBitsSet();
}
else if (vn1Type == TYP_SIMD64)
{
allBitsSet = GetConstantSimd64(vns, baseType, packed).IsAllBitsSet();
}
#endif
else
{
unreached();
}
return vns->VNForIntCon(oper == GT_EQ ? allBitsSet : !allBitsSet);
}

default:
{
unreached();
Expand Down Expand Up @@ -7167,6 +7208,43 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(var_types type,

switch (ni)
{
#ifdef TARGET_ARM64
case NI_Vector64_op_Equality:
case NI_Vector128_op_Equality:
case NI_Vector64_EqualsAll:
case NI_Vector128_EqualsAll:
#else
case NI_Vector128_op_Equality:
case NI_Vector256_op_Equality:
case NI_Vector512_op_Equality:
case NI_Vector128_EqualsAll:
case NI_Vector256_EqualsAll:
case NI_Vector512_EqualsAll:
#endif
{
if (!varTypeIsFloating(baseType))
{
return EvaluateBinarySimd(this, GT_EQ, /* scalar */ false, type, baseType, arg0VN, arg1VN);
}
break;
}

#ifdef TARGET_ARM64
case NI_Vector64_op_Inequality:
case NI_Vector128_op_Inequality:
#else
case NI_Vector128_op_Inequality:
case NI_Vector256_op_Inequality:
case NI_Vector512_op_Inequality:
#endif
{
if (!varTypeIsFloating(baseType))
{
return EvaluateBinarySimd(this, GT_NE, /* scalar */ false, type, baseType, arg0VN, arg1VN);
}
break;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to handle GT, LT, GE, and LE simultaneously?

Likewise given you've added the support for computing the vector version in order to make bool work, should we just handle the intrinsics that produce a vector as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those had no hits so I just didn't want to add more code and tests 🙂 Although, even EQ/NE don't have hits, I just found a use case outside.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the type of scenario where it’s a core comparison on SIMD and so we want to generally handle it even if our own code doesn’t have any hits today.

it’s odd to cover some of the relational comparisons and not the others


#ifdef TARGET_ARM64
case NI_AdvSimd_Add:
case NI_AdvSimd_Arm64_Add:
Expand Down Expand Up @@ -10269,14 +10347,35 @@ void Compiler::fgValueNumberSsaVarDef(GenTreeLclVarCommon* lcl)
static bool GetStaticFieldSeqAndAddress(ValueNumStore* vnStore, GenTree* tree, ssize_t* byteOffset, FieldSeq** pFseq)
{
VNFuncApp funcApp;
if (vnStore->GetVNFunc(tree->gtVNPair.GetLiberal(), &funcApp) && (funcApp.m_func == VNF_PtrToStatic))
if (vnStore->GetVNFunc(tree->gtVNPair.GetLiberal(), &funcApp))
{
FieldSeq* fseq = vnStore->FieldSeqVNToFieldSeq(funcApp.m_args[1]);
if (fseq->GetKind() == FieldSeq::FieldKind::SimpleStatic)
if (funcApp.m_func == VNF_PtrToStatic)
{
*byteOffset = vnStore->ConstantValue<ssize_t>(funcApp.m_args[2]);
*pFseq = fseq;
return true;
FieldSeq* fseq = vnStore->FieldSeqVNToFieldSeq(funcApp.m_args[1]);
if (fseq->GetKind() == FieldSeq::FieldKind::SimpleStatic)
{
*byteOffset = vnStore->ConstantValue<ssize_t>(funcApp.m_args[2]);
*pFseq = fseq;
return true;
}
}
else if (funcApp.m_func == VNFunc(GT_ADD))
{
// Handle ADD(STATIC_HDL, OFFSET) via VN (the logic in this method mostly works with plain tree nodes)
if (vnStore->IsVNHandle(funcApp.m_args[0]) &&
(vnStore->GetHandleFlags(funcApp.m_args[0]) == GTF_ICON_STATIC_HDL) &&
vnStore->IsVNConstant(funcApp.m_args[1]) && !vnStore->IsVNHandle(funcApp.m_args[1]))
{
FieldSeq* fldSeq = vnStore->GetFieldSeqFromAddress(funcApp.m_args[0]);
if (fldSeq != nullptr)
{
assert(fldSeq->GetKind() == FieldSeq::FieldKind::SimpleStaticKnownAddress);
*pFseq = fldSeq;
*byteOffset = vnStore->CoercedConstantValue<ssize_t>(funcApp.m_args[0]) - fldSeq->GetOffset() +
vnStore->CoercedConstantValue<ssize_t>(funcApp.m_args[1]);
return true;
}
}
}
}
ssize_t val = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,4 +733,34 @@ public static void XorTests()
^ Vector128.Create((double)(+1), +1)
);
}

[Fact]
public static void EqualityTests()
{
Assert.True(
Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0) ==
Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0));
Assert.False(
Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0) !=
Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0));
Assert.False(
Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xFF) ==
Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0));
Assert.True(
Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xFF) !=
Vector128.Create((byte)(0x01), 0xFE, 0x03, 0xFC, 0x05, 0xFA, 0x07, 0xF8, 0x09, 0xF6, 0x0B, 0xF4, 0x0D, 0xF2, 0x0F, 0xF0));
Assert.True(
Vector128.EqualsAll(
Vector128.Create((ushort)(0x0001), 0xFFFE, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8),
Vector128.Create((ushort)(0x0001), 0xFFFE, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8)));
Assert.False(
Vector128.EqualsAll(
Vector128.Create((ushort)(0x0001), 0xFFFF, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8),
Vector128.Create((ushort)(0x0001), 0xFFFE, 0x0003, 0xFFFC, 0x0005, 0xFFFA, 0x0007, 0xFFF8)));
Assert.Equal(
Vector128.Create(4294967295, 0, 4294967295, 4294967295),
Vector128.Equals(
Vector128.Create((uint)(0x0000_0001), 0xFFFF_FFFE, 0x0000_0003, 0xFFFF_FFFC),
Vector128.Create((uint)(0x0000_0001), 0x0000_0001, 0x0000_0003, 0xFFFF_FFFC)));
}
}