Skip to content

Commit 381c658

Browse files
committed
refactor: format code
1 parent 91c0567 commit 381c658

File tree

10 files changed

+490
-317
lines changed

10 files changed

+490
-317
lines changed

.clang-format

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
BasedOnStyle: Google
2+
ColumnLimit: 100
3+
BinPackArguments: false
4+
BinPackParameters: false

include/adam_op/adam_op.h

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,55 @@
2121
#include "include/common.h"
2222

2323
namespace torchopt {
24+
25+
namespace py = pybind11;
26+
2427
namespace adam_op {
25-
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
26-
const torch::Tensor& mu,
27-
const torch::Tensor& nu, const pyfloat_t& b1,
28-
const pyfloat_t& b2, const pyfloat_t& eps,
29-
const pyfloat_t& eps_root,
30-
const pyuint_t& count);
31-
32-
torch::Tensor adamForwardMu(const torch::Tensor& updates,
33-
const torch::Tensor& mu, const pyfloat_t& b1);
34-
35-
torch::Tensor adamForwardNu(const torch::Tensor& updates,
36-
const torch::Tensor& nu, const pyfloat_t& b2);
37-
38-
torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
39-
const torch::Tensor& new_nu,
40-
const pyfloat_t& b1, const pyfloat_t& b2,
41-
const pyfloat_t& eps,
42-
const pyfloat_t& eps_root,
43-
const pyuint_t& count);
44-
45-
TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
46-
const torch::Tensor& updates,
47-
const torch::Tensor& mu, const pyfloat_t& b1);
48-
49-
TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
50-
const torch::Tensor& updates,
51-
const torch::Tensor& nu, const pyfloat_t& b2);
52-
53-
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
54-
const torch::Tensor& updates,
55-
const torch::Tensor& new_mu,
56-
const torch::Tensor& new_nu,
57-
const pyfloat_t& b1, const pyfloat_t& b2,
58-
const pyuint_t& count);
28+
29+
TensorArray<3> adamForwardInplace(const torch::Tensor &updates,
30+
const torch::Tensor &mu,
31+
const torch::Tensor &nu,
32+
const pyfloat_t &b1,
33+
const pyfloat_t &b2,
34+
const pyfloat_t &eps,
35+
const pyfloat_t &eps_root,
36+
const pyuint_t &count);
37+
38+
torch::Tensor adamForwardMu(const torch::Tensor &updates,
39+
const torch::Tensor &mu,
40+
const pyfloat_t &b1);
41+
42+
torch::Tensor adamForwardNu(const torch::Tensor &updates,
43+
const torch::Tensor &nu,
44+
const pyfloat_t &b2);
45+
46+
torch::Tensor adamForwardUpdates(const torch::Tensor &new_mu,
47+
const torch::Tensor &new_nu,
48+
const pyfloat_t &b1,
49+
const pyfloat_t &b2,
50+
const pyfloat_t &eps,
51+
const pyfloat_t &eps_root,
52+
const pyuint_t &count);
53+
54+
TensorArray<2> adamBackwardMu(const torch::Tensor &dmu,
55+
const torch::Tensor &updates,
56+
const torch::Tensor &mu,
57+
const pyfloat_t &b1);
58+
59+
TensorArray<2> adamBackwardNu(const torch::Tensor &dnu,
60+
const torch::Tensor &updates,
61+
const torch::Tensor &nu,
62+
const pyfloat_t &b2);
63+
64+
TensorArray<2> adamBackwardUpdates(const torch::Tensor &dupdates,
65+
const torch::Tensor &updates,
66+
const torch::Tensor &new_mu,
67+
const torch::Tensor &new_nu,
68+
const pyfloat_t &b1,
69+
const pyfloat_t &b2,
70+
const pyuint_t &count);
71+
72+
void buildSubmodule(pybind11::module &mod);
73+
5974
} // namespace adam_op
6075
} // namespace torchopt

include/adam_op/adam_op_impl_cpu.h

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,47 @@
2222

