Skip to content

Commit 064dea4

Browse files
committed
Changed the variables that are saved in checkpoint.
1 parent 0fb4123 commit 064dea4

File tree

3 files changed

+18
-19
lines changed

3 files changed

+18
-19
lines changed

neural_structured_learning/research/gam/trainer/trainer_agreement.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,10 @@ def train(self, data, session=None, **kwargs):
622622
# data_iterator_train = self._train_iterator(
623623
# labeled_samples, neighbors_val, data, ratio_pos_to_neg=ratio_pos_to_neg)
624624

625-
labeled_nodes_train, labeled_nodes_val = self._select_val_set_v2(
625+
labeled_samples_train, labeled_nodes_val = self._select_val_samples(
626626
labeled_samples, self.ratio_val)
627-
data_iterator_train = self._pair_iterator_v2(labeled_nodes_train, data,
628-
ratio_pos_neg=ratio_pos_to_neg)
627+
data_iterator_train = self._pair_iterator(labeled_samples_train, data,
628+
ratio_pos_neg=ratio_pos_to_neg)
629629

630630
# Start training.
631631
best_val_acc = -1
@@ -667,7 +667,7 @@ def train(self, data, session=None, **kwargs):
667667
# shuffle=False,
668668
# allow_smaller_batch=True,
669669
# repeat=False)
670-
data_iterator_val = self._pair_iterator_v2(labeled_nodes_val, data)
670+
data_iterator_val = self._pair_iterator(labeled_nodes_val, data)
671671
feed_dict_val = self._construct_feed_dict(
672672
data_iterator_val, is_train=False)
673673
cummulative_val_acc = 0.0
@@ -883,7 +883,7 @@ def predict_label_by_agreement(self, session, indices, num_neighbors=100):
883883
logging.info('Majority vote accuracy: %.2f.', acc)
884884
return acc
885885

886-
def _pair_iterator_v2(self, labeled_nodes, data, ratio_pos_neg=None):
886+
def _pair_iterator(self, labeled_nodes, data, ratio_pos_neg=None):
887887
# TODO: add documentation and rename neighbors to samples.
888888
neighbors_batch = np.empty(shape=(self.batch_size, 2), dtype=np.int32)
889889
agreement_batch = np.empty(shape=(self.batch_size,), dtype=np.float32)
@@ -908,14 +908,14 @@ def _pair_iterator_v2(self, labeled_nodes, data, ratio_pos_neg=None):
908908
num_added += 1
909909
yield neighbors_batch, agreement_batch
910910

911-
def _select_val_set_v2(self, labeled_nodes, percent_val):
912-
# TODO: rename and add documentation.
913-
num_labeled_nodes = labeled_nodes.shape[0]
914-
num_labeled_nodes_val = int(num_labeled_nodes * percent_val)
915-
self.rng.shuffle(labeled_nodes)
916-
labeled_nodes_val = labeled_nodes[:num_labeled_nodes_val]
917-
labeled_nodes_train = labeled_nodes[num_labeled_nodes_val:]
918-
return labeled_nodes_train, labeled_nodes_val
911+
def _select_val_samples(self, labeled_samples, ratio_val):
912+
# TODO: add documentation.
913+
num_labeled_samples = labeled_samples.shape[0]
914+
num_labeled_samples_val = int(num_labeled_samples * ratio_val)
915+
self.rng.shuffle(labeled_samples)
916+
labeled_samples_val = labeled_samples[:num_labeled_samples_val]
917+
labeled_samples_train = labeled_samples[num_labeled_samples_val:]
918+
return labeled_samples_train, labeled_samples_val
919919

920920

921921
class TrainerPerfectAgreement(object):

neural_structured_learning/research/gam/trainer/trainer_classification.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,9 @@ def __init__(self,
287287

288288
# Put together all variables that need to be saved in case the process is
289289
# interrupted and needs to be restarted.
290-
self.vars_to_save = [weight_decay_var, iter_cls_total, self.global_step]
290+
self.vars_to_save = [iter_cls_total, self.global_step]
291+
if isinstance(weight_decay_var, tf.Variable):
292+
self.vars_to_save.append(weight_decay_var)
291293
if self.warm_start:
292294
self.vars_to_save.extend([v for v in variables])
293295

neural_structured_learning/research/gam/trainer/trainer_cotrain.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,8 @@ def train(self, data, **kwargs):
505505

506506
# Create a saver which saves only the variables that we would need to
507507
# restore in case the training process is restarted.
508-
vars_to_save = [iter_cotrain]
509-
if self.warm_start_agr:
510-
vars_to_save.extend(trainer_agr.vars_to_save)
511-
if self.warm_start_cls:
512-
vars_to_save.extend(trainer_cls.vars_to_save)
508+
vars_to_save = [iter_cotrain] + trainer_agr.vars_to_save + \
509+
trainer_cls.vars_to_save
513510
saver = tf.train.Saver(vars_to_save)
514511

515512
# Create a TensorFlow session. We allow soft placement in order to place

0 commit comments

Comments
 (0)