@@ -757,7 +757,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
757
757
setOperationAction(ISD::FABS, MVT::v2f16, Legal);
758
758
759
759
// Can do this in one BFI plus a constant materialize.
760
- setOperationAction(ISD::FCOPYSIGN, {MVT::v2f16, MVT::v2bf16}, Custom);
760
+ setOperationAction(ISD::FCOPYSIGN,
761
+ {MVT::v2f16, MVT::v2bf16, MVT::v4f16, MVT::v4bf16},
762
+ Custom);
761
763
762
764
setOperationAction({ISD::FMAXNUM, ISD::FMINNUM}, MVT::f16, Custom);
763
765
setOperationAction({ISD::FMAXNUM_IEEE, ISD::FMINNUM_IEEE}, MVT::f16, Legal);
@@ -5936,10 +5938,11 @@ SDValue SITargetLowering::splitBinaryVectorOp(SDValue Op,
5936
5938
SelectionDAG &DAG) const {
5937
5939
unsigned Opc = Op.getOpcode();
5938
5940
EVT VT = Op.getValueType();
5939
- assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4f32 ||
5940
- VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v16i16 ||
5941
- VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
5942
- VT == MVT::v32f32 || VT == MVT::v32i16 || VT == MVT::v32f16);
5941
+ assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
5942
+ VT == MVT::v4f32 || VT == MVT::v8i16 || VT == MVT::v8f16 ||
5943
+ VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v8f32 ||
5944
+ VT == MVT::v16f32 || VT == MVT::v32f32 || VT == MVT::v32i16 ||
5945
+ VT == MVT::v32f16);
5943
5946
5944
5947
auto [Lo0, Hi0] = DAG.SplitVectorOperand(Op.getNode(), 0);
5945
5948
auto [Lo1, Hi1] = DAG.SplitVectorOperand(Op.getNode(), 1);
@@ -7122,18 +7125,17 @@ SDValue SITargetLowering::promoteUniformOpToI32(SDValue Op,
7122
7125
7123
7126
SDValue SITargetLowering::lowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const {
7124
7127
SDValue Mag = Op.getOperand(0);
7125
- SDValue Sign = Op.getOperand(1);
7126
-
7127
7128
EVT MagVT = Mag.getValueType();
7128
- EVT SignVT = Sign.getValueType();
7129
7129
7130
- assert(MagVT.isVector());
7130
+ if (MagVT.getVectorNumElements() > 2)
7131
+ return splitBinaryVectorOp(Op, DAG);
7132
+
7133
+ SDValue Sign = Op.getOperand(1);
7134
+ EVT SignVT = Sign.getValueType();
7131
7135
7132
7136
if (MagVT == SignVT)
7133
7137
return Op;
7134
7138
7135
- assert(MagVT.getVectorNumElements() == 2);
7136
-
7137
7139
// fcopysign v2f16:mag, v2f32:sign ->
7138
7140
// fcopysign v2f16:mag, bitcast (trunc (bitcast sign to v2i32) to v2i16)
7139
7141
0 commit comments