Skip to content

Commit 1820f38

Browse files
committed
Update _Dist/NeuralNetworks
1 parent 84073e8 commit 1820f38

File tree

3 files changed

+50
-17
lines changed

3 files changed

+50
-17
lines changed

_Dist/NeuralNetworks/Base.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, x, y, weights=None, name="Generator", shuffle=True):
3737
self.n_class = 1
3838
self._name = name
3939
self._do_shuffle = shuffle
40-
self._all_valid_data = np.hstack([self._x, self._y.reshape([-1, 1])])
40+
self._all_valid_data = self._generate_all_valid_data()
4141
self._valid_indices = np.arange(len(self._all_valid_data))
4242
self._random_indices = self._valid_indices.copy()
4343
np.random.shuffle(self._random_indices)
@@ -72,6 +72,9 @@ def n_dim(self):
7272
def shape(self):
7373
return self.n_valid, self.n_dim
7474

75+
def _generate_all_valid_data(self):
76+
return np.hstack([self._x, self._y.reshape([-1, 1])])
77+
7578
def _cache_current_status(self):
7679
self._cache["_valid_indices"] = self._valid_indices
7780
self._cache["_random_indices"] = self._random_indices
@@ -149,6 +152,33 @@ def get_all_data(self):
149152
return self._get_data(self._valid_indices)
150153

151154

