20
20
#include " SPIRVSubtarget.h"
21
21
#include " SPIRVTargetMachine.h"
22
22
#include " SPIRVUtils.h"
23
+ #include " llvm/ADT/APInt.h"
24
+ #include " llvm/IR/Constants.h"
25
+ #include " llvm/IR/Type.h"
23
26
#include " llvm/IR/TypedPointerType.h"
27
+ #include " llvm/Support/Casting.h"
28
+ #include < cassert>
24
29
25
30
using namespace llvm ;
26
31
SPIRVGlobalRegistry::SPIRVGlobalRegistry (unsigned PointerSize)
@@ -35,6 +40,15 @@ SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
35
40
return SpirvType;
36
41
}
37
42
43
+ SPIRVType *
44
+ SPIRVGlobalRegistry::assignFloatTypeToVReg (unsigned BitWidth, Register VReg,
45
+ MachineInstr &I,
46
+ const SPIRVInstrInfo &TII) {
47
+ SPIRVType *SpirvType = getOrCreateSPIRVFloatType (BitWidth, I, TII);
48
+ assignSPIRVTypeToVReg (SpirvType, VReg, *CurMF);
49
+ return SpirvType;
50
+ }
51
+
38
52
SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg (
39
53
SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
40
54
const SPIRVInstrInfo &TII) {
@@ -151,6 +165,8 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
151
165
Register Res = DT.find (CI, CurMF);
152
166
if (!Res.isValid ()) {
153
167
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth (SpvType) : 32 ;
168
+ // TODO: handle cases where the type is not 32bit wide
169
+ // TODO: https://github.com/llvm/llvm-project/issues/88129
154
170
LLT LLTy = LLT::scalar (32 );
155
171
Res = CurMF->getRegInfo ().createGenericVirtualRegister (LLTy);
156
172
CurMF->getRegInfo ().setRegClass (Res, &SPIRV::IDRegClass);
@@ -164,9 +180,83 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
164
180
return std::make_tuple (Res, CI, NewInstr);
165
181
}
166
182
183
+ std::tuple<Register, ConstantFP *, bool , unsigned >
184
+ SPIRVGlobalRegistry::getOrCreateConstFloatReg (APFloat Val, SPIRVType *SpvType,
185
+ MachineIRBuilder *MIRBuilder,
186
+ MachineInstr *I,
187
+ const SPIRVInstrInfo *TII) {
188
+ const Type *LLVMFloatTy;
189
+ LLVMContext &Ctx = CurMF->getFunction ().getContext ();
190
+ unsigned BitWidth = 32 ;
191
+ if (SpvType)
192
+ LLVMFloatTy = getTypeForSPIRVType (SpvType);
193
+ else {
194
+ LLVMFloatTy = Type::getFloatTy (Ctx);
195
+ if (MIRBuilder)
196
+ SpvType = getOrCreateSPIRVType (LLVMFloatTy, *MIRBuilder);
197
+ }
198
+ bool NewInstr = false ;
199
+ // Find a constant in DT or build a new one.
200
+ auto *const CI = ConstantFP::get (Ctx, Val);
201
+ Register Res = DT.find (CI, CurMF);
202
+ if (!Res.isValid ()) {
203
+ if (SpvType)
204
+ BitWidth = getScalarOrVectorBitWidth (SpvType);
205
+ // TODO: handle cases where the type is not 32bit wide
206
+ // TODO: https://github.com/llvm/llvm-project/issues/88129
207
+ LLT LLTy = LLT::scalar (32 );
208
+ Res = CurMF->getRegInfo ().createGenericVirtualRegister (LLTy);
209
+ CurMF->getRegInfo ().setRegClass (Res, &SPIRV::IDRegClass);
210
+ if (MIRBuilder)
211
+ assignTypeToVReg (LLVMFloatTy, Res, *MIRBuilder);
212
+ else
213
+ assignFloatTypeToVReg (BitWidth, Res, *I, *TII);
214
+ DT.add (CI, CurMF, Res);
215
+ NewInstr = true ;
216
+ }
217
+ return std::make_tuple (Res, CI, NewInstr, BitWidth);
218
+ }
219
+
220
+ Register SPIRVGlobalRegistry::getOrCreateConstFP (APFloat Val, MachineInstr &I,
221
+ SPIRVType *SpvType,
222
+ const SPIRVInstrInfo &TII,
223
+ bool ZeroAsNull) {
224
+ assert (SpvType);
225
+ ConstantFP *CI;
226
+ Register Res;
227
+ bool New;
228
+ unsigned BitWidth;
229
+ std::tie (Res, CI, New, BitWidth) =
230
+ getOrCreateConstFloatReg (Val, SpvType, nullptr , &I, &TII);
231
+ // If we have found Res register which is defined by the passed G_CONSTANT
232
+ // machine instruction, a new constant instruction should be created.
233
+ if (!New && (!I.getOperand (0 ).isReg () || Res != I.getOperand (0 ).getReg ()))
234
+ return Res;
235
+ MachineInstrBuilder MIB;
236
+ MachineBasicBlock &BB = *I.getParent ();
237
+ // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
238
+ if (Val.isPosZero () && ZeroAsNull) {
239
+ MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpConstantNull))
240
+ .addDef (Res)
241
+ .addUse (getSPIRVTypeID (SpvType));
242
+ } else {
243
+ MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpConstantF))
244
+ .addDef (Res)
245
+ .addUse (getSPIRVTypeID (SpvType));
246
+ addNumImm (
247
+ APInt (BitWidth, CI->getValueAPF ().bitcastToAPInt ().getZExtValue ()),
248
+ MIB);
249
+ }
250
+ const auto &ST = CurMF->getSubtarget ();
251
+ constrainSelectedInstRegOperands (*MIB, *ST.getInstrInfo (),
252
+ *ST.getRegisterInfo (), *ST.getRegBankInfo ());
253
+ return Res;
254
+ }
255
+
167
256
Register SPIRVGlobalRegistry::getOrCreateConstInt (uint64_t Val, MachineInstr &I,
168
257
SPIRVType *SpvType,
169
- const SPIRVInstrInfo &TII) {
258
+ const SPIRVInstrInfo &TII,
259
+ bool ZeroAsNull) {
170
260
assert (SpvType);
171
261
ConstantInt *CI;
172
262
Register Res;
@@ -179,7 +269,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
179
269
return Res;
180
270
MachineInstrBuilder MIB;
181
271
MachineBasicBlock &BB = *I.getParent ();
182
- if (Val) {
272
+ if (Val || !ZeroAsNull ) {
183
273
MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpConstantI))
184
274
.addDef (Res)
185
275
.addUse (getSPIRVTypeID (SpvType));
@@ -270,21 +360,46 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
270
360
return Res;
271
361
}
272
362
273
- Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull (
274
- uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
363
+ Register SPIRVGlobalRegistry::getOrCreateBaseRegister (Constant *Val,
364
+ MachineInstr &I,
365
+ SPIRVType *SpvType,
366
+ const SPIRVInstrInfo &TII,
367
+ unsigned BitWidth) {
368
+ SPIRVType *Type = SpvType;
369
+ if (SpvType->getOpcode () == SPIRV::OpTypeVector ||
370
+ SpvType->getOpcode () == SPIRV::OpTypeArray) {
371
+ auto EleTypeReg = SpvType->getOperand (1 ).getReg ();
372
+ Type = getSPIRVTypeForVReg (EleTypeReg);
373
+ }
374
+ if (Type->getOpcode () == SPIRV::OpTypeFloat) {
375
+ SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType (BitWidth, I, TII);
376
+ return getOrCreateConstFP (dyn_cast<ConstantFP>(Val)->getValue (), I,
377
+ SpvBaseType, TII);
378
+ }
379
+ assert (Type->getOpcode () == SPIRV::OpTypeInt);
380
+ SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType (BitWidth, I, TII);
381
+ return getOrCreateConstInt (Val->getUniqueInteger ().getSExtValue (), I,
382
+ SpvBaseType, TII);
383
+ }
384
+
385
+ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull (
386
+ Constant *Val, MachineInstr &I, SPIRVType *SpvType,
275
387
const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
276
- unsigned ElemCnt) {
388
+ unsigned ElemCnt, bool ZeroAsNull ) {
277
389
// Find a constant vector in DT or build a new one.
278
390
Register Res = DT.find (CA, CurMF);
391
+ // If no values are attached, the composite is null constant.
392
+ bool IsNull = Val->isNullValue () && ZeroAsNull;
279
393
if (!Res.isValid ()) {
280
- SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType (BitWidth, I, TII);
281
394
// SpvScalConst should be created before SpvVecConst to avoid undefined ID
282
395
// error on validation.
283
396
// TODO: can moved below once sorting of types/consts/defs is implemented.
284
397
Register SpvScalConst;
285
- if (Val)
286
- SpvScalConst = getOrCreateConstInt (Val, I, SpvBaseType, TII);
287
- // TODO: maybe use bitwidth of base type.
398
+ if (!IsNull)
399
+ SpvScalConst = getOrCreateBaseRegister (Val, I, SpvType, TII, BitWidth);
400
+
401
+ // TODO: handle cases where the type is not 32bit wide
402
+ // TODO: https://github.com/llvm/llvm-project/issues/88129
288
403
LLT LLTy = LLT::scalar (32 );
289
404
Register SpvVecConst =
290
405
CurMF->getRegInfo ().createGenericVirtualRegister (LLTy);
@@ -293,7 +408,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
293
408
DT.add (CA, CurMF, SpvVecConst);
294
409
MachineInstrBuilder MIB;
295
410
MachineBasicBlock &BB = *I.getParent ();
296
- if (Val ) {
411
+ if (!IsNull ) {
297
412
MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpConstantComposite))
298
413
.addDef (SpvVecConst)
299
414
.addUse (getSPIRVTypeID (SpvType));
@@ -313,20 +428,42 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
313
428
return Res;
314
429
}
315
430
316
- Register
317
- SPIRVGlobalRegistry::getOrCreateConsIntVector (uint64_t Val, MachineInstr &I,
318
- SPIRVType *SpvType,
319
- const SPIRVInstrInfo &TII) {
431
+ Register SPIRVGlobalRegistry::getOrCreateConstVector (uint64_t Val,
432
+ MachineInstr &I,
433
+ SPIRVType *SpvType,
434
+ const SPIRVInstrInfo &TII,
435
+ bool ZeroAsNull) {
320
436
const Type *LLVMTy = getTypeForSPIRVType (SpvType);
321
437
assert (LLVMTy->isVectorTy ());
322
438
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
323
439
Type *LLVMBaseTy = LLVMVecTy->getElementType ();
324
- const auto ConstInt = ConstantInt::get (LLVMBaseTy, Val);
325
- auto ConstVec =
326
- ConstantVector::getSplat (LLVMVecTy->getElementCount (), ConstInt);
440
+ assert (LLVMBaseTy->isIntegerTy ());
441
+ auto *ConstVal = ConstantInt::get (LLVMBaseTy, Val);
442
+ auto *ConstVec =
443
+ ConstantVector::getSplat (LLVMVecTy->getElementCount (), ConstVal);
327
444
unsigned BW = getScalarOrVectorBitWidth (SpvType);
328
- return getOrCreateIntCompositeOrNull (Val, I, SpvType, TII, ConstVec, BW,
329
- SpvType->getOperand (2 ).getImm ());
445
+ return getOrCreateCompositeOrNull (ConstVal, I, SpvType, TII, ConstVec, BW,
446
+ SpvType->getOperand (2 ).getImm (),
447
+ ZeroAsNull);
448
+ }
449
+
450
+ Register SPIRVGlobalRegistry::getOrCreateConstVector (APFloat Val,
451
+ MachineInstr &I,
452
+ SPIRVType *SpvType,
453
+ const SPIRVInstrInfo &TII,
454
+ bool ZeroAsNull) {
455
+ const Type *LLVMTy = getTypeForSPIRVType (SpvType);
456
+ assert (LLVMTy->isVectorTy ());
457
+ const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
458
+ Type *LLVMBaseTy = LLVMVecTy->getElementType ();
459
+ assert (LLVMBaseTy->isFloatingPointTy ());
460
+ auto *ConstVal = ConstantFP::get (LLVMBaseTy, Val);
461
+ auto *ConstVec =
462
+ ConstantVector::getSplat (LLVMVecTy->getElementCount (), ConstVal);
463
+ unsigned BW = getScalarOrVectorBitWidth (SpvType);
464
+ return getOrCreateCompositeOrNull (ConstVal, I, SpvType, TII, ConstVec, BW,
465
+ SpvType->getOperand (2 ).getImm (),
466
+ ZeroAsNull);
330
467
}
331
468
332
469
Register
@@ -337,13 +474,13 @@ SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
337
474
assert (LLVMTy->isArrayTy ());
338
475
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
339
476
Type *LLVMBaseTy = LLVMArrTy->getElementType ();
340
- const auto ConstInt = ConstantInt::get (LLVMBaseTy, Val);
341
- auto ConstArr =
477
+ auto * ConstInt = ConstantInt::get (LLVMBaseTy, Val);
478
+ auto * ConstArr =
342
479
ConstantArray::get (const_cast <ArrayType *>(LLVMArrTy), {ConstInt});
343
480
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg (SpvType->getOperand (1 ).getReg ());
344
481
unsigned BW = getScalarOrVectorBitWidth (SpvBaseTy);
345
- return getOrCreateIntCompositeOrNull (Val , I, SpvType, TII, ConstArr, BW,
346
- LLVMArrTy->getNumElements ());
482
+ return getOrCreateCompositeOrNull (ConstInt , I, SpvType, TII, ConstArr, BW,
483
+ LLVMArrTy->getNumElements ());
347
484
}
348
485
349
486
Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull (
@@ -1093,21 +1230,48 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1093
1230
return SpirvType;
1094
1231
}
1095
1232
1096
- SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType (
1097
- unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1098
- Type *LLVMTy = IntegerType::get (CurMF->getFunction ().getContext (), BitWidth);
1233
+ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType (unsigned BitWidth,
1234
+ MachineInstr &I,
1235
+ const SPIRVInstrInfo &TII,
1236
+ unsigned SPIRVOPcode,
1237
+ Type *LLVMTy) {
1099
1238
Register Reg = DT.find (LLVMTy, CurMF);
1100
1239
if (Reg.isValid ())
1101
1240
return getSPIRVTypeForVReg (Reg);
1102
1241
MachineBasicBlock &BB = *I.getParent ();
1103
- auto MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpTypeInt ))
1242
+ auto MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRVOPcode ))
1104
1243
.addDef (createTypeVReg (CurMF->getRegInfo ()))
1105
1244
.addImm (BitWidth)
1106
1245
.addImm (0 );
1107
1246
DT.add (LLVMTy, CurMF, getSPIRVTypeID (MIB));
1108
1247
return finishCreatingSPIRVType (LLVMTy, MIB);
1109
1248
}
1110
1249
1250
+ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType (
1251
+ unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1252
+ Type *LLVMTy = IntegerType::get (CurMF->getFunction ().getContext (), BitWidth);
1253
+ return getOrCreateSPIRVType (BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
1254
+ }
1255
+ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType (
1256
+ unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1257
+ LLVMContext &Ctx = CurMF->getFunction ().getContext ();
1258
+ Type *LLVMTy;
1259
+ switch (BitWidth) {
1260
+ case 16 :
1261
+ LLVMTy = Type::getHalfTy (Ctx);
1262
+ break ;
1263
+ case 32 :
1264
+ LLVMTy = Type::getFloatTy (Ctx);
1265
+ break ;
1266
+ case 64 :
1267
+ LLVMTy = Type::getDoubleTy (Ctx);
1268
+ break ;
1269
+ default :
1270
+ llvm_unreachable (" Bit width is of unexpected size." );
1271
+ }
1272
+ return getOrCreateSPIRVType (BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
1273
+ }
1274
+
1111
1275
SPIRVType *
1112
1276
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType (MachineIRBuilder &MIRBuilder) {
1113
1277
return getOrCreateSPIRVType (
0 commit comments