Skip to content

Commit ceb744e

Browse files
authored
[AMDGPU] Fix canonicalization of truncated values. (#83054)
We were relying on roundings to implicitly canonicalize, which is generally safe, except with roundings that may be optimized away. Fixes #82937.
1 parent 7cd61f8 commit ceb744e

14 files changed

+1021
-1410
lines changed

llvm/lib/Target/AMDGPU/AMDGPUInstructions.td

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,25 @@ class HasOneUseTernaryOp<SDPatternOperator op> : PatFrag<
194194
}];
195195
}
196196

197-
class is_canonicalized<SDPatternOperator op> : PatFrag<
197+
class is_canonicalized_1<SDPatternOperator op> : PatFrag<
198+
(ops node:$src0),
199+
(op $src0),
200+
[{
201+
const SITargetLowering &Lowering =
202+
*static_cast<const SITargetLowering *>(getTargetLowering());
203+
204+
return Lowering.isCanonicalized(*CurDAG, N->getOperand(0));
205+
}]> {
206+
207+
let GISelPredicateCode = [{
208+
const SITargetLowering *TLI = static_cast<const SITargetLowering *>(
209+
MF.getSubtarget().getTargetLowering());
210+
211+
return TLI->isCanonicalized(MI.getOperand(1).getReg(), MF);
212+
}];
213+
}
214+
215+
class is_canonicalized_2<SDPatternOperator op> : PatFrag<
198216
(ops node:$src0, node:$src1),
199217
(op $src0, $src1),
200218
[{
@@ -210,8 +228,8 @@ class is_canonicalized<SDPatternOperator op> : PatFrag<
210228
const SITargetLowering *TLI = static_cast<const SITargetLowering *>(
211229
MF.getSubtarget().getTargetLowering());
212230

213-
return TLI->isCanonicalized(MI.getOperand(1).getReg(), const_cast<MachineFunction&>(MF)) &&
214-
TLI->isCanonicalized(MI.getOperand(2).getReg(), const_cast<MachineFunction&>(MF));
231+
return TLI->isCanonicalized(MI.getOperand(1).getReg(), MF) &&
232+
TLI->isCanonicalized(MI.getOperand(2).getReg(), MF);
215233
}];
216234
}
217235

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12572,6 +12572,10 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
1257212572
case ISD::FREM:
1257312573
case ISD::FP_ROUND:
1257412574
case ISD::FP_EXTEND:
12575+
case ISD::FP16_TO_FP:
12576+
case ISD::FP_TO_FP16:
12577+
case ISD::BF16_TO_FP:
12578+
case ISD::FP_TO_BF16:
1257512579
case ISD::FLDEXP:
1257612580
case AMDGPUISD::FMUL_LEGACY:
1257712581
case AMDGPUISD::FMAD_FTZ:
@@ -12591,6 +12595,9 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
1259112595
case AMDGPUISD::CVT_F32_UBYTE1:
1259212596
case AMDGPUISD::CVT_F32_UBYTE2:
1259312597
case AMDGPUISD::CVT_F32_UBYTE3:
12598+
case AMDGPUISD::FP_TO_FP16:
12599+
case AMDGPUISD::SIN_HW:
12600+
case AMDGPUISD::COS_HW:
1259412601
return true;
1259512602

1259612603
// It can/will be lowered or combined as a bit operation.
@@ -12600,6 +12607,20 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
1260012607
case ISD::FCOPYSIGN:
1260112608
return isCanonicalized(DAG, Op.getOperand(0), MaxDepth - 1);
1260212609

12610+
case ISD::AND:
12611+
if (Op.getValueType() == MVT::i32) {
12612+
// Be careful as we only know it is a bitcast floating point type. It
12613+
// could be f32, v2f16, we have no way of knowing. Luckily the constant
12614+
// value that we optimize for, which comes up in fp32 to bf16 conversions,
12615+
// is valid to optimize for all types.
12616+
if (auto *RHS = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
12617+
if (RHS->getZExtValue() == 0xffff0000) {
12618+
return isCanonicalized(DAG, Op.getOperand(0), MaxDepth - 1);
12619+
}
12620+
}
12621+
}
12622+
break;
12623+
1260312624
case ISD::FSIN:
1260412625
case ISD::FCOS:
1260512626
case ISD::FSINCOS:
@@ -12665,6 +12686,9 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
1266512686
return false;
1266612687

