Skip to content

Commit 3a2eb39

Browse files
authored
Revert "codegen: explicitly handle Float16 intrinsics (#45249)" (#45627)
This reverts commit eb82f18.
1 parent 43df1f4 commit 3a2eb39

File tree

5 files changed

+82
-284
lines changed

5 files changed

+82
-284
lines changed

src/APInt-C.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
316316
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
317317
double Val;
318318
if (numbits == 16)
319-
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
319+
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
320320
else if (numbits == 32)
321321
Val = *(float*)pa;
322322
else if (numbits == 64)
@@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
391391
val = a.roundToDouble(true);
392392
}
393393
if (onumbits == 16)
394-
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
394+
*(uint16_t*)pr = __gnu_f2h_ieee(val);
395395
else if (onumbits == 32)
396396
*(float*)pr = val;
397397
else if (onumbits == 64)
@@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
408408
val = a.roundToDouble(false);
409409
}
410410
if (onumbits == 16)
411-
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
411+
*(uint16_t*)pr = __gnu_f2h_ieee(val);
412412
else if (onumbits == 32)
413413
*(float*)pr = val;
414414
else if (onumbits == 64)

src/julia.expmap

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
environ;
3838
__progname;
3939

40+
/* compiler run-time intrinsics */
41+
__gnu_h2f_ieee;
42+
__extendhfsf2;
43+
__gnu_f2h_ieee;
44+
__truncdfhf2;
45+
4046
local:
4147
*;
4248
};

src/julia_internal.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,18 +1522,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
15221522
#define JL_GC_ASSERT_LIVE(x) (void)(x)
15231523
#endif
15241524

1525-
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
1526-
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
1527-
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
1528-
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
1529-
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
1530-
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
1531-
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
1532-
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
1533-
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
1534-
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
1535-
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
1536-
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;
1525+
float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
1526+
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
15371527

15381528
#ifdef __cplusplus
15391529
}

src/llvm-demote-float16.cpp

Lines changed: 44 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
#include "support/dtypes.h"
2020

21-
#include <llvm/Pass.h>
2221
#include <llvm/IR/IRBuilder.h>
2322
#include <llvm/IR/LegacyPassManager.h>
2423
#include <llvm/IR/PassManager.h>
@@ -29,193 +28,15 @@ using namespace llvm;
2928

