Skip to content

JIT: teach VN to fold type comparisons #72136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/coreclr/jit/smallhash.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ struct HashTableInfo<unsigned>
}
};

#ifdef HOST_64BIT
//------------------------------------------------------------------------
// HashTableInfo<ssize_t>: specialized version of HashTableInfo for ssize_t-
// typed keys.
template <>
struct HashTableInfo<ssize_t>
{
static bool Equals(ssize_t x, ssize_t y)
{
return x == y;
}

static unsigned GetHashCode(ssize_t key)
{
// Return the key itself
return (unsigned)key;
}
};
#endif

//------------------------------------------------------------------------
// HashTableBase: base type for HashTable and SmallHashTable. This class
// provides the vast majority of the implementation. The
Expand Down
128 changes: 126 additions & 2 deletions src/coreclr/jit/valuenum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ ValueNumStore::ValueNumStore(Compiler* comp, CompAllocator alloc)
, m_intCnsMap(nullptr)
, m_longCnsMap(nullptr)
, m_handleMap(nullptr)
, m_embeddedToCompileTimeHandleMap(alloc)
, m_floatCnsMap(nullptr)
, m_doubleCnsMap(nullptr)
, m_byrefCnsMap(nullptr)
Expand Down Expand Up @@ -2135,6 +2136,22 @@ ValueNum ValueNumStore::VNForFunc(var_types typ, VNFunc func, ValueNum arg0VN, V

ValueNum resultVN = NoVN;

// Even if the argVNs differ, if both operands runtime types constructed from handles,
// we can sometimes also fold.
//
// The case where the arg VNs are equal is handled by EvalUsingMathIdentity below.
// This is the VN analog of gtFoldTypeCompare.
//
const genTreeOps oper = genTreeOps(func);
if ((arg0VN != arg1VN) && GenTree::StaticOperIs(oper, GT_EQ, GT_NE))
{
resultVN = VNEvalFoldTypeCompare(typ, func, arg0VN, arg1VN);
if (resultVN != NoVN)
{
return resultVN;
}
}

// We canonicalize commutative operations.
// (Perhaps should eventually handle associative/commutative [AC] ops -- but that gets complicated...)
if (VNFuncIsCommutative(func))
Expand Down Expand Up @@ -3651,6 +3668,108 @@ ValueNum ValueNumStore::EvalBitCastForConstantArgs(var_types dstType, ValueNum a
}
}

//------------------------------------------------------------------------
// VNEvalFoldTypeCompare:
//
// Arguments:
// type - The result type
// func - The function
// arg0VN - VN of the first argument
// arg1VN - VN of the second argument
//
// Return Value:
// NoVN if this is not a foldable type compare
// Simplified (perhaps constant) VN if it is foldable.
//
// Notes:
// Value number counterpart to gtFoldTypeCompare
// Doesn't handle all the cases (yet).
//
// (EQ/NE (TypeHandleToRuntimeType x) (TypeHandleToRuntimeType y)) == (EQ/NE x y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function is injective, e.g. all function pointers map to typeof(IntPtr).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we should check with the runtime via compareTypesForEquality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this function is injective, e.g. all function pointers map to typeof(IntPtr).

Worth to note that all function pointer typeofs currently compare as equal.

