@@ -12,7 +12,6 @@ class NNBase:
12
12
13
13
def __init__ (self ):
14
14
self ._layers = []
15
- self ._lr = 0
16
15
self ._optimizer = None
17
16
self ._current_dimension = 0
18
17
@@ -53,16 +52,13 @@ def _add_weight(self, shape):
53
52
54
53
@NNTiming .timeit (level = 1 , prefix = "[API] " )
55
54
def get_rs (self , x , y = None ):
56
- predict = True if y is None else False
57
- _cache = self ._layers [0 ].activate (x , self ._tf_weights [0 ], self ._tf_bias [0 ], predict )
55
+ _cache = self ._layers [0 ].activate (x , self ._tf_weights [0 ], self ._tf_bias [0 ])
58
56
for i , layer in enumerate (self ._layers [1 :]):
59
57
if i == len (self ._layers ) - 2 :
60
58
if y is None :
61
- if self ._tf_bias [- 1 ] is not None :
62
- return tf .matmul (_cache , self ._tf_weights [- 1 ]) + self ._tf_bias [- 1 ]
63
- return tf .matmul (_cache , self ._tf_weights [- 1 ])
64
- predict = y
65
- _cache = layer .activate (_cache , self ._tf_weights [i + 1 ], self ._tf_bias [i + 1 ], predict )
59
+ return tf .matmul (_cache , self ._tf_weights [- 1 ]) + self ._tf_bias [- 1 ]
60
+ return layer .activate (_cache , self ._tf_weights [i + 1 ], self ._tf_bias [i + 1 ], y )
61
+ _cache = layer .activate (_cache , self ._tf_weights [i + 1 ], self ._tf_bias [i + 1 ])
66
62
return _cache
67
63
68
64
@NNTiming .timeit (level = 4 , prefix = "[API] " )
@@ -114,19 +110,7 @@ def _get_l2_loss(self, lb):
114
110
115
111
@NNTiming .timeit (level = 1 , prefix = "[API] " )
116
112
def fit (self , x = None , y = None , lr = 0.001 , lb = 0.001 , epoch = 10 , batch_size = 512 ):
117
-
118
- self ._lr = lr
119
- self ._optimizer = Adam (self ._lr )
120
- print ("Optimizer: " , self ._optimizer .name )
121
- print ("-" * 30 )
122
-
123
- if not self ._layers :
124
- raise BuildNetworkError ("Please provide layers before fitting data" )
125
-
126
- if y .shape [1 ] != self ._current_dimension :
127
- raise BuildNetworkError ("Output layer's shape should be {}, {} found" .format (
128
- self ._current_dimension , y .shape [1 ]))
129
-
113
+ self ._optimizer = Adam (lr )
130
114
train_len = len (x )
131
115
batch_size = min (batch_size , train_len )
132
116
do_random_batch = train_len >= batch_size
0 commit comments