3029
namespace {
3130

32-
inline AttributeSet getFnAttrs(const AttributeList &Attrs)
33-
{
34-
#if JL_LLVM_VERSION >= 140000
35-
return Attrs.getFnAttrs();
36-
#else
37-
return Attrs.getFnAttributes();
38-
#endif
39-
}
40-
41-
inline AttributeSet getRetAttrs(const AttributeList &Attrs)
42-
{
43-
#if JL_LLVM_VERSION >= 140000
44-
return Attrs.getRetAttrs();
45-
#else
46-
return Attrs.getRetAttributes();
47-
#endif
48-
}
49-
50-
static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
51-
{
52-
Intrinsic::ID ID = call->getIntrinsicID();
53-
assert(ID);
54-
auto oldfType = call->getFunctionType();
55-
auto nargs = oldfType->getNumParams();
56-
assert(args.size() > nargs);
57-
SmallVector<Type*, 8> argTys(nargs);
58-
for (unsigned i = 0; i < nargs; i++)
59-
argTys[i] = args[i]->getType();
60-
auto newfType = FunctionType::get(RetTy, argTys, oldfType->isVarArg());
61-
62-
// Accumulate an array of overloaded types for the given intrinsic
63-
// and compute the new name mangling schema
64-
SmallVector<Type*, 4> overloadTys;
65-
{
66-
SmallVector<Intrinsic::IITDescriptor, 8> Table;
67-
getIntrinsicInfoTableEntries(ID, Table);
68-
ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
69-
auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys);
70-
assert(res == Intrinsic::MatchIntrinsicTypes_Match);
71-
(void)res;
72-
bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef);
73-
assert(matchvararg);
74-
(void)matchvararg;
75-
}
76-
auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys);
77-
assert(newF->getFunctionType() == newfType);
78-
newF->setCallingConv(call->getCallingConv());
79-
assert(args.back() == call->getCalledFunction());
80-
auto newCall = CallInst::Create(newF, args.drop_back(), "", call);
81-
newCall->setTailCallKind(call->getTailCallKind());
82-
auto old_attrs = call->getAttributes();
83-
newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs),
84-
getRetAttrs(old_attrs), {})); // drop parameter attributes
85-
return newCall;
86-
}
87-
88-
89-
static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
90-
{
91-
Type *SrcTy = V->getType();
92-
Type *RetTy = DestTy;
93-
if (auto *VC = dyn_cast<Constant>(V)) {
94-
// The input IR often has things of the form
95-
// fcmp olt half %0, 0xH7C00
96-
// and we would like to avoid turning that constant into a call here
97-
// if we can simply constant fold it to the new type.
98-
VC = ConstantExpr::getCast(opcode, VC, DestTy, true);
99-
if (VC)
100-
return VC;
101-
}
102-
assert(SrcTy->isVectorTy() == DestTy->isVectorTy());
103-
if (SrcTy->isVectorTy()) {
104-
unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements();
105-
assert(cast<FixedVectorType>(DestTy)->getNumElements() == NumElems && "Mismatched cast");
106-
Value *NewV = UndefValue::get(DestTy);
107-
RetTy = RetTy->getScalarType();
108-
for (unsigned i = 0; i < NumElems; ++i) {
109-
Value *I = builder.getInt32(i);
110-
Value *Vi = builder.CreateExtractElement(V, I);
111-
Vi = CreateFPCast(opcode, Vi, RetTy, builder);
112-
NewV = builder.CreateInsertElement(NewV, Vi, I);
113-
}
114-
return NewV;
115-
}
116-
auto &M = *builder.GetInsertBlock()->getModule();
117-
auto &ctx = M.getContext();
118-
// Pick the Function to call in the Julia runtime
119-
StringRef Name;
120-
switch (opcode) {
121-
case Instruction::FPExt:
122-
// this is exact, so we only need one conversion
123-
assert(SrcTy->isHalfTy());
124-
Name = "julia__gnu_h2f_ieee";
125-
RetTy = Type::getFloatTy(ctx);
126-
break;
127-
case Instruction::FPTrunc:
128-
assert(DestTy->isHalfTy());
129-
if (SrcTy->isFloatTy())
130-
Name = "julia__gnu_f2h_ieee";
131-
else if (SrcTy->isDoubleTy())
132-
Name = "julia__truncdfhf2";
133-
break;
134-
// All F16 fit exactly in Int32 (-65504 to 65504)
135-
case Instruction::FPToSI: JL_FALLTHROUGH;
136-
case Instruction::FPToUI:
137-
assert(SrcTy->isHalfTy());
138-
Name = "julia__gnu_h2f_ieee";
139-
RetTy = Type::getFloatTy(ctx);
140-
break;
141-
case Instruction::SIToFP: JL_FALLTHROUGH;
142-
case Instruction::UIToFP:
143-
assert(DestTy->isHalfTy());
144-
Name = "julia__gnu_f2h_ieee";
145-
SrcTy = Type::getFloatTy(ctx);
146-
break;
147-
default:
148-
errs() << Instruction::getOpcodeName(opcode) << ' ';
149-
V->getType()->print(errs());
150-
errs() << " to ";
151-
DestTy->print(errs());
152-
errs() << " is an ";
153-
llvm_unreachable("invalid cast");
154-
}
155-
if (Name.empty()) {
156-
errs() << Instruction::getOpcodeName(opcode) << ' ';
157-
V->getType()->print(errs());
158-
errs() << " to ";
159-
DestTy->print(errs());
160-
errs() << " is an ";
161-
llvm_unreachable("illegal cast");
162-
}
163-
// Coerce the source to the required size and type
164-
auto T_int16 = Type::getInt16Ty(ctx);
165-
if (SrcTy->isHalfTy())
166-
SrcTy = T_int16;
167-
if (opcode == Instruction::SIToFP)
168-
V = builder.CreateSIToFP(V, SrcTy);
169-
else if (opcode == Instruction::UIToFP)
170-
V = builder.CreateUIToFP(V, SrcTy);
171-
else
172-
V = builder.CreateBitCast(V, SrcTy);
173-
// Call our intrinsic
174-
if (RetTy->isHalfTy())
175-
RetTy = T_int16;
176-
auto FT = FunctionType::get(RetTy, {SrcTy}, false);
177-
FunctionCallee F = M.getOrInsertFunction(Name, FT);
178-
Value *I = builder.CreateCall(F, {V});
179-
// Coerce the result to the expected type
180-
if (opcode == Instruction::FPToSI)
181-
I = builder.CreateFPToSI(I, DestTy);
182-
else if (opcode == Instruction::FPToUI)
183-
I = builder.CreateFPToUI(I, DestTy);
184-
else if (opcode == Instruction::FPExt)
185-
I = builder.CreateFPCast(I, DestTy);
186-
else
187-
I = builder.CreateBitCast(I, DestTy);
188-
return I;
189-
}
190-
19131
static bool demoteFloat16(Function &F)
19232
{
19333
auto &ctx = F.getContext();
34+
auto T_float16 = Type::getHalfTy(ctx);
19435
auto T_float32 = Type::getFloatTy(ctx);
19536

19637
SmallVector<Instruction *, 0> erase;
19738
for (auto &BB : F) {
19839
for (auto &I : BB) {
199-
// extend Float16 operands to Float32
200-
bool Float16 = I.getType()->getScalarType()->isHalfTy();
201-
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
202-
Value *Op = I.getOperand(i);
203-
if (Op->getType()->getScalarType()->isHalfTy())
204-
Float16 = true;
205-
}
206-
if (!Float16)
207-
continue;
208-
209-
if (auto CI = dyn_cast<CastInst>(&I)) {
210-
if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
211-
IRBuilder<> builder(&I);
212-
Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder);
213-
I.replaceAllUsesWith(NewI);
214-
erase.push_back(&I);
215-
}
216-
continue;
217-
}
218-
21940
switch (I.getOpcode()) {
22041
case Instruction::FNeg:
22142
case Instruction::FAdd:
@@ -226,9 +47,6 @@ static bool demoteFloat16(Function &F)
22647
case Instruction::FCmp:
22748
break;
22849
default:
229-
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
230-
if (intrinsic->getIntrinsicID())
231-
break;
23250
continue;
23351
}
23452

