Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoongArch] Pass 'half' in the lower 16 bits of an f32 value with F/D ABI #109368

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 140 additions & 6 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
setOperationAction(ISD::FPOW, MVT::f32, Expand);
setOperationAction(ISD::FREM, MVT::f32, Expand);
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Custom);

if (Subtarget.is64Bit())
setOperationAction(ISD::FRINT, MVT::f32, Legal);
Expand Down Expand Up @@ -219,7 +219,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FPOW, MVT::f64, Expand);
setOperationAction(ISD::FREM, MVT::f64, Expand);
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);

if (Subtarget.is64Bit())
setOperationAction(ISD::FRINT, MVT::f64, Legal);
Expand Down Expand Up @@ -427,6 +427,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
return lowerBUILD_VECTOR(Op, DAG);
case ISD::VECTOR_SHUFFLE:
return lowerVECTOR_SHUFFLE(Op, DAG);
case ISD::FP_TO_FP16:
return lowerFP_TO_FP16(Op, DAG);
case ISD::FP16_TO_FP:
return lowerFP16_TO_FP(Op, DAG);
}
return SDValue();
}
Expand Down Expand Up @@ -1354,6 +1358,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
return SDValue();
}

SDValue LoongArchTargetLowering::lowerFP_TO_FP16(SDValue Op,
SelectionDAG &DAG) const {
// Custom lower to ensure the libcall return is passed in an FPR on hard
// float ABIs.
SDLoc DL(Op);
MakeLibCallOptions CallOptions;
SDValue Op0 = Op.getOperand(0);
SDValue Chain = SDValue();
RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
SDValue Res;
std::tie(Res, Chain) =
makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL, Chain);
if (Subtarget.is64Bit())
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Res);
return DAG.getBitcast(MVT::i32, Res);
}

