Skip to content

Commit a12dbc0

Browse files
ThomasJannaudfacebook-github-bot
authored andcommitted
Adding RMSNorm support to arbitrary x and normalized_dim shapes (#9966)
Summary: In D72014553, we were adding initial support for RMS norm for an input in 3 or 4 dimensions, and a weight of dimension 1 (same size as x[:-1]) In this diff, we allow for: - input of arbitrary shape - shape broadcasting of w (w must have dim <= 1) Differential Revision: D72484196
1 parent f28b5db commit a12dbc0

File tree

1 file changed

+0
-4
lines changed

1 file changed

+0
-4
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@
139139
"int in_zero_point, bool channel_last=False) -> (Tensor out)"
140140
)
141141
lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
142-
lib.define("rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)")
143142
lib.define(
144143
"transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
145144
"int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
@@ -211,9 +210,6 @@
211210
"fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)"
212211
)
213212
lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)")
214-
lib.define(
215-
"rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)"
216-
)
217213
lib.define(
218214
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
219215
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"

0 commit comments

Comments
 (0)