-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathtree_LSTM.h
130 lines (101 loc) · 2.45 KB
/
tree_LSTM.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//tree LSTM for 2 children
#ifndef TREE_LSTM_H
#define TREE_LSTM_H
template<typename dType>
class encoder_multi_source;
template<typename dType>
class tree_LSTM {
public:
int device_number;
cudaStream_t s0;
cublasHandle_t handle;
int LSTM_size;
int minibatch_size;
bool clip_gradients = false;
dType norm_clip;
//hidden and cell states
dType *d_child_ht_1;
dType *d_child_ht_2;
dType *d_child_ct_1;
dType *d_child_ct_2;
dType *d_ones_minibatch;
//parameters
//biases
dType *d_b_i;
dType *d_b_f; //initialize to one
dType *d_b_o;
dType *d_b_c;
//for hidden states
dType *d_M_i_1;
dType *d_M_f_1;
dType *d_M_o_1;
dType *d_M_c_1;
dType *d_M_i_2;
dType *d_M_f_2;
dType *d_M_o_2;
dType *d_M_c_2;
//biases
dType *d_b_i_grad;
dType *d_b_f_grad; //initialize to one
dType *d_b_o_grad;
dType *d_b_c_grad;
//for hidden states
dType *d_M_i_1_grad;
dType *d_M_f_1_grad;
dType *d_M_o_1_grad;
dType *d_M_c_1_grad;
dType *d_M_i_2_grad;
dType *d_M_f_2_grad;
dType *d_M_o_2_grad;
dType *d_M_c_2_grad;
//forward prop values
dType *d_i_t;
dType *d_f_t_1;
dType *d_f_t_2;
dType *d_c_prime_t_tanh;
dType *d_o_t;
dType *d_c_t;
dType *d_h_t;
//temp stuff
dType *d_temp1;
dType *d_temp2;
dType *d_temp3;
dType *d_temp4;
dType *d_temp5;
dType *d_temp6;
dType *d_temp7;
dType *d_temp8;
//backprop errors
dType *d_d_ERRnTOt_ht; //future h_t error stored here
dType *d_d_ERRnTOtp1_ct; //future c_t error stored here
dType *d_d_ERRt_ct; //cell error with tree LSTM stored here
dType *d_d_ERRnTOt_ct; //sum of the two cell errors
dType *d_d_ERRnTOt_it;
dType *d_d_ERRnTOt_ot;
dType *d_d_ERRnTOt_ft_1;
dType *d_d_ERRnTOt_ft_2;
dType *d_d_ERRnTOt_tanhcpt;
//for children hidden states
dType *d_d_ERRnTOt_h1;
dType *d_d_ERRnTOt_h2;
dType *d_d_ERRnTOt_c1;
dType *d_d_ERRnTOt_c2;
dType *d_temp_result;
dType *d_result;
encoder_multi_source<dType> *model;
//for training
tree_LSTM(global_params ¶ms,int device_number,encoder_multi_source<dType> *model);
//for decoding
tree_LSTM(int LSTM_size,int device_number,encoder_multi_source<dType> *model);
void forward();
void backward();
void clear_gradients();
void check_all_gradients(dType epsilon);
void update_weights();
void calculate_global_norm();
void dump_weights(std::ofstream &output);
void load_weights(std::ifstream &input);
void update_global_params();
void check_gradient_GPU(dType epsilon,dType *d_mat,dType *d_grad,int rows,int cols,int gpu_index);
};
#endif