Skip to content

[HLSL][SPIR-V] Add SV_DispatchThreadID semantic support #82536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "clang/AST/Decl.h"
#include "clang/Basic/TargetOptions.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -342,8 +343,19 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
return B.CreateCall(FunctionCallee(DxGroupIndex));
}
if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id);
return buildVectorInput(B, DxThreadID, Ty);
llvm::Function *ThreadIDIntrinsic;
switch (CGM.getTarget().getTriple().getArch()) {
case llvm::Triple::dxil:
ThreadIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_thread_id);
break;
case llvm::Triple::spirv:
ThreadIDIntrinsic = CGM.getIntrinsic(Intrinsic::spv_thread_id);
break;
default:
llvm_unreachable("Input semantic not supported by target");
break;
}
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
}
assert(false && "Unhandled parameter attribute");
return nullptr;
Expand Down
35 changes: 16 additions & 19 deletions clang/test/CodeGenHLSL/semantics/DispatchThreadID.hlsl
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Make sure SV_DispatchThreadID translated into dx.thread.id.

const RWBuffer<float> In;
RWBuffer<float> Out;

// CHECK: define void @foo()
// CHECK: %[[ID:[0-9a-zA-Z]+]] = call i32 @llvm.dx.thread.id(i32 0)
// CHECK: call void @"?foo@@YAXH@Z"(i32 %[[ID]])
// CHECK: define void @foo()
// CHECK-DXIL: %[[#ID:]] = call i32 @llvm.dx.thread.id(i32 0)
// CHECK-SPIRV: %[[#ID:]] = call i32 @llvm.spv.thread.id(i32 0)
// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
[shader("compute")]
[numthreads(8,8,1)]
void foo(uint Idx : SV_DispatchThreadID) {
Out[Idx] = In[Idx];
}
void foo(uint Idx : SV_DispatchThreadID) {}

// CHECK: define void @bar()
// CHECK: %[[ID_X:[0-9a-zA-Z]+]] = call i32 @llvm.dx.thread.id(i32 0)
// CHECK: %[[ID_X_:[0-9a-zA-Z]+]] = insertelement <2 x i32> poison, i32 %[[ID_X]], i64 0
// CHECK: %[[ID_Y:[0-9a-zA-Z]+]] = call i32 @llvm.dx.thread.id(i32 1)
// CHECK: %[[ID_XY:[0-9a-zA-Z]+]] = insertelement <2 x i32> %[[ID_X_]], i32 %[[ID_Y]], i64 1
// CHECK: call void @"?bar@@YAXT?$__vector@H$01@__clang@@@Z"(<2 x i32> %[[ID_XY]])
// CHECK: define void @bar()
// CHECK-DXIL: %[[#ID_X:]] = call i32 @llvm.dx.thread.id(i32 0)
// CHECK-SPIRV: %[[#ID_X:]] = call i32 @llvm.spv.thread.id(i32 0)
// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
// CHECK-DXIL: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id(i32 1)
// CHECK-SPIRV: %[[#ID_Y:]] = call i32 @llvm.spv.thread.id(i32 1)
// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
// CHECK-DXIL: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
[shader("compute")]
[numthreads(8,8,1)]
void bar(uint2 Idx : SV_DispatchThreadID) {
Out[Idx.y] = In[Idx.x];
}
void bar(uint2 Idx : SV_DispatchThreadID) {}

1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ let TargetPrefix = "spv" in {
def int_spv_expect : Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>]>;

// The following intrinsic(s) are mirrored from IntrinsicsDirectX.td for HLSL support.
def int_spv_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
def int_spv_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
}
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(

// Output decorations for the GV.
// TODO: maybe move to GenerateDecorations pass.
if (IsConst)
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
if (IsConst && ST.isOpenCLEnv())
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});

if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectUnmergeValues(MachineInstr &I) const;

Register buildI32Constant(uint32_t Val, MachineInstr &I,
Expand Down Expand Up @@ -301,6 +304,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_FREEZE:
return selectFreeze(ResVReg, ResType, I);

case TargetOpcode::G_INTRINSIC:
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
return selectIntrinsic(ResVReg, ResType, I);
Expand Down Expand Up @@ -1614,6 +1618,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg());
break;
case Intrinsic::spv_thread_id:
return selectSpvThreadId(ResVReg, ResType, I);
default:
llvm_unreachable("Intrinsic selection not implemented");
}
Expand Down Expand Up @@ -1864,6 +1870,68 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
return Result;
}

bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
// DX intrinsic: @llvm.dx.thread.id(i32)
// ID Name Description
// 93 ThreadId reads the thread ID

MachineIRBuilder MIRBuilder(I);
const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
const SPIRVType *Vec3Ty =
GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder);
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);

// Create new register for GlobalInvocationID builtin variable.
Register NewRegister =
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 32));
GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());

