Skip to content

[DirectX][ShaderFlags] Add analysis for WaveOps flag #118140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 14, 2025

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Nov 29, 2024

  • Check each call instruction for a WaveOp intrinsic and set the WaveOps flag if this is true for any intrinsic, Done in DXILShaderFlags.cpp

Resolves #114565

@inbelic inbelic marked this pull request as ready for review November 29, 2024 22:42
@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2024

@llvm/pr-subscribers-tablegen

@llvm/pr-subscribers-backend-directx

Author: Finn Plummer (inbelic)

Changes
  • updates DXILShaderFlags to check if any functions call a wave intrinsic
  • adds a testcase to iswave.ll

Resolves #114565


Patch is 119.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118140.diff

56 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXIL.td (+36-14)
  • (modified) llvm/lib/Target/DirectX/DXILConstants.h (+10)
  • (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+27-2)
  • (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+21)
  • (modified) llvm/test/CodeGen/DirectX/BufferLoad.ll (+13-11)
  • (modified) llvm/test/CodeGen/DirectX/BufferStore.ll (+4-4)
  • (added) llvm/test/CodeGen/DirectX/ShaderFlags/iswave.ll (+37)
  • (modified) llvm/test/CodeGen/DirectX/WaveActiveAnyTrue.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/WaveGetLaneIndex.ll (+3-1)
  • (modified) llvm/test/CodeGen/DirectX/WaveReadLaneAt-vec.ll (+9-9)
  • (modified) llvm/test/CodeGen/DirectX/WaveReadLaneAt.ll (+9-7)
  • (modified) llvm/test/CodeGen/DirectX/abs.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/acos.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/asin.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/atan.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/ceil.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/comput_ids.ll (+6-4)
  • (modified) llvm/test/CodeGen/DirectX/cos.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/cosh.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/countbits.ll (+12-10)
  • (modified) llvm/test/CodeGen/DirectX/dot4add_i8packed.ll (+3-1)
  • (modified) llvm/test/CodeGen/DirectX/dot4add_u8packed.ll (+3-1)
  • (modified) llvm/test/CodeGen/DirectX/exp.ll (+4-2)
  • (modified) llvm/test/CodeGen/DirectX/fdot.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/firstbithigh.ll (+16-14)
  • (modified) llvm/test/CodeGen/DirectX/floor.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/fmad.ll (+4-3)
  • (modified) llvm/test/CodeGen/DirectX/fmax.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/fmin.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/frac.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/idot.ll (+12-10)
  • (modified) llvm/test/CodeGen/DirectX/imad.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/isinf.ll (+3-2)
  • (modified) llvm/test/CodeGen/DirectX/log.ll (+4-2)
  • (modified) llvm/test/CodeGen/DirectX/log10.ll (+4-2)
  • (modified) llvm/test/CodeGen/DirectX/log2.ll (+4-2)
  • (modified) llvm/test/CodeGen/DirectX/reversebits.ll (+9-7)
  • (modified) llvm/test/CodeGen/DirectX/round.ll (+7-6)
  • (modified) llvm/test/CodeGen/DirectX/rsqrt.ll (+7-6)
  • (modified) llvm/test/CodeGen/DirectX/saturate.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/sin.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/sinh.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/smax.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/smin.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/splitdouble.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/sqrt.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/tan.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/tanh.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/trunc.ll (+8-6)
  • (modified) llvm/test/CodeGen/DirectX/umad.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/umax.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/umin.ll (+5-3)
  • (modified) llvm/test/CodeGen/DirectX/updateCounter.ll (+3-3)
  • (modified) llvm/test/CodeGen/DirectX/wave_is_first_lane.ll (+2)
  • (modified) llvm/utils/TableGen/DXILEmitter.cpp (+98-24)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 7cc08b2fe7cc4b..4b285d0e8043e7 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -266,18 +266,30 @@ def miss : DXILShaderStage;
 def all_stages : DXILShaderStage;
 // Denote support for DXIL Op to have been removed
 def removed : DXILShaderStage;
+
 // DXIL Op attributes
 
+// A function attribute denotes that there is a corresponding LLVM function
+// attribute that will be set when building the DXIL op. The mapping for
+// non-trivial cases is defined by setDXILAttribute in DXILOpBuilder.cpp
 class DXILAttribute;
 
