Skip to content

Commit 2e4683b

Browse files
committed
fixed pylints in generic_trainer, fixed typos
1 parent 9c94e76 commit 2e4683b

File tree

2 files changed

+76
-59
lines changed

2 files changed

+76
-59
lines changed

neuralmonkey/trainers/generic_trainer.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from neuralmonkey.model.model_part import ModelPart
88
from neuralmonkey.runners.base_runner import (
99
Executable, ExecutionResult, NextExecute)
10-
from neuralmonkey.trainers.regularizers import (Regularizer, L2Regularizer)
10+
from neuralmonkey.trainers.regularizers import (
11+
Regularizer, L1Regularizer, L2Regularizer)
1112

1213
# pylint: disable=invalid-name
1314
Gradients = List[Tuple[tf.Tensor, tf.Variable]]
@@ -38,8 +39,7 @@ class Objective(NamedTuple(
3839
"""
3940

4041

41-
# pylint: disable=too-few-public-methods,too-many-locals,too-many-branches
42-
# pylint: disable=too-many-statements
42+
# pylint: disable=too-few-public-methods
4343
class GenericTrainer:
4444

4545
def __init__(self,
@@ -51,25 +51,15 @@ def __init__(self,
5151
var_collection: str = None) -> None:
5252
check_argument_types()
5353

54+
self.objectives = objectives
55+
5456
self.regularizers = [] # type: List[Regularizer]
5557
if regularizers is not None:
5658
self.regularizers = regularizers
5759

58-
if var_collection is None:
59-
var_collection = tf.GraphKeys.TRAINABLE_VARIABLES
60-
61-
if var_scopes is None:
62-
var_lists = [tf.get_collection(var_collection)]
63-
else:
64-
var_lists = [tf.get_collection(var_collection, scope)
65-
for scope in var_scopes]
66-
67-
# Flatten the list of lists
68-
self.var_list = [var for var_list in var_lists for var in var_list]
60+
self.var_list = _get_var_list(var_scopes, var_collection)
6961

7062
with tf.name_scope("trainer"):
71-
step = tf.train.get_or_create_global_step()
72-
7363
if optimizer:
7464
self.optimizer = optimizer
7565
else:
@@ -85,51 +75,33 @@ def __init__(self,
8575
collections=["summary_train"])
8676
# pylint: enable=protected-access
8777

88-
with tf.name_scope("regularization"):
89-
regularizable = [v for v in tf.trainable_variables()
90-
if not BIAS_REGEX.findall(v.name)
91-
and not v.name.startswith("vgg")
92-
and not v.name.startswith("Inception")
93-
and not v.name.startswith("resnet")]
94-
reg_values = [reg.value(regularizable)
95-
for reg in self.regularizers]
96-
reg_costs = [
97-
reg.weight * reg_value
98-
for reg, reg_value in zip(self.regularizers, reg_values)]
99-
10078
# unweighted losses for fetching
101-
self.losses = [o.loss for o in objectives] + reg_values
102-
103-
# we always want to include l2 values in the summary
104-
if L2Regularizer not in [type(r) for r in self.regularizers]:
105-
l2_reg = L2Regularizer(name="train_l2", weight=0.)
106-
tf.summary.scalar(l2_reg.name, l2_reg.value(regularizable),
107-
collections=["summary_train"])
108-
for reg, reg_value in zip(self.regularizers, reg_values):
109-
tf.summary.scalar(reg.name, reg_value,
110-
collections=["summary_train"])
79+
self.losses = [o.loss for o in self.objectives]
11180

11281
# log all objectives
113-
for obj in objectives:
82+
for obj in self.objectives:
11483
tf.summary.scalar(
11584
obj.name, obj.loss, collections=["summary_train"])
11685

86+
# compute regularization costs
87+
reg_costs = self._compute_regularization()
88+
11789
# if the objective does not have its own gradients,
11890
# just use TF to do the derivative
119-
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
120-
with tf.control_dependencies(update_ops):
91+
with tf.control_dependencies(
92+
tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
12193
with tf.name_scope("gradient_collection"):
12294
differentiable_loss_sum = sum(
12395
[(o.weight if o.weight is not None else 1.) * o.loss
124-
for o in objectives if o.gradients is None])
96+
for o in self.objectives if o.gradients is None])
12597
differentiable_loss_sum += sum(reg_costs)
12698
implicit_gradients = self._get_gradients(
12799
differentiable_loss_sum)
128100

129101
# objectives that have their gradients explictly computed
130102
other_gradients = [
131103
_scale_gradients(o.gradients, o.weight)
132-
for o in objectives if o.gradients is not None]
104+
for o in self.objectives if o.gradients is not None]
133105

134106
if other_gradients:
135107
self.gradients = _sum_gradients(
@@ -148,10 +120,11 @@ def __init__(self,
148120
if grad is not None]
149121

150122
self.all_coders = set.union(*(obj.decoder.get_dependencies()
151-
for obj in objectives))
123+
for obj in self.objectives))
152124

153125
self.train_op = self.optimizer.apply_gradients(
154-
self.gradients, global_step=step)
126+
self.gradients,
127+
global_step=tf.train.get_or_create_global_step())
155128

156129
for grad, var in self.gradients:
157130
if grad is not None:
@@ -164,6 +137,38 @@ def __init__(self,
164137
self.scalar_summaries = tf.summary.merge(
165138
tf.get_collection("summary_train"))
166139

140+
def _compute_regularization(self) -> List[tf.Tensor]:
141+
with tf.name_scope("regularization"):
142+
regularizable = [v for v in tf.trainable_variables()
143+
if not BIAS_REGEX.findall(v.name)
144+
and not v.name.startswith("vgg")
145+
and not v.name.startswith("Inception")
146+
and not v.name.startswith("resnet")]
147+
reg_values = [reg.value(regularizable)
148+
for reg in self.regularizers]
149+
reg_costs = [
150+
reg.weight * reg_value
151+
for reg, reg_value in zip(self.regularizers, reg_values)]
152+
153+
# add unweighted regularization values
154+
self.losses += reg_values
155+
156+
# we always want to include l1 and l2 values in the summary
157+
if L1Regularizer not in [type(r) for r in self.regularizers]:
158+
l1_reg = L1Regularizer(name="train_l1", weight=0.)
159+
tf.summary.scalar(l1_reg.name, l1_reg.value(regularizable),
160+
collections=["summary_train"])
161+
if L2Regularizer not in [type(r) for r in self.regularizers]:
162+
l2_reg = L2Regularizer(name="train_l2", weight=0.)
163+
tf.summary.scalar(l2_reg.name, l2_reg.value(regularizable),
164+
collections=["summary_train"])
165+
166+
for reg, reg_value in zip(self.regularizers, reg_values):
167+
tf.summary.scalar(reg.name, reg_value,
168+
collections=["summary_train"])
169+
170+
return reg_costs
171+
167172
def _get_gradients(self, tensor: tf.Tensor) -> Gradients:
168173
gradient_list = self.optimizer.compute_gradients(tensor, self.var_list)
169174
return gradient_list
@@ -181,6 +186,20 @@ def get_executable(
181186
self.histogram_summaries if summaries else None)
182187

183188

189+
def _get_var_list(var_scopes, var_collection) -> List[tf.Variable]:
190+
if var_collection is None:
191+
var_collection = tf.GraphKeys.TRAINABLE_VARIABLES
192+
193+
if var_scopes is None:
194+
var_lists = [tf.get_collection(var_collection)]
195+
else:
196+
var_lists = [tf.get_collection(var_collection, scope)
197+
for scope in var_scopes]
198+
199+
# Flatten the list of lists
200+
return [var for var_list in var_lists for var in var_list]
201+
202+
184203
def _sum_gradients(gradients_list: List[Gradients]) -> Gradients:
185204
summed_dict = {} # type: Dict[tf.Variable, tf.Tensor]
186205
for gradients in gradients_list:

neuralmonkey/trainers/regularizers.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616

1717
class Regularizer(metaclass=ABCMeta):
18-
"""Base clas s for regularizers.
18+
"""Base class for regularizers.
1919
2020
Regularizer objects are used to introduce additional loss terms to
21-
the trainerthus constraining the model variable during training. These
22-
loss terms have an adjustable weight allowing to set the ``importance''
21+
the trainer, thus constraining the model variable during training. These
22+
loss terms have an adjustable weight allowing to set the "importance"
2323
of the term.
2424
"""
2525

@@ -31,7 +31,7 @@ def __init__(self,
3131
Arguments:
3232
name: Regularizer name.
3333
weight: Weight of the regularization term (usually expressed
34-
as ``lambda'' in the literature).
34+
as "lambda" in the literature).
3535
"""
3636
self._name = name
3737
self._weight = weight
@@ -64,7 +64,7 @@ def __init__(self,
6464
6565
Arguments:
6666
name: Regularizer name.
67-
weight: Weight of the regularization term (default=1.0e-8.
67+
weight: Weight of the regularization term.
6868
"""
6969
Regularizer.__init__(self, name, weight)
7070

@@ -95,9 +95,8 @@ class EWCRegularizer(Regularizer):
9595
9696
Implements Elastic Weight Consolidation from the "Overcoming catastrophic
9797
forgetting in neural networks" paper.
98-
The regularizer applies separate regularization weight to each trainable
99-
variable based on how important the variable was for the previously
100-
learned task.
98+
The regularizer applies a separate regularization weight to each trainable
99+
variable based on its importance for the previously learned task.
101100
102101
https://arxiv.org/pdf/1612.00796.pdf
103102
"""
@@ -120,8 +119,8 @@ def __init__(self,
120119
check_argument_types()
121120
Regularizer.__init__(self, name, weight)
122121

123-
log("Loading initial variables for EWC from "
124-
"{}.".format(variables_file))
122+
log("Loading initial variables for EWC from {}."
123+
.format(variables_file))
125124
self.init_vars = tf.contrib.framework.load_checkpoint(variables_file)
126125
log("EWC initial variables loaded.")
127126

@@ -132,15 +131,14 @@ def __init__(self,
132131
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
133132
ewc_value = tf.constant(0.0)
134133
for var in variables:
135-
var_name = var.name
136-
init_var_name = var_name.split(":")[0]
137-
if (var_name in self.gradients.files
134+
init_var_name = var.name.split(":")[0]
135+
if (var.name in self.gradients.files
138136
and self.init_vars.has_tensor(init_var_name)):
139137
init_var = tf.constant(
140138
self.init_vars.get_tensor(init_var_name),
141139
name="{}_init_value".format(init_var_name))
142140
grad_squared = tf.constant(
143-
np.square(self.gradients[var_name]),
141+
np.square(self.gradients[var.name]),
144142
name="{}_ewc_weight".format(init_var_name))
145143
ewc_value += tf.reduce_sum(tf.multiply(
146144
grad_squared, tf.square(var - init_var)))

0 commit comments

Comments
 (0)