Skip to content

Commit f4f8037

Browse files
committed
Upload _Dist/NeuralNetworks/h_RNN
1 parent 1820f38 commit f4f8037

File tree

2 files changed

+228
-0
lines changed

2 files changed

+228
-0
lines changed

_Dist/NeuralNetworks/h_RNN/Cell.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import os
2+
import sys
3+
root_path = os.path.abspath("../../../")
4+
if root_path not in sys.path:
5+
sys.path.append(root_path)
6+
7+
import tensorflow as tf
8+
from tensorflow.contrib.rnn import LSTMStateTuple
9+
10+
from _Dist.NeuralNetworks.NNUtil import DNDF
11+
12+
13+
class LSTMCell(tf.contrib.rnn.LSTMCell):
14+
def __str__(self):
15+
return "LSTMCell"
16+
17+
18+
class BasicLSTMCell(tf.contrib.rnn.BasicLSTMCell):
19+
def __str__(self):
20+
return "BasicLSTMCell"
21+
22+
23+
class CustomLSTMCell(tf.contrib.rnn.BasicRNNCell):
24+
def __init__(self, *args, **kwargs):
25+
super(CustomLSTMCell, self).__init__(*args, **kwargs)
26+
self._n_batch_placeholder = tf.placeholder(tf.int32, [], "n_batch_placeholder")
27+
28+
def __str__(self):
29+
return "CustomLSTMCell"
30+
31+
def __call__(self, x, state, scope="LSTM"):
32+
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
33+
s_old, h_old = state
34+
net = tf.concat([x, s_old], 1)
35+
36+
w = tf.get_variable(
37+
"W", [net.shape[1].value, 4 * self._num_units], tf.float32,
38+
tf.contrib.layers.xavier_initializer()
39+
)
40+
b = tf.get_variable(
41+
"b", [4 * self._num_units], tf.float32,
42+
tf.zeros_initializer()
43+
)
44+
gates = tf.nn.xw_plus_b(net, w, b)
45+
46+
r1, g1, g2, g3 = tf.split(gates, 4, 1)
47+
r1, g1, g3 = tf.nn.sigmoid(r1), tf.nn.sigmoid(g1), tf.nn.sigmoid(g3)
48+
g2 = tf.nn.tanh(g2)
49+
h_new = h_old * r1 + g1 * g2
50+
s_new = tf.nn.tanh(h_new) * g3
51+
return s_new, LSTMStateTuple(s_new, h_new)
52+
53+
@property
54+
def state_size(self):
55+
return LSTMStateTuple(self._num_units, self._num_units)
56+
57+
58+
class DNDFCell(tf.contrib.rnn.BasicRNNCell):
59+
def __init__(self, *args, **kwargs):
60+
self._dndf = DNDF(reuse=True)
61+
self.n_batch_placeholder = kwargs.pop("n_batch_placeholder", None)
62+
if self.n_batch_placeholder is None:
63+
self.n_batch_placeholder = tf.placeholder(tf.int32, name="n_batch_placeholder")
64+
super(DNDFCell, self).__init__(*args, **kwargs)
65+
66+
def __str__(self):
67+
return "DNDFCell"
68+
69+
def __call__(self, x, state, scope="DNDFCell"):
70+
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
71+
s_old, h_old = state
72+
net = tf.concat([
73+
x,
74+
s_old,
75+
self._dndf(s_old, self.n_batch_placeholder, "feature"),
76+
], 1)
77+
78+
w = tf.get_variable(
79+
"W", [net.shape[1].value, 4 * self._num_units], tf.float32,
80+
tf.contrib.layers.xavier_initializer()
81+
)
82+
b = tf.get_variable(
83+
"b", [4 * self._num_units], tf.float32,
84+
tf.zeros_initializer()
85+
)
86+
gates = tf.nn.xw_plus_b(net, w, b)
87+
88+
r1, g1, g2, g3 = tf.split(gates, 4, 1)
89+
r1, g1, g3 = tf.nn.sigmoid(r1), tf.nn.sigmoid(g1), tf.nn.sigmoid(g3)
90+
g2 = tf.nn.tanh(g2)
91+
h_new = h_old * r1 + g1 * g2
92+
s_new = tf.nn.tanh(h_new) * g3
93+
94+
return s_new, LSTMStateTuple(s_new, h_new)
95+
96+
@property
97+
def state_size(self):
98+
return LSTMStateTuple(self._num_units, self._num_units)
99+
100+
101+
class CellFactory:
102+
@staticmethod
103+
def get_cell(name, n_hidden, **kwargs):
104+
if name == "LSTM" or name == "LSTMCell":
105+
cell = LSTMCell
106+
elif name == "BasicLSTM" or name == "BasicLSTMCell":
107+
cell = BasicLSTMCell
108+
elif name == "CustomLSTM" or name == "CustomLSTMCell":
109+
cell = CustomLSTMCell
110+
elif name == "DNDF" or name == "DNDFCell":
111+
cell = DNDFCell
112+
else:
113+
raise NotImplementedError("Cell '{}' not implemented".format(name))
114+
return cell(n_hidden, **kwargs)