-def ReadOnly : DXILAttribute;
 def ReadNone : DXILAttribute;
-def IsDerivative : DXILAttribute;
-def IsGradient : DXILAttribute;
-def IsFeedback : DXILAttribute;
-def IsWave : DXILAttribute;
-def NeedsUniformInputs : DXILAttribute;
-def IsBarrier : DXILAttribute;
+def ReadOnly : DXILAttribute;
+def NoDuplicate : DXILAttribute;
+def NoReturn : DXILAttribute;
+
+// A property is simply used to mark a DXIL op belongs to a sub-group of
+// DXIL ops, and it is used to query if a particular holds this property.
+// This is used for static analysis of DXIL ops.
+class DXILProperty;
+
+def IsBarrier : DXILProperty;
+def IsDerivative : DXILProperty;
+def IsGradient : DXILProperty;
+def IsFeedback : DXILProperty;
+def IsWave : DXILProperty;
+def RequiresUniformInputs : DXILProperty;
 
 class Overloads<Version ver, list<DXILOpParamType> ols> {
   Version dxil_version = ver;
@@ -291,7 +303,7 @@ class Stages<Version ver, list<DXILShaderStage> st> {
 
 class Attributes<Version ver = DXIL1_0, list<DXILAttribute> attrs> {
   Version dxil_version = ver;
-  list<DXILAttribute> op_attrs = attrs;
+  list<DXILAttribute> fn_attrs = attrs;
 }
 
 // Abstraction DXIL Operation
@@ -322,6 +334,9 @@ class DXILOp<int opcode, DXILOpClass opclass> {
 
   // Versioned attributes of operation
   list<Attributes> attributes = [];
+
+  // List of properties. Default to no properties.
+  list<DXILProperty> properties = [];
 }
 
 // Concrete definitions of DXIL Operations
@@ -729,6 +744,7 @@ def CreateHandle : DXILOp<57, createHandle> {
   let arguments = [Int8Ty, Int32Ty, Int32Ty, Int1Ty];
   let result = HandleTy;
   let stages = [Stages<DXIL1_0, [all_stages]>, Stages<DXIL1_6, [removed]>];
+  let attributes = [Attributes<DXIL1_0, [ReadOnly]>];
 }
 
 def BufferLoad : DXILOp<68, bufferLoad> {
@@ -740,6 +756,7 @@ def BufferLoad : DXILOp<68, bufferLoad> {
       [Overloads<DXIL1_0,
                  [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadOnly]>];
 }
 
 def BufferStore : DXILOp<69, bufferStore> {
@@ -768,6 +785,7 @@ def CheckAccessFullyMapped : DXILOp<71, checkAccessFullyMapped> {
   let result = Int1Ty;
   let overloads = [Overloads<DXIL1_0, [Int32Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadOnly]>];
 }
 
 def Discard : DXILOp<82, discard> {
@@ -842,8 +860,8 @@ def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
   let LLVMIntrinsic = int_dx_dot4add_i8packed;
   let arguments = [Int32Ty, Int32Ty, Int32Ty];
   let result = Int32Ty;
-  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
 def Dot4AddU8Packed : DXILOp<164, dot4AddPacked> {
@@ -852,8 +870,8 @@ def Dot4AddU8Packed : DXILOp<164, dot4AddPacked> {
   let LLVMIntrinsic = int_dx_dot4add_u8packed;
   let arguments = [Int32Ty, Int32Ty, Int32Ty];
   let result = Int32Ty;
-  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
 def AnnotateHandle : DXILOp<216, annotateHandle> {
@@ -861,6 +879,7 @@ def AnnotateHandle : DXILOp<216, annotateHandle> {
   let arguments = [HandleTy, ResPropsTy];
   let result = HandleTy;
   let stages = [Stages<DXIL1_6, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
 def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
@@ -868,6 +887,7 @@ def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
   let arguments = [ResBindTy, Int32Ty, Int1Ty];
   let result = HandleTy;
   let stages = [Stages<DXIL1_6, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
 def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
@@ -876,6 +896,7 @@ def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
   let arguments = [Int1Ty];
   let result = Int1Ty;
   let stages = [Stages<DXIL1_0, [all_stages]>];
+  let properties = [IsWave];
 }
 
 def WaveIsFirstLane :  DXILOp<110, waveIsFirstLane> {
@@ -884,7 +905,7 @@ def WaveIsFirstLane :  DXILOp<110, waveIsFirstLane> {
   let arguments = [];
   let result = Int1Ty;
   let stages = [Stages<DXIL1_0, [all_stages]>];
-  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let properties = [IsWave];
 }
 
 def WaveReadLaneAt:  DXILOp<117, waveReadLaneAt> {
@@ -894,7 +915,7 @@ def WaveReadLaneAt:  DXILOp<117, waveReadLaneAt> {
   let result = OverloadTy;
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty, Int64Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
-  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let properties = [IsWave];
 }
 
 def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
@@ -903,7 +924,8 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
   let arguments = [];
   let result = Int32Ty;
   let stages = [Stages<DXIL1_0, [all_stages]>];
-  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let attributes = [Attributes<DXIL1_0, [ReadOnly]>];
+  let properties = [IsWave];
 }
 
 def WaveAllBitCount : DXILOp<135, waveAllOp> {
@@ -912,5 +934,5 @@ def WaveAllBitCount : DXILOp<135, waveAllOp> {
   let arguments = [Int1Ty];
   let result = Int32Ty;
   let stages = [Stages<DXIL1_0, [all_stages]>];
-  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let properties = [IsWave];
 }
diff --git a/llvm/lib/Target/DirectX/DXILConstants.h b/llvm/lib/Target/DirectX/DXILConstants.h
index 022cd57795a063..229401d6b271aa 100644
--- a/llvm/lib/Target/DirectX/DXILConstants.h
+++ b/llvm/lib/Target/DirectX/DXILConstants.h
@@ -30,6 +30,16 @@ enum class OpParamType : unsigned {
 #include "DXILOperation.inc"
 };
 
+enum class Attribute : unsigned {
+#define DXIL_ATTRIBUTE(Name) Name,
+#include "DXILOperation.inc"
+};
+
+enum class Property : unsigned {
+#define DXIL_PROPERTY(Name) Name,
+#include "DXILOperation.inc"
+};
+
 } // namespace dxil
 } // namespace llvm
 
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 5d5bb3eacace25..cae3f2ea43bf8e 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -54,7 +54,7 @@ struct OpStage {
 
 struct OpAttribute {
   Version DXILVersion;
-  uint32_t ValidAttrs;
+  llvm::SmallVector<dxil::Attribute> ValidAttrs;
 };
 
 static const char *getOverloadTypeName(OverloadKind Kind) {
@@ -159,6 +159,7 @@ struct OpCodeProperty {
   llvm::SmallVector<OpOverload> Overloads;
   llvm::SmallVector<OpStage> Stages;
   llvm::SmallVector<OpAttribute> Attributes;
+  llvm::SmallVector<dxil::Property> Properties;
   int OverloadParamIndex; // parameter index which control the overload.
                           // When < 0, should be only 1 overload type.
 };
@@ -367,6 +368,20 @@ static std::optional<size_t> getPropIndex(ArrayRef<T> PropList,
   return std::nullopt;
 }
 
+static void setDXILAttribute(CallInst *CI, dxil::Attribute Attr) {
+  switch (Attr) {
+  case dxil::Attribute::ReadNone:
+    return CI->setDoesNotAccessMemory();
+  case dxil::Attribute::ReadOnly:
+    return CI->setOnlyReadsMemory();
+  case dxil::Attribute::NoReturn:
+    return CI->setDoesNotReturn();
+  case dxil::Attribute::NoDuplicate:
+    return CI->setCannotDuplicate();
+  }
+  llvm_unreachable("Invalid function attribute specified for DXIL operation");
+}
+
 namespace llvm {
 namespace dxil {
 
@@ -461,7 +476,17 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
   OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode)));
   OpArgs.append(Args.begin(), Args.end());
 
-  return IRB.CreateCall(DXILFn, OpArgs, Name);
+  // Create the function call instruction
+  CallInst *CI = IRB.CreateCall(DXILFn, OpArgs, Name);
+
+  // We then need to attach available function attributes
+  for (auto OpAttr : Prop->Attributes)
+    if (VersionTuple(OpAttr.DXILVersion.Major, OpAttr.DXILVersion.Minor) <=
+        DXILVersion)
+      for (auto Attr : OpAttr.ValidAttrs)
+        setDXILAttribute(CI, Attr);
+
+  return CI;
 }
 
 CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 9fa137b4c025e1..d720e75ce257b0 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -12,14 +12,32 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILShaderFlags.h"
+#include "DXILConstants.h"
 #include "DirectX.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace llvm;
 using namespace llvm::dxil;
 
+// Include hasProperty which is generated by tabelGen
+#define DXIL_OP_PROPERTY_HELPER
+#include "DXILOperation.inc"
+#undef DXIL_OP_PROPERTY_HELPER
+
+static bool checkWaveOps(const CallInst *CI) {
+  switch (CI->getIntrinsicID()) {
+#define DXIL_OP_INTRINSIC(OpCode, Intrin)                                      \
+  case Intrin:                                                                 \
+    return hasProperty(OpCode, dxil::Property::IsWave);
+#include "DXILOperation.inc"
+  }
+  return false;
+}
+
 static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
   Type *Ty = I.getType();
   if (Ty->isDoubleTy()) {
@@ -34,6 +52,9 @@ static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
       break;
     }
   }
+
+  if (auto *CI = dyn_cast<CallInst>(&I))
+    Flags.WaveOps |= checkWaveOps(CI);
 }
 
 ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll
index 24d65fe1648c15..874c81df29b64a 100644
--- a/llvm/test/CodeGen/DirectX/BufferLoad.ll
+++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll
@@ -16,7 +16,7 @@ define void @loadv4f32() {
   ; The temporary casts should all have been cleaned up
   ; CHECK-NOT: %dx.cast_handle
 
-  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR:]]
   %data0 = call <4 x float> @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
 
@@ -33,7 +33,7 @@ define void @loadv4f32() {
   call void @scalar_user(float %data0_0)
   call void @scalar_user(float %data0_2)
 
-  ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
+  ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef) #[[#ATTR]]
   %data4 = call <4 x float> @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)
 
@@ -47,7 +47,7 @@ define void @loadv4f32() {
   ; CHECK: insertelement <4 x float>
   call void @vector_user(<4 x float> %data4)
 
-  ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
+  ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef) #[[#ATTR]]
   %data12 = call <4 x float> @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)
 
@@ -69,7 +69,7 @@ define void @index_dynamic(i32 %bufindex, i32 %elemindex) {
       @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
           i32 0, i32 0, i32 1, i32 0, i1 false)
 
-  ; CHECK: [[LOAD:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 %bufindex, i32 undef)
+  ; CHECK: [[LOAD:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 %bufindex, i32 undef) #[[#ATTR]]
   %load = call <4 x float> @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %bufindex)
 
@@ -104,7 +104,7 @@ define void @loadf32() {
       @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_0_0_0(
           i32 0, i32 0, i32 1, i32 0, i1 false)
 
-  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
   %data0 = call float @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", float, 0, 0, 0) %buffer, i32 0)
 
@@ -122,7 +122,7 @@ define void @loadv2f32() {
       @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v2f32_0_0_0(
           i32 0, i32 0, i32 1, i32 0, i1 false)
 
-  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
   %data0 = call <2 x float> @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", <2 x float>, 0, 0, 0) %buffer, i32 0)
 
@@ -136,12 +136,12 @@ define void @loadv4f32_checkbit() {
       @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
           i32 0, i32 0, i32 1, i32 0, i1 false)
 
-  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
   %data0 = call {<4 x float>, i1} @llvm.dx.typedBufferLoad.checkbit.f32(
       target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
 
   ; CHECK: [[STATUS:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 4
-  ; CHECK: [[MAPPED:%.*]] = call i1 @dx.op.checkAccessFullyMapped.i32(i32 71, i32 [[STATUS]]
+  ; CHECK: [[MAPPED:%.*]] = call i1 @dx.op.checkAccessFullyMapped.i32(i32 71, i32 [[STATUS]]) #[[#ATTR]]
   %check = extractvalue {<4 x float>, i1} %data0, 1
 
   ; CHECK: call void @check_user(i1 [[MAPPED]])
@@ -157,7 +157,7 @@ define void @loadv4i32() {
       @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
           i32 0, i32 0, i32 1, i32 0, i1 false)
 
-  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
   %data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
 
@@ -171,7 +171,7 @@ define void @loadv4f16() {
       @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
           i32 0, i32 0, i32 1, i32 0, i1 false)
 
-  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
   %data0 = call <4 x half> @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
 
@@ -185,9 +185,11 @@ define void @loadv4i16() {
       @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
           i32 0, i32 0, i32 1, i32 0, i1 false)
 
-  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) #[[#ATTR]]
   %data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
       target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
 
   ret void
 }
+
+; CHECK: attributes #[[#ATTR]] = {{{.*}} memory(read) {{.*}}}
diff --git a/llvm/test/CodeGen/DirectX/BufferStore.ll b/llvm/test/CodeGen/DirectX/BufferStore.ll
index 9ea7735be59c81..68849bc71edd22 100644
--- a/llvm/test/CodeGen/DirectX/BufferStore.ll
+++ b/llvm/test/CodeGen/DirectX/BufferStore.ll
@@ -17,7 +17,7 @@ define void @storefloat(<4 x float> %data, i32 %index) {
   ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x float> %data, i32 1
   ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x float> %data, i32 2
   ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x float> %data, i32 3
-  ; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float [[DATA0_0]], float [[DATA0_1]], float [[DATA0_2]], float [[DATA0_3]], i8 15)
+  ; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float [[DATA0_0]], float [[DATA0_1]], float [[DATA0_2]], float [[DATA0_3]], i8 15){{$}}
   call void @llvm.dx.typedBufferStore(
       target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer,
       i32 %index, <4 x float> %data)
@@ -37,7 +37,7 @@ define void @storeint(<4 x i32> %data, i32 %index) {
   ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x i32> %data, i32 1
   ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x i32> %data, i32 2
   ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x i32> %data, i32 3
-  ; CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i32 [[DATA0_0]], i32 [[DATA0_1]], i32 [[DATA0_2]], i32 [[DATA0_3]], i8 15)
+  ; CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i32 [[DATA0_0]], i32 [[DATA0_1]], i32 [[DATA0_2]], i32 [[DATA0_3]], i8 15){{$}}
   call void @llvm.dx.typedBufferStore(
       target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) %buffer,
       i32 %index, <4 x i32> %data)
@@ -60,7 +60,7 @@ define void @storehalf(<4 x half> %data, i32 %index) {
   ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x half> %data, i32 1
   ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x half> %data, i32 2
   ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x half> %data, i32 3
-  ; CHECK: call void @dx.op.bufferStore.f16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, half [[DATA0_0]], half [[DATA0_1]], half [[DATA0_2]], half [[DATA0_3]], i8 15)
+  ; CHECK: call void @dx.op.bufferStore.f16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, half [[DATA0_0]], half [[DATA0_1]], half [[DATA0_2]], half [[DATA0_3]], i8 15){{$}}
   call void @llvm.dx.typedBufferStore(
       target("dx.TypedBuffer", <4 x half>, 1, 0, 0) %buffer,
       i32 %index, <4 x half> %data)
@@ -83,7 +83,7 @@ define void @storei16(<4 x i16> %data, i32 %index) {
   ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x i16> %data, i32 1
   ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x i16> %data, i32 2
   ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x i16> %data, i32 3
-  ; CHECK: call void @dx.op.bufferStore.i16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i16 [[DATA0_0]], i16 [[DATA0_1]], i16 [[DATA0_2]], i16 [[DATA0_3]], i8 15)
+  ; CHECK: call void @dx.op.bufferStore.i16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i16 [[DATA0_0]], i16 [[DATA0_1]], i16 [[DATA0_2]], i16 [[DATA0_3]], i8 15){{$}}
   call void @llvm.dx.typedBufferStore(
       target("dx.TypedBuffer", <4 x i16>, 1, 0, 0) %buffer,
       i32 %index, <4 x i16> %data)
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/iswave.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/iswave.ll
new file mode 100644
index 00000000000000..2c1164a163a37f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/iswave.ll
@@ -0,0 +1,37 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+ta...
[truncated]

@inbelic inbelic marked this pull request as draft December 9, 2024 17:17
@inbelic inbelic force-pushed the inbelic/wave-flags branch from 9d1fe4f to 898854e Compare January 22, 2025 23:12
@inbelic inbelic changed the title [DXIL][ShaderFlags] Add analysis for WaveOps flag [DirectX][ShaderFlags] Add analysis for WaveOps flag Jan 22, 2025
Copy link

github-actions bot commented Jan 22, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@inbelic inbelic marked this pull request as ready for review January 24, 2025 01:22
@@ -1091,4 +1092,4 @@ def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
let result = HandleTy;
let stages = [Stages<DXIL1_6, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Artifact from github ui resolve conflicts. Will update with other review comments

Comment on lines 34 to 64
static dxil::Properties getOpCodeProperties(dxil::OpCode OpCode) {
dxil::Properties Props;
switch (OpCode) {
#define DXIL_OP_PROPERTIES(OpCode, ...) \
case OpCode: \
Props = dxil::Properties{__VA_ARGS__}; \
break;
#include "DXILOperation.inc"
}
return Props;
}

static bool checkWaveOps(Intrinsic::ID IID) {
switch (IID) {
#define DXIL_OP_INTRINSIC(OpCode, IntrinsicID, ...) \
case IntrinsicID: \
return getOpCodeProperties(OpCode).IsWave;
#include "DXILOperation.inc"
}
return false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be an argument for simply listing the 30/40 relevant intrinsic IDs in a switch statement here rather than having the whole infrastructure to map in reverse from the DXIL opcode to the directx intrinsic

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are probably going to want to be able to map DXIL opcodes to intrinsic functions for the DXIL loading path though right? So it is infrastructure we need one way or another.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in either case, it is more a question if we want to continuing using the IsWave property (and DXILPropertys in general) in DXIL.td to model sub-groups of ops for shader flag analysis.

The removal of DXILProperty is outlined here.

- add Shader Flag analysis for the `WaveOps` flag in DXILShaderFlags.cpp
- add testing of all currenlty supported wave ops in wave-ops.ll
@inbelic
Copy link
Contributor Author

inbelic commented Feb 7, 2025

As outlined here, we will switch to an approach that avoids the use of Property

; CHECK: ; Note: shader requires additional functionality:
; CHECK-NEXT: ; Wave level operations
; CHECK-NEXT: ; Note: extra DXIL module flags:
; CHECK-NEXT: {{^;$}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what this check does?

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines 49 to 62
// Currently unsupported intrinsics
// case Intrinsic::dx_WaveGetLaneCount:
// case Intrinsic::dx_WaveActiveAllEqual:
// case Intrinsic::dx_WaveActiveBallot:
// case Intrinsic::dx_WaveReadLaneFirst:
// case Intrinsic::dx_WaveActiveBit:
// case Intrinsic::dx_WavePrefixOp:
// case Intrinsic::dx_QuadReadLaneAt:
// case Intrinsic::dx_QuadOp:
// case Intrinsic::dx_WavePrefixBitCount:
// case Intrinsic::dx_WaveMatch:
// case Intrinsic::dx_WaveMultiPrefixOp:
// case Intrinsic::dx_WaveMultiPrefixBitCount:
// case Intrinsic::dx_QuadVote:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have proposed spellings for all of these, so you could name these to match those if you like: https://github.com/llvm/wg-hlsl/blob/main/proposals/0014-consistent-naming-for-dx-intrinsics.md#wave-ops

Also clang-format does something kind of awkward here - might be better to put this comment outside of the switch.

@inbelic inbelic merged commit 5742dc4 into llvm:main Feb 14, 2025
6 of 9 checks passed
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
- Check each call instruction for a `WaveOp` intrinsic and set the
`WaveOps` flag if this is true for any intrinsic, Done in
DXILShaderFlags.cpp

Resolves llvm#114565
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
- Check each call instruction for a `WaveOp` intrinsic and set the
`WaveOps` flag if this is true for any intrinsic, Done in
DXILShaderFlags.cpp

Resolves llvm#114565
@inbelic inbelic deleted the inbelic/wave-flags branch April 2, 2025 18:24
@damyanp damyanp moved this to Closed in HLSL Support Apr 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Closed
Development

Successfully merging this pull request may close these issues.

[DirectX] Implement Shader Flags Analysis for WaveOps
5 participants