1266712688
case ISD::BITCAST:
12689+
// TODO: This is incorrect as it loses track of the operand's type. We may
12690+
// end up effectively bitcasting from f32 to v2f16 or vice versa, and the
12691+
// same bits that are canonicalized in one type need not be in the other.
1266812692
return isCanonicalized(DAG, Op.getOperand(0), MaxDepth - 1);
1266912693
case ISD::TRUNCATE: {
1267012694
// Hack round the mess we make when legalizing extract_vector_elt
@@ -12694,25 +12718,26 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
1269412718
case Intrinsic::amdgcn_trig_preop:
1269512719
case Intrinsic::amdgcn_log:
1269612720
case Intrinsic::amdgcn_exp2:
12721+
case Intrinsic::amdgcn_sqrt:
1269712722
return true;
1269812723
default:
1269912724
break;
1270012725
}
1270112726

12702-
[[fallthrough]];
12727+
break;
1270312728
}
1270412729
default:
12705-
// FIXME: denormalsEnabledForType is broken for dynamic
12706-
return denormalsEnabledForType(DAG, Op.getValueType()) &&
12707-
DAG.isKnownNeverSNaN(Op);
12730+
break;
1270812731
}
1270912732

12710-
llvm_unreachable("invalid operation");
12733+
// FIXME: denormalsEnabledForType is broken for dynamic
12734+
return denormalsEnabledForType(DAG, Op.getValueType()) &&
12735+
DAG.isKnownNeverSNaN(Op);
1271112736
}
1271212737

12713-
bool SITargetLowering::isCanonicalized(Register Reg, MachineFunction &MF,
12738+
bool SITargetLowering::isCanonicalized(Register Reg, const MachineFunction &MF,
1271412739
unsigned MaxDepth) const {
12715-
MachineRegisterInfo &MRI = MF.getRegInfo();
12740+
const MachineRegisterInfo &MRI = MF.getRegInfo();
1271612741
MachineInstr *MI = MRI.getVRegDef(Reg);
1271712742
unsigned Opcode = MI->getOpcode();
1271812743

@@ -12931,27 +12956,7 @@ SDValue SITargetLowering::performFCanonicalizeCombine(
1293112956
}
1293212957
}
1293312958

12934-
unsigned SrcOpc = N0.getOpcode();
12935-
12936-
// If it's free to do so, push canonicalizes further up the source, which may
12937-
// find a canonical source.
12938-
//
12939-
// TODO: More opcodes. Note this is unsafe for the _ieee minnum/maxnum for
12940-
// sNaNs.
12941-
if (SrcOpc == ISD::FMINNUM || SrcOpc == ISD::FMAXNUM) {
12942-
auto *CRHS = dyn_cast<ConstantFPSDNode>(N0.getOperand(1));
12943-
if (CRHS && N0.hasOneUse()) {
12944-
SDLoc SL(N);
12945-
SDValue Canon0 = DAG.getNode(ISD::FCANONICALIZE, SL, VT,
12946-
N0.getOperand(0));
12947-
SDValue Canon1 = getCanonicalConstantFP(DAG, SL, VT, CRHS->getValueAPF());
12948-
DCI.AddToWorklist(Canon0.getNode());
12949-
12950-
return DAG.getNode(N0.getOpcode(), SL, VT, Canon0, Canon1);
12951-
}
12952-
}
12953-
12954-
return isCanonicalized(DAG, N0) ? N0 : SDValue();
12959+
return SDValue();
1295512960
}
1295612961

