1- #include < torch/torch .h>
1+ #include < torch/extension .h>
22
33#include < vector>
44
55// s'(z) = (1 - s(z)) * s(z)
6- at ::Tensor d_sigmoid (at ::Tensor z) {
7- auto s = at ::sigmoid (z);
6+ torch ::Tensor d_sigmoid (torch ::Tensor z) {
7+ auto s = torch ::sigmoid (z);
88 return (1 - s) * s;
99}
1010
1111// tanh'(z) = 1 - tanh^2(z)
12- at ::Tensor d_tanh (at ::Tensor z) {
12+ torch ::Tensor d_tanh (torch ::Tensor z) {
1313 return 1 - z.tanh ().pow (2 );
1414}
1515
1616// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
17- at ::Tensor d_elu (at ::Tensor z, at ::Scalar alpha = 1.0 ) {
17+ torch ::Tensor d_elu (torch ::Tensor z, torch ::Scalar alpha = 1.0 ) {
1818 auto e = z.exp ();
1919 auto mask = (alpha * (e - 1 )) < 0 ;
2020 return (z > 0 ).type_as (z) + mask.type_as (z) * (alpha * e);
2121}
2222
23- std::vector<at ::Tensor> lltm_forward (
24- at ::Tensor input,
25- at ::Tensor weights,
26- at ::Tensor bias,
27- at ::Tensor old_h,
28- at ::Tensor old_cell) {
29- auto X = at ::cat ({old_h, input}, /* dim=*/ 1 );
23+ std::vector<torch ::Tensor> lltm_forward (
24+ torch ::Tensor input,
25+ torch ::Tensor weights,
26+ torch ::Tensor bias,
27+ torch ::Tensor old_h,
28+ torch ::Tensor old_cell) {
29+ auto X = torch ::cat ({old_h, input}, /* dim=*/ 1 );
3030
31- auto gate_weights = at ::addmm (bias, X, weights.transpose (0 , 1 ));
31+ auto gate_weights = torch ::addmm (bias, X, weights.transpose (0 , 1 ));
3232 auto gates = gate_weights.chunk (3 , /* dim=*/ 1 );
3333
34- auto input_gate = at ::sigmoid (gates[0 ]);
35- auto output_gate = at ::sigmoid (gates[1 ]);
36- auto candidate_cell = at ::elu (gates[2 ], /* alpha=*/ 1.0 );
34+ auto input_gate = torch ::sigmoid (gates[0 ]);
35+ auto output_gate = torch ::sigmoid (gates[1 ]);
36+ auto candidate_cell = torch ::elu (gates[2 ], /* alpha=*/ 1.0 );
3737
3838 auto new_cell = old_cell + candidate_cell * input_gate;
39- auto new_h = at ::tanh (new_cell) * output_gate;
39+ auto new_h = torch ::tanh (new_cell) * output_gate;
4040
4141 return {new_h,
4242 new_cell,
@@ -47,17 +47,17 @@ std::vector<at::Tensor> lltm_forward(
4747 gate_weights};
4848}
4949
50- std::vector<at ::Tensor> lltm_backward (
51- at ::Tensor grad_h,
52- at ::Tensor grad_cell,
53- at ::Tensor new_cell,
54- at ::Tensor input_gate,
55- at ::Tensor output_gate,
56- at ::Tensor candidate_cell,
57- at ::Tensor X,
58- at ::Tensor gate_weights,
59- at ::Tensor weights) {
60- auto d_output_gate = at ::tanh (new_cell) * grad_h;
50+ std::vector<torch ::Tensor> lltm_backward (
51+ torch ::Tensor grad_h,
52+ torch ::Tensor grad_cell,
53+ torch ::Tensor new_cell,
54+ torch ::Tensor input_gate,
55+ torch ::Tensor output_gate,
56+ torch ::Tensor candidate_cell,
57+ torch ::Tensor X,
58+ torch ::Tensor gate_weights,
59+ torch ::Tensor weights) {
60+ auto d_output_gate = torch ::tanh (new_cell) * grad_h;
6161 auto d_tanh_new_cell = output_gate * grad_h;
6262 auto d_new_cell = d_tanh (new_cell) * d_tanh_new_cell + grad_cell;
6363
@@ -71,7 +71,7 @@ std::vector<at::Tensor> lltm_backward(
7171 d_candidate_cell *= d_elu (gates[2 ]);
7272
7373 auto d_gates =
74- at ::cat ({d_input_gate, d_output_gate, d_candidate_cell}, /* dim=*/ 1 );
74+ torch ::cat ({d_input_gate, d_output_gate, d_candidate_cell}, /* dim=*/ 1 );
7575
7676 auto d_weights = d_gates.t ().mm (X);
7777 auto d_bias = d_gates.sum (/* dim=*/ 0 , /* keepdim=*/ true );
0 commit comments