Skip to content

[DirectX] Bug fix for Data Scalarization crash #118426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 55 additions & 46 deletions llvm/lib/Target/DirectX/DXILDataScalarization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ static bool findAndReplaceVectors(Module &M);
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
public:
DataScalarizerVisitor() : GlobalMap() {}
bool visit(Function &F);
bool visit(Instruction &I);
// InstVisitor methods. They return true if the instruction was scalarized,
// false if nothing changed.
bool visitInstruction(Instruction &I) { return false; }
Expand All @@ -67,28 +67,11 @@ class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
private:
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
bool finish();
};

bool DataScalarizerVisitor::visit(Function &F) {
bool DataScalarizerVisitor::visit(Instruction &I) {
assert(!GlobalMap.empty());
ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
for (BasicBlock *BB : RPOT) {
for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
Instruction *I = &*II;
bool Done = InstVisitor::visit(I);
++II;
if (Done && I->getType()->isVoidTy())
I->eraseFromParent();
}
}
return finish();
}

bool DataScalarizerVisitor::finish() {
RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs);
return true;
return InstVisitor::visit(I);
}

GlobalVariable *
Expand All @@ -106,6 +89,20 @@ bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
unsigned NumOperands = LI.getNumOperands();
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = LI.getOperand(I);
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
GetElementPtrInst *OldGEP =
cast<GetElementPtrInst>(CE->getAsInstruction());
OldGEP->insertBefore(&LI);
IRBuilder<> Builder(&LI);
LoadInst *NewLoad =
Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
NewLoad->setAlignment(LI.getAlign());
LI.replaceAllUsesWith(NewLoad);
LI.eraseFromParent();
visitGetElementPtrInst(*OldGEP);
return true;
}
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
LI.setOperand(I, NewGlobal);
}
Expand All @@ -116,32 +113,48 @@ bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
unsigned NumOperands = SI.getNumOperands();
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = SI.getOperand(I);
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand)) {
SI.setOperand(I, NewGlobal);
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
GetElementPtrInst *OldGEP =
cast<GetElementPtrInst>(CE->getAsInstruction());
OldGEP->insertBefore(&SI);
IRBuilder<> Builder(&SI);
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
NewStore->setAlignment(SI.getAlign());
SI.replaceAllUsesWith(NewStore);
SI.eraseFromParent();
visitGetElementPtrInst(*OldGEP);
return true;
}
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
SI.setOperand(I, NewGlobal);
}
return false;
}

bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {

unsigned NumOperands = GEPI.getNumOperands();
GlobalVariable *NewGlobal = nullptr;
for (unsigned I = 0; I < NumOperands; ++I) {
Value *CurrOpperand = GEPI.getOperand(I);
GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand);
if (!NewGlobal)
continue;
IRBuilder<> Builder(&GEPI);

SmallVector<Value *, MaxVecSize> Indices;
for (auto &Index : GEPI.indices())
Indices.push_back(Index);

Value *NewGEP =
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices);

GEPI.replaceAllUsesWith(NewGEP);
PotentiallyDeadInstrs.emplace_back(&GEPI);
NewGlobal = lookupReplacementGlobal(CurrOpperand);
if (NewGlobal)
break;
}
if (!NewGlobal)
return false;

IRBuilder<> Builder(&GEPI);
SmallVector<Value *, MaxVecSize> Indices;
for (auto &Index : GEPI.indices())
Indices.push_back(Index);

Value *NewGEP =
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
GEPI.getName(), GEPI.getNoWrapFlags());
GEPI.replaceAllUsesWith(NewGEP);
GEPI.eraseFromParent();
return true;
}

Expand Down Expand Up @@ -247,17 +260,13 @@ static bool findAndReplaceVectors(Module &M) {
for (User *U : make_early_inc_range(G.users())) {
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
ConstantExpr *CE = cast<ConstantExpr>(U);
convertUsersOfConstantsToInstructions(CE,
/*RestrictToFunc=*/nullptr,
/*RemoveDeadConstants=*/false,
/*IncludeSelf=*/true);
}
if (isa<Instruction>(U)) {
Instruction *Inst = cast<Instruction>(U);
Function *F = Inst->getFunction();
if (F)
Impl.visit(*F);
for (User *UCE : make_early_inc_range(CE->users())) {
if (Instruction *Inst = dyn_cast<Instruction>(UCE))
Impl.visit(*Inst);
}
}
if (Instruction *Inst = dyn_cast<Instruction>(U))
Impl.visit(*Inst);
}
}
}
Expand Down
42 changes: 31 additions & 11 deletions llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,18 @@ bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
Value *CurrOpperand = LI.getOperand(I);
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
convertUsersOfConstantsToInstructions(CE,
/*RestrictToFunc=*/nullptr,
/*RemoveDeadConstants=*/false,
/*IncludeSelf=*/true);
return false;
GetElementPtrInst *OldGEP =
cast<GetElementPtrInst>(CE->getAsInstruction());
OldGEP->insertBefore(&LI);

