Skip to content

Commit 8b5b768

Browse files
committed
Update _Dist/NeuralNetworks
1 parent e616949 commit 8b5b768

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

_Dist/NeuralNetworks/Base.py

-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def __init__(self, name=None, model_param_settings=None, model_structure_setting
186186
self.log = {}
187187
self._name = name
188188
self._name_appendix = ""
189-
190189
self._settings_initialized = False
191190

192191
self._generator_base = Generator

_Dist/NeuralNetworks/NNUtil.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def prune_w(self, w, w_abs, w_abs_mean, w_abs_std):
678678
self.cursor += 1
679679
with tf.name_scope("Prune"):
680680
if self.cond_placeholder is None:
681-
log_w = tf.log(tf.maximum(self.eps, tf.abs(w) / (w_abs_mean * self.gamma)))
681+
log_w = tf.log(tf.maximum(self.eps, w_abs / (w_abs_mean * self.gamma)))
682682
if self.max_ratio > 0:
683683
log_w = tf.minimum(self.max_ratio, self.beta * log_w)
684684
self.masks.append(tf.maximum(self.alpha / self.beta * log_w, log_w))

_Dist/NeuralNetworks/d_Traditional2NN/Toolbox.py

+25
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,28 @@ def _transform(self):
172172
w2 *= max_route_length
173173
self._transform_ws = [w1, w2, w3]
174174
self._transform_bs = [b]
175+
176+
177+
if __name__ == '__main__':
178+
train, test = np.load("train.npy"), np.load("test.npy")
179+
x, y = train[..., :-1], train[..., -1]
180+
x_test, y_test = test[..., :-1], test[..., -1]
181+
# x_mean, x_std = x.mean(), x.std()
182+
# x -= x_mean; x /= x_std
183+
# x_test -= x_mean; x_test /= x_std
184+
from _Dist.NeuralNetworks.e_AdvancedNN.NN import Advanced
185+
Advanced(
186+
name="madelon",
187+
data_info={
188+
"numerical_idx": [True] * 500 + [False],
189+
"categorical_columns": []
190+
},
191+
model_param_settings={
192+
"lr": 1e-3,
193+
"activations": ["relu", "relu"]
194+
}, model_structure_settings={
195+
"use_pruner": False,
196+
"use_wide_network": False,
197+
"hidden_units": [152, 153]
198+
}
199+
).fit(x, y, x_test, y_test, snapshot_ratio=1)

_Dist/NeuralNetworks/e_AdvancedNN/NN.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,11 @@ def init_from_data(self, x, y, x_test, y_test, sample_weights, names):
6161

6262
def init_model_param_settings(self):
6363
super(Advanced, self).init_model_param_settings()
64-
self.dropout_keep_prob = self.model_param_settings.get("p_keep", 0.5)
65-
self.use_batch_norm = self.model_param_settings.get("use_batch_norm", False)
64+
self.dropout_keep_prob = self.model_param_settings.get("keep_prob", 0.5)
65+
self.use_batch_norm = self.model_param_settings.get("use_batch_norm", True)
6666

6767
def init_model_structure_settings(self):
68+
self.hidden_units = self.model_structure_settings.get("hidden_units", None)
6869
self._deep_input = self.model_structure_settings.get("deep_input", "embedding_concat")
6970
self._wide_input = self.model_structure_settings.get("wide_input", "continuous")
7071
self.embedding_size = self.model_structure_settings.get("embedding_size", 8)
@@ -210,7 +211,8 @@ def _define_input_and_placeholder(self):
210211
self._deep_input = self._tfx
211212
else:
212213
self._deep_input = getattr(self, "_" + self._deep_input)
213-
self._define_hidden_units()
214+
if self.hidden_units is None:
215+
self._define_hidden_units()
214216
self._settings = "{}_{}(dndf)_{}(prune)".format(
215217
self.hidden_units, self._dndf is not None, self._pruner is not None
216218
)

0 commit comments

Comments
 (0)