Skip to content

[GlobaISel] Allow expanding of sdiv -> mul by constant #146504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ class CombinerHelper {
/// Query is legal on the target.
bool isLegalOrBeforeLegalizer(const LegalityQuery &Query) const;

/// \return true if \p Query is legal on the target, or if \p Query will
/// perform WidenScalar action on the target.
bool isLegalorHasWidenScalar(const LegalityQuery &Query) const;

/// \return true if the combine is running prior to legalization, or if \p Ty
/// is a legal integer constant type on the target.
bool isConstantLegalOrBeforeLegalizer(const LLT Ty) const;
Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -2046,9 +2046,9 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
div_rem_to_divrem, funnel_shift_combines, bitreverse_shift, commute_shift,
form_bitfield_extract, constant_fold_binops, constant_fold_fma,
constant_fold_cast_op, fabs_fneg_fold,
intdiv_combines, mulh_combines, redundant_neg_operands,
mulh_combines, redundant_neg_operands,
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
sub_add_reg, select_to_minmax,
intdiv_combines, sub_add_reg, select_to_minmax,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
simplify_neg_minmax, combine_concat_vector,
sext_trunc, zext_trunc, prefer_sign_combines, shuffle_combines,
Expand Down
134 changes: 114 additions & 20 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ bool CombinerHelper::isLegalOrBeforeLegalizer(
return isPreLegalize() || isLegal(Query);
}

bool CombinerHelper::isLegalorHasWidenScalar(const LegalityQuery &Query) const {
return isLegal(Query) ||
LI->getAction(Query).Action == LegalizeActions::WidenScalar;
}

bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
if (!Ty.isVector())
return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}});
Expand Down Expand Up @@ -5510,6 +5515,8 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
Register Dst = MI.getOperand(0).getReg();
Register RHS = MI.getOperand(2).getReg();
LLT DstTy = MRI.getType(Dst);
auto SizeInBits = DstTy.getScalarSizeInBits();
LLT WideTy = DstTy.changeElementSize(SizeInBits * 2);

auto &MF = *MI.getMF();
AttributeList Attr = MF.getFunction().getAttributes();
Expand All @@ -5529,8 +5536,21 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
}

// Don't support the general case for now.
return false;
auto *RHSDef = MRI.getVRegDef(RHS);
if (!isConstantOrConstantVector(*RHSDef, MRI))
return false;

// Don't do this if the types are not going to be legal.
if (LI) {
if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}}))
return false;
if (!isLegal({TargetOpcode::G_SMULH, {DstTy}}) &&
!isLegalorHasWidenScalar({TargetOpcode::G_MUL, {WideTy, WideTy}}))
return false;
}

return matchUnaryPredicate(
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
}