155+
class Generator3d(Generator):
156+
@property
157+
def n_time_step(self):
158+
return self._x.shape[1]
159+
160+
@property
161+
def shape(self):
162+
return self.n_valid, self.n_time_step, self.n_dim
163+
164+
def _generate_all_valid_data(self):
165+
return np.array([(x, y) for x, y in zip(self._x, self._y)])
166+
167+
168+
class Generator4d(Generator3d):
169+
@property
170+
def height(self):
171+
return self._x.shape[1]
172+
173+
@property
174+
def width(self):
175+
return self._x.shape[2]
176+
177+
@property
178+
def shape(self):
179+
return self.n_valid, self.height, self.width, self.n_dim
180+
181+
152182
class Base:
153183
def __init__(self, name=None, model_param_settings=None, model_structure_settings=None):
154184
tf.reset_default_graph()
@@ -159,6 +189,7 @@ def __init__(self, name=None, model_param_settings=None, model_structure_setting
159189

160190
self._settings_initialized = False
161191

192+
self._generator_base = Generator
162193
self._train_generator = self._test_generator = None
163194
self._sample_weights = self._tf_sample_weights = None
164195
self.n_dim = self.n_class = None
@@ -221,9 +252,9 @@ def init_from_data(self, x, y, x_test, y_test, sample_weights, names):
221252
else:
222253
self._tf_sample_weights = tf.placeholder(tf.float32, name="sample_weights")
223254

224-
self._train_generator = Generator(x, y, self._sample_weights, name="TrainGenerator")
255+
self._train_generator = self._generator_base(x, y, self._sample_weights, name="TrainGenerator")
225256
if x_test is not None and y_test is not None:
226-
self._test_generator = Generator(x_test, y_test, name="TestGenerator")
257+
self._test_generator = self._generator_base(x_test, y_test, name="TestGenerator")
227258
else:
228259
self._test_generator = None
229260
self.n_random_train_subset = int(len(self._train_generator) * 0.1)
@@ -431,7 +462,8 @@ def _define_tf_collections(self):
431462
def add_tf_collections(self):
432463
for tensor in self.tf_collections:
433464
target = getattr(self, tensor)
434-
tf.add_to_collection(tensor, target)
465+
if target is not None:
466+
tf.add_to_collection(tensor, target)
435467

436468
def clear_tf_collections(self):
437469
for key in self.tf_collections:
@@ -450,6 +482,8 @@ def restore_collections(self, folder):
450482
setattr(self, name, value)
451483
for tensor in self.tf_collections:
452484
target = tf.get_collection(tensor)
485+
if target is None:
486+
continue
453487
assert len(target) == 1, "{} available '{}' found".format(len(target), tensor)
454488
setattr(self, tensor, target[0])
455489
self.clear_tf_collections()

_Dist/NeuralNetworks/e_AdvancedNN/NN.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ def __init__(self, name=None, data_info=None, model_param_settings=None, model_s
3737
self._use_wide_network = self._dndf = self._pruner = None
3838

3939
self._tf_p_keep = None
40-
self._n_batch_placeholder = tf.placeholder(tf.int32, name="n_batch")
41-
42-
def init_all_settings(self):
43-
super(Advanced, self).init_all_settings()
44-
self.tf_collections.append("_n_batch_placeholder")
40+
self._n_batch_placeholder = None
4541

4642
def init_data_info(self):
4743
if self._data_info_initialized:
@@ -222,6 +218,7 @@ def _define_input_and_placeholder(self):
222218
self._is_training, lambda: self.dropout_keep_prob, lambda: 1.,
223219
name="p_keep"
224220
)
221+
self._n_batch_placeholder = tf.placeholder(tf.int32, name="n_batch")
225222

226223
def _define_py_collections(self):
227224
super(Advanced, self)._define_py_collections()
@@ -230,7 +227,7 @@ def _define_py_collections(self):
230227
def _define_tf_collections(self):
231228
super(Advanced, self)._define_tf_collections()
232229
self.tf_collections += [
233-
"_deep_input", "_wide_input",
230+
"_deep_input", "_wide_input", "_n_batch_placeholder",
234231
"_embedding", "_one_hot", "_embedding_with_one_hot",
235232
"_embedding_concat", "_one_hot_concat", "_embedding_with_one_hot_concat",
236233
]
@@ -240,12 +237,15 @@ def add_tf_collections(self):
240237
super(Advanced, self).add_tf_collections()
241238
for tf_list in self.tf_list_collections:
242239
target_list = getattr(self, tf_list)
240+
if target_list is None:
241+
continue
243242
for tensor in target_list:
244243
tf.add_to_collection(tf_list, tensor)
245244

246245
def restore_collections(self, folder):
247246
for tf_list in self.tf_list_collections:
248-
setattr(self, tf_list, tf.get_collection(tf_list))
247+
if tf_list is not None:
248+
setattr(self, tf_list, tf.get_collection(tf_list))
249249
super(Advanced, self).restore_collections(folder)
250250

251251
def print_settings(self):

_Dist/NeuralNetworks/g_DistNN/NN.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import tensorflow as tf
99

10-
from _Dist.NeuralNetworks.Base import Generator
1110
from _Dist.NeuralNetworks.f_AutoNN.NN import Auto
1211
from _Dist.NeuralNetworks.NNUtil import Toolbox, PreProcessor
1312

@@ -56,7 +55,7 @@ def rolling_fit(self, train_rate=0.8, cv_rate=0.1, sample_weights=None, **kwargs
5655
if self._test_generator is None:
5756
test_data, _ = self._train_generator.get_range(train_cursor, test_cursor)
5857
x_test, y_test = test_data[..., :-1], test_data[..., -1]
59-
self._test_generator = Generator(x_test, y_test, name="TestGenerator")
58+
self._test_generator = self._generator_base(x_test, y_test, name="TestGenerator")
6059
self._train_generator.set_range(cursor, train_cursor)
6160
kwargs["print_settings"] = print_settings
6261
self.fit(**kwargs)
@@ -96,9 +95,9 @@ def increment_fit(self, x=None, y=None, x_test=None, y_test=None, sample_weights
9695
self._handle_unbalance(y)
9796
self._handle_sparsity()
9897
if data is not None:
99-
self._train_generator = Generator(x, y, self._sample_weights, name="Generator")
98+
self._train_generator = self._generator_base(x, y, self._sample_weights, name="Generator")
10099
if x_test is not None and y_test is not None:
101-
self._test_generator = Generator(x_test, y_test, name="TestGenerator")
100+
self._test_generator = self._generator_base(x_test, y_test, name="TestGenerator")
102101
self.fit(**kwargs)
103102
x, y, _ = self._gen_batch(self._train_generator, self.n_random_train_subset, True)
104103
print(" - Performance of increment fit", end=" | ")
@@ -170,7 +169,7 @@ def k_fold(self, k=10, data=None, test_rate=0., sample_weights=None, **kwargs):
170169
self._merge_preprocessors_from_k_series(names)
171170
self._sample_weights = sample_weights_store
172171
if x_test is not None and y_test is not None:
173-
self._test_generator = Generator(x_test, y_test, name="TestGenerator")
172+
self._test_generator = self._generator_base(x_test, y_test, name="TestGenerator")
174173
return self
175174

176175
def k_random(self, k=3, data=None, cv_rate=0.1, test_rate=0., sample_weights=None, **kwargs):
@@ -201,7 +200,7 @@ def k_random(self, k=3, data=None, cv_rate=0.1, test_rate=0., sample_weights=Non
201200
self._merge_preprocessors_from_k_series(names)
202201
self._sample_weights = sample_weights_store
203202
if x_test is not None and y_test is not None:
204-
self._test_generator = Generator(x_test, y_test, name="TestGenerator")
203+
self._test_generator = self._generator_base(x_test, y_test, name="TestGenerator")
205204
return self
206205

207206

0 commit comments

Comments
 (0)