@@ -164,6 +164,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
164164 addRegisterClass(MVT::v2bf16, &AMDGPU::SReg_32RegClass);
165165 addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
166166 addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
167+ addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
167168 addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
168169 addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
169170 addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
@@ -312,10 +313,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
312313 {MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
313314 MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
314315 MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
315- MVT::v4f16, MVT::v3i64, MVT::v3f64 , MVT::v6i32 , MVT::v6f32 ,
316- MVT::v4i64 , MVT::v4f64 , MVT::v8i64 , MVT::v8f64 , MVT::v8i16 ,
317- MVT::v8f16 , MVT::v16i16, MVT::v16f16 , MVT::v16i64 , MVT::v16f64 ,
318- MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
316+ MVT::v4f16, MVT::v4bf16, MVT::v3i64 , MVT::v3f64 , MVT::v6i32 ,
317+ MVT::v6f32 , MVT::v4i64 , MVT::v4f64 , MVT::v8i64 , MVT::v8f64 ,
318+ MVT::v8i16 , MVT::v8f16, MVT::v16i16 , MVT::v16f16 , MVT::v16i64 ,
319+ MVT::v16f64, MVT:: v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
319320 for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
320321 switch (Op) {
321322 case ISD::LOAD:
@@ -421,13 +422,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
421422 {MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32},
422423 Expand);
423424
424- setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16}, Custom);
425+ setOperationAction(ISD::BUILD_VECTOR, {MVT::v4f16, MVT::v4i16, MVT::v4bf16},
426+ Custom);
425427
426428 // Avoid stack access for these.
427429 // TODO: Generalize to more vector types.
428430 setOperationAction({ISD::EXTRACT_VECTOR_ELT, ISD::INSERT_VECTOR_ELT},
429431 {MVT::v2i16, MVT::v2f16, MVT::v2bf16, MVT::v2i8, MVT::v4i8,
430- MVT::v8i8, MVT::v4i16, MVT::v4f16},
432+ MVT::v8i8, MVT::v4i16, MVT::v4f16, MVT::v4bf16 },
431433 Custom);
432434
433435 // Deal with vec3 vector operations when widened to vec4.
@@ -667,11 +669,15 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
667669 AddPromotedToType(ISD::LOAD, MVT::v4i16, MVT::v2i32);
668670 setOperationAction(ISD::LOAD, MVT::v4f16, Promote);
669671 AddPromotedToType(ISD::LOAD, MVT::v4f16, MVT::v2i32);
672+ setOperationAction(ISD::LOAD, MVT::v4bf16, Promote);
673+ AddPromotedToType(ISD::LOAD, MVT::v4bf16, MVT::v2i32);
670674
671675 setOperationAction(ISD::STORE, MVT::v4i16, Promote);
672676 AddPromotedToType(ISD::STORE, MVT::v4i16, MVT::v2i32);
673677 setOperationAction(ISD::STORE, MVT::v4f16, Promote);
674678 AddPromotedToType(ISD::STORE, MVT::v4f16, MVT::v2i32);
679+ setOperationAction(ISD::STORE, MVT::v4bf16, Promote);
680+ AddPromotedToType(ISD::STORE, MVT::v4bf16, MVT::v2i32);
675681
676682 setOperationAction(ISD::LOAD, MVT::v8i16, Promote);
677683 AddPromotedToType(ISD::LOAD, MVT::v8i16, MVT::v4i32);
@@ -781,7 +787,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
781787 Custom);
782788
783789 setOperationAction(ISD::FEXP, MVT::v2f16, Custom);
784- setOperationAction(ISD::SELECT, {MVT::v4i16, MVT::v4f16}, Custom);
790+ setOperationAction(ISD::SELECT, {MVT::v4i16, MVT::v4f16, MVT::v4bf16},
791+ Custom);
785792
786793 if (Subtarget->hasPackedFP32Ops()) {
787794 setOperationAction({ISD::FADD, ISD::FMUL, ISD::FMA, ISD::FNEG},
@@ -6805,7 +6812,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
68056812 SDLoc SL(Op);
68066813 EVT VT = Op.getValueType();
68076814
6808- if (VT == MVT::v4i16 || VT == MVT::v4f16 ||
6815+ if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
68096816 VT == MVT::v8i16 || VT == MVT::v8f16) {
68106817 EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
68116818 VT.getVectorNumElements() / 2);
@@ -6871,7 +6878,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
68716878 return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
68726879 }
68736880
6874- assert(VT == MVT::v2f16 || VT == MVT::v2i16);
6881+ assert(VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16 );
68756882 assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
68766883
68776884 SDValue Lo = Op.getOperand(0);
0 commit comments