Skip to content

Commit

Permalink
[AArch64] Add some basic handling for bf16 constants.
Browse files Browse the repository at this point in the history
This adds some basic handling for bf16 constants, attempting to treat them a
lot like fp16 constants where it can. Zero immediates get lowered to FMOVH0,
others either get lowered to FMOVWHr(MOVi32imm) or use FMOVHi if they can.
Without fp16 they get expanded. This may not always be optimal, but fixes a gap
in our lowering. See llvm/test/CodeGen/AArch64/f16-imm.ll for the equivalent
fp16 test.

Differential Revision: https://reviews.llvm.org/D156649
  • Loading branch information
davemgreen committed Jul 31, 2023
1 parent e9bdf4a commit 778fa4e
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 4 deletions.
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ void TargetLoweringBase::initActions() {
// Legal, in which case all fp constants are legal, or use isFPImmLegal()
// to optimize expansions for certain constants.
setOperationAction(ISD::ConstantFP,
{MVT::f16, MVT::f32, MVT::f64, MVT::f80, MVT::f128},
{MVT::bf16, MVT::f16, MVT::f32, MVT::f64, MVT::f80, MVT::f128},
Expand);

// These library functions default to expand.
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

if (Subtarget->hasFullFP16()) {
setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);

setOperationAction(ISD::SINT_TO_FP, MVT::v8i8, Custom);
setOperationAction(ISD::UINT_TO_FP, MVT::v8i8, Custom);
Expand Down Expand Up @@ -9789,7 +9790,7 @@ bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero();
else if (VT == MVT::f32)
IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero();
else if (VT == MVT::f16)
else if (VT == MVT::f16 || VT == MVT::bf16)
IsLegal =
(Subtarget->hasFullFP16() && AArch64_AM::getFP16Imm(ImmInt) != -1) ||
Imm.isPosZero();
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,11 @@ def fpimm16 : Operand<f16>,
let PrintMethod = "printFPImmOperand";
}

def fpimmbf16 : Operand<bf16>,
FPImmLeaf<bf16, [{
return AArch64_AM::getFP16Imm(Imm) != -1;
}], fpimm16XForm>;

def fpimm32 : Operand<f32>,
FPImmLeaf<f32, [{
return AArch64_AM::getFP32Imm(Imm) != -1;
Expand Down
16 changes: 14 additions & 2 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -4355,16 +4355,23 @@ def FMOVS0 : Pseudo<(outs FPR32:$Rd), (ins), [(set f32:$Rd, (fpimm0))]>,
def FMOVD0 : Pseudo<(outs FPR64:$Rd), (ins), [(set f64:$Rd, (fpimm0))]>,
Sched<[WriteF]>;
}

// Similarly add aliases
def : InstAlias<"fmov $Rd, #0.0", (FMOVWHr FPR16:$Rd, WZR), 0>,
Requires<[HasFullFP16]>;
def : InstAlias<"fmov $Rd, #0.0", (FMOVWSr FPR32:$Rd, WZR), 0>;
def : InstAlias<"fmov $Rd, #0.0", (FMOVXDr FPR64:$Rd, XZR), 0>;

// Pattern for FP16 immediates
def : Pat<(bf16 fpimm0),
(FMOVH0)>;

// Pattern for FP16 and BF16 immediates
let Predicates = [HasFullFP16] in {
def : Pat<(f16 fpimm:$in),
(FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 f16:$in)))>;
(FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 f16:$in)))>;

def : Pat<(bf16 fpimm:$in),
(FMOVWHr (MOVi32imm (bitcast_fpimm_to_i32 bf16:$in)))>;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4617,6 +4624,11 @@ let isReMaterializable = 1, isAsCheapAsAMove = 1 in {
defm FMOV : FPMoveImmediate<"fmov">;
}

let Predicates = [HasFullFP16] in {
def : Pat<(bf16 fpimmbf16:$in),
(FMOVHi (fpimm16XForm bf16:$in))>;
}

//===----------------------------------------------------------------------===//
// Advanced SIMD two vector instructions.
//===----------------------------------------------------------------------===//
Expand Down
121 changes: 121 additions & 0 deletions llvm/test/CodeGen/AArch64/bf16-imm.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=aarch64 -mattr=+fullfp16 | FileCheck %s --check-prefixes=CHECK,CHECK-FP16
; RUN: llc < %s -mtriple=aarch64 -mattr=-fullfp16 | FileCheck %s --check-prefixes=CHECK,CHECK-NOFP16

define bfloat @Const0() {
; CHECK-LABEL: Const0:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi d0, #0000000000000000
; CHECK-NEXT: ret
entry:
ret bfloat 0xR0000
}

