Skip to content

Reduce TP for targets with more than 64 Registers Part 2 : Make operations on fixedRegs less expensive #113713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 88 additions & 24 deletions src/coreclr/jit/lsra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,30 +275,66 @@ SingleTypeRegSet LinearScan::lowSIMDRegs()
#endif
}

template <bool isLow>
void LinearScan::updateNextFixedRef(RegRecord* regRecord, RefPosition* nextRefPosition, RefPosition* nextKill)
{
LsraLocation nextLocation = nextRefPosition == nullptr ? MaxLocation : nextRefPosition->nodeLocation;

RefPosition* kill = nextKill;

#ifdef HAS_MORE_THAN_64_REGISTERS
SingleTypeRegSet regMask = isLow ? genSingleTypeRegMask(regRecord->regNum)
: genSingleTypeRegMask((regNumber)(regRecord->regNum - REG_HIGH_BASE));
#else
SingleTypeRegSet regMask = genSingleTypeRegMask(regRecord->regNum);
#endif
while ((kill != nullptr) && (kill->nodeLocation < nextLocation))
{
if (kill->killedRegisters.IsRegNumInMask(regRecord->regNum))
if (isLow)
{
nextLocation = kill->nodeLocation;
break;
if ((kill->killedRegisters.getLow() & regMask) != RBM_NONE)
{
nextLocation = kill->nodeLocation;
break;
}
}

#ifdef HAS_MORE_THAN_64_REGISTERS
else
{
if ((kill->killedRegisters.getHigh() & regMask) != RBM_NONE)
{
nextLocation = kill->nodeLocation;
break;
}
}
#endif
kill = kill->nextRefPosition;
}

if (nextLocation == MaxLocation)
if (isLow)
{
fixedRegs.RemoveRegNumFromMask(regRecord->regNum);
if (nextLocation == MaxLocation)
{
fixedRegsLow &= ~regMask;
}
else
{
fixedRegsLow |= regMask;
}
}
#ifdef HAS_MORE_THAN_64_REGISTERS
else
{
fixedRegs.AddRegNumInMask(regRecord->regNum);
if (nextLocation == MaxLocation)
{
fixedRegsHigh &= ~regMask;
}
else
{
fixedRegsHigh |= regMask;
}
}
#endif

nextFixedRef[regRecord->regNum] = nextLocation;
}
Expand Down Expand Up @@ -3857,10 +3893,10 @@ void LinearScan::processKills(RefPosition* killRefPosition)

regMaskTP killedRegs = killRefPosition->getKilledRegisters();

freeKilledRegs(killRefPosition, killedRegs.getLow(), nextKill, REG_LOW_BASE);
freeKilledRegs<true>(killRefPosition, killedRegs.getLow(), nextKill, REG_LOW_BASE);

#ifdef HAS_MORE_THAN_64_REGISTERS
freeKilledRegs(killRefPosition, killedRegs.getHigh(), nextKill, REG_HIGH_BASE);
freeKilledRegs<false>(killRefPosition, killedRegs.getHigh(), nextKill, REG_HIGH_BASE);
#endif

regsBusyUntilKill &= ~killRefPosition->getKilledRegisters();
Expand All @@ -3877,6 +3913,7 @@ void LinearScan::processKills(RefPosition* killRefPosition)
// nextKill - The RefPosition for next kill
// regBase - `0` or `64` based on the `killedRegs` being processed
//
template <bool isLow>
void LinearScan::freeKilledRegs(RefPosition* killRefPosition,
SingleTypeRegSet killedRegs,
RefPosition* nextKill,
Expand All @@ -3899,7 +3936,7 @@ void LinearScan::freeKilledRegs(RefPosition* killRefPosition,
RefPosition* regNextRefPos = regRecord->recentRefPosition == nullptr
? regRecord->firstRefPosition
: regRecord->recentRefPosition->nextRefPosition;
updateNextFixedRef(regRecord, regNextRefPos, nextKill);
updateNextFixedRef<isLow>(regRecord, regNextRefPos, nextKill);
}
}

