Skip to content

Commit 340e46a

Browse files
committed
1. address PR comments
2. remove use of convertUsersOfConstantsToInstructions (was leaving use chains in unstable state). 3. remove RPOT for DXILDataScalarization.cpp. we were visiting twice.
1 parent 3fa78a6 commit 340e46a

File tree

6 files changed

+135
-108
lines changed

6 files changed

+135
-108
lines changed

llvm/lib/Target/DirectX/DXILDataScalarization.cpp

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ static bool findAndReplaceVectors(Module &M);
4242
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
4343
public:
4444
DataScalarizerVisitor() : GlobalMap() {}
45-
bool visit(Function &F);
45+
bool visit(Instruction &I);
4646
// InstVisitor methods. They return true if the instruction was scalarized,
4747
// false if nothing changed.
4848
bool visitInstruction(Instruction &I) { return false; }
@@ -65,19 +65,14 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
6565
friend bool findAndReplaceVectors(llvm::Module &M);
6666

6767
private:
68+
Value *createNewGetElementPtr(GetElementPtrInst &GEPI);
6869
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
6970
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
7071
};
7172

72-
bool DataScalarizerVisitor::visit(Function &F) {
73+
bool DataScalarizerVisitor::visit(Instruction &I) {
7374
assert(!GlobalMap.empty());
74-
bool MadeChange = false;
75-
ReversePostOrderTraversal<Function *> RPOT(&F);
76-
for (BasicBlock *BB : make_early_inc_range(RPOT)) {
77-
for (Instruction &I : make_early_inc_range(*BB))
78-
MadeChange |= InstVisitor::visit(I);
79-
}
80-
return MadeChange;
75+
return InstVisitor::visit(I);
8176
}
8277

8378
GlobalVariable *
@@ -95,6 +90,21 @@ bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
9590
unsigned NumOperands = LI.getNumOperands();
9691
for (unsigned I = 0; I < NumOperands; ++I) {
9792
Value *CurrOpperand = LI.getOperand(I);
93+
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
94+
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
95+
GetElementPtrInst *OldGEP =
96+
cast<GetElementPtrInst>(CE->getAsInstruction());
97+
OldGEP->insertBefore(&LI);
98+
Value *NewGEP = createNewGetElementPtr(*OldGEP);
99+
IRBuilder<> Builder(&LI);
100+
LoadInst *NewLoad =
101+
Builder.CreateLoad(LI.getType(), NewGEP, LI.getName());
102+
NewLoad->setAlignment(LI.getAlign());
103+
LI.replaceAllUsesWith(NewLoad);
104+
LI.eraseFromParent();
105+
OldGEP->eraseFromParent();
106+
return true;
107+
}
98108
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
99109
LI.setOperand(I, NewGlobal);
100110
}
@@ -105,32 +115,53 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
105115
unsigned NumOperands = SI.getNumOperands();
106116
for (unsigned I = 0; I < NumOperands; ++I) {
107117
Value *CurrOpperand = SI.getOperand(I);
108-
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) {
109-
SI.setOperand(I, NewGlobal);
118+
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
119+
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
120+
GetElementPtrInst *OldGEP =
121+
cast<GetElementPtrInst>(CE->getAsInstruction());
122+
OldGEP->insertBefore(&SI);
123+
Value *NewGEP = createNewGetElementPtr(*OldGEP);
124+
IRBuilder<> Builder(&SI);
125+
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), NewGEP);
126+
NewStore->setAlignment(SI.getAlign());
127+
SI.replaceAllUsesWith(NewStore);
128+
SI.eraseFromParent();
129+
OldGEP->eraseFromParent();
130+
return true;
110131
}
132+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
133+
SI.setOperand(I, NewGlobal);
111134
}
112135
return false;
113136
}
114137

