Skip to content

[NVPTX] Consistently check fast-math flags when lowering div #136890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,14 @@ enum PrmtMode {
RC16,
};
}
}

enum class DivPrecisionLevel : unsigned {
Approx = 0,
Full = 1,
IEEE754 = 2,
};

} // namespace NVPTX
void initializeNVPTXDAGToDAGISelLegacyPass(PassRegistry &);
} // namespace llvm

Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) {
return SelectionDAGISel::runOnMachineFunction(MF);
}

int NVPTXDAGToDAGISel::getDivF32Level() const {
return Subtarget->getTargetLowering()->getDivF32Level();
NVPTX::DivPrecisionLevel
NVPTXDAGToDAGISel::getDivF32Level(const SDNode *N) const {
return Subtarget->getTargetLowering()->getDivF32Level(*MF, N);
}

bool NVPTXDAGToDAGISel::usePrecSqrtF32() const {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
// If true, generate mul.wide from sext and mul
bool doMulWide;

int getDivF32Level() const;
NVPTX::DivPrecisionLevel getDivF32Level(const SDNode *N) const;
bool usePrecSqrtF32() const;
bool useF32FTZ() const;
bool allowFMA() const;
Expand Down
36 changes: 24 additions & 12 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,16 @@ static cl::opt<unsigned> FMAContractLevelOpt(
" 1: do it 2: do it aggressively"),
cl::init(2));

static cl::opt<int> UsePrecDivF32(
static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
"nvptx-prec-divf32", cl::Hidden,
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
" IEEE Compliant F32 div.rnd if available."),
cl::init(2));
cl::values(clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0",
"Use div.approx"),
clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
"Use IEEE Compliant F32 div.rnd if available")),
cl::init(NVPTX::DivPrecisionLevel::IEEE754));

static cl::opt<bool> UsePrecSqrtF32(
"nvptx-prec-sqrtf32", cl::Hidden,
Expand All @@ -109,17 +114,24 @@ static cl::opt<bool> ForceMinByValParamAlign(
" params of device functions."),
cl::init(false));

