Skip to content

Commit bcf4c99

Browse files
authored
Revert "codegen: explicitly handle Float16 intrinsics (#45249)"
This reverts commit f2c627e.
1 parent a9e3cc7 commit bcf4c99

File tree

5 files changed

+92
-294
lines changed

5 files changed

+92
-294
lines changed

src/APInt-C.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
316316
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
317317
double Val;
318318
if (numbits == 16)
319-
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
319+
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
320320
else if (numbits == 32)
321321
Val = *(float*)pa;
322322
else if (numbits == 64)
@@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
391391
val = a.roundToDouble(true);
392392
}
393393
if (onumbits == 16)
394-
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
394+
*(uint16_t*)pr = __gnu_f2h_ieee(val);
395395
else if (onumbits == 32)
396396
*(float*)pr = val;
397397
else if (onumbits == 64)
@@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
408408
val = a.roundToDouble(false);
409409
}
410410
if (onumbits == 16)
411-
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
411+
*(uint16_t*)pr = __gnu_f2h_ieee(val);
412412
else if (onumbits == 32)
413413
*(float*)pr = val;
414414
else if (onumbits == 64)

src/julia.expmap

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
environ;
3838
__progname;
3939

40+
/* compiler run-time intrinsics */
41+
__gnu_h2f_ieee;
42+
__extendhfsf2;
43+
__gnu_f2h_ieee;
44+
__truncdfhf2;
45+
4046
local:
4147
*;
4248
};

src/julia_internal.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,18 +1544,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
15441544
#define JL_GC_ASSERT_LIVE(x) (void)(x)
15451545
#endif
15461546

1547-
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
1548-
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
1549-
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
1550-
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
1551-
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
1552-
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
1553-
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
1554-
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
1555-
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
1556-
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
1557-
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
1558-
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;
1547+
float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
1548+
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
15591549

15601550
#ifdef __cplusplus
15611551
}

src/llvm-demote-float16.cpp

Lines changed: 54 additions & 242 deletions
Original file line numberDiff line numberDiff line change
@@ -45,194 +45,15 @@ INST_STATISTIC(FCmp);
4545