define bfloat @Const1() {
; CHECK-FP16-LABEL: Const1:
; CHECK-FP16: // %bb.0: // %entry
; CHECK-FP16-NEXT: fmov h0, #1.00000000
; CHECK-FP16-NEXT: ret
;
; CHECK-NOFP16-LABEL: Const1:
; CHECK-NOFP16: // %bb.0: // %entry
; CHECK-NOFP16-NEXT: adrp x8, .LCPI1_0
; CHECK-NOFP16-NEXT: ldr h0, [x8, :lo12:.LCPI1_0]
; CHECK-NOFP16-NEXT: ret
entry:
ret bfloat 0xR3C00
}

define bfloat @Const2() {
; CHECK-FP16-LABEL: Const2:
; CHECK-FP16: // %bb.0: // %entry
; CHECK-FP16-NEXT: fmov h0, #0.12500000
; CHECK-FP16-NEXT: ret
;
; CHECK-NOFP16-LABEL: Const2:
; CHECK-NOFP16: // %bb.0: // %entry
; CHECK-NOFP16-NEXT: adrp x8, .LCPI2_0
; CHECK-NOFP16-NEXT: ldr h0, [x8, :lo12:.LCPI2_0]
; CHECK-NOFP16-NEXT: ret
entry:
ret bfloat 0xR3000
}

define bfloat @Const3() {
; CHECK-FP16-LABEL: Const3:
; CHECK-FP16: // %bb.0: // %entry
; CHECK-FP16-NEXT: fmov h0, #30.00000000
; CHECK-FP16-NEXT: ret
;
; CHECK-NOFP16-LABEL: Const3:
; CHECK-NOFP16: // %bb.0: // %entry
; CHECK-NOFP16-NEXT: adrp x8, .LCPI3_0
; CHECK-NOFP16-NEXT: ldr h0, [x8, :lo12:.LCPI3_0]
; CHECK-NOFP16-NEXT: ret
entry:
ret bfloat 0xR4F80
}

define bfloat @Const4() {
; CHECK-FP16-LABEL: Const4:
; CHECK-FP16: // %bb.0: // %entry
; CHECK-FP16-NEXT: fmov h0, #31.00000000
; CHECK-FP16-NEXT: ret
;
; CHECK-NOFP16-LABEL: Const4:
; CHECK-NOFP16: // %bb.0: // %entry
; CHECK-NOFP16-NEXT: adrp x8, .LCPI4_0
; CHECK-NOFP16-NEXT: ldr h0, [x8, :lo12:.LCPI4_0]
; CHECK-NOFP16-NEXT: ret
entry:
ret bfloat 0xR4FC0
}

define bfloat @Const5() {
; CHECK-FP16-LABEL: Const5:
; CHECK-FP16: // %bb.0: // %entry
; CHECK-FP16-NEXT: mov w8, #12272 // =0x2ff0
; CHECK-FP16-NEXT: fmov h0, w8
; CHECK-FP16-NEXT: ret
;
; CHECK-NOFP16-LABEL: Const5:
; CHECK-NOFP16: // %bb.0: // %entry
; CHECK-NOFP16-NEXT: adrp x8, .LCPI5_0
; CHECK-NOFP16-NEXT: ldr h0, [x8, :lo12:.LCPI5_0]
; CHECK-NOFP16-NEXT: ret
entry:
ret bfloat 0xR2FF0
}

define bfloat @Const6() {
; CHECK-FP16-LABEL: Const6:
; CHECK-FP16: // %bb.0: // %entry
; CHECK-FP16-NEXT: mov w8, #20417 // =0x4fc1
; CHECK-FP16-NEXT: fmov h0, w8
; CHECK-FP16-NEXT: ret
;
; CHECK-NOFP16-LABEL: Const6:
; CHECK-NOFP16: // %bb.0: // %entry
; CHECK-NOFP16-NEXT: adrp x8, .LCPI6_0
; CHECK-NOFP16-NEXT: ldr h0, [x8, :lo12:.LCPI6_0]
; CHECK-NOFP16-NEXT: ret
entry:
ret bfloat 0xR4FC1
}

define bfloat @Const7() {
; CHECK-FP16-LABEL: Const7:
; CHECK-FP16: // %bb.0: // %entry
; CHECK-FP16-NEXT: mov w8, #20480 // =0x5000
; CHECK-FP16-NEXT: fmov h0, w8
; CHECK-FP16-NEXT: ret
;
; CHECK-NOFP16-LABEL: Const7:
; CHECK-NOFP16: // %bb.0: // %entry
; CHECK-NOFP16-NEXT: adrp x8, .LCPI7_0
; CHECK-NOFP16-NEXT: ldr h0, [x8, :lo12:.LCPI7_0]
; CHECK-NOFP16-NEXT: ret
entry:
ret bfloat 0xR5000
}

0 comments on commit 778fa4e

Please sign in to comment.