Skip to content

Commit 4d13877

Browse files
committed
Update Zhihu NN
1 parent 3cbf47a commit 4d13877

File tree

3 files changed

+109
-86
lines changed

3 files changed

+109
-86
lines changed

.idea/workspace.xml

+104-65
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Zhihu/NN/one/Network.py

+5-21
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ class NNBase:
1212

1313
def __init__(self):
1414
self._layers = []
15-
self._lr = 0
1615
self._optimizer = None
1716
self._current_dimension = 0
1817

@@ -53,16 +52,13 @@ def _add_weight(self, shape):
5352

5453
@NNTiming.timeit(level=1, prefix="[API] ")
5554
def get_rs(self, x, y=None):
56-
predict = True if y is None else False
57-
_cache = self._layers[0].activate(x, self._tf_weights[0], self._tf_bias[0], predict)
55+
_cache = self._layers[0].activate(x, self._tf_weights[0], self._tf_bias[0])
5856
for i, layer in enumerate(self._layers[1:]):
5957
if i == len(self._layers) - 2:
6058
if y is None:
61-
if self._tf_bias[-1] is not None:
62-
return tf.matmul(_cache, self._tf_weights[-1]) + self._tf_bias[-1]
63-
return tf.matmul(_cache, self._tf_weights[-1])
64-
predict = y
65-
_cache = layer.activate(_cache, self._tf_weights[i + 1], self._tf_bias[i + 1], predict)
59+
return tf.matmul(_cache, self._tf_weights[-1]) + self._tf_bias[-1]
60+
return layer.activate(_cache, self._tf_weights[i + 1], self._tf_bias[i + 1], y)
61+
_cache = layer.activate(_cache, self._tf_weights[i + 1], self._tf_bias[i + 1])
6662
return _cache
6763

6864
@NNTiming.timeit(level=4, prefix="[API] ")
@@ -114,19 +110,7 @@ def _get_l2_loss(self, lb):
114110

115111
@NNTiming.timeit(level=1, prefix="[API] ")
116112
def fit(self, x=None, y=None, lr=0.001, lb=0.001, epoch=10, batch_size=512):
117-
118-
self._lr = lr
119-
self._optimizer = Adam(self._lr)
120-
print("Optimizer: ", self._optimizer.name)
121-
print("-" * 30)
122-
123-
if not self._layers:
124-
raise BuildNetworkError("Please provide layers before fitting data")
125-
126-
if y.shape[1] != self._current_dimension:
127-
raise BuildNetworkError("Output layer's shape should be {}, {} found".format(
128-
self._current_dimension, y.shape[1]))
129-
113+
self._optimizer = Adam(lr)
130114
train_len = len(x)
131115
batch_size = min(batch_size, train_len)
132116
do_random_batch = train_len >= batch_size
-373 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)