|
2 | 2 | //
|
3 | 3 | // This source code is licensed under the MIT license found in the
|
4 | 4 | // LICENSE file in the root directory of this source tree.
|
| 5 | +// utils.h |
| 6 | + |
| 7 | +#pragma once |
5 | 8 |
|
6 | 9 | #include <torch/extension.h>
|
7 | 10 | #include <torch/torch.h>
|
8 | 11 |
|
9 |
| -#include <iostream> |
10 |
| - |
11 |
| -using namespace torch::autograd; |
| 12 | +torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6); |
| 13 | +torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6); |
12 | 14 |
|
13 |
| -class SafeTanh : public Function<SafeTanh> { |
| 15 | +class SafeTanh : public torch::autograd::Function<SafeTanh> { |
14 | 16 | public:
|
15 |
| - static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input, |
16 |
| - float eps = 1e-6) { |
17 |
| - auto out = torch::tanh(input); |
18 |
| - auto lim = 1.0 - eps; |
19 |
| - out = out.clamp(-lim, lim); |
20 |
| - ctx->save_for_backward({out}); |
21 |
| - return out; |
22 |
| - } |
23 |
| - |
24 |
| - static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { |
25 |
| - auto saved = ctx->get_saved_variables(); |
26 |
| - auto out = saved[0]; |
27 |
| - auto go = grad_outputs[0]; |
28 |
| - auto grad = go * (1 - out * out); |
29 |
| - return {grad, torch::Tensor()}; |
30 |
| - } |
| 17 | + static torch::Tensor forward(torch::autograd::AutogradContext* ctx, |
| 18 | + torch::Tensor input, float eps); |
| 19 | + static torch::autograd::tensor_list backward( |
| 20 | + torch::autograd::AutogradContext* ctx, |
| 21 | + torch::autograd::tensor_list grad_outputs); |
31 | 22 | };
|
32 | 23 |
|
33 |
| -torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6) { |
34 |
| - return SafeTanh::apply(input, eps); |
35 |
| -} |
36 |
| - |
37 |
| -class SafeInvTanh : public Function<SafeInvTanh> { |
| 24 | +class SafeInvTanh : public torch::autograd::Function<SafeInvTanh> { |
38 | 25 | public:
|
39 |
| - static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input, |
40 |
| - float eps = 1e-6) { |
41 |
| - auto lim = 1.0 - eps; |
42 |
| - auto intermediate = input.clamp(-lim, lim); |
43 |
| - ctx->save_for_backward({intermediate}); |
44 |
| - auto out = torch::atanh(intermediate); |
45 |
| - return out; |
46 |
| - } |
47 |
| - |
48 |
| - static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { |
49 |
| - auto saved = ctx->get_saved_variables(); |
50 |
| - auto input = saved[0]; |
51 |
| - auto go = grad_outputs[0]; |
52 |
| - auto grad = go / (1 - input * input); |
53 |
| - return {grad, torch::Tensor()}; |
54 |
| - } |
| 26 | + static torch::Tensor forward(torch::autograd::AutogradContext* ctx, |
| 27 | + torch::Tensor input, float eps); |
| 28 | + static torch::autograd::tensor_list backward( |
| 29 | + torch::autograd::AutogradContext* ctx, |
| 30 | + torch::autograd::tensor_list grad_outputs); |
55 | 31 | };
|
56 |
| - |
57 |
| -torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6) { |
58 |
| - return SafeInvTanh::apply(input, eps); |
59 |
| -} |
0 commit comments