2323
namespace torchopt {
2424
namespace adam_op {
25-
TensorArray<3> adamForwardInplaceCPU(
26-
const torch::Tensor& updates, const torch::Tensor& mu,
27-
const torch::Tensor& nu, const pyfloat_t& b1, const pyfloat_t& b2,
28-
const pyfloat_t& eps, const pyfloat_t& eps_root, const pyuint_t& count);
29-
30-
torch::Tensor adamForwardMuCPU(const torch::Tensor& updates,
31-
const torch::Tensor& mu, const pyfloat_t& b1);
32-
33-
torch::Tensor adamForwardNuCPU(const torch::Tensor& updates,
34-
const torch::Tensor& nu, const pyfloat_t& b2);
35-
36-
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu,
37-
const torch::Tensor& new_nu,
38-
const pyfloat_t& b1, const pyfloat_t& b2,
39-
const pyfloat_t& eps,
40-
const pyfloat_t& eps_root,
41-
const pyuint_t& count);
42-
43-
TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu,
44-
const torch::Tensor& updates,
45-
const torch::Tensor& mu, const pyfloat_t& b1);
46-
47-
TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu,
48-
const torch::Tensor& updates,
49-
const torch::Tensor& nu, const pyfloat_t& b2);
50-
51-
TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
52-
const torch::Tensor& updates,
53-
const torch::Tensor& new_mu,
54-
const torch::Tensor& new_nu,
55-
const pyfloat_t& b1, const pyfloat_t& b2,
56-
const pyuint_t& count);
25+
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates,
26+
const torch::Tensor &mu,
27+
const torch::Tensor &nu,
28+
const pyfloat_t &b1,
29+
const pyfloat_t &b2,
30+
const pyfloat_t &eps,
31+
const pyfloat_t &eps_root,
32+
const pyuint_t &count);
33+
34+
torch::Tensor adamForwardMuCPU(const torch::Tensor &updates,
35+
const torch::Tensor &mu,
36+
const pyfloat_t &b1);
37+
38+
torch::Tensor adamForwardNuCPU(const torch::Tensor &updates,
39+
const torch::Tensor &nu,
40+
const pyfloat_t &b2);
41+
42+
torch::Tensor adamForwardUpdatesCPU(const torch::Tensor &new_mu,
43+
const torch::Tensor &new_nu,
44+
const pyfloat_t &b1,
45+
const pyfloat_t &b2,
46+
const pyfloat_t &eps,
47+
const pyfloat_t &eps_root,
48+
const pyuint_t &count);
49+
50+
TensorArray<2> adamBackwardMuCPU(const torch::Tensor &dmu,
51+
const torch::Tensor &updates,
52+
const torch::Tensor &mu,
53+
const pyfloat_t &b1);
54+
55+
TensorArray<2> adamBackwardNuCPU(const torch::Tensor &dnu,
56+
const torch::Tensor &updates,
57+
const torch::Tensor &nu,
58+
const pyfloat_t &b2);
59+
60+
TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor &dupdates,
61+
const torch::Tensor &updates,
62+
const torch::Tensor &new_mu,
63+
const torch::Tensor &new_nu,
64+
const pyfloat_t &b1,
65+
const pyfloat_t &b2,
66+
const pyuint_t &count);
5767
} // namespace adam_op
5868
} // namespace torchopt

include/adam_op/adam_op_impl_cuda.cuh

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,47 @@
2222

2323
namespace torchopt {
2424
namespace adam_op {
25-
TensorArray<3> adamForwardInplaceCUDA(
26-
const torch::Tensor &updates, const torch::Tensor &mu,
27-
const torch::Tensor &nu, const pyfloat_t &b1, const pyfloat_t &b2,
28-
const pyfloat_t &eps, const pyfloat_t &eps_root, const pyuint_t &count);
25+
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
26+
const torch::Tensor &mu,
27+
const torch::Tensor &nu,
28+
const pyfloat_t &b1,
29+
const pyfloat_t &b2,
30+
const pyfloat_t &eps,
31+
const pyfloat_t &eps_root,
32+
const pyuint_t &count);
2933

3034
torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates,
31-
const torch::Tensor &mu, const pyfloat_t &b1);
35+
const torch::Tensor &mu,
36+
const pyfloat_t &b1);
3237

3338
torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates,
34-
const torch::Tensor &nu, const pyfloat_t &b2);
39+
const torch::Tensor &nu,
40+
const pyfloat_t &b2);
3541

3642
torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
3743
const torch::Tensor &new_nu,
38-
const pyfloat_t &b1, const pyfloat_t &b2,
44+
const pyfloat_t &b1,
45+
const pyfloat_t &b2,
3946
const pyfloat_t &eps,
4047
const pyfloat_t &eps_root,
4148
const pyuint_t &count);
4249

4350
TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu,
4451
const torch::Tensor &updates,
45-
const torch::Tensor &mu, const pyfloat_t &b1);
52+
const torch::Tensor &mu,
53+
const pyfloat_t &b1);
4654

4755
TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu,
4856
const torch::Tensor &updates,
49-
const torch::Tensor &nu, const pyfloat_t &b2);
57+
const torch::Tensor &nu,
58+
const pyfloat_t &b2);
5059

5160
TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
5261
const torch::Tensor &updates,
5362
const torch::Tensor &new_mu,
5463
const torch::Tensor &new_nu,
55-
const pyfloat_t &b1, const pyfloat_t &b2,
64+
const pyfloat_t &b1,
65+
const pyfloat_t &b2,
5666
const pyuint_t &count);
5767
} // namespace adam_op
5868
} // namespace torchopt

include/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ using pyfloat_t = double;
2323
using pyuint_t = std::size_t;
2424

2525
namespace torchopt {
26-
template <size_t _Nm>
27-
using TensorArray = std::array<torch::Tensor, _Nm>;
26+
template <size_t N>
27+
using TensorArray = std::array<torch::Tensor, N>;
2828
}

include/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#endif
2424

2525
namespace torchopt {
26-
__forceinline__ size_t getTensorPlainSize(const torch::Tensor& tensor) {
26+
__forceinline__ size_t getTensorPlainSize(const torch::Tensor &tensor) {
2727
const auto dim = tensor.dim();
2828
size_t n = 1;
2929
for (std::decay_t<decltype(dim)> i = 0; i < dim; ++i) {

0 commit comments

Comments
 (0)