Skip to content

Commit 3a19ab7

Browse files
committed
addressing PR reviews
1 parent aaa8e8a commit 3a19ab7

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

neuralmonkey/runners/gradient_runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ def collect_results(self, results: List[Dict]) -> None:
4646

4747

4848
class GradientRunner(BaseRunner[SupportedDecoder]):
49+
"""Runner for fetching gradients computed over the dataset.
50+
51+
Gradient runner applies provided trainer on a desired dataset
52+
and uses it to compute gradients over the gold data. It is currently
53+
used to gather gradients for Elastic Weight Consolidation.
54+
55+
(https://arxiv.org/pdf/1612.00796.pdf)
56+
"""
4957

5058
def __init__(self,
5159
output_series: str,

neuralmonkey/trainers/cross_entropy_trainer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import tensorflow as tf
44
from typeguard import check_argument_types
55

6+
from neuralmonkey.logging import warn
67
from neuralmonkey.trainers.generic_trainer import (
78
GenericTrainer, Objective, ObjectiveWeight)
8-
from neuralmonkey.trainers.regularizers import BaseRegularizer
9+
from neuralmonkey.trainers.regularizers import (
10+
BaseRegularizer, L1Regularizer, L2Regularizer)
911

1012

1113
def xent_objective(decoder, weight=None) -> Objective:
@@ -29,13 +31,29 @@ def __init__(self,
2931
clip_norm: float = None,
3032
optimizer: tf.train.Optimizer = None,
3133
regularizers: List[BaseRegularizer] = None,
34+
l1_weight: float = 0.,
35+
l2_weight: float = 0.,
3236
var_scopes: List[str] = None,
3337
var_collection: str = None) -> None:
3438
check_argument_types()
3539

3640
if decoder_weights is None:
3741
decoder_weights = [None for _ in decoders]
3842

43+
if regularizers is None:
44+
regularizers = []
45+
if l1_weight > 0.:
46+
if L1Regularizer in [type(r) for r in regularizers]:
47+
warn("You specified both trainer l1_weight "
48+
"and a L1Regularizer object in your config")
49+
regularizers.append(L1Regularizer(weight=l1_weight))
50+
51+
if l2_weight > 0.:
52+
if L2Regularizer in [type(r) for r in regularizers]:
53+
warn("You specified both trainer l2_weight "
54+
"and a L2Regularizer object in your config")
55+
regularizers.append(L2Regularizer(weight=l2_weight))
56+
3957
if len(decoder_weights) != len(decoders):
4058
raise ValueError(
4159
"decoder_weights (length {}) do not match decoders (length {})"

neuralmonkey/trainers/regularizers.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class BaseRegularizer:
2020
def __init__(self,
2121
name: str,
2222
weight: float) -> None:
23+
"""Create the regularizer.
24+
25+
Arguments:
26+
name: Regularizer name.
27+
weight: Weight of the regularization term.
28+
"""
2329
check_argument_types()
2430

2531
self._name = name
@@ -38,21 +44,35 @@ def value(self, variables) -> float:
3844

3945

4046
class L1Regularizer(BaseRegularizer):
47+
"""L1 regularizer."""
4148

4249
def __init__(self,
4350
name: str = "train_l1",
4451
weight: float = 1.0e-8) -> None:
52+
"""Create the regularizer.
53+
54+
Arguments:
55+
name: Regularizer name.
56+
weight: Weight of the regularization term.
57+
"""
4558
BaseRegularizer.__init__(self, name, weight)
4659

4760
def value(self, variables: List[tf.Tensor]) -> float:
4861
return sum(tf.reduce_sum(abs(v)) for v in variables)
4962

5063

5164
class L2Regularizer(BaseRegularizer):
65+
"""L2 regularizer."""
5266

5367
def __init__(self,
5468
name: str = "train_l2",
5569
weight: float = 1.0e-8) -> None:
70+
"""Create the regularizer.
71+
72+
Arguments:
73+
name: Regularizer name.
74+
weight: Weight of the regularization term.
75+
"""
5676
BaseRegularizer.__init__(self, name, weight)
5777

5878
def value(self, variables: List[tf.Tensor]) -> float:
@@ -62,14 +82,27 @@ def value(self, variables: List[tf.Tensor]) -> float:
6282
class EWCRegularizer(BaseRegularizer):
6383
"""Regularizer based on the Elastic Weight Consolidation.
6484
65-
TODO description
85+
Implements Elastic Weight Consolidation from the "Overcoming catastrophic
86+
forgetting in neural networks" paper.
87+
88+
https://arxiv.org/pdf/1612.00796.pdf
6689
"""
6790

6891
def __init__(self,
6992
name: str = "train_ewc",
7093
weight: float = 0.,
7194
gradients_file: str = None,
7295
variables_file: str = None) -> None:
96+
"""Create the regularizer.
97+
98+
Arguments:
99+
name: Regularizer name.
100+
weight: Weight of the regularization term.
101+
gradients_file: File containing the gradient estimates
102+
from the previous task.
103+
variables_files: File containing the variables learned
104+
on the previous task.
105+
"""
73106
check_argument_types()
74107

75108
BaseRegularizer.__init__(self, name, weight)
@@ -88,13 +121,16 @@ def __init__(self,
88121
log("Gradient estimates loaded")
89122

90123
def value(self, variables: List[tf.Tensor]) -> float:
91-
ewc_value = 0.0
124+
ewc_value = tf.constant(0.0)
92125
for var in variables:
93126
var_name = var.name.split(":")[0]
94-
init_var = self.init_vars.get_tensor(var_name)
95-
gradient = self.gradients[var_name]
96-
ewc_value += tf.reduce_sum(tf.multiply(
97-
tf.square(gradient), tf.square(var - init_var)))
127+
if (var_name in self.gradients.files
128+
and self.init_vars.has_tensor(var_name)):
129+
init_var = self.init_vars.get_tensor(var_name)
130+
gradient = tf.constant(
131+
self.gradients[var_name], name="ewc_gradients")
132+
ewc_value += tf.reduce_sum(tf.multiply(
133+
tf.square(gradient), tf.square(var - init_var)))
98134

99135
return ewc_value
100136

0 commit comments

Comments
 (0)