IRBuilder<> Builder(&LI);
LoadInst *NewLoad =
Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
NewLoad->setAlignment(LI.getAlign());
LI.replaceAllUsesWith(NewLoad);
LI.eraseFromParent();
visitGetElementPtrInst(*OldGEP);
return true;
}
}
return false;
Expand All @@ -180,11 +187,17 @@ bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
Value *CurrOpperand = SI.getOperand(I);
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
convertUsersOfConstantsToInstructions(CE,
/*RestrictToFunc=*/nullptr,
/*RemoveDeadConstants=*/false,
/*IncludeSelf=*/true);
return false;
GetElementPtrInst *OldGEP =
cast<GetElementPtrInst>(CE->getAsInstruction());
OldGEP->insertBefore(&SI);

IRBuilder<> Builder(&SI);
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
NewStore->setAlignment(SI.getAlign());
SI.replaceAllUsesWith(NewStore);
SI.eraseFromParent();
visitGetElementPtrInst(*OldGEP);
return true;
}
}
return false;
Expand Down Expand Up @@ -317,10 +330,17 @@ bool DXILFlattenArraysVisitor::visit(Function &F) {
static void collectElements(Constant *Init,
SmallVectorImpl<Constant *> &Elements) {
// Base case: If Init is not an array, add it directly to the vector.
if (!isa<ArrayType>(Init->getType())) {
auto *ArrayTy = dyn_cast<ArrayType>(Init->getType());
if (!ArrayTy) {
Elements.push_back(Init);
return;
}
unsigned ArrSize = ArrayTy->getNumElements();
if (isa<ConstantAggregateZero>(Init)) {
for (unsigned I = 0; I < ArrSize; ++I)
Elements.push_back(Constant::getNullValue(ArrayTy->getElementType()));
return;
}

// Recursive case: Process each element in the array.
if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) {
Expand Down
23 changes: 23 additions & 0 deletions llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='dxil-flatten-arrays,dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s


@ZerroInitArr = internal constant [2 x [3 x float]] [[3 x float] zeroinitializer, [3 x float] [float 1.000000e+00, float 1.000000e+00, float 1.000000e+00]], align 16


define internal void @main() {
; CHECK-LABEL: define internal void @main() {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [24 x float], ptr @ZerroInitArr.1dim, i32 1
; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [24 x float], ptr @ZerroInitArr.1dim, i32 2
; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
; CHECK-NEXT: ret void
;
entry:
%0 = getelementptr [8 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 1
%.i0 = load float, ptr %0, align 16
%1 = getelementptr [8 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 2
%.i03 = load float, ptr %1, align 16
ret void
}
29 changes: 9 additions & 20 deletions llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
; CHECK-NOT: @groushared2dArrayofVectors
; CHECK-NOT: @groushared2dArrayofVectors.scalarized


define <4 x i32> @load_array_vec_test() #0 {
; CHECK-LABEL: define <4 x i32> @load_array_vec_test(
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
Expand All @@ -33,18 +32,13 @@ define <4 x i32> @load_array_vec_test() #0 {
; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
; CHECK-NEXT: [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 3) to ptr addrspace(3)
; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) @arrayofVecData.scalarized.1dim to ptr addrspace(3)
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr [2 x [3 x float]], ptr addrspace(3) [[TMP9]], i32 0, i32 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one of the bugs from convertUsersOfConstantsToInstructions. It would split the bitcast from the GEP, but then we would never visit the GEP and so the GEP type never got flattened.

; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1) to ptr addrspace(3)
; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4
; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
; CHECK-NEXT: [[DOTI12:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP13]], i32 1
; CHECK-NEXT: [[DOTI12:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 1) to ptr addrspace(3)
; CHECK-NEXT: [[DOTI13:%.*]] = load i32, ptr addrspace(3) [[DOTI12]], align 4
; CHECK-NEXT: [[TMP14:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
; CHECK-NEXT: [[DOTI24:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP14]], i32 2
; CHECK-NEXT: [[DOTI24:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 2) to ptr addrspace(3)
; CHECK-NEXT: [[DOTI25:%.*]] = load i32, ptr addrspace(3) [[DOTI24]], align 4
; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
; CHECK-NEXT: [[DOTI36:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP15]], i32 3
; CHECK-NEXT: [[DOTI36:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), i32 3) to ptr addrspace(3)
; CHECK-NEXT: [[DOTI37:%.*]] = load i32, ptr addrspace(3) [[DOTI36]], align 4
; CHECK-NEXT: [[DOTI08:%.*]] = add i32 [[TMP2]], [[TMP12]]
; CHECK-NEXT: [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]]
Expand Down Expand Up @@ -87,7 +81,7 @@ define <4 x i32> @load_vec_test() #0 {
define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 {
; CHECK-LABEL: define <4 x i32> @load_static_array_of_vec_test(
; CHECK-SAME: i32 [[INDEX:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 [[INDEX]]
; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr inbounds [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 [[INDEX]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our GEP builders were not preserving inbounds. Now we are.

; CHECK-NEXT: [[TMP1:%.*]] = bitcast ptr [[DOTFLAT]] to ptr
; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4
; CHECK-NEXT: [[TMP3:%.*]] = bitcast ptr [[DOTFLAT]] to ptr
Expand Down Expand Up @@ -121,18 +115,13 @@ define <4 x i32> @multid_load_test() #0 {
; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
; CHECK-NEXT: [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 3) to ptr addrspace(3)
; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim to ptr addrspace(3)
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr [3 x [3 x [4 x i32]]], ptr addrspace(3) [[TMP9]], i32 0, i32 1, i32 1
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1) to ptr addrspace(3)
; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4
; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
; CHECK-NEXT: [[DOTI12:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP13]], i32 1
; CHECK-NEXT: [[DOTI12:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 1) to ptr addrspace(3)
; CHECK-NEXT: [[DOTI13:%.*]] = load i32, ptr addrspace(3) [[DOTI12]], align 4
; CHECK-NEXT: [[TMP14:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
; CHECK-NEXT: [[DOTI24:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP14]], i32 2
; CHECK-NEXT: [[DOTI24:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 2) to ptr addrspace(3)
; CHECK-NEXT: [[DOTI25:%.*]] = load i32, ptr addrspace(3) [[DOTI24]], align 4
; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) [[TMP10]] to ptr addrspace(3)
; CHECK-NEXT: [[DOTI36:%.*]] = getelementptr i32, ptr addrspace(3) [[TMP15]], i32 3
; CHECK-NEXT: [[DOTI36:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), i32 3) to ptr addrspace(3)
; CHECK-NEXT: [[DOTI37:%.*]] = load i32, ptr addrspace(3) [[DOTI36]], align 4
; CHECK-NEXT: [[DOTI08:%.*]] = add i32 [[TMP2]], [[TMP12]]
; CHECK-NEXT: [[DOTI19:%.*]] = add i32 [[TMP4]], [[DOTI13]]
Expand Down
25 changes: 25 additions & 0 deletions llvm/test/CodeGen/DirectX/scalar-bug-117273.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='dxil-data-scalarization,dxil-flatten-arrays,function(scalarizer<load-store>),dxil-op-lower' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s


@StaticArr = internal constant [8 x <3 x float>] [<3 x float> zeroinitializer, <3 x float> splat (float 5.000000e-01), <3 x float> <float 1.000000e+00, float 5.000000e-01, float 5.000000e-01>, <3 x float> <float 5.000000e-01, float 1.000000e+00, float 5.000000e-01>, <3 x float> <float 5.000000e-01, float 5.000000e-01, float 1.000000e+00>, <3 x float> <float 5.000000e-01, float 1.000000e+00, float 1.000000e+00>, <3 x float> <float 1.000000e+00, float 5.000000e-01, float 1.000000e+00>, <3 x float> <float 1.000000e+00, float 1.000000e+00, float 5.000000e-01>], align 16

; Function Attrs: alwaysinline convergent mustprogress norecurse nounwind
define internal void @main() #1 {
; CHECK-LABEL: define internal void @main() {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), align 16
; CHECK-NEXT: [[DOTI1:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 1), align 4
; CHECK-NEXT: [[DOTI2:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 1), i32 2), align 8
; CHECK-NEXT: [[DOTI01:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), align 16
; CHECK-NEXT: [[DOTI12:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 1), align 4
; CHECK-NEXT: [[DOTI23:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 2), i32 2), align 8
; CHECK-NEXT: ret void
;
entry:
%arrayidx = getelementptr inbounds [8 x <3 x float>], ptr @StaticArr, i32 0, i32 1
%2 = load <3 x float>, ptr %arrayidx, align 16
%arrayidx2 = getelementptr inbounds [8 x <3 x float>], ptr @StaticArr, i32 0, i32 2
%3 = load <3 x float>, ptr %arrayidx2, align 16
ret void
}
Loading
Loading