Skip to content

[HLSL] Analyze updateCounter usage #135669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2025

Conversation

V-FEXrt
Copy link
Contributor

@V-FEXrt V-FEXrt commented Apr 14, 2025

Fixes #135667

Analyze and annotate ResourceInfo with the derived direction of calls to updateCounter (if any).

This change only sets the value. Any diagnostics that should be raised must be done somewhere else.

@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2025

@llvm/pr-subscribers-backend-directx

@llvm/pr-subscribers-llvm-analysis

Author: Ashley Coleman (V-FEXrt)

Changes

Fixes #135667

Analyze and annotate ResourceInfo with the derived direction of calls to updateCounter (if any).

This change only sets the value. Any diagnostics that should be raised must be done somewhere else.


Patch is 22.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135669.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+1-1)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+45-6)
  • (modified) llvm/test/Analysis/DXILResource/buffer-frombinding.ll (+4-2)
  • (modified) llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp (+190-41)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 96e90e563e230..a8124caf64420 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -463,7 +463,7 @@ class DXILResourceMap {
   /// ambiguous so multiple creation instructions may be returned. The resulting
   /// ResourceInfo can be used to depuplicate unique handles that
   /// reference the same resource
-  SmallVector<dxil::ResourceInfo> findByUse(const Value *Key) const;
+  SmallVector<dxil::ResourceInfo *> findByUse(const Value *Key);
 
   const_iterator find(const CallInst *Key) const {
     auto Pos = CallMap.find(Key);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 1c4348321c1d0..d2392ff929611 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -697,6 +697,9 @@ bool DXILResourceTypeMap::invalidate(Module &M, const PreservedAnalyses &PA,
 }
 
 //===----------------------------------------------------------------------===//
+static bool isUpdateCounterIntrinsic(Function &F) {
+  return F.getIntrinsicID() == Intrinsic::dx_resource_updatecounter;
+}
 
 void DXILResourceMap::populate(Module &M, DXILResourceTypeMap &DRTM) {
   SmallVector<std::tuple<CallInst *, ResourceInfo, ResourceTypeInfo>> CIToInfos;
@@ -775,6 +778,42 @@ void DXILResourceMap::populate(Module &M, DXILResourceTypeMap &DRTM) {
     // Adjust the resource binding to use the next ID.
     RI.setBindingID(NextID++);
   }
+
+  for (Function &F : M.functions()) {
+    if (!isUpdateCounterIntrinsic(F))
+      continue;
+
+    LLVM_DEBUG(dbgs() << "Update Counter Function: " << F.getName() << "\n");
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      assert(CI && "Users of dx_resource_updateCounter must be call instrs");
+
+      // Determine if the use is an increment or decrement
+      Value *CountArg = CI->getArgOperand(1);
+      ConstantInt *CountValue = cast<ConstantInt>(CountArg);
+      int64_t CountLiteral = CountValue->getSExtValue();
+
+      // 0 is an unknown direction and shouldn't result in an insert
+      if (CountLiteral == 0)
+        continue;
+
+      ResourceCounterDirection Direction = ResourceCounterDirection::Decrement;
+      if (CountLiteral > 0)
+        Direction = ResourceCounterDirection::Increment;
+
+      // Collect all potential creation points for the handle arg
+      Value *HandleArg = CI->getArgOperand(0);
+      SmallVector<ResourceInfo *> RBInfos = findByUse(HandleArg);
+      for (ResourceInfo *RBInfo : RBInfos) {
+        if (RBInfo->CounterDirection == ResourceCounterDirection::Unknown ||
+            RBInfo->CounterDirection == Direction)
+          RBInfo->CounterDirection = Direction;
+        else
+          RBInfo->CounterDirection = ResourceCounterDirection::Invalid;
+      }
+    }
+  }
 }
 
 void DXILResourceMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
@@ -793,10 +832,9 @@ void DXILResourceMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
   }
 }
 