4646
namespace {
4747

48-
inline AttributeSet getFnAttrs(const AttributeList &Attrs)
49-
{
50-
#if JL_LLVM_VERSION >= 140000
51-
return Attrs.getFnAttrs();
52-
#else
53-
return Attrs.getFnAttributes();
54-
#endif
55-
}
56-
57-
inline AttributeSet getRetAttrs(const AttributeList &Attrs)
58-
{
59-
#if JL_LLVM_VERSION >= 140000
60-
return Attrs.getRetAttrs();
61-
#else
62-
return Attrs.getRetAttributes();
63-
#endif
64-
}
65-
66-
static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
67-
{
68-
Intrinsic::ID ID = call->getIntrinsicID();
69-
assert(ID);
70-
auto oldfType = call->getFunctionType();
71-
auto nargs = oldfType->getNumParams();
72-
assert(args.size() > nargs);
73-
SmallVector<Type*, 8> argTys(nargs);
74-
for (unsigned i = 0; i < nargs; i++)
75-
argTys[i] = args[i]->getType();
76-
auto newfType = FunctionType::get(RetTy, argTys, oldfType->isVarArg());
77-
78-
// Accumulate an array of overloaded types for the given intrinsic
79-
// and compute the new name mangling schema
80-
SmallVector<Type*, 4> overloadTys;
81-
{
82-
SmallVector<Intrinsic::IITDescriptor, 8> Table;
83-
getIntrinsicInfoTableEntries(ID, Table);
84-
ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
85-
auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys);
86-
assert(res == Intrinsic::MatchIntrinsicTypes_Match);
87-
(void)res;
88-
bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef);
89-
assert(matchvararg);
90-
(void)matchvararg;
91-
}
92-
auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys);
93-
assert(newF->getFunctionType() == newfType);
94-
newF->setCallingConv(call->getCallingConv());
95-
assert(args.back() == call->getCalledFunction());
96-
auto newCall = CallInst::Create(newF, args.drop_back(), "", call);
97-
newCall->setTailCallKind(call->getTailCallKind());
98-
auto old_attrs = call->getAttributes();
99-
newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs),
100-
getRetAttrs(old_attrs), {})); // drop parameter attributes
101-
return newCall;
102-
}
103-
104-
105-
static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
106-
{
107-
Type *SrcTy = V->getType();
108-
Type *RetTy = DestTy;
109-
if (auto *VC = dyn_cast<Constant>(V)) {
110-
// The input IR often has things of the form
111-
// fcmp olt half %0, 0xH7C00
112-
// and we would like to avoid turning that constant into a call here
113-
// if we can simply constant fold it to the new type.
114-
VC = ConstantExpr::getCast(opcode, VC, DestTy, true);
115-
if (VC)
116-
return VC;
117-
}
118-
assert(SrcTy->isVectorTy() == DestTy->isVectorTy());
119-
if (SrcTy->isVectorTy()) {
120-
unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements();
121-
assert(cast<FixedVectorType>(DestTy)->getNumElements() == NumElems && "Mismatched cast");
122-
Value *NewV = UndefValue::get(DestTy);
123-
RetTy = RetTy->getScalarType();
124-
for (unsigned i = 0; i < NumElems; ++i) {
125-
Value *I = builder.getInt32(i);
126-
Value *Vi = builder.CreateExtractElement(V, I);
127-
Vi = CreateFPCast(opcode, Vi, RetTy, builder);
128-
NewV = builder.CreateInsertElement(NewV, Vi, I);
129-
}
130-
return NewV;
131-
}
132-
auto &M = *builder.GetInsertBlock()->getModule();
133-
auto &ctx = M.getContext();
134-
// Pick the Function to call in the Julia runtime
135-
StringRef Name;
136-
switch (opcode) {
137-
case Instruction::FPExt:
138-
// this is exact, so we only need one conversion
139-
assert(SrcTy->isHalfTy());
140-
Name = "julia__gnu_h2f_ieee";
141-
RetTy = Type::getFloatTy(ctx);
142-
break;
143-
case Instruction::FPTrunc:
144-
assert(DestTy->isHalfTy());
145-
if (SrcTy->isFloatTy())
146-
Name = "julia__gnu_f2h_ieee";
147-
else if (SrcTy->isDoubleTy())
148-
Name = "julia__truncdfhf2";
149-
break;
150-
// All F16 fit exactly in Int32 (-65504 to 65504)
151-
case Instruction::FPToSI: JL_FALLTHROUGH;
152-
case Instruction::FPToUI:
153-
assert(SrcTy->isHalfTy());
154-
Name = "julia__gnu_h2f_ieee";
155-
RetTy = Type::getFloatTy(ctx);
156-
break;
157-
case Instruction::SIToFP: JL_FALLTHROUGH;
158-
case Instruction::UIToFP:
159-
assert(DestTy->isHalfTy());
160-
Name = "julia__gnu_f2h_ieee";
161-
SrcTy = Type::getFloatTy(ctx);
162-
break;
163-
default:
164-
errs() << Instruction::getOpcodeName(opcode) << ' ';
165-
V->getType()->print(errs());
166-
errs() << " to ";
167-
DestTy->print(errs());
168-
errs() << " is an ";
169-
llvm_unreachable("invalid cast");
170-
}
171-
if (Name.empty()) {
172-
errs() << Instruction::getOpcodeName(opcode) << ' ';
173-
V->getType()->print(errs());
174-
errs() << " to ";
175-
DestTy->print(errs());
176-
errs() << " is an ";
177-
llvm_unreachable("illegal cast");
178-
}
179-
// Coerce the source to the required size and type
180-
auto T_int16 = Type::getInt16Ty(ctx);
181-
if (SrcTy->isHalfTy())
182-
SrcTy = T_int16;
183-
if (opcode == Instruction::SIToFP)
184-
V = builder.CreateSIToFP(V, SrcTy);
185-
else if (opcode == Instruction::UIToFP)
186-
V = builder.CreateUIToFP(V, SrcTy);
187-
else
188-
V = builder.CreateBitCast(V, SrcTy);
189-
// Call our intrinsic
190-
if (RetTy->isHalfTy())
191-
RetTy = T_int16;
192-
auto FT = FunctionType::get(RetTy, {SrcTy}, false);
193-
FunctionCallee F = M.getOrInsertFunction(Name, FT);
194-
Value *I = builder.CreateCall(F, {V});
195-
// Coerce the result to the expected type
196-
if (opcode == Instruction::FPToSI)
197-
I = builder.CreateFPToSI(I, DestTy);
198-
else if (opcode == Instruction::FPToUI)
199-
I = builder.CreateFPToUI(I, DestTy);
200-
else if (opcode == Instruction::FPExt)
201-
I = builder.CreateFPCast(I, DestTy);
202-
else
203-
I = builder.CreateBitCast(I, DestTy);
204-
return I;
205-
}
206-
20748
static bool demoteFloat16(Function &F)
20849
{
20950
auto &ctx = F.getContext();
51+
auto T_float16 = Type::getHalfTy(ctx);
21052
auto T_float32 = Type::getFloatTy(ctx);
21153

21254
SmallVector<Instruction *, 0> erase;
21355
for (auto &BB : F) {
21456
for (auto &I : BB) {
215-
// extend Float16 operands to Float32
216-
bool Float16 = I.getType()->getScalarType()->isHalfTy();
217-
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
218-
Value *Op = I.getOperand(i);
219-
if (Op->getType()->getScalarType()->isHalfTy())
220-
Float16 = true;
221-
}
222-
if (!Float16)
223-
continue;
224-
225-
if (auto CI = dyn_cast<CastInst>(&I)) {
226-
if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
227-
++TotalChanged;
228-
IRBuilder<> builder(&I);
229-
Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder);
230-
I.replaceAllUsesWith(NewI);
231-
erase.push_back(&I);
232-
}
233-
continue;
234-
}
235-
23657
switch (I.getOpcode()) {
23758
case Instruction::FNeg:
23859
case Instruction::FAdd:
@@ -243,9 +64,6 @@ static bool demoteFloat16(Function &F)
24364
case Instruction::FCmp:
24465
break;
24566
default:
246-
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
247-
if (intrinsic->getIntrinsicID())
248-
break;
24967
continue;
25068
}
25169

