@@ -37,7 +37,7 @@ def __init__(self, x, y, weights=None, name="Generator", shuffle=True):
37
37
self .n_class = 1
38
38
self ._name = name
39
39
self ._do_shuffle = shuffle
40
- self ._all_valid_data = np . hstack ([ self ._x , self . _y . reshape ([ - 1 , 1 ])] )
40
+ self ._all_valid_data = self ._generate_all_valid_data ( )
41
41
self ._valid_indices = np .arange (len (self ._all_valid_data ))
42
42
self ._random_indices = self ._valid_indices .copy ()
43
43
np .random .shuffle (self ._random_indices )
@@ -72,6 +72,9 @@ def n_dim(self):
72
72
def shape (self ):
73
73
return self .n_valid , self .n_dim
74
74
75
+ def _generate_all_valid_data (self ):
76
+ return np .hstack ([self ._x , self ._y .reshape ([- 1 , 1 ])])
77
+
75
78
def _cache_current_status (self ):
76
79
self ._cache ["_valid_indices" ] = self ._valid_indices
77
80
self ._cache ["_random_indices" ] = self ._random_indices
@@ -149,6 +152,33 @@ def get_all_data(self):
149
152
return self ._get_data (self ._valid_indices )
150
153
151
154
155
+ class Generator3d (Generator ):
156
+ @property
157
+ def n_time_step (self ):
158
+ return self ._x .shape [1 ]
159
+
160
+ @property
161
+ def shape (self ):
162
+ return self .n_valid , self .n_time_step , self .n_dim
163
+
164
+ def _generate_all_valid_data (self ):
165
+ return np .array ([(x , y ) for x , y in zip (self ._x , self ._y )])
166
+
167
+
168
+ class Generator4d (Generator3d ):
169
+ @property
170
+ def height (self ):
171
+ return self ._x .shape [1 ]
172
+
173
+ @property
174
+ def width (self ):
175
+ return self ._x .shape [2 ]
176
+
177
+ @property
178
+ def shape (self ):
179
+ return self .n_valid , self .height , self .width , self .n_dim
180
+
181
+
152
182
class Base :
153
183
def __init__ (self , name = None , model_param_settings = None , model_structure_settings = None ):
154
184
tf .reset_default_graph ()
@@ -159,6 +189,7 @@ def __init__(self, name=None, model_param_settings=None, model_structure_setting
159
189
160
190
self ._settings_initialized = False
161
191
192
+ self ._generator_base = Generator
162
193
self ._train_generator = self ._test_generator = None
163
194
self ._sample_weights = self ._tf_sample_weights = None
164
195
self .n_dim = self .n_class = None
@@ -221,9 +252,9 @@ def init_from_data(self, x, y, x_test, y_test, sample_weights, names):
221
252
else :
222
253
self ._tf_sample_weights = tf .placeholder (tf .float32 , name = "sample_weights" )
223
254
224
- self ._train_generator = Generator (x , y , self ._sample_weights , name = "TrainGenerator" )
255
+ self ._train_generator = self . _generator_base (x , y , self ._sample_weights , name = "TrainGenerator" )
225
256
if x_test is not None and y_test is not None :
226
- self ._test_generator = Generator (x_test , y_test , name = "TestGenerator" )
257
+ self ._test_generator = self . _generator_base (x_test , y_test , name = "TestGenerator" )
227
258
else :
228
259
self ._test_generator = None
229
260
self .n_random_train_subset = int (len (self ._train_generator ) * 0.1 )
@@ -431,7 +462,8 @@ def _define_tf_collections(self):
431
462
def add_tf_collections (self ):
432
463
for tensor in self .tf_collections :
433
464
target = getattr (self , tensor )
434
- tf .add_to_collection (tensor , target )
465
+ if target is not None :
466
+ tf .add_to_collection (tensor , target )
435
467
436
468
def clear_tf_collections (self ):
437
469
for key in self .tf_collections :
@@ -450,6 +482,8 @@ def restore_collections(self, folder):
450
482
setattr (self , name , value )
451
483
for tensor in self .tf_collections :
452
484
target = tf .get_collection (tensor )
485
+ if target is None :
486
+ continue
453
487
assert len (target ) == 1 , "{} available '{}' found" .format (len (target ), tensor )
454
488
setattr (self , tensor , target [0 ])
455
489
self .clear_tf_collections ()
0 commit comments