Skip to content

[SPIR-V] Add store legalization for ptrcast #135369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,124 @@ class SPIRVLegalizePointerCast : public FunctionPass {
DeadInstructions.push_back(LI);
}

// Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
unsigned Index) {
Type *Int32Ty = Type::getInt32Ty(B.getContext());
SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
Element->getType(), Int32Ty};
SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
Instruction *NewI =
B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
buildAssignType(B, Vector->getType(), NewI);
return NewI;
}

// Creates an spv_extractelt instruction (equivalent to llvm's
// extractelement).
Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
unsigned Index) {
Type *Int32Ty = Type::getInt32Ty(B.getContext());
SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
Instruction *NewI =
B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
buildAssignType(B, ElementType, NewI);
return NewI;
}

// Stores the given Src vector operand into the Dst vector, adjusting the size
// if required.
Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
Align Alignment) {
FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
FixedVectorType *DstType =
cast<FixedVectorType>(GR->findDeducedElementType(Dst));
assert(DstType->getNumElements() >= SrcType->getNumElements());

LoadInst *LI = B.CreateLoad(DstType, Dst);
LI->setAlignment(Alignment);
Value *OldValues = LI;
buildAssignType(B, OldValues->getType(), OldValues);
Value *NewValues = Src;

for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
Value *Element =
makeExtractElement(B, SrcType->getElementType(), NewValues, I);
OldValues = makeInsertElement(B, OldValues, Element, I);
}

StoreInst *SI = B.CreateStore(OldValues, Dst);
SI->setAlignment(Alignment);
return SI;
}

void buildGEPIndexChain(IRBuilder<> &B, Type *Search, Type *Aggregate,
SmallVectorImpl<Value *> &Indices) {
Indices.push_back(B.getInt32(0));

if (Search == Aggregate)
return;

if (auto *ST = dyn_cast<StructType>(Aggregate))
buildGEPIndexChain(B, Search, ST->getTypeAtIndex(0u), Indices);
else if (auto *AT = dyn_cast<ArrayType>(Aggregate))
buildGEPIndexChain(B, Search, AT->getElementType(), Indices);
else if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
buildGEPIndexChain(B, Search, VT->getElementType(), Indices);
else
llvm_unreachable("Bad access chain?");
}

// Stores the given Src value into the first entry of the Dst aggregate.
Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
Type *DstPointeeType, Align Alignment) {
SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst};
buildGEPIndexChain(B, Src->getType(), DstPointeeType, Args);
auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
GR->buildAssignPtr(B, Src->getType(), GEP);
StoreInst *SI = B.CreateStore(Src, GEP);
SI->setAlignment(Alignment);
return SI;
}

bool isTypeFirstElementAggregate(Type *Search, Type *Aggregate) {
if (Search == Aggregate)
return true;
if (auto *ST = dyn_cast<StructType>(Aggregate))
return isTypeFirstElementAggregate(Search, ST->getTypeAtIndex(0u));
if (auto *VT = dyn_cast<FixedVectorType>(Aggregate))
return isTypeFirstElementAggregate(Search, VT->getElementType());
if (auto *AT = dyn_cast<ArrayType>(Aggregate))
return isTypeFirstElementAggregate(Search, AT->getElementType());
return false;
}

// Transforms a store instruction (or SPV intrinsic) using a ptrcast as
// operand into a valid logical SPIR-V store with no ptrcast.
void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
Value *Dst, Align Alignment) {
Type *ToTy = GR->findDeducedElementType(Dst);
Type *FromTy = Src->getType();

auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
auto *D_ST = dyn_cast<StructType>(ToTy);
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);

B.SetInsertPoint(BadStore);
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
storeToFirstValueAggregate(B, Src, Dst, D_ST, Alignment);
else if (D_VT && S_VT)
storeVectorFromVector(B, Src, Dst, Alignment);
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
else
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");

DeadInstructions.push_back(BadStore);
}

void legalizePointerCast(IntrinsicInst *II) {
Value *CastedOperand = II;
Value *OriginalOperand = II->getOperand(0);
Expand All @@ -165,6 +283,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
continue;
}