@@ -240,67 +58,61 @@ static bool demoteFloat16(Function &F)
24058
IRBuilder<> builder(&I);
24159

24260
// extend Float16 operands to Float32
243-
// XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
61+
bool OperandsChanged = false;
24462
SmallVector<Value *, 2> Operands(I.getNumOperands());
24563
for (size_t i = 0; i < I.getNumOperands(); i++) {
24664
Value *Op = I.getOperand(i);
247-
if (Op->getType()->getScalarType()->isHalfTy()) {
248-
Op = CreateFPCast(Instruction::FPExt, Op, Op->getType()->getWithNewType(T_float32), builder);
65+
if (Op->getType() == T_float16) {
66+
Op = builder.CreateFPExt(Op, T_float32);
67+
OperandsChanged = true;
24968
}
25069
Operands[i] = (Op);
25170
}
25271

25372
// recreate the instruction if any operands changed,
25473
// truncating the result back to Float16
255-
Value *NewI;
256-
switch (I.getOpcode()) {
257-
case Instruction::FNeg:
258-
assert(Operands.size() == 1);
259-
NewI = builder.CreateFNeg(Operands[0]);
260-
break;
261-
case Instruction::FAdd:
262-
assert(Operands.size() == 2);
263-
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
264-
break;
265-
case Instruction::FSub:
266-
assert(Operands.size() == 2);
267-
NewI = builder.CreateFSub(Operands[0], Operands[1]);
268-
break;
269-
case Instruction::FMul:
270-
assert(Operands.size() == 2);
271-
NewI = builder.CreateFMul(Operands[0], Operands[1]);
272-
break;
273-
case Instruction::FDiv:
274-
assert(Operands.size() == 2);
275-
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
276-
break;
277-
case Instruction::FRem:
278-
assert(Operands.size() == 2);
279-
NewI = builder.CreateFRem(Operands[0], Operands[1]);
280-
break;
281-
case Instruction::FCmp:
282-
assert(Operands.size() == 2);
283-
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
284-
Operands[0], Operands[1]);
285-
break;
286-
default:
287-
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
288-
// XXX: this is not correct in general
289-
// some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
290-
Type *RetTy = I.getType();
291-
if (RetTy->getScalarType()->isHalfTy())
292-
RetTy = RetTy->getWithNewType(T_float32);
293-
NewI = replaceIntrinsicWith(intrinsic, RetTy, Operands);
74+
if (OperandsChanged) {
75+
Value *NewI;
76+
switch (I.getOpcode()) {
77+
case Instruction::FNeg:
78+
assert(Operands.size() == 1);
79+
NewI = builder.CreateFNeg(Operands[0]);
80+
break;
81+
case Instruction::FAdd:
82+
assert(Operands.size() == 2);
83+
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
84+
break;
85+
case Instruction::FSub:
86+
assert(Operands.size() == 2);
87+
NewI = builder.CreateFSub(Operands[0], Operands[1]);
88+
break;
89+
case Instruction::FMul:
90+
assert(Operands.size() == 2);
91+
NewI = builder.CreateFMul(Operands[0], Operands[1]);
92+
break;
93+
case Instruction::FDiv:
94+
assert(Operands.size() == 2);
95+
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
96+
break;
97+
case Instruction::FRem:
98+
assert(Operands.size() == 2);
99+
NewI = builder.CreateFRem(Operands[0], Operands[1]);
100+
break;
101+
case Instruction::FCmp:
102+
assert(Operands.size() == 2);
103+
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
104+
Operands[0], Operands[1]);
294105
break;
106+
default:
107+
abort();
295108
}
296-
abort();
109+
cast<Instruction>(NewI)->copyMetadata(I);
110+
cast<Instruction>(NewI)->copyFastMathFlags(&I);
111+
if (NewI->getType() != I.getType())
112+
NewI = builder.CreateFPTrunc(NewI, I.getType());
113+
I.replaceAllUsesWith(NewI);
114+
erase.push_back(&I);
297115
}
298-
cast<Instruction>(NewI)->copyMetadata(I);
299-
cast<Instruction>(NewI)->copyFastMathFlags(&I);
300-
if (NewI->getType() != I.getType())
301-
NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder);
302-
I.replaceAllUsesWith(NewI);
303-
erase.push_back(&I);
304116
}
305117
}
306118

0 commit comments

Comments
 (0)