int NVPTXTargetLowering::getDivF32Level() const {
if (UsePrecDivF32.getNumOccurrences() > 0) {
// If nvptx-prec-div32=N is used on the command-line, always honor it
NVPTX::DivPrecisionLevel
NVPTXTargetLowering::getDivF32Level(const MachineFunction &MF,
const SDNode *N) const {
// If nvptx-prec-div32=N is used on the command-line, always honor it
if (UsePrecDivF32.getNumOccurrences() > 0)
return UsePrecDivF32;
} else {
// Otherwise, use div.approx if fast math is enabled
if (getTargetMachine().Options.UnsafeFPMath)
return 0;
else
return 2;

// Otherwise, use div.approx if fast math is enabled
if (allowUnsafeFPMath(MF))
return NVPTX::DivPrecisionLevel::Approx;

if (N) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever call it with a nullptr N ? Or, in other words, do we ever intend to use if for the MF alone?
If so, then it would probably make sense to declare N with the default value.

If we always expect to see a valid node, then I'd pass it as a constref.

const SDNodeFlags Flags = N->getFlags();
if (Flags.hasApproximateFuncs())
return NVPTX::DivPrecisionLevel::Approx;
}

return NVPTX::DivPrecisionLevel::IEEE754;
}

bool NVPTXTargetLowering::usePrecSqrtF32() const {
Expand Down Expand Up @@ -4947,7 +4959,7 @@ bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
return allowUnsafeFPMath(MF);
}

bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
bool NVPTXTargetLowering::allowUnsafeFPMath(const MachineFunction &MF) const {
// Honor TargetOptions flags that explicitly say unsafe math is okay.
if (MF.getTarget().Options.UnsafeFPMath)
return true;
Expand Down
9 changes: 3 additions & 6 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,8 @@ class NVPTXTargetLowering : public TargetLowering {

// Get the degree of precision we want from 32-bit floating point division
// operations.
//
// 0 - Use ptx div.approx
// 1 - Use ptx.div.full (approximate, but less so than div.approx)
// 2 - Use IEEE-compliant div instructions, if available.
int getDivF32Level() const;
NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF,
const SDNode *N) const;

// Get whether we should use a precise or approximate 32-bit floating point
// sqrt instruction.
Expand All @@ -235,7 +232,7 @@ class NVPTXTargetLowering : public TargetLowering {
unsigned combineRepeatedFPDivisors() const override { return 2; }

bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const;
bool allowUnsafeFPMath(MachineFunction &MF) const;
bool allowUnsafeFPMath(const MachineFunction &MF) const;

bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
EVT) const override {
Expand Down
125 changes: 60 additions & 65 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;

def doMulWide : Predicate<"doMulWide">;

def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;

def do_SQRTF32_APPROX : Predicate<"!usePrecSqrtF32()">;
def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;

Expand Down Expand Up @@ -1108,26 +1105,19 @@ def INEG64 :
//-----------------------------------

// Constant 1.0f
def FloatConst1 : PatLeaf<(fpimm), [{
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEsingle() &&
N->getValueAPF().convertToFloat() == 1.0f;
def f32imm_1 : FPImmLeaf<f32, [{
return &Imm.getSemantics() == &llvm::APFloat::IEEEsingle() &&
Imm.convertToFloat() == 1.0f;
}]>;
// Constant 1.0 (double)
def DoubleConst1 : PatLeaf<(fpimm), [{
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEdouble() &&
N->getValueAPF().convertToDouble() == 1.0;
def f64imm_1 : FPImmLeaf<f64, [{
return &Imm.getSemantics() == &llvm::APFloat::IEEEdouble() &&
Imm.convertToDouble() == 1.0;
}]>;
// Constant -1.0 (double)
def DoubleConstNeg1 : PatLeaf<(fpimm), [{
return &N->getValueAPF().getSemantics() == &llvm::APFloat::IEEEdouble() &&
N->getValueAPF().convertToDouble() == -1.0;
}]>;


// Constant -X -> X (double)
def NegDoubleConst : SDNodeXForm<fpimm, [{
return CurDAG->getTargetConstantFP(-(N->getValueAPF()),
SDLoc(N), MVT::f64);
def f64imm_neg1 : FPImmLeaf<f64, [{
return &Imm.getSemantics() == &llvm::APFloat::IEEEdouble() &&
Imm.convertToDouble() == -1.0;
}]>;

defm FADD : F3_fma_component<"add", fadd>;
Expand Down Expand Up @@ -1178,11 +1168,11 @@ def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, Int32Regs, True>;
//
// F64 division
//
def FDIV641r :
def FRCP64r :
NVPTXInst<(outs Float64Regs:$dst),
(ins f64imm:$a, Float64Regs:$b),
(ins Float64Regs:$b),
"rcp.rn.f64 \t$dst, $b;",
[(set f64:$dst, (fdiv DoubleConst1:$a, f64:$b))]>;
[(set f64:$dst, (fdiv f64imm_1, f64:$b))]>;
def FDIV64rr :
NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, Float64Regs:$b),
Expand All @@ -1196,109 +1186,114 @@ def FDIV64ri :

// fdiv will be converted to rcp
// fneg (fdiv 1.0, X) => fneg (rcp.rn X)
def : Pat<(fdiv DoubleConstNeg1:$a, f64:$b),
(FNEGf64 (FDIV641r (NegDoubleConst node:$a), $b))>;
def : Pat<(fdiv f64imm_neg1, f64:$b),
(FNEGf64 (FRCP64r $b))>;

//
// F32 Approximate reciprocal
//
def FDIV321r_ftz :

def fdiv_approx : PatFrag<(ops node:$a, node:$b),
(fdiv node:$a, node:$b), [{
return getDivF32Level(N) == NVPTX::DivPrecisionLevel::Approx;
}]>;


def FRCP32_approx_r_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins f32imm:$a, Float32Regs:$b),
(ins Float32Regs:$b),
"rcp.approx.ftz.f32 \t$dst, $b;",
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
def FDIV321r :
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>,
Requires<[doF32FTZ]>;
def FRCP32_approx_r :
NVPTXInst<(outs Float32Regs:$dst),
(ins f32imm:$a, Float32Regs:$b),
(ins Float32Regs:$b),
"rcp.approx.f32 \t$dst, $b;",
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
Requires<[do_DIVF32_APPROX]>;
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>;

//
// F32 Approximate division
//
def FDIV32approxrr_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b),
"div.approx.ftz.f32 \t$dst, $a, $b;",
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>,
Requires<[doF32FTZ]>;
def FDIV32approxri_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b),
"div.approx.ftz.f32 \t$dst, $a, $b;",
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
Requires<[do_DIVF32_APPROX, doF32FTZ]>;
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>,
Requires<[doF32FTZ]>;
def FDIV32approxrr :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b),
"div.approx.f32 \t$dst, $a, $b;",
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
Requires<[do_DIVF32_APPROX]>;
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>;
def FDIV32approxri :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b),
"div.approx.f32 \t$dst, $a, $b;",
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
Requires<[do_DIVF32_APPROX]>;
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>;
//
// F32 Semi-accurate reciprocal
//
// rcp.approx gives the same result as div.full(1.0f, a) and is faster.
//
def FDIV321r_approx_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins f32imm:$a, Float32Regs:$b),
"rcp.approx.ftz.f32 \t$dst, $b;",
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
Requires<[do_DIVF32_FULL, doF32FTZ]>;
def FDIV321r_approx :
NVPTXInst<(outs Float32Regs:$dst),
(ins f32imm:$a, Float32Regs:$b),
"rcp.approx.f32 \t$dst, $b;",
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
Requires<[do_DIVF32_FULL]>;

def fdiv_full : PatFrag<(ops node:$a, node:$b),
(fdiv node:$a, node:$b), [{
return getDivF32Level(N) == NVPTX::DivPrecisionLevel::Full;
}]>;


def : Pat<(fdiv_full f32imm_1, f32:$b),
(FRCP32_approx_r_ftz $b)>,
Requires<[doF32FTZ]>;

def : Pat<(fdiv_full f32imm_1, f32:$b),
(FRCP32_approx_r $b)>;

//
// F32 Semi-accurate division
//
def FDIV32rr_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b),
"div.full.ftz.f32 \t$dst, $a, $b;",
[(set f32:$dst, (fdiv Float32Regs:$a, f32:$b))]>,
Requires<[do_DIVF32_FULL, doF32FTZ]>;
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>,
Requires<[doF32FTZ]>;
def FDIV32ri_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b),
"div.full.ftz.f32 \t$dst, $a, $b;",
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
Requires<[do_DIVF32_FULL, doF32FTZ]>;
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>,
Requires<[doF32FTZ]>;
def FDIV32rr :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b),
"div.full.f32 \t$dst, $a, $b;",
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
Requires<[do_DIVF32_FULL]>;
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>;
def FDIV32ri :
NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b),
"div.full.f32 \t$dst, $a, $b;",
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
Requires<[do_DIVF32_FULL]>;
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>;
//
// F32 Accurate reciprocal
//
def FDIV321r_prec_ftz :
NVPTXInst<(outs Float32Regs:$dst),
(ins f32imm:$a, Float32Regs:$b),
(ins Float32Regs:$b),
"rcp.rn.ftz.f32 \t$dst, $b;",
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>,
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>,
Requires<[doF32FTZ]>;
def FDIV321r_prec :
def FRCP32r_prec :
NVPTXInst<(outs Float32Regs:$dst),
(ins f32imm:$a, Float32Regs:$b),
(ins Float32Regs:$b),
"rcp.rn.f32 \t$dst, $b;",
[(set f32:$dst, (fdiv FloatConst1:$a, f32:$b))]>;
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>;
//
// F32 Accurate division
//
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1606,24 +1606,24 @@ def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
F64RT, F64RT, int_nvvm_rsqrt_approx_d>;

// 1.0f / sqrt_approx -> rsqrt_approx
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f f32:$a)),
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_f f32:$a)),
(INT_NVVM_RSQRT_APPROX_F $a)>,
Requires<[doRsqrtOpt]>;
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f f32:$a)),
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_approx_ftz_f f32:$a)),
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
Requires<[doRsqrtOpt]>;
// same for int_nvvm_sqrt_f when non-precision sqrt is requested
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f f32:$a)),
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
(INT_NVVM_RSQRT_APPROX_F $a)>,
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f f32:$a)),
def: Pat<(fdiv f32imm_1, (int_nvvm_sqrt_f f32:$a)),
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;

def: Pat<(fdiv FloatConst1, (fsqrt f32:$a)),
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
(INT_NVVM_RSQRT_APPROX_F $a)>,
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
def: Pat<(fdiv FloatConst1, (fsqrt f32:$a)),
def: Pat<(fdiv f32imm_1, (fsqrt f32:$a)),
(INT_NVVM_RSQRT_APPROX_FTZ_F $a)>,
Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
//
Expand Down
Loading
Loading