Skip to content

[AMDGPU] Fix canonicalization of truncated values. #83054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions llvm/lib/Target/AMDGPU/AMDGPUInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,25 @@ class HasOneUseTernaryOp<SDPatternOperator op> : PatFrag<
}];
}

class is_canonicalized<SDPatternOperator op> : PatFrag<
class is_canonicalized_1<SDPatternOperator op> : PatFrag<
(ops node:$src0),
(op $src0),
[{
const SITargetLowering &Lowering =
*static_cast<const SITargetLowering *>(getTargetLowering());

return Lowering.isCanonicalized(*CurDAG, N->getOperand(0));
}]> {

let GISelPredicateCode = [{
const SITargetLowering *TLI = static_cast<const SITargetLowering *>(
MF.getSubtarget().getTargetLowering());

return TLI->isCanonicalized(MI.getOperand(1).getReg(), MF);
}];
}

class is_canonicalized_2<SDPatternOperator op> : PatFrag<
(ops node:$src0, node:$src1),
(op $src0, $src1),
[{
Expand All @@ -210,8 +228,8 @@ class is_canonicalized<SDPatternOperator op> : PatFrag<
const SITargetLowering *TLI = static_cast<const SITargetLowering *>(
MF.getSubtarget().getTargetLowering());

return TLI->isCanonicalized(MI.getOperand(1).getReg(), const_cast<MachineFunction&>(MF)) &&
TLI->isCanonicalized(MI.getOperand(2).getReg(), const_cast<MachineFunction&>(MF));
return TLI->isCanonicalized(MI.getOperand(1).getReg(), MF) &&
TLI->isCanonicalized(MI.getOperand(2).getReg(), MF);
}];
}

Expand Down
65 changes: 35 additions & 30 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12573,6 +12573,10 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
case ISD::FREM:
case ISD::FP_ROUND:
case ISD::FP_EXTEND:
case ISD::FP16_TO_FP:
case ISD::FP_TO_FP16:
case ISD::BF16_TO_FP:
case ISD::FP_TO_BF16:
case ISD::FLDEXP:
case AMDGPUISD::FMUL_LEGACY:
case AMDGPUISD::FMAD_FTZ:
Expand All @@ -12592,6 +12596,9 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
case AMDGPUISD::CVT_F32_UBYTE1:
case AMDGPUISD::CVT_F32_UBYTE2:
case AMDGPUISD::CVT_F32_UBYTE3:
case AMDGPUISD::FP_TO_FP16:
case AMDGPUISD::SIN_HW:
case AMDGPUISD::COS_HW:
return true;

// It can/will be lowered or combined as a bit operation.
Expand All @@ -12601,6 +12608,20 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
case ISD::FCOPYSIGN:
return isCanonicalized(DAG, Op.getOperand(0), MaxDepth - 1);

case ISD::AND:
if (Op.getValueType() == MVT::i32) {
// Be careful as we only know it is a bitcast floating point type. It
// could be f32, v2f16, we have no way of knowing. Luckily the constant
// value that we optimize for, which comes up in fp32 to bf16 conversions,
// is valid to optimize for all types.
if (auto *RHS = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
if (RHS->getZExtValue() == 0xffff0000) {
return isCanonicalized(DAG, Op.getOperand(0), MaxDepth - 1);
}
}
}
break;

case ISD::FSIN:
case ISD::FCOS:
case ISD::FSINCOS:
Expand Down Expand Up @@ -12666,6 +12687,9 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
return false;