if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
SI->getAlign());
continue;
}

if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
DeadInstructions.push_back(Intrin);
Expand All @@ -176,6 +300,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
/* DeleteOld= */ false);
continue;
}

if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
Align Alignment;
if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
Alignment = Align(C->getZExtValue());
transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
Alignment);
continue;
}
}

llvm_unreachable("Unsupported ptrcast user. Please fix.");
Expand Down
20 changes: 20 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,23 @@ entry:
%val = load i32, ptr addrspace(10) %ptr
ret i32 %val
}

define spir_func void @foos(i64 noundef %index) local_unnamed_addr {
; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
entry:
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global1]] %[[#uint_0]] %[[#index]]
%ptr = getelementptr inbounds %S1, ptr addrspace(10) @global1, i64 0, i32 0, i64 %index
; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
store i32 0, ptr addrspace(10) %ptr
ret void
}

define spir_func void @bars(i64 noundef %index) local_unnamed_addr {
; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
entry:
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global2]] %[[#uint_0]] %[[#uint_0]] %[[#index]] %[[#uint_1]]
%ptr = getelementptr inbounds %S2, ptr addrspace(10) @global2, i64 0, i32 0, i32 0, i64 %index, i32 1
; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
store i32 0, ptr addrspace(10) %ptr
ret void
}
110 changes: 110 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
; CHECK-DAG: %[[#uint_pp:]] = OpTypePointer Private %[[#uint]]
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1
; CHECK-DAG: %[[#uint_2:]] = OpConstant %[[#uint]] 2
; CHECK-DAG: %[[#v2:]] = OpTypeVector %[[#uint]] 2
; CHECK-DAG: %[[#v3:]] = OpTypeVector %[[#uint]] 3
; CHECK-DAG: %[[#v4:]] = OpTypeVector %[[#uint]] 4
; CHECK-DAG: %[[#v2_01:]] = OpConstantComposite %[[#v2]] %[[#uint_0]] %[[#uint_1]]
; CHECK-DAG: %[[#v3_012:]] = OpConstantComposite %[[#v3]] %[[#uint_0]] %[[#uint_1]] %[[#uint_2]]
; CHECK-DAG: %[[#v4_pp:]] = OpTypePointer Private %[[#v4]]
; CHECK-DAG: %[[#v4_fp:]] = OpTypePointer Function %[[#v4]]

Expand Down Expand Up @@ -108,3 +112,109 @@ define internal spir_func i32 @bazBounds(ptr %a) {
ret i32 %2
; CHECK: OpReturnValue %[[#val]]
}

define internal spir_func void @foos(ptr addrspace(10) %a) {

%1 = getelementptr inbounds <4 x i32>, ptr addrspace(10) %a, i64 0
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_pp]] %[[#]]

store <3 x i32> <i32 0, i32 1, i32 2>, ptr addrspace(10) %1, align 16
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16

ret void
}

define internal spir_func void @foosDefault(ptr %a) {

%1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]

store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 16
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16

ret void
}

define internal spir_func void @foosBounds(ptr %a) {

%1 = getelementptr <4 x i32>, ptr %a, i64 0
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]

store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 64
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 64
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 64

ret void
}

define internal spir_func void @bars(ptr addrspace(10) %a) {

%1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]

store <2 x i32> <i32 0, i32 1>, ptr addrspace(10) %1, align 16
; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 0
; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 1
; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
; CHECK: OpStore %[[#ptr]] %[[#out2]] Aligned 1

ret void
}

define internal spir_func void @bazs(ptr addrspace(10) %a) {

%1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]

store i32 0, ptr addrspace(10) %1, align 32
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#ptr]] %[[#uint_0]]
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 32

ret void
}

define internal spir_func void @bazsDefault(ptr %a) {

%1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]

store i32 0, ptr %1, align 16
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16

ret void
}

define internal spir_func void @bazsBounds(ptr %a) {

%1 = getelementptr <4 x i32>, ptr %a, i64 0
; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]

store i32 0, ptr %1, align 16
; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16

ret void
}
Loading
Loading