void CombinerHelper::applySDivByConst(MachineInstr &MI) const {
Expand All @@ -5546,21 +5566,22 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
Register RHS = SDiv.getReg(2);
LLT Ty = MRI.getType(Dst);
LLT ScalarTy = Ty.getScalarType();
const unsigned EltBits = ScalarTy.getScalarSizeInBits();
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
auto &MIB = Builder;

bool UseSRA = false;
SmallVector<Register, 16> Shifts, Factors;
SmallVector<Register, 16> ExactShifts, ExactFactors;

auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
bool IsSplat = getIConstantSplatVal(*RHSDef, MRI).has_value();
auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value();

auto BuildSDIVPattern = [&](const Constant *C) {
auto BuildExactSDIVPattern = [&](const Constant *C) {
// Don't recompute inverses for each splat element.
if (IsSplat && !Factors.empty()) {
Shifts.push_back(Shifts[0]);
Factors.push_back(Factors[0]);
if (IsSplat && !ExactFactors.empty()) {
ExactShifts.push_back(ExactShifts[0]);
ExactFactors.push_back(ExactFactors[0]);
return true;
}

Expand All @@ -5575,31 +5596,104 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
// Calculate the multiplicative inverse modulo BW.
// 2^W requires W + 1 bits, so we have to extend and then truncate.
APInt Factor = Divisor.multiplicativeInverse();
Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
ExactShifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
ExactFactors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
return true;
};

// Collect all magic values from the build vector.
if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
// Collect all magic values from the build vector.
bool Matched = matchUnaryPredicate(MRI, RHS, BuildExactSDIVPattern);
(void)Matched;
assert(Matched && "Expected unary predicate match to succeed");

Register Shift, Factor;
if (Ty.isVector()) {
Shift = MIB.buildBuildVector(ShiftAmtTy, ExactShifts).getReg(0);
Factor = MIB.buildBuildVector(Ty, ExactFactors).getReg(0);
} else {
Shift = ExactShifts[0];
Factor = ExactFactors[0];
}

Register Res = LHS;

if (UseSRA)
Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);

return MIB.buildMul(Ty, Res, Factor);
}

SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks;

auto BuildSDIVPattern = [&](const Constant *C) {
auto *CI = cast<ConstantInt>(C);
const APInt &Divisor = CI->getValue();

SignedDivisionByConstantInfo magics =
SignedDivisionByConstantInfo::get(Divisor);
int NumeratorFactor = 0;
int ShiftMask = -1;

if (Divisor.isOne() || Divisor.isAllOnes()) {
// If d is +1/-1, we just multiply the numerator by +1/-1.
NumeratorFactor = Divisor.getSExtValue();
magics.Magic = 0;
magics.ShiftAmount = 0;
ShiftMask = 0;
} else if (Divisor.isStrictlyPositive() && magics.Magic.isNegative()) {
// If d > 0 and m < 0, add the numerator.
NumeratorFactor = 1;
} else if (Divisor.isNegative() && magics.Magic.isStrictlyPositive()) {
// If d < 0 and m > 0, subtract the numerator.
NumeratorFactor = -1;
}

MagicFactors.push_back(MIB.buildConstant(ScalarTy, magics.Magic).getReg(0));
Factors.push_back(MIB.buildConstant(ScalarTy, NumeratorFactor).getReg(0));
Shifts.push_back(
MIB.buildConstant(ScalarShiftAmtTy, magics.ShiftAmount).getReg(0));
ShiftMasks.push_back(MIB.buildConstant(ScalarTy, ShiftMask).getReg(0));

return true;
};

// Collect the shifts/magic values from each element.
bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern);
(void)Matched;
assert(Matched && "Expected unary predicate match to succeed");

Register Shift, Factor;
if (Ty.isVector()) {
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
Register MagicFactor, Factor, Shift, ShiftMask;
auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
if (RHSDef) {
MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0);
Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
ShiftMask = MIB.buildBuildVector(Ty, ShiftMasks).getReg(0);
} else {
Shift = Shifts[0];
assert(MRI.getType(RHS).isScalar() &&
"Non-build_vector operation should have been a scalar");
MagicFactor = MagicFactors[0];
Factor = Factors[0];
Shift = Shifts[0];
ShiftMask = ShiftMasks[0];
}

Register Res = LHS;
Register Q = LHS;
Q = MIB.buildSMulH(Ty, LHS, MagicFactor).getReg(0);

// (Optionally) Add/subtract the numerator using Factor.
Factor = MIB.buildMul(Ty, LHS, Factor).getReg(0);
Q = MIB.buildAdd(Ty, Q, Factor).getReg(0);

if (UseSRA)
Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
// Shift right algebraic by shift value.
Q = MIB.buildAShr(Ty, Q, Shift).getReg(0);

return MIB.buildMul(Ty, Res, Factor);
// Extract the sign bit, mask it and add it to the quotient.
auto SignShift = MIB.buildConstant(ShiftAmtTy, EltBits - 1);
auto T = MIB.buildLShr(Ty, Q, SignShift);
T = MIB.buildAnd(Ty, T, ShiftMask);
return MIB.buildAdd(Ty, Q, T);
}

bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {
Expand Down
Loading