Skip to content

Commit 20ad420

Browse files
authored
AMDGPU: Improve v4f16/v4bf16 copysign handling (#142174)
1 parent 4aa4005 commit 20ad420

File tree

3 files changed

+2010
-2478
lines changed

3 files changed

+2010
-2478
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
757757
setOperationAction(ISD::FABS, MVT::v2f16, Legal);
758758

759759
// Can do this in one BFI plus a constant materialize.
760-
setOperationAction(ISD::FCOPYSIGN, {MVT::v2f16, MVT::v2bf16}, Custom);
760+
setOperationAction(ISD::FCOPYSIGN,
761+
{MVT::v2f16, MVT::v2bf16, MVT::v4f16, MVT::v4bf16},
762+
Custom);
761763

762764
setOperationAction({ISD::FMAXNUM, ISD::FMINNUM}, MVT::f16, Custom);
763765
setOperationAction({ISD::FMAXNUM_IEEE, ISD::FMINNUM_IEEE}, MVT::f16, Legal);
@@ -5936,10 +5938,11 @@ SDValue SITargetLowering::splitBinaryVectorOp(SDValue Op,
59365938
SelectionDAG &DAG) const {
59375939
unsigned Opc = Op.getOpcode();
59385940
EVT VT = Op.getValueType();
5939-
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4f32 ||
5940-
VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i16 ||
5941-
VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5942-
VT == MVT::v32f32 || VT == MVT::v32i16 || VT == MVT::v32f16);
5941+
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
5942+
VT == MVT::v4f32 || VT == MVT::v8i16 || VT == MVT::v8f16 ||
5943+
VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v8f32 ||
5944+
VT == MVT::v16f32 || VT == MVT::v32f32 || VT == MVT::v32i16 ||
5945+
VT == MVT::v32f16);
59435946

59445947
auto [Lo0, Hi0] = DAG.SplitVectorOperand(Op.getNode(), 0);
59455948
auto [Lo1, Hi1] = DAG.SplitVectorOperand(Op.getNode(), 1);
@@ -7122,18 +7125,17 @@ SDValue SITargetLowering::promoteUniformOpToI32(SDValue Op,
71227125

71237126
SDValue SITargetLowering::lowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const {
71247127
SDValue Mag = Op.getOperand(0);
7125-
SDValue Sign = Op.getOperand(1);
7126-
71277128
EVT MagVT = Mag.getValueType();
7128-
EVT SignVT = Sign.getValueType();
71297129

7130-
assert(MagVT.isVector());
7130+
if (MagVT.getVectorNumElements() > 2)
7131+
return splitBinaryVectorOp(Op, DAG);
7132+
7133+
SDValue Sign = Op.getOperand(1);
7134+
EVT SignVT = Sign.getValueType();
71317135

71327136
if (MagVT == SignVT)
71337137
return Op;
71347138

7135-
assert(MagVT.getVectorNumElements() == 2);
7136-
71377139
// fcopysign v2f16:mag, v2f32:sign ->
71387140
// fcopysign v2f16:mag, bitcast (trunc (bitcast sign to v2i32) to v2i16)
71397141

0 commit comments

Comments
 (0)