//
ValueNum ValueNumStore::VNEvalFoldTypeCompare(var_types type, VNFunc func, ValueNum arg0VN, ValueNum arg1VN)
{
const genTreeOps oper = genTreeOps(func);
assert(GenTree::StaticOperIs(oper, GT_EQ, GT_NE));

VNFuncApp arg0Func;
const bool arg0IsFunc = GetVNFunc(arg0VN, &arg0Func);

if (!arg0IsFunc || (arg0Func.m_func != VNF_TypeHandleToRuntimeType))
{
return NoVN;
}

VNFuncApp arg1Func;
const bool arg1IsFunc = GetVNFunc(arg1VN, &arg1Func);

if (!arg1IsFunc || (arg1Func.m_func != VNF_TypeHandleToRuntimeType))
{
return NoVN;
}

// Only re-express as handle equality when we have known
// class handles and the VM agrees comparing these gives the same
// result as comparing the runtime types.
//
// Note that VN actually tracks the value of embedded handle;
// we need to pass the VM the associated the compile time handles,
// in case they differ (say for prejitting or AOT).
//
ValueNum handle0 = arg0Func.m_args[0];
if (!IsVNHandle(handle0))
{
return NoVN;
}

ValueNum handle1 = arg1Func.m_args[0];
if (!IsVNHandle(handle1))
{
return NoVN;
}

assert(GetHandleFlags(handle0) == GTF_ICON_CLASS_HDL);
assert(GetHandleFlags(handle1) == GTF_ICON_CLASS_HDL);

const ssize_t handleVal0 = ConstantValue<ssize_t>(handle0);
const ssize_t handleVal1 = ConstantValue<ssize_t>(handle1);
ssize_t compileTimeHandle0;
ssize_t compileTimeHandle1;

// These mappings should always exist.
//
const bool found0 = m_embeddedToCompileTimeHandleMap.TryGetValue(handleVal0, &compileTimeHandle0);
const bool found1 = m_embeddedToCompileTimeHandleMap.TryGetValue(handleVal1, &compileTimeHandle1);
assert(found0 && found1);

// We may see null compile time handles for some constructed class handle cases.
// We should fix the construction if possible. But just skip those cases for now.
//
if ((compileTimeHandle0 == 0) || (compileTimeHandle1 == 0))
{
return NoVN;
}

JITDUMP("Asking runtime to compare %p (%s) and %p (%s) for equality\n", dspPtr(compileTimeHandle0),
m_pComp->eeGetClassName(CORINFO_CLASS_HANDLE(compileTimeHandle0)), dspPtr(compileTimeHandle1),
m_pComp->eeGetClassName(CORINFO_CLASS_HANDLE(compileTimeHandle1)));

ValueNum result = NoVN;
const TypeCompareState s =
m_pComp->info.compCompHnd->compareTypesForEquality(CORINFO_CLASS_HANDLE(compileTimeHandle0),
CORINFO_CLASS_HANDLE(compileTimeHandle1));
if (s != TypeCompareState::May)
{
const bool typesAreEqual = (s == TypeCompareState::Must);
const bool operatorIsEQ = (oper == GT_EQ);
const int compareResult = operatorIsEQ ^ typesAreEqual ? 0 : 1;
JITDUMP("Runtime reports comparison is known at jit time: %u\n", compareResult);
result = VNForIntCon(compareResult);
}

return result;
}