case ISD::BITCAST:
// TODO: This is incorrect as it loses track of the operand's type. We may
// end up effectively bitcasting from f32 to v2f16 or vice versa, and the
// same bits that are canonicalized in one type need not be in the other.
Comment on lines +12690 to +12692
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should give up on different element size bitcasts? (separate patch though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think rather than give up, we should add a parameter specifying the type so that we would be able to handle things like v2f16 -> i32 -> truncate to i16 -> f16. But yes, separate patch, this is a pre-existing issue.

return isCanonicalized(DAG, Op.getOperand(0), MaxDepth - 1);
case ISD::TRUNCATE: {
// Hack round the mess we make when legalizing extract_vector_elt
Expand Down Expand Up @@ -12695,25 +12719,26 @@ bool SITargetLowering::isCanonicalized(SelectionDAG &DAG, SDValue Op,
case Intrinsic::amdgcn_trig_preop:
case Intrinsic::amdgcn_log:
case Intrinsic::amdgcn_exp2:
case Intrinsic::amdgcn_sqrt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated bonus change?

Copy link
Contributor Author

@hvdijk hvdijk Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not unrelated, this is one of the things that came up in LLVM's testsuite where previously we were able to optimize away a fcanonicalize because we saw fsqrt, but now we only check later and see amdgcn_sqrt.

return true;
default:
break;
}

[[fallthrough]];
break;
}
default:
// FIXME: denormalsEnabledForType is broken for dynamic
return denormalsEnabledForType(DAG, Op.getValueType()) &&
DAG.isKnownNeverSNaN(Op);
break;
}

llvm_unreachable("invalid operation");
// FIXME: denormalsEnabledForType is broken for dynamic
return denormalsEnabledForType(DAG, Op.getValueType()) &&
DAG.isKnownNeverSNaN(Op);
}

bool SITargetLowering::isCanonicalized(Register Reg, MachineFunction &MF,
bool SITargetLowering::isCanonicalized(Register Reg, const MachineFunction &MF,
unsigned MaxDepth) const {
MachineRegisterInfo &MRI = MF.getRegInfo();
const MachineRegisterInfo &MRI = MF.getRegInfo();
MachineInstr *MI = MRI.getVRegDef(Reg);
unsigned Opcode = MI->getOpcode();

Expand Down Expand Up @@ -12932,27 +12957,7 @@ SDValue SITargetLowering::performFCanonicalizeCombine(
}
}

unsigned SrcOpc = N0.getOpcode();

// If it's free to do so, push canonicalizes further up the source, which may
// find a canonical source.
//
// TODO: More opcodes. Note this is unsafe for the _ieee minnum/maxnum for
// sNaNs.
if (SrcOpc == ISD::FMINNUM || SrcOpc == ISD::FMAXNUM) {
auto *CRHS = dyn_cast<ConstantFPSDNode>(N0.getOperand(1));
if (CRHS && N0.hasOneUse()) {
SDLoc SL(N);
SDValue Canon0 = DAG.getNode(ISD::FCANONICALIZE, SL, VT,
N0.getOperand(0));
SDValue Canon1 = getCanonicalConstantFP(DAG, SL, VT, CRHS->getValueAPF());
DCI.AddToWorklist(Canon0.getNode());

return DAG.getNode(N0.getOpcode(), SL, VT, Canon0, Canon1);
}
}

return isCanonicalized(DAG, N0) ? N0 : SDValue();
return SDValue();
}

static unsigned minMaxOpcToMin3Max3Opc(unsigned Opc) {
Expand Down Expand Up @@ -15924,8 +15929,8 @@ bool SITargetLowering::denormalsEnabledForType(const SelectionDAG &DAG,
}
}