SDValue LoongArchTargetLowering::lowerFP16_TO_FP(SDValue Op,
SelectionDAG &DAG) const {
// Custom lower to ensure the libcall argument is passed in an FPR on hard
// float ABIs.
Comment on lines +1380 to +1381
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need to custom lower the casts to change the ABI. Are you trying to special case the ABI for this one call in this one instance? That seems bad

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not trying to special case the ABI, but rather ensuring compliance with the ABI rules for floating-point operations. Specifically, the argument for f32 __gnu_h2f_ieee(f16) needs to be passed via the FPR, as per the floating-point ABI, rather than the GPR. Custom lowering ensures that the argument is correctly passed through the FPR in cases where the default behavior doesn't align with this requirement. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change does not accomplish this. The cast opcodes have nothing to do with the ABI, other than ABI code may result in inserting them

Copy link
Member Author

@heiher heiher Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default behavior of softPromoteHalf does not align with the expectations of the architecture, the custom lowering as referenced by the approach used in RISC-V.

Ref: https://reviews.llvm.org/D151284

This legalisation produces ISD::FP_TO_FP16 and ISD::FP16_TO_FP nodes which (as described in ISDOpcodes.h) provide a "semi-softened interface for dealing with f16 (as an i16)". i.e. the return type of the FP_TO_FP16 is an integer rather than a float (and the arg of FP16_TO_FP is an integer). The remainder of the description focuses primarily on FP_TO_FP16 for ease of explanation.

In Rust's implementation, the argument of __gnu_h2f_ieee is f16, not i16.

https://github.com/rust-lang/compiler-builtins/blob/compiler_builtins-v0.1.126/src/float/extend.rs#L95-L97

pub extern "C" fn __gnu_h2f_ieee(a: f16) -> f32 {
    extend(a)
}

Is there a better approach to achieve this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words, you are special casing the ABI for this one libcall that happens to be used for legalization of the conversions. If you are fixing the ABI, as the title suggests, you need to make changes to the calling convention lowering code and possibly use addRegisterClass for f16.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default the half is going to be be pre-promoted to a legal f32 type and you need to intervene before that. I would start by overriding getRegisterTypeForCallingConv, and then see how that goes. You may need to just custom hack on the argument lists in each of the Lower* functions.

Yes SelectionDAG makes this more difficult than it should be. It would be much easier if calling convention code operated on the raw IR signature instead of going through type legalization first

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your guidance. I haven't quite managed to get it done yet. To pass fp16 in FPR (excluding libcall), splitValueIntoRegisterParts and joinRegisterPartsIntoValue insert all ones in the upper bits and extract the lower bits by casting to an integer type. I guess this is the key point where the fp16 value is promoted to an integer when it reaches the FP16_TO_FP operation, but I'm not sure how to bypass the integer type to achieve this.

Additionally, it seems that custom-lowering FP16_TO_FP and FP_TO_FP16 to generate a libcall while keeping it passed in FPR works quite well and is fairly easy to implement, RISC-V is already using this approach. Can we go ahead with this? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really correct. It only happens to work out, and MOVFR2GR_S_LA64 is a hack

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Rust's implementation, the argument of __gnu_h2f_ieee is f16, not i16.

https://github.com/rust-lang/compiler-builtins/blob/compiler_builtins-v0.1.126/src/float/extend.rs#L95-L97

pub extern "C" fn __gnu_h2f_ieee(a: f16) -> f32 {
    extend(a)
}

Just to address this (quite late) - aiui __gnu_h2f_ieee is only called on platforms where f16 is passed as an integer so that doesn't matter here. If there is a float ABI then __extendhfsf2/__truncsfhf2 is used instead.

There is something weird with the return value though, Rust's compiler-builtins and LLVM's compiler-rt both use f32 but GCC uses int. (I think maybe GCC just never emits that libcall except on ARM, LLVM seems to use it a lot more).

Copy link
Member Author

@heiher heiher Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my view, __gnu_h2f_ieee is primarily designed for converting f16 values to f32, particularly in scenarios where hardware lacks native f16 support, requiring software emulation instead. This function does not inherently define any specific argument-passing conventions; rather, those are determined by the architecture's ABI. For instance, on RISC-V, the lower 16 bits of a floating point argument register are used (given that hardware support for f16 is not enabled), while some other architectures use integer registers. In terms of implementation, __extendhfsf2 serves as an alias for __gnu_h2f_ieee and are used by arm.

https://github.com/rust-lang/compiler-builtins/blob/compiler_builtins-v0.1.126/src/float/extend.rs#L87-L89

pub extern "C" fn __extendhfsf2(a: f16) -> f32 {
    extend(a)
}

https://github.com/rust-lang/compiler-builtins/blob/compiler_builtins-v0.1.126/src/float/extend.rs#L95-L97

pub extern "C" fn __gnu_h2f_ieee(a: f16) -> f32 {
    extend(a)
}

https://github.com/llvm/llvm-project/blob/llvmorg-19.1.2/compiler-rt/lib/builtins/extendhfsf2.c#L13-L19

// Use a forwarding definition and noinline to implement a poor man's alias,
// as there isn't a good cross-platform way of defining one.
COMPILER_RT_ABI NOINLINE float __extendhfsf2(src_t a) {
  return __extendXfYf2__(a);
}

COMPILER_RT_ABI float __gnu_h2f_ieee(src_t a) { return __extendhfsf2(a); }

SDLoc DL(Op);
MakeLibCallOptions CallOptions;
SDValue Op0 = Op.getOperand(0);
SDValue Chain = SDValue();
SDValue Arg = Subtarget.is64Bit() ? DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64,
DL, MVT::f32, Op0)
: DAG.getBitcast(MVT::f32, Op0);
SDValue Res;
std::tie(Res, Chain) = makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg,
CallOptions, DL, Chain);
return Res;
}

static bool isConstantOrUndef(const SDValue Op) {
if (Op->isUndef())
return true;
Expand Down Expand Up @@ -1656,16 +1694,19 @@ SDValue LoongArchTargetLowering::lowerFP_TO_SINT(SDValue Op,
SelectionDAG &DAG) const {

SDLoc DL(Op);
SDValue Op0 = Op.getOperand(0);

if (Op0.getValueType() == MVT::f16)
Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0);

if (Op.getValueSizeInBits() > 32 && Subtarget.hasBasicF() &&
!Subtarget.hasBasicD()) {
SDValue Dst =
DAG.getNode(LoongArchISD::FTINT, DL, MVT::f32, Op.getOperand(0));
SDValue Dst = DAG.getNode(LoongArchISD::FTINT, DL, MVT::f32, Op0);
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Dst);
}

EVT FPTy = EVT::getFloatingPointVT(Op.getValueSizeInBits());
SDValue Trunc = DAG.getNode(LoongArchISD::FTINT, DL, FPTy, Op.getOperand(0));
SDValue Trunc = DAG.getNode(LoongArchISD::FTINT, DL, FPTy, Op0);
return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Trunc);
}

