18
18
19
19
#include " support/dtypes.h"
20
20
21
- #include < llvm/Pass.h>
22
21
#include < llvm/IR/IRBuilder.h>
23
22
#include < llvm/IR/LegacyPassManager.h>
24
23
#include < llvm/IR/PassManager.h>
@@ -29,193 +28,15 @@ using namespace llvm;
29
28
30
29
namespace {
31
30
32
- inline AttributeSet getFnAttrs (const AttributeList &Attrs)
33
- {
34
- #if JL_LLVM_VERSION >= 140000
35
- return Attrs.getFnAttrs ();
36
- #else
37
- return Attrs.getFnAttributes ();
38
- #endif
39
- }
40
-
41
- inline AttributeSet getRetAttrs (const AttributeList &Attrs)
42
- {
43
- #if JL_LLVM_VERSION >= 140000
44
- return Attrs.getRetAttrs ();
45
- #else
46
- return Attrs.getRetAttributes ();
47
- #endif
48
- }
49
-
50
- static Instruction *replaceIntrinsicWith (IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
51
- {
52
- Intrinsic::ID ID = call->getIntrinsicID ();
53
- assert (ID);
54
- auto oldfType = call->getFunctionType ();
55
- auto nargs = oldfType->getNumParams ();
56
- assert (args.size () > nargs);
57
- SmallVector<Type*, 8 > argTys (nargs);
58
- for (unsigned i = 0 ; i < nargs; i++)
59
- argTys[i] = args[i]->getType ();
60
- auto newfType = FunctionType::get (RetTy, argTys, oldfType->isVarArg ());
61
-
62
- // Accumulate an array of overloaded types for the given intrinsic
63
- // and compute the new name mangling schema
64
- SmallVector<Type*, 4 > overloadTys;
65
- {
66
- SmallVector<Intrinsic::IITDescriptor, 8 > Table;
67
- getIntrinsicInfoTableEntries (ID, Table);
68
- ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
69
- auto res = Intrinsic::matchIntrinsicSignature (newfType, TableRef, overloadTys);
70
- assert (res == Intrinsic::MatchIntrinsicTypes_Match);
71
- (void )res;
72
- bool matchvararg = !Intrinsic::matchIntrinsicVarArg (newfType->isVarArg (), TableRef);
73
- assert (matchvararg);
74
- (void )matchvararg;
75
- }
76
- auto newF = Intrinsic::getDeclaration (call->getModule (), ID, overloadTys);
77
- assert (newF->getFunctionType () == newfType);
78
- newF->setCallingConv (call->getCallingConv ());
79
- assert (args.back () == call->getCalledFunction ());
80
- auto newCall = CallInst::Create (newF, args.drop_back (), " " , call);
81
- newCall->setTailCallKind (call->getTailCallKind ());
82
- auto old_attrs = call->getAttributes ();
83
- newCall->setAttributes (AttributeList::get (call->getContext (), getFnAttrs (old_attrs),
84
- getRetAttrs (old_attrs), {})); // drop parameter attributes
85
- return newCall;
86
- }
87
-
88
-
89
- static Value* CreateFPCast (Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
90
- {
91
- Type *SrcTy = V->getType ();
92
- Type *RetTy = DestTy;
93
- if (auto *VC = dyn_cast<Constant>(V)) {
94
- // The input IR often has things of the form
95
- // fcmp olt half %0, 0xH7C00
96
- // and we would like to avoid turning that constant into a call here
97
- // if we can simply constant fold it to the new type.
98
- VC = ConstantExpr::getCast (opcode, VC, DestTy, true );
99
- if (VC)
100
- return VC;
101
- }
102
- assert (SrcTy->isVectorTy () == DestTy->isVectorTy ());
103
- if (SrcTy->isVectorTy ()) {
104
- unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements ();
105
- assert (cast<FixedVectorType>(DestTy)->getNumElements () == NumElems && " Mismatched cast" );
106
- Value *NewV = UndefValue::get (DestTy);
107
- RetTy = RetTy->getScalarType ();
108
- for (unsigned i = 0 ; i < NumElems; ++i) {
109
- Value *I = builder.getInt32 (i);
110
- Value *Vi = builder.CreateExtractElement (V, I);
111
- Vi = CreateFPCast (opcode, Vi, RetTy, builder);
112
- NewV = builder.CreateInsertElement (NewV, Vi, I);
113
- }
114
- return NewV;
115
- }
116
- auto &M = *builder.GetInsertBlock ()->getModule ();
117
- auto &ctx = M.getContext ();
118
- // Pick the Function to call in the Julia runtime
119
- StringRef Name;
120
- switch (opcode) {
121
- case Instruction::FPExt:
122
- // this is exact, so we only need one conversion
123
- assert (SrcTy->isHalfTy ());
124
- Name = " julia__gnu_h2f_ieee" ;
125
- RetTy = Type::getFloatTy (ctx);
126
- break ;
127
- case Instruction::FPTrunc:
128
- assert (DestTy->isHalfTy ());
129
- if (SrcTy->isFloatTy ())
130
- Name = " julia__gnu_f2h_ieee" ;
131
- else if (SrcTy->isDoubleTy ())
132
- Name = " julia__truncdfhf2" ;
133
- break ;
134
- // All F16 fit exactly in Int32 (-65504 to 65504)
135
- case Instruction::FPToSI: JL_FALLTHROUGH;
136
- case Instruction::FPToUI:
137
- assert (SrcTy->isHalfTy ());
138
- Name = " julia__gnu_h2f_ieee" ;
139
- RetTy = Type::getFloatTy (ctx);
140
- break ;
141
- case Instruction::SIToFP: JL_FALLTHROUGH;
142
- case Instruction::UIToFP:
143
- assert (DestTy->isHalfTy ());
144
- Name = " julia__gnu_f2h_ieee" ;
145
- SrcTy = Type::getFloatTy (ctx);
146
- break ;
147
- default :
148
- errs () << Instruction::getOpcodeName (opcode) << ' ' ;
149
- V->getType ()->print (errs ());
150
- errs () << " to " ;
151
- DestTy->print (errs ());
152
- errs () << " is an " ;
153
- llvm_unreachable (" invalid cast" );
154
- }
155
- if (Name.empty ()) {
156
- errs () << Instruction::getOpcodeName (opcode) << ' ' ;
157
- V->getType ()->print (errs ());
158
- errs () << " to " ;
159
- DestTy->print (errs ());
160
- errs () << " is an " ;
161
- llvm_unreachable (" illegal cast" );
162
- }
163
- // Coerce the source to the required size and type
164
- auto T_int16 = Type::getInt16Ty (ctx);
165
- if (SrcTy->isHalfTy ())
166
- SrcTy = T_int16;
167
- if (opcode == Instruction::SIToFP)
168
- V = builder.CreateSIToFP (V, SrcTy);
169
- else if (opcode == Instruction::UIToFP)
170
- V = builder.CreateUIToFP (V, SrcTy);
171
- else
172
- V = builder.CreateBitCast (V, SrcTy);
173
- // Call our intrinsic
174
- if (RetTy->isHalfTy ())
175
- RetTy = T_int16;
176
- auto FT = FunctionType::get (RetTy, {SrcTy}, false );
177
- FunctionCallee F = M.getOrInsertFunction (Name, FT);
178
- Value *I = builder.CreateCall (F, {V});
179
- // Coerce the result to the expected type
180
- if (opcode == Instruction::FPToSI)
181
- I = builder.CreateFPToSI (I, DestTy);
182
- else if (opcode == Instruction::FPToUI)
183
- I = builder.CreateFPToUI (I, DestTy);
184
- else if (opcode == Instruction::FPExt)
185
- I = builder.CreateFPCast (I, DestTy);
186
- else
187
- I = builder.CreateBitCast (I, DestTy);
188
- return I;
189
- }
190
-
191
31
static bool demoteFloat16 (Function &F)
192
32
{
193
33
auto &ctx = F.getContext ();
34
+ auto T_float16 = Type::getHalfTy (ctx);
194
35
auto T_float32 = Type::getFloatTy (ctx);
195
36
196
37
SmallVector<Instruction *, 0 > erase;
197
38
for (auto &BB : F) {
198
39
for (auto &I : BB) {
199
- // extend Float16 operands to Float32
200
- bool Float16 = I.getType ()->getScalarType ()->isHalfTy ();
201
- for (size_t i = 0 ; !Float16 && i < I.getNumOperands (); i++) {
202
- Value *Op = I.getOperand (i);
203
- if (Op->getType ()->getScalarType ()->isHalfTy ())
204
- Float16 = true ;
205
- }
206
- if (!Float16)
207
- continue ;
208
-
209
- if (auto CI = dyn_cast<CastInst>(&I)) {
210
- if (CI->getOpcode () != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
211
- IRBuilder<> builder (&I);
212
- Value *NewI = CreateFPCast (CI->getOpcode (), I.getOperand (0 ), I.getType (), builder);
213
- I.replaceAllUsesWith (NewI);
214
- erase.push_back (&I);
215
- }
216
- continue ;
217
- }
218
-
219
40
switch (I.getOpcode ()) {
220
41
case Instruction::FNeg:
221
42
case Instruction::FAdd:
@@ -226,9 +47,6 @@ static bool demoteFloat16(Function &F)
226
47
case Instruction::FCmp:
227
48
break ;
228
49
default :
229
- if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
230
- if (intrinsic->getIntrinsicID ())
231
- break ;
232
50
continue ;
233
51
}
234
52
@@ -240,67 +58,61 @@ static bool demoteFloat16(Function &F)
240
58
IRBuilder<> builder (&I);
241
59
242
60
// extend Float16 operands to Float32
243
- // XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
61
+ bool OperandsChanged = false ;
244
62
SmallVector<Value *, 2 > Operands (I.getNumOperands ());
245
63
for (size_t i = 0 ; i < I.getNumOperands (); i++) {
246
64
Value *Op = I.getOperand (i);
247
- if (Op->getType ()->getScalarType ()->isHalfTy ()) {
248
- Op = CreateFPCast (Instruction::FPExt, Op, Op->getType ()->getWithNewType (T_float32), builder);
65
+ if (Op->getType () == T_float16) {
66
+ Op = builder.CreateFPExt (Op, T_float32);
67
+ OperandsChanged = true ;
249
68
}
250
69
Operands[i] = (Op);
251
70
}
252
71
253
72
// recreate the instruction if any operands changed,
254
73
// truncating the result back to Float16
255
- Value *NewI;
256
- switch (I.getOpcode ()) {
257
- case Instruction::FNeg:
258
- assert (Operands.size () == 1 );
259
- NewI = builder.CreateFNeg (Operands[0 ]);
260
- break ;
261
- case Instruction::FAdd:
262
- assert (Operands.size () == 2 );
263
- NewI = builder.CreateFAdd (Operands[0 ], Operands[1 ]);
264
- break ;
265
- case Instruction::FSub:
266
- assert (Operands.size () == 2 );
267
- NewI = builder.CreateFSub (Operands[0 ], Operands[1 ]);
268
- break ;
269
- case Instruction::FMul:
270
- assert (Operands.size () == 2 );
271
- NewI = builder.CreateFMul (Operands[0 ], Operands[1 ]);
272
- break ;
273
- case Instruction::FDiv:
274
- assert (Operands.size () == 2 );
275
- NewI = builder.CreateFDiv (Operands[0 ], Operands[1 ]);
276
- break ;
277
- case Instruction::FRem:
278
- assert (Operands.size () == 2 );
279
- NewI = builder.CreateFRem (Operands[0 ], Operands[1 ]);
280
- break ;
281
- case Instruction::FCmp:
282
- assert (Operands.size () == 2 );
283
- NewI = builder.CreateFCmp (cast<FCmpInst>(&I)->getPredicate (),
284
- Operands[0 ], Operands[1 ]);
285
- break ;
286
- default :
287
- if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
288
- // XXX: this is not correct in general
289
- // some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
290
- Type *RetTy = I.getType ();
291
- if (RetTy->getScalarType ()->isHalfTy ())
292
- RetTy = RetTy->getWithNewType (T_float32);
293
- NewI = replaceIntrinsicWith (intrinsic, RetTy, Operands);
74
+ if (OperandsChanged) {
75
+ Value *NewI;
76
+ switch (I.getOpcode ()) {
77
+ case Instruction::FNeg:
78
+ assert (Operands.size () == 1 );
79
+ NewI = builder.CreateFNeg (Operands[0 ]);
80
+ break ;
81
+ case Instruction::FAdd:
82
+ assert (Operands.size () == 2 );
83
+ NewI = builder.CreateFAdd (Operands[0 ], Operands[1 ]);
84
+ break ;
85
+ case Instruction::FSub:
86
+ assert (Operands.size () == 2 );
87
+ NewI = builder.CreateFSub (Operands[0 ], Operands[1 ]);
88
+ break ;
89
+ case Instruction::FMul:
90
+ assert (Operands.size () == 2 );
91
+ NewI = builder.CreateFMul (Operands[0 ], Operands[1 ]);
92
+ break ;
93
+ case Instruction::FDiv:
94
+ assert (Operands.size () == 2 );
95
+ NewI = builder.CreateFDiv (Operands[0 ], Operands[1 ]);
96
+ break ;
97
+ case Instruction::FRem:
98
+ assert (Operands.size () == 2 );
99
+ NewI = builder.CreateFRem (Operands[0 ], Operands[1 ]);
100
+ break ;
101
+ case Instruction::FCmp:
102
+ assert (Operands.size () == 2 );
103
+ NewI = builder.CreateFCmp (cast<FCmpInst>(&I)->getPredicate (),
104
+ Operands[0 ], Operands[1 ]);
294
105
break ;
106
+ default :
107
+ abort ();
295
108
}
296
- abort ();
109
+ cast<Instruction>(NewI)->copyMetadata (I);
110
+ cast<Instruction>(NewI)->copyFastMathFlags (&I);
111
+ if (NewI->getType () != I.getType ())
112
+ NewI = builder.CreateFPTrunc (NewI, I.getType ());
113
+ I.replaceAllUsesWith (NewI);
114
+ erase.push_back (&I);
297
115
}
298
- cast<Instruction>(NewI)->copyMetadata (I);
299
- cast<Instruction>(NewI)->copyFastMathFlags (&I);
300
- if (NewI->getType () != I.getType ())
301
- NewI = CreateFPCast (Instruction::FPTrunc, NewI, I.getType (), builder);
302
- I.replaceAllUsesWith (NewI);
303
- erase.push_back (&I);
304
116
}
305
117
}
306
118
0 commit comments