Skip to content

[SDAG] Fix fmaximum legalization errors #142170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,20 @@ class SDDbgInfo {

LLVM_ABI void checkForCycles(const SelectionDAG *DAG, bool force = false);

/// Keeps track of state when getting the sign of a floating-point value as an
/// integer.
struct FloatSignAsInt {
EVT FloatVT;
SDValue Chain;
SDValue FloatPtr;
SDValue IntPtr;
MachinePointerInfo IntPointerInfo;
MachinePointerInfo FloatPointerInfo;
SDValue IntValue;
APInt SignMask;
uint8_t SignBit;
};

/// This is used to represent a portion of an LLVM function in a low-level
/// Data Dependence DAG representation suitable for instruction selection.
/// This DAG is constructed as the first step of instruction selection in order
Expand Down Expand Up @@ -2017,6 +2031,16 @@ class SelectionDAG {
/// value types.
LLVM_ABI SDValue CreateStackTemporary(EVT VT1, EVT VT2);

/// Bitcast a floating-point value to an integer value. Only bitcast the part
/// containing the sign bit if the target has no integer value capable of
/// holding all bits of the floating-point value.
void getSignAsIntValue(FloatSignAsInt &State, const SDLoc &DL, SDValue Value);

/// Replace the integer value produced by getSignAsIntValue() with a new value
/// and cast the result back to a floating-point type.
SDValue modifySignAsInt(const FloatSignAsInt &State, const SDLoc &DL,
SDValue NewIntValue);

LLVM_ABI SDValue FoldSymbolOffset(unsigned Opcode, EVT VT,
const GlobalAddressSDNode *GA,
const SDNode *N2);
Expand Down
100 changes: 7 additions & 93 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,6 @@ using namespace llvm;

namespace {

/// Keeps track of state when getting the sign of a floating-point value as an
/// integer.
struct FloatSignAsInt {
EVT FloatVT;
SDValue Chain;
SDValue FloatPtr;
SDValue IntPtr;
MachinePointerInfo IntPointerInfo;
MachinePointerInfo FloatPointerInfo;
SDValue IntValue;
APInt SignMask;
uint8_t SignBit;
};

//===----------------------------------------------------------------------===//
/// This takes an arbitrary SelectionDAG as input and
/// hacks on it until the target machine can handle it. This involves
Expand Down Expand Up @@ -166,10 +152,6 @@ class SelectionDAGLegalize {
SDValue ExpandSCALAR_TO_VECTOR(SDNode *Node);
void ExpandDYNAMIC_STACKALLOC(SDNode *Node,
SmallVectorImpl<SDValue> &Results);
void getSignAsIntValue(FloatSignAsInt &State, const SDLoc &DL,
SDValue Value) const;
SDValue modifySignAsInt(const FloatSignAsInt &State, const SDLoc &DL,
SDValue NewIntValue) const;
SDValue ExpandFCOPYSIGN(SDNode *Node) const;
SDValue ExpandFABS(SDNode *Node) const;
SDValue ExpandFNEG(SDNode *Node) const;
Expand Down Expand Up @@ -1620,82 +1602,14 @@ SDValue SelectionDAGLegalize::ExpandVectorBuildThroughStack(SDNode* Node) {
return DAG.getLoad(VT, dl, StoreChain, FIPtr, PtrInfo);
}

/// Bitcast a floating-point value to an integer value. Only bitcast the part
/// containing the sign bit if the target has no integer value capable of
/// holding all bits of the floating-point value.
void SelectionDAGLegalize::getSignAsIntValue(FloatSignAsInt &State,
const SDLoc &DL,
SDValue Value) const {
EVT FloatVT = Value.getValueType();
unsigned NumBits = FloatVT.getScalarSizeInBits();
State.FloatVT = FloatVT;
EVT IVT = EVT::getIntegerVT(*DAG.getContext(), NumBits);
// Convert to an integer of the same size.
if (TLI.isTypeLegal(IVT)) {
State.IntValue = DAG.getNode(ISD::BITCAST, DL, IVT, Value);
State.SignMask = APInt::getSignMask(NumBits);
State.SignBit = NumBits - 1;
return;
}

auto &DataLayout = DAG.getDataLayout();
// Store the float to memory, then load the sign part out as an integer.
MVT LoadTy = TLI.getRegisterType(MVT::i8);
// First create a temporary that is aligned for both the load and store.
SDValue StackPtr = DAG.CreateStackTemporary(FloatVT, LoadTy);
int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
// Then store the float to it.
State.FloatPtr = StackPtr;
MachineFunction &MF = DAG.getMachineFunction();
State.FloatPointerInfo = MachinePointerInfo::getFixedStack(MF, FI);
State.Chain = DAG.getStore(DAG.getEntryNode(), DL, Value, State.FloatPtr,
State.FloatPointerInfo);

SDValue IntPtr;
if (DataLayout.isBigEndian()) {
assert(FloatVT.isByteSized() && "Unsupported floating point type!");
// Load out a legal integer with the same sign bit as the float.
IntPtr = StackPtr;
State.IntPointerInfo = State.FloatPointerInfo;
} else {
// Advance the pointer so that the loaded byte will contain the sign bit.
unsigned ByteOffset = (NumBits / 8) - 1;
IntPtr =
DAG.getMemBasePlusOffset(StackPtr, TypeSize::getFixed(ByteOffset), DL);
State.IntPointerInfo = MachinePointerInfo::getFixedStack(MF, FI,
ByteOffset);
}

State.IntPtr = IntPtr;
State.IntValue = DAG.getExtLoad(ISD::EXTLOAD, DL, LoadTy, State.Chain, IntPtr,
State.IntPointerInfo, MVT::i8);
State.SignMask = APInt::getOneBitSet(LoadTy.getScalarSizeInBits(), 7);
State.SignBit = 7;
}

/// Replace the integer value produced by getSignAsIntValue() with a new value
/// and cast the result back to a floating-point type.
SDValue SelectionDAGLegalize::modifySignAsInt(const FloatSignAsInt &State,
const SDLoc &DL,
SDValue NewIntValue) const {
if (!State.Chain)
return DAG.getNode(ISD::BITCAST, DL, State.FloatVT, NewIntValue);

// Override the part containing the sign bit in the value stored on the stack.
SDValue Chain = DAG.getTruncStore(State.Chain, DL, NewIntValue, State.IntPtr,
State.IntPointerInfo, MVT::i8);
return DAG.getLoad(State.FloatVT, DL, Chain, State.FloatPtr,
State.FloatPointerInfo);
}

SDValue SelectionDAGLegalize::ExpandFCOPYSIGN(SDNode *Node) const {
SDLoc DL(Node);
SDValue Mag = Node->getOperand(0);
SDValue Sign = Node->getOperand(1);

// Get sign bit into an integer value.
FloatSignAsInt SignAsInt;
getSignAsIntValue(SignAsInt, DL, Sign);
DAG.getSignAsIntValue(SignAsInt, DL, Sign);

EVT IntVT = SignAsInt.IntValue.getValueType();
SDValue SignMask = DAG.getConstant(SignAsInt.SignMask, DL, IntVT);
Expand All @@ -1716,7 +1630,7 @@ SDValue SelectionDAGLegalize::ExpandFCOPYSIGN(SDNode *Node) const {

// Transform Mag value to integer, and clear the sign bit.
FloatSignAsInt MagAsInt;
getSignAsIntValue(MagAsInt, DL, Mag);
DAG.getSignAsIntValue(MagAsInt, DL, Mag);
EVT MagVT = MagAsInt.IntValue.getValueType();
SDValue ClearSignMask = DAG.getConstant(~MagAsInt.SignMask, DL, MagVT);
SDValue ClearedSign = DAG.getNode(ISD::AND, DL, MagVT, MagAsInt.IntValue,
Expand Down Expand Up @@ -1746,14 +1660,14 @@ SDValue SelectionDAGLegalize::ExpandFCOPYSIGN(SDNode *Node) const {
SDValue CopiedSign = DAG.getNode(ISD::OR, DL, MagVT, ClearedSign, SignBit,
SDNodeFlags::Disjoint);

return modifySignAsInt(MagAsInt, DL, CopiedSign);
return DAG.modifySignAsInt(MagAsInt, DL, CopiedSign);
}

SDValue SelectionDAGLegalize::ExpandFNEG(SDNode *Node) const {
// Get the sign bit as an integer.
SDLoc DL(Node);
FloatSignAsInt SignAsInt;
getSignAsIntValue(SignAsInt, DL, Node->getOperand(0));
DAG.getSignAsIntValue(SignAsInt, DL, Node->getOperand(0));
EVT IntVT = SignAsInt.IntValue.getValueType();

// Flip the sign.
Expand All @@ -1762,7 +1676,7 @@ SDValue SelectionDAGLegalize::ExpandFNEG(SDNode *Node) const {
DAG.getNode(ISD::XOR, DL, IntVT, SignAsInt.IntValue, SignMask);

// Convert back to float.
return modifySignAsInt(SignAsInt, DL, SignFlip);
return DAG.modifySignAsInt(SignAsInt, DL, SignFlip);
}

SDValue SelectionDAGLegalize::ExpandFABS(SDNode *Node) const {
Expand All @@ -1778,12 +1692,12 @@ SDValue SelectionDAGLegalize::ExpandFABS(SDNode *Node) const {

// Transform value to integer, clear the sign bit and transform back.
FloatSignAsInt ValueAsInt;
getSignAsIntValue(ValueAsInt, DL, Value);
DAG.getSignAsIntValue(ValueAsInt, DL, Value);
EVT IntVT = ValueAsInt.IntValue.getValueType();
SDValue ClearSignMask = DAG.getConstant(~ValueAsInt.SignMask, DL, IntVT);
SDValue ClearedSign = DAG.getNode(ISD::AND, DL, IntVT, ValueAsInt.IntValue,
ClearSignMask);
return modifySignAsInt(ValueAsInt, DL, ClearedSign);
return DAG.modifySignAsInt(ValueAsInt, DL, ClearedSign);
}

void SelectionDAGLegalize::ExpandDYNAMIC_STACKALLOC(SDNode* Node,
Expand Down
60 changes: 60 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2760,6 +2760,66 @@ SDValue SelectionDAG::CreateStackTemporary(EVT VT1, EVT VT2) {
return CreateStackTemporary(Bytes, Align);
}

void SelectionDAG::getSignAsIntValue(FloatSignAsInt &State, const SDLoc &DL,
SDValue Value) {
EVT FloatVT = Value.getValueType();
unsigned NumBits = FloatVT.getScalarSizeInBits();
State.FloatVT = FloatVT;
EVT IVT = FloatVT.changeTypeToInteger();
// Convert to an integer of the same size.
if (TLI->isTypeLegal(IVT)) {
State.IntValue = getNode(ISD::BITCAST, DL, IVT, Value);
State.SignMask = APInt::getSignMask(NumBits);
State.SignBit = NumBits - 1;
return;
}

auto &DataLayout = getDataLayout();
// Store the float to memory, then load the sign part out as an integer.
MVT LoadTy = TLI->getRegisterType(MVT::i8);
// First create a temporary that is aligned for both the load and store.
SDValue StackPtr = CreateStackTemporary(FloatVT, LoadTy);
int FI = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
// Then store the float to it.
State.FloatPtr = StackPtr;
MachineFunction &MF = getMachineFunction();
State.FloatPointerInfo = MachinePointerInfo::getFixedStack(MF, FI);
State.Chain = getStore(getEntryNode(), DL, Value, State.FloatPtr,
State.FloatPointerInfo);

SDValue IntPtr;
if (DataLayout.isBigEndian()) {
assert(FloatVT.isByteSized() && "Unsupported floating point type!");
// Load out a legal integer with the same sign bit as the float.
IntPtr = StackPtr;
State.IntPointerInfo = State.FloatPointerInfo;
} else {
// Advance the pointer so that the loaded byte will contain the sign bit.
unsigned ByteOffset = (NumBits / 8) - 1;
IntPtr = getMemBasePlusOffset(StackPtr, TypeSize::getFixed(ByteOffset), DL);
State.IntPointerInfo =
MachinePointerInfo::getFixedStack(MF, FI, ByteOffset);
}

State.IntPtr = IntPtr;
State.IntValue = getExtLoad(ISD::EXTLOAD, DL, LoadTy, State.Chain, IntPtr,
State.IntPointerInfo, MVT::i8);
State.SignMask = APInt::getOneBitSet(LoadTy.getScalarSizeInBits(), 7);
State.SignBit = 7;
}

SDValue SelectionDAG::modifySignAsInt(const FloatSignAsInt &State,
const SDLoc &DL, SDValue NewIntValue) {
if (!State.Chain)
return getNode(ISD::BITCAST, DL, State.FloatVT, NewIntValue);

// Override the part containing the sign bit in the value stored on the stack.
SDValue Chain = getTruncStore(State.Chain, DL, NewIntValue, State.IntPtr,
State.IntPointerInfo, MVT::i8);
return getLoad(State.FloatVT, DL, Chain, State.FloatPtr,
State.FloatPointerInfo);
}

SDValue SelectionDAG::FoldSetCC(EVT VT, SDValue N1, SDValue N2,
ISD::CondCode Cond, const SDLoc &dl) {
EVT OpVT = N1.getValueType();
Expand Down
18 changes: 10 additions & 8 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8610,16 +8610,18 @@ SDValue TargetLowering::expandFMINIMUM_FMAXIMUM(SDNode *N,
// fminimum/fmaximum requires -0.0 less than +0.0
if (!MinMaxMustRespectOrderedZero && !N->getFlags().hasNoSignedZeros() &&
!DAG.isKnownNeverZeroFloat(RHS) && !DAG.isKnownNeverZeroFloat(LHS)) {
auto IsSpecificZero = [&](SDValue F) {
FloatSignAsInt State;
DAG.getSignAsIntValue(State, DL, F);
return DAG.getSetCC(DL, CCVT, State.IntValue,
DAG.getConstant(0, DL, State.IntValue.getValueType()),
IsMax ? ISD::SETEQ : ISD::SETNE);
};
SDValue IsZero = DAG.getSetCC(DL, CCVT, MinMax,
DAG.getConstantFP(0.0, DL, VT), ISD::SETOEQ);
SDValue TestZero =
DAG.getTargetConstant(IsMax ? fcPosZero : fcNegZero, DL, MVT::i32);
SDValue LCmp = DAG.getSelect(
DL, VT, DAG.getNode(ISD::IS_FPCLASS, DL, CCVT, LHS, TestZero), LHS,
MinMax, Flags);
SDValue RCmp = DAG.getSelect(
DL, VT, DAG.getNode(ISD::IS_FPCLASS, DL, CCVT, RHS, TestZero), RHS,
LCmp, Flags);
SDValue LCmp =
DAG.getSelect(DL, VT, IsSpecificZero(LHS), LHS, MinMax, Flags);
SDValue RCmp = DAG.getSelect(DL, VT, IsSpecificZero(RHS), RHS, LCmp, Flags);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, we need only one IsSpecificZero here.
When Max, if LHS is positive, we can select it.
When Min, if LHS is positive, we can just select MinMax, since when have already preferred RHS if LHS==RHS.

Copy link
Contributor Author

@nikic nikic Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I submitted #142732 to drop one the comparisons.

When Min, if LHS is positive, we can just select MinMax, since when have already preferred RHS if LHS==RHS.

This is only guaranteed if the main legalization used the select form, rather than one of the min/max opcodes. But we also don't need to rely on it as we can directly select RHS instead of MinMax.

MinMax = DAG.getSelect(DL, VT, IsZero, RCmp, MinMax, Flags);
}

Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/AArch64/fmaximum-legalization.ll
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,59 @@ define <4 x half> @fmaximum_v4f16(<4 x half> %x, <4 x half> %y) {
%r = call <4 x half> @llvm.maximum.v4f16(<4 x half> %x, <4 x half> %y)
ret <4 x half> %r
}

define fp128 @maximum_fp128(fp128 %x, fp128 %y) nounwind {
; CHECK-LABEL: maximum_fp128:
; CHECK: // %bb.0:
; CHECK-NEXT: sub sp, sp, #96
; CHECK-NEXT: str x30, [sp, #80] // 8-byte Folded Spill
; CHECK-NEXT: stp q0, q1, [sp] // 32-byte Folded Spill
; CHECK-NEXT: stp q1, q0, [sp, #48]
; CHECK-NEXT: bl __gttf2
; CHECK-NEXT: ldp q0, q1, [sp] // 32-byte Folded Reload
; CHECK-NEXT: cmp w0, #0
; CHECK-NEXT: b.le .LBB1_2
; CHECK-NEXT: // %bb.1:
; CHECK-NEXT: mov v1.16b, v0.16b
; CHECK-NEXT: .LBB1_2:
; CHECK-NEXT: str q1, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: ldr q1, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: bl __unordtf2
; CHECK-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: cmp w0, #0
; CHECK-NEXT: b.eq .LBB1_4
; CHECK-NEXT: // %bb.3:
; CHECK-NEXT: adrp x8, .LCPI1_0
; CHECK-NEXT: ldr q0, [x8, :lo12:.LCPI1_0]
; CHECK-NEXT: .LBB1_4:
; CHECK-NEXT: ldrb w8, [sp, #79]
; CHECK-NEXT: mov v1.16b, v0.16b
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: b.ne .LBB1_6
; CHECK-NEXT: // %bb.5:
; CHECK-NEXT: ldr q1, [sp] // 16-byte Folded Reload
; CHECK-NEXT: .LBB1_6:
; CHECK-NEXT: ldrb w8, [sp, #63]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: b.ne .LBB1_8
; CHECK-NEXT: // %bb.7:
; CHECK-NEXT: ldr q1, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: .LBB1_8:
; CHECK-NEXT: adrp x8, .LCPI1_1
; CHECK-NEXT: str q0, [sp, #32] // 16-byte Folded Spill
; CHECK-NEXT: str q1, [sp, #16] // 16-byte Folded Spill
; CHECK-NEXT: ldr q1, [x8, :lo12:.LCPI1_1]
; CHECK-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: bl __eqtf2
; CHECK-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
; CHECK-NEXT: cmp w0, #0
; CHECK-NEXT: b.ne .LBB1_10
; CHECK-NEXT: // %bb.9:
; CHECK-NEXT: ldr q0, [sp, #16] // 16-byte Folded Reload
; CHECK-NEXT: .LBB1_10:
; CHECK-NEXT: ldr x30, [sp, #80] // 8-byte Folded Reload
; CHECK-NEXT: add sp, sp, #96
; CHECK-NEXT: ret
%res = call fp128 @llvm.maximum.f128(fp128 %x, fp128 %y)
ret fp128 %res
}
Loading
Loading