Skip to content

JIT: Generalize strategy for finding addrecs #99048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9717,6 +9717,7 @@ JITDBGAPI void __cdecl cScev(Compiler* comp, Scev* scev)
else
{
scev->Dump(comp);
printf("\n");
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/inductionvariableopts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ PhaseStatus Compiler::optInductionVariables()
initStmt->GetNextStmt());
}

JITDUMP(" Replacing in the loop; %d statements with appearences\n", ivUses.Height());
JITDUMP(" Replacing in the loop; %d statements with appearances\n", ivUses.Height());
for (int i = 0; i < ivUses.Height(); i++)
{
Statement* stmt = ivUses.Bottom(i);
Expand Down
206 changes: 172 additions & 34 deletions src/coreclr/jit/scev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@
// straightforward translation from JIT IR into the SCEV IR. Creating the add
// recurrences requires paying attention to the structure of PHIs, and
// disambiguating the values coming from outside the loop and the values coming
// from the backedges. Currently only simplistic add recurrences that do not
// require recursive analysis are supported. These simplistic add recurrences
// are always on the form i = i + k.
// from the backedges.
//

#include "jitpch.h"
Expand Down Expand Up @@ -208,7 +206,7 @@ void Scev::Dump(Compiler* comp)
// ResetForLoop.
//
ScalarEvolutionContext::ScalarEvolutionContext(Compiler* comp)
: m_comp(comp), m_cache(comp->getAllocator(CMK_LoopIVOpts))
: m_comp(comp), m_cache(comp->getAllocator(CMK_LoopIVOpts)), m_ephemeralCache(comp->getAllocator(CMK_LoopIVOpts))
{
}

Expand Down Expand Up @@ -471,34 +469,34 @@ Scev* ScalarEvolutionContext::AnalyzeNew(BasicBlock* block, GenTree* tree, int d

assert(ssaDsc->GetBlock() != nullptr);

// We currently do not handle complicated addrecs. We can do this
// by inserting a symbolic node in the cache and analyzing while it
// is part of the cache. It would allow us to model
//
// int i = 0;
// while (i < n)
// {
// int j = i + 1;
// ...
// i = j;
// }
// => <L, 0, 1>
//
// and chains of recurrences, such as
//
// int i = 0;
// int j = 0;
// while (i < n)
// {
// j++;
// i += j;
// }
// => <L, 0, <L, 1, 1>>
//
// The main issue is that it requires cache invalidation afterwards
// and turning the recursive result into an addrec.
//
return CreateSimpleAddRec(store, enterScev, ssaDsc->GetBlock(), ssaDsc->GetDefNode()->Data());
Scev* simpleAddRec = CreateSimpleAddRec(store, enterScev, ssaDsc->GetBlock(), ssaDsc->GetDefNode()->Data());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be useful to retain the prior comment here and augment so there's as a worked-out example of how these symbolic addrecs end up getting resolved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense... let me add it as part of a future PR.

if (simpleAddRec != nullptr)
{
return simpleAddRec;
}

ScevConstant* symbolicAddRec = NewConstant(data->TypeGet(), 0xdeadbeef);
m_ephemeralCache.Emplace(store, symbolicAddRec);

Scev* result;
if (m_usingEphemeralCache)
{
result = Analyze(ssaDsc->GetBlock(), ssaDsc->GetDefNode()->Data(), depth + 1);
}
else
{
m_usingEphemeralCache = true;
result = Analyze(ssaDsc->GetBlock(), ssaDsc->GetDefNode()->Data(), depth + 1);
m_usingEphemeralCache = false;
m_ephemeralCache.RemoveAll();
}

if (result == nullptr)
{
return nullptr;
}

return MakeAddRecFromRecursiveScev(enterScev, result, symbolicAddRec);
}
case GT_CAST:
{
Expand Down Expand Up @@ -611,6 +609,138 @@ Scev* ScalarEvolutionContext::CreateSimpleAddRec(GenTreeLclVarCommon* headerStor
return NewAddRec(enterScev, stepScev);
}

//------------------------------------------------------------------------
// ExtractAddOperands: Extract all operands of potentially nested add
// operations.
//
// Parameters:
// binop - The binop representing an add
// operands - Array stack to add the operands to
//
void ScalarEvolutionContext::ExtractAddOperands(ScevBinop* binop, ArrayStack<Scev*>& operands)
{
assert(binop->OperIs(ScevOper::Add));

if (binop->Op1->OperIs(ScevOper::Add))
{
ExtractAddOperands(static_cast<ScevBinop*>(binop->Op1), operands);
}
else
{
operands.Push(binop->Op1);
}

if (binop->Op2->OperIs(ScevOper::Add))
{
ExtractAddOperands(static_cast<ScevBinop*>(binop->Op2), operands);
}
else
{
operands.Push(binop->Op2);
}
}

//------------------------------------------------------------------------
// MakeAddRecFromRecursiveScev: Given a recursive SCEV and a symbolic SCEV
// whose appearances represent an occurrence of the full SCEV, create a
// non-recursive add-rec from it.
//
// Parameters:
// startScev - The start value of the addrec
// scev - The scev
// recursiveScev - A symbolic node whose appearance represents the value of "scev"
//
// Returns:
// A non-recursive addrec
//
Scev* ScalarEvolutionContext::MakeAddRecFromRecursiveScev(Scev* startScev, Scev* scev, Scev* recursiveScev)
{
if (!scev->OperIs(ScevOper::Add))
{
return nullptr;
}

ArrayStack<Scev*> addOperands(m_comp->getAllocator(CMK_LoopIVOpts));
ExtractAddOperands(static_cast<ScevBinop*>(scev), addOperands);

assert(addOperands.Height() >= 2);

int numAppearances = 0;
for (int i = 0; i < addOperands.Height(); i++)
{
Scev* addOperand = addOperands.Bottom(i);
if (addOperand == recursiveScev)
{
numAppearances++;
}
else
{
ScevVisit result = addOperand->Visit([=](Scev* node) {
if (node == recursiveScev)
{
return ScevVisit::Abort;
}

return ScevVisit::Continue;
});

if (result == ScevVisit::Abort)
{
// We do not handle nested occurrences. Some of these may be representable, some won't.
return nullptr;
}
}
}

if (numAppearances == 0)
{
// TODO-CQ: We currently cannot handle cases like
// i = arr.Length;
// j = i - 1;
// i = j;
// while (true) { ...; j = i - 1; i = j; }
//
// These cases can arise from loop structures like "for (int i =
// arr.Length; --i >= 0;)" when Roslyn emits a "sub; dup; stloc"
// sequence, and local prop + loop inversion converts the duplicated
// local into a fully fledged IV.
// In this case we see that i = <L, [i from outside loop], -1>, but for
// j we will see <L, [i from outside loop], -1> + (-1) in this function
// as the value coming around the backedge, and we cannot reconcile
// this.
//
return nullptr;
}

if (numAppearances > 1)
{
// Multiple occurrences -- cannot be represented as an addrec
// (corresponds to a geometric progression).
return nullptr;
}

Scev* step = nullptr;
for (int i = 0; i < addOperands.Height(); i++)
{
Scev* addOperand = addOperands.Bottom(i);
if (addOperand == recursiveScev)
{
continue;
}

if (step == nullptr)
{
step = addOperand;
}
else
{
step = NewBinop(ScevOper::Add, step, addOperand);
}
}

return NewAddRec(startScev, step);
}

//------------------------------------------------------------------------
// Analyze: Analyze the specified tree in the specified block.
//
Expand Down Expand Up @@ -653,15 +783,23 @@ const int SCALAR_EVOLUTION_ANALYSIS_MAX_DEPTH = 64;
Scev* ScalarEvolutionContext::Analyze(BasicBlock* block, GenTree* tree, int depth)
{
Scev* result;
if (!m_cache.Lookup(tree, &result))
if (!m_cache.Lookup(tree, &result) && (!m_usingEphemeralCache || !m_ephemeralCache.Lookup(tree, &result)))
{
if (depth >= SCALAR_EVOLUTION_ANALYSIS_MAX_DEPTH)
{
return nullptr;
}

result = AnalyzeNew(block, tree, depth);
m_cache.Set(tree, result);

if (m_usingEphemeralCache)
{
m_ephemeralCache.Set(tree, result, ScalarEvolutionMap::Overwrite);
}
else
{
m_cache.Set(tree, result);
}
}

return result;
Expand Down
67 changes: 67 additions & 0 deletions src/coreclr/jit/scev.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ static bool ScevOperIs(ScevOper oper, ScevOper operFirst, Args... operTail)
return oper == operFirst || ScevOperIs(oper, operTail...);
}

enum class ScevVisit
{
Abort,
Continue,
};

struct Scev
{
const ScevOper Oper;
Expand All @@ -62,6 +68,8 @@ struct Scev
#ifdef DEBUG
void Dump(Compiler* comp);
#endif
template <typename TVisitor>
ScevVisit Visit(TVisitor visitor);
};

struct ScevConstant : Scev
Expand Down Expand Up @@ -119,6 +127,57 @@ struct ScevAddRec : Scev
INDEBUG(FlowGraphNaturalLoop* const Loop);
};

//------------------------------------------------------------------------
// Scev::Visit: Recursively visit all SCEV nodes in the SCEV tree.
//
// Parameters:
// visitor - Callback with signature Scev* -> ScevVisit.
//
// Returns:
// ScevVisit::Abort if "visitor" aborted, otherwise ScevVisit::Continue.
//
// Remarks:
// The visit is done in preorder.
//
template <typename TVisitor>
ScevVisit Scev::Visit(TVisitor visitor)
{
if (visitor(this) == ScevVisit::Abort)
return ScevVisit::Abort;

switch (Oper)
{
case ScevOper::Constant:
case ScevOper::Local:
break;
case ScevOper::ZeroExtend:
case ScevOper::SignExtend:
return static_cast<ScevUnop*>(this)->Op1->Visit(visitor);
case ScevOper::Add:
case ScevOper::Mul:
case ScevOper::Lsh:
{
ScevBinop* binop = static_cast<ScevBinop*>(this);
if (binop->Op1->Visit(visitor) == ScevVisit::Abort)
return ScevVisit::Abort;

return binop->Op2->Visit(visitor);
}
case ScevOper::AddRec:
{
ScevAddRec* addrec = static_cast<ScevAddRec*>(this);
if (addrec->Start->Visit(visitor) == ScevVisit::Abort)
return ScevVisit::Abort;

return addrec->Step->Visit(visitor);
}
default:
unreached();
}

return ScevVisit::Continue;
}

typedef JitHashTable<GenTree*, JitPtrKeyFuncs<GenTree>, Scev*> ScalarEvolutionMap;

// Scalar evolution is analyzed in the context of a single loop, and are
Expand All @@ -130,14 +189,22 @@ class ScalarEvolutionContext
FlowGraphNaturalLoop* m_loop = nullptr;
ScalarEvolutionMap m_cache;

// During analysis of PHIs we insert a symbolic node representing the
// "recurrence"; we use this cache to be able to invalidate things that end
// up depending on the symbolic node quickly.
ScalarEvolutionMap m_ephemeralCache;
bool m_usingEphemeralCache = false;

Scev* Analyze(BasicBlock* block, GenTree* tree, int depth);
Scev* AnalyzeNew(BasicBlock* block, GenTree* tree, int depth);
Scev* CreateSimpleAddRec(GenTreeLclVarCommon* headerStore,
ScevLocal* start,
BasicBlock* stepDefBlock,
GenTree* stepDefData);
Scev* MakeAddRecFromRecursiveScev(Scev* start, Scev* scev, Scev* recursiveScev);
Scev* CreateSimpleInvariantScev(GenTree* tree);
Scev* CreateScevForConstant(GenTreeIntConCommon* tree);
void ExtractAddOperands(ScevBinop* add, ArrayStack<Scev*>& operands);

public:
ScalarEvolutionContext(Compiler* comp);
Expand Down