Expand Down Expand Up @@ -4862,7 +4899,7 @@ void LinearScan::allocateRegistersMinimal()
{
RegRecord* physRegRecord = getRegisterRecord(reg);
physRegRecord->recentRefPosition = nullptr;
updateNextFixedRef(physRegRecord, physRegRecord->firstRefPosition, killHead);
updateNextFixedRefDispatch(physRegRecord, physRegRecord->firstRefPosition, killHead);
assert(physRegRecord->assignedInterval == nullptr);
}

Expand Down Expand Up @@ -5080,8 +5117,7 @@ void LinearScan::allocateRegistersMinimal()
{
RegRecord* regRecord = currentRefPosition.getReg();
Interval* assignedInterval = regRecord->assignedInterval;

updateNextFixedRef(regRecord, currentRefPosition.nextRefPosition, nextKill);
updateNextFixedRefDispatch(regRecord, currentRefPosition.nextRefPosition, nextKill);

// This is a FixedReg. Disassociate any inactive constant interval from this register.
if (assignedInterval != nullptr && !assignedInterval->isActive && assignedInterval->isConstant)
Expand Down Expand Up @@ -5526,7 +5562,7 @@ void LinearScan::allocateRegisters()
{
RegRecord* physRegRecord = getRegisterRecord(reg);
physRegRecord->recentRefPosition = nullptr;
updateNextFixedRef(physRegRecord, physRegRecord->firstRefPosition, killHead);
updateNextFixedRefDispatch(physRegRecord, physRegRecord->firstRefPosition, killHead);

// Is this an incoming arg register? (Note that we don't, currently, consider reassigning
// an incoming arg register as having spill cost.)
Expand Down Expand Up @@ -5794,8 +5830,7 @@ void LinearScan::allocateRegisters()
{
RegRecord* regRecord = currentRefPosition.getReg();
Interval* assignedInterval = regRecord->assignedInterval;

updateNextFixedRef(regRecord, currentRefPosition.nextRefPosition, nextKill);
updateNextFixedRefDispatch(regRecord, currentRefPosition.nextRefPosition, nextKill);

// This is a FixedReg. Disassociate any inactive constant interval from this register.
if (assignedInterval != nullptr && !assignedInterval->isActive && assignedInterval->isConstant)
Expand Down Expand Up @@ -13593,14 +13628,29 @@ SingleTypeRegSet LinearScan::RegisterSelection::select(Interval*
// Also eliminate as busy any register with a conflicting fixed reference at this or
// the next location.
// Note that this will eliminate the fixedReg, if any, but we'll add it back below.
SingleTypeRegSet checkConflictMask = candidates & linearScan->fixedRegs.GetRegSetForType(regType);
SingleTypeRegSet checkConflictMask = candidates;
int regBase = REG_LOW_BASE;
#ifdef HAS_MORE_THAN_64_REGISTERS
if (!varTypeIsMask(regType))
{
checkConflictMask &= linearScan->fixedRegsLow;
}
else
{
regBase = REG_HIGH_BASE;
checkConflictMask &= linearScan->fixedRegsHigh;
}
#else
checkConflictMask &= linearScan->fixedRegsLow;
#endif
while (checkConflictMask != RBM_NONE)
{
regNumber checkConflictReg = genFirstRegNumFromMask(checkConflictMask, regType);
SingleTypeRegSet checkConflictBit = genSingleTypeRegMask(checkConflictReg);
regNumber checkConflictRegSingle = (regNumber)BitOperations::BitScanForward(checkConflictMask);
SingleTypeRegSet checkConflictBit = genSingleTypeRegMask(checkConflictRegSingle);
checkConflictMask ^= checkConflictBit;

LsraLocation checkConflictLocation = linearScan->nextFixedRef[checkConflictReg];
LsraLocation checkConflictLocation =
linearScan->nextFixedRef[(regNumber)(checkConflictRegSingle + regBase)];

if ((checkConflictLocation == refPosition->nodeLocation) ||
(refPosition->delayRegFree && (checkConflictLocation == (refPosition->nodeLocation + 1))))
Expand Down Expand Up @@ -13912,14 +13962,28 @@ SingleTypeRegSet LinearScan::RegisterSelection::selectMinimal(
// Also eliminate as busy any register with a conflicting fixed reference at this or
// the next location.
// Note that this will eliminate the fixedReg, if any, but we'll add it back below.
SingleTypeRegSet checkConflictMask = candidates & linearScan->fixedRegs.GetRegSetForType(regType);
SingleTypeRegSet checkConflictMask = candidates;
int regBase = REG_LOW_BASE;
#ifdef HAS_MORE_THAN_64_REGISTERS
if (!varTypeIsMask(regType))
{
checkConflictMask &= linearScan->fixedRegsLow;
}
else
{
regBase = REG_HIGH_BASE;
checkConflictMask &= linearScan->fixedRegsHigh;
}
#else
checkConflictMask &= linearScan->fixedRegsLow;
#endif
while (checkConflictMask != RBM_NONE)
{
regNumber checkConflictReg = genFirstRegNumFromMask(checkConflictMask, regType);
SingleTypeRegSet checkConflictBit = genSingleTypeRegMask(checkConflictReg);
regNumber checkConflictRegSingle = (regNumber)BitOperations::BitScanForward(checkConflictMask);
SingleTypeRegSet checkConflictBit = genSingleTypeRegMask(checkConflictRegSingle);
checkConflictMask ^= checkConflictBit;

LsraLocation checkConflictLocation = linearScan->nextFixedRef[checkConflictReg];
LsraLocation checkConflictLocation = linearScan->nextFixedRef[(regNumber)(checkConflictRegSingle + regBase)];

if ((checkConflictLocation == refPosition->nodeLocation) ||
(refPosition->delayRegFree && (checkConflictLocation == (refPosition->nodeLocation + 1))))
Expand Down
26 changes: 23 additions & 3 deletions src/coreclr/jit/lsra.h
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,8 @@ class LinearScan : public LinearScanInterface
void setIntervalAsSplit(Interval* interval);
void spillInterval(Interval* interval, RefPosition* fromRefPosition DEBUGARG(RefPosition* toRefPosition));

void processKills(RefPosition* killRefPosition);
void processKills(RefPosition* killRefPosition);
template <bool isLow>
FORCEINLINE void freeKilledRegs(RefPosition* killRefPosition,
SingleTypeRegSet killedRegs,
RefPosition* nextKill,
Expand Down Expand Up @@ -1822,9 +1823,28 @@ class LinearScan : public LinearScanInterface
}
SingleTypeRegSet getMatchingConstants(SingleTypeRegSet mask, Interval* currentInterval, RefPosition* refPosition);

regMaskTP fixedRegs;
SingleTypeRegSet fixedRegsLow;
#ifdef HAS_MORE_THAN_64_REGISTERS
SingleTypeRegSet fixedRegsHigh;
#endif
LsraLocation nextFixedRef[REG_COUNT];
void updateNextFixedRef(RegRecord* regRecord, RefPosition* nextRefPosition, RefPosition* nextKill);
template <bool isLow>
void updateNextFixedRef(RegRecord* regRecord, RefPosition* nextRefPosition, RefPosition* nextKill);
void updateNextFixedRefDispatch(RegRecord* regRecord, RefPosition* nextRefPosition, RefPosition* nextKill)
{
#ifdef HAS_MORE_THAN_64_REGISTERS
if (regRecord->regNum < 64)
{
updateNextFixedRef<true>(regRecord, nextRefPosition, nextKill);
}
else
{
updateNextFixedRef<false>(regRecord, nextRefPosition, nextKill);
}
#else
updateNextFixedRef<true>(regRecord, nextRefPosition, nextKill);
#endif
}
LsraLocation getNextFixedRef(regNumber regNum, var_types regType)
{
LsraLocation loc = nextFixedRef[regNum];
Expand Down
Loading