Skip to content

[ET-VK] Adding all tensor packing support for native layer norm. #9870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,6 @@ def register_ported_op_all_packed_dims(features: OpFeatures):
[
exir_ops.edge.aten.embedding.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
]
)
def register_ported_ops_with_prepacking(features: OpFeatures):
Expand All @@ -587,6 +586,20 @@ def register_ported_ops_with_prepacking(features: OpFeatures):
return features


# Ported ops that support their own prepacking.
@update_features(
[
exir_ops.edge.aten.native_layer_norm.default,
]
)
def register_ported_ops_with_prepacking_all_dims(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims=all_packed_dims,
)
features.handles_own_prepacking = True
return features


#######################
## Utility functions ##
#######################
Expand Down
126 changes: 94 additions & 32 deletions backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#define VEC4_T ${texel_type(DTYPE)}

#define T ${texel_component_type(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
Expand Down Expand Up @@ -48,37 +50,97 @@ void main() {

const int width = int(sizes.x);

VEC4_T mean = VEC4_T(0);
VEC4_T delta = VEC4_T(0);
VEC4_T delta2 = VEC4_T(0);
VEC4_T M2 = VEC4_T(0);

// Use Welford's online algorithm to compute mean and variance in one pass
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
for (int w = 0; w < width; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
delta = v - mean;
mean += delta / (w + 1);
delta2 = v - mean;
M2 += delta * delta2;
}

VEC4_T var = M2 / width;
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
VEC4_T offset = -rstd * mean;

for (int w = 0; w < width; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
// broadcasting
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
VEC4_T outtex = (v * rstd + offset) * weight + bias;
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
if (in_packed_dim != W_DIM) {
VEC4_T mean = VEC4_T(0);
VEC4_T delta = VEC4_T(0);
VEC4_T delta2 = VEC4_T(0);
VEC4_T M2 = VEC4_T(0);

// Use Welford's online algorithm to compute mean and variance in one pass
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
for (int w = 0; w < width; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
delta = v - mean;
mean += delta / (w + 1);
delta2 = v - mean;
M2 += delta * delta2;
}

VEC4_T var = M2 / width;
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
VEC4_T offset = -rstd * mean;

for (int w = 0; w < width; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
// broadcasting
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
VEC4_T outtex = (v * rstd + offset) * weight + bias;
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
}

write_texel(t_mean, lpos, mean);
write_texel(t_rstd, lpos, rstd);
} else {
const int packed_width = divup4(width);

T mean = T(0);
T delta = T(0);
T delta2 = T(0);
T M2 = T(0);
// Use Welford's online algorithm to compute mean and variance in one pass
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
T width_counter = T(1);

const bool has_unaligned_width = (width & 0x3) != 0;
const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width);

// iterate through texels that are fully packed ie. has 4 components
for (int w = 0; w < fully_packed_4_comp_count; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
for (int i=0; i<4; i++) {
delta = v[i] - mean;
mean += delta / width_counter;
delta2 = v[i] - mean;
M2 += delta * delta2;
width_counter++;
}
}

// handle last texel if its not 4 aligned
if (has_unaligned_width) {
in_pos[in_axis_map.x] = fully_packed_4_comp_count;
const int remaining_width = width & 0x3;

VEC4_T v = load_texel(t_in, in_pos);
for (int i=0; i<remaining_width; i++) {
delta = v[i] - mean;
mean += delta / width_counter;
delta2 = v[i] - mean;
M2 += delta * delta2;
width_counter++;
}
}

T var = M2 / (width_counter - 1);
T rstd = inversesqrt(var + epsilon);
T offset = -rstd * mean;

for (int w = 0; w < packed_width; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0));
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0));
VEC4_T outtex = (v * rstd + offset) * weight + bias;
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
}

write_texel(t_mean, lpos, VEC4_T(mean));
write_texel(t_rstd, lpos, VEC4_T(rstd));
}

write_texel(t_mean, lpos, mean);
write_texel(t_rstd, lpos, rstd);
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ void resize_native_layer_norm_node(
rstd->virtual_resize(mean_size);
}

void check_args(const api::vTensor& in, const api::vTensor& out) {
VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
}

void add_native_layer_norm_node(
ComputeGraph& graph,
const ValueRef in,
Expand Down Expand Up @@ -84,7 +79,7 @@ void add_native_layer_norm_node(
vTensorPtr t_input = graph.get_tensor(in);
float epsilon = graph.extract_scalar<float>(eps);

check_args(*t_input, *t_out);
VK_CHECK_COND(check_same_packed_dim(*t_input, *t_out));

std::vector<int64_t> in_sizes = t_input->sizes();

Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ def get_native_layer_norm_inputs():
((S, XL, M1, M2), [M2], (M2), (M2), 0.001),
]
)
test_suite.layouts = [
"utils::kWidthPacked",
"utils::kHeightPacked",
"utils::kChannelsPacked",
]
return test_suite


Expand Down
Loading