@@ -88,6 +88,8 @@ namespace {
88
88
// A list of memory references (within a BB) with the distance to the begining of the BB.
89
89
typedef std::vector<std::pair<Instruction*, unsigned > > MemRefListTy;
90
90
typedef std::vector<Instruction*> TrivialMemRefListTy;
91
+ // ALoadInst, Offset, MemRefListTy::iterator, LeadingLoad's int2PtrOffset
92
+ typedef SmallVector<std::tuple<Instruction *, int64_t , MemRefListTy::iterator>, 8 > MergeVector;
91
93
92
94
public:
93
95
static char ID;
@@ -134,6 +136,10 @@ namespace {
134
136
Value* getShuffle (Value* ShflId, Instruction* BlockReadToOptimize,
135
137
Value* SgId, llvm::IRBuilder<>& Builder, unsigned & ToOptSize);
136
138
139
+ unsigned getNumElements (Type* Ty) {
140
+ return Ty->isVectorTy () ? (unsigned )cast<IGCLLVM::FixedVectorType>(Ty)->getNumElements () : 1 ;
141
+ }
142
+
137
143
Type* getVectorElementType (Type* Ty) const {
138
144
return isa<VectorType>(Ty) ? cast<VectorType>(Ty)->getElementType () : Ty;
139
145
}
@@ -145,6 +151,95 @@ namespace {
145
151
return DL->getTypeStoreSize (A) == DL->getTypeStoreSize (B);
146
152
}
147
153
154
+ Value* createBitOrPointerCast (Value* V, Type* DestTy,
155
+ IGCIRBuilder<>& Builder) {
156
+ if (V->getType () == DestTy)
157
+ return V;
158
+
159
+ if (V->getType ()->isPointerTy () && DestTy->isPointerTy ()) {
160
+ PointerType* SrcPtrTy = cast<PointerType>(V->getType ());
161
+ PointerType* DstPtrTy = cast<PointerType>(DestTy);
162
+ if (SrcPtrTy->getPointerAddressSpace () !=
163
+ DstPtrTy->getPointerAddressSpace ())
164
+ return Builder.CreateAddrSpaceCast (V, DestTy);
165
+ }
166
+
167
+ if (V->getType ()->isPointerTy ()) {
168
+ if (DestTy->isIntegerTy ()) {
169
+ return Builder.CreatePtrToInt (V, DestTy);
170
+ }
171
+ else if (DestTy->isFloatingPointTy ()) {
172
+ uint32_t Size = (uint32_t )DestTy->getPrimitiveSizeInBits ();
173
+ Value* Cast = Builder.CreatePtrToInt (
174
+ V, Builder.getIntNTy (Size));
175
+ return Builder.CreateBitCast (Cast, DestTy);
176
+ }
177
+ }
178
+
179
+ if (DestTy->isPointerTy ()) {
180
+ if (V->getType ()->isIntegerTy ()) {
181
+ return Builder.CreateIntToPtr (V, DestTy);
182
+ }
183
+ else if (V->getType ()->isFloatingPointTy ()) {
184
+ uint32_t Size = (uint32_t )V->getType ()->getPrimitiveSizeInBits ();
185
+ Value* Cast = Builder.CreateBitCast (
186
+ V, Builder.getIntNTy (Size));
187
+ return Builder.CreateIntToPtr (Cast, DestTy);
188
+ }
189
+ }
190
+
191
+ return Builder.CreateBitCast (V, DestTy);
192
+ }
193
+
194
+ /* *
195
+ * @brief Creates a new merge value for merged load from a set of predicated loads' merge values.
196
+ *
197
+ * This function constructs a new combined merge value by merging the merge values of multiple predicated load intrinsics.
198
+ * Merge value from each input predicated load is inserted into the appropriate position in the resulting merge vector value,
199
+ * based on its offset and the scalar size. The function handles both scalar and vector merge input values.
200
+ *
201
+ * @param MergeValTy The type of the merged value to be created.
202
+ * @param LoadsToMerge A vector of tuples, each containing a load instruction and its associated offset.
203
+ * @param LdScalarSize The size (in bytes) of the scalar element being loaded in the combined load.
204
+ * @param NumElts Number of elements in the merged value vector.
205
+ * @return Value* The newly created merged value, or nullptr if we are merging generic loads, not predicated.
206
+ */
207
+ Value* CreateNewMergeValue (IGCIRBuilder<>& Builder, Type* MergeValTy,
208
+ const MergeVector& LoadsToMerge, unsigned LdScalarSize,
209
+ unsigned & NumElts) {
210
+ Value* NewMergeValue = UndefValue::get (MergeValTy);
211
+ unsigned Pos = 0 ;
212
+ int64_t FirstOffset = std::get<1 >(LoadsToMerge.front ());
213
+
214
+ for (auto & I : LoadsToMerge) {
215
+ PredicatedLoadIntrinsic* PLI = ALoadInst::get (std::get<0 >(I))->getPredicatedLoadIntrinsic ();
216
+ if (!PLI)
217
+ return nullptr ;
218
+
219
+ Value* MergeValue = PLI->getMergeValue ();
220
+ unsigned MergeValNumElements = getNumElements (MergeValue->getType ());
221
+ Type* MergeValScalarTy = MergeValTy->getScalarType ();
222
+ Pos = unsigned ((std::get<1 >(I) - FirstOffset) / LdScalarSize);
223
+
224
+ if (MergeValNumElements == 1 ) {
225
+ IGC_ASSERT_MESSAGE (Pos < NumElts, " Index is larger than the number of elements, we cannot update merge value." );
226
+ MergeValue = createBitOrPointerCast (MergeValue, MergeValScalarTy, Builder);
227
+ NewMergeValue = Builder.CreateInsertElement (NewMergeValue, MergeValue, Builder.getInt32 (Pos));
228
+ continue ;
229
+ }
230
+
231
+ IGC_ASSERT_MESSAGE (Pos + MergeValNumElements <= NumElts,
232
+ " Index is larger than the number of elements, we cannot update merge value." );
233
+
234
+ for (unsigned i = 0 ; i < MergeValNumElements; ++i) {
235
+ Value* ExtractValue = Builder.CreateExtractElement (MergeValue, Builder.getInt32 (i));
236
+ ExtractValue = createBitOrPointerCast (ExtractValue, MergeValScalarTy, Builder);
237
+ NewMergeValue = Builder.CreateInsertElement (NewMergeValue, ExtractValue, Builder.getInt32 (Pos + i));
238
+ }
239
+ }
240
+ return NewMergeValue;
241
+ }
242
+
148
243
bool isSafeToMergeLoad (const ALoadInst& Ld,
149
244
const SmallVectorImpl<Instruction*>& checkList) const ;
150
245
bool isSafeToMergeStores (
@@ -644,76 +739,6 @@ bool MemOpt::removeRedBlockRead(GenIntrinsicInst* LeadingBlockRead,
644
739
return true ;
645
740
}
646
741
647
- namespace {
648
- unsigned getNumElements (Type* Ty) {
649
- return Ty->isVectorTy () ? (unsigned )cast<IGCLLVM::FixedVectorType>(Ty)->getNumElements () : 1 ;
650
- }
651
-
652
- template <typename T>
653
- Value* createBitOrPointerCast (Value* V, Type* DestTy,
654
- IGCIRBuilder<T>& Builder) {
655
- if (V->getType () == DestTy)
656
- return V;
657
- if (V->getType ()->isPointerTy () && DestTy->isPointerTy ()) {
658
- PointerType* SrcPtrTy = cast<PointerType>(V->getType ());
659
- PointerType* DstPtrTy = cast<PointerType>(DestTy);
660
- if (SrcPtrTy->getPointerAddressSpace () !=
661
- DstPtrTy->getPointerAddressSpace ())
662
- return Builder.CreateAddrSpaceCast (V, DestTy);
663
- }
664
- if (V->getType ()->isPointerTy ()) {
665
- if (DestTy->isIntegerTy ()) {
666
- return Builder.CreatePtrToInt (V, DestTy);
667
- }
668
- else if (DestTy->isFloatingPointTy ()) {
669
- uint32_t Size = (uint32_t )DestTy->getPrimitiveSizeInBits ();
670
- Value* Cast = Builder.CreatePtrToInt (
671
- V, Builder.getIntNTy (Size));
672
- return Builder.CreateBitCast (Cast, DestTy);
673
- }
674
- }
675
- if (DestTy->isPointerTy ()) {
676
- if (V->getType ()->isIntegerTy ()) {
677
- return Builder.CreateIntToPtr (V, DestTy);
678
- }
679
- else if (V->getType ()->isFloatingPointTy ()) {
680
- uint32_t Size = (uint32_t )V->getType ()->getPrimitiveSizeInBits ();
681
- Value* Cast = Builder.CreateBitCast (
682
- V, Builder.getIntNTy (Size));
683
- return Builder.CreateIntToPtr (Cast, DestTy);
684
- }
685
- }
686
- return Builder.CreateBitCast (V, DestTy);
687
- }
688
-
689
- template <typename T>
690
- Value* CreateNewMergeValue (IGCIRBuilder<T>& Builder, Type* Ty,
691
- const SmallVector<Instruction *>& LoadsToMerge) {
692
- Value* NewMergeValue = UndefValue::get (Ty);
693
- unsigned idx = 0 ;
694
- for (auto * Inst : LoadsToMerge) {
695
- PredicatedLoadIntrinsic* PLI = ALoadInst::get (Inst)->getPredicatedLoadIntrinsic ();
696
- if (!PLI)
697
- return nullptr ;
698
- Value* MergeValue = PLI->getMergeValue ();
699
- unsigned NumElements = getNumElements (MergeValue->getType ());
700
- Type* ScalarTy = Ty->getScalarType ();
701
- if (NumElements == 1 ) {
702
- MergeValue = createBitOrPointerCast (MergeValue, ScalarTy, Builder);
703
- NewMergeValue = Builder.CreateInsertElement (NewMergeValue, MergeValue, Builder.getInt32 (idx++));
704
- continue ;
705
- }
706
-
707
- for (unsigned i = 0 ; i < NumElements; ++i) {
708
- Value* ExtractValue = Builder.CreateExtractElement (MergeValue, Builder.getInt32 (i));
709
- ExtractValue = createBitOrPointerCast (ExtractValue, ScalarTy, Builder);
710
- NewMergeValue = Builder.CreateInsertElement (NewMergeValue, ExtractValue, Builder.getInt32 (idx++));
711
- }
712
- }
713
- return NewMergeValue;
714
- }
715
- }
716
-
717
742
// Removes redundant blockread if both blockreads are scalar.
718
743
void MemOpt::removeScalarBlockRead (Instruction* BlockReadToOptimize,
719
744
Instruction* BlockReadToRemove, Value* SgId,
@@ -1151,8 +1176,7 @@ bool MemOpt::mergeLoad(ALoadInst& LeadingLoad,
1151
1176
}
1152
1177
1153
1178
// ALoadInst, Offset, MemRefListTy::iterator, LeadingLoad's int2PtrOffset
1154
- SmallVector<std::tuple<Instruction *, int64_t , MemRefListTy::iterator>, 8 >
1155
- LoadsToMerge;
1179
+ MergeVector LoadsToMerge;
1156
1180
LoadsToMerge.push_back (std::make_tuple (LeadingLoad.inst (), 0 , MI));
1157
1181
1158
1182
// Loads to be merged is scanned in the program order and will be merged into
@@ -1375,10 +1399,8 @@ bool MemOpt::mergeLoad(ALoadInst& LeadingLoad,
1375
1399
Value* NewPointer = Builder.CreateBitCast (Ptr, NewPointerType);
1376
1400
1377
1401
// Prepare Merge Value if needed:
1378
- SmallVector<Instruction*> LoadsToMergeInsts;
1379
- for (auto & I : LoadsToMerge)
1380
- LoadsToMergeInsts.push_back (std::get<0 >(I));
1381
- Value* NewMergeValue = CreateNewMergeValue (Builder, NewLoadType, LoadsToMergeInsts);
1402
+ Value* NewMergeValue = CreateNewMergeValue (Builder, NewLoadType, LoadsToMerge,
1403
+ LdScalarSize, NumElts);
1382
1404
1383
1405
Instruction* NewLoad =
1384
1406
FirstLoad.CreateAlignedLoad (Builder, NewLoadType, NewPointer, NewMergeValue);
@@ -3486,7 +3508,7 @@ void LdStCombine::combineLoads()
3486
3508
++numInsts;
3487
3509
3488
3510
// cannot merge beyond fence or window limit
3489
- if ((I->isFenceLike () && !isa<PredicatedLoadIntrinsic>(I)) || numInsts > LDWINDOWSIZE) {
3511
+ if ((I->isFenceLike () && !isa<PredicatedLoadIntrinsic>(I) && !isa<PredicatedStoreIntrinsic>(I) ) || numInsts > LDWINDOWSIZE) {
3490
3512
LLVM_DEBUG (dbgs () << " - - Stop at fence or window limit\n " );
3491
3513
break ;
3492
3514
}
@@ -3772,6 +3794,7 @@ void LdStCombine::createBundles(BasicBlock* BB, InstAndOffsetPairs& LoadStores)
3772
3794
{
3773
3795
// 1. The first one is the leading store.
3774
3796
const LdStInfo* leadLSI = &LoadStores[i];
3797
+ LLVM_DEBUG (llvm::dbgs () << " Try leading LdSt: " << *leadLSI->getInst () << " \n " );
3775
3798
if (isBundled (leadLSI, m_combinedInsts) ||
3776
3799
(i+1 ) == SZ) /* skip for last one */ {
3777
3800
++i;
@@ -3780,7 +3803,7 @@ void LdStCombine::createBundles(BasicBlock* BB, InstAndOffsetPairs& LoadStores)
3780
3803
3781
3804
if (m_WI && isUniform &&
3782
3805
!m_WI->isUniform (leadLSI->getValueOperand ())) {
3783
- // no combining for *uniform-ptr = non-uniform value
3806
+ LLVM_DEBUG ( llvm::dbgs () << " No combining for *uniform-ptr = non-uniform value\n " );
3784
3807
++i;
3785
3808
continue ;
3786
3809
}
@@ -3816,6 +3839,7 @@ void LdStCombine::createBundles(BasicBlock* BB, InstAndOffsetPairs& LoadStores)
3816
3839
uint32_t vecSize = -1 ;
3817
3840
for (int j = i + 1 ; j < SZ; ++j) {
3818
3841
const LdStInfo* LSI = &LoadStores[j];
3842
+ LLVM_DEBUG (llvm::dbgs () << " Try to make bundle with: " << *LSI->getInst () << " \n " );
3819
3843
if (isBundled (LSI, m_combinedInsts) ||
3820
3844
(leadLSI->getByteOffset () + totalBytes) != LSI->getByteOffset ())
3821
3845
{
@@ -3824,7 +3848,7 @@ void LdStCombine::createBundles(BasicBlock* BB, InstAndOffsetPairs& LoadStores)
3824
3848
}
3825
3849
if (m_WI && isUniform &&
3826
3850
!m_WI->isUniform (LSI->getValueOperand ())) {
3827
- // no combining for *uniform-ptr = non-uniform value
3851
+ LLVM_DEBUG ( llvm::dbgs () << " No combining for *uniform-ptr = non-uniform value\n " );
3828
3852
break ;
3829
3853
}
3830
3854
0 commit comments