Skip to content

Commit 1d7d71d

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Allow expression of scalar tensor buffers, non string values in variants (#4292)
Summary: Pull Request resolved: #4292 Some simple improvements to the SPIR-V compilation script: 1. Allow `layout_declare_tensor` to create a scalar buffer instead of always creating a vectorized buffer 2. Allow handling of non-string (i.e. int) values in shader codegen YAML configurations. Reviewed By: jorgep31415 Differential Revision: D59877805 fbshipit-source-id: 579888fbc19d19a0d24f2fbd831e74f4ba32f033
1 parent e5687a4 commit 1d7d71d

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def layout_declare_tensor(
231231
var_name: str,
232232
dtype: str,
233233
storage_type: str,
234+
is_scalar_array: bool = False,
234235
precision: str = "PRECISION",
235236
) -> str:
236237
assert storage_type.lower() in ["buffer", "texture3d", "texture2d"]
@@ -242,7 +243,12 @@ def layout_declare_tensor(
242243
# Create buffer binding
243244
if storage_type.lower() == "buffer":
244245
return layout_declare_buffer(
245-
slot, access_type, var_name, dtype, precision, is_scalar_array=False
246+
slot,
247+
access_type,
248+
var_name,
249+
dtype,
250+
precision,
251+
is_scalar_array=is_scalar_array,
246252
)
247253

248254
# Create image/sampler binding
@@ -533,7 +539,7 @@ def generateVariantCombinations(
533539
curr_suffix = (
534540
suffix + "_" + str(i) if suffix else str(i)
535541
)
536-
param_values.append((param_name, curr_suffix, str(i)))
542+
param_values.append((param_name, curr_suffix, i))
537543
else:
538544
raise ValueError(
539545
f"{value['RANGE']} is not a valid range. Must be in format [start, end] (inclusive)."
@@ -595,7 +601,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
595601
variant_name = variant["NAME"]
596602
for param_value in combination:
597603
default_params_copy[param_value[0]] = param_value[2]
598-
if len(param_value[1]) > 0:
604+
if len(str(param_value[1])) > 0:
599605
variant_name = f"{variant_name}_{param_value[1]}"
600606

601607
default_params_copy["NAME"] = variant_name

0 commit comments

Comments
 (0)