@@ -18,24 +18,34 @@ def unflatten(vector, shapes):
18
18
19
19
20
20
class TensorFlowVariables (object ):
21
- """An object used to extract variables from a loss function.
22
-
23
- This object also provides methods for getting and setting the weights of
24
- the relevant variables.
21
+ """A class used to set and get weights for Tensorflow networks.
25
22
26
23
Attributes:
27
24
sess (tf.Session): The tensorflow session used to run assignment.
28
- loss: The loss function passed in by the user.
29
- variables (List[tf.Variable]): Extracted variables from the loss.
30
- assignment_placeholders (List[tf.placeholders]): The nodes that weights
31
- get passed to.
32
- assignment _nodes (List[tf.Tensor]): The nodes that assign the weights.
25
+ variables (Dict[str, tf.Variable]): Extracted variables from the loss
26
+ or additional variables that are passed in.
27
+ placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
28
+ assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
33
29
"""
34
- def __init__ (self , loss , sess = None ):
35
- """Creates a TensorFlowVariables instance."""
30
+ def __init__ (self , loss , sess = None , input_variables = None ):
31
+ """Creates TensorFlowVariables containing extracted variables.
32
+
33
+ The variables are extracted by performing a BFS search on the
34
+ dependency graph with loss as the root node. After the tree is
35
+ traversed and those variables are collected, we append input_variables
36
+ to the collected variables. For each variable in the list, the
37
+ variable has a placeholder and assignment operation created for it.
38
+
39
+ Args:
40
+ loss (tf.Operation): The tensorflow operation to extract all
41
+ variables from.
42
+ sess (tf.Session): Session used for running the get and set
43
+ methods.
44
+ input_variables (List[tf.Variables]): Variables to include in the
45
+ list.
46
+ """
36
47
import tensorflow as tf
37
48
self .sess = sess
38
- self .loss = loss
39
49
queue = deque ([loss ])
40
50
variable_names = []
41
51
explored_inputs = set ([loss ])
@@ -44,9 +54,10 @@ def __init__(self, loss, sess=None):
44
54
# the variables.
45
55
while len (queue ) != 0 :
46
56
tf_obj = queue .popleft ()
47
-
48
- # The object put into the queue is not necessarily an operation, so
49
- # we want the op attribute to get the operation underlying the
57
+ if tf_obj is None :
58
+ continue
59
+ # The object put into the queue is not necessarily an operation,
60
+ # so we want the op attribute to get the operation underlying the
50
61
# object. Only operations contain the inputs that we can explore.
51
62
if hasattr (tf_obj , "op" ):
52
63
tf_obj = tf_obj .op
@@ -63,23 +74,37 @@ def __init__(self, loss, sess=None):
63
74
if "Variable" in tf_obj .node_def .op :
64
75
variable_names .append (tf_obj .node_def .name )
65
76
self .variables = OrderedDict ()
66
- for v in [v for v in tf .global_variables ()
67
- if v .op .node_def .name in variable_names ]:
77
+ variable_list = [v for v in tf .global_variables ()
78
+ if v .op .node_def .name in variable_names ]
79
+ if input_variables is not None :
80
+ variable_list += input_variables
81
+ for v in variable_list :
68
82
self .variables [v .op .node_def .name ] = v
83
+
69
84
self .placeholders = dict ()
70
- self .assignment_nodes = []
85
+ self .assignment_nodes = dict ()
71
86
72
87
# Create new placeholders to put in custom weights.
73
88
for k , var in self .variables .items ():
74
89
self .placeholders [k ] = tf .placeholder (var .value ().dtype ,
75
- var .get_shape ().as_list ())
76
- self .assignment_nodes .append (var .assign (self .placeholders [k ]))
90
+ var .get_shape ().as_list (),
91
+ name = "Placeholder_" + k )
92
+ self .assignment_nodes [k ] = var .assign (self .placeholders [k ])
77
93
78
94
def set_session (self , sess ):
79
- """Modifies the current session used by the class."""
95
+ """Sets the current session used by the class.
96
+
97
+ Args:
98
+ sess (tf.Session): Session to set the attribute with.
99
+ """
80
100
self .sess = sess
81
101
82
102
def get_flat_size (self ):
103
+ """Returns the total length of all of the flattened variables.
104
+
105
+ Returns:
106
+ The length of all flattened variables concatenated.
107
+ """
83
108
return sum ([np .prod (v .get_shape ().as_list ())
84
109
for v in self .variables .values ()])
85
110
@@ -91,31 +116,64 @@ def _check_sess(self):
91
116
"calling set_session(sess)." )
92
117
93
118
def get_flat (self ):
94
- """Gets the weights and returns them as a flat array."""
119
+ """Gets the weights and returns them as a flat array.
120
+
121
+ Returns:
122
+ 1D Array containing the flattened weights.
123
+ """
95
124
self ._check_sess ()
96
125
return np .concatenate ([v .eval (session = self .sess ).flatten ()
97
126
for v in self .variables .values ()])
98
127
99
128
def set_flat (self , new_weights ):
100
- """Sets the weights to new_weights, converting from a flat array."""
129
+ """Sets the weights to new_weights, converting from a flat array.
130
+
131
+ Note:
132
+ You can only set all weights in the network using this function,
133
+ i.e., the length of the array must match get_flat_size.
134
+
135
+ Args:
136
+ new_weights (np.ndarray): Flat array containing weights.
137
+ """
101
138
self ._check_sess ()
102
139
shapes = [v .get_shape ().as_list () for v in self .variables .values ()]
103
140
arrays = unflatten (new_weights , shapes )
104
- placeholders = [self .placeholders [k ]
105
- for k , v in self .variables .items ()]
106
- self .sess .run (self .assignment_nodes ,
141
+ placeholders = [self .placeholders [k ] for k , v
142
+ in self .variables .items ()]
143
+ self .sess .run (list ( self .assignment_nodes . values ()) ,
107
144
feed_dict = dict (zip (placeholders , arrays )))
108
145
109
146
def get_weights (self ):
110
- """Returns a list of the weights of the loss function variables."""
147
+ """Returns a dictionary containing the weights of the network.
148
+
149
+ Returns:
150
+ Dictionary mapping variable names to their weights.
151
+ """
111
152
self ._check_sess ()
112
- return {k : v .eval (session = self .sess )
113
- for k , v in self .variables .items ()}
153
+ return {k : v .eval (session = self .sess ) for k , v
154
+ in self .variables .items ()}
114
155
115
156
def set_weights (self , new_weights ):
116
- """Sets the weights to new_weights."""
157
+ """Sets the weights to new_weights.
158
+
159
+ Note:
160
+ Can set subsets of variables as well, by only passing in the
161
+ variables you want to be set.
162
+
163
+ Args:
164
+ new_weights (Dict): Dictionary mapping variable names to their
165
+ weights.
166
+ """
117
167
self ._check_sess ()
118
- self .sess .run (self .assignment_nodes ,
168
+ assign_list = [self .assignment_nodes [name ]
169
+ for name in new_weights .keys ()
170
+ if name in self .assignment_nodes ]
171
+ assert assign_list , ("No variables in the input matched those in the "
172
+ "network. Possible cause: Two networks were "
173
+ "defined in the same TensorFlow graph. To fix "
174
+ "this, place each network definition in its own "
175
+ "tf.Graph." )
176
+ self .sess .run (assign_list ,
119
177
feed_dict = {self .placeholders [name ]: value
120
178
for (name , value ) in new_weights .items ()
121
179
if name in self .placeholders })
0 commit comments