1295712962
static unsigned minMaxOpcToMin3Max3Opc(unsigned Opc) {
@@ -15939,8 +15944,8 @@ bool SITargetLowering::denormalsEnabledForType(const SelectionDAG &DAG,
1593915944
}
1594015945
}
1594115946

15942-
bool SITargetLowering::denormalsEnabledForType(LLT Ty,
15943-
MachineFunction &MF) const {
15947+
bool SITargetLowering::denormalsEnabledForType(
15948+
LLT Ty, const MachineFunction &MF) const {
1594415949
switch (Ty.getScalarSizeInBits()) {
1594515950
case 32:
1594615951
return !denormalModeIsFlushAllF32(MF);

llvm/lib/Target/AMDGPU/SIISelLowering.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,10 @@ class SITargetLowering final : public AMDGPUTargetLowering {
523523

524524
bool isCanonicalized(SelectionDAG &DAG, SDValue Op,
525525
unsigned MaxDepth = 5) const;
526-
bool isCanonicalized(Register Reg, MachineFunction &MF,
526+
bool isCanonicalized(Register Reg, const MachineFunction &MF,
527527
unsigned MaxDepth = 5) const;
528528
bool denormalsEnabledForType(const SelectionDAG &DAG, EVT VT) const;
529-
bool denormalsEnabledForType(LLT Ty, MachineFunction &MF) const;
529+
bool denormalsEnabledForType(LLT Ty, const MachineFunction &MF) const;
530530

531531
bool checkForPhysRegDependency(SDNode *Def, SDNode *User, unsigned Op,
532532
const TargetRegisterInfo *TRI,

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,6 +2944,34 @@ def : GCNPat<
29442944
(V_BFREV_B32_e64 (i32 (EXTRACT_SUBREG VReg_64:$a, sub1))), sub0,
29452945
(V_BFREV_B32_e64 (i32 (EXTRACT_SUBREG VReg_64:$a, sub0))), sub1)>;
29462946

2947+
// If fcanonicalize's operand is implicitly canonicalized, we only need a copy.
2948+
let AddedComplexity = 1000 in {
2949+
def : GCNPat<
2950+
(is_canonicalized_1<fcanonicalize> f16:$src),
2951+
(COPY f16:$src)
2952+
>;
2953+
2954+
def : GCNPat<
2955+
(is_canonicalized_1<fcanonicalize> v2f16:$src),
2956+
(COPY v2f16:$src)
2957+
>;
2958+
2959+
def : GCNPat<
2960+
(is_canonicalized_1<fcanonicalize> f32:$src),
2961+
(COPY f32:$src)
2962+
>;
2963+
2964+
def : GCNPat<
2965+
(is_canonicalized_1<fcanonicalize> v2f32:$src),
2966+
(COPY v2f32:$src)
2967+
>;
2968+
2969+
def : GCNPat<
2970+
(is_canonicalized_1<fcanonicalize> f64:$src),
2971+
(COPY f64:$src)
2972+
>;
2973+
}
2974+
29472975
// Prefer selecting to max when legal, but using mul is always valid.
29482976
let AddedComplexity = -5 in {
29492977

@@ -3277,8 +3305,8 @@ def : GCNPat <
32773305

32783306
let AddedComplexity = 5 in {
32793307
def : GCNPat <
3280-
(v2f16 (is_canonicalized<build_vector> (f16 (VOP3Mods (f16 VGPR_32:$src0), i32:$src0_mods)),
3281-
(f16 (VOP3Mods (f16 VGPR_32:$src1), i32:$src1_mods)))),
3308+
(v2f16 (is_canonicalized_2<build_vector> (f16 (VOP3Mods (f16 VGPR_32:$src0), i32:$src0_mods)),
3309+
(f16 (VOP3Mods (f16 VGPR_32:$src1), i32:$src1_mods)))),
32823310
(V_PACK_B32_F16_e64 $src0_mods, VGPR_32:$src0, $src1_mods, VGPR_32:$src1)
32833311
>;
32843312
}
@@ -3590,6 +3618,17 @@ FPMinMaxPat<Instruction minmaxInst, ValueType vt, SDPatternOperator min_or_max,
35903618
DSTCLAMP.NONE, DSTOMOD.NONE)
35913619
>;
35923620

3621+
class
3622+
FPMinCanonMaxPat<Instruction minmaxInst, ValueType vt, SDPatternOperator min_or_max,
3623+
SDPatternOperator max_or_min_oneuse> : GCNPat <
3624+
(min_or_max (is_canonicalized_1<fcanonicalize>
3625+
(max_or_min_oneuse (VOP3Mods vt:$src0, i32:$src0_mods),
3626+
(VOP3Mods vt:$src1, i32:$src1_mods))),
3627+
(vt (VOP3Mods vt:$src2, i32:$src2_mods))),
3628+
(minmaxInst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
3629+
DSTCLAMP.NONE, DSTOMOD.NONE)
3630+
>;
3631+
35933632
let OtherPredicates = [isGFX11Plus] in {
35943633
def : IntMinMaxPat<V_MAXMIN_I32_e64, smin, smax_oneuse>;
35953634
def : IntMinMaxPat<V_MINMAX_I32_e64, smax, smin_oneuse>;
@@ -3599,6 +3638,10 @@ def : FPMinMaxPat<V_MINMAX_F32_e64, f32, fmaxnum_like, fminnum_like_oneuse>;
35993638
def : FPMinMaxPat<V_MAXMIN_F32_e64, f32, fminnum_like, fmaxnum_like_oneuse>;
36003639
def : FPMinMaxPat<V_MINMAX_F16_e64, f16, fmaxnum_like, fminnum_like_oneuse>;
36013640
def : FPMinMaxPat<V_MAXMIN_F16_e64, f16, fminnum_like, fmaxnum_like_oneuse>;
3641+
def : FPMinCanonMaxPat<V_MINMAX_F32_e64, f32, fmaxnum_like, fminnum_like_oneuse>;
3642+
def : FPMinCanonMaxPat<V_MAXMIN_F32_e64, f32, fminnum_like, fmaxnum_like_oneuse>;
3643+
def : FPMinCanonMaxPat<V_MINMAX_F16_e64, f16, fmaxnum_like, fminnum_like_oneuse>;
3644+
def : FPMinCanonMaxPat<V_MAXMIN_F16_e64, f16, fminnum_like, fmaxnum_like_oneuse>;
36023645
}
36033646

36043647
let OtherPredicates = [isGFX9Plus] in {
@@ -3612,6 +3655,10 @@ def : FPMinMaxPat<V_MINIMUMMAXIMUM_F32_e64, f32, DivergentBinFrag<fmaximum>, fmi
36123655
def : FPMinMaxPat<V_MAXIMUMMINIMUM_F32_e64, f32, DivergentBinFrag<fminimum>, fmaximum_oneuse>;
36133656
def : FPMinMaxPat<V_MINIMUMMAXIMUM_F16_e64, f16, DivergentBinFrag<fmaximum>, fminimum_oneuse>;
36143657
def : FPMinMaxPat<V_MAXIMUMMINIMUM_F16_e64, f16, DivergentBinFrag<fminimum>, fmaximum_oneuse>;
3658+
def : FPMinCanonMaxPat<V_MINIMUMMAXIMUM_F32_e64, f32, DivergentBinFrag<fmaximum>, fminimum_oneuse>;
3659+
def : FPMinCanonMaxPat<V_MAXIMUMMINIMUM_F32_e64, f32, DivergentBinFrag<fminimum>, fmaximum_oneuse>;
3660+
def : FPMinCanonMaxPat<V_MINIMUMMAXIMUM_F16_e64, f16, DivergentBinFrag<fmaximum>, fminimum_oneuse>;
3661+
def : FPMinCanonMaxPat<V_MAXIMUMMINIMUM_F16_e64, f16, DivergentBinFrag<fminimum>, fmaximum_oneuse>;
36153662
}
36163663

36173664
// Convert a floating-point power of 2 to the integer exponent.

0 commit comments

Comments
 (0)