77from neuralmonkey .model .model_part import ModelPart
88from neuralmonkey .runners .base_runner import (
99 Executable , ExecutionResult , NextExecute )
10- from neuralmonkey .trainers .regularizers import (Regularizer , L2Regularizer )
10+ from neuralmonkey .trainers .regularizers import (
11+ Regularizer , L1Regularizer , L2Regularizer )
1112
1213# pylint: disable=invalid-name
1314Gradients = List [Tuple [tf .Tensor , tf .Variable ]]
@@ -38,8 +39,7 @@ class Objective(NamedTuple(
3839 """
3940
4041
41- # pylint: disable=too-few-public-methods,too-many-locals,too-many-branches
42- # pylint: disable=too-many-statements
42+ # pylint: disable=too-few-public-methods
4343class GenericTrainer :
4444
4545 def __init__ (self ,
@@ -51,25 +51,15 @@ def __init__(self,
5151 var_collection : str = None ) -> None :
5252 check_argument_types ()
5353
54+ self .objectives = objectives
55+
5456 self .regularizers = [] # type: List[Regularizer]
5557 if regularizers is not None :
5658 self .regularizers = regularizers
5759
58- if var_collection is None :
59- var_collection = tf .GraphKeys .TRAINABLE_VARIABLES
60-
61- if var_scopes is None :
62- var_lists = [tf .get_collection (var_collection )]
63- else :
64- var_lists = [tf .get_collection (var_collection , scope )
65- for scope in var_scopes ]
66-
67- # Flatten the list of lists
68- self .var_list = [var for var_list in var_lists for var in var_list ]
60+ self .var_list = _get_var_list (var_scopes , var_collection )
6961
7062 with tf .name_scope ("trainer" ):
71- step = tf .train .get_or_create_global_step ()
72-
7363 if optimizer :
7464 self .optimizer = optimizer
7565 else :
@@ -85,51 +75,33 @@ def __init__(self,
8575 collections = ["summary_train" ])
8676 # pylint: enable=protected-access
8777
88- with tf .name_scope ("regularization" ):
89- regularizable = [v for v in tf .trainable_variables ()
90- if not BIAS_REGEX .findall (v .name )
91- and not v .name .startswith ("vgg" )
92- and not v .name .startswith ("Inception" )
93- and not v .name .startswith ("resnet" )]
94- reg_values = [reg .value (regularizable )
95- for reg in self .regularizers ]
96- reg_costs = [
97- reg .weight * reg_value
98- for reg , reg_value in zip (self .regularizers , reg_values )]
99-
10078 # unweighted losses for fetching
101- self .losses = [o .loss for o in objectives ] + reg_values
102-
103- # we always want to include l2 values in the summary
104- if L2Regularizer not in [type (r ) for r in self .regularizers ]:
105- l2_reg = L2Regularizer (name = "train_l2" , weight = 0. )
106- tf .summary .scalar (l2_reg .name , l2_reg .value (regularizable ),
107- collections = ["summary_train" ])
108- for reg , reg_value in zip (self .regularizers , reg_values ):
109- tf .summary .scalar (reg .name , reg_value ,
110- collections = ["summary_train" ])
79+ self .losses = [o .loss for o in self .objectives ]
11180
11281 # log all objectives
113- for obj in objectives :
82+ for obj in self . objectives :
11483 tf .summary .scalar (
11584 obj .name , obj .loss , collections = ["summary_train" ])
11685
86+ # compute regularization costs
87+ reg_costs = self ._compute_regularization ()
88+
11789 # if the objective does not have its own gradients,
11890 # just use TF to do the derivative
119- update_ops = tf .get_collection ( tf . GraphKeys . UPDATE_OPS )
120- with tf .control_dependencies ( update_ops ):
91+ with tf .control_dependencies (
92+ tf .get_collection ( tf . GraphKeys . UPDATE_OPS ) ):
12193 with tf .name_scope ("gradient_collection" ):
12294 differentiable_loss_sum = sum (
12395 [(o .weight if o .weight is not None else 1. ) * o .loss
124- for o in objectives if o .gradients is None ])
96+ for o in self . objectives if o .gradients is None ])
12597 differentiable_loss_sum += sum (reg_costs )
12698 implicit_gradients = self ._get_gradients (
12799 differentiable_loss_sum )
128100
129101 # objectives that have their gradients explictly computed
130102 other_gradients = [
131103 _scale_gradients (o .gradients , o .weight )
132- for o in objectives if o .gradients is not None ]
104+ for o in self . objectives if o .gradients is not None ]
133105
134106 if other_gradients :
135107 self .gradients = _sum_gradients (
@@ -148,10 +120,11 @@ def __init__(self,
148120 if grad is not None ]
149121
150122 self .all_coders = set .union (* (obj .decoder .get_dependencies ()
151- for obj in objectives ))
123+ for obj in self . objectives ))
152124
153125 self .train_op = self .optimizer .apply_gradients (
154- self .gradients , global_step = step )
126+ self .gradients ,
127+ global_step = tf .train .get_or_create_global_step ())
155128
156129 for grad , var in self .gradients :
157130 if grad is not None :
@@ -164,6 +137,38 @@ def __init__(self,
164137 self .scalar_summaries = tf .summary .merge (
165138 tf .get_collection ("summary_train" ))
166139
140+ def _compute_regularization (self ) -> List [tf .Tensor ]:
141+ with tf .name_scope ("regularization" ):
142+ regularizable = [v for v in tf .trainable_variables ()
143+ if not BIAS_REGEX .findall (v .name )
144+ and not v .name .startswith ("vgg" )
145+ and not v .name .startswith ("Inception" )
146+ and not v .name .startswith ("resnet" )]
147+ reg_values = [reg .value (regularizable )
148+ for reg in self .regularizers ]
149+ reg_costs = [
150+ reg .weight * reg_value
151+ for reg , reg_value in zip (self .regularizers , reg_values )]
152+
153+ # add unweighted regularization values
154+ self .losses += reg_values
155+
156+ # we always want to include l1 and l2 values in the summary
157+ if L1Regularizer not in [type (r ) for r in self .regularizers ]:
158+ l1_reg = L1Regularizer (name = "train_l1" , weight = 0. )
159+ tf .summary .scalar (l1_reg .name , l1_reg .value (regularizable ),
160+ collections = ["summary_train" ])
161+ if L2Regularizer not in [type (r ) for r in self .regularizers ]:
162+ l2_reg = L2Regularizer (name = "train_l2" , weight = 0. )
163+ tf .summary .scalar (l2_reg .name , l2_reg .value (regularizable ),
164+ collections = ["summary_train" ])
165+
166+ for reg , reg_value in zip (self .regularizers , reg_values ):
167+ tf .summary .scalar (reg .name , reg_value ,
168+ collections = ["summary_train" ])
169+
170+ return reg_costs
171+
167172 def _get_gradients (self , tensor : tf .Tensor ) -> Gradients :
168173 gradient_list = self .optimizer .compute_gradients (tensor , self .var_list )
169174 return gradient_list
@@ -181,6 +186,20 @@ def get_executable(
181186 self .histogram_summaries if summaries else None )
182187
183188
189+ def _get_var_list (var_scopes , var_collection ) -> List [tf .Variable ]:
190+ if var_collection is None :
191+ var_collection = tf .GraphKeys .TRAINABLE_VARIABLES
192+
193+ if var_scopes is None :
194+ var_lists = [tf .get_collection (var_collection )]
195+ else :
196+ var_lists = [tf .get_collection (var_collection , scope )
197+ for scope in var_scopes ]
198+
199+ # Flatten the list of lists
200+ return [var for var_list in var_lists for var in var_list ]
201+
202+
184203def _sum_gradients (gradients_list : List [Gradients ]) -> Gradients :
185204 summed_dict = {} # type: Dict[tf.Variable, tf.Tensor]
186205 for gradients in gradients_list :
0 commit comments