// Build GlobalInvocationID global variable with the necessary decorations.
Register Variable = GR.buildGlobalVariable(
NewRegister, PtrType,
getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr,
SPIRV::StorageClass::Input, nullptr, true, true,
SPIRV::LinkageType::Import, MIRBuilder, false);

// Create new register for loading value.
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register LoadedRegister = MRI->createVirtualRegister(&SPIRV::IDRegClass);
MIRBuilder.getMRI()->setType(LoadedRegister, LLT::pointer(0, 32));
GR.assignSPIRVTypeToVReg(Vec3Ty, LoadedRegister, MIRBuilder.getMF());

// Load v3uint value from the global variable.
BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
.addDef(LoadedRegister)
.addUse(GR.getSPIRVTypeID(Vec3Ty))
.addUse(Variable);

// Get Thread ID index. Expecting operand is a constant immediate value,
// wrapped in a type assignment.
assert(I.getOperand(2).isReg());
Register ThreadIdReg = I.getOperand(2).getReg();
SPIRVType *ConstTy = this->MRI->getVRegDef(ThreadIdReg);
assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE &&
ConstTy->getOperand(1).isReg());
Register ConstReg = ConstTy->getOperand(1).getReg();
const MachineInstr *Const = this->MRI->getVRegDef(ConstReg);
assert(Const && Const->getOpcode() == TargetOpcode::G_CONSTANT);
const llvm::APInt &Val = Const->getOperand(1).getCImm()->getValue();
const uint32_t ThreadId = Val.getZExtValue();

// Extract the thread ID from the loaded vector value.
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(LoadedRegister)
.addImm(ThreadId);
return MIB.constrainAllUses(TII, TRI, RBI);
}

namespace llvm {
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
Expand Down
76 changes: 76 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
; RUN: llc -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}

; This file generated from the following command:
; clang -cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -finclude-default-header - -o - <<EOF
; [shader("compute")]
; [numthreads(1,1,1)]
; void main(uint3 ID : SV_DispatchThreadID) {}
; EOF

; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#v3int:]] = OpTypeVector %[[#int]] 3
; CHECK-DAG: %[[#ptr_Input_v3int:]] = OpTypePointer Input %[[#v3int]]
; CHECK-DAG: %[[#tempvar:]] = OpUndef %[[#v3int]]
; CHECK-DAG: %[[#GlobalInvocationId:]] = OpVariable %[[#ptr_Input_v3int]] Input

; CHECK-DAG: OpEntryPoint GLCompute {{.*}} %[[#GlobalInvocationId]]
; CHECK-DAG: OpName %[[#GlobalInvocationId]] "__spirv_BuiltInGlobalInvocationId"
; CHECK-DAG: OpDecorate %[[#GlobalInvocationId]] LinkageAttributes "__spirv_BuiltInGlobalInvocationId" Import
; CHECK-DAG: OpDecorate %[[#GlobalInvocationId]] BuiltIn GlobalInvocationId

; ModuleID = '-'
source_filename = "-"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spirv-unknown-vulkan-library"

; Function Attrs: noinline norecurse nounwind optnone
define internal spir_func void @main(<3 x i32> noundef %ID) #0 {
entry:
%ID.addr = alloca <3 x i32>, align 16
store <3 x i32> %ID, ptr %ID.addr, align 16
ret void
}

; Function Attrs: norecurse
define void @main.1() #1 {
entry:

; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
; CHECK: %[[#load0:]] = OpCompositeExtract %[[#int]] %[[#load]] 0
%0 = call i32 @llvm.spv.thread.id(i32 0)

; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load0]] %[[#tempvar]] 0
%1 = insertelement <3 x i32> poison, i32 %0, i64 0

; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
; CHECK: %[[#load1:]] = OpCompositeExtract %[[#int]] %[[#load]] 1
%2 = call i32 @llvm.spv.thread.id(i32 1)

; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load1]] %[[#tempvar]] 1
%3 = insertelement <3 x i32> %1, i32 %2, i64 1

; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#GlobalInvocationId]]
; CHECK: %[[#load2:]] = OpCompositeExtract %[[#int]] %[[#load]] 2
%4 = call i32 @llvm.spv.thread.id(i32 2)

; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load2]] %[[#tempvar]] 2
%5 = insertelement <3 x i32> %3, i32 %4, i64 2

call void @main(<3 x i32> %5)
ret void
}

; Function Attrs: nounwind willreturn memory(none)
declare i32 @llvm.spv.thread.id(i32) #2

attributes #0 = { noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #2 = { nounwind willreturn memory(none) }

!llvm.module.flags = !{!0, !1}
!llvm.ident = !{!2}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
!2 = !{!"clang version 19.0.0git (git@github.com:llvm/llvm-project.git 91600507765679e92434ec7c5edb883bf01f847f)"}