Skip to content

Conversation

@FranklandJack
Copy link
Contributor

Add support for generating shader arguments as global variables in the SPIR-V module when the argument in question is a SPIR-V image.

Add lit tests to execute the new logic and check global variables are being generated.

Add support for generating shader arguments as global variables in the
SPIR-V module when the argument in question is a SPIR-V image.

Add lit tests to execute the new logic and check global variables are
being generated.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
@llvmbot
Copy link
Member

llvmbot commented Jul 28, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Jack Frankland (FranklandJack)

Changes

Add support for generating shader arguments as global variables in the SPIR-V module when the argument in question is a SPIR-V image.

Add lit tests to execute the new logic and check global variables are being generated.


Full diff: https://github.com/llvm/llvm-project/pull/150996.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+11-1)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir (+24)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 85525a5a02fa2..e447e4bfae9dc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -58,7 +58,17 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
         spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
   }
   auto varPtrType = cast<spirv::PointerType>(varType);
-  auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
+  auto pointeeType = varPtrType.getPointeeType();
+
+  // Images are an opaque type and so we can just return a pointer to an image.
+  // Note that currently only sampled images are supported in the SPIR-V
+  // lowering.
+  if (isa<spirv::SampledImageType>(pointeeType))
+    return builder.create<spirv::GlobalVariableOp>(
+        funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
+        abiInfo.getBinding());
+
+  auto varPointeeType = cast<spirv::StructType>(pointeeType);
 
   // Set the offset information.
   varPointeeType =
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index bd51a07843652..f3a3218e5aec0 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -66,3 +66,27 @@ spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#s
   // CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
   // CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
 } // end spirv.module
+
+// -----
+
+module {
+  spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Sampled1D], []>, #spirv.resource_limits<>>} {
+    // CHECK-DAG: spirv.GlobalVariable @[[IMAGE_GV:.*]] bind(0, 0) : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+    // CHECK: spirv.func @read_image
+    spirv.func @read_image(%arg0: !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}, %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None" attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+      // CHECK: %[[IMAGE_ADDR:.*]] = spirv.mlir.addressof @[[IMAGE_GV]] : !spirv.ptr<!spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>, UniformConstant>
+      %cst0_i32 = spirv.Constant 0 : i32
+      // CHECK: spirv.Load "UniformConstant" %[[IMAGE_ADDR]]
+      %0 = spirv.Load "UniformConstant" %arg0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+      %1 = spirv.Image %0 : !spirv.sampled_image<!spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>>
+      %2 = spirv.ImageFetch %1, %cst0_i32  : !spirv.image<f32, Dim1D, DepthUnknown, NonArrayed, SingleSampled, NeedSampler, R32f>, i32 -> vector<4xf32>
+      %3 = spirv.CompositeExtract %2[0 : i32] : vector<4xf32>
+      %cst0_i32_0 = spirv.Constant 0 : i32
+      %cst0_i32_1 = spirv.Constant 0 : i32
+      %cst1_i32 = spirv.Constant 1 : i32
+      %4 = spirv.AccessChain %arg1[%cst0_i32_0, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+      spirv.Store "StorageBuffer" %4, %3 : f32
+      spirv.Return
+    }
+  }
+}

@kuhar kuhar requested a review from IgWod-IMG July 28, 2025 18:37
Make variable type explicit in definition.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a small nit

Use new builder APIs over builder methods.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
@FranklandJack FranklandJack merged commit 8bb3095 into llvm:main Jul 29, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants