Skip to content

Commit 2f52336

Browse files
pytorchbottrivedivivek
authored andcommitted
[ET-VK] Replace Uniform buffers with push constants for native layer norm op (#9872)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9831 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/73/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/73/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/72/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/73/orig @diff-train-skip-merge --------- Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com>
1 parent fa6cddb commit 2f52336

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

+5-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2727
${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)}
2828
${layout_declare_tensor(B, "r", "t_bias", DTYPE, STORAGE)}
2929

30-
${layout_declare_ubo(B, "ivec3", "out_limits")}
31-
${layout_declare_ubo(B, "ivec4", "sizes")}
32-
${layout_declare_ubo(B, "float", "epsilon")}
30+
layout(push_constant) uniform PRECISION restrict Block {
31+
ivec3 out_limits;
32+
ivec4 sizes;
33+
float epsilon;
34+
};
3335

3436
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3537

backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,20 @@ void add_native_layer_norm_node(
101101
vkapi::MemoryAccessType::WRITE},
102102
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
103103
// Shader params buffers
104-
{
105-
t_out->logical_limits_ubo(),
106-
t_out->sizes_ubo(),
107-
graph.create_params_buffer(epsilon),
108-
},
104+
{},
109105
// Specialization Constants
110106
{
111107
t_input->hashed_layout(),
112108
t_out->hashed_layout(),
113109
},
114110
// Resizing Logic
115111
resize_native_layer_norm_node,
116-
{normalized_shape}));
112+
{normalized_shape},
113+
{
114+
graph.logical_limits_pc_of(out_val->at(0)),
115+
graph.sizes_pc_of(out_val->at(0)),
116+
PushConstantDataInfo(&epsilon, sizeof(epsilon)),
117+
}));
117118
}
118119

119120
void native_layer_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {

0 commit comments

Comments
 (0)