Skip to content

JIT: Improve local assertion prop throughput #94597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 13, 2023
Merged
124 changes: 83 additions & 41 deletions src/coreclr/jit/assertionprop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1617,7 +1617,7 @@ bool Compiler::optAssertionVnInvolvesNan(AssertionDsc* assertion)
*
* If it is already in the assertion table return the assertionIndex that
* we use to refer to this element.
* Otherwise add it to the assertion table ad return the assertionIndex that
* Otherwise add it to the assertion table and return the assertionIndex that
* we use to refer to this element.
* If we need to add to the table and the table is full return the value zero
*/
Expand All @@ -1633,13 +1633,57 @@ AssertionIndex Compiler::optAddAssertion(AssertionDsc* newAssertion)
return NO_ASSERTION_INDEX;
}

// Check if exists already, so we can skip adding new one. Search backwards.
for (AssertionIndex index = optAssertionCount; index >= 1; index--)
// See if we already have this assertion in the table.
//
// For local assertion prop we can speed things up by checking the dep vectors.
//
if (optLocalAssertionProp)
{
AssertionDsc* curAssertion = optGetAssertion(index);
if (curAssertion->Equals(newAssertion, !optLocalAssertionProp))
assert(newAssertion->op1.kind == O1K_LCLVAR);

unsigned lclNum = newAssertion->op1.lcl.lclNum;
BitVecOps::Iter iter(apTraits, GetAssertionDep(lclNum));
unsigned bvIndex = 0;
while (iter.NextElem(&bvIndex))
{
return index;
AssertionIndex const index = GetAssertionIndex(bvIndex);
AssertionDsc* const curAssertion = optGetAssertion(index);

if (curAssertion->Equals(newAssertion, !optLocalAssertionProp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (curAssertion->Equals(newAssertion, !optLocalAssertionProp))
if (curAssertion->Equals(newAssertion, /* vnBased */ false))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Also in the other cases)

{
return index;
}
}

if (newAssertion->op2.kind == O2K_LCLVAR_COPY)
{
lclNum = newAssertion->op2.lcl.lclNum;
BitVecOps::Iter iter(apTraits, GetAssertionDep(lclNum));
unsigned bvIndex = 0;
while (iter.NextElem(&bvIndex))
{
AssertionIndex const index = GetAssertionIndex(bvIndex);
AssertionDsc* const curAssertion = optGetAssertion(index);

if (curAssertion->Equals(newAssertion, !optLocalAssertionProp))
{
return index;
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? Shouldn't the previous case have found it if there is an equal assertion? Or will we only keep an assertion like v1 = v2 in one of the bit vectors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's kept for both (see just below) -- so yeah that second loop is not needed.

}
else
{
// For global prop we search the entire table.
//
// Check if exists already, so we can skip adding new one. Search backwards.
for (AssertionIndex index = optAssertionCount; index >= 1; index--)
{
AssertionDsc* curAssertion = optGetAssertion(index);
if (curAssertion->Equals(newAssertion, !optLocalAssertionProp))
{
return index;
}
}
}

Expand Down Expand Up @@ -2490,15 +2534,19 @@ AssertionIndex Compiler::optFindComplementary(AssertionIndex assertIndex)
//
AssertionIndex Compiler::optAssertionIsSubrange(GenTree* tree, IntegralRange range, ASSERT_VALARG_TP assertions)
{
if ((!optLocalAssertionProp && BitVecOps::IsEmpty(apTraits, assertions)) || !optCanPropSubRange)
if (!optCanPropSubRange)
{
// (don't early out in checked, verify above)
return NO_ASSERTION_INDEX;
}

for (AssertionIndex index = 1; index <= optAssertionCount; index++)
BitVecOps::Iter iter(apTraits, assertions);
unsigned bvIndex = 0;
while (iter.NextElem(&bvIndex))
{
AssertionDsc* curAssertion = optGetAssertion(index);
if (BitVecOps::IsMember(apTraits, assertions, index - 1) && curAssertion->CanPropSubRange())
AssertionIndex const index = GetAssertionIndex(bvIndex);
AssertionDsc* const curAssertion = optGetAssertion(index);
if (curAssertion->CanPropSubRange())
{
// For local assertion prop use comparison on locals, and use comparison on vns for global prop.
bool isEqual = optLocalAssertionProp
Expand Down Expand Up @@ -2530,18 +2578,12 @@ AssertionIndex Compiler::optAssertionIsSubrange(GenTree* tree, IntegralRange ran
*/
AssertionIndex Compiler::optAssertionIsSubtype(GenTree* tree, GenTree* methodTableArg, ASSERT_VALARG_TP assertions)
{
if (BitVecOps::IsEmpty(apTraits, assertions))
{
return NO_ASSERTION_INDEX;
}
for (AssertionIndex index = 1; index <= optAssertionCount; index++)
BitVecOps::Iter iter(apTraits, assertions);
unsigned bvIndex = 0;
while (iter.NextElem(&bvIndex))
{
if (!BitVecOps::IsMember(apTraits, assertions, index - 1))
{
continue;
}

AssertionDsc* curAssertion = optGetAssertion(index);
AssertionIndex const index = GetAssertionIndex(bvIndex);
AssertionDsc* curAssertion = optGetAssertion(index);
if (curAssertion->assertionKind != OAK_EQUAL ||
(curAssertion->op1.kind != O1K_SUBTYPE && curAssertion->op1.kind != O1K_EXACT_TYPE))
{
Expand Down Expand Up @@ -3709,31 +3751,31 @@ AssertionIndex Compiler::optLocalAssertionIsEqualOrNotEqual(
{
noway_assert((op1Kind == O1K_LCLVAR) || (op1Kind == O1K_EXACT_TYPE) || (op1Kind == O1K_SUBTYPE));
noway_assert((op2Kind == O2K_CONST_INT) || (op2Kind == O2K_IND_CNS_INT) || (op2Kind == O2K_ZEROOBJ));
if (BitVecOps::IsEmpty(apTraits, assertions))
{
return NO_ASSERTION_INDEX;
}

for (AssertionIndex index = 1; index <= optAssertionCount; ++index)
assert(optLocalAssertionProp);
ASSERT_TP apDependent = BitVecOps::Intersection(apTraits, GetAssertionDep(lclNum), assertions);

BitVecOps::Iter iter(apTraits, apDependent);
unsigned bvIndex = 0;
while (iter.NextElem(&bvIndex))
{
AssertionDsc* curAssertion = optGetAssertion(index);
if (BitVecOps::IsMember(apTraits, assertions, index - 1))
AssertionIndex const index = GetAssertionIndex(bvIndex);
AssertionDsc* curAssertion = optGetAssertion(index);

if ((curAssertion->assertionKind != OAK_EQUAL) && (curAssertion->assertionKind != OAK_NOT_EQUAL))
{
if ((curAssertion->assertionKind != OAK_EQUAL) && (curAssertion->assertionKind != OAK_NOT_EQUAL))
{
continue;
}
continue;
}

if ((curAssertion->op1.kind == op1Kind) && (curAssertion->op1.lcl.lclNum == lclNum) &&
(curAssertion->op2.kind == op2Kind))
{
bool constantIsEqual = (curAssertion->op2.u1.iconVal == cnsVal);
bool assertionIsEqual = (curAssertion->assertionKind == OAK_EQUAL);
if ((curAssertion->op1.kind == op1Kind) && (curAssertion->op1.lcl.lclNum == lclNum) &&
(curAssertion->op2.kind == op2Kind))
{
bool constantIsEqual = (curAssertion->op2.u1.iconVal == cnsVal);
bool assertionIsEqual = (curAssertion->assertionKind == OAK_EQUAL);

if (constantIsEqual || assertionIsEqual)
{
return index;
}
if (constantIsEqual || assertionIsEqual)
{
return index;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4799,7 +4799,7 @@ class Compiler
FoldResult fgFoldConditional(BasicBlock* block);

PhaseStatus fgMorphBlocks();
void fgMorphBlock(BasicBlock* block);
void fgMorphBlock(BasicBlock* block, unsigned highestReachablePostorder = 0);
void fgMorphStmts(BasicBlock* block);

void fgMergeBlockReturn(BasicBlock* block);
Expand Down
18 changes: 15 additions & 3 deletions src/coreclr/jit/morph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13772,8 +13772,10 @@ void Compiler::fgMorphStmts(BasicBlock* block)
//
// Arguments:
// block - block in question
// highestReachablePostorder - maximum postorder number for a
// reachable block.
//
void Compiler::fgMorphBlock(BasicBlock* block)
void Compiler::fgMorphBlock(BasicBlock* block, unsigned highestReachablePostorder)
{
JITDUMP("\nMorphing " FMT_BB "\n", block->bbNum);

Expand All @@ -13788,6 +13790,8 @@ void Compiler::fgMorphBlock(BasicBlock* block)
}
else
{
assert(highestReachablePostorder > 0);

// Determine if this block can leverage assertions from its pred blocks.
//
// Some blocks are ineligible.
Expand Down Expand Up @@ -13818,6 +13822,14 @@ void Compiler::fgMorphBlock(BasicBlock* block)
break;
}

if (pred->bbPostorderNum > highestReachablePostorder)
{
// This pred was not reachable from the original DFS root set, so
// we can ignore its assertion information.
//
continue;
}

// Yes, pred assertions are available. If this is the first pred, copy.
// If this is a subsequent pred, intersect.
//
Expand Down Expand Up @@ -13939,7 +13951,7 @@ PhaseStatus Compiler::fgMorphBlocks()
// We are optimizing. Process in RPO.
//
fgRenumberBlocks();
fgDfsReversePostorder();
const unsigned highestReachablePostorder = fgDfsReversePostorder();

// Disallow general creation of new blocks or edges as it
// would invalidate RPO.
Expand Down Expand Up @@ -13971,7 +13983,7 @@ PhaseStatus Compiler::fgMorphBlocks()
for (unsigned i = 1; i <= bbNumMax; i++)
{
BasicBlock* const block = fgBBReversePostorder[i];
fgMorphBlock(block);
fgMorphBlock(block, highestReachablePostorder);
}
assert(bbNumMax == fgBBNumMax);

Expand Down