_Dist/NeuralNetworks/h_RNN/RNN.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import os
2+
import sys
3+
root_path = os.path.abspath("../../../")
4+
if root_path not in sys.path:
5+
sys.path.append(root_path)
6+
7+
import numpy as np
8+
import tensorflow as tf
9+
10+
from _Dist.NeuralNetworks.NNUtil import Toolbox
11+
from _Dist.NeuralNetworks.Base import Generator3d
12+
from _Dist.NeuralNetworks.c_BasicNN.NN import Basic
13+
from _Dist.NeuralNetworks.h_RNN.Cell import CellFactory
14+
15+
16+
class Basic3d(Basic):
17+
def _gen_batch(self, generator, n_batch, gen_random_subset=False, one_hot=False):
18+
if gen_random_subset:
19+
data, weights = generator.gen_random_subset(n_batch)
20+
else:
21+
data, weights = generator.gen_batch(n_batch)
22+
x = np.array([d[0] for d in data], np.float32)
23+
y = np.array([d[1] for d in data], np.float32)
24+
if not one_hot:
25+
return x, y, weights
26+
if self.n_class == 1:
27+
y = y.reshape([-1, 1])
28+
else:
29+
y = Toolbox.get_one_hot(y, self.n_class)
30+
return x, y, weights
31+
32+
33+
class RNN(Basic3d):
34+
def __init__(self, *args, **kwargs):
35+
self.n_time_step = kwargs.pop("n_time_step", None)
36+
37+
super(RNN, self).__init__(*args, **kwargs)
38+
self._name_appendix = "RNN"
39+
self._generator_base = Generator3d
40+
41+
self._using_dndf_cell = False
42+
self._n_batch_placeholder = None
43+
self._cell = self._cell_name = None
44+
self.n_hidden = self.n_history = self.use_final_state = None
45+
46+
def init_model_param_settings(self):
47+
super(RNN, self).init_model_param_settings()
48+
self._cell_name = self.model_param_settings.get("cell", "CustomLSTM")
49+
50+
def init_model_structure_settings(self):
51+
super(RNN, self).init_model_structure_settings()
52+
self.n_hidden = self.model_structure_settings.get("n_hidden", 128)
53+
self.n_history = self.model_structure_settings.get("n_history", 0)
54+
self.use_final_state = self.model_structure_settings.get("use_final_state", True)
55+
56+
def init_from_data(self, x, y, x_test, y_test, sample_weights, names):
57+
if self.n_time_step is None:
58+
assert len(x.shape) == 3, "n_time_step is not provided, hence len(x.shape) should be 3"
59+
self.n_time_step = x.shape[1]
60+
if len(x.shape) == 2:
61+
x = x.reshape(len(x), self.n_time_step, -1)
62+
else:
63+
assert self.n_time_step == x.shape[1], "n_time_step is set to be {}, but {} found".format(
64+
self.n_time_step, x.shape[1]
65+
)
66+
if len(x_test.shape) == 2:
67+
x_test = x_test.reshape(len(x_test), self.n_time_step, -1)
68+
super(RNN, self).init_from_data(x, y, x_test, y_test, sample_weights, names)
69+
70+
def _define_input_and_placeholder(self):
71+
self._is_training = tf.placeholder(tf.bool, name="is_training")
72+
self._tfx = tf.placeholder(tf.float32, [None, self.n_time_step, self.n_dim], name="X")
73+
self._tfy = tf.placeholder(tf.float32, [None, self.n_class], name="Y")
74+
75+
def _build_model(self, net=None):
76+
self._model_built = True
77+
if net is None:
78+
net = self._tfx
79+
80+
self._cell = CellFactory.get_cell(self._cell_name, self.n_hidden)
81+
if "DNDF" in self._cell_name:
82+
self._using_dndf_cell = True
83+
self._n_batch_placeholder = self._cell.n_batch_placeholder
84+
85+
initial_state = self._cell.zero_state(tf.shape(net)[0], tf.float32)
86+
rnn_outputs, rnn_final_state = tf.nn.dynamic_rnn(self._cell, net, initial_state=initial_state)
87+
88+
if self.n_history == 0:
89+
net = None
90+
elif self.n_history == 1:
91+
net = rnn_outputs[..., -1, :]
92+
else:
93+
net = rnn_outputs[..., -self.n_history:, :]
94+
net = tf.reshape(net, [-1, self.n_history * int(net.shape[2].value)])
95+
if self.use_final_state:
96+
if net is None:
97+
net = rnn_final_state[1]
98+
else:
99+
net = tf.concat([net, rnn_final_state[1]], axis=1)
100+
return super(RNN, self)._build_model(net)
101+
102+
def _get_feed_dict(self, x, y=None, weights=None, is_training=False):
103+
feed_dict = super(RNN, self)._get_feed_dict(x, y, weights, is_training)
104+
if self._using_dndf_cell:
105+
feed_dict[self._n_batch_placeholder] = len(x)
106+
return feed_dict
107+
108+
def _define_py_collections(self):
109+
super(RNN, self)._define_py_collections()
110+
self.py_collections.append("_using_dndf_cell")
111+
112+
def _define_tf_collections(self):
113+
super(RNN, self)._define_tf_collections()
114+
self.tf_collections.append("_n_batch_placeholder")

0 commit comments

Comments
 (0)