Expand Down Expand Up @@ -2848,6 +2889,10 @@ void LoongArchTargetLowering::ReplaceNodeResults(
EVT FVT = EVT::getFloatingPointVT(N->getValueSizeInBits(0));
if (getTypeAction(*DAG.getContext(), Src.getValueType()) !=
TargetLowering::TypeSoftenFloat) {
if (!isTypeLegal(Src.getValueType()))
return;
if (Src.getValueType() == MVT::f16)
Src = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
SDValue Dst = DAG.getNode(LoongArchISD::FTINT, DL, FVT, Src);
Results.push_back(DAG.getNode(ISD::BITCAST, DL, VT, Dst));
return;
Expand Down Expand Up @@ -4229,6 +4274,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue performMOVGR2FR_WCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
// If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
// conversion is unnecessary and can be replaced with the
// MOVFR2GR_S_LA64 operand.
SDValue Op0 = N->getOperand(0);
if (Op0.getOpcode() == LoongArchISD::MOVFR2GR_S_LA64)
return Op0.getOperand(0);
return SDValue();
}

static SDValue performMOVFR2GR_SCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
// If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
// conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
// operand.
SDValue Op0 = N->getOperand(0);
MVT VT = N->getSimpleValueType(0);
if (Op0->getOpcode() == LoongArchISD::MOVGR2FR_W_LA64) {
assert(Op0.getOperand(0).getValueType() == VT && "Unexpected value type!");
return Op0.getOperand(0);
}
return SDValue();
}

SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand All @@ -4247,6 +4319,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
return performBITREV_WCombine(N, DAG, DCI, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
return performINTRINSIC_WO_CHAINCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::MOVGR2FR_W_LA64:
return performMOVGR2FR_WCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::MOVFR2GR_S_LA64:
return performMOVFR2GR_SCombine(N, DAG, DCI, Subtarget);
}
return SDValue();
}
Expand Down Expand Up @@ -6260,3 +6336,61 @@ bool LoongArchTargetLowering::shouldAlignPointerArgs(CallInst *CI,

return true;
}

bool LoongArchTargetLowering::splitValueIntoRegisterParts(
SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();
EVT ValueVT = Val.getValueType();

if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
// Cast the f16 to i16, extend to i32, pad with ones to make a float
// nan, and cast to f32.
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
DAG.getConstant(0xFFFF0000, DL, MVT::i32));
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
Parts[0] = Val;
return true;
}

return false;
}

SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue(
SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
bool IsABIRegCopy = CC.has_value();

if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
SDValue Val = Parts[0];

// Cast the f32 to i32, truncate to i16, and cast back to f16.
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
return Val;
}

return SDValue();
}

MVT LoongArchTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
CallingConv::ID CC,
EVT VT) const {
// Use f32 to pass f16.
if (VT == MVT::f16 && Subtarget.hasBasicF())
return MVT::f32;

return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}

unsigned LoongArchTargetLowering::getNumRegistersForCallingConv(
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
// Use f32 to pass f16.
if (VT == MVT::f16 && Subtarget.hasBasicF())
return 1;

return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
}
24 changes: 24 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ class LoongArchTargetLowering : public TargetLowering {
SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;

bool isFPImmLegal(const APFloat &Imm, EVT VT,
bool ForCodeSize) const override;
Expand All @@ -339,6 +341,28 @@ class LoongArchTargetLowering : public TargetLowering {
const SmallVectorImpl<CCValAssign> &ArgLocs) const;

bool softPromoteHalfType() const override { return true; }

bool
splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
SDValue *Parts, unsigned NumParts, MVT PartVT,
std::optional<CallingConv::ID> CC) const override;

SDValue
joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL,
const SDValue *Parts, unsigned NumParts,
MVT PartVT, EVT ValueVT,
std::optional<CallingConv::ID> CC) const override;

/// Return the register type for a given MVT, ensuring vectors are treated
/// as a series of gpr sized integers.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
EVT VT) const override;

/// Return the number of registers for a given MVT, ensuring vectors are
/// treated as a series of gpr sized integers.
unsigned getNumRegistersForCallingConv(LLVMContext &Context,
CallingConv::ID CC,
EVT VT) const override;
};

} // end namespace llvm
Expand Down
Loading
Loading