Skip to content

Commit 0029c32

Browse files
author
Vincent Moens
authored
[Quality] Split utils.h and utils.cpp (#2348)
1 parent 59d2ae1 commit 0029c32

File tree

2 files changed

+65
-45
lines changed

2 files changed

+65
-45
lines changed

torchrl/csrc/utils.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
//
3+
// This source code is licensed under the MIT license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
// utils.h
6+
#include "utils.h"
7+
8+
#include <iostream>
9+
torch::Tensor safetanh(torch::Tensor input, float eps) {
10+
return SafeTanh::apply(input, eps);
11+
}
12+
torch::Tensor safeatanh(torch::Tensor input, float eps) {
13+
return SafeInvTanh::apply(input, eps);
14+
}
15+
torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx,
16+
torch::Tensor input, float eps) {
17+
auto out = torch::tanh(input);
18+
auto lim = 1.0 - eps;
19+
out = out.clamp(-lim, lim);
20+
ctx->save_for_backward({out});
21+
return out;
22+
}
23+
torch::autograd::tensor_list SafeTanh::backward(
24+
torch::autograd::AutogradContext* ctx,
25+
torch::autograd::tensor_list grad_outputs) {
26+
auto saved = ctx->get_saved_variables();
27+
auto out = saved[0];
28+
auto go = grad_outputs[0];
29+
auto grad = go * (1 - out * out);
30+
return {grad, torch::Tensor()};
31+
}
32+
torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx,
33+
torch::Tensor input, float eps) {
34+
auto lim = 1.0 - eps;
35+
auto intermediate = input.clamp(-lim, lim);
36+
ctx->save_for_backward({intermediate});
37+
auto out = torch::atanh(intermediate);
38+
return out;
39+
}
40+
torch::autograd::tensor_list SafeInvTanh::backward(
41+
torch::autograd::AutogradContext* ctx,
42+
torch::autograd::tensor_list grad_outputs) {
43+
auto saved = ctx->get_saved_variables();
44+
auto input = saved[0];
45+
auto go = grad_outputs[0];
46+
auto grad = go / (1 - input * input);
47+
return {grad, torch::Tensor()};
48+
}

torchrl/csrc/utils.h

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,30 @@
22
//
33
// This source code is licensed under the MIT license found in the
44
// LICENSE file in the root directory of this source tree.
5+
// utils.h
6+
7+
#pragma once
58

69
#include <torch/extension.h>
710
#include <torch/torch.h>
811

9-
#include <iostream>
10-
11-
using namespace torch::autograd;
12+
torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6);
13+
torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6);
1214

13-
class SafeTanh : public Function<SafeTanh> {
15+
class SafeTanh : public torch::autograd::Function<SafeTanh> {
1416
public:
15-
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input,
16-
float eps = 1e-6) {
17-
auto out = torch::tanh(input);
18-
auto lim = 1.0 - eps;
19-
out = out.clamp(-lim, lim);
20-
ctx->save_for_backward({out});
21-
return out;
22-
}
23-
24-
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
25-
auto saved = ctx->get_saved_variables();
26-
auto out = saved[0];
27-
auto go = grad_outputs[0];
28-
auto grad = go * (1 - out * out);
29-
return {grad, torch::Tensor()};
30-
}
17+
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
18+
torch::Tensor input, float eps);
19+
static torch::autograd::tensor_list backward(
20+
torch::autograd::AutogradContext* ctx,
21+
torch::autograd::tensor_list grad_outputs);
3122
};
3223

33-
torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6) {
34-
return SafeTanh::apply(input, eps);
35-
}
36-
37-
class SafeInvTanh : public Function<SafeInvTanh> {
24+
class SafeInvTanh : public torch::autograd::Function<SafeInvTanh> {
3825
public:
39-
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input,
40-
float eps = 1e-6) {
41-
auto lim = 1.0 - eps;
42-
auto intermediate = input.clamp(-lim, lim);
43-
ctx->save_for_backward({intermediate});
44-
auto out = torch::atanh(intermediate);
45-
return out;
46-
}
47-
48-
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
49-
auto saved = ctx->get_saved_variables();
50-
auto input = saved[0];
51-
auto go = grad_outputs[0];
52-
auto grad = go / (1 - input * input);
53-
return {grad, torch::Tensor()};
54-
}
26+
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
27+
torch::Tensor input, float eps);
28+
static torch::autograd::tensor_list backward(
29+
torch::autograd::AutogradContext* ctx,
30+
torch::autograd::tensor_list grad_outputs);
5531
};
56-
57-
torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6) {
58-
return SafeInvTanh::apply(input, eps);
59-
}

0 commit comments

Comments
 (0)