115-
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
138+
Value *DataScalarizerVisitor::createNewGetElementPtr(GetElementPtrInst &GEPI) {
116139
unsigned NumOperands = GEPI.getNumOperands();
140+
GlobalVariable *NewGlobal = nullptr;
117141
for (unsigned I = 0; I < NumOperands; ++I) {
118142
Value *CurrOpperand = GEPI.getOperand(I);
119-
GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand);
120-
if (!NewGlobal)
121-
continue;
122-
IRBuilder<> Builder(&GEPI);
143+
NewGlobal = lookupReplacementGlobal(CurrOpperand);
144+
if (NewGlobal)
145+
break;
146+
}
147+
if (!NewGlobal)
148+
return nullptr;
123149

124-
SmallVector<Value *, MaxVecSize> Indices;
125-
for (auto &Index : GEPI.indices())
126-
Indices.push_back(Index);
150+
IRBuilder<> Builder(&GEPI);
151+
SmallVector<Value *, MaxVecSize> Indices;
152+
for (auto &Index : GEPI.indices())
153+
Indices.push_back(Index);
127154

128-
Value *NewGEP =
129-
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);
155+
return Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
156+
GEPI.getName(), GEPI.getNoWrapFlags());
157+
}
130158

131-
GEPI.replaceAllUsesWith(NewGEP);
132-
GEPI.eraseFromParent();
133-
}
159+
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
160+
Value *NewGEP = createNewGetElementPtr(GEPI);
161+
if (!NewGEP)
162+
return false;
163+
GEPI.replaceAllUsesWith(NewGEP);
164+
GEPI.eraseFromParent();
134165
return true;
135166
}
136167

@@ -236,16 +267,13 @@ static bool findAndReplaceVectors(Module &M) {
236267
for (User *U : make_early_inc_range(G.users())) {
237268
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
238269
ConstantExpr *CE = cast<ConstantExpr>(U);
239-
convertUsersOfConstantsToInstructions(CE,
240-
/*RestrictToFunc=*/nullptr,
241-
/*RemoveDeadConstants=*/false,
242-
/*IncludeSelf=*/true);
243-
}
244-
if (Instruction *Inst = dyn_cast<Instruction>(U)) {
245-
Function *F = Inst->getFunction();
246-
if (F)
247-
Impl.visit(*F);
270+
for (User *UCE : make_early_inc_range(CE->users())) {
271+
if (Instruction *Inst = dyn_cast<Instruction>(UCE))
272+
Impl.visit(*Inst);
273+
}
248274
}
275+
if (Instruction *Inst = dyn_cast<Instruction>(U))
276+
Impl.visit(*Inst);
249277
}
250278
}
251279
}

