Skip to content

Commit c7ec63a

Browse files
authored
Fix SPIR-V global to function replacement for differing load types (#2160)
In some cases, we will see IR with the following @__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 ... %0 = load <6 x i32>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32 %1 = extractelement <6 x i32> %0, i64 0 Note the global type and load type are different. Change the handling of vector loads from vector globals to reconstruct the global vector type and then bitcast to the load type. Thanks to @jcranmer-intel for helping me find the simplest solution.
1 parent 3d28e52 commit c7ec63a

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,19 +1989,20 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV,
19891989
/// are accumulated in the AccumulatedOffset parameter, which will eventually be
19901990
/// used to figure out which index of a variable is being used.
19911991
static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
1992-
Function *ReplacementFunc) {
1992+
Function *ReplacementFunc,
1993+
GlobalVariable *GV) {
19931994
const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
19941995
SmallVector<Instruction *, 4> InstsToRemove;
19951996
for (User *U : V->users()) {
19961997
if (auto *Cast = dyn_cast<CastInst>(U)) {
1997-
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
1998+
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc, GV);
19981999
InstsToRemove.push_back(Cast);
19992000
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
20002001
APInt NewOffset = AccumulatedOffset.sextOrTrunc(
20012002
DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
20022003
if (!GEP->accumulateConstantOffset(DL, NewOffset))
20032004
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
2004-
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
2005+
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc, GV);
20052006
InstsToRemove.push_back(GEP);
20062007
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
20072008
// Figure out which index the accumulated offset corresponds to. If we
@@ -2024,7 +2025,12 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
20242025
} else {
20252026
// The function has an index parameter.
20262027
if (auto *VecTy = dyn_cast<FixedVectorType>(Load->getType())) {
2027-
if (!Index.isZero())
2028+
// Reconstruct the original global variable vector because
2029+
// the load type may not match.
2030+
// global <3 x i64>, load <6 x i32>
2031+
VecTy = cast<FixedVectorType>(GV->getValueType());
2032+
if (!Index.isZero() || DL.getTypeSizeInBits(VecTy) !=
2033+
DL.getTypeSizeInBits(Load->getType()))
20282034
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
20292035
Replacement = UndefValue::get(VecTy);
20302036
for (unsigned I = 0; I < VecTy->getNumElements(); I++) {
@@ -2034,6 +2040,19 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
20342040
Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
20352041
Builder.getInt32(I));
20362042
}
2043+
// Insert a bitcast from the reconstructed vector to the load vector
2044+
// type in case they are different.
2045+
// Input:
2046+
// %1 = load <6 x i32>, ptr addrspace(1) %0, align 32
2047+
// %2 = extractelement <6 x i32> %1, i32 0
2048+
// %3 = add i32 5, %2
2049+
// Modified:
2050+
// < reconstruct global vector elements 0 and 1 >
2051+
// %2 = insertelement <3 x i64> %0, i64 %1, i32 2
2052+
// %3 = bitcast <3 x i64> %2 to <6 x i32>
2053+
// %4 = extractelement <6 x i32> %3, i32 0
2054+
// %5 = add i32 5, %4
2055+
Replacement = Builder.CreateBitCast(Replacement, Load->getType());
20372056
} else if (Load->getType() == ScalarTy) {
20382057
Replacement = setAttrByCalledFunc(Builder.CreateCall(
20392058
ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
@@ -2087,7 +2106,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
20872106
Func->setDoesNotAccessMemory();
20882107
}
20892108

2090-
replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
2109+
replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func, GV);
20912110
return true;
20922111
}
20932112

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv -spirv-ext=+SPV_INTEL_vector_compute
3+
; RUN: llvm-spirv -r %t.spv --spirv-target-env=SPV-IR -o %t.out.bc
4+
; RUN: llvm-dis %t.out.bc -o - | FileCheck %s --check-prefix=CHECK-SPV-IR
5+
6+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
7+
target triple = "spir-unknown-unknown"
8+
9+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
10+
11+
; Function Attrs: nounwind readnone
12+
define spir_kernel void @f() {
13+
entry:
14+
%0 = load <6 x i32>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
15+
%1 = extractelement <6 x i32> %0, i64 0
16+
%2 = add i32 5, %1
17+
ret void
18+
; CHECK-SPV-IR: %[[#ID0:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0) #1
19+
; CHECK-SPV-IR: %[[#ID1:]] = insertelement <3 x i64> undef, i64 %[[#ID0]], i32 0
20+
; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1) #1
21+
; CHECK-SPV-IR: %[[#ID3:]] = insertelement <3 x i64> %[[#ID1]], i64 %[[#ID2]], i32 1
22+
; CHECK-SPV-IR: %[[#ID4:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2) #1
23+
; CHECK-SPV-IR: %[[#ID5:]] = insertelement <3 x i64> %[[#ID3]], i64 %[[#ID4]], i32 2
24+
; CHECK-SPV-IR: %[[#ID6:]] = bitcast <3 x i64> %[[#ID5]] to <6 x i32>
25+
; CHECK-SPV-IR: %[[#ID7:]] = extractelement <6 x i32> %[[#ID6]], i32 0
26+
; CHECK-SPV-IR: = add i32 5, %[[#ID7]]
27+
}

0 commit comments

Comments
 (0)