@@ -20,6 +20,12 @@ class BaseRegularizer:
2020 def __init__ (self ,
2121 name : str ,
2222 weight : float ) -> None :
23+ """Create the regularizer.
24+
25+ Arguments:
26+ name: Regularizer name.
27+ weight: Weight of the regularization term.
28+ """
2329 check_argument_types ()
2430
2531 self ._name = name
@@ -38,21 +44,35 @@ def value(self, variables) -> float:
3844
3945
4046class L1Regularizer (BaseRegularizer ):
47+ """L1 regularizer."""
4148
4249 def __init__ (self ,
4350 name : str = "train_l1" ,
4451 weight : float = 1.0e-8 ) -> None :
52+ """Create the regularizer.
53+
54+ Arguments:
55+ name: Regularizer name.
56+ weight: Weight of the regularization term.
57+ """
4558 BaseRegularizer .__init__ (self , name , weight )
4659
4760 def value (self , variables : List [tf .Tensor ]) -> float :
4861 return sum (tf .reduce_sum (abs (v )) for v in variables )
4962
5063
5164class L2Regularizer (BaseRegularizer ):
65+ """L2 regularizer."""
5266
5367 def __init__ (self ,
5468 name : str = "train_l2" ,
5569 weight : float = 1.0e-8 ) -> None :
70+ """Create the regularizer.
71+
72+ Arguments:
73+ name: Regularizer name.
74+ weight: Weight of the regularization term.
75+ """
5676 BaseRegularizer .__init__ (self , name , weight )
5777
5878 def value (self , variables : List [tf .Tensor ]) -> float :
@@ -62,14 +82,27 @@ def value(self, variables: List[tf.Tensor]) -> float:
6282class EWCRegularizer (BaseRegularizer ):
6383 """Regularizer based on the Elastic Weight Consolidation.
6484
65- TODO description
85+ Implements Elastic Weight Consolidation from the "Overcoming catastrophic
86+ forgetting in neural networks" paper.
87+
88+ https://arxiv.org/pdf/1612.00796.pdf
6689 """
6790
6891 def __init__ (self ,
6992 name : str = "train_ewc" ,
7093 weight : float = 0. ,
7194 gradients_file : str = None ,
7295 variables_file : str = None ) -> None :
96+ """Create the regularizer.
97+
98+ Arguments:
99+ name: Regularizer name.
100+ weight: Weight of the regularization term.
101+ gradients_file: File containing the gradient estimates
102+ from the previous task.
103+ variables_files: File containing the variables learned
104+ on the previous task.
105+ """
73106 check_argument_types ()
74107
75108 BaseRegularizer .__init__ (self , name , weight )
@@ -88,13 +121,16 @@ def __init__(self,
88121 log ("Gradient estimates loaded" )
89122
90123 def value (self , variables : List [tf .Tensor ]) -> float :
91- ewc_value = 0.0
124+ ewc_value = tf . constant ( 0.0 )
92125 for var in variables :
93126 var_name = var .name .split (":" )[0 ]
94- init_var = self .init_vars .get_tensor (var_name )
95- gradient = self .gradients [var_name ]
96- ewc_value += tf .reduce_sum (tf .multiply (
97- tf .square (gradient ), tf .square (var - init_var )))
127+ if (var_name in self .gradients .files
128+ and self .init_vars .has_tensor (var_name )):
129+ init_var = self .init_vars .get_tensor (var_name )
130+ gradient = tf .constant (
131+ self .gradients [var_name ], name = "ewc_gradients" )
132+ ewc_value += tf .reduce_sum (tf .multiply (
133+ tf .square (gradient ), tf .square (var - init_var )))
98134
99135 return ewc_value
100136
0 commit comments