llvm/lib/Target/DirectX/DXILFlattenArrays.cpp

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,18 @@ bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
164164
Value *CurrOpperand = LI.getOperand(I);
165165
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
166166
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
167-
convertUsersOfConstantsToInstructions(CE,
168-
/*RestrictToFunc=*/nullptr,
169-
/*RemoveDeadConstants=*/false,
170-
/*IncludeSelf=*/true);
171-
return false;
167+
GetElementPtrInst *OldGEP =
168+
cast<GetElementPtrInst>(CE->getAsInstruction());
169+
OldGEP->insertBefore(&LI);
170+
171+
IRBuilder<> Builder(&LI);
172+
LoadInst *NewLoad =
173+
Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
174+
NewLoad->setAlignment(LI.getAlign());
175+
LI.replaceAllUsesWith(NewLoad);
176+
LI.eraseFromParent();
177+
visitGetElementPtrInst(*OldGEP);
178+
return true;
172179
}
173180
}
174181
return false;
@@ -180,11 +187,17 @@ bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
180187
Value *CurrOpperand = SI.getOperand(I);
181188
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
182189
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
183-
convertUsersOfConstantsToInstructions(CE,
184-
/*RestrictToFunc=*/nullptr,
185-
/*RemoveDeadConstants=*/false,
186-
/*IncludeSelf=*/true);
187-
return false;
190+
GetElementPtrInst *OldGEP =
191+
cast<GetElementPtrInst>(CE->getAsInstruction());
192+
OldGEP->insertBefore(&SI);
193+
194+
IRBuilder<> Builder(&SI);
195+
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
196+
NewStore->setAlignment(SI.getAlign());
197+
SI.replaceAllUsesWith(NewStore);
198+
SI.eraseFromParent();
199+
visitGetElementPtrInst(*OldGEP);
200+
return true;
188201
}
189202
}
190203
return false;
@@ -317,20 +330,20 @@ bool DXILFlattenArraysVisitor::visit(Function &F) {
317330
static void collectElements(Constant *Init,
318331
SmallVectorImpl<Constant *> &Elements) {
319332
// Base case: If Init is not an array, add it directly to the vector.
320-
if (!isa<ArrayType>(Init->getType())) {
333+
auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
334+
if (!ArrayTy) {
321335
Elements.push_back(Init);
322336
return;
323337
}
324-
auto *LLVMArrayType = dyn_cast<ArrayType>(Init->getType());
338+
unsigned ArrSize = ArrayTy->getNumElements();
325339
if (isa<ConstantAggregateZero>(Init)) {
326-
for (unsigned I = 0; I < LLVMArrayType->getNumElements(); ++I)
327-
Elements.push_back(
328-
Constant::getNullValue(LLVMArrayType->getElementType()));
340+
for (unsigned I = 0; I < ArrSize; ++I)
341+
Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
329342
return;
330343
}
331344
if (isa<UndefValue>(Init)) {
332-
for (unsigned I = 0; I < LLVMArrayType->getNumElements(); ++I)
333-
Elements.push_back(UndefValue::get(LLVMArrayType->getElementType()));
345+
for (unsigned I = 0; I < ArrSize; ++I)
346+
Elements.push_back(UndefValue::get(ArrayTy->getElementType()));
334347
return;
335348
}
336349

llvm/test/CodeGen/DirectX/flatten-bug-117273.ll

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,22 @@
22
; RUN: opt -S -passes='dxil-flatten-arrays,dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
33

44

5-
@_ZL7Palette = internal constant [2 x [3 x float]] [[3 x float] zeroinitializer, [3 x float] undef], align 16
5+
@ZerroInitAndUndefArr = internal constant [2 x [3 x float]] [[3 x float] zeroinitializer, [3 x float] undef], align 16
66

7-
; CHECK: @_ZL7Palette.1dim = internal constant [6 x float] [float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float undef, float undef, float undef], align 16
87

9-
define internal void @_Z4mainDv3_j(<3 x i32> noundef %DID) {
10-
; CHECK-LABEL: define internal void @_Z4mainDv3_j(
11-
; CHECK-SAME: <3 x i32> noundef [[DID:%.*]]) {
8+
define internal void @main() {
9+
; CHECK-LABEL: define internal void @main() {
1210
; CHECK-NEXT: [[ENTRY:.*:]]
13-
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.1dim, i32 1
11+
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [24 x float], ptr @ZerroInitAndUndefArr.1dim, i32 1
1412
; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
15-
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [24 x float], ptr @_ZL7Palette.1dim, i32 2
13+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [24 x float], ptr @ZerroInitAndUndefArr.1dim, i32 2
1614
; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
1715
; CHECK-NEXT: ret void
1816
;
1917
entry:
20-
%0 = getelementptr [8 x [3 x float]], ptr @_ZL7Palette, i32 0, i32 1
18+
%0 = getelementptr [8 x [3 x float]], ptr @ZerroInitAndUndefArr, i32 0, i32 1
2119
%.i0 = load float, ptr %0, align 16
22-
%1 = getelementptr [8 x [3 x float]], ptr @_ZL7Palette, i32 0, i32 2
20+
%1 = getelementptr [8 x [3 x float]], ptr @ZerroInitAndUndefArr, i32 0, i32 2
2321
%.i03 = load float, ptr %1, align 16
2422
ret void
2523
}

llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
; CHECK-NOT: @groushared2dArrayofVectors
2222
; CHECK-NOT: @groushared2dArrayofVectors.scalarized
2323

24-
2524
define <4 x i32> @load_array_vec_test() #0 {
2625
; CHECK-LABEL: define <4 x i32> @load_array_vec_test(
2726
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
@@ -33,18 +32,13 @@ define <4 x i32> @load_array_vec_test() #0 {
3332
; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
3433
; CHECK-NEXT: [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 3) to ptr addrspace(3)
3534
; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
36-
; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) @arrayofVecData.scalarized.1dim to ptr addrspace(3)
37-
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr [2 x [3 x float]], ptr addrspace(3) [[TMP9]], i32 0, i32 1
38-
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
35+
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1) to ptr addrspace(3)
3936
; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4
40-
; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
41-
; CHECK-NEXT: [[DOTI12:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP13]], i32 1
37+
; CHECK-NEXT: [[DOTI12:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 1) to ptr addrspace(3)
4238
; CHECK-NEXT: [[DOTI13:%.*]] = load i32, ptr addrspace(3) [[DOTI12]], align 4
43-
; CHECK-NEXT: [[TMP14:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
44-
; CHECK-NEXT: [[DOTI24:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP14]], i32 2
39+
; CHECK-NEXT: [[DOTI24:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 2) to ptr addrspace(3)
4540
; CHECK-NEXT: [[DOTI25:%.*]] = load i32, ptr addrspace(3) [[DOTI24]], align 4
46-
; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
47-
; CHECK-NEXT: [[DOTI36:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP15]], i32 3
41+
; CHECK-NEXT: [[DOTI36:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 3) to ptr addrspace(3)
4842
; CHECK-NEXT: [[DOTI37:%.*]] = load i32, ptr addrspace(3) [[DOTI36]], align 4
4943
; CHECK-NEXT: [[DOTI08:%.*]] = add i32 [[TMP2]], [[TMP12]]
5044
; CHECK-NEXT: [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]]
@@ -87,7 +81,7 @@ define <4 x i32> @load_vec_test() #0 {
8781
define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
8882
; CHECK-LABEL: define <4 x i32> @load_static_array_of_vec_test(
8983
; CHECK-SAME: i32 [[INDEX:%.*]]) #[[ATTR0]] {
90-
; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 [[INDEX]]
84+
; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr inbounds [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 [[INDEX]]
9185
; CHECK-NEXT: [[TMP1:%.*]] = bitcast ptr [[DOTFLAT]] to ptr
9286
; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
9387
; CHECK-NEXT: [[TMP3:%.*]] = bitcast ptr [[DOTFLAT]] to ptr
@@ -121,18 +115,13 @@ define <4 x i32> @multid_load_test() #0 {
121115
; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
122116
; CHECK-NEXT: [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 3) to ptr addrspace(3)
123117
; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
124-
; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim to ptr addrspace(3)
125-
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr [3 x [3 x [4 x i32]]], ptr addrspace(3) [[TMP9]], i32 0, i32 1, i32 1
126-
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
118+
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1) to ptr addrspace(3)
127119
; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4
128-
; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
129-
; CHECK-NEXT: [[DOTI12:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP13]], i32 1
120+
; CHECK-NEXT: [[DOTI12:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 1) to ptr addrspace(3)
130121
; CHECK-NEXT: [[DOTI13:%.*]] = load i32, ptr addrspace(3) [[DOTI12]], align 4
131-
; CHECK-NEXT: [[TMP14:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
132-
; CHECK-NEXT: [[DOTI24:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP14]], i32 2
122+
; CHECK-NEXT: [[DOTI24:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 2) to ptr addrspace(3)
133123
; CHECK-NEXT: [[DOTI25:%.*]] = load i32, ptr addrspace(3) [[DOTI24]], align 4
134-
; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
135-
; CHECK-NEXT: [[DOTI36:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP15]], i32 3
124+
; CHECK-NEXT: [[DOTI36:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 3) to ptr addrspace(3)
136125
; CHECK-NEXT: [[DOTI37:%.*]] = load i32, ptr addrspace(3) [[DOTI36]], align 4
137126
; CHECK-NEXT: [[DOTI08:%.*]] = add i32 [[TMP2]], [[TMP12]]
138127
; CHECK-NEXT: [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]]

0 commit comments

Comments
 (0)