//------------------------------------------------------------------------
// VNEvalCanFoldBinaryFunc: Can the given binary function be constant-folded?
//
Expand Down Expand Up @@ -7929,8 +8048,13 @@ void Compiler::fgValueNumberTreeConst(GenTree* tree)
case TYP_BOOL:
if (tree->IsIconHandle())
{
tree->gtVNPair.SetBoth(
vnStore->VNForHandle(ssize_t(tree->AsIntConCommon()->IconValue()), tree->GetIconHandleFlag()));
const ssize_t embeddedHandle = tree->AsIntCon()->IconValue();
tree->gtVNPair.SetBoth(vnStore->VNForHandle(embeddedHandle, tree->GetIconHandleFlag()));
if (tree->GetIconHandleFlag() == GTF_ICON_CLASS_HDL)
{
const ssize_t compileTimeHandle = tree->AsIntCon()->gtCompileTimeHandle;
vnStore->AddToEmbeddedHandleMap(embeddedHandle, compileTimeHandle);
}
}
else if ((typ == TYP_LONG) || (typ == TYP_ULONG))
{
Expand Down
11 changes: 11 additions & 0 deletions src/coreclr/jit/valuenum.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ class ValueNumStore
// Returns "true" iff "vnf" should be folded by evaluating the func with constant arguments.
bool VNEvalShouldFold(var_types typ, VNFunc func, ValueNum arg0VN, ValueNum arg1VN);

// Value number a type comparison
ValueNum VNEvalFoldTypeCompare(var_types type, VNFunc func, ValueNum arg0VN, ValueNum arg1VN);

// return vnf(v0)
template <typename T>
static T EvalOp(VNFunc vnf, T v0);
Expand Down Expand Up @@ -458,6 +461,11 @@ class ValueNumStore
// that happens to be the same...
ValueNum VNForHandle(ssize_t cnsVal, GenTreeFlags iconFlags);

void AddToEmbeddedHandleMap(ssize_t embeddedHandle, ssize_t compileTimeHandle)
{
m_embeddedToCompileTimeHandleMap.AddOrUpdate(embeddedHandle, compileTimeHandle);
}

// And the single constant for an object reference type.
static ValueNum VNForNull()
{
Expand Down Expand Up @@ -1380,6 +1388,9 @@ class ValueNumStore
return m_handleMap;
}

typedef SmallHashTable<ssize_t, ssize_t> EmbeddedToCompileTimeHandleMap;
EmbeddedToCompileTimeHandleMap m_embeddedToCompileTimeHandleMap;

struct LargePrimitiveKeyFuncsFloat : public JitLargePrimitiveKeyFuncs<float>
{
static bool Equals(float x, float y)
Expand Down
170 changes: 170 additions & 0 deletions src/tests/JIT/opt/ValueNumbering/TypeTestFolding.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;

public enum Enum1 : int { A }
public enum Enum2 : uint { A }

class TypeTestFolding
{
[MethodImpl(MethodImplOptions.NoInlining)]
static void SideEffect() { }

//static bool True0() => typeof(delegate*<int, double>) == typeof(delegate* unmanaged<float, void*, void>);
//static bool True1()
//{
// var t0 = typeof(delegate*<int, double>);
// SideEffect();
// var t1 = typeof(delegate* unmanaged<float, void*, void>);
// return t0 == t1;
//}

static bool True2() => typeof(TypeTestFolding) == typeof(TypeTestFolding);
static bool True3()
{
var t0 = typeof(TypeTestFolding);
SideEffect();
var t1 = typeof(TypeTestFolding);
return t0 == t1;
}

static bool True4() => typeof(ValueTuple<TypeTestFolding>) == typeof(ValueTuple<TypeTestFolding>);
static bool True5()
{
var t0 = typeof(ValueTuple<TypeTestFolding>);
SideEffect();
var t1 = typeof(ValueTuple<TypeTestFolding>);
return t0 == t1;
}

//static bool True6() => typeof(delegate*<int>) == typeof(nint);
//static bool True7()
//{
// var t0 = typeof(delegate*<int>);
// SideEffect();
// var t1 = typeof(nint);
// return t0 == t1;
//}

static bool False0() => typeof(List<object>) == typeof(List<string>);
static bool False1()
{
var t0 = typeof(List<object>);
SideEffect();
var t1 = typeof(List<string>);
return t0 == t1;
}

static bool False2() => typeof(int) == typeof(Enum1);
static bool False3()
{
var t0 = typeof(int);
SideEffect();
var t1 = typeof(Enum1);
return t0 == t1;
}

static bool False4() => typeof(Enum1) == typeof(Enum2);
static bool False5()
{
var t0 = typeof(Enum1);
SideEffect();
var t1 = typeof(Enum2);
return t0 == t1;
}

static bool False6() => typeof(int?) == typeof(uint?);
static bool False7()
{
var t0 = typeof(int?);
SideEffect();
var t1 = typeof(uint?);
return t0 == t1;
}

static bool False8() => typeof(int?) == typeof(Enum1?);
static bool False9()
{
var t0 = typeof(int?);
SideEffect();
var t1 = typeof(Enum1?);
return t0 == t1;
}

static bool False10() => typeof(ValueTuple<TypeTestFolding>) == typeof(ValueTuple<string>);
static bool False11()
{
var t0 = typeof(ValueTuple<TypeTestFolding>);
SideEffect();
var t1 = typeof(ValueTuple<string>);
return t0 == t1;
}

//static bool False12() => typeof(delegate*<int>[]) == typeof(delegate*<float>[]);
//static bool False13()
//{
// var t0 = typeof(delegate*<int>[]);
// SideEffect();
// var t1 = typeof(delegate*<float>[]);
// return t0 == t1;
//}

static bool False14() => typeof(int[]) == typeof(uint[]);
static bool False15()
{
var t0 = typeof(int[]);
SideEffect();
var t1 = typeof(uint[]);
return t0 == t1;
}

//static bool False16() => typeof(delegate*<int>) == typeof(IntPtr);
//static bool False17()
//{
// var t0 = typeof(delegate*<int>);
// SideEffect();
// var t1 = typeof(UIntPtr);
// return t0 == t1;
//}

unsafe static int Main()
{
delegate*<bool>[] trueFuncs = new delegate*<bool>[] { &True2, &True3, &True4, &True5 };
delegate*<bool>[] falseFuncs = new delegate*<bool>[] { &False0, &False1, &False2, &False3, &False4, &False5,
&False6, &False7, &False8, &False9, &False10, &False11,
&False14, &False15 };

int result = 100;
int trueCount = 0;
int falseCount = 0;

foreach (var tf in trueFuncs)
{
if (!tf())
{
Console.WriteLine($"True{trueCount} failed");
result++;
}
trueCount++;
}

foreach (var ff in falseFuncs)
{
if (ff())
{
Console.WriteLine($"False{falseCount} failed");
result++;
}
falseCount++;
}

Console.WriteLine($"Ran {trueCount + falseCount} tests; result {result}");
return result;
}
}



13 changes: 13 additions & 0 deletions src/tests/JIT/opt/ValueNumbering/TypeTestFolding.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
</PropertyGroup>
<PropertyGroup>
<DebugType>None</DebugType>
<Optimize>True</Optimize>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(MSBuildProjectName).cs" />
</ItemGroup>
</Project>
Loading