@@ -622,10 +622,10 @@ def train(self, data, session=None, **kwargs):
622
622
# data_iterator_train = self._train_iterator(
623
623
# labeled_samples, neighbors_val, data, ratio_pos_to_neg=ratio_pos_to_neg)
624
624
625
- labeled_nodes_train , labeled_nodes_val = self ._select_val_set_v2 (
625
+ labeled_samples_train , labeled_nodes_val = self ._select_val_samples (
626
626
labeled_samples , self .ratio_val )
627
- data_iterator_train = self ._pair_iterator_v2 ( labeled_nodes_train , data ,
628
- ratio_pos_neg = ratio_pos_to_neg )
627
+ data_iterator_train = self ._pair_iterator ( labeled_samples_train , data ,
628
+ ratio_pos_neg = ratio_pos_to_neg )
629
629
630
630
# Start training.
631
631
best_val_acc = - 1
@@ -667,7 +667,7 @@ def train(self, data, session=None, **kwargs):
667
667
# shuffle=False,
668
668
# allow_smaller_batch=True,
669
669
# repeat=False)
670
- data_iterator_val = self ._pair_iterator_v2 (labeled_nodes_val , data )
670
+ data_iterator_val = self ._pair_iterator (labeled_nodes_val , data )
671
671
feed_dict_val = self ._construct_feed_dict (
672
672
data_iterator_val , is_train = False )
673
673
cummulative_val_acc = 0.0
@@ -883,7 +883,7 @@ def predict_label_by_agreement(self, session, indices, num_neighbors=100):
883
883
logging .info ('Majority vote accuracy: %.2f.' , acc )
884
884
return acc
885
885
886
- def _pair_iterator_v2 (self , labeled_nodes , data , ratio_pos_neg = None ):
886
+ def _pair_iterator (self , labeled_nodes , data , ratio_pos_neg = None ):
887
887
# TODO: add documentation and rename neighbors to samples.
888
888
neighbors_batch = np .empty (shape = (self .batch_size , 2 ), dtype = np .int32 )
889
889
agreement_batch = np .empty (shape = (self .batch_size ,), dtype = np .float32 )
@@ -908,14 +908,14 @@ def _pair_iterator_v2(self, labeled_nodes, data, ratio_pos_neg=None):
908
908
num_added += 1
909
909
yield neighbors_batch , agreement_batch
910
910
911
- def _select_val_set_v2 (self , labeled_nodes , percent_val ):
912
- # TODO: rename and add documentation.
913
- num_labeled_nodes = labeled_nodes .shape [0 ]
914
- num_labeled_nodes_val = int (num_labeled_nodes * percent_val )
915
- self .rng .shuffle (labeled_nodes )
916
- labeled_nodes_val = labeled_nodes [: num_labeled_nodes_val ]
917
- labeled_nodes_train = labeled_nodes [ num_labeled_nodes_val :]
918
- return labeled_nodes_train , labeled_nodes_val
911
+ def _select_val_samples (self , labeled_samples , ratio_val ):
912
+ # TODO: add documentation.
913
+ num_labeled_samples = labeled_samples .shape [0 ]
914
+ num_labeled_samples_val = int (num_labeled_samples * ratio_val )
915
+ self .rng .shuffle (labeled_samples )
916
+ labeled_samples_val = labeled_samples [: num_labeled_samples_val ]
917
+ labeled_samples_train = labeled_samples [ num_labeled_samples_val :]
918
+ return labeled_samples_train , labeled_samples_val
919
919
920
920
921
921
class TrainerPerfectAgreement (object ):
0 commit comments