-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[SPIR-V] Add store legalization for ptrcast #135369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-spir-v Author: Nathan Gauër (Keenuts) ChangesThis commits adds handling for spv.ptrcast result being used in a store instruction, modifying the store to operate on the source type. Full diff: https://github.com/llvm/llvm-project/pull/135369.diff 4 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 5ba4fbb02560d..f3f1558265d4a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -150,6 +150,95 @@ class SPIRVLegalizePointerCast : public FunctionPass {
DeadInstructions.push_back(LI);
}
+ // Creates an spv_insertelt instruction (equivalent to llvm's insertelement).
+ Value *makeInsertElement(IRBuilder<> &B, Value *Vector, Value *Element,
+ unsigned Index) {
+ Type *Int32Ty = Type::getInt32Ty(B.getContext());
+ SmallVector<Type *, 4> Types = {Vector->getType(), Vector->getType(),
+ Element->getType(), Int32Ty};
+ SmallVector<Value *> Args = {Vector, Element, B.getInt32(Index)};
+ Instruction *NewI =
+ B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
+ buildAssignType(B, Vector->getType(), NewI);
+ return NewI;
+ }
+
+ // Creates an spv_extractelt instruction (equivalent to llvm's
+ // extractelement).
+ Value *makeExtractElement(IRBuilder<> &B, Type *ElementType, Value *Vector,
+ unsigned Index) {
+ Type *Int32Ty = Type::getInt32Ty(B.getContext());
+ SmallVector<Type *, 3> Types = {ElementType, Vector->getType(), Int32Ty};
+ SmallVector<Value *> Args = {Vector, B.getInt32(Index)};
+ Instruction *NewI =
+ B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
+ buildAssignType(B, ElementType, NewI);
+ return NewI;
+ }
+
+ // Stores the given Src vector operand into the Dst vector, adjusting the size
+ // if required.
+ Value *storeVectorFromVector(IRBuilder<> &B, Value *Src, Value *Dst,
+ Align Alignment) {
+ FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
+ FixedVectorType *DstType =
+ cast<FixedVectorType>(GR->findDeducedElementType(Dst));
+ assert(DstType->getNumElements() >= SrcType->getNumElements());
+
+ LoadInst *LI = B.CreateLoad(DstType, Dst);
+ LI->setAlignment(Alignment);
+ Value *OldValues = LI;
+ buildAssignType(B, OldValues->getType(), OldValues);
+ Value *NewValues = Src;
+
+ for (unsigned I = 0; I < SrcType->getNumElements(); ++I) {
+ Value *Element =
+ makeExtractElement(B, SrcType->getElementType(), NewValues, I);
+ OldValues = makeInsertElement(B, OldValues, Element, I);
+ }
+
+ StoreInst *SI = B.CreateStore(OldValues, Dst);
+ SI->setAlignment(Alignment);
+ return SI;
+ }
+
+ // Stores the given Src value into the first entry of the Dst aggregate.
+ Value *storeToFirstValueAggregate(IRBuilder<> &B, Value *Src, Value *Dst,
+ Align Alignment) {
+ SmallVector<Type *, 2> Types = {Dst->getType(), Dst->getType()};
+ SmallVector<Value *, 3> Args{/* isInBounds= */ B.getInt1(true), Dst,
+ B.getInt32(0), B.getInt32(0)};
+ auto *GEP = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
+ GR->buildAssignPtr(B, Src->getType(), GEP);
+ StoreInst *SI = B.CreateStore(Src, GEP);
+ SI->setAlignment(Alignment);
+ return SI;
+ }
+
+ // Transforms a store instruction (or SPV intrinsic) using a ptrcast as
+ // operand into a valid logical SPIR-V store with no ptrcast.
+ void transformStore(IRBuilder<> &B, Instruction *BadStore, Value *Src,
+ Value *Dst, Align Alignment) {
+ Type *ToTy = GR->findDeducedElementType(Dst);
+ Type *FromTy = Src->getType();
+
+ auto *SVT = dyn_cast<FixedVectorType>(FromTy);
+ auto *DST = dyn_cast<StructType>(ToTy);
+ auto *DVT = dyn_cast<FixedVectorType>(ToTy);
+
+ B.SetInsertPoint(BadStore);
+ if (DST && DST->getTypeAtIndex(0u) == FromTy)
+ storeToFirstValueAggregate(B, Src, Dst, Alignment);
+ else if (DVT && SVT)
+ storeVectorFromVector(B, Src, Dst, Alignment);
+ else if (DVT && !SVT && FromTy == DVT->getElementType())
+ storeToFirstValueAggregate(B, Src, Dst, Alignment);
+ else
+ llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
+
+ DeadInstructions.push_back(BadStore);
+ }
+
void legalizePointerCast(IntrinsicInst *II) {
Value *CastedOperand = II;
Value *OriginalOperand = II->getOperand(0);
@@ -165,6 +254,12 @@ class SPIRVLegalizePointerCast : public FunctionPass {
continue;
}
+ if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
+ transformStore(B, SI, SI->getValueOperand(), OriginalOperand,
+ SI->getAlign());
+ continue;
+ }
+
if (IntrinsicInst *Intrin = dyn_cast<IntrinsicInst>(User)) {
if (Intrin->getIntrinsicID() == Intrinsic::spv_assign_ptr_type) {
DeadInstructions.push_back(Intrin);
@@ -176,6 +271,15 @@ class SPIRVLegalizePointerCast : public FunctionPass {
/* DeleteOld= */ false);
continue;
}
+
+ if (Intrin->getIntrinsicID() == Intrinsic::spv_store) {
+ Align Alignment;
+ if (ConstantInt *C = dyn_cast<ConstantInt>(Intrin->getOperand(3)))
+ Alignment = Align(C->getZExtValue());
+ transformStore(B, Intrin, Intrin->getArgOperand(0), OriginalOperand,
+ Alignment);
+ continue;
+ }
}
llvm_unreachable("Unsupported ptrcast user. Please fix.");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
index b0a68a30e29be..35e5880881e5c 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-struct.ll
@@ -45,3 +45,23 @@ entry:
%val = load i32, ptr addrspace(10) %ptr
ret i32 %val
}
+
+define spir_func void @foos(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global1]] %[[#uint_0]] %[[#index]]
+ %ptr = getelementptr inbounds %S1, ptr addrspace(10) @global1, i64 0, i32 0, i64 %index
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+ store i32 0, ptr addrspace(10) %ptr
+ ret void
+}
+
+define spir_func void @bars(i64 noundef %index) local_unnamed_addr {
+; CHECK: %[[#index:]] = OpFunctionParameter %[[#uint64]]
+entry:
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#global2]] %[[#uint_0]] %[[#uint_0]] %[[#index]] %[[#uint_1]]
+ %ptr = getelementptr inbounds %S2, ptr addrspace(10) @global2, i64 0, i32 0, i32 0, i64 %index, i32 1
+; CHECK: OpStore %[[#ptr]] %[[#uint_0]] Aligned 4
+ store i32 0, ptr addrspace(10) %ptr
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
index d4131fa8a2658..be9e2a23365cc 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-downcast-vector.ll
@@ -5,9 +5,13 @@
; CHECK-DAG: %[[#uint_pp:]] = OpTypePointer Private %[[#uint]]
; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1
+; CHECK-DAG: %[[#uint_2:]] = OpConstant %[[#uint]] 2
; CHECK-DAG: %[[#v2:]] = OpTypeVector %[[#uint]] 2
; CHECK-DAG: %[[#v3:]] = OpTypeVector %[[#uint]] 3
; CHECK-DAG: %[[#v4:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG: %[[#v2_01:]] = OpConstantComposite %[[#v2]] %[[#uint_0]] %[[#uint_1]]
+; CHECK-DAG: %[[#v3_012:]] = OpConstantComposite %[[#v3]] %[[#uint_0]] %[[#uint_1]] %[[#uint_2]]
; CHECK-DAG: %[[#v4_pp:]] = OpTypePointer Private %[[#v4]]
; CHECK-DAG: %[[#v4_fp:]] = OpTypePointer Function %[[#v4]]
@@ -108,3 +112,109 @@ define internal spir_func i32 @bazBounds(ptr %a) {
ret i32 %2
; CHECK: OpReturnValue %[[#val]]
}
+
+define internal spir_func void @foos(ptr addrspace(10) %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_pp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @foosDefault(ptr %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @foosBounds(ptr %a) {
+
+ %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
+
+ store <3 x i32> <i32 0, i32 1, i32 2>, ptr %1, align 64
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 64
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: %[[#C:]] = OpCompositeExtract %[[#uint]] %[[#v3_012]] 2
+; CHECK: %[[#out3:]] = OpCompositeInsert %[[#v4]] %[[#C]] %[[#out2]] 2
+; CHECK: OpStore %[[#ptr]] %[[#out3]] Aligned 64
+
+ ret void
+}
+
+define internal spir_func void @bars(ptr addrspace(10) %a) {
+
+ %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
+
+ store <2 x i32> <i32 0, i32 1>, ptr addrspace(10) %1, align 16
+; CHECK: %[[#out0:]] = OpLoad %[[#v4]] %[[#ptr]] Aligned 16
+; CHECK: %[[#A:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 0
+; CHECK: %[[#out1:]] = OpCompositeInsert %[[#v4]] %[[#A]] %[[#out0]] 0
+; CHECK: %[[#B:]] = OpCompositeExtract %[[#uint]] %[[#v2_01]] 1
+; CHECK: %[[#out2:]] = OpCompositeInsert %[[#v4]] %[[#B]] %[[#out1]] 1
+; CHECK: OpStore %[[#ptr]] %[[#out2]] Aligned 1
+
+ ret void
+}
+
+define internal spir_func void @bazs(ptr addrspace(10) %a) {
+
+ %1 = getelementptr <4 x i32>, ptr addrspace(10) %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_pp]] %[[#]]
+
+ store i32 0, ptr addrspace(10) %1, align 32
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_pp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 32
+
+ ret void
+}
+
+define internal spir_func void @bazsDefault(ptr %a) {
+
+ %1 = getelementptr inbounds <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#v4_fp]] %[[#]]
+
+ store i32 0, ptr %1, align 16
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+ ret void
+}
+
+define internal spir_func void @bazsBounds(ptr %a) {
+
+ %1 = getelementptr <4 x i32>, ptr %a, i64 0
+; CHECK: %[[#ptr:]] = OpAccessChain %[[#v4_fp]] %[[#]]
+
+ store i32 0, ptr %1, align 16
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#ptr]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 16
+
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
new file mode 100644
index 0000000000000..7d2c1093f0a71
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-struct.ll
@@ -0,0 +1,66 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_fp:]] = OpTypePointer Function %[[#float]]
+; CHECK-DAG: %[[#float_pp:]] = OpTypePointer Private %[[#float]]
+; CHECK-DAG: %[[#uint_fp:]] = OpTypePointer Function %[[#uint]]
+; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
+; CHECK-DAG: %[[#float_0:]] = OpConstant %[[#float]] 0
+; CHECK-DAG: %[[#sf:]] = OpTypeStruct %[[#float]]
+; CHECK-DAG: %[[#su:]] = OpTypeStruct %[[#uint]]
+; CHECK-DAG: %[[#sfuf:]] = OpTypeStruct %[[#float]] %[[#uint]] %[[#float]]
+; CHECK-DAG: %[[#sf_fp:]] = OpTypePointer Function %[[#sf]]
+; CHECK-DAG: %[[#su_fp:]] = OpTypePointer Function %[[#su]]
+; CHECK-DAG: %[[#sfuf_fp:]] = OpTypePointer Function %[[#sfuf]]
+; CHECK-DAG: %[[#sfuf_pp:]] = OpTypePointer Private %[[#sfuf]]
+
+%struct.SF = type { float }
+%struct.SU = type { i32 }
+%struct.SFUF = type { float, i32, float }
+
+@gsfuf = external addrspace(10) global %struct.SFUF
+; CHECK: %[[#gsfuf:]] = OpVariable %[[#sfuf_pp]] Private
+
+define internal spir_func void @foo() {
+ %1 = alloca %struct.SF, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#sf_fp]] Function
+
+ store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @bar() {
+ %1 = alloca %struct.SU, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#su_fp]] Function
+
+ store i32 0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#uint_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#uint_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @baz() {
+ %1 = alloca %struct.SFUF, align 4
+; CHECK: %[[#var:]] = OpVariable %[[#sfuf_fp]] Function
+
+ store float 0.0, ptr %1, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_fp]] %[[#var]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
+define internal spir_func void @biz() {
+ store float 0.0, ptr addrspace(10) @gsfuf, align 4
+; CHECK: %[[#tmp:]] = OpInBoundsAccessChain %[[#float_pp]] %[[#gsfuf]] %[[#uint_0]]
+; CHECK: OpStore %[[#tmp]] %[[#float_0]] Aligned 4
+
+ ret void
+}
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we please use OpVectorShuffle. Also, I nocited that existing OpVectorShuffle tests look wrong. In @fooBounds
You have:
; CHECK: %[[#val:]] = OpVectorShuffle %[[#v3]] %[[#load]] %[[#load]] 0 0 0
I believe this should be:
; CHECK: %[[#val:]] = OpVectorShuffle %[[#v3]] %[[#load]] %[[#load]] 0 1 2
You can the first three element of the first input vector to be the first three elements of the output vector.
Good catch, I'll send another PR to fix the load case. |
Thanks, fixed the recursive part, PTAL |
Build failure seems unrelated. We got 2 broken SPIR-V tests: smoothstep & SV_GroupIndex, those are caused by a recently added validation in Vulkan1.3 which disallow the Linkage capability. |
This commits adds handling for spv.ptrcast result being used in a store instruction, modifying the store to operate on the source type.
rebased on main, tests are passing locally (2 previously broken tests have been marked as XFAIL since) |
This commits adds handling for spv.ptrcast result being used in a store instruction, modifying the store to operate on the source type.