bool SITargetLowering::denormalsEnabledForType(LLT Ty,
MachineFunction &MF) const {
bool SITargetLowering::denormalsEnabledForType(
LLT Ty, const MachineFunction &MF) const {
switch (Ty.getScalarSizeInBits()) {
case 32:
return !denormalModeIsFlushAllF32(MF);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AMDGPU/SIISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,10 @@ class SITargetLowering final : public AMDGPUTargetLowering {

bool isCanonicalized(SelectionDAG &DAG, SDValue Op,
unsigned MaxDepth = 5) const;
bool isCanonicalized(Register Reg, MachineFunction &MF,
bool isCanonicalized(Register Reg, const MachineFunction &MF,
unsigned MaxDepth = 5) const;
bool denormalsEnabledForType(const SelectionDAG &DAG, EVT VT) const;
bool denormalsEnabledForType(LLT Ty, MachineFunction &MF) const;
bool denormalsEnabledForType(LLT Ty, const MachineFunction &MF) const;

bool checkForPhysRegDependency(SDNode *Def, SDNode *User, unsigned Op,
const TargetRegisterInfo *TRI,
Expand Down
51 changes: 49 additions & 2 deletions llvm/lib/Target/AMDGPU/SIInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -2944,6 +2944,34 @@ def : GCNPat<
(V_BFREV_B32_e64 (i32 (EXTRACT_SUBREG VReg_64:$a, sub1))), sub0,
(V_BFREV_B32_e64 (i32 (EXTRACT_SUBREG VReg_64:$a, sub0))), sub1)>;

// If fcanonicalize's operand is implicitly canonicalized, we only need a copy.
let AddedComplexity = 1000 in {
def : GCNPat<
(is_canonicalized_1<fcanonicalize> f16:$src),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a dumb question (I have never understood DAG selection patterns) but why can't we write this more like (fcanonicalize (is_canonicalized f16:$src)), so the predicate applies to the operand instead of to the whole expression? Then we would not need separate _1 and _2 versions of is_canonicalized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a dumb question, one that I would like an answer to as well :) Variations of that were what I tried first, but I couldn't get it to work, and TableGen's error messages were not helpful in letting me figure out a way to get it working. I decided to just stick what was already being done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PatFrag is what owns the predicate; it might be possible to define something using srcvalue as a way of getting a catch-all matcher with a predicate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look and couldn't find any custom predicate logic other than for instructions, which means we cannot do it on the operand level as the operand need not be an instruction. It's possible I missed something simple, but I think I would prefer to just leave it like this over making big changes to make TableGen support operand-level predicates.

(COPY f16:$src)
>;

def : GCNPat<
(is_canonicalized_1<fcanonicalize> v2f16:$src),
(COPY v2f16:$src)
>;

def : GCNPat<
(is_canonicalized_1<fcanonicalize> f32:$src),
(COPY f32:$src)
>;

def : GCNPat<
(is_canonicalized_1<fcanonicalize> v2f32:$src),
(COPY v2f32:$src)
>;

def : GCNPat<
(is_canonicalized_1<fcanonicalize> f64:$src),
(COPY f64:$src)
>;
}

// Prefer selecting to max when legal, but using mul is always valid.
let AddedComplexity = -5 in {

Expand Down Expand Up @@ -3277,8 +3305,8 @@ def : GCNPat <

let AddedComplexity = 5 in {
def : GCNPat <
(v2f16 (is_canonicalized<build_vector> (f16 (VOP3Mods (f16 VGPR_32:$src0), i32:$src0_mods)),
(f16 (VOP3Mods (f16 VGPR_32:$src1), i32:$src1_mods)))),
(v2f16 (is_canonicalized_2<build_vector> (f16 (VOP3Mods (f16 VGPR_32:$src0), i32:$src0_mods)),
(f16 (VOP3Mods (f16 VGPR_32:$src1), i32:$src1_mods)))),
(V_PACK_B32_F16_e64 $src0_mods, VGPR_32:$src0, $src1_mods, VGPR_32:$src1)
>;
}
Expand Down Expand Up @@ -3590,6 +3618,17 @@ FPMinMaxPat<Instruction minmaxInst, ValueType vt, SDPatternOperator min_or_max,
DSTCLAMP.NONE, DSTOMOD.NONE)
>;

class
FPMinCanonMaxPat<Instruction minmaxInst, ValueType vt, SDPatternOperator min_or_max,
SDPatternOperator max_or_min_oneuse> : GCNPat <
(min_or_max (is_canonicalized_1<fcanonicalize>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how many more fp patterns there are that might need to skip over an fcanonicalize in the middle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's really arbitrary operations. fmin/fmax are just of particular interest because we're stuck inserting canonicalizes in the lowering for them. The more I think about it, the more I think we'd better of making this a separate transform (but just moving this to selection patterns is a net improvement for now I think)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may well be more, I just went for the ones that I came across from the test results. But if there are more, it doesn't result in wrong code, only suboptimal code, which I think is acceptable if the overall code still looks better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd only consider the min/max legalize case important. There's only a handful of canonicalizes in the math library, they aren't common (and I think we could avoid most of those too if we rewrote the code slightly)

(max_or_min_oneuse (VOP3Mods vt:$src0, i32:$src0_mods),
(VOP3Mods vt:$src1, i32:$src1_mods))),
(vt (VOP3Mods vt:$src2, i32:$src2_mods))),
(minmaxInst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
DSTCLAMP.NONE, DSTOMOD.NONE)
>;

let OtherPredicates = [isGFX11Plus] in {
def : IntMinMaxPat<V_MAXMIN_I32_e64, smin, smax_oneuse>;
def : IntMinMaxPat<V_MINMAX_I32_e64, smax, smin_oneuse>;
Expand All @@ -3599,6 +3638,10 @@ def : FPMinMaxPat<V_MINMAX_F32_e64, f32, fmaxnum_like, fminnum_like_oneuse>;
def : FPMinMaxPat<V_MAXMIN_F32_e64, f32, fminnum_like, fmaxnum_like_oneuse>;
def : FPMinMaxPat<V_MINMAX_F16_e64, f16, fmaxnum_like, fminnum_like_oneuse>;
def : FPMinMaxPat<V_MAXMIN_F16_e64, f16, fminnum_like, fmaxnum_like_oneuse>;
def : FPMinCanonMaxPat<V_MINMAX_F32_e64, f32, fmaxnum_like, fminnum_like_oneuse>;
def : FPMinCanonMaxPat<V_MAXMIN_F32_e64, f32, fminnum_like, fmaxnum_like_oneuse>;
def : FPMinCanonMaxPat<V_MINMAX_F16_e64, f16, fmaxnum_like, fminnum_like_oneuse>;
def : FPMinCanonMaxPat<V_MAXMIN_F16_e64, f16, fminnum_like, fmaxnum_like_oneuse>;
}

let OtherPredicates = [isGFX9Plus] in {
Expand All @@ -3612,6 +3655,10 @@ def : FPMinMaxPat<V_MINIMUMMAXIMUM_F32_e64, f32, DivergentBinFrag<fmaximum>, fmi
def : FPMinMaxPat<V_MAXIMUMMINIMUM_F32_e64, f32, DivergentBinFrag<fminimum>, fmaximum_oneuse>;
def : FPMinMaxPat<V_MINIMUMMAXIMUM_F16_e64, f16, DivergentBinFrag<fmaximum>, fminimum_oneuse>;
def : FPMinMaxPat<V_MAXIMUMMINIMUM_F16_e64, f16, DivergentBinFrag<fminimum>, fmaximum_oneuse>;
def : FPMinCanonMaxPat<V_MINIMUMMAXIMUM_F32_e64, f32, DivergentBinFrag<fmaximum>, fminimum_oneuse>;
def : FPMinCanonMaxPat<V_MAXIMUMMINIMUM_F32_e64, f32, DivergentBinFrag<fminimum>, fmaximum_oneuse>;
def : FPMinCanonMaxPat<V_MINIMUMMAXIMUM_F16_e64, f16, DivergentBinFrag<fmaximum>, fminimum_oneuse>;
def : FPMinCanonMaxPat<V_MAXIMUMMINIMUM_F16_e64, f16, DivergentBinFrag<fminimum>, fmaximum_oneuse>;
}

// Convert a floating-point power of 2 to the integer exponent.
Expand Down
Loading