Skip to content

Commit 9a60fa3

Browse files
committed
Update Zhihu NN
1 parent 35020ad commit 9a60fa3

File tree

3 files changed

+47
-79
lines changed

3 files changed

+47
-79
lines changed

.idea/workspace.xml

+43-49
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Zhihu/NN/one/Network.py

+4-30
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,6 @@ def _get_prediction(self, x, out_of_sess=False):
8989
with self._sess.as_default():
9090
return self.get_rs(x).eval(feed_dict={self._tfx: x})
9191

92-
@NNTiming.timeit(level=4)
93-
def _get_activations(self, x):
94-
_activations = [self._layers[0].activate(x, self._tf_weights[0], self._tf_bias[0], True)]
95-
for i, layer in enumerate(self._layers[1:]):
96-
if i == len(self._layers) - 2:
97-
_activations.append(tf.matmul(_activations[-1], self._tf_weights[-1]) + self._tf_bias[-1])
98-
else:
99-
_activations.append(layer.activate(
100-
_activations[-1], self._tf_weights[i + 1], self._tf_bias[i + 1], True))
101-
return _activations
102-
10392
@NNTiming.timeit(level=4)
10493
def _get_l2_loss(self, lb):
10594
if lb <= 0:
@@ -109,34 +98,19 @@ def _get_l2_loss(self, lb):
10998
# API
11099

111100
@NNTiming.timeit(level=1, prefix="[API] ")
112-
def fit(self, x=None, y=None, lr=0.001, lb=0.001, epoch=10, batch_size=512):
101+
def fit(self, x=None, y=None, lr=0.001, lb=0.001, epoch=10):
113102
self._optimizer = Adam(lr)
114-
train_len = len(x)
115-
batch_size = min(batch_size, train_len)
116-
do_random_batch = train_len >= batch_size
117-
train_repeat = int(train_len / batch_size) + 1
118-
119103
self._tfx = tf.placeholder(tf.float32, shape=[None, *x.shape[1:]])
120104
self._tfy = tf.placeholder(tf.float32, shape=[None, y.shape[1]])
121-
122105
with self._sess.as_default() as sess:
123-
124-
# Session
106+
# Define session
125107
self._cost = self.get_rs(self._tfx, self._tfy) + self._get_l2_loss(lb)
126108
self._y_pred = self.get_rs(self._tfx)
127-
self._activations = self._get_activations(self._tfx)
128109
self._train_step = self._optimizer.minimize(self._cost)
129110
sess.run(tf.global_variables_initializer())
130-
111+
# Train
131112
for counter in range(epoch):
132-
for _i in range(train_repeat):
133-
if do_random_batch:
134-
batch = np.random.choice(train_len, batch_size)
135-
x_batch, y_batch = x[batch], y[batch]
136-
else:
137-
x_batch, y_batch = x, y
138-
139-
self._train_step.run(feed_dict={self._tfx: x_batch, self._tfy: y_batch})
113+
self._train_step.run(feed_dict={self._tfx: x, self._tfy: y})
140114

141115
@NNTiming.timeit(level=4, prefix="[API] ")
142116
def predict_classes(self, x, flatten=True):
-825 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)