@@ -257,78 +75,72 @@ static bool demoteFloat16(Function &F)
25775
IRBuilder<> builder(&I);
25876

25977
// extend Float16 operands to Float32
260-
// XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
78+
bool OperandsChanged = false;
26179
SmallVector<Value *, 2> Operands(I.getNumOperands());
26280
for (size_t i = 0; i < I.getNumOperands(); i++) {
26381
Value *Op = I.getOperand(i);
264-
if (Op->getType()->getScalarType()->isHalfTy()) {
82+
if (Op->getType() == T_float16) {
26583
++TotalExt;
266-
Op = CreateFPCast(Instruction::FPExt, Op, Op->getType()->getWithNewType(T_float32), builder);
84+
Op = builder.CreateFPExt(Op, T_float32);
85+
OperandsChanged = true;
26786
}
26887
Operands[i] = (Op);
26988
}
27089

27190
// recreate the instruction if any operands changed,
27291
// truncating the result back to Float16
273-
Value *NewI;
274-
++TotalChanged;
275-
switch (I.getOpcode()) {
276-
case Instruction::FNeg:
277-
assert(Operands.size() == 1);
278-
++FNegChanged;
279-
NewI = builder.CreateFNeg(Operands[0]);
280-
break;
281-
case Instruction::FAdd:
282-
assert(Operands.size() == 2);
283-
++FAddChanged;
284-
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
285-
break;
286-
case Instruction::FSub:
287-
assert(Operands.size() == 2);
288-
++FSubChanged;
289-
NewI = builder.CreateFSub(Operands[0], Operands[1]);
290-
break;
291-
case Instruction::FMul:
292-
assert(Operands.size() == 2);
293-
++FMulChanged;
294-
NewI = builder.CreateFMul(Operands[0], Operands[1]);
295-
break;
296-
case Instruction::FDiv:
297-
assert(Operands.size() == 2);
298-
++FDivChanged;
299-
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
300-
break;
301-
case Instruction::FRem:
302-
assert(Operands.size() == 2);
303-
++FRemChanged;
304-
NewI = builder.CreateFRem(Operands[0], Operands[1]);
305-
break;
306-
case Instruction::FCmp:
307-
assert(Operands.size() == 2);
308-
++FCmpChanged;
309-
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
310-
Operands[0], Operands[1]);
311-
break;
312-
default:
313-
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
314-
// XXX: this is not correct in general
315-
// some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
316-
Type *RetTy = I.getType();
317-
if (RetTy->getScalarType()->isHalfTy())
318-
RetTy = RetTy->getWithNewType(T_float32);
319-
NewI = replaceIntrinsicWith(intrinsic, RetTy, Operands);
92+
if (OperandsChanged) {
93+
Value *NewI;
94+
++TotalChanged;
95+
switch (I.getOpcode()) {
96+
case Instruction::FNeg:
97+
assert(Operands.size() == 1);
98+
++FNegChanged;
99+
NewI = builder.CreateFNeg(Operands[0]);
100+
break;
101+
case Instruction::FAdd:
102+
assert(Operands.size() == 2);
103+
++FAddChanged;
104+
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
105+
break;
106+
case Instruction::FSub:
107+
assert(Operands.size() == 2);
108+
++FSubChanged;
109+
NewI = builder.CreateFSub(Operands[0], Operands[1]);
320110
break;
111+
case Instruction::FMul:
112+
assert(Operands.size() == 2);
113+
++FMulChanged;
114+
NewI = builder.CreateFMul(Operands[0], Operands[1]);
115+
break;
116+
case Instruction::FDiv:
117+
assert(Operands.size() == 2);
118+
++FDivChanged;
119+
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
120+
break;
121+
case Instruction::FRem:
122+
assert(Operands.size() == 2);
123+
++FRemChanged;
124+
NewI = builder.CreateFRem(Operands[0], Operands[1]);
125+
break;
126+
case Instruction::FCmp:
127+
assert(Operands.size() == 2);
128+
++FCmpChanged;
129+
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
130+
Operands[0], Operands[1]);
131+
break;
132+
default:
133+
abort();
321134
}
322-
abort();
323-
}
324-
cast<Instruction>(NewI)->copyMetadata(I);
325-
cast<Instruction>(NewI)->copyFastMathFlags(&I);
326-
if (NewI->getType() != I.getType()) {
327-
++TotalTrunc;
328-
NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder);
135+
cast<Instruction>(NewI)->copyMetadata(I);
136+
cast<Instruction>(NewI)->copyFastMathFlags(&I);
137+
if (NewI->getType() != I.getType()) {
138+
++TotalTrunc;
139+
NewI = builder.CreateFPTrunc(NewI, I.getType());
140+
}
141+
I.replaceAllUsesWith(NewI);
142+
erase.push_back(&I);
329143
}
330-
I.replaceAllUsesWith(NewI);
331-
erase.push_back(&I);
332144
}
333145
}
334146

0 commit comments

Comments
 (0)