-SmallVector<dxil::ResourceInfo>
-DXILResourceMap::findByUse(const Value *Key) const {
+SmallVector<dxil::ResourceInfo *> DXILResourceMap::findByUse(const Value *Key) {
   if (const PHINode *Phi = dyn_cast<PHINode>(Key)) {
-    SmallVector<dxil::ResourceInfo> Children;
+    SmallVector<dxil::ResourceInfo *> Children;
     for (const Value *V : Phi->operands()) {
       Children.append(findByUse(V));
     }
@@ -810,9 +848,10 @@ DXILResourceMap::findByUse(const Value *Key) const {
   switch (CI->getIntrinsicID()) {
   // Found the create, return the binding
   case Intrinsic::dx_resource_handlefrombinding: {
-    const auto *It = find(CI);
+    auto Pos = CallMap.find(CI);
+    ResourceInfo *It = &Infos[Pos->second];
     assert(It != Infos.end() && "HandleFromBinding must be in resource map");
-    return {*It};
+    return {It};
   }
   default:
     break;
@@ -821,7 +860,7 @@ DXILResourceMap::findByUse(const Value *Key) const {
   // Check if any of the parameters are the resource we are following. If so
   // keep searching. If none of them are return an empty list
   const Type *UseType = CI->getType();
-  SmallVector<dxil::ResourceInfo> Children;
+  SmallVector<dxil::ResourceInfo *> Children;
   for (const Value *V : CI->args()) {
     if (V->getType() != UseType)
       continue;
diff --git a/llvm/test/Analysis/DXILResource/buffer-frombinding.ll b/llvm/test/Analysis/DXILResource/buffer-frombinding.ll
index 81c8b5530afb6..ea683bb4e5783 100644
--- a/llvm/test/Analysis/DXILResource/buffer-frombinding.ll
+++ b/llvm/test/Analysis/DXILResource/buffer-frombinding.ll
@@ -69,6 +69,7 @@ define void @test_typedbuffer() {
   %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
               @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_f32_1_0(
                   i32 3, i32 5, i32 1, i32 0, i1 false)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i8 -1)
   ; CHECK: Resource [[UAV1:[0-9]+]]:
   ; CHECK:   Binding:
   ; CHECK:     Record ID: 1
@@ -76,7 +77,7 @@ define void @test_typedbuffer() {
   ; CHECK:     Lower Bound: 5
   ; CHECK:     Size: 1
   ; CHECK:   Globally Coherent: 0
-  ; CHECK:   Counter Direction: Unknown
+  ; CHECK:   Counter Direction: Decrement
   ; CHECK:   Class: UAV
   ; CHECK:   Kind: TypedBuffer
   ; CHECK:   IsROV: 0
@@ -92,6 +93,7 @@ define void @test_typedbuffer() {
   %uav2_2 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
               @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_f32_1_0(
                   i32 4, i32 0, i32 10, i32 5, i1 false)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav2_2, i8 1)
   ; CHECK: Resource [[UAV2:[0-9]+]]:
   ; CHECK:   Binding:
   ; CHECK:     Record ID: 2
@@ -99,7 +101,7 @@ define void @test_typedbuffer() {
   ; CHECK:     Lower Bound: 0
   ; CHECK:     Size: 10
   ; CHECK:   Globally Coherent: 0
-  ; CHECK:   Counter Direction: Unknown
+  ; CHECK:   Counter Direction: Increment
   ; CHECK:   Class: UAV
   ; CHECK:   Kind: TypedBuffer
   ; CHECK:   IsROV: 0
diff --git a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
index 675a3dc19b912..d1ebfc3b1da41 100644
--- a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
+++ b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
@@ -11,6 +11,7 @@
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/CodeGen/CommandFlags.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
@@ -28,8 +29,9 @@ class UniqueResourceFromUseTest : public testing::Test {
 protected:
   PassBuilder *PB;
   ModuleAnalysisManager *MAM;
-
+  LLVMContext *Context;
   virtual void SetUp() {
+    Context = new LLVMContext();
     MAM = new ModuleAnalysisManager();
     PB = new PassBuilder();
     PB->registerModuleAnalyses(*MAM);
@@ -37,9 +39,17 @@ class UniqueResourceFromUseTest : public testing::Test {
     MAM->registerPass([&] { return DXILResourceAnalysis(); });
   }
 
+  std::unique_ptr<Module> parseAsm(StringRef Asm) {
+    SMDiagnostic Error;
+    std::unique_ptr<Module> M = parseAssemblyString(Asm, Error, *Context);
+    EXPECT_TRUE(M) << "Bad assembly?: " << Error.getMessage();
+    return M;
+  }
+
   virtual void TearDown() {
     delete PB;
     delete MAM;
+    delete Context;
   }
 };
 
@@ -47,22 +57,18 @@ TEST_F(UniqueResourceFromUseTest, TestTrivialUse) {
   StringRef Assembly = R"(
 define void @main() {
 entry:
-  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
   call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
   call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
   ret void
 }
 
-declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
 declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
   )";
 
-  LLVMContext Context;
-  SMDiagnostic Error;
-  auto M = parseAssemblyString(Assembly, Error, Context);
-  ASSERT_TRUE(M) << "Bad assembly?";
+  auto M = parseAsm(Assembly);
 
-  const DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
   for (const Function &F : M->functions()) {
     if (F.getName() != "a.func") {
       continue;
@@ -77,7 +83,7 @@ declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
       ASSERT_EQ(Bindings.size(), 1u)
           << "Handle should resolve into one resource";
 
-      auto Binding = Bindings[0].getBinding();
+      auto Binding = Bindings[0]->getBinding();
       EXPECT_EQ(0u, Binding.RecordID);
       EXPECT_EQ(1u, Binding.Space);
       EXPECT_EQ(2u, Binding.LowerBound);
@@ -94,7 +100,7 @@ declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
 TEST_F(UniqueResourceFromUseTest, TestIndirectUse) {
   StringRef Assembly = R"(
 define void @foo() {
-  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
   %handle2 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
   %handle3 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle2)
   %handle4 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle3)
@@ -102,17 +108,13 @@ define void @foo() {
   ret void
 }
 
-declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
 declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
 declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
   )";
 
-  LLVMContext Context;
-  SMDiagnostic Error;
-  auto M = parseAssemblyString(Assembly, Error, Context);
-  ASSERT_TRUE(M) << "Bad assembly?";
+  auto M = parseAsm(Assembly);
 
-  const DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
   for (const Function &F : M->functions()) {
     if (F.getName() != "a.func") {
       continue;
@@ -127,7 +129,7 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
       ASSERT_EQ(Bindings.size(), 1u)
           << "Handle should resolve into one resource";
 
-      auto Binding = Bindings[0].getBinding();
+      auto Binding = Bindings[0]->getBinding();
       EXPECT_EQ(0u, Binding.RecordID);
       EXPECT_EQ(1u, Binding.Space);
       EXPECT_EQ(2u, Binding.LowerBound);
@@ -144,10 +146,10 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
 TEST_F(UniqueResourceFromUseTest, TestAmbigousIndirectUse) {
   StringRef Assembly = R"(
 define void @foo() {
-  %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
-  %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 2, i32 2, i32 2, i32 2, i1 false)
-  %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 3, i32 3, i32 3, i32 3, i1 false)
-  %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 2, i32 2, i32 2, i32 2, i1 false)
+  %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 3, i32 3, i32 3, i32 3, i1 false)
+  %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 4, i32 4, i32 4, i32 4, i1 false)
   %a = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %foo, target("dx.RawBuffer", float, 1, 0) %bar)
   %b = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %baz, target("dx.RawBuffer", float, 1, 0) %bat)
   %handle = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %a, target("dx.RawBuffer", float, 1, 0) %b)
@@ -155,17 +157,13 @@ define void @foo() {
   ret void
 }
 
-declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
 declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
 declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x, target("dx.RawBuffer", float, 1, 0) %y)
   )";
 
-  LLVMContext Context;
-  SMDiagnostic Error;
-  auto M = parseAssemblyString(Assembly, Error, Context);
-  ASSERT_TRUE(M) << "Bad assembly?";
+  auto M = parseAsm(Assembly);
 
-  const DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
   for (const Function &F : M->functions()) {
     if (F.getName() != "a.func") {
       continue;
@@ -180,25 +178,25 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
       ASSERT_EQ(Bindings.size(), 4u)
           << "Handle should resolve into four resources";
 
-      auto Binding = Bindings[0].getBinding();
+      auto Binding = Bindings[0]->getBinding();
       EXPECT_EQ(0u, Binding.RecordID);
       EXPECT_EQ(1u, Binding.Space);
       EXPECT_EQ(1u, Binding.LowerBound);
       EXPECT_EQ(1u, Binding.Size);
 
-      Binding = Bindings[1].getBinding();
+      Binding = Bindings[1]->getBinding();
       EXPECT_EQ(1u, Binding.RecordID);
       EXPECT_EQ(2u, Binding.Space);
       EXPECT_EQ(2u, Binding.LowerBound);
       EXPECT_EQ(2u, Binding.Size);
 
-      Binding = Bindings[2].getBinding();
+      Binding = Bindings[2]->getBinding();
       EXPECT_EQ(2u, Binding.RecordID);
       EXPECT_EQ(3u, Binding.Space);
       EXPECT_EQ(3u, Binding.LowerBound);
       EXPECT_EQ(3u, Binding.Size);
 
-      Binding = Bindings[3].getBinding();
+      Binding = Bindings[3]->getBinding();
       EXPECT_EQ(3u, Binding.RecordID);
       EXPECT_EQ(4u, Binding.Space);
       EXPECT_EQ(4u, Binding.LowerBound);
@@ -216,8 +214,8 @@ TEST_F(UniqueResourceFromUseTest, TestConditionalUse) {
   StringRef Assembly = R"(
 define void @foo(i32 %n) {
 entry:
-  %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
-  %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 4, i32 4, i32 4, i32 4, i1 false)
   %cond = icmp eq i32 %n, 0
   br i1 %cond, label %bb.true, label %bb.false
 
@@ -235,17 +233,13 @@ bb.exit:
   ret void
 }
 
-declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
 declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
 declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
   )";
 
-  LLVMContext Context;
-  SMDiagnostic Error;
-  auto M = parseAssemblyString(Assembly, Error, Context);
-  ASSERT_TRUE(M) << "Bad assembly?";
+  auto M = parseAsm(Assembly);
 
-  const DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
   for (const Function &F : M->functions()) {
     if (F.getName() != "a.func") {
       continue;
@@ -260,13 +254,13 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
       ASSERT_EQ(Bindings.size(), 2u)
           << "Handle should resolve into four resources";
 
-      auto Binding = Bindings[0].getBinding();
+      auto Binding = Bindings[0]->getBinding();
       EXPECT_EQ(0u, Binding.RecordID);
       EXPECT_EQ(1u, Binding.Space);
       EXPECT_EQ(1u, Binding.LowerBound);
       EXPECT_EQ(1u, Binding.Size);
 
-      Binding = Bindings[1].getBinding();
+      Binding = Bindings[1]->getBinding();
       EXPECT_EQ(1u, Binding.RecordID);
       EXPECT_EQ(4u, Binding.Space);
       EXPECT_EQ(4u, Binding.LowerBound);
@@ -280,4 +274,159 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
   }
 }
 
+// Test that several calls to decrement on the same resource don't raise a
+// Diagnositic and resolves to a single decrement entry
+TEST_F(UniqueResourceFromUseTest, TestResourceCounterDecrement) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
+  ret void
+}
+  )";
+
+  auto M = parseAsm(Assembly);
+
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+
+  for (const Function &F : M->functions()) {
+    if (F.getIntrinsicID() != Intrinsic::dx_resource_handlefrombinding) {
+      continue;
+    }
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = cast<CallInst>(U);
+      const auto *const Binding = DRM.find(CI);
+      ASSERT_EQ(Binding->CounterDirection, ResourceCounterDirection::Decrement);
+    }
+  }
+}
+
+// Test that several calls to increment on the same resource don't raise a
+// Diagnositic and resolves to a single increment entry
+TEST_F(UniqueResourceFromUseTest, TestResourceCounterIncrement) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 1)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 1)
+  call i32 @llvm.dx.resource.updatecounter(target("dx.RawBuffer", float, 1, 0) %handle, i8 1)
+  ret void
+}
+  )";
+
+  auto M = parseAsm(Assembly);
+
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+
+  for (const Function &F : M->functions()) {
+    if (F.getIntrinsicID() != Intrinsic::dx_resource_handlefrombinding) {
+      continue;
+    }
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = cast<CallInst>(U);
+      const auto *const Binding = DRM.find(CI);
+      ASSERT_EQ(Binding->CounterDirection, ResourceCounterDirection::Increment);
+    }
+  }
+}
+
+// Test that looking up a resource that doesn't have the counter updated
+// resoves to unknown
+TEST_F(UniqueResourceFromUseTest, TestResourceCounterUnknown) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding(i32 1, i32 2, i32 3, i32 4, i1 false)
+  ret void
+}
+  )";
+
+  auto M = parseAsm(Assembly);
+
+  DXILResourceMap &DRM = MAM->getResult<DXILResourceAnalysis>(*M);
+
+  for (const Function &F : M->functions()) {
+    if (F.getIntrinsicID() != Intrinsic::dx_resource_handlefrombinding) {
+      continue;
+    }
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = cast<CallInst>(U);
+      const auto *const Binding = DRM.find(CI);
+      ASSERT_EQ(Binding->CounterDirection, ResourceCounterDirection::Unknown);
+    }
+  }
+}
+
+// Test that multiple different resources with unique incs/decs aren't
+// marked...
[truncated]

RBInfo->CounterDirection == Direction)
RBInfo->CounterDirection = Direction;
else
RBInfo->CounterDirection = ResourceCounterDirection::Invalid;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this case mean? The CounterDirection was previously set and was wrong? Or there is conflicting information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicting information!

If CounterDirection is detected as both Increment and Decrement that is considered invalid (because only one direction is allowed)

Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! A few nits and suggestions.

@V-FEXrt V-FEXrt merged commit f12fb2f into llvm:main Apr 24, 2025
12 checks passed
@V-FEXrt V-FEXrt deleted the 135667-analyze-updatecounter-usage branch April 24, 2025 19:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Closed
Development

Successfully merging this pull request may close these issues.

[HLSL] Analyze and annotate updateCounter direction on a resource
4 participants