Skip to content

Commit

Permalink
Symintified layer_norm (pytorch#89466)
Browse files Browse the repository at this point in the history
Summary: As titled.

Test Plan:
```
buck2 run mode/opt scripts/wwei6:test_executorch
```

Differential Revision: D41451390

Pull Request resolved: pytorch#89466
Approved by: https://github.com/frank-wei, https://github.com/ezyang
  • Loading branch information
tissue3 authored and pytorchmergebot committed Nov 24, 2022
1 parent fdb2dd1 commit f0e5bc4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
OP_DECOMPOSE(instance_norm);
OP_DECOMPOSE(kron);
OP_DECOMPOSE(l1_loss);
OP_DECOMPOSE(layer_norm);
m.impl("layer_norm", native::layer_norm_symint);
OP_DECOMPOSE2(ldexp, Tensor);
OP_DECOMPOSE2(less_equal, Tensor );
OP_DECOMPOSE2(less, Tensor );
Expand Down
7 changes: 3 additions & 4 deletions aten/src/ATen/native/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_cpu(
return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta));
}

Tensor layer_norm(
Tensor layer_norm_symint(
const Tensor& input,
IntArrayRef normalized_shape, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */,
c10::SymIntArrayRef normalized_shape, const c10::optional<Tensor>& weight_opt /* optional */, const c10::optional<Tensor>& bias_opt /* optional */,
double eps,
bool /* cudnn_enable, deprecated */) {
// See [Note: hacky wrapper removal for optional tensor]
Expand All @@ -186,8 +186,7 @@ Tensor layer_norm(
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;


return std::get<0>(at::native_layer_norm(input, normalized_shape, weight, bias, eps));
return std::get<0>(at::native_layer_norm_symint(input, normalized_shape, weight, bias, eps));
}

DEFINE_DISPATCH(LayerNormKernel);
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2938,7 +2938,9 @@

- func: kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)

- func: layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
- func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
dispatch:
CompositeImplicitAutograd: layer_norm_symint

- func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
dispatch:
Expand Down

0 comments on commit f0e5bc4

Please sign in to comment.