Skip to content

Commit 4df00ae

Browse files
YuriPlyakhinigcbot
authored andcommitted
Fix bugs, add more lit tests for predicated memory optimizations
Adds more lit tests to verify the functionality of the predicated memory optimizations. Fixes a few bugs found during the testing.
1 parent 39945e0 commit 4df00ae

9 files changed

+655
-108
lines changed

IGC/Compiler/CISACodeGen/MemOpt.cpp

Lines changed: 103 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ namespace {
8888
// A list of memory references (within a BB) with the distance to the begining of the BB.
8989
typedef std::vector<std::pair<Instruction*, unsigned> > MemRefListTy;
9090
typedef std::vector<Instruction*> TrivialMemRefListTy;
91+
// ALoadInst, Offset, MemRefListTy::iterator, LeadingLoad's int2PtrOffset
92+
typedef SmallVector<std::tuple<Instruction *, int64_t, MemRefListTy::iterator>, 8> MergeVector;
9193

9294
public:
9395
static char ID;
@@ -134,6 +136,10 @@ namespace {
134136
Value* getShuffle(Value* ShflId, Instruction* BlockReadToOptimize,
135137
Value* SgId, llvm::IRBuilder<>& Builder, unsigned& ToOptSize);
136138

139+
unsigned getNumElements(Type* Ty) {
140+
return Ty->isVectorTy() ? (unsigned)cast<IGCLLVM::FixedVectorType>(Ty)->getNumElements() : 1;
141+
}
142+
137143
Type* getVectorElementType(Type* Ty) const {
138144
return isa<VectorType>(Ty) ? cast<VectorType>(Ty)->getElementType() : Ty;
139145
}
@@ -145,6 +151,95 @@ namespace {
145151
return DL->getTypeStoreSize(A) == DL->getTypeStoreSize(B);
146152
}
147153

154+
Value* createBitOrPointerCast(Value* V, Type* DestTy,
155+
IGCIRBuilder<>& Builder) {
156+
if (V->getType() == DestTy)
157+
return V;
158+
159+
if (V->getType()->isPointerTy() && DestTy->isPointerTy()) {
160+
PointerType* SrcPtrTy = cast<PointerType>(V->getType());
161+
PointerType* DstPtrTy = cast<PointerType>(DestTy);
162+
if (SrcPtrTy->getPointerAddressSpace() !=
163+
DstPtrTy->getPointerAddressSpace())
164+
return Builder.CreateAddrSpaceCast(V, DestTy);
165+
}
166+
167+
if (V->getType()->isPointerTy()) {
168+
if (DestTy->isIntegerTy()) {
169+
return Builder.CreatePtrToInt(V, DestTy);
170+
}
171+
else if (DestTy->isFloatingPointTy()) {
172+
uint32_t Size = (uint32_t)DestTy->getPrimitiveSizeInBits();
173+
Value* Cast = Builder.CreatePtrToInt(
174+
V, Builder.getIntNTy(Size));
175+
return Builder.CreateBitCast(Cast, DestTy);
176+
}
177+
}
178+
179+
if (DestTy->isPointerTy()) {
180+
if (V->getType()->isIntegerTy()) {
181+
return Builder.CreateIntToPtr(V, DestTy);
182+
}
183+
else if (V->getType()->isFloatingPointTy()) {
184+
uint32_t Size = (uint32_t)V->getType()->getPrimitiveSizeInBits();
185+
Value* Cast = Builder.CreateBitCast(
186+
V, Builder.getIntNTy(Size));
187+
return Builder.CreateIntToPtr(Cast, DestTy);
188+
}
189+
}
190+
191+
return Builder.CreateBitCast(V, DestTy);
192+
}
193+
194+
/**
195+
* @brief Creates a new merge value for merged load from a set of predicated loads' merge values.
196+
*
197+
* This function constructs a new combined merge value by merging the merge values of multiple predicated load intrinsics.
198+
* Merge value from each input predicated load is inserted into the appropriate position in the resulting merge vector value,
199+
* based on its offset and the scalar size. The function handles both scalar and vector merge input values.
200+
*
201+
* @param MergeValTy The type of the merged value to be created.
202+
* @param LoadsToMerge A vector of tuples, each containing a load instruction and its associated offset.
203+
* @param LdScalarSize The size (in bytes) of the scalar element being loaded in the combined load.
204+
* @param NumElts Number of elements in the merged value vector.
205+
* @return Value* The newly created merged value, or nullptr if we are merging generic loads, not predicated.
206+
*/
207+
Value* CreateNewMergeValue(IGCIRBuilder<>& Builder, Type* MergeValTy,
208+
const MergeVector& LoadsToMerge, unsigned LdScalarSize,
209+
unsigned& NumElts) {
210+
Value* NewMergeValue = UndefValue::get(MergeValTy);
211+
unsigned Pos = 0;
212+
int64_t FirstOffset = std::get<1>(LoadsToMerge.front());
213+
214+
for (auto& I : LoadsToMerge) {
215+
PredicatedLoadIntrinsic* PLI = ALoadInst::get(std::get<0>(I))->getPredicatedLoadIntrinsic();
216+
if (!PLI)
217+
return nullptr;
218+
219+
Value* MergeValue = PLI->getMergeValue();
220+
unsigned MergeValNumElements = getNumElements(MergeValue->getType());
221+
Type* MergeValScalarTy = MergeValTy->getScalarType();
222+
Pos = unsigned((std::get<1>(I) - FirstOffset) / LdScalarSize);
223+
224+
if (MergeValNumElements == 1) {
225+
IGC_ASSERT_MESSAGE(Pos < NumElts, "Index is larger than the number of elements, we cannot update merge value.");
226+
MergeValue = createBitOrPointerCast(MergeValue, MergeValScalarTy, Builder);
227+
NewMergeValue = Builder.CreateInsertElement(NewMergeValue, MergeValue, Builder.getInt32(Pos));
228+
continue;
229+
}
230+
231+
IGC_ASSERT_MESSAGE(Pos + MergeValNumElements <= NumElts,
232+
"Index is larger than the number of elements, we cannot update merge value.");
233+
234+
for (unsigned i = 0; i < MergeValNumElements; ++i) {
235+
Value* ExtractValue = Builder.CreateExtractElement(MergeValue, Builder.getInt32(i));
236+
ExtractValue = createBitOrPointerCast(ExtractValue, MergeValScalarTy, Builder);
237+
NewMergeValue = Builder.CreateInsertElement(NewMergeValue, ExtractValue, Builder.getInt32(Pos + i));
238+
}
239+
}
240+
return NewMergeValue;
241+
}
242+
148243
bool isSafeToMergeLoad(const ALoadInst& Ld,
149244
const SmallVectorImpl<Instruction*>& checkList) const;
150245
bool isSafeToMergeStores(
@@ -644,76 +739,6 @@ bool MemOpt::removeRedBlockRead(GenIntrinsicInst* LeadingBlockRead,
644739
return true;
645740
}
646741

647-
namespace {
648-
unsigned getNumElements(Type* Ty) {
649-
return Ty->isVectorTy() ? (unsigned)cast<IGCLLVM::FixedVectorType>(Ty)->getNumElements() : 1;
650-
}
651-
652-
template <typename T>
653-
Value* createBitOrPointerCast(Value* V, Type* DestTy,
654-
IGCIRBuilder<T>& Builder) {
655-
if (V->getType() == DestTy)
656-
return V;
657-
if (V->getType()->isPointerTy() && DestTy->isPointerTy()) {
658-
PointerType* SrcPtrTy = cast<PointerType>(V->getType());
659-
PointerType* DstPtrTy = cast<PointerType>(DestTy);
660-
if (SrcPtrTy->getPointerAddressSpace() !=
661-
DstPtrTy->getPointerAddressSpace())
662-
return Builder.CreateAddrSpaceCast(V, DestTy);
663-
}
664-
if (V->getType()->isPointerTy()) {
665-
if (DestTy->isIntegerTy()) {
666-
return Builder.CreatePtrToInt(V, DestTy);
667-
}
668-
else if (DestTy->isFloatingPointTy()) {
669-
uint32_t Size = (uint32_t)DestTy->getPrimitiveSizeInBits();
670-
Value* Cast = Builder.CreatePtrToInt(
671-
V, Builder.getIntNTy(Size));
672-
return Builder.CreateBitCast(Cast, DestTy);
673-
}
674-
}
675-
if (DestTy->isPointerTy()) {
676-
if (V->getType()->isIntegerTy()) {
677-
return Builder.CreateIntToPtr(V, DestTy);
678-
}
679-
else if (V->getType()->isFloatingPointTy()) {
680-
uint32_t Size = (uint32_t)V->getType()->getPrimitiveSizeInBits();
681-
Value* Cast = Builder.CreateBitCast(
682-
V, Builder.getIntNTy(Size));
683-
return Builder.CreateIntToPtr(Cast, DestTy);
684-
}
685-
}
686-
return Builder.CreateBitCast(V, DestTy);
687-
}
688-
689-
template <typename T>
690-
Value* CreateNewMergeValue(IGCIRBuilder<T>& Builder, Type* Ty,
691-
const SmallVector<Instruction *>& LoadsToMerge) {
692-
Value* NewMergeValue = UndefValue::get(Ty);
693-
unsigned idx = 0;
694-
for (auto* Inst : LoadsToMerge) {
695-
PredicatedLoadIntrinsic* PLI = ALoadInst::get(Inst)->getPredicatedLoadIntrinsic();
696-
if (!PLI)
697-
return nullptr;
698-
Value* MergeValue = PLI->getMergeValue();
699-
unsigned NumElements = getNumElements(MergeValue->getType());
700-
Type* ScalarTy = Ty->getScalarType();
701-
if (NumElements == 1) {
702-
MergeValue = createBitOrPointerCast(MergeValue, ScalarTy, Builder);
703-
NewMergeValue = Builder.CreateInsertElement(NewMergeValue, MergeValue, Builder.getInt32(idx++));
704-
continue;
705-
}
706-
707-
for (unsigned i = 0; i < NumElements; ++i) {
708-
Value* ExtractValue = Builder.CreateExtractElement(MergeValue, Builder.getInt32(i));
709-
ExtractValue = createBitOrPointerCast(ExtractValue, ScalarTy, Builder);
710-
NewMergeValue = Builder.CreateInsertElement(NewMergeValue, ExtractValue, Builder.getInt32(idx++));
711-
}
712-
}
713-
return NewMergeValue;
714-
}
715-
}
716-
717742
//Removes redundant blockread if both blockreads are scalar.
718743
void MemOpt::removeScalarBlockRead(Instruction* BlockReadToOptimize,
719744
Instruction* BlockReadToRemove, Value* SgId,
@@ -1151,8 +1176,7 @@ bool MemOpt::mergeLoad(ALoadInst& LeadingLoad,
11511176
}
11521177

11531178
// ALoadInst, Offset, MemRefListTy::iterator, LeadingLoad's int2PtrOffset
1154-
SmallVector<std::tuple<Instruction *, int64_t, MemRefListTy::iterator>, 8>
1155-
LoadsToMerge;
1179+
MergeVector LoadsToMerge;
11561180
LoadsToMerge.push_back(std::make_tuple(LeadingLoad.inst(), 0, MI));
11571181

11581182
// Loads to be merged is scanned in the program order and will be merged into
@@ -1375,10 +1399,8 @@ bool MemOpt::mergeLoad(ALoadInst& LeadingLoad,
13751399
Value* NewPointer = Builder.CreateBitCast(Ptr, NewPointerType);
13761400

13771401
// Prepare Merge Value if needed:
1378-
SmallVector<Instruction*> LoadsToMergeInsts;
1379-
for (auto& I : LoadsToMerge)
1380-
LoadsToMergeInsts.push_back(std::get<0>(I));
1381-
Value* NewMergeValue = CreateNewMergeValue(Builder, NewLoadType, LoadsToMergeInsts);
1402+
Value* NewMergeValue = CreateNewMergeValue(Builder, NewLoadType, LoadsToMerge,
1403+
LdScalarSize, NumElts);
13821404

13831405
Instruction* NewLoad =
13841406
FirstLoad.CreateAlignedLoad(Builder, NewLoadType, NewPointer, NewMergeValue);
@@ -3486,7 +3508,7 @@ void LdStCombine::combineLoads()
34863508
++numInsts;
34873509

34883510
// cannot merge beyond fence or window limit
3489-
if ((I->isFenceLike() && !isa<PredicatedLoadIntrinsic>(I)) || numInsts > LDWINDOWSIZE) {
3511+
if ((I->isFenceLike() && !isa<PredicatedLoadIntrinsic>(I) && !isa<PredicatedStoreIntrinsic>(I)) || numInsts > LDWINDOWSIZE) {
34903512
LLVM_DEBUG(dbgs() << "- - Stop at fence or window limit\n");
34913513
break;
34923514
}
@@ -3772,6 +3794,7 @@ void LdStCombine::createBundles(BasicBlock* BB, InstAndOffsetPairs& LoadStores)
37723794
{
37733795
// 1. The first one is the leading store.
37743796
const LdStInfo* leadLSI = &LoadStores[i];
3797+
LLVM_DEBUG(llvm::dbgs() << "Try leading LdSt: " << *leadLSI->getInst() << "\n");
37753798
if (isBundled(leadLSI, m_combinedInsts) ||
37763799
(i+1) == SZ) /* skip for last one */ {
37773800
++i;
@@ -3780,7 +3803,7 @@ void LdStCombine::createBundles(BasicBlock* BB, InstAndOffsetPairs& LoadStores)
37803803

37813804
if (m_WI && isUniform &&
37823805
!m_WI->isUniform(leadLSI->getValueOperand())) {
3783-
// no combining for *uniform-ptr = non-uniform value
3806+
LLVM_DEBUG(llvm::dbgs() << "No combining for *uniform-ptr = non-uniform value\n");
37843807
++i;
37853808
continue;
37863809
}
@@ -3816,6 +3839,7 @@ void LdStCombine::createBundles(BasicBlock* BB, InstAndOffsetPairs& LoadStores)
38163839
uint32_t vecSize = -1;
38173840
for (int j = i + 1; j < SZ; ++j) {
38183841
const LdStInfo* LSI = &LoadStores[j];
3842+
LLVM_DEBUG(llvm::dbgs() << "Try to make bundle with: " << *LSI->getInst() << "\n");
38193843
if (isBundled(LSI, m_combinedInsts) ||
38203844
(leadLSI->getByteOffset() + totalBytes) != LSI->getByteOffset())
38213845
{
@@ -3824,7 +3848,7 @@ void LdStCombine::createBundles(BasicBlock* BB, InstAndOffsetPairs& LoadStores)
38243848
}
38253849
if (m_WI && isUniform &&
38263850
!m_WI->isUniform(LSI->getValueOperand())) {
3827-
// no combining for *uniform-ptr = non-uniform value
3851+
LLVM_DEBUG(llvm::dbgs() << "No combining for *uniform-ptr = non-uniform value\n");
38283852
break;
38293853
}
38303854

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2025 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
9+
; REQUIRES: regkeys
10+
; RUN: igc_opt --typed-pointers %s -S -inputocl -igc-ldstcombine -regkey=EnableLdStCombine=5 -platformbmg -instcombine | FileCheck %s
11+
12+
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f16:16:16-f32:32:32-f64:64:64-f80:128:128-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-a:64:64-f80:128:128-n8:16:32:64"
13+
14+
define void @f0(i32* %dst, i32* %src) {
15+
entry:
16+
%0 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %src, i64 4, i1 true, i32 42), !nontemporal !5
17+
%arrayidx1 = getelementptr inbounds i32, i32* %src, i64 1
18+
%1 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %arrayidx1, i64 4, i1 true, i32 43), !nontemporal !5
19+
%arrayidx2 = getelementptr inbounds i32, i32* %src, i64 2
20+
%2 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %arrayidx2, i64 4, i1 true, i32 44), !nontemporal !5
21+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %dst, i32 %0, i64 4, i1 true), !nontemporal !5
22+
%arrayidx4 = getelementptr inbounds i32, i32* %dst, i64 1
23+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %arrayidx4, i32 %1, i64 4, i1 true), !nontemporal !5
24+
%arrayidx5 = getelementptr inbounds i32, i32* %dst, i64 2
25+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %arrayidx5, i32 %2, i64 4, i1 true), !nontemporal !5
26+
ret void
27+
}
28+
29+
; CHECK-LABEL: @f0(
30+
; CHECK-NEXT: entry:
31+
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[SRC:%.*]] to <3 x i32>*
32+
; CHECK-NEXT: [[TMP1:%.*]] = call <3 x i32> @llvm.genx.GenISA.PredicatedLoad.v3i32.p0v3i32.v3i32(<3 x i32>* [[TMP0]], i64 4, i1 true, <3 x i32> <i32 42, i32 43, i32 44>)
33+
; CHECK-SAME: !nontemporal ![[NONTEMPORAL:[0-9]+]]
34+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast i32* [[DST:%.*]] to <3 x i32>*
35+
; CHECK-NEXT: call void @llvm.genx.GenISA.PredicatedStore.p0v3i32.v3i32(<3 x i32>* [[TMP2]], <3 x i32> [[TMP1]], i64 4, i1 true)
36+
; CHECK-SAME: !nontemporal ![[NONTEMPORAL:[0-9]+]]
37+
; CHECK-NEXT: ret void
38+
39+
40+
define void @f1(i32* %dst, i32* %src) {
41+
entry:
42+
%0 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %src, i64 4, i1 true, i32 42)
43+
%arrayidx1 = getelementptr inbounds i32, i32* %src, i64 1
44+
%1 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %arrayidx1, i64 4, i1 true, i32 43), !nontemporal !5
45+
%arrayidx2 = getelementptr inbounds i32, i32* %src, i64 2
46+
%2 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %arrayidx2, i64 4, i1 true, i32 44)
47+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %dst, i32 %0, i64 4, i1 true)
48+
%arrayidx4 = getelementptr inbounds i32, i32* %dst, i64 1
49+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %arrayidx4, i32 %1, i64 4, i1 true)
50+
%arrayidx5 = getelementptr inbounds i32, i32* %dst, i64 2
51+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %arrayidx5, i32 %2, i64 4, i1 true), !nontemporal !5
52+
ret void
53+
}
54+
55+
; CHECK-LABEL: @f1(
56+
; CHECK-NEXT: entry:
57+
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[SRC:%.*]] to <3 x i32>*
58+
; CHECK-NEXT: [[TMP1:%.*]] = call <3 x i32> @llvm.genx.GenISA.PredicatedLoad.v3i32.p0v3i32.v3i32(<3 x i32>* [[TMP0]], i64 4, i1 true, <3 x i32> <i32 42, i32 43, i32 44>)
59+
; CHECK-SAME: !nontemporal ![[NONTEMPORAL:[0-9]+]]
60+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast i32* [[DST:%.*]] to <3 x i32>*
61+
; CHECK-NEXT: call void @llvm.genx.GenISA.PredicatedStore.p0v3i32.v3i32(<3 x i32>* [[TMP2]], <3 x i32> [[TMP1]], i64 4, i1 true)
62+
; CHECK-SAME: !nontemporal ![[NONTEMPORAL:[0-9]+]]
63+
; CHECK-NEXT: ret void
64+
65+
define void @f2(i32* %dst, i32* %src) {
66+
entry:
67+
%0 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %src, i64 4, i1 true, i32 42)
68+
%arrayidx1 = getelementptr inbounds i32, i32* %src, i64 1
69+
%1 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %arrayidx1, i64 4, i1 true, i32 43)
70+
%arrayidx2 = getelementptr inbounds i32, i32* %src, i64 2
71+
%2 = call i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32* %arrayidx2, i64 4, i1 true, i32 44)
72+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %dst, i32 %0, i64 4, i1 true)
73+
%arrayidx4 = getelementptr inbounds i32, i32* %dst, i64 1
74+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %arrayidx4, i32 %1, i64 4, i1 true)
75+
%arrayidx5 = getelementptr inbounds i32, i32* %dst, i64 2
76+
call void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32* %arrayidx5, i32 %2, i64 4, i1 true)
77+
ret void
78+
}
79+
80+
; CHECK-LABEL: @f2(
81+
; CHECK-NEXT: entry:
82+
; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[SRC:%.*]] to <3 x i32>*
83+
; CHECK-NEXT: [[TMP1:%.*]] = call <3 x i32> @llvm.genx.GenISA.PredicatedLoad.v3i32.p0v3i32.v3i32(<3 x i32>* [[TMP0]], i64 4, i1 true, <3 x i32> <i32 42, i32 43, i32 44>)
84+
; CHECK-NOT: !nontemporal ![[NONTEMPORAL:[0-9]+]]
85+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast i32* [[DST:%.*]] to <3 x i32>*
86+
; CHECK-NEXT: call void @llvm.genx.GenISA.PredicatedStore.p0v3i32.v3i32(<3 x i32>* [[TMP2]], <3 x i32> [[TMP1]], i64 4, i1 true)
87+
; CHECK-NOT: !nontemporal ![[NONTEMPORAL:[0-9]+]]
88+
; CHECK-NEXT: ret void
89+
90+
; CHECK: ![[NONTEMPORAL]] = !{i32 1}
91+
92+
; Function Attrs: nounwind readonly
93+
declare i32 @llvm.genx.GenISA.PredicatedLoad.i32.p0i32.i32(i32*, i64, i1, i32) #0
94+
95+
declare void @llvm.genx.GenISA.PredicatedStore.p0i32.i32(i32*, i32, i64, i1)
96+
97+
attributes #0 = { nounwind readonly }
98+
99+
!igc.functions = !{!0, !3, !4}
100+
101+
!0 = !{void (i32*, i32*)* @f0, !1}
102+
!1 = !{!2}
103+
!2 = !{!"function_type", i32 0}
104+
!3 = !{void (i32*, i32*)* @f1, !1}
105+
!4 = !{void (i32*, i32*)* @f2, !1}
106+
!5 = !{i32 1}

0 commit comments

Comments
 (0)