@@ -45,194 +45,15 @@ INST_STATISTIC(FCmp);
45
45
46
46
namespace {
47
47
48
- inline AttributeSet getFnAttrs (const AttributeList &Attrs)
49
- {
50
- #if JL_LLVM_VERSION >= 140000
51
- return Attrs.getFnAttrs ();
52
- #else
53
- return Attrs.getFnAttributes ();
54
- #endif
55
- }
56
-
57
- inline AttributeSet getRetAttrs (const AttributeList &Attrs)
58
- {
59
- #if JL_LLVM_VERSION >= 140000
60
- return Attrs.getRetAttrs ();
61
- #else
62
- return Attrs.getRetAttributes ();
63
- #endif
64
- }
65
-
66
- static Instruction *replaceIntrinsicWith (IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
67
- {
68
- Intrinsic::ID ID = call->getIntrinsicID ();
69
- assert (ID);
70
- auto oldfType = call->getFunctionType ();
71
- auto nargs = oldfType->getNumParams ();
72
- assert (args.size () > nargs);
73
- SmallVector<Type*, 8 > argTys (nargs);
74
- for (unsigned i = 0 ; i < nargs; i++)
75
- argTys[i] = args[i]->getType ();
76
- auto newfType = FunctionType::get (RetTy, argTys, oldfType->isVarArg ());
77
-
78
- // Accumulate an array of overloaded types for the given intrinsic
79
- // and compute the new name mangling schema
80
- SmallVector<Type*, 4 > overloadTys;
81
- {
82
- SmallVector<Intrinsic::IITDescriptor, 8 > Table;
83
- getIntrinsicInfoTableEntries (ID, Table);
84
- ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
85
- auto res = Intrinsic::matchIntrinsicSignature (newfType, TableRef, overloadTys);
86
- assert (res == Intrinsic::MatchIntrinsicTypes_Match);
87
- (void )res;
88
- bool matchvararg = !Intrinsic::matchIntrinsicVarArg (newfType->isVarArg (), TableRef);
89
- assert (matchvararg);
90
- (void )matchvararg;
91
- }
92
- auto newF = Intrinsic::getDeclaration (call->getModule (), ID, overloadTys);
93
- assert (newF->getFunctionType () == newfType);
94
- newF->setCallingConv (call->getCallingConv ());
95
- assert (args.back () == call->getCalledFunction ());
96
- auto newCall = CallInst::Create (newF, args.drop_back (), " " , call);
97
- newCall->setTailCallKind (call->getTailCallKind ());
98
- auto old_attrs = call->getAttributes ();
99
- newCall->setAttributes (AttributeList::get (call->getContext (), getFnAttrs (old_attrs),
100
- getRetAttrs (old_attrs), {})); // drop parameter attributes
101
- return newCall;
102
- }
103
-
104
-
105
- static Value* CreateFPCast (Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
106
- {
107
- Type *SrcTy = V->getType ();
108
- Type *RetTy = DestTy;
109
- if (auto *VC = dyn_cast<Constant>(V)) {
110
- // The input IR often has things of the form
111
- // fcmp olt half %0, 0xH7C00
112
- // and we would like to avoid turning that constant into a call here
113
- // if we can simply constant fold it to the new type.
114
- VC = ConstantExpr::getCast (opcode, VC, DestTy, true );
115
- if (VC)
116
- return VC;
117
- }
118
- assert (SrcTy->isVectorTy () == DestTy->isVectorTy ());
119
- if (SrcTy->isVectorTy ()) {
120
- unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements ();
121
- assert (cast<FixedVectorType>(DestTy)->getNumElements () == NumElems && " Mismatched cast" );
122
- Value *NewV = UndefValue::get (DestTy);
123
- RetTy = RetTy->getScalarType ();
124
- for (unsigned i = 0 ; i < NumElems; ++i) {
125
- Value *I = builder.getInt32 (i);
126
- Value *Vi = builder.CreateExtractElement (V, I);
127
- Vi = CreateFPCast (opcode, Vi, RetTy, builder);
128
- NewV = builder.CreateInsertElement (NewV, Vi, I);
129
- }
130
- return NewV;
131
- }
132
- auto &M = *builder.GetInsertBlock ()->getModule ();
133
- auto &ctx = M.getContext ();
134
- // Pick the Function to call in the Julia runtime
135
- StringRef Name;
136
- switch (opcode) {
137
- case Instruction::FPExt:
138
- // this is exact, so we only need one conversion
139
- assert (SrcTy->isHalfTy ());
140
- Name = " julia__gnu_h2f_ieee" ;
141
- RetTy = Type::getFloatTy (ctx);
142
- break ;
143
- case Instruction::FPTrunc:
144
- assert (DestTy->isHalfTy ());
145
- if (SrcTy->isFloatTy ())
146
- Name = " julia__gnu_f2h_ieee" ;
147
- else if (SrcTy->isDoubleTy ())
148
- Name = " julia__truncdfhf2" ;
149
- break ;
150
- // All F16 fit exactly in Int32 (-65504 to 65504)
151
- case Instruction::FPToSI: JL_FALLTHROUGH;
152
- case Instruction::FPToUI:
153
- assert (SrcTy->isHalfTy ());
154
- Name = " julia__gnu_h2f_ieee" ;
155
- RetTy = Type::getFloatTy (ctx);
156
- break ;
157
- case Instruction::SIToFP: JL_FALLTHROUGH;
158
- case Instruction::UIToFP:
159
- assert (DestTy->isHalfTy ());
160
- Name = " julia__gnu_f2h_ieee" ;
161
- SrcTy = Type::getFloatTy (ctx);
162
- break ;
163
- default :
164
- errs () << Instruction::getOpcodeName (opcode) << ' ' ;
165
- V->getType ()->print (errs ());
166
- errs () << " to " ;
167
- DestTy->print (errs ());
168
- errs () << " is an " ;
169
- llvm_unreachable (" invalid cast" );
170
- }
171
- if (Name.empty ()) {
172
- errs () << Instruction::getOpcodeName (opcode) << ' ' ;
173
- V->getType ()->print (errs ());
174
- errs () << " to " ;
175
- DestTy->print (errs ());
176
- errs () << " is an " ;
177
- llvm_unreachable (" illegal cast" );
178
- }
179
- // Coerce the source to the required size and type
180
- auto T_int16 = Type::getInt16Ty (ctx);
181
- if (SrcTy->isHalfTy ())
182
- SrcTy = T_int16;
183
- if (opcode == Instruction::SIToFP)
184
- V = builder.CreateSIToFP (V, SrcTy);
185
- else if (opcode == Instruction::UIToFP)
186
- V = builder.CreateUIToFP (V, SrcTy);
187
- else
188
- V = builder.CreateBitCast (V, SrcTy);
189
- // Call our intrinsic
190
- if (RetTy->isHalfTy ())
191
- RetTy = T_int16;
192
- auto FT = FunctionType::get (RetTy, {SrcTy}, false );
193
- FunctionCallee F = M.getOrInsertFunction (Name, FT);
194
- Value *I = builder.CreateCall (F, {V});
195
- // Coerce the result to the expected type
196
- if (opcode == Instruction::FPToSI)
197
- I = builder.CreateFPToSI (I, DestTy);
198
- else if (opcode == Instruction::FPToUI)
199
- I = builder.CreateFPToUI (I, DestTy);
200
- else if (opcode == Instruction::FPExt)
201
- I = builder.CreateFPCast (I, DestTy);
202
- else
203
- I = builder.CreateBitCast (I, DestTy);
204
- return I;
205
- }
206
-
207
48
static bool demoteFloat16 (Function &F)
208
49
{
209
50
auto &ctx = F.getContext ();
51
+ auto T_float16 = Type::getHalfTy (ctx);
210
52
auto T_float32 = Type::getFloatTy (ctx);
211
53
212
54
SmallVector<Instruction *, 0 > erase;
213
55
for (auto &BB : F) {
214
56
for (auto &I : BB) {
215
- // extend Float16 operands to Float32
216
- bool Float16 = I.getType ()->getScalarType ()->isHalfTy ();
217
- for (size_t i = 0 ; !Float16 && i < I.getNumOperands (); i++) {
218
- Value *Op = I.getOperand (i);
219
- if (Op->getType ()->getScalarType ()->isHalfTy ())
220
- Float16 = true ;
221
- }
222
- if (!Float16)
223
- continue ;
224
-
225
- if (auto CI = dyn_cast<CastInst>(&I)) {
226
- if (CI->getOpcode () != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
227
- ++TotalChanged;
228
- IRBuilder<> builder (&I);
229
- Value *NewI = CreateFPCast (CI->getOpcode (), I.getOperand (0 ), I.getType (), builder);
230
- I.replaceAllUsesWith (NewI);
231
- erase.push_back (&I);
232
- }
233
- continue ;
234
- }
235
-
236
57
switch (I.getOpcode ()) {
237
58
case Instruction::FNeg:
238
59
case Instruction::FAdd:
@@ -243,9 +64,6 @@ static bool demoteFloat16(Function &F)
243
64
case Instruction::FCmp:
244
65
break ;
245
66
default :
246
- if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
247
- if (intrinsic->getIntrinsicID ())
248
- break ;
249
67
continue ;
250
68
}
251
69
@@ -257,78 +75,72 @@ static bool demoteFloat16(Function &F)
257
75
IRBuilder<> builder (&I);
258
76
259
77
// extend Float16 operands to Float32
260
- // XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
78
+ bool OperandsChanged = false ;
261
79
SmallVector<Value *, 2 > Operands (I.getNumOperands ());
262
80
for (size_t i = 0 ; i < I.getNumOperands (); i++) {
263
81
Value *Op = I.getOperand (i);
264
- if (Op->getType ()-> getScalarType ()-> isHalfTy () ) {
82
+ if (Op->getType () == T_float16 ) {
265
83
++TotalExt;
266
- Op = CreateFPCast (Instruction::FPExt, Op, Op->getType ()->getWithNewType (T_float32), builder);
84
+ Op = builder.CreateFPExt (Op, T_float32);
85
+ OperandsChanged = true ;
267
86
}
268
87
Operands[i] = (Op);
269
88
}
270
89
271
90
// recreate the instruction if any operands changed,
272
91
// truncating the result back to Float16
273
- Value *NewI;
274
- ++TotalChanged;
275
- switch (I.getOpcode ()) {
276
- case Instruction::FNeg:
277
- assert (Operands.size () == 1 );
278
- ++FNegChanged;
279
- NewI = builder.CreateFNeg (Operands[0 ]);
280
- break ;
281
- case Instruction::FAdd:
282
- assert (Operands.size () == 2 );
283
- ++FAddChanged;
284
- NewI = builder.CreateFAdd (Operands[0 ], Operands[1 ]);
285
- break ;
286
- case Instruction::FSub:
287
- assert (Operands.size () == 2 );
288
- ++FSubChanged;
289
- NewI = builder.CreateFSub (Operands[0 ], Operands[1 ]);
290
- break ;
291
- case Instruction::FMul:
292
- assert (Operands.size () == 2 );
293
- ++FMulChanged;
294
- NewI = builder.CreateFMul (Operands[0 ], Operands[1 ]);
295
- break ;
296
- case Instruction::FDiv:
297
- assert (Operands.size () == 2 );
298
- ++FDivChanged;
299
- NewI = builder.CreateFDiv (Operands[0 ], Operands[1 ]);
300
- break ;
301
- case Instruction::FRem:
302
- assert (Operands.size () == 2 );
303
- ++FRemChanged;
304
- NewI = builder.CreateFRem (Operands[0 ], Operands[1 ]);
305
- break ;
306
- case Instruction::FCmp:
307
- assert (Operands.size () == 2 );
308
- ++FCmpChanged;
309
- NewI = builder.CreateFCmp (cast<FCmpInst>(&I)->getPredicate (),
310
- Operands[0 ], Operands[1 ]);
311
- break ;
312
- default :
313
- if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
314
- // XXX: this is not correct in general
315
- // some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
316
- Type *RetTy = I.getType ();
317
- if (RetTy->getScalarType ()->isHalfTy ())
318
- RetTy = RetTy->getWithNewType (T_float32);
319
- NewI = replaceIntrinsicWith (intrinsic, RetTy, Operands);
92
+ if (OperandsChanged) {
93
+ Value *NewI;
94
+ ++TotalChanged;
95
+ switch (I.getOpcode ()) {
96
+ case Instruction::FNeg:
97
+ assert (Operands.size () == 1 );
98
+ ++FNegChanged;
99
+ NewI = builder.CreateFNeg (Operands[0 ]);
100
+ break ;
101
+ case Instruction::FAdd:
102
+ assert (Operands.size () == 2 );
103
+ ++FAddChanged;
104
+ NewI = builder.CreateFAdd (Operands[0 ], Operands[1 ]);
105
+ break ;
106
+ case Instruction::FSub:
107
+ assert (Operands.size () == 2 );
108
+ ++FSubChanged;
109
+ NewI = builder.CreateFSub (Operands[0 ], Operands[1 ]);
320
110
break ;
111
+ case Instruction::FMul:
112
+ assert (Operands.size () == 2 );
113
+ ++FMulChanged;
114
+ NewI = builder.CreateFMul (Operands[0 ], Operands[1 ]);
115
+ break ;
116
+ case Instruction::FDiv:
117
+ assert (Operands.size () == 2 );
118
+ ++FDivChanged;
119
+ NewI = builder.CreateFDiv (Operands[0 ], Operands[1 ]);
120
+ break ;
121
+ case Instruction::FRem:
122
+ assert (Operands.size () == 2 );
123
+ ++FRemChanged;
124
+ NewI = builder.CreateFRem (Operands[0 ], Operands[1 ]);
125
+ break ;
126
+ case Instruction::FCmp:
127
+ assert (Operands.size () == 2 );
128
+ ++FCmpChanged;
129
+ NewI = builder.CreateFCmp (cast<FCmpInst>(&I)->getPredicate (),
130
+ Operands[0 ], Operands[1 ]);
131
+ break ;
132
+ default :
133
+ abort ();
321
134
}
322
- abort ();
323
- }
324
- cast<Instruction>(NewI)->copyMetadata (I);
325
- cast<Instruction>(NewI)->copyFastMathFlags (&I);
326
- if (NewI->getType () != I.getType ()) {
327
- ++TotalTrunc;
328
- NewI = CreateFPCast (Instruction::FPTrunc, NewI, I.getType (), builder);
135
+ cast<Instruction>(NewI)->copyMetadata (I);
136
+ cast<Instruction>(NewI)->copyFastMathFlags (&I);
137
+ if (NewI->getType () != I.getType ()) {
138
+ ++TotalTrunc;
139
+ NewI = builder.CreateFPTrunc (NewI, I.getType ());
140
+ }
141
+ I.replaceAllUsesWith (NewI);
142
+ erase.push_back (&I);
329
143
}
330
- I.replaceAllUsesWith (NewI);
331
- erase.push_back (&I);
332
144
}
333
145
}
334
146
0 commit comments